Merge remote-tracking branch 'origin/main' into omni-java

# Conflicts:
#	.claude/rules/architecture.md
#	.claude/rules/code-style.md
#	.github/workflows/claude.yml
#	.github/workflows/duplicate-code-detector.yml
#	codeflash/api/aiservice.py
#	codeflash/cli_cmds/console.py
#	codeflash/cli_cmds/logging_config.py
#	codeflash/code_utils/deduplicate_code.py
#	codeflash/discovery/discover_unit_tests.py
#	codeflash/languages/base.py
#	codeflash/languages/code_replacer.py
#	codeflash/languages/javascript/mocha_runner.py
#	codeflash/languages/javascript/support.py
#	codeflash/languages/python/support.py
#	codeflash/optimization/function_optimizer.py
#	codeflash/verification/parse_test_output.py
#	codeflash/verification/verification_utils.py
#	codeflash/verification/verifier.py
#	packages/codeflash/package-lock.json
#	packages/codeflash/package.json
#	tests/languages/javascript/test_support_dispatch.py
#	tests/test_codeflash_capture.py
#	tests/test_languages/test_javascript_test_runner.py
#	tests/test_multi_file_code_replacement.py
This commit is contained in:
Kevin Turcios 2026-03-04 01:52:32 -05:00
commit eceac13fc3
85 changed files with 1230 additions and 2771 deletions

View file

@ -9,3 +9,5 @@
- **Naming**: NEVER use leading underscores (`_function_name`) - Python has no true private functions, use public names
- **Paths**: Always use absolute paths
- **Encoding**: Always pass `encoding="utf-8"` to `open()`, `read_text()`, `write_text()`, etc. in new or changed code — Windows defaults to `cp1252` which breaks on non-ASCII content. Don't flag pre-existing code that lacks it unless you're already modifying that line.
- **Pre-commit**: Run `uv run prek` before committing — fix any issues before creating the commit
- **Pre-push**: Before pushing, run `uv run prek run --from-ref origin/<base>` to check all changed files against the PR base — this matches CI behavior and catches issues that per-commit prek misses. To detect the base branch: `gh pr view --json baseRefName -q .baseRefName 2>/dev/null || echo main`

View file

@ -1,8 +1,20 @@
name: Claude Code
on:
workflow_dispatch:
pull_request:
types: [opened, synchronize, ready_for_review, reopened]
paths-ignore:
- '.github/workflows/**'
- '*.md'
- 'docs/**'
- 'demos/**'
- 'experiments/**'
- 'LICENSE'
- '.tessl/**'
- 'code_to_optimize/**'
- 'codeflash.code-workspace'
- 'uv.lock'
issue_comment:
types: [created]
pull_request_review_comment:
@ -16,10 +28,16 @@ jobs:
# Automatic PR review (can fix linting issues and push)
# Blocked for fork PRs to prevent malicious code execution
pr-review:
concurrency:
group: pr-review-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
if: |
(
github.event_name == 'pull_request' &&
github.actor != 'claude[bot]' &&
github.event.sender.login != 'claude[bot]' &&
github.event.pull_request.head.repo.full_name == github.repository
) ||
github.event_name == 'workflow_dispatch'
runs-on: ubuntu-latest
permissions:
contents: write
@ -32,7 +50,7 @@ jobs:
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.ref }}
ref: ${{ github.event.pull_request.head.ref || github.ref }}
- name: Install uv
uses: astral-sh/setup-uv@v6
@ -54,7 +72,9 @@ jobs:
with:
use_bedrock: "true"
use_sticky_comment: true
track_progress: true
allowed_bots: "claude[bot],codeflash-ai[bot]"
exclude_comments_by_actor: "*[bot]"
prompt: |
<context>
repo: ${{ github.repository }}
@ -68,6 +88,20 @@ jobs:
Post all review findings in a single summary comment only — never as inline PR review comments.
</commitment>
<step name="triage">
Before doing any work, assess the PR scope:
1. Run `gh pr diff ${{ github.event.pull_request.number }} --name-only` to get changed files.
2. Classify as TRIVIAL if ALL changed files are:
- Config/CI files (.github/, .tessl/, *.toml, *.lock, *.json, *.yml, *.yaml)
- Documentation (*.md, docs/)
- Non-production code (demos/, experiments/, code_to_optimize/)
- Only whitespace, formatting, or comment changes
If TRIVIAL: post a single comment "No substantive code changes to review." and stop — do not execute any further steps.
Otherwise: continue with the full review below.
</step>
<step name="lint_and_typecheck">
Run checks on files changed in this PR and auto-fix what you can.
@ -109,6 +143,33 @@ jobs:
Record findings for the summary comment. Refer to CLAUDE.md for project conventions.
</step>
<step name="duplicate_detection">
Check whether this PR introduces code that duplicates logic already present elsewhere in the repository — including across languages. Focus on finding true duplicates, not just similar-looking code.
1. Get changed source files (excluding tests and config):
`git diff --name-only origin/main...HEAD -- '*.py' '*.js' '*.ts' '*.java' | grep -v -E '(test_|_test\.(py|js|ts)|\.test\.(js|ts)|\.spec\.(js|ts)|conftest\.py|/tests/|/test/|/__tests__/)' | grep -v -E '^(\.github/|code_to_optimize/|\.tessl/|node_modules/)'`
2. For each changed file, read it and identify functions/methods added or substantially modified (longer than 5 lines).
3. Search for duplicates using Grep:
- Same function name defined elsewhere
- 2-3 distinctive operations from the body (specific API calls, algorithm patterns, string literals)
4. Cross-module check: this codebase has parallel modules under `languages/python/`, `languages/javascript/`, and `languages/java/` plus runtimes under `packages/codeflash/runtime/` and `codeflash-java-runtime/`. When a changed file is under one of these areas, search the others for equivalent logic. Only flag cases where the logic is genuinely shared or one module could import from the other.
5. When a Grep hit looks promising, read the full function and compare semantics. Flag only:
- Same function with same/very similar body in another module
- Same helper logic repeated in sibling files
- Same logic implemented inline across multiple classes
- Same algorithm reimplemented across language modules (Python code, not target-language differences)
Report at most 5 findings with confidence (HIGH/MEDIUM), locations, what's duplicated, and suggestion.
DO NOT report: boilerplate, functions under 5 lines, config/setup, intentional polymorphism, test files, imports, code that must differ due to target-language semantics.
If no duplicates found, include "No duplicates detected" in the summary.
</step>
<step name="coverage">
Analyze test coverage for changed files:
@ -120,30 +181,17 @@ jobs:
</step>
<step name="summary_comment">
Post exactly one summary comment containing all results from previous steps.
Post exactly one summary comment containing all results from previous steps using this format:
To ensure one comment: find an existing claude[bot] comment and update it, or create one if none exists.
Delete any duplicate claude[bot] comments.
```
gh api repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments --jq '.[] | select(.user.login == "claude[bot]") | .id' | head -1
```
Format:
## PR Review Summary
### Prek Checks
### Code Review
### Duplicate Detection
### Test Coverage
---
*Last updated: <timestamp>*
</step>
<step name="simplify">
Run /simplify to review recently changed code for reuse, quality, and efficiency opportunities.
If improvements are found, commit with "refactor: simplify <description>" and push.
Only make behavior-preserving changes.
</step>
<step name="merge_optimization_prs">
Check for open PRs from codeflash-ai[bot]:
`gh pr list --author "codeflash-ai[bot]" --state open --json number,title,headRefName,createdAt,mergeable`
@ -165,12 +213,15 @@ jobs:
- All findings are in a single summary comment (no inline review comments were created)
- If fixes were made, they were verified with prek
</verification>
claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit,Skill"'
claude_args: '--model us.anthropic.claude-sonnet-4-6 --allowedTools "Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh pr close:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"'
additional_permissions: |
actions: read
# @claude mentions (can edit and push) - restricted to maintainers only
claude-mention:
concurrency:
group: claude-mention-${{ github.event.issue.number || github.event.pull_request.number || github.run_id }}
cancel-in-progress: false
if: |
(
github.event_name == 'issue_comment' &&
@ -240,6 +291,6 @@ jobs:
uses: anthropics/claude-code-action@v1
with:
use_bedrock: "true"
claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"'
claude_args: '--model us.anthropic.claude-sonnet-4-6 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"'
additional_permissions: |
actions: read

View file

@ -1,119 +0,0 @@
name: Duplicate Code Detector
on:
workflow_dispatch:
pull_request:
types: [opened, synchronize]
jobs:
detect-duplicates:
if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name == 'workflow_dispatch'
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
issues: write
id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.ref || github.ref }}
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }}
aws-region: ${{ secrets.AWS_REGION }}
- name: Get changed source files
id: changed-files
run: |
FILES=$(git diff --name-only origin/main...HEAD -- '*.py' '*.js' '*.ts' '*.java' \
| grep -v -E '(test_|_test\.(py|js|ts)|\.test\.(js|ts)|\.spec\.(js|ts)|conftest\.py|/tests/|/test/|/__tests__/)' \
| grep -v -E '^(\.github/|code_to_optimize/|\.tessl/|node_modules/)' \
|| true)
if [ -z "$FILES" ]; then
echo "files=" >> "$GITHUB_OUTPUT"
echo "No changed source files to analyze."
else
echo "files<<EOF" >> "$GITHUB_OUTPUT"
echo "$FILES" >> "$GITHUB_OUTPUT"
echo "EOF" >> "$GITHUB_OUTPUT"
echo "Changed files:"
echo "$FILES"
fi
- name: Run Claude Code
if: steps.changed-files.outputs.files != ''
uses: anthropics/claude-code-action@v1
with:
use_bedrock: "true"
use_sticky_comment: true
allowed_bots: "claude[bot],codeflash-ai[bot]"
claude_args: '--allowedTools "Read,Glob,Grep,Bash(git diff:*),Bash(git log:*),Bash(git show:*),Bash(wc *),Bash(gh pr comment:*)"'
prompt: |
REPO: ${{ github.repository }}
PR NUMBER: ${{ github.event.pull_request.number }}
You are a duplicate code detector for a multi-language codebase (Python, JavaScript, TypeScript, Java). Check whether this PR introduces code that duplicates logic already present elsewhere in the repository — including across languages. Focus on finding true duplicates, not just similar-looking code.
## Changed files
```
${{ steps.changed-files.outputs.files }}
```
## Steps
1. **Read changed files.** For each file above, read it and identify functions or methods that were added or substantially modified (longer than 5 lines).
2. **Search for duplicates.** For each function, use Grep to search the codebase for:
- The same function name defined elsewhere (`def function_name` for Python, `function function_name` / `const function_name` / `module.exports` for the JS files under `packages/`)
- 2-3 distinctive operations from the body (specific API calls, algorithm patterns, string literals, exception types) — this catches duplicates that have different names but implement the same logic
3. **Cross-module check.** This codebase has parallel Python modules under `languages/python/`, `languages/javascript/`, and `languages/java/` that handle the same concerns (parsing, code replacement, test running, etc.) for different target languages. It also has a JS runtime under `packages/codeflash/runtime/` and a Java runtime under `codeflash-java-runtime/`. When a changed file is under one of these areas, also search the others for equivalent logic. For example:
- `languages/javascript/code_replacer.py` and `languages/python/static_analysis/code_replacer.py` both handle code replacement — shared logic should be extracted
- Shared concepts (AST traversal, scope analysis, import resolution, test running) are prime candidates for duplication across these modules
4. **Compare candidates.** When a Grep hit looks promising (not just a shared import or call site), read the full function and compare semantics. Flag it only if it matches one of these patterns:
- **Same function in two modules** — a function with the same or very similar body exists in another module. One should import from the other instead (within the same language).
- **Shared logic across sibling files** — the same helper logic repeated in files within the same package. Should be extracted to a common module.
- **Repeated pattern across classes** — multiple classes implement the same logic inline (e.g., identical traversal, identical validation). Should be a mixin or shared helper.
- **Cross-module reimplementation** — the same algorithm or utility implemented in both `languages/python/` and `languages/javascript/` (both are Python) or between Python orchestration code and JS runtime code in `packages/`. Note: some duplication is unavoidable (each target language needs its own parser, for example). Only flag cases where the logic is genuinely shared or where one module could import from the other.
5. **Report findings.** Post a single PR comment. Report at most 5 findings.
**If duplicates found**, for each one:
- **Confidence**: HIGH (identical or near-identical logic) / MEDIUM (same intent, minor differences worth reviewing)
- **Locations**: `file_path:line_number` for both the new and existing code
- **What's duplicated**: One sentence describing the shared logic
- **Suggestion**: How to consolidate — import from canonical location, extract to shared module, create a mixin. For cross-module duplicates (between language directories or Python↔JS runtime), just flag it for a tech lead to review rather than prescribing a specific fix.
**If no duplicates found**, post a comment that just says "No duplicates detected." so the sticky comment gets updated.
## Examples (illustrative — these are past cases, some already resolved)
**IS a duplicate (HIGH):** A 12-line `is_build_output_dir()` function was defined identically in two modules (`setup/detector.py` and `code_utils/config_js.py`). Fix: delete one, import from the other.
**IS a duplicate (MEDIUM):** `is_assignment_used()` was implemented separately in two context files with the same logic. Fix: move to a shared module, import from both call sites.
**IS a duplicate (MEDIUM, cross-module):** `normalize_path()` implemented in both `languages/python/support.py` and `languages/javascript/support.py` with identical logic. Flagging for tech lead review — should likely be extracted to `languages/base.py` or a shared utility.
**NOT a duplicate:** Two classes each define a `visit()` method that traverses an AST, but they handle different node types and produce different outputs. This is intentional polymorphism.
**NOT a duplicate (cross-module):** `languages/python/static_analysis/code_extractor.py` and `languages/javascript/parse.py` both extract functions from source code, but they use fundamentally different parsing strategies (Python AST vs tree-sitter). The logic is necessarily different.
## DO NOT report
- Standard boilerplate (`__init__`, `__repr__`, `__str__`, `__eq__`, simple property accessors, constructors)
- Functions under 5 lines
- Config/setup code that naturally has similar structure
- Intentional polymorphism (same method name, genuinely different behavior)
- Test files, conftest files, spec files
- Import statements and logging setup
- Files under `.github/`, `code_to_optimize/`, `.tessl/`
- Code across language modules that must differ due to target-language semantics (parsers, AST node types, runtime-specific APIs)
Do NOT create issues or edit any files. Only post a PR comment.

View file

@ -1,126 +0,0 @@
"""Code deduplication utilities using language-specific normalizers.
This module provides functions to normalize code, generate fingerprints,
and detect duplicate code segments across different programming languages.
"""
from __future__ import annotations
import hashlib
import re
from codeflash.code_utils.normalizers import get_normalizer
from codeflash.languages import current_language
def normalize_code(
code: str, remove_docstrings: bool = True, return_ast_dump: bool = False, language: str | None = None
) -> str:
"""Normalize code by parsing, cleaning, and normalizing variable names.
Function names, class names, and parameters are preserved.
Args:
code: Source code as string
remove_docstrings: Whether to remove docstrings (Python only)
return_ast_dump: Return AST dump instead of unparsed code (Python only)
language: Language of the code. If None, uses the current session language.
Returns:
Normalized code as string
"""
if language is None:
language = current_language().value
try:
normalizer = get_normalizer(language)
if return_ast_dump:
return normalizer.normalize_for_hash(code)
# Only Python normalizer accepts remove_docstrings; pass it via **kwargs
# so non-Python normalizers (which don't accept it) still work
try:
return normalizer.normalize(code, remove_docstrings=remove_docstrings)
except TypeError:
return normalizer.normalize(code)
except ValueError:
# Unknown language - fall back to basic normalization
return _basic_normalize(code)
except Exception:
# Parsing error - try other languages or fall back
if language == "python":
# Try JavaScript as fallback
try:
js_normalizer = get_normalizer("javascript")
js_result = js_normalizer.normalize(code)
if js_result != _basic_normalize(code):
return js_result
except Exception:
pass
return _basic_normalize(code)
def _basic_normalize(code: str) -> str:
"""Basic normalization: remove comments and normalize whitespace."""
# Remove single-line comments (// and #)
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
code = re.sub(r"#.*$", "", code, flags=re.MULTILINE)
# Remove multi-line comments
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
code = re.sub(r'""".*?"""', "", code, flags=re.DOTALL)
code = re.sub(r"'''.*?'''", "", code, flags=re.DOTALL)
# Normalize whitespace
return " ".join(code.split())
def get_code_fingerprint(code: str, language: str | None = None) -> str:
"""Generate a fingerprint for normalized code.
Args:
code: Source code
language: Language of the code. If None, uses the current session language.
Returns:
SHA-256 hash of normalized code
"""
if language is None:
language = current_language().value
try:
normalizer = get_normalizer(language)
return normalizer.get_fingerprint(code)
except ValueError:
# Unknown language - use basic normalization
normalized = _basic_normalize(code)
return hashlib.sha256(normalized.encode()).hexdigest()
def are_codes_duplicate(code1: str, code2: str, language: str | None = None) -> bool:
"""Check if two code segments are duplicates after normalization.
Args:
code1: First code segment
code2: Second code segment
language: Language of the code. If None, uses the current session language.
Returns:
True if codes are structurally identical (ignoring local variable names)
"""
if language is None:
language = current_language().value
try:
normalizer = get_normalizer(language)
return normalizer.are_duplicates(code1, code2)
except ValueError:
# Unknown language - use basic comparison
return _basic_normalize(code1) == _basic_normalize(code2)
except Exception:
return False
# Re-export for backward compatibility
__all__ = ["are_codes_duplicate", "get_code_fingerprint", "normalize_code"]

View file

@ -1,106 +0,0 @@
"""Code normalizers for different programming languages.
This module provides language-specific code normalizers that transform source code
into canonical forms for duplicate detection. The normalizers:
- Replace local variable names with canonical forms (var_0, var_1, etc.)
- Preserve function names, class names, parameters, and imports
- Remove or normalize comments and docstrings
- Produce consistent output for structurally identical code
Usage:
>>> normalizer = get_normalizer("python")
>>> normalized = normalizer.normalize(code)
>>> fingerprint = normalizer.get_fingerprint(code)
>>> are_same = normalizer.are_duplicates(code1, code2)
"""
from __future__ import annotations
from codeflash.code_utils.normalizers.base import CodeNormalizer
from codeflash.code_utils.normalizers.javascript import JavaScriptNormalizer, TypeScriptNormalizer
from codeflash.code_utils.normalizers.python import PythonNormalizer
__all__ = [
"CodeNormalizer",
"JavaScriptNormalizer",
"PythonNormalizer",
"TypeScriptNormalizer",
"get_normalizer",
"get_normalizer_for_extension",
]
# Registry of normalizers by language
_NORMALIZERS: dict[str, type[CodeNormalizer]] = {
"python": PythonNormalizer,
"javascript": JavaScriptNormalizer,
"typescript": TypeScriptNormalizer,
}
# Singleton cache for normalizer instances
_normalizer_instances: dict[str, CodeNormalizer] = {}
def get_normalizer(language: str) -> CodeNormalizer:
"""Get a code normalizer for the specified language.
Args:
language: Language name ('python', 'javascript', 'typescript')
Returns:
CodeNormalizer instance for the language
Raises:
ValueError: If no normalizer exists for the language
"""
language = language.lower()
# Check cache first
if language in _normalizer_instances:
return _normalizer_instances[language]
# Get normalizer class
if language not in _NORMALIZERS:
supported = ", ".join(sorted(_NORMALIZERS.keys()))
msg = f"No normalizer available for language '{language}'. Supported: {supported}"
raise ValueError(msg)
# Create and cache instance
normalizer = _NORMALIZERS[language]()
_normalizer_instances[language] = normalizer
return normalizer
def get_normalizer_for_extension(extension: str) -> CodeNormalizer | None:
"""Get a code normalizer based on file extension.
Args:
extension: File extension including dot (e.g., '.py', '.js')
Returns:
CodeNormalizer instance if found, None otherwise
"""
extension = extension.lower()
if not extension.startswith("."):
extension = f".{extension}"
for language in _NORMALIZERS:
normalizer = get_normalizer(language)
if extension in normalizer.supported_extensions:
return normalizer
return None
def register_normalizer(language: str, normalizer_class: type[CodeNormalizer]) -> None:
"""Register a new normalizer for a language.
Args:
language: Language name
normalizer_class: CodeNormalizer subclass
"""
_NORMALIZERS[language.lower()] = normalizer_class
# Clear cached instance if it exists
_normalizer_instances.pop(language.lower(), None)

View file

@ -1,104 +0,0 @@
"""Abstract base class for code normalizers.
Code normalizers transform source code into a canonical form for duplicate detection.
They normalize variable names, remove comments/docstrings, and produce consistent output
that can be compared across different implementations of the same algorithm.
"""
# TODO:{claude} move to base.py in language folder
from __future__ import annotations
from abc import ABC, abstractmethod
class CodeNormalizer(ABC):
"""Abstract base class for language-specific code normalizers.
Subclasses must implement the normalize() method for their specific language.
The normalization should:
- Normalize local variable names to canonical forms (var_0, var_1, etc.)
- Preserve function names, class names, parameters, and imports
- Remove or normalize comments and docstrings
- Produce consistent output for structurally identical code
Example:
>>> normalizer = PythonNormalizer()
>>> code1 = "def foo(x): y = x + 1; return y"
>>> code2 = "def foo(x): z = x + 1; return z"
>>> normalizer.normalize(code1) == normalizer.normalize(code2)
True
"""
@property
@abstractmethod
def language(self) -> str:
"""Return the language this normalizer handles."""
...
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return ()
@abstractmethod
def normalize(self, code: str) -> str:
"""Normalize code to a canonical form for comparison.
Args:
code: Source code to normalize
Returns:
Normalized representation of the code
"""
...
@abstractmethod
def normalize_for_hash(self, code: str) -> str:
"""Normalize code optimized for hashing/fingerprinting.
This may return a more compact representation than normalize().
Args:
code: Source code to normalize
Returns:
Normalized representation suitable for hashing
"""
...
def are_duplicates(self, code1: str, code2: str) -> bool:
"""Check if two code segments are duplicates after normalization.
Args:
code1: First code segment
code2: Second code segment
Returns:
True if codes are structurally identical
"""
try:
normalized1 = self.normalize_for_hash(code1)
normalized2 = self.normalize_for_hash(code2)
except Exception:
return False
else:
return normalized1 == normalized2
def get_fingerprint(self, code: str) -> str:
"""Generate a fingerprint hash for normalized code.
Args:
code: Source code to fingerprint
Returns:
SHA-256 hash of normalized code
"""
import hashlib
normalized = self.normalize_for_hash(code)
return hashlib.sha256(normalized.encode()).hexdigest()

View file

@ -1,290 +0,0 @@
"""JavaScript/TypeScript code normalizer using tree-sitter."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from codeflash.code_utils.normalizers.base import CodeNormalizer
if TYPE_CHECKING:
from tree_sitter import Node
# TODO:{claude} move to language support directory to keep the directory structure clean
class JavaScriptVariableNormalizer:
"""Normalizes JavaScript/TypeScript code for duplicate detection using tree-sitter.
Normalizes local variable names while preserving function names, class names,
parameters, and imported names.
"""
def __init__(self) -> None:
self.var_counter = 0
self.var_mapping: dict[str, str] = {}
self.preserved_names: set[str] = set()
# Common JavaScript builtins
self.builtins = {
"console",
"window",
"document",
"Math",
"JSON",
"Object",
"Array",
"String",
"Number",
"Boolean",
"Date",
"RegExp",
"Error",
"Promise",
"Map",
"Set",
"WeakMap",
"WeakSet",
"Symbol",
"Proxy",
"Reflect",
"undefined",
"null",
"NaN",
"Infinity",
"globalThis",
"parseInt",
"parseFloat",
"isNaN",
"isFinite",
"eval",
"setTimeout",
"setInterval",
"clearTimeout",
"clearInterval",
"fetch",
"require",
"module",
"exports",
"process",
"__dirname",
"__filename",
"Buffer",
}
def get_normalized_name(self, name: str) -> str:
"""Get or create normalized name for a variable."""
if name in self.builtins or name in self.preserved_names:
return name
if name not in self.var_mapping:
self.var_mapping[name] = f"var_{self.var_counter}"
self.var_counter += 1
return self.var_mapping[name]
def collect_preserved_names(self, node: Node, source_code: bytes) -> None:
"""Collect names that should be preserved (function names, class names, imports, params)."""
# Function declarations and expressions - preserve the function name
if node.type in ("function_declaration", "function_expression", "method_definition", "arrow_function"):
name_node = node.child_by_field_name("name")
if name_node:
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
# Preserve parameters
params_node = node.child_by_field_name("parameters") or node.child_by_field_name("parameter")
if params_node:
self._collect_parameter_names(params_node, source_code)
# Class declarations
elif node.type == "class_declaration":
name_node = node.child_by_field_name("name")
if name_node:
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
# Import declarations
elif node.type in ("import_statement", "import_declaration"):
for child in node.children:
if child.type == "import_clause":
self._collect_import_names(child, source_code)
elif child.type == "identifier":
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
# Recurse
for child in node.children:
self.collect_preserved_names(child, source_code)
def _collect_parameter_names(self, node: Node, source_code: bytes) -> None:
"""Collect parameter names from a parameters node."""
for child in node.children:
if child.type == "identifier":
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
elif child.type in ("required_parameter", "optional_parameter", "rest_parameter"):
pattern_node = child.child_by_field_name("pattern")
if pattern_node and pattern_node.type == "identifier":
self.preserved_names.add(
source_code[pattern_node.start_byte : pattern_node.end_byte].decode("utf-8")
)
# Recurse for nested patterns
self._collect_parameter_names(child, source_code)
def _collect_import_names(self, node: Node, source_code: bytes) -> None:
"""Collect imported names from import clause."""
for child in node.children:
if child.type == "identifier":
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
elif child.type == "import_specifier":
# Get the local name (alias or original)
alias_node = child.child_by_field_name("alias")
name_node = child.child_by_field_name("name")
if alias_node:
self.preserved_names.add(source_code[alias_node.start_byte : alias_node.end_byte].decode("utf-8"))
elif name_node:
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
self._collect_import_names(child, source_code)
def normalize_tree(self, node: Node, source_code: bytes) -> str:
"""Normalize the AST tree to a string representation for comparison."""
parts: list[str] = []
self._normalize_node(node, source_code, parts)
return " ".join(parts)
def _normalize_node(self, node: Node, source_code: bytes, parts: list[str]) -> None:
"""Recursively normalize a node."""
# Skip comments
if node.type in ("comment", "line_comment", "block_comment"):
return
# Handle identifiers - normalize variable names
if node.type == "identifier":
name = source_code[node.start_byte : node.end_byte].decode("utf-8")
normalized = self.get_normalized_name(name)
parts.append(normalized)
return
# Handle type identifiers (TypeScript) - preserve as-is
if node.type == "type_identifier":
parts.append(source_code[node.start_byte : node.end_byte].decode("utf-8"))
return
# Handle string literals - normalize to placeholder
if node.type in ("string", "template_string", "string_fragment"):
parts.append('"STR"')
return
# Handle number literals - normalize to placeholder
if node.type == "number":
parts.append("NUM")
return
# For leaf nodes, output the node type
if len(node.children) == 0:
text = source_code[node.start_byte : node.end_byte].decode("utf-8")
parts.append(text)
return
# Output node type for structure
parts.append(f"({node.type}")
# Recurse into children
for child in node.children:
self._normalize_node(child, source_code, parts)
parts.append(")")
def _basic_normalize(code: str) -> str:
"""Basic normalization: remove comments and normalize whitespace."""
# Remove single-line comments
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
# Remove multi-line comments
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
# Normalize whitespace
return " ".join(code.split())
class JavaScriptNormalizer(CodeNormalizer):
"""JavaScript code normalizer using tree-sitter.
Normalizes JavaScript code by:
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
- Preserving function names, class names, parameters, and imports
- Removing comments
- Normalizing string and number literals
"""
@property
def language(self) -> str:
"""Return the language this normalizer handles."""
return "javascript"
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return (".js", ".jsx", ".mjs", ".cjs")
def _get_tree_sitter_language(self) -> str:
"""Get the tree-sitter language identifier."""
return "javascript"
def normalize(self, code: str) -> str:
"""Normalize JavaScript code to a canonical form.
Args:
code: JavaScript source code to normalize
Returns:
Normalized representation of the code
"""
try:
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
lang_map = {"javascript": TreeSitterLanguage.JAVASCRIPT, "typescript": TreeSitterLanguage.TYPESCRIPT}
lang = lang_map.get(self._get_tree_sitter_language(), TreeSitterLanguage.JAVASCRIPT)
analyzer = TreeSitterAnalyzer(lang)
tree = analyzer.parse(code)
if tree.root_node.has_error:
return _basic_normalize(code)
normalizer = JavaScriptVariableNormalizer()
source_bytes = code.encode("utf-8")
# First pass: collect preserved names
normalizer.collect_preserved_names(tree.root_node, source_bytes)
# Second pass: normalize and build representation
return normalizer.normalize_tree(tree.root_node, source_bytes)
except Exception:
return _basic_normalize(code)
def normalize_for_hash(self, code: str) -> str:
"""Normalize JavaScript code optimized for hashing.
For JavaScript, this is the same as normalize().
Args:
code: JavaScript source code to normalize
Returns:
Normalized representation suitable for hashing
"""
return self.normalize(code)
class TypeScriptNormalizer(JavaScriptNormalizer):
"""TypeScript code normalizer using tree-sitter.
Inherits from JavaScriptNormalizer and overrides language-specific settings.
"""
@property
def language(self) -> str:
"""Return the language this normalizer handles."""
return "typescript"
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return (".ts", ".tsx", ".mts", ".cts")
def _get_tree_sitter_language(self) -> str:
"""Get the tree-sitter language identifier."""
return "typescript"

View file

@ -1,226 +0,0 @@
"""Python code normalizer using AST transformation."""
from __future__ import annotations
import ast
from codeflash.code_utils.normalizers.base import CodeNormalizer
class VariableNormalizer(ast.NodeTransformer):
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
Preserves function names, class names, parameters, built-ins, and imported names.
"""
def __init__(self) -> None:
self.var_counter = 0
self.var_mapping: dict[str, str] = {}
self.scope_stack: list[dict] = []
self.builtins = set(dir(__builtins__))
self.imports: set[str] = set()
self.global_vars: set[str] = set()
self.nonlocal_vars: set[str] = set()
self.parameters: set[str] = set()
def enter_scope(self) -> None:
"""Enter a new scope (function/class)."""
self.scope_stack.append(
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
)
def exit_scope(self) -> None:
"""Exit current scope and restore parent scope."""
if self.scope_stack:
scope = self.scope_stack.pop()
self.var_mapping = scope["var_mapping"]
self.var_counter = scope["var_counter"]
self.parameters = scope["parameters"]
def get_normalized_name(self, name: str) -> str:
"""Get or create normalized name for a variable."""
if (
name in self.builtins
or name in self.imports
or name in self.global_vars
or name in self.nonlocal_vars
or name in self.parameters
):
return name
if name not in self.var_mapping:
self.var_mapping[name] = f"var_{self.var_counter}"
self.var_counter += 1
return self.var_mapping[name]
def visit_Import(self, node: ast.Import) -> ast.Import:
"""Track imported names."""
for alias in node.names:
name = alias.asname if alias.asname else alias.name
self.imports.add(name.split(".")[0])
return node
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
"""Track imported names from modules."""
for alias in node.names:
name = alias.asname if alias.asname else alias.name
self.imports.add(name)
return node
def visit_Global(self, node: ast.Global) -> ast.Global:
"""Track global variable declarations."""
self.global_vars.update(node.names)
return node
def visit_Nonlocal(self, node: ast.Nonlocal) -> ast.Nonlocal:
"""Track nonlocal variable declarations."""
self.nonlocal_vars.update(node.names)
return node
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
"""Process function but keep function name and parameters unchanged."""
self.enter_scope()
for arg in node.args.args:
self.parameters.add(arg.arg)
if node.args.vararg:
self.parameters.add(node.args.vararg.arg)
if node.args.kwarg:
self.parameters.add(node.args.kwarg.arg)
for arg in node.args.kwonlyargs:
self.parameters.add(arg.arg)
node = self.generic_visit(node)
self.exit_scope()
return node
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
"""Handle async functions same as regular functions."""
return self.visit_FunctionDef(node) # type: ignore[return-value]
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
"""Process class but keep class name unchanged."""
self.enter_scope()
node = self.generic_visit(node)
self.exit_scope()
return node
def visit_Name(self, node: ast.Name) -> ast.Name:
"""Normalize variable names in Name nodes."""
if isinstance(node.ctx, (ast.Store, ast.Del)):
if (
node.id not in self.builtins
and node.id not in self.imports
and node.id not in self.parameters
and node.id not in self.global_vars
and node.id not in self.nonlocal_vars
):
node.id = self.get_normalized_name(node.id)
elif isinstance(node.ctx, ast.Load) and node.id in self.var_mapping:
node.id = self.var_mapping[node.id]
return node
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> ast.ExceptHandler:
"""Normalize exception variable names."""
if node.name:
node.name = self.get_normalized_name(node.name)
return self.generic_visit(node)
def visit_comprehension(self, node: ast.comprehension) -> ast.comprehension:
"""Normalize comprehension target variables."""
old_mapping = dict(self.var_mapping)
old_counter = self.var_counter
node = self.generic_visit(node)
self.var_mapping = old_mapping
self.var_counter = old_counter
return node
def visit_For(self, node: ast.For) -> ast.For:
"""Handle for loop target variables."""
return self.generic_visit(node)
def visit_With(self, node: ast.With) -> ast.With:
"""Handle with statement as variables."""
return self.generic_visit(node)
def _remove_docstrings_from_ast(node: ast.AST) -> None:
"""Remove docstrings from AST nodes."""
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
stack = [node]
while stack:
current_node = stack.pop()
if isinstance(current_node, node_types):
body = current_node.body
if (
body
and isinstance(body[0], ast.Expr)
and isinstance(body[0].value, ast.Constant)
and isinstance(body[0].value.value, str)
):
current_node.body = body[1:]
stack.extend([child for child in body if isinstance(child, node_types)])
class PythonNormalizer(CodeNormalizer):
"""Python code normalizer using AST transformation.
Normalizes Python code by:
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
- Preserving function names, class names, parameters, and imports
- Optionally removing docstrings
"""
@property
def language(self) -> str:
"""Return the language this normalizer handles."""
return "python"
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return (".py", ".pyw", ".pyi")
def normalize(self, code: str, remove_docstrings: bool = True) -> str:
"""Normalize Python code to a canonical form.
Args:
code: Python source code to normalize
remove_docstrings: Whether to remove docstrings
Returns:
Normalized Python code as a string
"""
tree = ast.parse(code)
if remove_docstrings:
_remove_docstrings_from_ast(tree)
normalizer = VariableNormalizer()
normalized_tree = normalizer.visit(tree)
ast.fix_missing_locations(normalized_tree)
return ast.unparse(normalized_tree)
def normalize_for_hash(self, code: str) -> str:
"""Normalize Python code optimized for hashing.
Returns AST dump which is faster than unparsing.
Args:
code: Python source code to normalize
Returns:
AST dump string suitable for hashing
"""
tree = ast.parse(code)
_remove_docstrings_from_ast(tree)
normalizer = VariableNormalizer()
normalized_tree = normalizer.visit(tree)
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)

View file

@ -536,7 +536,9 @@ def run_mocha_benchmarking_tests(
)
mocha_env["CODEFLASH_TEST_MODULE"] = test_module_path
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)
# Subprocess timeout: target_duration + 120s headroom for Mocha startup.
# capturePerf's time budget governs actual looping.
total_timeout = max(120, (target_duration_ms // 1000) + 120, timeout or 120)
logger.debug(f"Running Mocha benchmarking tests: {' '.join(mocha_cmd)}")
logger.debug(

View file

@ -1025,9 +1025,9 @@ def run_jest_benchmarking_tests(
if "--max-old-space-size" not in existing_node_options:
jest_env["NODE_OPTIONS"] = f"{existing_node_options} --max-old-space-size=4096".strip()
# Total timeout for the entire benchmark run (longer than single-loop timeout)
# Account for startup overhead + target duration + buffer
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)
# Subprocess timeout: target_duration + 120s headroom for Jest startup
# and TS compilation. capturePerf's time budget governs actual looping.
total_timeout = max(120, (target_duration_ms // 1000) + 120)
logger.debug(f"Running Jest benchmarking tests with in-process loop runner: {' '.join(jest_cmd)}")
logger.debug(

View file

@ -616,8 +616,10 @@ def run_vitest_benchmarking_tests(
vitest_env["CODEFLASH_TEST_MODULE"] = test_module_path
logger.debug(f"[VITEST-BENCH] Set CODEFLASH_TEST_MODULE={test_module_path}")
# Total timeout for the entire benchmark run
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)
# Subprocess timeout: target_duration + 120s headroom for Vitest startup
# (TS compilation, module resolution). The capturePerf time budget (10s default)
# governs actual looping; this is just a safety net for process-level hangs.
total_timeout = max(120, (target_duration_ms // 1000) + 120)
logger.debug(f"[VITEST-BENCH] Running Vitest benchmarking tests: {' '.join(vitest_cmd)}")
logger.debug(

View file

@ -773,7 +773,13 @@ class InvocationId:
test_src = test_path.read_text(encoding="utf-8")
module_node = cst.parse_module(test_src)
except Exception:
return None
# libcst can't parse non-Python files (JS/TS) — return a descriptive string
# so the code repair API receives a non-None test_src_code.
return (
f"// Test: {self.test_function_name}\n"
f"// File: {test_path.name}\n"
f"// Testing function: {self.function_getting_tested}"
)
if self.test_class_name:
for stmt in module_node.body:

View file

@ -1,2 +1,2 @@
# These version placeholders will be replaced by uv-dynamic-versioning during build.
__version__ = "0.20.1.post242.dev0+7c7eeb5b"
__version__ = "0.20.1.post675.dev0+1218a1cd"

View file

@ -1,12 +1,12 @@
{
"name": "codeflash",
"version": "0.9.0",
"version": "0.10.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "codeflash",
"version": "0.9.0",
"version": "0.10.1",
"hasInstallScript": true,
"license": "MIT",
"dependencies": {

View file

@ -1,6 +1,6 @@
{
"name": "codeflash",
"version": "0.10.0",
"version": "0.10.1",
"description": "Codeflash - AI-powered code optimization for JavaScript and TypeScript",
"main": "runtime/index.js",
"types": "runtime/index.d.ts",

View file

@ -926,11 +926,7 @@ class TestTestFrameworkConfigOverride:
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test-project",
"devDependencies": {"vitest": "^1.0.0"},
"codeflash": {"moduleRoot": "src"},
}
{"name": "test-project", "devDependencies": {"vitest": "^1.0.0"}, "codeflash": {"moduleRoot": "src"}}
)
)
@ -945,11 +941,7 @@ class TestTestFrameworkConfigOverride:
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test-project",
"devDependencies": {"jest": "^29.0.0"},
"codeflash": {"test-framework": ""},
}
{"name": "test-project", "devDependencies": {"jest": "^29.0.0"}, "codeflash": {"test-framework": ""}}
)
)

View file

@ -2,7 +2,10 @@ from __future__ import annotations
from typing import Any
from codeflash.languages.python.static_analysis.coverage_utils import build_fully_qualified_name, extract_dependent_function
from codeflash.languages.python.static_analysis.coverage_utils import (
build_fully_qualified_name,
extract_dependent_function,
)
from codeflash.models.function_types import FunctionParent
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown
from codeflash.verification.coverage_utils import CoverageUtils

View file

@ -9,7 +9,6 @@ from __future__ import annotations
import tempfile
from pathlib import Path
import pytest
from junitparser import JUnitXml
from codeflash.languages.javascript.parse import jest_end_pattern, jest_start_pattern
@ -338,9 +337,7 @@ class TestFilenameBasedLookupFallback:
path2.touch()
test_file1 = TestFile(
original_file_path=path1,
test_type=TestType.GENERATED_REGRESSION,
instrumented_behavior_file_path=path1,
original_file_path=path1, test_type=TestType.GENERATED_REGRESSION, instrumented_behavior_file_path=path1
)
test_file2 = TestFile(
original_file_path=path2,

View file

@ -9,8 +9,6 @@ from __future__ import annotations
import tempfile
from pathlib import Path
import pytest
from codeflash.languages.javascript.vitest_runner import (
_build_vitest_behavioral_command,
_build_vitest_benchmarking_command,

View file

@ -6,9 +6,7 @@ from end_to_end_test_utilities import TestConfig, run_codeflash_command, run_wit
def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
file_path="src/main/java/com/example/Fibonacci.java",
function_name="fibonacci",
min_improvement_x=0.70,
file_path="src/main/java/com/example/Fibonacci.java", function_name="fibonacci", min_improvement_x=0.70
)
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "java").resolve()
return run_codeflash_command(cwd, config, expected_improvement_pct)

View file

@ -1,6 +1,5 @@
"""Tests for cleanup of instrumented test files."""
from pathlib import Path
from codeflash.optimization.optimizer import Optimizer

View file

@ -3809,8 +3809,7 @@ def test_enrich_testgen_context_deduplicates(tmp_path: Path) -> None:
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
(package_dir / "base.py").write_text(
"class Base:\n def __init__(self, x: int):\n self.x = x\n",
encoding="utf-8",
"class Base:\n def __init__(self, x: int):\n self.x = x\n", encoding="utf-8"
)
code = "from mypkg.base import Base\n\nclass A(Base):\n pass\n\nclass B(Base):\n pass\n"
@ -3954,8 +3953,7 @@ def test_testgen_context_includes_external_base_inits(tmp_path: Path) -> None:
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
(package_dir / "base.py").write_text(
"class BaseDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
encoding="utf-8",
"class BaseDict:\n def __init__(self, data=None):\n self.data = data or {}\n", encoding="utf-8"
)
code = "from mypkg.base import BaseDict\n\nclass MyCustomDict(BaseDict):\n def target_method(self):\n return self.data\n"
@ -4009,8 +4007,7 @@ def test_enrich_testgen_context_attribute_base(tmp_path: Path) -> None:
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
(package_dir / "base.py").write_text(
"class CustomDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
encoding="utf-8",
"class CustomDict:\n def __init__(self, data=None):\n self.data = data or {}\n", encoding="utf-8"
)
code = "from mypkg.base import CustomDict\n\nclass MyDict(CustomDict):\n def custom_method(self):\n return self.data\n"
@ -4702,18 +4699,17 @@ def get_log_level() -> str:
assert "class AppConfig:" in combined
assert "@property" in combined
def test_extract_parameter_type_constructors_isinstance_single(tmp_path: Path) -> None:
"""isinstance(x, SomeType) in function body should be picked up."""
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "models.py").write_text(
"class Widget:\n def __init__(self, size: int):\n self.size = size\n",
encoding="utf-8",
"class Widget:\n def __init__(self, size: int):\n self.size = size\n", encoding="utf-8"
)
(pkg / "processor.py").write_text(
"from mypkg.models import Widget\n\ndef check(x) -> bool:\n return isinstance(x, Widget)\n",
encoding="utf-8",
"from mypkg.models import Widget\n\ndef check(x) -> bool:\n return isinstance(x, Widget)\n", encoding="utf-8"
)
fto = FunctionToOptimize(
function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
@ -4754,12 +4750,10 @@ def test_extract_parameter_type_constructors_type_is_pattern(tmp_path: Path) ->
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "models.py").write_text(
"class Gadget:\n def __init__(self, val: float):\n self.val = val\n",
encoding="utf-8",
"class Gadget:\n def __init__(self, val: float):\n self.val = val\n", encoding="utf-8"
)
(pkg / "processor.py").write_text(
"from mypkg.models import Gadget\n\ndef check(x) -> bool:\n return type(x) is Gadget\n",
encoding="utf-8",
"from mypkg.models import Gadget\n\ndef check(x) -> bool:\n return type(x) is Gadget\n", encoding="utf-8"
)
fto = FunctionToOptimize(
function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
@ -4775,8 +4769,7 @@ def test_extract_parameter_type_constructors_base_classes(tmp_path: Path) -> Non
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "base.py").write_text(
"class BaseProcessor:\n def __init__(self, config: str):\n self.config = config\n",
encoding="utf-8",
"class BaseProcessor:\n def __init__(self, config: str):\n self.config = config\n", encoding="utf-8"
)
(pkg / "child.py").write_text(
"from mypkg.base import BaseProcessor\n\nclass ChildProcessor(BaseProcessor):\n"
@ -4801,8 +4794,7 @@ def test_extract_parameter_type_constructors_isinstance_builtins_excluded(tmp_pa
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "func.py").write_text(
"def check(x) -> bool:\n return isinstance(x, (int, str, float))\n",
encoding="utf-8",
"def check(x) -> bool:\n return isinstance(x, (int, str, float))\n", encoding="utf-8"
)
fto = FunctionToOptimize(
function_name="check", file_path=(pkg / "func.py").resolve(), starting_line=1, ending_line=2
@ -4817,8 +4809,7 @@ def test_extract_parameter_type_constructors_transitive(tmp_path: Path) -> None:
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "config.py").write_text(
"class Config:\n def __init__(self, debug: bool = False):\n self.debug = debug\n",
encoding="utf-8",
"class Config:\n def __init__(self, debug: bool = False):\n self.debug = debug\n", encoding="utf-8"
)
(pkg / "models.py").write_text(
"from mypkg.config import Config\n\n"
@ -4826,8 +4817,7 @@ def test_extract_parameter_type_constructors_transitive(tmp_path: Path) -> None:
encoding="utf-8",
)
(pkg / "processor.py").write_text(
"from mypkg.models import Widget\n\ndef process(w: Widget) -> str:\n return str(w)\n",
encoding="utf-8",
"from mypkg.models import Widget\n\ndef process(w: Widget) -> str:\n return str(w)\n", encoding="utf-8"
)
fto = FunctionToOptimize(
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
@ -4838,8 +4828,6 @@ def test_extract_parameter_type_constructors_transitive(tmp_path: Path) -> None:
assert "class Config:" in combined
def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None:
"""Third-party classes should produce compact __init__ stubs, not full class source."""
# Use a real third-party package (pydantic) so jedi can actually resolve it

View file

@ -7,7 +7,12 @@ from pathlib import Path
import libcst as cst
from codeflash.languages.python.static_analysis.code_extractor import delete___future___aliased_imports, find_preexisting_objects
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.languages.python.static_analysis.code_extractor import (
delete___future___aliased_imports,
find_preexisting_objects,
)
from codeflash.languages.python.static_analysis.code_replacer import (
AddRequestArgument,
AutouseFixtureModifier,
@ -16,9 +21,7 @@ from codeflash.languages.python.static_analysis.code_replacer import (
replace_functions_and_add_imports,
replace_functions_in_file,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, FunctionSource
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"

View file

@ -20,7 +20,11 @@ from codeflash.code_utils.code_utils import (
validate_python_code,
)
from codeflash.languages.python.static_analysis.concolic_utils import clean_concolic_tests
from codeflash.languages.python.static_analysis.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
from codeflash.languages.python.static_analysis.coverage_utils import (
extract_dependent_function,
generate_candidates,
prepare_coverage_files,
)
from codeflash.models.models import CodeStringsMarkdown
from codeflash.verification.parse_test_output import resolve_test_file_from_class_path

View file

@ -1,4 +1,3 @@
from codeflash.benchmarking.codeflash_trace import codeflash_trace

View file

@ -592,12 +592,10 @@ def test_itertools_permutations_combinations() -> None:
assert comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 2))
assert not comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 3))
assert comparator(
itertools.combinations_with_replacement("ABC", 2),
itertools.combinations_with_replacement("ABC", 2),
itertools.combinations_with_replacement("ABC", 2), itertools.combinations_with_replacement("ABC", 2)
)
assert not comparator(
itertools.combinations_with_replacement("ABC", 2),
itertools.combinations_with_replacement("ABD", 2),
itertools.combinations_with_replacement("ABC", 2), itertools.combinations_with_replacement("ABD", 2)
)
@ -615,38 +613,31 @@ def test_itertools_filtering() -> None:
# compress
assert comparator(
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]), itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1])
)
assert not comparator(
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
itertools.compress("ABCDEF", [1, 1, 1, 0, 1, 1]),
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]), itertools.compress("ABCDEF", [1, 1, 1, 0, 1, 1])
)
# dropwhile
assert comparator(
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1])
)
assert not comparator(
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.dropwhile(lambda x: x < 5, [1, 4, 7, 4, 1]),
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.dropwhile(lambda x: x < 5, [1, 4, 7, 4, 1])
)
# takewhile
assert comparator(
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1])
)
assert not comparator(
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.takewhile(lambda x: x < 5, [1, 3, 6, 4, 1]),
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.takewhile(lambda x: x < 5, [1, 3, 6, 4, 1])
)
# filterfalse
assert comparator(
itertools.filterfalse(lambda x: x % 2, range(10)),
itertools.filterfalse(lambda x: x % 2, range(10)),
itertools.filterfalse(lambda x: x % 2, range(10)), itertools.filterfalse(lambda x: x % 2, range(10))
)
@ -654,25 +645,19 @@ def test_itertools_starmap() -> None:
import itertools
assert comparator(
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
)
assert not comparator(
itertools.starmap(pow, [(2, 3), (3, 2)]),
itertools.starmap(pow, [(2, 3), (3, 3)]),
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]), itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)])
)
assert not comparator(itertools.starmap(pow, [(2, 3), (3, 2)]), itertools.starmap(pow, [(2, 3), (3, 3)]))
def test_itertools_zip_longest() -> None:
import itertools
assert comparator(
itertools.zip_longest("AB", "xyz", fillvalue="-"),
itertools.zip_longest("AB", "xyz", fillvalue="-"),
itertools.zip_longest("AB", "xyz", fillvalue="-"), itertools.zip_longest("AB", "xyz", fillvalue="-")
)
assert not comparator(
itertools.zip_longest("AB", "xyz", fillvalue="-"),
itertools.zip_longest("AB", "xyz", fillvalue="*"),
itertools.zip_longest("AB", "xyz", fillvalue="-"), itertools.zip_longest("AB", "xyz", fillvalue="*")
)
@ -685,8 +670,7 @@ def test_itertools_groupby() -> None:
# With key function
assert comparator(
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x), itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x)
)
@ -714,10 +698,7 @@ def test_itertools_in_containers() -> None:
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
)
assert not comparator(
[itertools.product("AB", repeat=2)],
[itertools.product("AC", repeat=2)],
)
assert not comparator([itertools.product("AB", repeat=2)], [itertools.product("AC", repeat=2)])
# Different itertools types should not match
assert not comparator(itertools.chain([1, 2]), itertools.islice([1, 2], 2))
@ -2017,59 +1998,30 @@ def test_torch_nn_sequential():
# Test identical Sequential modules
torch.manual_seed(42)
a = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
a = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
torch.manual_seed(42)
b = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
b = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
assert comparator(a, b)
# Test Sequential with different weights
torch.manual_seed(42)
c = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
c = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
torch.manual_seed(123)
d = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
d = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
assert not comparator(c, d)
# Test Sequential with different number of layers
torch.manual_seed(42)
e = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU()
)
e = nn.Sequential(nn.Linear(10, 20), nn.ReLU())
torch.manual_seed(42)
f = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
f = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
assert not comparator(e, f)
# Test Sequential with different layer types
torch.manual_seed(42)
g = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU()
)
g = nn.Sequential(nn.Linear(10, 20), nn.ReLU())
torch.manual_seed(42)
h = nn.Sequential(
nn.Linear(10, 20),
nn.Sigmoid()
)
h = nn.Sequential(nn.Linear(10, 20), nn.Sigmoid())
assert not comparator(g, h)
@ -2106,28 +2058,16 @@ def test_torch_nn_moduledict():
# Test identical ModuleDict
torch.manual_seed(42)
a = nn.ModuleDict({
"fc1": nn.Linear(10, 20),
"fc2": nn.Linear(20, 5)
})
a = nn.ModuleDict({"fc1": nn.Linear(10, 20), "fc2": nn.Linear(20, 5)})
torch.manual_seed(42)
b = nn.ModuleDict({
"fc1": nn.Linear(10, 20),
"fc2": nn.Linear(20, 5)
})
b = nn.ModuleDict({"fc1": nn.Linear(10, 20), "fc2": nn.Linear(20, 5)})
assert comparator(a, b)
# Test ModuleDict with different keys
torch.manual_seed(42)
c = nn.ModuleDict({
"fc1": nn.Linear(10, 20),
"fc2": nn.Linear(20, 5)
})
c = nn.ModuleDict({"fc1": nn.Linear(10, 20), "fc2": nn.Linear(20, 5)})
torch.manual_seed(42)
d = nn.ModuleDict({
"layer1": nn.Linear(10, 20),
"layer2": nn.Linear(20, 5)
})
d = nn.ModuleDict({"layer1": nn.Linear(10, 20), "layer2": nn.Linear(20, 5)})
assert not comparator(c, d)

View file

@ -294,7 +294,7 @@ class MockTestConfig:
"""Mocks codeflash.verification.verification_utils.TestConfig"""
tests_root: Path
tests_project_rootdir: Path = Path(".")
tests_project_rootdir: Path = Path()
@contextlib.contextmanager

View file

@ -4,8 +4,8 @@ import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import is_successful
from codeflash.models.models import FunctionParent
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.models.models import FunctionParent
from codeflash.verification.verification_utils import TestConfig

View file

@ -3,8 +3,8 @@ from pathlib import Path
import pytest
from codeflash.languages.python.static_analysis.code_extractor import get_code
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.static_analysis.code_extractor import get_code
from codeflash.models.models import FunctionParent

View file

@ -6,8 +6,8 @@ import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import is_successful
from codeflash.models.models import FunctionParent, get_code_block_splitter
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.models.models import FunctionParent, get_code_block_splitter
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.verification_utils import TestConfig

View file

@ -412,7 +412,9 @@ def test_conditional_class_definitions() -> None:
platform = "other"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set()).code
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set()
).code
assert dedent(expected).strip() == output.strip()

View file

@ -123,7 +123,9 @@ def test_multiple_top_level_classes() -> None:
def process(self):
return "C"
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"}).code
result = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"}
).code
expected = dedent("""
class ClassA:

View file

@ -304,7 +304,9 @@ def test_conditional_class_definitions() -> None:
print("other")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set()).code
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set()
).code
assert dedent(expected).strip() == output.strip()

View file

@ -305,9 +305,7 @@ class TestShouldModifySkipConfirm:
"""With skip_confirm and valid config, should return (False, config) — no reconfigure."""
monkeypatch.chdir(tmp_project)
codeflash_config = {"moduleRoot": "."}
(tmp_project / "package.json").write_text(
json.dumps({"name": "test", "codeflash": codeflash_config})
)
(tmp_project / "package.json").write_text(json.dumps({"name": "test", "codeflash": codeflash_config}))
should_modify, config = should_modify_package_json_config(skip_confirm=True)
@ -320,9 +318,7 @@ class TestShouldModifySkipConfirm:
"""With skip_confirm and invalid config (bad moduleRoot), should return (True, None)."""
monkeypatch.chdir(tmp_project)
codeflash_config = {"moduleRoot": "/nonexistent/path/that/does/not/exist"}
(tmp_project / "package.json").write_text(
json.dumps({"name": "test", "codeflash": codeflash_config})
)
(tmp_project / "package.json").write_text(json.dumps({"name": "test", "codeflash": codeflash_config}))
should_modify, config = should_modify_package_json_config(skip_confirm=True)

View file

@ -470,8 +470,7 @@ class OuterClass:
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = nested_async_code.replace(
" async def nested_async_method",
f" @{decorator_name}\n async def nested_async_method",
" async def nested_async_method", f" @{decorator_name}\n async def nested_async_method"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)

View file

@ -2,10 +2,10 @@ import os
from pathlib import Path
from tempfile import TemporaryDirectory
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
from codeflash.models.models import CodeOptimizationContext
from codeflash.verification.verification_utils import TestConfig

View file

@ -15,8 +15,9 @@ from codeflash.code_utils.instrument_existing_tests import (
FunctionImportedAsVisitor,
inject_profiling_into_existing_test,
)
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports
from codeflash.models.models import (
CodeOptimizationContext,
CodePosition,
@ -27,7 +28,6 @@ from codeflash.models.models import (
TestsInFile,
TestType,
)
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):

View file

@ -776,7 +776,8 @@ public void testFibonacci() {
def test_junit4_message_first_with_string_expected(self):
"""When assertEquals has 3 args and the first is a message but the second is also a string,
the type should be inferred from the second arg (the real expected value), not the message."""
the type should be inferred from the second arg (the real expected value), not the message.
"""
source = """\
@Test
public void testGetName() {
@ -807,7 +808,8 @@ public void testIsValid() {
def test_two_arg_string_expected_not_treated_as_message(self):
"""When assertEquals has only 2 args and the first is a string, it IS the expected value,
not a message. This tests that we don't incorrectly skip the first arg."""
not a message. This tests that we don't incorrectly skip the first arg.
"""
source = """\
@Test
public void testGetGreeting() {
@ -869,8 +871,7 @@ void test() {
def test_qualified_name_support(self):
transformer = JavaAssertTransformer(
function_name="fibonacci",
qualified_name="com.example.Calculator.fibonacci",
function_name="fibonacci", qualified_name="com.example.Calculator.fibonacci"
)
assert transformer.qualified_name == "com.example.Calculator.fibonacci"

View file

@ -2,14 +2,11 @@
import subprocess
from pathlib import Path
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from codeflash.languages.java.test_runner import (
_multimodule_deps_installed,
ensure_multi_module_deps_installed,
)
from codeflash.languages.java.test_runner import _multimodule_deps_installed, ensure_multi_module_deps_installed
@pytest.fixture(autouse=True)
@ -85,9 +82,7 @@ def test_different_modules_not_cached(mock_run, mock_mvn):
@patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout")
def test_returns_false_on_maven_failure(mock_run, mock_mvn):
"""Non-zero exit code should return False and NOT cache."""
mock_run.return_value = subprocess.CompletedProcess(
args=["mvn"], returncode=1, stdout="", stderr="BUILD FAILURE"
)
mock_run.return_value = subprocess.CompletedProcess(args=["mvn"], returncode=1, stdout="", stderr="BUILD FAILURE")
root = Path("/project")
result = ensure_multi_module_deps_installed(root, "guava-tests", {})

View file

@ -41,7 +41,7 @@ def make_func(name: str, class_name: str, file_path: Path | None = None) -> Func
def make_test_method(
name: str, class_name: str, starting_line: int, ending_line: int, file_path: Path | None = None,
name: str, class_name: str, starting_line: int, ending_line: int, file_path: Path | None = None
) -> FunctionToOptimize:
return FunctionToOptimize(
function_name=name,
@ -329,9 +329,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
type_map = {"calc": "Calculator"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 5, analyzer, type_map, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 5, analyzer, type_map, {})
assert "Calculator.add" in resolved
def test_static_method_call(self, analyzer):
@ -344,9 +342,7 @@ class FooTest {
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
assert "Calculator.add" in resolved
def test_static_import_call(self, analyzer):
@ -361,9 +357,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
static_map = {"add": "Calculator"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 3, 5, analyzer, {}, static_map,
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 3, 5, analyzer, {}, static_map)
assert "Calculator.add" in resolved
def test_new_expression_method_call(self, analyzer):
@ -376,9 +370,7 @@ class FooTest {
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
assert "Calculator.add" in resolved
def test_field_access_via_this(self, analyzer):
@ -393,9 +385,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
type_map = {"calculator": "Calculator"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 3, 5, analyzer, type_map, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 3, 5, analyzer, type_map, {})
assert "Calculator.add" in resolved
def test_unresolvable_call_not_included(self, analyzer):
@ -408,9 +398,7 @@ class FooTest {
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
# someUnknown is lowercase and not in type_map → not resolved
assert len(resolved) == 0
@ -425,9 +413,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
# assertEquals has no receiver, and not in static_import_map
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
assert len(resolved) == 0
def test_multiple_different_receivers(self, analyzer):
@ -444,9 +430,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
type_map = {"calc": "Calculator", "buf": "Buffer"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 7, analyzer, type_map, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 7, analyzer, type_map, {})
assert "Calculator.add" in resolved
assert "Buffer.read" in resolved
@ -466,9 +450,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
type_map = {"calc": "Calculator"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 6, 9, analyzer, type_map, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 6, 9, analyzer, type_map, {})
assert "Calculator.add" in resolved
assert "Calculator.init" not in resolved
@ -769,10 +751,7 @@ class CalculatorTest {
}
}
"""
func_map = {
"Calculator.add": make_func("add", "Calculator"),
"Buffer.add": make_func("add", "Buffer"),
}
func_map = {"Calculator.add": make_func("add", "Calculator"), "Buffer.add": make_func("add", "Buffer")}
test_method = make_test_method("testAdd", "CalculatorTest", 6, 10)
matched = _match_test_to_functions(test_method, test_source, func_map, analyzer)
# Local Calculator declaration shadows the Buffer field
@ -792,7 +771,8 @@ class TestDiscoverTests:
test_dir.mkdir(parents=True)
test_file = test_dir / "CalculatorTest.java"
test_file.write_text("""\
test_file.write_text(
"""\
package com.example;
import com.example.Calculator;
@ -814,7 +794,9 @@ class CalculatorTest {
assertEquals(2, result);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [
make_func("add", "Calculator"),
@ -840,7 +822,8 @@ class CalculatorTest {
test_dir.mkdir(parents=True)
test_file = test_dir / "MathUtilsTest.java"
test_file.write_text("""\
test_file.write_text(
"""\
package com.example;
import com.example.MathUtils;
@ -857,7 +840,9 @@ class MathUtilsTest {
int result = MathUtils.abs(-3);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [
make_func("square", "MathUtils"),
@ -880,7 +865,8 @@ class MathUtilsTest {
test_dir.mkdir(parents=True)
test_file = test_dir / "CalculatorTest.java"
test_file.write_text("""\
test_file.write_text(
"""\
package com.example;
import com.example.Calculator;
@ -905,7 +891,9 @@ class CalculatorTest {
calculator.multiply(3, 4);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [
make_func("add", "Calculator"),
@ -1074,9 +1062,7 @@ class FooTest {
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 5, analyzer, {"obj": "Object"}, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 5, analyzer, {"obj": "Object"}, {})
assert "Calculator.add" in resolved
def test_method_call_inside_if(self, analyzer):
@ -1093,9 +1079,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
type_map = {"calc": "Calculator"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 7, analyzer, type_map, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 7, analyzer, type_map, {})
assert "Calculator.add" in resolved
def test_method_call_inside_try_catch(self, analyzer):
@ -1114,9 +1098,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
type_map = {"calc": "Calculator"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 9, analyzer, type_map, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 9, analyzer, type_map, {})
assert "Calculator.add" in resolved
assert "Calculator.reset" in resolved
@ -1134,9 +1116,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
type_map = {"calc": "Calculator"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 7, analyzer, type_map, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 7, analyzer, type_map, {})
assert "Calculator.add" in resolved
def test_method_call_inside_lambda(self, analyzer):
@ -1151,9 +1131,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
type_map = {"calc": "Calculator"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 5, analyzer, type_map, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 5, analyzer, type_map, {})
assert "Calculator.add" in resolved
def test_duplicate_calls_resolved_once(self, analyzer):
@ -1170,9 +1148,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
type_map = {"calc": "Calculator"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 7, analyzer, type_map, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 7, analyzer, type_map, {})
# resolved is a set, so duplicates are naturally deduplicated
assert resolved == {"Calculator.add", "Calculator.Calculator", "Calculator.<init>"}
@ -1190,9 +1166,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
type_map = {"calc": "Calculator", "buf": "Buffer"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 7, analyzer, type_map, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 7, analyzer, type_map, {})
assert "Calculator.add" in resolved
assert "Buffer.add" in resolved
# Also includes constructor refs: Calculator.Calculator, Calculator.<init>, Buffer.Buffer, Buffer.<init>
@ -1212,9 +1186,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
type_map = {"calc": "Calculator"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 5, analyzer, type_map, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 5, analyzer, type_map, {})
# calc.getResult() resolves to Calculator.getResult
assert "Calculator.getResult" in resolved
# toString() is called on the return of getResult() which is unresolvable
@ -1231,9 +1203,7 @@ class FooTest {
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
assert len(resolved) == 0
def test_this_method_call_not_resolved(self, analyzer):
@ -1247,9 +1217,7 @@ class FooTest {
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
# this is not a field_access with a field that's in the type map, so not resolved
assert len(resolved) == 0
@ -1263,9 +1231,7 @@ class FooTest {
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
# getCalculator() returns a method_invocation node as object, can't resolve
assert "Calculator.add" not in resolved
@ -1279,9 +1245,7 @@ class FooTest {
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
assert "ArrayList.add" in resolved
def test_assertion_via_static_import_mapped_to_assertions_class(self, analyzer):
@ -1297,9 +1261,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
static_map = {"assertEquals": "Assertions"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 3, 5, analyzer, {}, static_map,
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 3, 5, analyzer, {}, static_map)
assert "Assertions.assertEquals" in resolved
assert len(resolved) == 1
@ -1314,9 +1276,7 @@ class FooTest {
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
assert "Calculator.Calculator" in resolved
assert "Calculator.<init>" in resolved
@ -1334,9 +1294,7 @@ class FooTest {
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
type_map = {"records": "List"}
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 6, analyzer, type_map, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 6, analyzer, type_map, {})
assert "BatchRead.BatchRead" in resolved
assert "BatchRead.<init>" in resolved
assert "Key.Key" in resolved
@ -1353,9 +1311,7 @@ class FooTest {
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
resolved = _resolve_method_calls_in_range(
tree.root_node, source_bytes, 2, 4, analyzer, {}, {},
)
resolved = _resolve_method_calls_in_range(tree.root_node, source_bytes, 2, 4, analyzer, {}, {})
assert "HashMap.HashMap" in resolved
assert "HashMap.<init>" in resolved
@ -1379,10 +1335,7 @@ class MyTest {
}
}
"""
func_map = {
"Calculator.add": make_func("add", "Calculator"),
"MathUtils.add": make_func("add", "MathUtils"),
}
func_map = {"Calculator.add": make_func("add", "Calculator"), "MathUtils.add": make_func("add", "MathUtils")}
test_method = make_test_method("testAdd", "MyTest", 4, 8)
matched = _match_test_to_functions(test_method, test_source, func_map, analyzer)
assert matched == ["Calculator.add"]
@ -1564,7 +1517,7 @@ class CalculatorTest {
assert matched == []
def test_constructor_matched(self, analyzer):
"""new ClassName() should match the constructor in the function map."""
"""New ClassName() should match the constructor in the function map."""
test_source = """\
import org.junit.jupiter.api.Test;
@ -1582,7 +1535,7 @@ class BatchReadTest {
assert "BatchRead.BatchRead" in matched
def test_constructor_init_convention_matched(self, analyzer):
"""new ClassName() should also match <init> naming convention."""
"""New ClassName() should also match <init> naming convention."""
test_source = """\
import org.junit.jupiter.api.Test;
@ -1599,7 +1552,7 @@ class BatchReadTest {
assert "BatchRead.<init>" in matched
def test_constructor_does_not_match_unrelated_methods(self, analyzer):
"""new BatchRead() should not cause BatchRead.read to match."""
"""New BatchRead() should not cause BatchRead.read to match."""
test_source = """\
import org.junit.jupiter.api.Test;
@ -1660,7 +1613,8 @@ class TestDiscoverTestsExtended:
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "CalculatorTests.java").write_text("""\
(test_dir / "CalculatorTests.java").write_text(
"""\
package com.example;
import org.junit.jupiter.api.Test;
@ -1671,7 +1625,9 @@ class CalculatorTests {
calc.add(1, 2);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [make_func("add", "Calculator")]
result = discover_tests(tmp_path, source_functions, analyzer)
@ -1682,7 +1638,8 @@ class CalculatorTests {
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "TestCalculator.java").write_text("""\
(test_dir / "TestCalculator.java").write_text(
"""\
package com.example;
import org.junit.jupiter.api.Test;
@ -1693,7 +1650,9 @@ class TestCalculator {
calc.add(1, 2);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [make_func("add", "Calculator")]
result = discover_tests(tmp_path, source_functions, analyzer)
@ -1710,7 +1669,8 @@ class TestCalculator {
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "CalculatorTest.java").write_text("""\
(test_dir / "CalculatorTest.java").write_text(
"""\
package com.example;
import org.junit.jupiter.api.Test;
@ -1733,12 +1693,11 @@ class CalculatorTest {
calc.subtract(5, 3);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [
make_func("add", "Calculator"),
make_func("subtract", "Calculator"),
]
source_functions = [make_func("add", "Calculator"), make_func("subtract", "Calculator")]
result = discover_tests(tmp_path, source_functions, analyzer)
assert "Calculator.add" in result
@ -1753,7 +1712,8 @@ class CalculatorTest {
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "CalculatorTest.java").write_text("""\
(test_dir / "CalculatorTest.java").write_text(
"""\
package com.example;
import org.junit.jupiter.api.Test;
@ -1764,9 +1724,12 @@ class CalculatorTest {
calc.add(1, 2);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
(test_dir / "IntegrationTest.java").write_text("""\
(test_dir / "IntegrationTest.java").write_text(
"""\
package com.example;
import org.junit.jupiter.api.Test;
@ -1777,7 +1740,9 @@ class IntegrationTest {
calc.add(10, 20);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [make_func("add", "Calculator")]
result = discover_tests(tmp_path, source_functions, analyzer)
@ -1791,7 +1756,8 @@ class IntegrationTest {
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "CalculatorTest.java").write_text("""\
(test_dir / "CalculatorTest.java").write_text(
"""\
package com.example;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
@ -1804,7 +1770,9 @@ class CalculatorTest {
calc.add(a, b);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [make_func("add", "Calculator")]
result = discover_tests(tmp_path, source_functions, analyzer)
@ -1815,7 +1783,8 @@ class CalculatorTest {
deep_dir = tmp_path / "test" / "com" / "example" / "deep"
deep_dir.mkdir(parents=True)
(deep_dir / "NestedTest.java").write_text("""\
(deep_dir / "NestedTest.java").write_text(
"""\
package com.example.deep;
import org.junit.jupiter.api.Test;
@ -1826,7 +1795,9 @@ class NestedTest {
calc.add(1, 2);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [make_func("add", "Calculator")]
result = discover_tests(tmp_path, source_functions, analyzer)
@ -1836,7 +1807,8 @@ class NestedTest {
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "CalculatorTest.java").write_text("""\
(test_dir / "CalculatorTest.java").write_text(
"""\
package com.example;
import org.junit.jupiter.api.Test;
@ -1847,7 +1819,9 @@ class CalculatorTest {
calc.add(1, 2);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [make_func("add", "Calculator")]
result = discover_tests(tmp_path, source_functions, analyzer)
@ -1857,7 +1831,8 @@ class CalculatorTest {
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "CalculatorTest.java").write_text("""\
(test_dir / "CalculatorTest.java").write_text(
"""\
package com.example;
import org.junit.jupiter.api.Test;
@ -1868,7 +1843,9 @@ class CalculatorTest {
calc.add(1, 2);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
result = discover_tests(tmp_path, [], analyzer)
assert result == {}
@ -1878,7 +1855,8 @@ class CalculatorTest {
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "BatchReadTest.java").write_text("""\
(test_dir / "BatchReadTest.java").write_text(
"""\
package com.aerospike.test;
import com.aerospike.client.BatchRead;
import com.aerospike.client.Key;
@ -1892,7 +1870,9 @@ class BatchReadTest {
records.add(new BatchRead(new Key("ns", "set", "k2"), false));
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [
make_func("BatchRead", "BatchRead"),
@ -1988,7 +1968,8 @@ class TestFindTestsForFunction:
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "CalculatorTest.java").write_text("""\
(test_dir / "CalculatorTest.java").write_text(
"""\
package com.example;
import org.junit.jupiter.api.Test;
@ -1999,7 +1980,9 @@ class CalculatorTest {
calc.add(1, 2);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
func = make_func("add", "Calculator")
result = find_tests_for_function(func, tmp_path, analyzer)
@ -2020,7 +2003,8 @@ class TestDiscoverAllTests:
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "CalculatorTest.java").write_text("""\
(test_dir / "CalculatorTest.java").write_text(
"""\
package com.example;
import org.junit.jupiter.api.Test;
@ -2031,7 +2015,9 @@ class CalculatorTest {
@Test
void testSubtract() {}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
all_tests = discover_all_tests(tmp_path, analyzer)
names = {t.function_name for t in all_tests}
@ -2048,32 +2034,40 @@ class CalculatorTest {
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "ATest.java").write_text("""\
(test_dir / "ATest.java").write_text(
"""\
import org.junit.jupiter.api.Test;
class ATest {
@Test
void testA() {}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
(test_dir / "BTest.java").write_text("""\
(test_dir / "BTest.java").write_text(
"""\
import org.junit.jupiter.api.Test;
class BTest {
@Test
void testB() {}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
all_tests = discover_all_tests(tmp_path, analyzer)
names = {t.function_name for t in all_tests}
assert names == {"testA", "testB"}
def test_no_false_positive_import_only_integration(self, tmp_path, analyzer):
"""A test file that imports Calculator but never calls its methods should not match."""
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
test_file = test_dir / "SomeTest.java"
test_file.write_text("""\
test_file.write_text(
"""\
package com.example;
import com.example.Calculator;
@ -2085,12 +2079,11 @@ class SomeTest {
int x = 42;
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [
make_func("add", "Calculator"),
make_func("subtract", "Calculator"),
]
source_functions = [make_func("add", "Calculator"), make_func("subtract", "Calculator")]
result = discover_tests(tmp_path, source_functions, analyzer)
assert result == {}
@ -2099,7 +2092,8 @@ class SomeTest {
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "CalculatorTest.java").write_text("""\
(test_dir / "CalculatorTest.java").write_text(
"""\
package com.example;
import org.junit.jupiter.api.Test;
@ -2110,9 +2104,12 @@ class CalculatorTest {
calc.add(1, 2);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
(test_dir / "BufferTest.java").write_text("""\
(test_dir / "BufferTest.java").write_text(
"""\
package com.example;
import org.junit.jupiter.api.Test;
@ -2123,13 +2120,11 @@ class BufferTest {
buf.read();
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [
make_func("add", "Calculator"),
make_func("read", "Buffer"),
make_func("write", "Buffer"),
]
source_functions = [make_func("add", "Calculator"), make_func("read", "Buffer"), make_func("write", "Buffer")]
result = discover_tests(tmp_path, source_functions, analyzer)
@ -2147,7 +2142,8 @@ class BufferTest {
test_dir.mkdir(parents=True)
# This file matches *Test.java pattern
(test_dir / "CalculatorTest.java").write_text("""\
(test_dir / "CalculatorTest.java").write_text(
"""\
package com.example;
import org.junit.jupiter.api.Test;
@ -2158,7 +2154,9 @@ class CalculatorTest {
calc.add(1, 2);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [make_func("add", "Calculator")]
result = discover_tests(tmp_path, source_functions, analyzer)
@ -2171,7 +2169,8 @@ class CalculatorTest {
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "MathUtilsTest.java").write_text("""\
(test_dir / "MathUtilsTest.java").write_text(
"""\
package com.example;
import static com.example.MathUtils.square;
import org.junit.jupiter.api.Test;
@ -2182,12 +2181,11 @@ class MathUtilsTest {
int result = square(5);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [
make_func("square", "MathUtils"),
make_func("cube", "MathUtils"),
]
source_functions = [make_func("square", "MathUtils"), make_func("cube", "MathUtils")]
result = discover_tests(tmp_path, source_functions, analyzer)
@ -2198,7 +2196,8 @@ class MathUtilsTest {
test_dir = tmp_path / "test"
test_dir.mkdir(parents=True)
(test_dir / "CalculatorTest.java").write_text("""\
(test_dir / "CalculatorTest.java").write_text(
"""\
package com.example;
import org.junit.jupiter.api.Test;
@ -2210,7 +2209,9 @@ class CalculatorTest {
int b = calc.multiply(a, 3);
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
source_functions = [
make_func("add", "Calculator"),

View file

@ -3,9 +3,10 @@
import subprocess
from pathlib import Path
from unittest.mock import patch
import pytest
from codeflash.languages.java.test_runner import _run_maven_tests, _build_test_filter
from codeflash.languages.java.test_runner import _build_test_filter, _run_maven_tests
from codeflash.models.models import TestFile, TestFiles, TestType
@ -40,15 +41,11 @@ def test_build_test_filter_with_valid_paths():
test_files = TestFiles(
test_files=[
TestFile(
instrumented_behavior_file_path=Path(
"/project/src/test/java/com/example/Test1__perfinstrumented.java"
),
benchmarking_file_path=Path(
"/project/src/test/java/com/example/Test1__perfonlyinstrumented.java"
),
instrumented_behavior_file_path=Path("/project/src/test/java/com/example/Test1__perfinstrumented.java"),
benchmarking_file_path=Path("/project/src/test/java/com/example/Test1__perfonlyinstrumented.java"),
original_file_path=Path("/project/src/test/java/com/example/Test1.java"),
test_type=TestType.EXISTING_UNIT_TEST,
),
)
]
)
@ -71,7 +68,7 @@ def test_run_maven_tests_raises_on_empty_filter():
benchmarking_file_path=None, # Will cause empty filter in performance mode
original_file_path=Path("/tmp/test.java"),
test_type=TestType.EXISTING_UNIT_TEST,
),
)
]
)
@ -99,37 +96,26 @@ def test_run_maven_tests_succeeds_with_valid_filter():
test_files = TestFiles(
test_files=[
TestFile(
instrumented_behavior_file_path=Path(
"/tmp/src/test/java/com/example/Test__perfinstrumented.java"
),
benchmarking_file_path=Path(
"/tmp/src/test/java/com/example/Test__perfonlyinstrumented.java"
),
instrumented_behavior_file_path=Path("/tmp/src/test/java/com/example/Test__perfinstrumented.java"),
benchmarking_file_path=Path("/tmp/src/test/java/com/example/Test__perfonlyinstrumented.java"),
original_file_path=Path("/tmp/src/test/java/com/example/Test.java"),
test_type=TestType.EXISTING_UNIT_TEST,
),
)
]
)
# Mock Maven executable and _run_cmd_kill_pg_on_timeout (which replaced subprocess.run)
with patch("codeflash.languages.java.test_runner.find_maven_executable") as mock_maven, \
patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout") as mock_run:
with (
patch("codeflash.languages.java.test_runner.find_maven_executable") as mock_maven,
patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout") as mock_run,
):
mock_maven.return_value = "mvn"
mock_run.return_value = subprocess.CompletedProcess(
args=[],
returncode=0,
stdout="Tests run: 1, Failures: 0, Errors: 0, Skipped: 0",
stderr="",
args=[], returncode=0, stdout="Tests run: 1, Failures: 0, Errors: 0, Skipped: 0", stderr=""
)
# Should not raise - filter is valid
result = _run_maven_tests(
project_root,
test_files,
env,
timeout=60,
mode="performance",
)
result = _run_maven_tests(project_root, test_files, env, timeout=60, mode="performance")
# Verify Maven was called with -Dtest parameter
assert mock_run.called

View file

@ -47,8 +47,7 @@ def test_java_tests_project_rootdir_set_to_tests_root(tmp_path):
# Verify that tests_project_rootdir was updated to tests_root
assert test_cfg.tests_project_rootdir == tests_root, (
f"Expected tests_project_rootdir to be {tests_root}, "
f"but got {test_cfg.tests_project_rootdir}"
f"Expected tests_project_rootdir to be {tests_root}, but got {test_cfg.tests_project_rootdir}"
)
@ -68,9 +67,7 @@ def test_python_tests_project_rootdir_unchanged(tmp_path):
# Create test config
original_tests_project_rootdir = project_root / "some" / "other" / "dir"
test_cfg = TestConfig(
tests_root=tests_root,
project_root_path=project_root,
tests_project_rootdir=original_tests_project_rootdir,
tests_root=tests_root, project_root_path=project_root, tests_project_rootdir=original_tests_project_rootdir
)
# Mock pytest discovery

View file

@ -17,10 +17,7 @@ def make_func(name: str, class_name: str | None = None) -> FunctionToOptimize:
"""Helper to create FunctionToOptimize for testing."""
parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else []
return FunctionToOptimize(
function_name=name,
file_path=Path("/test/file.js"),
parents=parents,
language="javascript",
function_name=name, file_path=Path("/test/file.js"), parents=parents, language="javascript"
)
@ -458,7 +455,9 @@ class TestQualifiedNames:
def test_simple_qualified_name(self) -> None:
"""Test simple qualified name."""
code = "expect(func(5)).toBe(5);"
result, _ = transform_expect_calls(code, make_func("func", class_name="module"), "capture", remove_assertions=True)
result, _ = transform_expect_calls(
code, make_func("func", class_name="module"), "capture", remove_assertions=True
)
assert result == "codeflash.capture('module.func', '1', func, 5);"
def test_nested_qualified_name(self) -> None:

View file

@ -26,8 +26,8 @@ import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.verification.verification_utils import TestConfig
@ -1840,7 +1840,9 @@ export const sendSlackMessage = async (
test_config = TestConfig(
tests_root=temp_project, tests_project_rootdir=temp_project, project_root_path=temp_project
)
func_optimizer = JavaScriptFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config, aiservice_client=MagicMock())
func_optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=fto, test_cfg=test_config, aiservice_client=MagicMock()
)
ctx = func_optimizer.get_code_optimization_context().unwrap()
# The read_writable_code should contain the target function AND helper functions

View file

@ -10,18 +10,19 @@ Each test verifies:
from __future__ import annotations
import pytest
from pathlib import Path
import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language, ReferenceInfo
from codeflash.languages.javascript.find_references import (
ExportedFunction,
Reference,
ReferenceFinder,
ExportedFunction,
ReferenceSearchContext,
find_references,
)
from codeflash.languages.base import Language, FunctionInfo, ReferenceInfo
from codeflash.languages.python.static_analysis.code_extractor import _format_references_as_markdown
from codeflash.models.models import FunctionParent
@ -29,12 +30,7 @@ from codeflash.models.models import FunctionParent
def make_func(name: str, file_path: Path, class_name: str | None = None) -> FunctionToOptimize:
"""Helper to create FunctionToOptimize for testing."""
parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else []
return FunctionToOptimize(
function_name=name,
file_path=file_path,
parents=parents,
language="javascript",
)
return FunctionToOptimize(function_name=name, file_path=file_path, parents=parents, language="javascript")
class TestReferenceFinder:
@ -93,30 +89,30 @@ class TestBasicNamedExports:
# Source file with named export
(utils_dir / "DynamicBindingUtils.ts").write_text(
'export function getDynamicBindings(value: string): string[] {\n'
' const regex = /{{([^}]+)}}/g;\n'
' return [];\n'
'}\n'
"export function getDynamicBindings(value: string): string[] {\n"
" const regex = /{{([^}]+)}}/g;\n"
" return [];\n"
"}\n"
)
# File that imports and uses the function
(src_dir / "evaluator.ts").write_text(
"import { getDynamicBindings } from './utils/DynamicBindingUtils';\n"
'\n'
'export function evaluate(expression: string) {\n'
' const bindings = getDynamicBindings(expression);\n'
' return bindings;\n'
'}\n'
"\n"
"export function evaluate(expression: string) {\n"
" const bindings = getDynamicBindings(expression);\n"
" return bindings;\n"
"}\n"
)
# Another file that uses the function
(src_dir / "validator.ts").write_text(
"import { getDynamicBindings } from './utils/DynamicBindingUtils';\n"
'\n'
'export function validateBindings(input: string) {\n'
' const bindings = getDynamicBindings(input);\n'
' return bindings.length > 0;\n'
'}\n'
"\n"
"export function validateBindings(input: string) {\n"
" const bindings = getDynamicBindings(input);\n"
" return bindings.length > 0;\n"
"}\n"
)
return tmp_path
@ -158,7 +154,8 @@ class TestBasicNamedExports:
refs = finder.find_references(make_func("getDynamicBindings", source_file))
# Convert to ReferenceInfo and sort for consistent ordering
ref_infos = sorted([
ref_infos = sorted(
[
ReferenceInfo(
file_path=r.file_path,
line=r.line,
@ -171,23 +168,25 @@ class TestBasicNamedExports:
caller_function=r.caller_function,
)
for r in refs
], key=lambda r: str(r.file_path))
],
key=lambda r: str(r.file_path),
)
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
expected_markdown = (
'```typescript:src/evaluator.ts\n'
'function evaluate(expression: string) {\n'
' const bindings = getDynamicBindings(expression);\n'
' return bindings;\n'
'}\n'
'```\n'
'```typescript:src/validator.ts\n'
'function validateBindings(input: string) {\n'
' const bindings = getDynamicBindings(input);\n'
' return bindings.length > 0;\n'
'}\n'
'```\n'
"```typescript:src/evaluator.ts\n"
"function evaluate(expression: string) {\n"
" const bindings = getDynamicBindings(expression);\n"
" return bindings;\n"
"}\n"
"```\n"
"```typescript:src/validator.ts\n"
"function validateBindings(input: string) {\n"
" const bindings = getDynamicBindings(input);\n"
" return bindings.length > 0;\n"
"}\n"
"```\n"
)
assert markdown == expected_markdown
@ -203,30 +202,30 @@ class TestDefaultExports:
# Source file with default export
(src_dir / "helper.ts").write_text(
'function processData(data: any[]) {\n'
' return data.filter(item => item.active);\n'
'}\n'
'\n'
'export default processData;\n'
"function processData(data: any[]) {\n"
" return data.filter(item => item.active);\n"
"}\n"
"\n"
"export default processData;\n"
)
# File that imports the default export
(src_dir / "main.ts").write_text(
"import processData from './helper';\n"
'\n'
'export function handleData(items: any[]) {\n'
' const processed = processData(items);\n'
' return processed.length;\n'
'}\n'
"\n"
"export function handleData(items: any[]) {\n"
" const processed = processData(items);\n"
" return processed.length;\n"
"}\n"
)
# File that imports with a different name
(src_dir / "alternative.ts").write_text(
"import myProcessor from './helper';\n"
'\n'
'export function process(items: any[]) {\n'
' return myProcessor(items);\n'
'}\n'
"\n"
"export function process(items: any[]) {\n"
" return myProcessor(items);\n"
"}\n"
)
return tmp_path
@ -263,30 +262,38 @@ class TestDefaultExports:
source_file = project_root / "src" / "helper.ts"
refs = finder.find_references(make_func("processData", source_file))
ref_infos = sorted([
ref_infos = sorted(
[
ReferenceInfo(
file_path=r.file_path, line=r.line, column=r.column,
end_line=r.end_line, end_column=r.end_column, context=r.context,
reference_type=r.reference_type, import_name=r.import_name,
file_path=r.file_path,
line=r.line,
column=r.column,
end_line=r.end_line,
end_column=r.end_column,
context=r.context,
reference_type=r.reference_type,
import_name=r.import_name,
caller_function=r.caller_function,
)
for r in refs
], key=lambda r: str(r.file_path))
],
key=lambda r: str(r.file_path),
)
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
expected_markdown = (
'```typescript:src/alternative.ts\n'
'function process(items: any[]) {\n'
' return myProcessor(items);\n'
'}\n'
'```\n'
'```typescript:src/main.ts\n'
'function handleData(items: any[]) {\n'
' const processed = processData(items);\n'
' return processed.length;\n'
'}\n'
'```\n'
"```typescript:src/alternative.ts\n"
"function process(items: any[]) {\n"
" return myProcessor(items);\n"
"}\n"
"```\n"
"```typescript:src/main.ts\n"
"function handleData(items: any[]) {\n"
" const processed = processData(items);\n"
" return processed.length;\n"
"}\n"
"```\n"
)
assert markdown == expected_markdown
@ -304,23 +311,21 @@ class TestReExports:
# Original function file
(utils_dir / "filterUtils.ts").write_text(
'export function filterBySearchTerm(items: any[], term: string) {\n'
' return items.filter(i => i.name.includes(term));\n'
'}\n'
"export function filterBySearchTerm(items: any[], term: string) {\n"
" return items.filter(i => i.name.includes(term));\n"
"}\n"
)
# Index file that re-exports
(utils_dir / "index.ts").write_text(
"export { filterBySearchTerm } from './filterUtils';\n"
)
(utils_dir / "index.ts").write_text("export { filterBySearchTerm } from './filterUtils';\n")
# Consumer that imports from index
(src_dir / "consumer.ts").write_text(
"import { filterBySearchTerm } from './utils';\n"
'\n'
'export function searchItems(items: any[], query: string) {\n'
' return filterBySearchTerm(items, query);\n'
'}\n'
"\n"
"export function searchItems(items: any[], query: string) {\n"
" return filterBySearchTerm(items, query);\n"
"}\n"
)
return tmp_path
@ -352,27 +357,35 @@ class TestReExports:
source_file = project_root / "src" / "utils" / "filterUtils.ts"
refs = finder.find_references(make_func("filterBySearchTerm", source_file))
ref_infos = sorted([
ref_infos = sorted(
[
ReferenceInfo(
file_path=r.file_path, line=r.line, column=r.column,
end_line=r.end_line, end_column=r.end_column, context=r.context,
reference_type=r.reference_type, import_name=r.import_name,
file_path=r.file_path,
line=r.line,
column=r.column,
end_line=r.end_line,
end_column=r.end_column,
context=r.context,
reference_type=r.reference_type,
import_name=r.import_name,
caller_function=r.caller_function,
)
for r in refs
], key=lambda r: str(r.file_path))
],
key=lambda r: str(r.file_path),
)
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
expected_markdown = (
'```typescript:src/consumer.ts\n'
'function searchItems(items: any[], query: string) {\n'
' return filterBySearchTerm(items, query);\n'
'}\n'
'```\n'
'```typescript:src/utils/index.ts\n'
"```typescript:src/consumer.ts\n"
"function searchItems(items: any[], query: string) {\n"
" return filterBySearchTerm(items, query);\n"
"}\n"
"```\n"
"```typescript:src/utils/index.ts\n"
"export { filterBySearchTerm } from './filterUtils';\n"
'```\n'
"```\n"
)
assert markdown == expected_markdown
@ -388,19 +401,17 @@ class TestCallbackPatterns:
# Helper function
(src_dir / "transforms.ts").write_text(
'export function normalizeItem(item: any) {\n'
' return { ...item, normalized: true };\n'
'}\n'
"export function normalizeItem(item: any) {\n return { ...item, normalized: true };\n}\n"
)
# Consumer using callbacks
(src_dir / "processor.ts").write_text(
"import { normalizeItem } from './transforms';\n"
'\n'
'export function processItems(items: any[]) {\n'
' const normalized = items.map(normalizeItem);\n'
' return normalized;\n'
'}\n'
"\n"
"export function processItems(items: any[]) {\n"
" const normalized = items.map(normalizeItem);\n"
" return normalized;\n"
"}\n"
)
return tmp_path
@ -430,9 +441,14 @@ class TestCallbackPatterns:
refs = finder.find_references(make_func("normalizeItem", source_file))
ref_infos = [
ReferenceInfo(
file_path=r.file_path, line=r.line, column=r.column,
end_line=r.end_line, end_column=r.end_column, context=r.context,
reference_type=r.reference_type, import_name=r.import_name,
file_path=r.file_path,
line=r.line,
column=r.column,
end_line=r.end_line,
end_column=r.end_column,
context=r.context,
reference_type=r.reference_type,
import_name=r.import_name,
caller_function=r.caller_function,
)
for r in refs
@ -441,12 +457,12 @@ class TestCallbackPatterns:
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
expected_markdown = (
'```typescript:src/processor.ts\n'
'function processItems(items: any[]) {\n'
' const normalized = items.map(normalizeItem);\n'
' return normalized;\n'
'}\n'
'```\n'
"```typescript:src/processor.ts\n"
"function processItems(items: any[]) {\n"
" const normalized = items.map(normalizeItem);\n"
" return normalized;\n"
"}\n"
"```\n"
)
assert expected_markdown == markdown
@ -462,19 +478,17 @@ class TestAliasImports:
# Source file
(src_dir / "utils.ts").write_text(
'export function computeValue(input: number): number {\n'
' return input * 2;\n'
'}\n'
"export function computeValue(input: number): number {\n return input * 2;\n}\n"
)
# File using alias
(src_dir / "consumer.ts").write_text(
"import { computeValue as calculate } from './utils';\n"
'\n'
'export function processNumber(n: number) {\n'
' const result = calculate(n);\n'
' return result + 10;\n'
'}\n'
"\n"
"export function processNumber(n: number) {\n"
" const result = calculate(n);\n"
" return result + 10;\n"
"}\n"
)
return tmp_path
@ -504,9 +518,14 @@ class TestAliasImports:
refs = finder.find_references(make_func("computeValue", source_file))
ref_infos = [
ReferenceInfo(
file_path=r.file_path, line=r.line, column=r.column,
end_line=r.end_line, end_column=r.end_column, context=r.context,
reference_type=r.reference_type, import_name=r.import_name,
file_path=r.file_path,
line=r.line,
column=r.column,
end_line=r.end_line,
end_column=r.end_column,
context=r.context,
reference_type=r.reference_type,
import_name=r.import_name,
caller_function=r.caller_function,
)
for r in refs
@ -515,12 +534,12 @@ class TestAliasImports:
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
expected_markdown = (
'```typescript:src/consumer.ts\n'
'function processNumber(n: number) {\n'
' const result = calculate(n);\n'
' return result + 10;\n'
'}\n'
'```\n'
"```typescript:src/consumer.ts\n"
"function processNumber(n: number) {\n"
" const result = calculate(n);\n"
" return result + 10;\n"
"}\n"
"```\n"
)
assert expected_markdown == markdown
@ -536,18 +555,16 @@ class TestNamespaceImports:
# Source file with multiple exports
(src_dir / "mathUtils.ts").write_text(
'export function add(a: number, b: number): number {\n'
' return a + b;\n'
'}\n'
"export function add(a: number, b: number): number {\n return a + b;\n}\n"
)
# File using namespace import
(src_dir / "calculator.ts").write_text(
"import * as MathUtils from './mathUtils';\n"
'\n'
'export function calculate(a: number, b: number) {\n'
' return MathUtils.add(a, b);\n'
'}\n'
"\n"
"export function calculate(a: number, b: number) {\n"
" return MathUtils.add(a, b);\n"
"}\n"
)
return tmp_path
@ -576,9 +593,14 @@ class TestNamespaceImports:
refs = finder.find_references(make_func("add", source_file))
ref_infos = [
ReferenceInfo(
file_path=r.file_path, line=r.line, column=r.column,
end_line=r.end_line, end_column=r.end_column, context=r.context,
reference_type=r.reference_type, import_name=r.import_name,
file_path=r.file_path,
line=r.line,
column=r.column,
end_line=r.end_line,
end_column=r.end_column,
context=r.context,
reference_type=r.reference_type,
import_name=r.import_name,
caller_function=r.caller_function,
)
for r in refs
@ -587,11 +609,11 @@ class TestNamespaceImports:
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
expected_markdown = (
'```typescript:src/calculator.ts\n'
'function calculate(a: number, b: number) {\n'
' return MathUtils.add(a, b);\n'
'}\n'
'```\n'
"```typescript:src/calculator.ts\n"
"function calculate(a: number, b: number) {\n"
" return MathUtils.add(a, b);\n"
"}\n"
"```\n"
)
assert expected_markdown == markdown
@ -607,21 +629,19 @@ class TestMemoizedFunctions:
# Source file with function to be memoized
(src_dir / "expensive.ts").write_text(
'export function computeExpensive(x: number): number {\n'
' return x * x;\n'
'}\n'
"export function computeExpensive(x: number): number {\n return x * x;\n}\n"
)
# File that memoizes the function
(src_dir / "memoized.ts").write_text(
"import memoize from 'micro-memoize';\n"
"import { computeExpensive } from './expensive';\n"
'\n'
'export const memoizedCompute = memoize(computeExpensive);\n'
'\n'
'export function process(x: number) {\n'
' return computeExpensive(x) + memoizedCompute(x);\n'
'}\n'
"\n"
"export const memoizedCompute = memoize(computeExpensive);\n"
"\n"
"export function process(x: number) {\n"
" return computeExpensive(x) + memoizedCompute(x);\n"
"}\n"
)
return tmp_path
@ -659,10 +679,10 @@ class TestSameFileReferences:
# File with internal references
(src_dir / "recursive.ts").write_text(
'export function factorial(n: number): number {\n'
' if (n <= 1) return 1;\n'
' return n * factorial(n - 1);\n'
'}\n'
"export function factorial(n: number): number {\n"
" if (n <= 1) return 1;\n"
" return n * factorial(n - 1);\n"
"}\n"
)
return tmp_path
@ -697,24 +717,20 @@ class TestComplexMultiFileScenarios:
# Core utility function
(src_dir / "utils" / "widgetUtils.ts").write_text(
'export function isLargeWidget(type: string): boolean {\n'
" return ['TABLE', 'LIST'].includes(type);\n"
'}\n'
"export function isLargeWidget(type: string): boolean {\n return ['TABLE', 'LIST'].includes(type);\n}\n"
)
# Re-export from index
(src_dir / "utils" / "index.ts").write_text(
"export { isLargeWidget } from './widgetUtils';\n"
)
(src_dir / "utils" / "index.ts").write_text("export { isLargeWidget } from './widgetUtils';\n")
# Component using the function via re-export
(src_dir / "components" / "Widget.tsx").write_text(
"import { isLargeWidget } from '../utils';\n"
'\n'
'export function Widget({ type }: { type: string }) {\n'
' const isLarge = isLargeWidget(type);\n'
' return isLarge;\n'
'}\n'
"\n"
"export function Widget({ type }: { type: string }) {\n"
" const isLarge = isLargeWidget(type);\n"
" return isLarge;\n"
"}\n"
)
return tmp_path
@ -745,28 +761,36 @@ class TestComplexMultiFileScenarios:
source_file = project_root / "src" / "utils" / "widgetUtils.ts"
refs = finder.find_references(make_func("isLargeWidget", source_file))
ref_infos = sorted([
ref_infos = sorted(
[
ReferenceInfo(
file_path=r.file_path, line=r.line, column=r.column,
end_line=r.end_line, end_column=r.end_column, context=r.context,
reference_type=r.reference_type, import_name=r.import_name,
file_path=r.file_path,
line=r.line,
column=r.column,
end_line=r.end_line,
end_column=r.end_column,
context=r.context,
reference_type=r.reference_type,
import_name=r.import_name,
caller_function=r.caller_function,
)
for r in refs
], key=lambda r: str(r.file_path))
],
key=lambda r: str(r.file_path),
)
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.TYPESCRIPT)
expected_markdown = (
'```typescript:src/components/Widget.tsx\n'
'function Widget({ type }: { type: string }) {\n'
' const isLarge = isLargeWidget(type);\n'
' return isLarge;\n'
'}\n'
'```\n'
'```typescript:src/utils/index.ts\n'
"```typescript:src/components/Widget.tsx\n"
"function Widget({ type }: { type: string }) {\n"
" const isLarge = isLargeWidget(type);\n"
" return isLarge;\n"
"}\n"
"```\n"
"```typescript:src/utils/index.ts\n"
"export { isLargeWidget } from './widgetUtils';\n"
'```\n'
"```\n"
)
assert markdown == expected_markdown
@ -794,13 +818,13 @@ class TestEdgeCases:
"""Test handling of non-exported function."""
# Create a file with non-exported function
(project_root / "src" / "private.ts").write_text(
'function internalHelper() {\n'
' return 42;\n'
'}\n'
'\n'
'export function publicFunction() {\n'
' return internalHelper();\n'
'}\n'
"function internalHelper() {\n"
" return 42;\n"
"}\n"
"\n"
"export function publicFunction() {\n"
" return internalHelper();\n"
"}\n"
)
finder = ReferenceFinder(project_root)
@ -824,7 +848,9 @@ class TestEdgeCases:
def test_format_references_empty_list(self, project_root):
"""Test _format_references_as_markdown with empty list."""
markdown = _format_references_as_markdown([], project_root / "src" / "file.ts", project_root, Language.TYPESCRIPT)
markdown = _format_references_as_markdown(
[], project_root / "src" / "file.ts", project_root, Language.TYPESCRIPT
)
assert markdown == ""
@ -839,22 +865,22 @@ class TestCommonJSPatterns:
# CommonJS module
(src_dir / "helpers.js").write_text(
'function processConfig(config) {\n'
' return { ...config, processed: true };\n'
'}\n'
'\n'
'module.exports = { processConfig };\n'
"function processConfig(config) {\n"
" return { ...config, processed: true };\n"
"}\n"
"\n"
"module.exports = { processConfig };\n"
)
# Consumer using destructured require
(src_dir / "main.js").write_text(
"const { processConfig } = require('./helpers');\n"
'\n'
'function handleConfig(config) {\n'
' return processConfig(config);\n'
'}\n'
'\n'
'module.exports = handleConfig;\n'
"\n"
"function handleConfig(config) {\n"
" return processConfig(config);\n"
"}\n"
"\n"
"module.exports = handleConfig;\n"
)
return tmp_path
@ -879,24 +905,28 @@ class TestCommonJSPatterns:
source_file = project_root / "src" / "helpers.js"
refs = finder.find_references(make_func("processConfig", source_file))
ref_infos = sorted([
ref_infos = sorted(
[
ReferenceInfo(
file_path=r.file_path, line=r.line, column=r.column,
end_line=r.end_line, end_column=r.end_column, context=r.context,
reference_type=r.reference_type, import_name=r.import_name,
file_path=r.file_path,
line=r.line,
column=r.column,
end_line=r.end_line,
end_column=r.end_column,
context=r.context,
reference_type=r.reference_type,
import_name=r.import_name,
caller_function=r.caller_function,
)
for r in refs
], key=lambda r: str(r.file_path))
],
key=lambda r: str(r.file_path),
)
markdown = _format_references_as_markdown(ref_infos, source_file, project_root, Language.JAVASCRIPT)
expected_markdown = (
'```javascript:src/main.js\n'
'function handleConfig(config) {\n'
' return processConfig(config);\n'
'}\n'
'```\n'
"```javascript:src/main.js\nfunction handleConfig(config) {\n return processConfig(config);\n}\n```\n"
)
assert markdown == expected_markdown
@ -910,18 +940,10 @@ class TestConvenienceFunction:
src_dir = tmp_path / "src"
src_dir.mkdir()
(src_dir / "utils.ts").write_text(
'export function helper() {\n'
' return 42;\n'
'}\n'
)
(src_dir / "utils.ts").write_text("export function helper() {\n return 42;\n}\n")
(src_dir / "main.ts").write_text(
"import { helper } from './utils';\n"
'\n'
'export function main() {\n'
' return helper();\n'
'}\n'
"import { helper } from './utils';\n\nexport function main() {\n return helper();\n}\n"
)
return tmp_path
@ -988,10 +1010,7 @@ class TestExportedFunctionDataclass:
def test_exported_function_named(self, tmp_path):
"""Test ExportedFunction for named export."""
exp = ExportedFunction(
function_name="myHelper",
export_name="myHelper",
is_default=False,
file_path=tmp_path / "utils.ts",
function_name="myHelper", export_name="myHelper", is_default=False, file_path=tmp_path / "utils.ts"
)
assert exp.function_name == "myHelper"
@ -1002,10 +1021,7 @@ class TestExportedFunctionDataclass:
def test_exported_function_default(self, tmp_path):
"""Test ExportedFunction for default export."""
exp = ExportedFunction(
function_name="processData",
export_name="default",
is_default=True,
file_path=tmp_path / "processor.ts",
function_name="processData", export_name="default", is_default=True, file_path=tmp_path / "processor.ts"
)
assert exp.function_name == "processData"
@ -1046,23 +1062,19 @@ class TestEdgeCasesAdvanced:
# Create circular import structure
(src_dir / "a.ts").write_text(
"import { funcB } from './b';\n"
'\n'
'export function funcA() {\n'
' return funcB() + 1;\n'
'}\n'
"import { funcB } from './b';\n\nexport function funcA() {\n return funcB() + 1;\n}\n"
)
(src_dir / "b.ts").write_text(
"import { funcA } from './a';\n"
'\n'
'export function funcB() {\n'
' return 42;\n'
'}\n'
'\n'
'export function callsA() {\n'
' return funcA();\n'
'}\n'
"\n"
"export function funcB() {\n"
" return 42;\n"
"}\n"
"\n"
"export function callsA() {\n"
" return funcA();\n"
"}\n"
)
finder = ReferenceFinder(project_root)
@ -1080,19 +1092,11 @@ class TestEdgeCasesAdvanced:
"""Test that syntax errors in files are handled gracefully."""
src_dir = project_root / "src"
(src_dir / "valid.ts").write_text(
'export function validFunction() {\n'
' return 42;\n'
'}\n'
)
(src_dir / "valid.ts").write_text("export function validFunction() {\n return 42;\n}\n")
# Create a file with syntax error
(src_dir / "invalid.ts").write_text(
"import { validFunction } from './valid';\n"
'\n'
'export function broken( {\n'
' return validFunction(\n'
'}\n'
"import { validFunction } from './valid';\n\nexport function broken( {\n return validFunction(\n}\n"
)
finder = ReferenceFinder(project_root)

View file

@ -4,7 +4,6 @@ These tests verify that the ImportResolver correctly resolves import paths
to actual file paths, enabling multi-file context extraction.
"""
import pytest
from codeflash.languages.javascript.import_resolver import HelperSearchContext, ImportResolver, MultiFileHelperFinder

View file

@ -278,6 +278,7 @@ version = '1.0.0'
assert len(info.source_roots) == 1
assert len(info.test_roots) == 1
class TestXmlModuleExtraction:
"""Tests for XML-based module extraction replacing regex."""
@ -374,6 +375,7 @@ class TestMavenProfiles:
profiles = os.environ.get("CODEFLASH_MAVEN_PROFILES", "").strip()
assert profiles == "my-profile"
class TestMavenExecutableWithProjectRoot:
"""Tests for find_maven_executable with project_root parameter."""
@ -554,7 +556,6 @@ class TestAddCodeflashDependencyToPom:
def test_returns_false_when_no_dependencies_tag(self, tmp_path):
pom = tmp_path / "pom.xml"
pom.write_text(
'<?xml version="1.0"?>\n<project><modelVersion>4.0.0</modelVersion></project>\n',
encoding="utf-8",
'<?xml version="1.0"?>\n<project><modelVersion>4.0.0</modelVersion></project>\n', encoding="utf-8"
)
assert add_codeflash_dependency_to_pom(pom) is False

View file

@ -6,15 +6,15 @@ proceeding to SQLite file comparison (which would crash with FileNotFoundError s
instrumentation hooks never fired).
"""
from dataclasses import dataclass
from pathlib import Path
from codeflash.either import Failure
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults
from codeflash.models.test_type import TestType
def make_test_invocation(*, did_pass: bool, test_type: TestType = TestType.EXISTING_UNIT_TEST) -> FunctionTestInvocation:
def make_test_invocation(
*, did_pass: bool, test_type: TestType = TestType.EXISTING_UNIT_TEST
) -> FunctionTestInvocation:
"""Helper to create a FunctionTestInvocation with minimal required fields."""
return FunctionTestInvocation(
loop_index=1,
@ -101,7 +101,8 @@ class TestCandidateBehavioralTestGuard:
"""All test types failing should yield 0 total passed."""
results = TestResults()
for tt in [TestType.EXISTING_UNIT_TEST, TestType.GENERATED_REGRESSION, TestType.REPLAY_TEST]:
results.add(FunctionTestInvocation(
results.add(
FunctionTestInvocation(
loop_index=1,
id=InvocationId(
test_module_path="com.example.FooTest",
@ -117,7 +118,8 @@ class TestCandidateBehavioralTestGuard:
test_type=tt,
return_value=None,
timed_out=False,
))
)
)
report = results.get_test_pass_fail_report_by_type()
total_passed = sum(r.get("passed", 0) for r in report.values())
@ -129,7 +131,8 @@ class TestCandidateBehavioralTestGuard:
results = TestResults()
# Many failures
for i in range(5):
results.add(FunctionTestInvocation(
results.add(
FunctionTestInvocation(
loop_index=1,
id=InvocationId(
test_module_path="com.example.FooTest",
@ -145,9 +148,11 @@ class TestCandidateBehavioralTestGuard:
test_type=TestType.GENERATED_REGRESSION,
return_value=None,
timed_out=False,
))
)
)
# One pass
results.add(FunctionTestInvocation(
results.add(
FunctionTestInvocation(
loop_index=1,
id=InvocationId(
test_module_path="com.example.FooTest",
@ -163,7 +168,8 @@ class TestCandidateBehavioralTestGuard:
test_type=TestType.EXISTING_UNIT_TEST,
return_value=None,
timed_out=False,
))
)
)
report = results.get_test_pass_fail_report_by_type()
total_passed = sum(r.get("passed", 0) for r in report.values())

View file

@ -1,24 +1,17 @@
"""Tests for Java test result comparison."""
import json
import shutil
import sqlite3
import tempfile
from pathlib import Path
import pytest
from codeflash.languages.java.comparator import (
compare_invocations_directly,
compare_test_results,
values_equal,
)
from codeflash.languages.java.comparator import compare_invocations_directly, compare_test_results, values_equal
from codeflash.models.models import TestDiffScope
# Skip tests that require Java runtime if Java is not available
requires_java = pytest.mark.skipif(
shutil.which("java") is None,
reason="Java not found - skipping Comparator integration tests",
shutil.which("java") is None, reason="Java not found - skipping Comparator integration tests"
)
# Kryo-serialized bytes for common test values.
@ -38,7 +31,9 @@ KRYO_STR_RESULT3 = bytes([0x03, 0x01, 0x7B, 0x22, 0x72, 0x65, 0x73, 0x75, 0x6C,
KRYO_STR_VALUE1 = bytes([0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x31, 0xFD])
KRYO_STR_VALUE2 = bytes([0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x32, 0xFD])
KRYO_STR_VALUE42 = bytes([0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x34, 0x32, 0xFD])
KRYO_STR_VALUE100 = bytes([0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x31, 0x30, 0x30, 0xFD])
KRYO_STR_VALUE100 = bytes(
[0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x31, 0x30, 0x30, 0xFD]
)
KRYO_DOUBLE_1_0000000001 = bytes([0x0A, 0x38, 0xDF, 0x06, 0x00, 0x00, 0x00, 0xF0, 0x3F])
KRYO_DOUBLE_1_0000000002 = bytes([0x0A, 0x70, 0xBE, 0x0D, 0x00, 0x00, 0x00, 0xF0, 0x3F])
KRYO_NAN = bytes([0x0A, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF8, 0x7F])
@ -67,12 +62,8 @@ class TestDirectComparison:
def test_different_return_values(self):
"""Test detecting different return values."""
original = {
"1": {"result_json": '{"value": 42}', "error_json": None},
}
candidate = {
"1": {"result_json": '{"value": 99}', "error_json": None},
}
original = {"1": {"result_json": '{"value": 42}', "error_json": None}}
candidate = {"1": {"result_json": '{"value": 99}', "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
@ -89,7 +80,7 @@ class TestDirectComparison:
"2": {"result_json": '{"value": 100}', "error_json": None},
}
candidate = {
"1": {"result_json": '{"value": 42}', "error_json": None},
"1": {"result_json": '{"value": 42}', "error_json": None}
# Missing invocation 2
}
@ -101,9 +92,7 @@ class TestDirectComparison:
def test_extra_invocation_in_candidate(self):
"""Test detecting extra invocation in candidate."""
original = {
"1": {"result_json": '{"value": 42}', "error_json": None},
}
original = {"1": {"result_json": '{"value": 42}', "error_json": None}}
candidate = {
"1": {"result_json": '{"value": 42}', "error_json": None},
"2": {"result_json": '{"value": 100}', "error_json": None}, # Extra
@ -116,11 +105,9 @@ class TestDirectComparison:
def test_exception_differences(self):
"""Test detecting exception differences."""
original = {
"1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'},
}
original = {"1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'}}
candidate = {
"1": {"result_json": '{"value": 42}', "error_json": None}, # No exception
"1": {"result_json": '{"value": 42}', "error_json": None} # No exception
}
equivalent, diffs = compare_invocations_directly(original, candidate)
@ -176,12 +163,8 @@ class TestNumericValueEquality:
def test_numeric_comparison_in_direct_invocation(self):
"""Test that compare_invocations_directly uses numeric-aware comparison."""
original = {
"1": {"result_json": "0", "error_json": None},
}
candidate = {
"1": {"result_json": "0.0", "error_json": None},
}
original = {"1": {"result_json": "0", "error_json": None}}
candidate = {"1": {"result_json": "0.0", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
@ -189,12 +172,8 @@ class TestNumericValueEquality:
def test_integer_long_mismatch_resolved(self):
"""Test that Integer(42) vs Long(42) serialized differently are still equal."""
original = {
"1": {"result_json": "42", "error_json": None},
}
candidate = {
"1": {"result_json": "42.0", "error_json": None},
}
original = {"1": {"result_json": "42", "error_json": None}}
candidate = {"1": {"result_json": "42.0", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
@ -263,46 +242,30 @@ class TestNumericValueEquality:
def test_boolean_invocation_comparison(self):
"""Test boolean return values in full invocation comparison."""
original = {
"1": {"result_json": "true", "error_json": None},
}
candidate = {
"1": {"result_json": "true", "error_json": None},
}
original = {"1": {"result_json": "true", "error_json": None}}
candidate = {"1": {"result_json": "true", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
def test_boolean_mismatch_invocation_comparison(self):
"""Test boolean mismatch is correctly detected."""
original = {
"1": {"result_json": "true", "error_json": None},
}
candidate = {
"1": {"result_json": "false", "error_json": None},
}
original = {"1": {"result_json": "true", "error_json": None}}
candidate = {"1": {"result_json": "false", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is False
assert len(diffs) == 1
def test_array_invocation_comparison(self):
"""Test array return values in full invocation comparison."""
original = {
"1": {"result_json": "[1, 2, 3]", "error_json": None},
}
candidate = {
"1": {"result_json": "[1, 2, 3]", "error_json": None},
}
original = {"1": {"result_json": "[1, 2, 3]", "error_json": None}}
candidate = {"1": {"result_json": "[1, 2, 3]", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
def test_array_mismatch_invocation_comparison(self):
"""Test array mismatch is correctly detected."""
original = {
"1": {"result_json": "[1, 2, 3]", "error_json": None},
}
candidate = {
"1": {"result_json": "[1, 2, 4]", "error_json": None},
}
original = {"1": {"result_json": "[1, 2, 3]", "error_json": None}}
candidate = {"1": {"result_json": "[1, 2, 4]", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is False
assert len(diffs) == 1
@ -382,35 +345,25 @@ class TestComparisonWithRealData:
def test_string_result_comparison(self):
"""Test comparing string results."""
original = {
"1": {"result_json": '"Hello World"', "error_json": None},
}
candidate = {
"1": {"result_json": '"Hello World"', "error_json": None},
}
original = {"1": {"result_json": '"Hello World"', "error_json": None}}
candidate = {"1": {"result_json": '"Hello World"', "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
def test_array_result_comparison(self):
"""Test comparing array results."""
original = {
"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None},
}
candidate = {
"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None},
}
original = {"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None}}
candidate = {"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
def test_array_order_matters(self):
"""Test that array order matters for comparison."""
original = {
"1": {"result_json": "[1, 2, 3]", "error_json": None},
}
original = {"1": {"result_json": "[1, 2, 3]", "error_json": None}}
candidate = {
"1": {"result_json": "[3, 2, 1]", "error_json": None}, # Different order
"1": {"result_json": "[3, 2, 1]", "error_json": None} # Different order
}
equivalent, diffs = compare_invocations_directly(original, candidate)
@ -418,24 +371,16 @@ class TestComparisonWithRealData:
def test_object_result_comparison(self):
"""Test comparing object results."""
original = {
"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None},
}
candidate = {
"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None},
}
original = {"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None}}
candidate = {"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
def test_null_result(self):
"""Test comparing null results."""
original = {
"1": {"result_json": "null", "error_json": None},
}
candidate = {
"1": {"result_json": "null", "error_json": None},
}
original = {"1": {"result_json": "null", "error_json": None}}
candidate = {"1": {"result_json": "null", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
@ -462,11 +407,9 @@ class TestEdgeCases:
def test_whitespace_in_json(self):
"""Test that whitespace differences in JSON don't cause issues."""
original = {
"1": {"result_json": '{"a":1,"b":2}', "error_json": None},
}
original = {"1": {"result_json": '{"a":1,"b":2}', "error_json": None}}
candidate = {
"1": {"result_json": '{ "a": 1, "b": 2 }', "error_json": None}, # With spaces
"1": {"result_json": '{ "a": 1, "b": 2 }', "error_json": None} # With spaces
}
# Note: Direct string comparison will see these as different
@ -486,12 +429,8 @@ class TestEdgeCases:
def test_unicode_in_results(self):
"""Test handling unicode in results."""
original = {
"1": {"result_json": '"Hello 世界 🌍"', "error_json": None},
}
candidate = {
"1": {"result_json": '"Hello 世界 🌍"', "error_json": None},
}
original = {"1": {"result_json": '"Hello 世界 🌍"', "error_json": None}}
candidate = {"1": {"result_json": '"Hello 世界 🌍"', "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
@ -499,12 +438,8 @@ class TestEdgeCases:
def test_deeply_nested_objects(self):
"""Test handling deeply nested objects."""
nested = '{"a": {"b": {"c": {"d": {"e": 1}}}}}'
original = {
"1": {"result_json": nested, "error_json": None},
}
candidate = {
"1": {"result_json": nested, "error_json": None},
}
original = {"1": {"result_json": nested, "error_json": None}}
candidate = {"1": {"result_json": nested, "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
@ -573,9 +508,7 @@ class TestTestResultsTableSchema:
return _create
def test_comparator_reads_test_results_table_identical(
self, tmp_path: Path, create_test_results_db
):
def test_comparator_reads_test_results_table_identical(self, tmp_path: Path, create_test_results_db):
"""Test that Comparator correctly reads test_results table with identical results."""
original_path = tmp_path / "original.db"
candidate_path = tmp_path / "candidate.db"
@ -607,9 +540,7 @@ class TestTestResultsTableSchema:
assert equivalent is True
assert len(diffs) == 0
def test_comparator_reads_test_results_table_different_values(
self, tmp_path: Path, create_test_results_db
):
def test_comparator_reads_test_results_table_different_values(self, tmp_path: Path, create_test_results_db):
"""Test that Comparator detects different return values from test_results table."""
original_path = tmp_path / "original.db"
candidate_path = tmp_path / "candidate.db"
@ -621,7 +552,7 @@ class TestTestResultsTableSchema:
"loop_index": 1,
"iteration_id": "1_0",
"return_value": KRYO_STR_OLLEH,
},
}
]
candidate_results = [
@ -631,7 +562,7 @@ class TestTestResultsTableSchema:
"loop_index": 1,
"iteration_id": "1_0",
"return_value": KRYO_STR_WRONG, # Different result
},
}
]
create_test_results_db(original_path, original_results)
@ -644,9 +575,7 @@ class TestTestResultsTableSchema:
assert len(diffs) == 1
assert diffs[0].scope == TestDiffScope.RETURN_VALUE
def test_comparator_handles_multiple_loop_iterations(
self, tmp_path: Path, create_test_results_db
):
def test_comparator_handles_multiple_loop_iterations(self, tmp_path: Path, create_test_results_db):
"""Test that Comparator correctly handles multiple loop iterations."""
original_path = tmp_path / "original.db"
candidate_path = tmp_path / "candidate.db"
@ -676,9 +605,7 @@ class TestTestResultsTableSchema:
assert equivalent is True
assert len(diffs) == 0
def test_comparator_iteration_id_parsing(
self, tmp_path: Path, create_test_results_db
):
def test_comparator_iteration_id_parsing(self, tmp_path: Path, create_test_results_db):
"""Test that Comparator correctly parses iteration_id format 'iter_testIteration'."""
original_path = tmp_path / "original.db"
candidate_path = tmp_path / "candidate.db"
@ -711,32 +638,18 @@ class TestTestResultsTableSchema:
assert equivalent is True
assert len(diffs) == 0
def test_comparator_missing_result_in_candidate(
self, tmp_path: Path, create_test_results_db
):
def test_comparator_missing_result_in_candidate(self, tmp_path: Path, create_test_results_db):
"""Test that Comparator detects missing results in candidate."""
original_path = tmp_path / "original.db"
candidate_path = tmp_path / "candidate.db"
original_results = [
{
"loop_index": 1,
"iteration_id": "1_0",
"return_value": KRYO_INT_1,
},
{
"loop_index": 1,
"iteration_id": "2_0",
"return_value": KRYO_INT_2,
},
{"loop_index": 1, "iteration_id": "1_0", "return_value": KRYO_INT_1},
{"loop_index": 1, "iteration_id": "2_0", "return_value": KRYO_INT_2},
]
candidate_results = [
{
"loop_index": 1,
"iteration_id": "1_0",
"return_value": KRYO_INT_1,
},
{"loop_index": 1, "iteration_id": "1_0", "return_value": KRYO_INT_1}
# Missing second iteration
]
@ -779,12 +692,8 @@ class TestComparatorEdgeCases:
For truly different values, the difference must exceed the epsilon threshold.
"""
# These values differ by ~3e-10, which is within epsilon tolerance (1e-9)
original = {
"1": {"result_json": "3.14159", "error_json": None},
}
candidate = {
"1": {"result_json": "3.141590001", "error_json": None},
}
original = {"1": {"result_json": "3.14159", "error_json": None}}
candidate = {"1": {"result_json": "3.141590001", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True # Within epsilon tolerance
@ -792,11 +701,9 @@ class TestComparatorEdgeCases:
def test_float_values_significantly_different(self):
"""Float strings outside epsilon tolerance should be detected as different."""
original = {
"1": {"result_json": "3.14159", "error_json": None},
}
original = {"1": {"result_json": "3.14159", "error_json": None}}
candidate = {
"1": {"result_json": "3.14160", "error_json": None}, # Differs by ~1e-5
"1": {"result_json": "3.14160", "error_json": None} # Differs by ~1e-5
}
equivalent, diffs = compare_invocations_directly(original, candidate)
@ -806,12 +713,8 @@ class TestComparatorEdgeCases:
def test_nan_string_comparison(self):
"""NaN as a string return value should be comparable."""
original = {
"1": {"result_json": "NaN", "error_json": None},
}
candidate = {
"1": {"result_json": "NaN", "error_json": None},
}
original = {"1": {"result_json": "NaN", "error_json": None}}
candidate = {"1": {"result_json": "NaN", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
@ -819,12 +722,8 @@ class TestComparatorEdgeCases:
def test_nan_vs_number(self):
"""NaN vs a normal number should be detected as different."""
original = {
"1": {"result_json": "NaN", "error_json": None},
}
candidate = {
"1": {"result_json": "0.0", "error_json": None},
}
original = {"1": {"result_json": "NaN", "error_json": None}}
candidate = {"1": {"result_json": "0.0", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is False
@ -832,12 +731,8 @@ class TestComparatorEdgeCases:
def test_infinity_string_comparison(self):
"""Infinity as a string return value should be comparable."""
original = {
"1": {"result_json": "Infinity", "error_json": None},
}
candidate = {
"1": {"result_json": "Infinity", "error_json": None},
}
original = {"1": {"result_json": "Infinity", "error_json": None}}
candidate = {"1": {"result_json": "Infinity", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
@ -845,12 +740,8 @@ class TestComparatorEdgeCases:
def test_negative_infinity(self):
"""-Infinity as a string return value should be comparable."""
original = {
"1": {"result_json": "-Infinity", "error_json": None},
}
candidate = {
"1": {"result_json": "-Infinity", "error_json": None},
}
original = {"1": {"result_json": "-Infinity", "error_json": None}}
candidate = {"1": {"result_json": "-Infinity", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
@ -858,12 +749,8 @@ class TestComparatorEdgeCases:
def test_infinity_vs_negative_infinity(self):
"""Infinity and -Infinity should be detected as different."""
original = {
"1": {"result_json": "Infinity", "error_json": None},
}
candidate = {
"1": {"result_json": "-Infinity", "error_json": None},
}
original = {"1": {"result_json": "Infinity", "error_json": None}}
candidate = {"1": {"result_json": "-Infinity", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is False
@ -871,12 +758,8 @@ class TestComparatorEdgeCases:
def test_empty_collection_results(self):
"""Empty array '[]' as return value should be comparable."""
original = {
"1": {"result_json": "[]", "error_json": None},
}
candidate = {
"1": {"result_json": "[]", "error_json": None},
}
original = {"1": {"result_json": "[]", "error_json": None}}
candidate = {"1": {"result_json": "[]", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
@ -884,12 +767,8 @@ class TestComparatorEdgeCases:
def test_empty_object_results(self):
"""Empty object '{}' as return value should be comparable."""
original = {
"1": {"result_json": "{}", "error_json": None},
}
candidate = {
"1": {"result_json": "{}", "error_json": None},
}
original = {"1": {"result_json": "{}", "error_json": None}}
candidate = {"1": {"result_json": "{}", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
@ -917,12 +796,8 @@ class TestComparatorEdgeCases:
1e+17 as floats due to precision limits, making them indistinguishable.
This is a known limitation of floating-point comparison for very large integers.
"""
original = {
"1": {"result_json": "99999999999999999", "error_json": None},
}
candidate = {
"1": {"result_json": "99999999999999998", "error_json": None},
}
original = {"1": {"result_json": "99999999999999999", "error_json": None}}
candidate = {"1": {"result_json": "99999999999999998", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
# Due to float precision limits, these are considered equal
@ -931,12 +806,8 @@ class TestComparatorEdgeCases:
def test_large_number_significantly_different(self):
"""Large numbers with significant differences should be detected."""
original = {
"1": {"result_json": "100000000000000000", "error_json": None},
}
candidate = {
"1": {"result_json": "200000000000000000", "error_json": None},
}
original = {"1": {"result_json": "100000000000000000", "error_json": None}}
candidate = {"1": {"result_json": "200000000000000000", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is False
@ -944,12 +815,8 @@ class TestComparatorEdgeCases:
def test_null_vs_empty_string(self):
"""'null' and '""' should NOT be equivalent."""
original = {
"1": {"result_json": "null", "error_json": None},
}
candidate = {
"1": {"result_json": '""', "error_json": None},
}
original = {"1": {"result_json": "null", "error_json": None}}
candidate = {"1": {"result_json": '""', "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is False
@ -958,10 +825,7 @@ class TestComparatorEdgeCases:
def test_boolean_string_comparison(self):
"""Boolean strings 'true'/'false' should compare correctly."""
original = {
"1": {"result_json": "true", "error_json": None},
"2": {"result_json": "false", "error_json": None},
}
original = {"1": {"result_json": "true", "error_json": None}, "2": {"result_json": "false", "error_json": None}}
candidate = {
"1": {"result_json": "true", "error_json": None},
"2": {"result_json": "false", "error_json": None},
@ -972,12 +836,8 @@ class TestComparatorEdgeCases:
def test_boolean_true_vs_false(self):
"""'true' vs 'false' should be detected as different."""
original = {
"1": {"result_json": "true", "error_json": None},
}
candidate = {
"1": {"result_json": "false", "error_json": None},
}
original = {"1": {"result_json": "true", "error_json": None}}
candidate = {"1": {"result_json": "false", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is False
@ -1024,12 +884,8 @@ class TestComparatorErrorHandling:
def test_compare_with_none_return_values_direct(self):
"""Rows where result_json is None should be handled in direct comparison."""
original = {
"1": {"result_json": None, "error_json": None},
}
candidate = {
"1": {"result_json": None, "error_json": None},
}
original = {"1": {"result_json": None, "error_json": None}}
candidate = {"1": {"result_json": None, "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
@ -1037,12 +893,8 @@ class TestComparatorErrorHandling:
def test_compare_one_none_one_value_direct(self):
"""One None result vs a real value should detect the difference."""
original = {
"1": {"result_json": None, "error_json": None},
}
candidate = {
"1": {"result_json": "42", "error_json": None},
}
original = {"1": {"result_json": None, "error_json": None}}
candidate = {"1": {"result_json": "42", "error_json": None}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is False
@ -1050,12 +902,8 @@ class TestComparatorErrorHandling:
def test_compare_both_errors_identical(self):
"""Identical errors in both original and candidate should be equivalent."""
original = {
"1": {"result_json": None, "error_json": '{"type": "IOException", "message": "file not found"}'},
}
candidate = {
"1": {"result_json": None, "error_json": '{"type": "IOException", "message": "file not found"}'},
}
original = {"1": {"result_json": None, "error_json": '{"type": "IOException", "message": "file not found"}'}}
candidate = {"1": {"result_json": None, "error_json": '{"type": "IOException", "message": "file not found"}'}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
@ -1063,12 +911,8 @@ class TestComparatorErrorHandling:
def test_compare_different_error_types(self):
"""Different error types should be detected."""
original = {
"1": {"result_json": None, "error_json": '{"type": "IOException"}'},
}
candidate = {
"1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'},
}
original = {"1": {"result_json": None, "error_json": '{"type": "IOException"}'}}
candidate = {"1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'}}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is False
@ -1083,9 +927,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
Extends TestTestResultsTableSchema to reuse the create_test_results_db fixture.
"""
def test_comparator_float_epsilon_tolerance(
self, tmp_path: Path, create_test_results_db
):
def test_comparator_float_epsilon_tolerance(self, tmp_path: Path, create_test_results_db):
"""Values differing by less than EPSILON (1e-9) should be treated as equivalent.
The Java Comparator uses EPSILON=1e-9 for float comparison.
@ -1102,7 +944,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
"loop_index": 1,
"iteration_id": "1_0",
"return_value": KRYO_DOUBLE_1_0000000001,
},
}
]
candidate_results = [
@ -1112,7 +954,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
"loop_index": 1,
"iteration_id": "1_0",
"return_value": KRYO_DOUBLE_1_0000000002,
},
}
]
create_test_results_db(original_path, original_results)
@ -1124,9 +966,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
assert equivalent is True
assert len(diffs) == 0
def test_comparator_nan_handling(
self, tmp_path: Path, create_test_results_db
):
def test_comparator_nan_handling(self, tmp_path: Path, create_test_results_db):
"""Java Comparator should handle NaN return values."""
original_path = tmp_path / "original.db"
candidate_path = tmp_path / "candidate.db"
@ -1138,7 +978,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
"loop_index": 1,
"iteration_id": "1_0",
"return_value": KRYO_NAN,
},
}
]
create_test_results_db(original_path, results)
@ -1150,9 +990,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
assert equivalent is True
assert len(diffs) == 0
def test_comparator_empty_table(
self, tmp_path: Path, create_test_results_db
):
def test_comparator_empty_table(self, tmp_path: Path, create_test_results_db):
"""Empty test_results tables should result in equivalent=False (vacuous equivalence guard)."""
original_path = tmp_path / "original.db"
candidate_path = tmp_path / "candidate.db"
@ -1167,9 +1005,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema):
assert equivalent is False
assert len(diffs) == 0
def test_comparator_infinity_handling(
self, tmp_path: Path, create_test_results_db
):
def test_comparator_infinity_handling(self, tmp_path: Path, create_test_results_db):
"""Java Comparator should handle Infinity return values correctly."""
original_path = tmp_path / "original.db"
candidate_path = tmp_path / "candidate.db"

View file

@ -7,25 +7,12 @@ fail with an error to maintain strict correctness guarantees.
import inspect
import sqlite3
from dataclasses import dataclass
from pathlib import Path
import pytest
from codeflash.languages.java.comparator import (
compare_test_results as java_compare_test_results,
)
from codeflash.models.models import (
FunctionTestInvocation,
InvocationId,
TestDiffScope,
TestResults,
TestType,
VerificationType,
)
from codeflash.verification.equivalence import (
compare_test_results as python_compare_test_results,
)
from codeflash.languages.java.comparator import compare_test_results as java_compare_test_results
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType
def make_invocation(
@ -142,7 +129,7 @@ class TestSqlitePathSelection:
"loop_index": 1,
"iteration_id": "1_0",
"return_value": '{"value": 42}',
},
}
]
create_test_results_db(original_path, results)
create_test_results_db(candidate_path, results)

View file

@ -3,14 +3,9 @@
import tempfile
from pathlib import Path
import pytest
from codeflash.languages.base import FunctionInfo
from codeflash.languages.java.concurrency_analyzer import JavaConcurrencyAnalyzer, analyze_function_concurrency
from codeflash.languages.language_enum import Language
from codeflash.languages.java.concurrency_analyzer import (
JavaConcurrencyAnalyzer,
analyze_function_concurrency,
)
class TestCompletableFutureDetection:

View file

@ -2,25 +2,19 @@
from pathlib import Path
import pytest
from codeflash.languages.base import FunctionFilterCriteria, Language, ParentInfo
from codeflash.languages.base import FunctionFilterCriteria, Language
from codeflash.languages.java.context import (
TypeSkeleton,
_extract_public_method_signatures,
_extract_type_skeleton,
_format_skeleton_for_context,
extract_class_context,
extract_code_context,
extract_function_source,
extract_read_only_context,
find_helper_functions,
get_java_imported_type_skeletons,
_extract_public_method_signatures,
_format_skeleton_for_context,
)
from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.import_resolver import JavaImportResolver, ResolvedImport
from codeflash.languages.java.parser import JavaImportInfo, get_java_analyzer
from codeflash.languages.java.parser import get_java_analyzer
# Filter criteria that includes void methods
NO_RETURN_FILTER = FunctionFilterCriteria(require_return=False)
@ -1785,12 +1779,15 @@ class TestExtractCodeContextEdgeCases:
def test_unicode_in_source(self, tmp_path: Path):
"""Test context extraction for method with unicode characters."""
java_file = tmp_path / "Unicode.java"
java_file.write_text("""public class Unicode {
java_file.write_text(
"""public class Unicode {
public String greet() {
return "こんにちは世界";
}
}
""", encoding="utf-8")
""",
encoding="utf-8",
)
functions = discover_functions_from_source(java_file.read_text(encoding="utf-8"), file_path=java_file)
assert len(functions) == 1

View file

@ -4,14 +4,7 @@ import os
from pathlib import Path
from unittest.mock import patch
import pytest
from codeflash.languages.java.formatter import (
JavaFormatter,
format_java_code,
format_java_file,
normalize_java_code,
)
from codeflash.languages.java.formatter import JavaFormatter, format_java_code, format_java_file, normalize_java_code
from codeflash.setup.detector import _detect_formatter
@ -201,12 +194,12 @@ class TestNormalizationEdgeCases:
def test_string_with_comment_chars(self):
"""Test string containing comment characters."""
source = '''
source = """
public class Example {
String s1 = "// not a comment";
String s2 = "/* also not */";
}
'''
"""
normalized = normalize_java_code(source)
# Note: current implementation incorrectly removes content in s2 string
expected = 'public class Example {\nString s1 = "// not a comment";\nString s2 = "";\n}'
@ -273,10 +266,7 @@ class TestDetectJavaFormatter:
def test_detect_formatter_returns_empty_when_java_not_available(self, tmp_path: Path):
"""Detector returns empty list with descriptive message when Java is not found."""
with (
patch.dict(os.environ, {}, clear=True),
patch("shutil.which", return_value=None),
):
with patch.dict(os.environ, {}, clear=True), patch("shutil.which", return_value=None):
cmds, description = _detect_formatter(tmp_path, "java")
assert cmds == []

View file

@ -2,14 +2,7 @@
from pathlib import Path
import pytest
from codeflash.languages.java.import_resolver import (
JavaImportResolver,
ResolvedImport,
find_helper_files,
resolve_imports_for_file,
)
from codeflash.languages.java.import_resolver import JavaImportResolver, ResolvedImport, find_helper_files
from codeflash.languages.java.parser import JavaImportInfo
@ -21,11 +14,7 @@ class TestJavaImportResolver:
resolver = JavaImportResolver(tmp_path)
import_info = JavaImportInfo(
import_path="java.util.List",
is_static=False,
is_wildcard=False,
start_line=1,
end_line=1,
import_path="java.util.List", is_static=False, is_wildcard=False, start_line=1, end_line=1
)
resolved = resolver.resolve_import(import_info)
@ -38,11 +27,7 @@ class TestJavaImportResolver:
resolver = JavaImportResolver(tmp_path)
import_info = JavaImportInfo(
import_path="javax.annotation.Nullable",
is_static=False,
is_wildcard=False,
start_line=1,
end_line=1,
import_path="javax.annotation.Nullable", is_static=False, is_wildcard=False, start_line=1, end_line=1
)
resolved = resolver.resolve_import(import_info)
@ -53,11 +38,7 @@ class TestJavaImportResolver:
resolver = JavaImportResolver(tmp_path)
import_info = JavaImportInfo(
import_path="org.junit.jupiter.api.Test",
is_static=False,
is_wildcard=False,
start_line=1,
end_line=1,
import_path="org.junit.jupiter.api.Test", is_static=False, is_wildcard=False, start_line=1, end_line=1
)
resolved = resolver.resolve_import(import_info)
@ -89,11 +70,7 @@ public class StringUtils {
resolver = JavaImportResolver(tmp_path)
import_info = JavaImportInfo(
import_path="com.example.utils.StringUtils",
is_static=False,
is_wildcard=False,
start_line=1,
end_line=1,
import_path="com.example.utils.StringUtils", is_static=False, is_wildcard=False, start_line=1, end_line=1
)
resolved = resolver.resolve_import(import_info)
@ -107,11 +84,7 @@ public class StringUtils {
resolver = JavaImportResolver(tmp_path)
import_info = JavaImportInfo(
import_path="java.util",
is_static=False,
is_wildcard=True,
start_line=1,
end_line=1,
import_path="java.util", is_static=False, is_wildcard=True, start_line=1, end_line=1
)
resolved = resolver.resolve_import(import_info)
@ -123,11 +96,7 @@ public class StringUtils {
resolver = JavaImportResolver(tmp_path)
import_info = JavaImportInfo(
import_path="java.lang.Math.PI",
is_static=True,
is_wildcard=False,
start_line=1,
end_line=1,
import_path="java.lang.Math.PI", is_static=True, is_wildcard=False, start_line=1, end_line=1
)
resolved = resolver.resolve_import(import_info)
@ -286,11 +255,7 @@ class TestResolvedImport:
def test_resolved_import_external(self):
"""Test ResolvedImport for external dependency."""
resolved = ResolvedImport(
import_path="java.util.List",
file_path=None,
is_external=True,
is_wildcard=False,
class_name="List",
import_path="java.util.List", file_path=None, is_external=True, is_wildcard=False, class_name="List"
)
assert resolved.is_external is True
assert resolved.file_path is None

View file

@ -221,9 +221,7 @@ public class StringUtilsTest {
return new String(chars);
}"""
optimized = support.replace_function(
src_file.read_text(), functions[0], new_code
)
optimized = support.replace_function(src_file.read_text(), functions[0], new_code)
assert "Optimized version" in optimized
assert "StringUtils" in optimized

View file

@ -100,7 +100,9 @@ class TestFixJavaTestPathsIntegration:
# Bind the actual methods
mock_optimizer._get_java_sources_root = lambda: JavaFunctionOptimizer._get_java_sources_root(mock_optimizer)
mock_optimizer._fix_java_test_paths = lambda behavior_source, perf_source, used_paths: JavaFunctionOptimizer._fix_java_test_paths(mock_optimizer, behavior_source, perf_source, used_paths)
mock_optimizer._fix_java_test_paths = lambda behavior_source, perf_source, used_paths: (
JavaFunctionOptimizer._fix_java_test_paths(mock_optimizer, behavior_source, perf_source, used_paths)
)
return mock_optimizer
@ -133,8 +135,14 @@ public class UnpackerTest__perfonlyinstrumented {
# The path should be test/src/com/aerospike/client/util/UnpackerTest__perfinstrumented.java
# NOT test/src/com/aerospike/test/com/aerospike/client/util/...
expected_java_root = tmp_path / "test" / "src"
assert behavior_path == expected_java_root / "com" / "aerospike" / "client" / "util" / "UnpackerTest__perfinstrumented.java"
assert perf_path == expected_java_root / "com" / "aerospike" / "client" / "util" / "UnpackerTest__perfonlyinstrumented.java"
assert (
behavior_path
== expected_java_root / "com" / "aerospike" / "client" / "util" / "UnpackerTest__perfinstrumented.java"
)
assert (
perf_path
== expected_java_root / "com" / "aerospike" / "client" / "util" / "UnpackerTest__perfonlyinstrumented.java"
)
# Verify there's no duplication in the path
assert "com/aerospike/test/com" not in str(behavior_path)
@ -169,6 +177,7 @@ public class CalculatorTest__perfonlyinstrumented {
assert behavior_path == tests_root / "com" / "example" / "CalculatorTest__perfinstrumented.java"
assert perf_path == tests_root / "com" / "example" / "CalculatorTest__perfonlyinstrumented.java"
class TestPathToClassNameWithCustomDirs:
"""Tests for _path_to_class_name with custom source directories."""

View file

@ -62,14 +62,7 @@ public class Calculator {
"targets": [
{
"className": "com/example/Calculator",
"methods": [
{
"name": "add",
"startLine": 4,
"endLine": 7,
"sourceFile": file_path.as_posix(),
}
],
"methods": [{"name": "add", "startLine": 4, "endLine": 7, "sourceFile": file_path.as_posix()}],
}
],
"lineContents": {
@ -172,18 +165,8 @@ public class Calculator {
config = json.loads(config_path.read_text())
assert config["targets"][0]["methods"] == [
{
"name": "method1",
"startLine": 2,
"endLine": 4,
"sourceFile": file_path.as_posix(),
},
{
"name": "method2",
"startLine": 6,
"endLine": 8,
"sourceFile": file_path.as_posix(),
},
{"name": "method1", "startLine": 2, "endLine": 4, "sourceFile": file_path.as_posix()},
{"name": "method2", "startLine": 6, "endLine": 8, "sourceFile": file_path.as_posix()},
]
def test_empty_function_list(self):
@ -403,12 +386,7 @@ class TestAgentConfigBoundaryConditions:
{
"className": "Test",
"methods": [
{
"name": "foo",
"startLine": 100,
"endLine": 200,
"sourceFile": file_path.as_posix(),
}
{"name": "foo", "startLine": 100, "endLine": 200, "sourceFile": file_path.as_posix()}
],
}
],
@ -450,12 +428,7 @@ class TestAgentConfigBoundaryConditions:
{
"className": "Test",
"methods": [
{
"name": "foo",
"startLine": -5,
"endLine": -1,
"sourceFile": file_path.as_posix(),
}
{"name": "foo", "startLine": -5, "endLine": -1, "sourceFile": file_path.as_posix()}
],
}
],
@ -496,9 +469,7 @@ class TestLineProfileResultsParsing:
results = JavaLineProfiler.parse_results(profile_file)
assert results["unit"] == 1e-9
assert results["timings"] == {
("/tmp/Test.java", 10, "Test.java"): [(10, 100, 5000000), (11, 100, 95000000)]
}
assert results["timings"] == {("/tmp/Test.java", 10, "Test.java"): [(10, 100, 5000000), (11, 100, 95000000)]}
assert results["line_contents"] == {
("/tmp/Test.java", 10): "int x = compute();",
("/tmp/Test.java", 11): "result = slowOperation(x);",
@ -601,9 +572,7 @@ class TestLineProfileResultsParsing:
assert results == {
"unit": 1e-9,
"timings": {
("/tmp/Sorter.java", 5, "Sorter.java"): [(5, 10, 2000000), (6, 10, 8000000)]
},
"timings": {("/tmp/Sorter.java", 5, "Sorter.java"): [(5, 10, 2000000), (6, 10, 8000000)]},
"line_contents": {
("/tmp/Sorter.java", 5): "int n = arr.length;",
("/tmp/Sorter.java", 6): "for (int i = 0; i < n; i++) {",

View file

@ -1,5 +1,4 @@
"""Integration tests for Java line profiler with JavaSupport.
"""
"""Integration tests for Java line profiler with JavaSupport."""
import json
import math
@ -16,12 +15,10 @@ from codeflash.languages.java.support import get_java_support
class TestLineProfilerInstrumentation:
"""Integration tests for line profiler instrumentation through JavaSupport.
"""
"""Integration tests for line profiler instrumentation through JavaSupport."""
def test_instrument_with_package(self):
"""Test instrumentation for a class with a package declaration.
"""
"""Test instrumentation for a class with a package declaration."""
source = """package com.example;
public class Calculator {
@ -70,14 +67,7 @@ public class Calculator {
"targets": [
{
"className": "com/example/Calculator",
"methods": [
{
"name": "add",
"startLine": 4,
"endLine": 7,
"sourceFile": java_file.as_posix(),
}
],
"methods": [{"name": "add", "startLine": 4, "endLine": 7, "sourceFile": java_file.as_posix()}],
}
],
"lineContents": {
@ -155,12 +145,7 @@ public class Calculator {
{
"className": "Sorter",
"methods": [
{
"name": "sort",
"startLine": 2,
"endLine": 14,
"sourceFile": java_file.as_posix(),
}
{"name": "sort", "startLine": 2, "endLine": 14, "sourceFile": java_file.as_posix()}
],
}
],
@ -254,9 +239,7 @@ public class Calculator {
# Both methods should appear as targets when generated together
profiler = JavaLineProfiler(output_file=profile_output)
profiler.generate_agent_config(
source, java_file, [func_reverse, func_palindrome], config_path
)
profiler.generate_agent_config(source, java_file, [func_reverse, func_palindrome], config_path)
config = json.loads(config_path.read_text(encoding="utf-8"))
assert config == {
@ -266,12 +249,7 @@ public class Calculator {
{
"className": "StringProcessor",
"methods": [
{
"name": "reverse",
"startLine": 2,
"endLine": 14,
"sourceFile": java_file.as_posix(),
},
{"name": "reverse", "startLine": 2, "endLine": 14, "sourceFile": java_file.as_posix()},
{
"name": "isPalindrome",
"startLine": 16,
@ -355,12 +333,7 @@ public class StringUtils {
{
"className": "org/apache/commons/lang3/StringUtils",
"methods": [
{
"name": "isEmpty",
"startLine": 4,
"endLine": 6,
"sourceFile": java_file.as_posix(),
}
{"name": "isEmpty", "startLine": 4, "endLine": 6, "sourceFile": java_file.as_posix()}
],
}
],
@ -484,10 +457,7 @@ def run_spin_timer_profiled(tmppath: Path, spin_durations_ns: list[int]) -> dict
agent_arg = profiler.build_javaagent_arg(config_path)
result = subprocess.run(
["javac", "--release", "11", str(java_file)],
capture_output=True,
text=True,
cwd=str(tmppath),
["javac", "--release", "11", str(java_file)], capture_output=True, text=True, cwd=str(tmppath)
)
assert result.returncode == 0, f"javac failed: {result.stderr}"
@ -512,13 +482,7 @@ class TestSpinTimerProfiling:
profiler-reported total time matches the expected sum of all spin durations.
"""
@pytest.mark.parametrize(
"spin_durations_ns",
[
[50_000_000, 100_000_000],
[30_000_000, 40_000_000, 80_000_000],
],
)
@pytest.mark.parametrize("spin_durations_ns", [[50_000_000, 100_000_000], [30_000_000, 40_000_000, 80_000_000]])
def test_total_time_matches_expected(self, spin_durations_ns):
"""Profiler total time should match the sum of all spin durations."""
expected_ns = sum(spin_durations_ns)

View file

@ -3,13 +3,8 @@
import logging
from pathlib import Path
import pytest
from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.test_discovery import (
disambiguate_overloads,
discover_tests,
)
from codeflash.languages.java.test_discovery import disambiguate_overloads, discover_tests
class TestOverloadDisambiguation:
@ -109,9 +104,7 @@ public class CalculatorTest {
"""When overloaded methods are detected, info log fires."""
matched_names = ["Calculator.add", "StringUtils.add"]
with caplog.at_level(logging.INFO):
result = disambiguate_overloads(
matched_names, "testAdd", "some test source code"
)
result = disambiguate_overloads(matched_names, "testAdd", "some test source code")
assert result == matched_names
info_messages = [r.message for r in caplog.records if r.levelno == logging.INFO]

View file

@ -1,15 +1,6 @@
"""Tests for the Java tree-sitter parser utilities."""
import pytest
from codeflash.languages.java.parser import (
JavaAnalyzer,
JavaClassNode,
JavaFieldInfo,
JavaImportInfo,
JavaMethodNode,
get_java_analyzer,
)
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
class TestJavaAnalyzerBasic:

View file

@ -118,12 +118,7 @@ def java_project(tmp_path: Path):
def _make_optimizer(project_root: Path, test_dir: Path, function_name: str, src_file: Path) -> tuple:
"""Create an Optimizer and FunctionOptimizer for the given function."""
fto = FunctionToOptimize(
function_name=function_name,
file_path=src_file,
parents=[],
language="java",
)
fto = FunctionToOptimize(function_name=function_name, file_path=src_file, parents=[], language="java")
opt = Optimizer(
Namespace(
project_root=project_root,
@ -493,11 +488,7 @@ public class PreciseWaiterTest {
project_root, src_dir, test_dir = self._setup_precise_waiter_project(java_project)
test_results = self._instrument_and_run(
project_root,
src_dir,
test_dir,
self.PRECISE_WAITER_TEST,
"PreciseWaiterTest.java",
project_root, src_dir, test_dir, self.PRECISE_WAITER_TEST, "PreciseWaiterTest.java"
)
# 2 outer loops × 2 inner iterations = 4 total results
@ -542,9 +533,7 @@ public class PreciseWaiterTest {
runtime_by_test = test_results.usable_runtime_data_by_test_case()
# Should have 1 test case (constant iteration_id per call site)
assert len(runtime_by_test) == 1, (
f"Expected 1 test case (constant iteration_id), got {len(runtime_by_test)}"
)
assert len(runtime_by_test) == 1, f"Expected 1 test case (constant iteration_id), got {len(runtime_by_test)}"
# The single test case should have 4 runtimes (2 outer loops × 2 inner iterations)
for test_id, test_runtimes in runtime_by_test.items():
@ -584,11 +573,7 @@ public class PreciseWaiterMultiTest {
}
"""
test_results = self._instrument_and_run(
project_root,
src_dir,
test_dir,
multi_test_source,
"PreciseWaiterMultiTest.java",
project_root, src_dir, test_dir, multi_test_source, "PreciseWaiterMultiTest.java"
)
# 2 test methods × 2 outer loops × 2 inner iterations = 8 total results
@ -651,5 +636,3 @@ public class PreciseWaiterMultiTest {
f"total_passed_runtime {total_runtime / 1_000_000:.3f}ms not close to expected "
f"{expected_total_ns / 1_000_000:.1f}ms (2 methods × min of 4 runtimes × 10ms, ±3%)"
)

View file

@ -4,11 +4,7 @@ from pathlib import Path
import pytest
from codeflash.languages.java.test_runner import (
_validate_java_class_name,
_validate_test_filter,
get_test_run_command,
)
from codeflash.languages.java.test_runner import _validate_java_class_name, _validate_test_filter, get_test_run_command
class TestInputValidation:
@ -62,12 +58,7 @@ class TestInputValidation:
def test_validate_test_filter_wildcards(self):
"""Test validation of wildcard patterns."""
valid_patterns = [
"My*Test",
"*Test",
"com.example.*Test",
"com.example.**",
]
valid_patterns = ["My*Test", "*Test", "com.example.*Test", "com.example.**"]
for pattern in valid_patterns:
result = _validate_test_filter(pattern)
@ -203,7 +194,7 @@ class TestXMLParsingSecurity:
for i in range(3):
xml_file = surefire_dir / f"TEST-Suite{i}.xml"
xml_file.write_text(f"""<?xml version="1.0" encoding="UTF-8"?>
<testsuite tests="{i+1}" failures="0" errors="0" skipped="0">
<testsuite tests="{i + 1}" failures="0" errors="0" skipped="0">
<testcase name="test1" classname="Suite{i}" time="0.001"/>
</testsuite>
""")

View file

@ -108,9 +108,7 @@ public class CalculatorTest {
""")
# Get source functions
source_functions = discover_functions_from_source(
src_file.read_text(), file_path=src_file
)
source_functions = discover_functions_from_source(src_file.read_text(), file_path=src_file)
# Discover tests
result = discover_tests(test_dir, source_functions)
@ -168,7 +166,6 @@ public class StringUtilsTest {
# Create source function
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
func = FunctionToOptimize(
function_name="reverse",
@ -242,9 +239,7 @@ public class TestQueryBlob {
""")
# Get source functions
source_functions = discover_functions_from_source(
src_file.read_text(), file_path=src_file
)
source_functions = discover_functions_from_source(src_file.read_text(), file_path=src_file)
# Filter to just bytesToHexString
target_functions = [f for f in source_functions if f.function_name == "bytesToHexString"]
@ -288,9 +283,7 @@ public class IntegrationTest {
""")
# Get source functions
source_functions = discover_functions_from_source(
src_file.read_text(), file_path=src_file
)
source_functions = discover_functions_from_source(src_file.read_text(), file_path=src_file)
# Discover tests
result = discover_tests(test_dir, source_functions)
@ -325,8 +318,8 @@ class TestImportExtraction:
def test_basic_import(self):
"""Test extraction of basic import statement."""
from codeflash.languages.java.test_discovery import _extract_imports
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.languages.java.test_discovery import _extract_imports
analyzer = get_java_analyzer()
source = """
@ -341,8 +334,8 @@ public class Test {}
def test_multiple_imports(self):
"""Test extraction of multiple imports."""
from codeflash.languages.java.test_discovery import _extract_imports
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.languages.java.test_discovery import _extract_imports
analyzer = get_java_analyzer()
source = """
@ -358,8 +351,8 @@ public class Test {}
def test_wildcard_import_returns_empty(self):
"""Test that wildcard imports don't add specific classes."""
from codeflash.languages.java.test_discovery import _extract_imports
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.languages.java.test_discovery import _extract_imports
analyzer = get_java_analyzer()
source = """
@ -374,8 +367,8 @@ public class Test {}
def test_static_import_extracts_class(self):
"""Test that static imports extract the class name, not the method."""
from codeflash.languages.java.test_discovery import _extract_imports
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.languages.java.test_discovery import _extract_imports
analyzer = get_java_analyzer()
source = """
@ -390,8 +383,8 @@ public class Test {}
def test_static_wildcard_import_extracts_class(self):
"""Test that static wildcard imports extract the class name."""
from codeflash.languages.java.test_discovery import _extract_imports
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.languages.java.test_discovery import _extract_imports
analyzer = get_java_analyzer()
source = """
@ -406,8 +399,8 @@ public class Test {}
def test_deeply_nested_package(self):
"""Test extraction from deeply nested package."""
from codeflash.languages.java.test_discovery import _extract_imports
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.languages.java.test_discovery import _extract_imports
analyzer = get_java_analyzer()
source = """
@ -422,8 +415,8 @@ public class Test {}
def test_mixed_imports(self):
"""Test extraction with mix of regular, static, and wildcard imports."""
from codeflash.languages.java.test_discovery import _extract_imports
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.languages.java.test_discovery import _extract_imports
analyzer = get_java_analyzer()
source = """
@ -448,8 +441,8 @@ class TestMethodCallDetection:
def test_find_method_calls(self):
"""Test detection of method calls within a code range."""
from codeflash.languages.java.test_discovery import _find_method_calls_in_range
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.languages.java.test_discovery import _find_method_calls_in_range
analyzer = get_java_analyzer()
source = """

View file

@ -319,12 +319,7 @@ class TestJavaCompilation:
pytest.skip("Maven not installed")
# Compile the project
result = subprocess.run(
["mvn", "compile", "-q"],
cwd=java_project_dir,
capture_output=True,
timeout=120,
)
result = subprocess.run(["mvn", "compile", "-q"], cwd=java_project_dir, capture_output=True, timeout=120)
assert result.returncode == 0, f"Compilation failed: {result.stderr.decode()}"
@ -342,11 +337,6 @@ class TestJavaCompilation:
pytest.skip("Maven not installed")
# Run tests
result = subprocess.run(
["mvn", "test", "-q"],
cwd=java_project_dir,
capture_output=True,
timeout=180,
)
result = subprocess.run(["mvn", "test", "-q"], cwd=java_project_dir, capture_output=True, timeout=180)
assert result.returncode == 0, f"Tests failed: {result.stderr.decode()}"

View file

@ -19,10 +19,7 @@ def make_func(name: str, class_name: str | None = None) -> FunctionToOptimize:
"""Helper to create FunctionToOptimize for testing."""
parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else []
return FunctionToOptimize(
function_name=name,
file_path=Path("/test/file.js"),
parents=parents,
language="javascript",
function_name=name, file_path=Path("/test/file.js"), parents=parents, language="javascript"
)
@ -386,7 +383,9 @@ test('fibonacci works', () => {
});
"""
transformed, counter = transform_expect_calls(
code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture"
code=code,
function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"),
capture_func="capture",
)
# Should transform expect(calc.fibonacci(10)) to
@ -433,7 +432,9 @@ class FibonacciCalculator {
}
"""
transformed, counter = transform_standalone_calls(
code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture"
code=code,
function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"),
capture_func="capture",
)
# The method definition should NOT be transformed
@ -452,7 +453,9 @@ FibonacciCalculator.prototype.fibonacci = function(n) {
};
"""
transformed, counter = transform_standalone_calls(
code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture"
code=code,
function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"),
capture_func="capture",
)
# The prototype assignment should NOT be transformed
@ -558,7 +561,10 @@ describe('Calculator', () => {
});
"""
instrumented = _instrument_js_test_code(
code=test_code, function_to_optimize=make_func("add", class_name="Calculator"), test_file_path="test.js", mode="behavior"
code=test_code,
function_to_optimize=make_func("add", class_name="Calculator"),
test_file_path="test.js",
mode="behavior",
)
# describe and test structure should be preserved
@ -886,15 +892,15 @@ test('should compute fibonacci(20) and fibonacci(30) to known values', () => {
from codeflash.languages.javascript.instrument import transform_standalone_calls
func = make_func("fibonacci")
code = '''
code = """
test("should compute fibonacci(20) correctly", () => {
const result = fibonacci(10);
});
'''
"""
transformed, _counter = transform_standalone_calls(code, func, "capture")
# The function call in the test description should NOT be transformed
assert 'fibonacci(20)' in transformed
assert "fibonacci(20)" in transformed
# The actual call should be transformed
assert "codeflash.capture('fibonacci'" in transformed

View file

@ -8,13 +8,11 @@ These tests call the actual backend /testgen API endpoint and verify:
Similar to test_validate_python_code.py but for JavaScript/TypeScript.
"""
from pathlib import Path
from unittest.mock import patch
import pytest
from codeflash.api.aiservice import AiServiceClient
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.models.models import CodeString, OptimizedCandidateSource
@ -23,6 +21,7 @@ def skip_if_js_not_supported():
"""Skip test if JavaScript/TypeScript languages are not supported."""
try:
from codeflash.languages import get_language_support
get_language_support(Language.JAVASCRIPT)
except Exception as e:
pytest.skip(f"JavaScript/TypeScript language support not available: {e}")
@ -218,12 +217,13 @@ export function add(a: number, b: number): number {
def capture_request(*args, **kwargs):
nonlocal captured_payload
if 'payload' in kwargs:
captured_payload = kwargs['payload']
if "payload" in kwargs:
captured_payload = kwargs["payload"]
elif len(args) > 1:
captured_payload = args[1]
# Return a mock response to avoid actual API call
from unittest.mock import MagicMock
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
@ -233,7 +233,7 @@ export function add(a: number, b: number): number {
}
return mock_response
with patch.object(ai_client, 'make_ai_service_request', side_effect=capture_request):
with patch.object(ai_client, "make_ai_service_request", side_effect=capture_request):
ai_client.generate_regression_tests(
source_code_being_tested=ts_file.read_text(),
function_to_optimize=func,
@ -248,8 +248,9 @@ export function add(a: number, b: number): number {
)
assert captured_payload is not None
assert captured_payload.get('language') == 'typescript', \
assert captured_payload.get("language") == "typescript", (
f"Expected language='typescript', got: {captured_payload.get('language')}"
)
def test_testgen_request_includes_javascript_language(self, tmp_path):
"""Verify the language parameter is sent as 'javascript' for .js files."""
@ -279,11 +280,12 @@ module.exports = { add };
def capture_request(*args, **kwargs):
nonlocal captured_payload
if 'payload' in kwargs:
captured_payload = kwargs['payload']
if "payload" in kwargs:
captured_payload = kwargs["payload"]
elif len(args) > 1:
captured_payload = args[1]
from unittest.mock import MagicMock
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
@ -293,7 +295,7 @@ module.exports = { add };
}
return mock_response
with patch.object(ai_client, 'make_ai_service_request', side_effect=capture_request):
with patch.object(ai_client, "make_ai_service_request", side_effect=capture_request):
ai_client.generate_regression_tests(
source_code_being_tested=js_file.read_text(),
function_to_optimize=func,
@ -308,5 +310,6 @@ module.exports = { add };
)
assert captured_payload is not None
assert captured_payload.get('language') == 'javascript', \
assert captured_payload.get("language") == "javascript", (
f"Expected language='javascript', got: {captured_payload.get('language')}"
)

View file

@ -1,5 +1,4 @@
"""Tests for JavaScript module system detection.
"""
"""Tests for JavaScript module system detection."""
import json
import tempfile

View file

@ -86,9 +86,7 @@ export function add(a: number, b: number): number {
ts_support = get_language_support(Language.TYPESCRIPT)
code_context = ts_support.extract_code_context(func, tmp_path, tmp_path)
context = JavaScriptFunctionOptimizer._build_optimization_context(
code_context, ts_file, "typescript", tmp_path
)
context = JavaScriptFunctionOptimizer._build_optimization_context(code_context, ts_file, "typescript", tmp_path)
assert context.read_writable_code is not None
assert context.read_writable_code.language == "typescript"
@ -193,8 +191,9 @@ export function add(a: number, b: number): number {
assert mock_request.called, "API request should have been made"
call_args = mock_request.call_args
payload = call_args[1].get("payload", call_args[0][1] if len(call_args[0]) > 1 else {})
assert payload.get("language") == "typescript", \
assert payload.get("language") == "typescript", (
f"Expected language='typescript', got language='{payload.get('language')}'"
)
class TestFunctionOptimizerForJavaScript:
@ -328,9 +327,7 @@ describe('fibonacci', () => {
)
optimizer = FunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),
function_to_optimize=func_to_optimize, test_cfg=test_config, aiservice_client=MagicMock()
)
assert optimizer is not None
@ -363,9 +360,7 @@ describe('fibonacci', () => {
)
optimizer = FunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),
function_to_optimize=func_to_optimize, test_cfg=test_config, aiservice_client=MagicMock()
)
assert optimizer is not None
@ -398,9 +393,7 @@ describe('fibonacci', () => {
)
optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),
function_to_optimize=func_to_optimize, test_cfg=test_config, aiservice_client=MagicMock()
)
result = optimizer.get_code_optimization_context()
@ -437,9 +430,7 @@ describe('fibonacci', () => {
)
optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),
function_to_optimize=func_to_optimize, test_cfg=test_config, aiservice_client=MagicMock()
)
result = optimizer.get_code_optimization_context()
@ -486,16 +477,11 @@ module.exports = { main };
)
test_config = TestConfig(
tests_root=tmp_path,
tests_project_rootdir=tmp_path,
project_root_path=tmp_path,
pytest_cmd="jest",
tests_root=tmp_path, tests_project_rootdir=tmp_path, project_root_path=tmp_path, pytest_cmd="jest"
)
optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),
function_to_optimize=func_to_optimize, test_cfg=test_config, aiservice_client=MagicMock()
)
result = optimizer.get_code_optimization_context()
@ -535,16 +521,11 @@ export function main(): number {
)
test_config = TestConfig(
tests_root=tmp_path,
tests_project_rootdir=tmp_path,
project_root_path=tmp_path,
pytest_cmd="vitest",
tests_root=tmp_path, tests_project_rootdir=tmp_path, project_root_path=tmp_path, pytest_cmd="vitest"
)
optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),
function_to_optimize=func_to_optimize, test_cfg=test_config, aiservice_client=MagicMock()
)
result = optimizer.get_code_optimization_context()

View file

@ -4,7 +4,6 @@ Tests the verify_requirements function that checks Node.js, npm, and test framew
"""
import json
import subprocess
from pathlib import Path
from unittest.mock import MagicMock, patch
@ -30,14 +29,7 @@ class TestVerifyRequirements:
(node_modules / "codeflash").mkdir()
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test-project",
"devDependencies": {"jest": "^29.0.0"},
}
)
)
package_json.write_text(json.dumps({"name": "test-project", "devDependencies": {"jest": "^29.0.0"}}))
return tmp_path
@pytest.fixture
@ -49,14 +41,7 @@ class TestVerifyRequirements:
(node_modules / "codeflash").mkdir()
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test-project",
"devDependencies": {"vitest": "^2.0.0"},
}
)
)
package_json.write_text(json.dumps({"name": "test-project", "devDependencies": {"vitest": "^2.0.0"}}))
return tmp_path
@pytest.fixture

View file

@ -2,7 +2,7 @@
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch
import pytest
@ -58,10 +58,7 @@ class TestJestRootsConfiguration:
try:
run_jest_behavioral_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
)
except Exception:
pass # Expected to fail since no real Jest
@ -79,9 +76,9 @@ class TestJestRootsConfiguration:
# Should have added the test directory as a root
assert len(roots_flags) > 0, "Expected --roots flag in Jest command"
assert str(test_dir) in roots_flags or any(
str(test_dir) in root for root in roots_flags
), f"Expected test directory {test_dir} in --roots flags: {roots_flags}"
assert str(test_dir) in roots_flags or any(str(test_dir) in root for root in roots_flags), (
f"Expected test directory {test_dir} in --roots flags: {roots_flags}"
)
def test_benchmarking_tests_adds_roots_for_test_directories(self):
"""Test that run_jest_benchmarking_tests adds --roots for test directories."""
@ -106,7 +103,7 @@ class TestJestRootsConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
)
]
)
@ -119,10 +116,7 @@ class TestJestRootsConfiguration:
try:
run_jest_benchmarking_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
)
except Exception:
pass
@ -161,7 +155,7 @@ class TestJestRootsConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
)
]
)
@ -174,10 +168,7 @@ class TestJestRootsConfiguration:
try:
run_jest_line_profile_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
)
except Exception:
pass
@ -239,10 +230,7 @@ class TestJestRootsConfiguration:
try:
run_jest_behavioral_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
)
except Exception:
pass
@ -286,7 +274,7 @@ class TestVitestTimeoutConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
)
]
)
@ -314,7 +302,9 @@ class TestVitestTimeoutConfiguration:
# Subprocess timeout should be at least 120 seconds (minimum)
# or 10x the per-test timeout (150 seconds)
assert subprocess_timeout >= 120, f"Expected subprocess timeout >= 120s, got {subprocess_timeout}s"
assert subprocess_timeout >= 15 * 10, f"Expected subprocess timeout >= 150s (10x per-test), got {subprocess_timeout}s"
assert subprocess_timeout >= 15 * 10, (
f"Expected subprocess timeout >= 150s (10x per-test), got {subprocess_timeout}s"
)
def test_vitest_line_profile_subprocess_timeout_larger_than_test_timeout(self):
"""Test that subprocess timeout is larger than per-test timeout for Vitest line profile tests."""
@ -339,7 +329,7 @@ class TestVitestTimeoutConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
)
]
)
@ -351,11 +341,7 @@ class TestVitestTimeoutConfiguration:
mock_run.return_value = mock_result
run_vitest_line_profile_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
timeout=15,
project_root=tmpdir_path,
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, timeout=15, project_root=tmpdir_path
)
assert mock_run.called
@ -387,7 +373,7 @@ class TestVitestTimeoutConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
)
]
)
@ -400,10 +386,7 @@ class TestVitestTimeoutConfiguration:
# Run without specifying a timeout
run_vitest_behavioral_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
)
assert mock_run.called
@ -445,7 +428,7 @@ class TestVitestInternalLoopingConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
)
]
)
@ -503,7 +486,7 @@ class TestVitestInternalLoopingConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
)
]
)
@ -550,13 +533,7 @@ class TestBundlerModuleResolutionFix:
tmpdir_path = Path(tmpdir)
# Create tsconfig with bundler moduleResolution
tsconfig = {
"compilerOptions": {
"moduleResolution": "bundler",
"module": "preserve",
"target": "ES2022",
}
}
tsconfig = {"compilerOptions": {"moduleResolution": "bundler", "module": "preserve", "target": "ES2022"}}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
assert _detect_bundler_module_resolution(tmpdir_path) is True
@ -571,12 +548,7 @@ class TestBundlerModuleResolutionFix:
tmpdir_path = Path(tmpdir)
# Create tsconfig with Node moduleResolution
tsconfig = {
"compilerOptions": {
"moduleResolution": "Node",
"module": "ESNext",
}
}
tsconfig = {"compilerOptions": {"moduleResolution": "Node", "module": "ESNext"}}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
assert _detect_bundler_module_resolution(tmpdir_path) is False
@ -601,21 +573,11 @@ class TestBundlerModuleResolutionFix:
# Create a base config with bundler in a subdirectory (simulating node_modules)
node_modules = tmpdir_path / "node_modules" / "@myorg" / "tsconfig"
node_modules.mkdir(parents=True)
base_tsconfig = {
"compilerOptions": {
"moduleResolution": "bundler",
"module": "preserve",
}
}
base_tsconfig = {"compilerOptions": {"moduleResolution": "bundler", "module": "preserve"}}
(node_modules / "tsconfig.json").write_text(json.dumps(base_tsconfig))
# Create a project tsconfig that extends the base
project_tsconfig = {
"extends": "@myorg/tsconfig/tsconfig.json",
"compilerOptions": {
"target": "ES2022",
}
}
project_tsconfig = {"extends": "@myorg/tsconfig/tsconfig.json", "compilerOptions": {"target": "ES2022"}}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(project_tsconfig))
# Should detect bundler from extended config
@ -632,11 +594,7 @@ class TestBundlerModuleResolutionFix:
# Create original tsconfig
original_tsconfig = {
"compilerOptions": {
"moduleResolution": "bundler",
"module": "preserve",
"target": "ES2022",
},
"compilerOptions": {"moduleResolution": "bundler", "module": "preserve", "target": "ES2022"},
"include": ["src/**/*.ts"],
"exclude": ["node_modules"],
}
@ -683,12 +641,7 @@ class TestBundlerModuleResolutionFix:
tmpdir_path = Path(tmpdir)
# Create tsconfig with bundler
tsconfig = {
"compilerOptions": {
"moduleResolution": "bundler",
"module": "preserve",
}
}
tsconfig = {"compilerOptions": {"moduleResolution": "bundler", "module": "preserve"}}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
(tmpdir_path / "package.json").write_text('{"name": "test"}')
@ -709,12 +662,7 @@ class TestBundlerModuleResolutionFix:
tmpdir_path = Path(tmpdir)
# Create tsconfig with Node moduleResolution
tsconfig = {
"compilerOptions": {
"moduleResolution": "Node",
"module": "ESNext",
}
}
tsconfig = {"compilerOptions": {"moduleResolution": "Node", "module": "ESNext"}}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
(tmpdir_path / "package.json").write_text('{"name": "test"}')
@ -772,7 +720,7 @@ class TestBundledJestReporter:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
)
]
)
@ -785,10 +733,7 @@ class TestBundledJestReporter:
try:
run_jest_behavioral_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
)
except Exception:
pass
@ -796,7 +741,9 @@ class TestBundledJestReporter:
if mock_run.called:
cmd = mock_run.call_args[0][0]
reporter_args = [a for a in cmd if "--reporters=" in a and "jest-reporter" in a]
assert len(reporter_args) == 1, f"Expected exactly one codeflash/jest-reporter flag, got: {reporter_args}"
assert len(reporter_args) == 1, (
f"Expected exactly one codeflash/jest-reporter flag, got: {reporter_args}"
)
assert reporter_args[0] == "--reporters=codeflash/jest-reporter"
# Must NOT reference jest-junit
jest_junit_args = [a for a in cmd if "jest-junit" in a]
@ -823,7 +770,7 @@ class TestBundledJestReporter:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
)
]
)
@ -836,10 +783,7 @@ class TestBundledJestReporter:
try:
run_jest_benchmarking_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
)
except Exception:
pass
@ -870,7 +814,7 @@ class TestBundledJestReporter:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
)
]
)
@ -883,10 +827,7 @@ class TestBundledJestReporter:
try:
run_jest_line_profile_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
)
except Exception:
pass
@ -963,12 +904,7 @@ reporter.onRunComplete([], results);
console.log('OK');
""")
result = subprocess.run(
["node", str(test_script)],
capture_output=True,
text=True,
timeout=10,
)
result = subprocess.run(["node", str(test_script)], capture_output=True, text=True, timeout=10)
assert result.returncode == 0, f"Reporter script failed: {result.stderr}"
assert output_file.exists(), "Reporter did not create output file"
@ -1020,7 +956,6 @@ console.log('OK');
assert exports["./jest-reporter"]["require"] == "./runtime/jest-reporter.js"
class TestUnsupportedFrameworkError:
"""Tests for clear error on unsupported test frameworks."""
@ -1030,12 +965,7 @@ class TestUnsupportedFrameworkError:
support = JavaScriptSupport()
with pytest.raises(NotImplementedError, match="not yet supported"):
support.run_behavioral_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="tap",
)
support.run_behavioral_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="tap")
def test_unknown_framework_raises_error_benchmarking(self):
"""run_benchmarking_tests should raise NotImplementedError for unknown frameworks."""
@ -1043,12 +973,7 @@ class TestUnsupportedFrameworkError:
support = JavaScriptSupport()
with pytest.raises(NotImplementedError, match="not yet supported"):
support.run_benchmarking_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="tap",
)
support.run_benchmarking_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="tap")
def test_unknown_framework_raises_error_line_profile(self):
"""run_line_profile_tests should raise NotImplementedError for unknown frameworks."""
@ -1056,42 +981,27 @@ class TestUnsupportedFrameworkError:
support = JavaScriptSupport()
with pytest.raises(NotImplementedError, match="not yet supported"):
support.run_line_profile_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="tap",
)
support.run_line_profile_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="tap")
def test_jest_framework_does_not_raise_not_implemented(self):
"""jest framework should NOT raise NotImplementedError."""
"""Jest framework should NOT raise NotImplementedError."""
from codeflash.languages.javascript.support import JavaScriptSupport
support = JavaScriptSupport()
try:
support.run_behavioral_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="jest",
)
support.run_behavioral_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="jest")
except NotImplementedError:
pytest.fail("jest framework should not raise NotImplementedError")
except Exception:
pass # Other exceptions are fine — Jest isn't installed in test env
def test_mocha_framework_does_not_raise_not_implemented(self):
"""mocha framework should NOT raise NotImplementedError."""
"""Mocha framework should NOT raise NotImplementedError."""
from codeflash.languages.javascript.support import JavaScriptSupport
support = JavaScriptSupport()
try:
support.run_behavioral_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="mocha",
)
support.run_behavioral_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="mocha")
except NotImplementedError:
pytest.fail("mocha framework should not raise NotImplementedError")
except Exception:

View file

@ -8,12 +8,13 @@ from pathlib import Path
from unittest.mock import MagicMock
import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.languages.registry import get_language_support
from codeflash.models.models import FunctionParent
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
FIXTURES_DIR = Path(__file__).parent / "fixtures"

View file

@ -5,7 +5,6 @@ import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from junitparser import JUnitXml
@ -19,12 +18,7 @@ class TestMochaJsonToJunitXml:
{
"stats": {"tests": 2, "passes": 2, "failures": 0, "duration": 50},
"tests": [
{
"title": "should add numbers",
"fullTitle": "math should add numbers",
"duration": 20,
"err": {},
},
{"title": "should add numbers", "fullTitle": "math should add numbers", "duration": 20, "err": {}},
{
"title": "should subtract numbers",
"fullTitle": "math should subtract numbers",
@ -62,7 +56,7 @@ class TestMochaJsonToJunitXml:
"message": "expected 1 to equal 2",
"stack": "AssertionError: expected 1 to equal 2\n at Context.<anonymous>",
},
},
}
],
"passes": [],
"failures": [],
@ -92,7 +86,7 @@ class TestMochaJsonToJunitXml:
"duration": 0,
"pending": True,
"err": {},
},
}
],
"passes": [],
"failures": [],
@ -198,9 +192,7 @@ class TestMochaJsonToJunitXml:
mocha_json = json.dumps(
{
"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 10},
"tests": [
{"title": "test1", "fullTitle": "someOtherSuite test1", "duration": 10, "err": {}},
],
"tests": [{"title": "test1", "fullTitle": "someOtherSuite test1", "duration": 10, "err": {}}],
"passes": [],
"failures": [],
"pending": [],
@ -229,9 +221,7 @@ class TestMochaJsonToJunitXml:
mocha_json = json.dumps(
{
"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 10},
"tests": [
{"title": "test1", "fullTitle": "suite test1", "duration": 10, "err": {}},
],
"tests": [{"title": "test1", "fullTitle": "suite test1", "duration": 10, "err": {}}],
"passes": [],
"failures": [],
"pending": [],
@ -435,7 +425,13 @@ class TestRunMochaBehavioralTests:
from codeflash.models.test_type import TestType
mocha_output = json.dumps(
{"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 10}, "tests": [{"title": "t", "fullTitle": "s t", "duration": 10, "err": {}}], "passes": [], "failures": [], "pending": []}
{
"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 10},
"tests": [{"title": "t", "fullTitle": "s t", "duration": 10, "err": {}}],
"passes": [],
"failures": [],
"pending": [],
}
)
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
@ -457,10 +453,7 @@ class TestRunMochaBehavioralTests:
)
result_file, result, cov, _ = run_mocha_behavioral_tests(
test_paths=test_paths,
test_env={},
cwd=tmpdir_path,
candidate_index=3,
test_paths=test_paths, test_env={}, cwd=tmpdir_path, candidate_index=3
)
# Verify env vars were passed
@ -478,7 +471,13 @@ class TestRunMochaBehavioralTests:
from codeflash.models.test_type import TestType
mocha_output = json.dumps(
{"stats": {"tests": 0, "passes": 0, "failures": 0, "duration": 0}, "tests": [], "passes": [], "failures": [], "pending": []}
{
"stats": {"tests": 0, "passes": 0, "failures": 0, "duration": 0},
"tests": [],
"passes": [],
"failures": [],
"pending": [],
}
)
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
@ -499,11 +498,7 @@ class TestRunMochaBehavioralTests:
]
)
_, _, coverage_path, _ = run_mocha_behavioral_tests(
test_paths=test_paths,
test_env={},
cwd=tmpdir_path,
)
_, _, coverage_path, _ = run_mocha_behavioral_tests(test_paths=test_paths, test_env={}, cwd=tmpdir_path)
assert coverage_path is None
@ -518,7 +513,13 @@ class TestRunMochaBenchmarkingTests:
from codeflash.models.test_type import TestType
mocha_output = json.dumps(
{"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 100}, "tests": [{"title": "perf", "fullTitle": "bench perf", "duration": 100, "err": {}}], "passes": [], "failures": [], "pending": []}
{
"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 100},
"tests": [{"title": "perf", "fullTitle": "bench perf", "duration": 100, "err": {}}],
"passes": [],
"failures": [],
"pending": [],
}
)
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
@ -729,7 +730,13 @@ class TestRunMochaLineProfileTests:
from codeflash.models.test_type import TestType
mocha_output = json.dumps(
{"stats": {"tests": 0, "passes": 0, "failures": 0, "duration": 0}, "tests": [], "passes": [], "failures": [], "pending": []}
{
"stats": {"tests": 0, "passes": 0, "failures": 0, "duration": 0},
"tests": [],
"passes": [],
"failures": [],
"pending": [],
}
)
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
@ -752,10 +759,7 @@ class TestRunMochaLineProfileTests:
)
run_mocha_line_profile_tests(
test_paths=test_paths,
test_env={},
cwd=tmpdir_path,
line_profile_output_file=profile_output,
test_paths=test_paths, test_env={}, cwd=tmpdir_path, line_profile_output_file=profile_output
)
call_kwargs = mock_run.call_args
@ -769,7 +773,8 @@ class TestParserUnknownTestNameFallback:
def test_unknown_markers_matched_to_first_testcase(self):
"""When capturePerf markers have 'unknown' test name (Vitest beforeEach not firing),
the parser should still match them to testcases via the fallback logic."""
the parser should still match them to testcases via the fallback logic.
"""
from codeflash.languages.javascript.parse import parse_jest_test_xml
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
@ -817,10 +822,7 @@ class TestParserUnknownTestNameFallback:
test_config.test_framework = "vitest"
results = parse_jest_test_xml(
test_xml_file_path=xml_path,
test_files=test_files,
test_config=test_config,
run_result=mock_result,
test_xml_file_path=xml_path, test_files=test_files, test_config=test_config, run_result=mock_result
)
# The "unknown" fallback should assign all 5 markers to the testcase

View file

@ -272,8 +272,8 @@ class TestClearFunctions:
assert not is_language_supported(Language.PYTHON)
# Re-register all languages by importing
from codeflash.languages.python.support import PythonSupport
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.languages.python.support import PythonSupport
# Need to manually register since decorator already ran
register_language(PythonSupport)

View file

@ -839,7 +839,7 @@ class TestNamedExportConstArrow:
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
def test_named_export_const_arrow(self, ts_analyzer):
"""const arrow function exported via separate export { } clause."""
"""Const arrow function exported via separate export { } clause."""
code = """const joinBy = (arr: string[], separator: string) => {
return arr.join(separator);
};
@ -852,7 +852,7 @@ export { joinBy };"""
assert joinBy.is_exported is True
def test_named_export_alias(self, ts_analyzer):
"""export { foo as bar } — foo should be marked as exported."""
"""Export { foo as bar } — foo should be marked as exported."""
code = """const foo = (x: number) => {
return x * 2;
};

View file

@ -60,8 +60,9 @@ class TestTypeScriptFunctionDiscovery:
# Critical: Verify language is "typescript", not "javascript"
for func in func_list:
assert func.language == "typescript", \
assert func.language == "typescript", (
f"Function {func.function_name} should have language='typescript', got '{func.language}'"
)
def test_discover_functions_with_type_annotations(self):
"""Test discovering TypeScript functions with type annotations."""
@ -176,11 +177,7 @@ function multiply(a: number, b: number): number {
ts_support = get_language_support(Language.TYPESCRIPT)
func_info = FunctionInfo(
function_name="add",
file_path=Path("/tmp/test.ts"),
starting_line=2,
ending_line=4,
language="typescript"
function_name="add", file_path=Path("/tmp/test.ts"), starting_line=2, ending_line=4, language="typescript"
)
result = ts_support.replace_function(original_source, func_info, new_function)
@ -227,7 +224,7 @@ function processConfig(config: Config): string {
file_path=Path("/tmp/test.ts"),
starting_line=7,
ending_line=9,
language="typescript"
language="typescript",
)
result = ts_support.replace_function(original_source, func_info, new_function)
@ -264,11 +261,7 @@ class TestTypeScriptTestDiscovery:
fib_file = ts_project_dir / "fibonacci.ts"
func_info = FunctionInfo(
function_name="fibonacci",
file_path=fib_file,
starting_line=1,
ending_line=7,
language="typescript"
function_name="fibonacci", file_path=fib_file, starting_line=1, ending_line=7, language="typescript"
)
tests = ts_support.discover_tests(test_root, [func_info])
@ -328,7 +321,7 @@ export function standalone(x: number): number {
CodeString(
code="function add(a: number, b: number): number { return a + b; }",
file_path=Path("test.ts"),
language="typescript"
language="typescript",
)
],
language="typescript",

View file

@ -301,15 +301,7 @@ class TestVitestVsJestDetection:
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test",
"devDependencies": {
"vitest": "^2.0.0",
"jest": "^29.0.0",
},
}
)
json.dumps({"name": "test", "devDependencies": {"vitest": "^2.0.0", "jest": "^29.0.0"}})
)
package_data = get_package_json_data(package_json)

View file

@ -23,7 +23,13 @@ def pytest_loops_instance(pytestconfig: Config) -> PytestLoops:
@pytest.fixture
def mock_item() -> type:
class MockItem:
def __init__(self, function: types.FunctionType, name: str = "test_func", cls: type = None, module: types.ModuleType = None) -> None:
def __init__(
self,
function: types.FunctionType,
name: str = "test_func",
cls: type = None,
module: types.ModuleType = None,
) -> None:
self.function = function
self.name = name
self.cls = cls
@ -352,7 +358,9 @@ obj.my_method(5)
item = mock_item(no_cache_func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
def test_clears_module_level_caches_via_sys_modules(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
def test_clears_module_level_caches_via_sys_modules(
self, pytest_loops_instance: PytestLoops, mock_item: type
) -> None:
module_name = "_cf_test_module_scan"
source_code = """
import functools

View file

@ -127,7 +127,9 @@ code_to_optimize/tests/test_simple.py:10: AssertionError
)
assert "TestCalculator.test_divide_by_zero" in errors
assert errors["TestCalculator.test_divide_by_zero"] == """
assert (
errors["TestCalculator.test_divide_by_zero"]
== """
class TestCalculator:
def test_divide_by_zero(self):
> Calculator().divide(10, 0)
@ -135,6 +137,7 @@ E ZeroDivisionError: division by zero
code_to_optimize/tests/test_calculator.py:22: ZeroDivisionError
"""
)
def test_extracting_from_invalid_pytest_stdout():

View file

@ -1,11 +1,6 @@
"""Tests for the regex patterns and string matching in parse_test_output.py."""
from codeflash.verification.parse_test_output import (
matches_re_end,
matches_re_start,
parse_test_failures_from_stdout,
)
from codeflash.verification.parse_test_output import matches_re_end, matches_re_start, parse_test_failures_from_stdout
# --- matches_re_start tests ---
@ -42,10 +37,7 @@ class TestMatchesReStart:
assert m.groups() == ("mod", "", "test_fn", "f", "1", "x")
def test_multiple_matches(self) -> None:
s = (
"!$######m1:C1.fn1:t1:1:a######$!\n"
"!$######m2:fn2:t2:2:b######$!\n"
)
s = "!$######m1:C1.fn1:t1:1:a######$!\n!$######m2:fn2:t2:2:b######$!\n"
matches = list(matches_re_start.finditer(s))
assert len(matches) == 2
assert matches[0].groups() == ("m1", "C1.", "fn1", "t1", "1", "a")
@ -170,20 +162,12 @@ class TestParseTestFailuresHeader:
def test_word_failures_without_equals_is_not_matched(self) -> None:
"""'FAILURES' without surrounding '=' signs should not trigger the header detection."""
stdout = (
"FAILURES detected in module\n"
"_______ test_baz _______\n"
"\n"
" assert False\n"
)
stdout = "FAILURES detected in module\n_______ test_baz _______\n\n assert False\n"
result = parse_test_failures_from_stdout(stdout)
assert result == {}
def test_failures_in_test_output_not_matched(self) -> None:
"""A test printing 'FAILURES' (no = signs) should not trigger header detection."""
stdout = (
"Testing FAILURES handling\n"
"All good\n"
)
stdout = "Testing FAILURES handling\nAll good\n"
result = parse_test_failures_from_stdout(stdout)
assert result == {}

View file

@ -1,5 +1,3 @@
from codeflash.languages.python.context.unused_definition_remover import remove_unused_definitions_by_function_names

View file

@ -74,10 +74,7 @@ class TestCodeflashConfig:
def test_to_pyproject_dict_minimal(self):
"""Should only include non-default values."""
config = CodeflashConfig(
language="python",
module_root="src",
)
config = CodeflashConfig(language="python", module_root="src")
result = config.to_pyproject_dict()
@ -149,11 +146,7 @@ class TestCodeflashConfig:
def test_from_package_json_dict(self):
"""Should create config from package.json dict."""
data = {
"moduleRoot": "lib",
"formatterCmds": ["npx prettier --write $file"],
"disableTelemetry": True,
}
data = {"moduleRoot": "lib", "formatterCmds": ["npx prettier --write $file"], "disableTelemetry": True}
config = CodeflashConfig.from_package_json_dict(data)
@ -168,11 +161,7 @@ class TestWritePyprojectToml:
def test_creates_new_pyproject(self, tmp_path):
"""Should create pyproject.toml if it doesn't exist."""
config = CodeflashConfig(
language="python",
module_root="src",
tests_root="tests",
)
config = CodeflashConfig(language="python", module_root="src", tests_root="tests")
success, message = _write_pyproject_toml(tmp_path, config)
@ -192,10 +181,7 @@ class TestWritePyprojectToml:
'[project]\nname = "myapp"\nversion = "1.0.0"\n\n[tool.ruff]\nline-length = 120'
)
config = CodeflashConfig(
language="python",
module_root="src",
)
config = CodeflashConfig(language="python", module_root="src")
success, message = _write_pyproject_toml(tmp_path, config)
@ -210,15 +196,9 @@ class TestWritePyprojectToml:
def test_updates_existing_codeflash_section(self, tmp_path):
"""Should update existing codeflash section."""
(tmp_path / "pyproject.toml").write_text(
'[tool.codeflash]\nmodule-root = "old"\ntests-root = "old_tests"'
)
(tmp_path / "pyproject.toml").write_text('[tool.codeflash]\nmodule-root = "old"\ntests-root = "old_tests"')
config = CodeflashConfig(
language="python",
module_root="new",
tests_root="new_tests",
)
config = CodeflashConfig(language="python", module_root="new", tests_root="new_tests")
success, message = _write_pyproject_toml(tmp_path, config)
@ -235,15 +215,10 @@ class TestWritePackageJson:
def test_adds_codeflash_section(self, tmp_path):
"""Should add codeflash section to package.json."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "myapp",
"version": "1.0.0"
}, indent=2))
(tmp_path / "package.json").write_text(json.dumps({"name": "myapp", "version": "1.0.0"}, indent=2))
config = CodeflashConfig(
language="javascript",
module_root="lib",
formatter_cmds=["npx prettier --write $file"],
language="javascript", module_root="lib", formatter_cmds=["npx prettier --write $file"]
)
success, message = _write_package_json(tmp_path, config)
@ -259,16 +234,14 @@ class TestWritePackageJson:
def test_preserves_existing_content(self, tmp_path):
"""Should preserve existing package.json content."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "myapp",
"dependencies": {"lodash": "^4.17.0"},
"devDependencies": {"jest": "^29.0.0"}
}, indent=2))
config = CodeflashConfig(
language="javascript",
module_root="lib",
(tmp_path / "package.json").write_text(
json.dumps(
{"name": "myapp", "dependencies": {"lodash": "^4.17.0"}, "devDependencies": {"jest": "^29.0.0"}},
indent=2,
)
)
config = CodeflashConfig(language="javascript", module_root="lib")
success, message = _write_package_json(tmp_path, config)
@ -281,10 +254,9 @@ class TestWritePackageJson:
def test_removes_empty_codeflash_section(self, tmp_path):
"""Should remove codeflash section if all defaults."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "myapp",
"codeflash": {"moduleRoot": "old"}
}, indent=2))
(tmp_path / "package.json").write_text(
json.dumps({"name": "myapp", "codeflash": {"moduleRoot": "old"}}, indent=2)
)
# Config with all defaults - should result in empty dict
config = CodeflashConfig(
@ -342,9 +314,7 @@ class TestRemoveConfig:
def test_removes_from_pyproject(self, tmp_path):
"""Should remove codeflash section from pyproject.toml."""
(tmp_path / "pyproject.toml").write_text(
'[project]\nname = "test"\n\n[tool.codeflash]\nmodule-root = "src"'
)
(tmp_path / "pyproject.toml").write_text('[project]\nname = "test"\n\n[tool.codeflash]\nmodule-root = "src"')
success, message = remove_config(tmp_path, "python")
@ -357,10 +327,9 @@ class TestRemoveConfig:
def test_removes_from_package_json(self, tmp_path):
"""Should remove codeflash section from package.json."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"codeflash": {"moduleRoot": "src"}
}, indent=2))
(tmp_path / "package.json").write_text(
json.dumps({"name": "test", "codeflash": {"moduleRoot": "src"}}, indent=2)
)
success, message = remove_config(tmp_path, "javascript")

View file

@ -141,10 +141,9 @@ class TestDetectModuleRoot:
def test_js_detects_from_exports(self, tmp_path):
"""Should detect module root from package.json exports when no common src dir exists."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"exports": {".": "./packages/core/index.js"}
}))
(tmp_path / "package.json").write_text(
json.dumps({"name": "test", "exports": {".": "./packages/core/index.js"}})
)
(tmp_path / "packages" / "core").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
@ -161,11 +160,9 @@ class TestDetectModuleRoot:
def test_js_prefers_src_over_build_src(self, tmp_path):
"""Should prefer src/ over build/src/ even when package.json points to build/."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "build/src/index.js",
"module": "build/src/index.js"
}))
(tmp_path / "package.json").write_text(
json.dumps({"name": "test", "main": "build/src/index.js", "module": "build/src/index.js"})
)
(tmp_path / "src").mkdir()
(tmp_path / "build" / "src").mkdir(parents=True)
@ -175,10 +172,7 @@ class TestDetectModuleRoot:
def test_js_skips_build_dir_from_main(self, tmp_path):
"""Should skip build output directories from package.json main field."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "build/index.js"
}))
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "main": "build/index.js"}))
(tmp_path / "build").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
@ -187,10 +181,7 @@ class TestDetectModuleRoot:
def test_js_skips_dist_dir_from_exports(self, tmp_path):
"""Should skip dist output directories from package.json exports field."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"exports": {".": "./dist/index.js"}
}))
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "exports": {".": "./dist/index.js"}}))
(tmp_path / "dist").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
@ -199,10 +190,7 @@ class TestDetectModuleRoot:
def test_js_skips_out_dir_from_module(self, tmp_path):
"""Should skip out output directories from package.json module field."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"module": "out/esm/index.js"
}))
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "module": "out/esm/index.js"}))
(tmp_path / "out" / "esm").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
@ -211,10 +199,7 @@ class TestDetectModuleRoot:
def test_js_prefers_lib_over_build_dir(self, tmp_path):
"""Should prefer lib/ over build output directories."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "dist/index.js"
}))
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "main": "dist/index.js"}))
(tmp_path / "lib").mkdir()
(tmp_path / "dist").mkdir()
@ -224,10 +209,7 @@ class TestDetectModuleRoot:
def test_js_prefers_source_over_build_dir(self, tmp_path):
"""Should prefer source/ over build output directories."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "build/index.js"
}))
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "main": "build/index.js"}))
(tmp_path / "source").mkdir()
(tmp_path / "build").mkdir()
@ -237,10 +219,9 @@ class TestDetectModuleRoot:
def test_js_falls_back_to_valid_exports_path(self, tmp_path):
"""Should use exports path when no common source dirs exist and path is not build output."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"exports": {".": "./packages/core/index.js"}
}))
(tmp_path / "package.json").write_text(
json.dumps({"name": "test", "exports": {".": "./packages/core/index.js"}})
)
(tmp_path / "packages" / "core").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
@ -249,10 +230,7 @@ class TestDetectModuleRoot:
def test_js_falls_back_to_valid_main_path(self, tmp_path):
"""Should use main path when no common source dirs exist and path is not build output."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "packages/main/index.js"
}))
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "main": "packages/main/index.js"}))
(tmp_path / "packages" / "main").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
@ -261,10 +239,7 @@ class TestDetectModuleRoot:
def test_js_falls_back_to_valid_module_path(self, tmp_path):
"""Should use module path when no common source dirs exist and path is not build output."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"module": "esm/index.js"
}))
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "module": "esm/index.js"}))
(tmp_path / "esm").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
@ -273,12 +248,16 @@ class TestDetectModuleRoot:
def test_js_returns_project_root_when_all_paths_are_build_output(self, tmp_path):
"""Should return project root when all package.json paths point to build outputs."""
(tmp_path / "package.json").write_text(json.dumps({
(tmp_path / "package.json").write_text(
json.dumps(
{
"name": "test",
"main": "dist/cjs/index.js",
"module": "dist/esm/index.js",
"exports": {".": "./build/index.js"}
}))
"exports": {".": "./build/index.js"},
}
)
)
(tmp_path / "dist" / "cjs").mkdir(parents=True)
(tmp_path / "dist" / "esm").mkdir(parents=True)
(tmp_path / "build").mkdir()
@ -302,6 +281,7 @@ class TestIsBuildOutputDir:
def test_detects_build_dir(self):
"""Should detect build/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path("build"))
assert is_build_output_dir(Path("build/src"))
assert is_build_output_dir(Path("build/src/index.js"))
@ -309,6 +289,7 @@ class TestIsBuildOutputDir:
def test_detects_dist_dir(self):
"""Should detect dist/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path("dist"))
assert is_build_output_dir(Path("dist/esm"))
assert is_build_output_dir(Path("dist/cjs/index.js"))
@ -316,53 +297,62 @@ class TestIsBuildOutputDir:
def test_detects_out_dir(self):
"""Should detect out/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path("out"))
assert is_build_output_dir(Path("out/src"))
def test_detects_next_dir(self):
"""Should detect .next/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path(".next"))
assert is_build_output_dir(Path(".next/static"))
def test_detects_nuxt_dir(self):
"""Should detect .nuxt/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path(".nuxt"))
assert is_build_output_dir(Path(".nuxt/dist"))
def test_detects_nested_build_dir(self):
"""Should detect build dir nested in path."""
from pathlib import Path
assert is_build_output_dir(Path("packages/build/index.js"))
assert is_build_output_dir(Path("foo/dist/bar"))
def test_does_not_detect_src(self):
"""Should not detect src/ as build output."""
from pathlib import Path
assert not is_build_output_dir(Path("src"))
assert not is_build_output_dir(Path("src/index.js"))
def test_does_not_detect_lib(self):
"""Should not detect lib/ as build output."""
from pathlib import Path
assert not is_build_output_dir(Path("lib"))
assert not is_build_output_dir(Path("lib/utils"))
def test_does_not_detect_source(self):
"""Should not detect source/ as build output."""
from pathlib import Path
assert not is_build_output_dir(Path("source"))
def test_does_not_detect_packages(self):
"""Should not detect packages/ as build output."""
from pathlib import Path
assert not is_build_output_dir(Path("packages"))
assert not is_build_output_dir(Path("packages/core"))
def test_does_not_detect_similar_names(self):
"""Should not detect directories with similar but different names."""
from pathlib import Path
assert not is_build_output_dir(Path("builder"))
assert not is_build_output_dir(Path("distribution"))
assert not is_build_output_dir(Path("output"))
@ -417,18 +407,14 @@ class TestDetectTestRunner:
def test_js_detects_jest_from_deps(self, tmp_path):
"""Should detect jest from devDependencies."""
(tmp_path / "package.json").write_text(json.dumps({
"devDependencies": {"jest": "^29.0.0"}
}))
(tmp_path / "package.json").write_text(json.dumps({"devDependencies": {"jest": "^29.0.0"}}))
runner, detail = _detect_js_test_runner(tmp_path)
assert runner == "jest"
def test_js_detects_vitest_from_deps(self, tmp_path):
"""Should detect vitest from devDependencies (preferred over jest)."""
(tmp_path / "package.json").write_text(json.dumps({
"devDependencies": {"vitest": "^1.0.0", "jest": "^29.0.0"}
}))
(tmp_path / "package.json").write_text(json.dumps({"devDependencies": {"vitest": "^1.0.0", "jest": "^29.0.0"}}))
runner, detail = _detect_js_test_runner(tmp_path)
assert runner == "vitest"
@ -469,9 +455,7 @@ class TestDetectFormatter:
def test_js_detects_prettier_from_deps(self, tmp_path):
"""Should detect prettier from devDependencies."""
(tmp_path / "package.json").write_text(json.dumps({
"devDependencies": {"prettier": "^3.0.0"}
}))
(tmp_path / "package.json").write_text(json.dumps({"devDependencies": {"prettier": "^3.0.0"}}))
formatter, detail = _detect_js_formatter(tmp_path)
assert any("prettier" in cmd for cmd in formatter)
@ -483,9 +467,7 @@ class TestDetectProject:
def test_detects_python_project(self, tmp_path):
"""Should correctly detect a Python project."""
# Create Python project structure
(tmp_path / "pyproject.toml").write_text(
'[project]\nname = "myapp"\n\n[tool.ruff]\nline-length = 120'
)
(tmp_path / "pyproject.toml").write_text('[project]\nname = "myapp"\n\n[tool.ruff]\nline-length = 120')
(tmp_path / "myapp").mkdir()
(tmp_path / "myapp" / "__init__.py").write_text("")
(tmp_path / "tests").mkdir()
@ -503,10 +485,9 @@ class TestDetectProject:
def test_detects_javascript_project(self, tmp_path):
"""Should correctly detect a JavaScript project."""
# Create JS project structure
(tmp_path / "package.json").write_text(json.dumps({
"name": "myapp",
"devDependencies": {"jest": "^29.0.0", "prettier": "^3.0.0"}
}))
(tmp_path / "package.json").write_text(
json.dumps({"name": "myapp", "devDependencies": {"jest": "^29.0.0", "prettier": "^3.0.0"}})
)
(tmp_path / "src").mkdir()
(tmp_path / "tests").mkdir()
(tmp_path / ".git").mkdir()
@ -523,10 +504,9 @@ class TestDetectProject:
def test_detects_typescript_project(self, tmp_path):
"""Should correctly detect a TypeScript project."""
# Create TS project structure
(tmp_path / "package.json").write_text(json.dumps({
"name": "myapp",
"devDependencies": {"vitest": "^1.0.0", "typescript": "^5.0.0"}
}))
(tmp_path / "package.json").write_text(
json.dumps({"name": "myapp", "devDependencies": {"vitest": "^1.0.0", "typescript": "^5.0.0"}})
)
(tmp_path / "tsconfig.json").write_text("{}")
(tmp_path / "src").mkdir()
(tmp_path / ".git").mkdir()
@ -556,9 +536,7 @@ class TestHasExistingConfig:
def test_detects_pyproject_config(self, tmp_path):
"""Should detect config in pyproject.toml."""
(tmp_path / "pyproject.toml").write_text(
'[tool.codeflash]\nmodule-root = "src"'
)
(tmp_path / "pyproject.toml").write_text('[tool.codeflash]\nmodule-root = "src"')
has_config, config_type = has_existing_config(tmp_path)
assert has_config is True
@ -566,10 +544,7 @@ class TestHasExistingConfig:
def test_detects_package_json_config(self, tmp_path):
"""Should detect config in package.json."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"codeflash": {"moduleRoot": "src"}
}))
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "codeflash": {"moduleRoot": "src"}}))
has_config, config_type = has_existing_config(tmp_path)
assert has_config is True

View file

@ -31,7 +31,8 @@ from codeflash.setup import (
def python_src_layout(tmp_path):
"""Create a Python project with src/ layout."""
# pyproject.toml with poetry
(tmp_path / "pyproject.toml").write_text("""
(tmp_path / "pyproject.toml").write_text(
"""
[tool.poetry]
name = "myapp"
version = "0.1.0"
@ -41,7 +42,8 @@ line-length = 120
[tool.pytest.ini_options]
testpaths = ["tests"]
""".strip())
""".strip()
)
# src/myapp package
src_dir = tmp_path / "src" / "myapp"
@ -66,14 +68,16 @@ testpaths = ["tests"]
@pytest.fixture
def python_flat_layout(tmp_path):
"""Create a Python project with flat layout (package at root)."""
(tmp_path / "pyproject.toml").write_text("""
(tmp_path / "pyproject.toml").write_text(
"""
[project]
name = "myapp"
version = "0.1.0"
[tool.black]
line-length = 88
""".strip())
""".strip()
)
# Package at root
pkg_dir = tmp_path / "myapp"
@ -93,14 +97,16 @@ line-length = 88
@pytest.fixture
def python_setup_py_project(tmp_path):
"""Create a Python project with setup.py (legacy)."""
(tmp_path / "setup.py").write_text("""
(tmp_path / "setup.py").write_text(
"""
from setuptools import setup, find_packages
setup(
name="legacyapp",
version="1.0.0",
packages=find_packages(),
)
""".strip())
""".strip()
)
pkg_dir = tmp_path / "legacyapp"
pkg_dir.mkdir()
@ -114,19 +120,18 @@ setup(
@pytest.fixture
def javascript_npm_project(tmp_path):
"""Create a JavaScript project with npm."""
(tmp_path / "package.json").write_text(json.dumps({
(tmp_path / "package.json").write_text(
json.dumps(
{
"name": "my-js-app",
"version": "1.0.0",
"main": "src/index.js",
"scripts": {
"test": "jest",
"lint": "eslint src/"
"scripts": {"test": "jest", "lint": "eslint src/"},
"devDependencies": {"jest": "^29.7.0", "prettier": "^3.0.0"},
},
"devDependencies": {
"jest": "^29.7.0",
"prettier": "^3.0.0"
}
}, indent=2))
indent=2,
)
)
(tmp_path / "package-lock.json").write_text("{}")
@ -147,15 +152,17 @@ def javascript_npm_project(tmp_path):
@pytest.fixture
def javascript_yarn_project(tmp_path):
"""Create a JavaScript project with yarn."""
(tmp_path / "package.json").write_text(json.dumps({
(tmp_path / "package.json").write_text(
json.dumps(
{
"name": "yarn-app",
"version": "1.0.0",
"main": "lib/index.js",
"devDependencies": {
"jest": "^29.0.0",
"eslint": "^8.0.0"
}
}, indent=2))
"devDependencies": {"jest": "^29.0.0", "eslint": "^8.0.0"},
},
indent=2,
)
)
(tmp_path / "yarn.lock").write_text("# yarn lockfile")
@ -171,16 +178,17 @@ def javascript_yarn_project(tmp_path):
@pytest.fixture
def javascript_pnpm_project(tmp_path):
"""Create a JavaScript project with pnpm."""
(tmp_path / "package.json").write_text(json.dumps({
(tmp_path / "package.json").write_text(
json.dumps(
{
"name": "pnpm-app",
"version": "1.0.0",
"exports": {
".": "./dist/index.js"
"exports": {".": "./dist/index.js"},
"devDependencies": {"vitest": "^1.0.0"},
},
"devDependencies": {
"vitest": "^1.0.0"
}
}, indent=2))
indent=2,
)
)
(tmp_path / "pnpm-lock.yaml").write_text("lockfileVersion: 5.4")
@ -193,14 +201,17 @@ def javascript_pnpm_project(tmp_path):
@pytest.fixture
def javascript_bun_project(tmp_path):
"""Create a JavaScript project with bun."""
(tmp_path / "package.json").write_text(json.dumps({
(tmp_path / "package.json").write_text(
json.dumps(
{
"name": "bun-app",
"version": "1.0.0",
"module": "src/index.ts",
"devDependencies": {
"bun-types": "latest"
}
}, indent=2))
"devDependencies": {"bun-types": "latest"},
},
indent=2,
)
)
(tmp_path / "bun.lockb").write_bytes(b"bun lockfile")
@ -212,32 +223,35 @@ def javascript_bun_project(tmp_path):
@pytest.fixture
def typescript_project(tmp_path):
"""Create a TypeScript project."""
(tmp_path / "package.json").write_text(json.dumps({
(tmp_path / "package.json").write_text(
json.dumps(
{
"name": "ts-app",
"version": "1.0.0",
"main": "dist/index.js",
"types": "dist/index.d.ts",
"scripts": {
"build": "tsc",
"test": "vitest"
"scripts": {"build": "tsc", "test": "vitest"},
"devDependencies": {"typescript": "^5.0.0", "vitest": "^1.0.0", "@types/node": "^20.0.0"},
},
"devDependencies": {
"typescript": "^5.0.0",
"vitest": "^1.0.0",
"@types/node": "^20.0.0"
}
}, indent=2))
indent=2,
)
)
(tmp_path / "tsconfig.json").write_text(json.dumps({
(tmp_path / "tsconfig.json").write_text(
json.dumps(
{
"compilerOptions": {
"target": "ES2020",
"module": "commonjs",
"outDir": "./dist",
"rootDir": "./src",
"strict": True
"strict": True,
},
"include": ["src/**/*"]
}, indent=2))
"include": ["src/**/*"],
},
indent=2,
)
)
src_dir = tmp_path / "src"
src_dir.mkdir()
@ -255,7 +269,9 @@ def typescript_project(tmp_path):
@pytest.fixture
def typescript_react_project(tmp_path):
"""Create a TypeScript React project (like Create React App)."""
(tmp_path / "package.json").write_text(json.dumps({
(tmp_path / "package.json").write_text(
json.dumps(
{
"name": "react-app",
"version": "0.1.0",
"private": True,
@ -263,27 +279,26 @@ def typescript_react_project(tmp_path):
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-scripts": "5.0.1",
"jest": "^29.0.0"
"jest": "^29.0.0",
},
"devDependencies": {
"@types/react": "^18.0.0",
"@testing-library/react": "^14.0.0",
"typescript": "^5.0.0"
"typescript": "^5.0.0",
},
"scripts": {
"start": "react-scripts start",
"build": "react-scripts build",
"test": "react-scripts test"
}
}, indent=2))
"test": "react-scripts test",
},
},
indent=2,
)
)
(tmp_path / "tsconfig.json").write_text(json.dumps({
"compilerOptions": {
"target": "es5",
"lib": ["dom", "es2015"],
"jsx": "react-jsx"
}
}, indent=2))
(tmp_path / "tsconfig.json").write_text(
json.dumps({"compilerOptions": {"target": "es5", "lib": ["dom", "es2015"], "jsx": "react-jsx"}}, indent=2)
)
src_dir = tmp_path / "src"
src_dir.mkdir()
@ -299,7 +314,8 @@ def typescript_react_project(tmp_path):
@pytest.fixture
def project_with_existing_config(tmp_path):
"""Create a project with existing codeflash config."""
(tmp_path / "pyproject.toml").write_text("""
(tmp_path / "pyproject.toml").write_text(
"""
[project]
name = "configured-app"
@ -307,7 +323,8 @@ name = "configured-app"
module-root = "src"
tests-root = "tests"
formatter-cmds = ["black $file"]
""".strip())
""".strip()
)
(tmp_path / "src").mkdir()
(tmp_path / "tests").mkdir()
@ -319,13 +336,15 @@ formatter-cmds = ["black $file"]
def mixed_python_js_project(tmp_path):
"""Create a project with both Python and JS files (monorepo-like)."""
# Python backend
(tmp_path / "pyproject.toml").write_text("""
(tmp_path / "pyproject.toml").write_text(
"""
[project]
name = "fullstack-app"
[tool.codeflash]
module-root = "backend"
""".strip())
""".strip()
)
backend_dir = tmp_path / "backend"
backend_dir.mkdir()
@ -335,10 +354,7 @@ module-root = "backend"
# JS frontend
frontend_dir = tmp_path / "frontend"
frontend_dir.mkdir()
(frontend_dir / "package.json").write_text(json.dumps({
"name": "frontend",
"devDependencies": {"jest": "^29.0.0"}
}))
(frontend_dir / "package.json").write_text(json.dumps({"name": "frontend", "devDependencies": {"jest": "^29.0.0"}}))
(frontend_dir / "src").mkdir()
(frontend_dir / "src" / "app.js").write_text("")
@ -458,10 +474,7 @@ class TestE2EFirstRunCheck:
def test_has_existing_config_js(self, tmp_path):
"""Should find existing config in package.json."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"codeflash": {"moduleRoot": "src"}
}))
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "codeflash": {"moduleRoot": "src"}}))
has_config, config_type = has_existing_config(tmp_path)
assert has_config is True
@ -610,17 +623,9 @@ class TestE2EFirstRunExperience:
monkeypatch.chdir(python_flat_layout)
monkeypatch.setenv("CODEFLASH_API_KEY", "cf-test-key-12345")
existing_args = Namespace(
file="myapp/core.py",
function="process",
custom_flag=True,
)
existing_args = Namespace(file="myapp/core.py", function="process", custom_flag=True)
result = handle_first_run(
args=existing_args,
skip_confirm=True,
skip_api_key=True,
)
result = handle_first_run(args=existing_args, skip_confirm=True, skip_api_key=True)
assert result is not None
assert result.custom_flag is True # Preserved
@ -681,10 +686,9 @@ class TestE2EEdgeCases:
def test_project_without_formatter(self, tmp_path):
"""Should handle project without detectable formatter."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "no-formatter",
"devDependencies": {"jest": "^29.0.0"}
}))
(tmp_path / "package.json").write_text(
json.dumps({"name": "no-formatter", "devDependencies": {"jest": "^29.0.0"}})
)
detected = detect_project(tmp_path)
@ -868,9 +872,11 @@ class TestE2ECLIFlags:
printed_messages.append(str(msg))
from codeflash.cli_cmds import console
monkeypatch.setattr(console.console, "print", mock_print)
from codeflash.cli_cmds.cli import _handle_show_config
_handle_show_config()
# Verify config path is displayed
@ -889,9 +895,11 @@ class TestE2ECLIFlags:
printed_messages.append(str(msg))
from codeflash.cli_cmds import console
monkeypatch.setattr(console.console, "print", mock_print)
from codeflash.cli_cmds.cli import _handle_show_config
_handle_show_config()
# Verify no config path line is displayed

View file

@ -27,19 +27,14 @@ class TestIsFirstRun:
def test_returns_false_when_pyproject_config_exists(self, tmp_path):
"""Should return False when codeflash config exists in pyproject.toml."""
(tmp_path / "pyproject.toml").write_text(
'[tool.codeflash]\nmodule-root = "src"'
)
(tmp_path / "pyproject.toml").write_text('[tool.codeflash]\nmodule-root = "src"')
result = is_first_run(tmp_path)
assert result is False
def test_returns_false_when_package_json_config_exists(self, tmp_path):
"""Should return False when codeflash config exists in package.json."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"codeflash": {"moduleRoot": "src"}
}))
(tmp_path / "package.json").write_text(json.dumps({"name": "test", "codeflash": {"moduleRoot": "src"}}))
result = is_first_run(tmp_path)
assert result is False
@ -109,11 +104,7 @@ class TestHandleFirstRun:
existing_args = Namespace(custom_flag=True, module_root=None)
result = handle_first_run(
args=existing_args,
skip_confirm=True,
skip_api_key=True,
)
result = handle_first_run(args=existing_args, skip_confirm=True, skip_api_key=True)
assert result is not None
assert result.custom_flag is True # Preserved
@ -229,9 +220,7 @@ class TestFirstRunIntegration:
def test_full_python_first_run(self, tmp_path, monkeypatch):
"""Should complete full first-run for Python project."""
# Create Python project
(tmp_path / "pyproject.toml").write_text(
'[project]\nname = "myapp"\n\n[tool.ruff]\nline-length = 120'
)
(tmp_path / "pyproject.toml").write_text('[project]\nname = "myapp"\n\n[tool.ruff]\nline-length = 120')
pkg_dir = tmp_path / "myapp"
pkg_dir.mkdir()
(pkg_dir / "__init__.py").write_text("")
@ -257,10 +246,9 @@ class TestFirstRunIntegration:
def test_full_javascript_first_run(self, tmp_path, monkeypatch):
"""Should complete full first-run for JavaScript project."""
# Create JS project
(tmp_path / "package.json").write_text(json.dumps({
"name": "myapp",
"devDependencies": {"jest": "^29.0.0"}
}, indent=2))
(tmp_path / "package.json").write_text(
json.dumps({"name": "myapp", "devDependencies": {"jest": "^29.0.0"}}, indent=2)
)
(tmp_path / "src").mkdir()
(tmp_path / "tests").mkdir()
@ -277,9 +265,7 @@ class TestFirstRunIntegration:
def test_subsequent_run_uses_saved_config(self, tmp_path, monkeypatch):
"""After first run, subsequent runs should not trigger first-run."""
# Create project with existing config
(tmp_path / "pyproject.toml").write_text(
'[tool.codeflash]\nmodule-root = "src"'
)
(tmp_path / "pyproject.toml").write_text('[tool.codeflash]\nmodule-root = "src"')
monkeypatch.chdir(tmp_path)