mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge remote-tracking branch 'origin/main' into omni-java
# Conflicts: # .claude/rules/architecture.md # .claude/rules/code-style.md # .github/workflows/claude.yml # .github/workflows/duplicate-code-detector.yml # codeflash/api/aiservice.py # codeflash/cli_cmds/console.py # codeflash/cli_cmds/logging_config.py # codeflash/code_utils/deduplicate_code.py # codeflash/discovery/discover_unit_tests.py # codeflash/languages/base.py # codeflash/languages/code_replacer.py # codeflash/languages/javascript/mocha_runner.py # codeflash/languages/javascript/support.py # codeflash/languages/python/support.py # codeflash/optimization/function_optimizer.py # codeflash/verification/parse_test_output.py # codeflash/verification/verification_utils.py # codeflash/verification/verifier.py # packages/codeflash/package-lock.json # packages/codeflash/package.json # tests/languages/javascript/test_support_dispatch.py # tests/test_codeflash_capture.py # tests/test_languages/test_javascript_test_runner.py # tests/test_multi_file_code_replacement.py
This commit is contained in:
commit
eceac13fc3
85 changed files with 1230 additions and 2771 deletions
|
|
@ -9,3 +9,5 @@
|
|||
- **Naming**: NEVER use leading underscores (`_function_name`) - Python has no true private functions, use public names
|
||||
- **Paths**: Always use absolute paths
|
||||
- **Encoding**: Always pass `encoding="utf-8"` to `open()`, `read_text()`, `write_text()`, etc. in new or changed code — Windows defaults to `cp1252` which breaks on non-ASCII content. Don't flag pre-existing code that lacks it unless you're already modifying that line.
|
||||
- **Pre-commit**: Run `uv run prek` before committing — fix any issues before creating the commit
|
||||
- **Pre-push**: Before pushing, run `uv run prek run --from-ref origin/<base>` to check all changed files against the PR base — this matches CI behavior and catches issues that per-commit prek misses. To detect the base branch: `gh pr view --json baseRefName -q .baseRefName 2>/dev/null || echo main`
|
||||
|
|
|
|||
89
.github/workflows/claude.yml
vendored
89
.github/workflows/claude.yml
vendored
|
|
@ -1,8 +1,20 @@
|
|||
name: Claude Code
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
types: [opened, synchronize, ready_for_review, reopened]
|
||||
paths-ignore:
|
||||
- '.github/workflows/**'
|
||||
- '*.md'
|
||||
- 'docs/**'
|
||||
- 'demos/**'
|
||||
- 'experiments/**'
|
||||
- 'LICENSE'
|
||||
- '.tessl/**'
|
||||
- 'code_to_optimize/**'
|
||||
- 'codeflash.code-workspace'
|
||||
- 'uv.lock'
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_review_comment:
|
||||
|
|
@ -16,10 +28,16 @@ jobs:
|
|||
# Automatic PR review (can fix linting issues and push)
|
||||
# Blocked for fork PRs to prevent malicious code execution
|
||||
pr-review:
|
||||
concurrency:
|
||||
group: pr-review-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
if: |
|
||||
(
|
||||
github.event_name == 'pull_request' &&
|
||||
github.actor != 'claude[bot]' &&
|
||||
github.event.sender.login != 'claude[bot]' &&
|
||||
github.event.pull_request.head.repo.full_name == github.repository
|
||||
) ||
|
||||
github.event_name == 'workflow_dispatch'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
|
|
@ -32,7 +50,7 @@ jobs:
|
|||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
ref: ${{ github.event.pull_request.head.ref || github.ref }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
|
@ -54,7 +72,9 @@ jobs:
|
|||
with:
|
||||
use_bedrock: "true"
|
||||
use_sticky_comment: true
|
||||
track_progress: true
|
||||
allowed_bots: "claude[bot],codeflash-ai[bot]"
|
||||
exclude_comments_by_actor: "*[bot]"
|
||||
prompt: |
|
||||
<context>
|
||||
repo: ${{ github.repository }}
|
||||
|
|
@ -68,6 +88,20 @@ jobs:
|
|||
Post all review findings in a single summary comment only — never as inline PR review comments.
|
||||
</commitment>
|
||||
|
||||
<step name="triage">
|
||||
Before doing any work, assess the PR scope:
|
||||
|
||||
1. Run `gh pr diff ${{ github.event.pull_request.number }} --name-only` to get changed files.
|
||||
2. Classify as TRIVIAL if ALL changed files are:
|
||||
- Config/CI files (.github/, .tessl/, *.toml, *.lock, *.json, *.yml, *.yaml)
|
||||
- Documentation (*.md, docs/)
|
||||
- Non-production code (demos/, experiments/, code_to_optimize/)
|
||||
- Only whitespace, formatting, or comment changes
|
||||
|
||||
If TRIVIAL: post a single comment "No substantive code changes to review." and stop — do not execute any further steps.
|
||||
Otherwise: continue with the full review below.
|
||||
</step>
|
||||
|
||||
<step name="lint_and_typecheck">
|
||||
Run checks on files changed in this PR and auto-fix what you can.
|
||||
|
||||
|
|
@ -109,6 +143,33 @@ jobs:
|
|||
Record findings for the summary comment. Refer to CLAUDE.md for project conventions.
|
||||
</step>
|
||||
|
||||
<step name="duplicate_detection">
|
||||
Check whether this PR introduces code that duplicates logic already present elsewhere in the repository — including across languages. Focus on finding true duplicates, not just similar-looking code.
|
||||
|
||||
1. Get changed source files (excluding tests and config):
|
||||
`git diff --name-only origin/main...HEAD -- '*.py' '*.js' '*.ts' '*.java' | grep -v -E '(test_|_test\.(py|js|ts)|\.test\.(js|ts)|\.spec\.(js|ts)|conftest\.py|/tests/|/test/|/__tests__/)' | grep -v -E '^(\.github/|code_to_optimize/|\.tessl/|node_modules/)'`
|
||||
|
||||
2. For each changed file, read it and identify functions/methods added or substantially modified (longer than 5 lines).
|
||||
|
||||
3. Search for duplicates using Grep:
|
||||
- Same function name defined elsewhere
|
||||
- 2-3 distinctive operations from the body (specific API calls, algorithm patterns, string literals)
|
||||
|
||||
4. Cross-module check: this codebase has parallel modules under `languages/python/`, `languages/javascript/`, and `languages/java/` plus runtimes under `packages/codeflash/runtime/` and `codeflash-java-runtime/`. When a changed file is under one of these areas, search the others for equivalent logic. Only flag cases where the logic is genuinely shared or one module could import from the other.
|
||||
|
||||
5. When a Grep hit looks promising, read the full function and compare semantics. Flag only:
|
||||
- Same function with same/very similar body in another module
|
||||
- Same helper logic repeated in sibling files
|
||||
- Same logic implemented inline across multiple classes
|
||||
- Same algorithm reimplemented across language modules (Python code, not target-language differences)
|
||||
|
||||
Report at most 5 findings with confidence (HIGH/MEDIUM), locations, what's duplicated, and suggestion.
|
||||
|
||||
DO NOT report: boilerplate, functions under 5 lines, config/setup, intentional polymorphism, test files, imports, code that must differ due to target-language semantics.
|
||||
|
||||
If no duplicates found, include "No duplicates detected" in the summary.
|
||||
</step>
|
||||
|
||||
<step name="coverage">
|
||||
Analyze test coverage for changed files:
|
||||
|
||||
|
|
@ -120,30 +181,17 @@ jobs:
|
|||
</step>
|
||||
|
||||
<step name="summary_comment">
|
||||
Post exactly one summary comment containing all results from previous steps.
|
||||
Post exactly one summary comment containing all results from previous steps using this format:
|
||||
|
||||
To ensure one comment: find an existing claude[bot] comment and update it, or create one if none exists.
|
||||
Delete any duplicate claude[bot] comments.
|
||||
|
||||
```
|
||||
gh api repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments --jq '.[] | select(.user.login == "claude[bot]") | .id' | head -1
|
||||
```
|
||||
|
||||
Format:
|
||||
## PR Review Summary
|
||||
### Prek Checks
|
||||
### Code Review
|
||||
### Duplicate Detection
|
||||
### Test Coverage
|
||||
---
|
||||
*Last updated: <timestamp>*
|
||||
</step>
|
||||
|
||||
<step name="simplify">
|
||||
Run /simplify to review recently changed code for reuse, quality, and efficiency opportunities.
|
||||
If improvements are found, commit with "refactor: simplify <description>" and push.
|
||||
Only make behavior-preserving changes.
|
||||
</step>
|
||||
|
||||
<step name="merge_optimization_prs">
|
||||
Check for open PRs from codeflash-ai[bot]:
|
||||
`gh pr list --author "codeflash-ai[bot]" --state open --json number,title,headRefName,createdAt,mergeable`
|
||||
|
|
@ -165,12 +213,15 @@ jobs:
|
|||
- All findings are in a single summary comment (no inline review comments were created)
|
||||
- If fixes were made, they were verified with prek
|
||||
</verification>
|
||||
claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit,Skill"'
|
||||
claude_args: '--model us.anthropic.claude-sonnet-4-6 --allowedTools "Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh pr close:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"'
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
|
||||
# @claude mentions (can edit and push) - restricted to maintainers only
|
||||
claude-mention:
|
||||
concurrency:
|
||||
group: claude-mention-${{ github.event.issue.number || github.event.pull_request.number || github.run_id }}
|
||||
cancel-in-progress: false
|
||||
if: |
|
||||
(
|
||||
github.event_name == 'issue_comment' &&
|
||||
|
|
@ -240,6 +291,6 @@ jobs:
|
|||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
use_bedrock: "true"
|
||||
claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"'
|
||||
claude_args: '--model us.anthropic.claude-sonnet-4-6 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"'
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
|
|
|
|||
119
.github/workflows/duplicate-code-detector.yml
vendored
119
.github/workflows/duplicate-code-detector.yml
vendored
|
|
@ -1,119 +0,0 @@
|
|||
name: Duplicate Code Detector
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
|
||||
jobs:
|
||||
detect-duplicates:
|
||||
if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name == 'workflow_dispatch'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
issues: write
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.pull_request.head.ref || github.ref }}
|
||||
|
||||
- name: Configure AWS Credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }}
|
||||
aws-region: ${{ secrets.AWS_REGION }}
|
||||
|
||||
- name: Get changed source files
|
||||
id: changed-files
|
||||
run: |
|
||||
FILES=$(git diff --name-only origin/main...HEAD -- '*.py' '*.js' '*.ts' '*.java' \
|
||||
| grep -v -E '(test_|_test\.(py|js|ts)|\.test\.(js|ts)|\.spec\.(js|ts)|conftest\.py|/tests/|/test/|/__tests__/)' \
|
||||
| grep -v -E '^(\.github/|code_to_optimize/|\.tessl/|node_modules/)' \
|
||||
|| true)
|
||||
if [ -z "$FILES" ]; then
|
||||
echo "files=" >> "$GITHUB_OUTPUT"
|
||||
echo "No changed source files to analyze."
|
||||
else
|
||||
echo "files<<EOF" >> "$GITHUB_OUTPUT"
|
||||
echo "$FILES" >> "$GITHUB_OUTPUT"
|
||||
echo "EOF" >> "$GITHUB_OUTPUT"
|
||||
echo "Changed files:"
|
||||
echo "$FILES"
|
||||
fi
|
||||
|
||||
- name: Run Claude Code
|
||||
if: steps.changed-files.outputs.files != ''
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
use_bedrock: "true"
|
||||
use_sticky_comment: true
|
||||
allowed_bots: "claude[bot],codeflash-ai[bot]"
|
||||
claude_args: '--allowedTools "Read,Glob,Grep,Bash(git diff:*),Bash(git log:*),Bash(git show:*),Bash(wc *),Bash(gh pr comment:*)"'
|
||||
prompt: |
|
||||
REPO: ${{ github.repository }}
|
||||
PR NUMBER: ${{ github.event.pull_request.number }}
|
||||
|
||||
You are a duplicate code detector for a multi-language codebase (Python, JavaScript, TypeScript, Java). Check whether this PR introduces code that duplicates logic already present elsewhere in the repository — including across languages. Focus on finding true duplicates, not just similar-looking code.
|
||||
|
||||
## Changed files
|
||||
|
||||
```
|
||||
${{ steps.changed-files.outputs.files }}
|
||||
```
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Read changed files.** For each file above, read it and identify functions or methods that were added or substantially modified (longer than 5 lines).
|
||||
|
||||
2. **Search for duplicates.** For each function, use Grep to search the codebase for:
|
||||
- The same function name defined elsewhere (`def function_name` for Python, `function function_name` / `const function_name` / `module.exports` for the JS files under `packages/`)
|
||||
- 2-3 distinctive operations from the body (specific API calls, algorithm patterns, string literals, exception types) — this catches duplicates that have different names but implement the same logic
|
||||
|
||||
3. **Cross-module check.** This codebase has parallel Python modules under `languages/python/`, `languages/javascript/`, and `languages/java/` that handle the same concerns (parsing, code replacement, test running, etc.) for different target languages. It also has a JS runtime under `packages/codeflash/runtime/` and a Java runtime under `codeflash-java-runtime/`. When a changed file is under one of these areas, also search the others for equivalent logic. For example:
|
||||
- `languages/javascript/code_replacer.py` and `languages/python/static_analysis/code_replacer.py` both handle code replacement — shared logic should be extracted
|
||||
- Shared concepts (AST traversal, scope analysis, import resolution, test running) are prime candidates for duplication across these modules
|
||||
|
||||
4. **Compare candidates.** When a Grep hit looks promising (not just a shared import or call site), read the full function and compare semantics. Flag it only if it matches one of these patterns:
|
||||
- **Same function in two modules** — a function with the same or very similar body exists in another module. One should import from the other instead (within the same language).
|
||||
- **Shared logic across sibling files** — the same helper logic repeated in files within the same package. Should be extracted to a common module.
|
||||
- **Repeated pattern across classes** — multiple classes implement the same logic inline (e.g., identical traversal, identical validation). Should be a mixin or shared helper.
|
||||
- **Cross-module reimplementation** — the same algorithm or utility implemented in both `languages/python/` and `languages/javascript/` (both are Python) or between Python orchestration code and JS runtime code in `packages/`. Note: some duplication is unavoidable (each target language needs its own parser, for example). Only flag cases where the logic is genuinely shared or where one module could import from the other.
|
||||
|
||||
5. **Report findings.** Post a single PR comment. Report at most 5 findings.
|
||||
|
||||
**If duplicates found**, for each one:
|
||||
- **Confidence**: HIGH (identical or near-identical logic) / MEDIUM (same intent, minor differences worth reviewing)
|
||||
- **Locations**: `file_path:line_number` for both the new and existing code
|
||||
- **What's duplicated**: One sentence describing the shared logic
|
||||
- **Suggestion**: How to consolidate — import from canonical location, extract to shared module, create a mixin. For cross-module duplicates (between language directories or Python↔JS runtime), just flag it for a tech lead to review rather than prescribing a specific fix.
|
||||
|
||||
**If no duplicates found**, post a comment that just says "No duplicates detected." so the sticky comment gets updated.
|
||||
|
||||
## Examples (illustrative — these are past cases, some already resolved)
|
||||
|
||||
**IS a duplicate (HIGH):** A 12-line `is_build_output_dir()` function was defined identically in two modules (`setup/detector.py` and `code_utils/config_js.py`). Fix: delete one, import from the other.
|
||||
|
||||
**IS a duplicate (MEDIUM):** `is_assignment_used()` was implemented separately in two context files with the same logic. Fix: move to a shared module, import from both call sites.
|
||||
|
||||
**IS a duplicate (MEDIUM, cross-module):** `normalize_path()` implemented in both `languages/python/support.py` and `languages/javascript/support.py` with identical logic. Flagging for tech lead review — should likely be extracted to `languages/base.py` or a shared utility.
|
||||
|
||||
**NOT a duplicate:** Two classes each define a `visit()` method that traverses an AST, but they handle different node types and produce different outputs. This is intentional polymorphism.
|
||||
|
||||
**NOT a duplicate (cross-module):** `languages/python/static_analysis/code_extractor.py` and `languages/javascript/parse.py` both extract functions from source code, but they use fundamentally different parsing strategies (Python AST vs tree-sitter). The logic is necessarily different.
|
||||
|
||||
## DO NOT report
|
||||
|
||||
- Standard boilerplate (`__init__`, `__repr__`, `__str__`, `__eq__`, simple property accessors, constructors)
|
||||
- Functions under 5 lines
|
||||
- Config/setup code that naturally has similar structure
|
||||
- Intentional polymorphism (same method name, genuinely different behavior)
|
||||
- Test files, conftest files, spec files
|
||||
- Import statements and logging setup
|
||||
- Files under `.github/`, `code_to_optimize/`, `.tessl/`
|
||||
- Code across language modules that must differ due to target-language semantics (parsers, AST node types, runtime-specific APIs)
|
||||
|
||||
Do NOT create issues or edit any files. Only post a PR comment.
|
||||
|
|
@ -1,126 +0,0 @@
|
|||
"""Code deduplication utilities using language-specific normalizers.
|
||||
|
||||
This module provides functions to normalize code, generate fingerprints,
|
||||
and detect duplicate code segments across different programming languages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
from codeflash.code_utils.normalizers import get_normalizer
|
||||
from codeflash.languages import current_language
|
||||
|
||||
|
||||
def normalize_code(
|
||||
code: str, remove_docstrings: bool = True, return_ast_dump: bool = False, language: str | None = None
|
||||
) -> str:
|
||||
"""Normalize code by parsing, cleaning, and normalizing variable names.
|
||||
|
||||
Function names, class names, and parameters are preserved.
|
||||
|
||||
Args:
|
||||
code: Source code as string
|
||||
remove_docstrings: Whether to remove docstrings (Python only)
|
||||
return_ast_dump: Return AST dump instead of unparsed code (Python only)
|
||||
language: Language of the code. If None, uses the current session language.
|
||||
|
||||
Returns:
|
||||
Normalized code as string
|
||||
|
||||
"""
|
||||
if language is None:
|
||||
language = current_language().value
|
||||
|
||||
try:
|
||||
normalizer = get_normalizer(language)
|
||||
|
||||
if return_ast_dump:
|
||||
return normalizer.normalize_for_hash(code)
|
||||
# Only Python normalizer accepts remove_docstrings; pass it via **kwargs
|
||||
# so non-Python normalizers (which don't accept it) still work
|
||||
try:
|
||||
return normalizer.normalize(code, remove_docstrings=remove_docstrings)
|
||||
except TypeError:
|
||||
return normalizer.normalize(code)
|
||||
except ValueError:
|
||||
# Unknown language - fall back to basic normalization
|
||||
return _basic_normalize(code)
|
||||
except Exception:
|
||||
# Parsing error - try other languages or fall back
|
||||
if language == "python":
|
||||
# Try JavaScript as fallback
|
||||
try:
|
||||
js_normalizer = get_normalizer("javascript")
|
||||
js_result = js_normalizer.normalize(code)
|
||||
if js_result != _basic_normalize(code):
|
||||
return js_result
|
||||
except Exception:
|
||||
pass
|
||||
return _basic_normalize(code)
|
||||
|
||||
|
||||
def _basic_normalize(code: str) -> str:
|
||||
"""Basic normalization: remove comments and normalize whitespace."""
|
||||
# Remove single-line comments (// and #)
|
||||
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
|
||||
code = re.sub(r"#.*$", "", code, flags=re.MULTILINE)
|
||||
# Remove multi-line comments
|
||||
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
|
||||
code = re.sub(r'""".*?"""', "", code, flags=re.DOTALL)
|
||||
code = re.sub(r"'''.*?'''", "", code, flags=re.DOTALL)
|
||||
# Normalize whitespace
|
||||
return " ".join(code.split())
|
||||
|
||||
|
||||
def get_code_fingerprint(code: str, language: str | None = None) -> str:
|
||||
"""Generate a fingerprint for normalized code.
|
||||
|
||||
Args:
|
||||
code: Source code
|
||||
language: Language of the code. If None, uses the current session language.
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of normalized code
|
||||
|
||||
"""
|
||||
if language is None:
|
||||
language = current_language().value
|
||||
|
||||
try:
|
||||
normalizer = get_normalizer(language)
|
||||
return normalizer.get_fingerprint(code)
|
||||
except ValueError:
|
||||
# Unknown language - use basic normalization
|
||||
normalized = _basic_normalize(code)
|
||||
return hashlib.sha256(normalized.encode()).hexdigest()
|
||||
|
||||
|
||||
def are_codes_duplicate(code1: str, code2: str, language: str | None = None) -> bool:
|
||||
"""Check if two code segments are duplicates after normalization.
|
||||
|
||||
Args:
|
||||
code1: First code segment
|
||||
code2: Second code segment
|
||||
language: Language of the code. If None, uses the current session language.
|
||||
|
||||
Returns:
|
||||
True if codes are structurally identical (ignoring local variable names)
|
||||
|
||||
"""
|
||||
if language is None:
|
||||
language = current_language().value
|
||||
|
||||
try:
|
||||
normalizer = get_normalizer(language)
|
||||
return normalizer.are_duplicates(code1, code2)
|
||||
except ValueError:
|
||||
# Unknown language - use basic comparison
|
||||
return _basic_normalize(code1) == _basic_normalize(code2)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Re-export for backward compatibility
|
||||
__all__ = ["are_codes_duplicate", "get_code_fingerprint", "normalize_code"]
|
||||
|
|
@ -1,106 +0,0 @@
|
|||
"""Code normalizers for different programming languages.
|
||||
|
||||
This module provides language-specific code normalizers that transform source code
|
||||
into canonical forms for duplicate detection. The normalizers:
|
||||
- Replace local variable names with canonical forms (var_0, var_1, etc.)
|
||||
- Preserve function names, class names, parameters, and imports
|
||||
- Remove or normalize comments and docstrings
|
||||
- Produce consistent output for structurally identical code
|
||||
|
||||
Usage:
|
||||
>>> normalizer = get_normalizer("python")
|
||||
>>> normalized = normalizer.normalize(code)
|
||||
>>> fingerprint = normalizer.get_fingerprint(code)
|
||||
>>> are_same = normalizer.are_duplicates(code1, code2)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from codeflash.code_utils.normalizers.base import CodeNormalizer
|
||||
from codeflash.code_utils.normalizers.javascript import JavaScriptNormalizer, TypeScriptNormalizer
|
||||
from codeflash.code_utils.normalizers.python import PythonNormalizer
|
||||
|
||||
__all__ = [
|
||||
"CodeNormalizer",
|
||||
"JavaScriptNormalizer",
|
||||
"PythonNormalizer",
|
||||
"TypeScriptNormalizer",
|
||||
"get_normalizer",
|
||||
"get_normalizer_for_extension",
|
||||
]
|
||||
|
||||
# Registry of normalizers by language
|
||||
_NORMALIZERS: dict[str, type[CodeNormalizer]] = {
|
||||
"python": PythonNormalizer,
|
||||
"javascript": JavaScriptNormalizer,
|
||||
"typescript": TypeScriptNormalizer,
|
||||
}
|
||||
|
||||
# Singleton cache for normalizer instances
|
||||
_normalizer_instances: dict[str, CodeNormalizer] = {}
|
||||
|
||||
|
||||
def get_normalizer(language: str) -> CodeNormalizer:
|
||||
"""Get a code normalizer for the specified language.
|
||||
|
||||
Args:
|
||||
language: Language name ('python', 'javascript', 'typescript')
|
||||
|
||||
Returns:
|
||||
CodeNormalizer instance for the language
|
||||
|
||||
Raises:
|
||||
ValueError: If no normalizer exists for the language
|
||||
|
||||
"""
|
||||
language = language.lower()
|
||||
|
||||
# Check cache first
|
||||
if language in _normalizer_instances:
|
||||
return _normalizer_instances[language]
|
||||
|
||||
# Get normalizer class
|
||||
if language not in _NORMALIZERS:
|
||||
supported = ", ".join(sorted(_NORMALIZERS.keys()))
|
||||
msg = f"No normalizer available for language '{language}'. Supported: {supported}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Create and cache instance
|
||||
normalizer = _NORMALIZERS[language]()
|
||||
_normalizer_instances[language] = normalizer
|
||||
return normalizer
|
||||
|
||||
|
||||
def get_normalizer_for_extension(extension: str) -> CodeNormalizer | None:
|
||||
"""Get a code normalizer based on file extension.
|
||||
|
||||
Args:
|
||||
extension: File extension including dot (e.g., '.py', '.js')
|
||||
|
||||
Returns:
|
||||
CodeNormalizer instance if found, None otherwise
|
||||
|
||||
"""
|
||||
extension = extension.lower()
|
||||
if not extension.startswith("."):
|
||||
extension = f".{extension}"
|
||||
|
||||
for language in _NORMALIZERS:
|
||||
normalizer = get_normalizer(language)
|
||||
if extension in normalizer.supported_extensions:
|
||||
return normalizer
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def register_normalizer(language: str, normalizer_class: type[CodeNormalizer]) -> None:
|
||||
"""Register a new normalizer for a language.
|
||||
|
||||
Args:
|
||||
language: Language name
|
||||
normalizer_class: CodeNormalizer subclass
|
||||
|
||||
"""
|
||||
_NORMALIZERS[language.lower()] = normalizer_class
|
||||
# Clear cached instance if it exists
|
||||
_normalizer_instances.pop(language.lower(), None)
|
||||
|
|
@ -1,104 +0,0 @@
|
|||
"""Abstract base class for code normalizers.
|
||||
|
||||
Code normalizers transform source code into a canonical form for duplicate detection.
|
||||
They normalize variable names, remove comments/docstrings, and produce consistent output
|
||||
that can be compared across different implementations of the same algorithm.
|
||||
"""
|
||||
|
||||
# TODO:{claude} move to base.py in language folder
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class CodeNormalizer(ABC):
|
||||
"""Abstract base class for language-specific code normalizers.
|
||||
|
||||
Subclasses must implement the normalize() method for their specific language.
|
||||
The normalization should:
|
||||
- Normalize local variable names to canonical forms (var_0, var_1, etc.)
|
||||
- Preserve function names, class names, parameters, and imports
|
||||
- Remove or normalize comments and docstrings
|
||||
- Produce consistent output for structurally identical code
|
||||
|
||||
Example:
|
||||
>>> normalizer = PythonNormalizer()
|
||||
>>> code1 = "def foo(x): y = x + 1; return y"
|
||||
>>> code2 = "def foo(x): z = x + 1; return z"
|
||||
>>> normalizer.normalize(code1) == normalizer.normalize(code2)
|
||||
True
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
...
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return ()
|
||||
|
||||
@abstractmethod
|
||||
def normalize(self, code: str) -> str:
|
||||
"""Normalize code to a canonical form for comparison.
|
||||
|
||||
Args:
|
||||
code: Source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation of the code
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def normalize_for_hash(self, code: str) -> str:
|
||||
"""Normalize code optimized for hashing/fingerprinting.
|
||||
|
||||
This may return a more compact representation than normalize().
|
||||
|
||||
Args:
|
||||
code: Source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation suitable for hashing
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def are_duplicates(self, code1: str, code2: str) -> bool:
|
||||
"""Check if two code segments are duplicates after normalization.
|
||||
|
||||
Args:
|
||||
code1: First code segment
|
||||
code2: Second code segment
|
||||
|
||||
Returns:
|
||||
True if codes are structurally identical
|
||||
|
||||
"""
|
||||
try:
|
||||
normalized1 = self.normalize_for_hash(code1)
|
||||
normalized2 = self.normalize_for_hash(code2)
|
||||
except Exception:
|
||||
return False
|
||||
else:
|
||||
return normalized1 == normalized2
|
||||
|
||||
def get_fingerprint(self, code: str) -> str:
|
||||
"""Generate a fingerprint hash for normalized code.
|
||||
|
||||
Args:
|
||||
code: Source code to fingerprint
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of normalized code
|
||||
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
normalized = self.normalize_for_hash(code)
|
||||
return hashlib.sha256(normalized.encode()).hexdigest()
|
||||
|
|
@ -1,290 +0,0 @@
|
|||
"""JavaScript/TypeScript code normalizer using tree-sitter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.code_utils.normalizers.base import CodeNormalizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Node
|
||||
|
||||
|
||||
# TODO:{claude} move to language support directory to keep the directory structure clean
|
||||
class JavaScriptVariableNormalizer:
|
||||
"""Normalizes JavaScript/TypeScript code for duplicate detection using tree-sitter.
|
||||
|
||||
Normalizes local variable names while preserving function names, class names,
|
||||
parameters, and imported names.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.var_counter = 0
|
||||
self.var_mapping: dict[str, str] = {}
|
||||
self.preserved_names: set[str] = set()
|
||||
# Common JavaScript builtins
|
||||
self.builtins = {
|
||||
"console",
|
||||
"window",
|
||||
"document",
|
||||
"Math",
|
||||
"JSON",
|
||||
"Object",
|
||||
"Array",
|
||||
"String",
|
||||
"Number",
|
||||
"Boolean",
|
||||
"Date",
|
||||
"RegExp",
|
||||
"Error",
|
||||
"Promise",
|
||||
"Map",
|
||||
"Set",
|
||||
"WeakMap",
|
||||
"WeakSet",
|
||||
"Symbol",
|
||||
"Proxy",
|
||||
"Reflect",
|
||||
"undefined",
|
||||
"null",
|
||||
"NaN",
|
||||
"Infinity",
|
||||
"globalThis",
|
||||
"parseInt",
|
||||
"parseFloat",
|
||||
"isNaN",
|
||||
"isFinite",
|
||||
"eval",
|
||||
"setTimeout",
|
||||
"setInterval",
|
||||
"clearTimeout",
|
||||
"clearInterval",
|
||||
"fetch",
|
||||
"require",
|
||||
"module",
|
||||
"exports",
|
||||
"process",
|
||||
"__dirname",
|
||||
"__filename",
|
||||
"Buffer",
|
||||
}
|
||||
|
||||
def get_normalized_name(self, name: str) -> str:
|
||||
"""Get or create normalized name for a variable."""
|
||||
if name in self.builtins or name in self.preserved_names:
|
||||
return name
|
||||
if name not in self.var_mapping:
|
||||
self.var_mapping[name] = f"var_{self.var_counter}"
|
||||
self.var_counter += 1
|
||||
return self.var_mapping[name]
|
||||
|
||||
def collect_preserved_names(self, node: Node, source_code: bytes) -> None:
|
||||
"""Collect names that should be preserved (function names, class names, imports, params)."""
|
||||
# Function declarations and expressions - preserve the function name
|
||||
if node.type in ("function_declaration", "function_expression", "method_definition", "arrow_function"):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
|
||||
# Preserve parameters
|
||||
params_node = node.child_by_field_name("parameters") or node.child_by_field_name("parameter")
|
||||
if params_node:
|
||||
self._collect_parameter_names(params_node, source_code)
|
||||
|
||||
# Class declarations
|
||||
elif node.type == "class_declaration":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
|
||||
|
||||
# Import declarations
|
||||
elif node.type in ("import_statement", "import_declaration"):
|
||||
for child in node.children:
|
||||
if child.type == "import_clause":
|
||||
self._collect_import_names(child, source_code)
|
||||
elif child.type == "identifier":
|
||||
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
|
||||
|
||||
# Recurse
|
||||
for child in node.children:
|
||||
self.collect_preserved_names(child, source_code)
|
||||
|
||||
def _collect_parameter_names(self, node: Node, source_code: bytes) -> None:
|
||||
"""Collect parameter names from a parameters node."""
|
||||
for child in node.children:
|
||||
if child.type == "identifier":
|
||||
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
|
||||
elif child.type in ("required_parameter", "optional_parameter", "rest_parameter"):
|
||||
pattern_node = child.child_by_field_name("pattern")
|
||||
if pattern_node and pattern_node.type == "identifier":
|
||||
self.preserved_names.add(
|
||||
source_code[pattern_node.start_byte : pattern_node.end_byte].decode("utf-8")
|
||||
)
|
||||
# Recurse for nested patterns
|
||||
self._collect_parameter_names(child, source_code)
|
||||
|
||||
def _collect_import_names(self, node: Node, source_code: bytes) -> None:
|
||||
"""Collect imported names from import clause."""
|
||||
for child in node.children:
|
||||
if child.type == "identifier":
|
||||
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
|
||||
elif child.type == "import_specifier":
|
||||
# Get the local name (alias or original)
|
||||
alias_node = child.child_by_field_name("alias")
|
||||
name_node = child.child_by_field_name("name")
|
||||
if alias_node:
|
||||
self.preserved_names.add(source_code[alias_node.start_byte : alias_node.end_byte].decode("utf-8"))
|
||||
elif name_node:
|
||||
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
|
||||
self._collect_import_names(child, source_code)
|
||||
|
||||
def normalize_tree(self, node: Node, source_code: bytes) -> str:
|
||||
"""Normalize the AST tree to a string representation for comparison."""
|
||||
parts: list[str] = []
|
||||
self._normalize_node(node, source_code, parts)
|
||||
return " ".join(parts)
|
||||
|
||||
def _normalize_node(self, node: Node, source_code: bytes, parts: list[str]) -> None:
|
||||
"""Recursively normalize a node."""
|
||||
# Skip comments
|
||||
if node.type in ("comment", "line_comment", "block_comment"):
|
||||
return
|
||||
|
||||
# Handle identifiers - normalize variable names
|
||||
if node.type == "identifier":
|
||||
name = source_code[node.start_byte : node.end_byte].decode("utf-8")
|
||||
normalized = self.get_normalized_name(name)
|
||||
parts.append(normalized)
|
||||
return
|
||||
|
||||
# Handle type identifiers (TypeScript) - preserve as-is
|
||||
if node.type == "type_identifier":
|
||||
parts.append(source_code[node.start_byte : node.end_byte].decode("utf-8"))
|
||||
return
|
||||
|
||||
# Handle string literals - normalize to placeholder
|
||||
if node.type in ("string", "template_string", "string_fragment"):
|
||||
parts.append('"STR"')
|
||||
return
|
||||
|
||||
# Handle number literals - normalize to placeholder
|
||||
if node.type == "number":
|
||||
parts.append("NUM")
|
||||
return
|
||||
|
||||
# For leaf nodes, output the node type
|
||||
if len(node.children) == 0:
|
||||
text = source_code[node.start_byte : node.end_byte].decode("utf-8")
|
||||
parts.append(text)
|
||||
return
|
||||
|
||||
# Output node type for structure
|
||||
parts.append(f"({node.type}")
|
||||
|
||||
# Recurse into children
|
||||
for child in node.children:
|
||||
self._normalize_node(child, source_code, parts)
|
||||
|
||||
parts.append(")")
|
||||
|
||||
|
||||
def _basic_normalize(code: str) -> str:
|
||||
"""Basic normalization: remove comments and normalize whitespace."""
|
||||
# Remove single-line comments
|
||||
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
|
||||
# Remove multi-line comments
|
||||
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
|
||||
# Normalize whitespace
|
||||
return " ".join(code.split())
|
||||
|
||||
|
||||
class JavaScriptNormalizer(CodeNormalizer):
|
||||
"""JavaScript code normalizer using tree-sitter.
|
||||
|
||||
Normalizes JavaScript code by:
|
||||
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
|
||||
- Preserving function names, class names, parameters, and imports
|
||||
- Removing comments
|
||||
- Normalizing string and number literals
|
||||
"""
|
||||
|
||||
@property
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
return "javascript"
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return (".js", ".jsx", ".mjs", ".cjs")
|
||||
|
||||
def _get_tree_sitter_language(self) -> str:
|
||||
"""Get the tree-sitter language identifier."""
|
||||
return "javascript"
|
||||
|
||||
def normalize(self, code: str) -> str:
|
||||
"""Normalize JavaScript code to a canonical form.
|
||||
|
||||
Args:
|
||||
code: JavaScript source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation of the code
|
||||
|
||||
"""
|
||||
try:
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
lang_map = {"javascript": TreeSitterLanguage.JAVASCRIPT, "typescript": TreeSitterLanguage.TYPESCRIPT}
|
||||
lang = lang_map.get(self._get_tree_sitter_language(), TreeSitterLanguage.JAVASCRIPT)
|
||||
analyzer = TreeSitterAnalyzer(lang)
|
||||
tree = analyzer.parse(code)
|
||||
|
||||
if tree.root_node.has_error:
|
||||
return _basic_normalize(code)
|
||||
|
||||
normalizer = JavaScriptVariableNormalizer()
|
||||
source_bytes = code.encode("utf-8")
|
||||
|
||||
# First pass: collect preserved names
|
||||
normalizer.collect_preserved_names(tree.root_node, source_bytes)
|
||||
|
||||
# Second pass: normalize and build representation
|
||||
return normalizer.normalize_tree(tree.root_node, source_bytes)
|
||||
except Exception:
|
||||
return _basic_normalize(code)
|
||||
|
||||
def normalize_for_hash(self, code: str) -> str:
|
||||
"""Normalize JavaScript code optimized for hashing.
|
||||
|
||||
For JavaScript, this is the same as normalize().
|
||||
|
||||
Args:
|
||||
code: JavaScript source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation suitable for hashing
|
||||
|
||||
"""
|
||||
return self.normalize(code)
|
||||
|
||||
|
||||
class TypeScriptNormalizer(JavaScriptNormalizer):
|
||||
"""TypeScript code normalizer using tree-sitter.
|
||||
|
||||
Inherits from JavaScriptNormalizer and overrides language-specific settings.
|
||||
"""
|
||||
|
||||
@property
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
return "typescript"
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return (".ts", ".tsx", ".mts", ".cts")
|
||||
|
||||
def _get_tree_sitter_language(self) -> str:
|
||||
"""Get the tree-sitter language identifier."""
|
||||
return "typescript"
|
||||
|
|
@ -1,226 +0,0 @@
|
|||
"""Python code normalizer using AST transformation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
|
||||
from codeflash.code_utils.normalizers.base import CodeNormalizer
|
||||
|
||||
|
||||
class VariableNormalizer(ast.NodeTransformer):
|
||||
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
|
||||
|
||||
Preserves function names, class names, parameters, built-ins, and imported names.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.var_counter = 0
|
||||
self.var_mapping: dict[str, str] = {}
|
||||
self.scope_stack: list[dict] = []
|
||||
self.builtins = set(dir(__builtins__))
|
||||
self.imports: set[str] = set()
|
||||
self.global_vars: set[str] = set()
|
||||
self.nonlocal_vars: set[str] = set()
|
||||
self.parameters: set[str] = set()
|
||||
|
||||
def enter_scope(self) -> None:
|
||||
"""Enter a new scope (function/class)."""
|
||||
self.scope_stack.append(
|
||||
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
|
||||
)
|
||||
|
||||
def exit_scope(self) -> None:
|
||||
"""Exit current scope and restore parent scope."""
|
||||
if self.scope_stack:
|
||||
scope = self.scope_stack.pop()
|
||||
self.var_mapping = scope["var_mapping"]
|
||||
self.var_counter = scope["var_counter"]
|
||||
self.parameters = scope["parameters"]
|
||||
|
||||
def get_normalized_name(self, name: str) -> str:
|
||||
"""Get or create normalized name for a variable."""
|
||||
if (
|
||||
name in self.builtins
|
||||
or name in self.imports
|
||||
or name in self.global_vars
|
||||
or name in self.nonlocal_vars
|
||||
or name in self.parameters
|
||||
):
|
||||
return name
|
||||
|
||||
if name not in self.var_mapping:
|
||||
self.var_mapping[name] = f"var_{self.var_counter}"
|
||||
self.var_counter += 1
|
||||
return self.var_mapping[name]
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> ast.Import:
|
||||
"""Track imported names."""
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
self.imports.add(name.split(".")[0])
|
||||
return node
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
|
||||
"""Track imported names from modules."""
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
self.imports.add(name)
|
||||
return node
|
||||
|
||||
def visit_Global(self, node: ast.Global) -> ast.Global:
|
||||
"""Track global variable declarations."""
|
||||
self.global_vars.update(node.names)
|
||||
return node
|
||||
|
||||
def visit_Nonlocal(self, node: ast.Nonlocal) -> ast.Nonlocal:
|
||||
"""Track nonlocal variable declarations."""
|
||||
self.nonlocal_vars.update(node.names)
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
|
||||
"""Process function but keep function name and parameters unchanged."""
|
||||
self.enter_scope()
|
||||
|
||||
for arg in node.args.args:
|
||||
self.parameters.add(arg.arg)
|
||||
if node.args.vararg:
|
||||
self.parameters.add(node.args.vararg.arg)
|
||||
if node.args.kwarg:
|
||||
self.parameters.add(node.args.kwarg.arg)
|
||||
for arg in node.args.kwonlyargs:
|
||||
self.parameters.add(arg.arg)
|
||||
|
||||
node = self.generic_visit(node)
|
||||
self.exit_scope()
|
||||
return node
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
|
||||
"""Handle async functions same as regular functions."""
|
||||
return self.visit_FunctionDef(node) # type: ignore[return-value]
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||
"""Process class but keep class name unchanged."""
|
||||
self.enter_scope()
|
||||
node = self.generic_visit(node)
|
||||
self.exit_scope()
|
||||
return node
|
||||
|
||||
def visit_Name(self, node: ast.Name) -> ast.Name:
|
||||
"""Normalize variable names in Name nodes."""
|
||||
if isinstance(node.ctx, (ast.Store, ast.Del)):
|
||||
if (
|
||||
node.id not in self.builtins
|
||||
and node.id not in self.imports
|
||||
and node.id not in self.parameters
|
||||
and node.id not in self.global_vars
|
||||
and node.id not in self.nonlocal_vars
|
||||
):
|
||||
node.id = self.get_normalized_name(node.id)
|
||||
elif isinstance(node.ctx, ast.Load) and node.id in self.var_mapping:
|
||||
node.id = self.var_mapping[node.id]
|
||||
return node
|
||||
|
||||
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> ast.ExceptHandler:
|
||||
"""Normalize exception variable names."""
|
||||
if node.name:
|
||||
node.name = self.get_normalized_name(node.name)
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_comprehension(self, node: ast.comprehension) -> ast.comprehension:
|
||||
"""Normalize comprehension target variables."""
|
||||
old_mapping = dict(self.var_mapping)
|
||||
old_counter = self.var_counter
|
||||
|
||||
node = self.generic_visit(node)
|
||||
|
||||
self.var_mapping = old_mapping
|
||||
self.var_counter = old_counter
|
||||
return node
|
||||
|
||||
def visit_For(self, node: ast.For) -> ast.For:
|
||||
"""Handle for loop target variables."""
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_With(self, node: ast.With) -> ast.With:
|
||||
"""Handle with statement as variables."""
|
||||
return self.generic_visit(node)
|
||||
|
||||
|
||||
def _remove_docstrings_from_ast(node: ast.AST) -> None:
|
||||
"""Remove docstrings from AST nodes."""
|
||||
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
|
||||
stack = [node]
|
||||
while stack:
|
||||
current_node = stack.pop()
|
||||
if isinstance(current_node, node_types):
|
||||
body = current_node.body
|
||||
if (
|
||||
body
|
||||
and isinstance(body[0], ast.Expr)
|
||||
and isinstance(body[0].value, ast.Constant)
|
||||
and isinstance(body[0].value.value, str)
|
||||
):
|
||||
current_node.body = body[1:]
|
||||
stack.extend([child for child in body if isinstance(child, node_types)])
|
||||
|
||||
|
||||
class PythonNormalizer(CodeNormalizer):
|
||||
"""Python code normalizer using AST transformation.
|
||||
|
||||
Normalizes Python code by:
|
||||
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
|
||||
- Preserving function names, class names, parameters, and imports
|
||||
- Optionally removing docstrings
|
||||
"""
|
||||
|
||||
@property
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
return "python"
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return (".py", ".pyw", ".pyi")
|
||||
|
||||
def normalize(self, code: str, remove_docstrings: bool = True) -> str:
|
||||
"""Normalize Python code to a canonical form.
|
||||
|
||||
Args:
|
||||
code: Python source code to normalize
|
||||
remove_docstrings: Whether to remove docstrings
|
||||
|
||||
Returns:
|
||||
Normalized Python code as a string
|
||||
|
||||
"""
|
||||
tree = ast.parse(code)
|
||||
|
||||
if remove_docstrings:
|
||||
_remove_docstrings_from_ast(tree)
|
||||
|
||||
normalizer = VariableNormalizer()
|
||||
normalized_tree = normalizer.visit(tree)
|
||||
ast.fix_missing_locations(normalized_tree)
|
||||
|
||||
return ast.unparse(normalized_tree)
|
||||
|
||||
def normalize_for_hash(self, code: str) -> str:
|
||||
"""Normalize Python code optimized for hashing.
|
||||
|
||||
Returns AST dump which is faster than unparsing.
|
||||
|
||||
Args:
|
||||
code: Python source code to normalize
|
||||
|
||||
Returns:
|
||||
AST dump string suitable for hashing
|
||||
|
||||
"""
|
||||
tree = ast.parse(code)
|
||||
_remove_docstrings_from_ast(tree)
|
||||
|
||||
normalizer = VariableNormalizer()
|
||||
normalized_tree = normalizer.visit(tree)
|
||||
|
||||
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)
|
||||
|
|
@ -536,7 +536,9 @@ def run_mocha_benchmarking_tests(
|
|||
)
|
||||
mocha_env["CODEFLASH_TEST_MODULE"] = test_module_path
|
||||
|
||||
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)
|
||||
# Subprocess timeout: target_duration + 120s headroom for Mocha startup.
|
||||
# capturePerf's time budget governs actual looping.
|
||||
total_timeout = max(120, (target_duration_ms // 1000) + 120, timeout or 120)
|
||||
|
||||
logger.debug(f"Running Mocha benchmarking tests: {' '.join(mocha_cmd)}")
|
||||
logger.debug(
|
||||
|
|
|
|||
|
|
@ -1025,9 +1025,9 @@ def run_jest_benchmarking_tests(
|
|||
if "--max-old-space-size" not in existing_node_options:
|
||||
jest_env["NODE_OPTIONS"] = f"{existing_node_options} --max-old-space-size=4096".strip()
|
||||
|
||||
# Total timeout for the entire benchmark run (longer than single-loop timeout)
|
||||
# Account for startup overhead + target duration + buffer
|
||||
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)
|
||||
# Subprocess timeout: target_duration + 120s headroom for Jest startup
|
||||
# and TS compilation. capturePerf's time budget governs actual looping.
|
||||
total_timeout = max(120, (target_duration_ms // 1000) + 120)
|
||||
|
||||
logger.debug(f"Running Jest benchmarking tests with in-process loop runner: {' '.join(jest_cmd)}")
|
||||
logger.debug(
|
||||
|
|
|
|||
|
|
@ -616,8 +616,10 @@ def run_vitest_benchmarking_tests(
|
|||
vitest_env["CODEFLASH_TEST_MODULE"] = test_module_path
|
||||
logger.debug(f"[VITEST-BENCH] Set CODEFLASH_TEST_MODULE={test_module_path}")
|
||||
|
||||
# Total timeout for the entire benchmark run
|
||||
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)
|
||||
# Subprocess timeout: target_duration + 120s headroom for Vitest startup
|
||||
# (TS compilation, module resolution). The capturePerf time budget (10s default)
|
||||
# governs actual looping; this is just a safety net for process-level hangs.
|
||||
total_timeout = max(120, (target_duration_ms // 1000) + 120)
|
||||
|
||||
logger.debug(f"[VITEST-BENCH] Running Vitest benchmarking tests: {' '.join(vitest_cmd)}")
|
||||
logger.debug(
|
||||
|
|
|
|||
|
|
@ -773,7 +773,13 @@ class InvocationId:
|
|||
test_src = test_path.read_text(encoding="utf-8")
|
||||
module_node = cst.parse_module(test_src)
|
||||
except Exception:
|
||||
return None
|
||||
# libcst can't parse non-Python files (JS/TS) — return a descriptive string
|
||||
# so the code repair API receives a non-None test_src_code.
|
||||
return (
|
||||
f"// Test: {self.test_function_name}\n"
|
||||
f"// File: {test_path.name}\n"
|
||||
f"// Testing function: {self.function_getting_tested}"
|
||||
)
|
||||
|
||||
if self.test_class_name:
|
||||
for stmt in module_node.body:
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
# These version placeholders will be replaced by uv-dynamic-versioning during build.
|
||||
__version__ = "0.20.1.post242.dev0+7c7eeb5b"
|
||||
__version__ = "0.20.1.post675.dev0+1218a1cd"
|
||||
|
|
|
|||
4
packages/codeflash/package-lock.json
generated
4
packages/codeflash/package-lock.json
generated
|
|
@ -1,12 +1,12 @@
|
|||
{
|
||||
"name": "codeflash",
|
||||
"version": "0.9.0",
|
||||
"version": "0.10.1",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "codeflash",
|
||||
"version": "0.9.0",
|
||||
"version": "0.10.1",
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "codeflash",
|
||||
"version": "0.10.0",
|
||||
"version": "0.10.1",
|
||||
"description": "Codeflash - AI-powered code optimization for JavaScript and TypeScript",
|
||||
"main": "runtime/index.js",
|
||||
"types": "runtime/index.d.ts",
|
||||
|
|
|
|||
|
|
@ -926,11 +926,7 @@ class TestTestFrameworkConfigOverride:
|
|||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test-project",
|
||||
"devDependencies": {"vitest": "^1.0.0"},
|
||||
"codeflash": {"moduleRoot": "src"},
|
||||
}
|
||||
{"name": "test-project", "devDependencies": {"vitest": "^1.0.0"}, "codeflash": {"moduleRoot": "src"}}
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -945,11 +941,7 @@ class TestTestFrameworkConfigOverride:
|
|||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test-project",
|
||||
"devDependencies": {"jest": "^29.0.0"},
|
||||
"codeflash": {"test-framework": ""},
|
||||
}
|
||||
{"name": "test-project", "devDependencies": {"jest": "^29.0.0"}, "codeflash": {"test-framework": ""}}
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,10 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any
|
||||
|
||||
from codeflash.languages.python.static_analysis.coverage_utils import build_fully_qualified_name, extract_dependent_function
|
||||
from codeflash.languages.python.static_analysis.coverage_utils import (
|
||||
build_fully_qualified_name,
|
||||
extract_dependent_function,
|
||||
)
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown
|
||||
from codeflash.verification.coverage_utils import CoverageUtils
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from __future__ import annotations
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from junitparser import JUnitXml
|
||||
|
||||
from codeflash.languages.javascript.parse import jest_end_pattern, jest_start_pattern
|
||||
|
|
@ -338,9 +337,7 @@ class TestFilenameBasedLookupFallback:
|
|||
path2.touch()
|
||||
|
||||
test_file1 = TestFile(
|
||||
original_file_path=path1,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
instrumented_behavior_file_path=path1,
|
||||
original_file_path=path1, test_type=TestType.GENERATED_REGRESSION, instrumented_behavior_file_path=path1
|
||||
)
|
||||
test_file2 = TestFile(
|
||||
original_file_path=path2,
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ from __future__ import annotations
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.javascript.vitest_runner import (
|
||||
_build_vitest_behavioral_command,
|
||||
_build_vitest_benchmarking_command,
|
||||
|
|
|
|||
|
|
@ -6,9 +6,7 @@ from end_to_end_test_utilities import TestConfig, run_codeflash_command, run_wit
|
|||
|
||||
def run_test(expected_improvement_pct: int) -> bool:
|
||||
config = TestConfig(
|
||||
file_path="src/main/java/com/example/Fibonacci.java",
|
||||
function_name="fibonacci",
|
||||
min_improvement_x=0.70,
|
||||
file_path="src/main/java/com/example/Fibonacci.java", function_name="fibonacci", min_improvement_x=0.70
|
||||
)
|
||||
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "java").resolve()
|
||||
return run_codeflash_command(cwd, config, expected_improvement_pct)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
"""Tests for cleanup of instrumented test files."""
|
||||
|
||||
from pathlib import Path
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3809,8 +3809,7 @@ def test_enrich_testgen_context_deduplicates(tmp_path: Path) -> None:
|
|||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(package_dir / "base.py").write_text(
|
||||
"class Base:\n def __init__(self, x: int):\n self.x = x\n",
|
||||
encoding="utf-8",
|
||||
"class Base:\n def __init__(self, x: int):\n self.x = x\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
code = "from mypkg.base import Base\n\nclass A(Base):\n pass\n\nclass B(Base):\n pass\n"
|
||||
|
|
@ -3954,8 +3953,7 @@ def test_testgen_context_includes_external_base_inits(tmp_path: Path) -> None:
|
|||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(package_dir / "base.py").write_text(
|
||||
"class BaseDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
|
||||
encoding="utf-8",
|
||||
"class BaseDict:\n def __init__(self, data=None):\n self.data = data or {}\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
code = "from mypkg.base import BaseDict\n\nclass MyCustomDict(BaseDict):\n def target_method(self):\n return self.data\n"
|
||||
|
|
@ -4009,8 +4007,7 @@ def test_enrich_testgen_context_attribute_base(tmp_path: Path) -> None:
|
|||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(package_dir / "base.py").write_text(
|
||||
"class CustomDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
|
||||
encoding="utf-8",
|
||||
"class CustomDict:\n def __init__(self, data=None):\n self.data = data or {}\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
code = "from mypkg.base import CustomDict\n\nclass MyDict(CustomDict):\n def custom_method(self):\n return self.data\n"
|
||||
|
|
@ -4702,18 +4699,17 @@ def get_log_level() -> str:
|
|||
assert "class AppConfig:" in combined
|
||||
assert "@property" in combined
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_isinstance_single(tmp_path: Path) -> None:
|
||||
"""isinstance(x, SomeType) in function body should be picked up."""
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "models.py").write_text(
|
||||
"class Widget:\n def __init__(self, size: int):\n self.size = size\n",
|
||||
encoding="utf-8",
|
||||
"class Widget:\n def __init__(self, size: int):\n self.size = size\n", encoding="utf-8"
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"from mypkg.models import Widget\n\ndef check(x) -> bool:\n return isinstance(x, Widget)\n",
|
||||
encoding="utf-8",
|
||||
"from mypkg.models import Widget\n\ndef check(x) -> bool:\n return isinstance(x, Widget)\n", encoding="utf-8"
|
||||
)
|
||||
fto = FunctionToOptimize(
|
||||
function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
|
|
@ -4754,12 +4750,10 @@ def test_extract_parameter_type_constructors_type_is_pattern(tmp_path: Path) ->
|
|||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "models.py").write_text(
|
||||
"class Gadget:\n def __init__(self, val: float):\n self.val = val\n",
|
||||
encoding="utf-8",
|
||||
"class Gadget:\n def __init__(self, val: float):\n self.val = val\n", encoding="utf-8"
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"from mypkg.models import Gadget\n\ndef check(x) -> bool:\n return type(x) is Gadget\n",
|
||||
encoding="utf-8",
|
||||
"from mypkg.models import Gadget\n\ndef check(x) -> bool:\n return type(x) is Gadget\n", encoding="utf-8"
|
||||
)
|
||||
fto = FunctionToOptimize(
|
||||
function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
|
|
@ -4775,8 +4769,7 @@ def test_extract_parameter_type_constructors_base_classes(tmp_path: Path) -> Non
|
|||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "base.py").write_text(
|
||||
"class BaseProcessor:\n def __init__(self, config: str):\n self.config = config\n",
|
||||
encoding="utf-8",
|
||||
"class BaseProcessor:\n def __init__(self, config: str):\n self.config = config\n", encoding="utf-8"
|
||||
)
|
||||
(pkg / "child.py").write_text(
|
||||
"from mypkg.base import BaseProcessor\n\nclass ChildProcessor(BaseProcessor):\n"
|
||||
|
|
@ -4801,8 +4794,7 @@ def test_extract_parameter_type_constructors_isinstance_builtins_excluded(tmp_pa
|
|||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "func.py").write_text(
|
||||
"def check(x) -> bool:\n return isinstance(x, (int, str, float))\n",
|
||||
encoding="utf-8",
|
||||
"def check(x) -> bool:\n return isinstance(x, (int, str, float))\n", encoding="utf-8"
|
||||
)
|
||||
fto = FunctionToOptimize(
|
||||
function_name="check", file_path=(pkg / "func.py").resolve(), starting_line=1, ending_line=2
|
||||
|
|
@ -4817,8 +4809,7 @@ def test_extract_parameter_type_constructors_transitive(tmp_path: Path) -> None:
|
|||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "config.py").write_text(
|
||||
"class Config:\n def __init__(self, debug: bool = False):\n self.debug = debug\n",
|
||||
encoding="utf-8",
|
||||
"class Config:\n def __init__(self, debug: bool = False):\n self.debug = debug\n", encoding="utf-8"
|
||||
)
|
||||
(pkg / "models.py").write_text(
|
||||
"from mypkg.config import Config\n\n"
|
||||
|
|
@ -4826,8 +4817,7 @@ def test_extract_parameter_type_constructors_transitive(tmp_path: Path) -> None:
|
|||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"from mypkg.models import Widget\n\ndef process(w: Widget) -> str:\n return str(w)\n",
|
||||
encoding="utf-8",
|
||||
"from mypkg.models import Widget\n\ndef process(w: Widget) -> str:\n return str(w)\n", encoding="utf-8"
|
||||
)
|
||||
fto = FunctionToOptimize(
|
||||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
|
|
@ -4838,8 +4828,6 @@ def test_extract_parameter_type_constructors_transitive(tmp_path: Path) -> None:
|
|||
assert "class Config:" in combined
|
||||
|
||||
|
||||
|
||||
|
||||
def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None:
|
||||
"""Third-party classes should produce compact __init__ stubs, not full class source."""
|
||||
# Use a real third-party package (pydantic) so jedi can actually resolve it
|
||||
|
|
|
|||
|
|
@ -7,7 +7,12 @@ from pathlib import Path
|
|||
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.languages.python.static_analysis.code_extractor import delete___future___aliased_imports, find_preexisting_objects
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.languages.python.static_analysis.code_extractor import (
|
||||
delete___future___aliased_imports,
|
||||
find_preexisting_objects,
|
||||
)
|
||||
from codeflash.languages.python.static_analysis.code_replacer import (
|
||||
AddRequestArgument,
|
||||
AutouseFixtureModifier,
|
||||
|
|
@ -16,9 +21,7 @@ from codeflash.languages.python.static_analysis.code_replacer import (
|
|||
replace_functions_and_add_imports,
|
||||
replace_functions_in_file,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, FunctionSource
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
||||
|
|
|
|||
|
|
@ -20,7 +20,11 @@ from codeflash.code_utils.code_utils import (
|
|||
validate_python_code,
|
||||
)
|
||||
from codeflash.languages.python.static_analysis.concolic_utils import clean_concolic_tests
|
||||
from codeflash.languages.python.static_analysis.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
|
||||
from codeflash.languages.python.static_analysis.coverage_utils import (
|
||||
extract_dependent_function,
|
||||
generate_candidates,
|
||||
prepare_coverage_files,
|
||||
)
|
||||
from codeflash.models.models import CodeStringsMarkdown
|
||||
from codeflash.verification.parse_test_output import resolve_test_file_from_class_path
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -592,12 +592,10 @@ def test_itertools_permutations_combinations() -> None:
|
|||
assert comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 2))
|
||||
assert not comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 3))
|
||||
assert comparator(
|
||||
itertools.combinations_with_replacement("ABC", 2),
|
||||
itertools.combinations_with_replacement("ABC", 2),
|
||||
itertools.combinations_with_replacement("ABC", 2), itertools.combinations_with_replacement("ABC", 2)
|
||||
)
|
||||
assert not comparator(
|
||||
itertools.combinations_with_replacement("ABC", 2),
|
||||
itertools.combinations_with_replacement("ABD", 2),
|
||||
itertools.combinations_with_replacement("ABC", 2), itertools.combinations_with_replacement("ABD", 2)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -615,38 +613,31 @@ def test_itertools_filtering() -> None:
|
|||
|
||||
# compress
|
||||
assert comparator(
|
||||
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
|
||||
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
|
||||
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]), itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1])
|
||||
)
|
||||
assert not comparator(
|
||||
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
|
||||
itertools.compress("ABCDEF", [1, 1, 1, 0, 1, 1]),
|
||||
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]), itertools.compress("ABCDEF", [1, 1, 1, 0, 1, 1])
|
||||
)
|
||||
|
||||
# dropwhile
|
||||
assert comparator(
|
||||
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
|
||||
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
|
||||
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1])
|
||||
)
|
||||
assert not comparator(
|
||||
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
|
||||
itertools.dropwhile(lambda x: x < 5, [1, 4, 7, 4, 1]),
|
||||
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.dropwhile(lambda x: x < 5, [1, 4, 7, 4, 1])
|
||||
)
|
||||
|
||||
# takewhile
|
||||
assert comparator(
|
||||
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
|
||||
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
|
||||
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1])
|
||||
)
|
||||
assert not comparator(
|
||||
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
|
||||
itertools.takewhile(lambda x: x < 5, [1, 3, 6, 4, 1]),
|
||||
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.takewhile(lambda x: x < 5, [1, 3, 6, 4, 1])
|
||||
)
|
||||
|
||||
# filterfalse
|
||||
assert comparator(
|
||||
itertools.filterfalse(lambda x: x % 2, range(10)),
|
||||
itertools.filterfalse(lambda x: x % 2, range(10)),
|
||||
itertools.filterfalse(lambda x: x % 2, range(10)), itertools.filterfalse(lambda x: x % 2, range(10))
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -654,25 +645,19 @@ def test_itertools_starmap() -> None:
|
|||
import itertools
|
||||
|
||||
assert comparator(
|
||||
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
|
||||
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
|
||||
)
|
||||
assert not comparator(
|
||||
itertools.starmap(pow, [(2, 3), (3, 2)]),
|
||||
itertools.starmap(pow, [(2, 3), (3, 3)]),
|
||||
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]), itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)])
|
||||
)
|
||||
assert not comparator(itertools.starmap(pow, [(2, 3), (3, 2)]), itertools.starmap(pow, [(2, 3), (3, 3)]))
|
||||
|
||||
|
||||
def test_itertools_zip_longest() -> None:
|
||||
import itertools
|
||||
|
||||
assert comparator(
|
||||
itertools.zip_longest("AB", "xyz", fillvalue="-"),
|
||||
itertools.zip_longest("AB", "xyz", fillvalue="-"),
|
||||
itertools.zip_longest("AB", "xyz", fillvalue="-"), itertools.zip_longest("AB", "xyz", fillvalue="-")
|
||||
)
|
||||
assert not comparator(
|
||||
itertools.zip_longest("AB", "xyz", fillvalue="-"),
|
||||
itertools.zip_longest("AB", "xyz", fillvalue="*"),
|
||||
itertools.zip_longest("AB", "xyz", fillvalue="-"), itertools.zip_longest("AB", "xyz", fillvalue="*")
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -685,8 +670,7 @@ def test_itertools_groupby() -> None:
|
|||
|
||||
# With key function
|
||||
assert comparator(
|
||||
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
|
||||
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
|
||||
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x), itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -714,10 +698,7 @@ def test_itertools_in_containers() -> None:
|
|||
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
|
||||
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
|
||||
)
|
||||
assert not comparator(
|
||||
[itertools.product("AB", repeat=2)],
|
||||
[itertools.product("AC", repeat=2)],
|
||||
)
|
||||
assert not comparator([itertools.product("AB", repeat=2)], [itertools.product("AC", repeat=2)])
|
||||
|
||||
# Different itertools types should not match
|
||||
assert not comparator(itertools.chain([1, 2]), itertools.islice([1, 2], 2))
|
||||
|
|
@ -2017,59 +1998,30 @@ def test_torch_nn_sequential():
|
|||
|
||||
# Test identical Sequential modules
|
||||
torch.manual_seed(42)
|
||||
a = nn.Sequential(
|
||||
nn.Linear(10, 20),
|
||||
nn.ReLU(),
|
||||
nn.Linear(20, 5)
|
||||
)
|
||||
a = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
|
||||
torch.manual_seed(42)
|
||||
b = nn.Sequential(
|
||||
nn.Linear(10, 20),
|
||||
nn.ReLU(),
|
||||
nn.Linear(20, 5)
|
||||
)
|
||||
b = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
|
||||
assert comparator(a, b)
|
||||
|
||||
# Test Sequential with different weights
|
||||
torch.manual_seed(42)
|
||||
c = nn.Sequential(
|
||||
nn.Linear(10, 20),
|
||||
nn.ReLU(),
|
||||
nn.Linear(20, 5)
|
||||
)
|
||||
c = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
|
||||
torch.manual_seed(123)
|
||||
d = nn.Sequential(
|
||||
nn.Linear(10, 20),
|
||||
nn.ReLU(),
|
||||
nn.Linear(20, 5)
|
||||
)
|
||||
d = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
|
||||
assert not comparator(c, d)
|
||||
|
||||
# Test Sequential with different number of layers
|
||||
torch.manual_seed(42)
|
||||
e = nn.Sequential(
|
||||
nn.Linear(10, 20),
|
||||
nn.ReLU()
|
||||
)
|
||||
e = nn.Sequential(nn.Linear(10, 20), nn.ReLU())
|
||||
torch.manual_seed(42)
|
||||
f = nn.Sequential(
|
||||
nn.Linear(10, 20),
|
||||
nn.ReLU(),
|
||||
nn.Linear(20, 5)
|
||||
)
|
||||
f = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
|
||||
assert not comparator(e, f)
|
||||
|
||||
# Test Sequential with different layer types
|
||||
torch.manual_seed(42)
|
||||
g = nn.Sequential(
|
||||
nn.Linear(10, 20),
|
||||
nn.ReLU()
|
||||
)
|
||||
g = nn.Sequential(nn.Linear(10, 20), nn.ReLU())
|
||||
torch.manual_seed(42)
|
||||
h = nn.Sequential(
|
||||
nn.Linear(10, 20),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
h = nn.Sequential(nn.Linear(10, 20), nn.Sigmoid())
|
||||
assert not comparator(g, h)
|
||||
|
||||
|
||||
|
|
@ -2106,28 +2058,16 @@ def test_torch_nn_moduledict():
|
|||
|
||||
# Test identical ModuleDict
|
||||
torch.manual_seed(42)
|
||||
a = nn.ModuleDict({
|
||||
"fc1": nn.Linear(10, 20),
|
||||
"fc2": nn.Linear(20, 5)
|
||||
})
|
||||
a = nn.ModuleDict({"fc1": nn.Linear(10, 20), "fc2": nn.Linear(20, 5)})
|
||||
torch.manual_seed(42)
|
||||
b = nn.ModuleDict({
|
||||
"fc1": nn.Linear(10, 20),
|
||||
"fc2": nn.Linear(20, 5)
|
||||
})
|
||||
b = nn.ModuleDict({"fc1": nn.Linear(10, 20), "fc2": nn.Linear(20, 5)})
|
||||
assert comparator(a, b)
|
||||
|
||||
# Test ModuleDict with different keys
|
||||
torch.manual_seed(42)
|
||||
c = nn.ModuleDict({
|
||||
"fc1": nn.Linear(10, 20),
|
||||
"fc2": nn.Linear(20, 5)
|
||||
})
|
||||
c = nn.ModuleDict({"fc1": nn.Linear(10, 20), "fc2": nn.Linear(20, 5)})
|
||||
torch.manual_seed(42)
|
||||
d = nn.ModuleDict({
|
||||
"layer1": nn.Linear(10, 20),
|
||||
"layer2": nn.Linear(20, 5)
|
||||
})
|
||||
d = nn.ModuleDict({"layer1": nn.Linear(10, 20), "layer2": nn.Linear(20, 5)})
|
||||
assert not comparator(c, d)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -294,7 +294,7 @@ class MockTestConfig:
|
|||
"""Mocks codeflash.verification.verification_utils.TestConfig"""
|
||||
|
||||
tests_root: Path
|
||||
tests_project_rootdir: Path = Path(".")
|
||||
tests_project_rootdir: Path = Path()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import pytest
|
|||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.python.static_analysis.code_extractor import get_code
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.static_analysis.code_extractor import get_code
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import pytest
|
|||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.models.models import FunctionParent, get_code_block_splitter
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.models.models import FunctionParent, get_code_block_splitter
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -412,7 +412,9 @@ def test_conditional_class_definitions() -> None:
|
|||
platform = "other"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set()).code
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set()
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -123,7 +123,9 @@ def test_multiple_top_level_classes() -> None:
|
|||
def process(self):
|
||||
return "C"
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"}).code
|
||||
result = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"}
|
||||
).code
|
||||
|
||||
expected = dedent("""
|
||||
class ClassA:
|
||||
|
|
|
|||
|
|
@ -304,7 +304,9 @@ def test_conditional_class_definitions() -> None:
|
|||
print("other")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set()).code
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set()
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -305,9 +305,7 @@ class TestShouldModifySkipConfirm:
|
|||
"""With skip_confirm and valid config, should return (False, config) — no reconfigure."""
|
||||
monkeypatch.chdir(tmp_project)
|
||||
codeflash_config = {"moduleRoot": "."}
|
||||
(tmp_project / "package.json").write_text(
|
||||
json.dumps({"name": "test", "codeflash": codeflash_config})
|
||||
)
|
||||
(tmp_project / "package.json").write_text(json.dumps({"name": "test", "codeflash": codeflash_config}))
|
||||
|
||||
should_modify, config = should_modify_package_json_config(skip_confirm=True)
|
||||
|
||||
|
|
@ -320,9 +318,7 @@ class TestShouldModifySkipConfirm:
|
|||
"""With skip_confirm and invalid config (bad moduleRoot), should return (True, None)."""
|
||||
monkeypatch.chdir(tmp_project)
|
||||
codeflash_config = {"moduleRoot": "/nonexistent/path/that/does/not/exist"}
|
||||
(tmp_project / "package.json").write_text(
|
||||
json.dumps({"name": "test", "codeflash": codeflash_config})
|
||||
)
|
||||
(tmp_project / "package.json").write_text(json.dumps({"name": "test", "codeflash": codeflash_config}))
|
||||
|
||||
should_modify, config = should_modify_package_json_config(skip_confirm=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -470,8 +470,7 @@ class OuterClass:
|
|||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
code_with_decorator = nested_async_code.replace(
|
||||
" async def nested_async_method",
|
||||
f" @{decorator_name}\n async def nested_async_method",
|
||||
" async def nested_async_method", f" @{decorator_name}\n async def nested_async_method"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ import os
|
|||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ from codeflash.code_utils.instrument_existing_tests import (
|
|||
FunctionImportedAsVisitor,
|
||||
inject_profiling_into_existing_test,
|
||||
)
|
||||
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports
|
||||
from codeflash.models.models import (
|
||||
CodeOptimizationContext,
|
||||
CodePosition,
|
||||
|
|
@ -27,7 +28,6 @@ from codeflash.models.models import (
|
|||
TestsInFile,
|
||||
TestType,
|
||||
)
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -776,7 +776,8 @@ public void testFibonacci() {
|
|||
|
||||
def test_junit4_message_first_with_string_expected(self):
|
||||
"""When assertEquals has 3 args and the first is a message but the second is also a string,
|
||||
the type should be inferred from the second arg (the real expected value), not the message."""
|
||||
the type should be inferred from the second arg (the real expected value), not the message.
|
||||
"""
|
||||
source = """\
|
||||
@Test
|
||||
public void testGetName() {
|
||||
|
|
@ -807,7 +808,8 @@ public void testIsValid() {
|
|||
|
||||
def test_two_arg_string_expected_not_treated_as_message(self):
|
||||
"""When assertEquals has only 2 args and the first is a string, it IS the expected value,
|
||||
not a message. This tests that we don't incorrectly skip the first arg."""
|
||||
not a message. This tests that we don't incorrectly skip the first arg.
|
||||
"""
|
||||
source = """\
|
||||
@Test
|
||||
public void testGetGreeting() {
|
||||
|
|
@ -869,8 +871,7 @@ void test() {
|
|||
|
||||
def test_qualified_name_support(self):
|
||||
transformer = JavaAssertTransformer(
|
||||
function_name="fibonacci",
|
||||
qualified_name="com.example.Calculator.fibonacci",
|
||||
function_name="fibonacci", qualified_name="com.example.Calculator.fibonacci"
|
||||
)
|
||||
assert transformer.qualified_name == "com.example.Calculator.fibonacci"
|
||||
|
||||
|
|
|
|||
|
|
@ -2,14 +2,11 @@
|
|||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.test_runner import (
|
||||
_multimodule_deps_installed,
|
||||
ensure_multi_module_deps_installed,
|
||||
)
|
||||
from codeflash.languages.java.test_runner import _multimodule_deps_installed, ensure_multi_module_deps_installed
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
@ -85,9 +82,7 @@ def test_different_modules_not_cached(mock_run, mock_mvn):
|
|||
@patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout")
|
||||
def test_returns_false_on_maven_failure(mock_run, mock_mvn):
|
||||
"""Non-zero exit code should return False and NOT cache."""
|
||||
mock_run.return_value = subprocess.CompletedProcess(
|
||||
args=["mvn"], returncode=1, stdout="", stderr="BUILD FAILURE"
|
||||
)
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=["mvn"], returncode=1, stdout="", stderr="BUILD FAILURE")
|
||||
|
||||
root = Path("/project")
|
||||
result = ensure_multi_module_deps_installed(root, "guava-tests", {})
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def make_func(name: str, class_name: str, file_path: Path | None = None) -> Func
|
|||
|
||||
|
||||
def make_test_method(
|
||||
name: str, class_name: str, starting_line: int, ending_line: int, file_path: Path | None = None,
|
||||
name: str, class_name: str, starting_line: int, ending_line: int, file_path: Path | None = None
|
||||
) -> FunctionToOptimize:
|
||||
return FunctionToOptimize(
|
||||
function_name=name,
|
||||
|
|
@ -329,9 +329,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
type_map = {"calc": "Calculator"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 5, analyzer, type_map, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 5, analyzer, type_map, {})
|
||||
assert "Calculator.add" in resolved
|
||||
|
||||
def test_static_method_call(self, analyzer):
|
||||
|
|
@ -344,9 +342,7 @@ class FooTest {
|
|||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
|
||||
assert "Calculator.add" in resolved
|
||||
|
||||
def test_static_import_call(self, analyzer):
|
||||
|
|
@ -361,9 +357,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
static_map = {"add": "Calculator"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 3, 5, analyzer, {}, static_map,
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 3, 5, analyzer, {}, static_map)
|
||||
assert "Calculator.add" in resolved
|
||||
|
||||
def test_new_expression_method_call(self, analyzer):
|
||||
|
|
@ -376,9 +370,7 @@ class FooTest {
|
|||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
|
||||
assert "Calculator.add" in resolved
|
||||
|
||||
def test_field_access_via_this(self, analyzer):
|
||||
|
|
@ -393,9 +385,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
type_map = {"calculator": "Calculator"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 3, 5, analyzer, type_map, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 3, 5, analyzer, type_map, {})
|
||||
assert "Calculator.add" in resolved
|
||||
|
||||
def test_unresolvable_call_not_included(self, analyzer):
|
||||
|
|
@ -408,9 +398,7 @@ class FooTest {
|
|||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
|
||||
# someUnknown is lowercase and not in type_map → not resolved
|
||||
assert len(resolved) == 0
|
||||
|
||||
|
|
@ -425,9 +413,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
# assertEquals has no receiver, and not in static_import_map
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
|
||||
assert len(resolved) == 0
|
||||
|
||||
def test_multiple_different_receivers(self, analyzer):
|
||||
|
|
@ -444,9 +430,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
type_map = {"calc": "Calculator", "buf": "Buffer"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 7, analyzer, type_map, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 7, analyzer, type_map, {})
|
||||
assert "Calculator.add" in resolved
|
||||
assert "Buffer.read" in resolved
|
||||
|
||||
|
|
@ -466,9 +450,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
type_map = {"calc": "Calculator"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 6, 9, analyzer, type_map, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 6, 9, analyzer, type_map, {})
|
||||
assert "Calculator.add" in resolved
|
||||
assert "Calculator.init" not in resolved
|
||||
|
||||
|
|
@ -769,10 +751,7 @@ class CalculatorTest {
|
|||
}
|
||||
}
|
||||
"""
|
||||
func_map = {
|
||||
"Calculator.add": make_func("add", "Calculator"),
|
||||
"Buffer.add": make_func("add", "Buffer"),
|
||||
}
|
||||
func_map = {"Calculator.add": make_func("add", "Calculator"), "Buffer.add": make_func("add", "Buffer")}
|
||||
test_method = make_test_method("testAdd", "CalculatorTest", 6, 10)
|
||||
matched = _match_test_to_functions(test_method, test_source, func_map, analyzer)
|
||||
# Local Calculator declaration shadows the Buffer field
|
||||
|
|
@ -792,7 +771,8 @@ class TestDiscoverTests:
|
|||
test_dir.mkdir(parents=True)
|
||||
|
||||
test_file = test_dir / "CalculatorTest.java"
|
||||
test_file.write_text("""\
|
||||
test_file.write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
|
||||
import com.example.Calculator;
|
||||
|
|
@ -814,7 +794,9 @@ class CalculatorTest {
|
|||
assertEquals(2, result);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [
|
||||
make_func("add", "Calculator"),
|
||||
|
|
@ -840,7 +822,8 @@ class CalculatorTest {
|
|||
test_dir.mkdir(parents=True)
|
||||
|
||||
test_file = test_dir / "MathUtilsTest.java"
|
||||
test_file.write_text("""\
|
||||
test_file.write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
|
||||
import com.example.MathUtils;
|
||||
|
|
@ -857,7 +840,9 @@ class MathUtilsTest {
|
|||
int result = MathUtils.abs(-3);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [
|
||||
make_func("square", "MathUtils"),
|
||||
|
|
@ -880,7 +865,8 @@ class MathUtilsTest {
|
|||
test_dir.mkdir(parents=True)
|
||||
|
||||
test_file = test_dir / "CalculatorTest.java"
|
||||
test_file.write_text("""\
|
||||
test_file.write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
|
||||
import com.example.Calculator;
|
||||
|
|
@ -905,7 +891,9 @@ class CalculatorTest {
|
|||
calculator.multiply(3, 4);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [
|
||||
make_func("add", "Calculator"),
|
||||
|
|
@ -1074,9 +1062,7 @@ class FooTest {
|
|||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 5, analyzer, {"obj": "Object"}, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 5, analyzer, {"obj": "Object"}, {})
|
||||
assert "Calculator.add" in resolved
|
||||
|
||||
def test_method_call_inside_if(self, analyzer):
|
||||
|
|
@ -1093,9 +1079,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
type_map = {"calc": "Calculator"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 7, analyzer, type_map, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 7, analyzer, type_map, {})
|
||||
assert "Calculator.add" in resolved
|
||||
|
||||
def test_method_call_inside_try_catch(self, analyzer):
|
||||
|
|
@ -1114,9 +1098,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
type_map = {"calc": "Calculator"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 9, analyzer, type_map, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 9, analyzer, type_map, {})
|
||||
assert "Calculator.add" in resolved
|
||||
assert "Calculator.reset" in resolved
|
||||
|
||||
|
|
@ -1134,9 +1116,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
type_map = {"calc": "Calculator"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 7, analyzer, type_map, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 7, analyzer, type_map, {})
|
||||
assert "Calculator.add" in resolved
|
||||
|
||||
def test_method_call_inside_lambda(self, analyzer):
|
||||
|
|
@ -1151,9 +1131,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
type_map = {"calc": "Calculator"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 5, analyzer, type_map, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 5, analyzer, type_map, {})
|
||||
assert "Calculator.add" in resolved
|
||||
|
||||
def test_duplicate_calls_resolved_once(self, analyzer):
|
||||
|
|
@ -1170,9 +1148,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
type_map = {"calc": "Calculator"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 7, analyzer, type_map, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 7, analyzer, type_map, {})
|
||||
# resolved is a set, so duplicates are naturally deduplicated
|
||||
assert resolved == {"Calculator.add", "Calculator.Calculator", "Calculator.<init>"}
|
||||
|
||||
|
|
@ -1190,9 +1166,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
type_map = {"calc": "Calculator", "buf": "Buffer"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 7, analyzer, type_map, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 7, analyzer, type_map, {})
|
||||
assert "Calculator.add" in resolved
|
||||
assert "Buffer.add" in resolved
|
||||
# Also includes constructor refs: Calculator.Calculator, Calculator.<init>, Buffer.Buffer, Buffer.<init>
|
||||
|
|
@ -1212,9 +1186,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
type_map = {"calc": "Calculator"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 5, analyzer, type_map, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 5, analyzer, type_map, {})
|
||||
# calc.getResult() resolves to Calculator.getResult
|
||||
assert "Calculator.getResult" in resolved
|
||||
# toString() is called on the return of getResult() which is unresolvable
|
||||
|
|
@ -1231,9 +1203,7 @@ class FooTest {
|
|||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
|
||||
assert len(resolved) == 0
|
||||
|
||||
def test_this_method_call_not_resolved(self, analyzer):
|
||||
|
|
@ -1247,9 +1217,7 @@ class FooTest {
|
|||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
|
||||
# this is not a field_access with a field that's in the type map, so not resolved
|
||||
assert len(resolved) == 0
|
||||
|
||||
|
|
@ -1263,9 +1231,7 @@ class FooTest {
|
|||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
|
||||
# getCalculator() returns a method_invocation node as object, can't resolve
|
||||
assert "Calculator.add" not in resolved
|
||||
|
||||
|
|
@ -1279,9 +1245,7 @@ class FooTest {
|
|||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
|
||||
assert "ArrayList.add" in resolved
|
||||
|
||||
def test_assertion_via_static_import_mapped_to_assertions_class(self, analyzer):
|
||||
|
|
@ -1297,9 +1261,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
static_map = {"assertEquals": "Assertions"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 3, 5, analyzer, {}, static_map,
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 3, 5, analyzer, {}, static_map)
|
||||
assert "Assertions.assertEquals" in resolved
|
||||
assert len(resolved) == 1
|
||||
|
||||
|
|
@ -1314,9 +1276,7 @@ class FooTest {
|
|||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
|
||||
assert "Calculator.Calculator" in resolved
|
||||
assert "Calculator.<init>" in resolved
|
||||
|
||||
|
|
@ -1334,9 +1294,7 @@ class FooTest {
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
type_map = {"records": "List"}
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 6, analyzer, type_map, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 6, analyzer, type_map, {})
|
||||
assert "BatchRead.BatchRead" in resolved
|
||||
assert "BatchRead.<init>" in resolved
|
||||
assert "Key.Key" in resolved
|
||||
|
|
@ -1353,9 +1311,7 @@ class FooTest {
|
|||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
resolved = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
|
||||
)
|
||||
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
|
||||
assert "HashMap.HashMap" in resolved
|
||||
assert "HashMap.<init>" in resolved
|
||||
|
||||
|
|
@ -1379,10 +1335,7 @@ class MyTest {
|
|||
}
|
||||
}
|
||||
"""
|
||||
func_map = {
|
||||
"Calculator.add": make_func("add", "Calculator"),
|
||||
"MathUtils.add": make_func("add", "MathUtils"),
|
||||
}
|
||||
func_map = {"Calculator.add": make_func("add", "Calculator"), "MathUtils.add": make_func("add", "MathUtils")}
|
||||
test_method = make_test_method("testAdd", "MyTest", 4, 8)
|
||||
matched = _match_test_to_functions(test_method, test_source, func_map, analyzer)
|
||||
assert matched == ["Calculator.add"]
|
||||
|
|
@ -1564,7 +1517,7 @@ class CalculatorTest {
|
|||
assert matched == []
|
||||
|
||||
def test_constructor_matched(self, analyzer):
|
||||
"""new ClassName() should match the constructor in the function map."""
|
||||
"""New ClassName() should match the constructor in the function map."""
|
||||
test_source = """\
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -1582,7 +1535,7 @@ class BatchReadTest {
|
|||
assert "BatchRead.BatchRead" in matched
|
||||
|
||||
def test_constructor_init_convention_matched(self, analyzer):
|
||||
"""new ClassName() should also match <init> naming convention."""
|
||||
"""New ClassName() should also match <init> naming convention."""
|
||||
test_source = """\
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -1599,7 +1552,7 @@ class BatchReadTest {
|
|||
assert "BatchRead.<init>" in matched
|
||||
|
||||
def test_constructor_does_not_match_unrelated_methods(self, analyzer):
|
||||
"""new BatchRead() should not cause BatchRead.read to match."""
|
||||
"""New BatchRead() should not cause BatchRead.read to match."""
|
||||
test_source = """\
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -1660,7 +1613,8 @@ class TestDiscoverTestsExtended:
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "CalculatorTests.java").write_text("""\
|
||||
(test_dir / "CalculatorTests.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -1671,7 +1625,9 @@ class CalculatorTests {
|
|||
calc.add(1, 2);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [make_func("add", "Calculator")]
|
||||
result = discover_tests(tmp_path, source_functions, analyzer)
|
||||
|
|
@ -1682,7 +1638,8 @@ class CalculatorTests {
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "TestCalculator.java").write_text("""\
|
||||
(test_dir / "TestCalculator.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -1693,7 +1650,9 @@ class TestCalculator {
|
|||
calc.add(1, 2);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [make_func("add", "Calculator")]
|
||||
result = discover_tests(tmp_path, source_functions, analyzer)
|
||||
|
|
@ -1710,7 +1669,8 @@ class TestCalculator {
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "CalculatorTest.java").write_text("""\
|
||||
(test_dir / "CalculatorTest.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -1733,12 +1693,11 @@ class CalculatorTest {
|
|||
calc.subtract(5, 3);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [
|
||||
make_func("add", "Calculator"),
|
||||
make_func("subtract", "Calculator"),
|
||||
]
|
||||
source_functions = [make_func("add", "Calculator"), make_func("subtract", "Calculator")]
|
||||
result = discover_tests(tmp_path, source_functions, analyzer)
|
||||
|
||||
assert "Calculator.add" in result
|
||||
|
|
@ -1753,7 +1712,8 @@ class CalculatorTest {
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "CalculatorTest.java").write_text("""\
|
||||
(test_dir / "CalculatorTest.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -1764,9 +1724,12 @@ class CalculatorTest {
|
|||
calc.add(1, 2);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
(test_dir / "IntegrationTest.java").write_text("""\
|
||||
(test_dir / "IntegrationTest.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -1777,7 +1740,9 @@ class IntegrationTest {
|
|||
calc.add(10, 20);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [make_func("add", "Calculator")]
|
||||
result = discover_tests(tmp_path, source_functions, analyzer)
|
||||
|
|
@ -1791,7 +1756,8 @@ class IntegrationTest {
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "CalculatorTest.java").write_text("""\
|
||||
(test_dir / "CalculatorTest.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
|
|
@ -1804,7 +1770,9 @@ class CalculatorTest {
|
|||
calc.add(a, b);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [make_func("add", "Calculator")]
|
||||
result = discover_tests(tmp_path, source_functions, analyzer)
|
||||
|
|
@ -1815,7 +1783,8 @@ class CalculatorTest {
|
|||
deep_dir = tmp_path / "test" / "com" / "example" / "deep"
|
||||
deep_dir.mkdir(parents=True)
|
||||
|
||||
(deep_dir / "NestedTest.java").write_text("""\
|
||||
(deep_dir / "NestedTest.java").write_text(
|
||||
"""\
|
||||
package com.example.deep;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -1826,7 +1795,9 @@ class NestedTest {
|
|||
calc.add(1, 2);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [make_func("add", "Calculator")]
|
||||
result = discover_tests(tmp_path, source_functions, analyzer)
|
||||
|
|
@ -1836,7 +1807,8 @@ class NestedTest {
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "CalculatorTest.java").write_text("""\
|
||||
(test_dir / "CalculatorTest.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -1847,7 +1819,9 @@ class CalculatorTest {
|
|||
calc.add(1, 2);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [make_func("add", "Calculator")]
|
||||
result = discover_tests(tmp_path, source_functions, analyzer)
|
||||
|
|
@ -1857,7 +1831,8 @@ class CalculatorTest {
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "CalculatorTest.java").write_text("""\
|
||||
(test_dir / "CalculatorTest.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -1868,7 +1843,9 @@ class CalculatorTest {
|
|||
calc.add(1, 2);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
result = discover_tests(tmp_path, [], analyzer)
|
||||
assert result == {}
|
||||
|
|
@ -1878,7 +1855,8 @@ class CalculatorTest {
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "BatchReadTest.java").write_text("""\
|
||||
(test_dir / "BatchReadTest.java").write_text(
|
||||
"""\
|
||||
package com.aerospike.test;
|
||||
import com.aerospike.client.BatchRead;
|
||||
import com.aerospike.client.Key;
|
||||
|
|
@ -1892,7 +1870,9 @@ class BatchReadTest {
|
|||
records.add(new BatchRead(new Key("ns", "set", "k2"), false));
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [
|
||||
make_func("BatchRead", "BatchRead"),
|
||||
|
|
@ -1988,7 +1968,8 @@ class TestFindTestsForFunction:
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "CalculatorTest.java").write_text("""\
|
||||
(test_dir / "CalculatorTest.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -1999,7 +1980,9 @@ class CalculatorTest {
|
|||
calc.add(1, 2);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
func = make_func("add", "Calculator")
|
||||
result = find_tests_for_function(func, tmp_path, analyzer)
|
||||
|
|
@ -2020,7 +2003,8 @@ class TestDiscoverAllTests:
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "CalculatorTest.java").write_text("""\
|
||||
(test_dir / "CalculatorTest.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -2031,7 +2015,9 @@ class CalculatorTest {
|
|||
@Test
|
||||
void testSubtract() {}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
all_tests = discover_all_tests(tmp_path, analyzer)
|
||||
names = {t.function_name for t in all_tests}
|
||||
|
|
@ -2048,32 +2034,40 @@ class CalculatorTest {
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "ATest.java").write_text("""\
|
||||
(test_dir / "ATest.java").write_text(
|
||||
"""\
|
||||
import org.junit.jupiter.api.Test;
|
||||
class ATest {
|
||||
@Test
|
||||
void testA() {}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
(test_dir / "BTest.java").write_text("""\
|
||||
(test_dir / "BTest.java").write_text(
|
||||
"""\
|
||||
import org.junit.jupiter.api.Test;
|
||||
class BTest {
|
||||
@Test
|
||||
void testB() {}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
all_tests = discover_all_tests(tmp_path, analyzer)
|
||||
names = {t.function_name for t in all_tests}
|
||||
assert names == {"testA", "testB"}
|
||||
|
||||
def test_no_false_positive_import_only_integration(self, tmp_path, analyzer):
|
||||
"""A test file that imports Calculator but never calls its methods should not match."""
|
||||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
test_file = test_dir / "SomeTest.java"
|
||||
test_file.write_text("""\
|
||||
test_file.write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
|
||||
import com.example.Calculator;
|
||||
|
|
@ -2085,12 +2079,11 @@ class SomeTest {
|
|||
int x = 42;
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [
|
||||
make_func("add", "Calculator"),
|
||||
make_func("subtract", "Calculator"),
|
||||
]
|
||||
source_functions = [make_func("add", "Calculator"), make_func("subtract", "Calculator")]
|
||||
|
||||
result = discover_tests(tmp_path, source_functions, analyzer)
|
||||
assert result == {}
|
||||
|
|
@ -2099,7 +2092,8 @@ class SomeTest {
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "CalculatorTest.java").write_text("""\
|
||||
(test_dir / "CalculatorTest.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -2110,9 +2104,12 @@ class CalculatorTest {
|
|||
calc.add(1, 2);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
(test_dir / "BufferTest.java").write_text("""\
|
||||
(test_dir / "BufferTest.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -2123,13 +2120,11 @@ class BufferTest {
|
|||
buf.read();
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [
|
||||
make_func("add", "Calculator"),
|
||||
make_func("read", "Buffer"),
|
||||
make_func("write", "Buffer"),
|
||||
]
|
||||
source_functions = [make_func("add", "Calculator"), make_func("read", "Buffer"), make_func("write", "Buffer")]
|
||||
|
||||
result = discover_tests(tmp_path, source_functions, analyzer)
|
||||
|
||||
|
|
@ -2147,7 +2142,8 @@ class BufferTest {
|
|||
test_dir.mkdir(parents=True)
|
||||
|
||||
# This file matches *Test.java pattern
|
||||
(test_dir / "CalculatorTest.java").write_text("""\
|
||||
(test_dir / "CalculatorTest.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -2158,7 +2154,9 @@ class CalculatorTest {
|
|||
calc.add(1, 2);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [make_func("add", "Calculator")]
|
||||
result = discover_tests(tmp_path, source_functions, analyzer)
|
||||
|
|
@ -2171,7 +2169,8 @@ class CalculatorTest {
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "MathUtilsTest.java").write_text("""\
|
||||
(test_dir / "MathUtilsTest.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import static com.example.MathUtils.square;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
|
@ -2182,12 +2181,11 @@ class MathUtilsTest {
|
|||
int result = square(5);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [
|
||||
make_func("square", "MathUtils"),
|
||||
make_func("cube", "MathUtils"),
|
||||
]
|
||||
source_functions = [make_func("square", "MathUtils"), make_func("cube", "MathUtils")]
|
||||
|
||||
result = discover_tests(tmp_path, source_functions, analyzer)
|
||||
|
||||
|
|
@ -2198,7 +2196,8 @@ class MathUtilsTest {
|
|||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(test_dir / "CalculatorTest.java").write_text("""\
|
||||
(test_dir / "CalculatorTest.java").write_text(
|
||||
"""\
|
||||
package com.example;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -2210,7 +2209,9 @@ class CalculatorTest {
|
|||
int b = calc.multiply(a, 3);
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
source_functions = [
|
||||
make_func("add", "Calculator"),
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@
|
|||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.test_runner import _run_maven_tests, _build_test_filter
|
||||
from codeflash.languages.java.test_runner import _build_test_filter, _run_maven_tests
|
||||
from codeflash.models.models import TestFile, TestFiles, TestType
|
||||
|
||||
|
||||
|
|
@ -40,15 +41,11 @@ def test_build_test_filter_with_valid_paths():
|
|||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=Path(
|
||||
"/project/src/test/java/com/example/Test1__perfinstrumented.java"
|
||||
),
|
||||
benchmarking_file_path=Path(
|
||||
"/project/src/test/java/com/example/Test1__perfonlyinstrumented.java"
|
||||
),
|
||||
instrumented_behavior_file_path=Path("/project/src/test/java/com/example/Test1__perfinstrumented.java"),
|
||||
benchmarking_file_path=Path("/project/src/test/java/com/example/Test1__perfonlyinstrumented.java"),
|
||||
original_file_path=Path("/project/src/test/java/com/example/Test1.java"),
|
||||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -71,7 +68,7 @@ def test_run_maven_tests_raises_on_empty_filter():
|
|||
benchmarking_file_path=None, # Will cause empty filter in performance mode
|
||||
original_file_path=Path("/tmp/test.java"),
|
||||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -99,37 +96,26 @@ def test_run_maven_tests_succeeds_with_valid_filter():
|
|||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=Path(
|
||||
"/tmp/src/test/java/com/example/Test__perfinstrumented.java"
|
||||
),
|
||||
benchmarking_file_path=Path(
|
||||
"/tmp/src/test/java/com/example/Test__perfonlyinstrumented.java"
|
||||
),
|
||||
instrumented_behavior_file_path=Path("/tmp/src/test/java/com/example/Test__perfinstrumented.java"),
|
||||
benchmarking_file_path=Path("/tmp/src/test/java/com/example/Test__perfonlyinstrumented.java"),
|
||||
original_file_path=Path("/tmp/src/test/java/com/example/Test.java"),
|
||||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Mock Maven executable and _run_cmd_kill_pg_on_timeout (which replaced subprocess.run)
|
||||
with patch("codeflash.languages.java.test_runner.find_maven_executable") as mock_maven, \
|
||||
patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout") as mock_run:
|
||||
with (
|
||||
patch("codeflash.languages.java.test_runner.find_maven_executable") as mock_maven,
|
||||
patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout") as mock_run,
|
||||
):
|
||||
mock_maven.return_value = "mvn"
|
||||
mock_run.return_value = subprocess.CompletedProcess(
|
||||
args=[],
|
||||
returncode=0,
|
||||
stdout="Tests run: 1, Failures: 0, Errors: 0, Skipped: 0",
|
||||
stderr="",
|
||||
args=[], returncode=0, stdout="Tests run: 1, Failures: 0, Errors: 0, Skipped: 0", stderr=""
|
||||
)
|
||||
|
||||
# Should not raise - filter is valid
|
||||
result = _run_maven_tests(
|
||||
project_root,
|
||||
test_files,
|
||||
env,
|
||||
timeout=60,
|
||||
mode="performance",
|
||||
)
|
||||
result = _run_maven_tests(project_root, test_files, env, timeout=60, mode="performance")
|
||||
|
||||
# Verify Maven was called with -Dtest parameter
|
||||
assert mock_run.called
|
||||
|
|
|
|||
|
|
@ -47,8 +47,7 @@ def test_java_tests_project_rootdir_set_to_tests_root(tmp_path):
|
|||
|
||||
# Verify that tests_project_rootdir was updated to tests_root
|
||||
assert test_cfg.tests_project_rootdir == tests_root, (
|
||||
f"Expected tests_project_rootdir to be {tests_root}, "
|
||||
f"but got {test_cfg.tests_project_rootdir}"
|
||||
f"Expected tests_project_rootdir to be {tests_root}, but got {test_cfg.tests_project_rootdir}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -68,9 +67,7 @@ def test_python_tests_project_rootdir_unchanged(tmp_path):
|
|||
# Create test config
|
||||
original_tests_project_rootdir = project_root / "some" / "other" / "dir"
|
||||
test_cfg = TestConfig(
|
||||
tests_root=tests_root,
|
||||
project_root_path=project_root,
|
||||
tests_project_rootdir=original_tests_project_rootdir,
|
||||
tests_root=tests_root, project_root_path=project_root, tests_project_rootdir=original_tests_project_rootdir
|
||||
)
|
||||
|
||||
# Mock pytest discovery
|
||||
|
|
|
|||
|
|
@ -17,10 +17,7 @@ def make_func(name: str, class_name: str | None = None) -> FunctionToOptimize:
|
|||
"""Helper to create FunctionToOptimize for testing."""
|
||||
parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else []
|
||||
return FunctionToOptimize(
|
||||
function_name=name,
|
||||
file_path=Path("/test/file.js"),
|
||||
parents=parents,
|
||||
language="javascript",
|
||||
function_name=name, file_path=Path("/test/file.js"), parents=parents, language="javascript"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -458,7 +455,9 @@ class TestQualifiedNames:
|
|||
def test_simple_qualified_name(self) -> None:
|
||||
"""Test simple qualified name."""
|
||||
code = "expect(func(5)).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, make_func("func", class_name="module"), "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(
|
||||
code, make_func("func", class_name="module"), "capture", remove_assertions=True
|
||||
)
|
||||
assert result == "codeflash.capture('module.func', '1', func, 5);"
|
||||
|
||||
def test_nested_qualified_name(self) -> None:
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ import pytest
|
|||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -1840,7 +1840,9 @@ export const sendSlackMessage = async (
|
|||
test_config = TestConfig(
|
||||
tests_root=temp_project, tests_project_rootdir=temp_project, project_root_path=temp_project
|
||||
)
|
||||
func_optimizer = JavaScriptFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config, aiservice_client=MagicMock())
|
||||
func_optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=fto, test_cfg=test_config, aiservice_client=MagicMock()
|
||||
)
|
||||
ctx = func_optimizer.get_code_optimization_context().unwrap()
|
||||
|
||||
# The read_writable_code should contain the target function AND helper functions
|
||||
|
|
|
|||
|
|
@ -10,18 +10,19 @@ Each test verifies:
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language, ReferenceInfo
|
||||
from codeflash.languages.javascript.find_references import (
|
||||
ExportedFunction,
|
||||
Reference,
|
||||
ReferenceFinder,
|
||||
ExportedFunction,
|
||||
ReferenceSearchContext,
|
||||
find_references,
|
||||
)
|
||||
from codeflash.languages.base import Language, FunctionInfo, ReferenceInfo
|
||||
from codeflash.languages.python.static_analysis.code_extractor import _format_references_as_markdown
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
|
|
@ -29,12 +30,7 @@ from codeflash.models.models import FunctionParent
|
|||
def make_func(name: str, file_path: Path, class_name: str | None = None) -> FunctionToOptimize:
|
||||
"""Helper to create FunctionToOptimize for testing."""
|
||||
parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else []
|
||||
return FunctionToOptimize(
|
||||
function_name=name,
|
||||
file_path=file_path,
|
||||
parents=parents,
|
||||
language="javascript",
|
||||
)
|
||||
return FunctionToOptimize(function_name=name, file_path=file_path, parents=parents, language="javascript")
|
||||
|
||||
|
||||
class TestReferenceFinder:
|
||||
|
|
@ -93,30 +89,30 @@ class TestBasicNamedExports:
|
|||
|
||||
# Source file with named export
|
||||
(utils_dir / "DynamicBindingUtils.ts").write_text(
|
||||
'export function getDynamicBindings(value: string): string[] {\n'
|
||||
' const regex = /{{([^}]+)}}/g;\n'
|
||||
' return [];\n'
|
||||
'}\n'
|
||||
"export function getDynamicBindings(value: string): string[] {\n"
|
||||
" const regex = /{{([^}]+)}}/g;\n"
|
||||
" return [];\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
# File that imports and uses the function
|
||||
(src_dir / "evaluator.ts").write_text(
|
||||
"import { getDynamicBindings } from './utils/DynamicBindingUtils';\n"
|
||||
'\n'
|
||||
'export function evaluate(expression: string) {\n'
|
||||
' const bindings = getDynamicBindings(expression);\n'
|
||||
' return bindings;\n'
|
||||
'}\n'
|
||||
"\n"
|
||||
"export function evaluate(expression: string) {\n"
|
||||
" const bindings = getDynamicBindings(expression);\n"
|
||||
" return bindings;\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
# Another file that uses the function
|
||||
(src_dir / "validator.ts").write_text(
|
||||
"import { getDynamicBindings } from './utils/DynamicBindingUtils';\n"
|
||||
'\n'
|
||||
'export function validateBindings(input: string) {\n'
|
||||
' const bindings = getDynamicBindings(input);\n'
|
||||
' return bindings.length > 0;\n'
|
||||
'}\n'
|
||||
"\n"
|
||||
"export function validateBindings(input: string) {\n"
|
||||
" const bindings = getDynamicBindings(input);\n"
|
||||
" return bindings.length > 0;\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
return tmp_path
|
||||
|
|
@ -158,7 +154,8 @@ class TestBasicNamedExports:
|
|||
refs = finder.find_references(make_func("getDynamicBindings", source_file))
|
||||
|
||||
# Convert to ReferenceInfo and sort for consistent ordering
|
||||
ref_infos = sorted([
|
||||
ref_infos = sorted(
|
||||
[
|
||||
ReferenceInfo(
|
||||
file_path=r.file_path,
|
||||
line=r.line,
|
||||
|
|
@ -171,23 +168,25 @@ class TestBasicNamedExports:
|
|||
caller_function=r.caller_function,
|
||||
)
|
||||
for r in refs
|
||||
], key=lambda r: str(r.file_path))
|
||||
],
|
||||
key=lambda r: str(r.file_path),
|
||||
)
|
||||
|
||||
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
|
||||
|
||||
expected_markdown = (
|
||||
'```typescript:src/evaluator.ts\n'
|
||||
'function evaluate(expression: string) {\n'
|
||||
' const bindings = getDynamicBindings(expression);\n'
|
||||
' return bindings;\n'
|
||||
'}\n'
|
||||
'```\n'
|
||||
'```typescript:src/validator.ts\n'
|
||||
'function validateBindings(input: string) {\n'
|
||||
' const bindings = getDynamicBindings(input);\n'
|
||||
' return bindings.length > 0;\n'
|
||||
'}\n'
|
||||
'```\n'
|
||||
"```typescript:src/evaluator.ts\n"
|
||||
"function evaluate(expression: string) {\n"
|
||||
" const bindings = getDynamicBindings(expression);\n"
|
||||
" return bindings;\n"
|
||||
"}\n"
|
||||
"```\n"
|
||||
"```typescript:src/validator.ts\n"
|
||||
"function validateBindings(input: string) {\n"
|
||||
" const bindings = getDynamicBindings(input);\n"
|
||||
" return bindings.length > 0;\n"
|
||||
"}\n"
|
||||
"```\n"
|
||||
)
|
||||
assert markdown == expected_markdown
|
||||
|
||||
|
|
@ -203,30 +202,30 @@ class TestDefaultExports:
|
|||
|
||||
# Source file with default export
|
||||
(src_dir / "helper.ts").write_text(
|
||||
'function processData(data: any[]) {\n'
|
||||
' return data.filter(item => item.active);\n'
|
||||
'}\n'
|
||||
'\n'
|
||||
'export default processData;\n'
|
||||
"function processData(data: any[]) {\n"
|
||||
" return data.filter(item => item.active);\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"export default processData;\n"
|
||||
)
|
||||
|
||||
# File that imports the default export
|
||||
(src_dir / "main.ts").write_text(
|
||||
"import processData from './helper';\n"
|
||||
'\n'
|
||||
'export function handleData(items: any[]) {\n'
|
||||
' const processed = processData(items);\n'
|
||||
' return processed.length;\n'
|
||||
'}\n'
|
||||
"\n"
|
||||
"export function handleData(items: any[]) {\n"
|
||||
" const processed = processData(items);\n"
|
||||
" return processed.length;\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
# File that imports with a different name
|
||||
(src_dir / "alternative.ts").write_text(
|
||||
"import myProcessor from './helper';\n"
|
||||
'\n'
|
||||
'export function process(items: any[]) {\n'
|
||||
' return myProcessor(items);\n'
|
||||
'}\n'
|
||||
"\n"
|
||||
"export function process(items: any[]) {\n"
|
||||
" return myProcessor(items);\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
return tmp_path
|
||||
|
|
@ -263,30 +262,38 @@ class TestDefaultExports:
|
|||
source_file = project_root / "src" / "helper.ts"
|
||||
|
||||
refs = finder.find_references(make_func("processData", source_file))
|
||||
ref_infos = sorted([
|
||||
ref_infos = sorted(
|
||||
[
|
||||
ReferenceInfo(
|
||||
file_path=r.file_path, line=r.line, column=r.column,
|
||||
end_line=r.end_line, end_column=r.end_column, context=r.context,
|
||||
reference_type=r.reference_type, import_name=r.import_name,
|
||||
file_path=r.file_path,
|
||||
line=r.line,
|
||||
column=r.column,
|
||||
end_line=r.end_line,
|
||||
end_column=r.end_column,
|
||||
context=r.context,
|
||||
reference_type=r.reference_type,
|
||||
import_name=r.import_name,
|
||||
caller_function=r.caller_function,
|
||||
)
|
||||
for r in refs
|
||||
], key=lambda r: str(r.file_path))
|
||||
],
|
||||
key=lambda r: str(r.file_path),
|
||||
)
|
||||
|
||||
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
|
||||
|
||||
expected_markdown = (
|
||||
'```typescript:src/alternative.ts\n'
|
||||
'function process(items: any[]) {\n'
|
||||
' return myProcessor(items);\n'
|
||||
'}\n'
|
||||
'```\n'
|
||||
'```typescript:src/main.ts\n'
|
||||
'function handleData(items: any[]) {\n'
|
||||
' const processed = processData(items);\n'
|
||||
' return processed.length;\n'
|
||||
'}\n'
|
||||
'```\n'
|
||||
"```typescript:src/alternative.ts\n"
|
||||
"function process(items: any[]) {\n"
|
||||
" return myProcessor(items);\n"
|
||||
"}\n"
|
||||
"```\n"
|
||||
"```typescript:src/main.ts\n"
|
||||
"function handleData(items: any[]) {\n"
|
||||
" const processed = processData(items);\n"
|
||||
" return processed.length;\n"
|
||||
"}\n"
|
||||
"```\n"
|
||||
)
|
||||
assert markdown == expected_markdown
|
||||
|
||||
|
|
@ -304,23 +311,21 @@ class TestReExports:
|
|||
|
||||
# Original function file
|
||||
(utils_dir / "filterUtils.ts").write_text(
|
||||
'export function filterBySearchTerm(items: any[], term: string) {\n'
|
||||
' return items.filter(i => i.name.includes(term));\n'
|
||||
'}\n'
|
||||
"export function filterBySearchTerm(items: any[], term: string) {\n"
|
||||
" return items.filter(i => i.name.includes(term));\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
# Index file that re-exports
|
||||
(utils_dir / "index.ts").write_text(
|
||||
"export { filterBySearchTerm } from './filterUtils';\n"
|
||||
)
|
||||
(utils_dir / "index.ts").write_text("export { filterBySearchTerm } from './filterUtils';\n")
|
||||
|
||||
# Consumer that imports from index
|
||||
(src_dir / "consumer.ts").write_text(
|
||||
"import { filterBySearchTerm } from './utils';\n"
|
||||
'\n'
|
||||
'export function searchItems(items: any[], query: string) {\n'
|
||||
' return filterBySearchTerm(items, query);\n'
|
||||
'}\n'
|
||||
"\n"
|
||||
"export function searchItems(items: any[], query: string) {\n"
|
||||
" return filterBySearchTerm(items, query);\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
return tmp_path
|
||||
|
|
@ -352,27 +357,35 @@ class TestReExports:
|
|||
source_file = project_root / "src" / "utils" / "filterUtils.ts"
|
||||
|
||||
refs = finder.find_references(make_func("filterBySearchTerm", source_file))
|
||||
ref_infos = sorted([
|
||||
ref_infos = sorted(
|
||||
[
|
||||
ReferenceInfo(
|
||||
file_path=r.file_path, line=r.line, column=r.column,
|
||||
end_line=r.end_line, end_column=r.end_column, context=r.context,
|
||||
reference_type=r.reference_type, import_name=r.import_name,
|
||||
file_path=r.file_path,
|
||||
line=r.line,
|
||||
column=r.column,
|
||||
end_line=r.end_line,
|
||||
end_column=r.end_column,
|
||||
context=r.context,
|
||||
reference_type=r.reference_type,
|
||||
import_name=r.import_name,
|
||||
caller_function=r.caller_function,
|
||||
)
|
||||
for r in refs
|
||||
], key=lambda r: str(r.file_path))
|
||||
],
|
||||
key=lambda r: str(r.file_path),
|
||||
)
|
||||
|
||||
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
|
||||
|
||||
expected_markdown = (
|
||||
'```typescript:src/consumer.ts\n'
|
||||
'function searchItems(items: any[], query: string) {\n'
|
||||
' return filterBySearchTerm(items, query);\n'
|
||||
'}\n'
|
||||
'```\n'
|
||||
'```typescript:src/utils/index.ts\n'
|
||||
"```typescript:src/consumer.ts\n"
|
||||
"function searchItems(items: any[], query: string) {\n"
|
||||
" return filterBySearchTerm(items, query);\n"
|
||||
"}\n"
|
||||
"```\n"
|
||||
"```typescript:src/utils/index.ts\n"
|
||||
"export { filterBySearchTerm } from './filterUtils';\n"
|
||||
'```\n'
|
||||
"```\n"
|
||||
)
|
||||
assert markdown == expected_markdown
|
||||
|
||||
|
|
@ -388,19 +401,17 @@ class TestCallbackPatterns:
|
|||
|
||||
# Helper function
|
||||
(src_dir / "transforms.ts").write_text(
|
||||
'export function normalizeItem(item: any) {\n'
|
||||
' return { ...item, normalized: true };\n'
|
||||
'}\n'
|
||||
"export function normalizeItem(item: any) {\n return { ...item, normalized: true };\n}\n"
|
||||
)
|
||||
|
||||
# Consumer using callbacks
|
||||
(src_dir / "processor.ts").write_text(
|
||||
"import { normalizeItem } from './transforms';\n"
|
||||
'\n'
|
||||
'export function processItems(items: any[]) {\n'
|
||||
' const normalized = items.map(normalizeItem);\n'
|
||||
' return normalized;\n'
|
||||
'}\n'
|
||||
"\n"
|
||||
"export function processItems(items: any[]) {\n"
|
||||
" const normalized = items.map(normalizeItem);\n"
|
||||
" return normalized;\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
return tmp_path
|
||||
|
|
@ -430,9 +441,14 @@ class TestCallbackPatterns:
|
|||
refs = finder.find_references(make_func("normalizeItem", source_file))
|
||||
ref_infos = [
|
||||
ReferenceInfo(
|
||||
file_path=r.file_path, line=r.line, column=r.column,
|
||||
end_line=r.end_line, end_column=r.end_column, context=r.context,
|
||||
reference_type=r.reference_type, import_name=r.import_name,
|
||||
file_path=r.file_path,
|
||||
line=r.line,
|
||||
column=r.column,
|
||||
end_line=r.end_line,
|
||||
end_column=r.end_column,
|
||||
context=r.context,
|
||||
reference_type=r.reference_type,
|
||||
import_name=r.import_name,
|
||||
caller_function=r.caller_function,
|
||||
)
|
||||
for r in refs
|
||||
|
|
@ -441,12 +457,12 @@ class TestCallbackPatterns:
|
|||
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
|
||||
|
||||
expected_markdown = (
|
||||
'```typescript:src/processor.ts\n'
|
||||
'function processItems(items: any[]) {\n'
|
||||
' const normalized = items.map(normalizeItem);\n'
|
||||
' return normalized;\n'
|
||||
'}\n'
|
||||
'```\n'
|
||||
"```typescript:src/processor.ts\n"
|
||||
"function processItems(items: any[]) {\n"
|
||||
" const normalized = items.map(normalizeItem);\n"
|
||||
" return normalized;\n"
|
||||
"}\n"
|
||||
"```\n"
|
||||
)
|
||||
assert expected_markdown == markdown
|
||||
|
||||
|
|
@ -462,19 +478,17 @@ class TestAliasImports:
|
|||
|
||||
# Source file
|
||||
(src_dir / "utils.ts").write_text(
|
||||
'export function computeValue(input: number): number {\n'
|
||||
' return input * 2;\n'
|
||||
'}\n'
|
||||
"export function computeValue(input: number): number {\n return input * 2;\n}\n"
|
||||
)
|
||||
|
||||
# File using alias
|
||||
(src_dir / "consumer.ts").write_text(
|
||||
"import { computeValue as calculate } from './utils';\n"
|
||||
'\n'
|
||||
'export function processNumber(n: number) {\n'
|
||||
' const result = calculate(n);\n'
|
||||
' return result + 10;\n'
|
||||
'}\n'
|
||||
"\n"
|
||||
"export function processNumber(n: number) {\n"
|
||||
" const result = calculate(n);\n"
|
||||
" return result + 10;\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
return tmp_path
|
||||
|
|
@ -504,9 +518,14 @@ class TestAliasImports:
|
|||
refs = finder.find_references(make_func("computeValue", source_file))
|
||||
ref_infos = [
|
||||
ReferenceInfo(
|
||||
file_path=r.file_path, line=r.line, column=r.column,
|
||||
end_line=r.end_line, end_column=r.end_column, context=r.context,
|
||||
reference_type=r.reference_type, import_name=r.import_name,
|
||||
file_path=r.file_path,
|
||||
line=r.line,
|
||||
column=r.column,
|
||||
end_line=r.end_line,
|
||||
end_column=r.end_column,
|
||||
context=r.context,
|
||||
reference_type=r.reference_type,
|
||||
import_name=r.import_name,
|
||||
caller_function=r.caller_function,
|
||||
)
|
||||
for r in refs
|
||||
|
|
@ -515,12 +534,12 @@ class TestAliasImports:
|
|||
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
|
||||
|
||||
expected_markdown = (
|
||||
'```typescript:src/consumer.ts\n'
|
||||
'function processNumber(n: number) {\n'
|
||||
' const result = calculate(n);\n'
|
||||
' return result + 10;\n'
|
||||
'}\n'
|
||||
'```\n'
|
||||
"```typescript:src/consumer.ts\n"
|
||||
"function processNumber(n: number) {\n"
|
||||
" const result = calculate(n);\n"
|
||||
" return result + 10;\n"
|
||||
"}\n"
|
||||
"```\n"
|
||||
)
|
||||
assert expected_markdown == markdown
|
||||
|
||||
|
|
@ -536,18 +555,16 @@ class TestNamespaceImports:
|
|||
|
||||
# Source file with multiple exports
|
||||
(src_dir / "mathUtils.ts").write_text(
|
||||
'export function add(a: number, b: number): number {\n'
|
||||
' return a + b;\n'
|
||||
'}\n'
|
||||
"export function add(a: number, b: number): number {\n return a + b;\n}\n"
|
||||
)
|
||||
|
||||
# File using namespace import
|
||||
(src_dir / "calculator.ts").write_text(
|
||||
"import * as MathUtils from './mathUtils';\n"
|
||||
'\n'
|
||||
'export function calculate(a: number, b: number) {\n'
|
||||
' return MathUtils.add(a, b);\n'
|
||||
'}\n'
|
||||
"\n"
|
||||
"export function calculate(a: number, b: number) {\n"
|
||||
" return MathUtils.add(a, b);\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
return tmp_path
|
||||
|
|
@ -576,9 +593,14 @@ class TestNamespaceImports:
|
|||
refs = finder.find_references(make_func("add", source_file))
|
||||
ref_infos = [
|
||||
ReferenceInfo(
|
||||
file_path=r.file_path, line=r.line, column=r.column,
|
||||
end_line=r.end_line, end_column=r.end_column, context=r.context,
|
||||
reference_type=r.reference_type, import_name=r.import_name,
|
||||
file_path=r.file_path,
|
||||
line=r.line,
|
||||
column=r.column,
|
||||
end_line=r.end_line,
|
||||
end_column=r.end_column,
|
||||
context=r.context,
|
||||
reference_type=r.reference_type,
|
||||
import_name=r.import_name,
|
||||
caller_function=r.caller_function,
|
||||
)
|
||||
for r in refs
|
||||
|
|
@ -587,11 +609,11 @@ class TestNamespaceImports:
|
|||
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
|
||||
|
||||
expected_markdown = (
|
||||
'```typescript:src/calculator.ts\n'
|
||||
'function calculate(a: number, b: number) {\n'
|
||||
' return MathUtils.add(a, b);\n'
|
||||
'}\n'
|
||||
'```\n'
|
||||
"```typescript:src/calculator.ts\n"
|
||||
"function calculate(a: number, b: number) {\n"
|
||||
" return MathUtils.add(a, b);\n"
|
||||
"}\n"
|
||||
"```\n"
|
||||
)
|
||||
assert expected_markdown == markdown
|
||||
|
||||
|
|
@ -607,21 +629,19 @@ class TestMemoizedFunctions:
|
|||
|
||||
# Source file with function to be memoized
|
||||
(src_dir / "expensive.ts").write_text(
|
||||
'export function computeExpensive(x: number): number {\n'
|
||||
' return x * x;\n'
|
||||
'}\n'
|
||||
"export function computeExpensive(x: number): number {\n return x * x;\n}\n"
|
||||
)
|
||||
|
||||
# File that memoizes the function
|
||||
(src_dir / "memoized.ts").write_text(
|
||||
"import memoize from 'micro-memoize';\n"
|
||||
"import { computeExpensive } from './expensive';\n"
|
||||
'\n'
|
||||
'export const memoizedCompute = memoize(computeExpensive);\n'
|
||||
'\n'
|
||||
'export function process(x: number) {\n'
|
||||
' return computeExpensive(x) + memoizedCompute(x);\n'
|
||||
'}\n'
|
||||
"\n"
|
||||
"export const memoizedCompute = memoize(computeExpensive);\n"
|
||||
"\n"
|
||||
"export function process(x: number) {\n"
|
||||
" return computeExpensive(x) + memoizedCompute(x);\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
return tmp_path
|
||||
|
|
@ -659,10 +679,10 @@ class TestSameFileReferences:
|
|||
|
||||
# File with internal references
|
||||
(src_dir / "recursive.ts").write_text(
|
||||
'export function factorial(n: number): number {\n'
|
||||
' if (n <= 1) return 1;\n'
|
||||
' return n * factorial(n - 1);\n'
|
||||
'}\n'
|
||||
"export function factorial(n: number): number {\n"
|
||||
" if (n <= 1) return 1;\n"
|
||||
" return n * factorial(n - 1);\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
return tmp_path
|
||||
|
|
@ -697,24 +717,20 @@ class TestComplexMultiFileScenarios:
|
|||
|
||||
# Core utility function
|
||||
(src_dir / "utils" / "widgetUtils.ts").write_text(
|
||||
'export function isLargeWidget(type: string): boolean {\n'
|
||||
" return ['TABLE', 'LIST'].includes(type);\n"
|
||||
'}\n'
|
||||
"export function isLargeWidget(type: string): boolean {\n return ['TABLE', 'LIST'].includes(type);\n}\n"
|
||||
)
|
||||
|
||||
# Re-export from index
|
||||
(src_dir / "utils" / "index.ts").write_text(
|
||||
"export { isLargeWidget } from './widgetUtils';\n"
|
||||
)
|
||||
(src_dir / "utils" / "index.ts").write_text("export { isLargeWidget } from './widgetUtils';\n")
|
||||
|
||||
# Component using the function via re-export
|
||||
(src_dir / "components" / "Widget.tsx").write_text(
|
||||
"import { isLargeWidget } from '../utils';\n"
|
||||
'\n'
|
||||
'export function Widget({ type }: { type: string }) {\n'
|
||||
' const isLarge = isLargeWidget(type);\n'
|
||||
' return isLarge;\n'
|
||||
'}\n'
|
||||
"\n"
|
||||
"export function Widget({ type }: { type: string }) {\n"
|
||||
" const isLarge = isLargeWidget(type);\n"
|
||||
" return isLarge;\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
return tmp_path
|
||||
|
|
@ -745,28 +761,36 @@ class TestComplexMultiFileScenarios:
|
|||
source_file = project_root / "src" / "utils" / "widgetUtils.ts"
|
||||
|
||||
refs = finder.find_references(make_func("isLargeWidget", source_file))
|
||||
ref_infos = sorted([
|
||||
ref_infos = sorted(
|
||||
[
|
||||
ReferenceInfo(
|
||||
file_path=r.file_path, line=r.line, column=r.column,
|
||||
end_line=r.end_line, end_column=r.end_column, context=r.context,
|
||||
reference_type=r.reference_type, import_name=r.import_name,
|
||||
file_path=r.file_path,
|
||||
line=r.line,
|
||||
column=r.column,
|
||||
end_line=r.end_line,
|
||||
end_column=r.end_column,
|
||||
context=r.context,
|
||||
reference_type=r.reference_type,
|
||||
import_name=r.import_name,
|
||||
caller_function=r.caller_function,
|
||||
)
|
||||
for r in refs
|
||||
], key=lambda r: str(r.file_path))
|
||||
],
|
||||
key=lambda r: str(r.file_path),
|
||||
)
|
||||
|
||||
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
|
||||
|
||||
expected_markdown = (
|
||||
'```typescript:src/components/Widget.tsx\n'
|
||||
'function Widget({ type }: { type: string }) {\n'
|
||||
' const isLarge = isLargeWidget(type);\n'
|
||||
' return isLarge;\n'
|
||||
'}\n'
|
||||
'```\n'
|
||||
'```typescript:src/utils/index.ts\n'
|
||||
"```typescript:src/components/Widget.tsx\n"
|
||||
"function Widget({ type }: { type: string }) {\n"
|
||||
" const isLarge = isLargeWidget(type);\n"
|
||||
" return isLarge;\n"
|
||||
"}\n"
|
||||
"```\n"
|
||||
"```typescript:src/utils/index.ts\n"
|
||||
"export { isLargeWidget } from './widgetUtils';\n"
|
||||
'```\n'
|
||||
"```\n"
|
||||
)
|
||||
assert markdown == expected_markdown
|
||||
|
||||
|
|
@ -794,13 +818,13 @@ class TestEdgeCases:
|
|||
"""Test handling of non-exported function."""
|
||||
# Create a file with non-exported function
|
||||
(project_root / "src" / "private.ts").write_text(
|
||||
'function internalHelper() {\n'
|
||||
' return 42;\n'
|
||||
'}\n'
|
||||
'\n'
|
||||
'export function publicFunction() {\n'
|
||||
' return internalHelper();\n'
|
||||
'}\n'
|
||||
"function internalHelper() {\n"
|
||||
" return 42;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"export function publicFunction() {\n"
|
||||
" return internalHelper();\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
finder = ReferenceFinder(project_root)
|
||||
|
|
@ -824,7 +848,9 @@ class TestEdgeCases:
|
|||
|
||||
def test_format_references_empty_list(self, project_root):
|
||||
"""Test _format_references_as_markdown with empty list."""
|
||||
markdown = _format_references_as_markdown([], project_root / "src" / "file.ts", project_root, Language.TYPESCRIPT)
|
||||
markdown = _format_references_as_markdown(
|
||||
[], project_root / "src" / "file.ts", project_root, Language.TYPESCRIPT
|
||||
)
|
||||
assert markdown == ""
|
||||
|
||||
|
||||
|
|
@ -839,22 +865,22 @@ class TestCommonJSPatterns:
|
|||
|
||||
# CommonJS module
|
||||
(src_dir / "helpers.js").write_text(
|
||||
'function processConfig(config) {\n'
|
||||
' return { ...config, processed: true };\n'
|
||||
'}\n'
|
||||
'\n'
|
||||
'module.exports = { processConfig };\n'
|
||||
"function processConfig(config) {\n"
|
||||
" return { ...config, processed: true };\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"module.exports = { processConfig };\n"
|
||||
)
|
||||
|
||||
# Consumer using destructured require
|
||||
(src_dir / "main.js").write_text(
|
||||
"const { processConfig } = require('./helpers');\n"
|
||||
'\n'
|
||||
'function handleConfig(config) {\n'
|
||||
' return processConfig(config);\n'
|
||||
'}\n'
|
||||
'\n'
|
||||
'module.exports = handleConfig;\n'
|
||||
"\n"
|
||||
"function handleConfig(config) {\n"
|
||||
" return processConfig(config);\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"module.exports = handleConfig;\n"
|
||||
)
|
||||
|
||||
return tmp_path
|
||||
|
|
@ -879,24 +905,28 @@ class TestCommonJSPatterns:
|
|||
source_file = project_root / "src" / "helpers.js"
|
||||
|
||||
refs = finder.find_references(make_func("processConfig", source_file))
|
||||
ref_infos = sorted([
|
||||
ref_infos = sorted(
|
||||
[
|
||||
ReferenceInfo(
|
||||
file_path=r.file_path, line=r.line, column=r.column,
|
||||
end_line=r.end_line, end_column=r.end_column, context=r.context,
|
||||
reference_type=r.reference_type, import_name=r.import_name,
|
||||
file_path=r.file_path,
|
||||
line=r.line,
|
||||
column=r.column,
|
||||
end_line=r.end_line,
|
||||
end_column=r.end_column,
|
||||
context=r.context,
|
||||
reference_type=r.reference_type,
|
||||
import_name=r.import_name,
|
||||
caller_function=r.caller_function,
|
||||
)
|
||||
for r in refs
|
||||
], key=lambda r: str(r.file_path))
|
||||
],
|
||||
key=lambda r: str(r.file_path),
|
||||
)
|
||||
|
||||
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.JAVASCRIPT)
|
||||
|
||||
expected_markdown = (
|
||||
'```javascript:src/main.js\n'
|
||||
'function handleConfig(config) {\n'
|
||||
' return processConfig(config);\n'
|
||||
'}\n'
|
||||
'```\n'
|
||||
"```javascript:src/main.js\nfunction handleConfig(config) {\n return processConfig(config);\n}\n```\n"
|
||||
)
|
||||
assert markdown == expected_markdown
|
||||
|
||||
|
|
@ -910,18 +940,10 @@ class TestConvenienceFunction:
|
|||
src_dir = tmp_path / "src"
|
||||
src_dir.mkdir()
|
||||
|
||||
(src_dir / "utils.ts").write_text(
|
||||
'export function helper() {\n'
|
||||
' return 42;\n'
|
||||
'}\n'
|
||||
)
|
||||
(src_dir / "utils.ts").write_text("export function helper() {\n return 42;\n}\n")
|
||||
|
||||
(src_dir / "main.ts").write_text(
|
||||
"import { helper } from './utils';\n"
|
||||
'\n'
|
||||
'export function main() {\n'
|
||||
' return helper();\n'
|
||||
'}\n'
|
||||
"import { helper } from './utils';\n\nexport function main() {\n return helper();\n}\n"
|
||||
)
|
||||
|
||||
return tmp_path
|
||||
|
|
@ -988,10 +1010,7 @@ class TestExportedFunctionDataclass:
|
|||
def test_exported_function_named(self, tmp_path):
|
||||
"""Test ExportedFunction for named export."""
|
||||
exp = ExportedFunction(
|
||||
function_name="myHelper",
|
||||
export_name="myHelper",
|
||||
is_default=False,
|
||||
file_path=tmp_path / "utils.ts",
|
||||
function_name="myHelper", export_name="myHelper", is_default=False, file_path=tmp_path / "utils.ts"
|
||||
)
|
||||
|
||||
assert exp.function_name == "myHelper"
|
||||
|
|
@ -1002,10 +1021,7 @@ class TestExportedFunctionDataclass:
|
|||
def test_exported_function_default(self, tmp_path):
|
||||
"""Test ExportedFunction for default export."""
|
||||
exp = ExportedFunction(
|
||||
function_name="processData",
|
||||
export_name="default",
|
||||
is_default=True,
|
||||
file_path=tmp_path / "processor.ts",
|
||||
function_name="processData", export_name="default", is_default=True, file_path=tmp_path / "processor.ts"
|
||||
)
|
||||
|
||||
assert exp.function_name == "processData"
|
||||
|
|
@ -1046,23 +1062,19 @@ class TestEdgeCasesAdvanced:
|
|||
|
||||
# Create circular import structure
|
||||
(src_dir / "a.ts").write_text(
|
||||
"import { funcB } from './b';\n"
|
||||
'\n'
|
||||
'export function funcA() {\n'
|
||||
' return funcB() + 1;\n'
|
||||
'}\n'
|
||||
"import { funcB } from './b';\n\nexport function funcA() {\n return funcB() + 1;\n}\n"
|
||||
)
|
||||
|
||||
(src_dir / "b.ts").write_text(
|
||||
"import { funcA } from './a';\n"
|
||||
'\n'
|
||||
'export function funcB() {\n'
|
||||
' return 42;\n'
|
||||
'}\n'
|
||||
'\n'
|
||||
'export function callsA() {\n'
|
||||
' return funcA();\n'
|
||||
'}\n'
|
||||
"\n"
|
||||
"export function funcB() {\n"
|
||||
" return 42;\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"export function callsA() {\n"
|
||||
" return funcA();\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
finder = ReferenceFinder(project_root)
|
||||
|
|
@ -1080,19 +1092,11 @@ class TestEdgeCasesAdvanced:
|
|||
"""Test that syntax errors in files are handled gracefully."""
|
||||
src_dir = project_root / "src"
|
||||
|
||||
(src_dir / "valid.ts").write_text(
|
||||
'export function validFunction() {\n'
|
||||
' return 42;\n'
|
||||
'}\n'
|
||||
)
|
||||
(src_dir / "valid.ts").write_text("export function validFunction() {\n return 42;\n}\n")
|
||||
|
||||
# Create a file with syntax error
|
||||
(src_dir / "invalid.ts").write_text(
|
||||
"import { validFunction } from './valid';\n"
|
||||
'\n'
|
||||
'export function broken( {\n'
|
||||
' return validFunction(\n'
|
||||
'}\n'
|
||||
"import { validFunction } from './valid';\n\nexport function broken( {\n return validFunction(\n}\n"
|
||||
)
|
||||
|
||||
finder = ReferenceFinder(project_root)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ These tests verify that the ImportResolver correctly resolves import paths
|
|||
to actual file paths, enabling multi-file context extraction.
|
||||
"""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.javascript.import_resolver import HelperSearchContext, ImportResolver, MultiFileHelperFinder
|
||||
|
|
|
|||
|
|
@ -278,6 +278,7 @@ version = '1.0.0'
|
|||
assert len(info.source_roots) == 1
|
||||
assert len(info.test_roots) == 1
|
||||
|
||||
|
||||
class TestXmlModuleExtraction:
|
||||
"""Tests for XML-based module extraction replacing regex."""
|
||||
|
||||
|
|
@ -374,6 +375,7 @@ class TestMavenProfiles:
|
|||
profiles = os.environ.get("CODEFLASH_MAVEN_PROFILES", "").strip()
|
||||
assert profiles == "my-profile"
|
||||
|
||||
|
||||
class TestMavenExecutableWithProjectRoot:
|
||||
"""Tests for find_maven_executable with project_root parameter."""
|
||||
|
||||
|
|
@ -554,7 +556,6 @@ class TestAddCodeflashDependencyToPom:
|
|||
def test_returns_false_when_no_dependencies_tag(self, tmp_path):
|
||||
pom = tmp_path / "pom.xml"
|
||||
pom.write_text(
|
||||
'<?xml version="1.0"?>\n<project><modelVersion>4.0.0</modelVersion></project>\n',
|
||||
encoding="utf-8",
|
||||
'<?xml version="1.0"?>\n<project><modelVersion>4.0.0</modelVersion></project>\n', encoding="utf-8"
|
||||
)
|
||||
assert add_codeflash_dependency_to_pom(pom) is False
|
||||
|
|
|
|||
|
|
@ -6,15 +6,15 @@ proceeding to SQLite file comparison (which would crash with FileNotFoundError s
|
|||
instrumentation hooks never fired).
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.either import Failure
|
||||
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
|
||||
def make_test_invocation(*, did_pass: bool, test_type: TestType = TestType.EXISTING_UNIT_TEST) -> FunctionTestInvocation:
|
||||
def make_test_invocation(
|
||||
*, did_pass: bool, test_type: TestType = TestType.EXISTING_UNIT_TEST
|
||||
) -> FunctionTestInvocation:
|
||||
"""Helper to create a FunctionTestInvocation with minimal required fields."""
|
||||
return FunctionTestInvocation(
|
||||
loop_index=1,
|
||||
|
|
@ -101,7 +101,8 @@ class TestCandidateBehavioralTestGuard:
|
|||
"""All test types failing should yield 0 total passed."""
|
||||
results = TestResults()
|
||||
for tt in [TestType.EXISTING_UNIT_TEST, TestType.GENERATED_REGRESSION, TestType.REPLAY_TEST]:
|
||||
results.add(FunctionTestInvocation(
|
||||
results.add(
|
||||
FunctionTestInvocation(
|
||||
loop_index=1,
|
||||
id=InvocationId(
|
||||
test_module_path="com.example.FooTest",
|
||||
|
|
@ -117,7 +118,8 @@ class TestCandidateBehavioralTestGuard:
|
|||
test_type=tt,
|
||||
return_value=None,
|
||||
timed_out=False,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
report = results.get_test_pass_fail_report_by_type()
|
||||
total_passed = sum(r.get("passed", 0) for r in report.values())
|
||||
|
|
@ -129,7 +131,8 @@ class TestCandidateBehavioralTestGuard:
|
|||
results = TestResults()
|
||||
# Many failures
|
||||
for i in range(5):
|
||||
results.add(FunctionTestInvocation(
|
||||
results.add(
|
||||
FunctionTestInvocation(
|
||||
loop_index=1,
|
||||
id=InvocationId(
|
||||
test_module_path="com.example.FooTest",
|
||||
|
|
@ -145,9 +148,11 @@ class TestCandidateBehavioralTestGuard:
|
|||
test_type=TestType.GENERATED_REGRESSION,
|
||||
return_value=None,
|
||||
timed_out=False,
|
||||
))
|
||||
)
|
||||
)
|
||||
# One pass
|
||||
results.add(FunctionTestInvocation(
|
||||
results.add(
|
||||
FunctionTestInvocation(
|
||||
loop_index=1,
|
||||
id=InvocationId(
|
||||
test_module_path="com.example.FooTest",
|
||||
|
|
@ -163,7 +168,8 @@ class TestCandidateBehavioralTestGuard:
|
|||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
return_value=None,
|
||||
timed_out=False,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
report = results.get_test_pass_fail_report_by_type()
|
||||
total_passed = sum(r.get("passed", 0) for r in report.values())
|
||||
|
|
|
|||
|
|
@ -1,24 +1,17 @@
|
|||
"""Tests for Java test result comparison."""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.comparator import (
|
||||
compare_invocations_directly,
|
||||
compare_test_results,
|
||||
values_equal,
|
||||
)
|
||||
from codeflash.languages.java.comparator import compare_invocations_directly, compare_test_results, values_equal
|
||||
from codeflash.models.models import TestDiffScope
|
||||
|
||||
# Skip tests that require Java runtime if Java is not available
|
||||
requires_java = pytest.mark.skipif(
|
||||
shutil.which("java") is None,
|
||||
reason="Java not found - skipping Comparator integration tests",
|
||||
shutil.which("java") is None, reason="Java not found - skipping Comparator integration tests"
|
||||
)
|
||||
|
||||
# Kryo-serialized bytes for common test values.
|
||||
|
|
@ -38,7 +31,9 @@ KRYO_STR_RESULT3 = bytes([0x03, 0x01, 0x7B, 0x22, 0x72, 0x65, 0x73, 0x75, 0x6C,
|
|||
KRYO_STR_VALUE1 = bytes([0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x31, 0xFD])
|
||||
KRYO_STR_VALUE2 = bytes([0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x32, 0xFD])
|
||||
KRYO_STR_VALUE42 = bytes([0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x34, 0x32, 0xFD])
|
||||
KRYO_STR_VALUE100 = bytes([0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x31, 0x30, 0x30, 0xFD])
|
||||
KRYO_STR_VALUE100 = bytes(
|
||||
[0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x31, 0x30, 0x30, 0xFD]
|
||||
)
|
||||
KRYO_DOUBLE_1_0000000001 = bytes([0x0A, 0x38, 0xDF, 0x06, 0x00, 0x00, 0x00, 0xF0, 0x3F])
|
||||
KRYO_DOUBLE_1_0000000002 = bytes([0x0A, 0x70, 0xBE, 0x0D, 0x00, 0x00, 0x00, 0xF0, 0x3F])
|
||||
KRYO_NAN = bytes([0x0A, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF8, 0x7F])
|
||||
|
|
@ -67,12 +62,8 @@ class TestDirectComparison:
|
|||
|
||||
def test_different_return_values(self):
|
||||
"""Test detecting different return values."""
|
||||
original = {
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '{"value": 99}', "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": '{"value": 42}', "error_json": None}}
|
||||
candidate = {"1": {"result_json": '{"value": 99}', "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
|
||||
|
|
@ -89,7 +80,7 @@ class TestDirectComparison:
|
|||
"2": {"result_json": '{"value": 100}', "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None},
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None}
|
||||
# Missing invocation 2
|
||||
}
|
||||
|
||||
|
|
@ -101,9 +92,7 @@ class TestDirectComparison:
|
|||
|
||||
def test_extra_invocation_in_candidate(self):
|
||||
"""Test detecting extra invocation in candidate."""
|
||||
original = {
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": '{"value": 42}', "error_json": None}}
|
||||
candidate = {
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None},
|
||||
"2": {"result_json": '{"value": 100}', "error_json": None}, # Extra
|
||||
|
|
@ -116,11 +105,9 @@ class TestDirectComparison:
|
|||
|
||||
def test_exception_differences(self):
|
||||
"""Test detecting exception differences."""
|
||||
original = {
|
||||
"1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'},
|
||||
}
|
||||
original = {"1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'}}
|
||||
candidate = {
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None}, # No exception
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None} # No exception
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
|
|
@ -176,12 +163,8 @@ class TestNumericValueEquality:
|
|||
|
||||
def test_numeric_comparison_in_direct_invocation(self):
|
||||
"""Test that compare_invocations_directly uses numeric-aware comparison."""
|
||||
original = {
|
||||
"1": {"result_json": "0", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "0.0", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "0", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "0.0", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
|
@ -189,12 +172,8 @@ class TestNumericValueEquality:
|
|||
|
||||
def test_integer_long_mismatch_resolved(self):
|
||||
"""Test that Integer(42) vs Long(42) serialized differently are still equal."""
|
||||
original = {
|
||||
"1": {"result_json": "42", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "42.0", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "42", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "42.0", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
|
@ -263,46 +242,30 @@ class TestNumericValueEquality:
|
|||
|
||||
def test_boolean_invocation_comparison(self):
|
||||
"""Test boolean return values in full invocation comparison."""
|
||||
original = {
|
||||
"1": {"result_json": "true", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "true", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "true", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "true", "error_json": None}}
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
||||
def test_boolean_mismatch_invocation_comparison(self):
|
||||
"""Test boolean mismatch is correctly detected."""
|
||||
original = {
|
||||
"1": {"result_json": "true", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "false", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "true", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "false", "error_json": None}}
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is False
|
||||
assert len(diffs) == 1
|
||||
|
||||
def test_array_invocation_comparison(self):
|
||||
"""Test array return values in full invocation comparison."""
|
||||
original = {
|
||||
"1": {"result_json": "[1, 2, 3]", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "[1, 2, 3]", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "[1, 2, 3]", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "[1, 2, 3]", "error_json": None}}
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
||||
def test_array_mismatch_invocation_comparison(self):
|
||||
"""Test array mismatch is correctly detected."""
|
||||
original = {
|
||||
"1": {"result_json": "[1, 2, 3]", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "[1, 2, 4]", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "[1, 2, 3]", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "[1, 2, 4]", "error_json": None}}
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is False
|
||||
assert len(diffs) == 1
|
||||
|
|
@ -382,35 +345,25 @@ class TestComparisonWithRealData:
|
|||
|
||||
def test_string_result_comparison(self):
|
||||
"""Test comparing string results."""
|
||||
original = {
|
||||
"1": {"result_json": '"Hello World"', "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '"Hello World"', "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": '"Hello World"', "error_json": None}}
|
||||
candidate = {"1": {"result_json": '"Hello World"', "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
||||
def test_array_result_comparison(self):
|
||||
"""Test comparing array results."""
|
||||
original = {
|
||||
"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
||||
def test_array_order_matters(self):
|
||||
"""Test that array order matters for comparison."""
|
||||
original = {
|
||||
"1": {"result_json": "[1, 2, 3]", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "[1, 2, 3]", "error_json": None}}
|
||||
candidate = {
|
||||
"1": {"result_json": "[3, 2, 1]", "error_json": None}, # Different order
|
||||
"1": {"result_json": "[3, 2, 1]", "error_json": None} # Different order
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
|
|
@ -418,24 +371,16 @@ class TestComparisonWithRealData:
|
|||
|
||||
def test_object_result_comparison(self):
|
||||
"""Test comparing object results."""
|
||||
original = {
|
||||
"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None}}
|
||||
candidate = {"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
||||
def test_null_result(self):
|
||||
"""Test comparing null results."""
|
||||
original = {
|
||||
"1": {"result_json": "null", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "null", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "null", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "null", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
|
@ -462,11 +407,9 @@ class TestEdgeCases:
|
|||
|
||||
def test_whitespace_in_json(self):
|
||||
"""Test that whitespace differences in JSON don't cause issues."""
|
||||
original = {
|
||||
"1": {"result_json": '{"a":1,"b":2}', "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": '{"a":1,"b":2}', "error_json": None}}
|
||||
candidate = {
|
||||
"1": {"result_json": '{ "a": 1, "b": 2 }', "error_json": None}, # With spaces
|
||||
"1": {"result_json": '{ "a": 1, "b": 2 }', "error_json": None} # With spaces
|
||||
}
|
||||
|
||||
# Note: Direct string comparison will see these as different
|
||||
|
|
@ -486,12 +429,8 @@ class TestEdgeCases:
|
|||
|
||||
def test_unicode_in_results(self):
|
||||
"""Test handling unicode in results."""
|
||||
original = {
|
||||
"1": {"result_json": '"Hello 世界 🌍"', "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '"Hello 世界 🌍"', "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": '"Hello 世界 🌍"', "error_json": None}}
|
||||
candidate = {"1": {"result_json": '"Hello 世界 🌍"', "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
|
@ -499,12 +438,8 @@ class TestEdgeCases:
|
|||
def test_deeply_nested_objects(self):
|
||||
"""Test handling deeply nested objects."""
|
||||
nested = '{"a": {"b": {"c": {"d": {"e": 1}}}}}'
|
||||
original = {
|
||||
"1": {"result_json": nested, "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": nested, "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": nested, "error_json": None}}
|
||||
candidate = {"1": {"result_json": nested, "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
|
@ -573,9 +508,7 @@ class TestTestResultsTableSchema:
|
|||
|
||||
return _create
|
||||
|
||||
def test_comparator_reads_test_results_table_identical(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
def test_comparator_reads_test_results_table_identical(self, tmp_path: Path, create_test_results_db):
|
||||
"""Test that Comparator correctly reads test_results table with identical results."""
|
||||
original_path = tmp_path / "original.db"
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
|
|
@ -607,9 +540,7 @@ class TestTestResultsTableSchema:
|
|||
assert equivalent is True
|
||||
assert len(diffs) == 0
|
||||
|
||||
def test_comparator_reads_test_results_table_different_values(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
def test_comparator_reads_test_results_table_different_values(self, tmp_path: Path, create_test_results_db):
|
||||
"""Test that Comparator detects different return values from test_results table."""
|
||||
original_path = tmp_path / "original.db"
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
|
|
@ -621,7 +552,7 @@ class TestTestResultsTableSchema:
|
|||
"loop_index": 1,
|
||||
"iteration_id": "1_0",
|
||||
"return_value": KRYO_STR_OLLEH,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
candidate_results = [
|
||||
|
|
@ -631,7 +562,7 @@ class TestTestResultsTableSchema:
|
|||
"loop_index": 1,
|
||||
"iteration_id": "1_0",
|
||||
"return_value": KRYO_STR_WRONG, # Different result
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
create_test_results_db(original_path, original_results)
|
||||
|
|
@ -644,9 +575,7 @@ class TestTestResultsTableSchema:
|
|||
assert len(diffs) == 1
|
||||
assert diffs[0].scope == TestDiffScope.RETURN_VALUE
|
||||
|
||||
def test_comparator_handles_multiple_loop_iterations(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
def test_comparator_handles_multiple_loop_iterations(self, tmp_path: Path, create_test_results_db):
|
||||
"""Test that Comparator correctly handles multiple loop iterations."""
|
||||
original_path = tmp_path / "original.db"
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
|
|
@ -676,9 +605,7 @@ class TestTestResultsTableSchema:
|
|||
assert equivalent is True
|
||||
assert len(diffs) == 0
|
||||
|
||||
def test_comparator_iteration_id_parsing(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
def test_comparator_iteration_id_parsing(self, tmp_path: Path, create_test_results_db):
|
||||
"""Test that Comparator correctly parses iteration_id format 'iter_testIteration'."""
|
||||
original_path = tmp_path / "original.db"
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
|
|
@ -711,32 +638,18 @@ class TestTestResultsTableSchema:
|
|||
assert equivalent is True
|
||||
assert len(diffs) == 0
|
||||
|
||||
def test_comparator_missing_result_in_candidate(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
def test_comparator_missing_result_in_candidate(self, tmp_path: Path, create_test_results_db):
|
||||
"""Test that Comparator detects missing results in candidate."""
|
||||
original_path = tmp_path / "original.db"
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
|
||||
original_results = [
|
||||
{
|
||||
"loop_index": 1,
|
||||
"iteration_id": "1_0",
|
||||
"return_value": KRYO_INT_1,
|
||||
},
|
||||
{
|
||||
"loop_index": 1,
|
||||
"iteration_id": "2_0",
|
||||
"return_value": KRYO_INT_2,
|
||||
},
|
||||
{"loop_index": 1, "iteration_id": "1_0", "return_value": KRYO_INT_1},
|
||||
{"loop_index": 1, "iteration_id": "2_0", "return_value": KRYO_INT_2},
|
||||
]
|
||||
|
||||
candidate_results = [
|
||||
{
|
||||
"loop_index": 1,
|
||||
"iteration_id": "1_0",
|
||||
"return_value": KRYO_INT_1,
|
||||
},
|
||||
{"loop_index": 1, "iteration_id": "1_0", "return_value": KRYO_INT_1}
|
||||
# Missing second iteration
|
||||
]
|
||||
|
||||
|
|
@ -779,12 +692,8 @@ class TestComparatorEdgeCases:
|
|||
For truly different values, the difference must exceed the epsilon threshold.
|
||||
"""
|
||||
# These values differ by ~3e-10, which is within epsilon tolerance (1e-9)
|
||||
original = {
|
||||
"1": {"result_json": "3.14159", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "3.141590001", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "3.14159", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "3.141590001", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True # Within epsilon tolerance
|
||||
|
|
@ -792,11 +701,9 @@ class TestComparatorEdgeCases:
|
|||
|
||||
def test_float_values_significantly_different(self):
|
||||
"""Float strings outside epsilon tolerance should be detected as different."""
|
||||
original = {
|
||||
"1": {"result_json": "3.14159", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "3.14159", "error_json": None}}
|
||||
candidate = {
|
||||
"1": {"result_json": "3.14160", "error_json": None}, # Differs by ~1e-5
|
||||
"1": {"result_json": "3.14160", "error_json": None} # Differs by ~1e-5
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
|
|
@ -806,12 +713,8 @@ class TestComparatorEdgeCases:
|
|||
|
||||
def test_nan_string_comparison(self):
|
||||
"""NaN as a string return value should be comparable."""
|
||||
original = {
|
||||
"1": {"result_json": "NaN", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "NaN", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "NaN", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "NaN", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
|
@ -819,12 +722,8 @@ class TestComparatorEdgeCases:
|
|||
|
||||
def test_nan_vs_number(self):
|
||||
"""NaN vs a normal number should be detected as different."""
|
||||
original = {
|
||||
"1": {"result_json": "NaN", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "0.0", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "NaN", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "0.0", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is False
|
||||
|
|
@ -832,12 +731,8 @@ class TestComparatorEdgeCases:
|
|||
|
||||
def test_infinity_string_comparison(self):
|
||||
"""Infinity as a string return value should be comparable."""
|
||||
original = {
|
||||
"1": {"result_json": "Infinity", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "Infinity", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "Infinity", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "Infinity", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
|
@ -845,12 +740,8 @@ class TestComparatorEdgeCases:
|
|||
|
||||
def test_negative_infinity(self):
|
||||
"""-Infinity as a string return value should be comparable."""
|
||||
original = {
|
||||
"1": {"result_json": "-Infinity", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "-Infinity", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "-Infinity", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "-Infinity", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
|
@ -858,12 +749,8 @@ class TestComparatorEdgeCases:
|
|||
|
||||
def test_infinity_vs_negative_infinity(self):
|
||||
"""Infinity and -Infinity should be detected as different."""
|
||||
original = {
|
||||
"1": {"result_json": "Infinity", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "-Infinity", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "Infinity", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "-Infinity", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is False
|
||||
|
|
@ -871,12 +758,8 @@ class TestComparatorEdgeCases:
|
|||
|
||||
def test_empty_collection_results(self):
|
||||
"""Empty array '[]' as return value should be comparable."""
|
||||
original = {
|
||||
"1": {"result_json": "[]", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "[]", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "[]", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "[]", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
|
@ -884,12 +767,8 @@ class TestComparatorEdgeCases:
|
|||
|
||||
def test_empty_object_results(self):
|
||||
"""Empty object '{}' as return value should be comparable."""
|
||||
original = {
|
||||
"1": {"result_json": "{}", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "{}", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "{}", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "{}", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
|
@ -917,12 +796,8 @@ class TestComparatorEdgeCases:
|
|||
1e+17 as floats due to precision limits, making them indistinguishable.
|
||||
This is a known limitation of floating-point comparison for very large integers.
|
||||
"""
|
||||
original = {
|
||||
"1": {"result_json": "99999999999999999", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "99999999999999998", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "99999999999999999", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "99999999999999998", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
# Due to float precision limits, these are considered equal
|
||||
|
|
@ -931,12 +806,8 @@ class TestComparatorEdgeCases:
|
|||
|
||||
def test_large_number_significantly_different(self):
|
||||
"""Large numbers with significant differences should be detected."""
|
||||
original = {
|
||||
"1": {"result_json": "100000000000000000", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "200000000000000000", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "100000000000000000", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "200000000000000000", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is False
|
||||
|
|
@ -944,12 +815,8 @@ class TestComparatorEdgeCases:
|
|||
|
||||
def test_null_vs_empty_string(self):
|
||||
"""'null' and '""' should NOT be equivalent."""
|
||||
original = {
|
||||
"1": {"result_json": "null", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '""', "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "null", "error_json": None}}
|
||||
candidate = {"1": {"result_json": '""', "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is False
|
||||
|
|
@ -958,10 +825,7 @@ class TestComparatorEdgeCases:
|
|||
|
||||
def test_boolean_string_comparison(self):
|
||||
"""Boolean strings 'true'/'false' should compare correctly."""
|
||||
original = {
|
||||
"1": {"result_json": "true", "error_json": None},
|
||||
"2": {"result_json": "false", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "true", "error_json": None}, "2": {"result_json": "false", "error_json": None}}
|
||||
candidate = {
|
||||
"1": {"result_json": "true", "error_json": None},
|
||||
"2": {"result_json": "false", "error_json": None},
|
||||
|
|
@ -972,12 +836,8 @@ class TestComparatorEdgeCases:
|
|||
|
||||
def test_boolean_true_vs_false(self):
|
||||
"""'true' vs 'false' should be detected as different."""
|
||||
original = {
|
||||
"1": {"result_json": "true", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "false", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": "true", "error_json": None}}
|
||||
candidate = {"1": {"result_json": "false", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is False
|
||||
|
|
@ -1024,12 +884,8 @@ class TestComparatorErrorHandling:
|
|||
|
||||
def test_compare_with_none_return_values_direct(self):
|
||||
"""Rows where result_json is None should be handled in direct comparison."""
|
||||
original = {
|
||||
"1": {"result_json": None, "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": None, "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": None, "error_json": None}}
|
||||
candidate = {"1": {"result_json": None, "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
|
@ -1037,12 +893,8 @@ class TestComparatorErrorHandling:
|
|||
|
||||
def test_compare_one_none_one_value_direct(self):
|
||||
"""One None result vs a real value should detect the difference."""
|
||||
original = {
|
||||
"1": {"result_json": None, "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "42", "error_json": None},
|
||||
}
|
||||
original = {"1": {"result_json": None, "error_json": None}}
|
||||
candidate = {"1": {"result_json": "42", "error_json": None}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is False
|
||||
|
|
@ -1050,12 +902,8 @@ class TestComparatorErrorHandling:
|
|||
|
||||
def test_compare_both_errors_identical(self):
|
||||
"""Identical errors in both original and candidate should be equivalent."""
|
||||
original = {
|
||||
"1": {"result_json": None, "error_json": '{"type": "IOException", "message": "file not found"}'},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": None, "error_json": '{"type": "IOException", "message": "file not found"}'},
|
||||
}
|
||||
original = {"1": {"result_json": None, "error_json": '{"type": "IOException", "message": "file not found"}'}}
|
||||
candidate = {"1": {"result_json": None, "error_json": '{"type": "IOException", "message": "file not found"}'}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
|
@ -1063,12 +911,8 @@ class TestComparatorErrorHandling:
|
|||
|
||||
def test_compare_different_error_types(self):
|
||||
"""Different error types should be detected."""
|
||||
original = {
|
||||
"1": {"result_json": None, "error_json": '{"type": "IOException"}'},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'},
|
||||
}
|
||||
original = {"1": {"result_json": None, "error_json": '{"type": "IOException"}'}}
|
||||
candidate = {"1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'}}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is False
|
||||
|
|
@ -1083,9 +927,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
|
|||
Extends TestTestResultsTableSchema to reuse the create_test_results_db fixture.
|
||||
"""
|
||||
|
||||
def test_comparator_float_epsilon_tolerance(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
def test_comparator_float_epsilon_tolerance(self, tmp_path: Path, create_test_results_db):
|
||||
"""Values differing by less than EPSILON (1e-9) should be treated as equivalent.
|
||||
|
||||
The Java Comparator uses EPSILON=1e-9 for float comparison.
|
||||
|
|
@ -1102,7 +944,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
|
|||
"loop_index": 1,
|
||||
"iteration_id": "1_0",
|
||||
"return_value": KRYO_DOUBLE_1_0000000001,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
candidate_results = [
|
||||
|
|
@ -1112,7 +954,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
|
|||
"loop_index": 1,
|
||||
"iteration_id": "1_0",
|
||||
"return_value": KRYO_DOUBLE_1_0000000002,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
create_test_results_db(original_path, original_results)
|
||||
|
|
@ -1124,9 +966,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
|
|||
assert equivalent is True
|
||||
assert len(diffs) == 0
|
||||
|
||||
def test_comparator_nan_handling(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
def test_comparator_nan_handling(self, tmp_path: Path, create_test_results_db):
|
||||
"""Java Comparator should handle NaN return values."""
|
||||
original_path = tmp_path / "original.db"
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
|
|
@ -1138,7 +978,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
|
|||
"loop_index": 1,
|
||||
"iteration_id": "1_0",
|
||||
"return_value": KRYO_NAN,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
create_test_results_db(original_path, results)
|
||||
|
|
@ -1150,9 +990,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
|
|||
assert equivalent is True
|
||||
assert len(diffs) == 0
|
||||
|
||||
def test_comparator_empty_table(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
def test_comparator_empty_table(self, tmp_path: Path, create_test_results_db):
|
||||
"""Empty test_results tables should result in equivalent=False (vacuous equivalence guard)."""
|
||||
original_path = tmp_path / "original.db"
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
|
|
@ -1167,9 +1005,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
|
|||
assert equivalent is False
|
||||
assert len(diffs) == 0
|
||||
|
||||
def test_comparator_infinity_handling(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
def test_comparator_infinity_handling(self, tmp_path: Path, create_test_results_db):
|
||||
"""Java Comparator should handle Infinity return values correctly."""
|
||||
original_path = tmp_path / "original.db"
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
|
|
|
|||
|
|
@ -7,25 +7,12 @@ fail with an error to maintain strict correctness guarantees.
|
|||
|
||||
import inspect
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.comparator import (
|
||||
compare_test_results as java_compare_test_results,
|
||||
)
|
||||
from codeflash.models.models import (
|
||||
FunctionTestInvocation,
|
||||
InvocationId,
|
||||
TestDiffScope,
|
||||
TestResults,
|
||||
TestType,
|
||||
VerificationType,
|
||||
)
|
||||
from codeflash.verification.equivalence import (
|
||||
compare_test_results as python_compare_test_results,
|
||||
)
|
||||
from codeflash.languages.java.comparator import compare_test_results as java_compare_test_results
|
||||
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType
|
||||
|
||||
|
||||
def make_invocation(
|
||||
|
|
@ -142,7 +129,7 @@ class TestSqlitePathSelection:
|
|||
"loop_index": 1,
|
||||
"iteration_id": "1_0",
|
||||
"return_value": '{"value": 42}',
|
||||
},
|
||||
}
|
||||
]
|
||||
create_test_results_db(original_path, results)
|
||||
create_test_results_db(candidate_path, results)
|
||||
|
|
|
|||
|
|
@ -3,14 +3,9 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
from codeflash.languages.java.concurrency_analyzer import JavaConcurrencyAnalyzer, analyze_function_concurrency
|
||||
from codeflash.languages.language_enum import Language
|
||||
from codeflash.languages.java.concurrency_analyzer import (
|
||||
JavaConcurrencyAnalyzer,
|
||||
analyze_function_concurrency,
|
||||
)
|
||||
|
||||
|
||||
class TestCompletableFutureDetection:
|
||||
|
|
|
|||
|
|
@ -2,25 +2,19 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import FunctionFilterCriteria, Language, ParentInfo
|
||||
from codeflash.languages.base import FunctionFilterCriteria, Language
|
||||
from codeflash.languages.java.context import (
|
||||
TypeSkeleton,
|
||||
_extract_public_method_signatures,
|
||||
_extract_type_skeleton,
|
||||
_format_skeleton_for_context,
|
||||
extract_class_context,
|
||||
extract_code_context,
|
||||
extract_function_source,
|
||||
extract_read_only_context,
|
||||
find_helper_functions,
|
||||
get_java_imported_type_skeletons,
|
||||
_extract_public_method_signatures,
|
||||
_format_skeleton_for_context,
|
||||
)
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
from codeflash.languages.java.import_resolver import JavaImportResolver, ResolvedImport
|
||||
from codeflash.languages.java.parser import JavaImportInfo, get_java_analyzer
|
||||
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
|
||||
# Filter criteria that includes void methods
|
||||
NO_RETURN_FILTER = FunctionFilterCriteria(require_return=False)
|
||||
|
|
@ -1785,12 +1779,15 @@ class TestExtractCodeContextEdgeCases:
|
|||
def test_unicode_in_source(self, tmp_path: Path):
|
||||
"""Test context extraction for method with unicode characters."""
|
||||
java_file = tmp_path / "Unicode.java"
|
||||
java_file.write_text("""public class Unicode {
|
||||
java_file.write_text(
|
||||
"""public class Unicode {
|
||||
public String greet() {
|
||||
return "こんにちは世界";
|
||||
}
|
||||
}
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
functions = discover_functions_from_source(java_file.read_text(encoding="utf-8"), file_path=java_file)
|
||||
assert len(functions) == 1
|
||||
|
||||
|
|
|
|||
|
|
@ -4,14 +4,7 @@ import os
|
|||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.formatter import (
|
||||
JavaFormatter,
|
||||
format_java_code,
|
||||
format_java_file,
|
||||
normalize_java_code,
|
||||
)
|
||||
from codeflash.languages.java.formatter import JavaFormatter, format_java_code, format_java_file, normalize_java_code
|
||||
from codeflash.setup.detector import _detect_formatter
|
||||
|
||||
|
||||
|
|
@ -201,12 +194,12 @@ class TestNormalizationEdgeCases:
|
|||
|
||||
def test_string_with_comment_chars(self):
|
||||
"""Test string containing comment characters."""
|
||||
source = '''
|
||||
source = """
|
||||
public class Example {
|
||||
String s1 = "// not a comment";
|
||||
String s2 = "/* also not */";
|
||||
}
|
||||
'''
|
||||
"""
|
||||
normalized = normalize_java_code(source)
|
||||
# Note: current implementation incorrectly removes content in s2 string
|
||||
expected = 'public class Example {\nString s1 = "// not a comment";\nString s2 = "";\n}'
|
||||
|
|
@ -273,10 +266,7 @@ class TestDetectJavaFormatter:
|
|||
|
||||
def test_detect_formatter_returns_empty_when_java_not_available(self, tmp_path: Path):
|
||||
"""Detector returns empty list with descriptive message when Java is not found."""
|
||||
with (
|
||||
patch.dict(os.environ, {}, clear=True),
|
||||
patch("shutil.which", return_value=None),
|
||||
):
|
||||
with patch.dict(os.environ, {}, clear=True), patch("shutil.which", return_value=None):
|
||||
cmds, description = _detect_formatter(tmp_path, "java")
|
||||
|
||||
assert cmds == []
|
||||
|
|
|
|||
|
|
@ -2,14 +2,7 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.import_resolver import (
|
||||
JavaImportResolver,
|
||||
ResolvedImport,
|
||||
find_helper_files,
|
||||
resolve_imports_for_file,
|
||||
)
|
||||
from codeflash.languages.java.import_resolver import JavaImportResolver, ResolvedImport, find_helper_files
|
||||
from codeflash.languages.java.parser import JavaImportInfo
|
||||
|
||||
|
||||
|
|
@ -21,11 +14,7 @@ class TestJavaImportResolver:
|
|||
resolver = JavaImportResolver(tmp_path)
|
||||
|
||||
import_info = JavaImportInfo(
|
||||
import_path="java.util.List",
|
||||
is_static=False,
|
||||
is_wildcard=False,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
import_path="java.util.List", is_static=False, is_wildcard=False, start_line=1, end_line=1
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(import_info)
|
||||
|
|
@ -38,11 +27,7 @@ class TestJavaImportResolver:
|
|||
resolver = JavaImportResolver(tmp_path)
|
||||
|
||||
import_info = JavaImportInfo(
|
||||
import_path="javax.annotation.Nullable",
|
||||
is_static=False,
|
||||
is_wildcard=False,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
import_path="javax.annotation.Nullable", is_static=False, is_wildcard=False, start_line=1, end_line=1
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(import_info)
|
||||
|
|
@ -53,11 +38,7 @@ class TestJavaImportResolver:
|
|||
resolver = JavaImportResolver(tmp_path)
|
||||
|
||||
import_info = JavaImportInfo(
|
||||
import_path="org.junit.jupiter.api.Test",
|
||||
is_static=False,
|
||||
is_wildcard=False,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
import_path="org.junit.jupiter.api.Test", is_static=False, is_wildcard=False, start_line=1, end_line=1
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(import_info)
|
||||
|
|
@ -89,11 +70,7 @@ public class StringUtils {
|
|||
resolver = JavaImportResolver(tmp_path)
|
||||
|
||||
import_info = JavaImportInfo(
|
||||
import_path="com.example.utils.StringUtils",
|
||||
is_static=False,
|
||||
is_wildcard=False,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
import_path="com.example.utils.StringUtils", is_static=False, is_wildcard=False, start_line=1, end_line=1
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(import_info)
|
||||
|
|
@ -107,11 +84,7 @@ public class StringUtils {
|
|||
resolver = JavaImportResolver(tmp_path)
|
||||
|
||||
import_info = JavaImportInfo(
|
||||
import_path="java.util",
|
||||
is_static=False,
|
||||
is_wildcard=True,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
import_path="java.util", is_static=False, is_wildcard=True, start_line=1, end_line=1
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(import_info)
|
||||
|
|
@ -123,11 +96,7 @@ public class StringUtils {
|
|||
resolver = JavaImportResolver(tmp_path)
|
||||
|
||||
import_info = JavaImportInfo(
|
||||
import_path="java.lang.Math.PI",
|
||||
is_static=True,
|
||||
is_wildcard=False,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
import_path="java.lang.Math.PI", is_static=True, is_wildcard=False, start_line=1, end_line=1
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(import_info)
|
||||
|
|
@ -286,11 +255,7 @@ class TestResolvedImport:
|
|||
def test_resolved_import_external(self):
|
||||
"""Test ResolvedImport for external dependency."""
|
||||
resolved = ResolvedImport(
|
||||
import_path="java.util.List",
|
||||
file_path=None,
|
||||
is_external=True,
|
||||
is_wildcard=False,
|
||||
class_name="List",
|
||||
import_path="java.util.List", file_path=None, is_external=True, is_wildcard=False, class_name="List"
|
||||
)
|
||||
assert resolved.is_external is True
|
||||
assert resolved.file_path is None
|
||||
|
|
|
|||
|
|
@ -221,9 +221,7 @@ public class StringUtilsTest {
|
|||
return new String(chars);
|
||||
}"""
|
||||
|
||||
optimized = support.replace_function(
|
||||
src_file.read_text(), functions[0], new_code
|
||||
)
|
||||
optimized = support.replace_function(src_file.read_text(), functions[0], new_code)
|
||||
|
||||
assert "Optimized version" in optimized
|
||||
assert "StringUtils" in optimized
|
||||
|
|
|
|||
|
|
@ -100,7 +100,9 @@ class TestFixJavaTestPathsIntegration:
|
|||
|
||||
# Bind the actual methods
|
||||
mock_optimizer._get_java_sources_root = lambda: JavaFunctionOptimizer._get_java_sources_root(mock_optimizer)
|
||||
mock_optimizer._fix_java_test_paths = lambda behavior_source, perf_source, used_paths: JavaFunctionOptimizer._fix_java_test_paths(mock_optimizer, behavior_source, perf_source, used_paths)
|
||||
mock_optimizer._fix_java_test_paths = lambda behavior_source, perf_source, used_paths: (
|
||||
JavaFunctionOptimizer._fix_java_test_paths(mock_optimizer, behavior_source, perf_source, used_paths)
|
||||
)
|
||||
|
||||
return mock_optimizer
|
||||
|
||||
|
|
@ -133,8 +135,14 @@ public class UnpackerTest__perfonlyinstrumented {
|
|||
# The path should be test/src/com/aerospike/client/util/UnpackerTest__perfinstrumented.java
|
||||
# NOT test/src/com/aerospike/test/com/aerospike/client/util/...
|
||||
expected_java_root = tmp_path / "test" / "src"
|
||||
assert behavior_path == expected_java_root / "com" / "aerospike" / "client" / "util" / "UnpackerTest__perfinstrumented.java"
|
||||
assert perf_path == expected_java_root / "com" / "aerospike" / "client" / "util" / "UnpackerTest__perfonlyinstrumented.java"
|
||||
assert (
|
||||
behavior_path
|
||||
== expected_java_root / "com" / "aerospike" / "client" / "util" / "UnpackerTest__perfinstrumented.java"
|
||||
)
|
||||
assert (
|
||||
perf_path
|
||||
== expected_java_root / "com" / "aerospike" / "client" / "util" / "UnpackerTest__perfonlyinstrumented.java"
|
||||
)
|
||||
|
||||
# Verify there's no duplication in the path
|
||||
assert "com/aerospike/test/com" not in str(behavior_path)
|
||||
|
|
@ -169,6 +177,7 @@ public class CalculatorTest__perfonlyinstrumented {
|
|||
assert behavior_path == tests_root / "com" / "example" / "CalculatorTest__perfinstrumented.java"
|
||||
assert perf_path == tests_root / "com" / "example" / "CalculatorTest__perfonlyinstrumented.java"
|
||||
|
||||
|
||||
class TestPathToClassNameWithCustomDirs:
|
||||
"""Tests for _path_to_class_name with custom source directories."""
|
||||
|
||||
|
|
|
|||
|
|
@ -62,14 +62,7 @@ public class Calculator {
|
|||
"targets": [
|
||||
{
|
||||
"className": "com/example/Calculator",
|
||||
"methods": [
|
||||
{
|
||||
"name": "add",
|
||||
"startLine": 4,
|
||||
"endLine": 7,
|
||||
"sourceFile": file_path.as_posix(),
|
||||
}
|
||||
],
|
||||
"methods": [{"name": "add", "startLine": 4, "endLine": 7, "sourceFile": file_path.as_posix()}],
|
||||
}
|
||||
],
|
||||
"lineContents": {
|
||||
|
|
@ -172,18 +165,8 @@ public class Calculator {
|
|||
config = json.loads(config_path.read_text())
|
||||
|
||||
assert config["targets"][0]["methods"] == [
|
||||
{
|
||||
"name": "method1",
|
||||
"startLine": 2,
|
||||
"endLine": 4,
|
||||
"sourceFile": file_path.as_posix(),
|
||||
},
|
||||
{
|
||||
"name": "method2",
|
||||
"startLine": 6,
|
||||
"endLine": 8,
|
||||
"sourceFile": file_path.as_posix(),
|
||||
},
|
||||
{"name": "method1", "startLine": 2, "endLine": 4, "sourceFile": file_path.as_posix()},
|
||||
{"name": "method2", "startLine": 6, "endLine": 8, "sourceFile": file_path.as_posix()},
|
||||
]
|
||||
|
||||
def test_empty_function_list(self):
|
||||
|
|
@ -403,12 +386,7 @@ class TestAgentConfigBoundaryConditions:
|
|||
{
|
||||
"className": "Test",
|
||||
"methods": [
|
||||
{
|
||||
"name": "foo",
|
||||
"startLine": 100,
|
||||
"endLine": 200,
|
||||
"sourceFile": file_path.as_posix(),
|
||||
}
|
||||
{"name": "foo", "startLine": 100, "endLine": 200, "sourceFile": file_path.as_posix()}
|
||||
],
|
||||
}
|
||||
],
|
||||
|
|
@ -450,12 +428,7 @@ class TestAgentConfigBoundaryConditions:
|
|||
{
|
||||
"className": "Test",
|
||||
"methods": [
|
||||
{
|
||||
"name": "foo",
|
||||
"startLine": -5,
|
||||
"endLine": -1,
|
||||
"sourceFile": file_path.as_posix(),
|
||||
}
|
||||
{"name": "foo", "startLine": -5, "endLine": -1, "sourceFile": file_path.as_posix()}
|
||||
],
|
||||
}
|
||||
],
|
||||
|
|
@ -496,9 +469,7 @@ class TestLineProfileResultsParsing:
|
|||
results = JavaLineProfiler.parse_results(profile_file)
|
||||
|
||||
assert results["unit"] == 1e-9
|
||||
assert results["timings"] == {
|
||||
("/tmp/Test.java", 10, "Test.java"): [(10, 100, 5000000), (11, 100, 95000000)]
|
||||
}
|
||||
assert results["timings"] == {("/tmp/Test.java", 10, "Test.java"): [(10, 100, 5000000), (11, 100, 95000000)]}
|
||||
assert results["line_contents"] == {
|
||||
("/tmp/Test.java", 10): "int x = compute();",
|
||||
("/tmp/Test.java", 11): "result = slowOperation(x);",
|
||||
|
|
@ -601,9 +572,7 @@ class TestLineProfileResultsParsing:
|
|||
|
||||
assert results == {
|
||||
"unit": 1e-9,
|
||||
"timings": {
|
||||
("/tmp/Sorter.java", 5, "Sorter.java"): [(5, 10, 2000000), (6, 10, 8000000)]
|
||||
},
|
||||
"timings": {("/tmp/Sorter.java", 5, "Sorter.java"): [(5, 10, 2000000), (6, 10, 8000000)]},
|
||||
"line_contents": {
|
||||
("/tmp/Sorter.java", 5): "int n = arr.length;",
|
||||
("/tmp/Sorter.java", 6): "for (int i = 0; i < n; i++) {",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""Integration tests for Java line profiler with JavaSupport.
|
||||
"""
|
||||
"""Integration tests for Java line profiler with JavaSupport."""
|
||||
|
||||
import json
|
||||
import math
|
||||
|
|
@ -16,12 +15,10 @@ from codeflash.languages.java.support import get_java_support
|
|||
|
||||
|
||||
class TestLineProfilerInstrumentation:
|
||||
"""Integration tests for line profiler instrumentation through JavaSupport.
|
||||
"""
|
||||
"""Integration tests for line profiler instrumentation through JavaSupport."""
|
||||
|
||||
def test_instrument_with_package(self):
|
||||
"""Test instrumentation for a class with a package declaration.
|
||||
"""
|
||||
"""Test instrumentation for a class with a package declaration."""
|
||||
source = """package com.example;
|
||||
|
||||
public class Calculator {
|
||||
|
|
@ -70,14 +67,7 @@ public class Calculator {
|
|||
"targets": [
|
||||
{
|
||||
"className": "com/example/Calculator",
|
||||
"methods": [
|
||||
{
|
||||
"name": "add",
|
||||
"startLine": 4,
|
||||
"endLine": 7,
|
||||
"sourceFile": java_file.as_posix(),
|
||||
}
|
||||
],
|
||||
"methods": [{"name": "add", "startLine": 4, "endLine": 7, "sourceFile": java_file.as_posix()}],
|
||||
}
|
||||
],
|
||||
"lineContents": {
|
||||
|
|
@ -155,12 +145,7 @@ public class Calculator {
|
|||
{
|
||||
"className": "Sorter",
|
||||
"methods": [
|
||||
{
|
||||
"name": "sort",
|
||||
"startLine": 2,
|
||||
"endLine": 14,
|
||||
"sourceFile": java_file.as_posix(),
|
||||
}
|
||||
{"name": "sort", "startLine": 2, "endLine": 14, "sourceFile": java_file.as_posix()}
|
||||
],
|
||||
}
|
||||
],
|
||||
|
|
@ -254,9 +239,7 @@ public class Calculator {
|
|||
|
||||
# Both methods should appear as targets when generated together
|
||||
profiler = JavaLineProfiler(output_file=profile_output)
|
||||
profiler.generate_agent_config(
|
||||
source, java_file, [func_reverse, func_palindrome], config_path
|
||||
)
|
||||
profiler.generate_agent_config(source, java_file, [func_reverse, func_palindrome], config_path)
|
||||
config = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert config == {
|
||||
|
|
@ -266,12 +249,7 @@ public class Calculator {
|
|||
{
|
||||
"className": "StringProcessor",
|
||||
"methods": [
|
||||
{
|
||||
"name": "reverse",
|
||||
"startLine": 2,
|
||||
"endLine": 14,
|
||||
"sourceFile": java_file.as_posix(),
|
||||
},
|
||||
{"name": "reverse", "startLine": 2, "endLine": 14, "sourceFile": java_file.as_posix()},
|
||||
{
|
||||
"name": "isPalindrome",
|
||||
"startLine": 16,
|
||||
|
|
@ -355,12 +333,7 @@ public class StringUtils {
|
|||
{
|
||||
"className": "org/apache/commons/lang3/StringUtils",
|
||||
"methods": [
|
||||
{
|
||||
"name": "isEmpty",
|
||||
"startLine": 4,
|
||||
"endLine": 6,
|
||||
"sourceFile": java_file.as_posix(),
|
||||
}
|
||||
{"name": "isEmpty", "startLine": 4, "endLine": 6, "sourceFile": java_file.as_posix()}
|
||||
],
|
||||
}
|
||||
],
|
||||
|
|
@ -484,10 +457,7 @@ def run_spin_timer_profiled(tmppath: Path, spin_durations_ns: list[int]) -> dict
|
|||
agent_arg = profiler.build_javaagent_arg(config_path)
|
||||
|
||||
result = subprocess.run(
|
||||
["javac", "--release", "11", str(java_file)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=str(tmppath),
|
||||
["javac", "--release", "11", str(java_file)], capture_output=True, text=True, cwd=str(tmppath)
|
||||
)
|
||||
assert result.returncode == 0, f"javac failed: {result.stderr}"
|
||||
|
||||
|
|
@ -512,13 +482,7 @@ class TestSpinTimerProfiling:
|
|||
profiler-reported total time matches the expected sum of all spin durations.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spin_durations_ns",
|
||||
[
|
||||
[50_000_000, 100_000_000],
|
||||
[30_000_000, 40_000_000, 80_000_000],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("spin_durations_ns", [[50_000_000, 100_000_000], [30_000_000, 40_000_000, 80_000_000]])
|
||||
def test_total_time_matches_expected(self, spin_durations_ns):
|
||||
"""Profiler total time should match the sum of all spin durations."""
|
||||
expected_ns = sum(spin_durations_ns)
|
||||
|
|
|
|||
|
|
@ -3,13 +3,8 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
from codeflash.languages.java.test_discovery import (
|
||||
disambiguate_overloads,
|
||||
discover_tests,
|
||||
)
|
||||
from codeflash.languages.java.test_discovery import disambiguate_overloads, discover_tests
|
||||
|
||||
|
||||
class TestOverloadDisambiguation:
|
||||
|
|
@ -109,9 +104,7 @@ public class CalculatorTest {
|
|||
"""When overloaded methods are detected, info log fires."""
|
||||
matched_names = ["Calculator.add", "StringUtils.add"]
|
||||
with caplog.at_level(logging.INFO):
|
||||
result = disambiguate_overloads(
|
||||
matched_names, "testAdd", "some test source code"
|
||||
)
|
||||
result = disambiguate_overloads(matched_names, "testAdd", "some test source code")
|
||||
|
||||
assert result == matched_names
|
||||
info_messages = [r.message for r in caplog.records if r.levelno == logging.INFO]
|
||||
|
|
|
|||
|
|
@ -1,15 +1,6 @@
|
|||
"""Tests for the Java tree-sitter parser utilities."""
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.parser import (
|
||||
JavaAnalyzer,
|
||||
JavaClassNode,
|
||||
JavaFieldInfo,
|
||||
JavaImportInfo,
|
||||
JavaMethodNode,
|
||||
get_java_analyzer,
|
||||
)
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
|
||||
|
||||
|
||||
class TestJavaAnalyzerBasic:
|
||||
|
|
|
|||
|
|
@ -118,12 +118,7 @@ def java_project(tmp_path: Path):
|
|||
|
||||
def _make_optimizer(project_root: Path, test_dir: Path, function_name: str, src_file: Path) -> tuple:
|
||||
"""Create an Optimizer and FunctionOptimizer for the given function."""
|
||||
fto = FunctionToOptimize(
|
||||
function_name=function_name,
|
||||
file_path=src_file,
|
||||
parents=[],
|
||||
language="java",
|
||||
)
|
||||
fto = FunctionToOptimize(function_name=function_name, file_path=src_file, parents=[], language="java")
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=project_root,
|
||||
|
|
@ -493,11 +488,7 @@ public class PreciseWaiterTest {
|
|||
project_root, src_dir, test_dir = self._setup_precise_waiter_project(java_project)
|
||||
|
||||
test_results = self._instrument_and_run(
|
||||
project_root,
|
||||
src_dir,
|
||||
test_dir,
|
||||
self.PRECISE_WAITER_TEST,
|
||||
"PreciseWaiterTest.java",
|
||||
project_root, src_dir, test_dir, self.PRECISE_WAITER_TEST, "PreciseWaiterTest.java"
|
||||
)
|
||||
|
||||
# 2 outer loops × 2 inner iterations = 4 total results
|
||||
|
|
@ -542,9 +533,7 @@ public class PreciseWaiterTest {
|
|||
runtime_by_test = test_results.usable_runtime_data_by_test_case()
|
||||
|
||||
# Should have 1 test case (constant iteration_id per call site)
|
||||
assert len(runtime_by_test) == 1, (
|
||||
f"Expected 1 test case (constant iteration_id), got {len(runtime_by_test)}"
|
||||
)
|
||||
assert len(runtime_by_test) == 1, f"Expected 1 test case (constant iteration_id), got {len(runtime_by_test)}"
|
||||
|
||||
# The single test case should have 4 runtimes (2 outer loops × 2 inner iterations)
|
||||
for test_id, test_runtimes in runtime_by_test.items():
|
||||
|
|
@ -584,11 +573,7 @@ public class PreciseWaiterMultiTest {
|
|||
}
|
||||
"""
|
||||
test_results = self._instrument_and_run(
|
||||
project_root,
|
||||
src_dir,
|
||||
test_dir,
|
||||
multi_test_source,
|
||||
"PreciseWaiterMultiTest.java",
|
||||
project_root, src_dir, test_dir, multi_test_source, "PreciseWaiterMultiTest.java"
|
||||
)
|
||||
|
||||
# 2 test methods × 2 outer loops × 2 inner iterations = 8 total results
|
||||
|
|
@ -651,5 +636,3 @@ public class PreciseWaiterMultiTest {
|
|||
f"total_passed_runtime {total_runtime / 1_000_000:.3f}ms not close to expected "
|
||||
f"{expected_total_ns / 1_000_000:.1f}ms (2 methods × min of 4 runtimes × 10ms, ±3%)"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.test_runner import (
|
||||
_validate_java_class_name,
|
||||
_validate_test_filter,
|
||||
get_test_run_command,
|
||||
)
|
||||
from codeflash.languages.java.test_runner import _validate_java_class_name, _validate_test_filter, get_test_run_command
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
|
|
@ -62,12 +58,7 @@ class TestInputValidation:
|
|||
|
||||
def test_validate_test_filter_wildcards(self):
|
||||
"""Test validation of wildcard patterns."""
|
||||
valid_patterns = [
|
||||
"My*Test",
|
||||
"*Test",
|
||||
"com.example.*Test",
|
||||
"com.example.**",
|
||||
]
|
||||
valid_patterns = ["My*Test", "*Test", "com.example.*Test", "com.example.**"]
|
||||
|
||||
for pattern in valid_patterns:
|
||||
result = _validate_test_filter(pattern)
|
||||
|
|
@ -203,7 +194,7 @@ class TestXMLParsingSecurity:
|
|||
for i in range(3):
|
||||
xml_file = surefire_dir / f"TEST-Suite{i}.xml"
|
||||
xml_file.write_text(f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<testsuite tests="{i+1}" failures="0" errors="0" skipped="0">
|
||||
<testsuite tests="{i + 1}" failures="0" errors="0" skipped="0">
|
||||
<testcase name="test1" classname="Suite{i}" time="0.001"/>
|
||||
</testsuite>
|
||||
""")
|
||||
|
|
|
|||
|
|
@ -108,9 +108,7 @@ public class CalculatorTest {
|
|||
""")
|
||||
|
||||
# Get source functions
|
||||
source_functions = discover_functions_from_source(
|
||||
src_file.read_text(), file_path=src_file
|
||||
)
|
||||
source_functions = discover_functions_from_source(src_file.read_text(), file_path=src_file)
|
||||
|
||||
# Discover tests
|
||||
result = discover_tests(test_dir, source_functions)
|
||||
|
|
@ -168,7 +166,6 @@ public class StringUtilsTest {
|
|||
|
||||
# Create source function
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="reverse",
|
||||
|
|
@ -242,9 +239,7 @@ public class TestQueryBlob {
|
|||
""")
|
||||
|
||||
# Get source functions
|
||||
source_functions = discover_functions_from_source(
|
||||
src_file.read_text(), file_path=src_file
|
||||
)
|
||||
source_functions = discover_functions_from_source(src_file.read_text(), file_path=src_file)
|
||||
|
||||
# Filter to just bytesToHexString
|
||||
target_functions = [f for f in source_functions if f.function_name == "bytesToHexString"]
|
||||
|
|
@ -288,9 +283,7 @@ public class IntegrationTest {
|
|||
""")
|
||||
|
||||
# Get source functions
|
||||
source_functions = discover_functions_from_source(
|
||||
src_file.read_text(), file_path=src_file
|
||||
)
|
||||
source_functions = discover_functions_from_source(src_file.read_text(), file_path=src_file)
|
||||
|
||||
# Discover tests
|
||||
result = discover_tests(test_dir, source_functions)
|
||||
|
|
@ -325,8 +318,8 @@ class TestImportExtraction:
|
|||
|
||||
def test_basic_import(self):
|
||||
"""Test extraction of basic import statement."""
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
|
|
@ -341,8 +334,8 @@ public class Test {}
|
|||
|
||||
def test_multiple_imports(self):
|
||||
"""Test extraction of multiple imports."""
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
|
|
@ -358,8 +351,8 @@ public class Test {}
|
|||
|
||||
def test_wildcard_import_returns_empty(self):
|
||||
"""Test that wildcard imports don't add specific classes."""
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
|
|
@ -374,8 +367,8 @@ public class Test {}
|
|||
|
||||
def test_static_import_extracts_class(self):
|
||||
"""Test that static imports extract the class name, not the method."""
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
|
|
@ -390,8 +383,8 @@ public class Test {}
|
|||
|
||||
def test_static_wildcard_import_extracts_class(self):
|
||||
"""Test that static wildcard imports extract the class name."""
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
|
|
@ -406,8 +399,8 @@ public class Test {}
|
|||
|
||||
def test_deeply_nested_package(self):
|
||||
"""Test extraction from deeply nested package."""
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
|
|
@ -422,8 +415,8 @@ public class Test {}
|
|||
|
||||
def test_mixed_imports(self):
|
||||
"""Test extraction with mix of regular, static, and wildcard imports."""
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
from codeflash.languages.java.test_discovery import _extract_imports
|
||||
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
|
|
@ -448,8 +441,8 @@ class TestMethodCallDetection:
|
|||
|
||||
def test_find_method_calls(self):
|
||||
"""Test detection of method calls within a code range."""
|
||||
from codeflash.languages.java.test_discovery import _find_method_calls_in_range
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
from codeflash.languages.java.test_discovery import _find_method_calls_in_range
|
||||
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
|
|
|
|||
|
|
@ -319,12 +319,7 @@ class TestJavaCompilation:
|
|||
pytest.skip("Maven not installed")
|
||||
|
||||
# Compile the project
|
||||
result = subprocess.run(
|
||||
["mvn", "compile", "-q"],
|
||||
cwd=java_project_dir,
|
||||
capture_output=True,
|
||||
timeout=120,
|
||||
)
|
||||
result = subprocess.run(["mvn", "compile", "-q"], cwd=java_project_dir, capture_output=True, timeout=120)
|
||||
|
||||
assert result.returncode == 0, f"Compilation failed: {result.stderr.decode()}"
|
||||
|
||||
|
|
@ -342,11 +337,6 @@ class TestJavaCompilation:
|
|||
pytest.skip("Maven not installed")
|
||||
|
||||
# Run tests
|
||||
result = subprocess.run(
|
||||
["mvn", "test", "-q"],
|
||||
cwd=java_project_dir,
|
||||
capture_output=True,
|
||||
timeout=180,
|
||||
)
|
||||
result = subprocess.run(["mvn", "test", "-q"], cwd=java_project_dir, capture_output=True, timeout=180)
|
||||
|
||||
assert result.returncode == 0, f"Tests failed: {result.stderr.decode()}"
|
||||
|
|
|
|||
|
|
@ -19,10 +19,7 @@ def make_func(name: str, class_name: str | None = None) -> FunctionToOptimize:
|
|||
"""Helper to create FunctionToOptimize for testing."""
|
||||
parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else []
|
||||
return FunctionToOptimize(
|
||||
function_name=name,
|
||||
file_path=Path("/test/file.js"),
|
||||
parents=parents,
|
||||
language="javascript",
|
||||
function_name=name, file_path=Path("/test/file.js"), parents=parents, language="javascript"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -386,7 +383,9 @@ test('fibonacci works', () => {
|
|||
});
|
||||
"""
|
||||
transformed, counter = transform_expect_calls(
|
||||
code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture"
|
||||
code=code,
|
||||
function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"),
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
# Should transform expect(calc.fibonacci(10)) to
|
||||
|
|
@ -433,7 +432,9 @@ class FibonacciCalculator {
|
|||
}
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture"
|
||||
code=code,
|
||||
function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"),
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
# The method definition should NOT be transformed
|
||||
|
|
@ -452,7 +453,9 @@ FibonacciCalculator.prototype.fibonacci = function(n) {
|
|||
};
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture"
|
||||
code=code,
|
||||
function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"),
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
# The prototype assignment should NOT be transformed
|
||||
|
|
@ -558,7 +561,10 @@ describe('Calculator', () => {
|
|||
});
|
||||
"""
|
||||
instrumented = _instrument_js_test_code(
|
||||
code=test_code, function_to_optimize=make_func("add", class_name="Calculator"), test_file_path="test.js", mode="behavior"
|
||||
code=test_code,
|
||||
function_to_optimize=make_func("add", class_name="Calculator"),
|
||||
test_file_path="test.js",
|
||||
mode="behavior",
|
||||
)
|
||||
|
||||
# describe and test structure should be preserved
|
||||
|
|
@ -886,15 +892,15 @@ test('should compute fibonacci(20) and fibonacci(30) to known values', () => {
|
|||
from codeflash.languages.javascript.instrument import transform_standalone_calls
|
||||
|
||||
func = make_func("fibonacci")
|
||||
code = '''
|
||||
code = """
|
||||
test("should compute fibonacci(20) correctly", () => {
|
||||
const result = fibonacci(10);
|
||||
});
|
||||
'''
|
||||
"""
|
||||
transformed, _counter = transform_standalone_calls(code, func, "capture")
|
||||
|
||||
# The function call in the test description should NOT be transformed
|
||||
assert 'fibonacci(20)' in transformed
|
||||
assert "fibonacci(20)" in transformed
|
||||
# The actual call should be transformed
|
||||
assert "codeflash.capture('fibonacci'" in transformed
|
||||
|
||||
|
|
|
|||
|
|
@ -8,13 +8,11 @@ These tests call the actual backend /testgen API endpoint and verify:
|
|||
Similar to test_validate_python_code.py but for JavaScript/TypeScript.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.api.aiservice import AiServiceClient
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.models.models import CodeString, OptimizedCandidateSource
|
||||
|
||||
|
|
@ -23,6 +21,7 @@ def skip_if_js_not_supported():
|
|||
"""Skip test if JavaScript/TypeScript languages are not supported."""
|
||||
try:
|
||||
from codeflash.languages import get_language_support
|
||||
|
||||
get_language_support(Language.JAVASCRIPT)
|
||||
except Exception as e:
|
||||
pytest.skip(f"JavaScript/TypeScript language support not available: {e}")
|
||||
|
|
@ -218,12 +217,13 @@ export function add(a: number, b: number): number {
|
|||
|
||||
def capture_request(*args, **kwargs):
|
||||
nonlocal captured_payload
|
||||
if 'payload' in kwargs:
|
||||
captured_payload = kwargs['payload']
|
||||
if "payload" in kwargs:
|
||||
captured_payload = kwargs["payload"]
|
||||
elif len(args) > 1:
|
||||
captured_payload = args[1]
|
||||
# Return a mock response to avoid actual API call
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
|
|
@ -233,7 +233,7 @@ export function add(a: number, b: number): number {
|
|||
}
|
||||
return mock_response
|
||||
|
||||
with patch.object(ai_client, 'make_ai_service_request', side_effect=capture_request):
|
||||
with patch.object(ai_client, "make_ai_service_request", side_effect=capture_request):
|
||||
ai_client.generate_regression_tests(
|
||||
source_code_being_tested=ts_file.read_text(),
|
||||
function_to_optimize=func,
|
||||
|
|
@ -248,8 +248,9 @@ export function add(a: number, b: number): number {
|
|||
)
|
||||
|
||||
assert captured_payload is not None
|
||||
assert captured_payload.get('language') == 'typescript', \
|
||||
assert captured_payload.get("language") == "typescript", (
|
||||
f"Expected language='typescript', got: {captured_payload.get('language')}"
|
||||
)
|
||||
|
||||
def test_testgen_request_includes_javascript_language(self, tmp_path):
|
||||
"""Verify the language parameter is sent as 'javascript' for .js files."""
|
||||
|
|
@ -279,11 +280,12 @@ module.exports = { add };
|
|||
|
||||
def capture_request(*args, **kwargs):
|
||||
nonlocal captured_payload
|
||||
if 'payload' in kwargs:
|
||||
captured_payload = kwargs['payload']
|
||||
if "payload" in kwargs:
|
||||
captured_payload = kwargs["payload"]
|
||||
elif len(args) > 1:
|
||||
captured_payload = args[1]
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
|
|
@ -293,7 +295,7 @@ module.exports = { add };
|
|||
}
|
||||
return mock_response
|
||||
|
||||
with patch.object(ai_client, 'make_ai_service_request', side_effect=capture_request):
|
||||
with patch.object(ai_client, "make_ai_service_request", side_effect=capture_request):
|
||||
ai_client.generate_regression_tests(
|
||||
source_code_being_tested=js_file.read_text(),
|
||||
function_to_optimize=func,
|
||||
|
|
@ -308,5 +310,6 @@ module.exports = { add };
|
|||
)
|
||||
|
||||
assert captured_payload is not None
|
||||
assert captured_payload.get('language') == 'javascript', \
|
||||
assert captured_payload.get("language") == "javascript", (
|
||||
f"Expected language='javascript', got: {captured_payload.get('language')}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""Tests for JavaScript module system detection.
|
||||
"""
|
||||
"""Tests for JavaScript module system detection."""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
|
|
|
|||
|
|
@ -86,9 +86,7 @@ export function add(a: number, b: number): number {
|
|||
|
||||
ts_support = get_language_support(Language.TYPESCRIPT)
|
||||
code_context = ts_support.extract_code_context(func, tmp_path, tmp_path)
|
||||
context = JavaScriptFunctionOptimizer._build_optimization_context(
|
||||
code_context, ts_file, "typescript", tmp_path
|
||||
)
|
||||
context = JavaScriptFunctionOptimizer._build_optimization_context(code_context, ts_file, "typescript", tmp_path)
|
||||
|
||||
assert context.read_writable_code is not None
|
||||
assert context.read_writable_code.language == "typescript"
|
||||
|
|
@ -193,8 +191,9 @@ export function add(a: number, b: number): number {
|
|||
assert mock_request.called, "API request should have been made"
|
||||
call_args = mock_request.call_args
|
||||
payload = call_args[1].get("payload", call_args[0][1] if len(call_args[0]) > 1 else {})
|
||||
assert payload.get("language") == "typescript", \
|
||||
assert payload.get("language") == "typescript", (
|
||||
f"Expected language='typescript', got language='{payload.get('language')}'"
|
||||
)
|
||||
|
||||
|
||||
class TestFunctionOptimizerForJavaScript:
|
||||
|
|
@ -328,9 +327,7 @@ describe('fibonacci', () => {
|
|||
)
|
||||
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
function_to_optimize=func_to_optimize, test_cfg=test_config, aiservice_client=MagicMock()
|
||||
)
|
||||
|
||||
assert optimizer is not None
|
||||
|
|
@ -363,9 +360,7 @@ describe('fibonacci', () => {
|
|||
)
|
||||
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
function_to_optimize=func_to_optimize, test_cfg=test_config, aiservice_client=MagicMock()
|
||||
)
|
||||
|
||||
assert optimizer is not None
|
||||
|
|
@ -398,9 +393,7 @@ describe('fibonacci', () => {
|
|||
)
|
||||
|
||||
optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
function_to_optimize=func_to_optimize, test_cfg=test_config, aiservice_client=MagicMock()
|
||||
)
|
||||
|
||||
result = optimizer.get_code_optimization_context()
|
||||
|
|
@ -437,9 +430,7 @@ describe('fibonacci', () => {
|
|||
)
|
||||
|
||||
optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
function_to_optimize=func_to_optimize, test_cfg=test_config, aiservice_client=MagicMock()
|
||||
)
|
||||
|
||||
result = optimizer.get_code_optimization_context()
|
||||
|
|
@ -486,16 +477,11 @@ module.exports = { main };
|
|||
)
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root=tmp_path,
|
||||
tests_project_rootdir=tmp_path,
|
||||
project_root_path=tmp_path,
|
||||
pytest_cmd="jest",
|
||||
tests_root=tmp_path, tests_project_rootdir=tmp_path, project_root_path=tmp_path, pytest_cmd="jest"
|
||||
)
|
||||
|
||||
optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
function_to_optimize=func_to_optimize, test_cfg=test_config, aiservice_client=MagicMock()
|
||||
)
|
||||
|
||||
result = optimizer.get_code_optimization_context()
|
||||
|
|
@ -535,16 +521,11 @@ export function main(): number {
|
|||
)
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root=tmp_path,
|
||||
tests_project_rootdir=tmp_path,
|
||||
project_root_path=tmp_path,
|
||||
pytest_cmd="vitest",
|
||||
tests_root=tmp_path, tests_project_rootdir=tmp_path, project_root_path=tmp_path, pytest_cmd="vitest"
|
||||
)
|
||||
|
||||
optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
function_to_optimize=func_to_optimize, test_cfg=test_config, aiservice_client=MagicMock()
|
||||
)
|
||||
|
||||
result = optimizer.get_code_optimization_context()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ Tests the verify_requirements function that checks Node.js, npm, and test framew
|
|||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
|
@ -30,14 +29,7 @@ class TestVerifyRequirements:
|
|||
(node_modules / "codeflash").mkdir()
|
||||
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test-project",
|
||||
"devDependencies": {"jest": "^29.0.0"},
|
||||
}
|
||||
)
|
||||
)
|
||||
package_json.write_text(json.dumps({"name": "test-project", "devDependencies": {"jest": "^29.0.0"}}))
|
||||
return tmp_path
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -49,14 +41,7 @@ class TestVerifyRequirements:
|
|||
(node_modules / "codeflash").mkdir()
|
||||
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test-project",
|
||||
"devDependencies": {"vitest": "^2.0.0"},
|
||||
}
|
||||
)
|
||||
)
|
||||
package_json.write_text(json.dumps({"name": "test-project", "devDependencies": {"vitest": "^2.0.0"}}))
|
||||
return tmp_path
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -58,10 +58,7 @@ class TestJestRootsConfiguration:
|
|||
|
||||
try:
|
||||
run_jest_behavioral_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
|
||||
)
|
||||
except Exception:
|
||||
pass # Expected to fail since no real Jest
|
||||
|
|
@ -79,9 +76,9 @@ class TestJestRootsConfiguration:
|
|||
|
||||
# Should have added the test directory as a root
|
||||
assert len(roots_flags) > 0, "Expected --roots flag in Jest command"
|
||||
assert str(test_dir) in roots_flags or any(
|
||||
str(test_dir) in root for root in roots_flags
|
||||
), f"Expected test directory {test_dir} in --roots flags: {roots_flags}"
|
||||
assert str(test_dir) in roots_flags or any(str(test_dir) in root for root in roots_flags), (
|
||||
f"Expected test directory {test_dir} in --roots flags: {roots_flags}"
|
||||
)
|
||||
|
||||
def test_benchmarking_tests_adds_roots_for_test_directories(self):
|
||||
"""Test that run_jest_benchmarking_tests adds --roots for test directories."""
|
||||
|
|
@ -106,7 +103,7 @@ class TestJestRootsConfiguration:
|
|||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -119,10 +116,7 @@ class TestJestRootsConfiguration:
|
|||
|
||||
try:
|
||||
run_jest_benchmarking_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -161,7 +155,7 @@ class TestJestRootsConfiguration:
|
|||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -174,10 +168,7 @@ class TestJestRootsConfiguration:
|
|||
|
||||
try:
|
||||
run_jest_line_profile_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -239,10 +230,7 @@ class TestJestRootsConfiguration:
|
|||
|
||||
try:
|
||||
run_jest_behavioral_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -286,7 +274,7 @@ class TestVitestTimeoutConfiguration:
|
|||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -314,7 +302,9 @@ class TestVitestTimeoutConfiguration:
|
|||
# Subprocess timeout should be at least 120 seconds (minimum)
|
||||
# or 10x the per-test timeout (150 seconds)
|
||||
assert subprocess_timeout >= 120, f"Expected subprocess timeout >= 120s, got {subprocess_timeout}s"
|
||||
assert subprocess_timeout >= 15 * 10, f"Expected subprocess timeout >= 150s (10x per-test), got {subprocess_timeout}s"
|
||||
assert subprocess_timeout >= 15 * 10, (
|
||||
f"Expected subprocess timeout >= 150s (10x per-test), got {subprocess_timeout}s"
|
||||
)
|
||||
|
||||
def test_vitest_line_profile_subprocess_timeout_larger_than_test_timeout(self):
|
||||
"""Test that subprocess timeout is larger than per-test timeout for Vitest line profile tests."""
|
||||
|
|
@ -339,7 +329,7 @@ class TestVitestTimeoutConfiguration:
|
|||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -351,11 +341,7 @@ class TestVitestTimeoutConfiguration:
|
|||
mock_run.return_value = mock_result
|
||||
|
||||
run_vitest_line_profile_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
timeout=15,
|
||||
project_root=tmpdir_path,
|
||||
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, timeout=15, project_root=tmpdir_path
|
||||
)
|
||||
|
||||
assert mock_run.called
|
||||
|
|
@ -387,7 +373,7 @@ class TestVitestTimeoutConfiguration:
|
|||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -400,10 +386,7 @@ class TestVitestTimeoutConfiguration:
|
|||
|
||||
# Run without specifying a timeout
|
||||
run_vitest_behavioral_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
|
||||
)
|
||||
|
||||
assert mock_run.called
|
||||
|
|
@ -445,7 +428,7 @@ class TestVitestInternalLoopingConfiguration:
|
|||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -503,7 +486,7 @@ class TestVitestInternalLoopingConfiguration:
|
|||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -550,13 +533,7 @@ class TestBundlerModuleResolutionFix:
|
|||
tmpdir_path = Path(tmpdir)
|
||||
|
||||
# Create tsconfig with bundler moduleResolution
|
||||
tsconfig = {
|
||||
"compilerOptions": {
|
||||
"moduleResolution": "bundler",
|
||||
"module": "preserve",
|
||||
"target": "ES2022",
|
||||
}
|
||||
}
|
||||
tsconfig = {"compilerOptions": {"moduleResolution": "bundler", "module": "preserve", "target": "ES2022"}}
|
||||
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
|
||||
|
||||
assert _detect_bundler_module_resolution(tmpdir_path) is True
|
||||
|
|
@ -571,12 +548,7 @@ class TestBundlerModuleResolutionFix:
|
|||
tmpdir_path = Path(tmpdir)
|
||||
|
||||
# Create tsconfig with Node moduleResolution
|
||||
tsconfig = {
|
||||
"compilerOptions": {
|
||||
"moduleResolution": "Node",
|
||||
"module": "ESNext",
|
||||
}
|
||||
}
|
||||
tsconfig = {"compilerOptions": {"moduleResolution": "Node", "module": "ESNext"}}
|
||||
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
|
||||
|
||||
assert _detect_bundler_module_resolution(tmpdir_path) is False
|
||||
|
|
@ -601,21 +573,11 @@ class TestBundlerModuleResolutionFix:
|
|||
# Create a base config with bundler in a subdirectory (simulating node_modules)
|
||||
node_modules = tmpdir_path / "node_modules" / "@myorg" / "tsconfig"
|
||||
node_modules.mkdir(parents=True)
|
||||
base_tsconfig = {
|
||||
"compilerOptions": {
|
||||
"moduleResolution": "bundler",
|
||||
"module": "preserve",
|
||||
}
|
||||
}
|
||||
base_tsconfig = {"compilerOptions": {"moduleResolution": "bundler", "module": "preserve"}}
|
||||
(node_modules / "tsconfig.json").write_text(json.dumps(base_tsconfig))
|
||||
|
||||
# Create a project tsconfig that extends the base
|
||||
project_tsconfig = {
|
||||
"extends": "@myorg/tsconfig/tsconfig.json",
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
}
|
||||
}
|
||||
project_tsconfig = {"extends": "@myorg/tsconfig/tsconfig.json", "compilerOptions": {"target": "ES2022"}}
|
||||
(tmpdir_path / "tsconfig.json").write_text(json.dumps(project_tsconfig))
|
||||
|
||||
# Should detect bundler from extended config
|
||||
|
|
@ -632,11 +594,7 @@ class TestBundlerModuleResolutionFix:
|
|||
|
||||
# Create original tsconfig
|
||||
original_tsconfig = {
|
||||
"compilerOptions": {
|
||||
"moduleResolution": "bundler",
|
||||
"module": "preserve",
|
||||
"target": "ES2022",
|
||||
},
|
||||
"compilerOptions": {"moduleResolution": "bundler", "module": "preserve", "target": "ES2022"},
|
||||
"include": ["src/**/*.ts"],
|
||||
"exclude": ["node_modules"],
|
||||
}
|
||||
|
|
@ -683,12 +641,7 @@ class TestBundlerModuleResolutionFix:
|
|||
tmpdir_path = Path(tmpdir)
|
||||
|
||||
# Create tsconfig with bundler
|
||||
tsconfig = {
|
||||
"compilerOptions": {
|
||||
"moduleResolution": "bundler",
|
||||
"module": "preserve",
|
||||
}
|
||||
}
|
||||
tsconfig = {"compilerOptions": {"moduleResolution": "bundler", "module": "preserve"}}
|
||||
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
|
||||
|
|
@ -709,12 +662,7 @@ class TestBundlerModuleResolutionFix:
|
|||
tmpdir_path = Path(tmpdir)
|
||||
|
||||
# Create tsconfig with Node moduleResolution
|
||||
tsconfig = {
|
||||
"compilerOptions": {
|
||||
"moduleResolution": "Node",
|
||||
"module": "ESNext",
|
||||
}
|
||||
}
|
||||
tsconfig = {"compilerOptions": {"moduleResolution": "Node", "module": "ESNext"}}
|
||||
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
|
||||
|
|
@ -772,7 +720,7 @@ class TestBundledJestReporter:
|
|||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -785,10 +733,7 @@ class TestBundledJestReporter:
|
|||
|
||||
try:
|
||||
run_jest_behavioral_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -796,7 +741,9 @@ class TestBundledJestReporter:
|
|||
if mock_run.called:
|
||||
cmd = mock_run.call_args[0][0]
|
||||
reporter_args = [a for a in cmd if "--reporters=" in a and "jest-reporter" in a]
|
||||
assert len(reporter_args) == 1, f"Expected exactly one codeflash/jest-reporter flag, got: {reporter_args}"
|
||||
assert len(reporter_args) == 1, (
|
||||
f"Expected exactly one codeflash/jest-reporter flag, got: {reporter_args}"
|
||||
)
|
||||
assert reporter_args[0] == "--reporters=codeflash/jest-reporter"
|
||||
# Must NOT reference jest-junit
|
||||
jest_junit_args = [a for a in cmd if "jest-junit" in a]
|
||||
|
|
@ -823,7 +770,7 @@ class TestBundledJestReporter:
|
|||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -836,10 +783,7 @@ class TestBundledJestReporter:
|
|||
|
||||
try:
|
||||
run_jest_benchmarking_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -870,7 +814,7 @@ class TestBundledJestReporter:
|
|||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -883,10 +827,7 @@ class TestBundledJestReporter:
|
|||
|
||||
try:
|
||||
run_jest_line_profile_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -963,12 +904,7 @@ reporter.onRunComplete([], results);
|
|||
console.log('OK');
|
||||
""")
|
||||
|
||||
result = subprocess.run(
|
||||
["node", str(test_script)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
result = subprocess.run(["node", str(test_script)], capture_output=True, text=True, timeout=10)
|
||||
|
||||
assert result.returncode == 0, f"Reporter script failed: {result.stderr}"
|
||||
assert output_file.exists(), "Reporter did not create output file"
|
||||
|
|
@ -1020,7 +956,6 @@ console.log('OK');
|
|||
assert exports["./jest-reporter"]["require"] == "./runtime/jest-reporter.js"
|
||||
|
||||
|
||||
|
||||
class TestUnsupportedFrameworkError:
|
||||
"""Tests for clear error on unsupported test frameworks."""
|
||||
|
||||
|
|
@ -1030,12 +965,7 @@ class TestUnsupportedFrameworkError:
|
|||
|
||||
support = JavaScriptSupport()
|
||||
with pytest.raises(NotImplementedError, match="not yet supported"):
|
||||
support.run_behavioral_tests(
|
||||
test_paths=MagicMock(),
|
||||
test_env={},
|
||||
cwd=Path("."),
|
||||
test_framework="tap",
|
||||
)
|
||||
support.run_behavioral_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="tap")
|
||||
|
||||
def test_unknown_framework_raises_error_benchmarking(self):
|
||||
"""run_benchmarking_tests should raise NotImplementedError for unknown frameworks."""
|
||||
|
|
@ -1043,12 +973,7 @@ class TestUnsupportedFrameworkError:
|
|||
|
||||
support = JavaScriptSupport()
|
||||
with pytest.raises(NotImplementedError, match="not yet supported"):
|
||||
support.run_benchmarking_tests(
|
||||
test_paths=MagicMock(),
|
||||
test_env={},
|
||||
cwd=Path("."),
|
||||
test_framework="tap",
|
||||
)
|
||||
support.run_benchmarking_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="tap")
|
||||
|
||||
def test_unknown_framework_raises_error_line_profile(self):
|
||||
"""run_line_profile_tests should raise NotImplementedError for unknown frameworks."""
|
||||
|
|
@ -1056,42 +981,27 @@ class TestUnsupportedFrameworkError:
|
|||
|
||||
support = JavaScriptSupport()
|
||||
with pytest.raises(NotImplementedError, match="not yet supported"):
|
||||
support.run_line_profile_tests(
|
||||
test_paths=MagicMock(),
|
||||
test_env={},
|
||||
cwd=Path("."),
|
||||
test_framework="tap",
|
||||
)
|
||||
support.run_line_profile_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="tap")
|
||||
|
||||
def test_jest_framework_does_not_raise_not_implemented(self):
|
||||
"""jest framework should NOT raise NotImplementedError."""
|
||||
"""Jest framework should NOT raise NotImplementedError."""
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport
|
||||
|
||||
support = JavaScriptSupport()
|
||||
try:
|
||||
support.run_behavioral_tests(
|
||||
test_paths=MagicMock(),
|
||||
test_env={},
|
||||
cwd=Path("."),
|
||||
test_framework="jest",
|
||||
)
|
||||
support.run_behavioral_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="jest")
|
||||
except NotImplementedError:
|
||||
pytest.fail("jest framework should not raise NotImplementedError")
|
||||
except Exception:
|
||||
pass # Other exceptions are fine — Jest isn't installed in test env
|
||||
|
||||
def test_mocha_framework_does_not_raise_not_implemented(self):
|
||||
"""mocha framework should NOT raise NotImplementedError."""
|
||||
"""Mocha framework should NOT raise NotImplementedError."""
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport
|
||||
|
||||
support = JavaScriptSupport()
|
||||
try:
|
||||
support.run_behavioral_tests(
|
||||
test_paths=MagicMock(),
|
||||
test_env={},
|
||||
cwd=Path("."),
|
||||
test_framework="mocha",
|
||||
)
|
||||
support.run_behavioral_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="mocha")
|
||||
except NotImplementedError:
|
||||
pytest.fail("mocha framework should not raise NotImplementedError")
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -8,12 +8,13 @@ from pathlib import Path
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
from codeflash.languages.registry import get_language_support
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import tempfile
|
|||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from junitparser import JUnitXml
|
||||
|
||||
|
||||
|
|
@ -19,12 +18,7 @@ class TestMochaJsonToJunitXml:
|
|||
{
|
||||
"stats": {"tests": 2, "passes": 2, "failures": 0, "duration": 50},
|
||||
"tests": [
|
||||
{
|
||||
"title": "should add numbers",
|
||||
"fullTitle": "math should add numbers",
|
||||
"duration": 20,
|
||||
"err": {},
|
||||
},
|
||||
{"title": "should add numbers", "fullTitle": "math should add numbers", "duration": 20, "err": {}},
|
||||
{
|
||||
"title": "should subtract numbers",
|
||||
"fullTitle": "math should subtract numbers",
|
||||
|
|
@ -62,7 +56,7 @@ class TestMochaJsonToJunitXml:
|
|||
"message": "expected 1 to equal 2",
|
||||
"stack": "AssertionError: expected 1 to equal 2\n at Context.<anonymous>",
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
"passes": [],
|
||||
"failures": [],
|
||||
|
|
@ -92,7 +86,7 @@ class TestMochaJsonToJunitXml:
|
|||
"duration": 0,
|
||||
"pending": True,
|
||||
"err": {},
|
||||
},
|
||||
}
|
||||
],
|
||||
"passes": [],
|
||||
"failures": [],
|
||||
|
|
@ -198,9 +192,7 @@ class TestMochaJsonToJunitXml:
|
|||
mocha_json = json.dumps(
|
||||
{
|
||||
"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 10},
|
||||
"tests": [
|
||||
{"title": "test1", "fullTitle": "someOtherSuite test1", "duration": 10, "err": {}},
|
||||
],
|
||||
"tests": [{"title": "test1", "fullTitle": "someOtherSuite test1", "duration": 10, "err": {}}],
|
||||
"passes": [],
|
||||
"failures": [],
|
||||
"pending": [],
|
||||
|
|
@ -229,9 +221,7 @@ class TestMochaJsonToJunitXml:
|
|||
mocha_json = json.dumps(
|
||||
{
|
||||
"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 10},
|
||||
"tests": [
|
||||
{"title": "test1", "fullTitle": "suite test1", "duration": 10, "err": {}},
|
||||
],
|
||||
"tests": [{"title": "test1", "fullTitle": "suite test1", "duration": 10, "err": {}}],
|
||||
"passes": [],
|
||||
"failures": [],
|
||||
"pending": [],
|
||||
|
|
@ -435,7 +425,13 @@ class TestRunMochaBehavioralTests:
|
|||
from codeflash.models.test_type import TestType
|
||||
|
||||
mocha_output = json.dumps(
|
||||
{"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 10}, "tests": [{"title": "t", "fullTitle": "s t", "duration": 10, "err": {}}], "passes": [], "failures": [], "pending": []}
|
||||
{
|
||||
"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 10},
|
||||
"tests": [{"title": "t", "fullTitle": "s t", "duration": 10, "err": {}}],
|
||||
"passes": [],
|
||||
"failures": [],
|
||||
"pending": [],
|
||||
}
|
||||
)
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
|
||||
|
||||
|
|
@ -457,10 +453,7 @@ class TestRunMochaBehavioralTests:
|
|||
)
|
||||
|
||||
result_file, result, cov, _ = run_mocha_behavioral_tests(
|
||||
test_paths=test_paths,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
candidate_index=3,
|
||||
test_paths=test_paths, test_env={}, cwd=tmpdir_path, candidate_index=3
|
||||
)
|
||||
|
||||
# Verify env vars were passed
|
||||
|
|
@ -478,7 +471,13 @@ class TestRunMochaBehavioralTests:
|
|||
from codeflash.models.test_type import TestType
|
||||
|
||||
mocha_output = json.dumps(
|
||||
{"stats": {"tests": 0, "passes": 0, "failures": 0, "duration": 0}, "tests": [], "passes": [], "failures": [], "pending": []}
|
||||
{
|
||||
"stats": {"tests": 0, "passes": 0, "failures": 0, "duration": 0},
|
||||
"tests": [],
|
||||
"passes": [],
|
||||
"failures": [],
|
||||
"pending": [],
|
||||
}
|
||||
)
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
|
||||
|
||||
|
|
@ -499,11 +498,7 @@ class TestRunMochaBehavioralTests:
|
|||
]
|
||||
)
|
||||
|
||||
_, _, coverage_path, _ = run_mocha_behavioral_tests(
|
||||
test_paths=test_paths,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
)
|
||||
_, _, coverage_path, _ = run_mocha_behavioral_tests(test_paths=test_paths, test_env={}, cwd=tmpdir_path)
|
||||
assert coverage_path is None
|
||||
|
||||
|
||||
|
|
@ -518,7 +513,13 @@ class TestRunMochaBenchmarkingTests:
|
|||
from codeflash.models.test_type import TestType
|
||||
|
||||
mocha_output = json.dumps(
|
||||
{"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 100}, "tests": [{"title": "perf", "fullTitle": "bench perf", "duration": 100, "err": {}}], "passes": [], "failures": [], "pending": []}
|
||||
{
|
||||
"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 100},
|
||||
"tests": [{"title": "perf", "fullTitle": "bench perf", "duration": 100, "err": {}}],
|
||||
"passes": [],
|
||||
"failures": [],
|
||||
"pending": [],
|
||||
}
|
||||
)
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
|
||||
|
||||
|
|
@ -729,7 +730,13 @@ class TestRunMochaLineProfileTests:
|
|||
from codeflash.models.test_type import TestType
|
||||
|
||||
mocha_output = json.dumps(
|
||||
{"stats": {"tests": 0, "passes": 0, "failures": 0, "duration": 0}, "tests": [], "passes": [], "failures": [], "pending": []}
|
||||
{
|
||||
"stats": {"tests": 0, "passes": 0, "failures": 0, "duration": 0},
|
||||
"tests": [],
|
||||
"passes": [],
|
||||
"failures": [],
|
||||
"pending": [],
|
||||
}
|
||||
)
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
|
||||
|
||||
|
|
@ -752,10 +759,7 @@ class TestRunMochaLineProfileTests:
|
|||
)
|
||||
|
||||
run_mocha_line_profile_tests(
|
||||
test_paths=test_paths,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
line_profile_output_file=profile_output,
|
||||
test_paths=test_paths, test_env={}, cwd=tmpdir_path, line_profile_output_file=profile_output
|
||||
)
|
||||
|
||||
call_kwargs = mock_run.call_args
|
||||
|
|
@ -769,7 +773,8 @@ class TestParserUnknownTestNameFallback:
|
|||
|
||||
def test_unknown_markers_matched_to_first_testcase(self):
|
||||
"""When capturePerf markers have 'unknown' test name (Vitest beforeEach not firing),
|
||||
the parser should still match them to testcases via the fallback logic."""
|
||||
the parser should still match them to testcases via the fallback logic.
|
||||
"""
|
||||
from codeflash.languages.javascript.parse import parse_jest_test_xml
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
|
@ -817,10 +822,7 @@ class TestParserUnknownTestNameFallback:
|
|||
test_config.test_framework = "vitest"
|
||||
|
||||
results = parse_jest_test_xml(
|
||||
test_xml_file_path=xml_path,
|
||||
test_files=test_files,
|
||||
test_config=test_config,
|
||||
run_result=mock_result,
|
||||
test_xml_file_path=xml_path, test_files=test_files, test_config=test_config, run_result=mock_result
|
||||
)
|
||||
|
||||
# The "unknown" fallback should assign all 5 markers to the testcase
|
||||
|
|
|
|||
|
|
@ -272,8 +272,8 @@ class TestClearFunctions:
|
|||
assert not is_language_supported(Language.PYTHON)
|
||||
|
||||
# Re-register all languages by importing
|
||||
from codeflash.languages.python.support import PythonSupport
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
from codeflash.languages.python.support import PythonSupport
|
||||
|
||||
# Need to manually register since decorator already ran
|
||||
register_language(PythonSupport)
|
||||
|
|
|
|||
|
|
@ -839,7 +839,7 @@ class TestNamedExportConstArrow:
|
|||
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
|
||||
def test_named_export_const_arrow(self, ts_analyzer):
|
||||
"""const arrow function exported via separate export { } clause."""
|
||||
"""Const arrow function exported via separate export { } clause."""
|
||||
code = """const joinBy = (arr: string[], separator: string) => {
|
||||
return arr.join(separator);
|
||||
};
|
||||
|
|
@ -852,7 +852,7 @@ export { joinBy };"""
|
|||
assert joinBy.is_exported is True
|
||||
|
||||
def test_named_export_alias(self, ts_analyzer):
|
||||
"""export { foo as bar } — foo should be marked as exported."""
|
||||
"""Export { foo as bar } — foo should be marked as exported."""
|
||||
code = """const foo = (x: number) => {
|
||||
return x * 2;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -60,8 +60,9 @@ class TestTypeScriptFunctionDiscovery:
|
|||
|
||||
# Critical: Verify language is "typescript", not "javascript"
|
||||
for func in func_list:
|
||||
assert func.language == "typescript", \
|
||||
assert func.language == "typescript", (
|
||||
f"Function {func.function_name} should have language='typescript', got '{func.language}'"
|
||||
)
|
||||
|
||||
def test_discover_functions_with_type_annotations(self):
|
||||
"""Test discovering TypeScript functions with type annotations."""
|
||||
|
|
@ -176,11 +177,7 @@ function multiply(a: number, b: number): number {
|
|||
ts_support = get_language_support(Language.TYPESCRIPT)
|
||||
|
||||
func_info = FunctionInfo(
|
||||
function_name="add",
|
||||
file_path=Path("/tmp/test.ts"),
|
||||
starting_line=2,
|
||||
ending_line=4,
|
||||
language="typescript"
|
||||
function_name="add", file_path=Path("/tmp/test.ts"), starting_line=2, ending_line=4, language="typescript"
|
||||
)
|
||||
|
||||
result = ts_support.replace_function(original_source, func_info, new_function)
|
||||
|
|
@ -227,7 +224,7 @@ function processConfig(config: Config): string {
|
|||
file_path=Path("/tmp/test.ts"),
|
||||
starting_line=7,
|
||||
ending_line=9,
|
||||
language="typescript"
|
||||
language="typescript",
|
||||
)
|
||||
|
||||
result = ts_support.replace_function(original_source, func_info, new_function)
|
||||
|
|
@ -264,11 +261,7 @@ class TestTypeScriptTestDiscovery:
|
|||
|
||||
fib_file = ts_project_dir / "fibonacci.ts"
|
||||
func_info = FunctionInfo(
|
||||
function_name="fibonacci",
|
||||
file_path=fib_file,
|
||||
starting_line=1,
|
||||
ending_line=7,
|
||||
language="typescript"
|
||||
function_name="fibonacci", file_path=fib_file, starting_line=1, ending_line=7, language="typescript"
|
||||
)
|
||||
|
||||
tests = ts_support.discover_tests(test_root, [func_info])
|
||||
|
|
@ -328,7 +321,7 @@ export function standalone(x: number): number {
|
|||
CodeString(
|
||||
code="function add(a: number, b: number): number { return a + b; }",
|
||||
file_path=Path("test.ts"),
|
||||
language="typescript"
|
||||
language="typescript",
|
||||
)
|
||||
],
|
||||
language="typescript",
|
||||
|
|
|
|||
|
|
@ -301,15 +301,7 @@ class TestVitestVsJestDetection:
|
|||
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test",
|
||||
"devDependencies": {
|
||||
"vitest": "^2.0.0",
|
||||
"jest": "^29.0.0",
|
||||
},
|
||||
}
|
||||
)
|
||||
json.dumps({"name": "test", "devDependencies": {"vitest": "^2.0.0", "jest": "^29.0.0"}})
|
||||
)
|
||||
package_data = get_package_json_data(package_json)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,13 @@ def pytest_loops_instance(pytestconfig: Config) -> PytestLoops:
|
|||
@pytest.fixture
|
||||
def mock_item() -> type:
|
||||
class MockItem:
|
||||
def __init__(self, function: types.FunctionType, name: str = "test_func", cls: type = None, module: types.ModuleType = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
function: types.FunctionType,
|
||||
name: str = "test_func",
|
||||
cls: type = None,
|
||||
module: types.ModuleType = None,
|
||||
) -> None:
|
||||
self.function = function
|
||||
self.name = name
|
||||
self.cls = cls
|
||||
|
|
@ -352,7 +358,9 @@ obj.my_method(5)
|
|||
item = mock_item(no_cache_func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
|
||||
def test_clears_module_level_caches_via_sys_modules(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
def test_clears_module_level_caches_via_sys_modules(
|
||||
self, pytest_loops_instance: PytestLoops, mock_item: type
|
||||
) -> None:
|
||||
module_name = "_cf_test_module_scan"
|
||||
source_code = """
|
||||
import functools
|
||||
|
|
|
|||
|
|
@ -127,7 +127,9 @@ code_to_optimize/tests/test_simple.py:10: AssertionError
|
|||
)
|
||||
|
||||
assert "TestCalculator.test_divide_by_zero" in errors
|
||||
assert errors["TestCalculator.test_divide_by_zero"] == """
|
||||
assert (
|
||||
errors["TestCalculator.test_divide_by_zero"]
|
||||
== """
|
||||
class TestCalculator:
|
||||
def test_divide_by_zero(self):
|
||||
> Calculator().divide(10, 0)
|
||||
|
|
@ -135,6 +137,7 @@ E ZeroDivisionError: division by zero
|
|||
|
||||
code_to_optimize/tests/test_calculator.py:22: ZeroDivisionError
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_extracting_from_invalid_pytest_stdout():
|
||||
|
|
|
|||
|
|
@ -1,11 +1,6 @@
|
|||
"""Tests for the regex patterns and string matching in parse_test_output.py."""
|
||||
|
||||
from codeflash.verification.parse_test_output import (
|
||||
matches_re_end,
|
||||
matches_re_start,
|
||||
parse_test_failures_from_stdout,
|
||||
)
|
||||
|
||||
from codeflash.verification.parse_test_output import matches_re_end, matches_re_start, parse_test_failures_from_stdout
|
||||
|
||||
# --- matches_re_start tests ---
|
||||
|
||||
|
|
@ -42,10 +37,7 @@ class TestMatchesReStart:
|
|||
assert m.groups() == ("mod", "", "test_fn", "f", "1", "x")
|
||||
|
||||
def test_multiple_matches(self) -> None:
|
||||
s = (
|
||||
"!$######m1:C1.fn1:t1:1:a######$!\n"
|
||||
"!$######m2:fn2:t2:2:b######$!\n"
|
||||
)
|
||||
s = "!$######m1:C1.fn1:t1:1:a######$!\n!$######m2:fn2:t2:2:b######$!\n"
|
||||
matches = list(matches_re_start.finditer(s))
|
||||
assert len(matches) == 2
|
||||
assert matches[0].groups() == ("m1", "C1.", "fn1", "t1", "1", "a")
|
||||
|
|
@ -170,20 +162,12 @@ class TestParseTestFailuresHeader:
|
|||
|
||||
def test_word_failures_without_equals_is_not_matched(self) -> None:
|
||||
"""'FAILURES' without surrounding '=' signs should not trigger the header detection."""
|
||||
stdout = (
|
||||
"FAILURES detected in module\n"
|
||||
"_______ test_baz _______\n"
|
||||
"\n"
|
||||
" assert False\n"
|
||||
)
|
||||
stdout = "FAILURES detected in module\n_______ test_baz _______\n\n assert False\n"
|
||||
result = parse_test_failures_from_stdout(stdout)
|
||||
assert result == {}
|
||||
|
||||
def test_failures_in_test_output_not_matched(self) -> None:
|
||||
"""A test printing 'FAILURES' (no = signs) should not trigger header detection."""
|
||||
stdout = (
|
||||
"Testing FAILURES handling\n"
|
||||
"All good\n"
|
||||
)
|
||||
stdout = "Testing FAILURES handling\nAll good\n"
|
||||
result = parse_test_failures_from_stdout(stdout)
|
||||
assert result == {}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
|
||||
|
||||
from codeflash.languages.python.context.unused_definition_remover import remove_unused_definitions_by_function_names
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -74,10 +74,7 @@ class TestCodeflashConfig:
|
|||
|
||||
def test_to_pyproject_dict_minimal(self):
|
||||
"""Should only include non-default values."""
|
||||
config = CodeflashConfig(
|
||||
language="python",
|
||||
module_root="src",
|
||||
)
|
||||
config = CodeflashConfig(language="python", module_root="src")
|
||||
|
||||
result = config.to_pyproject_dict()
|
||||
|
||||
|
|
@ -149,11 +146,7 @@ class TestCodeflashConfig:
|
|||
|
||||
def test_from_package_json_dict(self):
|
||||
"""Should create config from package.json dict."""
|
||||
data = {
|
||||
"moduleRoot": "lib",
|
||||
"formatterCmds": ["npx prettier --write $file"],
|
||||
"disableTelemetry": True,
|
||||
}
|
||||
data = {"moduleRoot": "lib", "formatterCmds": ["npx prettier --write $file"], "disableTelemetry": True}
|
||||
|
||||
config = CodeflashConfig.from_package_json_dict(data)
|
||||
|
||||
|
|
@ -168,11 +161,7 @@ class TestWritePyprojectToml:
|
|||
|
||||
def test_creates_new_pyproject(self, tmp_path):
|
||||
"""Should create pyproject.toml if it doesn't exist."""
|
||||
config = CodeflashConfig(
|
||||
language="python",
|
||||
module_root="src",
|
||||
tests_root="tests",
|
||||
)
|
||||
config = CodeflashConfig(language="python", module_root="src", tests_root="tests")
|
||||
|
||||
success, message = _write_pyproject_toml(tmp_path, config)
|
||||
|
||||
|
|
@ -192,10 +181,7 @@ class TestWritePyprojectToml:
|
|||
'[project]\nname = "myapp"\nversion = "1.0.0"\n\n[tool.ruff]\nline-length = 120'
|
||||
)
|
||||
|
||||
config = CodeflashConfig(
|
||||
language="python",
|
||||
module_root="src",
|
||||
)
|
||||
config = CodeflashConfig(language="python", module_root="src")
|
||||
|
||||
success, message = _write_pyproject_toml(tmp_path, config)
|
||||
|
||||
|
|
@ -210,15 +196,9 @@ class TestWritePyprojectToml:
|
|||
|
||||
def test_updates_existing_codeflash_section(self, tmp_path):
|
||||
"""Should update existing codeflash section."""
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
'[tool.codeflash]\nmodule-root = "old"\ntests-root = "old_tests"'
|
||||
)
|
||||
(tmp_path / "pyproject.toml").write_text('[tool.codeflash]\nmodule-root = "old"\ntests-root = "old_tests"')
|
||||
|
||||
config = CodeflashConfig(
|
||||
language="python",
|
||||
module_root="new",
|
||||
tests_root="new_tests",
|
||||
)
|
||||
config = CodeflashConfig(language="python", module_root="new", tests_root="new_tests")
|
||||
|
||||
success, message = _write_pyproject_toml(tmp_path, config)
|
||||
|
||||
|
|
@ -235,15 +215,10 @@ class TestWritePackageJson:
|
|||
|
||||
def test_adds_codeflash_section(self, tmp_path):
|
||||
"""Should add codeflash section to package.json."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "myapp",
|
||||
"version": "1.0.0"
|
||||
}, indent=2))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"name": "myapp", "version": "1.0.0"}, indent=2))
|
||||
|
||||
config = CodeflashConfig(
|
||||
language="javascript",
|
||||
module_root="lib",
|
||||
formatter_cmds=["npx prettier --write $file"],
|
||||
language="javascript", module_root="lib", formatter_cmds=["npx prettier --write $file"]
|
||||
)
|
||||
|
||||
success, message = _write_package_json(tmp_path, config)
|
||||
|
|
@ -259,16 +234,14 @@ class TestWritePackageJson:
|
|||
|
||||
def test_preserves_existing_content(self, tmp_path):
|
||||
"""Should preserve existing package.json content."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "myapp",
|
||||
"dependencies": {"lodash": "^4.17.0"},
|
||||
"devDependencies": {"jest": "^29.0.0"}
|
||||
}, indent=2))
|
||||
|
||||
config = CodeflashConfig(
|
||||
language="javascript",
|
||||
module_root="lib",
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps(
|
||||
{"name": "myapp", "dependencies": {"lodash": "^4.17.0"}, "devDependencies": {"jest": "^29.0.0"}},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
config = CodeflashConfig(language="javascript", module_root="lib")
|
||||
|
||||
success, message = _write_package_json(tmp_path, config)
|
||||
|
||||
|
|
@ -281,10 +254,9 @@ class TestWritePackageJson:
|
|||
|
||||
def test_removes_empty_codeflash_section(self, tmp_path):
|
||||
"""Should remove codeflash section if all defaults."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "myapp",
|
||||
"codeflash": {"moduleRoot": "old"}
|
||||
}, indent=2))
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps({"name": "myapp", "codeflash": {"moduleRoot": "old"}}, indent=2)
|
||||
)
|
||||
|
||||
# Config with all defaults - should result in empty dict
|
||||
config = CodeflashConfig(
|
||||
|
|
@ -342,9 +314,7 @@ class TestRemoveConfig:
|
|||
|
||||
def test_removes_from_pyproject(self, tmp_path):
|
||||
"""Should remove codeflash section from pyproject.toml."""
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
'[project]\nname = "test"\n\n[tool.codeflash]\nmodule-root = "src"'
|
||||
)
|
||||
(tmp_path / "pyproject.toml").write_text('[project]\nname = "test"\n\n[tool.codeflash]\nmodule-root = "src"')
|
||||
|
||||
success, message = remove_config(tmp_path, "python")
|
||||
|
||||
|
|
@ -357,10 +327,9 @@ class TestRemoveConfig:
|
|||
|
||||
def test_removes_from_package_json(self, tmp_path):
|
||||
"""Should remove codeflash section from package.json."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"codeflash": {"moduleRoot": "src"}
|
||||
}, indent=2))
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps({"name": "test", "codeflash": {"moduleRoot": "src"}}, indent=2)
|
||||
)
|
||||
|
||||
success, message = remove_config(tmp_path, "javascript")
|
||||
|
||||
|
|
|
|||
|
|
@ -141,10 +141,9 @@ class TestDetectModuleRoot:
|
|||
|
||||
def test_js_detects_from_exports(self, tmp_path):
|
||||
"""Should detect module root from package.json exports when no common src dir exists."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"exports": {".": "./packages/core/index.js"}
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps({"name": "test", "exports": {".": "./packages/core/index.js"}})
|
||||
)
|
||||
(tmp_path / "packages" / "core").mkdir(parents=True)
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
|
|
@ -161,11 +160,9 @@ class TestDetectModuleRoot:
|
|||
|
||||
def test_js_prefers_src_over_build_src(self, tmp_path):
|
||||
"""Should prefer src/ over build/src/ even when package.json points to build/."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"main": "build/src/index.js",
|
||||
"module": "build/src/index.js"
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps({"name": "test", "main": "build/src/index.js", "module": "build/src/index.js"})
|
||||
)
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "build" / "src").mkdir(parents=True)
|
||||
|
||||
|
|
@ -175,10 +172,7 @@ class TestDetectModuleRoot:
|
|||
|
||||
def test_js_skips_build_dir_from_main(self, tmp_path):
|
||||
"""Should skip build output directories from package.json main field."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"main": "build/index.js"
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "main": "build/index.js"}))
|
||||
(tmp_path / "build").mkdir()
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
|
|
@ -187,10 +181,7 @@ class TestDetectModuleRoot:
|
|||
|
||||
def test_js_skips_dist_dir_from_exports(self, tmp_path):
|
||||
"""Should skip dist output directories from package.json exports field."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"exports": {".": "./dist/index.js"}
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "exports": {".": "./dist/index.js"}}))
|
||||
(tmp_path / "dist").mkdir()
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
|
|
@ -199,10 +190,7 @@ class TestDetectModuleRoot:
|
|||
|
||||
def test_js_skips_out_dir_from_module(self, tmp_path):
|
||||
"""Should skip out output directories from package.json module field."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"module": "out/esm/index.js"
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "module": "out/esm/index.js"}))
|
||||
(tmp_path / "out" / "esm").mkdir(parents=True)
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
|
|
@ -211,10 +199,7 @@ class TestDetectModuleRoot:
|
|||
|
||||
def test_js_prefers_lib_over_build_dir(self, tmp_path):
|
||||
"""Should prefer lib/ over build output directories."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"main": "dist/index.js"
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "main": "dist/index.js"}))
|
||||
(tmp_path / "lib").mkdir()
|
||||
(tmp_path / "dist").mkdir()
|
||||
|
||||
|
|
@ -224,10 +209,7 @@ class TestDetectModuleRoot:
|
|||
|
||||
def test_js_prefers_source_over_build_dir(self, tmp_path):
|
||||
"""Should prefer source/ over build output directories."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"main": "build/index.js"
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "main": "build/index.js"}))
|
||||
(tmp_path / "source").mkdir()
|
||||
(tmp_path / "build").mkdir()
|
||||
|
||||
|
|
@ -237,10 +219,9 @@ class TestDetectModuleRoot:
|
|||
|
||||
def test_js_falls_back_to_valid_exports_path(self, tmp_path):
|
||||
"""Should use exports path when no common source dirs exist and path is not build output."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"exports": {".": "./packages/core/index.js"}
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps({"name": "test", "exports": {".": "./packages/core/index.js"}})
|
||||
)
|
||||
(tmp_path / "packages" / "core").mkdir(parents=True)
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
|
|
@ -249,10 +230,7 @@ class TestDetectModuleRoot:
|
|||
|
||||
def test_js_falls_back_to_valid_main_path(self, tmp_path):
|
||||
"""Should use main path when no common source dirs exist and path is not build output."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"main": "packages/main/index.js"
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "main": "packages/main/index.js"}))
|
||||
(tmp_path / "packages" / "main").mkdir(parents=True)
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
|
|
@ -261,10 +239,7 @@ class TestDetectModuleRoot:
|
|||
|
||||
def test_js_falls_back_to_valid_module_path(self, tmp_path):
|
||||
"""Should use module path when no common source dirs exist and path is not build output."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"module": "esm/index.js"
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "module": "esm/index.js"}))
|
||||
(tmp_path / "esm").mkdir()
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
|
|
@ -273,12 +248,16 @@ class TestDetectModuleRoot:
|
|||
|
||||
def test_js_returns_project_root_when_all_paths_are_build_output(self, tmp_path):
|
||||
"""Should return project root when all package.json paths point to build outputs."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test",
|
||||
"main": "dist/cjs/index.js",
|
||||
"module": "dist/esm/index.js",
|
||||
"exports": {".": "./build/index.js"}
|
||||
}))
|
||||
"exports": {".": "./build/index.js"},
|
||||
}
|
||||
)
|
||||
)
|
||||
(tmp_path / "dist" / "cjs").mkdir(parents=True)
|
||||
(tmp_path / "dist" / "esm").mkdir(parents=True)
|
||||
(tmp_path / "build").mkdir()
|
||||
|
|
@ -302,6 +281,7 @@ class TestIsBuildOutputDir:
|
|||
def test_detects_build_dir(self):
|
||||
"""Should detect build/ as build output."""
|
||||
from pathlib import Path
|
||||
|
||||
assert is_build_output_dir(Path("build"))
|
||||
assert is_build_output_dir(Path("build/src"))
|
||||
assert is_build_output_dir(Path("build/src/index.js"))
|
||||
|
|
@ -309,6 +289,7 @@ class TestIsBuildOutputDir:
|
|||
def test_detects_dist_dir(self):
|
||||
"""Should detect dist/ as build output."""
|
||||
from pathlib import Path
|
||||
|
||||
assert is_build_output_dir(Path("dist"))
|
||||
assert is_build_output_dir(Path("dist/esm"))
|
||||
assert is_build_output_dir(Path("dist/cjs/index.js"))
|
||||
|
|
@ -316,53 +297,62 @@ class TestIsBuildOutputDir:
|
|||
def test_detects_out_dir(self):
|
||||
"""Should detect out/ as build output."""
|
||||
from pathlib import Path
|
||||
|
||||
assert is_build_output_dir(Path("out"))
|
||||
assert is_build_output_dir(Path("out/src"))
|
||||
|
||||
def test_detects_next_dir(self):
|
||||
"""Should detect .next/ as build output."""
|
||||
from pathlib import Path
|
||||
|
||||
assert is_build_output_dir(Path(".next"))
|
||||
assert is_build_output_dir(Path(".next/static"))
|
||||
|
||||
def test_detects_nuxt_dir(self):
|
||||
"""Should detect .nuxt/ as build output."""
|
||||
from pathlib import Path
|
||||
|
||||
assert is_build_output_dir(Path(".nuxt"))
|
||||
assert is_build_output_dir(Path(".nuxt/dist"))
|
||||
|
||||
def test_detects_nested_build_dir(self):
|
||||
"""Should detect build dir nested in path."""
|
||||
from pathlib import Path
|
||||
|
||||
assert is_build_output_dir(Path("packages/build/index.js"))
|
||||
assert is_build_output_dir(Path("foo/dist/bar"))
|
||||
|
||||
def test_does_not_detect_src(self):
|
||||
"""Should not detect src/ as build output."""
|
||||
from pathlib import Path
|
||||
|
||||
assert not is_build_output_dir(Path("src"))
|
||||
assert not is_build_output_dir(Path("src/index.js"))
|
||||
|
||||
def test_does_not_detect_lib(self):
|
||||
"""Should not detect lib/ as build output."""
|
||||
from pathlib import Path
|
||||
|
||||
assert not is_build_output_dir(Path("lib"))
|
||||
assert not is_build_output_dir(Path("lib/utils"))
|
||||
|
||||
def test_does_not_detect_source(self):
|
||||
"""Should not detect source/ as build output."""
|
||||
from pathlib import Path
|
||||
|
||||
assert not is_build_output_dir(Path("source"))
|
||||
|
||||
def test_does_not_detect_packages(self):
|
||||
"""Should not detect packages/ as build output."""
|
||||
from pathlib import Path
|
||||
|
||||
assert not is_build_output_dir(Path("packages"))
|
||||
assert not is_build_output_dir(Path("packages/core"))
|
||||
|
||||
def test_does_not_detect_similar_names(self):
|
||||
"""Should not detect directories with similar but different names."""
|
||||
from pathlib import Path
|
||||
|
||||
assert not is_build_output_dir(Path("builder"))
|
||||
assert not is_build_output_dir(Path("distribution"))
|
||||
assert not is_build_output_dir(Path("output"))
|
||||
|
|
@ -417,18 +407,14 @@ class TestDetectTestRunner:
|
|||
|
||||
def test_js_detects_jest_from_deps(self, tmp_path):
|
||||
"""Should detect jest from devDependencies."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"devDependencies": {"jest": "^29.0.0"}
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"devDependencies": {"jest": "^29.0.0"}}))
|
||||
|
||||
runner, detail = _detect_js_test_runner(tmp_path)
|
||||
assert runner == "jest"
|
||||
|
||||
def test_js_detects_vitest_from_deps(self, tmp_path):
|
||||
"""Should detect vitest from devDependencies (preferred over jest)."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"devDependencies": {"vitest": "^1.0.0", "jest": "^29.0.0"}
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"devDependencies": {"vitest": "^1.0.0", "jest": "^29.0.0"}}))
|
||||
|
||||
runner, detail = _detect_js_test_runner(tmp_path)
|
||||
assert runner == "vitest"
|
||||
|
|
@ -469,9 +455,7 @@ class TestDetectFormatter:
|
|||
|
||||
def test_js_detects_prettier_from_deps(self, tmp_path):
|
||||
"""Should detect prettier from devDependencies."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"devDependencies": {"prettier": "^3.0.0"}
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"devDependencies": {"prettier": "^3.0.0"}}))
|
||||
|
||||
formatter, detail = _detect_js_formatter(tmp_path)
|
||||
assert any("prettier" in cmd for cmd in formatter)
|
||||
|
|
@ -483,9 +467,7 @@ class TestDetectProject:
|
|||
def test_detects_python_project(self, tmp_path):
|
||||
"""Should correctly detect a Python project."""
|
||||
# Create Python project structure
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
'[project]\nname = "myapp"\n\n[tool.ruff]\nline-length = 120'
|
||||
)
|
||||
(tmp_path / "pyproject.toml").write_text('[project]\nname = "myapp"\n\n[tool.ruff]\nline-length = 120')
|
||||
(tmp_path / "myapp").mkdir()
|
||||
(tmp_path / "myapp" / "__init__.py").write_text("")
|
||||
(tmp_path / "tests").mkdir()
|
||||
|
|
@ -503,10 +485,9 @@ class TestDetectProject:
|
|||
def test_detects_javascript_project(self, tmp_path):
|
||||
"""Should correctly detect a JavaScript project."""
|
||||
# Create JS project structure
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "myapp",
|
||||
"devDependencies": {"jest": "^29.0.0", "prettier": "^3.0.0"}
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps({"name": "myapp", "devDependencies": {"jest": "^29.0.0", "prettier": "^3.0.0"}})
|
||||
)
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "tests").mkdir()
|
||||
(tmp_path / ".git").mkdir()
|
||||
|
|
@ -523,10 +504,9 @@ class TestDetectProject:
|
|||
def test_detects_typescript_project(self, tmp_path):
|
||||
"""Should correctly detect a TypeScript project."""
|
||||
# Create TS project structure
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "myapp",
|
||||
"devDependencies": {"vitest": "^1.0.0", "typescript": "^5.0.0"}
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps({"name": "myapp", "devDependencies": {"vitest": "^1.0.0", "typescript": "^5.0.0"}})
|
||||
)
|
||||
(tmp_path / "tsconfig.json").write_text("{}")
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / ".git").mkdir()
|
||||
|
|
@ -556,9 +536,7 @@ class TestHasExistingConfig:
|
|||
|
||||
def test_detects_pyproject_config(self, tmp_path):
|
||||
"""Should detect config in pyproject.toml."""
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
'[tool.codeflash]\nmodule-root = "src"'
|
||||
)
|
||||
(tmp_path / "pyproject.toml").write_text('[tool.codeflash]\nmodule-root = "src"')
|
||||
|
||||
has_config, config_type = has_existing_config(tmp_path)
|
||||
assert has_config is True
|
||||
|
|
@ -566,10 +544,7 @@ class TestHasExistingConfig:
|
|||
|
||||
def test_detects_package_json_config(self, tmp_path):
|
||||
"""Should detect config in package.json."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"codeflash": {"moduleRoot": "src"}
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "codeflash": {"moduleRoot": "src"}}))
|
||||
|
||||
has_config, config_type = has_existing_config(tmp_path)
|
||||
assert has_config is True
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ from codeflash.setup import (
|
|||
def python_src_layout(tmp_path):
|
||||
"""Create a Python project with src/ layout."""
|
||||
# pyproject.toml with poetry
|
||||
(tmp_path / "pyproject.toml").write_text("""
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
"""
|
||||
[tool.poetry]
|
||||
name = "myapp"
|
||||
version = "0.1.0"
|
||||
|
|
@ -41,7 +42,8 @@ line-length = 120
|
|||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
""".strip())
|
||||
""".strip()
|
||||
)
|
||||
|
||||
# src/myapp package
|
||||
src_dir = tmp_path / "src" / "myapp"
|
||||
|
|
@ -66,14 +68,16 @@ testpaths = ["tests"]
|
|||
@pytest.fixture
|
||||
def python_flat_layout(tmp_path):
|
||||
"""Create a Python project with flat layout (package at root)."""
|
||||
(tmp_path / "pyproject.toml").write_text("""
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
"""
|
||||
[project]
|
||||
name = "myapp"
|
||||
version = "0.1.0"
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
""".strip())
|
||||
""".strip()
|
||||
)
|
||||
|
||||
# Package at root
|
||||
pkg_dir = tmp_path / "myapp"
|
||||
|
|
@ -93,14 +97,16 @@ line-length = 88
|
|||
@pytest.fixture
|
||||
def python_setup_py_project(tmp_path):
|
||||
"""Create a Python project with setup.py (legacy)."""
|
||||
(tmp_path / "setup.py").write_text("""
|
||||
(tmp_path / "setup.py").write_text(
|
||||
"""
|
||||
from setuptools import setup, find_packages
|
||||
setup(
|
||||
name="legacyapp",
|
||||
version="1.0.0",
|
||||
packages=find_packages(),
|
||||
)
|
||||
""".strip())
|
||||
""".strip()
|
||||
)
|
||||
|
||||
pkg_dir = tmp_path / "legacyapp"
|
||||
pkg_dir.mkdir()
|
||||
|
|
@ -114,19 +120,18 @@ setup(
|
|||
@pytest.fixture
|
||||
def javascript_npm_project(tmp_path):
|
||||
"""Create a JavaScript project with npm."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "my-js-app",
|
||||
"version": "1.0.0",
|
||||
"main": "src/index.js",
|
||||
"scripts": {
|
||||
"test": "jest",
|
||||
"lint": "eslint src/"
|
||||
"scripts": {"test": "jest", "lint": "eslint src/"},
|
||||
"devDependencies": {"jest": "^29.7.0", "prettier": "^3.0.0"},
|
||||
},
|
||||
"devDependencies": {
|
||||
"jest": "^29.7.0",
|
||||
"prettier": "^3.0.0"
|
||||
}
|
||||
}, indent=2))
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
(tmp_path / "package-lock.json").write_text("{}")
|
||||
|
||||
|
|
@ -147,15 +152,17 @@ def javascript_npm_project(tmp_path):
|
|||
@pytest.fixture
|
||||
def javascript_yarn_project(tmp_path):
|
||||
"""Create a JavaScript project with yarn."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "yarn-app",
|
||||
"version": "1.0.0",
|
||||
"main": "lib/index.js",
|
||||
"devDependencies": {
|
||||
"jest": "^29.0.0",
|
||||
"eslint": "^8.0.0"
|
||||
}
|
||||
}, indent=2))
|
||||
"devDependencies": {"jest": "^29.0.0", "eslint": "^8.0.0"},
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
(tmp_path / "yarn.lock").write_text("# yarn lockfile")
|
||||
|
||||
|
|
@ -171,16 +178,17 @@ def javascript_yarn_project(tmp_path):
|
|||
@pytest.fixture
|
||||
def javascript_pnpm_project(tmp_path):
|
||||
"""Create a JavaScript project with pnpm."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "pnpm-app",
|
||||
"version": "1.0.0",
|
||||
"exports": {
|
||||
".": "./dist/index.js"
|
||||
"exports": {".": "./dist/index.js"},
|
||||
"devDependencies": {"vitest": "^1.0.0"},
|
||||
},
|
||||
"devDependencies": {
|
||||
"vitest": "^1.0.0"
|
||||
}
|
||||
}, indent=2))
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
(tmp_path / "pnpm-lock.yaml").write_text("lockfileVersion: 5.4")
|
||||
|
||||
|
|
@ -193,14 +201,17 @@ def javascript_pnpm_project(tmp_path):
|
|||
@pytest.fixture
|
||||
def javascript_bun_project(tmp_path):
|
||||
"""Create a JavaScript project with bun."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "bun-app",
|
||||
"version": "1.0.0",
|
||||
"module": "src/index.ts",
|
||||
"devDependencies": {
|
||||
"bun-types": "latest"
|
||||
}
|
||||
}, indent=2))
|
||||
"devDependencies": {"bun-types": "latest"},
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
(tmp_path / "bun.lockb").write_bytes(b"bun lockfile")
|
||||
|
||||
|
|
@ -212,32 +223,35 @@ def javascript_bun_project(tmp_path):
|
|||
@pytest.fixture
|
||||
def typescript_project(tmp_path):
|
||||
"""Create a TypeScript project."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "ts-app",
|
||||
"version": "1.0.0",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
"scripts": {
|
||||
"build": "tsc",
|
||||
"test": "vitest"
|
||||
"scripts": {"build": "tsc", "test": "vitest"},
|
||||
"devDependencies": {"typescript": "^5.0.0", "vitest": "^1.0.0", "@types/node": "^20.0.0"},
|
||||
},
|
||||
"devDependencies": {
|
||||
"typescript": "^5.0.0",
|
||||
"vitest": "^1.0.0",
|
||||
"@types/node": "^20.0.0"
|
||||
}
|
||||
}, indent=2))
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
(tmp_path / "tsconfig.json").write_text(json.dumps({
|
||||
(tmp_path / "tsconfig.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"module": "commonjs",
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src",
|
||||
"strict": True
|
||||
"strict": True,
|
||||
},
|
||||
"include": ["src/**/*"]
|
||||
}, indent=2))
|
||||
"include": ["src/**/*"],
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
src_dir = tmp_path / "src"
|
||||
src_dir.mkdir()
|
||||
|
|
@ -255,7 +269,9 @@ def typescript_project(tmp_path):
|
|||
@pytest.fixture
|
||||
def typescript_react_project(tmp_path):
|
||||
"""Create a TypeScript React project (like Create React App)."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "react-app",
|
||||
"version": "0.1.0",
|
||||
"private": True,
|
||||
|
|
@ -263,27 +279,26 @@ def typescript_react_project(tmp_path):
|
|||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-scripts": "5.0.1",
|
||||
"jest": "^29.0.0"
|
||||
"jest": "^29.0.0",
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react": "^18.0.0",
|
||||
"@testing-library/react": "^14.0.0",
|
||||
"typescript": "^5.0.0"
|
||||
"typescript": "^5.0.0",
|
||||
},
|
||||
"scripts": {
|
||||
"start": "react-scripts start",
|
||||
"build": "react-scripts build",
|
||||
"test": "react-scripts test"
|
||||
}
|
||||
}, indent=2))
|
||||
"test": "react-scripts test",
|
||||
},
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
(tmp_path / "tsconfig.json").write_text(json.dumps({
|
||||
"compilerOptions": {
|
||||
"target": "es5",
|
||||
"lib": ["dom", "es2015"],
|
||||
"jsx": "react-jsx"
|
||||
}
|
||||
}, indent=2))
|
||||
(tmp_path / "tsconfig.json").write_text(
|
||||
json.dumps({"compilerOptions": {"target": "es5", "lib": ["dom", "es2015"], "jsx": "react-jsx"}}, indent=2)
|
||||
)
|
||||
|
||||
src_dir = tmp_path / "src"
|
||||
src_dir.mkdir()
|
||||
|
|
@ -299,7 +314,8 @@ def typescript_react_project(tmp_path):
|
|||
@pytest.fixture
|
||||
def project_with_existing_config(tmp_path):
|
||||
"""Create a project with existing codeflash config."""
|
||||
(tmp_path / "pyproject.toml").write_text("""
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
"""
|
||||
[project]
|
||||
name = "configured-app"
|
||||
|
||||
|
|
@ -307,7 +323,8 @@ name = "configured-app"
|
|||
module-root = "src"
|
||||
tests-root = "tests"
|
||||
formatter-cmds = ["black $file"]
|
||||
""".strip())
|
||||
""".strip()
|
||||
)
|
||||
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "tests").mkdir()
|
||||
|
|
@ -319,13 +336,15 @@ formatter-cmds = ["black $file"]
|
|||
def mixed_python_js_project(tmp_path):
|
||||
"""Create a project with both Python and JS files (monorepo-like)."""
|
||||
# Python backend
|
||||
(tmp_path / "pyproject.toml").write_text("""
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
"""
|
||||
[project]
|
||||
name = "fullstack-app"
|
||||
|
||||
[tool.codeflash]
|
||||
module-root = "backend"
|
||||
""".strip())
|
||||
""".strip()
|
||||
)
|
||||
|
||||
backend_dir = tmp_path / "backend"
|
||||
backend_dir.mkdir()
|
||||
|
|
@ -335,10 +354,7 @@ module-root = "backend"
|
|||
# JS frontend
|
||||
frontend_dir = tmp_path / "frontend"
|
||||
frontend_dir.mkdir()
|
||||
(frontend_dir / "package.json").write_text(json.dumps({
|
||||
"name": "frontend",
|
||||
"devDependencies": {"jest": "^29.0.0"}
|
||||
}))
|
||||
(frontend_dir / "package.json").write_text(json.dumps({"name": "frontend", "devDependencies": {"jest": "^29.0.0"}}))
|
||||
(frontend_dir / "src").mkdir()
|
||||
(frontend_dir / "src" / "app.js").write_text("")
|
||||
|
||||
|
|
@ -458,10 +474,7 @@ class TestE2EFirstRunCheck:
|
|||
|
||||
def test_has_existing_config_js(self, tmp_path):
|
||||
"""Should find existing config in package.json."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"codeflash": {"moduleRoot": "src"}
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "codeflash": {"moduleRoot": "src"}}))
|
||||
|
||||
has_config, config_type = has_existing_config(tmp_path)
|
||||
assert has_config is True
|
||||
|
|
@ -610,17 +623,9 @@ class TestE2EFirstRunExperience:
|
|||
monkeypatch.chdir(python_flat_layout)
|
||||
monkeypatch.setenv("CODEFLASH_API_KEY", "cf-test-key-12345")
|
||||
|
||||
existing_args = Namespace(
|
||||
file="myapp/core.py",
|
||||
function="process",
|
||||
custom_flag=True,
|
||||
)
|
||||
existing_args = Namespace(file="myapp/core.py", function="process", custom_flag=True)
|
||||
|
||||
result = handle_first_run(
|
||||
args=existing_args,
|
||||
skip_confirm=True,
|
||||
skip_api_key=True,
|
||||
)
|
||||
result = handle_first_run(args=existing_args, skip_confirm=True, skip_api_key=True)
|
||||
|
||||
assert result is not None
|
||||
assert result.custom_flag is True # Preserved
|
||||
|
|
@ -681,10 +686,9 @@ class TestE2EEdgeCases:
|
|||
|
||||
def test_project_without_formatter(self, tmp_path):
|
||||
"""Should handle project without detectable formatter."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "no-formatter",
|
||||
"devDependencies": {"jest": "^29.0.0"}
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps({"name": "no-formatter", "devDependencies": {"jest": "^29.0.0"}})
|
||||
)
|
||||
|
||||
detected = detect_project(tmp_path)
|
||||
|
||||
|
|
@ -868,9 +872,11 @@ class TestE2ECLIFlags:
|
|||
printed_messages.append(str(msg))
|
||||
|
||||
from codeflash.cli_cmds import console
|
||||
|
||||
monkeypatch.setattr(console.console, "print", mock_print)
|
||||
|
||||
from codeflash.cli_cmds.cli import _handle_show_config
|
||||
|
||||
_handle_show_config()
|
||||
|
||||
# Verify config path is displayed
|
||||
|
|
@ -889,9 +895,11 @@ class TestE2ECLIFlags:
|
|||
printed_messages.append(str(msg))
|
||||
|
||||
from codeflash.cli_cmds import console
|
||||
|
||||
monkeypatch.setattr(console.console, "print", mock_print)
|
||||
|
||||
from codeflash.cli_cmds.cli import _handle_show_config
|
||||
|
||||
_handle_show_config()
|
||||
|
||||
# Verify no config path line is displayed
|
||||
|
|
|
|||
|
|
@ -27,19 +27,14 @@ class TestIsFirstRun:
|
|||
|
||||
def test_returns_false_when_pyproject_config_exists(self, tmp_path):
|
||||
"""Should return False when codeflash config exists in pyproject.toml."""
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
'[tool.codeflash]\nmodule-root = "src"'
|
||||
)
|
||||
(tmp_path / "pyproject.toml").write_text('[tool.codeflash]\nmodule-root = "src"')
|
||||
|
||||
result = is_first_run(tmp_path)
|
||||
assert result is False
|
||||
|
||||
def test_returns_false_when_package_json_config_exists(self, tmp_path):
|
||||
"""Should return False when codeflash config exists in package.json."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"codeflash": {"moduleRoot": "src"}
|
||||
}))
|
||||
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "codeflash": {"moduleRoot": "src"}}))
|
||||
|
||||
result = is_first_run(tmp_path)
|
||||
assert result is False
|
||||
|
|
@ -109,11 +104,7 @@ class TestHandleFirstRun:
|
|||
|
||||
existing_args = Namespace(custom_flag=True, module_root=None)
|
||||
|
||||
result = handle_first_run(
|
||||
args=existing_args,
|
||||
skip_confirm=True,
|
||||
skip_api_key=True,
|
||||
)
|
||||
result = handle_first_run(args=existing_args, skip_confirm=True, skip_api_key=True)
|
||||
|
||||
assert result is not None
|
||||
assert result.custom_flag is True # Preserved
|
||||
|
|
@ -229,9 +220,7 @@ class TestFirstRunIntegration:
|
|||
def test_full_python_first_run(self, tmp_path, monkeypatch):
|
||||
"""Should complete full first-run for Python project."""
|
||||
# Create Python project
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
'[project]\nname = "myapp"\n\n[tool.ruff]\nline-length = 120'
|
||||
)
|
||||
(tmp_path / "pyproject.toml").write_text('[project]\nname = "myapp"\n\n[tool.ruff]\nline-length = 120')
|
||||
pkg_dir = tmp_path / "myapp"
|
||||
pkg_dir.mkdir()
|
||||
(pkg_dir / "__init__.py").write_text("")
|
||||
|
|
@ -257,10 +246,9 @@ class TestFirstRunIntegration:
|
|||
def test_full_javascript_first_run(self, tmp_path, monkeypatch):
|
||||
"""Should complete full first-run for JavaScript project."""
|
||||
# Create JS project
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "myapp",
|
||||
"devDependencies": {"jest": "^29.0.0"}
|
||||
}, indent=2))
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps({"name": "myapp", "devDependencies": {"jest": "^29.0.0"}}, indent=2)
|
||||
)
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "tests").mkdir()
|
||||
|
||||
|
|
@ -277,9 +265,7 @@ class TestFirstRunIntegration:
|
|||
def test_subsequent_run_uses_saved_config(self, tmp_path, monkeypatch):
|
||||
"""After first run, subsequent runs should not trigger first-run."""
|
||||
# Create project with existing config
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
'[tool.codeflash]\nmodule-root = "src"'
|
||||
)
|
||||
(tmp_path / "pyproject.toml").write_text('[tool.codeflash]\nmodule-root = "src"')
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue