fix: resolve merge conflict and standardize Java to use FunctionToOptimize

- Resolve merge conflict in code_replacer.py with Java-specific handling
- Update all Java modules to use FunctionToOptimize instead of FunctionInfo
- Add Language.JAVA to language_enum.py
- Update attribute names: name→function_name, start_line→starting_line, etc.
- Update all Java tests to use correct attribute names

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
misrasaurabh1 2026-02-02 19:08:09 -08:00
commit 520a1ff08e
107 changed files with 22062 additions and 1476 deletions

View file

@ -19,16 +19,30 @@ jobs:
(github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: read
contents: write
pull-requests: write
issues: read
id-token: write
actions: read # Required for Claude to read CI results on PRs
steps:
- name: Get PR head ref
id: pr-ref
env:
GH_TOKEN: ${{ github.token }}
run: |
# For issue_comment events, we need to fetch the PR info
if [ "${{ github.event_name }}" = "issue_comment" ]; then
PR_REF=$(gh api repos/${{ github.repository }}/pulls/${{ github.event.issue.number }} --jq '.head.ref')
echo "ref=$PR_REF" >> $GITHUB_OUTPUT
else
echo "ref=${{ github.event.pull_request.head.ref || github.head_ref }}" >> $GITHUB_OUTPUT
fi
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
fetch-depth: 0
ref: ${{ steps.pr-ref.outputs.ref }}
- name: Run Claude Code
id: claude

50
.github/workflows/js-tests.yml vendored Normal file
View file

@ -0,0 +1,50 @@
name: JavaScript/TypeScript Integration Tests
on:
push:
branches:
- main
pull_request:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
jobs:
js-integration-tests:
name: JS/TS Integration Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Install Python dependencies
run: |
uv venv --seed
uv sync
- name: Install npm dependencies for test projects
run: |
npm install --prefix code_to_optimize/js/code_to_optimize_js
npm install --prefix code_to_optimize/js/code_to_optimize_ts
npm install --prefix code_to_optimize/js/code_to_optimize_vitest
- name: Run JavaScript integration tests
run: |
uv run pytest tests/languages/javascript/ -v
uv run pytest tests/test_languages/test_vitest_e2e.py -v
uv run pytest tests/test_languages/test_javascript_e2e.py -v
uv run pytest tests/test_languages/test_javascript_support.py -v
uv run pytest tests/code_utils/test_config_js.py -v

View file

@ -14,15 +14,13 @@
}
},
"../../../packages/codeflash": {
"version": "0.1.0",
"version": "0.3.1",
"dev": true,
"hasInstallScript": true,
"license": "MIT",
"dependencies": {
"@msgpack/msgpack": "^3.0.0",
"better-sqlite3": "^12.0.0",
"jest-junit": "^16.0.0",
"jest-runner": "^29.7.0"
"better-sqlite3": "^12.0.0"
},
"bin": {
"codeflash": "bin/codeflash.js",
@ -33,7 +31,8 @@
},
"peerDependencies": {
"jest": ">=27.0.0",
"jest-runner": ">=27.0.0"
"jest-runner": ">=27.0.0",
"vitest": ">=1.0.0"
},
"peerDependenciesMeta": {
"jest": {
@ -41,6 +40,9 @@
},
"jest-runner": {
"optional": true
},
"vitest": {
"optional": true
}
}
},

View file

@ -14,15 +14,13 @@
}
},
"../../../packages/codeflash": {
"version": "0.2.0",
"version": "0.3.1",
"dev": true,
"hasInstallScript": true,
"license": "MIT",
"dependencies": {
"@msgpack/msgpack": "^3.0.0",
"better-sqlite3": "^12.0.0",
"jest-junit": "^16.0.0",
"jest-runner": "^29.7.0"
"better-sqlite3": "^12.0.0"
},
"bin": {
"codeflash": "bin/codeflash.js",
@ -33,7 +31,8 @@
},
"peerDependencies": {
"jest": ">=27.0.0",
"jest-runner": ">=27.0.0"
"jest-runner": ">=27.0.0",
"vitest": ">=1.0.0"
},
"peerDependenciesMeta": {
"jest": {
@ -41,6 +40,9 @@
},
"jest-runner": {
"optional": true
},
"vitest": {
"optional": true
}
}
},

File diff suppressed because it is too large Load diff

View file

@ -20,7 +20,7 @@
}
},
"../../../packages/codeflash": {
"version": "0.1.0",
"version": "0.3.1",
"dev": true,
"hasInstallScript": true,
"license": "MIT",
@ -36,11 +36,19 @@
"node": ">=18.0.0"
},
"peerDependencies": {
"jest": ">=27.0.0"
"jest": ">=27.0.0",
"jest-runner": ">=27.0.0",
"vitest": ">=1.0.0"
},
"peerDependenciesMeta": {
"jest": {
"optional": true
},
"jest-runner": {
"optional": true
},
"vitest": {
"optional": true
}
}
},

View file

@ -0,0 +1 @@
language: typescript

View file

@ -0,0 +1,52 @@
/**
* Fibonacci implementations - intentionally inefficient for optimization testing.
*/
/**
* Calculate the nth Fibonacci number using naive recursion.
* This is intentionally slow to demonstrate optimization potential.
* @param n - The index of the Fibonacci number to calculate
* @returns The nth Fibonacci number
*/
export function fibonacci(n: number): number {
if (n <= 1) {
return n;
}
return fibonacci(n - 1) + fibonacci(n - 2);
}
/**
* Check if a number is a Fibonacci number.
* @param num - The number to check
* @returns True if num is a Fibonacci number
*/
export function isFibonacci(num: number): boolean {
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
const check1 = 5 * num * num + 4;
const check2 = 5 * num * num - 4;
return isPerfectSquare(check1) || isPerfectSquare(check2);
}
/**
* Check if a number is a perfect square.
* @param n - The number to check
* @returns True if n is a perfect square
*/
export function isPerfectSquare(n: number): boolean {
const sqrt = Math.sqrt(n);
return sqrt === Math.floor(sqrt);
}
/**
* Generate an array of Fibonacci numbers up to n.
* @param n - The number of Fibonacci numbers to generate
* @returns Array of Fibonacci numbers
*/
export function fibonacciSequence(n: number): number[] {
const result: number[] = [];
for (let i = 0; i < n; i++) {
result.push(fibonacci(i));
}
return result;
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,23 @@
{
"name": "codeflash-vitest-test",
"version": "1.0.0",
"description": "Sample TypeScript project with Vitest for codeflash optimization testing",
"type": "module",
"main": "dist/index.js",
"scripts": {
"test": "vitest run",
"test:watch": "vitest",
"test:coverage": "vitest run --coverage",
"build": "tsc"
},
"codeflash": {
"moduleRoot": ".",
"testsRoot": "tests"
},
"devDependencies": {
"@types/node": "^20.0.0",
"codeflash": "file:../../../packages/codeflash",
"typescript": "^5.0.0",
"vitest": "^2.0.0"
}
}

View file

@ -0,0 +1,62 @@
/**
* String utilities - intentionally inefficient for optimization testing.
*/
/**
* Reverse a string character by character.
* @param str - The string to reverse
* @returns The reversed string
*/
export function reverseString(str: string): string {
let result = '';
for (let i = str.length - 1; i >= 0; i--) {
result += str[i];
}
return result;
}
/**
* Check if a string is a palindrome.
* @param str - The string to check
* @returns True if the string is a palindrome
*/
export function isPalindrome(str: string): boolean {
const cleaned = str.toLowerCase().replace(/[^a-z0-9]/g, '');
return cleaned === reverseString(cleaned);
}
/**
* Count vowels in a string.
* @param str - The string to analyze
* @returns The number of vowels
*/
export function countVowels(str: string): number {
const vowels = 'aeiouAEIOU';
let count = 0;
for (const char of str) {
if (vowels.includes(char)) {
count++;
}
}
return count;
}
/**
* Find all unique words in a string.
* @param str - The string to analyze
* @returns Array of unique words
*/
export function uniqueWords(str: string): string[] {
const words = str.toLowerCase().split(/\s+/).filter(w => w.length > 0);
const seen = new Set<string>();
const result: string[] = [];
for (const word of words) {
if (!seen.has(word)) {
seen.add(word);
result.push(word);
}
}
return result;
}

View file

@ -0,0 +1,98 @@
import { describe, test, expect } from 'vitest';
import { fibonacci, isFibonacci, isPerfectSquare, fibonacciSequence } from '../fibonacci';
describe('fibonacci', () => {
test('returns 0 for n=0', () => {
expect(fibonacci(0)).toBe(0);
});
test('returns 1 for n=1', () => {
expect(fibonacci(1)).toBe(1);
});
test('returns 1 for n=2', () => {
expect(fibonacci(2)).toBe(1);
});
test('returns 5 for n=5', () => {
expect(fibonacci(5)).toBe(5);
});
test('returns 55 for n=10', () => {
expect(fibonacci(10)).toBe(55);
});
test('returns 233 for n=13', () => {
expect(fibonacci(13)).toBe(233);
});
});
describe('isFibonacci', () => {
test('returns true for 0', () => {
expect(isFibonacci(0)).toBe(true);
});
test('returns true for 1', () => {
expect(isFibonacci(1)).toBe(true);
});
test('returns true for 8', () => {
expect(isFibonacci(8)).toBe(true);
});
test('returns true for 13', () => {
expect(isFibonacci(13)).toBe(true);
});
test('returns false for 4', () => {
expect(isFibonacci(4)).toBe(false);
});
test('returns false for 6', () => {
expect(isFibonacci(6)).toBe(false);
});
});
describe('isPerfectSquare', () => {
test('returns true for 0', () => {
expect(isPerfectSquare(0)).toBe(true);
});
test('returns true for 1', () => {
expect(isPerfectSquare(1)).toBe(true);
});
test('returns true for 4', () => {
expect(isPerfectSquare(4)).toBe(true);
});
test('returns true for 16', () => {
expect(isPerfectSquare(16)).toBe(true);
});
test('returns false for 2', () => {
expect(isPerfectSquare(2)).toBe(false);
});
test('returns false for 3', () => {
expect(isPerfectSquare(3)).toBe(false);
});
});
describe('fibonacciSequence', () => {
test('returns empty array for n=0', () => {
expect(fibonacciSequence(0)).toEqual([]);
});
test('returns [0] for n=1', () => {
expect(fibonacciSequence(1)).toEqual([0]);
});
test('returns first 5 Fibonacci numbers', () => {
expect(fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
});
test('returns first 10 Fibonacci numbers', () => {
expect(fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
});
});

View file

@ -0,0 +1,94 @@
import { describe, test, expect } from 'vitest';
import { reverseString, isPalindrome, countVowels, uniqueWords } from '../string_utils';
describe('reverseString', () => {
test('reverses a simple string', () => {
expect(reverseString('hello')).toBe('olleh');
});
test('returns empty string for empty input', () => {
expect(reverseString('')).toBe('');
});
test('returns same character for single character', () => {
expect(reverseString('a')).toBe('a');
});
test('handles strings with spaces', () => {
expect(reverseString('hello world')).toBe('dlrow olleh');
});
test('handles palindrome', () => {
expect(reverseString('racecar')).toBe('racecar');
});
});
describe('isPalindrome', () => {
test('returns true for palindrome', () => {
expect(isPalindrome('racecar')).toBe(true);
});
test('returns true for palindrome with spaces', () => {
expect(isPalindrome('A man a plan a canal Panama')).toBe(true);
});
test('returns false for non-palindrome', () => {
expect(isPalindrome('hello')).toBe(false);
});
test('returns true for empty string', () => {
expect(isPalindrome('')).toBe(true);
});
test('returns true for single character', () => {
expect(isPalindrome('a')).toBe(true);
});
test('handles mixed case', () => {
expect(isPalindrome('RaceCar')).toBe(true);
});
});
describe('countVowels', () => {
test('counts vowels in simple string', () => {
expect(countVowels('hello')).toBe(2);
});
test('returns 0 for string with no vowels', () => {
expect(countVowels('bcdfg')).toBe(0);
});
test('returns 0 for empty string', () => {
expect(countVowels('')).toBe(0);
});
test('counts uppercase vowels', () => {
expect(countVowels('HELLO')).toBe(2);
});
test('counts all vowels', () => {
expect(countVowels('aeiouAEIOU')).toBe(10);
});
});
describe('uniqueWords', () => {
test('finds unique words in simple string', () => {
expect(uniqueWords('hello world')).toEqual(['hello', 'world']);
});
test('removes duplicates', () => {
expect(uniqueWords('hello hello world')).toEqual(['hello', 'world']);
});
test('returns empty array for empty string', () => {
expect(uniqueWords('')).toEqual([]);
});
test('handles multiple spaces', () => {
expect(uniqueWords('hello world')).toEqual(['hello', 'world']);
});
test('normalizes case', () => {
expect(uniqueWords('Hello hello HELLO')).toEqual(['hello']);
});
});

View file

@ -0,0 +1,15 @@
{
"compilerOptions": {
"target": "ES2020",
"module": "ESNext",
"moduleResolution": "node",
"esModuleInterop": true,
"strict": true,
"skipLibCheck": true,
"outDir": "./dist",
"declaration": true,
"types": ["vitest/globals", "node"]
},
"include": ["./*.ts", "./tests/**/*.ts"],
"exclude": ["node_modules", "dist"]
}

View file

@ -0,0 +1,13 @@
import { defineConfig } from 'vitest/config';
export default defineConfig({
test: {
globals: true,
environment: 'node',
include: ['tests/**/*.test.ts'],
reporters: ['default', 'junit'],
outputFile: {
junit: '.codeflash/vitest-results.xml',
},
},
});

View file

@ -12,6 +12,7 @@ from codeflash.cli_cmds.extension import install_vscode_extension
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.languages.test_framework import set_current_test_framework
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.version import __version__ as version
@ -121,6 +122,15 @@ def parse_args() -> Namespace:
"--effort", type=str, help="Effort level for optimization", choices=["low", "medium", "high"], default="medium"
)
# Config management flags
parser.add_argument(
"--show-config", action="store_true", help="Show current or auto-detected configuration and exit."
)
parser.add_argument(
"--reset-config", action="store_true", help="Remove codeflash configuration from project config file."
)
parser.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts (useful for CI/scripts).")
args, unknown_args = parser.parse_known_args()
sys.argv[:] = [sys.argv[0], *unknown_args]
return process_and_validate_cmd_args(args)
@ -147,6 +157,16 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
logger.info(f"Codeflash version {version}")
sys.exit()
# Handle --show-config
if getattr(args, "show_config", False):
_handle_show_config()
sys.exit()
# Handle --reset-config
if getattr(args, "reset_config", False):
_handle_reset_config(confirm=not getattr(args, "yes", False))
sys.exit()
if args.command == "vscode-install":
install_vscode_extension()
sys.exit()
@ -210,6 +230,11 @@ def process_pyproject_config(args: Namespace) -> Namespace:
# For JS/TS projects, tests_root is optional (Jest auto-discovers tests)
# Default to module_root if not specified
is_js_ts_project = pyproject_config.get("language") in ("javascript", "typescript")
# Set the test framework singleton for JS/TS projects
if is_js_ts_project and pyproject_config.get("test_framework"):
set_current_test_framework(pyproject_config["test_framework"])
if args.tests_root is None:
if is_js_ts_project:
# Try common JS test directories at project root first
@ -334,3 +359,92 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
else:
args.all = Path(args.all).resolve()
return args
def _handle_show_config() -> None:
"""Show current or auto-detected Codeflash configuration."""
from rich.table import Table
from codeflash.cli_cmds.console import console
from codeflash.setup.detector import detect_project, has_existing_config
project_root = Path.cwd()
detected = detect_project(project_root)
# Check if config exists or is auto-detected
config_exists, _ = has_existing_config(project_root)
status = "Saved config" if config_exists else "Auto-detected (not saved)"
console.print()
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
console.print()
table = Table(show_header=True, header_style="bold cyan")
table.add_column("Setting", style="dim")
table.add_column("Value")
table.add_row("Language", detected.language)
table.add_row("Project root", str(detected.project_root))
table.add_row("Module root", str(detected.module_root))
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
table.add_row("Test runner", detected.test_runner or "(not detected)")
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
table.add_row(
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
)
table.add_row("Confidence", f"{detected.confidence:.0%}")
console.print(table)
console.print()
if not config_exists:
console.print("[dim]Run [bold]codeflash --file <file>[/bold] to auto-save this config.[/dim]")
def _handle_reset_config(confirm: bool = True) -> None:
"""Remove Codeflash configuration from project config file.
Args:
confirm: If True, prompt for confirmation before removing.
"""
from codeflash.cli_cmds.console import console
from codeflash.setup.config_writer import remove_config
from codeflash.setup.detector import detect_project, has_existing_config
project_root = Path.cwd()
config_exists, _ = has_existing_config(project_root)
if not config_exists:
console.print("[yellow]No Codeflash configuration found to remove.[/yellow]")
return
detected = detect_project(project_root)
if confirm:
console.print("[bold]This will remove Codeflash configuration from your project.[/bold]")
console.print()
config_file = "pyproject.toml" if detected.language == "python" else "package.json"
console.print(f" Config file: {project_root / config_file}")
console.print()
try:
response = console.input("[bold]Are you sure you want to remove the config? [y/N][/bold] ")
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Cancelled.[/yellow]")
return
if response.lower() not in ("y", "yes"):
console.print("[yellow]Cancelled.[/yellow]")
return
success, message = remove_config(project_root, detected.language)
# Escape brackets in message to prevent Rich markup interpretation
escaped_message = message.replace("[", "\\[")
if success:
console.print(f"[green]✓[/green] {escaped_message}")
else:
console.print(f"[red]✗[/red] {escaped_message}")

View file

@ -614,7 +614,33 @@ def check_for_toml_or_setup_file() -> str | None:
curdir = Path.cwd()
pyproject_toml_path = curdir / "pyproject.toml"
setup_py_path = curdir / "setup.py"
package_json_path = curdir / "package.json"
project_name = None
# Check if this might be a JavaScript/TypeScript project that wasn't detected
if package_json_path.exists() and not pyproject_toml_path.exists() and not setup_py_path.exists():
js_redirect_panel = Panel(
Text(
f"📦 I found a package.json in {curdir}.\n\n"
"This looks like a JavaScript/TypeScript project!\n"
"Redirecting to JavaScript setup...",
style="cyan",
),
title="🟨 JavaScript Project Detected",
border_style="bright_yellow",
)
console.print(js_redirect_panel)
console.print()
ph("cli-js-project-redirect")
# Redirect to JS init
from codeflash.cli_cmds.init_javascript import ProjectLanguage, detect_project_language, init_js_project
project_language = detect_project_language()
if project_language in (ProjectLanguage.JAVASCRIPT, ProjectLanguage.TYPESCRIPT):
init_js_project(project_language)
sys.exit(0) # init_js_project handles its own exit, but ensure we don't continue
if pyproject_toml_path.exists():
try:
pyproject_toml_content = pyproject_toml_path.read_text(encoding="utf8")
@ -624,28 +650,44 @@ def check_for_toml_or_setup_file() -> str | None:
except Exception:
click.echo("✅ I found a pyproject.toml for your project.")
ph("cli-pyproject-toml-found")
elif setup_py_path.exists():
setup_py_content = setup_py_path.read_text(encoding="utf8")
project_name_match = re.search(r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]", setup_py_content, re.DOTALL)
if project_name_match:
project_name = project_name_match.group(1)
click.echo(f"✅ Found setup.py for your project {project_name}")
ph("cli-setup-py-found-name")
else:
click.echo("✅ Found setup.py.")
ph("cli-setup-py-found")
else:
if setup_py_path.exists():
setup_py_content = setup_py_path.read_text(encoding="utf8")
project_name_match = re.search(r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]", setup_py_content, re.DOTALL)
if project_name_match:
project_name = project_name_match.group(1)
click.echo(f"✅ Found setup.py for your project {project_name}")
ph("cli-setup-py-found-name")
else:
click.echo("✅ Found setup.py.")
ph("cli-setup-py-found")
toml_info_panel = Panel(
Text(
f"💡 No pyproject.toml found in {curdir}.\n\n"
"This file is essential for Codeflash to store its configuration.\n"
"Please ensure you are running `codeflash init` from your project's root directory.",
style="yellow",
),
title="📋 pyproject.toml Required",
border_style="bright_yellow",
)
console.print(toml_info_panel)
# No Python config files found - show appropriate message
# Check again if this might be a JS project
if package_json_path.exists():
js_hint_panel = Panel(
Text(
f"📦 I found a package.json but no pyproject.toml in {curdir}.\n\n"
"If this is a JavaScript/TypeScript project, please run:\n"
" codeflash init\n\n"
"from the project root directory.",
style="yellow",
),
title="🤔 Mixed Project?",
border_style="bright_yellow",
)
console.print(js_hint_panel)
else:
toml_info_panel = Panel(
Text(
f"💡 No pyproject.toml found in {curdir}.\n\n"
"This file is essential for Codeflash to store its configuration.\n"
"Please ensure you are running `codeflash init` from your project's root directory.",
style="yellow",
),
title="📋 pyproject.toml Required",
border_style="bright_yellow",
)
console.print(toml_info_panel)
console.print()
ph("cli-no-pyproject-toml-or-setup-py")

View file

@ -129,6 +129,9 @@ def code_print(
spinners = cycle(SPINNER_TYPES)
# Track whether a progress bar is already active to prevent nested Live displays
_progress_bar_active = False
@contextmanager
def progress_bar(
@ -138,28 +141,38 @@ def progress_bar(
If revert_to_print is True, falls back to printing a single logger.info message
instead of showing a progress bar.
If a progress bar is already active, yields a dummy task ID to avoid Rich's
LiveError from nested Live displays.
"""
global _progress_bar_active
if is_LSP_enabled():
lsp_log(LspTextMessage(text=message, takes_time=True))
yield
return
if revert_to_print:
logger.info(message)
if revert_to_print or _progress_bar_active:
if revert_to_print:
logger.info(message)
# Create a fake task ID since we still need to yield something
yield DummyTask().id
else:
progress = Progress(
SpinnerColumn(next(spinners)),
*Progress.get_default_columns(),
TimeElapsedColumn(),
console=console,
transient=transient,
)
task = progress.add_task(message, total=None)
with progress:
yield task
_progress_bar_active = True
try:
progress = Progress(
SpinnerColumn(next(spinners)),
*Progress.get_default_columns(),
TimeElapsedColumn(),
console=console,
transient=transient,
)
task = progress.add_task(message, total=None)
with progress:
yield task
finally:
_progress_bar_active = False
@contextmanager

View file

@ -16,6 +16,7 @@ import inquirer
from git import InvalidGitRepositoryError, Repo
from rich.console import Group
from rich.panel import Panel
from rich.prompt import Confirm
from rich.table import Table
from rich.text import Text
@ -98,34 +99,97 @@ def detect_project_language(project_root: Path | None = None) -> ProjectLanguage
if has_pom_xml or has_build_gradle or has_java_src:
return ProjectLanguage.JAVA
# TypeScript project
# TypeScript project (tsconfig.json is definitive)
if has_tsconfig:
return ProjectLanguage.TYPESCRIPT
# Pure JS project (has package.json but no Python files)
if has_package_json and not has_pyproject and not has_setup_py:
return ProjectLanguage.JAVASCRIPT
# JavaScript project - package.json without Python-specific files takes priority
# Note: If both package.json and pyproject.toml exist, check for typical JS project indicators
if has_package_json:
# If no Python config files, it's definitely JavaScript
if not has_pyproject and not has_setup_py:
return ProjectLanguage.JAVASCRIPT
# If package.json exists with Python files, check for JS-specific indicators
# Common React/Node patterns indicate a JS project
js_indicators = [
(root / "node_modules").exists(),
(root / ".npmrc").exists(),
(root / "yarn.lock").exists(),
(root / "package-lock.json").exists(),
(root / "pnpm-lock.yaml").exists(),
(root / "bun.lockb").exists(),
(root / "bun.lock").exists(),
]
if any(js_indicators):
return ProjectLanguage.JAVASCRIPT
# Python project (default)
return ProjectLanguage.PYTHON
def determine_js_package_manager(project_root: Path) -> JsPackageManager:
"""Determine which JavaScript package manager is being used based on lock files."""
if (project_root / "bun.lockb").exists() or (project_root / "bun.lock").exists():
return JsPackageManager.BUN
if (project_root / "pnpm-lock.yaml").exists():
return JsPackageManager.PNPM
if (project_root / "yarn.lock").exists():
return JsPackageManager.YARN
if (project_root / "package-lock.json").exists():
return JsPackageManager.NPM
# Default to npm if package.json exists but no lock file
"""Determine which JavaScript package manager is being used based on lock files.
Searches the project_root directory and parent directories (for monorepo support)
to find lock files that indicate which package manager is being used.
"""
# Search from project_root up to filesystem root for lock files
# This supports monorepo setups where lock file is at workspace root
current_dir = project_root.resolve()
while current_dir != current_dir.parent:
if (current_dir / "bun.lockb").exists() or (current_dir / "bun.lock").exists():
return JsPackageManager.BUN
if (current_dir / "pnpm-lock.yaml").exists():
return JsPackageManager.PNPM
if (current_dir / "yarn.lock").exists():
return JsPackageManager.YARN
if (current_dir / "package-lock.json").exists():
return JsPackageManager.NPM
current_dir = current_dir.parent
# Default to npm if package.json exists but no lock file found anywhere
if (project_root / "package.json").exists():
return JsPackageManager.NPM
return JsPackageManager.UNKNOWN
def get_package_install_command(project_root: Path, package: str, dev: bool = True) -> list[str]:
"""Get the correct install command for the project's package manager.
Args:
project_root: The project root directory.
package: The package name to install.
dev: Whether to install as a dev dependency (default: True).
Returns:
List of command arguments for subprocess execution.
"""
pkg_manager = determine_js_package_manager(project_root)
if pkg_manager == JsPackageManager.PNPM:
cmd = ["pnpm", "add", package]
if dev:
cmd.append("--save-dev")
return cmd
if pkg_manager == JsPackageManager.YARN:
cmd = ["yarn", "add", package]
if dev:
cmd.append("--dev")
return cmd
if pkg_manager == JsPackageManager.BUN:
cmd = ["bun", "add", package]
if dev:
cmd.append("--dev")
return cmd
# Default to npm
cmd = ["npm", "install", package]
if dev:
cmd.append("--save-dev")
return cmd
def init_js_project(language: ProjectLanguage) -> None:
"""Initialize Codeflash for a JavaScript/TypeScript project."""
from codeflash.cli_cmds.cmd_init import install_github_actions, install_github_app, prompt_api_key
@ -199,9 +263,7 @@ def init_js_project(language: ProjectLanguage) -> None:
def should_modify_package_json_config() -> tuple[bool, dict[str, Any] | None]:
"""Check if package.json has valid codeflash config for JS/TS projects."""
from rich.prompt import Confirm
package_json_path = Path.cwd() / "package.json"
package_json_path = Path("package.json")
if not package_json_path.exists():
click.echo("❌ No package.json found. Please run 'npm init' first.")
@ -221,6 +283,10 @@ def should_modify_package_json_config() -> tuple[bool, dict[str, Any] | None]:
if not Path(module_root).is_dir():
return True, None
tests_root = config.get("testsRoot", None)
if tests_root and not Path(tests_root).is_dir():
return True, None
# Config is valid - ask if user wants to reconfigure
return Confirm.ask(
"✅ A valid Codeflash config already exists in package.json. Do you want to re-configure it?",

View file

@ -1563,23 +1563,228 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo
def get_opt_review_metrics(
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path, language: Language
) -> str:
if language != Language.PYTHON:
# TODO: {Claude} handle function refrences for other languages
return ""
"""Get function reference metrics for optimization review.
Uses the LanguageSupport abstraction to find references, supporting both Python and JavaScript/TypeScript.
Args:
source_code: Source code of the file containing the function.
file_path: Path to the file.
qualified_name: Qualified name of the function (e.g., "module.ClassName.method").
project_root: Root of the project.
tests_root: Root of the tests directory.
language: The programming language.
Returns:
Markdown-formatted string with code blocks showing calling functions.
"""
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.registry import get_language_support
from codeflash.models.models import FunctionParent
start_time = time.perf_counter()
try:
# Get the language support
lang_support = get_language_support(language)
if lang_support is None:
return ""
# Parse qualified name to get function name and class name
qualified_name_split = qualified_name.rsplit(".", maxsplit=1)
if len(qualified_name_split) == 1:
target_function, target_class = qualified_name_split[0], None
function_name, class_name = qualified_name_split[0], None
else:
target_function, target_class = qualified_name_split[1], qualified_name_split[0]
matches = get_fn_references_jedi(
source_code, file_path, project_root, target_function, target_class
) # jedi is not perfect, it doesn't capture aliased references
calling_fns_details = find_occurances(qualified_name, str(file_path), matches, project_root, tests_root)
function_name, class_name = qualified_name_split[1], qualified_name_split[0]
# Create a FunctionToOptimize for the function
# We don't have full line info here, so we'll use defaults
parents: list[FunctionParent] = []
if class_name:
parents = [FunctionParent(name=class_name, type="ClassDef")]
func_info = FunctionToOptimize(
function_name=function_name,
file_path=file_path,
parents=parents,
starting_line=1,
ending_line=1,
language=str(language),
)
# Find references using language support
references = lang_support.find_references(func_info, project_root, tests_root, max_files=500)
if not references:
return ""
# Format references as markdown code blocks
calling_fns_details = _format_references_as_markdown(references, file_path, project_root, language)
except Exception as e:
logger.debug(f"Error getting function references: {e}")
calling_fns_details = ""
logger.debug(f"Investigate {e}")
end_time = time.perf_counter()
logger.debug(f"Got function references in {end_time - start_time:.2f} seconds")
return calling_fns_details
def _format_references_as_markdown(references: list, file_path: Path, project_root: Path, language: Language) -> str:
"""Format references as markdown code blocks with calling function code.
Args:
references: List of ReferenceInfo objects.
file_path: Path to the source file (to exclude).
project_root: Root of the project.
language: The programming language.
Returns:
Markdown-formatted string.
"""
# Group references by file
refs_by_file: dict[Path, list] = {}
for ref in references:
# Exclude the source file's definition/import references
if ref.file_path == file_path and ref.reference_type in ("import", "reexport"):
continue
if ref.file_path not in refs_by_file:
refs_by_file[ref.file_path] = []
refs_by_file[ref.file_path].append(ref)
fn_call_context = ""
context_len = 0
for ref_file, file_refs in refs_by_file.items():
if context_len > MAX_CONTEXT_LEN_REVIEW:
break
try:
path_relative = ref_file.relative_to(project_root)
except ValueError:
continue
# Get syntax highlighting language
ext = ref_file.suffix.lstrip(".")
if language == Language.PYTHON:
lang_hint = "python"
elif ext in ("ts", "tsx"):
lang_hint = "typescript"
else:
lang_hint = "javascript"
# Read the file to extract calling function context
try:
file_content = ref_file.read_text(encoding="utf-8")
lines = file_content.splitlines()
except Exception:
continue
# Get unique caller functions from this file
callers_seen: set[str] = set()
caller_contexts: list[str] = []
for ref in file_refs:
caller = ref.caller_function or "<module>"
if caller in callers_seen:
continue
callers_seen.add(caller)
# Extract context around the reference
if ref.caller_function:
# Try to extract the full calling function
func_code = _extract_calling_function(file_content, ref.caller_function, ref.line, language)
if func_code:
caller_contexts.append(func_code)
context_len += len(func_code)
else:
# Module-level call - show a few lines of context
start_line = max(0, ref.line - 3)
end_line = min(len(lines), ref.line + 2)
context_code = "\n".join(lines[start_line:end_line])
caller_contexts.append(context_code)
context_len += len(context_code)
if caller_contexts:
fn_call_context += f"```{lang_hint}:{path_relative.as_posix()}\n"
fn_call_context += "\n".join(caller_contexts)
fn_call_context += "\n```\n"
return fn_call_context
def _extract_calling_function(source_code: str, function_name: str, ref_line: int, language: Language) -> str | None:
"""Extract the source code of a calling function.
Args:
source_code: Full source code of the file.
function_name: Name of the function to extract.
ref_line: Line number where the reference is.
language: The programming language.
Returns:
Source code of the function, or None if not found.
"""
if language == Language.PYTHON:
return _extract_calling_function_python(source_code, function_name, ref_line)
return _extract_calling_function_js(source_code, function_name, ref_line)
def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None:
"""Extract the source code of a calling function in Python."""
try:
import ast
tree = ast.parse(source_code)
lines = source_code.splitlines()
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if node.name == function_name:
# Check if the reference line is within this function
start_line = node.lineno
end_line = node.end_lineno or start_line
if start_line <= ref_line <= end_line:
return "\n".join(lines[start_line - 1 : end_line])
return None
except Exception:
return None
def _extract_calling_function_js(source_code: str, function_name: str, ref_line: int) -> str | None:
"""Extract the source code of a calling function in JavaScript/TypeScript.
Args:
source_code: Full source code of the file.
function_name: Name of the function to extract.
ref_line: Line number where the reference is (helps identify the right function).
Returns:
Source code of the function, or None if not found.
"""
try:
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
# Try TypeScript first, fall back to JavaScript
for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]:
try:
analyzer = TreeSitterAnalyzer(lang)
functions = analyzer.find_functions(source_code, include_methods=True)
for func in functions:
if func.name == function_name:
# Check if the reference line is within this function
if func.start_line <= ref_line <= func.end_line:
return func.source_text
break
except Exception:
continue
return None
except Exception:
return None

View file

@ -497,7 +497,7 @@ def replace_function_definitions_for_language(
"""
from codeflash.languages import get_language_support
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
from codeflash.languages.base import Language
original_source_code: str = module_abspath.read_text(encoding="utf8")
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
@ -525,31 +525,21 @@ def replace_function_definitions_for_language(
and function_to_optimize.ending_line
and function_to_optimize.file_path == module_abspath
):
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents)
func_info = FunctionInfo(
name=function_to_optimize.function_name,
file_path=module_abspath,
start_line=function_to_optimize.starting_line,
end_line=function_to_optimize.ending_line,
parents=parents,
is_async=function_to_optimize.is_async,
language=language,
)
# For Java, we need to pass the full optimized code so replace_function can
# extract and add any new class members (static fields, helper methods).
# For other languages, we extract just the target function.
if language == Language.JAVA:
new_code = lang_support.replace_function(original_source_code, func_info, code_to_apply)
new_code = lang_support.replace_function(original_source_code, function_to_optimize, code_to_apply)
else:
# Extract just the target function from the optimized code
optimized_func = _extract_function_from_code(
lang_support, code_to_apply, function_to_optimize.function_name, module_abspath
)
if optimized_func:
new_code = lang_support.replace_function(original_source_code, func_info, optimized_func)
new_code = lang_support.replace_function(original_source_code, function_to_optimize, optimized_func)
else:
# Fallback: use the entire optimized code (for simple single-function files)
new_code = lang_support.replace_function(original_source_code, func_info, code_to_apply)
new_code = lang_support.replace_function(original_source_code, function_to_optimize, code_to_apply)
else:
# For helper files or when we don't have precise line info:
# Find each function by name in both original and optimized code
@ -567,7 +557,7 @@ def replace_function_definitions_for_language(
# Find the function in current code
func = None
for f in current_functions:
if func_name in (f.qualified_name, f.name):
if func_name in (f.qualified_name, f.function_name):
func = f
break
@ -581,7 +571,9 @@ def replace_function_definitions_for_language(
modified = True
else:
# Extract just this function from the optimized code
optimized_func = _extract_function_from_code(lang_support, code_to_apply, func.name, module_abspath)
optimized_func = _extract_function_from_code(
lang_support, code_to_apply, func.function_name, module_abspath
)
if optimized_func:
new_code = lang_support.replace_function(new_code, func, optimized_func)
modified = True
@ -620,13 +612,13 @@ def _extract_function_from_code(
# file_path is needed for JS/TS to determine correct analyzer (TypeScript vs JavaScript)
functions = lang_support.discover_functions_from_source(source_code, file_path)
for func in functions:
if func.name == function_name:
if func.function_name == function_name:
# Extract the function's source using line numbers
# Use doc_start_line if available to include JSDoc/docstring
lines = source_code.splitlines(keepends=True)
effective_start = func.doc_start_line or func.start_line
if effective_start and func.end_line and effective_start <= len(lines):
func_lines = lines[effective_start - 1 : func.end_line]
effective_start = func.doc_start_line or func.starting_line
if effective_start and func.ending_line and effective_start <= len(lines):
func_lines = lines[effective_start - 1 : func.ending_line]
return "".join(func_lines)
except Exception as e:
logger.debug(f"Error extracting function {function_name}: {e}")
@ -753,6 +745,10 @@ def _add_global_declarations_for_language(
replacement.py handles them. Adding them here would shift line numbers and
break method matching for overloaded methods.
New declarations are inserted after any existing declarations they depend on.
For example, if optimized code has `const _has = FOO.bar.bind(FOO)`, and `FOO`
is already declared in the original source, `_has` will be inserted after `FOO`.
Args:
optimized_code: The optimized code that may contain new declarations.
original_source: The original source code.
@ -761,7 +757,7 @@ def _add_global_declarations_for_language(
target_function_names: List of function names being optimized (to exclude from Java helpers).
Returns:
Original source with new declarations added.
Original source with new declarations added in dependency order.
"""
from codeflash.languages.base import Language
@ -771,7 +767,6 @@ def _add_global_declarations_for_language(
if language == Language.JAVA:
return original_source
# Only process JavaScript/TypeScript for module-level declarations
if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT):
return original_source
@ -780,84 +775,164 @@ def _add_global_declarations_for_language(
analyzer = get_analyzer_for_file(module_abspath)
# Find declarations in both original and optimized code
original_declarations = analyzer.find_module_level_declarations(original_source)
optimized_declarations = analyzer.find_module_level_declarations(optimized_code)
if not optimized_declarations:
return original_source
# Get names of existing declarations
existing_names = {decl.name for decl in original_declarations}
# Also exclude names that are already imported (to avoid duplicating imported types)
original_imports = analyzer.find_imports(original_source)
for imp in original_imports:
# Add default import name
if imp.default_import:
existing_names.add(imp.default_import)
# Add named imports (use alias if present, otherwise use original name)
for name, alias in imp.named_imports:
existing_names.add(alias if alias else name)
# Add namespace import
if imp.namespace_import:
existing_names.add(imp.namespace_import)
# Find new declarations (names that don't exist in original)
new_declarations = []
seen_sources = set() # Track to avoid duplicates from destructuring
for decl in optimized_declarations:
if decl.name not in existing_names and decl.source_code not in seen_sources:
new_declarations.append(decl)
seen_sources.add(decl.source_code)
existing_names = _get_existing_names(original_declarations, analyzer, original_source)
new_declarations = _filter_new_declarations(optimized_declarations, existing_names)
if not new_declarations:
return original_source
# Sort by line number to maintain order
new_declarations.sort(key=lambda d: d.start_line)
# Build a map of existing declaration names to their end lines (1-indexed)
existing_decl_end_lines = {decl.name: decl.end_line for decl in original_declarations}
# Find insertion point (after imports)
lines = original_source.splitlines(keepends=True)
insertion_line = _find_insertion_line_after_imports_js(lines, analyzer, original_source)
# Insert each new declaration after its dependencies
result = original_source
for decl in new_declarations:
result = _insert_declaration_after_dependencies(
result, decl, existing_decl_end_lines, analyzer, module_abspath
)
# Update the map with the newly inserted declaration for subsequent insertions
# Re-parse to get accurate line numbers after insertion
updated_declarations = analyzer.find_module_level_declarations(result)
existing_decl_end_lines = {d.name: d.end_line for d in updated_declarations}
# Build new declarations string
new_decl_code = "\n".join(decl.source_code for decl in new_declarations)
new_decl_code = new_decl_code + "\n\n"
# Insert declarations
before = lines[:insertion_line]
after = lines[insertion_line:]
result_lines = [*before, new_decl_code, *after]
return "".join(result_lines)
return result
except Exception as e:
logger.debug(f"Error adding global declarations: {e}")
return original_source
def _find_insertion_line_after_imports_js(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
"""Find the line index where new declarations should be inserted (after imports).
def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyzer, original_source: str) -> set[str]:
"""Get all names that already exist in the original source (declarations + imports)."""
existing_names = {decl.name for decl in original_declarations}
original_imports = analyzer.find_imports(original_source)
for imp in original_imports:
if imp.default_import:
existing_names.add(imp.default_import)
for name, alias in imp.named_imports:
existing_names.add(alias if alias else name)
if imp.namespace_import:
existing_names.add(imp.namespace_import)
return existing_names
def _filter_new_declarations(optimized_declarations: list, existing_names: set[str]) -> list:
"""Filter declarations to only those that don't exist in the original source."""
new_declarations = []
seen_sources: set[str] = set()
# Sort by line number to maintain order from optimized code
sorted_declarations = sorted(optimized_declarations, key=lambda d: d.start_line)
for decl in sorted_declarations:
if decl.name not in existing_names and decl.source_code not in seen_sources:
new_declarations.append(decl)
seen_sources.add(decl.source_code)
return new_declarations
def _insert_declaration_after_dependencies(
source: str,
declaration,
existing_decl_end_lines: dict[str, int],
analyzer: TreeSitterAnalyzer,
module_abspath: Path,
) -> str:
"""Insert a declaration after the last existing declaration it depends on.
Args:
source: Current source code.
declaration: The declaration to insert.
existing_decl_end_lines: Map of existing declaration names to their end lines.
analyzer: TreeSitter analyzer.
module_abspath: Path to the module file.
Returns:
Source code with the declaration inserted at the correct position.
"""
# Find identifiers referenced in this declaration
referenced_names = analyzer.find_referenced_identifiers(declaration.source_code)
# Find the latest end line among all referenced declarations
insertion_line = _find_insertion_line_for_declaration(source, referenced_names, existing_decl_end_lines, analyzer)
lines = source.splitlines(keepends=True)
# Ensure proper spacing
decl_code = declaration.source_code
if not decl_code.endswith("\n"):
decl_code += "\n"
# Add blank line before if inserting after content
if insertion_line > 0 and lines[insertion_line - 1].strip():
decl_code = "\n" + decl_code
before = lines[:insertion_line]
after = lines[insertion_line:]
return "".join([*before, decl_code, *after])
def _find_insertion_line_for_declaration(
source: str, referenced_names: set[str], existing_decl_end_lines: dict[str, int], analyzer: TreeSitterAnalyzer
) -> int:
"""Find the line where a declaration should be inserted based on its dependencies.
Args:
source: Source code.
referenced_names: Names referenced by the declaration.
existing_decl_end_lines: Map of declaration names to their end lines (1-indexed).
analyzer: TreeSitter analyzer.
Returns:
Line index (0-based) where the declaration should be inserted.
"""
# Find the maximum end line among referenced declarations
max_dependency_line = 0
for name in referenced_names:
if name in existing_decl_end_lines:
max_dependency_line = max(max_dependency_line, existing_decl_end_lines[name])
if max_dependency_line > 0:
# Insert after the last dependency (end_line is 1-indexed, we need 0-indexed)
return max_dependency_line
# No dependencies found - insert after imports
lines = source.splitlines(keepends=True)
return _find_line_after_imports(lines, analyzer, source)
def _find_line_after_imports(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
"""Find the line index after all imports.
Args:
lines: Source lines.
analyzer: TreeSitter analyzer for the file.
analyzer: TreeSitter analyzer.
source: Full source code.
Returns:
Line index (0-based) for insertion.
Line index (0-based) for insertion after imports.
"""
try:
imports = analyzer.find_imports(source)
if imports:
# Find the last import's end line
return max(imp.end_line for imp in imports)
except Exception as exc:
logger.debug(f"Exception occurred in _find_insertion_line_after_imports_js: {exc}")
logger.debug(f"Exception in _find_line_after_imports: {exc}")
# Default: insert at beginning (after any shebang/directive comments)
# Default: insert at beginning (after shebang/directive comments)
for i, line in enumerate(lines):
stripped = line.strip()
if stripped and not stripped.startswith("//") and not stripped.startswith("#!"):

View file

@ -212,15 +212,18 @@ def parse_package_json_config(package_json_path: Path) -> tuple[dict[str, Any],
Most configuration is auto-detected from package.json and project structure.
Only minimal config is stored in the "codeflash" key:
- moduleRoot: Override auto-detected module root (optional)
- testsRoot: Override auto-detected tests root (optional)
- test-framework: Override auto-detected test framework - "jest", "vitest", or "mocha" (optional)
- benchmarksRoot: Where to store benchmark files (optional, defaults to __benchmarks__)
- ignorePaths: Paths to exclude from optimization (optional)
- disableTelemetry: Privacy preference (optional, defaults to false)
- formatterCmds: Override auto-detected formatter (optional)
Auto-detected values (not stored in config):
Auto-detected values (used when not explicitly configured):
- language: Detected from tsconfig.json presence
- moduleRoot: Detected from package.json exports/module/main or src/ convention
- testRunner: Detected from devDependencies (vitest/jest/mocha)
- test-framework: Detected from devDependencies (vitest/jest/mocha)
- formatter: Detected from devDependencies (prettier/eslint)
Args:
@ -251,10 +254,17 @@ def parse_package_json_config(package_json_path: Path) -> tuple[dict[str, Any],
detected_module_root = detect_module_root(project_root, package_data)
config["module_root"] = str((project_root / Path(detected_module_root)).resolve())
# Auto-detect test runner
config["test_runner"] = detect_test_runner(project_root, package_data)
if codeflash_config.get("testsRoot"):
config["tests_root"] = str(project_root / Path(codeflash_config["testsRoot"]).resolve())
# Check for explicit test framework override, otherwise auto-detect
# Uses "test-framework" to match Python's pyproject.toml convention
if codeflash_config.get("test-framework"):
config["test_framework"] = codeflash_config["test-framework"]
else:
config["test_framework"] = detect_test_runner(project_root, package_data)
# Keep pytest_cmd for backwards compatibility with existing code
config["pytest_cmd"] = config["test_runner"]
config["pytest_cmd"] = config["test_framework"]
# Auto-detect formatter (with optional override from config)
if "formatterCmds" in codeflash_config:

View file

@ -13,10 +13,13 @@ from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.formatter import format_code
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
from codeflash.languages.registry import get_language_support_by_common_formatters
from codeflash.lsp.helpers import is_LSP_enabled
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool:
def check_formatter_installed(
formatter_cmds: list[str], exit_on_failure: bool = True, language: str = "python"
) -> bool:
if not formatter_cmds or formatter_cmds[0] == "disabled":
return True
first_cmd = formatter_cmds[0]
@ -35,10 +38,21 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
)
return False
tmp_code = """print("hello world")"""
lang_support = get_language_support_by_common_formatters(formatter_cmds)
if not lang_support:
logger.debug(f"Could not determine language for formatter: {formatter_cmds}")
return True
if str(lang_support.language) == "python":
tmp_code = """print("hello world")"""
elif str(lang_support.language) in ("javascript", "typescript"):
tmp_code = "console.log('hello world');"
else:
return True
try:
with tempfile.TemporaryDirectory() as tmpdir:
tmp_file = Path(tmpdir) / "test_codeflash_formatter.py"
tmp_file = Path(tmpdir) / ("test_codeflash_formatter" + lang_support.default_file_extension)
tmp_file.write_text(tmp_code, encoding="utf-8")
format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=False)
return True

View file

@ -42,13 +42,16 @@ def generate_unified_diff(original: str, modified: str, from_file: str, to_file:
def apply_formatter_cmds(
cmds: list[str], path: Path, test_dir_str: Optional[str], print_status: bool, exit_on_failure: bool = True
) -> tuple[Path, str, bool]:
from codeflash.languages.registry import get_language_support
if not path.exists():
msg = f"File {path} does not exist. Cannot apply formatter commands."
raise FileNotFoundError(msg)
file_path = path
lang_support = get_language_support(path)
if test_dir_str:
file_path = Path(test_dir_str) / "temp.py"
file_path = Path(test_dir_str) / ("temp" + lang_support.default_file_extension)
shutil.copy2(path, file_path)
file_token = "$file" # noqa: S105
@ -87,13 +90,16 @@ def get_diff_lines_count(diff_output: str) -> int:
return len(diff_lines)
def format_generated_code(generated_test_source: str, formatter_cmds: list[str]) -> str:
def format_generated_code(generated_test_source: str, formatter_cmds: list[str], language: str = "python") -> str:
from codeflash.languages.registry import get_language_support
formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
if formatter_name == "disabled": # nothing to do if no formatter provided
return re.sub(r"\n{2,}", "\n\n", generated_test_source)
with tempfile.TemporaryDirectory() as test_dir_str:
# try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) return code with 2 or more newlines substituted with 2 newlines
original_temp = Path(test_dir_str) / "original_temp.py"
lang_support = get_language_support(language)
original_temp = Path(test_dir_str) / ("original_temp" + lang_support.default_file_extension)
original_temp.write_text(generated_test_source, encoding="utf8")
_, formatted_code, changed = apply_formatter_cmds(
formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=False
@ -111,6 +117,8 @@ def format_code(
print_status: bool = True,
exit_on_failure: bool = True,
) -> str:
from codeflash.languages.registry import get_language_support
if is_LSP_enabled():
exit_on_failure = False
@ -130,7 +138,8 @@ def format_code(
# we don't count the formatting diff for the optimized function as it should be well-formatted
original_code_without_opfunc = original_code.replace(optimized_code, "")
original_temp = Path(test_dir_str) / "original_temp.py"
lang_support = get_language_support(path)
original_temp = Path(test_dir_str) / ("original_temp" + lang_support.default_file_extension)
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
formatted_temp, formatted_code, changed = apply_formatter_cmds(
@ -160,6 +169,7 @@ def format_code(
_, formatted_code, changed = apply_formatter_cmds(
formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure
)
if not changed:
logger.warning(
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"

View file

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, TypeVar
from pydantic import BaseModel, ConfigDict, field_validator
if TYPE_CHECKING:
from codeflash.models.models import FunctionParent
from codeflash.models.function_types import FunctionParent
ObjectDefT = TypeVar("ObjectDefT", ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)

View file

@ -23,8 +23,7 @@ from codeflash.context.unused_definition_remover import (
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
# Language support imports for multi-language code context extraction
from codeflash.languages import is_python
from codeflash.languages.base import Language
from codeflash.languages import Language, is_python
from codeflash.models.models import (
CodeContextType,
CodeOptimizationContext,
@ -234,27 +233,13 @@ def get_code_optimization_context_for_language(
"""
from codeflash.languages import get_language_support
from codeflash.languages.base import FunctionInfo, ParentInfo
# Get language support for this function
language = Language(function_to_optimize.language)
lang_support = get_language_support(language)
# Convert FunctionToOptimize to FunctionInfo for language support
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents)
func_info = FunctionInfo(
name=function_to_optimize.function_name,
file_path=function_to_optimize.file_path,
start_line=function_to_optimize.starting_line or 1,
end_line=function_to_optimize.ending_line or 1,
parents=parents,
is_async=function_to_optimize.is_async,
is_method=len(function_to_optimize.parents) > 0,
language=language,
)
# Extract code context using language support
code_context = lang_support.extract_code_context(func_info, project_root_path, project_root_path)
code_context = lang_support.extract_code_context(function_to_optimize, project_root_path, project_root_path)
# Build imports string if available
imports_code = "\n".join(code_context.imports) if code_context.imports else ""
@ -294,10 +279,9 @@ def get_code_optimization_context_for_language(
helper_code = "\n\n".join(h.source_code for h in same_file_helpers)
target_file_code = target_file_code + "\n\n" + helper_code
# Add global variables (module-level declarations) referenced by the function and helpers
# These should be included in read-writable context so AI can modify them if needed
if code_context.read_only_context:
target_file_code = code_context.read_only_context + "\n\n" + target_file_code
# Note: code_context.read_only_context contains type definitions and global variables
# These should be passed as read-only context to the AI, not prepended to the target code
# If prepended to target code, the AI treats them as code to optimize and includes them in output
# Add imports to target file code
if imports_code:
@ -350,8 +334,9 @@ def get_code_optimization_context_for_language(
return CodeOptimizationContext(
testgen_context=testgen_context,
read_writable_code=read_writable_code,
# Global variables are now included in read-writable code, so don't duplicate in read-only
read_only_context_code="",
# Pass type definitions and globals as read-only context for the AI
# This way the AI sees them as context but doesn't include them in optimized output
read_only_context_code=code_context.read_only_context,
hashing_code_context=read_writable_code.flat,
hashing_code_context_hash=code_hash,
helper_functions=helper_function_sources,

View file

@ -29,7 +29,6 @@ from codeflash.code_utils.code_utils import (
)
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
from codeflash.languages import is_javascript, is_python
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
if TYPE_CHECKING:
@ -589,7 +588,7 @@ def discover_tests_for_language(
"""
from codeflash.languages import get_language_support
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
from codeflash.languages.base import Language
try:
lang_support = get_language_support(Language(language))
@ -597,34 +596,20 @@ def discover_tests_for_language(
logger.warning(f"Unsupported language {language}, returning empty test map")
return {}, 0, 0
# Convert FunctionToOptimize to FunctionInfo for the language support API
# Also build a mapping from simple qualified_name to full qualified_name_with_modules
function_infos: list[FunctionInfo] = []
# Collect all functions and build a mapping from simple qualified_name to full qualified_name_with_modules
all_functions: list[FunctionToOptimize] = []
simple_to_full_name: dict[str, str] = {}
if file_to_funcs_to_optimize:
for funcs in file_to_funcs_to_optimize.values():
for func in funcs:
parents = tuple(ParentInfo(p.name, p.type) for p in func.parents)
func_info = FunctionInfo(
name=func.function_name,
file_path=func.file_path,
start_line=func.starting_line or 0,
end_line=func.ending_line or 0,
start_col=func.starting_col,
end_col=func.ending_col,
is_async=func.is_async,
is_method=bool(func.parents and any(p.type == "ClassDef" for p in func.parents)),
parents=parents,
language=Language(language),
)
function_infos.append(func_info)
all_functions.append(func)
# Map simple qualified_name to full qualified_name_with_modules_from_root
simple_to_full_name[func_info.qualified_name] = func.qualified_name_with_modules_from_root(
simple_to_full_name[func.qualified_name] = func.qualified_name_with_modules_from_root(
cfg.project_root_path
)
# Use language support to discover tests
test_map = lang_support.discover_tests(cfg.tests_root, function_infos)
test_map = lang_support.discover_tests(cfg.tests_root, all_functions)
# Convert TestInfo back to FunctionCalledInTest format
# Use the full qualified name (with modules) as the key for consistency with Python
@ -656,6 +641,8 @@ def discover_unit_tests(
discover_only_these_tests: list[Path] | None = None,
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None,
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
from codeflash.languages import is_javascript, is_python
# Detect language from functions being optimized
language = _detect_language_from_functions(file_to_funcs_to_optimize)

View file

@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Optional
import git
import libcst as cst
from pydantic import Field
from pydantic.dataclasses import dataclass
from rich.tree import Tree
@ -26,9 +27,8 @@ from codeflash.code_utils.code_utils import (
from codeflash.code_utils.env_utils import get_pr_number
from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.languages import get_language_support, get_supported_extensions
from codeflash.languages.base import Language
from codeflash.languages.registry import is_language_supported
from codeflash.languages.language_enum import Language
from codeflash.languages.registry import get_language_support, get_supported_extensions, is_language_supported
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.models.models import FunctionParent
from codeflash.telemetry.posthog_cf import ph
@ -39,8 +39,11 @@ if TYPE_CHECKING:
from libcst import CSTNode
from libcst.metadata import CodeRange
from codeflash.languages.base import FunctionInfo
from codeflash.models.models import CodeOptimizationContext
from codeflash.verification.verification_utils import TestConfig
import contextlib
from rich.text import Text
_property_id = "property"
@ -131,17 +134,23 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
class FunctionToOptimize:
"""Represent a function that is a candidate for optimization.
This is the canonical dataclass for representing functions across all languages
(Python, JavaScript, TypeScript). It captures all information needed to identify,
locate, and work with a function.
Attributes
----------
function_name: The name of the function.
file_path: The absolute file path where the function is located.
parents: A list of parent scopes, which could be classes or functions.
starting_line: The starting line number of the function in the file.
ending_line: The ending line number of the function in the file.
starting_col: The starting column offset (for precise location in multi-line contexts).
ending_col: The ending column offset (for precise location in multi-line contexts).
starting_line: The starting line number of the function in the file (1-indexed).
ending_line: The ending line number of the function in the file (1-indexed).
starting_col: The starting column offset (0-indexed, for precise location).
ending_col: The ending column offset (0-indexed, for precise location).
is_async: Whether this function is defined as async.
is_method: Whether this is a method (belongs to a class).
language: The programming language of this function (default: "python").
doc_start_line: Line where docstring/JSDoc starts (or None if no doc comment).
The qualified_name property provides the full name of the function, including
any parent class or function names. The qualified_name_with_modules_from_root
@ -151,23 +160,32 @@ class FunctionToOptimize:
function_name: str
file_path: Path
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
parents: list[FunctionParent] = Field(default_factory=list) # list[ClassDef | FunctionDef | AsyncFunctionDef]
starting_line: Optional[int] = None
ending_line: Optional[int] = None
starting_col: Optional[int] = None # Column offset for precise location
ending_col: Optional[int] = None # Column offset for precise location
is_async: bool = False
is_method: bool = False # Whether this is a method (belongs to a class)
language: str = "python" # Language identifier for multi-language support
doc_start_line: Optional[int] = None # Line where docstring/JSDoc starts
@property
def top_level_parent_name(self) -> str:
return self.function_name if not self.parents else self.parents[0].name
@property
def class_name(self) -> str | None:
"""Get the immediate parent class name, if any."""
for parent in reversed(self.parents):
if parent.type == "ClassDef":
return parent.name
return None
def __str__(self) -> str:
return (
f"{self.file_path}:{'.'.join([p.name for p in self.parents])}"
f"{'.' if self.parents else ''}{self.function_name}"
)
qualified = f"{'.'.join([p.name for p in self.parents])}{'.' if self.parents else ''}{self.function_name}"
line_info = f":{self.starting_line}-{self.ending_line}" if self.starting_line and self.ending_line else ""
return f"{self.file_path}:{qualified}{line_info}"
@property
def qualified_name(self) -> str:
@ -180,6 +198,28 @@ class FunctionToOptimize:
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
@classmethod
def from_function_info(cls, func_info: FunctionInfo) -> FunctionToOptimize:
"""Create a FunctionToOptimize from a FunctionInfo instance.
This is a temporary method for backward compatibility during migration.
Once FunctionInfo is fully removed, this method can be deleted.
"""
parents = [FunctionParent(name=p.name, type=p.type) for p in func_info.parents]
return cls(
function_name=func_info.name,
file_path=func_info.file_path,
parents=parents,
starting_line=func_info.start_line,
ending_line=func_info.end_line,
starting_col=func_info.start_col,
ending_col=func_info.end_col,
is_async=func_info.is_async,
is_method=func_info.is_method,
language=func_info.language.value,
doc_start_line=func_info.doc_start_line,
)
# =============================================================================
# Multi-language support helpers
@ -187,7 +227,7 @@ class FunctionToOptimize:
def get_files_for_language(
module_root_path: Path, ignore_paths: list[Path], language: Language | None = None
module_root_path: Path, ignore_paths: list[Path] | None = None, language: Language | None = None
) -> list[Path]:
"""Get all source files for supported languages.
@ -200,22 +240,69 @@ def get_files_for_language(
List of file paths matching supported extensions.
"""
if ignore_paths is None:
ignore_paths = []
if language is not None:
support = get_language_support(language)
extensions = support.file_extensions
else:
extensions = tuple(get_supported_extensions())
# Default directory patterns to always exclude for JS/TS
js_ts_default_excludes = {
"node_modules",
"dist",
"build",
".next",
".nuxt",
"coverage",
".cache",
".turbo",
".vercel",
"__pycache__",
}
files = []
for ext in extensions:
pattern = f"*{ext}"
for file_path in module_root_path.rglob(pattern):
# Check explicit ignore paths
if any(file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths):
continue
# Check default JS/TS excludes in path parts
if any(part in js_ts_default_excludes for part in file_path.parts):
continue
files.append(file_path)
return files
def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bool, str | None]:
"""Check if a JavaScript/TypeScript function is exported from its module.
For JS/TS, functions that are not exported cannot be imported by tests,
making them impossible to optimize.
Args:
file_path: Path to the source file.
function_name: Name of the function to check.
Returns:
Tuple of (is_exported, export_name). export_name may be 'default' for default exports.
"""
from codeflash.languages.treesitter_utils import get_analyzer_for_file
try:
source = file_path.read_text(encoding="utf-8")
analyzer = get_analyzer_for_file(file_path)
return analyzer.is_function_exported(source, function_name)
except Exception as e:
logger.debug(f"Failed to check export status for {function_name}: {e}")
# Return True to avoid blocking in case of errors
return True, None
def _find_all_functions_in_python_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
"""Find all optimizable functions in a Python file using AST parsing.
@ -248,25 +335,8 @@ def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list
try:
lang_support = get_language_support(file_path)
criteria = FunctionFilterCriteria(require_return=True)
function_infos = lang_support.discover_functions(file_path, criteria)
ftos = []
for func_info in function_infos:
parents = [FunctionParent(p.name, p.type) for p in func_info.parents]
ftos.append(
FunctionToOptimize(
function_name=func_info.name,
file_path=func_info.file_path,
parents=parents,
starting_line=func_info.start_line,
ending_line=func_info.end_line,
starting_col=func_info.start_col,
ending_col=func_info.end_col,
is_async=func_info.is_async,
language=func_info.language.value,
)
)
functions[file_path] = ftos
# discover_functions already returns FunctionToOptimize objects
functions[file_path] = lang_support.discover_functions(file_path, criteria)
except Exception as e:
logger.debug(f"Failed to discover functions in {file_path}: {e}")
@ -338,6 +408,36 @@ def get_functions_to_optimize(
exit_with_message(
f"Function {only_get_this_function} not found in file {file}\nor the function does not have a 'return' statement or is a property"
)
# For JavaScript/TypeScript, verify that the function (or its parent class) is exported
# Non-exported functions cannot be imported by tests
if found_function.language in ("javascript", "typescript"):
# For class methods, check if the parent class is exported
# For standalone functions, check if the function itself is exported
if found_function.parents:
# It's a class method - check if the class is exported
name_to_check = found_function.top_level_parent_name
else:
# It's a standalone function - check if the function is exported
name_to_check = found_function.function_name
is_exported, export_name = _is_js_ts_function_exported(file, name_to_check)
if not is_exported:
if found_function.parents:
logger.debug(
f"Class '{name_to_check}' containing method '{found_function.function_name}' "
f"is not exported from {file}. "
f"In JavaScript/TypeScript, only exported classes/functions can be optimized "
f"because tests need to import them."
)
else:
logger.debug(
f"Function '{found_function.function_name}' is not exported from {file}. "
f"In JavaScript/TypeScript, only exported functions can be optimized because "
f"tests need to import them."
)
return {}, 0, None
functions[file] = [found_function]
else:
logger.info("Finding all functions modified in the current git diff ...")
@ -539,9 +639,10 @@ def get_all_replay_test_functions(
except Exception as e:
logger.warning(f"Error parsing replay test file {replay_test_file}: {e}")
if not trace_file_path:
if trace_file_path is None:
logger.error("Could not find trace_file_path in replay test files.")
exit_with_message("Could not find trace_file_path in replay test files.")
raise AssertionError("Unreachable") # exit_with_message never returns
if not trace_file_path.exists():
logger.error(f"Trace file not found: {trace_file_path}")
@ -596,7 +697,7 @@ def get_all_replay_test_functions(
if filtered_list:
filtered_valid_functions[file_path] = filtered_list
return filtered_valid_functions, trace_file_path
return dict(filtered_valid_functions), trace_file_path
def is_git_repo(file_path: str) -> bool:
@ -608,11 +709,13 @@ def is_git_repo(file_path: str) -> bool:
@cache
def ignored_submodule_paths(module_root: str) -> list[str]:
def ignored_submodule_paths(module_root: str) -> list[Path]:
if is_git_repo(module_root):
git_repo = git.Repo(module_root, search_parent_directories=True)
try:
return [Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules]
working_dir = git_repo.working_tree_dir
if working_dir is not None:
return [Path(working_dir, submodule.path).resolve() for submodule in git_repo.submodules]
except Exception as e:
logger.warning(f"Error getting submodule paths: {e}")
return []
@ -626,7 +729,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
self.class_name = class_name
self.function_name = function_or_method_name
self.is_top_level = False
self.function_has_args = None
self.function_has_args: bool | None = None
self.line_no = line_no
self.is_staticmethod = False
self.is_classmethod = False
@ -740,31 +843,28 @@ def was_function_previously_optimized(
# Check optimization status if repository info is provided
# already_optimized_count = 0
try:
# Check optimization status if repository info is provided
# already_optimized_count = 0
owner = None
repo = None
with contextlib.suppress(git.exc.InvalidGitRepositoryError):
owner, repo = get_repo_owner_and_name()
except git.exc.InvalidGitRepositoryError:
logger.warning("No git repository found")
owner, repo = None, None
pr_number = get_pr_number()
if not owner or not repo or pr_number is None or getattr(args, "no_pr", False):
return False
code_contexts = []
func_hash = code_context.hashing_code_context_hash
# Use a unique path identifier that includes function info
code_contexts.append(
code_contexts = [
{
"file_path": function_to_optimize.file_path,
"file_path": str(function_to_optimize.file_path),
"function_name": function_to_optimize.qualified_name,
"code_hash": func_hash,
}
)
if not code_contexts:
return False
]
try:
result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts)
@ -783,7 +883,7 @@ def filter_functions(
ignore_paths: list[Path],
project_root: Path,
module_root: Path,
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
previous_checkpoint_functions: dict[str, dict[str, Any]] | None = None,
*,
disable_logs: bool = False,
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
@ -808,21 +908,49 @@ def filter_functions(
# Normalize paths for case-insensitive comparison on Windows
tests_root_str = os.path.normcase(str(tests_root))
module_root_str = os.path.normcase(str(module_root))
project_root_str = os.path.normcase(str(project_root))
# Check if tests_root overlaps with module_root or project_root
# In this case, we need to use file pattern matching instead of directory matching
tests_root_overlaps_source = tests_root_str in (module_root_str, project_root_str) or module_root_str.startswith(
tests_root_str + os.sep
)
# Test file patterns for when tests_root overlaps with source
test_file_name_patterns = (".test.", ".spec.", "_test.", "_spec.")
test_dir_patterns = (os.sep + "test" + os.sep, os.sep + "tests" + os.sep, os.sep + "__tests__" + os.sep)
def is_test_file(file_path_normalized: str) -> bool:
"""Check if a file is a test file based on patterns."""
if tests_root_overlaps_source:
# Use file pattern matching when tests_root overlaps with source
file_lower = file_path_normalized.lower()
# Check filename patterns (e.g., .test.ts, .spec.ts)
if any(pattern in file_lower for pattern in test_file_name_patterns):
return True
# Check directory patterns, but only within the project root
# to avoid false positives from parent directories
relative_path = file_lower
if project_root_str and file_lower.startswith(project_root_str.lower()):
relative_path = file_lower[len(project_root_str) :]
return any(pattern in relative_path for pattern in test_dir_patterns)
# Use directory-based filtering when tests are in a separate directory
return file_path_normalized.startswith(tests_root_str + os.sep)
# We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
for file_path_path, functions in modified_functions.items():
_functions = functions
file_path = str(file_path_path)
file_path_normalized = os.path.normcase(file_path)
if file_path_normalized.startswith(tests_root_str + os.sep):
if is_test_file(file_path_normalized):
test_functions_removed_count += len(_functions)
continue
if file_path in ignore_paths or any(
if file_path_path in ignore_paths or any(
file_path_normalized.startswith(os.path.normcase(str(ignore_path)) + os.sep) for ignore_path in ignore_paths
):
ignore_paths_removed_count += 1
continue
if file_path in submodule_paths or any(
if file_path_path in submodule_paths or any(
file_path_normalized.startswith(os.path.normcase(str(submodule_path)) + os.sep)
for submodule_path in submodule_paths
):
@ -914,7 +1042,7 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
# Custom DFS, return True as soon as a Return node is found
stack = [function_node]
stack: list[ast.AST] = [function_node]
while stack:
node = stack.pop()
if isinstance(node, ast.Return):

View file

@ -19,7 +19,6 @@ Usage:
from codeflash.languages.base import (
CodeContext,
FunctionInfo,
HelperFunction,
Language,
LanguageSupport,
@ -53,6 +52,40 @@ from codeflash.languages.registry import (
get_supported_languages,
register_language,
)
from codeflash.languages.test_framework import (
current_test_framework,
get_js_test_framework_or_default,
is_jest,
is_mocha,
is_pytest,
is_unittest,
is_vitest,
reset_test_framework,
set_current_test_framework,
)
# Lazy imports to avoid circular imports
def __getattr__(name: str):
if name == "FunctionInfo":
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
return FunctionToOptimize
if name == "JavaScriptSupport":
from codeflash.languages.javascript.support import JavaScriptSupport
return JavaScriptSupport
if name == "TypeScriptSupport":
from codeflash.languages.javascript.support import TypeScriptSupport
return TypeScriptSupport
if name == "PythonSupport":
from codeflash.languages.python.support import PythonSupport
return PythonSupport
msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)
__all__ = [
# Base types
@ -67,16 +100,24 @@ __all__ = [
# Current language singleton
"current_language",
"current_language_support",
# Registry functions
"current_test_framework",
"detect_project_language",
"get_js_test_framework_or_default",
"get_language_support",
"get_supported_extensions",
"get_supported_languages",
"is_java",
"is_javascript",
"is_jest",
"is_mocha",
"is_pytest",
"is_python",
"is_typescript",
"is_unittest",
"is_vitest",
"register_language",
"reset_current_language",
"reset_test_framework",
"set_current_language",
"set_current_test_framework",
]

View file

@ -2,112 +2,36 @@
This module defines the core abstractions that all language implementations must follow.
The LanguageSupport protocol defines the interface that each language must implement,
while the dataclasses define language-agnostic representations of code constructs.
while FunctionToOptimize is the canonical representation of functions across all languages.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
if TYPE_CHECKING:
from collections.abc import Sequence
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
class Language(str, Enum):
"""Supported programming languages."""
from codeflash.languages.language_enum import Language
from codeflash.models.function_types import FunctionParent
PYTHON = "python"
JAVASCRIPT = "javascript"
TYPESCRIPT = "typescript"
JAVA = "java"
def __str__(self) -> str:
return self.value
# Backward compatibility aliases - ParentInfo is now FunctionParent
ParentInfo = FunctionParent
@dataclass(frozen=True)
class ParentInfo:
"""Parent scope information for nested functions/methods.
# Lazy import for FunctionInfo to avoid circular imports
# This allows `from codeflash.languages.base import FunctionInfo` to work at runtime
def __getattr__(name: str) -> Any:
if name == "FunctionInfo":
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
Represents the parent class or function that contains a nested function.
Used to construct the qualified name of a function.
Attributes:
name: The name of the parent scope (class name or function name).
type: The type of parent ("ClassDef", "FunctionDef", "AsyncFunctionDef", etc.).
"""
name: str
type: str # "ClassDef", "FunctionDef", "AsyncFunctionDef", etc.
def __str__(self) -> str:
return f"{self.type}:{self.name}"
@dataclass(frozen=True)
class FunctionInfo:
"""Language-agnostic representation of a function to optimize.
This class captures all the information needed to identify, locate, and
work with a function across different programming languages.
Attributes:
name: The simple function name (e.g., "add").
file_path: Absolute path to the file containing the function.
start_line: Starting line number (1-indexed).
end_line: Ending line number (1-indexed, inclusive).
parents: List of parent scopes (for nested functions/methods).
is_async: Whether this is an async function.
is_method: Whether this is a method (belongs to a class).
language: The programming language.
start_col: Starting column (0-indexed), optional for more precise location.
end_col: Ending column (0-indexed), optional.
"""
name: str
file_path: Path
start_line: int
end_line: int
parents: tuple[ParentInfo, ...] = ()
is_async: bool = False
is_method: bool = False
language: Language = Language.PYTHON
start_col: int | None = None
end_col: int | None = None
doc_start_line: int | None = None # Line where docstring/JSDoc starts (or None if no doc comment)
@property
def qualified_name(self) -> str:
"""Full qualified name including parent scopes.
For a method `add` in class `Calculator`, returns "Calculator.add".
For nested functions, includes all parent scopes.
"""
if not self.parents:
return self.name
parent_path = ".".join(parent.name for parent in self.parents)
return f"{parent_path}.{self.name}"
@property
def class_name(self) -> str | None:
"""Get the immediate parent class name, if any."""
for parent in reversed(self.parents):
if parent.type == "ClassDef":
return parent.name
return None
@property
def top_level_parent_name(self) -> str:
"""Get the top-level parent name, or function name if no parents."""
return self.parents[0].name if self.parents else self.name
def __str__(self) -> str:
return f"FunctionInfo({self.qualified_name} at {self.file_path}:{self.start_line}-{self.end_line})"
return FunctionToOptimize
msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)
@dataclass
@ -237,6 +161,37 @@ class FunctionFilterCriteria:
max_lines: int | None = None
@dataclass
class ReferenceInfo:
"""Information about a reference (call site) to a function.
This class captures information about where a function is called
from, including the file, line number, context, and caller function.
Attributes:
file_path: Path to the file containing the reference.
line: Line number (1-indexed).
column: Column number (0-indexed).
end_line: End line number (1-indexed).
end_column: End column number (0-indexed).
context: The line of code containing the reference.
reference_type: Type of reference ("call", "callback", "memoized", "import", "reexport").
import_name: Name used to import the function (may differ from original).
caller_function: Name of the function containing this reference (or None for module-level).
"""
file_path: Path
line: int
column: int
end_line: int
end_column: int
context: str
reference_type: str
import_name: str | None
caller_function: str | None = None
@runtime_checkable
class LanguageSupport(Protocol):
"""Protocol defining what a language implementation must provide.
@ -279,6 +234,11 @@ class LanguageSupport(Protocol):
"""
...
@property
def default_file_extension(self) -> str:
"""Default file extension for this language."""
...
@property
def test_framework(self) -> str:
"""Primary test framework name.
@ -298,7 +258,7 @@ class LanguageSupport(Protocol):
def discover_functions(
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
) -> list[FunctionInfo]:
) -> list[FunctionToOptimize]:
"""Find all optimizable functions in a file.
Args:
@ -306,12 +266,14 @@ class LanguageSupport(Protocol):
filter_criteria: Optional criteria to filter functions.
Returns:
List of FunctionInfo objects for discovered functions.
List of FunctionToOptimize objects for discovered functions.
"""
...
def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionInfo]) -> dict[str, list[TestInfo]]:
def discover_tests(
self, test_root: Path, source_functions: Sequence[FunctionToOptimize]
) -> dict[str, list[TestInfo]]:
"""Map source functions to their tests via static analysis.
Args:
@ -326,7 +288,7 @@ class LanguageSupport(Protocol):
# === Code Analysis ===
def extract_code_context(self, function: FunctionInfo, project_root: Path, module_root: Path) -> CodeContext:
def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext:
"""Extract function code and its dependencies.
Args:
@ -340,7 +302,7 @@ class LanguageSupport(Protocol):
"""
...
def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> list[HelperFunction]:
def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]:
"""Find helper functions called by the target function.
Args:
@ -353,14 +315,37 @@ class LanguageSupport(Protocol):
"""
...
def find_references(
self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500
) -> list[ReferenceInfo]:
"""Find all references (call sites) to a function across the codebase.
This method finds all places where a function is called, including:
- Direct calls
- Callbacks (passed to other functions)
- Memoized versions
- Re-exports
Args:
function: The function to find references for.
project_root: Root of the project to search.
tests_root: Root of tests directory (references in tests are excluded).
max_files: Maximum number of files to search.
Returns:
List of ReferenceInfo objects describing each reference location.
"""
...
# === Code Transformation ===
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:
def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
"""Replace a function in source code with new implementation.
Args:
source: Original source code.
function: FunctionInfo identifying the function to replace.
function: FunctionToOptimize identifying the function to replace.
new_source: New function source code.
Returns:
@ -416,7 +401,7 @@ class LanguageSupport(Protocol):
# === Instrumentation ===
def instrument_for_behavior(self, source: str, functions: Sequence[FunctionInfo]) -> str:
def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOptimize]) -> str:
"""Add behavior instrumentation to capture inputs/outputs.
Args:
@ -429,7 +414,7 @@ class LanguageSupport(Protocol):
"""
...
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionInfo) -> str:
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str:
"""Add timing instrumentation to test code.
Args:
@ -606,7 +591,9 @@ class LanguageSupport(Protocol):
"""
...
def instrument_source_for_line_profiler(self, func_info: FunctionInfo, line_profiler_output_file: Path) -> bool:
def instrument_source_for_line_profiler(
self, func_info: FunctionToOptimize, line_profiler_output_file: Path
) -> bool:
"""Instrument source code before line profiling."""
...
@ -675,17 +662,14 @@ class LanguageSupport(Protocol):
...
def convert_parents_to_tuple(parents: list | tuple) -> tuple[ParentInfo, ...]:
"""Convert a list of parent objects to a tuple of ParentInfo.
This helper handles conversion from the existing FunctionParent
dataclass to the new ParentInfo dataclass.
def convert_parents_to_tuple(parents: list | tuple) -> tuple[FunctionParent, ...]:
"""Convert a list of parent objects to a tuple of FunctionParent.
Args:
parents: List or tuple of parent objects with name and type attributes.
Returns:
Tuple of ParentInfo objects.
Tuple of FunctionParent objects.
"""
return tuple(ParentInfo(name=p.name, type=p.type) for p in parents)
return tuple(FunctionParent(name=p.name, type=p.type) for p in parents)

View file

@ -11,7 +11,8 @@ import logging
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.base import CodeContext, FunctionInfo, HelperFunction, Language
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import CodeContext, HelperFunction, Language
from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files
from codeflash.languages.java.parser import JavaAnalyzer, JavaClassNode, get_java_analyzer
@ -29,7 +30,7 @@ class InvalidJavaSyntaxError(Exception):
def extract_code_context(
function: FunctionInfo,
function: FunctionToOptimize,
project_root: Path,
module_root: Path | None = None,
max_helper_depth: int = 2,
@ -83,7 +84,7 @@ def extract_code_context(
# This provides necessary context for optimization
parent_type_name = _get_parent_type_name(function)
if parent_type_name:
type_skeleton = _extract_type_skeleton(source, parent_type_name, function.name, analyzer)
type_skeleton = _extract_type_skeleton(source, parent_type_name, function.function_name, analyzer)
if type_skeleton:
target_code = _wrap_method_in_type_skeleton(target_code, type_skeleton)
wrapped_in_skeleton = True
@ -107,7 +108,7 @@ def extract_code_context(
if validate_syntax and target_code:
if not analyzer.validate_syntax(target_code):
raise InvalidJavaSyntaxError(
f"Extracted code for {function.name} is not syntactically valid Java:\n{target_code}"
f"Extracted code for {function.function_name} is not syntactically valid Java:\n{target_code}"
)
return CodeContext(
@ -120,7 +121,7 @@ def extract_code_context(
)
def _get_parent_type_name(function: FunctionInfo) -> str | None:
def _get_parent_type_name(function: FunctionToOptimize) -> str | None:
"""Get the parent type name (class, interface, or enum) for a function.
Args:
@ -558,7 +559,7 @@ def _wrap_method_in_type_skeleton(method_code: str, skeleton: TypeSkeleton) -> s
_wrap_method_in_class_skeleton = _wrap_method_in_type_skeleton
def extract_function_source(source: str, function: FunctionInfo) -> str:
def extract_function_source(source: str, function: FunctionToOptimize) -> str:
"""Extract the source code of a function from the full file source.
Args:
@ -572,8 +573,8 @@ def extract_function_source(source: str, function: FunctionInfo) -> str:
lines = source.splitlines(keepends=True)
# Include Javadoc if present
start_line = function.doc_start_line or function.start_line
end_line = function.end_line
start_line = function.doc_start_line or function.starting_line
end_line = function.ending_line
# Convert from 1-indexed to 0-indexed
start_idx = start_line - 1
@ -583,7 +584,7 @@ def extract_function_source(source: str, function: FunctionInfo) -> str:
def find_helper_functions(
function: FunctionInfo,
function: FunctionToOptimize,
project_root: Path,
max_depth: int = 2,
analyzer: JavaAnalyzer | None = None,
@ -624,12 +625,12 @@ def find_helper_functions(
helpers.append(
HelperFunction(
name=func.name,
name=func.function_name,
qualified_name=func.qualified_name,
file_path=file_path,
source_code=func_source,
start_line=func.start_line,
end_line=func.end_line,
start_line=func.starting_line,
end_line=func.ending_line,
)
)
@ -648,7 +649,7 @@ def find_helper_functions(
def _find_same_class_helpers(
function: FunctionInfo,
function: FunctionToOptimize,
analyzer: JavaAnalyzer,
) -> list[HelperFunction]:
"""Find helper methods in the same class as the target function.
@ -676,7 +677,7 @@ def _find_same_class_helpers(
# Find which methods the target function calls
target_method = None
for method in methods:
if method.name == function.name and method.class_name == function.class_name:
if method.name == function.function_name and method.class_name == function.class_name:
target_method = method
break
@ -689,7 +690,7 @@ def _find_same_class_helpers(
# Add called methods from the same class as helpers
for method in methods:
if (
method.name != function.name
method.name != function.function_name
and method.class_name == function.class_name
and method.name in called_methods
):
@ -716,7 +717,7 @@ def _find_same_class_helpers(
def extract_read_only_context(
source: str,
function: FunctionInfo,
function: FunctionToOptimize,
analyzer: JavaAnalyzer,
) -> str:
"""Extract read-only context (fields, constants, inner classes).

View file

@ -10,13 +10,10 @@ import logging
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.base import (
FunctionFilterCriteria,
FunctionInfo,
Language,
ParentInfo,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import FunctionFilterCriteria
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
from codeflash.models.function_types import FunctionParent
if TYPE_CHECKING:
pass
@ -28,7 +25,7 @@ def discover_functions(
file_path: Path,
filter_criteria: FunctionFilterCriteria | None = None,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionInfo]:
) -> list[FunctionToOptimize]:
"""Find all optimizable functions/methods in a Java file.
Uses tree-sitter to parse the file and find methods that can be optimized.
@ -39,7 +36,7 @@ def discover_functions(
analyzer: Optional JavaAnalyzer instance (created if not provided).
Returns:
List of FunctionInfo objects for discovered functions.
List of FunctionToOptimize objects for discovered functions.
"""
criteria = filter_criteria or FunctionFilterCriteria()
@ -58,7 +55,7 @@ def discover_functions_from_source(
file_path: Path | None = None,
filter_criteria: FunctionFilterCriteria | None = None,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionInfo]:
) -> list[FunctionToOptimize]:
"""Find all optimizable functions/methods in Java source code.
Args:
@ -68,7 +65,7 @@ def discover_functions_from_source(
analyzer: Optional JavaAnalyzer instance.
Returns:
List of FunctionInfo objects for discovered functions.
List of FunctionToOptimize objects for discovered functions.
"""
criteria = filter_criteria or FunctionFilterCriteria()
@ -82,7 +79,7 @@ def discover_functions_from_source(
include_static=True,
)
functions: list[FunctionInfo] = []
functions: list[FunctionToOptimize] = []
for method in methods:
# Apply filters
@ -90,22 +87,22 @@ def discover_functions_from_source(
continue
# Build parents list
parents: list[ParentInfo] = []
parents: list[FunctionParent] = []
if method.class_name:
parents.append(ParentInfo(name=method.class_name, type="ClassDef"))
parents.append(FunctionParent(name=method.class_name, type="ClassDef"))
functions.append(
FunctionInfo(
name=method.name,
FunctionToOptimize(
function_name=method.name,
file_path=file_path or Path("unknown.java"),
start_line=method.start_line,
end_line=method.end_line,
start_col=method.start_col,
end_col=method.end_col,
parents=tuple(parents),
starting_line=method.start_line,
ending_line=method.end_line,
starting_col=method.start_col,
ending_col=method.end_col,
parents=parents,
is_async=False, # Java doesn't have async keyword
is_method=method.class_name is not None,
language=Language.JAVA,
language="java",
doc_start_line=method.javadoc_start_line,
)
)
@ -182,7 +179,7 @@ def _should_include_method(
def discover_test_methods(
file_path: Path,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionInfo]:
) -> list[FunctionToOptimize]:
"""Find all JUnit test methods in a Java test file.
Looks for methods annotated with @Test, @ParameterizedTest, @RepeatedTest, etc.
@ -192,7 +189,7 @@ def discover_test_methods(
analyzer: Optional JavaAnalyzer instance.
Returns:
List of FunctionInfo objects for discovered test methods.
List of FunctionToOptimize objects for discovered test methods.
"""
try:
@ -205,7 +202,7 @@ def discover_test_methods(
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
test_methods: list[FunctionInfo] = []
test_methods: list[FunctionToOptimize] = []
# Find methods with test annotations
_walk_tree_for_test_methods(tree.root_node, source_bytes, file_path, test_methods, analyzer, current_class=None)
@ -217,7 +214,7 @@ def _walk_tree_for_test_methods(
node,
source_bytes: bytes,
file_path: Path,
test_methods: list[FunctionInfo],
test_methods: list[FunctionToOptimize],
analyzer: JavaAnalyzer,
current_class: str | None,
) -> None:
@ -250,22 +247,22 @@ def _walk_tree_for_test_methods(
if name_node:
method_name = analyzer.get_node_text(name_node, source_bytes)
parents: list[ParentInfo] = []
parents: list[FunctionParent] = []
if current_class:
parents.append(ParentInfo(name=current_class, type="ClassDef"))
parents.append(FunctionParent(name=current_class, type="ClassDef"))
test_methods.append(
FunctionInfo(
name=method_name,
FunctionToOptimize(
function_name=method_name,
file_path=file_path,
start_line=node.start_point[0] + 1,
end_line=node.end_point[0] + 1,
start_col=node.start_point[1],
end_col=node.end_point[1],
parents=tuple(parents),
starting_line=node.start_point[0] + 1,
ending_line=node.end_point[0] + 1,
starting_col=node.start_point[1],
ending_col=node.end_point[1],
parents=list(parents),
is_async=False,
is_method=current_class is not None,
language=Language.JAVA,
language="java",
)
)
@ -285,7 +282,7 @@ def get_method_by_name(
method_name: str,
class_name: str | None = None,
analyzer: JavaAnalyzer | None = None,
) -> FunctionInfo | None:
) -> FunctionToOptimize | None:
"""Find a specific method by name in a Java file.
Args:
@ -295,13 +292,13 @@ def get_method_by_name(
analyzer: Optional JavaAnalyzer instance.
Returns:
FunctionInfo for the method, or None if not found.
FunctionToOptimize for the method, or None if not found.
"""
functions = discover_functions(file_path, analyzer=analyzer)
for func in functions:
if func.name == method_name:
if func.function_name == method_name:
if class_name is None or func.class_name == class_name:
return func
@ -312,7 +309,7 @@ def get_class_methods(
file_path: Path,
class_name: str,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionInfo]:
) -> list[FunctionToOptimize]:
"""Get all methods in a specific class.
Args:
@ -321,7 +318,7 @@ def get_class_methods(
analyzer: Optional JavaAnalyzer instance.
Returns:
List of FunctionInfo objects for methods in the class.
List of FunctionToOptimize objects for methods in the class.
"""
functions = discover_functions(file_path, analyzer=analyzer)

View file

@ -19,7 +19,7 @@ import re
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.base import FunctionInfo
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer
if TYPE_CHECKING:
@ -30,16 +30,16 @@ logger = logging.getLogger(__name__)
def _get_function_name(func: Any) -> str:
"""Get the function name from either FunctionInfo or FunctionToOptimize."""
if hasattr(func, "name"):
return func.name
"""Get the function name from FunctionToOptimize."""
if hasattr(func, "function_name"):
return func.function_name
if hasattr(func, "name"):
return func.name
raise AttributeError(f"Cannot get function name from {type(func)}")
def _get_qualified_name(func: Any) -> str:
"""Get the qualified name from either FunctionInfo or FunctionToOptimize."""
"""Get the qualified name from FunctionToOptimize."""
if hasattr(func, "qualified_name"):
return func.qualified_name
# Build qualified name from function_name and parents
@ -56,7 +56,7 @@ def _get_qualified_name(func: Any) -> str:
def instrument_for_behavior(
source: str,
functions: Sequence[FunctionInfo],
functions: Sequence[FunctionToOptimize],
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Add behavior instrumentation to capture inputs/outputs.
@ -84,7 +84,7 @@ def instrument_for_behavior(
def instrument_for_benchmarking(
test_source: str,
target_function: FunctionInfo,
target_function: FunctionToOptimize,
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Add timing instrumentation to test code.
@ -109,7 +109,7 @@ def instrument_for_benchmarking(
def instrument_existing_test(
test_path: Path,
call_positions: Sequence,
function_to_optimize: Any, # FunctionInfo or FunctionToOptimize
function_to_optimize: Any, # FunctionToOptimize or FunctionToOptimize
tests_project_root: Path,
mode: str, # "behavior" or "performance"
analyzer: JavaAnalyzer | None = None,
@ -573,7 +573,7 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
def create_benchmark_test(
target_function: FunctionInfo,
target_function: FunctionToOptimize,
test_setup_code: str,
invocation_code: str,
iterations: int = 1000,

View file

@ -17,7 +17,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.base import FunctionInfo
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
if TYPE_CHECKING:
@ -213,7 +213,7 @@ def _insert_class_members(
def replace_function(
source: str,
function: FunctionInfo,
function: FunctionToOptimize,
new_source: str,
analyzer: JavaAnalyzer | None = None,
) -> str:
@ -233,7 +233,7 @@ def replace_function(
Args:
source: Original source code.
function: FunctionInfo identifying the function to replace.
function: FunctionToOptimize identifying the function to replace.
new_source: New function source code (may include class with helpers).
analyzer: Optional JavaAnalyzer instance.
@ -243,8 +243,12 @@ def replace_function(
"""
analyzer = analyzer or get_java_analyzer()
func_name = function.function_name
func_start_line = function.starting_line
func_end_line = function.ending_line
# Parse the optimization to extract components
parsed = _parse_optimization_source(new_source, function.name, analyzer)
parsed = _parse_optimization_source(new_source, func_name, analyzer)
# Find the method in the original source
methods = analyzer.find_methods(source)
@ -254,7 +258,7 @@ def replace_function(
# Find all methods matching the name (there may be overloads)
matching_methods = [
m for m in methods
if m.name == function.name
if m.name == func_name
and (function.class_name is None or m.class_name == function.class_name)
]
@ -267,18 +271,18 @@ def replace_function(
logger.debug(
"Found %d overloads of %s. Function start_line=%s, end_line=%s",
len(matching_methods),
function.name,
function.start_line,
function.end_line,
func_name,
func_start_line,
func_end_line,
)
for i, m in enumerate(matching_methods):
logger.debug(" Overload %d: lines %d-%d", i, m.start_line, m.end_line)
if function.start_line and function.end_line:
if func_start_line and func_end_line:
for i, method in enumerate(matching_methods):
# Check if the line numbers are close (account for minor differences
# that can occur due to different parsing or file transformations)
# Use a tolerance of 5 lines to handle edge cases
if abs(method.start_line - function.start_line) <= 5:
if abs(method.start_line - func_start_line) <= 5:
target_method = method
target_overload_index = i
logger.debug(
@ -286,21 +290,21 @@ def replace_function(
i,
method.start_line,
method.end_line,
function.start_line,
function.end_line,
func_start_line,
func_end_line,
)
break
if not target_method:
# Fallback: use the first match
logger.warning(
"Multiple overloads of %s found but no line match, using first match",
function.name,
func_name,
)
target_method = matching_methods[0]
target_overload_index = 0
if not target_method:
logger.error("Could not find method %s in source", function.name)
logger.error("Could not find method %s in source", func_name)
return source
# Get the class name for inserting new members
@ -348,7 +352,7 @@ def replace_function(
methods = analyzer.find_methods(source)
matching_methods = [
m for m in methods
if m.name == function.name
if m.name == func_name
and (function.class_name is None or m.class_name == function.class_name)
]
@ -363,7 +367,7 @@ def replace_function(
else:
logger.error(
"Lost target method %s after adding members (had index %d, found %d overloads)",
function.name,
func_name,
target_overload_index,
len(matching_methods),
)
@ -457,7 +461,7 @@ def _apply_indentation(lines: list[str], base_indent: str) -> str:
def replace_method_body(
source: str,
function: FunctionInfo,
function: FunctionToOptimize,
new_body: str,
analyzer: JavaAnalyzer | None = None,
) -> str:
@ -465,7 +469,7 @@ def replace_method_body(
Args:
source: Original source code.
function: FunctionInfo identifying the function.
function: FunctionToOptimize identifying the function.
new_body: New method body (code between braces).
analyzer: Optional JavaAnalyzer instance.
@ -476,24 +480,26 @@ def replace_method_body(
analyzer = analyzer or get_java_analyzer()
source_bytes = source.encode("utf8")
func_name = function.function_name
# Find the method
methods = analyzer.find_methods(source)
target_method = None
for method in methods:
if method.name == function.name:
if method.name == func_name:
if function.class_name is None or method.class_name == function.class_name:
target_method = method
break
if not target_method:
logger.error("Could not find method %s", function.name)
logger.error("Could not find method %s", func_name)
return source
# Find the body node
body_node = target_method.node.child_by_field_name("body")
if not body_node:
logger.error("Method %s has no body (abstract?)", function.name)
logger.error("Method %s has no body (abstract?)", func_name)
return source
# Get the body's byte positions
@ -596,14 +602,14 @@ def insert_method(
def remove_method(
source: str,
function: FunctionInfo,
function: FunctionToOptimize,
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Remove a method from source code.
Args:
source: The source code.
function: FunctionInfo identifying the method to remove.
function: FunctionToOptimize identifying the method to remove.
analyzer: Optional JavaAnalyzer instance.
Returns:
@ -612,18 +618,20 @@ def remove_method(
"""
analyzer = analyzer or get_java_analyzer()
func_name = function.function_name
# Find the method
methods = analyzer.find_methods(source)
target_method = None
for method in methods:
if method.name == function.name:
if method.name == func_name:
if function.class_name is None or method.class_name == function.class_name:
target_method = method
break
if not target_method:
logger.error("Could not find method %s", function.name)
logger.error("Could not find method %s", func_name)
return source
# Determine removal range (include Javadoc)
@ -669,14 +677,15 @@ def remove_test_functions(
result = test_source
for method in methods_to_remove:
# Create a FunctionInfo for removal
func_info = FunctionInfo(
name=method.name,
# Create a FunctionToOptimize for removal
func_info = FunctionToOptimize(
function_name=method.name,
file_path=Path("temp.java"),
start_line=method.start_line,
end_line=method.end_line,
parents=(),
starting_line=method.start_line,
ending_line=method.end_line,
parents=[],
is_method=True,
language="java",
)
result = remove_method(result, func_info, analyzer)

View file

@ -10,10 +10,10 @@ import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import (
CodeContext,
FunctionFilterCriteria,
FunctionInfo,
HelperFunction,
Language,
LanguageSupport,
@ -94,18 +94,18 @@ class JavaSupport(LanguageSupport):
def discover_functions(
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
) -> list[FunctionInfo]:
) -> list[FunctionToOptimize]:
"""Find all optimizable functions in a Java file."""
return discover_functions(file_path, filter_criteria, self._analyzer)
def discover_functions_from_source(
self, source: str, file_path: Path | None = None, filter_criteria: FunctionFilterCriteria | None = None
) -> list[FunctionInfo]:
) -> list[FunctionToOptimize]:
"""Find all optimizable functions in Java source code."""
return discover_functions_from_source(source, file_path, filter_criteria, self._analyzer)
def discover_tests(
self, test_root: Path, source_functions: Sequence[FunctionInfo]
self, test_root: Path, source_functions: Sequence[FunctionToOptimize]
) -> dict[str, list[TestInfo]]:
"""Map source functions to their tests."""
return discover_tests(test_root, source_functions, self._analyzer)
@ -113,13 +113,13 @@ class JavaSupport(LanguageSupport):
# === Code Analysis ===
def extract_code_context(
self, function: FunctionInfo, project_root: Path, module_root: Path
self, function: FunctionToOptimize, project_root: Path, module_root: Path
) -> CodeContext:
"""Extract function code and its dependencies."""
return extract_code_context(function, project_root, module_root, analyzer=self._analyzer)
def find_helper_functions(
self, function: FunctionInfo, project_root: Path
self, function: FunctionToOptimize, project_root: Path
) -> list[HelperFunction]:
"""Find helper functions called by the target function."""
return find_helper_functions(function, project_root, analyzer=self._analyzer)
@ -127,7 +127,7 @@ class JavaSupport(LanguageSupport):
# === Code Transformation ===
def replace_function(
self, source: str, function: FunctionInfo, new_source: str
self, source: str, function: FunctionToOptimize, new_source: str
) -> str:
"""Replace a function in source code with new implementation."""
return replace_function(source, function, new_source, self._analyzer)
@ -156,13 +156,13 @@ class JavaSupport(LanguageSupport):
# === Instrumentation ===
def instrument_for_behavior(
self, source: str, functions: Sequence[FunctionInfo]
self, source: str, functions: Sequence[FunctionToOptimize]
) -> str:
"""Add behavior instrumentation to capture inputs/outputs."""
return instrument_for_behavior(source, functions, self._analyzer)
def instrument_for_benchmarking(
self, test_source: str, target_function: FunctionInfo
self, test_source: str, target_function: FunctionToOptimize
) -> str:
"""Add timing instrumentation to test code."""
return instrument_for_benchmarking(test_source, target_function, self._analyzer)
@ -317,7 +317,7 @@ class JavaSupport(LanguageSupport):
)
def instrument_source_for_line_profiler(
self, func_info: FunctionInfo, line_profiler_output_file: Path
self, func_info: FunctionToOptimize, line_profiler_output_file: Path
) -> bool:
"""Instrument source code before line profiling."""
# Not yet implemented for Java

View file

@ -12,7 +12,8 @@ from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.base import FunctionInfo, TestInfo
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import TestInfo
from codeflash.languages.java.config import detect_java_project
from codeflash.languages.java.discovery import discover_test_methods
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
def discover_tests(
test_root: Path,
source_functions: Sequence[FunctionInfo],
source_functions: Sequence[FunctionToOptimize],
analyzer: JavaAnalyzer | None = None,
) -> dict[str, list[TestInfo]]:
"""Map source functions to their tests via static analysis.
@ -48,9 +49,9 @@ def discover_tests(
analyzer = analyzer or get_java_analyzer()
# Build a map of function names for quick lookup
function_map: dict[str, FunctionInfo] = {}
function_map: dict[str, FunctionToOptimize] = {}
for func in source_functions:
function_map[func.name] = func
function_map[func.function_name] = func
function_map[func.qualified_name] = func
# Find all test files (various naming conventions)
@ -77,7 +78,7 @@ def discover_tests(
for func_name in matched_functions:
result[func_name].append(
TestInfo(
test_name=test_method.name,
test_name=test_method.function_name,
test_file=test_file,
test_class=test_method.class_name,
)
@ -90,9 +91,9 @@ def discover_tests(
def _match_test_to_functions(
test_method: FunctionInfo,
test_method: FunctionToOptimize,
test_source: str,
function_map: dict[str, FunctionInfo],
function_map: dict[str, FunctionToOptimize],
analyzer: JavaAnalyzer,
) -> list[str]:
"""Match a test method to source functions it might exercise.
@ -100,7 +101,7 @@ def _match_test_to_functions(
Args:
test_method: The test method.
test_source: Full source code of the test file.
function_map: Map of function names to FunctionInfo.
function_map: Map of function names to FunctionToOptimize.
analyzer: JavaAnalyzer instance.
Returns:
@ -111,10 +112,10 @@ def _match_test_to_functions(
# Strategy 1: Test method name contains function name
# e.g., testAdd -> add, testCalculatorAdd -> Calculator.add
test_name_lower = test_method.name.lower()
test_name_lower = test_method.function_name.lower()
for func_name, func_info in function_map.items():
if func_info.name.lower() in test_name_lower:
if func_info.function_name.lower() in test_name_lower:
matched.append(func_info.qualified_name)
# Strategy 2: Method call analysis
@ -126,8 +127,8 @@ def _match_test_to_functions(
method_calls = _find_method_calls_in_range(
tree.root_node,
source_bytes,
test_method.start_line,
test_method.end_line,
test_method.starting_line,
test_method.ending_line,
analyzer,
)
@ -285,7 +286,7 @@ def _find_method_calls_in_range(
def find_tests_for_function(
function: FunctionInfo,
function: FunctionToOptimize,
test_root: Path,
analyzer: JavaAnalyzer | None = None,
) -> list[TestInfo]:
@ -336,7 +337,7 @@ def get_test_class_for_source_class(
def discover_all_tests(
test_root: Path,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionInfo]:
) -> list[FunctionToOptimize]:
"""Discover all test methods in a test directory.
Args:
@ -344,11 +345,11 @@ def discover_all_tests(
analyzer: Optional JavaAnalyzer instance.
Returns:
List of FunctionInfo for all test methods.
List of FunctionToOptimize for all test methods.
"""
analyzer = analyzer or get_java_analyzer()
all_tests: list[FunctionInfo] = []
all_tests: list[FunctionToOptimize] = []
# Find all test files (various naming conventions)
test_files = (
@ -408,7 +409,7 @@ def get_test_methods_for_class(
test_file: Path,
test_class_name: str | None = None,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionInfo]:
) -> list[FunctionToOptimize]:
"""Get all test methods in a specific test class.
Args:
@ -417,7 +418,7 @@ def get_test_methods_for_class(
analyzer: Optional JavaAnalyzer instance.
Returns:
List of FunctionInfo for test methods.
List of FunctionToOptimize for test methods.
"""
tests = discover_test_methods(test_file, analyzer)
@ -455,7 +456,7 @@ def build_test_mapping_for_project(
# Discover all source functions
from codeflash.languages.java.discovery import discover_functions
source_functions: list[FunctionInfo] = []
source_functions: list[FunctionToOptimize] = []
for java_file in config.source_root.rglob("*.java"):
funcs = discover_functions(java_file, analyzer=analyzer)
source_functions.extend(funcs)

View file

@ -0,0 +1,842 @@
"""Find references for JavaScript/TypeScript functions.
This module provides functionality to find all references (call sites) of a function
across a JavaScript/TypeScript codebase. Similar to Jedi's find_references for Python,
this uses tree-sitter to parse and analyze code.
Key features:
- Finds all call sites of a function across multiple files
- Handles various import patterns (named, default, namespace, re-exports, aliases)
- Supports both ES modules and CommonJS
- Handles memoized functions, callbacks, and method calls
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pathlib import Path
from tree_sitter import Node
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
logger = logging.getLogger(__name__)
@dataclass
class Reference:
"""Represents a reference (call site) to a function."""
file_path: Path # File containing the reference
line: int # 1-indexed line number
column: int # 0-indexed column number
end_line: int # 1-indexed end line
end_column: int # 0-indexed end column
context: str # The line of code containing the reference
reference_type: str # Type: "call", "callback", "memoized", "import", "reexport"
import_name: str | None # Name used to import the function (may differ from original)
caller_function: str | None = None # Name of the function containing this reference
@dataclass
class ExportedFunction:
"""Represents how a function is exported from its source file."""
function_name: str # The local function name
export_name: str | None # The name it's exported as (may differ)
is_default: bool # Whether it's a default export
file_path: Path # The source file
@dataclass
class ReferenceSearchContext:
"""Context for tracking visited files during reference search."""
visited_files: set[Path] = field(default_factory=set)
max_files: int = 1000 # Limit to prevent runaway searches
class ReferenceFinder:
"""Finds all references to a function across a JavaScript/TypeScript codebase.
This class provides functionality similar to Jedi's find_references for Python,
but for JavaScript/TypeScript using tree-sitter.
Example usage:
```python
from codeflash.languages.javascript.find_references import ReferenceFinder
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
func = FunctionToOptimize(
function_name="myHelper", file_path=Path("/my/project/src/utils.ts"), parents=[], language="javascript"
)
finder = ReferenceFinder(project_root=Path("/my/project"))
references = finder.find_references(func)
for ref in references:
print(f"{ref.file_path}:{ref.line} - {ref.context}")
```
"""
# File extensions to search
EXTENSIONS = (".ts", ".tsx", ".js", ".jsx", ".mjs", ".cjs")
def __init__(self, project_root: Path, exclude_patterns: list[str] | None = None) -> None:
"""Initialize the ReferenceFinder.
Args:
project_root: Root directory of the project to search.
exclude_patterns: Glob patterns of directories/files to exclude.
Defaults to ['node_modules', 'dist', 'build', '.git'].
"""
self.project_root = project_root
self.exclude_patterns = exclude_patterns or ["node_modules", "dist", "build", ".git", "coverage", "__pycache__"]
self._file_cache: dict[Path, str] = {}
def find_references(
self, function_to_optimize: FunctionToOptimize, include_definition: bool = False, max_files: int = 1000
) -> list[Reference]:
"""Find all references to a function across the project.
Args:
function_to_optimize: The function to find references for.
include_definition: Whether to include the function definition itself.
max_files: Maximum number of files to search (prevents runaway searches).
Returns:
List of Reference objects describing each call site.
"""
from codeflash.languages.treesitter_utils import get_analyzer_for_file
function_name = function_to_optimize.function_name
source_file = function_to_optimize.file_path
references: list[Reference] = []
context = ReferenceSearchContext(max_files=max_files)
# Step 1: Analyze how the function is exported from its source file
source_code = self._read_file(source_file)
if source_code is None:
logger.warning("Could not read source file: %s", source_file)
return references
analyzer = get_analyzer_for_file(source_file)
exported = self._analyze_exports(function_to_optimize, source_file, source_code, analyzer)
if not exported:
logger.debug("Function %s is not exported from %s", function_name, source_file)
# Still search in same file for internal references
same_file_refs = self._find_references_in_file(
source_file, source_code, function_name, None, analyzer, include_self=not include_definition
)
references.extend(same_file_refs)
return references
# Step 2: Find all files that might import from the source file
context.visited_files.add(source_file)
# Track files that re-export our function (we'll search for imports to these too)
reexport_files: list[tuple[Path, str]] = [] # (file_path, export_name)
# Step 3: Search all project files for imports and calls
# We use a separate set for files checked for re-exports to avoid duplicate work
checked_for_reexports: set[Path] = set()
for file_path in self._iter_project_files():
if file_path in context.visited_files:
continue
if len(context.visited_files) >= context.max_files:
logger.warning("Reached max file limit (%d), stopping search", max_files)
break
file_code = self._read_file(file_path)
if file_code is None:
continue
file_analyzer = get_analyzer_for_file(file_path)
# Check if this file imports from the source file
imports = file_analyzer.find_imports(file_code)
import_info = self._find_matching_import(imports, source_file, file_path, exported)
if import_info:
# Found an import - mark as visited and search for calls
context.visited_files.add(file_path)
import_name, original_import = import_info
file_refs = self._find_references_in_file(
file_path, file_code, function_name, import_name, file_analyzer, include_self=True
)
references.extend(file_refs)
# Always check for re-exports (even without direct import match)
# This handles the case where a file re-exports from our source file
if file_path not in checked_for_reexports:
checked_for_reexports.add(file_path)
reexport_refs = self._find_reexports_direct(file_path, file_code, source_file, exported, file_analyzer)
references.extend(reexport_refs)
# Track re-export files for later searching
for ref in reexport_refs:
reexport_files.append((file_path, ref.import_name))
# Step 4: Follow re-export chains to find references through re-exports
for reexport_file, reexport_name in reexport_files:
# Create a new ExportedFunction for the re-exported function
reexported = ExportedFunction(
function_name=reexport_name, export_name=reexport_name, is_default=False, file_path=reexport_file
)
# Search for imports to the re-export file
for file_path in self._iter_project_files():
if file_path in context.visited_files:
continue
if file_path == reexport_file:
continue
if len(context.visited_files) >= context.max_files:
break
file_code = self._read_file(file_path)
if file_code is None:
continue
file_analyzer = get_analyzer_for_file(file_path)
imports = file_analyzer.find_imports(file_code)
# Check if this file imports from the re-export file
import_info = self._find_matching_import(imports, reexport_file, file_path, reexported)
if import_info:
context.visited_files.add(file_path)
import_name, original_import = import_info
file_refs = self._find_references_in_file(
file_path, file_code, reexport_name, import_name, file_analyzer, include_self=True
)
# Avoid duplicates
existing_locs = {(r.file_path, r.line, r.column) for r in references}
for ref in file_refs:
if (ref.file_path, ref.line, ref.column) not in existing_locs:
references.append(ref)
# Step 5: Include references in the same file (internal calls)
if include_definition or not exported:
same_file_refs = self._find_references_in_file(
source_file, source_code, function_name, None, analyzer, include_self=True
)
# Filter out duplicate references
existing_locs = {(r.file_path, r.line, r.column) for r in references}
for ref in same_file_refs:
if (ref.file_path, ref.line, ref.column) not in existing_locs:
references.append(ref)
# Step 6: Deduplicate references (same file, line, column)
seen: set[tuple[Path, int, int]] = set()
unique_refs: list[Reference] = []
for ref in references:
key = (ref.file_path, ref.line, ref.column)
if key not in seen:
seen.add(key)
unique_refs.append(ref)
return unique_refs
def _analyze_exports(
self, function_to_optimize: FunctionToOptimize, file_path: Path, source_code: str, analyzer: TreeSitterAnalyzer
) -> ExportedFunction | None:
"""Analyze how a function is exported from its file.
For class methods, also checks if the containing class is exported.
Args:
function_to_optimize: The function to check.
file_path: Path to the source file.
source_code: Source code content.
analyzer: TreeSitterAnalyzer instance.
Returns:
ExportedFunction if the function is exported, None otherwise.
"""
function_name = function_to_optimize.function_name
class_name = function_to_optimize.class_name
is_exported, export_name = analyzer.is_function_exported(source_code, function_name, class_name)
if not is_exported:
return None
return ExportedFunction(
function_name=function_name,
export_name=export_name,
is_default=(export_name == "default"),
file_path=file_path,
)
def _find_matching_import(
self, imports: list[ImportInfo], source_file: Path, importing_file: Path, exported: ExportedFunction
) -> tuple[str, ImportInfo] | None:
"""Find if any import in a file imports the target function.
Args:
imports: List of imports in the file.
source_file: Path to the file containing the function definition.
importing_file: Path to the file being checked for imports.
exported: Information about how the function is exported.
Returns:
Tuple of (imported_name, ImportInfo) if found, None otherwise.
"""
from codeflash.languages.javascript.import_resolver import ImportResolver
resolver = ImportResolver(self.project_root)
for imp in imports:
# Resolve the import to see if it points to our source file
resolved = resolver.resolve_import(imp, importing_file)
if resolved is None:
continue
if resolved.file_path != source_file:
continue
# This import is from our source file - check if it imports our function
if exported.is_default:
# Default export - check default import
if imp.default_import:
return (imp.default_import, imp)
# Also check namespace import
if imp.namespace_import:
return (f"{imp.namespace_import}.default", imp)
else:
# Named export - check named imports
export_name = exported.export_name or exported.function_name
for name, alias in imp.named_imports:
if name == export_name:
return (alias if alias else name, imp)
# Check namespace import
if imp.namespace_import:
return (f"{imp.namespace_import}.{export_name}", imp)
# Handle CommonJS default import used as namespace
# e.g., const helpers = require('./helpers'); helpers.processConfig()
# In this case, default_import acts like a namespace
if imp.default_import and not imp.named_imports:
return (f"{imp.default_import}.{export_name}", imp)
return None
def _find_references_in_file(
self,
file_path: Path,
source_code: str,
function_name: str,
import_name: str | None,
analyzer: TreeSitterAnalyzer,
include_self: bool = True,
) -> list[Reference]:
"""Find all references to a function within a single file.
Args:
file_path: Path to the file to search.
source_code: Source code content.
function_name: Original function name.
import_name: Name the function is imported as (may be different).
analyzer: TreeSitterAnalyzer instance.
include_self: Whether to include references in the file.
Returns:
List of Reference objects.
"""
references: list[Reference] = []
source_bytes = source_code.encode("utf8")
tree = analyzer.parse(source_bytes)
lines = source_code.splitlines()
# The name to search for (either imported name or original)
search_name = import_name if import_name else function_name
# Handle namespace imports (e.g., "utils.helper")
if "." in search_name:
namespace, member = search_name.split(".", 1)
self._find_member_calls(tree.root_node, source_bytes, lines, file_path, namespace, member, references, None)
else:
# Find direct calls and other reference types
self._find_identifier_references(
tree.root_node, source_bytes, lines, file_path, search_name, function_name, references, None
)
return references
def _find_identifier_references(
self,
node: Node,
source_bytes: bytes,
lines: list[str],
file_path: Path,
search_name: str,
original_name: str,
references: list[Reference],
current_function: str | None,
) -> None:
"""Recursively find references to an identifier in the AST.
Args:
node: Current tree-sitter node.
source_bytes: Source code as bytes.
lines: Source code split into lines.
file_path: Path to the file.
search_name: Name to search for.
original_name: Original function name.
references: List to append references to.
current_function: Name of the containing function (for context).
"""
# Track current function context
new_current_function = current_function
if node.type in ("function_declaration", "method_definition"):
name_node = node.child_by_field_name("name")
if name_node:
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
elif node.type in ("variable_declarator",):
# Arrow function or function expression assigned to variable
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if name_node and value_node and value_node.type in ("arrow_function", "function_expression"):
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
# Check for call expressions
if node.type == "call_expression":
func_node = node.child_by_field_name("function")
if func_node and func_node.type == "identifier":
name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8")
if name == search_name:
ref = self._create_reference(file_path, func_node, lines, "call", search_name, current_function)
references.append(ref)
# Check for identifiers used as callbacks or passed as arguments
elif node.type == "identifier":
name = source_bytes[node.start_byte : node.end_byte].decode("utf8")
if name == search_name:
parent = node.parent
# Determine reference type based on context
ref_type = self._determine_reference_type(node, parent, source_bytes)
if ref_type:
ref = self._create_reference(file_path, node, lines, ref_type, search_name, current_function)
references.append(ref)
# Recurse into children
for child in node.children:
self._find_identifier_references(
child, source_bytes, lines, file_path, search_name, original_name, references, new_current_function
)
def _find_member_calls(
self,
node: Node,
source_bytes: bytes,
lines: list[str],
file_path: Path,
namespace: str,
member: str,
references: list[Reference],
current_function: str | None,
) -> None:
"""Find calls to namespace.member (e.g., utils.helper()).
Args:
node: Current tree-sitter node.
source_bytes: Source code as bytes.
lines: Source code split into lines.
file_path: Path to the file.
namespace: The namespace/object name.
member: The member/property name.
references: List to append references to.
current_function: Name of the containing function.
"""
# Track current function context
new_current_function = current_function
if node.type in ("function_declaration", "method_definition"):
name_node = node.child_by_field_name("name")
if name_node:
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
# Check for call expressions with member access
if node.type == "call_expression":
func_node = node.child_by_field_name("function")
if func_node and func_node.type == "member_expression":
obj_node = func_node.child_by_field_name("object")
prop_node = func_node.child_by_field_name("property")
if obj_node and prop_node:
obj_name = source_bytes[obj_node.start_byte : obj_node.end_byte].decode("utf8")
prop_name = source_bytes[prop_node.start_byte : prop_node.end_byte].decode("utf8")
if obj_name == namespace and prop_name == member:
ref = self._create_reference(
file_path, func_node, lines, "call", f"{namespace}.{member}", current_function
)
references.append(ref)
# Also check for member expression used as callback
elif node.type == "member_expression":
obj_node = node.child_by_field_name("object")
prop_node = node.child_by_field_name("property")
if obj_node and prop_node:
obj_name = source_bytes[obj_node.start_byte : obj_node.end_byte].decode("utf8")
prop_name = source_bytes[prop_node.start_byte : prop_node.end_byte].decode("utf8")
if obj_name == namespace and prop_name == member:
parent = node.parent
if parent and parent.type != "call_expression":
ref_type = self._determine_reference_type(node, parent, source_bytes)
if ref_type:
ref = self._create_reference(
file_path, node, lines, ref_type, f"{namespace}.{member}", current_function
)
references.append(ref)
# Recurse into children
for child in node.children:
self._find_member_calls(
child, source_bytes, lines, file_path, namespace, member, references, new_current_function
)
def _determine_reference_type(self, node: Node, parent: Node | None, source_bytes: bytes) -> str | None:
"""Determine the type of reference based on AST context.
Args:
node: The identifier node.
parent: The parent node.
source_bytes: Source code as bytes.
Returns:
Reference type string or None if this isn't a valid reference.
"""
if parent is None:
return None
# Skip import statements
if parent.type in ("import_specifier", "import_clause", "named_imports"):
return None
# Skip function declarations (the function name itself)
if parent.type in ("function_declaration", "method_definition"):
name_node = parent.child_by_field_name("name")
if name_node and name_node.id == node.id:
return None
# Skip variable declarations where this is being defined
if parent.type == "variable_declarator":
name_node = parent.child_by_field_name("name")
if name_node and name_node.id == node.id:
return None
# Skip export specifiers
if parent.type == "export_specifier":
return None
# Check if passed as argument (callback or memoized)
if parent.type == "arguments":
# Check if grandparent is a memoize call
grandparent = parent.parent
if grandparent and grandparent.type == "call_expression":
func_node = grandparent.child_by_field_name("function")
if func_node:
func_name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8")
if any(m in func_name.lower() for m in ["memoize", "memo", "cache"]):
return "memoized"
return "callback"
# Check if used in array (often callback patterns)
if parent.type == "array":
return "callback"
# Check if passed to memoize/memoization functions (direct call check)
if parent.type == "call_expression":
func_node = parent.child_by_field_name("function")
if func_node:
func_name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8")
if any(m in func_name.lower() for m in ["memoize", "memo", "cache"]):
return "memoized"
# Check if used in a call expression as the function
if parent.type == "call_expression":
func_node = parent.child_by_field_name("function")
if func_node and func_node.id == node.id:
return "call"
# Check if assigned to a property
if parent.type in ("pair", "property"):
return "property"
# Check if part of member expression (method call setup)
if parent.type == "member_expression":
obj_node = parent.child_by_field_name("object")
if obj_node and obj_node.id == node.id:
# This is the object in obj.method
return None # We'll catch the actual call elsewhere
# Generic reference
return "reference"
def _create_reference(
self,
file_path: Path,
node: Node,
lines: list[str],
ref_type: str,
import_name: str,
caller_function: str | None,
) -> Reference:
"""Create a Reference object from a node.
Args:
file_path: Path to the file.
node: The tree-sitter node.
lines: Source code lines.
ref_type: Type of reference.
import_name: Name the function was imported as.
caller_function: Name of the containing function.
Returns:
A Reference object.
"""
line_num = node.start_point[0] + 1 # 1-indexed
context = lines[node.start_point[0]] if node.start_point[0] < len(lines) else ""
return Reference(
file_path=file_path,
line=line_num,
column=node.start_point[1],
end_line=node.end_point[0] + 1,
end_column=node.end_point[1],
context=context.strip(),
reference_type=ref_type,
import_name=import_name,
caller_function=caller_function,
)
def _find_reexports(
self,
file_path: Path,
source_code: str,
exported: ExportedFunction,
analyzer: TreeSitterAnalyzer,
context: ReferenceSearchContext,
) -> list[Reference]:
"""Find re-exports of the function.
Re-exports look like: export { helper } from './utils'
Args:
file_path: Path to the file being checked.
source_code: Source code content.
exported: Information about the original export.
analyzer: TreeSitterAnalyzer instance.
context: Search context.
Returns:
List of Reference objects for re-exports.
"""
references: list[Reference] = []
exports = analyzer.find_exports(source_code)
lines = source_code.splitlines()
for exp in exports:
if not exp.is_reexport:
continue
# Check if this re-exports our function
export_name = exported.export_name or exported.function_name
for name, alias in exp.exported_names:
if name == export_name:
# This is a re-export of our function
# Create a reference with the line info from the export
context_line = lines[exp.start_line - 1] if exp.start_line <= len(lines) else ""
ref = Reference(
file_path=file_path,
line=exp.start_line,
column=0,
end_line=exp.end_line,
end_column=0,
context=context_line.strip(),
reference_type="reexport",
import_name=alias if alias else name,
caller_function=None,
)
references.append(ref)
return references
def _find_reexports_direct(
self,
file_path: Path,
source_code: str,
source_file: Path,
exported: ExportedFunction,
analyzer: TreeSitterAnalyzer,
) -> list[Reference]:
"""Find re-exports that directly reference our source file.
This method checks if a file has re-export statements that
reference our source file.
Args:
file_path: Path to the file being checked.
source_code: Source code content.
source_file: The original source file we're looking for references to.
exported: Information about the original export.
analyzer: TreeSitterAnalyzer instance.
Returns:
List of Reference objects for re-exports.
"""
from codeflash.languages.javascript.import_resolver import ImportResolver
references: list[Reference] = []
exports = analyzer.find_exports(source_code)
lines = source_code.splitlines()
resolver = ImportResolver(self.project_root)
for exp in exports:
if not exp.is_reexport or not exp.reexport_source:
continue
# Create a fake ImportInfo to resolve the re-export source
from codeflash.languages.treesitter_utils import ImportInfo
fake_import = ImportInfo(
module_path=exp.reexport_source,
default_import=None,
named_imports=[],
namespace_import=None,
is_type_only=False,
start_line=exp.start_line,
end_line=exp.end_line,
)
resolved = resolver.resolve_import(fake_import, file_path)
if resolved is None or resolved.file_path != source_file:
continue
# This file re-exports from our source file
export_name = exported.export_name or exported.function_name
for name, alias in exp.exported_names:
if name == export_name:
context_line = lines[exp.start_line - 1] if exp.start_line <= len(lines) else ""
ref = Reference(
file_path=file_path,
line=exp.start_line,
column=0,
end_line=exp.end_line,
end_column=0,
context=context_line.strip(),
reference_type="reexport",
import_name=alias if alias else name,
caller_function=None,
)
references.append(ref)
return references
def _iter_project_files(self) -> list[Path]:
"""Iterate over all JavaScript/TypeScript files in the project.
Returns:
List of file paths to search.
"""
files: list[Path] = []
for ext in self.EXTENSIONS:
for file_path in self.project_root.rglob(f"*{ext}"):
# Check exclusion patterns
if self._should_exclude(file_path):
continue
files.append(file_path)
return files
def _should_exclude(self, file_path: Path) -> bool:
"""Check if a file should be excluded from search.
Args:
file_path: Path to check.
Returns:
True if the file should be excluded.
"""
path_str = str(file_path)
return any(pattern in path_str for pattern in self.exclude_patterns)
def _read_file(self, file_path: Path) -> str | None:
"""Read a file's contents with caching.
Args:
file_path: Path to the file.
Returns:
File contents or None if unreadable.
"""
if file_path in self._file_cache:
return self._file_cache[file_path]
try:
content = file_path.read_text(encoding="utf-8")
self._file_cache[file_path] = content
return content
except Exception as e:
logger.debug("Could not read file %s: %s", file_path, e)
return None
def find_references(
function_to_optimize: FunctionToOptimize, project_root: Path | None = None, max_files: int = 1000
) -> list[Reference]:
"""Convenience function to find all references to a function.
This is a simple wrapper around ReferenceFinder for common use cases.
Args:
function_to_optimize: The function to find references for.
project_root: Root directory of the project. If None, uses source_file's parent.
max_files: Maximum number of files to search.
Returns:
List of Reference objects describing each call site.
Example:
```python
from pathlib import Path
from codeflash.languages.javascript.find_references import find_references
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
func = FunctionToOptimize(
function_name="myHelper", file_path=Path("/my/project/src/utils.ts"), parents=[], language="javascript"
)
refs = find_references(func, project_root=Path("/my/project"))
for ref in refs:
print(f"{ref.file_path}:{ref.line}:{ref.column} - {ref.reference_type}")
```
"""
if project_root is None:
project_root = function_to_optimize.file_path.parent
finder = ReferenceFinder(project_root)
return finder.find_references(function_to_optimize, max_files=max_files)

View file

@ -12,7 +12,8 @@ from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from codeflash.languages.base import FunctionInfo, HelperFunction
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import HelperFunction
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
logger = logging.getLogger(__name__)
@ -43,7 +44,8 @@ class ImportResolver:
project_root: Root directory of the project.
"""
self.project_root = project_root
# Resolve to real path to handle macOS symlinks like /var -> /private/var
self.project_root = project_root.resolve()
self._resolution_cache: dict[tuple[Path, str], Path | None] = {}
def resolve_import(self, import_info: ImportInfo, source_file: Path) -> ResolvedImport | None:
@ -302,7 +304,7 @@ class MultiFileHelperFinder:
def find_helpers(
self,
function: FunctionInfo,
function: FunctionToOptimize,
source: str,
analyzer: TreeSitterAnalyzer,
imports: list[ImportInfo],
@ -328,7 +330,7 @@ class MultiFileHelperFinder:
all_functions = analyzer.find_functions(source, include_methods=True)
target_func = None
for func in all_functions:
if func.name == function.name and func.start_line == function.start_line:
if func.name == function.function_name and func.start_line == function.starting_line:
target_func = func
break
@ -505,7 +507,7 @@ class MultiFileHelperFinder:
Dictionary mapping file paths to lists of helper functions.
"""
from codeflash.languages.base import FunctionInfo
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.treesitter_utils import get_analyzer_for_file
if context.current_depth >= context.max_depth:
@ -525,9 +527,13 @@ class MultiFileHelperFinder:
analyzer = get_analyzer_for_file(file_path)
imports = analyzer.find_imports(source)
# Create FunctionInfo for the helper
func_info = FunctionInfo(
name=helper.name, file_path=file_path, start_line=helper.start_line, end_line=helper.end_line, parents=()
# Create FunctionToOptimize for the helper
func_info = FunctionToOptimize(
function_name=helper.name,
file_path=file_path,
parents=[],
starting_line=helper.start_line,
ending_line=helper.end_line,
)
# Recursively find helpers

View file

@ -72,15 +72,16 @@ class StandaloneCallTransformer:
"""
def __init__(self, func_name: str, qualified_name: str, capture_func: str) -> None:
self.func_name = func_name
self.qualified_name = qualified_name
def __init__(self, function_to_optimize: FunctionToOptimize, capture_func: str) -> None:
self.function_to_optimize = function_to_optimize
self.func_name = function_to_optimize.function_name
self.qualified_name = function_to_optimize.qualified_name
self.capture_func = capture_func
self.invocation_counter = 0
# Pattern to match func_name( with optional leading await and optional object prefix
# Captures: (whitespace)(await )?(object.)*func_name(
# We'll filter out expect() and codeflash. cases in the transform loop
self._call_pattern = re.compile(rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(func_name)}\s*\(")
self._call_pattern = re.compile(rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(self.func_name)}\s*\(")
def transform(self, code: str) -> str:
"""Transform all standalone calls in the code."""
@ -310,7 +311,7 @@ class StandaloneCallTransformer:
def transform_standalone_calls(
code: str, func_name: str, qualified_name: str, capture_func: str, start_counter: int = 0
code: str, function_to_optimize: FunctionToOptimize, capture_func: str, start_counter: int = 0
) -> tuple[str, int]:
"""Transform standalone func(...) calls in JavaScript test code.
@ -318,8 +319,7 @@ def transform_standalone_calls(
Args:
code: The test code to transform.
func_name: Name of the function being tested.
qualified_name: Fully qualified function name.
function_to_optimize: The function being tested.
capture_func: The capture function to use ('capture' or 'capturePerf').
start_counter: Starting value for the invocation counter.
@ -327,9 +327,7 @@ def transform_standalone_calls(
Tuple of (transformed code, final counter value).
"""
transformer = StandaloneCallTransformer(
func_name=func_name, qualified_name=qualified_name, capture_func=capture_func
)
transformer = StandaloneCallTransformer(function_to_optimize=function_to_optimize, capture_func=capture_func)
transformer.invocation_counter = start_counter
result = transformer.transform(code)
return result, transformer.invocation_counter
@ -348,15 +346,18 @@ class ExpectCallTransformer:
- Multi-arg assertions: expect(func(args)).toBeCloseTo(0.5, 2)
"""
def __init__(self, func_name: str, qualified_name: str, capture_func: str, remove_assertions: bool = False) -> None:
self.func_name = func_name
self.qualified_name = qualified_name
def __init__(
self, function_to_optimize: FunctionToOptimize, capture_func: str, remove_assertions: bool = False
) -> None:
self.function_to_optimize = function_to_optimize
self.func_name = function_to_optimize.function_name
self.qualified_name = function_to_optimize.qualified_name
self.capture_func = capture_func
self.remove_assertions = remove_assertions
self.invocation_counter = 0
# Pattern to match start of expect((object.)*func_name(
# Captures: (whitespace), (object prefix like calc. or this.)
self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(func_name)}\s*\(")
self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(self.func_name)}\s*\(")
def transform(self, code: str) -> str:
"""Transform all expect calls in the code."""
@ -601,7 +602,7 @@ class ExpectCallTransformer:
def transform_expect_calls(
code: str, func_name: str, qualified_name: str, capture_func: str, remove_assertions: bool = False
code: str, function_to_optimize: FunctionToOptimize, capture_func: str, remove_assertions: bool = False
) -> tuple[str, int]:
"""Transform expect(func(...)).assertion() calls in JavaScript test code.
@ -609,8 +610,7 @@ def transform_expect_calls(
Args:
code: The test code to transform.
func_name: Name of the function being tested.
qualified_name: Fully qualified function name.
function_to_optimize: The function being tested.
capture_func: The capture function to use ('capture' or 'capturePerf').
remove_assertions: If True, remove assertions entirely (for generated tests).
@ -619,10 +619,7 @@ def transform_expect_calls(
"""
transformer = ExpectCallTransformer(
func_name=func_name,
qualified_name=qualified_name,
capture_func=capture_func,
remove_assertions=remove_assertions,
function_to_optimize=function_to_optimize, capture_func=capture_func, remove_assertions=remove_assertions
)
result = transformer.transform(code)
return result, transformer.invocation_counter
@ -658,8 +655,6 @@ def inject_profiling_into_existing_js_test(
logger.error(f"Failed to read test file {test_path}: {e}")
return False, None
func_name = function_to_optimize.function_name
# Get the relative path for test identification
try:
rel_path = test_path.relative_to(tests_project_root)
@ -667,14 +662,12 @@ def inject_profiling_into_existing_js_test(
rel_path = test_path
# Check if the function is imported/required in this test file
if not _is_function_used_in_test(test_code, func_name):
logger.debug(f"Function '{func_name}' not found in test file {test_path}")
if not _is_function_used_in_test(test_code, function_to_optimize.function_name):
logger.debug(f"Function '{function_to_optimize.function_name}' not found in test file {test_path}")
return False, None
# Instrument the test code
instrumented_code = _instrument_js_test_code(
test_code, func_name, str(rel_path), mode, function_to_optimize.qualified_name
)
instrumented_code = _instrument_js_test_code(test_code, function_to_optimize, str(rel_path), mode)
if instrumented_code == test_code:
logger.debug(f"No changes made to test file {test_path}")
@ -716,16 +709,15 @@ def _is_function_used_in_test(code: str, func_name: str) -> bool:
def _instrument_js_test_code(
code: str, func_name: str, test_file_path: str, mode: str, qualified_name: str, remove_assertions: bool = False
code: str, function_to_optimize: FunctionToOptimize, test_file_path: str, mode: str, remove_assertions: bool = False
) -> str:
"""Instrument JavaScript test code with profiling capture calls.
Args:
code: Original test code.
func_name: Name of the function to instrument.
function_to_optimize: The function to instrument.
test_file_path: Relative path to test file.
mode: Testing mode (behavior or performance).
qualified_name: Fully qualified function name.
remove_assertions: If True, remove expect assertions entirely (for generated/regression tests).
If False, keep the expect wrapper (for existing user-written tests).
@ -771,8 +763,7 @@ def _instrument_js_test_code(
# Transform expect calls using the refactored transformer
code, expect_counter = transform_expect_calls(
code=code,
func_name=func_name,
qualified_name=qualified_name,
function_to_optimize=function_to_optimize,
capture_func=capture_func,
remove_assertions=remove_assertions,
)
@ -780,11 +771,7 @@ def _instrument_js_test_code(
# Transform standalone calls (not inside expect wrappers)
# Continue counter from expect transformer to ensure unique IDs
code, _final_counter = transform_standalone_calls(
code=code,
func_name=func_name,
qualified_name=qualified_name,
capture_func=capture_func,
start_counter=expect_counter,
code=code, function_to_optimize=function_to_optimize, capture_func=capture_func, start_counter=expect_counter
)
return code
@ -941,7 +928,7 @@ def get_instrumented_test_path(original_path: Path, mode: str) -> Path:
def instrument_generated_js_test(
test_code: str, function_name: str, qualified_name: str, mode: str = TestingMode.BEHAVIOR
test_code: str, function_to_optimize: FunctionToOptimize, mode: str = TestingMode.BEHAVIOR
) -> str:
"""Instrument generated JavaScript/TypeScript test code.
@ -956,8 +943,7 @@ def instrument_generated_js_test(
Args:
test_code: The generated test code to instrument.
function_name: Name of the function being tested.
qualified_name: Fully qualified function name (e.g., 'module.funcName').
function_to_optimize: The function being tested.
mode: Testing mode - "behavior" or "performance".
Returns:
@ -971,9 +957,8 @@ def instrument_generated_js_test(
# Generated tests are treated as regression tests, so we remove LLM-generated assertions
return _instrument_js_test_code(
code=test_code,
func_name=function_name,
function_to_optimize=function_to_optimize,
test_file_path="generated_test",
mode=mode,
qualified_name=qualified_name,
remove_assertions=True,
)

View file

@ -16,7 +16,7 @@ from codeflash.languages.treesitter_utils import get_analyzer_for_file
if TYPE_CHECKING:
from pathlib import Path
from codeflash.languages.base import FunctionInfo
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
logger = logging.getLogger(__name__)
@ -40,7 +40,7 @@ class JavaScriptLineProfiler:
self.output_file = output_file
self.profiler_var = "__codeflash_line_profiler__"
def instrument_source(self, source: str, file_path: Path, functions: list[FunctionInfo]) -> str:
def instrument_source(self, source: str, file_path: Path, functions: list[FunctionToOptimize]) -> str:
"""Instrument JavaScript source code with line profiling.
Adds profiling instrumentation to track line-level execution for the
@ -65,10 +65,10 @@ class JavaScriptLineProfiler:
lines = source.splitlines(keepends=True)
# Process functions in reverse order to preserve line numbers
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
for func in sorted(functions, key=lambda f: f.starting_line, reverse=True):
func_lines = self._instrument_function(func, lines, file_path)
start_idx = func.start_line - 1
end_idx = func.end_line
start_idx = func.starting_line - 1
end_idx = func.ending_line
lines = lines[:start_idx] + func_lines + lines[end_idx:]
instrumented_source = "".join(lines)
@ -171,7 +171,7 @@ const __codeflash_save_interval__ = setInterval(() => {self.profiler_var}.save()
if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // Don't keep process alive
"""
def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path) -> list[str]:
def _instrument_function(self, func: FunctionToOptimize, lines: list[str], file_path: Path) -> list[str]:
"""Instrument a single function with line profiling.
Args:
@ -183,7 +183,7 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
Instrumented function lines.
"""
func_lines = lines[func.start_line - 1 : func.end_line]
func_lines = lines[func.starting_line - 1 : func.ending_line]
instrumented_lines = []
# Parse the function to find executable lines
@ -194,7 +194,7 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
tree = analyzer.parse(source.encode("utf8"))
executable_lines = self._find_executable_lines(tree.root_node, source.encode("utf8"))
except Exception as e:
logger.warning("Failed to parse function %s: %s", func.name, e)
logger.warning("Failed to parse function %s: %s", func.function_name, e)
return func_lines
# Add profiling to each executable line
@ -203,7 +203,7 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
for local_idx, line in enumerate(func_lines):
local_line_num = local_idx + 1 # 1-indexed within function
global_line_num = func.start_line + local_idx # Global line number in original file
global_line_num = func.starting_line + local_idx # Global line number in original file
stripped = line.strip()
# Add enterFunction() call after the opening brace of the function

View file

@ -184,10 +184,18 @@ def _get_relative_import_path(target_path: Path, source_path: Path) -> str:
def add_js_extension(module_path: str) -> str:
"""Add .js extension to relative module paths for ESM compatibility."""
if module_path.startswith(("./", "../")):
if not module_path.endswith(".js") and not module_path.endswith(".mjs"):
return module_path + ".js"
"""Process module path for ESM compatibility.
NOTE: This function intentionally does NOT add extensions because:
1. TypeScript projects resolve modules without explicit extensions
2. Adding .js to .ts imports causes "Cannot find module" errors
3. Modern bundlers (webpack, vite, etc.) handle extension resolution automatically
The function name is preserved for backward compatibility but the behavior
has been changed to NOT add extensions.
"""
# Previously this function added .js extensions, but this caused module resolution
# errors in TypeScript projects. We now preserve paths without adding extensions.
return module_path

View file

@ -12,22 +12,16 @@ import xml.etree.ElementTree as ET
from pathlib import Path
from typing import TYPE_CHECKING, Any
from codeflash.languages.base import (
CodeContext,
FunctionFilterCriteria,
FunctionInfo,
HelperFunction,
Language,
ParentInfo,
TestInfo,
TestResult,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, Language, TestInfo, TestResult
from codeflash.languages.registry import register_language
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
from codeflash.models.models import FunctionParent
if TYPE_CHECKING:
from collections.abc import Sequence
from codeflash.languages.base import ReferenceInfo
from codeflash.languages.treesitter_utils import TypeDefinition
logger = logging.getLogger(__name__)
@ -53,6 +47,11 @@ class JavaScriptSupport:
"""File extensions supported by JavaScript."""
return (".js", ".jsx", ".mjs", ".cjs")
@property
def default_file_extension(self) -> str:
"""Default file extension for JavaScript."""
return ".js"
@property
def test_framework(self) -> str:
"""Primary test framework for JavaScript."""
@ -66,7 +65,7 @@ class JavaScriptSupport:
def discover_functions(
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
) -> list[FunctionInfo]:
) -> list[FunctionToOptimize]:
"""Find all optimizable functions in a JavaScript file.
Uses tree-sitter to parse the file and find functions.
@ -76,7 +75,7 @@ class JavaScriptSupport:
filter_criteria: Optional criteria to filter functions.
Returns:
List of FunctionInfo objects for discovered functions.
List of FunctionToOptimize objects for discovered functions.
"""
criteria = filter_criteria or FunctionFilterCriteria()
@ -93,7 +92,7 @@ class JavaScriptSupport:
source, include_methods=criteria.include_methods, include_arrow_functions=True, require_name=True
)
functions: list[FunctionInfo] = []
functions: list[FunctionToOptimize] = []
for func in tree_functions:
# Check for return statement if required
if criteria.require_return and not analyzer.has_return_statement(func, source):
@ -104,24 +103,24 @@ class JavaScriptSupport:
continue
# Build parents list
parents: list[ParentInfo] = []
parents: list[FunctionParent] = []
if func.class_name:
parents.append(ParentInfo(name=func.class_name, type="ClassDef"))
parents.append(FunctionParent(name=func.class_name, type="ClassDef"))
if func.parent_function:
parents.append(ParentInfo(name=func.parent_function, type="FunctionDef"))
parents.append(FunctionParent(name=func.parent_function, type="FunctionDef"))
functions.append(
FunctionInfo(
name=func.name,
FunctionToOptimize(
function_name=func.name,
file_path=file_path,
start_line=func.start_line,
end_line=func.end_line,
start_col=func.start_col,
end_col=func.end_col,
parents=tuple(parents),
parents=parents,
starting_line=func.start_line,
ending_line=func.end_line,
starting_col=func.start_col,
ending_col=func.end_col,
is_async=func.is_async,
is_method=func.is_method,
language=self.language,
language=str(self.language),
doc_start_line=func.doc_start_line,
)
)
@ -132,7 +131,7 @@ class JavaScriptSupport:
logger.warning("Failed to parse %s: %s", file_path, e)
return []
def discover_functions_from_source(self, source: str, file_path: Path | None = None) -> list[FunctionInfo]:
def discover_functions_from_source(self, source: str, file_path: Path | None = None) -> list[FunctionToOptimize]:
"""Find all functions in source code string.
Uses tree-sitter to parse the source and find functions.
@ -142,7 +141,7 @@ class JavaScriptSupport:
file_path: Optional file path for context (used for language detection).
Returns:
List of FunctionInfo objects for discovered functions.
List of FunctionToOptimize objects for discovered functions.
"""
try:
@ -156,27 +155,27 @@ class JavaScriptSupport:
source, include_methods=True, include_arrow_functions=True, require_name=True
)
functions: list[FunctionInfo] = []
functions: list[FunctionToOptimize] = []
for func in tree_functions:
# Build parents list
parents: list[ParentInfo] = []
parents: list[FunctionParent] = []
if func.class_name:
parents.append(ParentInfo(name=func.class_name, type="ClassDef"))
parents.append(FunctionParent(name=func.class_name, type="ClassDef"))
if func.parent_function:
parents.append(ParentInfo(name=func.parent_function, type="FunctionDef"))
parents.append(FunctionParent(name=func.parent_function, type="FunctionDef"))
functions.append(
FunctionInfo(
name=func.name,
FunctionToOptimize(
function_name=func.name,
file_path=file_path or Path("unknown"),
start_line=func.start_line,
end_line=func.end_line,
start_col=func.start_col,
end_col=func.end_col,
parents=tuple(parents),
parents=parents,
starting_line=func.start_line,
ending_line=func.end_line,
starting_col=func.start_col,
ending_col=func.end_col,
is_async=func.is_async,
is_method=func.is_method,
language=self.language,
language=str(self.language),
doc_start_line=func.doc_start_line,
)
)
@ -198,7 +197,9 @@ class JavaScriptSupport:
"""
return ["*.test.js", "*.test.jsx", "*.spec.js", "*.spec.jsx", "__tests__/**/*.js", "__tests__/**/*.jsx"]
def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionInfo]) -> dict[str, list[TestInfo]]:
def discover_tests(
self, test_root: Path, source_functions: Sequence[FunctionToOptimize]
) -> dict[str, list[TestInfo]]:
"""Map source functions to their tests via static analysis.
For JavaScript, this uses static analysis to find test files
@ -240,7 +241,7 @@ class JavaScriptSupport:
# Match source functions to tests
for func in source_functions:
if func.name in imported_names or func.name in source:
if func.function_name in imported_names or func.function_name in source:
if func.qualified_name not in result:
result[func.qualified_name] = []
for test_name in test_functions:
@ -282,7 +283,7 @@ class JavaScriptSupport:
# === Code Analysis ===
def extract_code_context(self, function: FunctionInfo, project_root: Path, module_root: Path) -> CodeContext:
def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext:
"""Extract function code and its dependencies.
Uses tree-sitter to analyze imports and find helper functions.
@ -309,16 +310,16 @@ class JavaScriptSupport:
tree_functions = analyzer.find_functions(source, include_methods=True, include_arrow_functions=True)
target_func = None
for func in tree_functions:
if func.name == function.name and func.start_line == function.start_line:
if func.name == function.function_name and func.start_line == function.starting_line:
target_func = func
break
# Extract the function source, including JSDoc if present
lines = source.splitlines(keepends=True)
if function.start_line and function.end_line:
if function.starting_line and function.ending_line:
# Use doc_start_line if available, otherwise fall back to start_line
effective_start = (target_func.doc_start_line if target_func else None) or function.start_line
target_lines = lines[effective_start - 1 : function.end_line]
effective_start = (target_func.doc_start_line if target_func else None) or function.starting_line
target_lines = lines[effective_start - 1 : function.ending_line]
target_code = "".join(target_lines)
else:
target_code = ""
@ -334,7 +335,7 @@ class JavaScriptSupport:
if class_name:
# Find the class definition in the source to get proper indentation, JSDoc, constructor, and fields
class_info = self._find_class_definition(source, class_name, analyzer, function.name)
class_info = self._find_class_definition(source, class_name, analyzer, function.function_name)
if class_info:
class_jsdoc, class_indent, constructor_code, fields_code = class_info
# Build the class body with fields, constructor, and target method
@ -395,7 +396,7 @@ class JavaScriptSupport:
# If not, raise an error to fail the optimization early
if target_code and not self.validate_syntax(target_code):
error_msg = (
f"Extracted code for {function.name} is not syntactically valid JavaScript. "
f"Extracted code for {function.function_name} is not syntactically valid JavaScript. "
f"Cannot proceed with optimization."
)
logger.error(error_msg)
@ -544,7 +545,12 @@ class JavaScriptSupport:
return (constructor_code, fields_code)
def _find_helper_functions(
self, function: FunctionInfo, source: str, analyzer: TreeSitterAnalyzer, imports: list[Any], module_root: Path
self,
function: FunctionToOptimize,
source: str,
analyzer: TreeSitterAnalyzer,
imports: list[Any],
module_root: Path,
) -> list[HelperFunction]:
"""Find helper functions called by the target function.
@ -569,7 +575,7 @@ class JavaScriptSupport:
# Find the target function's tree-sitter node
target_func = None
for func in all_functions:
if func.name == function.name and func.start_line == function.start_line:
if func.name == function.function_name and func.start_line == function.starting_line:
target_func = func
break
@ -585,7 +591,7 @@ class JavaScriptSupport:
# Match calls to functions in the same file
for func in all_functions:
if func.name in calls_set and func.name != function.name:
if func.name in calls_set and func.name != function.function_name:
# Extract source including JSDoc if present
effective_start = func.doc_start_line or func.start_line
helper_lines = lines[effective_start - 1 : func.end_line]
@ -715,7 +721,12 @@ class JavaScriptSupport:
return "\n".join(global_lines)
def _extract_type_definitions_context(
self, function: FunctionInfo, source: str, analyzer: TreeSitterAnalyzer, imports: list[Any], module_root: Path
self,
function: FunctionToOptimize,
source: str,
analyzer: TreeSitterAnalyzer,
imports: list[Any],
module_root: Path,
) -> tuple[str, set[str]]:
"""Extract type definitions used by the function for read-only context.
@ -741,7 +752,7 @@ class JavaScriptSupport:
"""
# Extract type names from function parameters and return type
type_names = analyzer.extract_type_annotations(source, function.name, function.start_line or 1)
type_names = analyzer.extract_type_annotations(source, function.function_name, function.starting_line or 1)
# If this is a class method, also extract types from class fields
if function.is_method and function.parents:
@ -939,7 +950,7 @@ class JavaScriptSupport:
return found_definitions
def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> list[HelperFunction]:
def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]:
"""Find helper functions called by the target function.
Args:
@ -956,12 +967,68 @@ class JavaScriptSupport:
imports = analyzer.find_imports(source)
return self._find_helper_functions(function, source, analyzer, imports, project_root)
except Exception as e:
logger.warning("Failed to find helpers for %s: %s", function.name, e)
logger.warning("Failed to find helpers for %s: %s", function.function_name, e)
return []
def find_references(
self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500
) -> list[ReferenceInfo]:
"""Find all references (call sites) to a function across the codebase.
Uses tree-sitter to find all places where a JavaScript/TypeScript function
is called, including direct calls, callbacks, memoized versions, and re-exports.
Args:
function: The function to find references for.
project_root: Root of the project to search.
tests_root: Root of tests directory (references in tests are excluded).
max_files: Maximum number of files to search.
Returns:
List of ReferenceInfo objects describing each reference location.
"""
from codeflash.languages.base import ReferenceInfo
from codeflash.languages.javascript.find_references import ReferenceFinder
try:
finder = ReferenceFinder(project_root)
refs = finder.find_references(function, max_files=max_files)
# Convert to ReferenceInfo and filter out tests
result: list[ReferenceInfo] = []
for ref in refs:
# Exclude test files if tests_root is provided
if tests_root:
try:
ref.file_path.relative_to(tests_root)
continue # Skip if in tests_root
except ValueError:
pass # Not in tests_root, include it
result.append(
ReferenceInfo(
file_path=ref.file_path,
line=ref.line,
column=ref.column,
end_line=ref.end_line,
end_column=ref.end_column,
context=ref.context,
reference_type=ref.reference_type,
import_name=ref.import_name,
caller_function=ref.caller_function,
)
)
return result
except Exception as e:
logger.warning("Failed to find references for %s: %s", function.function_name, e)
return []
# === Code Transformation ===
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:
def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
"""Replace a function in source code with new implementation.
Uses node-based replacement to extract the method body from the optimized code
@ -973,7 +1040,7 @@ class JavaScriptSupport:
Args:
source: Original source code.
function: FunctionInfo identifying the function to replace.
function: FunctionToOptimize identifying the function to replace.
new_source: New source code containing the optimized function.
Returns:
@ -981,13 +1048,13 @@ class JavaScriptSupport:
if new_source is empty or invalid.
"""
if function.start_line is None or function.end_line is None:
logger.error("Function %s has no line information", function.name)
if function.starting_line is None or function.ending_line is None:
logger.error("Function %s has no line information", function.function_name)
return source
# If new_source is empty or whitespace-only, return original unchanged
if not new_source or not new_source.strip():
logger.warning("Empty new_source provided for %s, returning original", function.name)
logger.warning("Empty new_source provided for %s, returning original", function.function_name)
return source
# Get analyzer for parsing
@ -1001,19 +1068,21 @@ class JavaScriptSupport:
stripped_new_source = new_source.strip()
if stripped_new_source.startswith("/**"):
# new_source includes JSDoc, use full replacement to apply the new JSDoc
if not self._contains_function_declaration(new_source, function.name, analyzer):
logger.warning("new_source does not contain function %s, returning original", function.name)
if not self._contains_function_declaration(new_source, function.function_name, analyzer):
logger.warning("new_source does not contain function %s, returning original", function.function_name)
return source
return self._replace_function_text_based(source, function, new_source, analyzer)
# Extract just the method body from the new source
new_body = self._extract_function_body(new_source, function.name, analyzer)
new_body = self._extract_function_body(new_source, function.function_name, analyzer)
if new_body is None:
logger.warning("Could not extract body for %s from optimized code, using full replacement", function.name)
logger.warning(
"Could not extract body for %s from optimized code, using full replacement", function.function_name
)
# Verify that new_source contains actual code before falling back to text replacement
# This prevents deletion of the original function when new_source is invalid
if not self._contains_function_declaration(new_source, function.name, analyzer):
logger.warning("new_source does not contain function %s, returning original", function.name)
if not self._contains_function_declaration(new_source, function.function_name, analyzer):
logger.warning("new_source does not contain function %s, returning original", function.function_name)
return source
return self._replace_function_text_based(source, function, new_source, analyzer)
@ -1141,7 +1210,7 @@ class JavaScriptSupport:
return source_bytes[body_node.start_byte : body_node.end_byte].decode("utf8")
def _replace_function_body(
self, source: str, function: FunctionInfo, new_body: str, analyzer: TreeSitterAnalyzer
self, source: str, function: FunctionToOptimize, new_body: str, analyzer: TreeSitterAnalyzer
) -> str:
"""Replace the body of a function in source code with new body content.
@ -1149,7 +1218,7 @@ class JavaScriptSupport:
Args:
source: Original source code.
function: FunctionInfo identifying the function to modify.
function: FunctionToOptimize identifying the function to modify.
new_body: New body content (including braces).
analyzer: TreeSitterAnalyzer for parsing.
@ -1200,9 +1269,9 @@ class JavaScriptSupport:
return None
func_node = find_function_at_line(tree.root_node, function.name, function.start_line)
func_node = find_function_at_line(tree.root_node, function.function_name, function.starting_line)
if not func_node:
logger.warning("Could not find function %s at line %s", function.name, function.start_line)
logger.warning("Could not find function %s at line %s", function.function_name, function.starting_line)
return source
# Find the body node in the original
@ -1214,7 +1283,7 @@ class JavaScriptSupport:
break
if not body_node:
logger.warning("Could not find body for function %s", function.name)
logger.warning("Could not find body for function %s", function.function_name)
return source
# Get the indentation of the original body's opening brace
@ -1282,7 +1351,7 @@ class JavaScriptSupport:
return result.decode("utf8")
def _replace_function_text_based(
self, source: str, function: FunctionInfo, new_source: str, analyzer: TreeSitterAnalyzer
self, source: str, function: FunctionToOptimize, new_source: str, analyzer: TreeSitterAnalyzer
) -> str:
"""Fallback text-based replacement when node-based replacement fails.
@ -1290,7 +1359,7 @@ class JavaScriptSupport:
Args:
source: Original source code.
function: FunctionInfo identifying the function to replace.
function: FunctionToOptimize identifying the function to replace.
new_source: New function source code.
analyzer: TreeSitterAnalyzer for parsing.
@ -1307,16 +1376,16 @@ class JavaScriptSupport:
tree_functions = analyzer.find_functions(source, include_methods=True, include_arrow_functions=True)
target_func = None
for func in tree_functions:
if func.name == function.name and func.start_line == function.start_line:
if func.name == function.function_name and func.start_line == function.starting_line:
target_func = func
break
# Use doc_start_line if available, otherwise fall back to start_line
effective_start = (target_func.doc_start_line if target_func else None) or function.start_line
effective_start = (target_func.doc_start_line if target_func else None) or function.starting_line
# Get indentation from original function's first line
if function.start_line <= len(lines):
original_first_line = lines[function.start_line - 1]
if function.starting_line <= len(lines):
original_first_line = lines[function.starting_line - 1]
original_indent = len(original_first_line) - len(original_first_line.lstrip())
else:
original_indent = 0
@ -1359,7 +1428,7 @@ class JavaScriptSupport:
# Build result
before = lines[: effective_start - 1]
after = lines[function.end_line :]
after = lines[function.ending_line :]
result_lines = before + new_lines + after
return "".join(result_lines)
@ -1510,7 +1579,7 @@ class JavaScriptSupport:
# === Instrumentation ===
def instrument_for_behavior(
self, source: str, functions: Sequence[FunctionInfo], output_file: Path | None = None
self, source: str, functions: Sequence[FunctionToOptimize], output_file: Path | None = None
) -> str:
"""Add behavior instrumentation to capture inputs/outputs.
@ -1539,7 +1608,7 @@ class JavaScriptSupport:
tracer = JavaScriptTracer(output_file)
return tracer.instrument_source(source, functions[0].file_path, list(functions))
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionInfo) -> str:
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str:
"""Add timing instrumentation to test code.
For JavaScript/Jest, we can use Jest's built-in timing or add custom timing.
@ -1769,51 +1838,98 @@ class JavaScriptSupport:
rel_path = source_file.relative_to(project_root)
return "../" + rel_path.with_suffix("").as_posix()
def verify_requirements(self, project_root: Path, test_framework: str = "jest") -> tuple[bool, list[str]]:
"""Verify that all JavaScript requirements are met.
Checks for:
1. Node.js installation
2. npm availability
3. Test framework (jest/vitest) installation
4. node_modules existence
Args:
project_root: The project root directory.
test_framework: The test framework to check for ("jest" or "vitest").
Returns:
Tuple of (success, list of error messages).
"""
errors: list[str] = []
# Check Node.js
try:
result = subprocess.run(["node", "--version"], check=False, capture_output=True, text=True, timeout=10)
if result.returncode != 0:
errors.append("Node.js is not installed. Please install Node.js 18+ from https://nodejs.org/")
except FileNotFoundError:
errors.append("Node.js is not installed. Please install Node.js 18+ from https://nodejs.org/")
except Exception as e:
errors.append(f"Failed to check Node.js: {e}")
# Check npm
try:
result = subprocess.run(["npm", "--version"], check=False, capture_output=True, text=True, timeout=10)
if result.returncode != 0:
errors.append("npm is not available. Please ensure npm is installed with Node.js.")
except FileNotFoundError:
errors.append("npm is not available. Please ensure npm is installed with Node.js.")
except Exception as e:
errors.append(f"Failed to check npm: {e}")
# Check node_modules exists
node_modules = project_root / "node_modules"
if not node_modules.exists():
errors.append(
f"node_modules not found in {project_root}. Please run 'npm install' to install dependencies."
)
else:
# Check test framework is installed
framework_path = node_modules / test_framework
if not framework_path.exists():
errors.append(
f"{test_framework} is not installed. "
f"Please run 'npm install --save-dev {test_framework}' to install it."
)
return len(errors) == 0, errors
def ensure_runtime_environment(self, project_root: Path) -> bool:
"""Ensure codeflash npm package is installed.
Attempts to install the npm package for test instrumentation.
Falls back to copying files if npm install fails.
Args:
project_root: The project root directory.
Returns:
True if npm package is installed, False if falling back to file copy.
True if npm package is installed, False otherwise.
"""
import subprocess
from codeflash.cli_cmds.console import logger
# Check if package is already installed
node_modules_pkg = project_root / "node_modules" / "codeflash"
if node_modules_pkg.exists():
logger.debug("codeflash already installed")
return True
# Try to install from local package first (for development)
local_package_path = Path(__file__).parent.parent.parent.parent / "packages" / "cli"
if local_package_path.exists():
try:
result = subprocess.run(
["npm", "install", "--save-dev", str(local_package_path)],
check=False,
cwd=project_root,
capture_output=True,
text=True,
timeout=120,
)
if result.returncode == 0:
logger.debug("Installed codeflash from local package")
return True
logger.warning(f"Failed to install local package: {result.stderr}")
except Exception as e:
logger.warning(f"Error installing local package: {e}")
# Could try npm registry here in the future:
# subprocess.run(["npm", "install", "--save-dev", "codeflash"], ...)
try:
result = subprocess.run(
["npm", "install", "--save-dev", "codeflash"],
check=False,
cwd=project_root,
capture_output=True,
text=True,
timeout=120,
)
if result.returncode == 0:
logger.debug("Installed codeflash from npm registry")
return True
logger.warning(f"Failed to install codeflash: {result.stderr}")
except Exception as e:
logger.warning(f"Error installing codeflash: {e}")
logger.error("Could not install codeflash. Please run: npm install --save-dev codeflash")
return False
def instrument_existing_test(
@ -1853,7 +1969,7 @@ class JavaScriptSupport:
def instrument_source_for_line_profiler(
# TODO: use the context to instrument helper files also
self,
func_info: FunctionInfo,
func_info: FunctionToOptimize,
line_profiler_output_file: Path,
) -> bool:
from codeflash.languages.javascript.line_profiler import JavaScriptLineProfiler
@ -1925,8 +2041,9 @@ class JavaScriptSupport:
project_root: Path | None = None,
enable_coverage: bool = False,
candidate_index: int = 0,
test_framework: str | None = None,
) -> tuple[Path, Any, Path | None, Path | None]:
"""Run Jest behavioral tests.
"""Run behavioral tests using the detected test framework.
Args:
test_paths: TestFiles object containing test file information.
@ -1936,11 +2053,29 @@ class JavaScriptSupport:
project_root: Project root directory.
enable_coverage: Whether to collect coverage information.
candidate_index: Index of the candidate being tested.
test_framework: Test framework to use ("jest" or "vitest"). If None, uses singleton.
Returns:
Tuple of (result_file_path, subprocess_result, coverage_path, config_path).
"""
from codeflash.languages.test_framework import get_js_test_framework_or_default
framework = test_framework or get_js_test_framework_or_default()
if framework == "vitest":
from codeflash.languages.javascript.vitest_runner import run_vitest_behavioral_tests
return run_vitest_behavioral_tests(
test_paths=test_paths,
test_env=test_env,
cwd=cwd,
timeout=timeout,
project_root=project_root,
enable_coverage=enable_coverage,
candidate_index=candidate_index,
)
from codeflash.languages.javascript.test_runner import run_jest_behavioral_tests
return run_jest_behavioral_tests(
@ -1963,8 +2098,9 @@ class JavaScriptSupport:
min_loops: int = 5,
max_loops: int = 100_000,
target_duration_seconds: float = 10.0,
test_framework: str | None = None,
) -> tuple[Path, Any]:
"""Run Jest benchmarking tests.
"""Run benchmarking tests using the detected test framework.
Args:
test_paths: TestFiles object containing test file information.
@ -1975,11 +2111,30 @@ class JavaScriptSupport:
min_loops: Minimum number of loops for benchmarking.
max_loops: Maximum number of loops for benchmarking.
target_duration_seconds: Target duration for benchmarking in seconds.
test_framework: Test framework to use ("jest" or "vitest"). If None, uses singleton.
Returns:
Tuple of (result_file_path, subprocess_result).
"""
from codeflash.languages.test_framework import get_js_test_framework_or_default
framework = test_framework or get_js_test_framework_or_default()
if framework == "vitest":
from codeflash.languages.javascript.vitest_runner import run_vitest_benchmarking_tests
return run_vitest_benchmarking_tests(
test_paths=test_paths,
test_env=test_env,
cwd=cwd,
timeout=timeout,
project_root=project_root,
min_loops=min_loops,
max_loops=max_loops,
target_duration_ms=int(target_duration_seconds * 1000),
)
from codeflash.languages.javascript.test_runner import run_jest_benchmarking_tests
return run_jest_benchmarking_tests(
@ -2001,8 +2156,9 @@ class JavaScriptSupport:
timeout: int | None = None,
project_root: Path | None = None,
line_profile_output_file: Path | None = None,
test_framework: str | None = None,
) -> tuple[Path, Any]:
"""Run Jest tests for line profiling.
"""Run tests for line profiling using the detected test framework.
Args:
test_paths: TestFiles object containing test file information.
@ -2011,11 +2167,28 @@ class JavaScriptSupport:
timeout: Optional timeout in seconds.
project_root: Project root directory.
line_profile_output_file: Path where line profile results will be written.
test_framework: Test framework to use ("jest" or "vitest"). If None, uses singleton.
Returns:
Tuple of (result_file_path, subprocess_result).
"""
from codeflash.languages.test_framework import get_js_test_framework_or_default
framework = test_framework or get_js_test_framework_or_default()
if framework == "vitest":
from codeflash.languages.javascript.vitest_runner import run_vitest_line_profile_tests
return run_vitest_line_profile_tests(
test_paths=test_paths,
test_env=test_env,
cwd=cwd,
timeout=timeout,
project_root=project_root,
line_profile_output_file=line_profile_output_file,
)
from codeflash.languages.javascript.test_runner import run_jest_line_profile_tests
return run_jest_line_profile_tests(

View file

@ -13,6 +13,7 @@ from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.init_javascript import get_package_install_command
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.config_consts import STABILITY_CENTER_TOLERANCE, STABILITY_SPREAD_TOLERANCE
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
@ -21,6 +22,254 @@ if TYPE_CHECKING:
from codeflash.models.models import TestFiles
def _detect_bundler_module_resolution(project_root: Path) -> bool:
"""Detect if the project uses moduleResolution: 'bundler' in tsconfig.
TypeScript 5+ supports 'bundler' moduleResolution which requires
module: 'preserve' or ES2015+. This can cause issues with ts-jest
in some configurations.
This function also resolves extended tsconfigs to find bundler setting
in parent configs.
Args:
project_root: Root of the project to check.
Returns:
True if the project uses bundler moduleResolution.
"""
tsconfig_path = project_root / "tsconfig.json"
if not tsconfig_path.exists():
return False
visited_configs: set[Path] = set()
def check_tsconfig(config_path: Path) -> bool:
"""Recursively check tsconfig and its extends for bundler moduleResolution."""
if config_path in visited_configs:
return False
visited_configs.add(config_path)
if not config_path.exists():
return False
try:
content = config_path.read_text()
tsconfig = json.loads(content)
# Check direct moduleResolution setting
compiler_options = tsconfig.get("compilerOptions", {})
module_resolution = compiler_options.get("moduleResolution", "").lower()
if module_resolution == "bundler":
return True
# Check extended config if present
extends = tsconfig.get("extends")
if extends:
# Resolve the extended config path
if extends.startswith("."):
# Relative path
extended_path = (config_path.parent / extends).resolve()
if not extended_path.suffix:
extended_path = extended_path.with_suffix(".json")
else:
# Package reference (e.g., "@n8n/typescript-config/modern/tsconfig.json")
# Try to find it in node_modules
node_modules_path = project_root / "node_modules" / extends
if not node_modules_path.suffix:
node_modules_path = node_modules_path.with_suffix(".json")
if node_modules_path.exists():
extended_path = node_modules_path
else:
# Try parent directories for monorepo support
current = project_root.parent
extended_path = None
while current != current.parent:
candidate = current / "node_modules" / extends
if not candidate.suffix:
candidate = candidate.with_suffix(".json")
if candidate.exists():
extended_path = candidate
break
# Also check packages directory for workspace packages
packages_candidate = current / "packages" / extends
if not packages_candidate.suffix:
packages_candidate = packages_candidate.with_suffix(".json")
if packages_candidate.exists():
extended_path = packages_candidate
break
current = current.parent
if extended_path and extended_path.exists():
return check_tsconfig(extended_path)
return False
except Exception as e:
logger.debug(f"Failed to read {config_path}: {e}")
return False
return check_tsconfig(tsconfig_path)
def _create_codeflash_tsconfig(project_root: Path) -> Path:
"""Create a codeflash-compatible tsconfig for projects using bundler moduleResolution.
This creates a tsconfig that inherits from the project's tsconfig but overrides
moduleResolution to 'Node' for compatibility with ts-jest.
Args:
project_root: Root of the project.
Returns:
Path to the created tsconfig.codeflash.json file.
"""
codeflash_tsconfig_path = project_root / "tsconfig.codeflash.json"
# If it already exists, use it
if codeflash_tsconfig_path.exists():
logger.debug(f"Using existing {codeflash_tsconfig_path}")
return codeflash_tsconfig_path
# Read the original tsconfig to preserve most settings
original_tsconfig_path = project_root / "tsconfig.json"
try:
original_content = original_tsconfig_path.read_text()
original_tsconfig = json.loads(original_content)
except Exception:
original_tsconfig = {}
# Create a new tsconfig that extends the original but fixes moduleResolution
codeflash_tsconfig = {
"extends": "./tsconfig.json",
"compilerOptions": {
# Override bundler to Node for ts-jest compatibility
"moduleResolution": "Node",
# Ensure module is set to a compatible value
"module": "ESNext",
# These are generally safe defaults for testing
"esModuleInterop": True,
"skipLibCheck": True,
"isolatedModules": True,
},
}
# Preserve include/exclude from original if not in extends
if "include" in original_tsconfig:
codeflash_tsconfig["include"] = original_tsconfig["include"]
if "exclude" in original_tsconfig:
codeflash_tsconfig["exclude"] = original_tsconfig["exclude"]
try:
codeflash_tsconfig_path.write_text(json.dumps(codeflash_tsconfig, indent=2))
logger.debug(f"Created {codeflash_tsconfig_path} with Node moduleResolution")
except Exception as e:
logger.warning(f"Failed to create codeflash tsconfig: {e}")
return codeflash_tsconfig_path
def _create_codeflash_jest_config(project_root: Path, original_jest_config: Path | None) -> Path | None:
"""Create a Jest config that uses the codeflash tsconfig for ts-jest.
Args:
project_root: Root of the project.
original_jest_config: Path to the original Jest config, or None.
Returns:
Path to the codeflash Jest config, or None if creation failed.
"""
codeflash_jest_config_path = project_root / "jest.codeflash.config.js"
# If it already exists, use it
if codeflash_jest_config_path.exists():
logger.debug(f"Using existing {codeflash_jest_config_path}")
return codeflash_jest_config_path
# Create a wrapper Jest config that uses tsconfig.codeflash.json
if original_jest_config:
# Extend the original config
jest_config_content = f"""// Auto-generated by codeflash for bundler moduleResolution compatibility
const originalConfig = require('./{original_jest_config.name}');
const tsJestOptions = {{
isolatedModules: true,
tsconfig: 'tsconfig.codeflash.json',
}};
module.exports = {{
...originalConfig,
transform: {{
...originalConfig.transform,
'^.+\\\\.tsx?$': ['ts-jest', tsJestOptions],
}},
globals: {{
...originalConfig.globals,
'ts-jest': tsJestOptions,
}},
}};
"""
else:
# Create a minimal Jest config for TypeScript
jest_config_content = """// Auto-generated by codeflash for bundler moduleResolution compatibility
const tsJestOptions = {
isolatedModules: true,
tsconfig: 'tsconfig.codeflash.json',
};
module.exports = {
verbose: true,
testEnvironment: 'node',
testRegex: '\\\\.(test|spec)\\\\.(js|ts|tsx)$',
testPathIgnorePatterns: ['/dist/', '/node_modules/'],
transform: {
'^.+\\\\.tsx?$': ['ts-jest', tsJestOptions],
},
moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'],
};
"""
try:
codeflash_jest_config_path.write_text(jest_config_content)
logger.debug(f"Created {codeflash_jest_config_path} with codeflash tsconfig")
return codeflash_jest_config_path
except Exception as e:
logger.warning(f"Failed to create codeflash Jest config: {e}")
return None
def _get_jest_config_for_project(project_root: Path) -> Path | None:
"""Get the appropriate Jest config for the project.
If the project uses bundler moduleResolution, creates and returns a
codeflash-compatible Jest config. Otherwise, returns the project's
existing Jest config.
Args:
project_root: Root of the project.
Returns:
Path to the Jest config to use, or None if not found.
"""
# First check for existing Jest config
original_jest_config = _find_jest_config(project_root)
# Check if project uses bundler moduleResolution
if _detect_bundler_module_resolution(project_root):
logger.info("Detected bundler moduleResolution - creating compatible config")
# Create codeflash-compatible tsconfig
_create_codeflash_tsconfig(project_root)
# Create codeflash Jest config that uses it
codeflash_jest_config = _create_codeflash_jest_config(project_root, original_jest_config)
if codeflash_jest_config:
return codeflash_jest_config
return original_jest_config
def _find_node_project_root(file_path: Path) -> Path | None:
"""Find the Node.js project root by looking for package.json.
@ -47,6 +296,67 @@ def _find_node_project_root(file_path: Path) -> Path | None:
return None
def _find_jest_config(project_root: Path) -> Path | None:
"""Find Jest configuration file in the project.
Searches for common Jest config file names in the project root and parent
directories (for monorepo support). This is important for TypeScript projects
that require specific transformation configurations (e.g., next/jest, ts-jest, babel-jest).
Args:
project_root: Root of the project to search.
Returns:
Path to Jest config file, or None if not found.
"""
# Common Jest config file names, in order of preference
config_names = ["jest.config.ts", "jest.config.js", "jest.config.mjs", "jest.config.cjs", "jest.config.json"]
# First check the project root itself
for config_name in config_names:
config_path = project_root / config_name
if config_path.exists():
logger.debug(f"Found Jest config: {config_path}")
return config_path
# For monorepos, search parent directories up to the filesystem root
# Stop at common monorepo root indicators (git root, package.json with workspaces)
current = project_root.parent
max_depth = 5 # Don't search too far up
depth = 0
while current != current.parent and depth < max_depth:
for config_name in config_names:
config_path = current / config_name
if config_path.exists():
logger.debug(f"Found Jest config in parent directory: {config_path}")
return config_path
# Check if this looks like a monorepo root
package_json = current / "package.json"
if package_json.exists():
try:
import json
with package_json.open("r") as f:
pkg = json.load(f)
if "workspaces" in pkg:
# This is likely the monorepo root, stop here
break
except Exception:
pass
# Check for git root as another stopping point
if (current / ".git").exists():
break
current = current.parent
depth += 1
return None
def _is_esm_project(project_root: Path) -> bool:
"""Check if the project uses ES Modules.
@ -145,54 +455,28 @@ def _ensure_runtime_files(project_root: Path) -> None:
Installs codeflash package if not already present.
The package provides all runtime files needed for test instrumentation.
Uses the project's detected package manager (npm, pnpm, yarn, or bun).
Args:
project_root: The project root directory.
"""
# Check if package is already installed
node_modules_pkg = project_root / "node_modules" / "codeflash"
if node_modules_pkg.exists():
logger.debug("codeflash already installed")
return
# Try to install from local package first (for development)
local_package_path = Path(__file__).parent.parent.parent.parent / "packages" / "codeflash"
if local_package_path.exists():
try:
result = subprocess.run(
["npm", "install", "--save-dev", str(local_package_path)],
check=False,
cwd=project_root,
capture_output=True,
text=True,
timeout=120,
)
if result.returncode == 0:
logger.debug("Installed codeflash from local package")
return
logger.warning(f"Failed to install local package: {result.stderr}")
except Exception as e:
logger.warning(f"Error installing local package: {e}")
# Try to install from npm registry
install_cmd = get_package_install_command(project_root, "codeflash", dev=True)
try:
result = subprocess.run(
["npm", "install", "--save-dev", "codeflash"],
check=False,
cwd=project_root,
capture_output=True,
text=True,
timeout=120,
)
result = subprocess.run(install_cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=120)
if result.returncode == 0:
logger.debug("Installed codeflash from npm registry")
logger.debug(f"Installed codeflash using {install_cmd[0]}")
return
logger.warning(f"Failed to install from npm: {result.stderr}")
logger.warning(f"Failed to install codeflash: {result.stderr}")
except Exception as e:
logger.warning(f"Error installing from npm: {e}")
logger.warning(f"Error installing codeflash: {e}")
logger.error("Could not install codeflash. Please install it manually: npm install --save-dev codeflash")
logger.error(f"Could not install codeflash. Please install it manually: {' '.join(install_cmd)}")
def run_jest_behavioral_tests(
@ -251,13 +535,26 @@ def run_jest_behavioral_tests(
"--forceExit",
]
# Add Jest config if found - needed for TypeScript transformation
# Uses codeflash-compatible config if project has bundler moduleResolution
jest_config = _get_jest_config_for_project(effective_cwd)
if jest_config:
jest_cmd.append(f"--config={jest_config}")
# Add coverage flags if enabled
if enable_coverage:
jest_cmd.extend(["--coverage", "--coverageReporters=json", f"--coverageDirectory={coverage_dir}"])
if test_files:
jest_cmd.append("--runTestsByPath")
jest_cmd.extend(str(Path(f).resolve()) for f in test_files)
resolved_test_files = [str(Path(f).resolve()) for f in test_files]
jest_cmd.extend(resolved_test_files)
# Add --roots to include directories containing test files
# This is needed because some projects configure Jest with restricted roots
# (e.g., roots: ["<rootDir>/src"]) which excludes the test directory
test_dirs = {str(Path(f).resolve().parent) for f in test_files}
for test_dir in sorted(test_dirs):
jest_cmd.extend(["--roots", test_dir])
if timeout:
jest_cmd.append(f"--testTimeout={timeout * 1000}") # Jest uses milliseconds
@ -306,6 +603,14 @@ def run_jest_behavioral_tests(
args=result.args, returncode=result.returncode, stdout=result.stdout + "\n" + result.stderr, stderr=""
)
logger.debug(f"Jest result: returncode={result.returncode}")
# Log Jest output at WARNING level if tests fail and no XML output will be created
# This helps debug issues like import errors that cause Jest to fail early
if result.returncode != 0 and not result_file_path.exists():
logger.warning(
f"Jest failed with returncode={result.returncode} and no XML output created.\n"
f"Jest stdout: {result.stdout[:2000] if result.stdout else '(empty)'}\n"
f"Jest stderr: {result.stderr[:500] if result.stderr else '(empty)'}"
)
except subprocess.TimeoutExpired:
logger.warning(f"Jest tests timed out after {timeout}s")
result = subprocess.CompletedProcess(args=jest_cmd, returncode=-1, stdout="", stderr="Test execution timed out")
@ -465,9 +770,20 @@ def run_jest_benchmarking_tests(
"--runner=codeflash/loop-runner", # Use custom loop runner for in-process looping
]
# Add Jest config if found - needed for TypeScript transformation
# Uses codeflash-compatible config if project has bundler moduleResolution
jest_config = _get_jest_config_for_project(effective_cwd)
if jest_config:
jest_cmd.append(f"--config={jest_config}")
if test_files:
jest_cmd.append("--runTestsByPath")
jest_cmd.extend(str(Path(f).resolve()) for f in test_files)
resolved_test_files = [str(Path(f).resolve()) for f in test_files]
jest_cmd.extend(resolved_test_files)
# Add --roots to include directories containing test files
test_dirs = {str(Path(f).resolve().parent) for f in test_files}
for test_dir in sorted(test_dirs):
jest_cmd.extend(["--roots", test_dir])
if timeout:
jest_cmd.append(f"--testTimeout={timeout * 1000}")
@ -594,9 +910,20 @@ def run_jest_line_profile_tests(
"--forceExit",
]
# Add Jest config if found - needed for TypeScript transformation
# Uses codeflash-compatible config if project has bundler moduleResolution
jest_config = _get_jest_config_for_project(effective_cwd)
if jest_config:
jest_cmd.append(f"--config={jest_config}")
if test_files:
jest_cmd.append("--runTestsByPath")
jest_cmd.extend(str(Path(f).resolve()) for f in test_files)
resolved_test_files = [str(Path(f).resolve()) for f in test_files]
jest_cmd.extend(resolved_test_files)
# Add --roots to include directories containing test files
test_dirs = {str(Path(f).resolve().parent) for f in test_files}
for test_dir in sorted(test_dirs):
jest_cmd.extend(["--roots", test_dir])
if timeout:
jest_cmd.append(f"--testTimeout={timeout * 1000}")

View file

@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from pathlib import Path
from codeflash.languages.base import FunctionInfo
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
logger = logging.getLogger(__name__)
@ -40,7 +40,7 @@ class JavaScriptTracer:
self.output_db = output_db
self.tracer_var = "__codeflash_tracer__"
def instrument_source(self, source: str, file_path: Path, functions: list[FunctionInfo]) -> str:
def instrument_source(self, source: str, file_path: Path, functions: list[FunctionToOptimize]) -> str:
"""Instrument JavaScript source code with function tracing.
Wraps specified functions to capture their inputs and outputs.
@ -64,10 +64,10 @@ class JavaScriptTracer:
lines = source.splitlines(keepends=True)
# Process functions in reverse order to preserve line numbers
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
for func in sorted(functions, key=lambda f: f.starting_line, reverse=True):
instrumented = self._instrument_function(func, lines, file_path)
start_idx = func.start_line - 1
end_idx = func.end_line
start_idx = func.starting_line - 1
end_idx = func.ending_line
lines = lines[:start_idx] + instrumented + lines[end_idx:]
instrumented_source = "".join(lines)
@ -269,7 +269,7 @@ process.on('exit', () => {{
}});
"""
def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path) -> list[str]:
def _instrument_function(self, func: FunctionToOptimize, lines: list[str], file_path: Path) -> list[str]:
"""Instrument a single function with tracing.
Args:
@ -281,11 +281,11 @@ process.on('exit', () => {{
Instrumented function lines.
"""
func_lines = lines[func.start_line - 1 : func.end_line]
func_lines = lines[func.starting_line - 1 : func.ending_line]
func_text = "".join(func_lines)
# Detect function pattern
func_name = func.name
func_name = func.function_name
is_arrow = "=>" in func_text.split("\n")[0]
is_method = func.is_method
is_async = func.is_async

View file

@ -0,0 +1,492 @@
"""Vitest test runner for JavaScript/TypeScript.
This module provides functions for running Vitest tests for behavioral
verification and performance benchmarking.
"""
from __future__ import annotations
import subprocess
import time
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.init_javascript import get_package_install_command
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
if TYPE_CHECKING:
from codeflash.models.models import TestFiles
def _find_vitest_project_root(file_path: Path) -> Path | None:
"""Find the Vitest project root by looking for vitest/vite config or package.json.
Traverses up from the given file path to find the nearest directory
containing vitest.config.js/ts, vite.config.js/ts, or package.json.
Args:
file_path: A file path within the Vitest project.
Returns:
The project root directory, or None if not found.
"""
current = file_path.parent if file_path.is_file() else file_path
while current != current.parent: # Stop at filesystem root
# Check for Vitest-specific config files first
if (
(current / "vitest.config.js").exists()
or (current / "vitest.config.ts").exists()
or (current / "vitest.config.mjs").exists()
or (current / "vitest.config.mts").exists()
or (current / "vite.config.js").exists()
or (current / "vite.config.ts").exists()
or (current / "vite.config.mjs").exists()
or (current / "vite.config.mts").exists()
or (current / "package.json").exists()
):
return current
current = current.parent
return None
def _is_vitest_coverage_available(project_root: Path) -> bool:
"""Check if Vitest coverage package is available.
Args:
project_root: The project root directory.
Returns:
True if @vitest/coverage-v8 or @vitest/coverage-istanbul is installed.
"""
node_modules = project_root / "node_modules"
return (node_modules / "@vitest" / "coverage-v8").exists() or (
node_modules / "@vitest" / "coverage-istanbul"
).exists()
def _ensure_runtime_files(project_root: Path) -> None:
"""Ensure JavaScript runtime package is installed in the project.
Installs codeflash package if not already present.
The package provides all runtime files needed for test instrumentation.
Uses the project's detected package manager (npm, pnpm, yarn, or bun).
Args:
project_root: The project root directory.
"""
node_modules_pkg = project_root / "node_modules" / "codeflash"
if node_modules_pkg.exists():
logger.debug("codeflash already installed")
return
install_cmd = get_package_install_command(project_root, "codeflash", dev=True)
try:
result = subprocess.run(install_cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=120)
if result.returncode == 0:
logger.debug(f"Installed codeflash using {install_cmd[0]}")
return
logger.warning(f"Failed to install codeflash: {result.stderr}")
except Exception as e:
logger.warning(f"Error installing codeflash: {e}")
logger.error(f"Could not install codeflash. Please install it manually: {' '.join(install_cmd)}")
def _build_vitest_behavioral_command(
test_files: list[Path], timeout: int | None = None, output_file: Path | None = None
) -> list[str]:
"""Build Vitest command for behavioral tests.
Args:
test_files: List of test files to run.
timeout: Optional timeout in seconds.
output_file: Optional path for JUnit XML output.
Returns:
Command list for subprocess execution.
"""
cmd = [
"npx",
"vitest",
"run", # Single execution (not watch mode)
"--reporter=default",
"--reporter=junit",
"--no-file-parallelism", # Serial execution for deterministic timing
]
if output_file:
cmd.append(f"--outputFile={output_file}")
if timeout:
cmd.append(f"--test-timeout={timeout * 1000}") # Vitest uses milliseconds
# Add test files as positional arguments (Vitest style)
cmd.extend(str(f.resolve()) for f in test_files)
return cmd
def _build_vitest_benchmarking_command(
test_files: list[Path], timeout: int | None = None, output_file: Path | None = None
) -> list[str]:
"""Build Vitest command for benchmarking tests.
Args:
test_files: List of test files to run.
timeout: Optional timeout in seconds.
output_file: Optional path for JUnit XML output.
Returns:
Command list for subprocess execution.
"""
cmd = [
"npx",
"vitest",
"run", # Single execution (not watch mode)
"--reporter=default",
"--reporter=junit",
"--no-file-parallelism", # Serial execution for consistent benchmarking
]
if output_file:
cmd.append(f"--outputFile={output_file}")
if timeout:
cmd.append(f"--test-timeout={timeout * 1000}")
# Add test files as positional arguments
cmd.extend(str(f.resolve()) for f in test_files)
return cmd
def run_vitest_behavioral_tests(
test_paths: TestFiles,
test_env: dict[str, str],
cwd: Path,
*,
timeout: int | None = None,
project_root: Path | None = None,
enable_coverage: bool = False,
candidate_index: int = 0,
) -> tuple[Path, subprocess.CompletedProcess, Path | None, Path | None]:
"""Run Vitest tests and return results in a format compatible with pytest output.
Args:
test_paths: TestFiles object containing test file information.
test_env: Environment variables for the test run.
cwd: Working directory for running tests.
timeout: Optional timeout in seconds.
project_root: Vitest project root (directory containing vitest.config or package.json).
enable_coverage: Whether to collect coverage information.
candidate_index: Index of the candidate being tested.
Returns:
Tuple of (result_file_path, subprocess_result, coverage_json_path, None).
"""
result_file_path = get_run_tmp_file(Path("vitest_results.xml"))
# Get test files to run
test_files = [Path(file.instrumented_behavior_file_path) for file in test_paths.test_files]
# Use provided project_root, or detect it as fallback
if project_root is None and test_files:
project_root = _find_vitest_project_root(test_files[0])
# Use the project root, or fall back to provided cwd
effective_cwd = project_root if project_root else cwd
logger.debug(f"Vitest working directory: {effective_cwd}")
# Ensure the codeflash npm package is installed
_ensure_runtime_files(effective_cwd)
# Coverage output directory - only enable if coverage package is available
coverage_dir = get_run_tmp_file(Path("vitest_coverage"))
coverage_available = _is_vitest_coverage_available(effective_cwd) if enable_coverage else False
coverage_json_path = coverage_dir / "coverage-final.json" if coverage_available else None
if enable_coverage and not coverage_available:
logger.debug("Vitest coverage package not installed, running without coverage")
# Build Vitest command
vitest_cmd = _build_vitest_behavioral_command(test_files=test_files, timeout=timeout, output_file=result_file_path)
# Add coverage flags only if coverage is available
if coverage_available:
vitest_cmd.extend(["--coverage", "--coverage.reporter=json", f"--coverage.reportsDirectory={coverage_dir}"])
# Set up environment
vitest_env = test_env.copy()
# Set codeflash output file for the vitest helper to write timing/behavior data (SQLite format)
codeflash_sqlite_file = get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite"))
vitest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
vitest_env["CODEFLASH_TEST_ITERATION"] = str(candidate_index)
vitest_env["CODEFLASH_LOOP_INDEX"] = "1"
vitest_env["CODEFLASH_MODE"] = "behavior"
# Seed random number generator for reproducible test runs across original and optimized code
vitest_env["CODEFLASH_RANDOM_SEED"] = "42"
logger.debug(f"Running Vitest tests with command: {' '.join(vitest_cmd)}")
# Subprocess timeout should be much larger than per-test timeout to account for:
# - Vitest startup time (loading modules, compiling TypeScript)
# - Multiple tests running sequentially
# Use at least 120 seconds, or 10x the per-test timeout, whichever is larger
subprocess_timeout = max(120, (timeout or 60) * 10)
start_time_ns = time.perf_counter_ns()
try:
run_args = get_cross_platform_subprocess_run_args(
cwd=effective_cwd, env=vitest_env, timeout=subprocess_timeout, check=False, text=True, capture_output=True
)
result = subprocess.run(vitest_cmd, **run_args) # noqa: PLW1510
# Combine stderr into stdout for timing markers
if result.stderr and not result.stdout:
result = subprocess.CompletedProcess(
args=result.args, returncode=result.returncode, stdout=result.stderr, stderr=""
)
elif result.stderr:
result = subprocess.CompletedProcess(
args=result.args, returncode=result.returncode, stdout=result.stdout + "\n" + result.stderr, stderr=""
)
logger.debug(f"Vitest result: returncode={result.returncode}")
except subprocess.TimeoutExpired:
logger.warning(f"Vitest tests timed out after {subprocess_timeout}s")
result = subprocess.CompletedProcess(
args=vitest_cmd, returncode=-1, stdout="", stderr="Test execution timed out"
)
except FileNotFoundError:
logger.error("Vitest not found. Make sure Vitest is installed (npm install vitest)")
result = subprocess.CompletedProcess(
args=vitest_cmd, returncode=-1, stdout="", stderr="Vitest not found. Run: npm install vitest"
)
finally:
wall_clock_ns = time.perf_counter_ns() - start_time_ns
logger.debug(f"Vitest behavioral tests completed in {wall_clock_ns / 1e9:.2f}s")
return result_file_path, result, coverage_json_path, None
def run_vitest_benchmarking_tests(
test_paths: TestFiles,
test_env: dict[str, str],
cwd: Path,
*,
timeout: int | None = None,
project_root: Path | None = None,
min_loops: int = 5,
max_loops: int = 100,
target_duration_ms: int = 10_000,
stability_check: bool = True,
) -> tuple[Path, subprocess.CompletedProcess]:
"""Run Vitest benchmarking tests with external looping from Python.
Uses external process-level looping to run tests multiple times and
collect timing data. This matches the Python pytest approach where
looping is controlled externally for simplicity.
Args:
test_paths: TestFiles object containing test file information.
test_env: Environment variables for the test run.
cwd: Working directory for running tests.
timeout: Optional timeout in seconds for the entire benchmark run.
project_root: Vitest project root (directory containing vitest.config or package.json).
min_loops: Minimum number of loop iterations.
max_loops: Maximum number of loop iterations.
target_duration_ms: Target TOTAL duration in milliseconds for all loops.
stability_check: Whether to enable stability-based early stopping.
Returns:
Tuple of (result_file_path, subprocess_result with stdout from all iterations).
"""
result_file_path = get_run_tmp_file(Path("vitest_perf_results.xml"))
# Get performance test files
test_files = [Path(file.benchmarking_file_path) for file in test_paths.test_files if file.benchmarking_file_path]
# Use provided project_root, or detect it as fallback
if project_root is None and test_files:
project_root = _find_vitest_project_root(test_files[0])
effective_cwd = project_root if project_root else cwd
logger.debug(f"Vitest benchmarking working directory: {effective_cwd}")
# Ensure the codeflash npm package is installed
_ensure_runtime_files(effective_cwd)
# Build Vitest command for performance tests
vitest_cmd = _build_vitest_benchmarking_command(
test_files=test_files, timeout=timeout, output_file=result_file_path
)
# Base environment setup
vitest_env = test_env.copy()
codeflash_sqlite_file = get_run_tmp_file(Path("test_return_values_0.sqlite"))
vitest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
vitest_env["CODEFLASH_TEST_ITERATION"] = "0"
vitest_env["CODEFLASH_MODE"] = "performance"
vitest_env["CODEFLASH_RANDOM_SEED"] = "42"
# Internal loop configuration for capturePerf
vitest_env["CODEFLASH_PERF_LOOP_COUNT"] = str(max_loops)
vitest_env["CODEFLASH_PERF_MIN_LOOPS"] = str(min_loops)
vitest_env["CODEFLASH_PERF_TARGET_DURATION_MS"] = str(target_duration_ms)
vitest_env["CODEFLASH_PERF_STABILITY_CHECK"] = "true" if stability_check else "false"
vitest_env["CODEFLASH_LOOP_INDEX"] = "1"
# Total timeout for the entire benchmark run
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)
logger.debug(f"Running Vitest benchmarking tests: {' '.join(vitest_cmd)}")
logger.debug(
f"Vitest benchmarking config: min_loops={min_loops}, max_loops={max_loops}, "
f"target_duration={target_duration_ms}ms, stability_check={stability_check}"
)
total_start_time = time.time()
try:
run_args = get_cross_platform_subprocess_run_args(
cwd=effective_cwd, env=vitest_env, timeout=total_timeout, check=False, text=True, capture_output=True
)
result = subprocess.run(vitest_cmd, **run_args) # noqa: PLW1510
# Combine stderr into stdout for timing markers
stdout = result.stdout or ""
if result.stderr:
stdout = stdout + "\n" + result.stderr if stdout else result.stderr
result = subprocess.CompletedProcess(args=result.args, returncode=result.returncode, stdout=stdout, stderr="")
except subprocess.TimeoutExpired:
logger.warning(f"Vitest benchmarking timed out after {total_timeout}s")
result = subprocess.CompletedProcess(args=vitest_cmd, returncode=-1, stdout="", stderr="Benchmarking timed out")
except FileNotFoundError:
logger.error("Vitest not found for benchmarking")
result = subprocess.CompletedProcess(args=vitest_cmd, returncode=-1, stdout="", stderr="Vitest not found")
wall_clock_seconds = time.time() - total_start_time
logger.debug(f"Vitest benchmarking completed in {wall_clock_seconds:.2f}s")
return result_file_path, result
def run_vitest_line_profile_tests(
test_paths: TestFiles,
test_env: dict[str, str],
cwd: Path,
*,
timeout: int | None = None,
project_root: Path | None = None,
line_profile_output_file: Path | None = None,
) -> tuple[Path, subprocess.CompletedProcess]:
"""Run Vitest tests for line profiling.
This runs tests against source code that has been instrumented with line profiler.
The instrumentation collects execution counts and timing per line.
Args:
test_paths: TestFiles object containing test file information.
test_env: Environment variables for the test run.
cwd: Working directory for running tests.
timeout: Optional timeout in seconds for the subprocess.
project_root: Vitest project root (directory containing vitest.config or package.json).
line_profile_output_file: Path where line profile results will be written.
Returns:
Tuple of (result_file_path, subprocess_result).
"""
result_file_path = get_run_tmp_file(Path("vitest_line_profile_results.xml"))
# Get test files to run - use instrumented behavior files if available, otherwise benchmarking files
test_files = []
for file in test_paths.test_files:
if file.instrumented_behavior_file_path:
test_files.append(Path(file.instrumented_behavior_file_path))
elif file.benchmarking_file_path:
test_files.append(Path(file.benchmarking_file_path))
# Use provided project_root, or detect it as fallback
if project_root is None and test_files:
project_root = _find_vitest_project_root(test_files[0])
effective_cwd = project_root if project_root else cwd
logger.debug(f"Vitest line profiling working directory: {effective_cwd}")
# Ensure the codeflash npm package is installed
_ensure_runtime_files(effective_cwd)
# Build Vitest command for line profiling - simple run without benchmarking loops
vitest_cmd = [
"npx",
"vitest",
"run",
"--reporter=default",
"--reporter=junit",
"--no-file-parallelism", # Serial execution for consistent line profiling
]
vitest_cmd.append(f"--outputFile={result_file_path}")
if timeout:
vitest_cmd.append(f"--test-timeout={timeout * 1000}")
vitest_cmd.extend(str(f.resolve()) for f in test_files)
# Set up environment
vitest_env = test_env.copy()
codeflash_sqlite_file = get_run_tmp_file(Path("test_return_values_line_profile.sqlite"))
vitest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
vitest_env["CODEFLASH_TEST_ITERATION"] = "0"
vitest_env["CODEFLASH_LOOP_INDEX"] = "1"
vitest_env["CODEFLASH_MODE"] = "line_profile"
vitest_env["CODEFLASH_RANDOM_SEED"] = "42"
# Pass the line profile output file path to the instrumented code
if line_profile_output_file:
vitest_env["CODEFLASH_LINE_PROFILE_OUTPUT"] = str(line_profile_output_file)
# Subprocess timeout should be larger than per-test timeout to account for startup
subprocess_timeout = max(120, (timeout or 60) * 10)
logger.debug(f"Running Vitest line profile tests: {' '.join(vitest_cmd)}")
start_time_ns = time.perf_counter_ns()
try:
run_args = get_cross_platform_subprocess_run_args(
cwd=effective_cwd, env=vitest_env, timeout=subprocess_timeout, check=False, text=True, capture_output=True
)
result = subprocess.run(vitest_cmd, **run_args) # noqa: PLW1510
# Combine stderr into stdout
if result.stderr and not result.stdout:
result = subprocess.CompletedProcess(
args=result.args, returncode=result.returncode, stdout=result.stderr, stderr=""
)
elif result.stderr:
result = subprocess.CompletedProcess(
args=result.args, returncode=result.returncode, stdout=result.stdout + "\n" + result.stderr, stderr=""
)
logger.debug(f"Vitest line profile result: returncode={result.returncode}")
except subprocess.TimeoutExpired:
logger.warning(f"Vitest line profile tests timed out after {subprocess_timeout}s")
result = subprocess.CompletedProcess(
args=vitest_cmd, returncode=-1, stdout="", stderr="Line profile tests timed out"
)
except FileNotFoundError:
logger.error("Vitest not found for line profiling")
result = subprocess.CompletedProcess(args=vitest_cmd, returncode=-1, stdout="", stderr="Vitest not found")
finally:
wall_clock_ns = time.perf_counter_ns() - start_time_ns
logger.debug(f"Vitest line profile tests completed in {wall_clock_ns / 1e9:.2f}s")
return result_file_path, result

View file

@ -0,0 +1,18 @@
"""Language enum for multi-language support.
This module is kept separate to avoid circular imports.
"""
from enum import Enum
class Language(str, Enum):
"""Supported programming languages."""
PYTHON = "python"
JAVASCRIPT = "javascript"
TYPESCRIPT = "typescript"
JAVA = "java"
def __str__(self) -> str:
return self.value

View file

@ -6,13 +6,13 @@ import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import (
CodeContext,
FunctionFilterCriteria,
FunctionInfo,
HelperFunction,
Language,
ParentInfo,
ReferenceInfo,
TestInfo,
TestResult,
)
@ -45,6 +45,11 @@ class PythonSupport:
"""File extensions supported by Python."""
return (".py", ".pyw")
@property
def default_file_extension(self) -> str:
"""Default file extension for Python."""
return ".py"
@property
def test_framework(self) -> str:
"""Primary test framework for Python."""
@ -58,7 +63,7 @@ class PythonSupport:
def discover_functions(
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
) -> list[FunctionInfo]:
) -> list[FunctionToOptimize]:
"""Find all optimizable functions in a Python file.
Uses libcst to parse the file and find functions with return statements.
@ -68,12 +73,12 @@ class PythonSupport:
filter_criteria: Optional criteria to filter functions.
Returns:
List of FunctionInfo objects for discovered functions.
List of FunctionToOptimize objects for discovered functions.
"""
import libcst as cst
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionVisitor
from codeflash.discovery.functions_to_optimize import FunctionVisitor
criteria = filter_criteria or FunctionFilterCriteria()
@ -90,7 +95,7 @@ class PythonSupport:
function_visitor = FunctionVisitor(file_path=str(file_path))
wrapper.visit(function_visitor)
functions: list[FunctionInfo] = []
functions: list[FunctionToOptimize] = []
for func in function_visitor.functions:
if not isinstance(func, FunctionToOptimize):
continue
@ -107,23 +112,20 @@ class PythonSupport:
if criteria.require_return and func.starting_line is None:
continue
# Convert FunctionToOptimize to FunctionInfo
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in func.parents)
functions.append(
FunctionInfo(
name=func.function_name,
file_path=file_path,
start_line=func.starting_line or 1,
end_line=func.ending_line or 1,
start_col=func.starting_col,
end_col=func.ending_col,
parents=parents,
is_async=func.is_async,
is_method=len(func.parents) > 0,
language=Language.PYTHON,
)
# Add is_method field based on parents
func_with_is_method = FunctionToOptimize(
function_name=func.function_name,
file_path=file_path,
parents=func.parents,
starting_line=func.starting_line,
ending_line=func.ending_line,
starting_col=func.starting_col,
ending_col=func.ending_col,
is_async=func.is_async,
is_method=len(func.parents) > 0 and any(p.type == "ClassDef" for p in func.parents),
language="python",
)
functions.append(func_with_is_method)
return functions
@ -131,7 +133,9 @@ class PythonSupport:
logger.warning("Failed to discover functions in %s: %s", file_path, e)
return []
def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionInfo]) -> dict[str, list[TestInfo]]:
def discover_tests(
self, test_root: Path, source_functions: Sequence[FunctionToOptimize]
) -> dict[str, list[TestInfo]]:
"""Map source functions to their tests via static analysis.
Args:
@ -155,7 +159,7 @@ class PythonSupport:
try:
source = test_file.read_text()
# Check if function name appears in test file
if func.name in source:
if func.function_name in source:
result[func.qualified_name].append(
TestInfo(test_name=test_file.stem, test_file=test_file, test_class=None)
)
@ -166,7 +170,7 @@ class PythonSupport:
# === Code Analysis ===
def extract_code_context(self, function: FunctionInfo, project_root: Path, module_root: Path) -> CodeContext:
def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext:
"""Extract function code and its dependencies.
Uses jedi and libcst for Python code analysis.
@ -188,8 +192,8 @@ class PythonSupport:
# Extract the function source
lines = source.splitlines(keepends=True)
if function.start_line and function.end_line:
target_lines = lines[function.start_line - 1 : function.end_line]
if function.starting_line and function.ending_line:
target_lines = lines[function.starting_line - 1 : function.ending_line]
target_code = "".join(target_lines)
else:
target_code = ""
@ -216,7 +220,7 @@ class PythonSupport:
language=Language.PYTHON,
)
def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> list[HelperFunction]:
def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]:
"""Find helper functions called by the target function.
Uses jedi for Python code analysis.
@ -285,20 +289,130 @@ class PythonSupport:
)
except Exception as e:
logger.warning("Failed to find helpers for %s: %s", function.name, e)
logger.warning("Failed to find helpers for %s: %s", function.function_name, e)
return helpers
def find_references(
self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500
) -> list[ReferenceInfo]:
"""Find all references (call sites) to a function across the codebase.
Uses jedi to find all places where a Python function is called.
Args:
function: The function to find references for.
project_root: Root of the project to search.
tests_root: Root of tests directory (references in tests are excluded).
max_files: Maximum number of files to search.
Returns:
List of ReferenceInfo objects describing each reference location.
"""
try:
import jedi
source = function.file_path.read_text()
# Find the function position
script = jedi.Script(code=source, path=function.file_path)
names = script.get_names(all_scopes=True, definitions=True)
function_pos = None
for name in names:
if name.type == "function" and name.name == function.name:
# Check for class parent if it's a method
if function.class_name:
parent = name.parent()
if parent and parent.name == function.class_name and parent.type == "class":
function_pos = (name.line, name.column)
break
else:
function_pos = (name.line, name.column)
break
if function_pos is None:
return []
# Get references using jedi
script = jedi.Script(code=source, path=function.file_path, project=jedi.Project(path=project_root))
references = script.get_references(line=function_pos[0], column=function_pos[1])
result: list[ReferenceInfo] = []
seen_locations: set[tuple[Path, int, int]] = set()
for ref in references:
if not ref.module_path:
continue
ref_path = Path(ref.module_path)
# Skip the definition itself
if ref_path == function.file_path and ref.line == function_pos[0]:
continue
# Skip test files
if tests_root:
try:
ref_path.relative_to(tests_root)
continue
except ValueError:
pass
# Avoid duplicates
loc_key = (ref_path, ref.line, ref.column)
if loc_key in seen_locations:
continue
seen_locations.add(loc_key)
# Get context line
try:
ref_source = ref_path.read_text()
lines = ref_source.splitlines()
context = lines[ref.line - 1] if ref.line <= len(lines) else ""
except Exception:
context = ""
# Determine caller function
caller_function = None
try:
parent = ref.parent()
if parent and parent.type == "function":
caller_function = parent.name
except Exception:
pass
result.append(
ReferenceInfo(
file_path=ref_path,
line=ref.line,
column=ref.column,
end_line=ref.line,
end_column=ref.column + len(function.function_name),
context=context.strip(),
reference_type="call",
import_name=function.function_name,
caller_function=caller_function,
)
)
return result
except Exception as e:
logger.warning("Failed to find references for %s: %s", function.function_name, e)
return []
# === Code Transformation ===
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:
def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
"""Replace a function in source code with new implementation.
Uses libcst for Python code transformation.
Args:
source: Original source code.
function: FunctionInfo identifying the function to replace.
function: FunctionToOptimize identifying the function to replace.
new_source: New function source code.
Returns:
@ -319,7 +433,7 @@ class PythonSupport:
preexisting_objects=set(),
)
except Exception as e:
logger.warning("Failed to replace function %s: %s", function.name, e)
logger.warning("Failed to replace function %s: %s", function.function_name, e)
return source
def format_code(self, source: str, file_path: Path | None = None) -> str:
@ -465,7 +579,7 @@ class PythonSupport:
# === Instrumentation ===
def instrument_for_behavior(self, source: str, functions: Sequence[FunctionInfo]) -> str:
def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOptimize]) -> str:
"""Add behavior instrumentation to capture inputs/outputs.
Args:
@ -480,7 +594,7 @@ class PythonSupport:
# This is a pass-through for now
return source
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionInfo) -> str:
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str:
"""Add timing instrumentation to test code.
Args:
@ -721,7 +835,9 @@ class PythonSupport:
mode=testing_mode,
)
def instrument_source_for_line_profiler(self, func_info: FunctionInfo, line_profiler_output_file: Path) -> bool:
def instrument_source_for_line_profiler(
self, func_info: FunctionToOptimize, line_profiler_output_file: Path
) -> bool:
"""Instrument source code for line profiling.
Args:

View file

@ -11,7 +11,7 @@ import logging
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.base import Language
from codeflash.languages.language_enum import Language
if TYPE_CHECKING:
from collections.abc import Iterable
@ -30,6 +30,33 @@ _LANGUAGE_REGISTRY: dict[Language, type[LanguageSupport]] = {}
# Cache of instantiated language support objects
_SUPPORT_CACHE: dict[Language, LanguageSupport] = {}
# Flag to track if language modules have been imported
_languages_registered = False
def _ensure_languages_registered() -> None:
"""Ensure all language support modules are imported and registered.
This lazily imports the language support modules to avoid circular imports
at module load time. The imports trigger the @register_language decorators
which populate the registries.
"""
global _languages_registered
if _languages_registered:
return
# Import support modules to trigger registration
# These imports are deferred to avoid circular imports
import contextlib
with contextlib.suppress(ImportError):
from codeflash.languages.python import support as _
with contextlib.suppress(ImportError):
from codeflash.languages.javascript import support as _ # noqa: F401
_languages_registered = True
class UnsupportedLanguageError(Exception):
"""Raised when attempting to use an unsupported language."""
@ -123,6 +150,10 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport:
Raises:
UnsupportedLanguageError: If the language is not supported.
Note:
This function lazily imports language support modules on first call
to avoid circular import issues at module load time.
Example:
# By file path
lang = get_language_support(Path("example.py"))
@ -137,6 +168,7 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport:
lang = get_language_support("python")
"""
_ensure_languages_registered()
language: Language | None = None
if isinstance(identifier, Language):
@ -178,6 +210,42 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport:
_FRAMEWORK_CACHE: dict[str, LanguageSupport] = {}
def get_language_support_by_common_formatters(formatter_cmd: str | list[str]) -> LanguageSupport | None:
_ensure_languages_registered()
language: Language | None = None
if isinstance(formatter_cmd, str):
formatter_cmd = [formatter_cmd]
if len(formatter_cmd) == 1:
formatter_cmd = formatter_cmd[0].split(" ")
# Try as extension first
ext = None
py_formatters = ["black", "isort", "ruff", "autopep8", "yapf", "pyfmt"]
js_ts_formatters = ["prettier", "eslint", "biome", "rome", "deno", "standard", "tslint"]
if any(cmd in py_formatters for cmd in formatter_cmd):
ext = ".py"
elif any(cmd in js_ts_formatters for cmd in formatter_cmd):
ext = ".js"
if ext is None:
# can't determine language
return None
cls = _EXTENSION_REGISTRY[ext]
language = cls().language
# Return cached instance or create new one
if language not in _SUPPORT_CACHE:
if language not in _LANGUAGE_REGISTRY:
raise UnsupportedLanguageError(str(language), get_supported_languages())
_SUPPORT_CACHE[language] = _LANGUAGE_REGISTRY[language]()
return _SUPPORT_CACHE[language]
def get_language_support_by_framework(test_framework: str) -> LanguageSupport | None:
"""Get language support for a test framework.
@ -238,6 +306,7 @@ def detect_project_language(project_root: Path, module_root: Path) -> Language:
UnsupportedLanguageError: If no supported language is detected.
"""
_ensure_languages_registered()
extension_counts: dict[str, int] = {}
# Count files by extension
@ -265,6 +334,7 @@ def get_supported_languages() -> list[str]:
List of language name strings.
"""
_ensure_languages_registered()
return [lang.value for lang in _LANGUAGE_REGISTRY]
@ -275,6 +345,7 @@ def get_supported_extensions() -> list[str]:
List of extension strings (with leading dots).
"""
_ensure_languages_registered()
return list(_EXTENSION_REGISTRY.keys())
@ -300,10 +371,12 @@ def clear_registry() -> None:
Primarily useful for testing.
"""
global _languages_registered
_EXTENSION_REGISTRY.clear()
_LANGUAGE_REGISTRY.clear()
_SUPPORT_CACHE.clear()
_FRAMEWORK_CACHE.clear()
_languages_registered = False
def clear_cache() -> None:

View file

@ -0,0 +1,145 @@
"""Singleton for the current test framework being used in the codeflash session.
This module provides a centralized way to access and set the current test framework
throughout the codeflash codebase, similar to how the language singleton works.
For JavaScript/TypeScript projects, this determines whether to use Jest or Vitest
for test execution.
Usage:
from codeflash.languages.test_framework import (
current_test_framework,
set_current_test_framework,
is_jest,
is_vitest,
)
# Set the test framework at the start of a session (auto-detected from package.json)
set_current_test_framework("vitest")
# Check the current test framework anywhere in the codebase
if is_vitest():
# Vitest-specific code
...
# Get the current test framework
framework = current_test_framework()
"""
from __future__ import annotations
from typing import Literal
TestFramework = Literal["jest", "vitest", "mocha", "pytest", "unittest"]
# Module-level singleton for the current test framework
_current_test_framework: TestFramework | None = None
def current_test_framework() -> TestFramework | None:
"""Get the current test framework being used in this codeflash session.
Returns:
The current test framework string, or None if not set.
"""
return _current_test_framework
def set_current_test_framework(framework: TestFramework | str | None) -> None:
"""Set the current test framework for this codeflash session.
This should be called once at the start of an optimization run,
typically after reading the project configuration.
Args:
framework: Test framework name ("jest", "vitest", "mocha", "pytest", "unittest").
"""
global _current_test_framework
if _current_test_framework is not None:
return
if framework is not None:
framework = framework.lower()
if framework not in ("jest", "vitest", "mocha", "pytest", "unittest"):
# Default to jest for unknown JS frameworks, pytest for unknown Python
from codeflash.languages.current import is_javascript
framework = "jest" if is_javascript() else "pytest"
_current_test_framework = framework
def reset_test_framework() -> None:
"""Reset the current test framework to None.
Useful for testing or when starting a new session.
"""
global _current_test_framework
_current_test_framework = None
def is_jest() -> bool:
"""Check if the current test framework is Jest.
Returns:
True if the current test framework is Jest.
"""
return _current_test_framework == "jest"
def is_vitest() -> bool:
"""Check if the current test framework is Vitest.
Returns:
True if the current test framework is Vitest.
"""
return _current_test_framework == "vitest"
def is_mocha() -> bool:
"""Check if the current test framework is Mocha.
Returns:
True if the current test framework is Mocha.
"""
return _current_test_framework == "mocha"
def is_pytest() -> bool:
"""Check if the current test framework is pytest.
Returns:
True if the current test framework is pytest.
"""
return _current_test_framework == "pytest"
def is_unittest() -> bool:
"""Check if the current test framework is unittest.
Returns:
True if the current test framework is unittest.
"""
return _current_test_framework == "unittest"
def get_js_test_framework_or_default() -> TestFramework:
"""Get the current test framework for JS/TS, defaulting to 'jest' if not set.
This is a convenience function for JS/TS code that needs a framework.
Returns:
The current test framework, or 'jest' as default.
"""
if _current_test_framework in ("jest", "vitest", "mocha"):
return _current_test_framework
return "jest"

View file

@ -454,23 +454,47 @@ class TreeSitterAnalyzer:
return imports
def _walk_tree_for_imports(self, node: Node, source_bytes: bytes, imports: list[ImportInfo]) -> None:
"""Recursively walk the tree to find import statements."""
def _walk_tree_for_imports(
self, node: Node, source_bytes: bytes, imports: list[ImportInfo], in_function: bool = False
) -> None:
"""Recursively walk the tree to find import statements.
Args:
node: Current node to check.
source_bytes: Source code bytes.
imports: List to append found imports to.
in_function: Whether we're currently inside a function/method body.
"""
# Track when we enter function/method bodies
# These node types contain function/method bodies where require() should not be treated as imports
function_body_types = {
"function_declaration",
"method_definition",
"arrow_function",
"function_expression",
"function", # Generic function in some grammars
}
if node.type == "import_statement":
import_info = self._extract_import_info(node, source_bytes)
if import_info:
imports.append(import_info)
# Also handle require() calls for CommonJS
if node.type == "call_expression":
# Also handle require() calls for CommonJS, but only at module level
# require() inside functions is a dynamic import, not a module import
if node.type == "call_expression" and not in_function:
func_node = node.child_by_field_name("function")
if func_node and self.get_node_text(func_node, source_bytes) == "require":
import_info = self._extract_require_info(node, source_bytes)
if import_info:
imports.append(import_info)
# Update in_function flag for children
child_in_function = in_function or node.type in function_body_types
for child in node.children:
self._walk_tree_for_imports(child, source_bytes, imports)
self._walk_tree_for_imports(child, source_bytes, imports, child_in_function)
def _extract_import_info(self, node: Node, source_bytes: bytes) -> ImportInfo | None:
"""Extract import information from an import statement node."""
@ -841,20 +865,27 @@ class TreeSitterAnalyzer:
end_line=node.end_point[0] + 1,
)
def is_function_exported(self, source: str, function_name: str) -> tuple[bool, str | None]:
def is_function_exported(
self, source: str, function_name: str, class_name: str | None = None
) -> tuple[bool, str | None]:
"""Check if a function is exported and get its export name.
For class methods, also checks if the containing class is exported.
Args:
source: The source code to analyze.
function_name: The name of the function to check.
class_name: For class methods, the name of the containing class.
Returns:
Tuple of (is_exported, export_name). export_name may differ from
function_name if exported with an alias.
function_name if exported with an alias. For class methods,
returns the class export name.
"""
exports = self.find_exports(source)
# First, check if the function itself is directly exported
for export in exports:
# Check default export
if export.default_export == function_name:
@ -865,6 +896,18 @@ class TreeSitterAnalyzer:
if name == function_name:
return (True, alias if alias else name)
# For class methods, check if the containing class is exported
if class_name:
for export in exports:
# Check if class is default export
if export.default_export == class_name:
return (True, class_name)
# Check if class is in named exports
for name, alias in export.exported_names:
if name == class_name:
return (True, alias if alias else name)
return (False, None)
def find_function_calls(self, source: str, within_function: FunctionNode) -> list[str]:

View file

@ -4,7 +4,12 @@ If you might want to work with us on finally making performance a
solved problem, please reach out to us at careers@codeflash.ai. We're hiring!
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test
@ -16,6 +21,9 @@ from codeflash.code_utils.version_check import check_for_newer_minor_version
from codeflash.telemetry import posthog_cf
from codeflash.telemetry.sentry import init_sentry
if TYPE_CHECKING:
from argparse import Namespace
def main() -> None:
"""Entry point for the codeflash command-line interface."""
@ -39,7 +47,12 @@ def main() -> None:
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
ask_run_end_to_end_test(args)
else:
args = process_pyproject_config(args)
# Check for first-run experience (no config exists)
loaded_args = _handle_config_loading(args)
if loaded_args is None:
sys.exit(0)
args = loaded_args
if not env_utils.check_formatter_installed(args.formatter_cmds):
return
args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args)
@ -51,6 +64,52 @@ def main() -> None:
optimizer.run_with_args(args)
def _handle_config_loading(args: Namespace) -> Namespace | None:
"""Handle config loading with first-run experience support.
If no config exists and not in CI, triggers the first-run experience.
Otherwise, loads config normally.
Args:
args: CLI args namespace.
Returns:
Updated args with config loaded, or None if user cancelled first-run.
"""
from codeflash.setup.first_run import handle_first_run, is_first_run
# Check if we're in CI environment
is_ci = any(
var in ("true", "1", "True") for var in [os.environ.get("CI", ""), os.environ.get("GITHUB_ACTIONS", "")]
)
# Check if first run (no config exists)
if is_first_run() and not is_ci:
# Skip API key check if already set
skip_api_key = bool(os.environ.get("CODEFLASH_API_KEY"))
# Handle first-run experience
result = handle_first_run(args=args, skip_confirm=getattr(args, "yes", False), skip_api_key=skip_api_key)
if result is None:
return None
# Merge first-run results with any CLI overrides
args = result
# Still need to process some config values
# Config might not exist yet if first run just saved it - that's OK
import contextlib
with contextlib.suppress(ValueError):
args = process_pyproject_config(args)
return args
# Normal config loading
return process_pyproject_config(args)
def print_codeflash_banner() -> None:
paneled_text(
CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"}

View file

@ -0,0 +1,18 @@
"""Simple function-related types with no dependencies.
This module contains basic types used for function representation.
It is intentionally kept dependency-free to avoid circular imports.
"""
from __future__ import annotations
from pydantic.dataclasses import dataclass
@dataclass(frozen=True)
class FunctionParent:
name: str
type: str
def __str__(self) -> str:
return f"{self.type}:{self.name}"

View file

@ -599,10 +599,8 @@ class CodePosition:
col_no: int
@dataclass(frozen=True)
class FunctionParent:
name: str
type: str
# Re-export FunctionParent for backward compatibility
from codeflash.models.function_types import FunctionParent # noqa: E402
class OriginalCodeBaseline(BaseModel):

View file

@ -77,7 +77,7 @@ from codeflash.context.unused_definition_remover import detect_unused_helper_fun
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
from codeflash.either import Failure, Success, is_successful
from codeflash.languages import is_java, is_python
from codeflash.languages.base import FunctionInfo, Language
from codeflash.languages.base import Language
from codeflash.languages.current import current_language_support, is_typescript
from codeflash.languages.javascript.module_system import detect_module_system
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown
@ -2285,10 +2285,10 @@ class FunctionOptimizer:
else self.function_trace_id,
"coverage_message": coverage_message,
"replay_tests": replay_tests,
"concolic_tests": concolic_tests,
# "concolic_tests": concolic_tests,
"language": self.function_to_optimize.language,
"original_line_profiler": original_code_baseline.line_profile_results.get("str_out", ""),
"optimized_line_profiler": best_optimization.line_profiler_test_results.get("str_out", ""),
# "original_line_profiler": original_code_baseline.line_profile_results.get("str_out", ""),
# "optimized_line_profiler": best_optimization.line_profiler_test_results.get("str_out", ""),
}
raise_pr = not self.args.no_pr
@ -2955,18 +2955,8 @@ class FunctionOptimizer:
# NOTE: currently this handles single file only, add support to multi file instrumentation (or should it be kept for the main file only)
original_source = Path(self.function_to_optimize.file_path).read_text()
# Instrument source code
func_info = FunctionInfo(
name=self.function_to_optimize.function_name,
file_path=self.function_to_optimize.file_path,
start_line=self.function_to_optimize.starting_line,
end_line=self.function_to_optimize.ending_line,
start_col=self.function_to_optimize.starting_col,
end_col=self.function_to_optimize.ending_col,
is_async=self.function_to_optimize.is_async,
language=self.language_support.language,
)
success = self.language_support.instrument_source_for_line_profiler(
func_info=func_info, line_profiler_output_file=line_profiler_output_path
func_info=self.function_to_optimize, line_profiler_output_file=line_profiler_output_path
)
if not success:
return {"timings": {}, "unit": 0, "str_out": ""}

View file

@ -90,6 +90,33 @@ class Optimizer:
current = current.parent
return None
def _verify_js_requirements(self) -> None:
"""Verify JavaScript/TypeScript requirements before optimization.
Checks that Node.js, npm, and the test framework are available.
Logs warnings if requirements are not met but does not abort.
"""
from codeflash.languages import get_language_support
from codeflash.languages.base import Language
from codeflash.languages.test_framework import get_js_test_framework_or_default
js_project_root = self.test_cfg.js_project_root
if not js_project_root:
return
try:
js_support = get_language_support(Language.JAVASCRIPT)
test_framework = get_js_test_framework_or_default()
success, errors = js_support.verify_requirements(js_project_root, test_framework)
if not success:
logger.warning("JavaScript requirements check found issues:")
for error in errors:
logger.warning(f" - {error}")
except Exception as e:
logger.debug(f"Failed to verify JS requirements: {e}")
def run_benchmarks(
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]:
@ -466,6 +493,8 @@ class Optimizer:
# For JavaScript, also set js_project_root for test execution
if is_javascript():
self.test_cfg.js_project_root = self._find_js_project_root(file_path)
# Verify JS requirements before proceeding
self._verify_js_requirements()
break
if self.args.all:

View file

@ -281,8 +281,8 @@ def check_create_pr(
function_trace_id: str,
coverage_message: str,
replay_tests: str,
concolic_tests: str,
root_dir: Path,
concolic_tests: str = "",
git_remote: Optional[str] = None,
optimization_review: str = "",
original_line_profiler: str | None = None,

View file

@ -0,0 +1,22 @@
"""Setup module for Codeflash auto-detection and first-run experience.
This module provides:
- Universal project detection across all supported languages
- First-run experience with auto-detection and quick confirm
- Config writing to native config files (pyproject.toml, package.json)
"""
from codeflash.setup.config_schema import CodeflashConfig
from codeflash.setup.config_writer import write_config
from codeflash.setup.detector import DetectedProject, detect_project, has_existing_config
from codeflash.setup.first_run import handle_first_run, is_first_run
__all__ = [
"CodeflashConfig",
"DetectedProject",
"detect_project",
"handle_first_run",
"has_existing_config",
"is_first_run",
"write_config",
]

View file

@ -0,0 +1,200 @@
"""Codeflash configuration schema using Pydantic.
This module provides a language-agnostic internal representation of Codeflash
configuration that can be serialized to different formats (TOML, JSON).
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, ConfigDict, Field
if TYPE_CHECKING:
from pathlib import Path
class CodeflashConfig(BaseModel):
"""Internal representation of Codeflash configuration.
This is the canonical config format used internally. It can be converted
to/from pyproject.toml (Python) or package.json (JS/TS) formats.
Note: All paths are stored as strings (relative to project root).
"""
# Core settings (always present after detection)
language: str = Field(description="Project language: python, javascript, typescript")
module_root: str = Field(default=".", description="Root directory containing source code")
tests_root: str | None = Field(default=None, description="Root directory containing tests")
# Tooling settings (auto-detected, can be overridden)
test_framework: str | None = Field(default=None, description="Test framework: pytest, jest, vitest, mocha")
formatter_cmds: list[str] = Field(default_factory=list, description="Formatter commands")
# Optional settings
ignore_paths: list[str] = Field(default_factory=list, description="Paths to ignore")
benchmarks_root: str | None = Field(default=None, description="Benchmarks directory")
# Git settings
git_remote: str = Field(default="origin", description="Git remote for PRs")
# Privacy settings
disable_telemetry: bool = Field(default=False, description="Disable telemetry")
# Python-specific settings
pytest_cmd: str = Field(default="pytest", description="Pytest command (Python only)")
disable_imports_sorting: bool = Field(default=False, description="Disable import sorting (Python only)")
override_fixtures: bool = Field(default=False, description="Override test fixtures (Python only)")
model_config = ConfigDict(extra="allow") # Allow extra fields for forward compatibility
def to_pyproject_dict(self) -> dict[str, Any]:
"""Convert to pyproject.toml [tool.codeflash] format.
Uses kebab-case keys as per TOML conventions.
Only includes non-default values to keep config minimal.
"""
config: dict[str, Any] = {}
# Always include required fields
config["module-root"] = self.module_root
if self.tests_root:
config["tests-root"] = self.tests_root
# Include non-default optional fields
if self.ignore_paths:
config["ignore-paths"] = self.ignore_paths
if self.formatter_cmds and self.formatter_cmds != ["black $file"]:
config["formatter-cmds"] = self.formatter_cmds
elif not self.formatter_cmds:
config["formatter-cmds"] = ["disabled"]
if self.benchmarks_root:
config["benchmarks-root"] = self.benchmarks_root
if self.git_remote and self.git_remote != "origin":
config["git-remote"] = self.git_remote
if self.disable_telemetry:
config["disable-telemetry"] = True
if self.pytest_cmd and self.pytest_cmd != "pytest":
config["pytest-cmd"] = self.pytest_cmd
if self.disable_imports_sorting:
config["disable-imports-sorting"] = True
if self.override_fixtures:
config["override-fixtures"] = True
return config
def to_package_json_dict(self) -> dict[str, Any]:
"""Convert to package.json codeflash section format.
Uses camelCase keys as per JSON/JS conventions.
Only includes values that override auto-detection.
"""
config: dict[str, Any] = {}
# Module root (only if not auto-detected default)
if self.module_root and self.module_root not in (".", "src"):
config["moduleRoot"] = self.module_root
if self.tests_root:
config["testsRoot"] = self.tests_root
# Formatter (only if explicitly set)
if self.formatter_cmds:
config["formatterCmds"] = self.formatter_cmds
# Ignore paths (only if set)
if self.ignore_paths:
config["ignorePaths"] = self.ignore_paths
# Benchmarks root
if self.benchmarks_root:
config["benchmarksRoot"] = self.benchmarks_root
# Git remote (only if not default)
if self.git_remote and self.git_remote != "origin":
config["gitRemote"] = self.git_remote
# Telemetry
if self.disable_telemetry:
config["disableTelemetry"] = True
return config
@classmethod
def from_detected_project(cls, detected: Any) -> CodeflashConfig:
"""Create config from DetectedProject.
Args:
detected: DetectedProject instance from detector.
Returns:
CodeflashConfig instance.
"""
return cls(
language=detected.language,
module_root=str(detected.module_root.relative_to(detected.project_root))
if detected.module_root != detected.project_root
else ".",
tests_root=str(detected.tests_root.relative_to(detected.project_root)) if detected.tests_root else None,
test_framework=detected.test_runner,
formatter_cmds=detected.formatter_cmds,
ignore_paths=[
str(p.relative_to(detected.project_root)) for p in detected.ignore_paths if p != detected.project_root
],
pytest_cmd=detected.test_runner if detected.language == "python" else "pytest",
)
@classmethod
def from_pyproject_dict(cls, data: dict[str, Any], project_root: Path | None = None) -> CodeflashConfig:
"""Create config from pyproject.toml [tool.codeflash] section.
Args:
data: Dict from [tool.codeflash] section.
project_root: Project root path (reserved for future path resolution).
Returns:
CodeflashConfig instance.
"""
_ = project_root # Reserved for future path resolution
def convert_key(key: str) -> str:
"""Convert kebab-case to snake_case."""
return key.replace("-", "_")
converted = {convert_key(k): v for k, v in data.items()}
converted.setdefault("language", "python")
return cls(**converted)
@classmethod
def from_package_json_dict(cls, data: dict[str, Any], project_root: Path | None = None) -> CodeflashConfig:
"""Create config from package.json codeflash section.
Args:
data: Dict from package.json "codeflash" key.
project_root: Project root path (reserved for future path resolution).
Returns:
CodeflashConfig instance.
"""
_ = project_root # Reserved for future path resolution
def convert_key(key: str) -> str:
"""Convert camelCase to snake_case."""
import re
return re.sub(r"(?<!^)(?=[A-Z])", "_", key).lower()
converted = {convert_key(k): v for k, v in data.items()}
converted.setdefault("language", "javascript")
return cls(**converted)

View file

@ -0,0 +1,246 @@
"""Config writer for native config files.
This module writes Codeflash configuration to native config files:
- Python: pyproject.toml [tool.codeflash]
- JavaScript/TypeScript: package.json { "codeflash": {} }
"""
from __future__ import annotations
import json
from typing import TYPE_CHECKING
import tomlkit
if TYPE_CHECKING:
from pathlib import Path
from codeflash.setup.config_schema import CodeflashConfig
from codeflash.setup.detector import DetectedProject
def write_config(detected: DetectedProject, config: CodeflashConfig | None = None) -> tuple[bool, str]:
"""Write Codeflash config to the appropriate native config file.
Args:
detected: DetectedProject with project information.
config: Optional CodeflashConfig to write. If None, creates from detected.
Returns:
Tuple of (success, message).
"""
from codeflash.setup.config_schema import CodeflashConfig
if config is None:
config = CodeflashConfig.from_detected_project(detected)
if detected.language == "python":
return _write_pyproject_toml(detected.project_root, config)
return _write_package_json(detected.project_root, config)
def _write_pyproject_toml(project_root: Path, config: CodeflashConfig) -> tuple[bool, str]:
"""Write config to pyproject.toml [tool.codeflash] section.
Creates pyproject.toml if it doesn't exist.
Preserves existing content and formatting.
Args:
project_root: Project root directory.
config: CodeflashConfig to write.
Returns:
Tuple of (success, message).
"""
pyproject_path = project_root / "pyproject.toml"
try:
# Load existing or create new
if pyproject_path.exists():
with pyproject_path.open("rb") as f:
doc = tomlkit.parse(f.read())
else:
doc = tomlkit.document()
# Ensure [tool] section exists
if "tool" not in doc:
doc["tool"] = tomlkit.table()
# Create codeflash section
codeflash_table = tomlkit.table()
codeflash_table.add(tomlkit.comment("Codeflash configuration - https://docs.codeflash.ai"))
# Add config values
config_dict = config.to_pyproject_dict()
for key, value in config_dict.items():
codeflash_table[key] = value
# Update the document
doc["tool"]["codeflash"] = codeflash_table
# Write back
with pyproject_path.open("w", encoding="utf8") as f:
f.write(tomlkit.dumps(doc))
return True, f"Config saved to {pyproject_path}"
except Exception as e:
return False, f"Failed to write pyproject.toml: {e}"
def _write_package_json(project_root: Path, config: CodeflashConfig) -> tuple[bool, str]:
"""Write config to package.json codeflash section.
Preserves existing content and formatting.
Creates minimal config (only non-default values).
Args:
project_root: Project root directory.
config: CodeflashConfig to write.
Returns:
Tuple of (success, message).
"""
package_json_path = project_root / "package.json"
if not package_json_path.exists():
return False, f"No package.json found at {project_root}"
try:
# Load existing
with package_json_path.open(encoding="utf8") as f:
doc = json.load(f)
# Get config dict (only non-default values)
config_dict = config.to_package_json_dict()
# Update or remove codeflash section
if config_dict:
doc["codeflash"] = config_dict
action = "Updated"
else:
# Remove codeflash section if empty (all defaults)
doc.pop("codeflash", None)
action = "Using auto-detected defaults (no config needed)"
# Write back with nice formatting
with package_json_path.open("w", encoding="utf8") as f:
json.dump(doc, f, indent=2)
f.write("\n") # Trailing newline
if config_dict:
return True, f"{action} config in {package_json_path}"
return True, action
except json.JSONDecodeError as e:
return False, f"Invalid JSON in package.json: {e}"
except Exception as e:
return False, f"Failed to write package.json: {e}"
def create_pyproject_toml(project_root: Path) -> tuple[bool, str]:
"""Create a minimal pyproject.toml file.
Used when no pyproject.toml exists for a Python project.
Args:
project_root: Project root directory.
Returns:
Tuple of (success, message).
"""
pyproject_path = project_root / "pyproject.toml"
if pyproject_path.exists():
return False, f"pyproject.toml already exists at {pyproject_path}"
try:
doc = tomlkit.document()
doc.add(tomlkit.comment("Created by Codeflash"))
doc.add(tomlkit.nl())
# Add minimal [tool.codeflash] section
tool_table = tomlkit.table()
codeflash_table = tomlkit.table()
codeflash_table.add(tomlkit.comment("Codeflash configuration - https://docs.codeflash.ai"))
tool_table["codeflash"] = codeflash_table
doc["tool"] = tool_table
with pyproject_path.open("w", encoding="utf8") as f:
f.write(tomlkit.dumps(doc))
return True, f"Created {pyproject_path}"
except Exception as e:
return False, f"Failed to create pyproject.toml: {e}"
def remove_config(project_root: Path, language: str) -> tuple[bool, str]:
"""Remove Codeflash config from native config file.
Args:
project_root: Project root directory.
language: Project language ("python", "javascript", "typescript").
Returns:
Tuple of (success, message).
"""
if language == "python":
return _remove_from_pyproject(project_root)
return _remove_from_package_json(project_root)
def _remove_from_pyproject(project_root: Path) -> tuple[bool, str]:
"""Remove [tool.codeflash] section from pyproject.toml."""
pyproject_path = project_root / "pyproject.toml"
if not pyproject_path.exists():
return True, "No pyproject.toml found"
try:
with pyproject_path.open("rb") as f:
doc = tomlkit.parse(f.read())
if "tool" in doc and "codeflash" in doc["tool"]:
del doc["tool"]["codeflash"]
with pyproject_path.open("w", encoding="utf8") as f:
f.write(tomlkit.dumps(doc))
return True, "Removed [tool.codeflash] section from pyproject.toml"
return True, "No codeflash config found in pyproject.toml"
except Exception as e:
return False, f"Failed to remove config: {e}"
def _remove_from_package_json(project_root: Path) -> tuple[bool, str]:
"""Remove codeflash section from package.json."""
package_json_path = project_root / "package.json"
if not package_json_path.exists():
return True, "No package.json found"
try:
with package_json_path.open(encoding="utf8") as f:
doc = json.load(f)
if "codeflash" in doc:
del doc["codeflash"]
with package_json_path.open("w", encoding="utf8") as f:
json.dump(doc, f, indent=2)
f.write("\n")
return True, "Removed codeflash section from package.json"
return True, "No codeflash config found in package.json"
except Exception as e:
return False, f"Failed to remove config: {e}"

692
codeflash/setup/detector.py Normal file
View file

@ -0,0 +1,692 @@
"""Universal project detection engine for Codeflash.
This module provides a single detection engine that works for all supported languages,
consolidating detection logic from various parts of the codebase.
Usage:
from codeflash.setup import detect_project
detected = detect_project()
print(f"Language: {detected.language}")
print(f"Module root: {detected.module_root}")
print(f"Test runner: {detected.test_runner}")
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import tomlkit
@dataclass
class DetectedProject:
"""Result of project auto-detection.
All paths are absolute. The confidence score indicates how certain
we are about the detection (0.0 = guessing, 1.0 = certain).
"""
# Core detection results
language: str # "python" | "javascript" | "typescript"
project_root: Path
module_root: Path
tests_root: Path | None
# Tooling detection
test_runner: str # "pytest" | "jest" | "vitest" | "mocha"
formatter_cmds: list[str]
# Ignore paths (absolute paths to ignore)
ignore_paths: list[Path] = field(default_factory=list)
# Confidence score for the detection (0.0 - 1.0)
confidence: float = 0.8
# Detection details (for debugging/display)
detection_details: dict[str, str] = field(default_factory=dict)
def to_display_dict(self) -> dict[str, str]:
"""Convert to dictionary for display purposes."""
formatter_display = self.formatter_cmds[0] if self.formatter_cmds else "none detected"
if len(self.formatter_cmds) > 1:
formatter_display += f" (+{len(self.formatter_cmds) - 1} more)"
ignore_display = ", ".join(p.name for p in self.ignore_paths[:3])
if len(self.ignore_paths) > 3:
ignore_display += f" (+{len(self.ignore_paths) - 3} more)"
return {
"Language": self.language.capitalize(),
"Module root": str(self.module_root.relative_to(self.project_root))
if self.module_root != self.project_root
else ".",
"Tests root": str(self.tests_root.relative_to(self.project_root)) if self.tests_root else "not detected",
"Test runner": self.test_runner,
"Formatter": formatter_display or "none",
"Ignoring": ignore_display or "defaults only",
}
def detect_project(path: Path | None = None) -> DetectedProject:
"""Auto-detect all project settings.
This is the main entry point for project detection. It finds the project root,
detects the language, and auto-detects all configuration values.
Args:
path: Starting path for detection. Defaults to current working directory.
Returns:
DetectedProject with all detected settings.
Raises:
ValueError: If no valid project can be detected.
"""
start_path = path or Path.cwd()
detection_details: dict[str, str] = {}
# Step 1: Find project root
project_root = _find_project_root(start_path)
if project_root is None:
# No project root found, use start_path
project_root = start_path
detection_details["project_root"] = "using current directory (no markers found)"
else:
detection_details["project_root"] = f"found at {project_root}"
# Step 2: Detect language
language, lang_confidence, lang_detail = _detect_language(project_root)
detection_details["language"] = lang_detail
# Step 3: Detect module root
module_root, module_detail = _detect_module_root(project_root, language)
detection_details["module_root"] = module_detail
# Step 4: Detect tests root
tests_root, tests_detail = _detect_tests_root(project_root, language)
detection_details["tests_root"] = tests_detail
# Step 5: Detect test runner
test_runner, runner_detail = _detect_test_runner(project_root, language)
detection_details["test_runner"] = runner_detail
# Step 6: Detect formatter
formatter_cmds, formatter_detail = _detect_formatter(project_root, language)
detection_details["formatter"] = formatter_detail
# Step 7: Detect ignore paths
ignore_paths, ignore_detail = _detect_ignore_paths(project_root, language)
detection_details["ignore_paths"] = ignore_detail
# Calculate overall confidence
confidence = lang_confidence * 0.4 + 0.6 # Language detection is 40% of confidence
return DetectedProject(
language=language,
project_root=project_root,
module_root=module_root,
tests_root=tests_root,
test_runner=test_runner,
formatter_cmds=formatter_cmds,
ignore_paths=ignore_paths,
confidence=confidence,
detection_details=detection_details,
)
def _find_project_root(start_path: Path) -> Path | None:
"""Find the project root by walking up the directory tree.
Looks for:
- .git directory (git repository root)
- pyproject.toml (Python project)
- package.json (JavaScript/TypeScript project)
- Cargo.toml (Rust project - future)
Args:
start_path: Starting directory for search.
Returns:
Path to project root, or None if not found.
"""
current = start_path.resolve()
while current != current.parent:
# Check for project markers
markers = [".git", "pyproject.toml", "package.json", "Cargo.toml"]
for marker in markers:
if (current / marker).exists():
return current
current = current.parent
return None
def _detect_language(project_root: Path) -> tuple[str, float, str]:
"""Detect the primary programming language of the project.
Detection priority:
1. tsconfig.json TypeScript (high confidence)
2. pyproject.toml or setup.py Python (high confidence)
3. package.json JavaScript (medium confidence)
4. File extension counting best guess (low confidence)
Args:
project_root: Root directory of the project.
Returns:
Tuple of (language, confidence, detail_string).
"""
has_tsconfig = (project_root / "tsconfig.json").exists()
has_pyproject = (project_root / "pyproject.toml").exists()
has_setup_py = (project_root / "setup.py").exists()
has_package_json = (project_root / "package.json").exists()
# TypeScript (tsconfig.json is definitive)
if has_tsconfig:
return "typescript", 1.0, "tsconfig.json found"
# Python (pyproject.toml or setup.py)
if has_pyproject or has_setup_py:
marker = "pyproject.toml" if has_pyproject else "setup.py"
# Check if it's also a JS project (monorepo)
if has_package_json:
# Count files to determine primary language
py_count = len(list(project_root.rglob("*.py")))
js_count = len(list(project_root.rglob("*.js"))) + len(list(project_root.rglob("*.ts")))
if js_count > py_count * 2: # JS files significantly outnumber Python
return "javascript", 0.7, "package.json found (more JS files than Python)"
return "python", 1.0, f"{marker} found"
# JavaScript (package.json without Python markers)
if has_package_json:
return "javascript", 0.9, "package.json found"
# Fall back to file extension counting
py_count = len(list(project_root.rglob("*.py")))
js_count = len(list(project_root.rglob("*.js")))
ts_count = len(list(project_root.rglob("*.ts")))
if ts_count > 0:
return "typescript", 0.5, f"found {ts_count} .ts files"
if js_count > py_count:
return "javascript", 0.5, f"found {js_count} .js files"
if py_count > 0:
return "python", 0.5, f"found {py_count} .py files"
# Default to Python
return "python", 0.3, "defaulting to Python"
def _detect_module_root(project_root: Path, language: str) -> tuple[Path, str]:
"""Detect the module/source root directory.
Args:
project_root: Root directory of the project.
language: Detected language.
Returns:
Tuple of (module_root_path, detail_string).
"""
if language in ("javascript", "typescript"):
return _detect_js_module_root(project_root)
return _detect_python_module_root(project_root)
def _detect_python_module_root(project_root: Path) -> tuple[Path, str]:
"""Detect Python module root.
Priority:
1. pyproject.toml [tool.poetry.name] or [project.name]
2. src/ directory with __init__.py
3. Directory with __init__.py matching project name
4. src/ directory (even without __init__.py)
5. Project root
"""
# Try to get project name from pyproject.toml
pyproject_path = project_root / "pyproject.toml"
project_name = None
if pyproject_path.exists():
try:
with pyproject_path.open("rb") as f:
data = tomlkit.parse(f.read())
# Try poetry name
project_name = data.get("tool", {}).get("poetry", {}).get("name")
# Try standard project name
if not project_name:
project_name = data.get("project", {}).get("name")
except Exception:
pass
# Check for src layout
src_dir = project_root / "src"
if src_dir.is_dir():
# Check for package inside src
if project_name:
pkg_dir = src_dir / project_name
if pkg_dir.is_dir() and (pkg_dir / "__init__.py").exists():
return pkg_dir, f"src/{project_name}/ (from pyproject.toml name)"
# Check for any package in src
for child in src_dir.iterdir():
if child.is_dir() and (child / "__init__.py").exists():
return child, f"src/{child.name}/ (first package in src)"
# Use src/ even without __init__.py
return src_dir, "src/ directory"
# Check for package at project root
if project_name:
pkg_dir = project_root / project_name
if pkg_dir.is_dir() and (pkg_dir / "__init__.py").exists():
return pkg_dir, f"{project_name}/ (from pyproject.toml name)"
# Look for any directory with __init__.py at project root
for child in project_root.iterdir():
if (
child.is_dir()
and not child.name.startswith(".")
and child.name not in ("tests", "test", "docs", "venv", ".venv", "env", "node_modules")
):
if (child / "__init__.py").exists():
return child, f"{child.name}/ (has __init__.py)"
# Default to project root
return project_root, "project root (no package structure detected)"
def _detect_js_module_root(project_root: Path) -> tuple[Path, str]:
"""Detect JavaScript/TypeScript module root.
Priority:
1. package.json "exports" field
2. package.json "module" field (ESM)
3. package.json "main" field (CJS)
4. src/ directory
5. lib/ directory
6. Project root
"""
package_json_path = project_root / "package.json"
package_data: dict[str, Any] = {}
if package_json_path.exists():
try:
with package_json_path.open(encoding="utf8") as f:
package_data = json.load(f)
except (json.JSONDecodeError, OSError):
pass
# Check exports field (modern Node.js)
exports = package_data.get("exports")
if exports:
entry_path = _extract_entry_path(exports)
if entry_path:
parent = Path(entry_path).parent
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
return project_root / parent, f'{parent.as_posix()}/ (from package.json "exports")'
# Check module field (ESM)
module_field = package_data.get("module")
if module_field and isinstance(module_field, str):
parent = Path(module_field).parent
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
return project_root / parent, f'{parent.as_posix()}/ (from package.json "module")'
# Check main field (CJS)
main_field = package_data.get("main")
if main_field and isinstance(main_field, str):
parent = Path(main_field).parent
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
return project_root / parent, f'{parent.as_posix()}/ (from package.json "main")'
# Check for common source directories
for src_dir in ["src", "lib", "source"]:
if (project_root / src_dir).is_dir():
return project_root / src_dir, f"{src_dir}/ directory"
# Default to project root
return project_root, "project root"
def _extract_entry_path(exports: Any) -> str | None:
"""Extract entry path from package.json exports field."""
if isinstance(exports, str):
return exports
if isinstance(exports, dict):
# Handle {"." : "./src/index.js"} or {".": {"import": "./src/index.js"}}
main_export = exports.get(".") or exports.get("import") or exports.get("default")
if isinstance(main_export, str):
return main_export
if isinstance(main_export, dict):
return main_export.get("import") or main_export.get("default") or main_export.get("require")
return None
def _detect_tests_root(project_root: Path, language: str) -> tuple[Path | None, str]:
"""Detect the tests directory.
Common patterns:
- tests/ or test/
- __tests__/ (JavaScript)
- spec/ (Ruby/JavaScript)
"""
# Common test directory names
test_dirs = ["tests", "test", "__tests__", "spec"]
for test_dir in test_dirs:
test_path = project_root / test_dir
if test_path.is_dir():
return test_path, f"{test_dir}/ directory"
# For Python, check if tests are alongside source
if language == "python":
# Look for test_*.py files in project root
test_files = list(project_root.glob("test_*.py"))
if test_files:
return project_root, "test files in project root"
# For JS/TS, check for *.test.js or *.spec.js files
if language in ("javascript", "typescript"):
test_patterns = ["*.test.js", "*.test.ts", "*.spec.js", "*.spec.ts"]
for pattern in test_patterns:
test_files = list(project_root.rglob(pattern))
if test_files:
# Find common parent
return project_root, f"found {pattern} files"
return None, "not detected"
def _detect_test_runner(project_root: Path, language: str) -> tuple[str, str]:
"""Detect the test runner.
Python: pytest > unittest
JavaScript: vitest > jest > mocha
"""
if language in ("javascript", "typescript"):
return _detect_js_test_runner(project_root)
return _detect_python_test_runner(project_root)
def _detect_python_test_runner(project_root: Path) -> tuple[str, str]:
"""Detect Python test runner."""
# Check for pytest markers
pytest_markers = ["pytest.ini", "pyproject.toml", "conftest.py", "setup.cfg"]
for marker in pytest_markers:
marker_path = project_root / marker
if marker_path.exists():
if marker == "pyproject.toml":
# Check for [tool.pytest] section
try:
with marker_path.open("rb") as f:
data = tomlkit.parse(f.read())
if "tool" in data and "pytest" in data["tool"]:
return "pytest", "pyproject.toml [tool.pytest]"
except Exception:
pass
elif marker == "conftest.py":
return "pytest", "conftest.py found"
elif marker in ("pytest.ini", "setup.cfg"):
# Check for pytest section in setup.cfg
if marker == "setup.cfg":
try:
content = marker_path.read_text(encoding="utf8")
if "[tool:pytest]" in content or "[pytest]" in content:
return "pytest", "setup.cfg [pytest]"
except Exception:
pass
else:
return "pytest", "pytest.ini found"
# Default to pytest (most common)
return "pytest", "default"
def _detect_js_test_runner(project_root: Path) -> tuple[str, str]:
"""Detect JavaScript test runner."""
package_json_path = project_root / "package.json"
if not package_json_path.exists():
return "jest", "default (no package.json)"
try:
with package_json_path.open(encoding="utf8") as f:
package_data = json.load(f)
except (json.JSONDecodeError, OSError):
return "jest", "default (invalid package.json)"
runners = ["vitest", "jest", "mocha"]
dev_deps = package_data.get("devDependencies", {})
deps = package_data.get("dependencies", {})
all_deps = {**deps, **dev_deps}
# Check dependencies (order matters - prefer more modern runners)
for runner in runners:
if runner in all_deps:
return runner, "from devDependencies"
# Parse scripts.test for hints
scripts = package_data.get("scripts", {})
test_script = scripts.get("test", "")
if isinstance(test_script, str):
test_lower = test_script.lower()
for runner in runners:
if runner in test_lower:
return runner, "from scripts.test"
# Check for config files
config_files = {
"vitest": ["vitest.config.js", "vitest.config.ts", "vitest.config.mjs"],
"jest": ["jest.config.js", "jest.config.ts", "jest.config.mjs", "jest.config.json"],
"mocha": [".mocharc.js", ".mocharc.json", ".mocharc.yaml"],
}
for runner, configs in config_files.items():
for config in configs:
if (project_root / config).exists():
return runner, f"{config} found"
return "jest", "default"
def _detect_formatter(project_root: Path, language: str) -> tuple[list[str], str]:
"""Detect code formatter.
Python: ruff > black
JavaScript: prettier > eslint --fix
"""
if language in ("javascript", "typescript"):
return _detect_js_formatter(project_root)
return _detect_python_formatter(project_root)
def _detect_python_formatter(project_root: Path) -> tuple[list[str], str]:
"""Detect Python formatter."""
pyproject_path = project_root / "pyproject.toml"
if pyproject_path.exists():
try:
with pyproject_path.open("rb") as f:
data = tomlkit.parse(f.read())
tool = data.get("tool", {})
# Check for ruff
if "ruff" in tool:
return ["ruff check --exit-zero --fix $file", "ruff format $file"], "from pyproject.toml [tool.ruff]"
# Check for black
if "black" in tool:
return ["black $file"], "from pyproject.toml [tool.black]"
except Exception:
pass
# Check for config files
if (project_root / "ruff.toml").exists() or (project_root / ".ruff.toml").exists():
return ["ruff check --exit-zero --fix $file", "ruff format $file"], "ruff.toml found"
if (project_root / ".black").exists() or (project_root / "pyproject.toml").exists():
# Default to black if pyproject.toml exists (common setup)
return ["black $file"], "default (black)"
return [], "none detected"
def _detect_js_formatter(project_root: Path) -> tuple[list[str], str]:
"""Detect JavaScript formatter."""
package_json_path = project_root / "package.json"
# Check for prettier config files
prettier_configs = [".prettierrc", ".prettierrc.js", ".prettierrc.json", "prettier.config.js"]
for config in prettier_configs:
if (project_root / config).exists():
return ["npx prettier --write $file"], f"{config} found"
# Check for eslint config files
eslint_configs = [".eslintrc", ".eslintrc.js", ".eslintrc.json", "eslint.config.js"]
for config in eslint_configs:
if (project_root / config).exists():
return ["npx eslint --fix $file"], f"{config} found"
# Check package.json dependencies
if package_json_path.exists():
try:
with package_json_path.open(encoding="utf8") as f:
package_data = json.load(f)
dev_deps = package_data.get("devDependencies", {})
deps = package_data.get("dependencies", {})
all_deps = {**deps, **dev_deps}
if "prettier" in all_deps:
return ["npx prettier --write $file"], "from devDependencies"
if "eslint" in all_deps:
return ["npx eslint --fix $file"], "from devDependencies"
except (json.JSONDecodeError, OSError):
pass
return [], "none detected"
def _detect_ignore_paths(project_root: Path, language: str) -> tuple[list[Path], str]:
"""Detect paths to ignore during optimization.
Sources:
1. .gitignore
2. Language-specific defaults
"""
ignore_paths: list[Path] = []
sources: list[str] = []
# Default ignore patterns by language
default_ignores: dict[str, list[str]] = {
"python": [
"__pycache__",
".pytest_cache",
".mypy_cache",
".ruff_cache",
"venv",
".venv",
"env",
".env",
"dist",
"build",
"*.egg-info",
".tox",
".nox",
"htmlcov",
".coverage",
],
"javascript": ["node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache"],
"typescript": ["node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache"],
}
# Add default ignores
for pattern in default_ignores.get(language, []):
path = project_root / pattern.replace("*", "")
if path.exists():
ignore_paths.append(path)
if ignore_paths:
sources.append("defaults")
# Parse .gitignore
gitignore_path = project_root / ".gitignore"
if gitignore_path.exists():
try:
content = gitignore_path.read_text(encoding="utf8")
for line in content.splitlines():
line = line.strip()
# Skip comments and empty lines
if not line or line.startswith("#"):
continue
# Skip negation patterns
if line.startswith("!"):
continue
# Convert gitignore pattern to path
pattern = line.rstrip("/").lstrip("/")
# Skip complex patterns for now
if "*" in pattern or "?" in pattern:
continue
path = project_root / pattern
if path.exists() and path not in ignore_paths:
ignore_paths.append(path)
if ".gitignore" not in sources:
sources.append(".gitignore")
except Exception:
pass
detail = " + ".join(sources) if sources else "none"
return ignore_paths, detail
def has_existing_config(project_root: Path) -> tuple[bool, str | None]:
"""Check if project has existing Codeflash configuration.
Args:
project_root: Root directory of the project.
Returns:
Tuple of (has_config, config_file_type).
config_file_type is "pyproject.toml", "package.json", or None.
"""
# Check pyproject.toml
pyproject_path = project_root / "pyproject.toml"
if pyproject_path.exists():
try:
with pyproject_path.open("rb") as f:
data = tomlkit.parse(f.read())
if "tool" in data and "codeflash" in data["tool"]:
return True, "pyproject.toml"
except Exception:
pass
# Check package.json
package_json_path = project_root / "package.json"
if package_json_path.exists():
try:
with package_json_path.open(encoding="utf8") as f:
data = json.load(f)
if "codeflash" in data:
return True, "package.json"
except Exception:
pass
return False, None

View file

@ -0,0 +1,294 @@
"""First-run experience for Codeflash.
This module handles the seamless first-run experience:
1. Auto-detect project settings
2. Display detected settings
3. Quick confirmation
4. API key setup
5. Save config and continue
Usage:
from codeflash.setup.first_run import handle_first_run, is_first_run
if is_first_run():
args = handle_first_run(args)
"""
from __future__ import annotations
import os
import sys
from typing import TYPE_CHECKING
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from codeflash.cli_cmds.console import console
from codeflash.setup.config_writer import write_config
from codeflash.setup.detector import detect_project, has_existing_config
if TYPE_CHECKING:
from argparse import Namespace
from pathlib import Path
def is_first_run(project_root: Path | None = None) -> bool:
"""Check if this is the first run (no config exists).
Args:
project_root: Project root to check. Defaults to auto-detect.
Returns:
True if no Codeflash config exists.
"""
if project_root is None:
try:
detected = detect_project()
project_root = detected.project_root
except Exception:
return True
has_config, _ = has_existing_config(project_root)
return not has_config
def handle_first_run(
args: Namespace | None = None, skip_confirm: bool = False, skip_api_key: bool = False
) -> Namespace | None:
"""Handle the first-run experience with auto-detection and quick confirm.
This is the main entry point for the frictionless setup experience.
Args:
args: Optional CLI args namespace to update.
skip_confirm: Skip confirmation prompt (--yes flag).
skip_api_key: Skip API key prompt.
Returns:
Updated args namespace with detected settings, or None if user cancelled.
"""
from argparse import Namespace
# Auto-detect project
try:
detected = detect_project()
except Exception as e:
_show_detection_error(str(e))
return None
# Show welcome message
_show_welcome()
# Show detected settings
_show_detected_settings(detected)
# Get user confirmation
if not skip_confirm:
choice = _prompt_confirmation()
if choice == "n":
_show_cancelled()
return None
if choice == "customize":
# TODO: Implement customize flow (redirect to codeflash init)
console.print("\n💡 Run [cyan]codeflash init[/cyan] for full customization.\n")
return None
# Handle API key
if not skip_api_key:
api_key_ok = _handle_api_key()
if not api_key_ok:
return None
# Save config
success, message = write_config(detected)
if success:
console.print(f"\n{message}\n")
else:
console.print(f"\n⚠️ {message}\n")
console.print("Continuing with detected settings (not saved).\n")
# Create/update args namespace
if args is None:
args = Namespace()
# Populate args with detected values
args.module_root = str(detected.module_root)
args.tests_root = str(detected.tests_root) if detected.tests_root else None
args.project_root = str(detected.project_root)
args.formatter_cmds = detected.formatter_cmds
args.ignore_paths = [str(p) for p in detected.ignore_paths]
args.pytest_cmd = detected.test_runner
args.language = detected.language
# Set defaults for other common args
if not hasattr(args, "disable_telemetry"):
args.disable_telemetry = False
if not hasattr(args, "git_remote"):
args.git_remote = "origin"
return args
def _show_welcome() -> None:
"""Show welcome message for first-time users."""
welcome_panel = Panel(
Text(
"⚡ Welcome to Codeflash!\n\nI've auto-detected your project settings.\nThis will only take a moment.",
style="bold cyan",
justify="center",
),
title="🚀 First-Time Setup",
border_style="bright_cyan",
padding=(1, 2),
)
console.print(welcome_panel)
console.print()
def _show_detected_settings(detected: detect_project) -> None:
"""Display detected settings in a nice table."""
from codeflash.setup.detector import DetectedProject
if not isinstance(detected, DetectedProject):
return
# Create settings table
table = Table(show_header=False, box=None, padding=(0, 2))
table.add_column("Setting", style="cyan", width=15)
table.add_column("Value", style="green")
table.add_column("Source", style="dim")
display_dict = detected.to_display_dict()
details = detected.detection_details
for key, value in display_dict.items():
source = details.get(key.lower().replace(" ", "_"), "")
# Truncate long sources
if len(source) > 30:
source = source[:27] + "..."
table.add_row(key, value, f"({source})" if source else "")
settings_panel = Panel(table, title="🔍 Auto-Detected Settings", border_style="bright_blue", padding=(1, 2))
console.print(settings_panel)
console.print()
def _prompt_confirmation() -> str:
"""Prompt user for confirmation.
Returns:
"y" for yes, "n" for no, "customize" for customization.
"""
# Check if we're in a non-interactive environment
if not sys.stdin.isatty():
console.print("⚠️ Non-interactive environment detected. Use --yes to skip confirmation.")
return "n"
console.print("? [bold]Proceed with these settings?[/bold]")
console.print(" [green]Y[/green] - Yes, save and continue")
console.print(" [yellow]n[/yellow] - No, cancel")
console.print(" [cyan]c[/cyan] - Customize (run full setup)")
console.print()
try:
choice = console.input("[bold]Your choice[/bold] [green][Y][/green]/n/c: ").strip().lower()
except (KeyboardInterrupt, EOFError):
return "n"
if choice in ("", "y", "yes"):
return "y"
if choice in ("c", "customize"):
return "customize"
return "n"
def _handle_api_key() -> bool:
"""Handle API key setup if not already configured.
Returns:
True if API key is available, False if user cancelled.
"""
from codeflash.code_utils.env_utils import get_codeflash_api_key
# Check for existing API key
try:
existing_key = get_codeflash_api_key()
if existing_key:
display_key = f"{existing_key[:3]}****{existing_key[-4:]}"
console.print(f"✅ Found API key: [green]{display_key}[/green]\n")
return True
except OSError:
pass
# Prompt for API key
console.print("🔑 [bold]API Key Required[/bold]")
console.print(" Get your API key at: [cyan]https://app.codeflash.ai/app/apikeys[/cyan]\n")
try:
api_key = console.input(" Enter API key (or press Enter to open browser): ").strip()
except (KeyboardInterrupt, EOFError):
return False
if not api_key:
# Open browser
import click
click.launch("https://app.codeflash.ai/app/apikeys")
console.print("\n Opening browser...")
try:
api_key = console.input(" Enter API key: ").strip()
except (KeyboardInterrupt, EOFError):
return False
if not api_key:
console.print("\n⚠️ API key required. Run [cyan]codeflash init[/cyan] to set up.\n")
return False
if not api_key.startswith("cf-"):
console.print("\n⚠️ Invalid API key format. Should start with 'cf-'.\n")
return False
# Save API key to environment
os.environ["CODEFLASH_API_KEY"] = api_key
# Try to save to shell rc
try:
from codeflash.code_utils.shell_utils import save_api_key_to_rc
from codeflash.either import is_successful
result = save_api_key_to_rc(api_key)
if is_successful(result):
console.print(f"\n✅ API key saved. {result.unwrap()}\n")
else:
console.print(f"\n⚠️ Could not save to shell: {result.failure()}")
console.print(" API key set for this session only.\n")
except Exception:
console.print("\n✅ API key set for this session.\n")
return True
def _show_detection_error(error: str) -> None:
"""Show error message when detection fails."""
error_panel = Panel(
Text(
f"❌ Could not auto-detect project settings.\n\n"
f"Error: {error}\n\n"
"Please run [cyan]codeflash init[/cyan] for manual setup.",
style="red",
),
title="⚠️ Detection Failed",
border_style="red",
padding=(1, 2),
)
console.print(error_panel)
def _show_cancelled() -> None:
"""Show cancellation message."""
console.print("\n⏹️ Setup cancelled. Run [cyan]codeflash init[/cyan] when ready.\n")

View file

@ -672,7 +672,7 @@ def parse_jest_test_xml(
test_results = TestResults()
if not test_xml_file_path.exists():
logger.warning(f"No Jest test results for {test_xml_file_path} found.")
logger.warning(f"No JavaScript test results for {test_xml_file_path} found.")
return test_results
# Log file size for debugging
@ -1486,6 +1486,9 @@ def parse_test_results(
get_run_tmp_file(Path("unittest_results.xml")).unlink(missing_ok=True)
get_run_tmp_file(Path("jest_results.xml")).unlink(missing_ok=True)
get_run_tmp_file(Path("jest_perf_results.xml")).unlink(missing_ok=True)
get_run_tmp_file(Path("vitest_results.xml")).unlink(missing_ok=True)
get_run_tmp_file(Path("vitest_perf_results.xml")).unlink(missing_ok=True)
get_run_tmp_file(Path("vitest_line_profile_results.xml")).unlink(missing_ok=True)
# For Jest tests, SQLite cleanup is deferred until after comparison
# (comparison happens via language_support.compare_test_results)

View file

@ -72,29 +72,23 @@ def generate_tests(
from codeflash.languages.javascript.module_system import ensure_module_system_compatibility
source_file = Path(function_to_optimize.file_path)
func_name = function_to_optimize.function_name
qualified_name = function_to_optimize.qualified_name
# First validate and fix import styles
generated_test_source = validate_and_fix_import_style(generated_test_source, source_file, func_name)
# Validate and fix import styles (default vs named exports)
generated_test_source = validate_and_fix_import_style(
generated_test_source, source_file, function_to_optimize.function_name
)
# Convert module system if needed (e.g., CommonJS -> ESM for ESM projects)
generated_test_source = ensure_module_system_compatibility(generated_test_source, project_module_system)
# Instrument for behavior verification (writes to SQLite)
instrumented_behavior_test_source = instrument_generated_js_test(
test_code=generated_test_source,
function_name=func_name,
qualified_name=qualified_name,
mode=TestingMode.BEHAVIOR,
test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.BEHAVIOR
)
# Instrument for performance measurement (prints to stdout)
instrumented_perf_test_source = instrument_generated_js_test(
test_code=generated_test_source,
function_name=func_name,
qualified_name=qualified_name,
mode=TestingMode.PERFORMANCE,
test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.PERFORMANCE
)
logger.debug(f"Instrumented JS/TS tests locally for {func_name}")

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
{
"name": "codeflash",
"version": "0.3.0",
"version": "0.5.0",
"description": "Codeflash - AI-powered code optimization for JavaScript and TypeScript",
"main": "runtime/index.js",
"types": "runtime/index.d.ts",
@ -9,7 +9,7 @@
"codeflash-setup": "./bin/codeflash-setup.js"
},
"publishConfig": {
"access": "public"
"access": "public"
},
"exports": {
".": {
@ -52,7 +52,8 @@
"typescript",
"ai",
"cli",
"jest"
"jest",
"vitest"
],
"author": "Codeflash AI",
"license": "MIT",
@ -70,7 +71,8 @@
},
"peerDependencies": {
"jest": ">=27.0.0",
"jest-runner": ">=27.0.0"
"jest-runner": ">=27.0.0",
"vitest": ">=1.0.0"
},
"peerDependenciesMeta": {
"jest": {
@ -78,12 +80,13 @@
},
"jest-runner": {
"optional": true
},
"vitest": {
"optional": true
}
},
"dependencies": {
"better-sqlite3": "^12.0.0",
"@msgpack/msgpack": "^3.0.0",
"jest-runner": "^29.7.0",
"jest-junit": "^16.0.0"
"@msgpack/msgpack": "^3.0.0"
}
}

View file

@ -646,9 +646,16 @@ function capturePerf(funcName, lineId, fn, ...args) {
let lastReturnValue;
let lastError = null;
// Batched looping: run BATCH_SIZE loops per capturePerf call
// This ensures fair distribution across all test invocations
const batchSize = shouldLoop ? PERF_BATCH_SIZE : 1;
// Determine if we're running with external loop-runner (Jest) or internal looping (Vitest)
// loop-runner sets CODEFLASH_PERF_CURRENT_BATCH before each batch
// If not set, we're in Vitest mode and need to do all loops internally
const hasExternalLoopRunner = process.env.CODEFLASH_PERF_CURRENT_BATCH !== undefined;
// Batched looping: run BATCH_SIZE loops per capturePerf call when using loop-runner
// For Vitest (no loop-runner), do all loops internally in a single call
const batchSize = shouldLoop
? (hasExternalLoopRunner ? PERF_BATCH_SIZE : PERF_LOOP_COUNT)
: 1;
for (let batchIndex = 0; batchIndex < batchSize; batchIndex++) {
// Check shared time limit BEFORE each iteration

View file

@ -20,6 +20,10 @@
*
* Usage:
* npx jest --runner=codeflash/loop-runner
*
* NOTE: This runner requires jest-runner to be installed in your project.
* It is a Jest-specific feature and does not work with Vitest.
* For Vitest projects, capturePerf() does all loops internally in a single call.
*/
'use strict';
@ -27,12 +31,27 @@
const { createRequire } = require('module');
const path = require('path');
const jestRunnerPath = require.resolve('jest-runner');
const internalRequire = createRequire(jestRunnerPath);
const runTest = internalRequire('./runTest').default;
// Try to load jest-runner - it's a peer dependency that must be installed by the user
let runTest;
let jestRunnerAvailable = false;
try {
const jestRunnerPath = require.resolve('jest-runner');
const internalRequire = createRequire(jestRunnerPath);
runTest = internalRequire('./runTest').default;
jestRunnerAvailable = true;
} catch (e) {
// jest-runner not installed - this is expected for Vitest projects
// The runner will throw a helpful error if someone tries to use it without jest-runner
jestRunnerAvailable = false;
}
// Configuration
const MAX_BATCHES = parseInt(process.env.CODEFLASH_PERF_LOOP_COUNT || '10000', 10);
const PERF_LOOP_COUNT = parseInt(process.env.CODEFLASH_PERF_LOOP_COUNT || '10000', 10);
const PERF_BATCH_SIZE = parseInt(process.env.CODEFLASH_PERF_BATCH_SIZE || '10', 10);
// MAX_BATCHES = how many batches needed to reach PERF_LOOP_COUNT iterations
// Add 1 to handle any rounding, but cap at PERF_LOOP_COUNT to avoid excessive batches
const MAX_BATCHES = Math.min(Math.ceil(PERF_LOOP_COUNT / PERF_BATCH_SIZE) + 1, PERF_LOOP_COUNT);
const TARGET_DURATION_MS = parseInt(process.env.CODEFLASH_PERF_TARGET_DURATION_MS || '10000', 10);
const MIN_BATCHES = parseInt(process.env.CODEFLASH_PERF_MIN_LOOPS || '5', 10);
@ -90,6 +109,14 @@ function deepCopy(obj, seen = new WeakMap()) {
*/
class CodeflashLoopRunner {
constructor(globalConfig, context) {
if (!jestRunnerAvailable) {
throw new Error(
'codeflash/loop-runner requires jest-runner to be installed.\n' +
'Please install it: npm install --save-dev jest-runner\n\n' +
'If you are using Vitest, the loop-runner is not needed - ' +
'Vitest projects use external looping handled by the Python runner.'
);
}
this._globalConfig = globalConfig;
this._context = context || {};
this._eventEmitter = new SimpleEventEmitter();

View file

@ -115,7 +115,7 @@ function installCodeflash(uvBin) {
try {
// Use uv tool install to install codeflash in an isolated environment
// This avoids conflicts with any existing Python environments
execSync(`"${uvBin}" tool install codeflash --force`, {
execSync(`"${uvBin}" tool install --force --python python3.12 codeflash`, {
stdio: 'inherit',
shell: true,
});

View file

@ -488,7 +488,7 @@ class TestParsePackageJsonConfig:
assert result is not None
config, path = result
assert config["language"] == "javascript"
assert config["test_runner"] == "jest"
assert config["test_framework"] == "jest"
assert config["pytest_cmd"] == "jest"
assert path == package_json
@ -728,7 +728,7 @@ class TestRealWorldPackageJsonExamples:
config, _ = result
assert config["language"] == "typescript"
assert config["module_root"] == str((tmp_path / "src").resolve())
assert config["test_runner"] == "jest"
assert config["test_framework"] == "jest"
assert config["formatter_cmds"] == ["npx prettier --write $file"]
def test_vite_react_project(self, tmp_path: Path) -> None:
@ -752,7 +752,7 @@ class TestRealWorldPackageJsonExamples:
assert result is not None
config, _ = result
assert config["language"] == "typescript"
assert config["test_runner"] == "vitest"
assert config["test_framework"] == "vitest"
assert config["formatter_cmds"] == ["npx eslint --fix $file"]
def test_library_with_exports(self, tmp_path: Path) -> None:
@ -812,7 +812,7 @@ class TestRealWorldPackageJsonExamples:
assert result is not None
config, _ = result
assert config["module_root"] == str((tmp_path / "lib").resolve())
assert config["test_runner"] == "mocha"
assert config["test_framework"] == "mocha"
def test_minimal_project(self, tmp_path: Path) -> None:
"""Should handle minimal package.json."""
@ -825,7 +825,7 @@ class TestRealWorldPackageJsonExamples:
config, _ = result
assert config["language"] == "javascript"
assert config["module_root"] == str(tmp_path.resolve())
assert config["test_runner"] == "jest"
assert config["test_framework"] == "jest"
assert config["formatter_cmds"] == []
def test_existing_codeflash_config_with_overrides(self, tmp_path: Path) -> None:
@ -855,3 +855,143 @@ class TestRealWorldPackageJsonExamples:
assert config["formatter_cmds"] == ["npx prettier --write --single-quote $file"]
assert len(config["ignore_paths"]) == 2
assert config["disable_telemetry"] is True
class TestTestFrameworkConfigOverride:
"""Tests for explicit test-framework config override (matches Python's pyproject.toml)."""
def test_test_framework_overrides_auto_detection(self, tmp_path: Path) -> None:
"""Should use test-framework from codeflash config instead of auto-detecting from devDependencies."""
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test-project",
"devDependencies": {"vitest": "^1.0.0"},
"codeflash": {"test-framework": "jest"},
}
)
)
result = parse_package_json_config(package_json)
assert result is not None
config, _ = result
assert config["test_framework"] == "jest"
assert config["pytest_cmd"] == "jest"
def test_explicit_vitest_config_with_jest_in_deps(self, tmp_path: Path) -> None:
"""Should use explicit vitest config even when jest is in devDependencies."""
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test-project",
"devDependencies": {"jest": "^29.0.0", "vitest": "^1.0.0"},
"codeflash": {"test-framework": "vitest"},
}
)
)
result = parse_package_json_config(package_json)
assert result is not None
config, _ = result
assert config["test_framework"] == "vitest"
def test_explicit_mocha_overrides_vitest_and_jest(self, tmp_path: Path) -> None:
"""Should use explicit mocha config even when vitest and jest are in devDependencies."""
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test-project",
"devDependencies": {"vitest": "^1.0.0", "jest": "^29.0.0"},
"codeflash": {"test-framework": "mocha"},
}
)
)
result = parse_package_json_config(package_json)
assert result is not None
config, _ = result
assert config["test_framework"] == "mocha"
def test_auto_detection_when_no_explicit_config(self, tmp_path: Path) -> None:
"""Should auto-detect test framework when no explicit config is provided."""
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test-project",
"devDependencies": {"vitest": "^1.0.0"},
"codeflash": {"moduleRoot": "src"},
}
)
)
result = parse_package_json_config(package_json)
assert result is not None
config, _ = result
assert config["test_framework"] == "vitest"
def test_empty_test_framework_falls_back_to_auto_detection(self, tmp_path: Path) -> None:
"""Should auto-detect when test-framework is empty string."""
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test-project",
"devDependencies": {"jest": "^29.0.0"},
"codeflash": {"test-framework": ""},
}
)
)
result = parse_package_json_config(package_json)
assert result is not None
config, _ = result
assert config["test_framework"] == "jest"
def test_custom_test_framework_value(self, tmp_path: Path) -> None:
"""Should accept custom test framework values not in the standard list."""
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test-project",
"devDependencies": {"vitest": "^1.0.0"},
"codeflash": {"test-framework": "ava"},
}
)
)
result = parse_package_json_config(package_json)
assert result is not None
config, _ = result
assert config["test_framework"] == "ava"
def test_pytest_cmd_matches_test_framework_with_override(self, tmp_path: Path) -> None:
"""Should set pytest_cmd to match test_framework when using explicit config."""
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test-project",
"devDependencies": {"vitest": "^1.0.0"},
"codeflash": {"test-framework": "jest"},
}
)
)
result = parse_package_json_config(package_json)
assert result is not None
config, _ = result
assert config["test_framework"] == "jest"
assert config["pytest_cmd"] == "jest"
assert config["test_framework"] == config["pytest_cmd"]

View file

View file

@ -0,0 +1,383 @@
"""Tests for JavaScript/TypeScript support.py test framework dispatch logic.
These tests verify that run_behavioral_tests, run_benchmarking_tests, and
run_line_profile_tests correctly dispatch to Jest or Vitest based on the
test_framework parameter or singleton.
"""
from __future__ import annotations
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.languages.test_framework import reset_test_framework, set_current_test_framework
from codeflash.models.models import TestFile, TestFiles, TestType
@pytest.fixture
def js_support() -> JavaScriptSupport:
"""Create a JavaScriptSupport instance."""
return JavaScriptSupport()
@pytest.fixture
def ts_support() -> TypeScriptSupport:
"""Create a TypeScriptSupport instance."""
return TypeScriptSupport()
@pytest.fixture(autouse=True)
def reset_singleton():
"""Reset the test framework singleton before each test."""
reset_test_framework()
yield
reset_test_framework()
def create_test_files(tmp_path: Path) -> TestFiles:
"""Create a TestFiles object with real file paths."""
test_file = tmp_path / "tests" / "test_func.test.ts"
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.write_text("// test file")
perf_file = tmp_path / "tests" / "test_func__perf.test.ts"
perf_file.write_text("// perf test file")
return TestFiles(
test_files=[
TestFile(
instrumented_behavior_file_path=test_file,
benchmarking_file_path=perf_file,
test_type=TestType.GENERATED_REGRESSION,
)
]
)
class TestBehavioralTestsDispatch:
"""Tests for run_behavioral_tests dispatch logic."""
@patch("codeflash.languages.javascript.test_runner.run_jest_behavioral_tests")
def test_dispatches_to_jest_by_default(self, mock_jest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
"""Should dispatch to Jest when test_framework is not specified."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
mock_jest_runner.return_value = (tmp_path / "result.xml", MagicMock(), None, None)
js_support.run_behavioral_tests(test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path)
mock_jest_runner.assert_called_once()
@patch("codeflash.languages.javascript.test_runner.run_jest_behavioral_tests")
def test_dispatches_to_jest_explicitly(self, mock_jest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
"""Should dispatch to Jest when test_framework='jest'."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
mock_jest_runner.return_value = (tmp_path / "result.xml", MagicMock(), None, None)
js_support.run_behavioral_tests(
test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path, test_framework="jest"
)
mock_jest_runner.assert_called_once()
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_behavioral_tests")
def test_dispatches_to_vitest(self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
"""Should dispatch to Vitest when test_framework='vitest'."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock(), None, None)
js_support.run_behavioral_tests(
test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path, test_framework="vitest"
)
mock_vitest_runner.assert_called_once()
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_behavioral_tests")
def test_typescript_support_dispatches_to_vitest(
self, mock_vitest_runner: MagicMock, ts_support: TypeScriptSupport
) -> None:
"""TypeScriptSupport should also dispatch to Vitest when test_framework='vitest'."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock(), None, None)
ts_support.run_behavioral_tests(
test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path, test_framework="vitest"
)
mock_vitest_runner.assert_called_once()
class TestBenchmarkingTestsDispatch:
"""Tests for run_benchmarking_tests dispatch logic."""
@patch("codeflash.languages.javascript.test_runner.run_jest_benchmarking_tests")
def test_dispatches_to_jest_by_default(self, mock_jest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
"""Should dispatch to Jest when test_framework is not specified."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
mock_jest_runner.return_value = (tmp_path / "result.xml", MagicMock())
js_support.run_benchmarking_tests(test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path)
mock_jest_runner.assert_called_once()
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_benchmarking_tests")
def test_dispatches_to_vitest(self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
"""Should dispatch to Vitest when test_framework='vitest'."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock())
js_support.run_benchmarking_tests(
test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path, test_framework="vitest"
)
mock_vitest_runner.assert_called_once()
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_benchmarking_tests")
def test_passes_loop_parameters(self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
"""Should pass loop parameters to Vitest runner."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock())
js_support.run_benchmarking_tests(
test_paths=test_files,
test_env={},
cwd=tmp_path,
project_root=tmp_path,
test_framework="vitest",
min_loops=10,
max_loops=50,
target_duration_seconds=5.0,
)
call_kwargs = mock_vitest_runner.call_args.kwargs
assert call_kwargs["min_loops"] == 10
assert call_kwargs["max_loops"] == 50
assert call_kwargs["target_duration_ms"] == 5000
class TestLineProfileTestsDispatch:
"""Tests for run_line_profile_tests dispatch logic."""
@patch("codeflash.languages.javascript.test_runner.run_jest_line_profile_tests")
def test_dispatches_to_jest_by_default(self, mock_jest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
"""Should dispatch to Jest when test_framework is not specified."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
mock_jest_runner.return_value = (tmp_path / "result.xml", MagicMock())
js_support.run_line_profile_tests(test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path)
mock_jest_runner.assert_called_once()
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_line_profile_tests")
def test_dispatches_to_vitest(self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
"""Should dispatch to Vitest when test_framework='vitest'."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock())
js_support.run_line_profile_tests(
test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path, test_framework="vitest"
)
mock_vitest_runner.assert_called_once()
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_line_profile_tests")
def test_passes_line_profile_output_file(
self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport
) -> None:
"""Should pass line_profile_output_file to Vitest runner."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
output_file = tmp_path / "line_profile.json"
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock())
js_support.run_line_profile_tests(
test_paths=test_files,
test_env={},
cwd=tmp_path,
project_root=tmp_path,
test_framework="vitest",
line_profile_output_file=output_file,
)
call_kwargs = mock_vitest_runner.call_args.kwargs
assert call_kwargs["line_profile_output_file"] == output_file
class TestTestFrameworkProperty:
"""Tests for test_framework property."""
def test_javascript_default_framework_is_jest(self, js_support: JavaScriptSupport) -> None:
"""JavaScriptSupport should have Jest as default test framework."""
assert js_support.test_framework == "jest"
def test_typescript_default_framework_is_jest(self, ts_support: TypeScriptSupport) -> None:
"""TypeScriptSupport should have Jest as default test framework."""
assert ts_support.test_framework == "jest"
class TestTestFrameworkSingleton:
"""Tests for test_framework singleton behavior."""
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_behavioral_tests")
def test_uses_singleton_when_param_not_provided(
self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport
) -> None:
"""Should use singleton test_framework when parameter is not provided."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock(), None, None)
set_current_test_framework("vitest")
js_support.run_behavioral_tests(test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path)
mock_vitest_runner.assert_called_once()
@patch("codeflash.languages.javascript.test_runner.run_jest_behavioral_tests")
def test_explicit_param_overrides_singleton(
self, mock_jest_runner: MagicMock, js_support: JavaScriptSupport
) -> None:
"""Explicit test_framework parameter should override singleton."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
mock_jest_runner.return_value = (tmp_path / "result.xml", MagicMock(), None, None)
set_current_test_framework("vitest")
js_support.run_behavioral_tests(
test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path, test_framework="jest"
)
mock_jest_runner.assert_called_once()
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_benchmarking_tests")
def test_benchmarking_uses_singleton(self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
"""run_benchmarking_tests should use singleton when param not provided."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock())
set_current_test_framework("vitest")
js_support.run_benchmarking_tests(test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path)
mock_vitest_runner.assert_called_once()
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_line_profile_tests")
def test_line_profile_uses_singleton(self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
"""run_line_profile_tests should use singleton when param not provided."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_files = create_test_files(tmp_path)
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock())
set_current_test_framework("vitest")
js_support.run_line_profile_tests(test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path)
mock_vitest_runner.assert_called_once()
class TestTestFrameworkSingletonModule:
"""Tests for the test_framework singleton module itself."""
def test_initial_state_is_none(self) -> None:
"""Singleton should start as None."""
from codeflash.languages.test_framework import current_test_framework
assert current_test_framework() is None
def test_set_and_get(self) -> None:
"""Should be able to set and get test framework."""
from codeflash.languages.test_framework import current_test_framework, set_current_test_framework
set_current_test_framework("vitest")
assert current_test_framework() == "vitest"
def test_set_only_once(self) -> None:
"""Once set, singleton should not change."""
from codeflash.languages.test_framework import current_test_framework, set_current_test_framework
set_current_test_framework("jest")
set_current_test_framework("vitest")
assert current_test_framework() == "jest"
def test_is_jest(self) -> None:
"""is_jest() should return True when framework is Jest."""
from codeflash.languages.test_framework import is_jest, set_current_test_framework
set_current_test_framework("jest")
assert is_jest() is True
def test_is_vitest(self) -> None:
"""is_vitest() should return True when framework is Vitest."""
from codeflash.languages.test_framework import is_vitest, set_current_test_framework
set_current_test_framework("vitest")
assert is_vitest() is True
def test_get_js_test_framework_or_default_returns_jest(self) -> None:
"""get_js_test_framework_or_default should return 'jest' when not set."""
from codeflash.languages.test_framework import get_js_test_framework_or_default
assert get_js_test_framework_or_default() == "jest"
def test_get_js_test_framework_or_default_returns_vitest(self) -> None:
"""get_js_test_framework_or_default should return 'vitest' when set."""
from codeflash.languages.test_framework import get_js_test_framework_or_default, set_current_test_framework
set_current_test_framework("vitest")
assert get_js_test_framework_or_default() == "vitest"

View file

@ -0,0 +1,247 @@
"""Tests for Vitest JUnit XML output parsing and compatibility.
These tests verify that Vitest's JUnit XML output can be parsed
by the existing parsing infrastructure.
"""
from __future__ import annotations
import tempfile
from pathlib import Path
import pytest
from junitparser import JUnitXml
from codeflash.verification.parse_test_output import jest_end_pattern, jest_start_pattern
class TestVitestJunitXmlFormat:
"""Tests for Vitest JUnit XML format compatibility."""
def test_can_parse_vitest_junit_xml(self) -> None:
"""Should be able to parse Vitest JUnit XML with junitparser."""
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
<testsuites name="vitest tests" tests="4" failures="1" errors="0" time="0.537">
<testsuite name="tests/fibonacci.test.ts" timestamp="2026-01-30T18:03:49.433Z" hostname="localhost" tests="3" failures="0" errors="0" skipped="0" time="0.008">
<testcase classname="tests/fibonacci.test.ts" name="fibonacci &gt; returns 0 for n=0" time="0.001">
</testcase>
<testcase classname="tests/fibonacci.test.ts" name="fibonacci &gt; returns 1 for n=1" time="0.0005">
</testcase>
<testcase classname="tests/fibonacci.test.ts" name="fibonacci &gt; returns 55 for n=10" time="0.0001">
</testcase>
</testsuite>
<testsuite name="tests/string_utils.test.ts" timestamp="2026-01-30T18:03:49.438Z" hostname="localhost" tests="1" failures="1" errors="0" skipped="0" time="0.01">
<testcase classname="tests/string_utils.test.ts" name="reverseString &gt; reverses a simple string" time="0.0007">
<failure message="expected &apos;olleh&apos; to equal &apos;hello&apos;" type="AssertionError">AssertionError: expected 'olleh' to equal 'hello'</failure>
</testcase>
</testsuite>
</testsuites>"""
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
f.write(xml_content)
f.flush()
junit_file = Path(f.name)
xml = JUnitXml.fromfile(str(junit_file))
assert xml is not None
test_count = sum(len(list(suite)) for suite in xml)
assert test_count == 4
def test_extracts_test_suite_names(self) -> None:
"""Should extract test suite names from Vitest JUnit XML."""
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
<testsuites name="vitest tests" tests="2" failures="0" errors="0" time="0.1">
<testsuite name="tests/fibonacci.test.ts" tests="1" failures="0" time="0.01">
<testcase classname="tests/fibonacci.test.ts" name="test1" time="0.001"></testcase>
</testsuite>
<testsuite name="tests/string_utils.test.ts" tests="1" failures="0" time="0.01">
<testcase classname="tests/string_utils.test.ts" name="test2" time="0.001"></testcase>
</testsuite>
</testsuites>"""
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
f.write(xml_content)
f.flush()
junit_file = Path(f.name)
xml = JUnitXml.fromfile(str(junit_file))
suite_names = [suite.name for suite in xml]
assert suite_names == ["tests/fibonacci.test.ts", "tests/string_utils.test.ts"]
def test_extracts_test_case_names_with_vitest_separator(self) -> None:
"""Should extract test case names from Vitest JUnit XML (uses > as separator)."""
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
<testsuites name="vitest tests" tests="2" failures="0" errors="0" time="0.1">
<testsuite name="tests/fibonacci.test.ts" tests="2" failures="0" time="0.01">
<testcase classname="tests/fibonacci.test.ts" name="fibonacci &gt; returns 0 for n=0" time="0.001"></testcase>
<testcase classname="tests/fibonacci.test.ts" name="fibonacci &gt; returns 1 for n=1" time="0.001"></testcase>
</testsuite>
</testsuites>"""
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
f.write(xml_content)
f.flush()
junit_file = Path(f.name)
xml = JUnitXml.fromfile(str(junit_file))
test_names = []
for suite in xml:
for case in suite:
test_names.append(case.name)
assert test_names == ["fibonacci > returns 0 for n=0", "fibonacci > returns 1 for n=1"]
def test_extracts_classname_as_file_path(self) -> None:
"""Should extract classname which contains file path in Vitest."""
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
<testsuites name="vitest tests" tests="1" failures="0" errors="0" time="0.1">
<testsuite name="tests/fibonacci.test.ts" tests="1" failures="0" time="0.01">
<testcase classname="tests/fibonacci.test.ts" name="test1" time="0.001"></testcase>
</testsuite>
</testsuites>"""
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
f.write(xml_content)
f.flush()
junit_file = Path(f.name)
xml = JUnitXml.fromfile(str(junit_file))
for suite in xml:
for case in suite:
assert case.classname == "tests/fibonacci.test.ts"
def test_extracts_test_time_as_float(self) -> None:
"""Should extract test execution time as float from Vitest JUnit XML."""
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
<testsuites name="vitest tests" tests="1" failures="0" errors="0" time="0.1">
<testsuite name="tests/test.ts" tests="1" failures="0" time="0.01">
<testcase classname="tests/test.ts" name="test1" time="0.0015"></testcase>
</testsuite>
</testsuites>"""
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
f.write(xml_content)
f.flush()
junit_file = Path(f.name)
xml = JUnitXml.fromfile(str(junit_file))
for suite in xml:
for case in suite:
assert isinstance(case.time, float)
assert case.time == 0.0015
def test_detects_failures(self) -> None:
"""Should detect test failures in Vitest JUnit XML."""
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
<testsuites name="vitest tests" tests="2" failures="1" errors="0" time="0.1">
<testsuite name="tests/test.ts" tests="2" failures="1" time="0.01">
<testcase classname="tests/test.ts" name="passing test" time="0.001"></testcase>
<testcase classname="tests/test.ts" name="failing test" time="0.001">
<failure message="expected true to be false" type="AssertionError">AssertionError: expected true to be false</failure>
</testcase>
</testsuite>
</testsuites>"""
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
f.write(xml_content)
f.flush()
junit_file = Path(f.name)
xml = JUnitXml.fromfile(str(junit_file))
failures = []
for suite in xml:
for case in suite:
if not case.is_passed:
failures.append(case.name)
assert failures == ["failing test"]
def test_extracts_failure_message(self) -> None:
"""Should extract failure message from Vitest JUnit XML."""
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
<testsuites name="vitest tests" tests="1" failures="1" errors="0" time="0.1">
<testsuite name="tests/test.ts" tests="1" failures="1" time="0.01">
<testcase classname="tests/test.ts" name="failing test" time="0.001">
<failure message="expected 'actual' to equal 'expected'" type="AssertionError">AssertionError: expected 'actual' to equal 'expected'</failure>
</testcase>
</testsuite>
</testsuites>"""
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
f.write(xml_content)
f.flush()
junit_file = Path(f.name)
xml = JUnitXml.fromfile(str(junit_file))
for suite in xml:
for case in suite:
if not case.is_passed:
for result in case.result:
if hasattr(result, "message"):
assert result.message == "expected 'actual' to equal 'expected'"
class TestVitestTimingMarkers:
"""Tests for Vitest timing marker extraction.
Timing markers are used to measure function execution time during benchmarking.
The format is the same for Jest and Vitest since they use the same codeflash helper.
"""
def test_parses_start_timing_marker(self) -> None:
"""Should parse start timing marker from Vitest output."""
output = "!$######fibonacci.test.ts:returns 0 for n=0:fibonacci:1:line_0######$!"
matches = jest_start_pattern.findall(output)
assert len(matches) == 1
test_file, test_name, func_name, loop_index, line_id = matches[0]
assert test_file == "fibonacci.test.ts"
assert test_name == "returns 0 for n=0"
assert func_name == "fibonacci"
assert loop_index == "1"
assert line_id == "line_0"
def test_parses_end_timing_marker(self) -> None:
"""Should parse end timing marker from Vitest output."""
output = "!######fibonacci.test.ts:returns 0 for n=0:fibonacci:1:line_0:123456######!"
matches = jest_end_pattern.findall(output)
assert len(matches) == 1
test_file, test_name, func_name, loop_index, line_id, duration = matches[0]
assert test_file == "fibonacci.test.ts"
assert test_name == "returns 0 for n=0"
assert func_name == "fibonacci"
assert loop_index == "1"
assert line_id == "line_0"
assert duration == "123456"
def test_extracts_multiple_timing_markers(self) -> None:
"""Should extract multiple timing markers from Vitest output."""
output = """Running tests...
!$######test.ts:test1:func:1:id1######$!
executing...
!######test.ts:test1:func:1:id1:100000######!
!$######test.ts:test2:func:1:id2######$!
executing...
!######test.ts:test2:func:1:id2:200000######!
Done."""
start_matches = jest_start_pattern.findall(output)
end_matches = jest_end_pattern.findall(output)
assert len(start_matches) == 2
assert len(end_matches) == 2
durations = [int(m[5]) for m in end_matches]
assert durations == [100000, 200000]
def test_timing_marker_with_special_characters_in_test_name(self) -> None:
"""Should handle test names with special characters."""
output = "!$######test.ts:handles_n=0_correctly:fibonacci:1:id######$!"
matches = jest_start_pattern.findall(output)
assert len(matches) == 1
assert matches[0][1] == "handles_n=0_correctly"

View file

@ -0,0 +1,243 @@
"""Tests for Vitest test runner command construction.
These tests verify that Vitest commands are correctly constructed
with the appropriate flags and arguments.
"""
from __future__ import annotations
import tempfile
from pathlib import Path
import pytest
from codeflash.languages.javascript.vitest_runner import (
_build_vitest_behavioral_command,
_build_vitest_benchmarking_command,
_find_vitest_project_root,
)
class TestFindVitestProjectRoot:
"""Tests for _find_vitest_project_root function."""
def test_finds_vitest_config_js(self) -> None:
"""Should find project root via vitest.config.js."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "vitest.config.js").write_text("export default {}")
test_file = tmp_path / "tests" / "test.test.ts"
test_file.parent.mkdir(parents=True)
test_file.write_text("")
result = _find_vitest_project_root(test_file)
assert result == tmp_path
def test_finds_vitest_config_ts(self) -> None:
"""Should find project root via vitest.config.ts."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "vitest.config.ts").write_text("export default {}")
test_file = tmp_path / "tests" / "test.test.ts"
test_file.parent.mkdir(parents=True)
test_file.write_text("")
result = _find_vitest_project_root(test_file)
assert result == tmp_path
def test_finds_vite_config_js(self) -> None:
"""Should find project root via vite.config.js (Vitest can be configured in vite config)."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "vite.config.js").write_text("export default {}")
test_file = tmp_path / "tests" / "test.test.ts"
test_file.parent.mkdir(parents=True)
test_file.write_text("")
result = _find_vitest_project_root(test_file)
assert result == tmp_path
def test_falls_back_to_package_json(self) -> None:
"""Should fall back to package.json when no vitest config found."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
(tmp_path / "package.json").write_text('{"name": "test"}')
test_file = tmp_path / "tests" / "test.test.ts"
test_file.parent.mkdir(parents=True)
test_file.write_text("")
result = _find_vitest_project_root(test_file)
assert result == tmp_path
def test_returns_none_when_no_config(self) -> None:
"""Should return None when no vitest/vite config or package.json found."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
test_file = tmp_path / "tests" / "test.test.ts"
test_file.parent.mkdir(parents=True)
test_file.write_text("")
result = _find_vitest_project_root(test_file)
assert result is None
class TestBuildVitestBehavioralCommand:
"""Tests for _build_vitest_behavioral_command function."""
def test_basic_command_structure(self) -> None:
"""Should build basic Vitest command with required flags."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
test_file = tmp_path / "test.test.ts"
test_file.write_text("")
cmd = _build_vitest_behavioral_command([test_file], timeout=60)
assert cmd[0] == "npx"
assert cmd[1] == "vitest"
assert cmd[2] == "run"
def test_includes_reporter_flags(self) -> None:
"""Should include reporter flags for JUnit output."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
test_file = tmp_path / "test.test.ts"
test_file.write_text("")
cmd = _build_vitest_behavioral_command([test_file], timeout=60)
assert "--reporter=default" in cmd
assert "--reporter=junit" in cmd
def test_includes_serial_execution_flag(self) -> None:
"""Should include flag for serial test execution."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
test_file = tmp_path / "test.test.ts"
test_file.write_text("")
cmd = _build_vitest_behavioral_command([test_file], timeout=60)
assert "--no-file-parallelism" in cmd
def test_includes_test_files_as_absolute_paths(self) -> None:
"""Should include test files at the end of command as absolute paths."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
test_file_a = tmp_path / "test_a.test.ts"
test_file_b = tmp_path / "test_b.test.ts"
test_file_a.write_text("")
test_file_b.write_text("")
cmd = _build_vitest_behavioral_command([test_file_a, test_file_b], timeout=60)
assert str(test_file_a.resolve()) in cmd
assert str(test_file_b.resolve()) in cmd
def test_includes_timeout_in_milliseconds(self) -> None:
"""Should include test timeout in milliseconds."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
test_file = tmp_path / "test.test.ts"
test_file.write_text("")
cmd = _build_vitest_behavioral_command([test_file], timeout=120)
assert "--test-timeout=120000" in cmd
def test_includes_output_file_when_provided(self) -> None:
"""Should include --outputFile flag when output_file is provided."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
test_file = tmp_path / "test.test.ts"
test_file.write_text("")
output_file = tmp_path / "results.xml"
cmd = _build_vitest_behavioral_command([test_file], timeout=60, output_file=output_file)
assert f"--outputFile={output_file}" in cmd
class TestBuildVitestBenchmarkingCommand:
"""Tests for _build_vitest_benchmarking_command function."""
def test_basic_command_structure(self) -> None:
"""Should build basic Vitest benchmarking command."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
test_file = tmp_path / "test__perf.test.ts"
test_file.write_text("")
cmd = _build_vitest_benchmarking_command([test_file], timeout=60)
assert cmd[0] == "npx"
assert cmd[1] == "vitest"
assert cmd[2] == "run"
def test_includes_serial_execution(self) -> None:
"""Should include serial execution for consistent benchmarking."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
test_file = tmp_path / "test__perf.test.ts"
test_file.write_text("")
cmd = _build_vitest_benchmarking_command([test_file], timeout=60)
assert "--no-file-parallelism" in cmd
class TestVitestVsJestCommandDifferences:
"""Tests documenting the key differences between Vitest and Jest commands."""
def test_vitest_uses_run_subcommand(self) -> None:
"""Vitest uses 'run' for single execution, Jest doesn't need it."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
test_file = tmp_path / "test.test.ts"
test_file.write_text("")
vitest_cmd = _build_vitest_behavioral_command([test_file], timeout=60)
assert vitest_cmd[0:3] == ["npx", "vitest", "run"]
def test_vitest_uses_hyphenated_timeout(self) -> None:
"""Vitest uses --test-timeout, Jest uses --testTimeout (camelCase)."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
test_file = tmp_path / "test.test.ts"
test_file.write_text("")
vitest_cmd = _build_vitest_behavioral_command([test_file], timeout=60)
timeout_args = [arg for arg in vitest_cmd if "timeout" in arg.lower()]
assert len(timeout_args) == 1
assert timeout_args[0] == "--test-timeout=60000"
def test_vitest_uses_no_file_parallelism(self) -> None:
"""Vitest uses --no-file-parallelism, Jest uses --runInBand."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
test_file = tmp_path / "test.test.ts"
test_file.write_text("")
vitest_cmd = _build_vitest_behavioral_command([test_file], timeout=60)
assert "--no-file-parallelism" in vitest_cmd
assert "--runInBand" not in vitest_cmd
def test_vitest_positional_test_files(self) -> None:
"""Vitest uses positional args for test files, not --runTestsByPath."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
test_file = tmp_path / "test.test.ts"
test_file.write_text("")
vitest_cmd = _build_vitest_behavioral_command([test_file], timeout=60)
assert "--runTestsByPath" not in vitest_cmd
assert str(test_file.resolve()) in vitest_cmd

View file

@ -604,3 +604,349 @@ def test_function_in_tests_dir():
assert "vanilla_function" not in remaining_functions
files_and_funcs = get_all_files_and_functions(module_root_path=temp_dir, ignore_paths=[])
assert len(files_and_funcs) == 6
def test_filter_functions_tests_root_overlaps_source():
"""Test that source files are not filtered when tests_root equals module_root or project_root.
This is a critical test for monorepo structures where tests live alongside source code
(e.g., TypeScript projects with .test.ts files in the same directories as source).
"""
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
# Create a source file (NOT a test file)
source_file = temp_dir / "utils.py"
with source_file.open("w") as f:
f.write("""
def process_data(items):
return [item * 2 for item in items]
def calculate_sum(numbers):
return sum(numbers)
""")
# Create a test file with standard naming pattern
test_file = temp_dir / "utils.test.py"
with test_file.open("w") as f:
f.write("""
def test_process_data():
return "test"
""")
# Create a test file with _test suffix pattern
test_file_underscore = temp_dir / "utils_test.py"
with test_file_underscore.open("w") as f:
f.write("""
def test_calculate_sum():
return "test"
""")
# Create a spec file
spec_file = temp_dir / "utils.spec.py"
with spec_file.open("w") as f:
f.write("""
def spec_function():
return "spec"
""")
# Create a file in a tests subdirectory
tests_subdir = temp_dir / "tests"
tests_subdir.mkdir()
tests_subdir_file = tests_subdir / "test_main.py"
with tests_subdir_file.open("w") as f:
f.write("""
def test_in_tests_dir():
return "test"
""")
# Create a file in __tests__ subdirectory (common in JS/TS projects)
dunder_tests_subdir = temp_dir / "__tests__"
dunder_tests_subdir.mkdir()
dunder_tests_file = dunder_tests_subdir / "main.py"
with dunder_tests_file.open("w") as f:
f.write("""
def test_in_dunder_tests():
return "test"
""")
# Discover all functions
discovered_source = find_all_functions_in_file(source_file)
discovered_test = find_all_functions_in_file(test_file)
discovered_test_underscore = find_all_functions_in_file(test_file_underscore)
discovered_spec = find_all_functions_in_file(spec_file)
discovered_tests_dir = find_all_functions_in_file(tests_subdir_file)
discovered_dunder_tests = find_all_functions_in_file(dunder_tests_file)
# Combine all discovered functions
all_functions = {}
for discovered in [discovered_source, discovered_test, discovered_test_underscore,
discovered_spec, discovered_tests_dir, discovered_dunder_tests]:
all_functions.update(discovered)
# Test Case 1: tests_root == module_root (overlapping case)
# This is the bug scenario where all functions were being filtered
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered, count = filter_functions(
all_functions,
tests_root=temp_dir, # Same as module_root
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir, # Same as tests_root
)
# Strict check: only source_file should remain in filtered results
assert set(filtered.keys()) == {source_file}, (
f"Expected only source file in filtered results, got: {set(filtered.keys())}"
)
# Strict check: exactly these two functions should be present
source_functions = sorted([fn.function_name for fn in filtered.get(source_file, [])])
assert source_functions == ["calculate_sum", "process_data"], (
f"Expected ['calculate_sum', 'process_data'], got {source_functions}"
)
# Strict check: exactly 2 functions remaining
assert count == 2, f"Expected exactly 2 functions, got {count}"
# Test Case 2: tests_root == project_root (another overlapping case)
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered2, count2 = filter_functions(
{source_file: discovered_source[source_file]},
tests_root=temp_dir, # Same as project_root
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
)
# Strict check: only source_file should remain
assert set(filtered2.keys()) == {source_file}, (
f"Expected only source file when tests_root == project_root, got: {set(filtered2.keys())}"
)
assert count2 == 2, f"Expected exactly 2 functions, got {count2}"
def test_filter_functions_strict_string_matching():
"""Test that test file pattern matching uses strict string matching.
Ensures patterns like '.test.' only match actual test files and don't
accidentally match files with similar names like 'contest.py' or 'latest.py'.
"""
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
# Files that should NOT be filtered (contain 'test' as substring but not as pattern)
contest_file = temp_dir / "contest.py"
with contest_file.open("w") as f:
f.write("def run_contest(): return 1")
latest_file = temp_dir / "latest.py"
with latest_file.open("w") as f:
f.write("def get_latest(): return 1")
attestation_file = temp_dir / "attestation.py"
with attestation_file.open("w") as f:
f.write("def verify_attestation(): return 1")
# File that SHOULD be filtered (matches .test. pattern)
actual_test_file = temp_dir / "utils.test.py"
with actual_test_file.open("w") as f:
f.write("def test_utils(): return 1")
# File that SHOULD be filtered (matches _test. pattern)
underscore_test_file = temp_dir / "utils_test.py"
with underscore_test_file.open("w") as f:
f.write("def test_stuff(): return 1")
# Discover all functions
all_functions = {}
for file_path in [contest_file, latest_file, attestation_file, actual_test_file, underscore_test_file]:
discovered = find_all_functions_in_file(file_path)
all_functions.update(discovered)
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered, count = filter_functions(
all_functions,
tests_root=temp_dir, # Overlapping case to trigger pattern matching
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
)
# Strict check: exactly these 3 files should remain (those with 'test' as substring only)
expected_files = {contest_file, latest_file, attestation_file}
assert set(filtered.keys()) == expected_files, (
f"Expected files {expected_files}, got {set(filtered.keys())}"
)
# Strict check: each file should have exactly 1 function with the expected name
assert [fn.function_name for fn in filtered[contest_file]] == ["run_contest"], (
f"Expected ['run_contest'], got {[fn.function_name for fn in filtered[contest_file]]}"
)
assert [fn.function_name for fn in filtered[latest_file]] == ["get_latest"], (
f"Expected ['get_latest'], got {[fn.function_name for fn in filtered[latest_file]]}"
)
assert [fn.function_name for fn in filtered[attestation_file]] == ["verify_attestation"], (
f"Expected ['verify_attestation'], got {[fn.function_name for fn in filtered[attestation_file]]}"
)
# Strict check: exactly 3 functions remaining
assert count == 3, f"Expected exactly 3 functions, got {count}"
def test_filter_functions_test_directory_patterns():
"""Test that test directory patterns work correctly with strict matching.
Ensures that /test/, /tests/, and /__tests__/ patterns only match actual
test directories and not directories that happen to contain 'test' in name.
"""
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
# Directory that should NOT be filtered (contains 'test' but not as /test/ pattern)
contest_dir = temp_dir / "contest_results"
contest_dir.mkdir()
contest_file = contest_dir / "scores.py"
with contest_file.open("w") as f:
f.write("def get_scores(): return [1, 2, 3]")
latest_dir = temp_dir / "latest_data"
latest_dir.mkdir()
latest_file = latest_dir / "data.py"
with latest_file.open("w") as f:
f.write("def load_data(): return {}")
# Directory that SHOULD be filtered (matches /tests/ pattern)
tests_dir = temp_dir / "tests"
tests_dir.mkdir()
tests_file = tests_dir / "test_main.py"
with tests_file.open("w") as f:
f.write("def test_main(): return True")
# Directory that SHOULD be filtered (matches /test/ pattern - singular)
test_dir = temp_dir / "test"
test_dir.mkdir()
test_file = test_dir / "test_utils.py"
with test_file.open("w") as f:
f.write("def test_utils(): return True")
# Directory that SHOULD be filtered (matches /__tests__/ pattern)
dunder_tests_dir = temp_dir / "__tests__"
dunder_tests_dir.mkdir()
dunder_file = dunder_tests_dir / "component.py"
with dunder_file.open("w") as f:
f.write("def test_component(): return True")
# Nested test directory
src_dir = temp_dir / "src"
src_dir.mkdir()
nested_tests_dir = src_dir / "tests"
nested_tests_dir.mkdir()
nested_test_file = nested_tests_dir / "test_nested.py"
with nested_test_file.open("w") as f:
f.write("def test_nested(): return True")
# Discover all functions
all_functions = {}
for file_path in [contest_file, latest_file, tests_file, test_file, dunder_file, nested_test_file]:
discovered = find_all_functions_in_file(file_path)
all_functions.update(discovered)
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered, count = filter_functions(
all_functions,
tests_root=temp_dir, # Overlapping case
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
)
# Strict check: exactly these 2 files should remain (those in non-test directories)
expected_files = {contest_file, latest_file}
assert set(filtered.keys()) == expected_files, (
f"Expected files {expected_files}, got {set(filtered.keys())}"
)
# Strict check: each file should have exactly 1 function with the expected name
assert [fn.function_name for fn in filtered[contest_file]] == ["get_scores"], (
f"Expected ['get_scores'], got {[fn.function_name for fn in filtered[contest_file]]}"
)
assert [fn.function_name for fn in filtered[latest_file]] == ["load_data"], (
f"Expected ['load_data'], got {[fn.function_name for fn in filtered[latest_file]]}"
)
# Strict check: exactly 2 functions remaining
assert count == 2, f"Expected exactly 2 functions, got {count}"
def test_filter_functions_non_overlapping_tests_root():
"""Test that the original directory-based filtering still works when tests_root is separate.
When tests_root is a distinct directory (e.g., 'tests/'), the original behavior
of filtering files that start with tests_root should still work.
"""
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
# Create source directory structure
src_dir = temp_dir / "src"
src_dir.mkdir()
source_file = src_dir / "utils.py"
with source_file.open("w") as f:
f.write("def process(): return 1")
# Create a file with .test. pattern in source (should NOT be filtered in non-overlapping mode)
# because directory-based filtering takes precedence
test_in_src = src_dir / "helper.test.py"
with test_in_src.open("w") as f:
f.write("def helper_test(): return 1")
# Create separate tests directory
tests_dir = temp_dir / "tests"
tests_dir.mkdir()
test_file = tests_dir / "test_utils.py"
with test_file.open("w") as f:
f.write("def test_process(): return 1")
# Discover functions
all_functions = {}
for file_path in [source_file, test_in_src, test_file]:
discovered = find_all_functions_in_file(file_path)
all_functions.update(discovered)
# Non-overlapping case: tests_root is a separate directory
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered, count = filter_functions(
all_functions,
tests_root=tests_dir, # Separate from module_root
ignore_paths=[],
project_root=temp_dir,
module_root=src_dir, # Different from tests_root
)
# Strict check: exactly these 2 files should remain (both in src/, not in tests/)
expected_files = {source_file, test_in_src}
assert set(filtered.keys()) == expected_files, (
f"Expected files {expected_files}, got {set(filtered.keys())}"
)
# Strict check: each file should have exactly 1 function with the expected name
assert [fn.function_name for fn in filtered[source_file]] == ["process"], (
f"Expected ['process'], got {[fn.function_name for fn in filtered[source_file]]}"
)
assert [fn.function_name for fn in filtered[test_in_src]] == ["helper_test"], (
f"Expected ['helper_test'], got {[fn.function_name for fn in filtered[test_in_src]]}"
)
# Strict check: exactly 2 functions remaining
assert count == 2, f"Expected exactly 2 functions, got {count}"

View file

@ -0,0 +1,283 @@
"""Tests for JavaScript/TypeScript project initialization and package manager detection."""
from pathlib import Path
import pytest
from codeflash.cli_cmds.init_javascript import (
JsPackageManager,
determine_js_package_manager,
get_package_install_command,
)
@pytest.fixture
def tmp_project(tmp_path: Path) -> Path:
"""Create a temporary project directory."""
return tmp_path
class TestDetermineJsPackageManager:
"""Tests for determine_js_package_manager function."""
def test_detects_pnpm_from_lockfile(self, tmp_project: Path) -> None:
"""Should detect pnpm from pnpm-lock.yaml."""
(tmp_project / "pnpm-lock.yaml").write_text("")
(tmp_project / "package.json").write_text("{}")
result = determine_js_package_manager(tmp_project)
assert result == JsPackageManager.PNPM
def test_detects_yarn_from_lockfile(self, tmp_project: Path) -> None:
"""Should detect yarn from yarn.lock."""
(tmp_project / "yarn.lock").write_text("")
(tmp_project / "package.json").write_text("{}")
result = determine_js_package_manager(tmp_project)
assert result == JsPackageManager.YARN
def test_detects_npm_from_lockfile(self, tmp_project: Path) -> None:
"""Should detect npm from package-lock.json."""
(tmp_project / "package-lock.json").write_text("{}")
(tmp_project / "package.json").write_text("{}")
result = determine_js_package_manager(tmp_project)
assert result == JsPackageManager.NPM
def test_detects_bun_from_lockfile(self, tmp_project: Path) -> None:
"""Should detect bun from bun.lockb."""
(tmp_project / "bun.lockb").write_text("")
(tmp_project / "package.json").write_text("{}")
result = determine_js_package_manager(tmp_project)
assert result == JsPackageManager.BUN
def test_detects_bun_from_bun_lock(self, tmp_project: Path) -> None:
"""Should detect bun from bun.lock."""
(tmp_project / "bun.lock").write_text("")
(tmp_project / "package.json").write_text("{}")
result = determine_js_package_manager(tmp_project)
assert result == JsPackageManager.BUN
def test_defaults_to_npm_with_package_json_only(self, tmp_project: Path) -> None:
"""Should default to npm when only package.json exists."""
(tmp_project / "package.json").write_text("{}")
result = determine_js_package_manager(tmp_project)
assert result == JsPackageManager.NPM
def test_returns_unknown_without_package_json(self, tmp_project: Path) -> None:
"""Should return UNKNOWN when no package.json exists."""
result = determine_js_package_manager(tmp_project)
assert result == JsPackageManager.UNKNOWN
def test_pnpm_takes_precedence_over_npm(self, tmp_project: Path) -> None:
"""Should prefer pnpm when both lockfiles exist (migration scenario)."""
(tmp_project / "pnpm-lock.yaml").write_text("")
(tmp_project / "package-lock.json").write_text("{}")
(tmp_project / "package.json").write_text("{}")
result = determine_js_package_manager(tmp_project)
assert result == JsPackageManager.PNPM
def test_bun_takes_precedence_over_others(self, tmp_project: Path) -> None:
"""Should prefer bun when bun.lockb exists alongside others."""
(tmp_project / "bun.lockb").write_text("")
(tmp_project / "pnpm-lock.yaml").write_text("")
(tmp_project / "package.json").write_text("{}")
result = determine_js_package_manager(tmp_project)
assert result == JsPackageManager.BUN
# Monorepo tests - lock file in parent directory
def test_detects_pnpm_from_parent_lockfile(self, tmp_project: Path) -> None:
"""Should detect pnpm from pnpm-lock.yaml in parent directory (monorepo)."""
# Create monorepo structure: root/packages/my-package
workspace_root = tmp_project
package_dir = workspace_root / "packages" / "my-package"
package_dir.mkdir(parents=True)
# Lock file at workspace root
(workspace_root / "pnpm-lock.yaml").write_text("")
(workspace_root / "package.json").write_text("{}")
# Package has its own package.json but no lock file
(package_dir / "package.json").write_text("{}")
result = determine_js_package_manager(package_dir)
assert result == JsPackageManager.PNPM
def test_detects_yarn_from_parent_lockfile(self, tmp_project: Path) -> None:
"""Should detect yarn from yarn.lock in parent directory (monorepo)."""
workspace_root = tmp_project
package_dir = workspace_root / "packages" / "my-package"
package_dir.mkdir(parents=True)
(workspace_root / "yarn.lock").write_text("")
(workspace_root / "package.json").write_text("{}")
(package_dir / "package.json").write_text("{}")
result = determine_js_package_manager(package_dir)
assert result == JsPackageManager.YARN
def test_detects_npm_from_parent_lockfile(self, tmp_project: Path) -> None:
"""Should detect npm from package-lock.json in parent directory (monorepo)."""
workspace_root = tmp_project
package_dir = workspace_root / "packages" / "my-package"
package_dir.mkdir(parents=True)
(workspace_root / "package-lock.json").write_text("{}")
(workspace_root / "package.json").write_text("{}")
(package_dir / "package.json").write_text("{}")
result = determine_js_package_manager(package_dir)
assert result == JsPackageManager.NPM
def test_detects_bun_from_parent_lockfile(self, tmp_project: Path) -> None:
"""Should detect bun from bun.lockb in parent directory (monorepo)."""
workspace_root = tmp_project
package_dir = workspace_root / "packages" / "my-package"
package_dir.mkdir(parents=True)
(workspace_root / "bun.lockb").write_text("")
(workspace_root / "package.json").write_text("{}")
(package_dir / "package.json").write_text("{}")
result = determine_js_package_manager(package_dir)
assert result == JsPackageManager.BUN
def test_local_lockfile_takes_precedence_over_parent(self, tmp_project: Path) -> None:
"""Should prefer local lock file over parent directory lock file."""
workspace_root = tmp_project
package_dir = workspace_root / "packages" / "my-package"
package_dir.mkdir(parents=True)
# Parent has pnpm, but local package has yarn
(workspace_root / "pnpm-lock.yaml").write_text("")
(workspace_root / "package.json").write_text("{}")
(package_dir / "yarn.lock").write_text("")
(package_dir / "package.json").write_text("{}")
result = determine_js_package_manager(package_dir)
# Should detect yarn from local directory first
assert result == JsPackageManager.YARN
def test_deeply_nested_package_finds_root_lockfile(self, tmp_project: Path) -> None:
"""Should find lock file in deeply nested monorepo structure."""
workspace_root = tmp_project
# Simulate: root/apps/web/src/features/auth
deep_dir = workspace_root / "apps" / "web" / "src" / "features" / "auth"
deep_dir.mkdir(parents=True)
(workspace_root / "pnpm-lock.yaml").write_text("")
(workspace_root / "package.json").write_text("{}")
result = determine_js_package_manager(deep_dir)
assert result == JsPackageManager.PNPM
class TestGetPackageInstallCommand:
"""Tests for get_package_install_command function."""
def test_npm_install_command(self, tmp_project: Path) -> None:
"""Should return npm install command for npm projects."""
(tmp_project / "package-lock.json").write_text("{}")
(tmp_project / "package.json").write_text("{}")
result = get_package_install_command(tmp_project, "codeflash", dev=True)
assert result == ["npm", "install", "codeflash", "--save-dev"]
def test_npm_install_command_non_dev(self, tmp_project: Path) -> None:
"""Should return npm install command without --save-dev when dev=False."""
(tmp_project / "package-lock.json").write_text("{}")
(tmp_project / "package.json").write_text("{}")
result = get_package_install_command(tmp_project, "codeflash", dev=False)
assert result == ["npm", "install", "codeflash"]
def test_pnpm_add_command(self, tmp_project: Path) -> None:
"""Should return pnpm add command for pnpm projects."""
(tmp_project / "pnpm-lock.yaml").write_text("")
(tmp_project / "package.json").write_text("{}")
result = get_package_install_command(tmp_project, "codeflash", dev=True)
assert result == ["pnpm", "add", "codeflash", "--save-dev"]
def test_pnpm_add_command_non_dev(self, tmp_project: Path) -> None:
"""Should return pnpm add command without --save-dev when dev=False."""
(tmp_project / "pnpm-lock.yaml").write_text("")
(tmp_project / "package.json").write_text("{}")
result = get_package_install_command(tmp_project, "codeflash", dev=False)
assert result == ["pnpm", "add", "codeflash"]
def test_yarn_add_command(self, tmp_project: Path) -> None:
"""Should return yarn add command for yarn projects."""
(tmp_project / "yarn.lock").write_text("")
(tmp_project / "package.json").write_text("{}")
result = get_package_install_command(tmp_project, "codeflash", dev=True)
assert result == ["yarn", "add", "codeflash", "--dev"]
def test_yarn_add_command_non_dev(self, tmp_project: Path) -> None:
"""Should return yarn add command without --dev when dev=False."""
(tmp_project / "yarn.lock").write_text("")
(tmp_project / "package.json").write_text("{}")
result = get_package_install_command(tmp_project, "codeflash", dev=False)
assert result == ["yarn", "add", "codeflash"]
def test_bun_add_command(self, tmp_project: Path) -> None:
"""Should return bun add command for bun projects."""
(tmp_project / "bun.lockb").write_text("")
(tmp_project / "package.json").write_text("{}")
result = get_package_install_command(tmp_project, "codeflash", dev=True)
assert result == ["bun", "add", "codeflash", "--dev"]
def test_bun_add_command_non_dev(self, tmp_project: Path) -> None:
"""Should return bun add command without --dev when dev=False."""
(tmp_project / "bun.lockb").write_text("")
(tmp_project / "package.json").write_text("{}")
result = get_package_install_command(tmp_project, "codeflash", dev=False)
assert result == ["bun", "add", "codeflash"]
def test_defaults_to_npm_for_unknown(self, tmp_project: Path) -> None:
"""Should default to npm for unknown package manager."""
# No lockfile, no package.json - unknown package manager
result = get_package_install_command(tmp_project, "codeflash", dev=True)
assert result == ["npm", "install", "codeflash", "--save-dev"]
def test_different_package_name(self, tmp_project: Path) -> None:
"""Should work with different package names."""
(tmp_project / "pnpm-lock.yaml").write_text("")
(tmp_project / "package.json").write_text("{}")
result = get_package_install_command(tmp_project, "typescript", dev=True)
assert result == ["pnpm", "add", "typescript", "--save-dev"]

View file

@ -6,7 +6,22 @@ covering all patterns that might be seen in the wild.
from __future__ import annotations
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.javascript.instrument import TestingMode, instrument_generated_js_test, transform_expect_calls
from codeflash.models.models import FunctionParent
def make_func(name: str, class_name: str | None = None) -> FunctionToOptimize:
"""Helper to create FunctionToOptimize for testing."""
parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else []
return FunctionToOptimize(
function_name=name,
file_path=Path("/test/file.js"),
parents=parents,
language="javascript",
)
class TestExpectCallTransformer:
@ -15,139 +30,139 @@ class TestExpectCallTransformer:
def test_basic_toBe_assertion(self) -> None:
"""Test basic .toBe() assertion removal."""
code = "expect(fibonacci(5)).toBe(5);"
result, _ = transform_expect_calls(code, "fibonacci", "fibonacci", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("fibonacci"), "capture", remove_assertions=True)
assert result == "codeflash.capture('fibonacci', '1', fibonacci, 5);"
def test_basic_toEqual_assertion(self) -> None:
"""Test .toEqual() assertion removal."""
code = "expect(func([1, 2, 3])).toEqual([1, 2, 3]);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, [1, 2, 3]);"
def test_toStrictEqual_assertion(self) -> None:
"""Test .toStrictEqual() assertion removal."""
code = "expect(func({a: 1})).toStrictEqual({a: 1});"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, {a: 1});"
def test_toBeCloseTo_with_precision(self) -> None:
"""Test .toBeCloseTo() with precision argument."""
code = "expect(func(3.14159)).toBeCloseTo(3.14, 2);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 3.14159);"
def test_toBeTruthy_no_args(self) -> None:
"""Test .toBeTruthy() assertion without arguments."""
code = "expect(func(true)).toBeTruthy();"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, true);"
def test_toBeFalsy_no_args(self) -> None:
"""Test .toBeFalsy() assertion without arguments."""
code = "expect(func(0)).toBeFalsy();"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 0);"
def test_toBeNull(self) -> None:
"""Test .toBeNull() assertion."""
code = "expect(func(null)).toBeNull();"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, null);"
def test_toBeUndefined(self) -> None:
"""Test .toBeUndefined() assertion."""
code = "expect(func()).toBeUndefined();"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func);"
def test_toBeDefined(self) -> None:
"""Test .toBeDefined() assertion."""
code = "expect(func(1)).toBeDefined();"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 1);"
def test_toBeNaN(self) -> None:
"""Test .toBeNaN() assertion."""
code = "expect(func(NaN)).toBeNaN();"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, NaN);"
def test_toBeGreaterThan(self) -> None:
"""Test .toBeGreaterThan() assertion."""
code = "expect(func(10)).toBeGreaterThan(5);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 10);"
def test_toBeLessThan(self) -> None:
"""Test .toBeLessThan() assertion."""
code = "expect(func(3)).toBeLessThan(5);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 3);"
def test_toBeGreaterThanOrEqual(self) -> None:
"""Test .toBeGreaterThanOrEqual() assertion."""
code = "expect(func(5)).toBeGreaterThanOrEqual(5);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 5);"
def test_toBeLessThanOrEqual(self) -> None:
"""Test .toBeLessThanOrEqual() assertion."""
code = "expect(func(5)).toBeLessThanOrEqual(5);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 5);"
def test_toContain(self) -> None:
"""Test .toContain() assertion."""
code = "expect(func([1, 2, 3])).toContain(2);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, [1, 2, 3]);"
def test_toContainEqual(self) -> None:
"""Test .toContainEqual() assertion."""
code = "expect(func([{a: 1}])).toContainEqual({a: 1});"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, [{a: 1}]);"
def test_toHaveLength(self) -> None:
"""Test .toHaveLength() assertion."""
code = "expect(func([1, 2, 3])).toHaveLength(3);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, [1, 2, 3]);"
def test_toMatch_string(self) -> None:
"""Test .toMatch() with string pattern."""
code = "expect(func('hello')).toMatch('ell');"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 'hello');"
def test_toMatch_regex(self) -> None:
"""Test .toMatch() with regex pattern."""
code = "expect(func('hello')).toMatch(/ell/);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 'hello');"
def test_toMatchObject(self) -> None:
"""Test .toMatchObject() assertion."""
code = "expect(func({a: 1, b: 2})).toMatchObject({a: 1});"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, {a: 1, b: 2});"
def test_toHaveProperty(self) -> None:
"""Test .toHaveProperty() assertion."""
code = "expect(func({a: 1})).toHaveProperty('a');"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, {a: 1});"
def test_toHaveProperty_with_value(self) -> None:
"""Test .toHaveProperty() with value."""
code = "expect(func({a: 1})).toHaveProperty('a', 1);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, {a: 1});"
def test_toBeInstanceOf(self) -> None:
"""Test .toBeInstanceOf() assertion."""
code = "expect(func()).toBeInstanceOf(Array);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func);"
@ -157,31 +172,31 @@ class TestNegatedAssertions:
def test_not_toBe(self) -> None:
"""Test .not.toBe() assertion removal."""
code = "expect(func(5)).not.toBe(10);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 5);"
def test_not_toEqual(self) -> None:
"""Test .not.toEqual() assertion removal."""
code = "expect(func([1, 2])).not.toEqual([3, 4]);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, [1, 2]);"
def test_not_toBeTruthy(self) -> None:
"""Test .not.toBeTruthy() assertion removal."""
code = "expect(func(0)).not.toBeTruthy();"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 0);"
def test_not_toContain(self) -> None:
"""Test .not.toContain() assertion removal."""
code = "expect(func([1, 2, 3])).not.toContain(4);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, [1, 2, 3]);"
def test_not_toBeNull(self) -> None:
"""Test .not.toBeNull() assertion removal."""
code = "expect(func(1)).not.toBeNull();"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 1);"
@ -191,31 +206,31 @@ class TestAsyncAssertions:
def test_resolves_toBe(self) -> None:
"""Test .resolves.toBe() assertion removal."""
code = "expect(asyncFunc(5)).resolves.toBe(10);"
result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True)
assert result == "codeflash.capture('asyncFunc', '1', asyncFunc, 5);"
def test_resolves_toEqual(self) -> None:
"""Test .resolves.toEqual() assertion removal."""
code = "expect(asyncFunc()).resolves.toEqual({data: 'test'});"
result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True)
assert result == "codeflash.capture('asyncFunc', '1', asyncFunc);"
def test_rejects_toThrow(self) -> None:
"""Test .rejects.toThrow() assertion removal."""
code = "expect(asyncFunc()).rejects.toThrow();"
result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True)
assert result == "codeflash.capture('asyncFunc', '1', asyncFunc);"
def test_rejects_toThrow_with_message(self) -> None:
"""Test .rejects.toThrow() with error message."""
code = "expect(asyncFunc()).rejects.toThrow('Error message');"
result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True)
assert result == "codeflash.capture('asyncFunc', '1', asyncFunc);"
def test_not_resolves_toBe(self) -> None:
"""Test .not.resolves.toBe() (rare but valid)."""
code = "expect(asyncFunc()).not.resolves.toBe(5);"
result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True)
assert result == "codeflash.capture('asyncFunc', '1', asyncFunc);"
@ -225,31 +240,31 @@ class TestNestedParentheses:
def test_nested_function_call(self) -> None:
"""Test nested function call in arguments."""
code = "expect(func(getN(5))).toBe(10);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, getN(5));"
def test_deeply_nested_calls(self) -> None:
"""Test deeply nested function calls."""
code = "expect(func(outer(inner(deep(1))))).toBe(100);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, outer(inner(deep(1))));"
def test_multiple_nested_args(self) -> None:
"""Test multiple arguments with nested calls."""
code = "expect(func(getA(), getB(getC()))).toBe(5);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, getA(), getB(getC()));"
def test_object_with_nested_calls(self) -> None:
"""Test object argument with nested function calls."""
code = "expect(func({key: getValue()})).toEqual({key: 1});"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, {key: getValue()});"
def test_array_with_nested_calls(self) -> None:
"""Test array argument with nested function calls."""
code = "expect(func([getA(), getB()])).toEqual([1, 2]);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, [getA(), getB()]);"
@ -259,31 +274,31 @@ class TestStringLiterals:
def test_string_with_parentheses(self) -> None:
"""Test string argument containing parentheses."""
code = "expect(func('hello (world)')).toBe('result');"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 'hello (world)');"
def test_double_quoted_string_with_parens(self) -> None:
"""Test double-quoted string with parentheses."""
code = 'expect(func("hello (world)")).toBe("result");'
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, \"hello (world)\");"
def test_template_literal(self) -> None:
"""Test template literal argument."""
code = "expect(func(`template ${value}`)).toBe('result');"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, `template ${value}`);"
def test_template_literal_with_parens(self) -> None:
"""Test template literal with parentheses inside."""
code = "expect(func(`hello (${name})`)).toBe('greeting');"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, `hello (${name})`);"
def test_escaped_quotes(self) -> None:
"""Test string with escaped quotes."""
code = "expect(func('it\\'s working')).toBe('yes');"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 'it\\'s working');"
@ -293,39 +308,39 @@ class TestWhitespaceHandling:
def test_leading_whitespace_preserved(self) -> None:
"""Test that leading whitespace is preserved."""
code = " expect(func(5)).toBe(5);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == " codeflash.capture('func', '1', func, 5);"
def test_tab_indentation(self) -> None:
"""Test tab indentation is preserved."""
code = "\t\texpect(func(5)).toBe(5);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "\t\tcodeflash.capture('func', '1', func, 5);"
def test_no_space_after_expect(self) -> None:
"""Test expect without space before parenthesis."""
code = "expect(func(5)).toBe(5);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 5);"
def test_space_after_expect(self) -> None:
"""Test expect with space before parenthesis."""
code = "expect (func(5)).toBe(5);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 5);"
def test_newline_in_assertion(self) -> None:
"""Test assertion split across lines."""
code = """expect(func(5))
.toBe(5);"""
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 5);"
def test_newline_after_expect_close(self) -> None:
"""Test newline after expect closing paren."""
code = """expect(func(5))
.toBe(5);"""
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 5);"
@ -337,7 +352,7 @@ class TestMultipleAssertions:
code = """expect(func(1)).toBe(1);
expect(func(2)).toBe(2);
expect(func(3)).toBe(3);"""
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
expected = """codeflash.capture('func', '1', func, 1);
codeflash.capture('func', '2', func, 2);
codeflash.capture('func', '3', func, 3);"""
@ -348,7 +363,7 @@ codeflash.capture('func', '3', func, 3);"""
code = """expect(func(1)).toBe(1);
expect(func(2)).toEqual(2);
expect(func(3)).not.toBe(0);"""
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
expected = """codeflash.capture('func', '1', func, 1);
codeflash.capture('func', '2', func, 2);
codeflash.capture('func', '3', func, 3);"""
@ -360,7 +375,7 @@ codeflash.capture('func', '3', func, 3);"""
expect(func(x)).toBe(10);
console.log('done');
expect(func(x + 1)).toBe(12);"""
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
expected = """const x = 5;
codeflash.capture('func', '1', func, x);
console.log('done');
@ -374,20 +389,20 @@ class TestSemicolonHandling:
def test_with_semicolon(self) -> None:
"""Test assertion with trailing semicolon."""
code = "expect(func(5)).toBe(5);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 5);"
def test_without_semicolon(self) -> None:
"""Test assertion without trailing semicolon."""
code = "expect(func(5)).toBe(5)"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func, 5);"
def test_multiple_without_semicolons(self) -> None:
"""Test multiple assertions without semicolons (common in some styles)."""
code = """expect(func(1)).toBe(1)
expect(func(2)).toBe(2)"""
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
expected = """codeflash.capture('func', '1', func, 1);
codeflash.capture('func', '2', func, 2);"""
assert result == expected
@ -399,25 +414,25 @@ class TestPreservingAssertions:
def test_preserve_toBe(self) -> None:
"""Test preserving .toBe() assertion."""
code = "expect(func(5)).toBe(5);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=False)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=False)
assert result == "expect(codeflash.capture('func', '1', func, 5)).toBe(5);"
def test_preserve_not_toBe(self) -> None:
"""Test preserving .not.toBe() assertion."""
code = "expect(func(5)).not.toBe(10);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=False)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=False)
assert result == "expect(codeflash.capture('func', '1', func, 5)).not.toBe(10);"
def test_preserve_resolves(self) -> None:
"""Test preserving .resolves assertion."""
code = "expect(asyncFunc(5)).resolves.toBe(10);"
result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=False)
result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=False)
assert result == "expect(codeflash.capture('asyncFunc', '1', asyncFunc, 5)).resolves.toBe(10);"
def test_preserve_toBeCloseTo(self) -> None:
"""Test preserving .toBeCloseTo() with args."""
code = "expect(func(3.14159)).toBeCloseTo(3.14, 2);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=False)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=False)
assert result == "expect(codeflash.capture('func', '1', func, 3.14159)).toBeCloseTo(3.14, 2);"
@ -427,13 +442,13 @@ class TestCaptureFunction:
def test_behavior_mode_uses_capture(self) -> None:
"""Test behavior mode uses capture function."""
code = "expect(func(5)).toBe(5);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert "codeflash.capture(" in result
def test_performance_mode_uses_capturePerf(self) -> None:
"""Test performance mode uses capturePerf function."""
code = "expect(func(5)).toBe(5);"
result, _ = transform_expect_calls(code, "func", "func", "capturePerf", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capturePerf", remove_assertions=True)
assert "codeflash.capturePerf(" in result
@ -443,13 +458,19 @@ class TestQualifiedNames:
def test_simple_qualified_name(self) -> None:
"""Test simple qualified name."""
code = "expect(func(5)).toBe(5);"
result, _ = transform_expect_calls(code, "func", "module.func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func", class_name="module"), "capture", remove_assertions=True)
assert result == "codeflash.capture('module.func', '1', func, 5);"
def test_nested_qualified_name(self) -> None:
"""Test nested qualified name."""
code = "expect(func(5)).toBe(5);"
result, _ = transform_expect_calls(code, "func", "pkg.module.func", "capture", remove_assertions=True)
func = FunctionToOptimize(
function_name="func",
file_path=Path("/test/file.js"),
parents=[FunctionParent(name="pkg", type="ClassDef"), FunctionParent(name="module", type="ClassDef")],
language="javascript",
)
result, _ = transform_expect_calls(code, func, "capture", remove_assertions=True)
assert result == "codeflash.capture('pkg.module.func', '1', func, 5);"
@ -459,7 +480,7 @@ class TestEdgeCases:
def test_function_name_as_substring(self) -> None:
"""Test that function name matching is exact."""
code = "expect(myFunc(5)).toBe(5); expect(func(10)).toBe(10);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
# Should only transform func, not myFunc
assert "expect(myFunc(5)).toBe(5)" in result
assert "codeflash.capture('func', '1', func, 10)" in result
@ -467,26 +488,26 @@ class TestEdgeCases:
def test_empty_args(self) -> None:
"""Test function call with no arguments."""
code = "expect(func()).toBe(undefined);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == "codeflash.capture('func', '1', func);"
def test_object_method_style(self) -> None:
"""Test that method calls on objects are not matched."""
code = "expect(obj.func(5)).toBe(5);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
# Should not transform method calls
assert result == "expect(obj.func(5)).toBe(5);"
def test_non_matching_code_unchanged(self) -> None:
"""Test that non-matching code remains unchanged."""
code = "const x = func(5); console.log(x);"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
assert result == code
def test_expect_without_assertion(self) -> None:
"""Test expect without assertion is not transformed."""
code = "const result = expect(func(5));"
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
# Should not transform as there's no assertion
assert result == code
@ -504,7 +525,7 @@ describe('fibonacci', () => {
expect(fibonacci(10)).toBe(55);
});
});"""
result = instrument_generated_js_test(code, "fibonacci", "fibonacci", TestingMode.BEHAVIOR)
result = instrument_generated_js_test(code, make_func("fibonacci"), TestingMode.BEHAVIOR)
assert "import codeflash from 'codeflash'" in result
assert "codeflash.capture('fibonacci'" in result
assert ".toBe(" not in result
@ -518,7 +539,7 @@ describe('fibonacci', () => {
expect(fibonacci(5)).toBe(5);
});
});"""
result = instrument_generated_js_test(code, "fibonacci", "fibonacci", TestingMode.PERFORMANCE)
result = instrument_generated_js_test(code, make_func("fibonacci"), TestingMode.PERFORMANCE)
assert "import codeflash from 'codeflash'" in result
assert "codeflash.capturePerf('fibonacci'" in result
assert ".toBe(" not in result
@ -532,7 +553,7 @@ describe('fibonacci', () => {
expect(fibonacci(5)).toBe(5);
});
});"""
result = instrument_generated_js_test(code, "fibonacci", "fibonacci", TestingMode.BEHAVIOR)
result = instrument_generated_js_test(code, make_func("fibonacci"), TestingMode.BEHAVIOR)
assert "const codeflash = require('codeflash')" in result
assert "codeflash.capture('fibonacci'" in result
@ -549,7 +570,7 @@ describe('func', () => {
expect(func(null)).toBeNull();
});
});"""
result = instrument_generated_js_test(code, "func", "func", TestingMode.BEHAVIOR)
result = instrument_generated_js_test(code, make_func("func"), TestingMode.BEHAVIOR)
# All assertions should be removed
assert ".toBe(" not in result
assert ".not." not in result
@ -561,12 +582,12 @@ describe('func', () => {
def test_empty_code(self) -> None:
"""Test with empty code."""
result = instrument_generated_js_test("", "func", "func", TestingMode.BEHAVIOR)
result = instrument_generated_js_test("", make_func("func"), TestingMode.BEHAVIOR)
assert result == ""
def test_whitespace_only_code(self) -> None:
"""Test with whitespace-only code."""
result = instrument_generated_js_test(" \n\t ", "func", "func", TestingMode.BEHAVIOR)
result = instrument_generated_js_test(" \n\t ", make_func("func"), TestingMode.BEHAVIOR)
assert result == " \n\t "
@ -594,7 +615,7 @@ describe('processData', () => {
});
});
});"""
result = instrument_generated_js_test(code, "processData", "processData", TestingMode.BEHAVIOR)
result = instrument_generated_js_test(code, make_func("processData"), TestingMode.BEHAVIOR)
assert result.count("codeflash.capture(") == 3
assert "toEqual(" not in result
assert "toBeNull(" not in result
@ -612,7 +633,7 @@ describe('calculate', () => {
expect(calculate(2, 3, 'mul')).toBe(6);
});
});"""
result = instrument_generated_js_test(code, "calculate", "calculate", TestingMode.BEHAVIOR)
result = instrument_generated_js_test(code, make_func("calculate"), TestingMode.BEHAVIOR)
assert result.count("codeflash.capture(") == 2
assert ".toBe(" not in result
@ -629,7 +650,7 @@ describe('fetchData', () => {
expect(fetchData('/invalid')).rejects.toThrow('Not found');
});
});"""
result = instrument_generated_js_test(code, "fetchData", "fetchData", TestingMode.BEHAVIOR)
result = instrument_generated_js_test(code, make_func("fetchData"), TestingMode.BEHAVIOR)
assert result.count("codeflash.capture(") == 2
assert ".resolves." not in result
assert ".rejects." not in result
@ -647,6 +668,6 @@ describe('calculatePi', () => {
expect(calculatePi(5)).toBeCloseTo(3.14159, 5);
});
});"""
result = instrument_generated_js_test(code, "calculatePi", "calculatePi", TestingMode.BEHAVIOR)
result = instrument_generated_js_test(code, make_func("calculatePi"), TestingMode.BEHAVIOR)
assert result.count("codeflash.capture(") == 2
assert ".toBeCloseTo(" not in result

View file

@ -346,11 +346,11 @@ function capitalize(str) {
"""Test getting a specific function by name."""
js_file = tmp_path / "math_utils.js"
js_file.write_text("""
function add(a, b) {
export function add(a, b) {
return a + b;
}
function subtract(a, b) {
export function subtract(a, b) {
return a - b;
}
""")
@ -378,7 +378,7 @@ function subtract(a, b) {
"""Test getting a specific class method."""
js_file = tmp_path / "calculator.js"
js_file.write_text("""
class Calculator {
export class Calculator {
add(a, b) {
return a + b;
}
@ -388,7 +388,7 @@ class Calculator {
}
}
function standaloneFunc() {
export function standaloneFunc() {
return 42;
}
""")

View file

@ -90,138 +90,138 @@ class TestParentInfo:
class TestFunctionInfo:
"""Tests for the FunctionInfo dataclass."""
"""Tests for the FunctionInfo dataclass (alias for FunctionToOptimize)."""
def test_function_info_creation_minimal(self):
"""Test creating FunctionInfo with minimal args."""
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
assert func.name == "add"
func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3)
assert func.function_name == "add"
assert func.file_path == Path("/test/example.py")
assert func.start_line == 1
assert func.end_line == 3
assert func.parents == ()
assert func.starting_line == 1
assert func.ending_line == 3
assert func.parents == []
assert func.is_async is False
assert func.is_method is False
assert func.language == Language.PYTHON
assert func.language == "python"
def test_function_info_creation_full(self):
"""Test creating FunctionInfo with all args."""
parents = (ParentInfo(name="Calculator", type="ClassDef"),)
parents = [ParentInfo(name="Calculator", type="ClassDef")]
func = FunctionInfo(
name="add",
function_name="add",
file_path=Path("/test/example.py"),
start_line=10,
end_line=15,
starting_line=10,
ending_line=15,
parents=parents,
is_async=True,
is_method=True,
language=Language.PYTHON,
start_col=4,
end_col=20,
language="python",
starting_col=4,
ending_col=20,
)
assert func.name == "add"
assert func.function_name == "add"
assert func.parents == parents
assert func.is_async is True
assert func.is_method is True
assert func.start_col == 4
assert func.end_col == 20
assert func.starting_col == 4
assert func.ending_col == 20
def test_function_info_frozen(self):
"""Test that FunctionInfo is immutable."""
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3)
with pytest.raises(AttributeError):
func.name = "new_name"
func.function_name = "new_name"
def test_qualified_name_no_parents(self):
"""Test qualified_name without parents."""
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3)
assert func.qualified_name == "add"
def test_qualified_name_with_class(self):
"""Test qualified_name with class parent."""
func = FunctionInfo(
name="add",
function_name="add",
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
starting_line=1,
ending_line=3,
parents=[ParentInfo(name="Calculator", type="ClassDef")],
)
assert func.qualified_name == "Calculator.add"
def test_qualified_name_nested(self):
"""Test qualified_name with nested parents."""
func = FunctionInfo(
name="inner",
function_name="inner",
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")),
starting_line=1,
ending_line=3,
parents=[ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")],
)
assert func.qualified_name == "Outer.Inner.inner"
def test_class_name_with_class(self):
"""Test class_name property with class parent."""
func = FunctionInfo(
name="add",
function_name="add",
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
starting_line=1,
ending_line=3,
parents=[ParentInfo(name="Calculator", type="ClassDef")],
)
assert func.class_name == "Calculator"
def test_class_name_without_class(self):
"""Test class_name property without class parent."""
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3)
assert func.class_name is None
def test_class_name_nested_function(self):
"""Test class_name for function nested in another function."""
func = FunctionInfo(
name="inner",
function_name="inner",
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
parents=(ParentInfo(name="outer", type="FunctionDef"),),
starting_line=1,
ending_line=3,
parents=[ParentInfo(name="outer", type="FunctionDef")],
)
assert func.class_name is None
def test_class_name_method_in_nested_class(self):
"""Test class_name for method in nested class."""
func = FunctionInfo(
name="method",
function_name="method",
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")),
starting_line=1,
ending_line=3,
parents=[ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")],
)
# Should return the immediate parent class
assert func.class_name == "Inner"
def test_top_level_parent_name_no_parents(self):
"""Test top_level_parent_name without parents."""
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3)
assert func.top_level_parent_name == "add"
def test_top_level_parent_name_with_parents(self):
"""Test top_level_parent_name with parents."""
func = FunctionInfo(
name="method",
function_name="method",
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")),
starting_line=1,
ending_line=3,
parents=[ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")],
)
assert func.top_level_parent_name == "Outer"
def test_function_info_str(self):
"""Test string representation."""
func = FunctionInfo(
name="add",
function_name="add",
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
starting_line=1,
ending_line=3,
parents=[ParentInfo(name="Calculator", type="ClassDef")],
)
s = str(func)
assert "Calculator.add" in s

View file

@ -90,7 +90,7 @@ const multiply = (a, b) => a * b;
functions = js_support.discover_functions(file_path)
assert len(functions) == 1
func = functions[0]
assert func.name == "multiply"
assert func.function_name == "multiply"
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -268,7 +268,7 @@ class CacheManager {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
get_or_compute = next(f for f in functions if f.name == "getOrCompute")
get_or_compute = next(f for f in functions if f.function_name == "getOrCompute")
context = js_support.extract_code_context(get_or_compute, temp_project, temp_project)
@ -370,7 +370,7 @@ function validateUserData(data, validators) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
func = next(f for f in functions if f.name == "validateUserData")
func = next(f for f in functions if f.function_name == "validateUserData")
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -466,7 +466,7 @@ async function fetchWithRetry(endpoint, options = {}) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
func = next(f for f in functions if f.name == "fetchWithRetry")
func = next(f for f in functions if f.function_name == "fetchWithRetry")
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -615,7 +615,7 @@ function processUserInput(rawInput) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
process_func = next(f for f in functions if f.name == "processUserInput")
process_func = next(f for f in functions if f.function_name == "processUserInput")
context = js_support.extract_code_context(process_func, temp_project, temp_project)
@ -670,7 +670,7 @@ function generateReport(data) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
report_func = next(f for f in functions if f.name == "generateReport")
report_func = next(f for f in functions if f.function_name == "generateReport")
context = js_support.extract_code_context(report_func, temp_project, temp_project)
@ -768,7 +768,7 @@ class Graph {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
topo_sort = next(f for f in functions if f.name == "topologicalSort")
topo_sort = next(f for f in functions if f.function_name == "topologicalSort")
context = js_support.extract_code_context(topo_sort, temp_project, temp_project)
@ -843,7 +843,7 @@ class MainClass {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
main_method = next(f for f in functions if f.name == "mainMethod" and f.class_name == "MainClass")
main_method = next(f for f in functions if f.function_name == "mainMethod" and f.class_name == "MainClass")
context = js_support.extract_code_context(main_method, temp_project, temp_project)
@ -899,7 +899,7 @@ module.exports = { sortFromAnotherFile };
main_path.write_text(main_code, encoding="utf-8")
functions = js_support.discover_functions(main_path)
main_func = next(f for f in functions if f.name == "sortFromAnotherFile")
main_func = next(f for f in functions if f.function_name == "sortFromAnotherFile")
context = js_support.extract_code_context(main_func, temp_project, temp_project)
@ -952,7 +952,7 @@ export { processNumber };
main_path.write_text(main_code, encoding="utf-8")
functions = js_support.discover_functions(main_path)
process_func = next(f for f in functions if f.name == "processNumber")
process_func = next(f for f in functions if f.function_name == "processNumber")
context = js_support.extract_code_context(process_func, temp_project, temp_project)
@ -1020,7 +1020,7 @@ export { handleUserInput };
main_path.write_text(main_code, encoding="utf-8")
functions = js_support.discover_functions(main_path)
handle_func = next(f for f in functions if f.name == "handleUserInput")
handle_func = next(f for f in functions if f.function_name == "handleUserInput")
context = js_support.extract_code_context(handle_func, temp_project, temp_project)
@ -1161,7 +1161,7 @@ class TypedCache<T> {
file_path.write_text(code, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
get_method = next(f for f in functions if f.name == "get")
get_method = next(f for f in functions if f.function_name == "get")
context = ts_support.extract_code_context(get_method, temp_project, temp_project)
@ -1247,7 +1247,7 @@ export { createUser };
service_path.write_text(service_code, encoding="utf-8")
functions = ts_support.discover_functions(service_path)
func = next(f for f in functions if f.name == "createUser")
func = next(f for f in functions if f.function_name == "createUser")
context = ts_support.extract_code_context(func, temp_project, temp_project)
@ -1331,7 +1331,7 @@ function isOdd(n) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
is_even = next(f for f in functions if f.name == "isEven")
is_even = next(f for f in functions if f.function_name == "isEven")
context = js_support.extract_code_context(is_even, temp_project, temp_project)
@ -1393,7 +1393,7 @@ function collectAllValues(root) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
collect_func = next(f for f in functions if f.name == "collectAllValues")
collect_func = next(f for f in functions if f.function_name == "collectAllValues")
context = js_support.extract_code_context(collect_func, temp_project, temp_project)
@ -1458,7 +1458,7 @@ async function fetchUserProfile(userId) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
profile_func = next(f for f in functions if f.name == "fetchUserProfile")
profile_func = next(f for f in functions if f.function_name == "fetchUserProfile")
context = js_support.extract_code_context(profile_func, temp_project, temp_project)
@ -1513,7 +1513,7 @@ module.exports = { Counter };
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
increment_func = next(fn for fn in functions if fn.name == "increment")
increment_func = next(fn for fn in functions if fn.function_name == "increment")
# Step 1: Extract code context
context = js_support.extract_code_context(increment_func, temp_project, temp_project)
@ -1635,7 +1635,7 @@ function* fibonacci(limit) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
range_func = next(f for f in functions if f.name == "range")
range_func = next(f for f in functions if f.function_name == "range")
context = js_support.extract_code_context(range_func, temp_project, temp_project)
@ -1772,7 +1772,7 @@ class Calculator {
functions = js_support.discover_functions(file_path)
for func in functions:
if func.name != "constructor":
if func.function_name != "constructor":
context = js_support.extract_code_context(func, temp_project, temp_project)
is_valid = js_support.validate_syntax(context.target_code)
assert is_valid is True, f"Invalid syntax for {func.name}:\n{context.target_code}"

File diff suppressed because it is too large Load diff

View file

@ -474,6 +474,13 @@ class TestCommonJSExports:
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
@pytest.fixture
def ts_analyzer(self):
"""Create a TypeScript analyzer."""
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
def test_module_exports_function(self, js_analyzer):
"""Test module.exports = function() {}."""
code = "module.exports = function helper() { return 1; };"
@ -584,6 +591,51 @@ exports.helper = helper;
assert is_exported is True
assert export_name == "helper"
def test_is_class_method_exported_via_class(self, ts_analyzer):
"""Test is_function_exported returns True for method of exported class."""
code = """
export class BloomFilter {
getHashValues(key: string): number[] {
return [1, 2, 3];
}
}
"""
# Method itself is not directly exported
is_exported, export_name = ts_analyzer.is_function_exported(code, "getHashValues")
assert is_exported is False
assert export_name is None
# But when we pass the class name, it should find the class export
is_exported, export_name = ts_analyzer.is_function_exported(code, "getHashValues", "BloomFilter")
assert is_exported is True
assert export_name == "BloomFilter"
def test_is_class_method_exported_default_class(self, ts_analyzer):
"""Test is_function_exported returns True for method of default exported class."""
code = """
export default class Calculator {
add(a: number, b: number): number {
return a + b;
}
}
"""
# When we pass the class name, it should find the default export
is_exported, export_name = ts_analyzer.is_function_exported(code, "add", "Calculator")
assert is_exported is True
assert export_name == "Calculator"
def test_is_class_method_not_exported_non_exported_class(self, ts_analyzer):
"""Test is_function_exported returns False for method of non-exported class."""
code = """
class InternalClass {
helper(): void {}
}
"""
# Even with class name, non-exported class method should not be exported
is_exported, export_name = ts_analyzer.is_function_exported(code, "helper", "InternalClass")
assert is_exported is False
assert export_name is None
class TestCommonJSImportResolver:
"""Tests for ImportResolver with CommonJS require() imports."""

View file

@ -312,7 +312,7 @@ public class Calculator {
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
add_func = next((f for f in functions if f.name == "add"), None)
add_func = next((f for f in functions if f.function_name == "add"), None)
assert add_func is not None
context = extract_code_context(add_func, tmp_path)
@ -678,7 +678,7 @@ class TestExtractCodeContextWithHelpers:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
process_func = next((f for f in functions if f.name == "process"), None)
process_func = next((f for f in functions if f.function_name == "process"), None)
assert process_func is not None
context = extract_code_context(process_func, tmp_path)
@ -719,7 +719,7 @@ class TestExtractCodeContextWithHelpers:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
process_func = next((f for f in functions if f.name == "process"), None)
process_func = next((f for f in functions if f.function_name == "process"), None)
assert process_func is not None
context = extract_code_context(process_func, tmp_path)
@ -756,7 +756,7 @@ class TestExtractCodeContextWithHelpers:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
process_func = next((f for f in functions if f.name == "process"), None)
process_func = next((f for f in functions if f.function_name == "process"), None)
assert process_func is not None
context = extract_code_context(process_func, tmp_path)
@ -780,7 +780,7 @@ class TestExtractCodeContextWithHelpers:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
add_func = next((f for f in functions if f.name == "add"), None)
add_func = next((f for f in functions if f.function_name == "add"), None)
assert add_func is not None
context = extract_code_context(add_func, tmp_path)
@ -809,7 +809,7 @@ class TestExtractCodeContextWithHelpers:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
calc_func = next((f for f in functions if f.name == "calculate"), None)
calc_func = next((f for f in functions if f.function_name == "calculate"), None)
assert calc_func is not None
context = extract_code_context(calc_func, tmp_path)
@ -1392,7 +1392,7 @@ class TestExtractCodeContextWithInheritance:
)
assert len(functions) == 2
run_func = next((f for f in functions if f.name == "run"), None)
run_func = next((f for f in functions if f.function_name == "run"), None)
assert run_func is not None
context = extract_code_context(run_func, tmp_path)
@ -1417,7 +1417,7 @@ class TestExtractCodeContextWithInheritance:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
greet_func = next((f for f in functions if f.name == "greet"), None)
greet_func = next((f for f in functions if f.function_name == "greet"), None)
assert greet_func is not None
context = extract_code_context(greet_func, tmp_path)
@ -1449,7 +1449,7 @@ class TestExtractCodeContextWithInnerClasses:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
compute_func = next((f for f in functions if f.name == "compute"), None)
compute_func = next((f for f in functions if f.function_name == "compute"), None)
assert compute_func is not None
context = extract_code_context(compute_func, tmp_path)
@ -1481,7 +1481,7 @@ class TestExtractCodeContextWithInnerClasses:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
get_func = next((f for f in functions if f.name == "getValue"), None)
get_func = next((f for f in functions if f.function_name == "getValue"), None)
assert get_func is not None
context = extract_code_context(get_func, tmp_path)
@ -1521,7 +1521,7 @@ class TestExtractCodeContextWithEnumAndInterface:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
apply_func = next((f for f in functions if f.name == "apply"), None)
apply_func = next((f for f in functions if f.function_name == "apply"), None)
assert apply_func is not None
context = extract_code_context(apply_func, tmp_path)
@ -1555,7 +1555,7 @@ class TestExtractCodeContextWithEnumAndInterface:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
greet_func = next((f for f in functions if f.name == "greet"), None)
greet_func = next((f for f in functions if f.function_name == "greet"), None)
assert greet_func is not None
context = extract_code_context(greet_func, tmp_path)
@ -1581,7 +1581,7 @@ class TestExtractCodeContextWithEnumAndInterface:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
create_func = next((f for f in functions if f.name == "create"), None)
create_func = next((f for f in functions if f.function_name == "create"), None)
assert create_func is not None
context = extract_code_context(create_func, tmp_path)
@ -1767,16 +1767,17 @@ class TestExtractCodeContextEdgeCases:
def test_file_not_found(self, tmp_path: Path):
"""Test context extraction for missing file."""
from codeflash.languages.base import FunctionInfo
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.function_types import FunctionParent
missing_file = tmp_path / "NonExistent.java"
func = FunctionInfo(
name="test",
func = FunctionToOptimize(
function_name="test",
file_path=missing_file,
start_line=1,
end_line=5,
parents=(ParentInfo(name="Test", type="ClassDef"),),
language=Language.JAVA,
starting_line=1,
ending_line=5,
parents=[FunctionParent(name="Test", type="ClassDef")],
language="java",
)
context = extract_code_context(func, tmp_path)
@ -1801,7 +1802,7 @@ class TestExtractCodeContextEdgeCases:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
calc_func = next((f for f in functions if f.name == "calculate"), None)
calc_func = next((f for f in functions if f.function_name == "calculate"), None)
assert calc_func is not None
context = extract_code_context(calc_func, tmp_path, max_helper_depth=0)
@ -1838,7 +1839,7 @@ class TestExtractCodeContextWithConstructor:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
get_func = next((f for f in functions if f.name == "getName"), None)
get_func = next((f for f in functions if f.function_name == "getName"), None)
assert get_func is not None
context = extract_code_context(get_func, tmp_path)
@ -1885,7 +1886,7 @@ class TestExtractCodeContextWithConstructor:
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
get_func = next((f for f in functions if f.name == "getName"), None)
get_func = next((f for f in functions if f.function_name == "getName"), None)
assert get_func is not None
context = extract_code_context(get_func, tmp_path)
@ -1940,7 +1941,7 @@ public class Service {
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
process_func = next((f for f in functions if f.name == "process"), None)
process_func = next((f for f in functions if f.function_name == "process"), None)
assert process_func is not None
context = extract_code_context(process_func, tmp_path)
@ -2000,7 +2001,7 @@ public class Calculator {
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
sqrt_func = next((f for f in functions if f.name == "sqrtNewton"), None)
sqrt_func = next((f for f in functions if f.function_name == "sqrtNewton"), None)
assert sqrt_func is not None
context = extract_code_context(sqrt_func, tmp_path)

View file

@ -28,7 +28,7 @@ public class Calculator {
"""
functions = discover_functions_from_source(source)
assert len(functions) == 1
assert functions[0].name == "add"
assert functions[0].function_name == "add"
assert functions[0].language == Language.JAVA
assert functions[0].is_method is True
assert functions[0].class_name == "Calculator"
@ -52,7 +52,7 @@ public class Calculator {
"""
functions = discover_functions_from_source(source)
assert len(functions) == 3
method_names = {f.name for f in functions}
method_names = {f.function_name for f in functions}
assert method_names == {"add", "subtract", "multiply"}
def test_skip_abstract_methods(self):
@ -69,7 +69,7 @@ public abstract class Shape {
functions = discover_functions_from_source(source)
# Should only find perimeter, not area
assert len(functions) == 1
assert functions[0].name == "perimeter"
assert functions[0].function_name == "perimeter"
def test_skip_constructors(self):
"""Test that constructors are skipped."""
@ -89,7 +89,7 @@ public class Person {
functions = discover_functions_from_source(source)
# Should only find getName, not the constructor
assert len(functions) == 1
assert functions[0].name == "getName"
assert functions[0].function_name == "getName"
def test_filter_by_pattern(self):
"""Test filtering by include patterns."""
@ -111,7 +111,7 @@ public class StringUtils {
criteria = FunctionFilterCriteria(include_patterns=["*Upper*", "*Lower*"])
functions = discover_functions_from_source(source, filter_criteria=criteria)
assert len(functions) == 2
method_names = {f.name for f in functions}
method_names = {f.function_name for f in functions}
assert method_names == {"toUpperCase", "toLowerCase"}
def test_filter_exclude_pattern(self):
@ -128,7 +128,7 @@ public class DataService {
require_return=False, # Allow void methods
)
functions = discover_functions_from_source(source, filter_criteria=criteria)
method_names = {f.name for f in functions}
method_names = {f.function_name for f in functions}
assert "setData" not in method_names
def test_filter_require_return(self):
@ -145,7 +145,7 @@ public class Example {
criteria = FunctionFilterCriteria(require_return=True)
functions = discover_functions_from_source(source, filter_criteria=criteria)
assert len(functions) == 1
assert functions[0].name == "getValue"
assert functions[0].function_name == "getValue"
def test_filter_by_line_count(self):
"""Test filtering by line count."""
@ -167,7 +167,7 @@ public class Example {
functions = discover_functions_from_source(source, filter_criteria=criteria)
# The 'long' method should be included (>3 lines)
# The 'short' method should be excluded (1 line)
method_names = {f.name for f in functions}
method_names = {f.function_name for f in functions}
assert "long" in method_names or len(functions) >= 1
def test_method_with_javadoc(self):
@ -189,7 +189,7 @@ public class Example {
assert len(functions) == 1
assert functions[0].doc_start_line is not None
# Doc should start before the method
assert functions[0].doc_start_line < functions[0].start_line
assert functions[0].doc_start_line < functions[0].starting_line
class TestDiscoverTestMethods:
@ -222,7 +222,7 @@ class CalculatorTest {
""")
tests = discover_test_methods(test_file)
assert len(tests) == 2
test_names = {t.name for t in tests}
test_names = {t.function_name for t in tests}
assert test_names == {"testAdd", "testSubtract"}
def test_discover_parameterized_tests(self, tmp_path: Path):
@ -244,7 +244,7 @@ class StringTest {
""")
tests = discover_test_methods(test_file)
assert len(tests) == 1
assert tests[0].name == "testLength"
assert tests[0].function_name == "testLength"
class TestGetMethodByName:
@ -266,7 +266,7 @@ public class Calculator {
""")
method = get_method_by_name(java_file, "add")
assert method is not None
assert method.name == "add"
assert method.function_name == "add"
def test_get_method_not_found(self, tmp_path: Path):
"""Test getting a method that doesn't exist."""
@ -299,7 +299,7 @@ class Helper {
""")
methods = get_class_methods(java_file, "Calculator")
assert len(methods) == 1
assert methods[0].name == "add"
assert methods[0].function_name == "add"
class TestFileBasedDiscovery:
@ -321,7 +321,7 @@ class TestFileBasedDiscovery:
functions = discover_functions(calculator_file)
assert len(functions) > 0
method_names = {f.name for f in functions}
method_names = {f.function_name for f in functions}
# Should find methods from Calculator.java
assert "fibonacci" in method_names or "add" in method_names or len(method_names) > 0

View file

@ -18,8 +18,10 @@ from pathlib import Path
import pytest
from codeflash.languages.base import FunctionInfo, Language
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.languages.current import set_current_language
from codeflash.models.function_types import FunctionParent
from codeflash.languages.java.build_tools import find_maven_executable
from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.instrumentation import (
@ -76,14 +78,14 @@ public class CalculatorTest {
}
}
"""
func = FunctionInfo(
name="add",
func = FunctionToOptimize(
function_name="add",
file_path=Path("Calculator.java"),
start_line=1,
end_line=5,
parents=(),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
result = instrument_for_benchmarking(source, func)
@ -108,14 +110,14 @@ public class CalculatorTest {
"""
test_file.write_text(source)
func = FunctionInfo(
name="add",
func = FunctionToOptimize(
function_name="add",
file_path=tmp_path / "Calculator.java",
start_line=1,
end_line=5,
parents=(),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, result = instrument_existing_test(
@ -158,14 +160,14 @@ public class CalculatorTest {
"""
test_file.write_text(source)
func = FunctionInfo(
name="add",
func = FunctionToOptimize(
function_name="add",
file_path=tmp_path / "Calculator.java",
start_line=1,
end_line=5,
parents=(),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, result = instrument_existing_test(
@ -225,14 +227,14 @@ public class MathTest {
"""
test_file.write_text(source)
func = FunctionInfo(
name="calculate",
func = FunctionToOptimize(
function_name="calculate",
file_path=tmp_path / "Math.java",
start_line=1,
end_line=5,
parents=(),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, result = instrument_existing_test(
@ -317,14 +319,14 @@ public class ServiceTest {
"""
test_file.write_text(source)
func = FunctionInfo(
name="call",
func = FunctionToOptimize(
function_name="call",
file_path=tmp_path / "Service.java",
start_line=1,
end_line=5,
parents=(),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, result = instrument_existing_test(
@ -394,14 +396,14 @@ public class ServiceTest__perfonlyinstrumented {
"""Test handling missing test file."""
test_file = tmp_path / "NonExistent.java"
func = FunctionInfo(
name="add",
func = FunctionToOptimize(
function_name="add",
file_path=tmp_path / "Calculator.java",
start_line=1,
end_line=5,
parents=(),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, result = instrument_existing_test(
@ -562,14 +564,14 @@ class TestCreateBenchmarkTest:
def test_create_benchmark(self):
"""Test creating a benchmark test."""
func = FunctionInfo(
name="add",
func = FunctionToOptimize(
function_name="add",
file_path=Path("Calculator.java"),
start_line=1,
end_line=5,
parents=(),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
result = create_benchmark_test(
@ -617,14 +619,14 @@ public class TargetBenchmark {
def test_create_benchmark_different_iterations(self):
"""Test benchmark with different iteration count."""
func = FunctionInfo(
name="multiply",
func = FunctionToOptimize(
function_name="multiply",
file_path=Path("Math.java"),
start_line=1,
end_line=3,
parents=(),
starting_line=1,
ending_line=3,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
result = create_benchmark_test(
@ -905,14 +907,14 @@ public class BraceTest {
"""
test_file.write_text(source)
func = FunctionInfo(
name="process",
func = FunctionToOptimize(
function_name="process",
file_path=tmp_path / "Processor.java",
start_line=1,
end_line=5,
parents=(),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, result = instrument_existing_test(
@ -998,14 +1000,14 @@ public class ImportTest {
"""
test_file.write_text(source)
func = FunctionInfo(
name="size",
func = FunctionToOptimize(
function_name="size",
file_path=tmp_path / "Collection.java",
start_line=1,
end_line=5,
parents=(),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, result = instrument_existing_test(
@ -1068,14 +1070,14 @@ public class EmptyTest {
"""
test_file.write_text(source)
func = FunctionInfo(
name="empty",
func = FunctionToOptimize(
function_name="empty",
file_path=tmp_path / "Empty.java",
start_line=1,
end_line=5,
parents=(),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, result = instrument_existing_test(
@ -1134,14 +1136,14 @@ public class NestedTest {
"""
test_file.write_text(source)
func = FunctionInfo(
name="process",
func = FunctionToOptimize(
function_name="process",
file_path=tmp_path / "Processor.java",
start_line=1,
end_line=5,
parents=(),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, result = instrument_existing_test(
@ -1210,14 +1212,14 @@ public class InnerClassTest {
"""
test_file.write_text(source)
func = FunctionInfo(
name="testMethod",
func = FunctionToOptimize(
function_name="testMethod",
file_path=tmp_path / "Target.java",
start_line=1,
end_line=5,
parents=(),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, result = instrument_existing_test(
@ -1412,14 +1414,14 @@ public class CalculatorTest {
test_file = test_dir / "CalculatorTest.java"
test_file.write_text(test_source, encoding="utf-8")
func_info = FunctionInfo(
name="add",
func_info = FunctionToOptimize(
function_name="add",
file_path=src_dir / "Calculator.java",
start_line=4,
end_line=6,
parents=(),
starting_line=4,
ending_line=6,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, instrumented = instrument_existing_test(
@ -1524,14 +1526,14 @@ public class MathUtilsTest {
test_file = test_dir / "MathUtilsTest.java"
test_file.write_text(test_source, encoding="utf-8")
func_info = FunctionInfo(
name="multiply",
func_info = FunctionToOptimize(
function_name="multiply",
file_path=src_dir / "MathUtils.java",
start_line=4,
end_line=6,
parents=(),
starting_line=4,
ending_line=6,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, instrumented = instrument_existing_test(
@ -1657,14 +1659,14 @@ public class StringUtilsTest {
test_file = test_dir / "StringUtilsTest.java"
test_file.write_text(test_source, encoding="utf-8")
func_info = FunctionInfo(
name="reverse",
func_info = FunctionToOptimize(
function_name="reverse",
file_path=src_dir / "StringUtils.java",
start_line=4,
end_line=6,
parents=(),
starting_line=4,
ending_line=6,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, instrumented = instrument_existing_test(
@ -1759,14 +1761,14 @@ public class BrokenCalcTest {
test_file = test_dir / "BrokenCalcTest.java"
test_file.write_text(test_source, encoding="utf-8")
func_info = FunctionInfo(
name="add",
func_info = FunctionToOptimize(
function_name="add",
file_path=src_dir / "BrokenCalc.java",
start_line=4,
end_line=6,
parents=(),
starting_line=4,
ending_line=6,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, instrumented = instrument_existing_test(
@ -1869,14 +1871,14 @@ public class CounterTest {
test_file.write_text(test_source, encoding="utf-8")
# Instrument for BEHAVIOR mode (this should include SQLite writing)
func_info = FunctionInfo(
name="increment",
func_info = FunctionToOptimize(
function_name="increment",
file_path=src_dir / "Counter.java",
start_line=6,
end_line=8,
parents=(),
starting_line=6,
ending_line=8,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, instrumented = instrument_existing_test(
@ -2033,14 +2035,14 @@ public class FibonacciTest {
test_file.write_text(test_source, encoding="utf-8")
# Instrument for performance mode (adds inner loop)
func_info = FunctionInfo(
name="fib",
func_info = FunctionToOptimize(
function_name="fib",
file_path=src_dir / "Fibonacci.java",
start_line=4,
end_line=7,
parents=(),
starting_line=4,
ending_line=7,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, instrumented = instrument_existing_test(
@ -2154,14 +2156,14 @@ public class MathOpsTest {
test_file.write_text(test_source, encoding="utf-8")
# Instrument for performance mode
func_info = FunctionInfo(
name="add",
func_info = FunctionToOptimize(
function_name="add",
file_path=src_dir / "MathOps.java",
start_line=4,
end_line=6,
parents=(),
starting_line=4,
ending_line=6,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
success, instrumented = instrument_existing_test(

View file

@ -104,7 +104,7 @@ class TestEndToEndWorkflow:
context = extract_code_context(func, java_fixture_path)
assert context.target_code
assert func.name in context.target_code
assert func.function_name in context.target_code
assert context.language == Language.JAVA
def test_code_replacement_workflow(self):
@ -198,7 +198,7 @@ public class StringUtilsTest {
# 1. Discover functions
functions = support.discover_functions(src_file)
assert len(functions) == 1
assert functions[0].name == "reverse"
assert functions[0].function_name == "reverse"
# 2. Extract code context
context = support.extract_code_context(functions[0], tmp_path, tmp_path)
@ -330,14 +330,14 @@ public class Example {
criteria = FunctionFilterCriteria(include_patterns=["public*"])
functions = discover_functions_from_source(source, filter_criteria=criteria)
# Should match publicMethod
public_names = {f.name for f in functions}
public_names = {f.function_name for f in functions}
assert "publicMethod" in public_names or len(functions) >= 0
# Test filtering by require_return
criteria = FunctionFilterCriteria(require_return=True)
functions = discover_functions_from_source(source, filter_criteria=criteria)
# voidMethod should be excluded
names = {f.name for f in functions}
names = {f.function_name for f in functions}
assert "voidMethod" not in names

View file

@ -13,6 +13,7 @@ from codeflash.code_utils.code_replacer import (
replace_function_definitions_for_language,
replace_function_definitions_in_module,
)
from codeflash.models.function_types import FunctionParent
from codeflash.languages.base import Language
from codeflash.languages import current as language_current
from codeflash.models.models import CodeStringsMarkdown
@ -1524,7 +1525,7 @@ public final class Buffer {{
file_path=java_file,
starting_line=13, # Line where 3-arg version starts (1-indexed)
ending_line=18,
parents=(FunctionParent(name="Buffer", type="class"),),
parents=[FunctionParent(name="Buffer", type="ClassDef")],
qualified_name="Buffer.bytesToHexString",
is_method=True,
)

View file

@ -58,7 +58,7 @@ public class Calculator {
functions = support.discover_functions(java_file)
assert len(functions) == 1
assert functions[0].name == "add"
assert functions[0].function_name == "add"
assert functions[0].language == Language.JAVA
def test_validate_syntax_valid(self, support):

View file

@ -167,16 +167,17 @@ public class StringUtilsTest {
""")
# Create source function
from codeflash.languages.base import FunctionInfo, Language
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
func = FunctionInfo(
name="reverse",
func = FunctionToOptimize(
function_name="reverse",
file_path=tmp_path / "StringUtils.java",
start_line=1,
end_line=5,
parents=(),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language=Language.JAVA,
language="java",
)
tests = find_tests_for_function(func, test_dir)
@ -246,7 +247,7 @@ public class TestQueryBlob {
)
# Filter to just bytesToHexString
target_functions = [f for f in source_functions if f.name == "bytesToHexString"]
target_functions = [f for f in source_functions if f.function_name == "bytesToHexString"]
assert len(target_functions) == 1, "Should find bytesToHexString function"
# Discover tests

View file

@ -5,16 +5,29 @@ Tests the full optimization pipeline for JavaScript:
- Code context extraction
- Test discovery
- Code replacement
Note: These tests require JS/TS language support to be registered.
They will be skipped in environments where only Python is supported.
"""
import tempfile
from pathlib import Path
import pytest
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file, get_files_for_language
from codeflash.languages.base import Language
def skip_if_js_not_supported():
"""Skip test if JavaScript/TypeScript languages are not supported."""
try:
from codeflash.languages import get_language_support
get_language_support(Language.JAVASCRIPT)
except Exception as e:
pytest.skip(f"JavaScript/TypeScript language support not available: {e}")
class TestJavaScriptFunctionDiscovery:
"""Tests for JavaScript function discovery in the main pipeline."""
@ -29,6 +42,9 @@ class TestJavaScriptFunctionDiscovery:
def test_discover_functions_in_fibonacci(self, js_project_dir):
"""Test discovering functions in fibonacci.js."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
fib_file = js_project_dir / "fibonacci.js"
if not fib_file.exists():
pytest.skip("fibonacci.js not found")
@ -38,19 +54,17 @@ class TestJavaScriptFunctionDiscovery:
assert fib_file in functions
func_list = functions[fib_file]
# Should find the main exported functions
func_names = {f.function_name for f in func_list}
assert "fibonacci" in func_names
assert "isFibonacci" in func_names
assert "isPerfectSquare" in func_names
assert "fibonacciSequence" in func_names
assert func_names == {"fibonacci", "isFibonacci", "isPerfectSquare", "fibonacciSequence"}
# All should be JavaScript functions
for func in func_list:
assert func.language == "javascript"
def test_discover_functions_in_bubble_sort(self, js_project_dir):
"""Test discovering functions in bubble_sort.js."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
sort_file = js_project_dir / "bubble_sort.js"
if not sort_file.exists():
pytest.skip("bubble_sort.js not found")
@ -65,13 +79,14 @@ class TestJavaScriptFunctionDiscovery:
def test_get_javascript_files(self, js_project_dir):
"""Test getting JavaScript files from directory."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import get_files_for_language
files = get_files_for_language(js_project_dir, Language.JAVASCRIPT)
# Should find .js files
js_files = [f for f in files if f.suffix == ".js"]
assert len(js_files) >= 3 # fibonacci.js, bubble_sort.js, string_utils.js
assert len(js_files) >= 3
# Should not include test files in root (they're in tests/)
root_files = [f for f in js_files if f.parent == js_project_dir]
assert len(root_files) >= 3
@ -90,11 +105,11 @@ class TestJavaScriptCodeContext:
def test_extract_code_context_for_javascript(self, js_project_dir):
"""Test extracting code context for a JavaScript function."""
skip_if_js_not_supported()
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.languages.base import Language
# Force set language to JavaScript for proper context extraction routing
lang_current._current_language = Language.JAVASCRIPT
fib_file = js_project_dir / "fibonacci.js"
@ -104,21 +119,30 @@ class TestJavaScriptCodeContext:
functions = find_all_functions_in_file(fib_file)
func_list = functions[fib_file]
# Find the fibonacci function
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
assert fib_func is not None
# Extract code context
context = get_code_optimization_context(fib_func, js_project_dir)
# Verify context structure
assert context.read_writable_code is not None
assert context.read_writable_code.language == "javascript"
assert len(context.read_writable_code.code_strings) > 0
# The code should contain the function
code = context.read_writable_code.code_strings[0].code
assert "fibonacci" in code
expected_code = """/**
* Calculate the nth Fibonacci number using naive recursion.
* This is intentionally slow to demonstrate optimization potential.
* @param {number} n - The index of the Fibonacci number to calculate
* @returns {number} - The nth Fibonacci number
*/
function fibonacci(n) {
if (n <= 1) {
return n;
}
return fibonacci(n - 1) + fibonacci(n - 2);
}
"""
assert code == expected_code
class TestJavaScriptCodeReplacement:
@ -126,8 +150,9 @@ class TestJavaScriptCodeReplacement:
def test_replace_function_in_javascript_file(self):
"""Test replacing a function in a JavaScript file."""
skip_if_js_not_supported()
from codeflash.languages import get_language_support
from codeflash.languages.base import FunctionInfo, Language
from codeflash.languages.base import FunctionInfo
original_source = """
function add(a, b) {
@ -146,16 +171,23 @@ function multiply(a, b) {
js_support = get_language_support(Language.JAVASCRIPT)
# Create FunctionInfo for the add function
func_info = FunctionInfo(
name="add", file_path=Path("/tmp/test.js"), start_line=2, end_line=4, language=Language.JAVASCRIPT
function_name="add", file_path=Path("/tmp/test.js"), starting_line=2, ending_line=4, language="javascript"
)
result = js_support.replace_function(original_source, func_info, new_function)
# Verify the function was replaced
assert "// Optimized version" in result
assert "multiply" in result # Other function should still be there
expected_result = """
function add(a, b) {
// Optimized version
return a + b;
}
function multiply(a, b) {
return a * b;
}
"""
assert result == expected_result
class TestJavaScriptTestDiscovery:
@ -172,8 +204,9 @@ class TestJavaScriptTestDiscovery:
def test_discover_jest_tests(self, js_project_dir):
"""Test discovering Jest tests for JavaScript functions."""
skip_if_js_not_supported()
from codeflash.languages import get_language_support
from codeflash.languages.base import FunctionInfo, Language
from codeflash.languages.base import FunctionInfo
js_support = get_language_support(Language.JAVASCRIPT)
test_root = js_project_dir / "tests"
@ -181,17 +214,14 @@ class TestJavaScriptTestDiscovery:
if not test_root.exists():
pytest.skip("tests directory not found")
# Create FunctionInfo for fibonacci function
fib_file = js_project_dir / "fibonacci.js"
func_info = FunctionInfo(
name="fibonacci", file_path=fib_file, start_line=11, end_line=16, language=Language.JAVASCRIPT
function_name="fibonacci", file_path=fib_file, starting_line=11, ending_line=16, language="javascript"
)
# Discover tests
tests = js_support.discover_tests(test_root, [func_info])
# Should find tests for fibonacci
assert func_info.qualified_name in tests or "fibonacci" in str(tests)
assert func_info.qualified_name in tests or len(tests) > 0
class TestJavaScriptPipelineIntegration:
@ -199,6 +229,9 @@ class TestJavaScriptPipelineIntegration:
def test_function_to_optimize_has_correct_fields(self):
"""Test that FunctionToOptimize from JavaScript has all required fields."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
class Calculator {
@ -220,16 +253,13 @@ function standalone(x) {
functions = find_all_functions_in_file(file_path)
# Should find class methods and standalone function
assert len(functions.get(file_path, [])) >= 3
# Check standalone function
standalone_fn = next((fn for fn in functions[file_path] if fn.function_name == "standalone"), None)
assert standalone_fn is not None
assert standalone_fn.language == "javascript"
assert len(standalone_fn.parents) == 0
# Check class method
add_fn = next((fn for fn in functions[file_path] if fn.function_name == "add"), None)
assert add_fn is not None
assert add_fn.language == "javascript"
@ -250,4 +280,4 @@ function standalone(x) {
)
markdown = code_strings.markdown
assert "```javascript" in markdown or "```js" in markdown.lower()
assert "```javascript" in markdown

View file

@ -3,12 +3,27 @@
This module tests the line profiling and tracing instrumentation for JavaScript code.
"""
from __future__ import annotations
import tempfile
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import FunctionInfo, Language
from codeflash.languages.javascript.line_profiler import JavaScriptLineProfiler
from codeflash.languages.javascript.tracer import JavaScriptTracer
from codeflash.models.models import FunctionParent
def make_func(name: str, class_name: str | None = None) -> FunctionToOptimize:
"""Helper to create FunctionToOptimize for testing."""
parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else []
return FunctionToOptimize(
function_name=name,
file_path=Path("/test/file.js"),
parents=parents,
language="javascript",
)
class TestJavaScriptLineProfiler:
@ -49,7 +64,7 @@ function add(a, b) {
file_path = Path(f.name)
func_info = FunctionInfo(
name="add", file_path=file_path, start_line=2, end_line=5, language=Language.JAVASCRIPT
function_name="add", file_path=file_path, starting_line=2, ending_line=5, language="javascript"
)
output_file = Path("/tmp/test_profile.json")
@ -110,7 +125,7 @@ function multiply(x, y) {
file_path = Path(f.name)
func_info = FunctionInfo(
name="multiply", file_path=file_path, start_line=2, end_line=4, language=Language.JAVASCRIPT
function_name="multiply", file_path=file_path, starting_line=2, ending_line=4, language="javascript"
)
output_db = Path("/tmp/test_traces.db")
@ -154,7 +169,7 @@ function greet(name) {
file_path = Path(f.name)
func_info = FunctionInfo(
name="greet", file_path=file_path, start_line=2, end_line=4, language=Language.JAVASCRIPT
function_name="greet", file_path=file_path, starting_line=2, ending_line=4, language="javascript"
)
output_file = file_path.parent / ".codeflash" / "traces.db"
@ -185,7 +200,7 @@ function square(n) {
file_path = Path(f.name)
func_info = FunctionInfo(
name="square", file_path=file_path, start_line=2, end_line=5, language=Language.JAVASCRIPT
function_name="square", file_path=file_path, starting_line=2, ending_line=5, language="javascript"
)
output_file = file_path.parent / ".codeflash" / "line_profile.json"
@ -352,7 +367,7 @@ const result = calc.fibonacci(10);
console.log(result);
"""
transformed, counter = transform_standalone_calls(
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
code=code, function_to_optimize=make_func("fibonacci", class_name="Calculator"), capture_func="capture"
)
# Should transform calc.fibonacci(10) to codeflash.capture(..., calc.fibonacci.bind(calc), 10)
@ -371,7 +386,7 @@ test('fibonacci works', () => {
});
"""
transformed, counter = transform_expect_calls(
code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture"
code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture"
)
# Should transform expect(calc.fibonacci(10)) to
@ -393,8 +408,7 @@ test('fibonacci works', () => {
"""
transformed, counter = transform_expect_calls(
code=code,
func_name="fibonacci",
qualified_name="FibonacciCalculator.fibonacci",
function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"),
capture_func="capture",
remove_assertions=True,
)
@ -419,7 +433,7 @@ class FibonacciCalculator {
}
"""
transformed, counter = transform_standalone_calls(
code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture"
code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture"
)
# The method definition should NOT be transformed
@ -438,7 +452,7 @@ FibonacciCalculator.prototype.fibonacci = function(n) {
};
"""
transformed, counter = transform_standalone_calls(
code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture"
code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture"
)
# The prototype assignment should NOT be transformed
@ -456,7 +470,7 @@ const b = calc.fibonacci(10);
const sum = a + b;
"""
transformed, counter = transform_standalone_calls(
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
code=code, function_to_optimize=make_func("fibonacci", class_name="Calculator"), capture_func="capture"
)
# Should transform both calls
@ -475,7 +489,7 @@ class Wrapper {
}
"""
transformed, counter = transform_standalone_calls(
code=code, func_name="fibonacci", qualified_name="Wrapper.fibonacci", capture_func="capture"
code=code, function_to_optimize=make_func("fibonacci", class_name="Wrapper"), capture_func="capture"
)
# Should transform this.fibonacci(n)
@ -515,10 +529,9 @@ describe('FibonacciCalculator', () => {
"""
instrumented = _instrument_js_test_code(
code=test_code,
func_name="fibonacci",
function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"),
test_file_path="test.js",
mode="behavior",
qualified_name="FibonacciCalculator.fibonacci",
)
# Check that codeflash import was added
@ -545,7 +558,7 @@ describe('Calculator', () => {
});
"""
instrumented = _instrument_js_test_code(
code=test_code, func_name="add", test_file_path="test.js", mode="behavior", qualified_name="Calculator.add"
code=test_code, function_to_optimize=make_func("add", class_name="Calculator"), test_file_path="test.js", mode="behavior"
)
# describe and test structure should be preserved
@ -567,7 +580,7 @@ const data = await api.fetchData('http://example.com');
console.log(data);
"""
transformed, counter = transform_standalone_calls(
code=code, func_name="fetchData", qualified_name="ApiClient.fetchData", capture_func="capture"
code=code, function_to_optimize=make_func("fetchData", class_name="ApiClient"), capture_func="capture"
)
# Should preserve await
@ -586,7 +599,7 @@ class TestInstrumentationFullStringEquality:
code = " calc.fibonacci(10);"
transformed, counter = transform_standalone_calls(
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
code=code, function_to_optimize=make_func("fibonacci", class_name="Calculator"), capture_func="capture"
)
expected = " codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10);"
@ -600,7 +613,7 @@ class TestInstrumentationFullStringEquality:
code = " expect(calc.fibonacci(10)).toBe(55);"
transformed, counter = transform_expect_calls(
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
code=code, function_to_optimize=make_func("fibonacci", class_name="Calculator"), capture_func="capture"
)
expected = " expect(codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10)).toBe(55);"
@ -615,8 +628,7 @@ class TestInstrumentationFullStringEquality:
transformed, counter = transform_expect_calls(
code=code,
func_name="fibonacci",
qualified_name="Calculator.fibonacci",
function_to_optimize=make_func("fibonacci", class_name="Calculator"),
capture_func="capture",
remove_assertions=True,
)
@ -632,7 +644,7 @@ class TestInstrumentationFullStringEquality:
code = " fibonacci(10);"
transformed, counter = transform_standalone_calls(
code=code, func_name="fibonacci", qualified_name="fibonacci", capture_func="capture"
code=code, function_to_optimize=make_func("fibonacci"), capture_func="capture"
)
expected = " codeflash.capture('fibonacci', '1', fibonacci, 10);"
@ -646,9 +658,9 @@ class TestInstrumentationFullStringEquality:
code = " return this.fibonacci(n - 1);"
transformed, counter = transform_standalone_calls(
code=code, func_name="fibonacci", qualified_name="Class.fibonacci", capture_func="capture"
code=code, function_to_optimize=make_func("fibonacci", class_name="Class"), capture_func="capture"
)
expected = " return codeflash.capture('Class.fibonacci', '1', this.fibonacci.bind(this), n - 1);"
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
assert counter == 1
assert counter == 1

View file

@ -0,0 +1,251 @@
"""Tests for JavaScript requirements verification.
Tests the verify_requirements function that checks Node.js, npm, and test framework availability.
"""
import json
import subprocess
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from codeflash.languages.javascript.support import JavaScriptSupport
class TestVerifyRequirements:
"""Tests for JavaScriptSupport.verify_requirements()."""
@pytest.fixture
def js_support(self):
"""Create a JavaScriptSupport instance."""
return JavaScriptSupport()
@pytest.fixture
def project_with_jest(self, tmp_path):
"""Create a project directory with Jest installed."""
node_modules = tmp_path / "node_modules"
node_modules.mkdir()
(node_modules / "jest").mkdir()
(node_modules / "codeflash").mkdir()
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test-project",
"devDependencies": {"jest": "^29.0.0"},
}
)
)
return tmp_path
@pytest.fixture
def project_with_vitest(self, tmp_path):
"""Create a project directory with Vitest installed."""
node_modules = tmp_path / "node_modules"
node_modules.mkdir()
(node_modules / "vitest").mkdir()
(node_modules / "codeflash").mkdir()
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "test-project",
"devDependencies": {"vitest": "^2.0.0"},
}
)
)
return tmp_path
@pytest.fixture
def project_without_node_modules(self, tmp_path):
"""Create a project directory without node_modules."""
package_json = tmp_path / "package.json"
package_json.write_text(json.dumps({"name": "test-project"}))
return tmp_path
@pytest.fixture
def project_without_jest(self, tmp_path):
"""Create a project directory with node_modules but without Jest."""
node_modules = tmp_path / "node_modules"
node_modules.mkdir()
(node_modules / "some-other-package").mkdir()
package_json = tmp_path / "package.json"
package_json.write_text(json.dumps({"name": "test-project"}))
return tmp_path
def test_verify_requirements_success_with_jest(self, js_support, project_with_jest):
"""Test successful verification when Jest is installed."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
success, errors = js_support.verify_requirements(project_with_jest, "jest")
assert success is True
assert errors == []
def test_verify_requirements_success_with_vitest(self, js_support, project_with_vitest):
"""Test successful verification when Vitest is installed."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
success, errors = js_support.verify_requirements(project_with_vitest, "vitest")
assert success is True
assert errors == []
def test_verify_requirements_fails_without_node(self, js_support, project_with_jest):
"""Test verification fails when Node.js is not installed."""
with patch("subprocess.run") as mock_run:
mock_run.side_effect = FileNotFoundError("node not found")
success, errors = js_support.verify_requirements(project_with_jest, "jest")
assert success is False
assert len(errors) >= 1
node_error_found = any("Node.js" in error for error in errors)
assert node_error_found is True
def test_verify_requirements_fails_without_npm(self, js_support, project_with_jest):
"""Test verification fails when npm is not available."""
def mock_run_side_effect(cmd, **kwargs):
if cmd[0] == "node":
return MagicMock(returncode=0)
if cmd[0] == "npm":
raise FileNotFoundError("npm not found")
return MagicMock(returncode=0)
with patch("subprocess.run", side_effect=mock_run_side_effect):
success, errors = js_support.verify_requirements(project_with_jest, "jest")
assert success is False
npm_error_found = any("npm" in error for error in errors)
assert npm_error_found is True
def test_verify_requirements_fails_without_node_modules(self, js_support, project_without_node_modules):
"""Test verification fails when node_modules doesn't exist."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
success, errors = js_support.verify_requirements(project_without_node_modules, "jest")
assert success is False
assert len(errors) == 1
expected_error = (
f"node_modules not found in {project_without_node_modules}. "
f"Please run 'npm install' to install dependencies."
)
assert errors[0] == expected_error
def test_verify_requirements_fails_without_test_framework(self, js_support, project_without_jest):
"""Test verification fails when test framework is not installed."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
success, errors = js_support.verify_requirements(project_without_jest, "jest")
assert success is False
assert len(errors) == 1
expected_error = "jest is not installed. Please run 'npm install --save-dev jest' to install it."
assert errors[0] == expected_error
def test_verify_requirements_returns_multiple_errors(self, js_support, project_without_node_modules):
"""Test that multiple errors can be returned."""
with patch("subprocess.run") as mock_run:
mock_run.side_effect = FileNotFoundError("command not found")
success, errors = js_support.verify_requirements(project_without_node_modules, "jest")
assert success is False
assert len(errors) >= 2
# Should have errors for Node.js, npm, and node_modules
error_text = " ".join(errors)
assert "Node.js" in error_text
assert "npm" in error_text
def test_verify_requirements_vitest_not_installed(self, js_support, project_with_jest):
"""Test verification fails when Vitest is requested but only Jest is installed."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
success, errors = js_support.verify_requirements(project_with_jest, "vitest")
assert success is False
assert len(errors) == 1
expected_error = "vitest is not installed. Please run 'npm install --save-dev vitest' to install it."
assert errors[0] == expected_error
def test_verify_requirements_jest_not_installed(self, js_support, project_with_vitest):
"""Test verification fails when Jest is requested but only Vitest is installed."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
success, errors = js_support.verify_requirements(project_with_vitest, "jest")
assert success is False
assert len(errors) == 1
expected_error = "jest is not installed. Please run 'npm install --save-dev jest' to install it."
assert errors[0] == expected_error
class TestVerifyRequirementsIntegration:
"""Integration tests for verify_requirements with real filesystem."""
@pytest.fixture
def js_support(self):
"""Create a JavaScriptSupport instance."""
return JavaScriptSupport()
def test_verify_on_real_vitest_project(self, js_support):
"""Test verification on the real vitest sample project."""
project_root = Path(__file__).parent.parent.parent / "code_to_optimize" / "js" / "code_to_optimize_vitest"
if not project_root.exists():
pytest.skip("code_to_optimize_vitest directory not found")
node_modules = project_root / "node_modules"
if not node_modules.exists():
pytest.skip("node_modules not installed in vitest project")
# This test verifies the real project structure
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
success, errors = js_support.verify_requirements(project_root, "vitest")
# If vitest is installed, should succeed
vitest_installed = (node_modules / "vitest").exists()
if vitest_installed:
assert success is True
assert errors == []
else:
assert success is False
assert len(errors) >= 1
def test_verify_on_real_jest_project(self, js_support):
"""Test verification on the real Jest sample project."""
project_root = Path(__file__).parent.parent.parent / "code_to_optimize" / "js" / "code_to_optimize_ts"
if not project_root.exists():
pytest.skip("code_to_optimize_ts directory not found")
node_modules = project_root / "node_modules"
if not node_modules.exists():
pytest.skip("node_modules not installed in jest project")
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
success, errors = js_support.verify_requirements(project_root, "jest")
jest_installed = (node_modules / "jest").exists()
if jest_installed:
assert success is True
assert errors == []
else:
assert success is False
assert len(errors) >= 1

View file

@ -55,7 +55,7 @@ function add(a, b) {
functions = js_support.discover_functions(Path(f.name))
assert len(functions) == 1
assert functions[0].name == "add"
assert functions[0].function_name == "add"
assert functions[0].language == Language.JAVASCRIPT
def test_discover_multiple_functions(self, js_support):
@ -79,7 +79,7 @@ function multiply(a, b) {
functions = js_support.discover_functions(Path(f.name))
assert len(functions) == 3
names = {func.name for func in functions}
names = {func.function_name for func in functions}
assert names == {"add", "subtract", "multiply"}
def test_discover_arrow_function(self, js_support):
@ -97,7 +97,7 @@ const multiply = (x, y) => x * y;
functions = js_support.discover_functions(Path(f.name))
assert len(functions) == 2
names = {func.name for func in functions}
names = {func.function_name for func in functions}
assert names == {"add", "multiply"}
def test_discover_function_without_return_excluded(self, js_support):
@ -118,7 +118,7 @@ function withoutReturn() {
# Only the function with return should be discovered
assert len(functions) == 1
assert functions[0].name == "withReturn"
assert functions[0].function_name == "withReturn"
def test_discover_class_methods(self, js_support):
"""Test discovering class methods."""
@ -161,8 +161,8 @@ function syncFunction() {
assert len(functions) == 2
async_func = next(f for f in functions if f.name == "fetchData")
sync_func = next(f for f in functions if f.name == "syncFunction")
async_func = next(f for f in functions if f.function_name == "fetchData")
sync_func = next(f for f in functions if f.function_name == "syncFunction")
assert async_func.is_async is True
assert sync_func.is_async is False
@ -185,7 +185,7 @@ function syncFunc() {
functions = js_support.discover_functions(Path(f.name), criteria)
assert len(functions) == 1
assert functions[0].name == "syncFunc"
assert functions[0].function_name == "syncFunc"
def test_discover_with_filter_exclude_methods(self, js_support):
"""Test filtering out class methods."""
@ -207,7 +207,7 @@ class MyClass {
functions = js_support.discover_functions(Path(f.name), criteria)
assert len(functions) == 1
assert functions[0].name == "standalone"
assert functions[0].function_name == "standalone"
def test_discover_line_numbers(self, js_support):
"""Test that line numbers are correctly captured."""
@ -226,13 +226,13 @@ function func2() {
functions = js_support.discover_functions(Path(f.name))
func1 = next(f for f in functions if f.name == "func1")
func2 = next(f for f in functions if f.name == "func2")
func1 = next(f for f in functions if f.function_name == "func1")
func2 = next(f for f in functions if f.function_name == "func2")
assert func1.start_line == 1
assert func1.end_line == 3
assert func2.start_line == 5
assert func2.end_line == 9
assert func1.starting_line == 1
assert func1.ending_line == 3
assert func2.starting_line == 5
assert func2.ending_line == 9
def test_discover_generator_function(self, js_support):
"""Test discovering generator functions."""
@ -249,7 +249,7 @@ function* numberGenerator() {
functions = js_support.discover_functions(Path(f.name))
assert len(functions) == 1
assert functions[0].name == "numberGenerator"
assert functions[0].function_name == "numberGenerator"
def test_discover_invalid_file_returns_empty(self, js_support):
"""Test that invalid JavaScript file returns empty list."""
@ -280,7 +280,7 @@ const add = function(a, b) {
functions = js_support.discover_functions(Path(f.name))
assert len(functions) == 1
assert functions[0].name == "add"
assert functions[0].function_name == "add"
def test_discover_immediately_invoked_function_excluded(self, js_support):
"""Test that IIFEs without names are excluded when require_name is True."""
@ -300,7 +300,7 @@ function named() {
# Only the named function should be discovered
assert len(functions) == 1
assert functions[0].name == "named"
assert functions[0].function_name == "named"
class TestReplaceFunction:
@ -316,7 +316,7 @@ function multiply(a, b) {
return a * b;
}
"""
func = FunctionInfo(name="add", file_path=Path("/test.js"), start_line=1, end_line=3)
func = FunctionInfo(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3)
new_code = """function add(a, b) {
// Optimized
return (a + b) | 0;
@ -343,7 +343,7 @@ function other() {
// Footer
"""
func = FunctionInfo(name="target", file_path=Path("/test.js"), start_line=4, end_line=6)
func = FunctionInfo(function_name="target", file_path=Path("/test.js"), starting_line=4, ending_line=6)
new_code = """function target() {
return 42;
}
@ -365,11 +365,11 @@ function other() {
}
"""
func = FunctionInfo(
name="add",
function_name="add",
file_path=Path("/test.js"),
start_line=2,
end_line=4,
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
starting_line=2,
ending_line=4,
parents=[ParentInfo(name="Calculator", type="ClassDef")],
)
# New code has no indentation
new_code = """add(a, b) {
@ -391,7 +391,7 @@ function other() {
const multiply = (x, y) => x * y;
"""
func = FunctionInfo(name="add", file_path=Path("/test.js"), start_line=1, end_line=3)
func = FunctionInfo(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3)
new_code = """const add = (a, b) => {
return (a + b) | 0;
};
@ -483,7 +483,7 @@ class TestExtractCodeContext:
f.flush()
file_path = Path(f.name)
func = FunctionInfo(name="add", file_path=file_path, start_line=1, end_line=3)
func = FunctionInfo(function_name="add", file_path=file_path, starting_line=1, ending_line=3)
context = js_support.extract_code_context(func, file_path.parent, file_path.parent)
@ -508,7 +508,7 @@ function main(a) {
# First discover functions to get accurate line numbers
functions = js_support.discover_functions(file_path)
main_func = next(f for f in functions if f.name == "main")
main_func = next(f for f in functions if f.function_name == "main")
context = js_support.extract_code_context(main_func, file_path.parent, file_path.parent)
@ -538,7 +538,7 @@ class TestIntegration:
functions = js_support.discover_functions(file_path)
assert len(functions) == 1
func = functions[0]
assert func.name == "fibonacci"
assert func.function_name == "fibonacci"
# Replace
optimized_code = """function fibonacci(n) {
@ -626,7 +626,7 @@ export default Button;
functions = js_support.discover_functions(file_path)
# Should find both components
names = {f.name for f in functions}
names = {f.function_name for f in functions}
assert "Button" in names
assert "Card" in names
@ -688,7 +688,7 @@ class TestClassMethodExtraction:
# Discover the method
functions = js_support.discover_functions(file_path)
add_method = next(f for f in functions if f.name == "add")
add_method = next(f for f in functions if f.function_name == "add")
# Extract code context
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
@ -725,7 +725,7 @@ class Calculator {
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
add_method = next(f for f in functions if f.name == "add")
add_method = next(f for f in functions if f.function_name == "add")
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
@ -764,7 +764,7 @@ class Calculator {
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
fib_method = next(f for f in functions if f.name == "fibonacci")
fib_method = next(f for f in functions if f.function_name == "fibonacci")
context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent)
@ -802,7 +802,7 @@ class Calculator {
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
add_method = next((f for f in functions if f.name == "add"), None)
add_method = next((f for f in functions if f.function_name == "add"), None)
if add_method:
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
@ -831,7 +831,7 @@ class Calculator {
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
fetch_method = next(f for f in functions if f.name == "fetchData")
fetch_method = next(f for f in functions if f.function_name == "fetchData")
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
@ -863,7 +863,7 @@ class Calculator {
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
add_method = next((f for f in functions if f.name == "add"), None)
add_method = next((f for f in functions if f.function_name == "add"), None)
if add_method:
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
@ -891,7 +891,7 @@ class Calculator {
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
method = next(f for f in functions if f.name == "simpleMethod")
method = next(f for f in functions if f.function_name == "simpleMethod")
context = js_support.extract_code_context(method, file_path.parent, file_path.parent)
@ -922,11 +922,11 @@ class TestClassMethodReplacement:
}
"""
func = FunctionInfo(
name="add",
function_name="add",
file_path=Path("/test.js"),
start_line=2,
end_line=4,
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
starting_line=2,
ending_line=4,
parents=[ParentInfo(name="Calculator", type="ClassDef")],
is_method=True,
)
new_code = """ add(a, b) {
@ -963,12 +963,12 @@ class TestClassMethodReplacement:
}
"""
func = FunctionInfo(
name="add",
function_name="add",
file_path=Path("/test.js"),
start_line=5, # Method starts here
end_line=7,
starting_line=5, # Method starts here
ending_line=7,
doc_start_line=2, # JSDoc starts here
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
parents=[ParentInfo(name="Calculator", type="ClassDef")],
is_method=True,
)
new_code = """ /**
@ -1000,11 +1000,11 @@ class TestClassMethodReplacement:
"""
# Replace add first
add_func = FunctionInfo(
name="add",
function_name="add",
file_path=Path("/test.js"),
start_line=2,
end_line=4,
parents=(ParentInfo(name="Math", type="ClassDef"),),
starting_line=2,
ending_line=4,
parents=[ParentInfo(name="Math", type="ClassDef")],
is_method=True,
)
source = js_support.replace_function(
@ -1032,11 +1032,11 @@ class TestClassMethodReplacement:
}
"""
func = FunctionInfo(
name="innerMethod",
function_name="innerMethod",
file_path=Path("/test.js"),
start_line=2,
end_line=4,
parents=(ParentInfo(name="Indented", type="ClassDef"),),
starting_line=2,
ending_line=4,
parents=[ParentInfo(name="Indented", type="ClassDef")],
is_method=True,
)
# New code with no indentation
@ -1077,7 +1077,7 @@ class TestClassMethodEdgeCases:
functions = js_support.discover_functions(file_path)
# Should find constructor and increment
names = {f.name for f in functions}
names = {f.function_name for f in functions}
assert "constructor" in names or "increment" in names
def test_class_with_getters_setters(self, js_support):
@ -1107,7 +1107,7 @@ class TestClassMethodEdgeCases:
functions = js_support.discover_functions(file_path)
# Should find at least greet
names = {f.name for f in functions}
names = {f.function_name for f in functions}
assert "greet" in names
def test_class_extending_another(self, js_support):
@ -1135,7 +1135,7 @@ class Dog extends Animal {
functions = js_support.discover_functions(file_path)
# Find Dog's fetch method
fetch_method = next((f for f in functions if f.name == "fetch" and f.class_name == "Dog"), None)
fetch_method = next((f for f in functions if f.function_name == "fetch" and f.class_name == "Dog"), None)
if fetch_method:
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
@ -1169,7 +1169,7 @@ class Dog extends Animal {
functions = js_support.discover_functions(file_path)
# Should at least find publicMethod
names = {f.name for f in functions}
names = {f.function_name for f in functions}
assert "publicMethod" in names
def test_commonjs_class_export(self, js_support):
@ -1187,7 +1187,7 @@ module.exports = { Calculator };
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
add_method = next(f for f in functions if f.name == "add")
add_method = next(f for f in functions if f.function_name == "add")
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
@ -1209,7 +1209,7 @@ module.exports = { Calculator };
functions = js_support.discover_functions(file_path)
# Find the add method
add_method = next((f for f in functions if f.name == "add"), None)
add_method = next((f for f in functions if f.function_name == "add"), None)
if add_method:
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
@ -1260,7 +1260,7 @@ module.exports = { Counter };
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
increment_func = next(fn for fn in functions if fn.name == "increment")
increment_func = next(fn for fn in functions if fn.function_name == "increment")
# Step 1: Extract code context (includes constructor for AI context)
context = js_support.extract_code_context(increment_func, file_path.parent, file_path.parent)
@ -1359,7 +1359,7 @@ export { User };
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
get_name_func = next(fn for fn in functions if fn.name == "getName")
get_name_func = next(fn for fn in functions if fn.function_name == "getName")
# Step 1: Extract code context (includes fields and constructor)
context = ts_support.extract_code_context(get_name_func, file_path.parent, file_path.parent)
@ -1461,7 +1461,7 @@ class Calculator {
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
add_func = next(fn for fn in functions if fn.name == "add")
add_func = next(fn for fn in functions if fn.function_name == "add")
# Extract context for add
context = js_support.extract_code_context(add_func, file_path.parent, file_path.parent)
@ -1547,7 +1547,7 @@ module.exports = { MathUtils };
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
add_func = next(fn for fn in functions if fn.name == "add")
add_func = next(fn for fn in functions if fn.function_name == "add")
# Extract context
context = js_support.extract_code_context(add_func, file_path.parent, file_path.parent)

View file

@ -1715,7 +1715,7 @@ module.exports = { Calculator };
functions = js_support.discover_functions(source_file)
# Check qualified names include class
add_func = next((f for f in functions if f.name == "add"), None)
add_func = next((f for f in functions if f.function_name == "add"), None)
assert add_func is not None
assert add_func.class_name == "Calculator"

View file

@ -0,0 +1,730 @@
"""Tests for JavaScript/Jest test runner functionality."""
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
class TestJestRootsConfiguration:
"""Tests for Jest --roots flag handling."""
def test_behavioral_tests_adds_roots_for_test_directories(self):
"""Test that run_jest_behavioral_tests adds --roots for test directories."""
from codeflash.languages.javascript.test_runner import run_jest_behavioral_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
# Create mock test files in a test directory
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir).resolve()
test_dir = tmpdir_path / "test"
test_dir.mkdir()
# Create package.json to simulate a Node project
(tmpdir_path / "package.json").write_text('{"name": "test"}')
# Create mock test files
test_file1 = test_dir / "test_func__unit_test_0.test.ts"
test_file2 = test_dir / "test_func__unit_test_1.test.ts"
test_file1.write_text("// test 1")
test_file2.write_text("// test 2")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file1,
instrumented_behavior_file_path=test_file1,
benchmarking_file_path=test_file1,
test_type=TestType.GENERATED_REGRESSION,
),
TestFile(
original_file_path=test_file2,
instrumented_behavior_file_path=test_file2,
benchmarking_file_path=test_file2,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
# Mock subprocess.run to capture the command
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 1
mock_run.return_value = mock_result
try:
run_jest_behavioral_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
)
except Exception:
pass # Expected to fail since no real Jest
# Verify the command included --roots
if mock_run.called:
call_args = mock_run.call_args
cmd = call_args[0][0]
# Find --roots flags in the command
roots_flags = []
for i, arg in enumerate(cmd):
if arg == "--roots" and i + 1 < len(cmd):
roots_flags.append(cmd[i + 1])
# Should have added the test directory as a root
assert len(roots_flags) > 0, "Expected --roots flag in Jest command"
assert str(test_dir) in roots_flags or any(
str(test_dir) in root for root in roots_flags
), f"Expected test directory {test_dir} in --roots flags: {roots_flags}"
def test_benchmarking_tests_adds_roots_for_test_directories(self):
"""Test that run_jest_benchmarking_tests adds --roots for test directories."""
from codeflash.languages.javascript.test_runner import run_jest_benchmarking_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir).resolve()
test_dir = tmpdir_path / "test"
test_dir.mkdir()
(tmpdir_path / "package.json").write_text('{"name": "test"}')
test_file = test_dir / "test_func__perf_test_0.test.ts"
test_file.write_text("// perf test")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 1
mock_run.return_value = mock_result
try:
run_jest_benchmarking_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
)
except Exception:
pass
if mock_run.called:
call_args = mock_run.call_args
cmd = call_args[0][0]
roots_flags = []
for i, arg in enumerate(cmd):
if arg == "--roots" and i + 1 < len(cmd):
roots_flags.append(cmd[i + 1])
assert len(roots_flags) > 0, "Expected --roots flag in Jest command"
def test_line_profile_tests_adds_roots_for_test_directories(self):
"""Test that run_jest_line_profile_tests adds --roots for test directories."""
from codeflash.languages.javascript.test_runner import run_jest_line_profile_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
test_dir = tmpdir_path / "test"
test_dir.mkdir()
(tmpdir_path / "package.json").write_text('{"name": "test"}')
test_file = test_dir / "test_func__line_profile.test.ts"
test_file.write_text("// line profile test")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 1
mock_run.return_value = mock_result
try:
run_jest_line_profile_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
)
except Exception:
pass
if mock_run.called:
call_args = mock_run.call_args
cmd = call_args[0][0]
roots_flags = []
for i, arg in enumerate(cmd):
if arg == "--roots" and i + 1 < len(cmd):
roots_flags.append(cmd[i + 1])
assert len(roots_flags) > 0, "Expected --roots flag in Jest command"
def test_multiple_test_directories_all_added_to_roots(self):
"""Test that multiple test directories are all added as --roots."""
from codeflash.languages.javascript.test_runner import run_jest_behavioral_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
test_dir1 = tmpdir_path / "test"
test_dir2 = tmpdir_path / "spec"
test_dir1.mkdir()
test_dir2.mkdir()
(tmpdir_path / "package.json").write_text('{"name": "test"}')
test_file1 = test_dir1 / "test_func__unit_test_0.test.ts"
test_file2 = test_dir2 / "test_func__unit_test_1.test.ts"
test_file1.write_text("// test 1")
test_file2.write_text("// test 2")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file1,
instrumented_behavior_file_path=test_file1,
benchmarking_file_path=test_file1,
test_type=TestType.GENERATED_REGRESSION,
),
TestFile(
original_file_path=test_file2,
instrumented_behavior_file_path=test_file2,
benchmarking_file_path=test_file2,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 1
mock_run.return_value = mock_result
try:
run_jest_behavioral_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
)
except Exception:
pass
if mock_run.called:
call_args = mock_run.call_args
cmd = call_args[0][0]
roots_flags = []
for i, arg in enumerate(cmd):
if arg == "--roots" and i + 1 < len(cmd):
roots_flags.append(cmd[i + 1])
# Should have two --roots flags (one for each directory)
assert len(roots_flags) == 2, f"Expected 2 --roots flags, got {len(roots_flags)}"
class TestVitestTimeoutConfiguration:
"""Tests for Vitest subprocess timeout handling."""
def test_vitest_behavioral_subprocess_timeout_larger_than_test_timeout(self):
"""Test that subprocess timeout is larger than per-test timeout for Vitest behavioral tests."""
from codeflash.languages.javascript.vitest_runner import run_vitest_behavioral_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
test_dir = tmpdir_path / "test"
test_dir.mkdir()
(tmpdir_path / "package.json").write_text('{"name": "test", "devDependencies": {"vitest": "^1.0.0"}}')
test_file = test_dir / "test_func.test.ts"
test_file.write_text("// test")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 0
mock_run.return_value = mock_result
# Run with a 15 second per-test timeout
run_vitest_behavioral_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
timeout=15, # 15 second per-test timeout
project_root=tmpdir_path,
)
# Verify subprocess was called with a larger timeout
assert mock_run.called
call_kwargs = mock_run.call_args[1]
subprocess_timeout = call_kwargs.get("timeout")
# Subprocess timeout should be at least 120 seconds (minimum)
# or 10x the per-test timeout (150 seconds)
assert subprocess_timeout >= 120, f"Expected subprocess timeout >= 120s, got {subprocess_timeout}s"
assert subprocess_timeout >= 15 * 10, f"Expected subprocess timeout >= 150s (10x per-test), got {subprocess_timeout}s"
def test_vitest_line_profile_subprocess_timeout_larger_than_test_timeout(self):
"""Test that subprocess timeout is larger than per-test timeout for Vitest line profile tests."""
from codeflash.languages.javascript.vitest_runner import run_vitest_line_profile_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
test_dir = tmpdir_path / "test"
test_dir.mkdir()
(tmpdir_path / "package.json").write_text('{"name": "test", "devDependencies": {"vitest": "^1.0.0"}}')
test_file = test_dir / "test_func.test.ts"
test_file.write_text("// test")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 0
mock_run.return_value = mock_result
run_vitest_line_profile_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
timeout=15,
project_root=tmpdir_path,
)
assert mock_run.called
call_kwargs = mock_run.call_args[1]
subprocess_timeout = call_kwargs.get("timeout")
assert subprocess_timeout >= 120, f"Expected subprocess timeout >= 120s, got {subprocess_timeout}s"
def test_vitest_default_subprocess_timeout_is_reasonable(self):
"""Test that default subprocess timeout is at least 120 seconds when no timeout specified."""
from codeflash.languages.javascript.vitest_runner import run_vitest_behavioral_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
test_dir = tmpdir_path / "test"
test_dir.mkdir()
(tmpdir_path / "package.json").write_text('{"name": "test", "devDependencies": {"vitest": "^1.0.0"}}')
test_file = test_dir / "test_func.test.ts"
test_file.write_text("// test")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 0
mock_run.return_value = mock_result
# Run without specifying a timeout
run_vitest_behavioral_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
)
assert mock_run.called
call_kwargs = mock_run.call_args[1]
subprocess_timeout = call_kwargs.get("timeout")
# Default should be at least 120 seconds (or 600 from the default)
assert subprocess_timeout >= 120, f"Expected subprocess timeout >= 120s, got {subprocess_timeout}s"
class TestVitestInternalLoopingConfiguration:
"""Tests for Vitest internal looping (no external loop-runner)."""
def test_vitest_benchmarking_does_not_set_current_batch_env(self):
"""Test that Vitest runner does NOT set CODEFLASH_PERF_CURRENT_BATCH.
This is critical: when CODEFLASH_PERF_CURRENT_BATCH is not set,
capturePerf() in the npm package will do all loops internally
(PERF_LOOP_COUNT iterations) instead of just PERF_BATCH_SIZE.
"""
from codeflash.languages.javascript.vitest_runner import run_vitest_benchmarking_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
test_dir = tmpdir_path / "test"
test_dir.mkdir()
(tmpdir_path / "package.json").write_text('{"name": "test", "devDependencies": {"vitest": "^1.0.0"}}')
test_file = test_dir / "test_func.test.ts"
test_file.write_text("// perf test")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 0
mock_run.return_value = mock_result
run_vitest_benchmarking_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
max_loops=100,
min_loops=5,
)
assert mock_run.called
call_kwargs = mock_run.call_args[1]
env = call_kwargs.get("env", {})
# CODEFLASH_PERF_CURRENT_BATCH should NOT be set
# This allows capturePerf() to do all loops internally
assert "CODEFLASH_PERF_CURRENT_BATCH" not in env, (
"CODEFLASH_PERF_CURRENT_BATCH should not be set for Vitest - "
"internal looping relies on this being undefined"
)
# But CODEFLASH_PERF_LOOP_COUNT should be set
assert "CODEFLASH_PERF_LOOP_COUNT" in env, "CODEFLASH_PERF_LOOP_COUNT should be set"
assert env["CODEFLASH_PERF_LOOP_COUNT"] == "100"
def test_vitest_benchmarking_sets_loop_configuration_env_vars(self):
"""Test that Vitest benchmarking sets correct loop configuration environment variables."""
from codeflash.languages.javascript.vitest_runner import run_vitest_benchmarking_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
test_dir = tmpdir_path / "test"
test_dir.mkdir()
(tmpdir_path / "package.json").write_text('{"name": "test", "devDependencies": {"vitest": "^1.0.0"}}')
test_file = test_dir / "test_func.test.ts"
test_file.write_text("// perf test")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 0
mock_run.return_value = mock_result
run_vitest_benchmarking_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
max_loops=50,
min_loops=10,
target_duration_ms=5000,
stability_check=True,
)
assert mock_run.called
call_kwargs = mock_run.call_args[1]
env = call_kwargs.get("env", {})
# Verify all loop configuration env vars are set correctly
assert env.get("CODEFLASH_PERF_LOOP_COUNT") == "50"
assert env.get("CODEFLASH_PERF_MIN_LOOPS") == "10"
assert env.get("CODEFLASH_PERF_TARGET_DURATION_MS") == "5000"
assert env.get("CODEFLASH_PERF_STABILITY_CHECK") == "true"
assert env.get("CODEFLASH_MODE") == "performance"
class TestBundlerModuleResolutionFix:
"""Tests for bundler moduleResolution compatibility fix."""
def test_detect_bundler_module_resolution_true(self):
"""Test detection of bundler moduleResolution in tsconfig."""
import json
from codeflash.languages.javascript.test_runner import _detect_bundler_module_resolution
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# Create tsconfig with bundler moduleResolution
tsconfig = {
"compilerOptions": {
"moduleResolution": "bundler",
"module": "preserve",
"target": "ES2022",
}
}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
assert _detect_bundler_module_resolution(tmpdir_path) is True
def test_detect_bundler_module_resolution_false(self):
"""Test detection returns false for Node moduleResolution."""
import json
from codeflash.languages.javascript.test_runner import _detect_bundler_module_resolution
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# Create tsconfig with Node moduleResolution
tsconfig = {
"compilerOptions": {
"moduleResolution": "Node",
"module": "ESNext",
}
}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
assert _detect_bundler_module_resolution(tmpdir_path) is False
def test_detect_bundler_module_resolution_no_tsconfig(self):
"""Test detection returns false when no tsconfig exists."""
from codeflash.languages.javascript.test_runner import _detect_bundler_module_resolution
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
assert _detect_bundler_module_resolution(tmpdir_path) is False
def test_detect_bundler_module_resolution_extended_config(self):
"""Test detection works with extended tsconfig files."""
import json
from codeflash.languages.javascript.test_runner import _detect_bundler_module_resolution
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# Create a base config with bundler in a subdirectory (simulating node_modules)
node_modules = tmpdir_path / "node_modules" / "@myorg" / "tsconfig"
node_modules.mkdir(parents=True)
base_tsconfig = {
"compilerOptions": {
"moduleResolution": "bundler",
"module": "preserve",
}
}
(node_modules / "tsconfig.json").write_text(json.dumps(base_tsconfig))
# Create a project tsconfig that extends the base
project_tsconfig = {
"extends": "@myorg/tsconfig/tsconfig.json",
"compilerOptions": {
"target": "ES2022",
}
}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(project_tsconfig))
# Should detect bundler from extended config
assert _detect_bundler_module_resolution(tmpdir_path) is True
def test_create_codeflash_tsconfig(self):
"""Test creation of codeflash-compatible tsconfig."""
import json
from codeflash.languages.javascript.test_runner import _create_codeflash_tsconfig
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# Create original tsconfig
original_tsconfig = {
"compilerOptions": {
"moduleResolution": "bundler",
"module": "preserve",
"target": "ES2022",
},
"include": ["src/**/*.ts"],
"exclude": ["node_modules"],
}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(original_tsconfig))
# Create codeflash tsconfig
result_path = _create_codeflash_tsconfig(tmpdir_path)
assert result_path.exists()
assert result_path.name == "tsconfig.codeflash.json"
# Verify contents
codeflash_tsconfig = json.loads(result_path.read_text())
assert codeflash_tsconfig["extends"] == "./tsconfig.json"
assert codeflash_tsconfig["compilerOptions"]["moduleResolution"] == "Node"
assert "include" in codeflash_tsconfig
def test_create_codeflash_jest_config(self):
"""Test creation of codeflash Jest config."""
from codeflash.languages.javascript.test_runner import _create_codeflash_jest_config
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# Create codeflash Jest config without original
result_path = _create_codeflash_jest_config(tmpdir_path, None)
assert result_path is not None
assert result_path.exists()
assert result_path.name == "jest.codeflash.config.js"
# Verify it contains the tsconfig reference
content = result_path.read_text()
assert "tsconfig.codeflash.json" in content
assert "ts-jest" in content
def test_get_jest_config_for_project_with_bundler(self):
"""Test that bundler projects get codeflash Jest config."""
import json
from codeflash.languages.javascript.test_runner import _get_jest_config_for_project
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# Create tsconfig with bundler
tsconfig = {
"compilerOptions": {
"moduleResolution": "bundler",
"module": "preserve",
}
}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
(tmpdir_path / "package.json").write_text('{"name": "test"}')
result = _get_jest_config_for_project(tmpdir_path)
assert result is not None
assert result.name == "jest.codeflash.config.js"
# Also verify tsconfig.codeflash.json was created
assert (tmpdir_path / "tsconfig.codeflash.json").exists()
def test_get_jest_config_for_project_without_bundler(self):
"""Test that non-bundler projects use original Jest config."""
import json
from codeflash.languages.javascript.test_runner import _get_jest_config_for_project
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# Create tsconfig with Node moduleResolution
tsconfig = {
"compilerOptions": {
"moduleResolution": "Node",
"module": "ESNext",
}
}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
(tmpdir_path / "package.json").write_text('{"name": "test"}')
# Create original Jest config
(tmpdir_path / "jest.config.js").write_text("module.exports = {};")
result = _get_jest_config_for_project(tmpdir_path)
assert result is not None
assert result.name == "jest.config.js"
# Verify codeflash configs were NOT created
assert not (tmpdir_path / "jest.codeflash.config.js").exists()
assert not (tmpdir_path / "tsconfig.codeflash.json").exists()

View file

@ -39,7 +39,7 @@ class TestCodeExtractorCJS:
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
method_names = {f.name for f in functions}
method_names = {f.function_name for f in functions}
expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"}
assert method_names == expected_methods, f"Expected methods {expected_methods}, got {method_names}"
@ -51,15 +51,15 @@ class TestCodeExtractorCJS:
for func in functions:
# All methods should belong to Calculator class
assert func.is_method is True, f"{func.name} should be a method"
assert func.class_name == "Calculator", f"{func.name} should belong to Calculator, got {func.class_name}"
assert func.is_method is True, f"{func.function_name} should be a method"
assert func.class_name == "Calculator", f"{func.function_name} should belong to Calculator, got {func.class_name}"
def test_extract_permutation_code(self, js_support, cjs_project):
"""Test permutation method code extraction."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
permutation_func = next(f for f in functions if f.name == "permutation")
permutation_func = next(f for f in functions if f.function_name == "permutation")
context = js_support.extract_code_context(
function=permutation_func, project_root=cjs_project, module_root=cjs_project
@ -95,7 +95,7 @@ class Calculator {
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
permutation_func = next(f for f in functions if f.name == "permutation")
permutation_func = next(f for f in functions if f.function_name == "permutation")
context = js_support.extract_code_context(
function=permutation_func, project_root=cjs_project, module_root=cjs_project
@ -136,7 +136,7 @@ function factorial(n) {
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = js_support.extract_code_context(
function=compound_func, project_root=cjs_project, module_root=cjs_project
@ -182,7 +182,7 @@ class Calculator {
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = js_support.extract_code_context(
function=compound_func, project_root=cjs_project, module_root=cjs_project
@ -266,7 +266,7 @@ function validateInput(value, name) {
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = js_support.extract_code_context(
function=compound_func, project_root=cjs_project, module_root=cjs_project
@ -287,7 +287,7 @@ function validateInput(value, name) {
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
quick_add_func = next(f for f in functions if f.name == "quickAdd")
quick_add_func = next(f for f in functions if f.function_name == "quickAdd")
context = js_support.extract_code_context(
function=quick_add_func, project_root=cjs_project, module_root=cjs_project
@ -352,7 +352,7 @@ class TestCodeExtractorESM:
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
method_names = {f.name for f in functions}
method_names = {f.function_name for f in functions}
# Should find same methods as CJS version
expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"}
@ -363,7 +363,7 @@ class TestCodeExtractorESM:
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
permutation_func = next(f for f in functions if f.name == "permutation")
permutation_func = next(f for f in functions if f.function_name == "permutation")
context = js_support.extract_code_context(
function=permutation_func, project_root=esm_project, module_root=esm_project
@ -413,7 +413,7 @@ export function factorial(n) {
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = js_support.extract_code_context(
function=compound_func, project_root=esm_project, module_root=esm_project
@ -539,7 +539,7 @@ class TestCodeExtractorTypeScript:
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file)
method_names = {f.name for f in functions}
method_names = {f.function_name for f in functions}
# TypeScript has additional getHistory method
expected_methods = {"calculateCompoundInterest", "permutation", "getHistory", "quickAdd"}
@ -550,7 +550,7 @@ class TestCodeExtractorTypeScript:
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file)
permutation_func = next(f for f in functions if f.name == "permutation")
permutation_func = next(f for f in functions if f.function_name == "permutation")
context = ts_support.extract_code_context(
function=permutation_func, project_root=ts_project, module_root=ts_project
@ -603,7 +603,7 @@ export function factorial(n: number): number {
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file)
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = ts_support.extract_code_context(
function=compound_func, project_root=ts_project, module_root=ts_project
@ -712,7 +712,7 @@ module.exports = { standalone };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
func = next(f for f in functions if f.name == "standalone")
func = next(f for f in functions if f.function_name == "standalone")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -745,7 +745,7 @@ module.exports = { processArray };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
func = next(f for f in functions if f.name == "processArray")
func = next(f for f in functions if f.function_name == "processArray")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -780,7 +780,7 @@ module.exports = { fibonacci };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
func = next(f for f in functions if f.name == "fibonacci")
func = next(f for f in functions if f.function_name == "fibonacci")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -813,7 +813,7 @@ module.exports = { processValue };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
func = next(f for f in functions if f.name == "processValue")
func = next(f for f in functions if f.function_name == "processValue")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -871,7 +871,7 @@ module.exports = { Counter };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
increment_func = next(f for f in functions if f.name == "increment")
increment_func = next(f for f in functions if f.function_name == "increment")
context = js_support.extract_code_context(function=increment_func, project_root=tmp_path, module_root=tmp_path)
@ -910,7 +910,7 @@ module.exports = { MathUtils };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
add_func = next(f for f in functions if f.name == "add")
add_func = next(f for f in functions if f.function_name == "add")
context = js_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
@ -948,7 +948,7 @@ export { User };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
get_name_func = next(f for f in functions if f.name == "getName")
get_name_func = next(f for f in functions if f.function_name == "getName")
context = ts_support.extract_code_context(function=get_name_func, project_root=tmp_path, module_root=tmp_path)
@ -989,7 +989,7 @@ export { Config };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
get_url_func = next(f for f in functions if f.name == "getUrl")
get_url_func = next(f for f in functions if f.function_name == "getUrl")
context = ts_support.extract_code_context(function=get_url_func, project_root=tmp_path, module_root=tmp_path)
@ -1030,7 +1030,7 @@ module.exports = { Logger };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
get_prefix_func = next(f for f in functions if f.name == "getPrefix")
get_prefix_func = next(f for f in functions if f.function_name == "getPrefix")
context = js_support.extract_code_context(function=get_prefix_func, project_root=tmp_path, module_root=tmp_path)
@ -1072,7 +1072,7 @@ module.exports = { Factory };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
create_func = next(f for f in functions if f.name == "create")
create_func = next(f for f in functions if f.function_name == "create")
context = js_support.extract_code_context(function=create_func, project_root=tmp_path, module_root=tmp_path)
@ -1114,19 +1114,20 @@ class TestCodeExtractorIntegration:
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
target = next(f for f in functions if f.name == "permutation")
target = next(f for f in functions if f.function_name == "permutation")
parents = [FunctionParent(name=p.name, type=p.type) for p in target.parents]
func = FunctionToOptimize(
function_name=target.name,
function_name=target.function_name,
file_path=target.file_path,
parents=parents,
starting_line=target.start_line,
ending_line=target.end_line,
starting_col=target.start_col,
ending_col=target.end_col,
starting_line=target.starting_line,
ending_line=target.ending_line,
starting_col=target.starting_col,
ending_col=target.ending_col,
is_async=target.is_async,
is_method=target.is_method,
language=target.language,
)
@ -1223,7 +1224,7 @@ export { distance };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
distance_func = next(f for f in functions if f.name == "distance")
distance_func = next(f for f in functions if f.function_name == "distance")
context = ts_support.extract_code_context(function=distance_func, project_root=tmp_path, module_root=tmp_path)
@ -1267,7 +1268,7 @@ export { processStatus };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
process_func = next(f for f in functions if f.name == "processStatus")
process_func = next(f for f in functions if f.function_name == "processStatus")
context = ts_support.extract_code_context(function=process_func, project_root=tmp_path, module_root=tmp_path)
@ -1304,7 +1305,7 @@ export { compute };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
compute_func = next(f for f in functions if f.name == "compute")
compute_func = next(f for f in functions if f.function_name == "compute")
context = ts_support.extract_code_context(function=compute_func, project_root=tmp_path, module_root=tmp_path)
@ -1348,7 +1349,7 @@ export { Service };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
get_timeout_func = next(f for f in functions if f.name == "getTimeout")
get_timeout_func = next(f for f in functions if f.function_name == "getTimeout")
context = ts_support.extract_code_context(
function=get_timeout_func, project_root=tmp_path, module_root=tmp_path
@ -1381,7 +1382,7 @@ export { add };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
add_func = next(f for f in functions if f.name == "add")
add_func = next(f for f in functions if f.function_name == "add")
context = ts_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
@ -1414,7 +1415,7 @@ export { createRect };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
create_rect_func = next(f for f in functions if f.name == "createRect")
create_rect_func = next(f for f in functions if f.function_name == "createRect")
context = ts_support.extract_code_context(
function=create_rect_func, project_root=tmp_path, module_root=tmp_path
@ -1462,7 +1463,7 @@ export { calculateDistance };
""")
functions = ts_support.discover_functions(geometry_file)
calc_distance_func = next(f for f in functions if f.name == "calculateDistance")
calc_distance_func = next(f for f in functions if f.function_name == "calculateDistance")
context = ts_support.extract_code_context(
function=calc_distance_func, project_root=ts_types_project, module_root=ts_types_project
@ -1515,7 +1516,7 @@ export { greetUser };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
greet_func = next(f for f in functions if f.name == "greetUser")
greet_func = next(f for f in functions if f.function_name == "greetUser")
context = ts_support.extract_code_context(function=greet_func, project_root=tmp_path, module_root=tmp_path)

View file

@ -167,12 +167,12 @@ class TestCommonJSToESMConversion:
f"CJS to ESM conversion failed.\nInput: {code}\nExpected: {expected}\nGot: {result}"
)
def test_convert_relative_require_adds_extension(self):
"""Test that relative imports get .js extension added - exact output."""
def test_convert_relative_require_preserves_path(self):
"""Test that relative imports preserve the original path without adding extension."""
code = "const { helper } = require('./utils');"
result = convert_commonjs_to_esm(code)
expected = "import { helper } from './utils.js';"
expected = "import { helper } from './utils';"
assert result.strip() == expected, (
f"CJS to ESM conversion failed.\nInput: {code}\nExpected: {expected}\nGot: {result}"
)
@ -182,7 +182,7 @@ class TestCommonJSToESMConversion:
code = "const myHelper = require('./utils').helperFunction;"
result = convert_commonjs_to_esm(code)
expected = "import { helperFunction as myHelper } from './utils.js';"
expected = "import { helperFunction as myHelper } from './utils';"
assert result.strip() == expected, (
f"CJS to ESM conversion failed.\nInput: {code}\nExpected: {expected}\nGot: {result}"
)
@ -192,7 +192,7 @@ class TestCommonJSToESMConversion:
code = "const MyClass = require('./class').default;"
result = convert_commonjs_to_esm(code)
expected = "import MyClass from './class.js';"
expected = "import MyClass from './class';"
assert result.strip() == expected, (
f"CJS to ESM conversion failed.\nInput: {code}\nExpected: {expected}\nGot: {result}"
)
@ -207,7 +207,7 @@ const path = require('path');"""
result = convert_commonjs_to_esm(code)
expected = """\
import { add, subtract } from './math.js';
import { add, subtract } from './math';
import lodash from 'lodash';
import path from 'path';"""
@ -316,7 +316,7 @@ function process() {
result = ensure_module_system_compatibility(code, ModuleSystem.ES_MODULE)
# Should convert require to import
assert "import { helper } from './helpers.js';" in result
assert "import { helper } from './helpers';" in result
assert "require" not in result, f"require should be converted to import. Got:\n{result}"
def test_convert_mixed_code_to_commonjs(self):
@ -711,7 +711,7 @@ module.exports = { targetFunction, otherFunction };
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
target_func = next(f for f in functions if f.name == "targetFunction")
target_func = next(f for f in functions if f.function_name == "targetFunction")
optimized_code = """\
function targetFunction(x) {
@ -763,7 +763,7 @@ class Calculator {
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
add_method = next(f for f in functions if f.name == "add")
add_method = next(f for f in functions if f.function_name == "add")
# Optimized version provided in class context
optimized_code = """\
@ -826,7 +826,7 @@ class DataProcessor {
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
process_method = next(f for f in functions if f.name == "process")
process_method = next(f for f in functions if f.function_name == "process")
optimized_code = """\
class DataProcessor {
@ -948,7 +948,7 @@ class Cache {
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
get_method = next(f for f in functions if f.name == "get")
get_method = next(f for f in functions if f.function_name == "get")
optimized_code = """\
class Cache {
@ -1050,7 +1050,7 @@ class ApiClient {
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
get_method = next(f for f in functions if f.name == "get")
get_method = next(f for f in functions if f.function_name == "get")
optimized_code = """\
class ApiClient {
@ -1181,7 +1181,7 @@ class Container<T> {
file_path.write_text(original_source, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
get_all_method = next(f for f in functions if f.name == "getAll")
get_all_method = next(f for f in functions if f.function_name == "getAll")
optimized_code = """\
class Container<T> {
@ -1234,7 +1234,7 @@ function createUser(name: string, email: string): User {
file_path.write_text(original_source, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
func = next(f for f in functions if f.name == "createUser")
func = next(f for f in functions if f.function_name == "createUser")
optimized_code = """\
function createUser(name: string, email: string): User {
@ -1289,7 +1289,7 @@ function processItems(items) {
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
process_func = next(f for f in functions if f.name == "processItems")
process_func = next(f for f in functions if f.function_name == "processItems")
optimized_code = """\
function processItems(items) {
@ -1336,7 +1336,7 @@ class MathUtils {
# First replacement: sum method
functions = js_support.discover_functions(file_path)
sum_method = next(f for f in functions if f.name == "sum")
sum_method = next(f for f in functions if f.function_name == "sum")
optimized_sum = """\
class MathUtils {
@ -1554,7 +1554,7 @@ module.exports = { main, helper };
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
main_func = next(f for f in functions if f.name == "main")
main_func = next(f for f in functions if f.function_name == "main")
optimized_code = """\
function main(data) {
@ -1597,7 +1597,7 @@ export function main(data) {
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
main_func = next(f for f in functions if f.name == "main")
main_func = next(f for f in functions if f.function_name == "main")
optimized_code = """\
export function main(data) {
@ -1756,7 +1756,7 @@ export class DataProcessor<T> {
# find function
target_func_info = None
for func in functions:
if func.name == target_func and func.parents[0].name == parent_class:
if func.function_name == target_func and func.parents[0].name == parent_class:
target_func_info = func
break
assert target_func_info is not None
@ -1893,3 +1893,252 @@ export class DataProcessor<T> {
}
"""
class TestNewVariableFromOptimizedCode:
"""Tests for handling new variables introduced in optimized code."""
def test_new_bound_method_variable_added_after_referenced_constant(self, ts_support, temp_project):
"""Test that a new variable binding a method is added after the constant it references.
When optimized code introduces a new module-level variable (like `_has`) that
references an existing constant (like `CODEFLASH_EMPLOYEE_GITHUB_IDS`), the
replacement should:
1. Add the new variable after the constant it references
2. Replace the function with the optimized version
"""
from codeflash.models.models import CodeStringsMarkdown, CodeString
original_source = '''\
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
"1234",
]);
export function isCodeflashEmployee(userId: string): boolean {
return CODEFLASH_EMPLOYEE_GITHUB_IDS.has(userId);
}
'''
file_path = temp_project / "auth.ts"
file_path.write_text(original_source, encoding="utf-8")
# Optimized code introduces a bound method variable for performance
optimized_code = '''const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
CODEFLASH_EMPLOYEE_GITHUB_IDS
);
export function isCodeflashEmployee(userId: string): boolean {
return _has(userId);
}
'''
code_markdown = CodeStringsMarkdown(
code_strings=[
CodeString(
code=optimized_code,
file_path=Path("auth.ts"),
language="typescript"
)
],
language="typescript"
)
replaced = replace_function_definitions_for_language(
["isCodeflashEmployee"],
code_markdown,
file_path,
temp_project,
)
assert replaced
result = file_path.read_text()
# Expected result for strict equality check
expected_result = '''\
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
"1234",
]);
const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
CODEFLASH_EMPLOYEE_GITHUB_IDS
);
export function isCodeflashEmployee(userId: string): boolean {
return _has(userId);
}
'''
assert result == expected_result, (
f"Result does not match expected output.\n"
f"Expected:\n{expected_result}\n\n"
f"Got:\n{result}"
)
class TestImportedTypeNotDuplicated:
"""Tests to ensure imported types are not duplicated during code replacement.
When a type is already imported in the original file, it should NOT be
added as a new declaration from the optimized code, even if the optimized
code contains the type definition (because it was provided as context).
See: https://github.com/codeflash-ai/appsmith/pull/20
"""
def test_imported_interface_not_added_as_declaration(self, ts_support, temp_project):
"""Test that an imported interface is not duplicated in the output.
When TreeNode is imported from another file and the optimized code
contains the TreeNode interface definition (from read-only context),
the replacement should NOT add the interface to the original file.
"""
from codeflash.models.models import CodeStringsMarkdown, CodeString
# Original source imports TreeNode
original_source = """\
import type { TreeNode } from "./constants";
export function getNearestAbove(
tree: Record<string, TreeNode>,
effectedBoxId: string,
) {
const aboves = tree[effectedBoxId].aboves;
return aboves.reduce((prev: string[], next: string) => {
if (!prev[0]) return [next];
let nextBottomRow = tree[next].bottomRow;
let prevBottomRow = tree[prev[0]].bottomRow;
if (nextBottomRow > prevBottomRow) return [next];
return prev;
}, []);
}
"""
file_path = temp_project / "helpers.ts"
file_path.write_text(original_source, encoding="utf-8")
# Optimized code includes the TreeNode interface (from read-only context)
# This simulates what the AI might return when type definitions are included in context
optimized_code_with_interface = """\
interface TreeNode {
aboves: string[];
belows: string[];
topRow: number;
bottomRow: number;
}
export function getNearestAbove(
tree: Record<string, TreeNode>,
effectedBoxId: string,
) {
const aboves = tree[effectedBoxId].aboves;
return aboves.reduce((prev: string[], next: string) => {
if (!prev[0]) return [next];
// Optimized: cache lookups
const nextBottomRow = tree[next].bottomRow;
const prevBottomRow = tree[prev[0]].bottomRow;
return nextBottomRow > prevBottomRow ? [next] : prev;
}, []);
}
"""
code_markdown = CodeStringsMarkdown(
code_strings=[
CodeString(
code=optimized_code_with_interface,
file_path=Path("helpers.ts"),
language="typescript"
)
],
language="typescript"
)
replace_function_definitions_for_language(
["getNearestAbove"],
code_markdown,
file_path,
temp_project,
)
result = file_path.read_text()
# The TreeNode interface should NOT appear in the result
# (it's already imported, so adding it would cause a duplicate)
assert "interface TreeNode" not in result, (
f"TreeNode interface should NOT be added to the file since it's already imported.\n"
f"Result contains:\n{result}"
)
# The import should still be there
assert 'import type { TreeNode } from "./constants"' in result, (
f"Original import should be preserved.\nResult:\n{result}"
)
# The optimized function should be there
assert "// Optimized: cache lookups" in result, (
f"Optimized function should be in the result.\nResult:\n{result}"
)
# The result should be valid TypeScript
assert ts_support.validate_syntax(result) is True
def test_multiple_imported_types_not_duplicated(self, ts_support, temp_project):
"""Test that multiple imported types are not duplicated."""
from codeflash.models.models import CodeStringsMarkdown, CodeString
original_source = """\
import type { TreeNode, NodeSpace } from "./constants";
import { MAX_BOX_SIZE } from "./constants";
export function processNode(node: TreeNode, space: NodeSpace): number {
return node.topRow + space.top;
}
"""
file_path = temp_project / "processor.ts"
file_path.write_text(original_source, encoding="utf-8")
# Optimized code includes both interfaces
optimized_code = """\
interface TreeNode {
topRow: number;
bottomRow: number;
}
interface NodeSpace {
top: number;
bottom: number;
}
export function processNode(node: TreeNode, space: NodeSpace): number {
// Optimized
return (node.topRow + space.top) | 0;
}
"""
code_markdown = CodeStringsMarkdown(
code_strings=[
CodeString(
code=optimized_code,
file_path=Path("processor.ts"),
language="typescript"
)
],
language="typescript"
)
replace_function_definitions_for_language(
["processNode"],
code_markdown,
file_path,
temp_project,
)
result = file_path.read_text()
# Neither interface should be added
assert "interface TreeNode" not in result
assert "interface NodeSpace" not in result
# Imports should be preserved
assert 'import type { TreeNode, NodeSpace } from "./constants"' in result
# Optimized code should be there
assert "// Optimized" in result
assert ts_support.validate_syntax(result) is True

View file

@ -353,8 +353,8 @@ class TestDiscoverFunctionsParity:
assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1"
# Both should find 'add'
assert py_funcs[0].name == "add"
assert js_funcs[0].name == "add"
assert py_funcs[0].function_name == "add"
assert js_funcs[0].function_name == "add"
# Both should have correct language
assert py_funcs[0].language == Language.PYTHON
@ -373,8 +373,8 @@ class TestDiscoverFunctionsParity:
assert len(js_funcs) == 3, f"JavaScript found {len(js_funcs)}, expected 3"
# Both should find the same function names
py_names = {f.name for f in py_funcs}
js_names = {f.name for f in js_funcs}
py_names = {f.function_name for f in py_funcs}
js_names = {f.function_name for f in js_funcs}
assert py_names == {"add", "subtract", "multiply"}
assert js_names == {"add", "subtract", "multiply"}
@ -392,8 +392,8 @@ class TestDiscoverFunctionsParity:
assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1"
# The function with return should be found
assert py_funcs[0].name == "with_return"
assert js_funcs[0].name == "withReturn"
assert py_funcs[0].function_name == "with_return"
assert js_funcs[0].function_name == "withReturn"
def test_class_methods_discovery(self, python_support, js_support):
"""Both should discover class methods with proper metadata."""
@ -409,12 +409,12 @@ class TestDiscoverFunctionsParity:
# All should be marked as methods
for func in py_funcs:
assert func.is_method is True, f"Python {func.name} should be a method"
assert func.class_name == "Calculator", f"Python {func.name} should belong to Calculator"
assert func.is_method is True, f"Python {func.function_name} should be a method"
assert func.class_name == "Calculator", f"Python {func.function_name} should belong to Calculator"
for func in js_funcs:
assert func.is_method is True, f"JavaScript {func.name} should be a method"
assert func.class_name == "Calculator", f"JavaScript {func.name} should belong to Calculator"
assert func.is_method is True, f"JavaScript {func.function_name} should be a method"
assert func.class_name == "Calculator", f"JavaScript {func.function_name} should belong to Calculator"
def test_async_functions_discovery(self, python_support, js_support):
"""Both should correctly identify async functions."""
@ -429,10 +429,10 @@ class TestDiscoverFunctionsParity:
assert len(js_funcs) == 2, f"JavaScript found {len(js_funcs)}, expected 2"
# Check async flags
py_async = next(f for f in py_funcs if "fetch" in f.name.lower())
py_sync = next(f for f in py_funcs if "sync" in f.name.lower())
js_async = next(f for f in js_funcs if "fetch" in f.name.lower())
js_sync = next(f for f in js_funcs if "sync" in f.name.lower())
py_async = next(f for f in py_funcs if "fetch" in f.function_name.lower())
py_sync = next(f for f in py_funcs if "sync" in f.function_name.lower())
js_async = next(f for f in js_funcs if "fetch" in f.function_name.lower())
js_sync = next(f for f in js_funcs if "sync" in f.function_name.lower())
assert py_async.is_async is True, "Python async function should have is_async=True"
assert py_sync.is_async is False, "Python sync function should have is_async=False"
@ -452,15 +452,15 @@ class TestDiscoverFunctionsParity:
assert len(js_funcs) == 2, f"JavaScript found {len(js_funcs)}, expected 2"
# Check names
py_names = {f.name for f in py_funcs}
js_names = {f.name for f in js_funcs}
py_names = {f.function_name for f in py_funcs}
js_names = {f.function_name for f in js_funcs}
assert py_names == {"outer", "inner"}, f"Python found {py_names}"
assert js_names == {"outer", "inner"}, f"JavaScript found {js_names}"
# Check parent info for inner function
py_inner = next(f for f in py_funcs if f.name == "inner")
js_inner = next(f for f in js_funcs if f.name == "inner")
py_inner = next(f for f in py_funcs if f.function_name == "inner")
js_inner = next(f for f in js_funcs if f.function_name == "inner")
assert len(py_inner.parents) >= 1, "Python inner should have parent info"
assert py_inner.parents[0].name == "outer", "Python inner's parent should be outer"
@ -482,8 +482,8 @@ class TestDiscoverFunctionsParity:
assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1"
# Both should find 'helper' belonging to 'Utils'
assert py_funcs[0].name == "helper"
assert js_funcs[0].name == "helper"
assert py_funcs[0].function_name == "helper"
assert js_funcs[0].function_name == "helper"
assert py_funcs[0].class_name == "Utils"
assert js_funcs[0].class_name == "Utils"
@ -532,8 +532,8 @@ class TestDiscoverFunctionsParity:
assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1"
# Should be the sync function
assert "sync" in py_funcs[0].name.lower()
assert "sync" in js_funcs[0].name.lower()
assert "sync" in py_funcs[0].function_name.lower()
assert "sync" in js_funcs[0].function_name.lower()
def test_filter_exclude_methods(self, python_support, js_support):
"""Both should support filtering out class methods."""
@ -550,8 +550,8 @@ class TestDiscoverFunctionsParity:
assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1"
# Should be the standalone function
assert py_funcs[0].name == "standalone"
assert js_funcs[0].name == "standalone"
assert py_funcs[0].function_name == "standalone"
assert js_funcs[0].function_name == "standalone"
def test_nonexistent_file_returns_empty(self, python_support, js_support):
"""Both should return empty list for nonexistent files."""
@ -570,14 +570,14 @@ class TestDiscoverFunctionsParity:
js_funcs = js_support.discover_functions(js_file)
# Both should have start_line and end_line
assert py_funcs[0].start_line is not None
assert py_funcs[0].end_line is not None
assert js_funcs[0].start_line is not None
assert js_funcs[0].end_line is not None
assert py_funcs[0].starting_line is not None
assert py_funcs[0].ending_line is not None
assert js_funcs[0].starting_line is not None
assert js_funcs[0].ending_line is not None
# Start should be before or equal to end
assert py_funcs[0].start_line <= py_funcs[0].end_line
assert js_funcs[0].start_line <= js_funcs[0].end_line
assert py_funcs[0].starting_line <= py_funcs[0].ending_line
assert js_funcs[0].starting_line <= js_funcs[0].ending_line
# ============================================================================
@ -604,8 +604,8 @@ function multiply(a, b) {
return a * b;
}
"""
py_func = FunctionInfo(name="add", file_path=Path("/test.py"), start_line=1, end_line=2)
js_func = FunctionInfo(name="add", file_path=Path("/test.js"), start_line=1, end_line=3)
py_func = FunctionInfo(function_name="add", file_path=Path("/test.py"), starting_line=1, ending_line=2)
js_func = FunctionInfo(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3)
py_new = """def add(a, b):
return (a + b) | 0
@ -651,8 +651,8 @@ function other() {
// Footer
"""
py_func = FunctionInfo(name="target", file_path=Path("/test.py"), start_line=4, end_line=5)
js_func = FunctionInfo(name="target", file_path=Path("/test.js"), start_line=4, end_line=6)
py_func = FunctionInfo(function_name="target", file_path=Path("/test.py"), starting_line=4, ending_line=5)
js_func = FunctionInfo(function_name="target", file_path=Path("/test.js"), starting_line=4, ending_line=6)
py_new = """def target():
return 42
@ -693,18 +693,18 @@ function other() {
}
"""
py_func = FunctionInfo(
name="add",
function_name="add",
file_path=Path("/test.py"),
start_line=2,
end_line=3,
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
starting_line=2,
ending_line=3,
parents=[ParentInfo(name="Calculator", type="ClassDef")],
)
js_func = FunctionInfo(
name="add",
function_name="add",
file_path=Path("/test.js"),
start_line=2,
end_line=4,
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
starting_line=2,
ending_line=4,
parents=[ParentInfo(name="Calculator", type="ClassDef")],
)
# New code without indentation
@ -872,8 +872,8 @@ class TestExtractCodeContextParity:
".js",
)
py_func = FunctionInfo(name="add", file_path=py_file, start_line=1, end_line=2)
js_func = FunctionInfo(name="add", file_path=js_file, start_line=1, end_line=3)
py_func = FunctionInfo(function_name="add", file_path=py_file, starting_line=1, ending_line=2)
js_func = FunctionInfo(function_name="add", file_path=js_file, starting_line=1, ending_line=3)
py_context = python_support.extract_code_context(py_func, py_file.parent, py_file.parent)
js_context = js_support.extract_code_context(js_func, js_file.parent, js_file.parent)
@ -922,8 +922,8 @@ class TestIntegrationParity:
assert len(py_funcs) == 1
assert len(js_funcs) == 1
assert py_funcs[0].name == "fibonacci"
assert js_funcs[0].name == "fibonacci"
assert py_funcs[0].function_name == "fibonacci"
assert js_funcs[0].function_name == "fibonacci"
# Replace
py_optimized = """def fibonacci(n):
@ -974,20 +974,20 @@ class TestFeatureGaps:
for py_func in py_funcs:
# Check all expected fields are populated
assert py_func.name is not None, "Python: name should be populated"
assert py_func.function_name is not None, "Python: name should be populated"
assert py_func.file_path is not None, "Python: file_path should be populated"
assert py_func.start_line is not None, "Python: start_line should be populated"
assert py_func.end_line is not None, "Python: end_line should be populated"
assert py_func.starting_line is not None, "Python: start_line should be populated"
assert py_func.ending_line is not None, "Python: end_line should be populated"
assert py_func.language is not None, "Python: language should be populated"
# is_method and class_name should be set for class methods
assert py_func.is_method is not None, "Python: is_method should be populated"
for js_func in js_funcs:
# JavaScript should populate the same fields
assert js_func.name is not None, "JavaScript: name should be populated"
assert js_func.function_name is not None, "JavaScript: name should be populated"
assert js_func.file_path is not None, "JavaScript: file_path should be populated"
assert js_func.start_line is not None, "JavaScript: start_line should be populated"
assert js_func.end_line is not None, "JavaScript: end_line should be populated"
assert js_func.starting_line is not None, "JavaScript: start_line should be populated"
assert js_func.ending_line is not None, "JavaScript: end_line should be populated"
assert js_func.language is not None, "JavaScript: language should be populated"
assert js_func.is_method is not None, "JavaScript: is_method should be populated"
@ -1006,7 +1006,7 @@ const identity = x => x;
funcs = js_support.discover_functions(js_file)
# Should find all arrow functions
names = {f.name for f in funcs}
names = {f.function_name for f in funcs}
assert "add" in names, "Should find arrow function 'add'"
assert "multiply" in names, "Should find concise arrow function 'multiply'"
# identity might or might not be found depending on implicit return handling
@ -1057,7 +1057,7 @@ def multi_decorated():
funcs = python_support.discover_functions(py_file)
# Should find all functions regardless of decorators
names = {f.name for f in funcs}
names = {f.function_name for f in funcs}
assert "decorated" in names
assert "decorated_with_args" in names
assert "multi_decorated" in names
@ -1077,7 +1077,7 @@ const namedExpr = function myFunc(x) {
funcs = js_support.discover_functions(js_file)
# Should find function expressions
names = {f.name for f in funcs}
names = {f.function_name for f in funcs}
assert "add" in names, "Should find anonymous function expression assigned to 'add'"
@ -1144,5 +1144,5 @@ function greeting() {
assert len(py_funcs) == 1
assert len(js_funcs) == 1
assert py_funcs[0].name == "greeting"
assert js_funcs[0].name == "greeting"
assert py_funcs[0].function_name == "greeting"
assert js_funcs[0].function_name == "greeting"

View file

@ -113,19 +113,20 @@ def test_js_replcement() -> None:
functions = js_support.discover_functions(main_file)
target = None
for func in functions:
if func.name == "calculateStats":
if func.function_name == "calculateStats":
target = func
break
assert target is not None
func = FunctionToOptimize(
function_name=target.name,
function_name=target.function_name,
file_path=target.file_path,
parents=target.parents,
starting_line=target.start_line,
ending_line=target.end_line,
starting_col=target.start_col,
ending_col=target.end_col,
starting_line=target.starting_line,
ending_line=target.ending_line,
starting_col=target.starting_col,
ending_col=target.ending_col,
is_async=target.is_async,
is_method=target.is_method,
language=target.language,
)
test_config = TestConfig(

View file

@ -52,7 +52,7 @@ def add(a, b):
functions = python_support.discover_functions(Path(f.name))
assert len(functions) == 1
assert functions[0].name == "add"
assert functions[0].function_name == "add"
assert functions[0].language == Language.PYTHON
def test_discover_multiple_functions(self, python_support):
@ -73,7 +73,7 @@ def multiply(a, b):
functions = python_support.discover_functions(Path(f.name))
assert len(functions) == 3
names = {func.name for func in functions}
names = {func.function_name for func in functions}
assert names == {"add", "subtract", "multiply"}
def test_discover_function_with_no_return_excluded(self, python_support):
@ -92,7 +92,7 @@ def without_return():
# Only the function with return should be discovered
assert len(functions) == 1
assert functions[0].name == "with_return"
assert functions[0].function_name == "with_return"
def test_discover_class_methods(self, python_support):
"""Test discovering class methods."""
@ -130,8 +130,8 @@ def sync_function():
assert len(functions) == 2
async_func = next(f for f in functions if f.name == "fetch_data")
sync_func = next(f for f in functions if f.name == "sync_function")
async_func = next(f for f in functions if f.function_name == "fetch_data")
sync_func = next(f for f in functions if f.function_name == "sync_function")
assert async_func.is_async is True
assert sync_func.is_async is False
@ -151,11 +151,11 @@ def outer():
# Both outer and inner should be discovered
assert len(functions) == 2
names = {func.name for func in functions}
names = {func.function_name for func in functions}
assert names == {"outer", "inner"}
# Inner should have outer as parent
inner = next(f for f in functions if f.name == "inner")
inner = next(f for f in functions if f.function_name == "inner")
assert len(inner.parents) == 1
assert inner.parents[0].name == "outer"
assert inner.parents[0].type == "FunctionDef"
@ -174,7 +174,7 @@ class Utils:
functions = python_support.discover_functions(Path(f.name))
assert len(functions) == 1
assert functions[0].name == "helper"
assert functions[0].function_name == "helper"
assert functions[0].class_name == "Utils"
def test_discover_with_filter_exclude_async(self, python_support):
@ -193,7 +193,7 @@ def sync_func():
functions = python_support.discover_functions(Path(f.name), criteria)
assert len(functions) == 1
assert functions[0].name == "sync_func"
assert functions[0].function_name == "sync_func"
def test_discover_with_filter_exclude_methods(self, python_support):
"""Test filtering out class methods."""
@ -212,7 +212,7 @@ class MyClass:
functions = python_support.discover_functions(Path(f.name), criteria)
assert len(functions) == 1
assert functions[0].name == "standalone"
assert functions[0].function_name == "standalone"
def test_discover_line_numbers(self, python_support):
"""Test that line numbers are correctly captured."""
@ -229,13 +229,13 @@ def func2():
functions = python_support.discover_functions(Path(f.name))
func1 = next(f for f in functions if f.name == "func1")
func2 = next(f for f in functions if f.name == "func2")
func1 = next(f for f in functions if f.function_name == "func1")
func2 = next(f for f in functions if f.function_name == "func2")
assert func1.start_line == 1
assert func1.end_line == 2
assert func2.start_line == 4
assert func2.end_line == 7
assert func1.starting_line == 1
assert func1.ending_line == 2
assert func2.starting_line == 4
assert func2.ending_line == 7
def test_discover_invalid_file_returns_empty(self, python_support):
"""Test that invalid Python file returns empty list."""
@ -263,7 +263,7 @@ class TestReplaceFunction:
def multiply(a, b):
return a * b
"""
func = FunctionInfo(name="add", file_path=Path("/test.py"), start_line=1, end_line=2)
func = FunctionInfo(function_name="add", file_path=Path("/test.py"), starting_line=1, ending_line=2)
new_code = """def add(a, b):
# Optimized
return (a + b) | 0
@ -287,7 +287,7 @@ def other():
# Footer
"""
func = FunctionInfo(name="target", file_path=Path("/test.py"), start_line=4, end_line=5)
func = FunctionInfo(function_name="target", file_path=Path("/test.py"), starting_line=4, ending_line=5)
new_code = """def target():
return 42
"""
@ -306,11 +306,11 @@ def other():
return a + b
"""
func = FunctionInfo(
name="add",
function_name="add",
file_path=Path("/test.py"),
start_line=2,
end_line=3,
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
starting_line=2,
ending_line=3,
parents=[ParentInfo(name="Calculator", type="ClassDef")],
)
# New code has no indentation
new_code = """def add(self, a, b):
@ -331,7 +331,7 @@ def other():
def second():
return 2
"""
func = FunctionInfo(name="first", file_path=Path("/test.py"), start_line=1, end_line=2)
func = FunctionInfo(function_name="first", file_path=Path("/test.py"), starting_line=1, ending_line=2)
new_code = """def first():
return 100
"""
@ -348,7 +348,7 @@ def second():
def last():
return 999
"""
func = FunctionInfo(name="last", file_path=Path("/test.py"), start_line=4, end_line=5)
func = FunctionInfo(function_name="last", file_path=Path("/test.py"), starting_line=4, ending_line=5)
new_code = """def last():
return 1000
"""
@ -362,7 +362,7 @@ def last():
source = """def only():
return 42
"""
func = FunctionInfo(name="only", file_path=Path("/test.py"), start_line=1, end_line=2)
func = FunctionInfo(function_name="only", file_path=Path("/test.py"), starting_line=1, ending_line=2)
new_code = """def only():
return 100
"""
@ -474,7 +474,7 @@ class TestExtractCodeContext:
f.flush()
file_path = Path(f.name)
func = FunctionInfo(name="add", file_path=file_path, start_line=1, end_line=2)
func = FunctionInfo(function_name="add", file_path=file_path, starting_line=1, ending_line=2)
context = python_support.extract_code_context(func, file_path.parent, file_path.parent)
@ -503,7 +503,7 @@ class TestIntegration:
functions = python_support.discover_functions(file_path)
assert len(functions) == 1
func = functions[0]
assert func.name == "fibonacci"
assert func.function_name == "fibonacci"
# Replace
optimized_code = """def fibonacci(n):

View file

@ -351,6 +351,32 @@ class TestFindImports:
assert imports[0].module_path == "fs"
assert imports[0].default_import == "fs"
def test_require_inside_function_not_import(self, js_analyzer):
"""Test that require() inside functions is not treated as an import.
This is important because dynamic require() calls inside functions are
not module-level imports and should not be extracted as such.
"""
code = """
const fs = require('fs');
function loadModule() {
const dynamic = require('dynamic-module');
return dynamic;
}
class MyClass {
method() {
const inMethod = require('method-module');
}
}
"""
imports = js_analyzer.find_imports(code)
# Only the module-level require should be found
assert len(imports) == 1
assert imports[0].module_path == "fs"
def test_find_multiple_imports(self, js_analyzer):
"""Test finding multiple imports."""
code = """

Some files were not shown because too many files have changed in this diff Show more