mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix: resolve merge conflict and standardize Java to use FunctionToOptimize
- Resolve merge conflict in code_replacer.py with Java-specific handling - Update all Java modules to use FunctionToOptimize instead of FunctionInfo - Add Language.JAVA to language_enum.py - Update attribute names: name→function_name, start_line→starting_line, etc. - Update all Java tests to use correct attribute names Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
commit
520a1ff08e
107 changed files with 22062 additions and 1476 deletions
20
.github/workflows/claude.yml
vendored
20
.github/workflows/claude.yml
vendored
|
|
@ -19,16 +19,30 @@ jobs:
|
|||
(github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for Claude to read CI results on PRs
|
||||
steps:
|
||||
- name: Get PR head ref
|
||||
id: pr-ref
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
# For issue_comment events, we need to fetch the PR info
|
||||
if [ "${{ github.event_name }}" = "issue_comment" ]; then
|
||||
PR_REF=$(gh api repos/${{ github.repository }}/pulls/${{ github.event.issue.number }} --jq '.head.ref')
|
||||
echo "ref=$PR_REF" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "ref=${{ github.event.pull_request.head.ref || github.head_ref }}" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
fetch-depth: 0
|
||||
ref: ${{ steps.pr-ref.outputs.ref }}
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
|
|
|
|||
50
.github/workflows/js-tests.yml
vendored
Normal file
50
.github/workflows/js-tests.yml
vendored
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
name: JavaScript/TypeScript Integration Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
js-integration-tests:
|
||||
name: JS/TS Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
uv venv --seed
|
||||
uv sync
|
||||
|
||||
- name: Install npm dependencies for test projects
|
||||
run: |
|
||||
npm install --prefix code_to_optimize/js/code_to_optimize_js
|
||||
npm install --prefix code_to_optimize/js/code_to_optimize_ts
|
||||
npm install --prefix code_to_optimize/js/code_to_optimize_vitest
|
||||
|
||||
- name: Run JavaScript integration tests
|
||||
run: |
|
||||
uv run pytest tests/languages/javascript/ -v
|
||||
uv run pytest tests/test_languages/test_vitest_e2e.py -v
|
||||
uv run pytest tests/test_languages/test_javascript_e2e.py -v
|
||||
uv run pytest tests/test_languages/test_javascript_support.py -v
|
||||
uv run pytest tests/code_utils/test_config_js.py -v
|
||||
|
|
@ -14,15 +14,13 @@
|
|||
}
|
||||
},
|
||||
"../../../packages/codeflash": {
|
||||
"version": "0.1.0",
|
||||
"version": "0.3.1",
|
||||
"dev": true,
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@msgpack/msgpack": "^3.0.0",
|
||||
"better-sqlite3": "^12.0.0",
|
||||
"jest-junit": "^16.0.0",
|
||||
"jest-runner": "^29.7.0"
|
||||
"better-sqlite3": "^12.0.0"
|
||||
},
|
||||
"bin": {
|
||||
"codeflash": "bin/codeflash.js",
|
||||
|
|
@ -33,7 +31,8 @@
|
|||
},
|
||||
"peerDependencies": {
|
||||
"jest": ">=27.0.0",
|
||||
"jest-runner": ">=27.0.0"
|
||||
"jest-runner": ">=27.0.0",
|
||||
"vitest": ">=1.0.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"jest": {
|
||||
|
|
@ -41,6 +40,9 @@
|
|||
},
|
||||
"jest-runner": {
|
||||
"optional": true
|
||||
},
|
||||
"vitest": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -14,15 +14,13 @@
|
|||
}
|
||||
},
|
||||
"../../../packages/codeflash": {
|
||||
"version": "0.2.0",
|
||||
"version": "0.3.1",
|
||||
"dev": true,
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@msgpack/msgpack": "^3.0.0",
|
||||
"better-sqlite3": "^12.0.0",
|
||||
"jest-junit": "^16.0.0",
|
||||
"jest-runner": "^29.7.0"
|
||||
"better-sqlite3": "^12.0.0"
|
||||
},
|
||||
"bin": {
|
||||
"codeflash": "bin/codeflash.js",
|
||||
|
|
@ -33,7 +31,8 @@
|
|||
},
|
||||
"peerDependencies": {
|
||||
"jest": ">=27.0.0",
|
||||
"jest-runner": ">=27.0.0"
|
||||
"jest-runner": ">=27.0.0",
|
||||
"vitest": ">=1.0.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"jest": {
|
||||
|
|
@ -41,6 +40,9 @@
|
|||
},
|
||||
"jest-runner": {
|
||||
"optional": true
|
||||
},
|
||||
"vitest": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
|||
7409
code_to_optimize/js/code_to_optimize_js_esm/package-lock.json
generated
Normal file
7409
code_to_optimize/js/code_to_optimize_js_esm/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -20,7 +20,7 @@
|
|||
}
|
||||
},
|
||||
"../../../packages/codeflash": {
|
||||
"version": "0.1.0",
|
||||
"version": "0.3.1",
|
||||
"dev": true,
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
|
|
@ -36,11 +36,19 @@
|
|||
"node": ">=18.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"jest": ">=27.0.0"
|
||||
"jest": ">=27.0.0",
|
||||
"jest-runner": ">=27.0.0",
|
||||
"vitest": ">=1.0.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"jest": {
|
||||
"optional": true
|
||||
},
|
||||
"jest-runner": {
|
||||
"optional": true
|
||||
},
|
||||
"vitest": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
language: typescript
|
||||
52
code_to_optimize/js/code_to_optimize_vitest/fibonacci.ts
Normal file
52
code_to_optimize/js/code_to_optimize_vitest/fibonacci.ts
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Fibonacci implementations - intentionally inefficient for optimization testing.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Calculate the nth Fibonacci number using naive recursion.
|
||||
* This is intentionally slow to demonstrate optimization potential.
|
||||
* @param n - The index of the Fibonacci number to calculate
|
||||
* @returns The nth Fibonacci number
|
||||
*/
|
||||
export function fibonacci(n: number): number {
|
||||
if (n <= 1) {
|
||||
return n;
|
||||
}
|
||||
return fibonacci(n - 1) + fibonacci(n - 2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a Fibonacci number.
|
||||
* @param num - The number to check
|
||||
* @returns True if num is a Fibonacci number
|
||||
*/
|
||||
export function isFibonacci(num: number): boolean {
|
||||
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
|
||||
const check1 = 5 * num * num + 4;
|
||||
const check2 = 5 * num * num - 4;
|
||||
|
||||
return isPerfectSquare(check1) || isPerfectSquare(check2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a perfect square.
|
||||
* @param n - The number to check
|
||||
* @returns True if n is a perfect square
|
||||
*/
|
||||
export function isPerfectSquare(n: number): boolean {
|
||||
const sqrt = Math.sqrt(n);
|
||||
return sqrt === Math.floor(sqrt);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an array of Fibonacci numbers up to n.
|
||||
* @param n - The number of Fibonacci numbers to generate
|
||||
* @returns Array of Fibonacci numbers
|
||||
*/
|
||||
export function fibonacciSequence(n: number): number[] {
|
||||
const result: number[] = [];
|
||||
for (let i = 0; i < n; i++) {
|
||||
result.push(fibonacci(i));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
1492
code_to_optimize/js/code_to_optimize_vitest/package-lock.json
generated
Normal file
1492
code_to_optimize/js/code_to_optimize_vitest/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load diff
23
code_to_optimize/js/code_to_optimize_vitest/package.json
Normal file
23
code_to_optimize/js/code_to_optimize_vitest/package.json
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"name": "codeflash-vitest-test",
|
||||
"version": "1.0.0",
|
||||
"description": "Sample TypeScript project with Vitest for codeflash optimization testing",
|
||||
"type": "module",
|
||||
"main": "dist/index.js",
|
||||
"scripts": {
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest",
|
||||
"test:coverage": "vitest run --coverage",
|
||||
"build": "tsc"
|
||||
},
|
||||
"codeflash": {
|
||||
"moduleRoot": ".",
|
||||
"testsRoot": "tests"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.0.0",
|
||||
"codeflash": "file:../../../packages/codeflash",
|
||||
"typescript": "^5.0.0",
|
||||
"vitest": "^2.0.0"
|
||||
}
|
||||
}
|
||||
62
code_to_optimize/js/code_to_optimize_vitest/string_utils.ts
Normal file
62
code_to_optimize/js/code_to_optimize_vitest/string_utils.ts
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* String utilities - intentionally inefficient for optimization testing.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Reverse a string character by character.
|
||||
* @param str - The string to reverse
|
||||
* @returns The reversed string
|
||||
*/
|
||||
export function reverseString(str: string): string {
|
||||
let result = '';
|
||||
for (let i = str.length - 1; i >= 0; i--) {
|
||||
result += str[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a string is a palindrome.
|
||||
* @param str - The string to check
|
||||
* @returns True if the string is a palindrome
|
||||
*/
|
||||
export function isPalindrome(str: string): boolean {
|
||||
const cleaned = str.toLowerCase().replace(/[^a-z0-9]/g, '');
|
||||
return cleaned === reverseString(cleaned);
|
||||
}
|
||||
|
||||
/**
|
||||
* Count vowels in a string.
|
||||
* @param str - The string to analyze
|
||||
* @returns The number of vowels
|
||||
*/
|
||||
export function countVowels(str: string): number {
|
||||
const vowels = 'aeiouAEIOU';
|
||||
let count = 0;
|
||||
for (const char of str) {
|
||||
if (vowels.includes(char)) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find all unique words in a string.
|
||||
* @param str - The string to analyze
|
||||
* @returns Array of unique words
|
||||
*/
|
||||
export function uniqueWords(str: string): string[] {
|
||||
const words = str.toLowerCase().split(/\s+/).filter(w => w.length > 0);
|
||||
const seen = new Set<string>();
|
||||
const result: string[] = [];
|
||||
|
||||
for (const word of words) {
|
||||
if (!seen.has(word)) {
|
||||
seen.add(word);
|
||||
result.push(word);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
import { describe, test, expect } from 'vitest';
|
||||
import { fibonacci, isFibonacci, isPerfectSquare, fibonacciSequence } from '../fibonacci';
|
||||
|
||||
describe('fibonacci', () => {
|
||||
test('returns 0 for n=0', () => {
|
||||
expect(fibonacci(0)).toBe(0);
|
||||
});
|
||||
|
||||
test('returns 1 for n=1', () => {
|
||||
expect(fibonacci(1)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 1 for n=2', () => {
|
||||
expect(fibonacci(2)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 5 for n=5', () => {
|
||||
expect(fibonacci(5)).toBe(5);
|
||||
});
|
||||
|
||||
test('returns 55 for n=10', () => {
|
||||
expect(fibonacci(10)).toBe(55);
|
||||
});
|
||||
|
||||
test('returns 233 for n=13', () => {
|
||||
expect(fibonacci(13)).toBe(233);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isFibonacci', () => {
|
||||
test('returns true for 0', () => {
|
||||
expect(isFibonacci(0)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 1', () => {
|
||||
expect(isFibonacci(1)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 8', () => {
|
||||
expect(isFibonacci(8)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 13', () => {
|
||||
expect(isFibonacci(13)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for 4', () => {
|
||||
expect(isFibonacci(4)).toBe(false);
|
||||
});
|
||||
|
||||
test('returns false for 6', () => {
|
||||
expect(isFibonacci(6)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPerfectSquare', () => {
|
||||
test('returns true for 0', () => {
|
||||
expect(isPerfectSquare(0)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 1', () => {
|
||||
expect(isPerfectSquare(1)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 4', () => {
|
||||
expect(isPerfectSquare(4)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 16', () => {
|
||||
expect(isPerfectSquare(16)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for 2', () => {
|
||||
expect(isPerfectSquare(2)).toBe(false);
|
||||
});
|
||||
|
||||
test('returns false for 3', () => {
|
||||
expect(isPerfectSquare(3)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fibonacciSequence', () => {
|
||||
test('returns empty array for n=0', () => {
|
||||
expect(fibonacciSequence(0)).toEqual([]);
|
||||
});
|
||||
|
||||
test('returns [0] for n=1', () => {
|
||||
expect(fibonacciSequence(1)).toEqual([0]);
|
||||
});
|
||||
|
||||
test('returns first 5 Fibonacci numbers', () => {
|
||||
expect(fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
|
||||
});
|
||||
|
||||
test('returns first 10 Fibonacci numbers', () => {
|
||||
expect(fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
import { describe, test, expect } from 'vitest';
|
||||
import { reverseString, isPalindrome, countVowels, uniqueWords } from '../string_utils';
|
||||
|
||||
describe('reverseString', () => {
|
||||
test('reverses a simple string', () => {
|
||||
expect(reverseString('hello')).toBe('olleh');
|
||||
});
|
||||
|
||||
test('returns empty string for empty input', () => {
|
||||
expect(reverseString('')).toBe('');
|
||||
});
|
||||
|
||||
test('returns same character for single character', () => {
|
||||
expect(reverseString('a')).toBe('a');
|
||||
});
|
||||
|
||||
test('handles strings with spaces', () => {
|
||||
expect(reverseString('hello world')).toBe('dlrow olleh');
|
||||
});
|
||||
|
||||
test('handles palindrome', () => {
|
||||
expect(reverseString('racecar')).toBe('racecar');
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPalindrome', () => {
|
||||
test('returns true for palindrome', () => {
|
||||
expect(isPalindrome('racecar')).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for palindrome with spaces', () => {
|
||||
expect(isPalindrome('A man a plan a canal Panama')).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for non-palindrome', () => {
|
||||
expect(isPalindrome('hello')).toBe(false);
|
||||
});
|
||||
|
||||
test('returns true for empty string', () => {
|
||||
expect(isPalindrome('')).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for single character', () => {
|
||||
expect(isPalindrome('a')).toBe(true);
|
||||
});
|
||||
|
||||
test('handles mixed case', () => {
|
||||
expect(isPalindrome('RaceCar')).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('countVowels', () => {
|
||||
test('counts vowels in simple string', () => {
|
||||
expect(countVowels('hello')).toBe(2);
|
||||
});
|
||||
|
||||
test('returns 0 for string with no vowels', () => {
|
||||
expect(countVowels('bcdfg')).toBe(0);
|
||||
});
|
||||
|
||||
test('returns 0 for empty string', () => {
|
||||
expect(countVowels('')).toBe(0);
|
||||
});
|
||||
|
||||
test('counts uppercase vowels', () => {
|
||||
expect(countVowels('HELLO')).toBe(2);
|
||||
});
|
||||
|
||||
test('counts all vowels', () => {
|
||||
expect(countVowels('aeiouAEIOU')).toBe(10);
|
||||
});
|
||||
});
|
||||
|
||||
describe('uniqueWords', () => {
|
||||
test('finds unique words in simple string', () => {
|
||||
expect(uniqueWords('hello world')).toEqual(['hello', 'world']);
|
||||
});
|
||||
|
||||
test('removes duplicates', () => {
|
||||
expect(uniqueWords('hello hello world')).toEqual(['hello', 'world']);
|
||||
});
|
||||
|
||||
test('returns empty array for empty string', () => {
|
||||
expect(uniqueWords('')).toEqual([]);
|
||||
});
|
||||
|
||||
test('handles multiple spaces', () => {
|
||||
expect(uniqueWords('hello world')).toEqual(['hello', 'world']);
|
||||
});
|
||||
|
||||
test('normalizes case', () => {
|
||||
expect(uniqueWords('Hello hello HELLO')).toEqual(['hello']);
|
||||
});
|
||||
});
|
||||
15
code_to_optimize/js/code_to_optimize_vitest/tsconfig.json
Normal file
15
code_to_optimize/js/code_to_optimize_vitest/tsconfig.json
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "node",
|
||||
"esModuleInterop": true,
|
||||
"strict": true,
|
||||
"skipLibCheck": true,
|
||||
"outDir": "./dist",
|
||||
"declaration": true,
|
||||
"types": ["vitest/globals", "node"]
|
||||
},
|
||||
"include": ["./*.ts", "./tests/**/*.ts"],
|
||||
"exclude": ["node_modules", "dist"]
|
||||
}
|
||||
13
code_to_optimize/js/code_to_optimize_vitest/vitest.config.ts
Normal file
13
code_to_optimize/js/code_to_optimize_vitest/vitest.config.ts
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
import { defineConfig } from 'vitest/config';
|
||||
|
||||
export default defineConfig({
|
||||
test: {
|
||||
globals: true,
|
||||
environment: 'node',
|
||||
include: ['tests/**/*.test.ts'],
|
||||
reporters: ['default', 'junit'],
|
||||
outputFile: {
|
||||
junit: '.codeflash/vitest-results.xml',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
|
@ -12,6 +12,7 @@ from codeflash.cli_cmds.extension import install_vscode_extension
|
|||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_utils import exit_with_message
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.languages.test_framework import set_current_test_framework
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.version import __version__ as version
|
||||
|
||||
|
|
@ -121,6 +122,15 @@ def parse_args() -> Namespace:
|
|||
"--effort", type=str, help="Effort level for optimization", choices=["low", "medium", "high"], default="medium"
|
||||
)
|
||||
|
||||
# Config management flags
|
||||
parser.add_argument(
|
||||
"--show-config", action="store_true", help="Show current or auto-detected configuration and exit."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reset-config", action="store_true", help="Remove codeflash configuration from project config file."
|
||||
)
|
||||
parser.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts (useful for CI/scripts).")
|
||||
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
sys.argv[:] = [sys.argv[0], *unknown_args]
|
||||
return process_and_validate_cmd_args(args)
|
||||
|
|
@ -147,6 +157,16 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
|
|||
logger.info(f"Codeflash version {version}")
|
||||
sys.exit()
|
||||
|
||||
# Handle --show-config
|
||||
if getattr(args, "show_config", False):
|
||||
_handle_show_config()
|
||||
sys.exit()
|
||||
|
||||
# Handle --reset-config
|
||||
if getattr(args, "reset_config", False):
|
||||
_handle_reset_config(confirm=not getattr(args, "yes", False))
|
||||
sys.exit()
|
||||
|
||||
if args.command == "vscode-install":
|
||||
install_vscode_extension()
|
||||
sys.exit()
|
||||
|
|
@ -210,6 +230,11 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
# For JS/TS projects, tests_root is optional (Jest auto-discovers tests)
|
||||
# Default to module_root if not specified
|
||||
is_js_ts_project = pyproject_config.get("language") in ("javascript", "typescript")
|
||||
|
||||
# Set the test framework singleton for JS/TS projects
|
||||
if is_js_ts_project and pyproject_config.get("test_framework"):
|
||||
set_current_test_framework(pyproject_config["test_framework"])
|
||||
|
||||
if args.tests_root is None:
|
||||
if is_js_ts_project:
|
||||
# Try common JS test directories at project root first
|
||||
|
|
@ -334,3 +359,92 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
|
|||
else:
|
||||
args.all = Path(args.all).resolve()
|
||||
return args
|
||||
|
||||
|
||||
def _handle_show_config() -> None:
|
||||
"""Show current or auto-detected Codeflash configuration."""
|
||||
from rich.table import Table
|
||||
|
||||
from codeflash.cli_cmds.console import console
|
||||
from codeflash.setup.detector import detect_project, has_existing_config
|
||||
|
||||
project_root = Path.cwd()
|
||||
detected = detect_project(project_root)
|
||||
|
||||
# Check if config exists or is auto-detected
|
||||
config_exists, _ = has_existing_config(project_root)
|
||||
status = "Saved config" if config_exists else "Auto-detected (not saved)"
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
|
||||
console.print()
|
||||
|
||||
table = Table(show_header=True, header_style="bold cyan")
|
||||
table.add_column("Setting", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
table.add_row("Language", detected.language)
|
||||
table.add_row("Project root", str(detected.project_root))
|
||||
table.add_row("Module root", str(detected.module_root))
|
||||
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
|
||||
table.add_row("Test runner", detected.test_runner or "(not detected)")
|
||||
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
|
||||
table.add_row(
|
||||
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
|
||||
)
|
||||
table.add_row("Confidence", f"{detected.confidence:.0%}")
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
if not config_exists:
|
||||
console.print("[dim]Run [bold]codeflash --file <file>[/bold] to auto-save this config.[/dim]")
|
||||
|
||||
|
||||
def _handle_reset_config(confirm: bool = True) -> None:
|
||||
"""Remove Codeflash configuration from project config file.
|
||||
|
||||
Args:
|
||||
confirm: If True, prompt for confirmation before removing.
|
||||
|
||||
"""
|
||||
from codeflash.cli_cmds.console import console
|
||||
from codeflash.setup.config_writer import remove_config
|
||||
from codeflash.setup.detector import detect_project, has_existing_config
|
||||
|
||||
project_root = Path.cwd()
|
||||
|
||||
config_exists, _ = has_existing_config(project_root)
|
||||
if not config_exists:
|
||||
console.print("[yellow]No Codeflash configuration found to remove.[/yellow]")
|
||||
return
|
||||
|
||||
detected = detect_project(project_root)
|
||||
|
||||
if confirm:
|
||||
console.print("[bold]This will remove Codeflash configuration from your project.[/bold]")
|
||||
console.print()
|
||||
|
||||
config_file = "pyproject.toml" if detected.language == "python" else "package.json"
|
||||
console.print(f" Config file: {project_root / config_file}")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
response = console.input("[bold]Are you sure you want to remove the config? [y/N][/bold] ")
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
console.print("\n[yellow]Cancelled.[/yellow]")
|
||||
return
|
||||
|
||||
if response.lower() not in ("y", "yes"):
|
||||
console.print("[yellow]Cancelled.[/yellow]")
|
||||
return
|
||||
|
||||
success, message = remove_config(project_root, detected.language)
|
||||
|
||||
# Escape brackets in message to prevent Rich markup interpretation
|
||||
escaped_message = message.replace("[", "\\[")
|
||||
|
||||
if success:
|
||||
console.print(f"[green]✓[/green] {escaped_message}")
|
||||
else:
|
||||
console.print(f"[red]✗[/red] {escaped_message}")
|
||||
|
|
|
|||
|
|
@ -614,7 +614,33 @@ def check_for_toml_or_setup_file() -> str | None:
|
|||
curdir = Path.cwd()
|
||||
pyproject_toml_path = curdir / "pyproject.toml"
|
||||
setup_py_path = curdir / "setup.py"
|
||||
package_json_path = curdir / "package.json"
|
||||
project_name = None
|
||||
|
||||
# Check if this might be a JavaScript/TypeScript project that wasn't detected
|
||||
if package_json_path.exists() and not pyproject_toml_path.exists() and not setup_py_path.exists():
|
||||
js_redirect_panel = Panel(
|
||||
Text(
|
||||
f"📦 I found a package.json in {curdir}.\n\n"
|
||||
"This looks like a JavaScript/TypeScript project!\n"
|
||||
"Redirecting to JavaScript setup...",
|
||||
style="cyan",
|
||||
),
|
||||
title="🟨 JavaScript Project Detected",
|
||||
border_style="bright_yellow",
|
||||
)
|
||||
console.print(js_redirect_panel)
|
||||
console.print()
|
||||
ph("cli-js-project-redirect")
|
||||
|
||||
# Redirect to JS init
|
||||
from codeflash.cli_cmds.init_javascript import ProjectLanguage, detect_project_language, init_js_project
|
||||
|
||||
project_language = detect_project_language()
|
||||
if project_language in (ProjectLanguage.JAVASCRIPT, ProjectLanguage.TYPESCRIPT):
|
||||
init_js_project(project_language)
|
||||
sys.exit(0) # init_js_project handles its own exit, but ensure we don't continue
|
||||
|
||||
if pyproject_toml_path.exists():
|
||||
try:
|
||||
pyproject_toml_content = pyproject_toml_path.read_text(encoding="utf8")
|
||||
|
|
@ -624,28 +650,44 @@ def check_for_toml_or_setup_file() -> str | None:
|
|||
except Exception:
|
||||
click.echo("✅ I found a pyproject.toml for your project.")
|
||||
ph("cli-pyproject-toml-found")
|
||||
elif setup_py_path.exists():
|
||||
setup_py_content = setup_py_path.read_text(encoding="utf8")
|
||||
project_name_match = re.search(r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]", setup_py_content, re.DOTALL)
|
||||
if project_name_match:
|
||||
project_name = project_name_match.group(1)
|
||||
click.echo(f"✅ Found setup.py for your project {project_name}")
|
||||
ph("cli-setup-py-found-name")
|
||||
else:
|
||||
click.echo("✅ Found setup.py.")
|
||||
ph("cli-setup-py-found")
|
||||
else:
|
||||
if setup_py_path.exists():
|
||||
setup_py_content = setup_py_path.read_text(encoding="utf8")
|
||||
project_name_match = re.search(r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]", setup_py_content, re.DOTALL)
|
||||
if project_name_match:
|
||||
project_name = project_name_match.group(1)
|
||||
click.echo(f"✅ Found setup.py for your project {project_name}")
|
||||
ph("cli-setup-py-found-name")
|
||||
else:
|
||||
click.echo("✅ Found setup.py.")
|
||||
ph("cli-setup-py-found")
|
||||
toml_info_panel = Panel(
|
||||
Text(
|
||||
f"💡 No pyproject.toml found in {curdir}.\n\n"
|
||||
"This file is essential for Codeflash to store its configuration.\n"
|
||||
"Please ensure you are running `codeflash init` from your project's root directory.",
|
||||
style="yellow",
|
||||
),
|
||||
title="📋 pyproject.toml Required",
|
||||
border_style="bright_yellow",
|
||||
)
|
||||
console.print(toml_info_panel)
|
||||
# No Python config files found - show appropriate message
|
||||
# Check again if this might be a JS project
|
||||
if package_json_path.exists():
|
||||
js_hint_panel = Panel(
|
||||
Text(
|
||||
f"📦 I found a package.json but no pyproject.toml in {curdir}.\n\n"
|
||||
"If this is a JavaScript/TypeScript project, please run:\n"
|
||||
" codeflash init\n\n"
|
||||
"from the project root directory.",
|
||||
style="yellow",
|
||||
),
|
||||
title="🤔 Mixed Project?",
|
||||
border_style="bright_yellow",
|
||||
)
|
||||
console.print(js_hint_panel)
|
||||
else:
|
||||
toml_info_panel = Panel(
|
||||
Text(
|
||||
f"💡 No pyproject.toml found in {curdir}.\n\n"
|
||||
"This file is essential for Codeflash to store its configuration.\n"
|
||||
"Please ensure you are running `codeflash init` from your project's root directory.",
|
||||
style="yellow",
|
||||
),
|
||||
title="📋 pyproject.toml Required",
|
||||
border_style="bright_yellow",
|
||||
)
|
||||
console.print(toml_info_panel)
|
||||
console.print()
|
||||
ph("cli-no-pyproject-toml-or-setup-py")
|
||||
|
||||
|
|
|
|||
|
|
@ -129,6 +129,9 @@ def code_print(
|
|||
|
||||
spinners = cycle(SPINNER_TYPES)
|
||||
|
||||
# Track whether a progress bar is already active to prevent nested Live displays
|
||||
_progress_bar_active = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def progress_bar(
|
||||
|
|
@ -138,28 +141,38 @@ def progress_bar(
|
|||
|
||||
If revert_to_print is True, falls back to printing a single logger.info message
|
||||
instead of showing a progress bar.
|
||||
|
||||
If a progress bar is already active, yields a dummy task ID to avoid Rich's
|
||||
LiveError from nested Live displays.
|
||||
"""
|
||||
global _progress_bar_active
|
||||
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspTextMessage(text=message, takes_time=True))
|
||||
yield
|
||||
return
|
||||
|
||||
if revert_to_print:
|
||||
logger.info(message)
|
||||
if revert_to_print or _progress_bar_active:
|
||||
if revert_to_print:
|
||||
logger.info(message)
|
||||
|
||||
# Create a fake task ID since we still need to yield something
|
||||
yield DummyTask().id
|
||||
else:
|
||||
progress = Progress(
|
||||
SpinnerColumn(next(spinners)),
|
||||
*Progress.get_default_columns(),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
transient=transient,
|
||||
)
|
||||
task = progress.add_task(message, total=None)
|
||||
with progress:
|
||||
yield task
|
||||
_progress_bar_active = True
|
||||
try:
|
||||
progress = Progress(
|
||||
SpinnerColumn(next(spinners)),
|
||||
*Progress.get_default_columns(),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
transient=transient,
|
||||
)
|
||||
task = progress.add_task(message, total=None)
|
||||
with progress:
|
||||
yield task
|
||||
finally:
|
||||
_progress_bar_active = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import inquirer
|
|||
from git import InvalidGitRepositoryError, Repo
|
||||
from rich.console import Group
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
|
|
@ -98,34 +99,97 @@ def detect_project_language(project_root: Path | None = None) -> ProjectLanguage
|
|||
if has_pom_xml or has_build_gradle or has_java_src:
|
||||
return ProjectLanguage.JAVA
|
||||
|
||||
# TypeScript project
|
||||
# TypeScript project (tsconfig.json is definitive)
|
||||
if has_tsconfig:
|
||||
return ProjectLanguage.TYPESCRIPT
|
||||
|
||||
# Pure JS project (has package.json but no Python files)
|
||||
if has_package_json and not has_pyproject and not has_setup_py:
|
||||
return ProjectLanguage.JAVASCRIPT
|
||||
# JavaScript project - package.json without Python-specific files takes priority
|
||||
# Note: If both package.json and pyproject.toml exist, check for typical JS project indicators
|
||||
if has_package_json:
|
||||
# If no Python config files, it's definitely JavaScript
|
||||
if not has_pyproject and not has_setup_py:
|
||||
return ProjectLanguage.JAVASCRIPT
|
||||
|
||||
# If package.json exists with Python files, check for JS-specific indicators
|
||||
# Common React/Node patterns indicate a JS project
|
||||
js_indicators = [
|
||||
(root / "node_modules").exists(),
|
||||
(root / ".npmrc").exists(),
|
||||
(root / "yarn.lock").exists(),
|
||||
(root / "package-lock.json").exists(),
|
||||
(root / "pnpm-lock.yaml").exists(),
|
||||
(root / "bun.lockb").exists(),
|
||||
(root / "bun.lock").exists(),
|
||||
]
|
||||
if any(js_indicators):
|
||||
return ProjectLanguage.JAVASCRIPT
|
||||
|
||||
# Python project (default)
|
||||
return ProjectLanguage.PYTHON
|
||||
|
||||
|
||||
def determine_js_package_manager(project_root: Path) -> JsPackageManager:
|
||||
"""Determine which JavaScript package manager is being used based on lock files."""
|
||||
if (project_root / "bun.lockb").exists() or (project_root / "bun.lock").exists():
|
||||
return JsPackageManager.BUN
|
||||
if (project_root / "pnpm-lock.yaml").exists():
|
||||
return JsPackageManager.PNPM
|
||||
if (project_root / "yarn.lock").exists():
|
||||
return JsPackageManager.YARN
|
||||
if (project_root / "package-lock.json").exists():
|
||||
return JsPackageManager.NPM
|
||||
# Default to npm if package.json exists but no lock file
|
||||
"""Determine which JavaScript package manager is being used based on lock files.
|
||||
|
||||
Searches the project_root directory and parent directories (for monorepo support)
|
||||
to find lock files that indicate which package manager is being used.
|
||||
"""
|
||||
# Search from project_root up to filesystem root for lock files
|
||||
# This supports monorepo setups where lock file is at workspace root
|
||||
current_dir = project_root.resolve()
|
||||
while current_dir != current_dir.parent:
|
||||
if (current_dir / "bun.lockb").exists() or (current_dir / "bun.lock").exists():
|
||||
return JsPackageManager.BUN
|
||||
if (current_dir / "pnpm-lock.yaml").exists():
|
||||
return JsPackageManager.PNPM
|
||||
if (current_dir / "yarn.lock").exists():
|
||||
return JsPackageManager.YARN
|
||||
if (current_dir / "package-lock.json").exists():
|
||||
return JsPackageManager.NPM
|
||||
current_dir = current_dir.parent
|
||||
|
||||
# Default to npm if package.json exists but no lock file found anywhere
|
||||
if (project_root / "package.json").exists():
|
||||
return JsPackageManager.NPM
|
||||
return JsPackageManager.UNKNOWN
|
||||
|
||||
|
||||
def get_package_install_command(project_root: Path, package: str, dev: bool = True) -> list[str]:
|
||||
"""Get the correct install command for the project's package manager.
|
||||
|
||||
Args:
|
||||
project_root: The project root directory.
|
||||
package: The package name to install.
|
||||
dev: Whether to install as a dev dependency (default: True).
|
||||
|
||||
Returns:
|
||||
List of command arguments for subprocess execution.
|
||||
|
||||
"""
|
||||
pkg_manager = determine_js_package_manager(project_root)
|
||||
|
||||
if pkg_manager == JsPackageManager.PNPM:
|
||||
cmd = ["pnpm", "add", package]
|
||||
if dev:
|
||||
cmd.append("--save-dev")
|
||||
return cmd
|
||||
if pkg_manager == JsPackageManager.YARN:
|
||||
cmd = ["yarn", "add", package]
|
||||
if dev:
|
||||
cmd.append("--dev")
|
||||
return cmd
|
||||
if pkg_manager == JsPackageManager.BUN:
|
||||
cmd = ["bun", "add", package]
|
||||
if dev:
|
||||
cmd.append("--dev")
|
||||
return cmd
|
||||
# Default to npm
|
||||
cmd = ["npm", "install", package]
|
||||
if dev:
|
||||
cmd.append("--save-dev")
|
||||
return cmd
|
||||
|
||||
|
||||
def init_js_project(language: ProjectLanguage) -> None:
|
||||
"""Initialize Codeflash for a JavaScript/TypeScript project."""
|
||||
from codeflash.cli_cmds.cmd_init import install_github_actions, install_github_app, prompt_api_key
|
||||
|
|
@ -199,9 +263,7 @@ def init_js_project(language: ProjectLanguage) -> None:
|
|||
|
||||
def should_modify_package_json_config() -> tuple[bool, dict[str, Any] | None]:
|
||||
"""Check if package.json has valid codeflash config for JS/TS projects."""
|
||||
from rich.prompt import Confirm
|
||||
|
||||
package_json_path = Path.cwd() / "package.json"
|
||||
package_json_path = Path("package.json")
|
||||
|
||||
if not package_json_path.exists():
|
||||
click.echo("❌ No package.json found. Please run 'npm init' first.")
|
||||
|
|
@ -221,6 +283,10 @@ def should_modify_package_json_config() -> tuple[bool, dict[str, Any] | None]:
|
|||
if not Path(module_root).is_dir():
|
||||
return True, None
|
||||
|
||||
tests_root = config.get("testsRoot", None)
|
||||
if tests_root and not Path(tests_root).is_dir():
|
||||
return True, None
|
||||
|
||||
# Config is valid - ask if user wants to reconfigure
|
||||
return Confirm.ask(
|
||||
"✅ A valid Codeflash config already exists in package.json. Do you want to re-configure it?",
|
||||
|
|
|
|||
|
|
@ -1563,23 +1563,228 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo
|
|||
def get_opt_review_metrics(
|
||||
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path, language: Language
|
||||
) -> str:
|
||||
if language != Language.PYTHON:
|
||||
# TODO: {Claude} handle function refrences for other languages
|
||||
return ""
|
||||
"""Get function reference metrics for optimization review.
|
||||
|
||||
Uses the LanguageSupport abstraction to find references, supporting both Python and JavaScript/TypeScript.
|
||||
|
||||
Args:
|
||||
source_code: Source code of the file containing the function.
|
||||
file_path: Path to the file.
|
||||
qualified_name: Qualified name of the function (e.g., "module.ClassName.method").
|
||||
project_root: Root of the project.
|
||||
tests_root: Root of the tests directory.
|
||||
language: The programming language.
|
||||
|
||||
Returns:
|
||||
Markdown-formatted string with code blocks showing calling functions.
|
||||
|
||||
"""
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.registry import get_language_support
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
# Get the language support
|
||||
lang_support = get_language_support(language)
|
||||
if lang_support is None:
|
||||
return ""
|
||||
|
||||
# Parse qualified name to get function name and class name
|
||||
qualified_name_split = qualified_name.rsplit(".", maxsplit=1)
|
||||
if len(qualified_name_split) == 1:
|
||||
target_function, target_class = qualified_name_split[0], None
|
||||
function_name, class_name = qualified_name_split[0], None
|
||||
else:
|
||||
target_function, target_class = qualified_name_split[1], qualified_name_split[0]
|
||||
matches = get_fn_references_jedi(
|
||||
source_code, file_path, project_root, target_function, target_class
|
||||
) # jedi is not perfect, it doesn't capture aliased references
|
||||
calling_fns_details = find_occurances(qualified_name, str(file_path), matches, project_root, tests_root)
|
||||
function_name, class_name = qualified_name_split[1], qualified_name_split[0]
|
||||
|
||||
# Create a FunctionToOptimize for the function
|
||||
# We don't have full line info here, so we'll use defaults
|
||||
parents: list[FunctionParent] = []
|
||||
if class_name:
|
||||
parents = [FunctionParent(name=class_name, type="ClassDef")]
|
||||
|
||||
func_info = FunctionToOptimize(
|
||||
function_name=function_name,
|
||||
file_path=file_path,
|
||||
parents=parents,
|
||||
starting_line=1,
|
||||
ending_line=1,
|
||||
language=str(language),
|
||||
)
|
||||
|
||||
# Find references using language support
|
||||
references = lang_support.find_references(func_info, project_root, tests_root, max_files=500)
|
||||
|
||||
if not references:
|
||||
return ""
|
||||
|
||||
# Format references as markdown code blocks
|
||||
calling_fns_details = _format_references_as_markdown(references, file_path, project_root, language)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting function references: {e}")
|
||||
calling_fns_details = ""
|
||||
logger.debug(f"Investigate {e}")
|
||||
|
||||
end_time = time.perf_counter()
|
||||
logger.debug(f"Got function references in {end_time - start_time:.2f} seconds")
|
||||
return calling_fns_details
|
||||
|
||||
|
||||
def _format_references_as_markdown(references: list, file_path: Path, project_root: Path, language: Language) -> str:
|
||||
"""Format references as markdown code blocks with calling function code.
|
||||
|
||||
Args:
|
||||
references: List of ReferenceInfo objects.
|
||||
file_path: Path to the source file (to exclude).
|
||||
project_root: Root of the project.
|
||||
language: The programming language.
|
||||
|
||||
Returns:
|
||||
Markdown-formatted string.
|
||||
|
||||
"""
|
||||
# Group references by file
|
||||
refs_by_file: dict[Path, list] = {}
|
||||
for ref in references:
|
||||
# Exclude the source file's definition/import references
|
||||
if ref.file_path == file_path and ref.reference_type in ("import", "reexport"):
|
||||
continue
|
||||
|
||||
if ref.file_path not in refs_by_file:
|
||||
refs_by_file[ref.file_path] = []
|
||||
refs_by_file[ref.file_path].append(ref)
|
||||
|
||||
fn_call_context = ""
|
||||
context_len = 0
|
||||
|
||||
for ref_file, file_refs in refs_by_file.items():
|
||||
if context_len > MAX_CONTEXT_LEN_REVIEW:
|
||||
break
|
||||
|
||||
try:
|
||||
path_relative = ref_file.relative_to(project_root)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Get syntax highlighting language
|
||||
ext = ref_file.suffix.lstrip(".")
|
||||
if language == Language.PYTHON:
|
||||
lang_hint = "python"
|
||||
elif ext in ("ts", "tsx"):
|
||||
lang_hint = "typescript"
|
||||
else:
|
||||
lang_hint = "javascript"
|
||||
|
||||
# Read the file to extract calling function context
|
||||
try:
|
||||
file_content = ref_file.read_text(encoding="utf-8")
|
||||
lines = file_content.splitlines()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Get unique caller functions from this file
|
||||
callers_seen: set[str] = set()
|
||||
caller_contexts: list[str] = []
|
||||
|
||||
for ref in file_refs:
|
||||
caller = ref.caller_function or "<module>"
|
||||
if caller in callers_seen:
|
||||
continue
|
||||
callers_seen.add(caller)
|
||||
|
||||
# Extract context around the reference
|
||||
if ref.caller_function:
|
||||
# Try to extract the full calling function
|
||||
func_code = _extract_calling_function(file_content, ref.caller_function, ref.line, language)
|
||||
if func_code:
|
||||
caller_contexts.append(func_code)
|
||||
context_len += len(func_code)
|
||||
else:
|
||||
# Module-level call - show a few lines of context
|
||||
start_line = max(0, ref.line - 3)
|
||||
end_line = min(len(lines), ref.line + 2)
|
||||
context_code = "\n".join(lines[start_line:end_line])
|
||||
caller_contexts.append(context_code)
|
||||
context_len += len(context_code)
|
||||
|
||||
if caller_contexts:
|
||||
fn_call_context += f"```{lang_hint}:{path_relative.as_posix()}\n"
|
||||
fn_call_context += "\n".join(caller_contexts)
|
||||
fn_call_context += "\n```\n"
|
||||
|
||||
return fn_call_context
|
||||
|
||||
|
||||
def _extract_calling_function(source_code: str, function_name: str, ref_line: int, language: Language) -> str | None:
|
||||
"""Extract the source code of a calling function.
|
||||
|
||||
Args:
|
||||
source_code: Full source code of the file.
|
||||
function_name: Name of the function to extract.
|
||||
ref_line: Line number where the reference is.
|
||||
language: The programming language.
|
||||
|
||||
Returns:
|
||||
Source code of the function, or None if not found.
|
||||
|
||||
"""
|
||||
if language == Language.PYTHON:
|
||||
return _extract_calling_function_python(source_code, function_name, ref_line)
|
||||
return _extract_calling_function_js(source_code, function_name, ref_line)
|
||||
|
||||
|
||||
def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None:
|
||||
"""Extract the source code of a calling function in Python."""
|
||||
try:
|
||||
import ast
|
||||
|
||||
tree = ast.parse(source_code)
|
||||
lines = source_code.splitlines()
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
if node.name == function_name:
|
||||
# Check if the reference line is within this function
|
||||
start_line = node.lineno
|
||||
end_line = node.end_lineno or start_line
|
||||
if start_line <= ref_line <= end_line:
|
||||
return "\n".join(lines[start_line - 1 : end_line])
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_calling_function_js(source_code: str, function_name: str, ref_line: int) -> str | None:
|
||||
"""Extract the source code of a calling function in JavaScript/TypeScript.
|
||||
|
||||
Args:
|
||||
source_code: Full source code of the file.
|
||||
function_name: Name of the function to extract.
|
||||
ref_line: Line number where the reference is (helps identify the right function).
|
||||
|
||||
Returns:
|
||||
Source code of the function, or None if not found.
|
||||
|
||||
"""
|
||||
try:
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
# Try TypeScript first, fall back to JavaScript
|
||||
for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]:
|
||||
try:
|
||||
analyzer = TreeSitterAnalyzer(lang)
|
||||
functions = analyzer.find_functions(source_code, include_methods=True)
|
||||
|
||||
for func in functions:
|
||||
if func.name == function_name:
|
||||
# Check if the reference line is within this function
|
||||
if func.start_line <= ref_line <= func.end_line:
|
||||
return func.source_text
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -497,7 +497,7 @@ def replace_function_definitions_for_language(
|
|||
|
||||
"""
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
original_source_code: str = module_abspath.read_text(encoding="utf8")
|
||||
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
|
||||
|
|
@ -525,31 +525,21 @@ def replace_function_definitions_for_language(
|
|||
and function_to_optimize.ending_line
|
||||
and function_to_optimize.file_path == module_abspath
|
||||
):
|
||||
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents)
|
||||
func_info = FunctionInfo(
|
||||
name=function_to_optimize.function_name,
|
||||
file_path=module_abspath,
|
||||
start_line=function_to_optimize.starting_line,
|
||||
end_line=function_to_optimize.ending_line,
|
||||
parents=parents,
|
||||
is_async=function_to_optimize.is_async,
|
||||
language=language,
|
||||
)
|
||||
# For Java, we need to pass the full optimized code so replace_function can
|
||||
# extract and add any new class members (static fields, helper methods).
|
||||
# For other languages, we extract just the target function.
|
||||
if language == Language.JAVA:
|
||||
new_code = lang_support.replace_function(original_source_code, func_info, code_to_apply)
|
||||
new_code = lang_support.replace_function(original_source_code, function_to_optimize, code_to_apply)
|
||||
else:
|
||||
# Extract just the target function from the optimized code
|
||||
optimized_func = _extract_function_from_code(
|
||||
lang_support, code_to_apply, function_to_optimize.function_name, module_abspath
|
||||
)
|
||||
if optimized_func:
|
||||
new_code = lang_support.replace_function(original_source_code, func_info, optimized_func)
|
||||
new_code = lang_support.replace_function(original_source_code, function_to_optimize, optimized_func)
|
||||
else:
|
||||
# Fallback: use the entire optimized code (for simple single-function files)
|
||||
new_code = lang_support.replace_function(original_source_code, func_info, code_to_apply)
|
||||
new_code = lang_support.replace_function(original_source_code, function_to_optimize, code_to_apply)
|
||||
else:
|
||||
# For helper files or when we don't have precise line info:
|
||||
# Find each function by name in both original and optimized code
|
||||
|
|
@ -567,7 +557,7 @@ def replace_function_definitions_for_language(
|
|||
# Find the function in current code
|
||||
func = None
|
||||
for f in current_functions:
|
||||
if func_name in (f.qualified_name, f.name):
|
||||
if func_name in (f.qualified_name, f.function_name):
|
||||
func = f
|
||||
break
|
||||
|
||||
|
|
@ -581,7 +571,9 @@ def replace_function_definitions_for_language(
|
|||
modified = True
|
||||
else:
|
||||
# Extract just this function from the optimized code
|
||||
optimized_func = _extract_function_from_code(lang_support, code_to_apply, func.name, module_abspath)
|
||||
optimized_func = _extract_function_from_code(
|
||||
lang_support, code_to_apply, func.function_name, module_abspath
|
||||
)
|
||||
if optimized_func:
|
||||
new_code = lang_support.replace_function(new_code, func, optimized_func)
|
||||
modified = True
|
||||
|
|
@ -620,13 +612,13 @@ def _extract_function_from_code(
|
|||
# file_path is needed for JS/TS to determine correct analyzer (TypeScript vs JavaScript)
|
||||
functions = lang_support.discover_functions_from_source(source_code, file_path)
|
||||
for func in functions:
|
||||
if func.name == function_name:
|
||||
if func.function_name == function_name:
|
||||
# Extract the function's source using line numbers
|
||||
# Use doc_start_line if available to include JSDoc/docstring
|
||||
lines = source_code.splitlines(keepends=True)
|
||||
effective_start = func.doc_start_line or func.start_line
|
||||
if effective_start and func.end_line and effective_start <= len(lines):
|
||||
func_lines = lines[effective_start - 1 : func.end_line]
|
||||
effective_start = func.doc_start_line or func.starting_line
|
||||
if effective_start and func.ending_line and effective_start <= len(lines):
|
||||
func_lines = lines[effective_start - 1 : func.ending_line]
|
||||
return "".join(func_lines)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting function {function_name}: {e}")
|
||||
|
|
@ -753,6 +745,10 @@ def _add_global_declarations_for_language(
|
|||
replacement.py handles them. Adding them here would shift line numbers and
|
||||
break method matching for overloaded methods.
|
||||
|
||||
New declarations are inserted after any existing declarations they depend on.
|
||||
For example, if optimized code has `const _has = FOO.bar.bind(FOO)`, and `FOO`
|
||||
is already declared in the original source, `_has` will be inserted after `FOO`.
|
||||
|
||||
Args:
|
||||
optimized_code: The optimized code that may contain new declarations.
|
||||
original_source: The original source code.
|
||||
|
|
@ -761,7 +757,7 @@ def _add_global_declarations_for_language(
|
|||
target_function_names: List of function names being optimized (to exclude from Java helpers).
|
||||
|
||||
Returns:
|
||||
Original source with new declarations added.
|
||||
Original source with new declarations added in dependency order.
|
||||
|
||||
"""
|
||||
from codeflash.languages.base import Language
|
||||
|
|
@ -771,7 +767,6 @@ def _add_global_declarations_for_language(
|
|||
if language == Language.JAVA:
|
||||
return original_source
|
||||
|
||||
# Only process JavaScript/TypeScript for module-level declarations
|
||||
if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT):
|
||||
return original_source
|
||||
|
||||
|
|
@ -780,84 +775,164 @@ def _add_global_declarations_for_language(
|
|||
|
||||
analyzer = get_analyzer_for_file(module_abspath)
|
||||
|
||||
# Find declarations in both original and optimized code
|
||||
original_declarations = analyzer.find_module_level_declarations(original_source)
|
||||
optimized_declarations = analyzer.find_module_level_declarations(optimized_code)
|
||||
|
||||
if not optimized_declarations:
|
||||
return original_source
|
||||
|
||||
# Get names of existing declarations
|
||||
existing_names = {decl.name for decl in original_declarations}
|
||||
|
||||
# Also exclude names that are already imported (to avoid duplicating imported types)
|
||||
original_imports = analyzer.find_imports(original_source)
|
||||
for imp in original_imports:
|
||||
# Add default import name
|
||||
if imp.default_import:
|
||||
existing_names.add(imp.default_import)
|
||||
# Add named imports (use alias if present, otherwise use original name)
|
||||
for name, alias in imp.named_imports:
|
||||
existing_names.add(alias if alias else name)
|
||||
# Add namespace import
|
||||
if imp.namespace_import:
|
||||
existing_names.add(imp.namespace_import)
|
||||
|
||||
# Find new declarations (names that don't exist in original)
|
||||
new_declarations = []
|
||||
seen_sources = set() # Track to avoid duplicates from destructuring
|
||||
for decl in optimized_declarations:
|
||||
if decl.name not in existing_names and decl.source_code not in seen_sources:
|
||||
new_declarations.append(decl)
|
||||
seen_sources.add(decl.source_code)
|
||||
existing_names = _get_existing_names(original_declarations, analyzer, original_source)
|
||||
new_declarations = _filter_new_declarations(optimized_declarations, existing_names)
|
||||
|
||||
if not new_declarations:
|
||||
return original_source
|
||||
|
||||
# Sort by line number to maintain order
|
||||
new_declarations.sort(key=lambda d: d.start_line)
|
||||
# Build a map of existing declaration names to their end lines (1-indexed)
|
||||
existing_decl_end_lines = {decl.name: decl.end_line for decl in original_declarations}
|
||||
|
||||
# Find insertion point (after imports)
|
||||
lines = original_source.splitlines(keepends=True)
|
||||
insertion_line = _find_insertion_line_after_imports_js(lines, analyzer, original_source)
|
||||
# Insert each new declaration after its dependencies
|
||||
result = original_source
|
||||
for decl in new_declarations:
|
||||
result = _insert_declaration_after_dependencies(
|
||||
result, decl, existing_decl_end_lines, analyzer, module_abspath
|
||||
)
|
||||
# Update the map with the newly inserted declaration for subsequent insertions
|
||||
# Re-parse to get accurate line numbers after insertion
|
||||
updated_declarations = analyzer.find_module_level_declarations(result)
|
||||
existing_decl_end_lines = {d.name: d.end_line for d in updated_declarations}
|
||||
|
||||
# Build new declarations string
|
||||
new_decl_code = "\n".join(decl.source_code for decl in new_declarations)
|
||||
new_decl_code = new_decl_code + "\n\n"
|
||||
|
||||
# Insert declarations
|
||||
before = lines[:insertion_line]
|
||||
after = lines[insertion_line:]
|
||||
result_lines = [*before, new_decl_code, *after]
|
||||
|
||||
return "".join(result_lines)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding global declarations: {e}")
|
||||
return original_source
|
||||
|
||||
|
||||
def _find_insertion_line_after_imports_js(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
|
||||
"""Find the line index where new declarations should be inserted (after imports).
|
||||
def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyzer, original_source: str) -> set[str]:
|
||||
"""Get all names that already exist in the original source (declarations + imports)."""
|
||||
existing_names = {decl.name for decl in original_declarations}
|
||||
|
||||
original_imports = analyzer.find_imports(original_source)
|
||||
for imp in original_imports:
|
||||
if imp.default_import:
|
||||
existing_names.add(imp.default_import)
|
||||
for name, alias in imp.named_imports:
|
||||
existing_names.add(alias if alias else name)
|
||||
if imp.namespace_import:
|
||||
existing_names.add(imp.namespace_import)
|
||||
|
||||
return existing_names
|
||||
|
||||
|
||||
def _filter_new_declarations(optimized_declarations: list, existing_names: set[str]) -> list:
|
||||
"""Filter declarations to only those that don't exist in the original source."""
|
||||
new_declarations = []
|
||||
seen_sources: set[str] = set()
|
||||
|
||||
# Sort by line number to maintain order from optimized code
|
||||
sorted_declarations = sorted(optimized_declarations, key=lambda d: d.start_line)
|
||||
|
||||
for decl in sorted_declarations:
|
||||
if decl.name not in existing_names and decl.source_code not in seen_sources:
|
||||
new_declarations.append(decl)
|
||||
seen_sources.add(decl.source_code)
|
||||
|
||||
return new_declarations
|
||||
|
||||
|
||||
def _insert_declaration_after_dependencies(
|
||||
source: str,
|
||||
declaration,
|
||||
existing_decl_end_lines: dict[str, int],
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
module_abspath: Path,
|
||||
) -> str:
|
||||
"""Insert a declaration after the last existing declaration it depends on.
|
||||
|
||||
Args:
|
||||
source: Current source code.
|
||||
declaration: The declaration to insert.
|
||||
existing_decl_end_lines: Map of existing declaration names to their end lines.
|
||||
analyzer: TreeSitter analyzer.
|
||||
module_abspath: Path to the module file.
|
||||
|
||||
Returns:
|
||||
Source code with the declaration inserted at the correct position.
|
||||
|
||||
"""
|
||||
# Find identifiers referenced in this declaration
|
||||
referenced_names = analyzer.find_referenced_identifiers(declaration.source_code)
|
||||
|
||||
# Find the latest end line among all referenced declarations
|
||||
insertion_line = _find_insertion_line_for_declaration(source, referenced_names, existing_decl_end_lines, analyzer)
|
||||
|
||||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Ensure proper spacing
|
||||
decl_code = declaration.source_code
|
||||
if not decl_code.endswith("\n"):
|
||||
decl_code += "\n"
|
||||
|
||||
# Add blank line before if inserting after content
|
||||
if insertion_line > 0 and lines[insertion_line - 1].strip():
|
||||
decl_code = "\n" + decl_code
|
||||
|
||||
before = lines[:insertion_line]
|
||||
after = lines[insertion_line:]
|
||||
|
||||
return "".join([*before, decl_code, *after])
|
||||
|
||||
|
||||
def _find_insertion_line_for_declaration(
|
||||
source: str, referenced_names: set[str], existing_decl_end_lines: dict[str, int], analyzer: TreeSitterAnalyzer
|
||||
) -> int:
|
||||
"""Find the line where a declaration should be inserted based on its dependencies.
|
||||
|
||||
Args:
|
||||
source: Source code.
|
||||
referenced_names: Names referenced by the declaration.
|
||||
existing_decl_end_lines: Map of declaration names to their end lines (1-indexed).
|
||||
analyzer: TreeSitter analyzer.
|
||||
|
||||
Returns:
|
||||
Line index (0-based) where the declaration should be inserted.
|
||||
|
||||
"""
|
||||
# Find the maximum end line among referenced declarations
|
||||
max_dependency_line = 0
|
||||
for name in referenced_names:
|
||||
if name in existing_decl_end_lines:
|
||||
max_dependency_line = max(max_dependency_line, existing_decl_end_lines[name])
|
||||
|
||||
if max_dependency_line > 0:
|
||||
# Insert after the last dependency (end_line is 1-indexed, we need 0-indexed)
|
||||
return max_dependency_line
|
||||
|
||||
# No dependencies found - insert after imports
|
||||
lines = source.splitlines(keepends=True)
|
||||
return _find_line_after_imports(lines, analyzer, source)
|
||||
|
||||
|
||||
def _find_line_after_imports(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
|
||||
"""Find the line index after all imports.
|
||||
|
||||
Args:
|
||||
lines: Source lines.
|
||||
analyzer: TreeSitter analyzer for the file.
|
||||
analyzer: TreeSitter analyzer.
|
||||
source: Full source code.
|
||||
|
||||
Returns:
|
||||
Line index (0-based) for insertion.
|
||||
Line index (0-based) for insertion after imports.
|
||||
|
||||
"""
|
||||
try:
|
||||
imports = analyzer.find_imports(source)
|
||||
if imports:
|
||||
# Find the last import's end line
|
||||
return max(imp.end_line for imp in imports)
|
||||
except Exception as exc:
|
||||
logger.debug(f"Exception occurred in _find_insertion_line_after_imports_js: {exc}")
|
||||
logger.debug(f"Exception in _find_line_after_imports: {exc}")
|
||||
|
||||
# Default: insert at beginning (after any shebang/directive comments)
|
||||
# Default: insert at beginning (after shebang/directive comments)
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("//") and not stripped.startswith("#!"):
|
||||
|
|
|
|||
|
|
@ -212,15 +212,18 @@ def parse_package_json_config(package_json_path: Path) -> tuple[dict[str, Any],
|
|||
|
||||
Most configuration is auto-detected from package.json and project structure.
|
||||
Only minimal config is stored in the "codeflash" key:
|
||||
- moduleRoot: Override auto-detected module root (optional)
|
||||
- testsRoot: Override auto-detected tests root (optional)
|
||||
- test-framework: Override auto-detected test framework - "jest", "vitest", or "mocha" (optional)
|
||||
- benchmarksRoot: Where to store benchmark files (optional, defaults to __benchmarks__)
|
||||
- ignorePaths: Paths to exclude from optimization (optional)
|
||||
- disableTelemetry: Privacy preference (optional, defaults to false)
|
||||
- formatterCmds: Override auto-detected formatter (optional)
|
||||
|
||||
Auto-detected values (not stored in config):
|
||||
Auto-detected values (used when not explicitly configured):
|
||||
- language: Detected from tsconfig.json presence
|
||||
- moduleRoot: Detected from package.json exports/module/main or src/ convention
|
||||
- testRunner: Detected from devDependencies (vitest/jest/mocha)
|
||||
- test-framework: Detected from devDependencies (vitest/jest/mocha)
|
||||
- formatter: Detected from devDependencies (prettier/eslint)
|
||||
|
||||
Args:
|
||||
|
|
@ -251,10 +254,17 @@ def parse_package_json_config(package_json_path: Path) -> tuple[dict[str, Any],
|
|||
detected_module_root = detect_module_root(project_root, package_data)
|
||||
config["module_root"] = str((project_root / Path(detected_module_root)).resolve())
|
||||
|
||||
# Auto-detect test runner
|
||||
config["test_runner"] = detect_test_runner(project_root, package_data)
|
||||
if codeflash_config.get("testsRoot"):
|
||||
config["tests_root"] = str(project_root / Path(codeflash_config["testsRoot"]).resolve())
|
||||
|
||||
# Check for explicit test framework override, otherwise auto-detect
|
||||
# Uses "test-framework" to match Python's pyproject.toml convention
|
||||
if codeflash_config.get("test-framework"):
|
||||
config["test_framework"] = codeflash_config["test-framework"]
|
||||
else:
|
||||
config["test_framework"] = detect_test_runner(project_root, package_data)
|
||||
# Keep pytest_cmd for backwards compatibility with existing code
|
||||
config["pytest_cmd"] = config["test_runner"]
|
||||
config["pytest_cmd"] = config["test_framework"]
|
||||
|
||||
# Auto-detect formatter (with optional override from config)
|
||||
if "formatterCmds" in codeflash_config:
|
||||
|
|
|
|||
|
|
@ -13,10 +13,13 @@ from codeflash.cli_cmds.console import logger
|
|||
from codeflash.code_utils.code_utils import exit_with_message
|
||||
from codeflash.code_utils.formatter import format_code
|
||||
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
|
||||
from codeflash.languages.registry import get_language_support_by_common_formatters
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
|
||||
|
||||
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool:
|
||||
def check_formatter_installed(
|
||||
formatter_cmds: list[str], exit_on_failure: bool = True, language: str = "python"
|
||||
) -> bool:
|
||||
if not formatter_cmds or formatter_cmds[0] == "disabled":
|
||||
return True
|
||||
first_cmd = formatter_cmds[0]
|
||||
|
|
@ -35,10 +38,21 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
|
|||
)
|
||||
return False
|
||||
|
||||
tmp_code = """print("hello world")"""
|
||||
lang_support = get_language_support_by_common_formatters(formatter_cmds)
|
||||
if not lang_support:
|
||||
logger.debug(f"Could not determine language for formatter: {formatter_cmds}")
|
||||
return True
|
||||
|
||||
if str(lang_support.language) == "python":
|
||||
tmp_code = """print("hello world")"""
|
||||
elif str(lang_support.language) in ("javascript", "typescript"):
|
||||
tmp_code = "console.log('hello world');"
|
||||
else:
|
||||
return True
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp_file = Path(tmpdir) / "test_codeflash_formatter.py"
|
||||
tmp_file = Path(tmpdir) / ("test_codeflash_formatter" + lang_support.default_file_extension)
|
||||
tmp_file.write_text(tmp_code, encoding="utf-8")
|
||||
format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=False)
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -42,13 +42,16 @@ def generate_unified_diff(original: str, modified: str, from_file: str, to_file:
|
|||
def apply_formatter_cmds(
|
||||
cmds: list[str], path: Path, test_dir_str: Optional[str], print_status: bool, exit_on_failure: bool = True
|
||||
) -> tuple[Path, str, bool]:
|
||||
from codeflash.languages.registry import get_language_support
|
||||
|
||||
if not path.exists():
|
||||
msg = f"File {path} does not exist. Cannot apply formatter commands."
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
file_path = path
|
||||
lang_support = get_language_support(path)
|
||||
if test_dir_str:
|
||||
file_path = Path(test_dir_str) / "temp.py"
|
||||
file_path = Path(test_dir_str) / ("temp" + lang_support.default_file_extension)
|
||||
shutil.copy2(path, file_path)
|
||||
|
||||
file_token = "$file" # noqa: S105
|
||||
|
|
@ -87,13 +90,16 @@ def get_diff_lines_count(diff_output: str) -> int:
|
|||
return len(diff_lines)
|
||||
|
||||
|
||||
def format_generated_code(generated_test_source: str, formatter_cmds: list[str]) -> str:
|
||||
def format_generated_code(generated_test_source: str, formatter_cmds: list[str], language: str = "python") -> str:
|
||||
from codeflash.languages.registry import get_language_support
|
||||
|
||||
formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
|
||||
if formatter_name == "disabled": # nothing to do if no formatter provided
|
||||
return re.sub(r"\n{2,}", "\n\n", generated_test_source)
|
||||
with tempfile.TemporaryDirectory() as test_dir_str:
|
||||
# try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) return code with 2 or more newlines substituted with 2 newlines
|
||||
original_temp = Path(test_dir_str) / "original_temp.py"
|
||||
lang_support = get_language_support(language)
|
||||
original_temp = Path(test_dir_str) / ("original_temp" + lang_support.default_file_extension)
|
||||
original_temp.write_text(generated_test_source, encoding="utf8")
|
||||
_, formatted_code, changed = apply_formatter_cmds(
|
||||
formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=False
|
||||
|
|
@ -111,6 +117,8 @@ def format_code(
|
|||
print_status: bool = True,
|
||||
exit_on_failure: bool = True,
|
||||
) -> str:
|
||||
from codeflash.languages.registry import get_language_support
|
||||
|
||||
if is_LSP_enabled():
|
||||
exit_on_failure = False
|
||||
|
||||
|
|
@ -130,7 +138,8 @@ def format_code(
|
|||
# we don't count the formatting diff for the optimized function as it should be well-formatted
|
||||
original_code_without_opfunc = original_code.replace(optimized_code, "")
|
||||
|
||||
original_temp = Path(test_dir_str) / "original_temp.py"
|
||||
lang_support = get_language_support(path)
|
||||
original_temp = Path(test_dir_str) / ("original_temp" + lang_support.default_file_extension)
|
||||
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
|
||||
|
||||
formatted_temp, formatted_code, changed = apply_formatter_cmds(
|
||||
|
|
@ -160,6 +169,7 @@ def format_code(
|
|||
_, formatted_code, changed = apply_formatter_cmds(
|
||||
formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure
|
||||
)
|
||||
|
||||
if not changed:
|
||||
logger.warning(
|
||||
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, TypeVar
|
|||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
|
||||
|
||||
ObjectDefT = TypeVar("ObjectDefT", ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)
|
||||
|
|
|
|||
|
|
@ -23,8 +23,7 @@ from codeflash.context.unused_definition_remover import (
|
|||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
|
||||
|
||||
# Language support imports for multi-language code context extraction
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages import Language, is_python
|
||||
from codeflash.models.models import (
|
||||
CodeContextType,
|
||||
CodeOptimizationContext,
|
||||
|
|
@ -234,27 +233,13 @@ def get_code_optimization_context_for_language(
|
|||
|
||||
"""
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import FunctionInfo, ParentInfo
|
||||
|
||||
# Get language support for this function
|
||||
language = Language(function_to_optimize.language)
|
||||
lang_support = get_language_support(language)
|
||||
|
||||
# Convert FunctionToOptimize to FunctionInfo for language support
|
||||
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents)
|
||||
func_info = FunctionInfo(
|
||||
name=function_to_optimize.function_name,
|
||||
file_path=function_to_optimize.file_path,
|
||||
start_line=function_to_optimize.starting_line or 1,
|
||||
end_line=function_to_optimize.ending_line or 1,
|
||||
parents=parents,
|
||||
is_async=function_to_optimize.is_async,
|
||||
is_method=len(function_to_optimize.parents) > 0,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Extract code context using language support
|
||||
code_context = lang_support.extract_code_context(func_info, project_root_path, project_root_path)
|
||||
code_context = lang_support.extract_code_context(function_to_optimize, project_root_path, project_root_path)
|
||||
|
||||
# Build imports string if available
|
||||
imports_code = "\n".join(code_context.imports) if code_context.imports else ""
|
||||
|
|
@ -294,10 +279,9 @@ def get_code_optimization_context_for_language(
|
|||
helper_code = "\n\n".join(h.source_code for h in same_file_helpers)
|
||||
target_file_code = target_file_code + "\n\n" + helper_code
|
||||
|
||||
# Add global variables (module-level declarations) referenced by the function and helpers
|
||||
# These should be included in read-writable context so AI can modify them if needed
|
||||
if code_context.read_only_context:
|
||||
target_file_code = code_context.read_only_context + "\n\n" + target_file_code
|
||||
# Note: code_context.read_only_context contains type definitions and global variables
|
||||
# These should be passed as read-only context to the AI, not prepended to the target code
|
||||
# If prepended to target code, the AI treats them as code to optimize and includes them in output
|
||||
|
||||
# Add imports to target file code
|
||||
if imports_code:
|
||||
|
|
@ -350,8 +334,9 @@ def get_code_optimization_context_for_language(
|
|||
return CodeOptimizationContext(
|
||||
testgen_context=testgen_context,
|
||||
read_writable_code=read_writable_code,
|
||||
# Global variables are now included in read-writable code, so don't duplicate in read-only
|
||||
read_only_context_code="",
|
||||
# Pass type definitions and globals as read-only context for the AI
|
||||
# This way the AI sees them as context but doesn't include them in optimized output
|
||||
read_only_context_code=code_context.read_only_context,
|
||||
hashing_code_context=read_writable_code.flat,
|
||||
hashing_code_context_hash=code_hash,
|
||||
helper_functions=helper_function_sources,
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ from codeflash.code_utils.code_utils import (
|
|||
)
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
|
||||
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
|
||||
from codeflash.languages import is_javascript, is_python
|
||||
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -589,7 +588,7 @@ def discover_tests_for_language(
|
|||
|
||||
"""
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
try:
|
||||
lang_support = get_language_support(Language(language))
|
||||
|
|
@ -597,34 +596,20 @@ def discover_tests_for_language(
|
|||
logger.warning(f"Unsupported language {language}, returning empty test map")
|
||||
return {}, 0, 0
|
||||
|
||||
# Convert FunctionToOptimize to FunctionInfo for the language support API
|
||||
# Also build a mapping from simple qualified_name to full qualified_name_with_modules
|
||||
function_infos: list[FunctionInfo] = []
|
||||
# Collect all functions and build a mapping from simple qualified_name to full qualified_name_with_modules
|
||||
all_functions: list[FunctionToOptimize] = []
|
||||
simple_to_full_name: dict[str, str] = {}
|
||||
if file_to_funcs_to_optimize:
|
||||
for funcs in file_to_funcs_to_optimize.values():
|
||||
for func in funcs:
|
||||
parents = tuple(ParentInfo(p.name, p.type) for p in func.parents)
|
||||
func_info = FunctionInfo(
|
||||
name=func.function_name,
|
||||
file_path=func.file_path,
|
||||
start_line=func.starting_line or 0,
|
||||
end_line=func.ending_line or 0,
|
||||
start_col=func.starting_col,
|
||||
end_col=func.ending_col,
|
||||
is_async=func.is_async,
|
||||
is_method=bool(func.parents and any(p.type == "ClassDef" for p in func.parents)),
|
||||
parents=parents,
|
||||
language=Language(language),
|
||||
)
|
||||
function_infos.append(func_info)
|
||||
all_functions.append(func)
|
||||
# Map simple qualified_name to full qualified_name_with_modules_from_root
|
||||
simple_to_full_name[func_info.qualified_name] = func.qualified_name_with_modules_from_root(
|
||||
simple_to_full_name[func.qualified_name] = func.qualified_name_with_modules_from_root(
|
||||
cfg.project_root_path
|
||||
)
|
||||
|
||||
# Use language support to discover tests
|
||||
test_map = lang_support.discover_tests(cfg.tests_root, function_infos)
|
||||
test_map = lang_support.discover_tests(cfg.tests_root, all_functions)
|
||||
|
||||
# Convert TestInfo back to FunctionCalledInTest format
|
||||
# Use the full qualified name (with modules) as the key for consistency with Python
|
||||
|
|
@ -656,6 +641,8 @@ def discover_unit_tests(
|
|||
discover_only_these_tests: list[Path] | None = None,
|
||||
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None,
|
||||
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
|
||||
from codeflash.languages import is_javascript, is_python
|
||||
|
||||
# Detect language from functions being optimized
|
||||
language = _detect_language_from_functions(file_to_funcs_to_optimize)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||
|
||||
import git
|
||||
import libcst as cst
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
from rich.tree import Tree
|
||||
|
||||
|
|
@ -26,9 +27,8 @@ from codeflash.code_utils.code_utils import (
|
|||
from codeflash.code_utils.env_utils import get_pr_number
|
||||
from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.languages import get_language_support, get_supported_extensions
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.registry import is_language_supported
|
||||
from codeflash.languages.language_enum import Language
|
||||
from codeflash.languages.registry import get_language_support, get_supported_extensions, is_language_supported
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
|
|
@ -39,8 +39,11 @@ if TYPE_CHECKING:
|
|||
from libcst import CSTNode
|
||||
from libcst.metadata import CodeRange
|
||||
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
import contextlib
|
||||
|
||||
from rich.text import Text
|
||||
|
||||
_property_id = "property"
|
||||
|
|
@ -131,17 +134,23 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
|
|||
class FunctionToOptimize:
|
||||
"""Represent a function that is a candidate for optimization.
|
||||
|
||||
This is the canonical dataclass for representing functions across all languages
|
||||
(Python, JavaScript, TypeScript). It captures all information needed to identify,
|
||||
locate, and work with a function.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
function_name: The name of the function.
|
||||
file_path: The absolute file path where the function is located.
|
||||
parents: A list of parent scopes, which could be classes or functions.
|
||||
starting_line: The starting line number of the function in the file.
|
||||
ending_line: The ending line number of the function in the file.
|
||||
starting_col: The starting column offset (for precise location in multi-line contexts).
|
||||
ending_col: The ending column offset (for precise location in multi-line contexts).
|
||||
starting_line: The starting line number of the function in the file (1-indexed).
|
||||
ending_line: The ending line number of the function in the file (1-indexed).
|
||||
starting_col: The starting column offset (0-indexed, for precise location).
|
||||
ending_col: The ending column offset (0-indexed, for precise location).
|
||||
is_async: Whether this function is defined as async.
|
||||
is_method: Whether this is a method (belongs to a class).
|
||||
language: The programming language of this function (default: "python").
|
||||
doc_start_line: Line where docstring/JSDoc starts (or None if no doc comment).
|
||||
|
||||
The qualified_name property provides the full name of the function, including
|
||||
any parent class or function names. The qualified_name_with_modules_from_root
|
||||
|
|
@ -151,23 +160,32 @@ class FunctionToOptimize:
|
|||
|
||||
function_name: str
|
||||
file_path: Path
|
||||
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
|
||||
parents: list[FunctionParent] = Field(default_factory=list) # list[ClassDef | FunctionDef | AsyncFunctionDef]
|
||||
starting_line: Optional[int] = None
|
||||
ending_line: Optional[int] = None
|
||||
starting_col: Optional[int] = None # Column offset for precise location
|
||||
ending_col: Optional[int] = None # Column offset for precise location
|
||||
is_async: bool = False
|
||||
is_method: bool = False # Whether this is a method (belongs to a class)
|
||||
language: str = "python" # Language identifier for multi-language support
|
||||
doc_start_line: Optional[int] = None # Line where docstring/JSDoc starts
|
||||
|
||||
@property
|
||||
def top_level_parent_name(self) -> str:
|
||||
return self.function_name if not self.parents else self.parents[0].name
|
||||
|
||||
@property
|
||||
def class_name(self) -> str | None:
|
||||
"""Get the immediate parent class name, if any."""
|
||||
for parent in reversed(self.parents):
|
||||
if parent.type == "ClassDef":
|
||||
return parent.name
|
||||
return None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"{self.file_path}:{'.'.join([p.name for p in self.parents])}"
|
||||
f"{'.' if self.parents else ''}{self.function_name}"
|
||||
)
|
||||
qualified = f"{'.'.join([p.name for p in self.parents])}{'.' if self.parents else ''}{self.function_name}"
|
||||
line_info = f":{self.starting_line}-{self.ending_line}" if self.starting_line and self.ending_line else ""
|
||||
return f"{self.file_path}:{qualified}{line_info}"
|
||||
|
||||
@property
|
||||
def qualified_name(self) -> str:
|
||||
|
|
@ -180,6 +198,28 @@ class FunctionToOptimize:
|
|||
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
|
||||
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
|
||||
|
||||
@classmethod
|
||||
def from_function_info(cls, func_info: FunctionInfo) -> FunctionToOptimize:
|
||||
"""Create a FunctionToOptimize from a FunctionInfo instance.
|
||||
|
||||
This is a temporary method for backward compatibility during migration.
|
||||
Once FunctionInfo is fully removed, this method can be deleted.
|
||||
"""
|
||||
parents = [FunctionParent(name=p.name, type=p.type) for p in func_info.parents]
|
||||
return cls(
|
||||
function_name=func_info.name,
|
||||
file_path=func_info.file_path,
|
||||
parents=parents,
|
||||
starting_line=func_info.start_line,
|
||||
ending_line=func_info.end_line,
|
||||
starting_col=func_info.start_col,
|
||||
ending_col=func_info.end_col,
|
||||
is_async=func_info.is_async,
|
||||
is_method=func_info.is_method,
|
||||
language=func_info.language.value,
|
||||
doc_start_line=func_info.doc_start_line,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Multi-language support helpers
|
||||
|
|
@ -187,7 +227,7 @@ class FunctionToOptimize:
|
|||
|
||||
|
||||
def get_files_for_language(
|
||||
module_root_path: Path, ignore_paths: list[Path], language: Language | None = None
|
||||
module_root_path: Path, ignore_paths: list[Path] | None = None, language: Language | None = None
|
||||
) -> list[Path]:
|
||||
"""Get all source files for supported languages.
|
||||
|
||||
|
|
@ -200,22 +240,69 @@ def get_files_for_language(
|
|||
List of file paths matching supported extensions.
|
||||
|
||||
"""
|
||||
if ignore_paths is None:
|
||||
ignore_paths = []
|
||||
|
||||
if language is not None:
|
||||
support = get_language_support(language)
|
||||
extensions = support.file_extensions
|
||||
else:
|
||||
extensions = tuple(get_supported_extensions())
|
||||
|
||||
# Default directory patterns to always exclude for JS/TS
|
||||
js_ts_default_excludes = {
|
||||
"node_modules",
|
||||
"dist",
|
||||
"build",
|
||||
".next",
|
||||
".nuxt",
|
||||
"coverage",
|
||||
".cache",
|
||||
".turbo",
|
||||
".vercel",
|
||||
"__pycache__",
|
||||
}
|
||||
|
||||
files = []
|
||||
for ext in extensions:
|
||||
pattern = f"*{ext}"
|
||||
for file_path in module_root_path.rglob(pattern):
|
||||
# Check explicit ignore paths
|
||||
if any(file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths):
|
||||
continue
|
||||
# Check default JS/TS excludes in path parts
|
||||
if any(part in js_ts_default_excludes for part in file_path.parts):
|
||||
continue
|
||||
files.append(file_path)
|
||||
return files
|
||||
|
||||
|
||||
def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bool, str | None]:
|
||||
"""Check if a JavaScript/TypeScript function is exported from its module.
|
||||
|
||||
For JS/TS, functions that are not exported cannot be imported by tests,
|
||||
making them impossible to optimize.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file.
|
||||
function_name: Name of the function to check.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_exported, export_name). export_name may be 'default' for default exports.
|
||||
|
||||
"""
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
return analyzer.is_function_exported(source, function_name)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to check export status for {function_name}: {e}")
|
||||
# Return True to avoid blocking in case of errors
|
||||
return True, None
|
||||
|
||||
|
||||
def _find_all_functions_in_python_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
|
||||
"""Find all optimizable functions in a Python file using AST parsing.
|
||||
|
||||
|
|
@ -248,25 +335,8 @@ def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list
|
|||
try:
|
||||
lang_support = get_language_support(file_path)
|
||||
criteria = FunctionFilterCriteria(require_return=True)
|
||||
function_infos = lang_support.discover_functions(file_path, criteria)
|
||||
|
||||
ftos = []
|
||||
for func_info in function_infos:
|
||||
parents = [FunctionParent(p.name, p.type) for p in func_info.parents]
|
||||
ftos.append(
|
||||
FunctionToOptimize(
|
||||
function_name=func_info.name,
|
||||
file_path=func_info.file_path,
|
||||
parents=parents,
|
||||
starting_line=func_info.start_line,
|
||||
ending_line=func_info.end_line,
|
||||
starting_col=func_info.start_col,
|
||||
ending_col=func_info.end_col,
|
||||
is_async=func_info.is_async,
|
||||
language=func_info.language.value,
|
||||
)
|
||||
)
|
||||
functions[file_path] = ftos
|
||||
# discover_functions already returns FunctionToOptimize objects
|
||||
functions[file_path] = lang_support.discover_functions(file_path, criteria)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to discover functions in {file_path}: {e}")
|
||||
|
||||
|
|
@ -338,6 +408,36 @@ def get_functions_to_optimize(
|
|||
exit_with_message(
|
||||
f"Function {only_get_this_function} not found in file {file}\nor the function does not have a 'return' statement or is a property"
|
||||
)
|
||||
|
||||
# For JavaScript/TypeScript, verify that the function (or its parent class) is exported
|
||||
# Non-exported functions cannot be imported by tests
|
||||
if found_function.language in ("javascript", "typescript"):
|
||||
# For class methods, check if the parent class is exported
|
||||
# For standalone functions, check if the function itself is exported
|
||||
if found_function.parents:
|
||||
# It's a class method - check if the class is exported
|
||||
name_to_check = found_function.top_level_parent_name
|
||||
else:
|
||||
# It's a standalone function - check if the function is exported
|
||||
name_to_check = found_function.function_name
|
||||
|
||||
is_exported, export_name = _is_js_ts_function_exported(file, name_to_check)
|
||||
if not is_exported:
|
||||
if found_function.parents:
|
||||
logger.debug(
|
||||
f"Class '{name_to_check}' containing method '{found_function.function_name}' "
|
||||
f"is not exported from {file}. "
|
||||
f"In JavaScript/TypeScript, only exported classes/functions can be optimized "
|
||||
f"because tests need to import them."
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Function '{found_function.function_name}' is not exported from {file}. "
|
||||
f"In JavaScript/TypeScript, only exported functions can be optimized because "
|
||||
f"tests need to import them."
|
||||
)
|
||||
return {}, 0, None
|
||||
|
||||
functions[file] = [found_function]
|
||||
else:
|
||||
logger.info("Finding all functions modified in the current git diff ...")
|
||||
|
|
@ -539,9 +639,10 @@ def get_all_replay_test_functions(
|
|||
except Exception as e:
|
||||
logger.warning(f"Error parsing replay test file {replay_test_file}: {e}")
|
||||
|
||||
if not trace_file_path:
|
||||
if trace_file_path is None:
|
||||
logger.error("Could not find trace_file_path in replay test files.")
|
||||
exit_with_message("Could not find trace_file_path in replay test files.")
|
||||
raise AssertionError("Unreachable") # exit_with_message never returns
|
||||
|
||||
if not trace_file_path.exists():
|
||||
logger.error(f"Trace file not found: {trace_file_path}")
|
||||
|
|
@ -596,7 +697,7 @@ def get_all_replay_test_functions(
|
|||
if filtered_list:
|
||||
filtered_valid_functions[file_path] = filtered_list
|
||||
|
||||
return filtered_valid_functions, trace_file_path
|
||||
return dict(filtered_valid_functions), trace_file_path
|
||||
|
||||
|
||||
def is_git_repo(file_path: str) -> bool:
|
||||
|
|
@ -608,11 +709,13 @@ def is_git_repo(file_path: str) -> bool:
|
|||
|
||||
|
||||
@cache
|
||||
def ignored_submodule_paths(module_root: str) -> list[str]:
|
||||
def ignored_submodule_paths(module_root: str) -> list[Path]:
|
||||
if is_git_repo(module_root):
|
||||
git_repo = git.Repo(module_root, search_parent_directories=True)
|
||||
try:
|
||||
return [Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules]
|
||||
working_dir = git_repo.working_tree_dir
|
||||
if working_dir is not None:
|
||||
return [Path(working_dir, submodule.path).resolve() for submodule in git_repo.submodules]
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting submodule paths: {e}")
|
||||
return []
|
||||
|
|
@ -626,7 +729,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|||
self.class_name = class_name
|
||||
self.function_name = function_or_method_name
|
||||
self.is_top_level = False
|
||||
self.function_has_args = None
|
||||
self.function_has_args: bool | None = None
|
||||
self.line_no = line_no
|
||||
self.is_staticmethod = False
|
||||
self.is_classmethod = False
|
||||
|
|
@ -740,31 +843,28 @@ def was_function_previously_optimized(
|
|||
|
||||
# Check optimization status if repository info is provided
|
||||
# already_optimized_count = 0
|
||||
try:
|
||||
|
||||
# Check optimization status if repository info is provided
|
||||
# already_optimized_count = 0
|
||||
owner = None
|
||||
repo = None
|
||||
with contextlib.suppress(git.exc.InvalidGitRepositoryError):
|
||||
owner, repo = get_repo_owner_and_name()
|
||||
except git.exc.InvalidGitRepositoryError:
|
||||
logger.warning("No git repository found")
|
||||
owner, repo = None, None
|
||||
|
||||
pr_number = get_pr_number()
|
||||
|
||||
if not owner or not repo or pr_number is None or getattr(args, "no_pr", False):
|
||||
return False
|
||||
|
||||
code_contexts = []
|
||||
|
||||
func_hash = code_context.hashing_code_context_hash
|
||||
# Use a unique path identifier that includes function info
|
||||
|
||||
code_contexts.append(
|
||||
code_contexts = [
|
||||
{
|
||||
"file_path": function_to_optimize.file_path,
|
||||
"file_path": str(function_to_optimize.file_path),
|
||||
"function_name": function_to_optimize.qualified_name,
|
||||
"code_hash": func_hash,
|
||||
}
|
||||
)
|
||||
|
||||
if not code_contexts:
|
||||
return False
|
||||
]
|
||||
|
||||
try:
|
||||
result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts)
|
||||
|
|
@ -783,7 +883,7 @@ def filter_functions(
|
|||
ignore_paths: list[Path],
|
||||
project_root: Path,
|
||||
module_root: Path,
|
||||
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
|
||||
previous_checkpoint_functions: dict[str, dict[str, Any]] | None = None,
|
||||
*,
|
||||
disable_logs: bool = False,
|
||||
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
|
||||
|
|
@ -808,21 +908,49 @@ def filter_functions(
|
|||
# Normalize paths for case-insensitive comparison on Windows
|
||||
tests_root_str = os.path.normcase(str(tests_root))
|
||||
module_root_str = os.path.normcase(str(module_root))
|
||||
project_root_str = os.path.normcase(str(project_root))
|
||||
|
||||
# Check if tests_root overlaps with module_root or project_root
|
||||
# In this case, we need to use file pattern matching instead of directory matching
|
||||
tests_root_overlaps_source = tests_root_str in (module_root_str, project_root_str) or module_root_str.startswith(
|
||||
tests_root_str + os.sep
|
||||
)
|
||||
|
||||
# Test file patterns for when tests_root overlaps with source
|
||||
test_file_name_patterns = (".test.", ".spec.", "_test.", "_spec.")
|
||||
test_dir_patterns = (os.sep + "test" + os.sep, os.sep + "tests" + os.sep, os.sep + "__tests__" + os.sep)
|
||||
|
||||
def is_test_file(file_path_normalized: str) -> bool:
|
||||
"""Check if a file is a test file based on patterns."""
|
||||
if tests_root_overlaps_source:
|
||||
# Use file pattern matching when tests_root overlaps with source
|
||||
file_lower = file_path_normalized.lower()
|
||||
# Check filename patterns (e.g., .test.ts, .spec.ts)
|
||||
if any(pattern in file_lower for pattern in test_file_name_patterns):
|
||||
return True
|
||||
# Check directory patterns, but only within the project root
|
||||
# to avoid false positives from parent directories
|
||||
relative_path = file_lower
|
||||
if project_root_str and file_lower.startswith(project_root_str.lower()):
|
||||
relative_path = file_lower[len(project_root_str) :]
|
||||
return any(pattern in relative_path for pattern in test_dir_patterns)
|
||||
# Use directory-based filtering when tests are in a separate directory
|
||||
return file_path_normalized.startswith(tests_root_str + os.sep)
|
||||
|
||||
# We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
|
||||
for file_path_path, functions in modified_functions.items():
|
||||
_functions = functions
|
||||
file_path = str(file_path_path)
|
||||
file_path_normalized = os.path.normcase(file_path)
|
||||
if file_path_normalized.startswith(tests_root_str + os.sep):
|
||||
if is_test_file(file_path_normalized):
|
||||
test_functions_removed_count += len(_functions)
|
||||
continue
|
||||
if file_path in ignore_paths or any(
|
||||
if file_path_path in ignore_paths or any(
|
||||
file_path_normalized.startswith(os.path.normcase(str(ignore_path)) + os.sep) for ignore_path in ignore_paths
|
||||
):
|
||||
ignore_paths_removed_count += 1
|
||||
continue
|
||||
if file_path in submodule_paths or any(
|
||||
if file_path_path in submodule_paths or any(
|
||||
file_path_normalized.startswith(os.path.normcase(str(submodule_path)) + os.sep)
|
||||
for submodule_path in submodule_paths
|
||||
):
|
||||
|
|
@ -914,7 +1042,7 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
|
|||
|
||||
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
|
||||
# Custom DFS, return True as soon as a Return node is found
|
||||
stack = [function_node]
|
||||
stack: list[ast.AST] = [function_node]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if isinstance(node, ast.Return):
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ Usage:
|
|||
|
||||
from codeflash.languages.base import (
|
||||
CodeContext,
|
||||
FunctionInfo,
|
||||
HelperFunction,
|
||||
Language,
|
||||
LanguageSupport,
|
||||
|
|
@ -53,6 +52,40 @@ from codeflash.languages.registry import (
|
|||
get_supported_languages,
|
||||
register_language,
|
||||
)
|
||||
from codeflash.languages.test_framework import (
|
||||
current_test_framework,
|
||||
get_js_test_framework_or_default,
|
||||
is_jest,
|
||||
is_mocha,
|
||||
is_pytest,
|
||||
is_unittest,
|
||||
is_vitest,
|
||||
reset_test_framework,
|
||||
set_current_test_framework,
|
||||
)
|
||||
|
||||
|
||||
# Lazy imports to avoid circular imports
|
||||
def __getattr__(name: str):
|
||||
if name == "FunctionInfo":
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
return FunctionToOptimize
|
||||
if name == "JavaScriptSupport":
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport
|
||||
|
||||
return JavaScriptSupport
|
||||
if name == "TypeScriptSupport":
|
||||
from codeflash.languages.javascript.support import TypeScriptSupport
|
||||
|
||||
return TypeScriptSupport
|
||||
if name == "PythonSupport":
|
||||
from codeflash.languages.python.support import PythonSupport
|
||||
|
||||
return PythonSupport
|
||||
msg = f"module {__name__!r} has no attribute {name!r}"
|
||||
raise AttributeError(msg)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Base types
|
||||
|
|
@ -67,16 +100,24 @@ __all__ = [
|
|||
# Current language singleton
|
||||
"current_language",
|
||||
"current_language_support",
|
||||
# Registry functions
|
||||
"current_test_framework",
|
||||
"detect_project_language",
|
||||
"get_js_test_framework_or_default",
|
||||
"get_language_support",
|
||||
"get_supported_extensions",
|
||||
"get_supported_languages",
|
||||
"is_java",
|
||||
"is_javascript",
|
||||
"is_jest",
|
||||
"is_mocha",
|
||||
"is_pytest",
|
||||
"is_python",
|
||||
"is_typescript",
|
||||
"is_unittest",
|
||||
"is_vitest",
|
||||
"register_language",
|
||||
"reset_current_language",
|
||||
"reset_test_framework",
|
||||
"set_current_language",
|
||||
"set_current_test_framework",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,112 +2,36 @@
|
|||
|
||||
This module defines the core abstractions that all language implementations must follow.
|
||||
The LanguageSupport protocol defines the interface that each language must implement,
|
||||
while the dataclasses define language-agnostic representations of code constructs.
|
||||
while FunctionToOptimize is the canonical representation of functions across all languages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
class Language(str, Enum):
|
||||
"""Supported programming languages."""
|
||||
from codeflash.languages.language_enum import Language
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
|
||||
PYTHON = "python"
|
||||
JAVASCRIPT = "javascript"
|
||||
TYPESCRIPT = "typescript"
|
||||
JAVA = "java"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
# Backward compatibility aliases - ParentInfo is now FunctionParent
|
||||
ParentInfo = FunctionParent
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParentInfo:
|
||||
"""Parent scope information for nested functions/methods.
|
||||
# Lazy import for FunctionInfo to avoid circular imports
|
||||
# This allows `from codeflash.languages.base import FunctionInfo` to work at runtime
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "FunctionInfo":
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
Represents the parent class or function that contains a nested function.
|
||||
Used to construct the qualified name of a function.
|
||||
|
||||
Attributes:
|
||||
name: The name of the parent scope (class name or function name).
|
||||
type: The type of parent ("ClassDef", "FunctionDef", "AsyncFunctionDef", etc.).
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: str # "ClassDef", "FunctionDef", "AsyncFunctionDef", etc.
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.type}:{self.name}"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionInfo:
|
||||
"""Language-agnostic representation of a function to optimize.
|
||||
|
||||
This class captures all the information needed to identify, locate, and
|
||||
work with a function across different programming languages.
|
||||
|
||||
Attributes:
|
||||
name: The simple function name (e.g., "add").
|
||||
file_path: Absolute path to the file containing the function.
|
||||
start_line: Starting line number (1-indexed).
|
||||
end_line: Ending line number (1-indexed, inclusive).
|
||||
parents: List of parent scopes (for nested functions/methods).
|
||||
is_async: Whether this is an async function.
|
||||
is_method: Whether this is a method (belongs to a class).
|
||||
language: The programming language.
|
||||
start_col: Starting column (0-indexed), optional for more precise location.
|
||||
end_col: Ending column (0-indexed), optional.
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
file_path: Path
|
||||
start_line: int
|
||||
end_line: int
|
||||
parents: tuple[ParentInfo, ...] = ()
|
||||
is_async: bool = False
|
||||
is_method: bool = False
|
||||
language: Language = Language.PYTHON
|
||||
start_col: int | None = None
|
||||
end_col: int | None = None
|
||||
doc_start_line: int | None = None # Line where docstring/JSDoc starts (or None if no doc comment)
|
||||
|
||||
@property
|
||||
def qualified_name(self) -> str:
|
||||
"""Full qualified name including parent scopes.
|
||||
|
||||
For a method `add` in class `Calculator`, returns "Calculator.add".
|
||||
For nested functions, includes all parent scopes.
|
||||
"""
|
||||
if not self.parents:
|
||||
return self.name
|
||||
parent_path = ".".join(parent.name for parent in self.parents)
|
||||
return f"{parent_path}.{self.name}"
|
||||
|
||||
@property
|
||||
def class_name(self) -> str | None:
|
||||
"""Get the immediate parent class name, if any."""
|
||||
for parent in reversed(self.parents):
|
||||
if parent.type == "ClassDef":
|
||||
return parent.name
|
||||
return None
|
||||
|
||||
@property
|
||||
def top_level_parent_name(self) -> str:
|
||||
"""Get the top-level parent name, or function name if no parents."""
|
||||
return self.parents[0].name if self.parents else self.name
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FunctionInfo({self.qualified_name} at {self.file_path}:{self.start_line}-{self.end_line})"
|
||||
return FunctionToOptimize
|
||||
msg = f"module {__name__!r} has no attribute {name!r}"
|
||||
raise AttributeError(msg)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -237,6 +161,37 @@ class FunctionFilterCriteria:
|
|||
max_lines: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceInfo:
|
||||
"""Information about a reference (call site) to a function.
|
||||
|
||||
This class captures information about where a function is called
|
||||
from, including the file, line number, context, and caller function.
|
||||
|
||||
Attributes:
|
||||
file_path: Path to the file containing the reference.
|
||||
line: Line number (1-indexed).
|
||||
column: Column number (0-indexed).
|
||||
end_line: End line number (1-indexed).
|
||||
end_column: End column number (0-indexed).
|
||||
context: The line of code containing the reference.
|
||||
reference_type: Type of reference ("call", "callback", "memoized", "import", "reexport").
|
||||
import_name: Name used to import the function (may differ from original).
|
||||
caller_function: Name of the function containing this reference (or None for module-level).
|
||||
|
||||
"""
|
||||
|
||||
file_path: Path
|
||||
line: int
|
||||
column: int
|
||||
end_line: int
|
||||
end_column: int
|
||||
context: str
|
||||
reference_type: str
|
||||
import_name: str | None
|
||||
caller_function: str | None = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LanguageSupport(Protocol):
|
||||
"""Protocol defining what a language implementation must provide.
|
||||
|
|
@ -279,6 +234,11 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def default_file_extension(self) -> str:
|
||||
"""Default file extension for this language."""
|
||||
...
|
||||
|
||||
@property
|
||||
def test_framework(self) -> str:
|
||||
"""Primary test framework name.
|
||||
|
|
@ -298,7 +258,7 @@ class LanguageSupport(Protocol):
|
|||
|
||||
def discover_functions(
|
||||
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
|
||||
) -> list[FunctionInfo]:
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Find all optimizable functions in a file.
|
||||
|
||||
Args:
|
||||
|
|
@ -306,12 +266,14 @@ class LanguageSupport(Protocol):
|
|||
filter_criteria: Optional criteria to filter functions.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo objects for discovered functions.
|
||||
List of FunctionToOptimize objects for discovered functions.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionInfo]) -> dict[str, list[TestInfo]]:
|
||||
def discover_tests(
|
||||
self, test_root: Path, source_functions: Sequence[FunctionToOptimize]
|
||||
) -> dict[str, list[TestInfo]]:
|
||||
"""Map source functions to their tests via static analysis.
|
||||
|
||||
Args:
|
||||
|
|
@ -326,7 +288,7 @@ class LanguageSupport(Protocol):
|
|||
|
||||
# === Code Analysis ===
|
||||
|
||||
def extract_code_context(self, function: FunctionInfo, project_root: Path, module_root: Path) -> CodeContext:
|
||||
def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext:
|
||||
"""Extract function code and its dependencies.
|
||||
|
||||
Args:
|
||||
|
|
@ -340,7 +302,7 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> list[HelperFunction]:
|
||||
def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]:
|
||||
"""Find helper functions called by the target function.
|
||||
|
||||
Args:
|
||||
|
|
@ -353,14 +315,37 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def find_references(
|
||||
self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500
|
||||
) -> list[ReferenceInfo]:
|
||||
"""Find all references (call sites) to a function across the codebase.
|
||||
|
||||
This method finds all places where a function is called, including:
|
||||
- Direct calls
|
||||
- Callbacks (passed to other functions)
|
||||
- Memoized versions
|
||||
- Re-exports
|
||||
|
||||
Args:
|
||||
function: The function to find references for.
|
||||
project_root: Root of the project to search.
|
||||
tests_root: Root of tests directory (references in tests are excluded).
|
||||
max_files: Maximum number of files to search.
|
||||
|
||||
Returns:
|
||||
List of ReferenceInfo objects describing each reference location.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
# === Code Transformation ===
|
||||
|
||||
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:
|
||||
def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
|
||||
"""Replace a function in source code with new implementation.
|
||||
|
||||
Args:
|
||||
source: Original source code.
|
||||
function: FunctionInfo identifying the function to replace.
|
||||
function: FunctionToOptimize identifying the function to replace.
|
||||
new_source: New function source code.
|
||||
|
||||
Returns:
|
||||
|
|
@ -416,7 +401,7 @@ class LanguageSupport(Protocol):
|
|||
|
||||
# === Instrumentation ===
|
||||
|
||||
def instrument_for_behavior(self, source: str, functions: Sequence[FunctionInfo]) -> str:
|
||||
def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOptimize]) -> str:
|
||||
"""Add behavior instrumentation to capture inputs/outputs.
|
||||
|
||||
Args:
|
||||
|
|
@ -429,7 +414,7 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionInfo) -> str:
|
||||
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str:
|
||||
"""Add timing instrumentation to test code.
|
||||
|
||||
Args:
|
||||
|
|
@ -606,7 +591,9 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def instrument_source_for_line_profiler(self, func_info: FunctionInfo, line_profiler_output_file: Path) -> bool:
|
||||
def instrument_source_for_line_profiler(
|
||||
self, func_info: FunctionToOptimize, line_profiler_output_file: Path
|
||||
) -> bool:
|
||||
"""Instrument source code before line profiling."""
|
||||
...
|
||||
|
||||
|
|
@ -675,17 +662,14 @@ class LanguageSupport(Protocol):
|
|||
...
|
||||
|
||||
|
||||
def convert_parents_to_tuple(parents: list | tuple) -> tuple[ParentInfo, ...]:
|
||||
"""Convert a list of parent objects to a tuple of ParentInfo.
|
||||
|
||||
This helper handles conversion from the existing FunctionParent
|
||||
dataclass to the new ParentInfo dataclass.
|
||||
def convert_parents_to_tuple(parents: list | tuple) -> tuple[FunctionParent, ...]:
|
||||
"""Convert a list of parent objects to a tuple of FunctionParent.
|
||||
|
||||
Args:
|
||||
parents: List or tuple of parent objects with name and type attributes.
|
||||
|
||||
Returns:
|
||||
Tuple of ParentInfo objects.
|
||||
Tuple of FunctionParent objects.
|
||||
|
||||
"""
|
||||
return tuple(ParentInfo(name=p.name, type=p.type) for p in parents)
|
||||
return tuple(FunctionParent(name=p.name, type=p.type) for p in parents)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,8 @@ import logging
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import CodeContext, FunctionInfo, HelperFunction, Language
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import CodeContext, HelperFunction, Language
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaClassNode, get_java_analyzer
|
||||
|
|
@ -29,7 +30,7 @@ class InvalidJavaSyntaxError(Exception):
|
|||
|
||||
|
||||
def extract_code_context(
|
||||
function: FunctionInfo,
|
||||
function: FunctionToOptimize,
|
||||
project_root: Path,
|
||||
module_root: Path | None = None,
|
||||
max_helper_depth: int = 2,
|
||||
|
|
@ -83,7 +84,7 @@ def extract_code_context(
|
|||
# This provides necessary context for optimization
|
||||
parent_type_name = _get_parent_type_name(function)
|
||||
if parent_type_name:
|
||||
type_skeleton = _extract_type_skeleton(source, parent_type_name, function.name, analyzer)
|
||||
type_skeleton = _extract_type_skeleton(source, parent_type_name, function.function_name, analyzer)
|
||||
if type_skeleton:
|
||||
target_code = _wrap_method_in_type_skeleton(target_code, type_skeleton)
|
||||
wrapped_in_skeleton = True
|
||||
|
|
@ -107,7 +108,7 @@ def extract_code_context(
|
|||
if validate_syntax and target_code:
|
||||
if not analyzer.validate_syntax(target_code):
|
||||
raise InvalidJavaSyntaxError(
|
||||
f"Extracted code for {function.name} is not syntactically valid Java:\n{target_code}"
|
||||
f"Extracted code for {function.function_name} is not syntactically valid Java:\n{target_code}"
|
||||
)
|
||||
|
||||
return CodeContext(
|
||||
|
|
@ -120,7 +121,7 @@ def extract_code_context(
|
|||
)
|
||||
|
||||
|
||||
def _get_parent_type_name(function: FunctionInfo) -> str | None:
|
||||
def _get_parent_type_name(function: FunctionToOptimize) -> str | None:
|
||||
"""Get the parent type name (class, interface, or enum) for a function.
|
||||
|
||||
Args:
|
||||
|
|
@ -558,7 +559,7 @@ def _wrap_method_in_type_skeleton(method_code: str, skeleton: TypeSkeleton) -> s
|
|||
_wrap_method_in_class_skeleton = _wrap_method_in_type_skeleton
|
||||
|
||||
|
||||
def extract_function_source(source: str, function: FunctionInfo) -> str:
|
||||
def extract_function_source(source: str, function: FunctionToOptimize) -> str:
|
||||
"""Extract the source code of a function from the full file source.
|
||||
|
||||
Args:
|
||||
|
|
@ -572,8 +573,8 @@ def extract_function_source(source: str, function: FunctionInfo) -> str:
|
|||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Include Javadoc if present
|
||||
start_line = function.doc_start_line or function.start_line
|
||||
end_line = function.end_line
|
||||
start_line = function.doc_start_line or function.starting_line
|
||||
end_line = function.ending_line
|
||||
|
||||
# Convert from 1-indexed to 0-indexed
|
||||
start_idx = start_line - 1
|
||||
|
|
@ -583,7 +584,7 @@ def extract_function_source(source: str, function: FunctionInfo) -> str:
|
|||
|
||||
|
||||
def find_helper_functions(
|
||||
function: FunctionInfo,
|
||||
function: FunctionToOptimize,
|
||||
project_root: Path,
|
||||
max_depth: int = 2,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
|
|
@ -624,12 +625,12 @@ def find_helper_functions(
|
|||
|
||||
helpers.append(
|
||||
HelperFunction(
|
||||
name=func.name,
|
||||
name=func.function_name,
|
||||
qualified_name=func.qualified_name,
|
||||
file_path=file_path,
|
||||
source_code=func_source,
|
||||
start_line=func.start_line,
|
||||
end_line=func.end_line,
|
||||
start_line=func.starting_line,
|
||||
end_line=func.ending_line,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -648,7 +649,7 @@ def find_helper_functions(
|
|||
|
||||
|
||||
def _find_same_class_helpers(
|
||||
function: FunctionInfo,
|
||||
function: FunctionToOptimize,
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> list[HelperFunction]:
|
||||
"""Find helper methods in the same class as the target function.
|
||||
|
|
@ -676,7 +677,7 @@ def _find_same_class_helpers(
|
|||
# Find which methods the target function calls
|
||||
target_method = None
|
||||
for method in methods:
|
||||
if method.name == function.name and method.class_name == function.class_name:
|
||||
if method.name == function.function_name and method.class_name == function.class_name:
|
||||
target_method = method
|
||||
break
|
||||
|
||||
|
|
@ -689,7 +690,7 @@ def _find_same_class_helpers(
|
|||
# Add called methods from the same class as helpers
|
||||
for method in methods:
|
||||
if (
|
||||
method.name != function.name
|
||||
method.name != function.function_name
|
||||
and method.class_name == function.class_name
|
||||
and method.name in called_methods
|
||||
):
|
||||
|
|
@ -716,7 +717,7 @@ def _find_same_class_helpers(
|
|||
|
||||
def extract_read_only_context(
|
||||
source: str,
|
||||
function: FunctionInfo,
|
||||
function: FunctionToOptimize,
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> str:
|
||||
"""Extract read-only context (fields, constants, inner classes).
|
||||
|
|
|
|||
|
|
@ -10,13 +10,10 @@ import logging
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import (
|
||||
FunctionFilterCriteria,
|
||||
FunctionInfo,
|
||||
Language,
|
||||
ParentInfo,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import FunctionFilterCriteria
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
|
@ -28,7 +25,7 @@ def discover_functions(
|
|||
file_path: Path,
|
||||
filter_criteria: FunctionFilterCriteria | None = None,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionInfo]:
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Find all optimizable functions/methods in a Java file.
|
||||
|
||||
Uses tree-sitter to parse the file and find methods that can be optimized.
|
||||
|
|
@ -39,7 +36,7 @@ def discover_functions(
|
|||
analyzer: Optional JavaAnalyzer instance (created if not provided).
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo objects for discovered functions.
|
||||
List of FunctionToOptimize objects for discovered functions.
|
||||
|
||||
"""
|
||||
criteria = filter_criteria or FunctionFilterCriteria()
|
||||
|
|
@ -58,7 +55,7 @@ def discover_functions_from_source(
|
|||
file_path: Path | None = None,
|
||||
filter_criteria: FunctionFilterCriteria | None = None,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionInfo]:
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Find all optimizable functions/methods in Java source code.
|
||||
|
||||
Args:
|
||||
|
|
@ -68,7 +65,7 @@ def discover_functions_from_source(
|
|||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo objects for discovered functions.
|
||||
List of FunctionToOptimize objects for discovered functions.
|
||||
|
||||
"""
|
||||
criteria = filter_criteria or FunctionFilterCriteria()
|
||||
|
|
@ -82,7 +79,7 @@ def discover_functions_from_source(
|
|||
include_static=True,
|
||||
)
|
||||
|
||||
functions: list[FunctionInfo] = []
|
||||
functions: list[FunctionToOptimize] = []
|
||||
|
||||
for method in methods:
|
||||
# Apply filters
|
||||
|
|
@ -90,22 +87,22 @@ def discover_functions_from_source(
|
|||
continue
|
||||
|
||||
# Build parents list
|
||||
parents: list[ParentInfo] = []
|
||||
parents: list[FunctionParent] = []
|
||||
if method.class_name:
|
||||
parents.append(ParentInfo(name=method.class_name, type="ClassDef"))
|
||||
parents.append(FunctionParent(name=method.class_name, type="ClassDef"))
|
||||
|
||||
functions.append(
|
||||
FunctionInfo(
|
||||
name=method.name,
|
||||
FunctionToOptimize(
|
||||
function_name=method.name,
|
||||
file_path=file_path or Path("unknown.java"),
|
||||
start_line=method.start_line,
|
||||
end_line=method.end_line,
|
||||
start_col=method.start_col,
|
||||
end_col=method.end_col,
|
||||
parents=tuple(parents),
|
||||
starting_line=method.start_line,
|
||||
ending_line=method.end_line,
|
||||
starting_col=method.start_col,
|
||||
ending_col=method.end_col,
|
||||
parents=parents,
|
||||
is_async=False, # Java doesn't have async keyword
|
||||
is_method=method.class_name is not None,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
doc_start_line=method.javadoc_start_line,
|
||||
)
|
||||
)
|
||||
|
|
@ -182,7 +179,7 @@ def _should_include_method(
|
|||
def discover_test_methods(
|
||||
file_path: Path,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionInfo]:
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Find all JUnit test methods in a Java test file.
|
||||
|
||||
Looks for methods annotated with @Test, @ParameterizedTest, @RepeatedTest, etc.
|
||||
|
|
@ -192,7 +189,7 @@ def discover_test_methods(
|
|||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo objects for discovered test methods.
|
||||
List of FunctionToOptimize objects for discovered test methods.
|
||||
|
||||
"""
|
||||
try:
|
||||
|
|
@ -205,7 +202,7 @@ def discover_test_methods(
|
|||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
|
||||
test_methods: list[FunctionInfo] = []
|
||||
test_methods: list[FunctionToOptimize] = []
|
||||
|
||||
# Find methods with test annotations
|
||||
_walk_tree_for_test_methods(tree.root_node, source_bytes, file_path, test_methods, analyzer, current_class=None)
|
||||
|
|
@ -217,7 +214,7 @@ def _walk_tree_for_test_methods(
|
|||
node,
|
||||
source_bytes: bytes,
|
||||
file_path: Path,
|
||||
test_methods: list[FunctionInfo],
|
||||
test_methods: list[FunctionToOptimize],
|
||||
analyzer: JavaAnalyzer,
|
||||
current_class: str | None,
|
||||
) -> None:
|
||||
|
|
@ -250,22 +247,22 @@ def _walk_tree_for_test_methods(
|
|||
if name_node:
|
||||
method_name = analyzer.get_node_text(name_node, source_bytes)
|
||||
|
||||
parents: list[ParentInfo] = []
|
||||
parents: list[FunctionParent] = []
|
||||
if current_class:
|
||||
parents.append(ParentInfo(name=current_class, type="ClassDef"))
|
||||
parents.append(FunctionParent(name=current_class, type="ClassDef"))
|
||||
|
||||
test_methods.append(
|
||||
FunctionInfo(
|
||||
name=method_name,
|
||||
FunctionToOptimize(
|
||||
function_name=method_name,
|
||||
file_path=file_path,
|
||||
start_line=node.start_point[0] + 1,
|
||||
end_line=node.end_point[0] + 1,
|
||||
start_col=node.start_point[1],
|
||||
end_col=node.end_point[1],
|
||||
parents=tuple(parents),
|
||||
starting_line=node.start_point[0] + 1,
|
||||
ending_line=node.end_point[0] + 1,
|
||||
starting_col=node.start_point[1],
|
||||
ending_col=node.end_point[1],
|
||||
parents=list(parents),
|
||||
is_async=False,
|
||||
is_method=current_class is not None,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -285,7 +282,7 @@ def get_method_by_name(
|
|||
method_name: str,
|
||||
class_name: str | None = None,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> FunctionInfo | None:
|
||||
) -> FunctionToOptimize | None:
|
||||
"""Find a specific method by name in a Java file.
|
||||
|
||||
Args:
|
||||
|
|
@ -295,13 +292,13 @@ def get_method_by_name(
|
|||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
FunctionInfo for the method, or None if not found.
|
||||
FunctionToOptimize for the method, or None if not found.
|
||||
|
||||
"""
|
||||
functions = discover_functions(file_path, analyzer=analyzer)
|
||||
|
||||
for func in functions:
|
||||
if func.name == method_name:
|
||||
if func.function_name == method_name:
|
||||
if class_name is None or func.class_name == class_name:
|
||||
return func
|
||||
|
||||
|
|
@ -312,7 +309,7 @@ def get_class_methods(
|
|||
file_path: Path,
|
||||
class_name: str,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionInfo]:
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Get all methods in a specific class.
|
||||
|
||||
Args:
|
||||
|
|
@ -321,7 +318,7 @@ def get_class_methods(
|
|||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo objects for methods in the class.
|
||||
List of FunctionToOptimize objects for methods in the class.
|
||||
|
||||
"""
|
||||
functions = discover_functions(file_path, analyzer=analyzer)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import re
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.java.parser import JavaAnalyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -30,16 +30,16 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def _get_function_name(func: Any) -> str:
|
||||
"""Get the function name from either FunctionInfo or FunctionToOptimize."""
|
||||
if hasattr(func, "name"):
|
||||
return func.name
|
||||
"""Get the function name from FunctionToOptimize."""
|
||||
if hasattr(func, "function_name"):
|
||||
return func.function_name
|
||||
if hasattr(func, "name"):
|
||||
return func.name
|
||||
raise AttributeError(f"Cannot get function name from {type(func)}")
|
||||
|
||||
|
||||
def _get_qualified_name(func: Any) -> str:
|
||||
"""Get the qualified name from either FunctionInfo or FunctionToOptimize."""
|
||||
"""Get the qualified name from FunctionToOptimize."""
|
||||
if hasattr(func, "qualified_name"):
|
||||
return func.qualified_name
|
||||
# Build qualified name from function_name and parents
|
||||
|
|
@ -56,7 +56,7 @@ def _get_qualified_name(func: Any) -> str:
|
|||
|
||||
def instrument_for_behavior(
|
||||
source: str,
|
||||
functions: Sequence[FunctionInfo],
|
||||
functions: Sequence[FunctionToOptimize],
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Add behavior instrumentation to capture inputs/outputs.
|
||||
|
|
@ -84,7 +84,7 @@ def instrument_for_behavior(
|
|||
|
||||
def instrument_for_benchmarking(
|
||||
test_source: str,
|
||||
target_function: FunctionInfo,
|
||||
target_function: FunctionToOptimize,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Add timing instrumentation to test code.
|
||||
|
|
@ -109,7 +109,7 @@ def instrument_for_benchmarking(
|
|||
def instrument_existing_test(
|
||||
test_path: Path,
|
||||
call_positions: Sequence,
|
||||
function_to_optimize: Any, # FunctionInfo or FunctionToOptimize
|
||||
function_to_optimize: Any, # FunctionToOptimize or FunctionToOptimize
|
||||
tests_project_root: Path,
|
||||
mode: str, # "behavior" or "performance"
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
|
|
@ -573,7 +573,7 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
|
|||
|
||||
|
||||
def create_benchmark_test(
|
||||
target_function: FunctionInfo,
|
||||
target_function: FunctionToOptimize,
|
||||
test_setup_code: str,
|
||||
invocation_code: str,
|
||||
iterations: int = 1000,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from dataclasses import dataclass
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -213,7 +213,7 @@ def _insert_class_members(
|
|||
|
||||
def replace_function(
|
||||
source: str,
|
||||
function: FunctionInfo,
|
||||
function: FunctionToOptimize,
|
||||
new_source: str,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
|
|
@ -233,7 +233,7 @@ def replace_function(
|
|||
|
||||
Args:
|
||||
source: Original source code.
|
||||
function: FunctionInfo identifying the function to replace.
|
||||
function: FunctionToOptimize identifying the function to replace.
|
||||
new_source: New function source code (may include class with helpers).
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
|
|
@ -243,8 +243,12 @@ def replace_function(
|
|||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
func_name = function.function_name
|
||||
func_start_line = function.starting_line
|
||||
func_end_line = function.ending_line
|
||||
|
||||
# Parse the optimization to extract components
|
||||
parsed = _parse_optimization_source(new_source, function.name, analyzer)
|
||||
parsed = _parse_optimization_source(new_source, func_name, analyzer)
|
||||
|
||||
# Find the method in the original source
|
||||
methods = analyzer.find_methods(source)
|
||||
|
|
@ -254,7 +258,7 @@ def replace_function(
|
|||
# Find all methods matching the name (there may be overloads)
|
||||
matching_methods = [
|
||||
m for m in methods
|
||||
if m.name == function.name
|
||||
if m.name == func_name
|
||||
and (function.class_name is None or m.class_name == function.class_name)
|
||||
]
|
||||
|
||||
|
|
@ -267,18 +271,18 @@ def replace_function(
|
|||
logger.debug(
|
||||
"Found %d overloads of %s. Function start_line=%s, end_line=%s",
|
||||
len(matching_methods),
|
||||
function.name,
|
||||
function.start_line,
|
||||
function.end_line,
|
||||
func_name,
|
||||
func_start_line,
|
||||
func_end_line,
|
||||
)
|
||||
for i, m in enumerate(matching_methods):
|
||||
logger.debug(" Overload %d: lines %d-%d", i, m.start_line, m.end_line)
|
||||
if function.start_line and function.end_line:
|
||||
if func_start_line and func_end_line:
|
||||
for i, method in enumerate(matching_methods):
|
||||
# Check if the line numbers are close (account for minor differences
|
||||
# that can occur due to different parsing or file transformations)
|
||||
# Use a tolerance of 5 lines to handle edge cases
|
||||
if abs(method.start_line - function.start_line) <= 5:
|
||||
if abs(method.start_line - func_start_line) <= 5:
|
||||
target_method = method
|
||||
target_overload_index = i
|
||||
logger.debug(
|
||||
|
|
@ -286,21 +290,21 @@ def replace_function(
|
|||
i,
|
||||
method.start_line,
|
||||
method.end_line,
|
||||
function.start_line,
|
||||
function.end_line,
|
||||
func_start_line,
|
||||
func_end_line,
|
||||
)
|
||||
break
|
||||
if not target_method:
|
||||
# Fallback: use the first match
|
||||
logger.warning(
|
||||
"Multiple overloads of %s found but no line match, using first match",
|
||||
function.name,
|
||||
func_name,
|
||||
)
|
||||
target_method = matching_methods[0]
|
||||
target_overload_index = 0
|
||||
|
||||
if not target_method:
|
||||
logger.error("Could not find method %s in source", function.name)
|
||||
logger.error("Could not find method %s in source", func_name)
|
||||
return source
|
||||
|
||||
# Get the class name for inserting new members
|
||||
|
|
@ -348,7 +352,7 @@ def replace_function(
|
|||
methods = analyzer.find_methods(source)
|
||||
matching_methods = [
|
||||
m for m in methods
|
||||
if m.name == function.name
|
||||
if m.name == func_name
|
||||
and (function.class_name is None or m.class_name == function.class_name)
|
||||
]
|
||||
|
||||
|
|
@ -363,7 +367,7 @@ def replace_function(
|
|||
else:
|
||||
logger.error(
|
||||
"Lost target method %s after adding members (had index %d, found %d overloads)",
|
||||
function.name,
|
||||
func_name,
|
||||
target_overload_index,
|
||||
len(matching_methods),
|
||||
)
|
||||
|
|
@ -457,7 +461,7 @@ def _apply_indentation(lines: list[str], base_indent: str) -> str:
|
|||
|
||||
def replace_method_body(
|
||||
source: str,
|
||||
function: FunctionInfo,
|
||||
function: FunctionToOptimize,
|
||||
new_body: str,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
|
|
@ -465,7 +469,7 @@ def replace_method_body(
|
|||
|
||||
Args:
|
||||
source: Original source code.
|
||||
function: FunctionInfo identifying the function.
|
||||
function: FunctionToOptimize identifying the function.
|
||||
new_body: New method body (code between braces).
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
|
|
@ -476,24 +480,26 @@ def replace_method_body(
|
|||
analyzer = analyzer or get_java_analyzer()
|
||||
source_bytes = source.encode("utf8")
|
||||
|
||||
func_name = function.function_name
|
||||
|
||||
# Find the method
|
||||
methods = analyzer.find_methods(source)
|
||||
target_method = None
|
||||
|
||||
for method in methods:
|
||||
if method.name == function.name:
|
||||
if method.name == func_name:
|
||||
if function.class_name is None or method.class_name == function.class_name:
|
||||
target_method = method
|
||||
break
|
||||
|
||||
if not target_method:
|
||||
logger.error("Could not find method %s", function.name)
|
||||
logger.error("Could not find method %s", func_name)
|
||||
return source
|
||||
|
||||
# Find the body node
|
||||
body_node = target_method.node.child_by_field_name("body")
|
||||
if not body_node:
|
||||
logger.error("Method %s has no body (abstract?)", function.name)
|
||||
logger.error("Method %s has no body (abstract?)", func_name)
|
||||
return source
|
||||
|
||||
# Get the body's byte positions
|
||||
|
|
@ -596,14 +602,14 @@ def insert_method(
|
|||
|
||||
def remove_method(
|
||||
source: str,
|
||||
function: FunctionInfo,
|
||||
function: FunctionToOptimize,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Remove a method from source code.
|
||||
|
||||
Args:
|
||||
source: The source code.
|
||||
function: FunctionInfo identifying the method to remove.
|
||||
function: FunctionToOptimize identifying the method to remove.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
|
|
@ -612,18 +618,20 @@ def remove_method(
|
|||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
func_name = function.function_name
|
||||
|
||||
# Find the method
|
||||
methods = analyzer.find_methods(source)
|
||||
target_method = None
|
||||
|
||||
for method in methods:
|
||||
if method.name == function.name:
|
||||
if method.name == func_name:
|
||||
if function.class_name is None or method.class_name == function.class_name:
|
||||
target_method = method
|
||||
break
|
||||
|
||||
if not target_method:
|
||||
logger.error("Could not find method %s", function.name)
|
||||
logger.error("Could not find method %s", func_name)
|
||||
return source
|
||||
|
||||
# Determine removal range (include Javadoc)
|
||||
|
|
@ -669,14 +677,15 @@ def remove_test_functions(
|
|||
result = test_source
|
||||
|
||||
for method in methods_to_remove:
|
||||
# Create a FunctionInfo for removal
|
||||
func_info = FunctionInfo(
|
||||
name=method.name,
|
||||
# Create a FunctionToOptimize for removal
|
||||
func_info = FunctionToOptimize(
|
||||
function_name=method.name,
|
||||
file_path=Path("temp.java"),
|
||||
start_line=method.start_line,
|
||||
end_line=method.end_line,
|
||||
parents=(),
|
||||
starting_line=method.start_line,
|
||||
ending_line=method.end_line,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language="java",
|
||||
)
|
||||
result = remove_method(result, func_info, analyzer)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ import logging
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import (
|
||||
CodeContext,
|
||||
FunctionFilterCriteria,
|
||||
FunctionInfo,
|
||||
HelperFunction,
|
||||
Language,
|
||||
LanguageSupport,
|
||||
|
|
@ -94,18 +94,18 @@ class JavaSupport(LanguageSupport):
|
|||
|
||||
def discover_functions(
|
||||
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
|
||||
) -> list[FunctionInfo]:
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Find all optimizable functions in a Java file."""
|
||||
return discover_functions(file_path, filter_criteria, self._analyzer)
|
||||
|
||||
def discover_functions_from_source(
|
||||
self, source: str, file_path: Path | None = None, filter_criteria: FunctionFilterCriteria | None = None
|
||||
) -> list[FunctionInfo]:
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Find all optimizable functions in Java source code."""
|
||||
return discover_functions_from_source(source, file_path, filter_criteria, self._analyzer)
|
||||
|
||||
def discover_tests(
|
||||
self, test_root: Path, source_functions: Sequence[FunctionInfo]
|
||||
self, test_root: Path, source_functions: Sequence[FunctionToOptimize]
|
||||
) -> dict[str, list[TestInfo]]:
|
||||
"""Map source functions to their tests."""
|
||||
return discover_tests(test_root, source_functions, self._analyzer)
|
||||
|
|
@ -113,13 +113,13 @@ class JavaSupport(LanguageSupport):
|
|||
# === Code Analysis ===
|
||||
|
||||
def extract_code_context(
|
||||
self, function: FunctionInfo, project_root: Path, module_root: Path
|
||||
self, function: FunctionToOptimize, project_root: Path, module_root: Path
|
||||
) -> CodeContext:
|
||||
"""Extract function code and its dependencies."""
|
||||
return extract_code_context(function, project_root, module_root, analyzer=self._analyzer)
|
||||
|
||||
def find_helper_functions(
|
||||
self, function: FunctionInfo, project_root: Path
|
||||
self, function: FunctionToOptimize, project_root: Path
|
||||
) -> list[HelperFunction]:
|
||||
"""Find helper functions called by the target function."""
|
||||
return find_helper_functions(function, project_root, analyzer=self._analyzer)
|
||||
|
|
@ -127,7 +127,7 @@ class JavaSupport(LanguageSupport):
|
|||
# === Code Transformation ===
|
||||
|
||||
def replace_function(
|
||||
self, source: str, function: FunctionInfo, new_source: str
|
||||
self, source: str, function: FunctionToOptimize, new_source: str
|
||||
) -> str:
|
||||
"""Replace a function in source code with new implementation."""
|
||||
return replace_function(source, function, new_source, self._analyzer)
|
||||
|
|
@ -156,13 +156,13 @@ class JavaSupport(LanguageSupport):
|
|||
# === Instrumentation ===
|
||||
|
||||
def instrument_for_behavior(
|
||||
self, source: str, functions: Sequence[FunctionInfo]
|
||||
self, source: str, functions: Sequence[FunctionToOptimize]
|
||||
) -> str:
|
||||
"""Add behavior instrumentation to capture inputs/outputs."""
|
||||
return instrument_for_behavior(source, functions, self._analyzer)
|
||||
|
||||
def instrument_for_benchmarking(
|
||||
self, test_source: str, target_function: FunctionInfo
|
||||
self, test_source: str, target_function: FunctionToOptimize
|
||||
) -> str:
|
||||
"""Add timing instrumentation to test code."""
|
||||
return instrument_for_benchmarking(test_source, target_function, self._analyzer)
|
||||
|
|
@ -317,7 +317,7 @@ class JavaSupport(LanguageSupport):
|
|||
)
|
||||
|
||||
def instrument_source_for_line_profiler(
|
||||
self, func_info: FunctionInfo, line_profiler_output_file: Path
|
||||
self, func_info: FunctionToOptimize, line_profiler_output_file: Path
|
||||
) -> bool:
|
||||
"""Instrument source code before line profiling."""
|
||||
# Not yet implemented for Java
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ from collections import defaultdict
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import FunctionInfo, TestInfo
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import TestInfo
|
||||
from codeflash.languages.java.config import detect_java_project
|
||||
from codeflash.languages.java.discovery import discover_test_methods
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
|
||||
|
|
@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
def discover_tests(
|
||||
test_root: Path,
|
||||
source_functions: Sequence[FunctionInfo],
|
||||
source_functions: Sequence[FunctionToOptimize],
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> dict[str, list[TestInfo]]:
|
||||
"""Map source functions to their tests via static analysis.
|
||||
|
|
@ -48,9 +49,9 @@ def discover_tests(
|
|||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
# Build a map of function names for quick lookup
|
||||
function_map: dict[str, FunctionInfo] = {}
|
||||
function_map: dict[str, FunctionToOptimize] = {}
|
||||
for func in source_functions:
|
||||
function_map[func.name] = func
|
||||
function_map[func.function_name] = func
|
||||
function_map[func.qualified_name] = func
|
||||
|
||||
# Find all test files (various naming conventions)
|
||||
|
|
@ -77,7 +78,7 @@ def discover_tests(
|
|||
for func_name in matched_functions:
|
||||
result[func_name].append(
|
||||
TestInfo(
|
||||
test_name=test_method.name,
|
||||
test_name=test_method.function_name,
|
||||
test_file=test_file,
|
||||
test_class=test_method.class_name,
|
||||
)
|
||||
|
|
@ -90,9 +91,9 @@ def discover_tests(
|
|||
|
||||
|
||||
def _match_test_to_functions(
|
||||
test_method: FunctionInfo,
|
||||
test_method: FunctionToOptimize,
|
||||
test_source: str,
|
||||
function_map: dict[str, FunctionInfo],
|
||||
function_map: dict[str, FunctionToOptimize],
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> list[str]:
|
||||
"""Match a test method to source functions it might exercise.
|
||||
|
|
@ -100,7 +101,7 @@ def _match_test_to_functions(
|
|||
Args:
|
||||
test_method: The test method.
|
||||
test_source: Full source code of the test file.
|
||||
function_map: Map of function names to FunctionInfo.
|
||||
function_map: Map of function names to FunctionToOptimize.
|
||||
analyzer: JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
|
|
@ -111,10 +112,10 @@ def _match_test_to_functions(
|
|||
|
||||
# Strategy 1: Test method name contains function name
|
||||
# e.g., testAdd -> add, testCalculatorAdd -> Calculator.add
|
||||
test_name_lower = test_method.name.lower()
|
||||
test_name_lower = test_method.function_name.lower()
|
||||
|
||||
for func_name, func_info in function_map.items():
|
||||
if func_info.name.lower() in test_name_lower:
|
||||
if func_info.function_name.lower() in test_name_lower:
|
||||
matched.append(func_info.qualified_name)
|
||||
|
||||
# Strategy 2: Method call analysis
|
||||
|
|
@ -126,8 +127,8 @@ def _match_test_to_functions(
|
|||
method_calls = _find_method_calls_in_range(
|
||||
tree.root_node,
|
||||
source_bytes,
|
||||
test_method.start_line,
|
||||
test_method.end_line,
|
||||
test_method.starting_line,
|
||||
test_method.ending_line,
|
||||
analyzer,
|
||||
)
|
||||
|
||||
|
|
@ -285,7 +286,7 @@ def _find_method_calls_in_range(
|
|||
|
||||
|
||||
def find_tests_for_function(
|
||||
function: FunctionInfo,
|
||||
function: FunctionToOptimize,
|
||||
test_root: Path,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[TestInfo]:
|
||||
|
|
@ -336,7 +337,7 @@ def get_test_class_for_source_class(
|
|||
def discover_all_tests(
|
||||
test_root: Path,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionInfo]:
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Discover all test methods in a test directory.
|
||||
|
||||
Args:
|
||||
|
|
@ -344,11 +345,11 @@ def discover_all_tests(
|
|||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo for all test methods.
|
||||
List of FunctionToOptimize for all test methods.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
all_tests: list[FunctionInfo] = []
|
||||
all_tests: list[FunctionToOptimize] = []
|
||||
|
||||
# Find all test files (various naming conventions)
|
||||
test_files = (
|
||||
|
|
@ -408,7 +409,7 @@ def get_test_methods_for_class(
|
|||
test_file: Path,
|
||||
test_class_name: str | None = None,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionInfo]:
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Get all test methods in a specific test class.
|
||||
|
||||
Args:
|
||||
|
|
@ -417,7 +418,7 @@ def get_test_methods_for_class(
|
|||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo for test methods.
|
||||
List of FunctionToOptimize for test methods.
|
||||
|
||||
"""
|
||||
tests = discover_test_methods(test_file, analyzer)
|
||||
|
|
@ -455,7 +456,7 @@ def build_test_mapping_for_project(
|
|||
# Discover all source functions
|
||||
from codeflash.languages.java.discovery import discover_functions
|
||||
|
||||
source_functions: list[FunctionInfo] = []
|
||||
source_functions: list[FunctionToOptimize] = []
|
||||
for java_file in config.source_root.rglob("*.java"):
|
||||
funcs = discover_functions(java_file, analyzer=analyzer)
|
||||
source_functions.extend(funcs)
|
||||
|
|
|
|||
842
codeflash/languages/javascript/find_references.py
Normal file
842
codeflash/languages/javascript/find_references.py
Normal file
|
|
@ -0,0 +1,842 @@
|
|||
"""Find references for JavaScript/TypeScript functions.
|
||||
|
||||
This module provides functionality to find all references (call sites) of a function
|
||||
across a JavaScript/TypeScript codebase. Similar to Jedi's find_references for Python,
|
||||
this uses tree-sitter to parse and analyze code.
|
||||
|
||||
Key features:
|
||||
- Finds all call sites of a function across multiple files
|
||||
- Handles various import patterns (named, default, namespace, re-exports, aliases)
|
||||
- Supports both ES modules and CommonJS
|
||||
- Handles memoized functions, callbacks, and method calls
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from tree_sitter import Node
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Reference:
|
||||
"""Represents a reference (call site) to a function."""
|
||||
|
||||
file_path: Path # File containing the reference
|
||||
line: int # 1-indexed line number
|
||||
column: int # 0-indexed column number
|
||||
end_line: int # 1-indexed end line
|
||||
end_column: int # 0-indexed end column
|
||||
context: str # The line of code containing the reference
|
||||
reference_type: str # Type: "call", "callback", "memoized", "import", "reexport"
|
||||
import_name: str | None # Name used to import the function (may differ from original)
|
||||
caller_function: str | None = None # Name of the function containing this reference
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExportedFunction:
|
||||
"""Represents how a function is exported from its source file."""
|
||||
|
||||
function_name: str # The local function name
|
||||
export_name: str | None # The name it's exported as (may differ)
|
||||
is_default: bool # Whether it's a default export
|
||||
file_path: Path # The source file
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceSearchContext:
|
||||
"""Context for tracking visited files during reference search."""
|
||||
|
||||
visited_files: set[Path] = field(default_factory=set)
|
||||
max_files: int = 1000 # Limit to prevent runaway searches
|
||||
|
||||
|
||||
class ReferenceFinder:
|
||||
"""Finds all references to a function across a JavaScript/TypeScript codebase.
|
||||
|
||||
This class provides functionality similar to Jedi's find_references for Python,
|
||||
but for JavaScript/TypeScript using tree-sitter.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from codeflash.languages.javascript.find_references import ReferenceFinder
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="myHelper", file_path=Path("/my/project/src/utils.ts"), parents=[], language="javascript"
|
||||
)
|
||||
finder = ReferenceFinder(project_root=Path("/my/project"))
|
||||
references = finder.find_references(func)
|
||||
for ref in references:
|
||||
print(f"{ref.file_path}:{ref.line} - {ref.context}")
|
||||
```
|
||||
"""
|
||||
|
||||
# File extensions to search
|
||||
EXTENSIONS = (".ts", ".tsx", ".js", ".jsx", ".mjs", ".cjs")
|
||||
|
||||
def __init__(self, project_root: Path, exclude_patterns: list[str] | None = None) -> None:
|
||||
"""Initialize the ReferenceFinder.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project to search.
|
||||
exclude_patterns: Glob patterns of directories/files to exclude.
|
||||
Defaults to ['node_modules', 'dist', 'build', '.git'].
|
||||
|
||||
"""
|
||||
self.project_root = project_root
|
||||
self.exclude_patterns = exclude_patterns or ["node_modules", "dist", "build", ".git", "coverage", "__pycache__"]
|
||||
self._file_cache: dict[Path, str] = {}
|
||||
|
||||
def find_references(
|
||||
self, function_to_optimize: FunctionToOptimize, include_definition: bool = False, max_files: int = 1000
|
||||
) -> list[Reference]:
|
||||
"""Find all references to a function across the project.
|
||||
|
||||
Args:
|
||||
function_to_optimize: The function to find references for.
|
||||
include_definition: Whether to include the function definition itself.
|
||||
max_files: Maximum number of files to search (prevents runaway searches).
|
||||
|
||||
Returns:
|
||||
List of Reference objects describing each call site.
|
||||
|
||||
"""
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
function_name = function_to_optimize.function_name
|
||||
source_file = function_to_optimize.file_path
|
||||
|
||||
references: list[Reference] = []
|
||||
context = ReferenceSearchContext(max_files=max_files)
|
||||
|
||||
# Step 1: Analyze how the function is exported from its source file
|
||||
source_code = self._read_file(source_file)
|
||||
if source_code is None:
|
||||
logger.warning("Could not read source file: %s", source_file)
|
||||
return references
|
||||
|
||||
analyzer = get_analyzer_for_file(source_file)
|
||||
exported = self._analyze_exports(function_to_optimize, source_file, source_code, analyzer)
|
||||
|
||||
if not exported:
|
||||
logger.debug("Function %s is not exported from %s", function_name, source_file)
|
||||
# Still search in same file for internal references
|
||||
same_file_refs = self._find_references_in_file(
|
||||
source_file, source_code, function_name, None, analyzer, include_self=not include_definition
|
||||
)
|
||||
references.extend(same_file_refs)
|
||||
return references
|
||||
|
||||
# Step 2: Find all files that might import from the source file
|
||||
context.visited_files.add(source_file)
|
||||
|
||||
# Track files that re-export our function (we'll search for imports to these too)
|
||||
reexport_files: list[tuple[Path, str]] = [] # (file_path, export_name)
|
||||
|
||||
# Step 3: Search all project files for imports and calls
|
||||
# We use a separate set for files checked for re-exports to avoid duplicate work
|
||||
checked_for_reexports: set[Path] = set()
|
||||
|
||||
for file_path in self._iter_project_files():
|
||||
if file_path in context.visited_files:
|
||||
continue
|
||||
if len(context.visited_files) >= context.max_files:
|
||||
logger.warning("Reached max file limit (%d), stopping search", max_files)
|
||||
break
|
||||
|
||||
file_code = self._read_file(file_path)
|
||||
if file_code is None:
|
||||
continue
|
||||
|
||||
file_analyzer = get_analyzer_for_file(file_path)
|
||||
|
||||
# Check if this file imports from the source file
|
||||
imports = file_analyzer.find_imports(file_code)
|
||||
import_info = self._find_matching_import(imports, source_file, file_path, exported)
|
||||
|
||||
if import_info:
|
||||
# Found an import - mark as visited and search for calls
|
||||
context.visited_files.add(file_path)
|
||||
import_name, original_import = import_info
|
||||
file_refs = self._find_references_in_file(
|
||||
file_path, file_code, function_name, import_name, file_analyzer, include_self=True
|
||||
)
|
||||
references.extend(file_refs)
|
||||
|
||||
# Always check for re-exports (even without direct import match)
|
||||
# This handles the case where a file re-exports from our source file
|
||||
if file_path not in checked_for_reexports:
|
||||
checked_for_reexports.add(file_path)
|
||||
reexport_refs = self._find_reexports_direct(file_path, file_code, source_file, exported, file_analyzer)
|
||||
references.extend(reexport_refs)
|
||||
|
||||
# Track re-export files for later searching
|
||||
for ref in reexport_refs:
|
||||
reexport_files.append((file_path, ref.import_name))
|
||||
|
||||
# Step 4: Follow re-export chains to find references through re-exports
|
||||
for reexport_file, reexport_name in reexport_files:
|
||||
# Create a new ExportedFunction for the re-exported function
|
||||
reexported = ExportedFunction(
|
||||
function_name=reexport_name, export_name=reexport_name, is_default=False, file_path=reexport_file
|
||||
)
|
||||
|
||||
# Search for imports to the re-export file
|
||||
for file_path in self._iter_project_files():
|
||||
if file_path in context.visited_files:
|
||||
continue
|
||||
if file_path == reexport_file:
|
||||
continue
|
||||
if len(context.visited_files) >= context.max_files:
|
||||
break
|
||||
|
||||
file_code = self._read_file(file_path)
|
||||
if file_code is None:
|
||||
continue
|
||||
|
||||
file_analyzer = get_analyzer_for_file(file_path)
|
||||
imports = file_analyzer.find_imports(file_code)
|
||||
|
||||
# Check if this file imports from the re-export file
|
||||
import_info = self._find_matching_import(imports, reexport_file, file_path, reexported)
|
||||
|
||||
if import_info:
|
||||
context.visited_files.add(file_path)
|
||||
import_name, original_import = import_info
|
||||
file_refs = self._find_references_in_file(
|
||||
file_path, file_code, reexport_name, import_name, file_analyzer, include_self=True
|
||||
)
|
||||
# Avoid duplicates
|
||||
existing_locs = {(r.file_path, r.line, r.column) for r in references}
|
||||
for ref in file_refs:
|
||||
if (ref.file_path, ref.line, ref.column) not in existing_locs:
|
||||
references.append(ref)
|
||||
|
||||
# Step 5: Include references in the same file (internal calls)
|
||||
if include_definition or not exported:
|
||||
same_file_refs = self._find_references_in_file(
|
||||
source_file, source_code, function_name, None, analyzer, include_self=True
|
||||
)
|
||||
# Filter out duplicate references
|
||||
existing_locs = {(r.file_path, r.line, r.column) for r in references}
|
||||
for ref in same_file_refs:
|
||||
if (ref.file_path, ref.line, ref.column) not in existing_locs:
|
||||
references.append(ref)
|
||||
|
||||
# Step 6: Deduplicate references (same file, line, column)
|
||||
seen: set[tuple[Path, int, int]] = set()
|
||||
unique_refs: list[Reference] = []
|
||||
for ref in references:
|
||||
key = (ref.file_path, ref.line, ref.column)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_refs.append(ref)
|
||||
|
||||
return unique_refs
|
||||
|
||||
def _analyze_exports(
|
||||
self, function_to_optimize: FunctionToOptimize, file_path: Path, source_code: str, analyzer: TreeSitterAnalyzer
|
||||
) -> ExportedFunction | None:
|
||||
"""Analyze how a function is exported from its file.
|
||||
|
||||
For class methods, also checks if the containing class is exported.
|
||||
|
||||
Args:
|
||||
function_to_optimize: The function to check.
|
||||
file_path: Path to the source file.
|
||||
source_code: Source code content.
|
||||
analyzer: TreeSitterAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
ExportedFunction if the function is exported, None otherwise.
|
||||
|
||||
"""
|
||||
function_name = function_to_optimize.function_name
|
||||
class_name = function_to_optimize.class_name
|
||||
is_exported, export_name = analyzer.is_function_exported(source_code, function_name, class_name)
|
||||
|
||||
if not is_exported:
|
||||
return None
|
||||
|
||||
return ExportedFunction(
|
||||
function_name=function_name,
|
||||
export_name=export_name,
|
||||
is_default=(export_name == "default"),
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
def _find_matching_import(
|
||||
self, imports: list[ImportInfo], source_file: Path, importing_file: Path, exported: ExportedFunction
|
||||
) -> tuple[str, ImportInfo] | None:
|
||||
"""Find if any import in a file imports the target function.
|
||||
|
||||
Args:
|
||||
imports: List of imports in the file.
|
||||
source_file: Path to the file containing the function definition.
|
||||
importing_file: Path to the file being checked for imports.
|
||||
exported: Information about how the function is exported.
|
||||
|
||||
Returns:
|
||||
Tuple of (imported_name, ImportInfo) if found, None otherwise.
|
||||
|
||||
"""
|
||||
from codeflash.languages.javascript.import_resolver import ImportResolver
|
||||
|
||||
resolver = ImportResolver(self.project_root)
|
||||
|
||||
for imp in imports:
|
||||
# Resolve the import to see if it points to our source file
|
||||
resolved = resolver.resolve_import(imp, importing_file)
|
||||
if resolved is None:
|
||||
continue
|
||||
|
||||
if resolved.file_path != source_file:
|
||||
continue
|
||||
|
||||
# This import is from our source file - check if it imports our function
|
||||
if exported.is_default:
|
||||
# Default export - check default import
|
||||
if imp.default_import:
|
||||
return (imp.default_import, imp)
|
||||
# Also check namespace import
|
||||
if imp.namespace_import:
|
||||
return (f"{imp.namespace_import}.default", imp)
|
||||
else:
|
||||
# Named export - check named imports
|
||||
export_name = exported.export_name or exported.function_name
|
||||
for name, alias in imp.named_imports:
|
||||
if name == export_name:
|
||||
return (alias if alias else name, imp)
|
||||
|
||||
# Check namespace import
|
||||
if imp.namespace_import:
|
||||
return (f"{imp.namespace_import}.{export_name}", imp)
|
||||
|
||||
# Handle CommonJS default import used as namespace
|
||||
# e.g., const helpers = require('./helpers'); helpers.processConfig()
|
||||
# In this case, default_import acts like a namespace
|
||||
if imp.default_import and not imp.named_imports:
|
||||
return (f"{imp.default_import}.{export_name}", imp)
|
||||
|
||||
return None
|
||||
|
||||
def _find_references_in_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
source_code: str,
|
||||
function_name: str,
|
||||
import_name: str | None,
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
include_self: bool = True,
|
||||
) -> list[Reference]:
|
||||
"""Find all references to a function within a single file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to search.
|
||||
source_code: Source code content.
|
||||
function_name: Original function name.
|
||||
import_name: Name the function is imported as (may be different).
|
||||
analyzer: TreeSitterAnalyzer instance.
|
||||
include_self: Whether to include references in the file.
|
||||
|
||||
Returns:
|
||||
List of Reference objects.
|
||||
|
||||
"""
|
||||
references: list[Reference] = []
|
||||
source_bytes = source_code.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
lines = source_code.splitlines()
|
||||
|
||||
# The name to search for (either imported name or original)
|
||||
search_name = import_name if import_name else function_name
|
||||
|
||||
# Handle namespace imports (e.g., "utils.helper")
|
||||
if "." in search_name:
|
||||
namespace, member = search_name.split(".", 1)
|
||||
self._find_member_calls(tree.root_node, source_bytes, lines, file_path, namespace, member, references, None)
|
||||
else:
|
||||
# Find direct calls and other reference types
|
||||
self._find_identifier_references(
|
||||
tree.root_node, source_bytes, lines, file_path, search_name, function_name, references, None
|
||||
)
|
||||
|
||||
return references
|
||||
|
||||
def _find_identifier_references(
|
||||
self,
|
||||
node: Node,
|
||||
source_bytes: bytes,
|
||||
lines: list[str],
|
||||
file_path: Path,
|
||||
search_name: str,
|
||||
original_name: str,
|
||||
references: list[Reference],
|
||||
current_function: str | None,
|
||||
) -> None:
|
||||
"""Recursively find references to an identifier in the AST.
|
||||
|
||||
Args:
|
||||
node: Current tree-sitter node.
|
||||
source_bytes: Source code as bytes.
|
||||
lines: Source code split into lines.
|
||||
file_path: Path to the file.
|
||||
search_name: Name to search for.
|
||||
original_name: Original function name.
|
||||
references: List to append references to.
|
||||
current_function: Name of the containing function (for context).
|
||||
|
||||
"""
|
||||
# Track current function context
|
||||
new_current_function = current_function
|
||||
if node.type in ("function_declaration", "method_definition"):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
|
||||
elif node.type in ("variable_declarator",):
|
||||
# Arrow function or function expression assigned to variable
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if name_node and value_node and value_node.type in ("arrow_function", "function_expression"):
|
||||
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
|
||||
|
||||
# Check for call expressions
|
||||
if node.type == "call_expression":
|
||||
func_node = node.child_by_field_name("function")
|
||||
if func_node and func_node.type == "identifier":
|
||||
name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8")
|
||||
if name == search_name:
|
||||
ref = self._create_reference(file_path, func_node, lines, "call", search_name, current_function)
|
||||
references.append(ref)
|
||||
|
||||
# Check for identifiers used as callbacks or passed as arguments
|
||||
elif node.type == "identifier":
|
||||
name = source_bytes[node.start_byte : node.end_byte].decode("utf8")
|
||||
if name == search_name:
|
||||
parent = node.parent
|
||||
# Determine reference type based on context
|
||||
ref_type = self._determine_reference_type(node, parent, source_bytes)
|
||||
if ref_type:
|
||||
ref = self._create_reference(file_path, node, lines, ref_type, search_name, current_function)
|
||||
references.append(ref)
|
||||
|
||||
# Recurse into children
|
||||
for child in node.children:
|
||||
self._find_identifier_references(
|
||||
child, source_bytes, lines, file_path, search_name, original_name, references, new_current_function
|
||||
)
|
||||
|
||||
def _find_member_calls(
|
||||
self,
|
||||
node: Node,
|
||||
source_bytes: bytes,
|
||||
lines: list[str],
|
||||
file_path: Path,
|
||||
namespace: str,
|
||||
member: str,
|
||||
references: list[Reference],
|
||||
current_function: str | None,
|
||||
) -> None:
|
||||
"""Find calls to namespace.member (e.g., utils.helper()).
|
||||
|
||||
Args:
|
||||
node: Current tree-sitter node.
|
||||
source_bytes: Source code as bytes.
|
||||
lines: Source code split into lines.
|
||||
file_path: Path to the file.
|
||||
namespace: The namespace/object name.
|
||||
member: The member/property name.
|
||||
references: List to append references to.
|
||||
current_function: Name of the containing function.
|
||||
|
||||
"""
|
||||
# Track current function context
|
||||
new_current_function = current_function
|
||||
if node.type in ("function_declaration", "method_definition"):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
|
||||
|
||||
# Check for call expressions with member access
|
||||
if node.type == "call_expression":
|
||||
func_node = node.child_by_field_name("function")
|
||||
if func_node and func_node.type == "member_expression":
|
||||
obj_node = func_node.child_by_field_name("object")
|
||||
prop_node = func_node.child_by_field_name("property")
|
||||
|
||||
if obj_node and prop_node:
|
||||
obj_name = source_bytes[obj_node.start_byte : obj_node.end_byte].decode("utf8")
|
||||
prop_name = source_bytes[prop_node.start_byte : prop_node.end_byte].decode("utf8")
|
||||
|
||||
if obj_name == namespace and prop_name == member:
|
||||
ref = self._create_reference(
|
||||
file_path, func_node, lines, "call", f"{namespace}.{member}", current_function
|
||||
)
|
||||
references.append(ref)
|
||||
|
||||
# Also check for member expression used as callback
|
||||
elif node.type == "member_expression":
|
||||
obj_node = node.child_by_field_name("object")
|
||||
prop_node = node.child_by_field_name("property")
|
||||
|
||||
if obj_node and prop_node:
|
||||
obj_name = source_bytes[obj_node.start_byte : obj_node.end_byte].decode("utf8")
|
||||
prop_name = source_bytes[prop_node.start_byte : prop_node.end_byte].decode("utf8")
|
||||
|
||||
if obj_name == namespace and prop_name == member:
|
||||
parent = node.parent
|
||||
if parent and parent.type != "call_expression":
|
||||
ref_type = self._determine_reference_type(node, parent, source_bytes)
|
||||
if ref_type:
|
||||
ref = self._create_reference(
|
||||
file_path, node, lines, ref_type, f"{namespace}.{member}", current_function
|
||||
)
|
||||
references.append(ref)
|
||||
|
||||
# Recurse into children
|
||||
for child in node.children:
|
||||
self._find_member_calls(
|
||||
child, source_bytes, lines, file_path, namespace, member, references, new_current_function
|
||||
)
|
||||
|
||||
def _determine_reference_type(self, node: Node, parent: Node | None, source_bytes: bytes) -> str | None:
|
||||
"""Determine the type of reference based on AST context.
|
||||
|
||||
Args:
|
||||
node: The identifier node.
|
||||
parent: The parent node.
|
||||
source_bytes: Source code as bytes.
|
||||
|
||||
Returns:
|
||||
Reference type string or None if this isn't a valid reference.
|
||||
|
||||
"""
|
||||
if parent is None:
|
||||
return None
|
||||
|
||||
# Skip import statements
|
||||
if parent.type in ("import_specifier", "import_clause", "named_imports"):
|
||||
return None
|
||||
|
||||
# Skip function declarations (the function name itself)
|
||||
if parent.type in ("function_declaration", "method_definition"):
|
||||
name_node = parent.child_by_field_name("name")
|
||||
if name_node and name_node.id == node.id:
|
||||
return None
|
||||
|
||||
# Skip variable declarations where this is being defined
|
||||
if parent.type == "variable_declarator":
|
||||
name_node = parent.child_by_field_name("name")
|
||||
if name_node and name_node.id == node.id:
|
||||
return None
|
||||
|
||||
# Skip export specifiers
|
||||
if parent.type == "export_specifier":
|
||||
return None
|
||||
|
||||
# Check if passed as argument (callback or memoized)
|
||||
if parent.type == "arguments":
|
||||
# Check if grandparent is a memoize call
|
||||
grandparent = parent.parent
|
||||
if grandparent and grandparent.type == "call_expression":
|
||||
func_node = grandparent.child_by_field_name("function")
|
||||
if func_node:
|
||||
func_name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8")
|
||||
if any(m in func_name.lower() for m in ["memoize", "memo", "cache"]):
|
||||
return "memoized"
|
||||
return "callback"
|
||||
|
||||
# Check if used in array (often callback patterns)
|
||||
if parent.type == "array":
|
||||
return "callback"
|
||||
|
||||
# Check if passed to memoize/memoization functions (direct call check)
|
||||
if parent.type == "call_expression":
|
||||
func_node = parent.child_by_field_name("function")
|
||||
if func_node:
|
||||
func_name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8")
|
||||
if any(m in func_name.lower() for m in ["memoize", "memo", "cache"]):
|
||||
return "memoized"
|
||||
|
||||
# Check if used in a call expression as the function
|
||||
if parent.type == "call_expression":
|
||||
func_node = parent.child_by_field_name("function")
|
||||
if func_node and func_node.id == node.id:
|
||||
return "call"
|
||||
|
||||
# Check if assigned to a property
|
||||
if parent.type in ("pair", "property"):
|
||||
return "property"
|
||||
|
||||
# Check if part of member expression (method call setup)
|
||||
if parent.type == "member_expression":
|
||||
obj_node = parent.child_by_field_name("object")
|
||||
if obj_node and obj_node.id == node.id:
|
||||
# This is the object in obj.method
|
||||
return None # We'll catch the actual call elsewhere
|
||||
|
||||
# Generic reference
|
||||
return "reference"
|
||||
|
||||
def _create_reference(
|
||||
self,
|
||||
file_path: Path,
|
||||
node: Node,
|
||||
lines: list[str],
|
||||
ref_type: str,
|
||||
import_name: str,
|
||||
caller_function: str | None,
|
||||
) -> Reference:
|
||||
"""Create a Reference object from a node.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
node: The tree-sitter node.
|
||||
lines: Source code lines.
|
||||
ref_type: Type of reference.
|
||||
import_name: Name the function was imported as.
|
||||
caller_function: Name of the containing function.
|
||||
|
||||
Returns:
|
||||
A Reference object.
|
||||
|
||||
"""
|
||||
line_num = node.start_point[0] + 1 # 1-indexed
|
||||
context = lines[node.start_point[0]] if node.start_point[0] < len(lines) else ""
|
||||
|
||||
return Reference(
|
||||
file_path=file_path,
|
||||
line=line_num,
|
||||
column=node.start_point[1],
|
||||
end_line=node.end_point[0] + 1,
|
||||
end_column=node.end_point[1],
|
||||
context=context.strip(),
|
||||
reference_type=ref_type,
|
||||
import_name=import_name,
|
||||
caller_function=caller_function,
|
||||
)
|
||||
|
||||
def _find_reexports(
|
||||
self,
|
||||
file_path: Path,
|
||||
source_code: str,
|
||||
exported: ExportedFunction,
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
context: ReferenceSearchContext,
|
||||
) -> list[Reference]:
|
||||
"""Find re-exports of the function.
|
||||
|
||||
Re-exports look like: export { helper } from './utils'
|
||||
|
||||
Args:
|
||||
file_path: Path to the file being checked.
|
||||
source_code: Source code content.
|
||||
exported: Information about the original export.
|
||||
analyzer: TreeSitterAnalyzer instance.
|
||||
context: Search context.
|
||||
|
||||
Returns:
|
||||
List of Reference objects for re-exports.
|
||||
|
||||
"""
|
||||
references: list[Reference] = []
|
||||
exports = analyzer.find_exports(source_code)
|
||||
lines = source_code.splitlines()
|
||||
|
||||
for exp in exports:
|
||||
if not exp.is_reexport:
|
||||
continue
|
||||
|
||||
# Check if this re-exports our function
|
||||
export_name = exported.export_name or exported.function_name
|
||||
for name, alias in exp.exported_names:
|
||||
if name == export_name:
|
||||
# This is a re-export of our function
|
||||
# Create a reference with the line info from the export
|
||||
context_line = lines[exp.start_line - 1] if exp.start_line <= len(lines) else ""
|
||||
ref = Reference(
|
||||
file_path=file_path,
|
||||
line=exp.start_line,
|
||||
column=0,
|
||||
end_line=exp.end_line,
|
||||
end_column=0,
|
||||
context=context_line.strip(),
|
||||
reference_type="reexport",
|
||||
import_name=alias if alias else name,
|
||||
caller_function=None,
|
||||
)
|
||||
references.append(ref)
|
||||
|
||||
return references
|
||||
|
||||
def _find_reexports_direct(
|
||||
self,
|
||||
file_path: Path,
|
||||
source_code: str,
|
||||
source_file: Path,
|
||||
exported: ExportedFunction,
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
) -> list[Reference]:
|
||||
"""Find re-exports that directly reference our source file.
|
||||
|
||||
This method checks if a file has re-export statements that
|
||||
reference our source file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file being checked.
|
||||
source_code: Source code content.
|
||||
source_file: The original source file we're looking for references to.
|
||||
exported: Information about the original export.
|
||||
analyzer: TreeSitterAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of Reference objects for re-exports.
|
||||
|
||||
"""
|
||||
from codeflash.languages.javascript.import_resolver import ImportResolver
|
||||
|
||||
references: list[Reference] = []
|
||||
exports = analyzer.find_exports(source_code)
|
||||
lines = source_code.splitlines()
|
||||
resolver = ImportResolver(self.project_root)
|
||||
|
||||
for exp in exports:
|
||||
if not exp.is_reexport or not exp.reexport_source:
|
||||
continue
|
||||
|
||||
# Create a fake ImportInfo to resolve the re-export source
|
||||
from codeflash.languages.treesitter_utils import ImportInfo
|
||||
|
||||
fake_import = ImportInfo(
|
||||
module_path=exp.reexport_source,
|
||||
default_import=None,
|
||||
named_imports=[],
|
||||
namespace_import=None,
|
||||
is_type_only=False,
|
||||
start_line=exp.start_line,
|
||||
end_line=exp.end_line,
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(fake_import, file_path)
|
||||
if resolved is None or resolved.file_path != source_file:
|
||||
continue
|
||||
|
||||
# This file re-exports from our source file
|
||||
export_name = exported.export_name or exported.function_name
|
||||
for name, alias in exp.exported_names:
|
||||
if name == export_name:
|
||||
context_line = lines[exp.start_line - 1] if exp.start_line <= len(lines) else ""
|
||||
ref = Reference(
|
||||
file_path=file_path,
|
||||
line=exp.start_line,
|
||||
column=0,
|
||||
end_line=exp.end_line,
|
||||
end_column=0,
|
||||
context=context_line.strip(),
|
||||
reference_type="reexport",
|
||||
import_name=alias if alias else name,
|
||||
caller_function=None,
|
||||
)
|
||||
references.append(ref)
|
||||
|
||||
return references
|
||||
|
||||
def _iter_project_files(self) -> list[Path]:
|
||||
"""Iterate over all JavaScript/TypeScript files in the project.
|
||||
|
||||
Returns:
|
||||
List of file paths to search.
|
||||
|
||||
"""
|
||||
files: list[Path] = []
|
||||
|
||||
for ext in self.EXTENSIONS:
|
||||
for file_path in self.project_root.rglob(f"*{ext}"):
|
||||
# Check exclusion patterns
|
||||
if self._should_exclude(file_path):
|
||||
continue
|
||||
files.append(file_path)
|
||||
|
||||
return files
|
||||
|
||||
def _should_exclude(self, file_path: Path) -> bool:
|
||||
"""Check if a file should be excluded from search.
|
||||
|
||||
Args:
|
||||
file_path: Path to check.
|
||||
|
||||
Returns:
|
||||
True if the file should be excluded.
|
||||
|
||||
"""
|
||||
path_str = str(file_path)
|
||||
return any(pattern in path_str for pattern in self.exclude_patterns)
|
||||
|
||||
def _read_file(self, file_path: Path) -> str | None:
|
||||
"""Read a file's contents with caching.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
|
||||
Returns:
|
||||
File contents or None if unreadable.
|
||||
|
||||
"""
|
||||
if file_path in self._file_cache:
|
||||
return self._file_cache[file_path]
|
||||
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
self._file_cache[file_path] = content
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.debug("Could not read file %s: %s", file_path, e)
|
||||
return None
|
||||
|
||||
|
||||
def find_references(
|
||||
function_to_optimize: FunctionToOptimize, project_root: Path | None = None, max_files: int = 1000
|
||||
) -> list[Reference]:
|
||||
"""Convenience function to find all references to a function.
|
||||
|
||||
This is a simple wrapper around ReferenceFinder for common use cases.
|
||||
|
||||
Args:
|
||||
function_to_optimize: The function to find references for.
|
||||
project_root: Root directory of the project. If None, uses source_file's parent.
|
||||
max_files: Maximum number of files to search.
|
||||
|
||||
Returns:
|
||||
List of Reference objects describing each call site.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from pathlib import Path
|
||||
from codeflash.languages.javascript.find_references import find_references
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="myHelper", file_path=Path("/my/project/src/utils.ts"), parents=[], language="javascript"
|
||||
)
|
||||
refs = find_references(func, project_root=Path("/my/project"))
|
||||
for ref in refs:
|
||||
print(f"{ref.file_path}:{ref.line}:{ref.column} - {ref.reference_type}")
|
||||
```
|
||||
|
||||
"""
|
||||
if project_root is None:
|
||||
project_root = function_to_optimize.file_path.parent
|
||||
|
||||
finder = ReferenceFinder(project_root)
|
||||
return finder.find_references(function_to_optimize, max_files=max_files)
|
||||
|
|
@ -12,7 +12,8 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.languages.base import FunctionInfo, HelperFunction
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import HelperFunction
|
||||
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -43,7 +44,8 @@ class ImportResolver:
|
|||
project_root: Root directory of the project.
|
||||
|
||||
"""
|
||||
self.project_root = project_root
|
||||
# Resolve to real path to handle macOS symlinks like /var -> /private/var
|
||||
self.project_root = project_root.resolve()
|
||||
self._resolution_cache: dict[tuple[Path, str], Path | None] = {}
|
||||
|
||||
def resolve_import(self, import_info: ImportInfo, source_file: Path) -> ResolvedImport | None:
|
||||
|
|
@ -302,7 +304,7 @@ class MultiFileHelperFinder:
|
|||
|
||||
def find_helpers(
|
||||
self,
|
||||
function: FunctionInfo,
|
||||
function: FunctionToOptimize,
|
||||
source: str,
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
imports: list[ImportInfo],
|
||||
|
|
@ -328,7 +330,7 @@ class MultiFileHelperFinder:
|
|||
all_functions = analyzer.find_functions(source, include_methods=True)
|
||||
target_func = None
|
||||
for func in all_functions:
|
||||
if func.name == function.name and func.start_line == function.start_line:
|
||||
if func.name == function.function_name and func.start_line == function.starting_line:
|
||||
target_func = func
|
||||
break
|
||||
|
||||
|
|
@ -505,7 +507,7 @@ class MultiFileHelperFinder:
|
|||
Dictionary mapping file paths to lists of helper functions.
|
||||
|
||||
"""
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
if context.current_depth >= context.max_depth:
|
||||
|
|
@ -525,9 +527,13 @@ class MultiFileHelperFinder:
|
|||
analyzer = get_analyzer_for_file(file_path)
|
||||
imports = analyzer.find_imports(source)
|
||||
|
||||
# Create FunctionInfo for the helper
|
||||
func_info = FunctionInfo(
|
||||
name=helper.name, file_path=file_path, start_line=helper.start_line, end_line=helper.end_line, parents=()
|
||||
# Create FunctionToOptimize for the helper
|
||||
func_info = FunctionToOptimize(
|
||||
function_name=helper.name,
|
||||
file_path=file_path,
|
||||
parents=[],
|
||||
starting_line=helper.start_line,
|
||||
ending_line=helper.end_line,
|
||||
)
|
||||
|
||||
# Recursively find helpers
|
||||
|
|
|
|||
|
|
@ -72,15 +72,16 @@ class StandaloneCallTransformer:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, func_name: str, qualified_name: str, capture_func: str) -> None:
|
||||
self.func_name = func_name
|
||||
self.qualified_name = qualified_name
|
||||
def __init__(self, function_to_optimize: FunctionToOptimize, capture_func: str) -> None:
|
||||
self.function_to_optimize = function_to_optimize
|
||||
self.func_name = function_to_optimize.function_name
|
||||
self.qualified_name = function_to_optimize.qualified_name
|
||||
self.capture_func = capture_func
|
||||
self.invocation_counter = 0
|
||||
# Pattern to match func_name( with optional leading await and optional object prefix
|
||||
# Captures: (whitespace)(await )?(object.)*func_name(
|
||||
# We'll filter out expect() and codeflash. cases in the transform loop
|
||||
self._call_pattern = re.compile(rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(func_name)}\s*\(")
|
||||
self._call_pattern = re.compile(rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(self.func_name)}\s*\(")
|
||||
|
||||
def transform(self, code: str) -> str:
|
||||
"""Transform all standalone calls in the code."""
|
||||
|
|
@ -310,7 +311,7 @@ class StandaloneCallTransformer:
|
|||
|
||||
|
||||
def transform_standalone_calls(
|
||||
code: str, func_name: str, qualified_name: str, capture_func: str, start_counter: int = 0
|
||||
code: str, function_to_optimize: FunctionToOptimize, capture_func: str, start_counter: int = 0
|
||||
) -> tuple[str, int]:
|
||||
"""Transform standalone func(...) calls in JavaScript test code.
|
||||
|
||||
|
|
@ -318,8 +319,7 @@ def transform_standalone_calls(
|
|||
|
||||
Args:
|
||||
code: The test code to transform.
|
||||
func_name: Name of the function being tested.
|
||||
qualified_name: Fully qualified function name.
|
||||
function_to_optimize: The function being tested.
|
||||
capture_func: The capture function to use ('capture' or 'capturePerf').
|
||||
start_counter: Starting value for the invocation counter.
|
||||
|
||||
|
|
@ -327,9 +327,7 @@ def transform_standalone_calls(
|
|||
Tuple of (transformed code, final counter value).
|
||||
|
||||
"""
|
||||
transformer = StandaloneCallTransformer(
|
||||
func_name=func_name, qualified_name=qualified_name, capture_func=capture_func
|
||||
)
|
||||
transformer = StandaloneCallTransformer(function_to_optimize=function_to_optimize, capture_func=capture_func)
|
||||
transformer.invocation_counter = start_counter
|
||||
result = transformer.transform(code)
|
||||
return result, transformer.invocation_counter
|
||||
|
|
@ -348,15 +346,18 @@ class ExpectCallTransformer:
|
|||
- Multi-arg assertions: expect(func(args)).toBeCloseTo(0.5, 2)
|
||||
"""
|
||||
|
||||
def __init__(self, func_name: str, qualified_name: str, capture_func: str, remove_assertions: bool = False) -> None:
|
||||
self.func_name = func_name
|
||||
self.qualified_name = qualified_name
|
||||
def __init__(
|
||||
self, function_to_optimize: FunctionToOptimize, capture_func: str, remove_assertions: bool = False
|
||||
) -> None:
|
||||
self.function_to_optimize = function_to_optimize
|
||||
self.func_name = function_to_optimize.function_name
|
||||
self.qualified_name = function_to_optimize.qualified_name
|
||||
self.capture_func = capture_func
|
||||
self.remove_assertions = remove_assertions
|
||||
self.invocation_counter = 0
|
||||
# Pattern to match start of expect((object.)*func_name(
|
||||
# Captures: (whitespace), (object prefix like calc. or this.)
|
||||
self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(func_name)}\s*\(")
|
||||
self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(self.func_name)}\s*\(")
|
||||
|
||||
def transform(self, code: str) -> str:
|
||||
"""Transform all expect calls in the code."""
|
||||
|
|
@ -601,7 +602,7 @@ class ExpectCallTransformer:
|
|||
|
||||
|
||||
def transform_expect_calls(
|
||||
code: str, func_name: str, qualified_name: str, capture_func: str, remove_assertions: bool = False
|
||||
code: str, function_to_optimize: FunctionToOptimize, capture_func: str, remove_assertions: bool = False
|
||||
) -> tuple[str, int]:
|
||||
"""Transform expect(func(...)).assertion() calls in JavaScript test code.
|
||||
|
||||
|
|
@ -609,8 +610,7 @@ def transform_expect_calls(
|
|||
|
||||
Args:
|
||||
code: The test code to transform.
|
||||
func_name: Name of the function being tested.
|
||||
qualified_name: Fully qualified function name.
|
||||
function_to_optimize: The function being tested.
|
||||
capture_func: The capture function to use ('capture' or 'capturePerf').
|
||||
remove_assertions: If True, remove assertions entirely (for generated tests).
|
||||
|
||||
|
|
@ -619,10 +619,7 @@ def transform_expect_calls(
|
|||
|
||||
"""
|
||||
transformer = ExpectCallTransformer(
|
||||
func_name=func_name,
|
||||
qualified_name=qualified_name,
|
||||
capture_func=capture_func,
|
||||
remove_assertions=remove_assertions,
|
||||
function_to_optimize=function_to_optimize, capture_func=capture_func, remove_assertions=remove_assertions
|
||||
)
|
||||
result = transformer.transform(code)
|
||||
return result, transformer.invocation_counter
|
||||
|
|
@ -658,8 +655,6 @@ def inject_profiling_into_existing_js_test(
|
|||
logger.error(f"Failed to read test file {test_path}: {e}")
|
||||
return False, None
|
||||
|
||||
func_name = function_to_optimize.function_name
|
||||
|
||||
# Get the relative path for test identification
|
||||
try:
|
||||
rel_path = test_path.relative_to(tests_project_root)
|
||||
|
|
@ -667,14 +662,12 @@ def inject_profiling_into_existing_js_test(
|
|||
rel_path = test_path
|
||||
|
||||
# Check if the function is imported/required in this test file
|
||||
if not _is_function_used_in_test(test_code, func_name):
|
||||
logger.debug(f"Function '{func_name}' not found in test file {test_path}")
|
||||
if not _is_function_used_in_test(test_code, function_to_optimize.function_name):
|
||||
logger.debug(f"Function '{function_to_optimize.function_name}' not found in test file {test_path}")
|
||||
return False, None
|
||||
|
||||
# Instrument the test code
|
||||
instrumented_code = _instrument_js_test_code(
|
||||
test_code, func_name, str(rel_path), mode, function_to_optimize.qualified_name
|
||||
)
|
||||
instrumented_code = _instrument_js_test_code(test_code, function_to_optimize, str(rel_path), mode)
|
||||
|
||||
if instrumented_code == test_code:
|
||||
logger.debug(f"No changes made to test file {test_path}")
|
||||
|
|
@ -716,16 +709,15 @@ def _is_function_used_in_test(code: str, func_name: str) -> bool:
|
|||
|
||||
|
||||
def _instrument_js_test_code(
|
||||
code: str, func_name: str, test_file_path: str, mode: str, qualified_name: str, remove_assertions: bool = False
|
||||
code: str, function_to_optimize: FunctionToOptimize, test_file_path: str, mode: str, remove_assertions: bool = False
|
||||
) -> str:
|
||||
"""Instrument JavaScript test code with profiling capture calls.
|
||||
|
||||
Args:
|
||||
code: Original test code.
|
||||
func_name: Name of the function to instrument.
|
||||
function_to_optimize: The function to instrument.
|
||||
test_file_path: Relative path to test file.
|
||||
mode: Testing mode (behavior or performance).
|
||||
qualified_name: Fully qualified function name.
|
||||
remove_assertions: If True, remove expect assertions entirely (for generated/regression tests).
|
||||
If False, keep the expect wrapper (for existing user-written tests).
|
||||
|
||||
|
|
@ -771,8 +763,7 @@ def _instrument_js_test_code(
|
|||
# Transform expect calls using the refactored transformer
|
||||
code, expect_counter = transform_expect_calls(
|
||||
code=code,
|
||||
func_name=func_name,
|
||||
qualified_name=qualified_name,
|
||||
function_to_optimize=function_to_optimize,
|
||||
capture_func=capture_func,
|
||||
remove_assertions=remove_assertions,
|
||||
)
|
||||
|
|
@ -780,11 +771,7 @@ def _instrument_js_test_code(
|
|||
# Transform standalone calls (not inside expect wrappers)
|
||||
# Continue counter from expect transformer to ensure unique IDs
|
||||
code, _final_counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name=func_name,
|
||||
qualified_name=qualified_name,
|
||||
capture_func=capture_func,
|
||||
start_counter=expect_counter,
|
||||
code=code, function_to_optimize=function_to_optimize, capture_func=capture_func, start_counter=expect_counter
|
||||
)
|
||||
|
||||
return code
|
||||
|
|
@ -941,7 +928,7 @@ def get_instrumented_test_path(original_path: Path, mode: str) -> Path:
|
|||
|
||||
|
||||
def instrument_generated_js_test(
|
||||
test_code: str, function_name: str, qualified_name: str, mode: str = TestingMode.BEHAVIOR
|
||||
test_code: str, function_to_optimize: FunctionToOptimize, mode: str = TestingMode.BEHAVIOR
|
||||
) -> str:
|
||||
"""Instrument generated JavaScript/TypeScript test code.
|
||||
|
||||
|
|
@ -956,8 +943,7 @@ def instrument_generated_js_test(
|
|||
|
||||
Args:
|
||||
test_code: The generated test code to instrument.
|
||||
function_name: Name of the function being tested.
|
||||
qualified_name: Fully qualified function name (e.g., 'module.funcName').
|
||||
function_to_optimize: The function being tested.
|
||||
mode: Testing mode - "behavior" or "performance".
|
||||
|
||||
Returns:
|
||||
|
|
@ -971,9 +957,8 @@ def instrument_generated_js_test(
|
|||
# Generated tests are treated as regression tests, so we remove LLM-generated assertions
|
||||
return _instrument_js_test_code(
|
||||
code=test_code,
|
||||
func_name=function_name,
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_file_path="generated_test",
|
||||
mode=mode,
|
||||
qualified_name=qualified_name,
|
||||
remove_assertions=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
|||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -40,7 +40,7 @@ class JavaScriptLineProfiler:
|
|||
self.output_file = output_file
|
||||
self.profiler_var = "__codeflash_line_profiler__"
|
||||
|
||||
def instrument_source(self, source: str, file_path: Path, functions: list[FunctionInfo]) -> str:
|
||||
def instrument_source(self, source: str, file_path: Path, functions: list[FunctionToOptimize]) -> str:
|
||||
"""Instrument JavaScript source code with line profiling.
|
||||
|
||||
Adds profiling instrumentation to track line-level execution for the
|
||||
|
|
@ -65,10 +65,10 @@ class JavaScriptLineProfiler:
|
|||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Process functions in reverse order to preserve line numbers
|
||||
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
|
||||
for func in sorted(functions, key=lambda f: f.starting_line, reverse=True):
|
||||
func_lines = self._instrument_function(func, lines, file_path)
|
||||
start_idx = func.start_line - 1
|
||||
end_idx = func.end_line
|
||||
start_idx = func.starting_line - 1
|
||||
end_idx = func.ending_line
|
||||
lines = lines[:start_idx] + func_lines + lines[end_idx:]
|
||||
|
||||
instrumented_source = "".join(lines)
|
||||
|
|
@ -171,7 +171,7 @@ const __codeflash_save_interval__ = setInterval(() => {self.profiler_var}.save()
|
|||
if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // Don't keep process alive
|
||||
"""
|
||||
|
||||
def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path) -> list[str]:
|
||||
def _instrument_function(self, func: FunctionToOptimize, lines: list[str], file_path: Path) -> list[str]:
|
||||
"""Instrument a single function with line profiling.
|
||||
|
||||
Args:
|
||||
|
|
@ -183,7 +183,7 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
|
|||
Instrumented function lines.
|
||||
|
||||
"""
|
||||
func_lines = lines[func.start_line - 1 : func.end_line]
|
||||
func_lines = lines[func.starting_line - 1 : func.ending_line]
|
||||
instrumented_lines = []
|
||||
|
||||
# Parse the function to find executable lines
|
||||
|
|
@ -194,7 +194,7 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
|
|||
tree = analyzer.parse(source.encode("utf8"))
|
||||
executable_lines = self._find_executable_lines(tree.root_node, source.encode("utf8"))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse function %s: %s", func.name, e)
|
||||
logger.warning("Failed to parse function %s: %s", func.function_name, e)
|
||||
return func_lines
|
||||
|
||||
# Add profiling to each executable line
|
||||
|
|
@ -203,7 +203,7 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
|
|||
|
||||
for local_idx, line in enumerate(func_lines):
|
||||
local_line_num = local_idx + 1 # 1-indexed within function
|
||||
global_line_num = func.start_line + local_idx # Global line number in original file
|
||||
global_line_num = func.starting_line + local_idx # Global line number in original file
|
||||
stripped = line.strip()
|
||||
|
||||
# Add enterFunction() call after the opening brace of the function
|
||||
|
|
|
|||
|
|
@ -184,10 +184,18 @@ def _get_relative_import_path(target_path: Path, source_path: Path) -> str:
|
|||
|
||||
|
||||
def add_js_extension(module_path: str) -> str:
|
||||
"""Add .js extension to relative module paths for ESM compatibility."""
|
||||
if module_path.startswith(("./", "../")):
|
||||
if not module_path.endswith(".js") and not module_path.endswith(".mjs"):
|
||||
return module_path + ".js"
|
||||
"""Process module path for ESM compatibility.
|
||||
|
||||
NOTE: This function intentionally does NOT add extensions because:
|
||||
1. TypeScript projects resolve modules without explicit extensions
|
||||
2. Adding .js to .ts imports causes "Cannot find module" errors
|
||||
3. Modern bundlers (webpack, vite, etc.) handle extension resolution automatically
|
||||
|
||||
The function name is preserved for backward compatibility but the behavior
|
||||
has been changed to NOT add extensions.
|
||||
"""
|
||||
# Previously this function added .js extensions, but this caused module resolution
|
||||
# errors in TypeScript projects. We now preserve paths without adding extensions.
|
||||
return module_path
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,22 +12,16 @@ import xml.etree.ElementTree as ET
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.languages.base import (
|
||||
CodeContext,
|
||||
FunctionFilterCriteria,
|
||||
FunctionInfo,
|
||||
HelperFunction,
|
||||
Language,
|
||||
ParentInfo,
|
||||
TestInfo,
|
||||
TestResult,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, Language, TestInfo, TestResult
|
||||
from codeflash.languages.registry import register_language
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from codeflash.languages.base import ReferenceInfo
|
||||
from codeflash.languages.treesitter_utils import TypeDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -53,6 +47,11 @@ class JavaScriptSupport:
|
|||
"""File extensions supported by JavaScript."""
|
||||
return (".js", ".jsx", ".mjs", ".cjs")
|
||||
|
||||
@property
|
||||
def default_file_extension(self) -> str:
|
||||
"""Default file extension for JavaScript."""
|
||||
return ".js"
|
||||
|
||||
@property
|
||||
def test_framework(self) -> str:
|
||||
"""Primary test framework for JavaScript."""
|
||||
|
|
@ -66,7 +65,7 @@ class JavaScriptSupport:
|
|||
|
||||
def discover_functions(
|
||||
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
|
||||
) -> list[FunctionInfo]:
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Find all optimizable functions in a JavaScript file.
|
||||
|
||||
Uses tree-sitter to parse the file and find functions.
|
||||
|
|
@ -76,7 +75,7 @@ class JavaScriptSupport:
|
|||
filter_criteria: Optional criteria to filter functions.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo objects for discovered functions.
|
||||
List of FunctionToOptimize objects for discovered functions.
|
||||
|
||||
"""
|
||||
criteria = filter_criteria or FunctionFilterCriteria()
|
||||
|
|
@ -93,7 +92,7 @@ class JavaScriptSupport:
|
|||
source, include_methods=criteria.include_methods, include_arrow_functions=True, require_name=True
|
||||
)
|
||||
|
||||
functions: list[FunctionInfo] = []
|
||||
functions: list[FunctionToOptimize] = []
|
||||
for func in tree_functions:
|
||||
# Check for return statement if required
|
||||
if criteria.require_return and not analyzer.has_return_statement(func, source):
|
||||
|
|
@ -104,24 +103,24 @@ class JavaScriptSupport:
|
|||
continue
|
||||
|
||||
# Build parents list
|
||||
parents: list[ParentInfo] = []
|
||||
parents: list[FunctionParent] = []
|
||||
if func.class_name:
|
||||
parents.append(ParentInfo(name=func.class_name, type="ClassDef"))
|
||||
parents.append(FunctionParent(name=func.class_name, type="ClassDef"))
|
||||
if func.parent_function:
|
||||
parents.append(ParentInfo(name=func.parent_function, type="FunctionDef"))
|
||||
parents.append(FunctionParent(name=func.parent_function, type="FunctionDef"))
|
||||
|
||||
functions.append(
|
||||
FunctionInfo(
|
||||
name=func.name,
|
||||
FunctionToOptimize(
|
||||
function_name=func.name,
|
||||
file_path=file_path,
|
||||
start_line=func.start_line,
|
||||
end_line=func.end_line,
|
||||
start_col=func.start_col,
|
||||
end_col=func.end_col,
|
||||
parents=tuple(parents),
|
||||
parents=parents,
|
||||
starting_line=func.start_line,
|
||||
ending_line=func.end_line,
|
||||
starting_col=func.start_col,
|
||||
ending_col=func.end_col,
|
||||
is_async=func.is_async,
|
||||
is_method=func.is_method,
|
||||
language=self.language,
|
||||
language=str(self.language),
|
||||
doc_start_line=func.doc_start_line,
|
||||
)
|
||||
)
|
||||
|
|
@ -132,7 +131,7 @@ class JavaScriptSupport:
|
|||
logger.warning("Failed to parse %s: %s", file_path, e)
|
||||
return []
|
||||
|
||||
def discover_functions_from_source(self, source: str, file_path: Path | None = None) -> list[FunctionInfo]:
|
||||
def discover_functions_from_source(self, source: str, file_path: Path | None = None) -> list[FunctionToOptimize]:
|
||||
"""Find all functions in source code string.
|
||||
|
||||
Uses tree-sitter to parse the source and find functions.
|
||||
|
|
@ -142,7 +141,7 @@ class JavaScriptSupport:
|
|||
file_path: Optional file path for context (used for language detection).
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo objects for discovered functions.
|
||||
List of FunctionToOptimize objects for discovered functions.
|
||||
|
||||
"""
|
||||
try:
|
||||
|
|
@ -156,27 +155,27 @@ class JavaScriptSupport:
|
|||
source, include_methods=True, include_arrow_functions=True, require_name=True
|
||||
)
|
||||
|
||||
functions: list[FunctionInfo] = []
|
||||
functions: list[FunctionToOptimize] = []
|
||||
for func in tree_functions:
|
||||
# Build parents list
|
||||
parents: list[ParentInfo] = []
|
||||
parents: list[FunctionParent] = []
|
||||
if func.class_name:
|
||||
parents.append(ParentInfo(name=func.class_name, type="ClassDef"))
|
||||
parents.append(FunctionParent(name=func.class_name, type="ClassDef"))
|
||||
if func.parent_function:
|
||||
parents.append(ParentInfo(name=func.parent_function, type="FunctionDef"))
|
||||
parents.append(FunctionParent(name=func.parent_function, type="FunctionDef"))
|
||||
|
||||
functions.append(
|
||||
FunctionInfo(
|
||||
name=func.name,
|
||||
FunctionToOptimize(
|
||||
function_name=func.name,
|
||||
file_path=file_path or Path("unknown"),
|
||||
start_line=func.start_line,
|
||||
end_line=func.end_line,
|
||||
start_col=func.start_col,
|
||||
end_col=func.end_col,
|
||||
parents=tuple(parents),
|
||||
parents=parents,
|
||||
starting_line=func.start_line,
|
||||
ending_line=func.end_line,
|
||||
starting_col=func.start_col,
|
||||
ending_col=func.end_col,
|
||||
is_async=func.is_async,
|
||||
is_method=func.is_method,
|
||||
language=self.language,
|
||||
language=str(self.language),
|
||||
doc_start_line=func.doc_start_line,
|
||||
)
|
||||
)
|
||||
|
|
@ -198,7 +197,9 @@ class JavaScriptSupport:
|
|||
"""
|
||||
return ["*.test.js", "*.test.jsx", "*.spec.js", "*.spec.jsx", "__tests__/**/*.js", "__tests__/**/*.jsx"]
|
||||
|
||||
def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionInfo]) -> dict[str, list[TestInfo]]:
|
||||
def discover_tests(
|
||||
self, test_root: Path, source_functions: Sequence[FunctionToOptimize]
|
||||
) -> dict[str, list[TestInfo]]:
|
||||
"""Map source functions to their tests via static analysis.
|
||||
|
||||
For JavaScript, this uses static analysis to find test files
|
||||
|
|
@ -240,7 +241,7 @@ class JavaScriptSupport:
|
|||
|
||||
# Match source functions to tests
|
||||
for func in source_functions:
|
||||
if func.name in imported_names or func.name in source:
|
||||
if func.function_name in imported_names or func.function_name in source:
|
||||
if func.qualified_name not in result:
|
||||
result[func.qualified_name] = []
|
||||
for test_name in test_functions:
|
||||
|
|
@ -282,7 +283,7 @@ class JavaScriptSupport:
|
|||
|
||||
# === Code Analysis ===
|
||||
|
||||
def extract_code_context(self, function: FunctionInfo, project_root: Path, module_root: Path) -> CodeContext:
|
||||
def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext:
|
||||
"""Extract function code and its dependencies.
|
||||
|
||||
Uses tree-sitter to analyze imports and find helper functions.
|
||||
|
|
@ -309,16 +310,16 @@ class JavaScriptSupport:
|
|||
tree_functions = analyzer.find_functions(source, include_methods=True, include_arrow_functions=True)
|
||||
target_func = None
|
||||
for func in tree_functions:
|
||||
if func.name == function.name and func.start_line == function.start_line:
|
||||
if func.name == function.function_name and func.start_line == function.starting_line:
|
||||
target_func = func
|
||||
break
|
||||
|
||||
# Extract the function source, including JSDoc if present
|
||||
lines = source.splitlines(keepends=True)
|
||||
if function.start_line and function.end_line:
|
||||
if function.starting_line and function.ending_line:
|
||||
# Use doc_start_line if available, otherwise fall back to start_line
|
||||
effective_start = (target_func.doc_start_line if target_func else None) or function.start_line
|
||||
target_lines = lines[effective_start - 1 : function.end_line]
|
||||
effective_start = (target_func.doc_start_line if target_func else None) or function.starting_line
|
||||
target_lines = lines[effective_start - 1 : function.ending_line]
|
||||
target_code = "".join(target_lines)
|
||||
else:
|
||||
target_code = ""
|
||||
|
|
@ -334,7 +335,7 @@ class JavaScriptSupport:
|
|||
|
||||
if class_name:
|
||||
# Find the class definition in the source to get proper indentation, JSDoc, constructor, and fields
|
||||
class_info = self._find_class_definition(source, class_name, analyzer, function.name)
|
||||
class_info = self._find_class_definition(source, class_name, analyzer, function.function_name)
|
||||
if class_info:
|
||||
class_jsdoc, class_indent, constructor_code, fields_code = class_info
|
||||
# Build the class body with fields, constructor, and target method
|
||||
|
|
@ -395,7 +396,7 @@ class JavaScriptSupport:
|
|||
# If not, raise an error to fail the optimization early
|
||||
if target_code and not self.validate_syntax(target_code):
|
||||
error_msg = (
|
||||
f"Extracted code for {function.name} is not syntactically valid JavaScript. "
|
||||
f"Extracted code for {function.function_name} is not syntactically valid JavaScript. "
|
||||
f"Cannot proceed with optimization."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
|
|
@ -544,7 +545,12 @@ class JavaScriptSupport:
|
|||
return (constructor_code, fields_code)
|
||||
|
||||
def _find_helper_functions(
|
||||
self, function: FunctionInfo, source: str, analyzer: TreeSitterAnalyzer, imports: list[Any], module_root: Path
|
||||
self,
|
||||
function: FunctionToOptimize,
|
||||
source: str,
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
imports: list[Any],
|
||||
module_root: Path,
|
||||
) -> list[HelperFunction]:
|
||||
"""Find helper functions called by the target function.
|
||||
|
||||
|
|
@ -569,7 +575,7 @@ class JavaScriptSupport:
|
|||
# Find the target function's tree-sitter node
|
||||
target_func = None
|
||||
for func in all_functions:
|
||||
if func.name == function.name and func.start_line == function.start_line:
|
||||
if func.name == function.function_name and func.start_line == function.starting_line:
|
||||
target_func = func
|
||||
break
|
||||
|
||||
|
|
@ -585,7 +591,7 @@ class JavaScriptSupport:
|
|||
|
||||
# Match calls to functions in the same file
|
||||
for func in all_functions:
|
||||
if func.name in calls_set and func.name != function.name:
|
||||
if func.name in calls_set and func.name != function.function_name:
|
||||
# Extract source including JSDoc if present
|
||||
effective_start = func.doc_start_line or func.start_line
|
||||
helper_lines = lines[effective_start - 1 : func.end_line]
|
||||
|
|
@ -715,7 +721,12 @@ class JavaScriptSupport:
|
|||
return "\n".join(global_lines)
|
||||
|
||||
def _extract_type_definitions_context(
|
||||
self, function: FunctionInfo, source: str, analyzer: TreeSitterAnalyzer, imports: list[Any], module_root: Path
|
||||
self,
|
||||
function: FunctionToOptimize,
|
||||
source: str,
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
imports: list[Any],
|
||||
module_root: Path,
|
||||
) -> tuple[str, set[str]]:
|
||||
"""Extract type definitions used by the function for read-only context.
|
||||
|
||||
|
|
@ -741,7 +752,7 @@ class JavaScriptSupport:
|
|||
|
||||
"""
|
||||
# Extract type names from function parameters and return type
|
||||
type_names = analyzer.extract_type_annotations(source, function.name, function.start_line or 1)
|
||||
type_names = analyzer.extract_type_annotations(source, function.function_name, function.starting_line or 1)
|
||||
|
||||
# If this is a class method, also extract types from class fields
|
||||
if function.is_method and function.parents:
|
||||
|
|
@ -939,7 +950,7 @@ class JavaScriptSupport:
|
|||
|
||||
return found_definitions
|
||||
|
||||
def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> list[HelperFunction]:
|
||||
def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]:
|
||||
"""Find helper functions called by the target function.
|
||||
|
||||
Args:
|
||||
|
|
@ -956,12 +967,68 @@ class JavaScriptSupport:
|
|||
imports = analyzer.find_imports(source)
|
||||
return self._find_helper_functions(function, source, analyzer, imports, project_root)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to find helpers for %s: %s", function.name, e)
|
||||
logger.warning("Failed to find helpers for %s: %s", function.function_name, e)
|
||||
return []
|
||||
|
||||
def find_references(
|
||||
self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500
|
||||
) -> list[ReferenceInfo]:
|
||||
"""Find all references (call sites) to a function across the codebase.
|
||||
|
||||
Uses tree-sitter to find all places where a JavaScript/TypeScript function
|
||||
is called, including direct calls, callbacks, memoized versions, and re-exports.
|
||||
|
||||
Args:
|
||||
function: The function to find references for.
|
||||
project_root: Root of the project to search.
|
||||
tests_root: Root of tests directory (references in tests are excluded).
|
||||
max_files: Maximum number of files to search.
|
||||
|
||||
Returns:
|
||||
List of ReferenceInfo objects describing each reference location.
|
||||
|
||||
"""
|
||||
from codeflash.languages.base import ReferenceInfo
|
||||
from codeflash.languages.javascript.find_references import ReferenceFinder
|
||||
|
||||
try:
|
||||
finder = ReferenceFinder(project_root)
|
||||
refs = finder.find_references(function, max_files=max_files)
|
||||
|
||||
# Convert to ReferenceInfo and filter out tests
|
||||
result: list[ReferenceInfo] = []
|
||||
for ref in refs:
|
||||
# Exclude test files if tests_root is provided
|
||||
if tests_root:
|
||||
try:
|
||||
ref.file_path.relative_to(tests_root)
|
||||
continue # Skip if in tests_root
|
||||
except ValueError:
|
||||
pass # Not in tests_root, include it
|
||||
|
||||
result.append(
|
||||
ReferenceInfo(
|
||||
file_path=ref.file_path,
|
||||
line=ref.line,
|
||||
column=ref.column,
|
||||
end_line=ref.end_line,
|
||||
end_column=ref.end_column,
|
||||
context=ref.context,
|
||||
reference_type=ref.reference_type,
|
||||
import_name=ref.import_name,
|
||||
caller_function=ref.caller_function,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to find references for %s: %s", function.function_name, e)
|
||||
return []
|
||||
|
||||
# === Code Transformation ===
|
||||
|
||||
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:
|
||||
def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
|
||||
"""Replace a function in source code with new implementation.
|
||||
|
||||
Uses node-based replacement to extract the method body from the optimized code
|
||||
|
|
@ -973,7 +1040,7 @@ class JavaScriptSupport:
|
|||
|
||||
Args:
|
||||
source: Original source code.
|
||||
function: FunctionInfo identifying the function to replace.
|
||||
function: FunctionToOptimize identifying the function to replace.
|
||||
new_source: New source code containing the optimized function.
|
||||
|
||||
Returns:
|
||||
|
|
@ -981,13 +1048,13 @@ class JavaScriptSupport:
|
|||
if new_source is empty or invalid.
|
||||
|
||||
"""
|
||||
if function.start_line is None or function.end_line is None:
|
||||
logger.error("Function %s has no line information", function.name)
|
||||
if function.starting_line is None or function.ending_line is None:
|
||||
logger.error("Function %s has no line information", function.function_name)
|
||||
return source
|
||||
|
||||
# If new_source is empty or whitespace-only, return original unchanged
|
||||
if not new_source or not new_source.strip():
|
||||
logger.warning("Empty new_source provided for %s, returning original", function.name)
|
||||
logger.warning("Empty new_source provided for %s, returning original", function.function_name)
|
||||
return source
|
||||
|
||||
# Get analyzer for parsing
|
||||
|
|
@ -1001,19 +1068,21 @@ class JavaScriptSupport:
|
|||
stripped_new_source = new_source.strip()
|
||||
if stripped_new_source.startswith("/**"):
|
||||
# new_source includes JSDoc, use full replacement to apply the new JSDoc
|
||||
if not self._contains_function_declaration(new_source, function.name, analyzer):
|
||||
logger.warning("new_source does not contain function %s, returning original", function.name)
|
||||
if not self._contains_function_declaration(new_source, function.function_name, analyzer):
|
||||
logger.warning("new_source does not contain function %s, returning original", function.function_name)
|
||||
return source
|
||||
return self._replace_function_text_based(source, function, new_source, analyzer)
|
||||
|
||||
# Extract just the method body from the new source
|
||||
new_body = self._extract_function_body(new_source, function.name, analyzer)
|
||||
new_body = self._extract_function_body(new_source, function.function_name, analyzer)
|
||||
if new_body is None:
|
||||
logger.warning("Could not extract body for %s from optimized code, using full replacement", function.name)
|
||||
logger.warning(
|
||||
"Could not extract body for %s from optimized code, using full replacement", function.function_name
|
||||
)
|
||||
# Verify that new_source contains actual code before falling back to text replacement
|
||||
# This prevents deletion of the original function when new_source is invalid
|
||||
if not self._contains_function_declaration(new_source, function.name, analyzer):
|
||||
logger.warning("new_source does not contain function %s, returning original", function.name)
|
||||
if not self._contains_function_declaration(new_source, function.function_name, analyzer):
|
||||
logger.warning("new_source does not contain function %s, returning original", function.function_name)
|
||||
return source
|
||||
return self._replace_function_text_based(source, function, new_source, analyzer)
|
||||
|
||||
|
|
@ -1141,7 +1210,7 @@ class JavaScriptSupport:
|
|||
return source_bytes[body_node.start_byte : body_node.end_byte].decode("utf8")
|
||||
|
||||
def _replace_function_body(
|
||||
self, source: str, function: FunctionInfo, new_body: str, analyzer: TreeSitterAnalyzer
|
||||
self, source: str, function: FunctionToOptimize, new_body: str, analyzer: TreeSitterAnalyzer
|
||||
) -> str:
|
||||
"""Replace the body of a function in source code with new body content.
|
||||
|
||||
|
|
@ -1149,7 +1218,7 @@ class JavaScriptSupport:
|
|||
|
||||
Args:
|
||||
source: Original source code.
|
||||
function: FunctionInfo identifying the function to modify.
|
||||
function: FunctionToOptimize identifying the function to modify.
|
||||
new_body: New body content (including braces).
|
||||
analyzer: TreeSitterAnalyzer for parsing.
|
||||
|
||||
|
|
@ -1200,9 +1269,9 @@ class JavaScriptSupport:
|
|||
|
||||
return None
|
||||
|
||||
func_node = find_function_at_line(tree.root_node, function.name, function.start_line)
|
||||
func_node = find_function_at_line(tree.root_node, function.function_name, function.starting_line)
|
||||
if not func_node:
|
||||
logger.warning("Could not find function %s at line %s", function.name, function.start_line)
|
||||
logger.warning("Could not find function %s at line %s", function.function_name, function.starting_line)
|
||||
return source
|
||||
|
||||
# Find the body node in the original
|
||||
|
|
@ -1214,7 +1283,7 @@ class JavaScriptSupport:
|
|||
break
|
||||
|
||||
if not body_node:
|
||||
logger.warning("Could not find body for function %s", function.name)
|
||||
logger.warning("Could not find body for function %s", function.function_name)
|
||||
return source
|
||||
|
||||
# Get the indentation of the original body's opening brace
|
||||
|
|
@ -1282,7 +1351,7 @@ class JavaScriptSupport:
|
|||
return result.decode("utf8")
|
||||
|
||||
def _replace_function_text_based(
|
||||
self, source: str, function: FunctionInfo, new_source: str, analyzer: TreeSitterAnalyzer
|
||||
self, source: str, function: FunctionToOptimize, new_source: str, analyzer: TreeSitterAnalyzer
|
||||
) -> str:
|
||||
"""Fallback text-based replacement when node-based replacement fails.
|
||||
|
||||
|
|
@ -1290,7 +1359,7 @@ class JavaScriptSupport:
|
|||
|
||||
Args:
|
||||
source: Original source code.
|
||||
function: FunctionInfo identifying the function to replace.
|
||||
function: FunctionToOptimize identifying the function to replace.
|
||||
new_source: New function source code.
|
||||
analyzer: TreeSitterAnalyzer for parsing.
|
||||
|
||||
|
|
@ -1307,16 +1376,16 @@ class JavaScriptSupport:
|
|||
tree_functions = analyzer.find_functions(source, include_methods=True, include_arrow_functions=True)
|
||||
target_func = None
|
||||
for func in tree_functions:
|
||||
if func.name == function.name and func.start_line == function.start_line:
|
||||
if func.name == function.function_name and func.start_line == function.starting_line:
|
||||
target_func = func
|
||||
break
|
||||
|
||||
# Use doc_start_line if available, otherwise fall back to start_line
|
||||
effective_start = (target_func.doc_start_line if target_func else None) or function.start_line
|
||||
effective_start = (target_func.doc_start_line if target_func else None) or function.starting_line
|
||||
|
||||
# Get indentation from original function's first line
|
||||
if function.start_line <= len(lines):
|
||||
original_first_line = lines[function.start_line - 1]
|
||||
if function.starting_line <= len(lines):
|
||||
original_first_line = lines[function.starting_line - 1]
|
||||
original_indent = len(original_first_line) - len(original_first_line.lstrip())
|
||||
else:
|
||||
original_indent = 0
|
||||
|
|
@ -1359,7 +1428,7 @@ class JavaScriptSupport:
|
|||
|
||||
# Build result
|
||||
before = lines[: effective_start - 1]
|
||||
after = lines[function.end_line :]
|
||||
after = lines[function.ending_line :]
|
||||
|
||||
result_lines = before + new_lines + after
|
||||
return "".join(result_lines)
|
||||
|
|
@ -1510,7 +1579,7 @@ class JavaScriptSupport:
|
|||
# === Instrumentation ===
|
||||
|
||||
def instrument_for_behavior(
|
||||
self, source: str, functions: Sequence[FunctionInfo], output_file: Path | None = None
|
||||
self, source: str, functions: Sequence[FunctionToOptimize], output_file: Path | None = None
|
||||
) -> str:
|
||||
"""Add behavior instrumentation to capture inputs/outputs.
|
||||
|
||||
|
|
@ -1539,7 +1608,7 @@ class JavaScriptSupport:
|
|||
tracer = JavaScriptTracer(output_file)
|
||||
return tracer.instrument_source(source, functions[0].file_path, list(functions))
|
||||
|
||||
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionInfo) -> str:
|
||||
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str:
|
||||
"""Add timing instrumentation to test code.
|
||||
|
||||
For JavaScript/Jest, we can use Jest's built-in timing or add custom timing.
|
||||
|
|
@ -1769,51 +1838,98 @@ class JavaScriptSupport:
|
|||
rel_path = source_file.relative_to(project_root)
|
||||
return "../" + rel_path.with_suffix("").as_posix()
|
||||
|
||||
def verify_requirements(self, project_root: Path, test_framework: str = "jest") -> tuple[bool, list[str]]:
|
||||
"""Verify that all JavaScript requirements are met.
|
||||
|
||||
Checks for:
|
||||
1. Node.js installation
|
||||
2. npm availability
|
||||
3. Test framework (jest/vitest) installation
|
||||
4. node_modules existence
|
||||
|
||||
Args:
|
||||
project_root: The project root directory.
|
||||
test_framework: The test framework to check for ("jest" or "vitest").
|
||||
|
||||
Returns:
|
||||
Tuple of (success, list of error messages).
|
||||
|
||||
"""
|
||||
errors: list[str] = []
|
||||
|
||||
# Check Node.js
|
||||
try:
|
||||
result = subprocess.run(["node", "--version"], check=False, capture_output=True, text=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
errors.append("Node.js is not installed. Please install Node.js 18+ from https://nodejs.org/")
|
||||
except FileNotFoundError:
|
||||
errors.append("Node.js is not installed. Please install Node.js 18+ from https://nodejs.org/")
|
||||
except Exception as e:
|
||||
errors.append(f"Failed to check Node.js: {e}")
|
||||
|
||||
# Check npm
|
||||
try:
|
||||
result = subprocess.run(["npm", "--version"], check=False, capture_output=True, text=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
errors.append("npm is not available. Please ensure npm is installed with Node.js.")
|
||||
except FileNotFoundError:
|
||||
errors.append("npm is not available. Please ensure npm is installed with Node.js.")
|
||||
except Exception as e:
|
||||
errors.append(f"Failed to check npm: {e}")
|
||||
|
||||
# Check node_modules exists
|
||||
node_modules = project_root / "node_modules"
|
||||
if not node_modules.exists():
|
||||
errors.append(
|
||||
f"node_modules not found in {project_root}. Please run 'npm install' to install dependencies."
|
||||
)
|
||||
else:
|
||||
# Check test framework is installed
|
||||
framework_path = node_modules / test_framework
|
||||
if not framework_path.exists():
|
||||
errors.append(
|
||||
f"{test_framework} is not installed. "
|
||||
f"Please run 'npm install --save-dev {test_framework}' to install it."
|
||||
)
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
def ensure_runtime_environment(self, project_root: Path) -> bool:
|
||||
"""Ensure codeflash npm package is installed.
|
||||
|
||||
Attempts to install the npm package for test instrumentation.
|
||||
Falls back to copying files if npm install fails.
|
||||
|
||||
Args:
|
||||
project_root: The project root directory.
|
||||
|
||||
Returns:
|
||||
True if npm package is installed, False if falling back to file copy.
|
||||
True if npm package is installed, False otherwise.
|
||||
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
# Check if package is already installed
|
||||
node_modules_pkg = project_root / "node_modules" / "codeflash"
|
||||
if node_modules_pkg.exists():
|
||||
logger.debug("codeflash already installed")
|
||||
return True
|
||||
|
||||
# Try to install from local package first (for development)
|
||||
local_package_path = Path(__file__).parent.parent.parent.parent / "packages" / "cli"
|
||||
if local_package_path.exists():
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["npm", "install", "--save-dev", str(local_package_path)],
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.debug("Installed codeflash from local package")
|
||||
return True
|
||||
logger.warning(f"Failed to install local package: {result.stderr}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error installing local package: {e}")
|
||||
|
||||
# Could try npm registry here in the future:
|
||||
# subprocess.run(["npm", "install", "--save-dev", "codeflash"], ...)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["npm", "install", "--save-dev", "codeflash"],
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.debug("Installed codeflash from npm registry")
|
||||
return True
|
||||
logger.warning(f"Failed to install codeflash: {result.stderr}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error installing codeflash: {e}")
|
||||
|
||||
logger.error("Could not install codeflash. Please run: npm install --save-dev codeflash")
|
||||
return False
|
||||
|
||||
def instrument_existing_test(
|
||||
|
|
@ -1853,7 +1969,7 @@ class JavaScriptSupport:
|
|||
def instrument_source_for_line_profiler(
|
||||
# TODO: use the context to instrument helper files also
|
||||
self,
|
||||
func_info: FunctionInfo,
|
||||
func_info: FunctionToOptimize,
|
||||
line_profiler_output_file: Path,
|
||||
) -> bool:
|
||||
from codeflash.languages.javascript.line_profiler import JavaScriptLineProfiler
|
||||
|
|
@ -1925,8 +2041,9 @@ class JavaScriptSupport:
|
|||
project_root: Path | None = None,
|
||||
enable_coverage: bool = False,
|
||||
candidate_index: int = 0,
|
||||
test_framework: str | None = None,
|
||||
) -> tuple[Path, Any, Path | None, Path | None]:
|
||||
"""Run Jest behavioral tests.
|
||||
"""Run behavioral tests using the detected test framework.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
|
|
@ -1936,11 +2053,29 @@ class JavaScriptSupport:
|
|||
project_root: Project root directory.
|
||||
enable_coverage: Whether to collect coverage information.
|
||||
candidate_index: Index of the candidate being tested.
|
||||
test_framework: Test framework to use ("jest" or "vitest"). If None, uses singleton.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result, coverage_path, config_path).
|
||||
|
||||
"""
|
||||
from codeflash.languages.test_framework import get_js_test_framework_or_default
|
||||
|
||||
framework = test_framework or get_js_test_framework_or_default()
|
||||
|
||||
if framework == "vitest":
|
||||
from codeflash.languages.javascript.vitest_runner import run_vitest_behavioral_tests
|
||||
|
||||
return run_vitest_behavioral_tests(
|
||||
test_paths=test_paths,
|
||||
test_env=test_env,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
project_root=project_root,
|
||||
enable_coverage=enable_coverage,
|
||||
candidate_index=candidate_index,
|
||||
)
|
||||
|
||||
from codeflash.languages.javascript.test_runner import run_jest_behavioral_tests
|
||||
|
||||
return run_jest_behavioral_tests(
|
||||
|
|
@ -1963,8 +2098,9 @@ class JavaScriptSupport:
|
|||
min_loops: int = 5,
|
||||
max_loops: int = 100_000,
|
||||
target_duration_seconds: float = 10.0,
|
||||
test_framework: str | None = None,
|
||||
) -> tuple[Path, Any]:
|
||||
"""Run Jest benchmarking tests.
|
||||
"""Run benchmarking tests using the detected test framework.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
|
|
@ -1975,11 +2111,30 @@ class JavaScriptSupport:
|
|||
min_loops: Minimum number of loops for benchmarking.
|
||||
max_loops: Maximum number of loops for benchmarking.
|
||||
target_duration_seconds: Target duration for benchmarking in seconds.
|
||||
test_framework: Test framework to use ("jest" or "vitest"). If None, uses singleton.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result).
|
||||
|
||||
"""
|
||||
from codeflash.languages.test_framework import get_js_test_framework_or_default
|
||||
|
||||
framework = test_framework or get_js_test_framework_or_default()
|
||||
|
||||
if framework == "vitest":
|
||||
from codeflash.languages.javascript.vitest_runner import run_vitest_benchmarking_tests
|
||||
|
||||
return run_vitest_benchmarking_tests(
|
||||
test_paths=test_paths,
|
||||
test_env=test_env,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
project_root=project_root,
|
||||
min_loops=min_loops,
|
||||
max_loops=max_loops,
|
||||
target_duration_ms=int(target_duration_seconds * 1000),
|
||||
)
|
||||
|
||||
from codeflash.languages.javascript.test_runner import run_jest_benchmarking_tests
|
||||
|
||||
return run_jest_benchmarking_tests(
|
||||
|
|
@ -2001,8 +2156,9 @@ class JavaScriptSupport:
|
|||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
line_profile_output_file: Path | None = None,
|
||||
test_framework: str | None = None,
|
||||
) -> tuple[Path, Any]:
|
||||
"""Run Jest tests for line profiling.
|
||||
"""Run tests for line profiling using the detected test framework.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
|
|
@ -2011,11 +2167,28 @@ class JavaScriptSupport:
|
|||
timeout: Optional timeout in seconds.
|
||||
project_root: Project root directory.
|
||||
line_profile_output_file: Path where line profile results will be written.
|
||||
test_framework: Test framework to use ("jest" or "vitest"). If None, uses singleton.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result).
|
||||
|
||||
"""
|
||||
from codeflash.languages.test_framework import get_js_test_framework_or_default
|
||||
|
||||
framework = test_framework or get_js_test_framework_or_default()
|
||||
|
||||
if framework == "vitest":
|
||||
from codeflash.languages.javascript.vitest_runner import run_vitest_line_profile_tests
|
||||
|
||||
return run_vitest_line_profile_tests(
|
||||
test_paths=test_paths,
|
||||
test_env=test_env,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
project_root=project_root,
|
||||
line_profile_output_file=line_profile_output_file,
|
||||
)
|
||||
|
||||
from codeflash.languages.javascript.test_runner import run_jest_line_profile_tests
|
||||
|
||||
return run_jest_line_profile_tests(
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.cli_cmds.init_javascript import get_package_install_command
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.config_consts import STABILITY_CENTER_TOLERANCE, STABILITY_SPREAD_TOLERANCE
|
||||
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
|
||||
|
|
@ -21,6 +22,254 @@ if TYPE_CHECKING:
|
|||
from codeflash.models.models import TestFiles
|
||||
|
||||
|
||||
def _detect_bundler_module_resolution(project_root: Path) -> bool:
|
||||
"""Detect if the project uses moduleResolution: 'bundler' in tsconfig.
|
||||
|
||||
TypeScript 5+ supports 'bundler' moduleResolution which requires
|
||||
module: 'preserve' or ES2015+. This can cause issues with ts-jest
|
||||
in some configurations.
|
||||
|
||||
This function also resolves extended tsconfigs to find bundler setting
|
||||
in parent configs.
|
||||
|
||||
Args:
|
||||
project_root: Root of the project to check.
|
||||
|
||||
Returns:
|
||||
True if the project uses bundler moduleResolution.
|
||||
|
||||
"""
|
||||
tsconfig_path = project_root / "tsconfig.json"
|
||||
if not tsconfig_path.exists():
|
||||
return False
|
||||
|
||||
visited_configs: set[Path] = set()
|
||||
|
||||
def check_tsconfig(config_path: Path) -> bool:
|
||||
"""Recursively check tsconfig and its extends for bundler moduleResolution."""
|
||||
if config_path in visited_configs:
|
||||
return False
|
||||
visited_configs.add(config_path)
|
||||
|
||||
if not config_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
content = config_path.read_text()
|
||||
tsconfig = json.loads(content)
|
||||
|
||||
# Check direct moduleResolution setting
|
||||
compiler_options = tsconfig.get("compilerOptions", {})
|
||||
module_resolution = compiler_options.get("moduleResolution", "").lower()
|
||||
if module_resolution == "bundler":
|
||||
return True
|
||||
|
||||
# Check extended config if present
|
||||
extends = tsconfig.get("extends")
|
||||
if extends:
|
||||
# Resolve the extended config path
|
||||
if extends.startswith("."):
|
||||
# Relative path
|
||||
extended_path = (config_path.parent / extends).resolve()
|
||||
if not extended_path.suffix:
|
||||
extended_path = extended_path.with_suffix(".json")
|
||||
else:
|
||||
# Package reference (e.g., "@n8n/typescript-config/modern/tsconfig.json")
|
||||
# Try to find it in node_modules
|
||||
node_modules_path = project_root / "node_modules" / extends
|
||||
if not node_modules_path.suffix:
|
||||
node_modules_path = node_modules_path.with_suffix(".json")
|
||||
if node_modules_path.exists():
|
||||
extended_path = node_modules_path
|
||||
else:
|
||||
# Try parent directories for monorepo support
|
||||
current = project_root.parent
|
||||
extended_path = None
|
||||
while current != current.parent:
|
||||
candidate = current / "node_modules" / extends
|
||||
if not candidate.suffix:
|
||||
candidate = candidate.with_suffix(".json")
|
||||
if candidate.exists():
|
||||
extended_path = candidate
|
||||
break
|
||||
# Also check packages directory for workspace packages
|
||||
packages_candidate = current / "packages" / extends
|
||||
if not packages_candidate.suffix:
|
||||
packages_candidate = packages_candidate.with_suffix(".json")
|
||||
if packages_candidate.exists():
|
||||
extended_path = packages_candidate
|
||||
break
|
||||
current = current.parent
|
||||
|
||||
if extended_path and extended_path.exists():
|
||||
return check_tsconfig(extended_path)
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to read {config_path}: {e}")
|
||||
return False
|
||||
|
||||
return check_tsconfig(tsconfig_path)
|
||||
|
||||
|
||||
def _create_codeflash_tsconfig(project_root: Path) -> Path:
|
||||
"""Create a codeflash-compatible tsconfig for projects using bundler moduleResolution.
|
||||
|
||||
This creates a tsconfig that inherits from the project's tsconfig but overrides
|
||||
moduleResolution to 'Node' for compatibility with ts-jest.
|
||||
|
||||
Args:
|
||||
project_root: Root of the project.
|
||||
|
||||
Returns:
|
||||
Path to the created tsconfig.codeflash.json file.
|
||||
|
||||
"""
|
||||
codeflash_tsconfig_path = project_root / "tsconfig.codeflash.json"
|
||||
|
||||
# If it already exists, use it
|
||||
if codeflash_tsconfig_path.exists():
|
||||
logger.debug(f"Using existing {codeflash_tsconfig_path}")
|
||||
return codeflash_tsconfig_path
|
||||
|
||||
# Read the original tsconfig to preserve most settings
|
||||
original_tsconfig_path = project_root / "tsconfig.json"
|
||||
try:
|
||||
original_content = original_tsconfig_path.read_text()
|
||||
original_tsconfig = json.loads(original_content)
|
||||
except Exception:
|
||||
original_tsconfig = {}
|
||||
|
||||
# Create a new tsconfig that extends the original but fixes moduleResolution
|
||||
codeflash_tsconfig = {
|
||||
"extends": "./tsconfig.json",
|
||||
"compilerOptions": {
|
||||
# Override bundler to Node for ts-jest compatibility
|
||||
"moduleResolution": "Node",
|
||||
# Ensure module is set to a compatible value
|
||||
"module": "ESNext",
|
||||
# These are generally safe defaults for testing
|
||||
"esModuleInterop": True,
|
||||
"skipLibCheck": True,
|
||||
"isolatedModules": True,
|
||||
},
|
||||
}
|
||||
|
||||
# Preserve include/exclude from original if not in extends
|
||||
if "include" in original_tsconfig:
|
||||
codeflash_tsconfig["include"] = original_tsconfig["include"]
|
||||
if "exclude" in original_tsconfig:
|
||||
codeflash_tsconfig["exclude"] = original_tsconfig["exclude"]
|
||||
|
||||
try:
|
||||
codeflash_tsconfig_path.write_text(json.dumps(codeflash_tsconfig, indent=2))
|
||||
logger.debug(f"Created {codeflash_tsconfig_path} with Node moduleResolution")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create codeflash tsconfig: {e}")
|
||||
|
||||
return codeflash_tsconfig_path
|
||||
|
||||
|
||||
def _create_codeflash_jest_config(project_root: Path, original_jest_config: Path | None) -> Path | None:
|
||||
"""Create a Jest config that uses the codeflash tsconfig for ts-jest.
|
||||
|
||||
Args:
|
||||
project_root: Root of the project.
|
||||
original_jest_config: Path to the original Jest config, or None.
|
||||
|
||||
Returns:
|
||||
Path to the codeflash Jest config, or None if creation failed.
|
||||
|
||||
"""
|
||||
codeflash_jest_config_path = project_root / "jest.codeflash.config.js"
|
||||
|
||||
# If it already exists, use it
|
||||
if codeflash_jest_config_path.exists():
|
||||
logger.debug(f"Using existing {codeflash_jest_config_path}")
|
||||
return codeflash_jest_config_path
|
||||
|
||||
# Create a wrapper Jest config that uses tsconfig.codeflash.json
|
||||
if original_jest_config:
|
||||
# Extend the original config
|
||||
jest_config_content = f"""// Auto-generated by codeflash for bundler moduleResolution compatibility
|
||||
const originalConfig = require('./{original_jest_config.name}');
|
||||
|
||||
const tsJestOptions = {{
|
||||
isolatedModules: true,
|
||||
tsconfig: 'tsconfig.codeflash.json',
|
||||
}};
|
||||
|
||||
module.exports = {{
|
||||
...originalConfig,
|
||||
transform: {{
|
||||
...originalConfig.transform,
|
||||
'^.+\\\\.tsx?$': ['ts-jest', tsJestOptions],
|
||||
}},
|
||||
globals: {{
|
||||
...originalConfig.globals,
|
||||
'ts-jest': tsJestOptions,
|
||||
}},
|
||||
}};
|
||||
"""
|
||||
else:
|
||||
# Create a minimal Jest config for TypeScript
|
||||
jest_config_content = """// Auto-generated by codeflash for bundler moduleResolution compatibility
|
||||
const tsJestOptions = {
|
||||
isolatedModules: true,
|
||||
tsconfig: 'tsconfig.codeflash.json',
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
verbose: true,
|
||||
testEnvironment: 'node',
|
||||
testRegex: '\\\\.(test|spec)\\\\.(js|ts|tsx)$',
|
||||
testPathIgnorePatterns: ['/dist/', '/node_modules/'],
|
||||
transform: {
|
||||
'^.+\\\\.tsx?$': ['ts-jest', tsJestOptions],
|
||||
},
|
||||
moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'],
|
||||
};
|
||||
"""
|
||||
|
||||
try:
|
||||
codeflash_jest_config_path.write_text(jest_config_content)
|
||||
logger.debug(f"Created {codeflash_jest_config_path} with codeflash tsconfig")
|
||||
return codeflash_jest_config_path
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create codeflash Jest config: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _get_jest_config_for_project(project_root: Path) -> Path | None:
|
||||
"""Get the appropriate Jest config for the project.
|
||||
|
||||
If the project uses bundler moduleResolution, creates and returns a
|
||||
codeflash-compatible Jest config. Otherwise, returns the project's
|
||||
existing Jest config.
|
||||
|
||||
Args:
|
||||
project_root: Root of the project.
|
||||
|
||||
Returns:
|
||||
Path to the Jest config to use, or None if not found.
|
||||
|
||||
"""
|
||||
# First check for existing Jest config
|
||||
original_jest_config = _find_jest_config(project_root)
|
||||
|
||||
# Check if project uses bundler moduleResolution
|
||||
if _detect_bundler_module_resolution(project_root):
|
||||
logger.info("Detected bundler moduleResolution - creating compatible config")
|
||||
# Create codeflash-compatible tsconfig
|
||||
_create_codeflash_tsconfig(project_root)
|
||||
# Create codeflash Jest config that uses it
|
||||
codeflash_jest_config = _create_codeflash_jest_config(project_root, original_jest_config)
|
||||
if codeflash_jest_config:
|
||||
return codeflash_jest_config
|
||||
|
||||
return original_jest_config
|
||||
|
||||
|
||||
def _find_node_project_root(file_path: Path) -> Path | None:
|
||||
"""Find the Node.js project root by looking for package.json.
|
||||
|
||||
|
|
@ -47,6 +296,67 @@ def _find_node_project_root(file_path: Path) -> Path | None:
|
|||
return None
|
||||
|
||||
|
||||
def _find_jest_config(project_root: Path) -> Path | None:
|
||||
"""Find Jest configuration file in the project.
|
||||
|
||||
Searches for common Jest config file names in the project root and parent
|
||||
directories (for monorepo support). This is important for TypeScript projects
|
||||
that require specific transformation configurations (e.g., next/jest, ts-jest, babel-jest).
|
||||
|
||||
Args:
|
||||
project_root: Root of the project to search.
|
||||
|
||||
Returns:
|
||||
Path to Jest config file, or None if not found.
|
||||
|
||||
"""
|
||||
# Common Jest config file names, in order of preference
|
||||
config_names = ["jest.config.ts", "jest.config.js", "jest.config.mjs", "jest.config.cjs", "jest.config.json"]
|
||||
|
||||
# First check the project root itself
|
||||
for config_name in config_names:
|
||||
config_path = project_root / config_name
|
||||
if config_path.exists():
|
||||
logger.debug(f"Found Jest config: {config_path}")
|
||||
return config_path
|
||||
|
||||
# For monorepos, search parent directories up to the filesystem root
|
||||
# Stop at common monorepo root indicators (git root, package.json with workspaces)
|
||||
current = project_root.parent
|
||||
max_depth = 5 # Don't search too far up
|
||||
depth = 0
|
||||
|
||||
while current != current.parent and depth < max_depth:
|
||||
for config_name in config_names:
|
||||
config_path = current / config_name
|
||||
if config_path.exists():
|
||||
logger.debug(f"Found Jest config in parent directory: {config_path}")
|
||||
return config_path
|
||||
|
||||
# Check if this looks like a monorepo root
|
||||
package_json = current / "package.json"
|
||||
if package_json.exists():
|
||||
try:
|
||||
import json
|
||||
|
||||
with package_json.open("r") as f:
|
||||
pkg = json.load(f)
|
||||
if "workspaces" in pkg:
|
||||
# This is likely the monorepo root, stop here
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for git root as another stopping point
|
||||
if (current / ".git").exists():
|
||||
break
|
||||
|
||||
current = current.parent
|
||||
depth += 1
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_esm_project(project_root: Path) -> bool:
|
||||
"""Check if the project uses ES Modules.
|
||||
|
||||
|
|
@ -145,54 +455,28 @@ def _ensure_runtime_files(project_root: Path) -> None:
|
|||
|
||||
Installs codeflash package if not already present.
|
||||
The package provides all runtime files needed for test instrumentation.
|
||||
Uses the project's detected package manager (npm, pnpm, yarn, or bun).
|
||||
|
||||
Args:
|
||||
project_root: The project root directory.
|
||||
|
||||
"""
|
||||
# Check if package is already installed
|
||||
node_modules_pkg = project_root / "node_modules" / "codeflash"
|
||||
if node_modules_pkg.exists():
|
||||
logger.debug("codeflash already installed")
|
||||
return
|
||||
|
||||
# Try to install from local package first (for development)
|
||||
local_package_path = Path(__file__).parent.parent.parent.parent / "packages" / "codeflash"
|
||||
if local_package_path.exists():
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["npm", "install", "--save-dev", str(local_package_path)],
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.debug("Installed codeflash from local package")
|
||||
return
|
||||
logger.warning(f"Failed to install local package: {result.stderr}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error installing local package: {e}")
|
||||
|
||||
# Try to install from npm registry
|
||||
install_cmd = get_package_install_command(project_root, "codeflash", dev=True)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["npm", "install", "--save-dev", "codeflash"],
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
result = subprocess.run(install_cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=120)
|
||||
if result.returncode == 0:
|
||||
logger.debug("Installed codeflash from npm registry")
|
||||
logger.debug(f"Installed codeflash using {install_cmd[0]}")
|
||||
return
|
||||
logger.warning(f"Failed to install from npm: {result.stderr}")
|
||||
logger.warning(f"Failed to install codeflash: {result.stderr}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error installing from npm: {e}")
|
||||
logger.warning(f"Error installing codeflash: {e}")
|
||||
|
||||
logger.error("Could not install codeflash. Please install it manually: npm install --save-dev codeflash")
|
||||
logger.error(f"Could not install codeflash. Please install it manually: {' '.join(install_cmd)}")
|
||||
|
||||
|
||||
def run_jest_behavioral_tests(
|
||||
|
|
@ -251,13 +535,26 @@ def run_jest_behavioral_tests(
|
|||
"--forceExit",
|
||||
]
|
||||
|
||||
# Add Jest config if found - needed for TypeScript transformation
|
||||
# Uses codeflash-compatible config if project has bundler moduleResolution
|
||||
jest_config = _get_jest_config_for_project(effective_cwd)
|
||||
if jest_config:
|
||||
jest_cmd.append(f"--config={jest_config}")
|
||||
|
||||
# Add coverage flags if enabled
|
||||
if enable_coverage:
|
||||
jest_cmd.extend(["--coverage", "--coverageReporters=json", f"--coverageDirectory={coverage_dir}"])
|
||||
|
||||
if test_files:
|
||||
jest_cmd.append("--runTestsByPath")
|
||||
jest_cmd.extend(str(Path(f).resolve()) for f in test_files)
|
||||
resolved_test_files = [str(Path(f).resolve()) for f in test_files]
|
||||
jest_cmd.extend(resolved_test_files)
|
||||
# Add --roots to include directories containing test files
|
||||
# This is needed because some projects configure Jest with restricted roots
|
||||
# (e.g., roots: ["<rootDir>/src"]) which excludes the test directory
|
||||
test_dirs = {str(Path(f).resolve().parent) for f in test_files}
|
||||
for test_dir in sorted(test_dirs):
|
||||
jest_cmd.extend(["--roots", test_dir])
|
||||
|
||||
if timeout:
|
||||
jest_cmd.append(f"--testTimeout={timeout * 1000}") # Jest uses milliseconds
|
||||
|
|
@ -306,6 +603,14 @@ def run_jest_behavioral_tests(
|
|||
args=result.args, returncode=result.returncode, stdout=result.stdout + "\n" + result.stderr, stderr=""
|
||||
)
|
||||
logger.debug(f"Jest result: returncode={result.returncode}")
|
||||
# Log Jest output at WARNING level if tests fail and no XML output will be created
|
||||
# This helps debug issues like import errors that cause Jest to fail early
|
||||
if result.returncode != 0 and not result_file_path.exists():
|
||||
logger.warning(
|
||||
f"Jest failed with returncode={result.returncode} and no XML output created.\n"
|
||||
f"Jest stdout: {result.stdout[:2000] if result.stdout else '(empty)'}\n"
|
||||
f"Jest stderr: {result.stderr[:500] if result.stderr else '(empty)'}"
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"Jest tests timed out after {timeout}s")
|
||||
result = subprocess.CompletedProcess(args=jest_cmd, returncode=-1, stdout="", stderr="Test execution timed out")
|
||||
|
|
@ -465,9 +770,20 @@ def run_jest_benchmarking_tests(
|
|||
"--runner=codeflash/loop-runner", # Use custom loop runner for in-process looping
|
||||
]
|
||||
|
||||
# Add Jest config if found - needed for TypeScript transformation
|
||||
# Uses codeflash-compatible config if project has bundler moduleResolution
|
||||
jest_config = _get_jest_config_for_project(effective_cwd)
|
||||
if jest_config:
|
||||
jest_cmd.append(f"--config={jest_config}")
|
||||
|
||||
if test_files:
|
||||
jest_cmd.append("--runTestsByPath")
|
||||
jest_cmd.extend(str(Path(f).resolve()) for f in test_files)
|
||||
resolved_test_files = [str(Path(f).resolve()) for f in test_files]
|
||||
jest_cmd.extend(resolved_test_files)
|
||||
# Add --roots to include directories containing test files
|
||||
test_dirs = {str(Path(f).resolve().parent) for f in test_files}
|
||||
for test_dir in sorted(test_dirs):
|
||||
jest_cmd.extend(["--roots", test_dir])
|
||||
|
||||
if timeout:
|
||||
jest_cmd.append(f"--testTimeout={timeout * 1000}")
|
||||
|
|
@ -594,9 +910,20 @@ def run_jest_line_profile_tests(
|
|||
"--forceExit",
|
||||
]
|
||||
|
||||
# Add Jest config if found - needed for TypeScript transformation
|
||||
# Uses codeflash-compatible config if project has bundler moduleResolution
|
||||
jest_config = _get_jest_config_for_project(effective_cwd)
|
||||
if jest_config:
|
||||
jest_cmd.append(f"--config={jest_config}")
|
||||
|
||||
if test_files:
|
||||
jest_cmd.append("--runTestsByPath")
|
||||
jest_cmd.extend(str(Path(f).resolve()) for f in test_files)
|
||||
resolved_test_files = [str(Path(f).resolve()) for f in test_files]
|
||||
jest_cmd.extend(resolved_test_files)
|
||||
# Add --roots to include directories containing test files
|
||||
test_dirs = {str(Path(f).resolve().parent) for f in test_files}
|
||||
for test_dir in sorted(test_dirs):
|
||||
jest_cmd.extend(["--roots", test_dir])
|
||||
|
||||
if timeout:
|
||||
jest_cmd.append(f"--testTimeout={timeout * 1000}")
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any
|
|||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -40,7 +40,7 @@ class JavaScriptTracer:
|
|||
self.output_db = output_db
|
||||
self.tracer_var = "__codeflash_tracer__"
|
||||
|
||||
def instrument_source(self, source: str, file_path: Path, functions: list[FunctionInfo]) -> str:
|
||||
def instrument_source(self, source: str, file_path: Path, functions: list[FunctionToOptimize]) -> str:
|
||||
"""Instrument JavaScript source code with function tracing.
|
||||
|
||||
Wraps specified functions to capture their inputs and outputs.
|
||||
|
|
@ -64,10 +64,10 @@ class JavaScriptTracer:
|
|||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Process functions in reverse order to preserve line numbers
|
||||
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
|
||||
for func in sorted(functions, key=lambda f: f.starting_line, reverse=True):
|
||||
instrumented = self._instrument_function(func, lines, file_path)
|
||||
start_idx = func.start_line - 1
|
||||
end_idx = func.end_line
|
||||
start_idx = func.starting_line - 1
|
||||
end_idx = func.ending_line
|
||||
lines = lines[:start_idx] + instrumented + lines[end_idx:]
|
||||
|
||||
instrumented_source = "".join(lines)
|
||||
|
|
@ -269,7 +269,7 @@ process.on('exit', () => {{
|
|||
}});
|
||||
"""
|
||||
|
||||
def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path) -> list[str]:
|
||||
def _instrument_function(self, func: FunctionToOptimize, lines: list[str], file_path: Path) -> list[str]:
|
||||
"""Instrument a single function with tracing.
|
||||
|
||||
Args:
|
||||
|
|
@ -281,11 +281,11 @@ process.on('exit', () => {{
|
|||
Instrumented function lines.
|
||||
|
||||
"""
|
||||
func_lines = lines[func.start_line - 1 : func.end_line]
|
||||
func_lines = lines[func.starting_line - 1 : func.ending_line]
|
||||
func_text = "".join(func_lines)
|
||||
|
||||
# Detect function pattern
|
||||
func_name = func.name
|
||||
func_name = func.function_name
|
||||
is_arrow = "=>" in func_text.split("\n")[0]
|
||||
is_method = func.is_method
|
||||
is_async = func.is_async
|
||||
|
|
|
|||
492
codeflash/languages/javascript/vitest_runner.py
Normal file
492
codeflash/languages/javascript/vitest_runner.py
Normal file
|
|
@ -0,0 +1,492 @@
|
|||
"""Vitest test runner for JavaScript/TypeScript.
|
||||
|
||||
This module provides functions for running Vitest tests for behavioral
|
||||
verification and performance benchmarking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.cli_cmds.init_javascript import get_package_install_command
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import TestFiles
|
||||
|
||||
|
||||
def _find_vitest_project_root(file_path: Path) -> Path | None:
|
||||
"""Find the Vitest project root by looking for vitest/vite config or package.json.
|
||||
|
||||
Traverses up from the given file path to find the nearest directory
|
||||
containing vitest.config.js/ts, vite.config.js/ts, or package.json.
|
||||
|
||||
Args:
|
||||
file_path: A file path within the Vitest project.
|
||||
|
||||
Returns:
|
||||
The project root directory, or None if not found.
|
||||
|
||||
"""
|
||||
current = file_path.parent if file_path.is_file() else file_path
|
||||
while current != current.parent: # Stop at filesystem root
|
||||
# Check for Vitest-specific config files first
|
||||
if (
|
||||
(current / "vitest.config.js").exists()
|
||||
or (current / "vitest.config.ts").exists()
|
||||
or (current / "vitest.config.mjs").exists()
|
||||
or (current / "vitest.config.mts").exists()
|
||||
or (current / "vite.config.js").exists()
|
||||
or (current / "vite.config.ts").exists()
|
||||
or (current / "vite.config.mjs").exists()
|
||||
or (current / "vite.config.mts").exists()
|
||||
or (current / "package.json").exists()
|
||||
):
|
||||
return current
|
||||
current = current.parent
|
||||
return None
|
||||
|
||||
|
||||
def _is_vitest_coverage_available(project_root: Path) -> bool:
|
||||
"""Check if Vitest coverage package is available.
|
||||
|
||||
Args:
|
||||
project_root: The project root directory.
|
||||
|
||||
Returns:
|
||||
True if @vitest/coverage-v8 or @vitest/coverage-istanbul is installed.
|
||||
|
||||
"""
|
||||
node_modules = project_root / "node_modules"
|
||||
return (node_modules / "@vitest" / "coverage-v8").exists() or (
|
||||
node_modules / "@vitest" / "coverage-istanbul"
|
||||
).exists()
|
||||
|
||||
|
||||
def _ensure_runtime_files(project_root: Path) -> None:
|
||||
"""Ensure JavaScript runtime package is installed in the project.
|
||||
|
||||
Installs codeflash package if not already present.
|
||||
The package provides all runtime files needed for test instrumentation.
|
||||
Uses the project's detected package manager (npm, pnpm, yarn, or bun).
|
||||
|
||||
Args:
|
||||
project_root: The project root directory.
|
||||
|
||||
"""
|
||||
node_modules_pkg = project_root / "node_modules" / "codeflash"
|
||||
if node_modules_pkg.exists():
|
||||
logger.debug("codeflash already installed")
|
||||
return
|
||||
|
||||
install_cmd = get_package_install_command(project_root, "codeflash", dev=True)
|
||||
try:
|
||||
result = subprocess.run(install_cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=120)
|
||||
if result.returncode == 0:
|
||||
logger.debug(f"Installed codeflash using {install_cmd[0]}")
|
||||
return
|
||||
logger.warning(f"Failed to install codeflash: {result.stderr}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error installing codeflash: {e}")
|
||||
|
||||
logger.error(f"Could not install codeflash. Please install it manually: {' '.join(install_cmd)}")
|
||||
|
||||
|
||||
def _build_vitest_behavioral_command(
|
||||
test_files: list[Path], timeout: int | None = None, output_file: Path | None = None
|
||||
) -> list[str]:
|
||||
"""Build Vitest command for behavioral tests.
|
||||
|
||||
Args:
|
||||
test_files: List of test files to run.
|
||||
timeout: Optional timeout in seconds.
|
||||
output_file: Optional path for JUnit XML output.
|
||||
|
||||
Returns:
|
||||
Command list for subprocess execution.
|
||||
|
||||
"""
|
||||
cmd = [
|
||||
"npx",
|
||||
"vitest",
|
||||
"run", # Single execution (not watch mode)
|
||||
"--reporter=default",
|
||||
"--reporter=junit",
|
||||
"--no-file-parallelism", # Serial execution for deterministic timing
|
||||
]
|
||||
|
||||
if output_file:
|
||||
cmd.append(f"--outputFile={output_file}")
|
||||
|
||||
if timeout:
|
||||
cmd.append(f"--test-timeout={timeout * 1000}") # Vitest uses milliseconds
|
||||
|
||||
# Add test files as positional arguments (Vitest style)
|
||||
cmd.extend(str(f.resolve()) for f in test_files)
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def _build_vitest_benchmarking_command(
|
||||
test_files: list[Path], timeout: int | None = None, output_file: Path | None = None
|
||||
) -> list[str]:
|
||||
"""Build Vitest command for benchmarking tests.
|
||||
|
||||
Args:
|
||||
test_files: List of test files to run.
|
||||
timeout: Optional timeout in seconds.
|
||||
output_file: Optional path for JUnit XML output.
|
||||
|
||||
Returns:
|
||||
Command list for subprocess execution.
|
||||
|
||||
"""
|
||||
cmd = [
|
||||
"npx",
|
||||
"vitest",
|
||||
"run", # Single execution (not watch mode)
|
||||
"--reporter=default",
|
||||
"--reporter=junit",
|
||||
"--no-file-parallelism", # Serial execution for consistent benchmarking
|
||||
]
|
||||
|
||||
if output_file:
|
||||
cmd.append(f"--outputFile={output_file}")
|
||||
|
||||
if timeout:
|
||||
cmd.append(f"--test-timeout={timeout * 1000}")
|
||||
|
||||
# Add test files as positional arguments
|
||||
cmd.extend(str(f.resolve()) for f in test_files)
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def run_vitest_behavioral_tests(
|
||||
test_paths: TestFiles,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
*,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
enable_coverage: bool = False,
|
||||
candidate_index: int = 0,
|
||||
) -> tuple[Path, subprocess.CompletedProcess, Path | None, Path | None]:
|
||||
"""Run Vitest tests and return results in a format compatible with pytest output.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
test_env: Environment variables for the test run.
|
||||
cwd: Working directory for running tests.
|
||||
timeout: Optional timeout in seconds.
|
||||
project_root: Vitest project root (directory containing vitest.config or package.json).
|
||||
enable_coverage: Whether to collect coverage information.
|
||||
candidate_index: Index of the candidate being tested.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result, coverage_json_path, None).
|
||||
|
||||
"""
|
||||
result_file_path = get_run_tmp_file(Path("vitest_results.xml"))
|
||||
|
||||
# Get test files to run
|
||||
test_files = [Path(file.instrumented_behavior_file_path) for file in test_paths.test_files]
|
||||
|
||||
# Use provided project_root, or detect it as fallback
|
||||
if project_root is None and test_files:
|
||||
project_root = _find_vitest_project_root(test_files[0])
|
||||
|
||||
# Use the project root, or fall back to provided cwd
|
||||
effective_cwd = project_root if project_root else cwd
|
||||
logger.debug(f"Vitest working directory: {effective_cwd}")
|
||||
|
||||
# Ensure the codeflash npm package is installed
|
||||
_ensure_runtime_files(effective_cwd)
|
||||
|
||||
# Coverage output directory - only enable if coverage package is available
|
||||
coverage_dir = get_run_tmp_file(Path("vitest_coverage"))
|
||||
coverage_available = _is_vitest_coverage_available(effective_cwd) if enable_coverage else False
|
||||
coverage_json_path = coverage_dir / "coverage-final.json" if coverage_available else None
|
||||
|
||||
if enable_coverage and not coverage_available:
|
||||
logger.debug("Vitest coverage package not installed, running without coverage")
|
||||
|
||||
# Build Vitest command
|
||||
vitest_cmd = _build_vitest_behavioral_command(test_files=test_files, timeout=timeout, output_file=result_file_path)
|
||||
|
||||
# Add coverage flags only if coverage is available
|
||||
if coverage_available:
|
||||
vitest_cmd.extend(["--coverage", "--coverage.reporter=json", f"--coverage.reportsDirectory={coverage_dir}"])
|
||||
|
||||
# Set up environment
|
||||
vitest_env = test_env.copy()
|
||||
# Set codeflash output file for the vitest helper to write timing/behavior data (SQLite format)
|
||||
codeflash_sqlite_file = get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite"))
|
||||
vitest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
|
||||
vitest_env["CODEFLASH_TEST_ITERATION"] = str(candidate_index)
|
||||
vitest_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
vitest_env["CODEFLASH_MODE"] = "behavior"
|
||||
# Seed random number generator for reproducible test runs across original and optimized code
|
||||
vitest_env["CODEFLASH_RANDOM_SEED"] = "42"
|
||||
|
||||
logger.debug(f"Running Vitest tests with command: {' '.join(vitest_cmd)}")
|
||||
|
||||
# Subprocess timeout should be much larger than per-test timeout to account for:
|
||||
# - Vitest startup time (loading modules, compiling TypeScript)
|
||||
# - Multiple tests running sequentially
|
||||
# Use at least 120 seconds, or 10x the per-test timeout, whichever is larger
|
||||
subprocess_timeout = max(120, (timeout or 60) * 10)
|
||||
|
||||
start_time_ns = time.perf_counter_ns()
|
||||
try:
|
||||
run_args = get_cross_platform_subprocess_run_args(
|
||||
cwd=effective_cwd, env=vitest_env, timeout=subprocess_timeout, check=False, text=True, capture_output=True
|
||||
)
|
||||
result = subprocess.run(vitest_cmd, **run_args) # noqa: PLW1510
|
||||
# Combine stderr into stdout for timing markers
|
||||
if result.stderr and not result.stdout:
|
||||
result = subprocess.CompletedProcess(
|
||||
args=result.args, returncode=result.returncode, stdout=result.stderr, stderr=""
|
||||
)
|
||||
elif result.stderr:
|
||||
result = subprocess.CompletedProcess(
|
||||
args=result.args, returncode=result.returncode, stdout=result.stdout + "\n" + result.stderr, stderr=""
|
||||
)
|
||||
logger.debug(f"Vitest result: returncode={result.returncode}")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"Vitest tests timed out after {subprocess_timeout}s")
|
||||
result = subprocess.CompletedProcess(
|
||||
args=vitest_cmd, returncode=-1, stdout="", stderr="Test execution timed out"
|
||||
)
|
||||
except FileNotFoundError:
|
||||
logger.error("Vitest not found. Make sure Vitest is installed (npm install vitest)")
|
||||
result = subprocess.CompletedProcess(
|
||||
args=vitest_cmd, returncode=-1, stdout="", stderr="Vitest not found. Run: npm install vitest"
|
||||
)
|
||||
finally:
|
||||
wall_clock_ns = time.perf_counter_ns() - start_time_ns
|
||||
logger.debug(f"Vitest behavioral tests completed in {wall_clock_ns / 1e9:.2f}s")
|
||||
|
||||
return result_file_path, result, coverage_json_path, None
|
||||
|
||||
|
||||
def run_vitest_benchmarking_tests(
|
||||
test_paths: TestFiles,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
*,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
min_loops: int = 5,
|
||||
max_loops: int = 100,
|
||||
target_duration_ms: int = 10_000,
|
||||
stability_check: bool = True,
|
||||
) -> tuple[Path, subprocess.CompletedProcess]:
|
||||
"""Run Vitest benchmarking tests with external looping from Python.
|
||||
|
||||
Uses external process-level looping to run tests multiple times and
|
||||
collect timing data. This matches the Python pytest approach where
|
||||
looping is controlled externally for simplicity.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
test_env: Environment variables for the test run.
|
||||
cwd: Working directory for running tests.
|
||||
timeout: Optional timeout in seconds for the entire benchmark run.
|
||||
project_root: Vitest project root (directory containing vitest.config or package.json).
|
||||
min_loops: Minimum number of loop iterations.
|
||||
max_loops: Maximum number of loop iterations.
|
||||
target_duration_ms: Target TOTAL duration in milliseconds for all loops.
|
||||
stability_check: Whether to enable stability-based early stopping.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result with stdout from all iterations).
|
||||
|
||||
"""
|
||||
result_file_path = get_run_tmp_file(Path("vitest_perf_results.xml"))
|
||||
|
||||
# Get performance test files
|
||||
test_files = [Path(file.benchmarking_file_path) for file in test_paths.test_files if file.benchmarking_file_path]
|
||||
|
||||
# Use provided project_root, or detect it as fallback
|
||||
if project_root is None and test_files:
|
||||
project_root = _find_vitest_project_root(test_files[0])
|
||||
|
||||
effective_cwd = project_root if project_root else cwd
|
||||
logger.debug(f"Vitest benchmarking working directory: {effective_cwd}")
|
||||
|
||||
# Ensure the codeflash npm package is installed
|
||||
_ensure_runtime_files(effective_cwd)
|
||||
|
||||
# Build Vitest command for performance tests
|
||||
vitest_cmd = _build_vitest_benchmarking_command(
|
||||
test_files=test_files, timeout=timeout, output_file=result_file_path
|
||||
)
|
||||
|
||||
# Base environment setup
|
||||
vitest_env = test_env.copy()
|
||||
codeflash_sqlite_file = get_run_tmp_file(Path("test_return_values_0.sqlite"))
|
||||
vitest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
|
||||
vitest_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
vitest_env["CODEFLASH_MODE"] = "performance"
|
||||
vitest_env["CODEFLASH_RANDOM_SEED"] = "42"
|
||||
|
||||
# Internal loop configuration for capturePerf
|
||||
vitest_env["CODEFLASH_PERF_LOOP_COUNT"] = str(max_loops)
|
||||
vitest_env["CODEFLASH_PERF_MIN_LOOPS"] = str(min_loops)
|
||||
vitest_env["CODEFLASH_PERF_TARGET_DURATION_MS"] = str(target_duration_ms)
|
||||
vitest_env["CODEFLASH_PERF_STABILITY_CHECK"] = "true" if stability_check else "false"
|
||||
vitest_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
|
||||
# Total timeout for the entire benchmark run
|
||||
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)
|
||||
|
||||
logger.debug(f"Running Vitest benchmarking tests: {' '.join(vitest_cmd)}")
|
||||
logger.debug(
|
||||
f"Vitest benchmarking config: min_loops={min_loops}, max_loops={max_loops}, "
|
||||
f"target_duration={target_duration_ms}ms, stability_check={stability_check}"
|
||||
)
|
||||
|
||||
total_start_time = time.time()
|
||||
|
||||
try:
|
||||
run_args = get_cross_platform_subprocess_run_args(
|
||||
cwd=effective_cwd, env=vitest_env, timeout=total_timeout, check=False, text=True, capture_output=True
|
||||
)
|
||||
result = subprocess.run(vitest_cmd, **run_args) # noqa: PLW1510
|
||||
|
||||
# Combine stderr into stdout for timing markers
|
||||
stdout = result.stdout or ""
|
||||
if result.stderr:
|
||||
stdout = stdout + "\n" + result.stderr if stdout else result.stderr
|
||||
|
||||
result = subprocess.CompletedProcess(args=result.args, returncode=result.returncode, stdout=stdout, stderr="")
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"Vitest benchmarking timed out after {total_timeout}s")
|
||||
result = subprocess.CompletedProcess(args=vitest_cmd, returncode=-1, stdout="", stderr="Benchmarking timed out")
|
||||
except FileNotFoundError:
|
||||
logger.error("Vitest not found for benchmarking")
|
||||
result = subprocess.CompletedProcess(args=vitest_cmd, returncode=-1, stdout="", stderr="Vitest not found")
|
||||
|
||||
wall_clock_seconds = time.time() - total_start_time
|
||||
logger.debug(f"Vitest benchmarking completed in {wall_clock_seconds:.2f}s")
|
||||
|
||||
return result_file_path, result
|
||||
|
||||
|
||||
def run_vitest_line_profile_tests(
|
||||
test_paths: TestFiles,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
*,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
line_profile_output_file: Path | None = None,
|
||||
) -> tuple[Path, subprocess.CompletedProcess]:
|
||||
"""Run Vitest tests for line profiling.
|
||||
|
||||
This runs tests against source code that has been instrumented with line profiler.
|
||||
The instrumentation collects execution counts and timing per line.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
test_env: Environment variables for the test run.
|
||||
cwd: Working directory for running tests.
|
||||
timeout: Optional timeout in seconds for the subprocess.
|
||||
project_root: Vitest project root (directory containing vitest.config or package.json).
|
||||
line_profile_output_file: Path where line profile results will be written.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result).
|
||||
|
||||
"""
|
||||
result_file_path = get_run_tmp_file(Path("vitest_line_profile_results.xml"))
|
||||
|
||||
# Get test files to run - use instrumented behavior files if available, otherwise benchmarking files
|
||||
test_files = []
|
||||
for file in test_paths.test_files:
|
||||
if file.instrumented_behavior_file_path:
|
||||
test_files.append(Path(file.instrumented_behavior_file_path))
|
||||
elif file.benchmarking_file_path:
|
||||
test_files.append(Path(file.benchmarking_file_path))
|
||||
|
||||
# Use provided project_root, or detect it as fallback
|
||||
if project_root is None and test_files:
|
||||
project_root = _find_vitest_project_root(test_files[0])
|
||||
|
||||
effective_cwd = project_root if project_root else cwd
|
||||
logger.debug(f"Vitest line profiling working directory: {effective_cwd}")
|
||||
|
||||
# Ensure the codeflash npm package is installed
|
||||
_ensure_runtime_files(effective_cwd)
|
||||
|
||||
# Build Vitest command for line profiling - simple run without benchmarking loops
|
||||
vitest_cmd = [
|
||||
"npx",
|
||||
"vitest",
|
||||
"run",
|
||||
"--reporter=default",
|
||||
"--reporter=junit",
|
||||
"--no-file-parallelism", # Serial execution for consistent line profiling
|
||||
]
|
||||
|
||||
vitest_cmd.append(f"--outputFile={result_file_path}")
|
||||
|
||||
if timeout:
|
||||
vitest_cmd.append(f"--test-timeout={timeout * 1000}")
|
||||
|
||||
vitest_cmd.extend(str(f.resolve()) for f in test_files)
|
||||
|
||||
# Set up environment
|
||||
vitest_env = test_env.copy()
|
||||
codeflash_sqlite_file = get_run_tmp_file(Path("test_return_values_line_profile.sqlite"))
|
||||
vitest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
|
||||
vitest_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
vitest_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
vitest_env["CODEFLASH_MODE"] = "line_profile"
|
||||
vitest_env["CODEFLASH_RANDOM_SEED"] = "42"
|
||||
|
||||
# Pass the line profile output file path to the instrumented code
|
||||
if line_profile_output_file:
|
||||
vitest_env["CODEFLASH_LINE_PROFILE_OUTPUT"] = str(line_profile_output_file)
|
||||
|
||||
# Subprocess timeout should be larger than per-test timeout to account for startup
|
||||
subprocess_timeout = max(120, (timeout or 60) * 10)
|
||||
|
||||
logger.debug(f"Running Vitest line profile tests: {' '.join(vitest_cmd)}")
|
||||
|
||||
start_time_ns = time.perf_counter_ns()
|
||||
try:
|
||||
run_args = get_cross_platform_subprocess_run_args(
|
||||
cwd=effective_cwd, env=vitest_env, timeout=subprocess_timeout, check=False, text=True, capture_output=True
|
||||
)
|
||||
result = subprocess.run(vitest_cmd, **run_args) # noqa: PLW1510
|
||||
# Combine stderr into stdout
|
||||
if result.stderr and not result.stdout:
|
||||
result = subprocess.CompletedProcess(
|
||||
args=result.args, returncode=result.returncode, stdout=result.stderr, stderr=""
|
||||
)
|
||||
elif result.stderr:
|
||||
result = subprocess.CompletedProcess(
|
||||
args=result.args, returncode=result.returncode, stdout=result.stdout + "\n" + result.stderr, stderr=""
|
||||
)
|
||||
logger.debug(f"Vitest line profile result: returncode={result.returncode}")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"Vitest line profile tests timed out after {subprocess_timeout}s")
|
||||
result = subprocess.CompletedProcess(
|
||||
args=vitest_cmd, returncode=-1, stdout="", stderr="Line profile tests timed out"
|
||||
)
|
||||
except FileNotFoundError:
|
||||
logger.error("Vitest not found for line profiling")
|
||||
result = subprocess.CompletedProcess(args=vitest_cmd, returncode=-1, stdout="", stderr="Vitest not found")
|
||||
finally:
|
||||
wall_clock_ns = time.perf_counter_ns() - start_time_ns
|
||||
logger.debug(f"Vitest line profile tests completed in {wall_clock_ns / 1e9:.2f}s")
|
||||
|
||||
return result_file_path, result
|
||||
18
codeflash/languages/language_enum.py
Normal file
18
codeflash/languages/language_enum.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Language enum for multi-language support.
|
||||
|
||||
This module is kept separate to avoid circular imports.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Language(str, Enum):
|
||||
"""Supported programming languages."""
|
||||
|
||||
PYTHON = "python"
|
||||
JAVASCRIPT = "javascript"
|
||||
TYPESCRIPT = "typescript"
|
||||
JAVA = "java"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
|
@ -6,13 +6,13 @@ import logging
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import (
|
||||
CodeContext,
|
||||
FunctionFilterCriteria,
|
||||
FunctionInfo,
|
||||
HelperFunction,
|
||||
Language,
|
||||
ParentInfo,
|
||||
ReferenceInfo,
|
||||
TestInfo,
|
||||
TestResult,
|
||||
)
|
||||
|
|
@ -45,6 +45,11 @@ class PythonSupport:
|
|||
"""File extensions supported by Python."""
|
||||
return (".py", ".pyw")
|
||||
|
||||
@property
|
||||
def default_file_extension(self) -> str:
|
||||
"""Default file extension for Python."""
|
||||
return ".py"
|
||||
|
||||
@property
|
||||
def test_framework(self) -> str:
|
||||
"""Primary test framework for Python."""
|
||||
|
|
@ -58,7 +63,7 @@ class PythonSupport:
|
|||
|
||||
def discover_functions(
|
||||
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
|
||||
) -> list[FunctionInfo]:
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Find all optimizable functions in a Python file.
|
||||
|
||||
Uses libcst to parse the file and find functions with return statements.
|
||||
|
|
@ -68,12 +73,12 @@ class PythonSupport:
|
|||
filter_criteria: Optional criteria to filter functions.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo objects for discovered functions.
|
||||
List of FunctionToOptimize objects for discovered functions.
|
||||
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionVisitor
|
||||
from codeflash.discovery.functions_to_optimize import FunctionVisitor
|
||||
|
||||
criteria = filter_criteria or FunctionFilterCriteria()
|
||||
|
||||
|
|
@ -90,7 +95,7 @@ class PythonSupport:
|
|||
function_visitor = FunctionVisitor(file_path=str(file_path))
|
||||
wrapper.visit(function_visitor)
|
||||
|
||||
functions: list[FunctionInfo] = []
|
||||
functions: list[FunctionToOptimize] = []
|
||||
for func in function_visitor.functions:
|
||||
if not isinstance(func, FunctionToOptimize):
|
||||
continue
|
||||
|
|
@ -107,23 +112,20 @@ class PythonSupport:
|
|||
if criteria.require_return and func.starting_line is None:
|
||||
continue
|
||||
|
||||
# Convert FunctionToOptimize to FunctionInfo
|
||||
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in func.parents)
|
||||
|
||||
functions.append(
|
||||
FunctionInfo(
|
||||
name=func.function_name,
|
||||
file_path=file_path,
|
||||
start_line=func.starting_line or 1,
|
||||
end_line=func.ending_line or 1,
|
||||
start_col=func.starting_col,
|
||||
end_col=func.ending_col,
|
||||
parents=parents,
|
||||
is_async=func.is_async,
|
||||
is_method=len(func.parents) > 0,
|
||||
language=Language.PYTHON,
|
||||
)
|
||||
# Add is_method field based on parents
|
||||
func_with_is_method = FunctionToOptimize(
|
||||
function_name=func.function_name,
|
||||
file_path=file_path,
|
||||
parents=func.parents,
|
||||
starting_line=func.starting_line,
|
||||
ending_line=func.ending_line,
|
||||
starting_col=func.starting_col,
|
||||
ending_col=func.ending_col,
|
||||
is_async=func.is_async,
|
||||
is_method=len(func.parents) > 0 and any(p.type == "ClassDef" for p in func.parents),
|
||||
language="python",
|
||||
)
|
||||
functions.append(func_with_is_method)
|
||||
|
||||
return functions
|
||||
|
||||
|
|
@ -131,7 +133,9 @@ class PythonSupport:
|
|||
logger.warning("Failed to discover functions in %s: %s", file_path, e)
|
||||
return []
|
||||
|
||||
def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionInfo]) -> dict[str, list[TestInfo]]:
|
||||
def discover_tests(
|
||||
self, test_root: Path, source_functions: Sequence[FunctionToOptimize]
|
||||
) -> dict[str, list[TestInfo]]:
|
||||
"""Map source functions to their tests via static analysis.
|
||||
|
||||
Args:
|
||||
|
|
@ -155,7 +159,7 @@ class PythonSupport:
|
|||
try:
|
||||
source = test_file.read_text()
|
||||
# Check if function name appears in test file
|
||||
if func.name in source:
|
||||
if func.function_name in source:
|
||||
result[func.qualified_name].append(
|
||||
TestInfo(test_name=test_file.stem, test_file=test_file, test_class=None)
|
||||
)
|
||||
|
|
@ -166,7 +170,7 @@ class PythonSupport:
|
|||
|
||||
# === Code Analysis ===
|
||||
|
||||
def extract_code_context(self, function: FunctionInfo, project_root: Path, module_root: Path) -> CodeContext:
|
||||
def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext:
|
||||
"""Extract function code and its dependencies.
|
||||
|
||||
Uses jedi and libcst for Python code analysis.
|
||||
|
|
@ -188,8 +192,8 @@ class PythonSupport:
|
|||
|
||||
# Extract the function source
|
||||
lines = source.splitlines(keepends=True)
|
||||
if function.start_line and function.end_line:
|
||||
target_lines = lines[function.start_line - 1 : function.end_line]
|
||||
if function.starting_line and function.ending_line:
|
||||
target_lines = lines[function.starting_line - 1 : function.ending_line]
|
||||
target_code = "".join(target_lines)
|
||||
else:
|
||||
target_code = ""
|
||||
|
|
@ -216,7 +220,7 @@ class PythonSupport:
|
|||
language=Language.PYTHON,
|
||||
)
|
||||
|
||||
def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> list[HelperFunction]:
|
||||
def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]:
|
||||
"""Find helper functions called by the target function.
|
||||
|
||||
Uses jedi for Python code analysis.
|
||||
|
|
@ -285,20 +289,130 @@ class PythonSupport:
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to find helpers for %s: %s", function.name, e)
|
||||
logger.warning("Failed to find helpers for %s: %s", function.function_name, e)
|
||||
|
||||
return helpers
|
||||
|
||||
def find_references(
|
||||
self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500
|
||||
) -> list[ReferenceInfo]:
|
||||
"""Find all references (call sites) to a function across the codebase.
|
||||
|
||||
Uses jedi to find all places where a Python function is called.
|
||||
|
||||
Args:
|
||||
function: The function to find references for.
|
||||
project_root: Root of the project to search.
|
||||
tests_root: Root of tests directory (references in tests are excluded).
|
||||
max_files: Maximum number of files to search.
|
||||
|
||||
Returns:
|
||||
List of ReferenceInfo objects describing each reference location.
|
||||
|
||||
"""
|
||||
try:
|
||||
import jedi
|
||||
|
||||
source = function.file_path.read_text()
|
||||
|
||||
# Find the function position
|
||||
script = jedi.Script(code=source, path=function.file_path)
|
||||
names = script.get_names(all_scopes=True, definitions=True)
|
||||
|
||||
function_pos = None
|
||||
for name in names:
|
||||
if name.type == "function" and name.name == function.name:
|
||||
# Check for class parent if it's a method
|
||||
if function.class_name:
|
||||
parent = name.parent()
|
||||
if parent and parent.name == function.class_name and parent.type == "class":
|
||||
function_pos = (name.line, name.column)
|
||||
break
|
||||
else:
|
||||
function_pos = (name.line, name.column)
|
||||
break
|
||||
|
||||
if function_pos is None:
|
||||
return []
|
||||
|
||||
# Get references using jedi
|
||||
script = jedi.Script(code=source, path=function.file_path, project=jedi.Project(path=project_root))
|
||||
references = script.get_references(line=function_pos[0], column=function_pos[1])
|
||||
|
||||
result: list[ReferenceInfo] = []
|
||||
seen_locations: set[tuple[Path, int, int]] = set()
|
||||
|
||||
for ref in references:
|
||||
if not ref.module_path:
|
||||
continue
|
||||
|
||||
ref_path = Path(ref.module_path)
|
||||
|
||||
# Skip the definition itself
|
||||
if ref_path == function.file_path and ref.line == function_pos[0]:
|
||||
continue
|
||||
|
||||
# Skip test files
|
||||
if tests_root:
|
||||
try:
|
||||
ref_path.relative_to(tests_root)
|
||||
continue
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Avoid duplicates
|
||||
loc_key = (ref_path, ref.line, ref.column)
|
||||
if loc_key in seen_locations:
|
||||
continue
|
||||
seen_locations.add(loc_key)
|
||||
|
||||
# Get context line
|
||||
try:
|
||||
ref_source = ref_path.read_text()
|
||||
lines = ref_source.splitlines()
|
||||
context = lines[ref.line - 1] if ref.line <= len(lines) else ""
|
||||
except Exception:
|
||||
context = ""
|
||||
|
||||
# Determine caller function
|
||||
caller_function = None
|
||||
try:
|
||||
parent = ref.parent()
|
||||
if parent and parent.type == "function":
|
||||
caller_function = parent.name
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
result.append(
|
||||
ReferenceInfo(
|
||||
file_path=ref_path,
|
||||
line=ref.line,
|
||||
column=ref.column,
|
||||
end_line=ref.line,
|
||||
end_column=ref.column + len(function.function_name),
|
||||
context=context.strip(),
|
||||
reference_type="call",
|
||||
import_name=function.function_name,
|
||||
caller_function=caller_function,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to find references for %s: %s", function.function_name, e)
|
||||
return []
|
||||
|
||||
# === Code Transformation ===
|
||||
|
||||
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:
|
||||
def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
|
||||
"""Replace a function in source code with new implementation.
|
||||
|
||||
Uses libcst for Python code transformation.
|
||||
|
||||
Args:
|
||||
source: Original source code.
|
||||
function: FunctionInfo identifying the function to replace.
|
||||
function: FunctionToOptimize identifying the function to replace.
|
||||
new_source: New function source code.
|
||||
|
||||
Returns:
|
||||
|
|
@ -319,7 +433,7 @@ class PythonSupport:
|
|||
preexisting_objects=set(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to replace function %s: %s", function.name, e)
|
||||
logger.warning("Failed to replace function %s: %s", function.function_name, e)
|
||||
return source
|
||||
|
||||
def format_code(self, source: str, file_path: Path | None = None) -> str:
|
||||
|
|
@ -465,7 +579,7 @@ class PythonSupport:
|
|||
|
||||
# === Instrumentation ===
|
||||
|
||||
def instrument_for_behavior(self, source: str, functions: Sequence[FunctionInfo]) -> str:
|
||||
def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOptimize]) -> str:
|
||||
"""Add behavior instrumentation to capture inputs/outputs.
|
||||
|
||||
Args:
|
||||
|
|
@ -480,7 +594,7 @@ class PythonSupport:
|
|||
# This is a pass-through for now
|
||||
return source
|
||||
|
||||
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionInfo) -> str:
|
||||
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str:
|
||||
"""Add timing instrumentation to test code.
|
||||
|
||||
Args:
|
||||
|
|
@ -721,7 +835,9 @@ class PythonSupport:
|
|||
mode=testing_mode,
|
||||
)
|
||||
|
||||
def instrument_source_for_line_profiler(self, func_info: FunctionInfo, line_profiler_output_file: Path) -> bool:
|
||||
def instrument_source_for_line_profiler(
|
||||
self, func_info: FunctionToOptimize, line_profiler_output_file: Path
|
||||
) -> bool:
|
||||
"""Instrument source code for line profiling.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import logging
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.language_enum import Language
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
|
@ -30,6 +30,33 @@ _LANGUAGE_REGISTRY: dict[Language, type[LanguageSupport]] = {}
|
|||
# Cache of instantiated language support objects
|
||||
_SUPPORT_CACHE: dict[Language, LanguageSupport] = {}
|
||||
|
||||
# Flag to track if language modules have been imported
|
||||
_languages_registered = False
|
||||
|
||||
|
||||
def _ensure_languages_registered() -> None:
|
||||
"""Ensure all language support modules are imported and registered.
|
||||
|
||||
This lazily imports the language support modules to avoid circular imports
|
||||
at module load time. The imports trigger the @register_language decorators
|
||||
which populate the registries.
|
||||
"""
|
||||
global _languages_registered
|
||||
if _languages_registered:
|
||||
return
|
||||
|
||||
# Import support modules to trigger registration
|
||||
# These imports are deferred to avoid circular imports
|
||||
import contextlib
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
from codeflash.languages.python import support as _
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
from codeflash.languages.javascript import support as _ # noqa: F401
|
||||
|
||||
_languages_registered = True
|
||||
|
||||
|
||||
class UnsupportedLanguageError(Exception):
|
||||
"""Raised when attempting to use an unsupported language."""
|
||||
|
|
@ -123,6 +150,10 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport:
|
|||
Raises:
|
||||
UnsupportedLanguageError: If the language is not supported.
|
||||
|
||||
Note:
|
||||
This function lazily imports language support modules on first call
|
||||
to avoid circular import issues at module load time.
|
||||
|
||||
Example:
|
||||
# By file path
|
||||
lang = get_language_support(Path("example.py"))
|
||||
|
|
@ -137,6 +168,7 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport:
|
|||
lang = get_language_support("python")
|
||||
|
||||
"""
|
||||
_ensure_languages_registered()
|
||||
language: Language | None = None
|
||||
|
||||
if isinstance(identifier, Language):
|
||||
|
|
@ -178,6 +210,42 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport:
|
|||
_FRAMEWORK_CACHE: dict[str, LanguageSupport] = {}
|
||||
|
||||
|
||||
def get_language_support_by_common_formatters(formatter_cmd: str | list[str]) -> LanguageSupport | None:
|
||||
_ensure_languages_registered()
|
||||
language: Language | None = None
|
||||
if isinstance(formatter_cmd, str):
|
||||
formatter_cmd = [formatter_cmd]
|
||||
|
||||
if len(formatter_cmd) == 1:
|
||||
formatter_cmd = formatter_cmd[0].split(" ")
|
||||
|
||||
# Try as extension first
|
||||
ext = None
|
||||
|
||||
py_formatters = ["black", "isort", "ruff", "autopep8", "yapf", "pyfmt"]
|
||||
js_ts_formatters = ["prettier", "eslint", "biome", "rome", "deno", "standard", "tslint"]
|
||||
|
||||
if any(cmd in py_formatters for cmd in formatter_cmd):
|
||||
ext = ".py"
|
||||
elif any(cmd in js_ts_formatters for cmd in formatter_cmd):
|
||||
ext = ".js"
|
||||
|
||||
if ext is None:
|
||||
# can't determine language
|
||||
return None
|
||||
|
||||
cls = _EXTENSION_REGISTRY[ext]
|
||||
language = cls().language
|
||||
|
||||
# Return cached instance or create new one
|
||||
if language not in _SUPPORT_CACHE:
|
||||
if language not in _LANGUAGE_REGISTRY:
|
||||
raise UnsupportedLanguageError(str(language), get_supported_languages())
|
||||
_SUPPORT_CACHE[language] = _LANGUAGE_REGISTRY[language]()
|
||||
|
||||
return _SUPPORT_CACHE[language]
|
||||
|
||||
|
||||
def get_language_support_by_framework(test_framework: str) -> LanguageSupport | None:
|
||||
"""Get language support for a test framework.
|
||||
|
||||
|
|
@ -238,6 +306,7 @@ def detect_project_language(project_root: Path, module_root: Path) -> Language:
|
|||
UnsupportedLanguageError: If no supported language is detected.
|
||||
|
||||
"""
|
||||
_ensure_languages_registered()
|
||||
extension_counts: dict[str, int] = {}
|
||||
|
||||
# Count files by extension
|
||||
|
|
@ -265,6 +334,7 @@ def get_supported_languages() -> list[str]:
|
|||
List of language name strings.
|
||||
|
||||
"""
|
||||
_ensure_languages_registered()
|
||||
return [lang.value for lang in _LANGUAGE_REGISTRY]
|
||||
|
||||
|
||||
|
|
@ -275,6 +345,7 @@ def get_supported_extensions() -> list[str]:
|
|||
List of extension strings (with leading dots).
|
||||
|
||||
"""
|
||||
_ensure_languages_registered()
|
||||
return list(_EXTENSION_REGISTRY.keys())
|
||||
|
||||
|
||||
|
|
@ -300,10 +371,12 @@ def clear_registry() -> None:
|
|||
|
||||
Primarily useful for testing.
|
||||
"""
|
||||
global _languages_registered
|
||||
_EXTENSION_REGISTRY.clear()
|
||||
_LANGUAGE_REGISTRY.clear()
|
||||
_SUPPORT_CACHE.clear()
|
||||
_FRAMEWORK_CACHE.clear()
|
||||
_languages_registered = False
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
|
|
|
|||
145
codeflash/languages/test_framework.py
Normal file
145
codeflash/languages/test_framework.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""Singleton for the current test framework being used in the codeflash session.
|
||||
|
||||
This module provides a centralized way to access and set the current test framework
|
||||
throughout the codeflash codebase, similar to how the language singleton works.
|
||||
|
||||
For JavaScript/TypeScript projects, this determines whether to use Jest or Vitest
|
||||
for test execution.
|
||||
|
||||
Usage:
|
||||
from codeflash.languages.test_framework import (
|
||||
current_test_framework,
|
||||
set_current_test_framework,
|
||||
is_jest,
|
||||
is_vitest,
|
||||
)
|
||||
|
||||
# Set the test framework at the start of a session (auto-detected from package.json)
|
||||
set_current_test_framework("vitest")
|
||||
|
||||
# Check the current test framework anywhere in the codebase
|
||||
if is_vitest():
|
||||
# Vitest-specific code
|
||||
...
|
||||
|
||||
# Get the current test framework
|
||||
framework = current_test_framework()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
TestFramework = Literal["jest", "vitest", "mocha", "pytest", "unittest"]
|
||||
|
||||
# Module-level singleton for the current test framework
|
||||
_current_test_framework: TestFramework | None = None
|
||||
|
||||
|
||||
def current_test_framework() -> TestFramework | None:
|
||||
"""Get the current test framework being used in this codeflash session.
|
||||
|
||||
Returns:
|
||||
The current test framework string, or None if not set.
|
||||
|
||||
"""
|
||||
return _current_test_framework
|
||||
|
||||
|
||||
def set_current_test_framework(framework: TestFramework | str | None) -> None:
|
||||
"""Set the current test framework for this codeflash session.
|
||||
|
||||
This should be called once at the start of an optimization run,
|
||||
typically after reading the project configuration.
|
||||
|
||||
Args:
|
||||
framework: Test framework name ("jest", "vitest", "mocha", "pytest", "unittest").
|
||||
|
||||
"""
|
||||
global _current_test_framework
|
||||
|
||||
if _current_test_framework is not None:
|
||||
return
|
||||
|
||||
if framework is not None:
|
||||
framework = framework.lower()
|
||||
if framework not in ("jest", "vitest", "mocha", "pytest", "unittest"):
|
||||
# Default to jest for unknown JS frameworks, pytest for unknown Python
|
||||
from codeflash.languages.current import is_javascript
|
||||
|
||||
framework = "jest" if is_javascript() else "pytest"
|
||||
|
||||
_current_test_framework = framework
|
||||
|
||||
|
||||
def reset_test_framework() -> None:
|
||||
"""Reset the current test framework to None.
|
||||
|
||||
Useful for testing or when starting a new session.
|
||||
"""
|
||||
global _current_test_framework
|
||||
_current_test_framework = None
|
||||
|
||||
|
||||
def is_jest() -> bool:
|
||||
"""Check if the current test framework is Jest.
|
||||
|
||||
Returns:
|
||||
True if the current test framework is Jest.
|
||||
|
||||
"""
|
||||
return _current_test_framework == "jest"
|
||||
|
||||
|
||||
def is_vitest() -> bool:
|
||||
"""Check if the current test framework is Vitest.
|
||||
|
||||
Returns:
|
||||
True if the current test framework is Vitest.
|
||||
|
||||
"""
|
||||
return _current_test_framework == "vitest"
|
||||
|
||||
|
||||
def is_mocha() -> bool:
|
||||
"""Check if the current test framework is Mocha.
|
||||
|
||||
Returns:
|
||||
True if the current test framework is Mocha.
|
||||
|
||||
"""
|
||||
return _current_test_framework == "mocha"
|
||||
|
||||
|
||||
def is_pytest() -> bool:
|
||||
"""Check if the current test framework is pytest.
|
||||
|
||||
Returns:
|
||||
True if the current test framework is pytest.
|
||||
|
||||
"""
|
||||
return _current_test_framework == "pytest"
|
||||
|
||||
|
||||
def is_unittest() -> bool:
|
||||
"""Check if the current test framework is unittest.
|
||||
|
||||
Returns:
|
||||
True if the current test framework is unittest.
|
||||
|
||||
"""
|
||||
return _current_test_framework == "unittest"
|
||||
|
||||
|
||||
def get_js_test_framework_or_default() -> TestFramework:
|
||||
"""Get the current test framework for JS/TS, defaulting to 'jest' if not set.
|
||||
|
||||
This is a convenience function for JS/TS code that needs a framework.
|
||||
|
||||
Returns:
|
||||
The current test framework, or 'jest' as default.
|
||||
|
||||
"""
|
||||
if _current_test_framework in ("jest", "vitest", "mocha"):
|
||||
return _current_test_framework
|
||||
return "jest"
|
||||
|
|
@ -454,23 +454,47 @@ class TreeSitterAnalyzer:
|
|||
|
||||
return imports
|
||||
|
||||
def _walk_tree_for_imports(self, node: Node, source_bytes: bytes, imports: list[ImportInfo]) -> None:
|
||||
"""Recursively walk the tree to find import statements."""
|
||||
def _walk_tree_for_imports(
|
||||
self, node: Node, source_bytes: bytes, imports: list[ImportInfo], in_function: bool = False
|
||||
) -> None:
|
||||
"""Recursively walk the tree to find import statements.
|
||||
|
||||
Args:
|
||||
node: Current node to check.
|
||||
source_bytes: Source code bytes.
|
||||
imports: List to append found imports to.
|
||||
in_function: Whether we're currently inside a function/method body.
|
||||
|
||||
"""
|
||||
# Track when we enter function/method bodies
|
||||
# These node types contain function/method bodies where require() should not be treated as imports
|
||||
function_body_types = {
|
||||
"function_declaration",
|
||||
"method_definition",
|
||||
"arrow_function",
|
||||
"function_expression",
|
||||
"function", # Generic function in some grammars
|
||||
}
|
||||
|
||||
if node.type == "import_statement":
|
||||
import_info = self._extract_import_info(node, source_bytes)
|
||||
if import_info:
|
||||
imports.append(import_info)
|
||||
|
||||
# Also handle require() calls for CommonJS
|
||||
if node.type == "call_expression":
|
||||
# Also handle require() calls for CommonJS, but only at module level
|
||||
# require() inside functions is a dynamic import, not a module import
|
||||
if node.type == "call_expression" and not in_function:
|
||||
func_node = node.child_by_field_name("function")
|
||||
if func_node and self.get_node_text(func_node, source_bytes) == "require":
|
||||
import_info = self._extract_require_info(node, source_bytes)
|
||||
if import_info:
|
||||
imports.append(import_info)
|
||||
|
||||
# Update in_function flag for children
|
||||
child_in_function = in_function or node.type in function_body_types
|
||||
|
||||
for child in node.children:
|
||||
self._walk_tree_for_imports(child, source_bytes, imports)
|
||||
self._walk_tree_for_imports(child, source_bytes, imports, child_in_function)
|
||||
|
||||
def _extract_import_info(self, node: Node, source_bytes: bytes) -> ImportInfo | None:
|
||||
"""Extract import information from an import statement node."""
|
||||
|
|
@ -841,20 +865,27 @@ class TreeSitterAnalyzer:
|
|||
end_line=node.end_point[0] + 1,
|
||||
)
|
||||
|
||||
def is_function_exported(self, source: str, function_name: str) -> tuple[bool, str | None]:
|
||||
def is_function_exported(
|
||||
self, source: str, function_name: str, class_name: str | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Check if a function is exported and get its export name.
|
||||
|
||||
For class methods, also checks if the containing class is exported.
|
||||
|
||||
Args:
|
||||
source: The source code to analyze.
|
||||
function_name: The name of the function to check.
|
||||
class_name: For class methods, the name of the containing class.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_exported, export_name). export_name may differ from
|
||||
function_name if exported with an alias.
|
||||
function_name if exported with an alias. For class methods,
|
||||
returns the class export name.
|
||||
|
||||
"""
|
||||
exports = self.find_exports(source)
|
||||
|
||||
# First, check if the function itself is directly exported
|
||||
for export in exports:
|
||||
# Check default export
|
||||
if export.default_export == function_name:
|
||||
|
|
@ -865,6 +896,18 @@ class TreeSitterAnalyzer:
|
|||
if name == function_name:
|
||||
return (True, alias if alias else name)
|
||||
|
||||
# For class methods, check if the containing class is exported
|
||||
if class_name:
|
||||
for export in exports:
|
||||
# Check if class is default export
|
||||
if export.default_export == class_name:
|
||||
return (True, class_name)
|
||||
|
||||
# Check if class is in named exports
|
||||
for name, alias in export.exported_names:
|
||||
if name == class_name:
|
||||
return (True, alias if alias else name)
|
||||
|
||||
return (False, None)
|
||||
|
||||
def find_function_calls(self, source: str, within_function: FunctionNode) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,12 @@ If you might want to work with us on finally making performance a
|
|||
solved problem, please reach out to us at careers@codeflash.ai. We're hiring!
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
|
||||
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test
|
||||
|
|
@ -16,6 +21,9 @@ from codeflash.code_utils.version_check import check_for_newer_minor_version
|
|||
from codeflash.telemetry import posthog_cf
|
||||
from codeflash.telemetry.sentry import init_sentry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point for the codeflash command-line interface."""
|
||||
|
|
@ -39,7 +47,12 @@ def main() -> None:
|
|||
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
|
||||
ask_run_end_to_end_test(args)
|
||||
else:
|
||||
args = process_pyproject_config(args)
|
||||
# Check for first-run experience (no config exists)
|
||||
loaded_args = _handle_config_loading(args)
|
||||
if loaded_args is None:
|
||||
sys.exit(0)
|
||||
args = loaded_args
|
||||
|
||||
if not env_utils.check_formatter_installed(args.formatter_cmds):
|
||||
return
|
||||
args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args)
|
||||
|
|
@ -51,6 +64,52 @@ def main() -> None:
|
|||
optimizer.run_with_args(args)
|
||||
|
||||
|
||||
def _handle_config_loading(args: Namespace) -> Namespace | None:
|
||||
"""Handle config loading with first-run experience support.
|
||||
|
||||
If no config exists and not in CI, triggers the first-run experience.
|
||||
Otherwise, loads config normally.
|
||||
|
||||
Args:
|
||||
args: CLI args namespace.
|
||||
|
||||
Returns:
|
||||
Updated args with config loaded, or None if user cancelled first-run.
|
||||
|
||||
"""
|
||||
from codeflash.setup.first_run import handle_first_run, is_first_run
|
||||
|
||||
# Check if we're in CI environment
|
||||
is_ci = any(
|
||||
var in ("true", "1", "True") for var in [os.environ.get("CI", ""), os.environ.get("GITHUB_ACTIONS", "")]
|
||||
)
|
||||
|
||||
# Check if first run (no config exists)
|
||||
if is_first_run() and not is_ci:
|
||||
# Skip API key check if already set
|
||||
skip_api_key = bool(os.environ.get("CODEFLASH_API_KEY"))
|
||||
|
||||
# Handle first-run experience
|
||||
result = handle_first_run(args=args, skip_confirm=getattr(args, "yes", False), skip_api_key=skip_api_key)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
# Merge first-run results with any CLI overrides
|
||||
args = result
|
||||
# Still need to process some config values
|
||||
# Config might not exist yet if first run just saved it - that's OK
|
||||
import contextlib
|
||||
|
||||
with contextlib.suppress(ValueError):
|
||||
args = process_pyproject_config(args)
|
||||
|
||||
return args
|
||||
|
||||
# Normal config loading
|
||||
return process_pyproject_config(args)
|
||||
|
||||
|
||||
def print_codeflash_banner() -> None:
|
||||
paneled_text(
|
||||
CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"}
|
||||
|
|
|
|||
18
codeflash/models/function_types.py
Normal file
18
codeflash/models/function_types.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Simple function-related types with no dependencies.
|
||||
|
||||
This module contains basic types used for function representation.
|
||||
It is intentionally kept dependency-free to avoid circular imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionParent:
|
||||
name: str
|
||||
type: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.type}:{self.name}"
|
||||
|
|
@ -599,10 +599,8 @@ class CodePosition:
|
|||
col_no: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionParent:
|
||||
name: str
|
||||
type: str
|
||||
# Re-export FunctionParent for backward compatibility
|
||||
from codeflash.models.function_types import FunctionParent # noqa: E402
|
||||
|
||||
|
||||
class OriginalCodeBaseline(BaseModel):
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ from codeflash.context.unused_definition_remover import detect_unused_helper_fun
|
|||
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
|
||||
from codeflash.either import Failure, Success, is_successful
|
||||
from codeflash.languages import is_java, is_python
|
||||
from codeflash.languages.base import FunctionInfo, Language
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.current import current_language_support, is_typescript
|
||||
from codeflash.languages.javascript.module_system import detect_module_system
|
||||
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown
|
||||
|
|
@ -2285,10 +2285,10 @@ class FunctionOptimizer:
|
|||
else self.function_trace_id,
|
||||
"coverage_message": coverage_message,
|
||||
"replay_tests": replay_tests,
|
||||
"concolic_tests": concolic_tests,
|
||||
# "concolic_tests": concolic_tests,
|
||||
"language": self.function_to_optimize.language,
|
||||
"original_line_profiler": original_code_baseline.line_profile_results.get("str_out", ""),
|
||||
"optimized_line_profiler": best_optimization.line_profiler_test_results.get("str_out", ""),
|
||||
# "original_line_profiler": original_code_baseline.line_profile_results.get("str_out", ""),
|
||||
# "optimized_line_profiler": best_optimization.line_profiler_test_results.get("str_out", ""),
|
||||
}
|
||||
|
||||
raise_pr = not self.args.no_pr
|
||||
|
|
@ -2955,18 +2955,8 @@ class FunctionOptimizer:
|
|||
# NOTE: currently this handles single file only, add support to multi file instrumentation (or should it be kept for the main file only)
|
||||
original_source = Path(self.function_to_optimize.file_path).read_text()
|
||||
# Instrument source code
|
||||
func_info = FunctionInfo(
|
||||
name=self.function_to_optimize.function_name,
|
||||
file_path=self.function_to_optimize.file_path,
|
||||
start_line=self.function_to_optimize.starting_line,
|
||||
end_line=self.function_to_optimize.ending_line,
|
||||
start_col=self.function_to_optimize.starting_col,
|
||||
end_col=self.function_to_optimize.ending_col,
|
||||
is_async=self.function_to_optimize.is_async,
|
||||
language=self.language_support.language,
|
||||
)
|
||||
success = self.language_support.instrument_source_for_line_profiler(
|
||||
func_info=func_info, line_profiler_output_file=line_profiler_output_path
|
||||
func_info=self.function_to_optimize, line_profiler_output_file=line_profiler_output_path
|
||||
)
|
||||
if not success:
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
|
|
|
|||
|
|
@ -90,6 +90,33 @@ class Optimizer:
|
|||
current = current.parent
|
||||
return None
|
||||
|
||||
def _verify_js_requirements(self) -> None:
|
||||
"""Verify JavaScript/TypeScript requirements before optimization.
|
||||
|
||||
Checks that Node.js, npm, and the test framework are available.
|
||||
Logs warnings if requirements are not met but does not abort.
|
||||
|
||||
"""
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.test_framework import get_js_test_framework_or_default
|
||||
|
||||
js_project_root = self.test_cfg.js_project_root
|
||||
if not js_project_root:
|
||||
return
|
||||
|
||||
try:
|
||||
js_support = get_language_support(Language.JAVASCRIPT)
|
||||
test_framework = get_js_test_framework_or_default()
|
||||
success, errors = js_support.verify_requirements(js_project_root, test_framework)
|
||||
|
||||
if not success:
|
||||
logger.warning("JavaScript requirements check found issues:")
|
||||
for error in errors:
|
||||
logger.warning(f" - {error}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to verify JS requirements: {e}")
|
||||
|
||||
def run_benchmarks(
|
||||
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
|
||||
) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]:
|
||||
|
|
@ -466,6 +493,8 @@ class Optimizer:
|
|||
# For JavaScript, also set js_project_root for test execution
|
||||
if is_javascript():
|
||||
self.test_cfg.js_project_root = self._find_js_project_root(file_path)
|
||||
# Verify JS requirements before proceeding
|
||||
self._verify_js_requirements()
|
||||
break
|
||||
|
||||
if self.args.all:
|
||||
|
|
|
|||
|
|
@ -281,8 +281,8 @@ def check_create_pr(
|
|||
function_trace_id: str,
|
||||
coverage_message: str,
|
||||
replay_tests: str,
|
||||
concolic_tests: str,
|
||||
root_dir: Path,
|
||||
concolic_tests: str = "",
|
||||
git_remote: Optional[str] = None,
|
||||
optimization_review: str = "",
|
||||
original_line_profiler: str | None = None,
|
||||
|
|
|
|||
22
codeflash/setup/__init__.py
Normal file
22
codeflash/setup/__init__.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""Setup module for Codeflash auto-detection and first-run experience.
|
||||
|
||||
This module provides:
|
||||
- Universal project detection across all supported languages
|
||||
- First-run experience with auto-detection and quick confirm
|
||||
- Config writing to native config files (pyproject.toml, package.json)
|
||||
"""
|
||||
|
||||
from codeflash.setup.config_schema import CodeflashConfig
|
||||
from codeflash.setup.config_writer import write_config
|
||||
from codeflash.setup.detector import DetectedProject, detect_project, has_existing_config
|
||||
from codeflash.setup.first_run import handle_first_run, is_first_run
|
||||
|
||||
__all__ = [
|
||||
"CodeflashConfig",
|
||||
"DetectedProject",
|
||||
"detect_project",
|
||||
"handle_first_run",
|
||||
"has_existing_config",
|
||||
"is_first_run",
|
||||
"write_config",
|
||||
]
|
||||
200
codeflash/setup/config_schema.py
Normal file
200
codeflash/setup/config_schema.py
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
"""Codeflash configuration schema using Pydantic.
|
||||
|
||||
This module provides a language-agnostic internal representation of Codeflash
|
||||
configuration that can be serialized to different formats (TOML, JSON).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class CodeflashConfig(BaseModel):
|
||||
"""Internal representation of Codeflash configuration.
|
||||
|
||||
This is the canonical config format used internally. It can be converted
|
||||
to/from pyproject.toml (Python) or package.json (JS/TS) formats.
|
||||
|
||||
Note: All paths are stored as strings (relative to project root).
|
||||
"""
|
||||
|
||||
# Core settings (always present after detection)
|
||||
language: str = Field(description="Project language: python, javascript, typescript")
|
||||
module_root: str = Field(default=".", description="Root directory containing source code")
|
||||
tests_root: str | None = Field(default=None, description="Root directory containing tests")
|
||||
|
||||
# Tooling settings (auto-detected, can be overridden)
|
||||
test_framework: str | None = Field(default=None, description="Test framework: pytest, jest, vitest, mocha")
|
||||
formatter_cmds: list[str] = Field(default_factory=list, description="Formatter commands")
|
||||
|
||||
# Optional settings
|
||||
ignore_paths: list[str] = Field(default_factory=list, description="Paths to ignore")
|
||||
benchmarks_root: str | None = Field(default=None, description="Benchmarks directory")
|
||||
|
||||
# Git settings
|
||||
git_remote: str = Field(default="origin", description="Git remote for PRs")
|
||||
|
||||
# Privacy settings
|
||||
disable_telemetry: bool = Field(default=False, description="Disable telemetry")
|
||||
|
||||
# Python-specific settings
|
||||
pytest_cmd: str = Field(default="pytest", description="Pytest command (Python only)")
|
||||
disable_imports_sorting: bool = Field(default=False, description="Disable import sorting (Python only)")
|
||||
override_fixtures: bool = Field(default=False, description="Override test fixtures (Python only)")
|
||||
|
||||
model_config = ConfigDict(extra="allow") # Allow extra fields for forward compatibility
|
||||
|
||||
def to_pyproject_dict(self) -> dict[str, Any]:
|
||||
"""Convert to pyproject.toml [tool.codeflash] format.
|
||||
|
||||
Uses kebab-case keys as per TOML conventions.
|
||||
Only includes non-default values to keep config minimal.
|
||||
"""
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
# Always include required fields
|
||||
config["module-root"] = self.module_root
|
||||
if self.tests_root:
|
||||
config["tests-root"] = self.tests_root
|
||||
|
||||
# Include non-default optional fields
|
||||
if self.ignore_paths:
|
||||
config["ignore-paths"] = self.ignore_paths
|
||||
|
||||
if self.formatter_cmds and self.formatter_cmds != ["black $file"]:
|
||||
config["formatter-cmds"] = self.formatter_cmds
|
||||
elif not self.formatter_cmds:
|
||||
config["formatter-cmds"] = ["disabled"]
|
||||
|
||||
if self.benchmarks_root:
|
||||
config["benchmarks-root"] = self.benchmarks_root
|
||||
|
||||
if self.git_remote and self.git_remote != "origin":
|
||||
config["git-remote"] = self.git_remote
|
||||
|
||||
if self.disable_telemetry:
|
||||
config["disable-telemetry"] = True
|
||||
|
||||
if self.pytest_cmd and self.pytest_cmd != "pytest":
|
||||
config["pytest-cmd"] = self.pytest_cmd
|
||||
|
||||
if self.disable_imports_sorting:
|
||||
config["disable-imports-sorting"] = True
|
||||
|
||||
if self.override_fixtures:
|
||||
config["override-fixtures"] = True
|
||||
|
||||
return config
|
||||
|
||||
def to_package_json_dict(self) -> dict[str, Any]:
|
||||
"""Convert to package.json codeflash section format.
|
||||
|
||||
Uses camelCase keys as per JSON/JS conventions.
|
||||
Only includes values that override auto-detection.
|
||||
"""
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
# Module root (only if not auto-detected default)
|
||||
if self.module_root and self.module_root not in (".", "src"):
|
||||
config["moduleRoot"] = self.module_root
|
||||
|
||||
if self.tests_root:
|
||||
config["testsRoot"] = self.tests_root
|
||||
|
||||
# Formatter (only if explicitly set)
|
||||
if self.formatter_cmds:
|
||||
config["formatterCmds"] = self.formatter_cmds
|
||||
|
||||
# Ignore paths (only if set)
|
||||
if self.ignore_paths:
|
||||
config["ignorePaths"] = self.ignore_paths
|
||||
|
||||
# Benchmarks root
|
||||
if self.benchmarks_root:
|
||||
config["benchmarksRoot"] = self.benchmarks_root
|
||||
|
||||
# Git remote (only if not default)
|
||||
if self.git_remote and self.git_remote != "origin":
|
||||
config["gitRemote"] = self.git_remote
|
||||
|
||||
# Telemetry
|
||||
if self.disable_telemetry:
|
||||
config["disableTelemetry"] = True
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_detected_project(cls, detected: Any) -> CodeflashConfig:
|
||||
"""Create config from DetectedProject.
|
||||
|
||||
Args:
|
||||
detected: DetectedProject instance from detector.
|
||||
|
||||
Returns:
|
||||
CodeflashConfig instance.
|
||||
|
||||
"""
|
||||
return cls(
|
||||
language=detected.language,
|
||||
module_root=str(detected.module_root.relative_to(detected.project_root))
|
||||
if detected.module_root != detected.project_root
|
||||
else ".",
|
||||
tests_root=str(detected.tests_root.relative_to(detected.project_root)) if detected.tests_root else None,
|
||||
test_framework=detected.test_runner,
|
||||
formatter_cmds=detected.formatter_cmds,
|
||||
ignore_paths=[
|
||||
str(p.relative_to(detected.project_root)) for p in detected.ignore_paths if p != detected.project_root
|
||||
],
|
||||
pytest_cmd=detected.test_runner if detected.language == "python" else "pytest",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pyproject_dict(cls, data: dict[str, Any], project_root: Path | None = None) -> CodeflashConfig:
|
||||
"""Create config from pyproject.toml [tool.codeflash] section.
|
||||
|
||||
Args:
|
||||
data: Dict from [tool.codeflash] section.
|
||||
project_root: Project root path (reserved for future path resolution).
|
||||
|
||||
Returns:
|
||||
CodeflashConfig instance.
|
||||
|
||||
"""
|
||||
_ = project_root # Reserved for future path resolution
|
||||
|
||||
def convert_key(key: str) -> str:
|
||||
"""Convert kebab-case to snake_case."""
|
||||
return key.replace("-", "_")
|
||||
|
||||
converted = {convert_key(k): v for k, v in data.items()}
|
||||
converted.setdefault("language", "python")
|
||||
return cls(**converted)
|
||||
|
||||
@classmethod
|
||||
def from_package_json_dict(cls, data: dict[str, Any], project_root: Path | None = None) -> CodeflashConfig:
|
||||
"""Create config from package.json codeflash section.
|
||||
|
||||
Args:
|
||||
data: Dict from package.json "codeflash" key.
|
||||
project_root: Project root path (reserved for future path resolution).
|
||||
|
||||
Returns:
|
||||
CodeflashConfig instance.
|
||||
|
||||
"""
|
||||
_ = project_root # Reserved for future path resolution
|
||||
|
||||
def convert_key(key: str) -> str:
|
||||
"""Convert camelCase to snake_case."""
|
||||
import re
|
||||
|
||||
return re.sub(r"(?<!^)(?=[A-Z])", "_", key).lower()
|
||||
|
||||
converted = {convert_key(k): v for k, v in data.items()}
|
||||
converted.setdefault("language", "javascript")
|
||||
return cls(**converted)
|
||||
246
codeflash/setup/config_writer.py
Normal file
246
codeflash/setup/config_writer.py
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
"""Config writer for native config files.
|
||||
|
||||
This module writes Codeflash configuration to native config files:
|
||||
- Python: pyproject.toml [tool.codeflash]
|
||||
- JavaScript/TypeScript: package.json { "codeflash": {} }
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import tomlkit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.setup.config_schema import CodeflashConfig
|
||||
from codeflash.setup.detector import DetectedProject
|
||||
|
||||
|
||||
def write_config(detected: DetectedProject, config: CodeflashConfig | None = None) -> tuple[bool, str]:
|
||||
"""Write Codeflash config to the appropriate native config file.
|
||||
|
||||
Args:
|
||||
detected: DetectedProject with project information.
|
||||
config: Optional CodeflashConfig to write. If None, creates from detected.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message).
|
||||
|
||||
"""
|
||||
from codeflash.setup.config_schema import CodeflashConfig
|
||||
|
||||
if config is None:
|
||||
config = CodeflashConfig.from_detected_project(detected)
|
||||
|
||||
if detected.language == "python":
|
||||
return _write_pyproject_toml(detected.project_root, config)
|
||||
return _write_package_json(detected.project_root, config)
|
||||
|
||||
|
||||
def _write_pyproject_toml(project_root: Path, config: CodeflashConfig) -> tuple[bool, str]:
|
||||
"""Write config to pyproject.toml [tool.codeflash] section.
|
||||
|
||||
Creates pyproject.toml if it doesn't exist.
|
||||
Preserves existing content and formatting.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory.
|
||||
config: CodeflashConfig to write.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message).
|
||||
|
||||
"""
|
||||
pyproject_path = project_root / "pyproject.toml"
|
||||
|
||||
try:
|
||||
# Load existing or create new
|
||||
if pyproject_path.exists():
|
||||
with pyproject_path.open("rb") as f:
|
||||
doc = tomlkit.parse(f.read())
|
||||
else:
|
||||
doc = tomlkit.document()
|
||||
|
||||
# Ensure [tool] section exists
|
||||
if "tool" not in doc:
|
||||
doc["tool"] = tomlkit.table()
|
||||
|
||||
# Create codeflash section
|
||||
codeflash_table = tomlkit.table()
|
||||
codeflash_table.add(tomlkit.comment("Codeflash configuration - https://docs.codeflash.ai"))
|
||||
|
||||
# Add config values
|
||||
config_dict = config.to_pyproject_dict()
|
||||
for key, value in config_dict.items():
|
||||
codeflash_table[key] = value
|
||||
|
||||
# Update the document
|
||||
doc["tool"]["codeflash"] = codeflash_table
|
||||
|
||||
# Write back
|
||||
with pyproject_path.open("w", encoding="utf8") as f:
|
||||
f.write(tomlkit.dumps(doc))
|
||||
|
||||
return True, f"Config saved to {pyproject_path}"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"Failed to write pyproject.toml: {e}"
|
||||
|
||||
|
||||
def _write_package_json(project_root: Path, config: CodeflashConfig) -> tuple[bool, str]:
|
||||
"""Write config to package.json codeflash section.
|
||||
|
||||
Preserves existing content and formatting.
|
||||
Creates minimal config (only non-default values).
|
||||
|
||||
Args:
|
||||
project_root: Project root directory.
|
||||
config: CodeflashConfig to write.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message).
|
||||
|
||||
"""
|
||||
package_json_path = project_root / "package.json"
|
||||
|
||||
if not package_json_path.exists():
|
||||
return False, f"No package.json found at {project_root}"
|
||||
|
||||
try:
|
||||
# Load existing
|
||||
with package_json_path.open(encoding="utf8") as f:
|
||||
doc = json.load(f)
|
||||
|
||||
# Get config dict (only non-default values)
|
||||
config_dict = config.to_package_json_dict()
|
||||
|
||||
# Update or remove codeflash section
|
||||
if config_dict:
|
||||
doc["codeflash"] = config_dict
|
||||
action = "Updated"
|
||||
else:
|
||||
# Remove codeflash section if empty (all defaults)
|
||||
doc.pop("codeflash", None)
|
||||
action = "Using auto-detected defaults (no config needed)"
|
||||
|
||||
# Write back with nice formatting
|
||||
with package_json_path.open("w", encoding="utf8") as f:
|
||||
json.dump(doc, f, indent=2)
|
||||
f.write("\n") # Trailing newline
|
||||
|
||||
if config_dict:
|
||||
return True, f"{action} config in {package_json_path}"
|
||||
return True, action
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
return False, f"Invalid JSON in package.json: {e}"
|
||||
except Exception as e:
|
||||
return False, f"Failed to write package.json: {e}"
|
||||
|
||||
|
||||
def create_pyproject_toml(project_root: Path) -> tuple[bool, str]:
|
||||
"""Create a minimal pyproject.toml file.
|
||||
|
||||
Used when no pyproject.toml exists for a Python project.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message).
|
||||
|
||||
"""
|
||||
pyproject_path = project_root / "pyproject.toml"
|
||||
|
||||
if pyproject_path.exists():
|
||||
return False, f"pyproject.toml already exists at {pyproject_path}"
|
||||
|
||||
try:
|
||||
doc = tomlkit.document()
|
||||
doc.add(tomlkit.comment("Created by Codeflash"))
|
||||
doc.add(tomlkit.nl())
|
||||
|
||||
# Add minimal [tool.codeflash] section
|
||||
tool_table = tomlkit.table()
|
||||
codeflash_table = tomlkit.table()
|
||||
codeflash_table.add(tomlkit.comment("Codeflash configuration - https://docs.codeflash.ai"))
|
||||
tool_table["codeflash"] = codeflash_table
|
||||
doc["tool"] = tool_table
|
||||
|
||||
with pyproject_path.open("w", encoding="utf8") as f:
|
||||
f.write(tomlkit.dumps(doc))
|
||||
|
||||
return True, f"Created {pyproject_path}"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"Failed to create pyproject.toml: {e}"
|
||||
|
||||
|
||||
def remove_config(project_root: Path, language: str) -> tuple[bool, str]:
|
||||
"""Remove Codeflash config from native config file.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory.
|
||||
language: Project language ("python", "javascript", "typescript").
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message).
|
||||
|
||||
"""
|
||||
if language == "python":
|
||||
return _remove_from_pyproject(project_root)
|
||||
return _remove_from_package_json(project_root)
|
||||
|
||||
|
||||
def _remove_from_pyproject(project_root: Path) -> tuple[bool, str]:
|
||||
"""Remove [tool.codeflash] section from pyproject.toml."""
|
||||
pyproject_path = project_root / "pyproject.toml"
|
||||
|
||||
if not pyproject_path.exists():
|
||||
return True, "No pyproject.toml found"
|
||||
|
||||
try:
|
||||
with pyproject_path.open("rb") as f:
|
||||
doc = tomlkit.parse(f.read())
|
||||
|
||||
if "tool" in doc and "codeflash" in doc["tool"]:
|
||||
del doc["tool"]["codeflash"]
|
||||
|
||||
with pyproject_path.open("w", encoding="utf8") as f:
|
||||
f.write(tomlkit.dumps(doc))
|
||||
|
||||
return True, "Removed [tool.codeflash] section from pyproject.toml"
|
||||
|
||||
return True, "No codeflash config found in pyproject.toml"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"Failed to remove config: {e}"
|
||||
|
||||
|
||||
def _remove_from_package_json(project_root: Path) -> tuple[bool, str]:
|
||||
"""Remove codeflash section from package.json."""
|
||||
package_json_path = project_root / "package.json"
|
||||
|
||||
if not package_json_path.exists():
|
||||
return True, "No package.json found"
|
||||
|
||||
try:
|
||||
with package_json_path.open(encoding="utf8") as f:
|
||||
doc = json.load(f)
|
||||
|
||||
if "codeflash" in doc:
|
||||
del doc["codeflash"]
|
||||
|
||||
with package_json_path.open("w", encoding="utf8") as f:
|
||||
json.dump(doc, f, indent=2)
|
||||
f.write("\n")
|
||||
|
||||
return True, "Removed codeflash section from package.json"
|
||||
|
||||
return True, "No codeflash config found in package.json"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"Failed to remove config: {e}"
|
||||
692
codeflash/setup/detector.py
Normal file
692
codeflash/setup/detector.py
Normal file
|
|
@ -0,0 +1,692 @@
|
|||
"""Universal project detection engine for Codeflash.
|
||||
|
||||
This module provides a single detection engine that works for all supported languages,
|
||||
consolidating detection logic from various parts of the codebase.
|
||||
|
||||
Usage:
|
||||
from codeflash.setup import detect_project
|
||||
|
||||
detected = detect_project()
|
||||
print(f"Language: {detected.language}")
|
||||
print(f"Module root: {detected.module_root}")
|
||||
print(f"Test runner: {detected.test_runner}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import tomlkit
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectedProject:
|
||||
"""Result of project auto-detection.
|
||||
|
||||
All paths are absolute. The confidence score indicates how certain
|
||||
we are about the detection (0.0 = guessing, 1.0 = certain).
|
||||
"""
|
||||
|
||||
# Core detection results
|
||||
language: str # "python" | "javascript" | "typescript"
|
||||
project_root: Path
|
||||
module_root: Path
|
||||
tests_root: Path | None
|
||||
|
||||
# Tooling detection
|
||||
test_runner: str # "pytest" | "jest" | "vitest" | "mocha"
|
||||
formatter_cmds: list[str]
|
||||
|
||||
# Ignore paths (absolute paths to ignore)
|
||||
ignore_paths: list[Path] = field(default_factory=list)
|
||||
|
||||
# Confidence score for the detection (0.0 - 1.0)
|
||||
confidence: float = 0.8
|
||||
|
||||
# Detection details (for debugging/display)
|
||||
detection_details: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def to_display_dict(self) -> dict[str, str]:
|
||||
"""Convert to dictionary for display purposes."""
|
||||
formatter_display = self.formatter_cmds[0] if self.formatter_cmds else "none detected"
|
||||
if len(self.formatter_cmds) > 1:
|
||||
formatter_display += f" (+{len(self.formatter_cmds) - 1} more)"
|
||||
|
||||
ignore_display = ", ".join(p.name for p in self.ignore_paths[:3])
|
||||
if len(self.ignore_paths) > 3:
|
||||
ignore_display += f" (+{len(self.ignore_paths) - 3} more)"
|
||||
|
||||
return {
|
||||
"Language": self.language.capitalize(),
|
||||
"Module root": str(self.module_root.relative_to(self.project_root))
|
||||
if self.module_root != self.project_root
|
||||
else ".",
|
||||
"Tests root": str(self.tests_root.relative_to(self.project_root)) if self.tests_root else "not detected",
|
||||
"Test runner": self.test_runner,
|
||||
"Formatter": formatter_display or "none",
|
||||
"Ignoring": ignore_display or "defaults only",
|
||||
}
|
||||
|
||||
|
||||
def detect_project(path: Path | None = None) -> DetectedProject:
|
||||
"""Auto-detect all project settings.
|
||||
|
||||
This is the main entry point for project detection. It finds the project root,
|
||||
detects the language, and auto-detects all configuration values.
|
||||
|
||||
Args:
|
||||
path: Starting path for detection. Defaults to current working directory.
|
||||
|
||||
Returns:
|
||||
DetectedProject with all detected settings.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid project can be detected.
|
||||
|
||||
"""
|
||||
start_path = path or Path.cwd()
|
||||
detection_details: dict[str, str] = {}
|
||||
|
||||
# Step 1: Find project root
|
||||
project_root = _find_project_root(start_path)
|
||||
if project_root is None:
|
||||
# No project root found, use start_path
|
||||
project_root = start_path
|
||||
detection_details["project_root"] = "using current directory (no markers found)"
|
||||
else:
|
||||
detection_details["project_root"] = f"found at {project_root}"
|
||||
|
||||
# Step 2: Detect language
|
||||
language, lang_confidence, lang_detail = _detect_language(project_root)
|
||||
detection_details["language"] = lang_detail
|
||||
|
||||
# Step 3: Detect module root
|
||||
module_root, module_detail = _detect_module_root(project_root, language)
|
||||
detection_details["module_root"] = module_detail
|
||||
|
||||
# Step 4: Detect tests root
|
||||
tests_root, tests_detail = _detect_tests_root(project_root, language)
|
||||
detection_details["tests_root"] = tests_detail
|
||||
|
||||
# Step 5: Detect test runner
|
||||
test_runner, runner_detail = _detect_test_runner(project_root, language)
|
||||
detection_details["test_runner"] = runner_detail
|
||||
|
||||
# Step 6: Detect formatter
|
||||
formatter_cmds, formatter_detail = _detect_formatter(project_root, language)
|
||||
detection_details["formatter"] = formatter_detail
|
||||
|
||||
# Step 7: Detect ignore paths
|
||||
ignore_paths, ignore_detail = _detect_ignore_paths(project_root, language)
|
||||
detection_details["ignore_paths"] = ignore_detail
|
||||
|
||||
# Calculate overall confidence
|
||||
confidence = lang_confidence * 0.4 + 0.6 # Language detection is 40% of confidence
|
||||
|
||||
return DetectedProject(
|
||||
language=language,
|
||||
project_root=project_root,
|
||||
module_root=module_root,
|
||||
tests_root=tests_root,
|
||||
test_runner=test_runner,
|
||||
formatter_cmds=formatter_cmds,
|
||||
ignore_paths=ignore_paths,
|
||||
confidence=confidence,
|
||||
detection_details=detection_details,
|
||||
)
|
||||
|
||||
|
||||
def _find_project_root(start_path: Path) -> Path | None:
|
||||
"""Find the project root by walking up the directory tree.
|
||||
|
||||
Looks for:
|
||||
- .git directory (git repository root)
|
||||
- pyproject.toml (Python project)
|
||||
- package.json (JavaScript/TypeScript project)
|
||||
- Cargo.toml (Rust project - future)
|
||||
|
||||
Args:
|
||||
start_path: Starting directory for search.
|
||||
|
||||
Returns:
|
||||
Path to project root, or None if not found.
|
||||
|
||||
"""
|
||||
current = start_path.resolve()
|
||||
|
||||
while current != current.parent:
|
||||
# Check for project markers
|
||||
markers = [".git", "pyproject.toml", "package.json", "Cargo.toml"]
|
||||
for marker in markers:
|
||||
if (current / marker).exists():
|
||||
return current
|
||||
current = current.parent
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _detect_language(project_root: Path) -> tuple[str, float, str]:
|
||||
"""Detect the primary programming language of the project.
|
||||
|
||||
Detection priority:
|
||||
1. tsconfig.json → TypeScript (high confidence)
|
||||
2. pyproject.toml or setup.py → Python (high confidence)
|
||||
3. package.json → JavaScript (medium confidence)
|
||||
4. File extension counting → best guess (low confidence)
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
|
||||
Returns:
|
||||
Tuple of (language, confidence, detail_string).
|
||||
|
||||
"""
|
||||
has_tsconfig = (project_root / "tsconfig.json").exists()
|
||||
has_pyproject = (project_root / "pyproject.toml").exists()
|
||||
has_setup_py = (project_root / "setup.py").exists()
|
||||
has_package_json = (project_root / "package.json").exists()
|
||||
|
||||
# TypeScript (tsconfig.json is definitive)
|
||||
if has_tsconfig:
|
||||
return "typescript", 1.0, "tsconfig.json found"
|
||||
|
||||
# Python (pyproject.toml or setup.py)
|
||||
if has_pyproject or has_setup_py:
|
||||
marker = "pyproject.toml" if has_pyproject else "setup.py"
|
||||
# Check if it's also a JS project (monorepo)
|
||||
if has_package_json:
|
||||
# Count files to determine primary language
|
||||
py_count = len(list(project_root.rglob("*.py")))
|
||||
js_count = len(list(project_root.rglob("*.js"))) + len(list(project_root.rglob("*.ts")))
|
||||
if js_count > py_count * 2: # JS files significantly outnumber Python
|
||||
return "javascript", 0.7, "package.json found (more JS files than Python)"
|
||||
return "python", 1.0, f"{marker} found"
|
||||
|
||||
# JavaScript (package.json without Python markers)
|
||||
if has_package_json:
|
||||
return "javascript", 0.9, "package.json found"
|
||||
|
||||
# Fall back to file extension counting
|
||||
py_count = len(list(project_root.rglob("*.py")))
|
||||
js_count = len(list(project_root.rglob("*.js")))
|
||||
ts_count = len(list(project_root.rglob("*.ts")))
|
||||
|
||||
if ts_count > 0:
|
||||
return "typescript", 0.5, f"found {ts_count} .ts files"
|
||||
if js_count > py_count:
|
||||
return "javascript", 0.5, f"found {js_count} .js files"
|
||||
if py_count > 0:
|
||||
return "python", 0.5, f"found {py_count} .py files"
|
||||
|
||||
# Default to Python
|
||||
return "python", 0.3, "defaulting to Python"
|
||||
|
||||
|
||||
def _detect_module_root(project_root: Path, language: str) -> tuple[Path, str]:
|
||||
"""Detect the module/source root directory.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
language: Detected language.
|
||||
|
||||
Returns:
|
||||
Tuple of (module_root_path, detail_string).
|
||||
|
||||
"""
|
||||
if language in ("javascript", "typescript"):
|
||||
return _detect_js_module_root(project_root)
|
||||
return _detect_python_module_root(project_root)
|
||||
|
||||
|
||||
def _detect_python_module_root(project_root: Path) -> tuple[Path, str]:
|
||||
"""Detect Python module root.
|
||||
|
||||
Priority:
|
||||
1. pyproject.toml [tool.poetry.name] or [project.name]
|
||||
2. src/ directory with __init__.py
|
||||
3. Directory with __init__.py matching project name
|
||||
4. src/ directory (even without __init__.py)
|
||||
5. Project root
|
||||
|
||||
"""
|
||||
# Try to get project name from pyproject.toml
|
||||
pyproject_path = project_root / "pyproject.toml"
|
||||
project_name = None
|
||||
|
||||
if pyproject_path.exists():
|
||||
try:
|
||||
with pyproject_path.open("rb") as f:
|
||||
data = tomlkit.parse(f.read())
|
||||
|
||||
# Try poetry name
|
||||
project_name = data.get("tool", {}).get("poetry", {}).get("name")
|
||||
# Try standard project name
|
||||
if not project_name:
|
||||
project_name = data.get("project", {}).get("name")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for src layout
|
||||
src_dir = project_root / "src"
|
||||
if src_dir.is_dir():
|
||||
# Check for package inside src
|
||||
if project_name:
|
||||
pkg_dir = src_dir / project_name
|
||||
if pkg_dir.is_dir() and (pkg_dir / "__init__.py").exists():
|
||||
return pkg_dir, f"src/{project_name}/ (from pyproject.toml name)"
|
||||
|
||||
# Check for any package in src
|
||||
for child in src_dir.iterdir():
|
||||
if child.is_dir() and (child / "__init__.py").exists():
|
||||
return child, f"src/{child.name}/ (first package in src)"
|
||||
|
||||
# Use src/ even without __init__.py
|
||||
return src_dir, "src/ directory"
|
||||
|
||||
# Check for package at project root
|
||||
if project_name:
|
||||
pkg_dir = project_root / project_name
|
||||
if pkg_dir.is_dir() and (pkg_dir / "__init__.py").exists():
|
||||
return pkg_dir, f"{project_name}/ (from pyproject.toml name)"
|
||||
|
||||
# Look for any directory with __init__.py at project root
|
||||
for child in project_root.iterdir():
|
||||
if (
|
||||
child.is_dir()
|
||||
and not child.name.startswith(".")
|
||||
and child.name not in ("tests", "test", "docs", "venv", ".venv", "env", "node_modules")
|
||||
):
|
||||
if (child / "__init__.py").exists():
|
||||
return child, f"{child.name}/ (has __init__.py)"
|
||||
|
||||
# Default to project root
|
||||
return project_root, "project root (no package structure detected)"
|
||||
|
||||
|
||||
def _detect_js_module_root(project_root: Path) -> tuple[Path, str]:
|
||||
"""Detect JavaScript/TypeScript module root.
|
||||
|
||||
Priority:
|
||||
1. package.json "exports" field
|
||||
2. package.json "module" field (ESM)
|
||||
3. package.json "main" field (CJS)
|
||||
4. src/ directory
|
||||
5. lib/ directory
|
||||
6. Project root
|
||||
|
||||
"""
|
||||
package_json_path = project_root / "package.json"
|
||||
package_data: dict[str, Any] = {}
|
||||
|
||||
if package_json_path.exists():
|
||||
try:
|
||||
with package_json_path.open(encoding="utf8") as f:
|
||||
package_data = json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
|
||||
# Check exports field (modern Node.js)
|
||||
exports = package_data.get("exports")
|
||||
if exports:
|
||||
entry_path = _extract_entry_path(exports)
|
||||
if entry_path:
|
||||
parent = Path(entry_path).parent
|
||||
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
|
||||
return project_root / parent, f'{parent.as_posix()}/ (from package.json "exports")'
|
||||
|
||||
# Check module field (ESM)
|
||||
module_field = package_data.get("module")
|
||||
if module_field and isinstance(module_field, str):
|
||||
parent = Path(module_field).parent
|
||||
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
|
||||
return project_root / parent, f'{parent.as_posix()}/ (from package.json "module")'
|
||||
|
||||
# Check main field (CJS)
|
||||
main_field = package_data.get("main")
|
||||
if main_field and isinstance(main_field, str):
|
||||
parent = Path(main_field).parent
|
||||
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
|
||||
return project_root / parent, f'{parent.as_posix()}/ (from package.json "main")'
|
||||
|
||||
# Check for common source directories
|
||||
for src_dir in ["src", "lib", "source"]:
|
||||
if (project_root / src_dir).is_dir():
|
||||
return project_root / src_dir, f"{src_dir}/ directory"
|
||||
|
||||
# Default to project root
|
||||
return project_root, "project root"
|
||||
|
||||
|
||||
def _extract_entry_path(exports: Any) -> str | None:
|
||||
"""Extract entry path from package.json exports field."""
|
||||
if isinstance(exports, str):
|
||||
return exports
|
||||
if isinstance(exports, dict):
|
||||
# Handle {"." : "./src/index.js"} or {".": {"import": "./src/index.js"}}
|
||||
main_export = exports.get(".") or exports.get("import") or exports.get("default")
|
||||
if isinstance(main_export, str):
|
||||
return main_export
|
||||
if isinstance(main_export, dict):
|
||||
return main_export.get("import") or main_export.get("default") or main_export.get("require")
|
||||
return None
|
||||
|
||||
|
||||
def _detect_tests_root(project_root: Path, language: str) -> tuple[Path | None, str]:
|
||||
"""Detect the tests directory.
|
||||
|
||||
Common patterns:
|
||||
- tests/ or test/
|
||||
- __tests__/ (JavaScript)
|
||||
- spec/ (Ruby/JavaScript)
|
||||
|
||||
"""
|
||||
# Common test directory names
|
||||
test_dirs = ["tests", "test", "__tests__", "spec"]
|
||||
|
||||
for test_dir in test_dirs:
|
||||
test_path = project_root / test_dir
|
||||
if test_path.is_dir():
|
||||
return test_path, f"{test_dir}/ directory"
|
||||
|
||||
# For Python, check if tests are alongside source
|
||||
if language == "python":
|
||||
# Look for test_*.py files in project root
|
||||
test_files = list(project_root.glob("test_*.py"))
|
||||
if test_files:
|
||||
return project_root, "test files in project root"
|
||||
|
||||
# For JS/TS, check for *.test.js or *.spec.js files
|
||||
if language in ("javascript", "typescript"):
|
||||
test_patterns = ["*.test.js", "*.test.ts", "*.spec.js", "*.spec.ts"]
|
||||
for pattern in test_patterns:
|
||||
test_files = list(project_root.rglob(pattern))
|
||||
if test_files:
|
||||
# Find common parent
|
||||
return project_root, f"found {pattern} files"
|
||||
|
||||
return None, "not detected"
|
||||
|
||||
|
||||
def _detect_test_runner(project_root: Path, language: str) -> tuple[str, str]:
|
||||
"""Detect the test runner.
|
||||
|
||||
Python: pytest > unittest
|
||||
JavaScript: vitest > jest > mocha
|
||||
|
||||
"""
|
||||
if language in ("javascript", "typescript"):
|
||||
return _detect_js_test_runner(project_root)
|
||||
return _detect_python_test_runner(project_root)
|
||||
|
||||
|
||||
def _detect_python_test_runner(project_root: Path) -> tuple[str, str]:
|
||||
"""Detect Python test runner."""
|
||||
# Check for pytest markers
|
||||
pytest_markers = ["pytest.ini", "pyproject.toml", "conftest.py", "setup.cfg"]
|
||||
for marker in pytest_markers:
|
||||
marker_path = project_root / marker
|
||||
if marker_path.exists():
|
||||
if marker == "pyproject.toml":
|
||||
# Check for [tool.pytest] section
|
||||
try:
|
||||
with marker_path.open("rb") as f:
|
||||
data = tomlkit.parse(f.read())
|
||||
if "tool" in data and "pytest" in data["tool"]:
|
||||
return "pytest", "pyproject.toml [tool.pytest]"
|
||||
except Exception:
|
||||
pass
|
||||
elif marker == "conftest.py":
|
||||
return "pytest", "conftest.py found"
|
||||
elif marker in ("pytest.ini", "setup.cfg"):
|
||||
# Check for pytest section in setup.cfg
|
||||
if marker == "setup.cfg":
|
||||
try:
|
||||
content = marker_path.read_text(encoding="utf8")
|
||||
if "[tool:pytest]" in content or "[pytest]" in content:
|
||||
return "pytest", "setup.cfg [pytest]"
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
return "pytest", "pytest.ini found"
|
||||
|
||||
# Default to pytest (most common)
|
||||
return "pytest", "default"
|
||||
|
||||
|
||||
def _detect_js_test_runner(project_root: Path) -> tuple[str, str]:
|
||||
"""Detect JavaScript test runner."""
|
||||
package_json_path = project_root / "package.json"
|
||||
|
||||
if not package_json_path.exists():
|
||||
return "jest", "default (no package.json)"
|
||||
|
||||
try:
|
||||
with package_json_path.open(encoding="utf8") as f:
|
||||
package_data = json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return "jest", "default (invalid package.json)"
|
||||
|
||||
runners = ["vitest", "jest", "mocha"]
|
||||
dev_deps = package_data.get("devDependencies", {})
|
||||
deps = package_data.get("dependencies", {})
|
||||
all_deps = {**deps, **dev_deps}
|
||||
|
||||
# Check dependencies (order matters - prefer more modern runners)
|
||||
for runner in runners:
|
||||
if runner in all_deps:
|
||||
return runner, "from devDependencies"
|
||||
|
||||
# Parse scripts.test for hints
|
||||
scripts = package_data.get("scripts", {})
|
||||
test_script = scripts.get("test", "")
|
||||
if isinstance(test_script, str):
|
||||
test_lower = test_script.lower()
|
||||
for runner in runners:
|
||||
if runner in test_lower:
|
||||
return runner, "from scripts.test"
|
||||
|
||||
# Check for config files
|
||||
config_files = {
|
||||
"vitest": ["vitest.config.js", "vitest.config.ts", "vitest.config.mjs"],
|
||||
"jest": ["jest.config.js", "jest.config.ts", "jest.config.mjs", "jest.config.json"],
|
||||
"mocha": [".mocharc.js", ".mocharc.json", ".mocharc.yaml"],
|
||||
}
|
||||
for runner, configs in config_files.items():
|
||||
for config in configs:
|
||||
if (project_root / config).exists():
|
||||
return runner, f"{config} found"
|
||||
|
||||
return "jest", "default"
|
||||
|
||||
|
||||
def _detect_formatter(project_root: Path, language: str) -> tuple[list[str], str]:
|
||||
"""Detect code formatter.
|
||||
|
||||
Python: ruff > black
|
||||
JavaScript: prettier > eslint --fix
|
||||
|
||||
"""
|
||||
if language in ("javascript", "typescript"):
|
||||
return _detect_js_formatter(project_root)
|
||||
return _detect_python_formatter(project_root)
|
||||
|
||||
|
||||
def _detect_python_formatter(project_root: Path) -> tuple[list[str], str]:
|
||||
"""Detect Python formatter."""
|
||||
pyproject_path = project_root / "pyproject.toml"
|
||||
|
||||
if pyproject_path.exists():
|
||||
try:
|
||||
with pyproject_path.open("rb") as f:
|
||||
data = tomlkit.parse(f.read())
|
||||
|
||||
tool = data.get("tool", {})
|
||||
|
||||
# Check for ruff
|
||||
if "ruff" in tool:
|
||||
return ["ruff check --exit-zero --fix $file", "ruff format $file"], "from pyproject.toml [tool.ruff]"
|
||||
|
||||
# Check for black
|
||||
if "black" in tool:
|
||||
return ["black $file"], "from pyproject.toml [tool.black]"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for config files
|
||||
if (project_root / "ruff.toml").exists() or (project_root / ".ruff.toml").exists():
|
||||
return ["ruff check --exit-zero --fix $file", "ruff format $file"], "ruff.toml found"
|
||||
|
||||
if (project_root / ".black").exists() or (project_root / "pyproject.toml").exists():
|
||||
# Default to black if pyproject.toml exists (common setup)
|
||||
return ["black $file"], "default (black)"
|
||||
|
||||
return [], "none detected"
|
||||
|
||||
|
||||
def _detect_js_formatter(project_root: Path) -> tuple[list[str], str]:
|
||||
"""Detect JavaScript formatter."""
|
||||
package_json_path = project_root / "package.json"
|
||||
|
||||
# Check for prettier config files
|
||||
prettier_configs = [".prettierrc", ".prettierrc.js", ".prettierrc.json", "prettier.config.js"]
|
||||
for config in prettier_configs:
|
||||
if (project_root / config).exists():
|
||||
return ["npx prettier --write $file"], f"{config} found"
|
||||
|
||||
# Check for eslint config files
|
||||
eslint_configs = [".eslintrc", ".eslintrc.js", ".eslintrc.json", "eslint.config.js"]
|
||||
for config in eslint_configs:
|
||||
if (project_root / config).exists():
|
||||
return ["npx eslint --fix $file"], f"{config} found"
|
||||
|
||||
# Check package.json dependencies
|
||||
if package_json_path.exists():
|
||||
try:
|
||||
with package_json_path.open(encoding="utf8") as f:
|
||||
package_data = json.load(f)
|
||||
|
||||
dev_deps = package_data.get("devDependencies", {})
|
||||
deps = package_data.get("dependencies", {})
|
||||
all_deps = {**deps, **dev_deps}
|
||||
|
||||
if "prettier" in all_deps:
|
||||
return ["npx prettier --write $file"], "from devDependencies"
|
||||
if "eslint" in all_deps:
|
||||
return ["npx eslint --fix $file"], "from devDependencies"
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
|
||||
return [], "none detected"
|
||||
|
||||
|
||||
def _detect_ignore_paths(project_root: Path, language: str) -> tuple[list[Path], str]:
|
||||
"""Detect paths to ignore during optimization.
|
||||
|
||||
Sources:
|
||||
1. .gitignore
|
||||
2. Language-specific defaults
|
||||
|
||||
"""
|
||||
ignore_paths: list[Path] = []
|
||||
sources: list[str] = []
|
||||
|
||||
# Default ignore patterns by language
|
||||
default_ignores: dict[str, list[str]] = {
|
||||
"python": [
|
||||
"__pycache__",
|
||||
".pytest_cache",
|
||||
".mypy_cache",
|
||||
".ruff_cache",
|
||||
"venv",
|
||||
".venv",
|
||||
"env",
|
||||
".env",
|
||||
"dist",
|
||||
"build",
|
||||
"*.egg-info",
|
||||
".tox",
|
||||
".nox",
|
||||
"htmlcov",
|
||||
".coverage",
|
||||
],
|
||||
"javascript": ["node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache"],
|
||||
"typescript": ["node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache"],
|
||||
}
|
||||
|
||||
# Add default ignores
|
||||
for pattern in default_ignores.get(language, []):
|
||||
path = project_root / pattern.replace("*", "")
|
||||
if path.exists():
|
||||
ignore_paths.append(path)
|
||||
|
||||
if ignore_paths:
|
||||
sources.append("defaults")
|
||||
|
||||
# Parse .gitignore
|
||||
gitignore_path = project_root / ".gitignore"
|
||||
if gitignore_path.exists():
|
||||
try:
|
||||
content = gitignore_path.read_text(encoding="utf8")
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
# Skip comments and empty lines
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
# Skip negation patterns
|
||||
if line.startswith("!"):
|
||||
continue
|
||||
# Convert gitignore pattern to path
|
||||
pattern = line.rstrip("/").lstrip("/")
|
||||
# Skip complex patterns for now
|
||||
if "*" in pattern or "?" in pattern:
|
||||
continue
|
||||
path = project_root / pattern
|
||||
if path.exists() and path not in ignore_paths:
|
||||
ignore_paths.append(path)
|
||||
|
||||
if ".gitignore" not in sources:
|
||||
sources.append(".gitignore")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
detail = " + ".join(sources) if sources else "none"
|
||||
return ignore_paths, detail
|
||||
|
||||
|
||||
def has_existing_config(project_root: Path) -> tuple[bool, str | None]:
|
||||
"""Check if project has existing Codeflash configuration.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
|
||||
Returns:
|
||||
Tuple of (has_config, config_file_type).
|
||||
config_file_type is "pyproject.toml", "package.json", or None.
|
||||
|
||||
"""
|
||||
# Check pyproject.toml
|
||||
pyproject_path = project_root / "pyproject.toml"
|
||||
if pyproject_path.exists():
|
||||
try:
|
||||
with pyproject_path.open("rb") as f:
|
||||
data = tomlkit.parse(f.read())
|
||||
if "tool" in data and "codeflash" in data["tool"]:
|
||||
return True, "pyproject.toml"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check package.json
|
||||
package_json_path = project_root / "package.json"
|
||||
if package_json_path.exists():
|
||||
try:
|
||||
with package_json_path.open(encoding="utf8") as f:
|
||||
data = json.load(f)
|
||||
if "codeflash" in data:
|
||||
return True, "package.json"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False, None
|
||||
294
codeflash/setup/first_run.py
Normal file
294
codeflash/setup/first_run.py
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
"""First-run experience for Codeflash.
|
||||
|
||||
This module handles the seamless first-run experience:
|
||||
1. Auto-detect project settings
|
||||
2. Display detected settings
|
||||
3. Quick confirmation
|
||||
4. API key setup
|
||||
5. Save config and continue
|
||||
|
||||
Usage:
|
||||
from codeflash.setup.first_run import handle_first_run, is_first_run
|
||||
|
||||
if is_first_run():
|
||||
args = handle_first_run(args)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from codeflash.cli_cmds.console import console
|
||||
from codeflash.setup.config_writer import write_config
|
||||
from codeflash.setup.detector import detect_project, has_existing_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def is_first_run(project_root: Path | None = None) -> bool:
|
||||
"""Check if this is the first run (no config exists).
|
||||
|
||||
Args:
|
||||
project_root: Project root to check. Defaults to auto-detect.
|
||||
|
||||
Returns:
|
||||
True if no Codeflash config exists.
|
||||
|
||||
"""
|
||||
if project_root is None:
|
||||
try:
|
||||
detected = detect_project()
|
||||
project_root = detected.project_root
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
has_config, _ = has_existing_config(project_root)
|
||||
return not has_config
|
||||
|
||||
|
||||
def handle_first_run(
|
||||
args: Namespace | None = None, skip_confirm: bool = False, skip_api_key: bool = False
|
||||
) -> Namespace | None:
|
||||
"""Handle the first-run experience with auto-detection and quick confirm.
|
||||
|
||||
This is the main entry point for the frictionless setup experience.
|
||||
|
||||
Args:
|
||||
args: Optional CLI args namespace to update.
|
||||
skip_confirm: Skip confirmation prompt (--yes flag).
|
||||
skip_api_key: Skip API key prompt.
|
||||
|
||||
Returns:
|
||||
Updated args namespace with detected settings, or None if user cancelled.
|
||||
|
||||
"""
|
||||
from argparse import Namespace
|
||||
|
||||
# Auto-detect project
|
||||
try:
|
||||
detected = detect_project()
|
||||
except Exception as e:
|
||||
_show_detection_error(str(e))
|
||||
return None
|
||||
|
||||
# Show welcome message
|
||||
_show_welcome()
|
||||
|
||||
# Show detected settings
|
||||
_show_detected_settings(detected)
|
||||
|
||||
# Get user confirmation
|
||||
if not skip_confirm:
|
||||
choice = _prompt_confirmation()
|
||||
if choice == "n":
|
||||
_show_cancelled()
|
||||
return None
|
||||
if choice == "customize":
|
||||
# TODO: Implement customize flow (redirect to codeflash init)
|
||||
console.print("\n💡 Run [cyan]codeflash init[/cyan] for full customization.\n")
|
||||
return None
|
||||
|
||||
# Handle API key
|
||||
if not skip_api_key:
|
||||
api_key_ok = _handle_api_key()
|
||||
if not api_key_ok:
|
||||
return None
|
||||
|
||||
# Save config
|
||||
success, message = write_config(detected)
|
||||
if success:
|
||||
console.print(f"\n✅ {message}\n")
|
||||
else:
|
||||
console.print(f"\n⚠️ {message}\n")
|
||||
console.print("Continuing with detected settings (not saved).\n")
|
||||
|
||||
# Create/update args namespace
|
||||
if args is None:
|
||||
args = Namespace()
|
||||
|
||||
# Populate args with detected values
|
||||
args.module_root = str(detected.module_root)
|
||||
args.tests_root = str(detected.tests_root) if detected.tests_root else None
|
||||
args.project_root = str(detected.project_root)
|
||||
args.formatter_cmds = detected.formatter_cmds
|
||||
args.ignore_paths = [str(p) for p in detected.ignore_paths]
|
||||
args.pytest_cmd = detected.test_runner
|
||||
args.language = detected.language
|
||||
|
||||
# Set defaults for other common args
|
||||
if not hasattr(args, "disable_telemetry"):
|
||||
args.disable_telemetry = False
|
||||
if not hasattr(args, "git_remote"):
|
||||
args.git_remote = "origin"
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def _show_welcome() -> None:
|
||||
"""Show welcome message for first-time users."""
|
||||
welcome_panel = Panel(
|
||||
Text(
|
||||
"⚡ Welcome to Codeflash!\n\nI've auto-detected your project settings.\nThis will only take a moment.",
|
||||
style="bold cyan",
|
||||
justify="center",
|
||||
),
|
||||
title="🚀 First-Time Setup",
|
||||
border_style="bright_cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print(welcome_panel)
|
||||
console.print()
|
||||
|
||||
|
||||
def _show_detected_settings(detected: detect_project) -> None:
|
||||
"""Display detected settings in a nice table."""
|
||||
from codeflash.setup.detector import DetectedProject
|
||||
|
||||
if not isinstance(detected, DetectedProject):
|
||||
return
|
||||
|
||||
# Create settings table
|
||||
table = Table(show_header=False, box=None, padding=(0, 2))
|
||||
table.add_column("Setting", style="cyan", width=15)
|
||||
table.add_column("Value", style="green")
|
||||
table.add_column("Source", style="dim")
|
||||
|
||||
display_dict = detected.to_display_dict()
|
||||
details = detected.detection_details
|
||||
|
||||
for key, value in display_dict.items():
|
||||
source = details.get(key.lower().replace(" ", "_"), "")
|
||||
# Truncate long sources
|
||||
if len(source) > 30:
|
||||
source = source[:27] + "..."
|
||||
table.add_row(key, value, f"({source})" if source else "")
|
||||
|
||||
settings_panel = Panel(table, title="🔍 Auto-Detected Settings", border_style="bright_blue", padding=(1, 2))
|
||||
console.print(settings_panel)
|
||||
console.print()
|
||||
|
||||
|
||||
def _prompt_confirmation() -> str:
|
||||
"""Prompt user for confirmation.
|
||||
|
||||
Returns:
|
||||
"y" for yes, "n" for no, "customize" for customization.
|
||||
|
||||
"""
|
||||
# Check if we're in a non-interactive environment
|
||||
if not sys.stdin.isatty():
|
||||
console.print("⚠️ Non-interactive environment detected. Use --yes to skip confirmation.")
|
||||
return "n"
|
||||
|
||||
console.print("? [bold]Proceed with these settings?[/bold]")
|
||||
console.print(" [green]Y[/green] - Yes, save and continue")
|
||||
console.print(" [yellow]n[/yellow] - No, cancel")
|
||||
console.print(" [cyan]c[/cyan] - Customize (run full setup)")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
choice = console.input("[bold]Your choice[/bold] [green][Y][/green]/n/c: ").strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return "n"
|
||||
|
||||
if choice in ("", "y", "yes"):
|
||||
return "y"
|
||||
if choice in ("c", "customize"):
|
||||
return "customize"
|
||||
return "n"
|
||||
|
||||
|
||||
def _handle_api_key() -> bool:
|
||||
"""Handle API key setup if not already configured.
|
||||
|
||||
Returns:
|
||||
True if API key is available, False if user cancelled.
|
||||
|
||||
"""
|
||||
from codeflash.code_utils.env_utils import get_codeflash_api_key
|
||||
|
||||
# Check for existing API key
|
||||
try:
|
||||
existing_key = get_codeflash_api_key()
|
||||
if existing_key:
|
||||
display_key = f"{existing_key[:3]}****{existing_key[-4:]}"
|
||||
console.print(f"✅ Found API key: [green]{display_key}[/green]\n")
|
||||
return True
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Prompt for API key
|
||||
console.print("🔑 [bold]API Key Required[/bold]")
|
||||
console.print(" Get your API key at: [cyan]https://app.codeflash.ai/app/apikeys[/cyan]\n")
|
||||
|
||||
try:
|
||||
api_key = console.input(" Enter API key (or press Enter to open browser): ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return False
|
||||
|
||||
if not api_key:
|
||||
# Open browser
|
||||
import click
|
||||
|
||||
click.launch("https://app.codeflash.ai/app/apikeys")
|
||||
console.print("\n Opening browser...")
|
||||
try:
|
||||
api_key = console.input(" Enter API key: ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return False
|
||||
|
||||
if not api_key:
|
||||
console.print("\n⚠️ API key required. Run [cyan]codeflash init[/cyan] to set up.\n")
|
||||
return False
|
||||
|
||||
if not api_key.startswith("cf-"):
|
||||
console.print("\n⚠️ Invalid API key format. Should start with 'cf-'.\n")
|
||||
return False
|
||||
|
||||
# Save API key to environment
|
||||
os.environ["CODEFLASH_API_KEY"] = api_key
|
||||
|
||||
# Try to save to shell rc
|
||||
try:
|
||||
from codeflash.code_utils.shell_utils import save_api_key_to_rc
|
||||
from codeflash.either import is_successful
|
||||
|
||||
result = save_api_key_to_rc(api_key)
|
||||
if is_successful(result):
|
||||
console.print(f"\n✅ API key saved. {result.unwrap()}\n")
|
||||
else:
|
||||
console.print(f"\n⚠️ Could not save to shell: {result.failure()}")
|
||||
console.print(" API key set for this session only.\n")
|
||||
except Exception:
|
||||
console.print("\n✅ API key set for this session.\n")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _show_detection_error(error: str) -> None:
|
||||
"""Show error message when detection fails."""
|
||||
error_panel = Panel(
|
||||
Text(
|
||||
f"❌ Could not auto-detect project settings.\n\n"
|
||||
f"Error: {error}\n\n"
|
||||
"Please run [cyan]codeflash init[/cyan] for manual setup.",
|
||||
style="red",
|
||||
),
|
||||
title="⚠️ Detection Failed",
|
||||
border_style="red",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print(error_panel)
|
||||
|
||||
|
||||
def _show_cancelled() -> None:
|
||||
"""Show cancellation message."""
|
||||
console.print("\n⏹️ Setup cancelled. Run [cyan]codeflash init[/cyan] when ready.\n")
|
||||
|
|
@ -672,7 +672,7 @@ def parse_jest_test_xml(
|
|||
test_results = TestResults()
|
||||
|
||||
if not test_xml_file_path.exists():
|
||||
logger.warning(f"No Jest test results for {test_xml_file_path} found.")
|
||||
logger.warning(f"No JavaScript test results for {test_xml_file_path} found.")
|
||||
return test_results
|
||||
|
||||
# Log file size for debugging
|
||||
|
|
@ -1486,6 +1486,9 @@ def parse_test_results(
|
|||
get_run_tmp_file(Path("unittest_results.xml")).unlink(missing_ok=True)
|
||||
get_run_tmp_file(Path("jest_results.xml")).unlink(missing_ok=True)
|
||||
get_run_tmp_file(Path("jest_perf_results.xml")).unlink(missing_ok=True)
|
||||
get_run_tmp_file(Path("vitest_results.xml")).unlink(missing_ok=True)
|
||||
get_run_tmp_file(Path("vitest_perf_results.xml")).unlink(missing_ok=True)
|
||||
get_run_tmp_file(Path("vitest_line_profile_results.xml")).unlink(missing_ok=True)
|
||||
|
||||
# For Jest tests, SQLite cleanup is deferred until after comparison
|
||||
# (comparison happens via language_support.compare_test_results)
|
||||
|
|
|
|||
|
|
@ -72,29 +72,23 @@ def generate_tests(
|
|||
from codeflash.languages.javascript.module_system import ensure_module_system_compatibility
|
||||
|
||||
source_file = Path(function_to_optimize.file_path)
|
||||
func_name = function_to_optimize.function_name
|
||||
qualified_name = function_to_optimize.qualified_name
|
||||
|
||||
# First validate and fix import styles
|
||||
generated_test_source = validate_and_fix_import_style(generated_test_source, source_file, func_name)
|
||||
# Validate and fix import styles (default vs named exports)
|
||||
generated_test_source = validate_and_fix_import_style(
|
||||
generated_test_source, source_file, function_to_optimize.function_name
|
||||
)
|
||||
|
||||
# Convert module system if needed (e.g., CommonJS -> ESM for ESM projects)
|
||||
generated_test_source = ensure_module_system_compatibility(generated_test_source, project_module_system)
|
||||
|
||||
# Instrument for behavior verification (writes to SQLite)
|
||||
instrumented_behavior_test_source = instrument_generated_js_test(
|
||||
test_code=generated_test_source,
|
||||
function_name=func_name,
|
||||
qualified_name=qualified_name,
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
# Instrument for performance measurement (prints to stdout)
|
||||
instrumented_perf_test_source = instrument_generated_js_test(
|
||||
test_code=generated_test_source,
|
||||
function_name=func_name,
|
||||
qualified_name=qualified_name,
|
||||
mode=TestingMode.PERFORMANCE,
|
||||
test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.PERFORMANCE
|
||||
)
|
||||
|
||||
logger.debug(f"Instrumented JS/TS tests locally for {func_name}")
|
||||
|
|
|
|||
505
packages/codeflash/package-lock.json
generated
505
packages/codeflash/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "codeflash",
|
||||
"version": "0.3.0",
|
||||
"version": "0.5.0",
|
||||
"description": "Codeflash - AI-powered code optimization for JavaScript and TypeScript",
|
||||
"main": "runtime/index.js",
|
||||
"types": "runtime/index.d.ts",
|
||||
|
|
@ -9,7 +9,7 @@
|
|||
"codeflash-setup": "./bin/codeflash-setup.js"
|
||||
},
|
||||
"publishConfig": {
|
||||
"access": "public"
|
||||
"access": "public"
|
||||
},
|
||||
"exports": {
|
||||
".": {
|
||||
|
|
@ -52,7 +52,8 @@
|
|||
"typescript",
|
||||
"ai",
|
||||
"cli",
|
||||
"jest"
|
||||
"jest",
|
||||
"vitest"
|
||||
],
|
||||
"author": "Codeflash AI",
|
||||
"license": "MIT",
|
||||
|
|
@ -70,7 +71,8 @@
|
|||
},
|
||||
"peerDependencies": {
|
||||
"jest": ">=27.0.0",
|
||||
"jest-runner": ">=27.0.0"
|
||||
"jest-runner": ">=27.0.0",
|
||||
"vitest": ">=1.0.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"jest": {
|
||||
|
|
@ -78,12 +80,13 @@
|
|||
},
|
||||
"jest-runner": {
|
||||
"optional": true
|
||||
},
|
||||
"vitest": {
|
||||
"optional": true
|
||||
}
|
||||
},
|
||||
"dependencies": {
|
||||
"better-sqlite3": "^12.0.0",
|
||||
"@msgpack/msgpack": "^3.0.0",
|
||||
"jest-runner": "^29.7.0",
|
||||
"jest-junit": "^16.0.0"
|
||||
"@msgpack/msgpack": "^3.0.0"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -646,9 +646,16 @@ function capturePerf(funcName, lineId, fn, ...args) {
|
|||
let lastReturnValue;
|
||||
let lastError = null;
|
||||
|
||||
// Batched looping: run BATCH_SIZE loops per capturePerf call
|
||||
// This ensures fair distribution across all test invocations
|
||||
const batchSize = shouldLoop ? PERF_BATCH_SIZE : 1;
|
||||
// Determine if we're running with external loop-runner (Jest) or internal looping (Vitest)
|
||||
// loop-runner sets CODEFLASH_PERF_CURRENT_BATCH before each batch
|
||||
// If not set, we're in Vitest mode and need to do all loops internally
|
||||
const hasExternalLoopRunner = process.env.CODEFLASH_PERF_CURRENT_BATCH !== undefined;
|
||||
|
||||
// Batched looping: run BATCH_SIZE loops per capturePerf call when using loop-runner
|
||||
// For Vitest (no loop-runner), do all loops internally in a single call
|
||||
const batchSize = shouldLoop
|
||||
? (hasExternalLoopRunner ? PERF_BATCH_SIZE : PERF_LOOP_COUNT)
|
||||
: 1;
|
||||
|
||||
for (let batchIndex = 0; batchIndex < batchSize; batchIndex++) {
|
||||
// Check shared time limit BEFORE each iteration
|
||||
|
|
|
|||
|
|
@ -20,6 +20,10 @@
|
|||
*
|
||||
* Usage:
|
||||
* npx jest --runner=codeflash/loop-runner
|
||||
*
|
||||
* NOTE: This runner requires jest-runner to be installed in your project.
|
||||
* It is a Jest-specific feature and does not work with Vitest.
|
||||
* For Vitest projects, capturePerf() does all loops internally in a single call.
|
||||
*/
|
||||
|
||||
'use strict';
|
||||
|
|
@ -27,12 +31,27 @@
|
|||
const { createRequire } = require('module');
|
||||
const path = require('path');
|
||||
|
||||
const jestRunnerPath = require.resolve('jest-runner');
|
||||
const internalRequire = createRequire(jestRunnerPath);
|
||||
const runTest = internalRequire('./runTest').default;
|
||||
// Try to load jest-runner - it's a peer dependency that must be installed by the user
|
||||
let runTest;
|
||||
let jestRunnerAvailable = false;
|
||||
|
||||
try {
|
||||
const jestRunnerPath = require.resolve('jest-runner');
|
||||
const internalRequire = createRequire(jestRunnerPath);
|
||||
runTest = internalRequire('./runTest').default;
|
||||
jestRunnerAvailable = true;
|
||||
} catch (e) {
|
||||
// jest-runner not installed - this is expected for Vitest projects
|
||||
// The runner will throw a helpful error if someone tries to use it without jest-runner
|
||||
jestRunnerAvailable = false;
|
||||
}
|
||||
|
||||
// Configuration
|
||||
const MAX_BATCHES = parseInt(process.env.CODEFLASH_PERF_LOOP_COUNT || '10000', 10);
|
||||
const PERF_LOOP_COUNT = parseInt(process.env.CODEFLASH_PERF_LOOP_COUNT || '10000', 10);
|
||||
const PERF_BATCH_SIZE = parseInt(process.env.CODEFLASH_PERF_BATCH_SIZE || '10', 10);
|
||||
// MAX_BATCHES = how many batches needed to reach PERF_LOOP_COUNT iterations
|
||||
// Add 1 to handle any rounding, but cap at PERF_LOOP_COUNT to avoid excessive batches
|
||||
const MAX_BATCHES = Math.min(Math.ceil(PERF_LOOP_COUNT / PERF_BATCH_SIZE) + 1, PERF_LOOP_COUNT);
|
||||
const TARGET_DURATION_MS = parseInt(process.env.CODEFLASH_PERF_TARGET_DURATION_MS || '10000', 10);
|
||||
const MIN_BATCHES = parseInt(process.env.CODEFLASH_PERF_MIN_LOOPS || '5', 10);
|
||||
|
||||
|
|
@ -90,6 +109,14 @@ function deepCopy(obj, seen = new WeakMap()) {
|
|||
*/
|
||||
class CodeflashLoopRunner {
|
||||
constructor(globalConfig, context) {
|
||||
if (!jestRunnerAvailable) {
|
||||
throw new Error(
|
||||
'codeflash/loop-runner requires jest-runner to be installed.\n' +
|
||||
'Please install it: npm install --save-dev jest-runner\n\n' +
|
||||
'If you are using Vitest, the loop-runner is not needed - ' +
|
||||
'Vitest projects use external looping handled by the Python runner.'
|
||||
);
|
||||
}
|
||||
this._globalConfig = globalConfig;
|
||||
this._context = context || {};
|
||||
this._eventEmitter = new SimpleEventEmitter();
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ function installCodeflash(uvBin) {
|
|||
try {
|
||||
// Use uv tool install to install codeflash in an isolated environment
|
||||
// This avoids conflicts with any existing Python environments
|
||||
execSync(`"${uvBin}" tool install codeflash --force`, {
|
||||
execSync(`"${uvBin}" tool install --force --python python3.12 codeflash`, {
|
||||
stdio: 'inherit',
|
||||
shell: true,
|
||||
});
|
||||
|
|
|
|||
|
|
@ -488,7 +488,7 @@ class TestParsePackageJsonConfig:
|
|||
assert result is not None
|
||||
config, path = result
|
||||
assert config["language"] == "javascript"
|
||||
assert config["test_runner"] == "jest"
|
||||
assert config["test_framework"] == "jest"
|
||||
assert config["pytest_cmd"] == "jest"
|
||||
assert path == package_json
|
||||
|
||||
|
|
@ -728,7 +728,7 @@ class TestRealWorldPackageJsonExamples:
|
|||
config, _ = result
|
||||
assert config["language"] == "typescript"
|
||||
assert config["module_root"] == str((tmp_path / "src").resolve())
|
||||
assert config["test_runner"] == "jest"
|
||||
assert config["test_framework"] == "jest"
|
||||
assert config["formatter_cmds"] == ["npx prettier --write $file"]
|
||||
|
||||
def test_vite_react_project(self, tmp_path: Path) -> None:
|
||||
|
|
@ -752,7 +752,7 @@ class TestRealWorldPackageJsonExamples:
|
|||
assert result is not None
|
||||
config, _ = result
|
||||
assert config["language"] == "typescript"
|
||||
assert config["test_runner"] == "vitest"
|
||||
assert config["test_framework"] == "vitest"
|
||||
assert config["formatter_cmds"] == ["npx eslint --fix $file"]
|
||||
|
||||
def test_library_with_exports(self, tmp_path: Path) -> None:
|
||||
|
|
@ -812,7 +812,7 @@ class TestRealWorldPackageJsonExamples:
|
|||
assert result is not None
|
||||
config, _ = result
|
||||
assert config["module_root"] == str((tmp_path / "lib").resolve())
|
||||
assert config["test_runner"] == "mocha"
|
||||
assert config["test_framework"] == "mocha"
|
||||
|
||||
def test_minimal_project(self, tmp_path: Path) -> None:
|
||||
"""Should handle minimal package.json."""
|
||||
|
|
@ -825,7 +825,7 @@ class TestRealWorldPackageJsonExamples:
|
|||
config, _ = result
|
||||
assert config["language"] == "javascript"
|
||||
assert config["module_root"] == str(tmp_path.resolve())
|
||||
assert config["test_runner"] == "jest"
|
||||
assert config["test_framework"] == "jest"
|
||||
assert config["formatter_cmds"] == []
|
||||
|
||||
def test_existing_codeflash_config_with_overrides(self, tmp_path: Path) -> None:
|
||||
|
|
@ -855,3 +855,143 @@ class TestRealWorldPackageJsonExamples:
|
|||
assert config["formatter_cmds"] == ["npx prettier --write --single-quote $file"]
|
||||
assert len(config["ignore_paths"]) == 2
|
||||
assert config["disable_telemetry"] is True
|
||||
|
||||
|
||||
class TestTestFrameworkConfigOverride:
|
||||
"""Tests for explicit test-framework config override (matches Python's pyproject.toml)."""
|
||||
|
||||
def test_test_framework_overrides_auto_detection(self, tmp_path: Path) -> None:
|
||||
"""Should use test-framework from codeflash config instead of auto-detecting from devDependencies."""
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test-project",
|
||||
"devDependencies": {"vitest": "^1.0.0"},
|
||||
"codeflash": {"test-framework": "jest"},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
result = parse_package_json_config(package_json)
|
||||
|
||||
assert result is not None
|
||||
config, _ = result
|
||||
assert config["test_framework"] == "jest"
|
||||
assert config["pytest_cmd"] == "jest"
|
||||
|
||||
def test_explicit_vitest_config_with_jest_in_deps(self, tmp_path: Path) -> None:
|
||||
"""Should use explicit vitest config even when jest is in devDependencies."""
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test-project",
|
||||
"devDependencies": {"jest": "^29.0.0", "vitest": "^1.0.0"},
|
||||
"codeflash": {"test-framework": "vitest"},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
result = parse_package_json_config(package_json)
|
||||
|
||||
assert result is not None
|
||||
config, _ = result
|
||||
assert config["test_framework"] == "vitest"
|
||||
|
||||
def test_explicit_mocha_overrides_vitest_and_jest(self, tmp_path: Path) -> None:
|
||||
"""Should use explicit mocha config even when vitest and jest are in devDependencies."""
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test-project",
|
||||
"devDependencies": {"vitest": "^1.0.0", "jest": "^29.0.0"},
|
||||
"codeflash": {"test-framework": "mocha"},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
result = parse_package_json_config(package_json)
|
||||
|
||||
assert result is not None
|
||||
config, _ = result
|
||||
assert config["test_framework"] == "mocha"
|
||||
|
||||
def test_auto_detection_when_no_explicit_config(self, tmp_path: Path) -> None:
|
||||
"""Should auto-detect test framework when no explicit config is provided."""
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test-project",
|
||||
"devDependencies": {"vitest": "^1.0.0"},
|
||||
"codeflash": {"moduleRoot": "src"},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
result = parse_package_json_config(package_json)
|
||||
|
||||
assert result is not None
|
||||
config, _ = result
|
||||
assert config["test_framework"] == "vitest"
|
||||
|
||||
def test_empty_test_framework_falls_back_to_auto_detection(self, tmp_path: Path) -> None:
|
||||
"""Should auto-detect when test-framework is empty string."""
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test-project",
|
||||
"devDependencies": {"jest": "^29.0.0"},
|
||||
"codeflash": {"test-framework": ""},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
result = parse_package_json_config(package_json)
|
||||
|
||||
assert result is not None
|
||||
config, _ = result
|
||||
assert config["test_framework"] == "jest"
|
||||
|
||||
def test_custom_test_framework_value(self, tmp_path: Path) -> None:
|
||||
"""Should accept custom test framework values not in the standard list."""
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test-project",
|
||||
"devDependencies": {"vitest": "^1.0.0"},
|
||||
"codeflash": {"test-framework": "ava"},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
result = parse_package_json_config(package_json)
|
||||
|
||||
assert result is not None
|
||||
config, _ = result
|
||||
assert config["test_framework"] == "ava"
|
||||
|
||||
def test_pytest_cmd_matches_test_framework_with_override(self, tmp_path: Path) -> None:
|
||||
"""Should set pytest_cmd to match test_framework when using explicit config."""
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test-project",
|
||||
"devDependencies": {"vitest": "^1.0.0"},
|
||||
"codeflash": {"test-framework": "jest"},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
result = parse_package_json_config(package_json)
|
||||
|
||||
assert result is not None
|
||||
config, _ = result
|
||||
assert config["test_framework"] == "jest"
|
||||
assert config["pytest_cmd"] == "jest"
|
||||
assert config["test_framework"] == config["pytest_cmd"]
|
||||
|
|
|
|||
0
tests/languages/javascript/__init__.py
Normal file
0
tests/languages/javascript/__init__.py
Normal file
383
tests/languages/javascript/test_support_dispatch.py
Normal file
383
tests/languages/javascript/test_support_dispatch.py
Normal file
|
|
@ -0,0 +1,383 @@
|
|||
"""Tests for JavaScript/TypeScript support.py test framework dispatch logic.
|
||||
|
||||
These tests verify that run_behavioral_tests, run_benchmarking_tests, and
|
||||
run_line_profile_tests correctly dispatch to Jest or Vitest based on the
|
||||
test_framework parameter or singleton.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
from codeflash.languages.test_framework import reset_test_framework, set_current_test_framework
|
||||
from codeflash.models.models import TestFile, TestFiles, TestType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def js_support() -> JavaScriptSupport:
|
||||
"""Create a JavaScriptSupport instance."""
|
||||
return JavaScriptSupport()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ts_support() -> TypeScriptSupport:
|
||||
"""Create a TypeScriptSupport instance."""
|
||||
return TypeScriptSupport()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton():
|
||||
"""Reset the test framework singleton before each test."""
|
||||
reset_test_framework()
|
||||
yield
|
||||
reset_test_framework()
|
||||
|
||||
|
||||
def create_test_files(tmp_path: Path) -> TestFiles:
|
||||
"""Create a TestFiles object with real file paths."""
|
||||
test_file = tmp_path / "tests" / "test_func.test.ts"
|
||||
test_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file.write_text("// test file")
|
||||
|
||||
perf_file = tmp_path / "tests" / "test_func__perf.test.ts"
|
||||
perf_file.write_text("// perf test file")
|
||||
|
||||
return TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=perf_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TestBehavioralTestsDispatch:
|
||||
"""Tests for run_behavioral_tests dispatch logic."""
|
||||
|
||||
@patch("codeflash.languages.javascript.test_runner.run_jest_behavioral_tests")
|
||||
def test_dispatches_to_jest_by_default(self, mock_jest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
|
||||
"""Should dispatch to Jest when test_framework is not specified."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
|
||||
mock_jest_runner.return_value = (tmp_path / "result.xml", MagicMock(), None, None)
|
||||
|
||||
js_support.run_behavioral_tests(test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path)
|
||||
|
||||
mock_jest_runner.assert_called_once()
|
||||
|
||||
@patch("codeflash.languages.javascript.test_runner.run_jest_behavioral_tests")
|
||||
def test_dispatches_to_jest_explicitly(self, mock_jest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
|
||||
"""Should dispatch to Jest when test_framework='jest'."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
|
||||
mock_jest_runner.return_value = (tmp_path / "result.xml", MagicMock(), None, None)
|
||||
|
||||
js_support.run_behavioral_tests(
|
||||
test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path, test_framework="jest"
|
||||
)
|
||||
|
||||
mock_jest_runner.assert_called_once()
|
||||
|
||||
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_behavioral_tests")
|
||||
def test_dispatches_to_vitest(self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
|
||||
"""Should dispatch to Vitest when test_framework='vitest'."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
|
||||
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock(), None, None)
|
||||
|
||||
js_support.run_behavioral_tests(
|
||||
test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path, test_framework="vitest"
|
||||
)
|
||||
|
||||
mock_vitest_runner.assert_called_once()
|
||||
|
||||
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_behavioral_tests")
|
||||
def test_typescript_support_dispatches_to_vitest(
|
||||
self, mock_vitest_runner: MagicMock, ts_support: TypeScriptSupport
|
||||
) -> None:
|
||||
"""TypeScriptSupport should also dispatch to Vitest when test_framework='vitest'."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
|
||||
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock(), None, None)
|
||||
|
||||
ts_support.run_behavioral_tests(
|
||||
test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path, test_framework="vitest"
|
||||
)
|
||||
|
||||
mock_vitest_runner.assert_called_once()
|
||||
|
||||
|
||||
class TestBenchmarkingTestsDispatch:
|
||||
"""Tests for run_benchmarking_tests dispatch logic."""
|
||||
|
||||
@patch("codeflash.languages.javascript.test_runner.run_jest_benchmarking_tests")
|
||||
def test_dispatches_to_jest_by_default(self, mock_jest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
|
||||
"""Should dispatch to Jest when test_framework is not specified."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
|
||||
mock_jest_runner.return_value = (tmp_path / "result.xml", MagicMock())
|
||||
|
||||
js_support.run_benchmarking_tests(test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path)
|
||||
|
||||
mock_jest_runner.assert_called_once()
|
||||
|
||||
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_benchmarking_tests")
|
||||
def test_dispatches_to_vitest(self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
|
||||
"""Should dispatch to Vitest when test_framework='vitest'."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
|
||||
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock())
|
||||
|
||||
js_support.run_benchmarking_tests(
|
||||
test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path, test_framework="vitest"
|
||||
)
|
||||
|
||||
mock_vitest_runner.assert_called_once()
|
||||
|
||||
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_benchmarking_tests")
|
||||
def test_passes_loop_parameters(self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
|
||||
"""Should pass loop parameters to Vitest runner."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
|
||||
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock())
|
||||
|
||||
js_support.run_benchmarking_tests(
|
||||
test_paths=test_files,
|
||||
test_env={},
|
||||
cwd=tmp_path,
|
||||
project_root=tmp_path,
|
||||
test_framework="vitest",
|
||||
min_loops=10,
|
||||
max_loops=50,
|
||||
target_duration_seconds=5.0,
|
||||
)
|
||||
|
||||
call_kwargs = mock_vitest_runner.call_args.kwargs
|
||||
assert call_kwargs["min_loops"] == 10
|
||||
assert call_kwargs["max_loops"] == 50
|
||||
assert call_kwargs["target_duration_ms"] == 5000
|
||||
|
||||
|
||||
class TestLineProfileTestsDispatch:
|
||||
"""Tests for run_line_profile_tests dispatch logic."""
|
||||
|
||||
@patch("codeflash.languages.javascript.test_runner.run_jest_line_profile_tests")
|
||||
def test_dispatches_to_jest_by_default(self, mock_jest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
|
||||
"""Should dispatch to Jest when test_framework is not specified."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
|
||||
mock_jest_runner.return_value = (tmp_path / "result.xml", MagicMock())
|
||||
|
||||
js_support.run_line_profile_tests(test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path)
|
||||
|
||||
mock_jest_runner.assert_called_once()
|
||||
|
||||
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_line_profile_tests")
|
||||
def test_dispatches_to_vitest(self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
|
||||
"""Should dispatch to Vitest when test_framework='vitest'."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
|
||||
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock())
|
||||
|
||||
js_support.run_line_profile_tests(
|
||||
test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path, test_framework="vitest"
|
||||
)
|
||||
|
||||
mock_vitest_runner.assert_called_once()
|
||||
|
||||
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_line_profile_tests")
|
||||
def test_passes_line_profile_output_file(
|
||||
self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport
|
||||
) -> None:
|
||||
"""Should pass line_profile_output_file to Vitest runner."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
output_file = tmp_path / "line_profile.json"
|
||||
|
||||
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock())
|
||||
|
||||
js_support.run_line_profile_tests(
|
||||
test_paths=test_files,
|
||||
test_env={},
|
||||
cwd=tmp_path,
|
||||
project_root=tmp_path,
|
||||
test_framework="vitest",
|
||||
line_profile_output_file=output_file,
|
||||
)
|
||||
|
||||
call_kwargs = mock_vitest_runner.call_args.kwargs
|
||||
assert call_kwargs["line_profile_output_file"] == output_file
|
||||
|
||||
|
||||
class TestTestFrameworkProperty:
|
||||
"""Tests for test_framework property."""
|
||||
|
||||
def test_javascript_default_framework_is_jest(self, js_support: JavaScriptSupport) -> None:
|
||||
"""JavaScriptSupport should have Jest as default test framework."""
|
||||
assert js_support.test_framework == "jest"
|
||||
|
||||
def test_typescript_default_framework_is_jest(self, ts_support: TypeScriptSupport) -> None:
|
||||
"""TypeScriptSupport should have Jest as default test framework."""
|
||||
assert ts_support.test_framework == "jest"
|
||||
|
||||
|
||||
class TestTestFrameworkSingleton:
|
||||
"""Tests for test_framework singleton behavior."""
|
||||
|
||||
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_behavioral_tests")
|
||||
def test_uses_singleton_when_param_not_provided(
|
||||
self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport
|
||||
) -> None:
|
||||
"""Should use singleton test_framework when parameter is not provided."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
|
||||
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock(), None, None)
|
||||
|
||||
set_current_test_framework("vitest")
|
||||
|
||||
js_support.run_behavioral_tests(test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path)
|
||||
|
||||
mock_vitest_runner.assert_called_once()
|
||||
|
||||
@patch("codeflash.languages.javascript.test_runner.run_jest_behavioral_tests")
|
||||
def test_explicit_param_overrides_singleton(
|
||||
self, mock_jest_runner: MagicMock, js_support: JavaScriptSupport
|
||||
) -> None:
|
||||
"""Explicit test_framework parameter should override singleton."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
|
||||
mock_jest_runner.return_value = (tmp_path / "result.xml", MagicMock(), None, None)
|
||||
|
||||
set_current_test_framework("vitest")
|
||||
|
||||
js_support.run_behavioral_tests(
|
||||
test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path, test_framework="jest"
|
||||
)
|
||||
|
||||
mock_jest_runner.assert_called_once()
|
||||
|
||||
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_benchmarking_tests")
|
||||
def test_benchmarking_uses_singleton(self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
|
||||
"""run_benchmarking_tests should use singleton when param not provided."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
|
||||
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock())
|
||||
|
||||
set_current_test_framework("vitest")
|
||||
|
||||
js_support.run_benchmarking_tests(test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path)
|
||||
|
||||
mock_vitest_runner.assert_called_once()
|
||||
|
||||
@patch("codeflash.languages.javascript.vitest_runner.run_vitest_line_profile_tests")
|
||||
def test_line_profile_uses_singleton(self, mock_vitest_runner: MagicMock, js_support: JavaScriptSupport) -> None:
|
||||
"""run_line_profile_tests should use singleton when param not provided."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_files = create_test_files(tmp_path)
|
||||
|
||||
mock_vitest_runner.return_value = (tmp_path / "result.xml", MagicMock())
|
||||
|
||||
set_current_test_framework("vitest")
|
||||
|
||||
js_support.run_line_profile_tests(test_paths=test_files, test_env={}, cwd=tmp_path, project_root=tmp_path)
|
||||
|
||||
mock_vitest_runner.assert_called_once()
|
||||
|
||||
|
||||
class TestTestFrameworkSingletonModule:
|
||||
"""Tests for the test_framework singleton module itself."""
|
||||
|
||||
def test_initial_state_is_none(self) -> None:
|
||||
"""Singleton should start as None."""
|
||||
from codeflash.languages.test_framework import current_test_framework
|
||||
|
||||
assert current_test_framework() is None
|
||||
|
||||
def test_set_and_get(self) -> None:
|
||||
"""Should be able to set and get test framework."""
|
||||
from codeflash.languages.test_framework import current_test_framework, set_current_test_framework
|
||||
|
||||
set_current_test_framework("vitest")
|
||||
assert current_test_framework() == "vitest"
|
||||
|
||||
def test_set_only_once(self) -> None:
|
||||
"""Once set, singleton should not change."""
|
||||
from codeflash.languages.test_framework import current_test_framework, set_current_test_framework
|
||||
|
||||
set_current_test_framework("jest")
|
||||
set_current_test_framework("vitest")
|
||||
assert current_test_framework() == "jest"
|
||||
|
||||
def test_is_jest(self) -> None:
|
||||
"""is_jest() should return True when framework is Jest."""
|
||||
from codeflash.languages.test_framework import is_jest, set_current_test_framework
|
||||
|
||||
set_current_test_framework("jest")
|
||||
assert is_jest() is True
|
||||
|
||||
def test_is_vitest(self) -> None:
|
||||
"""is_vitest() should return True when framework is Vitest."""
|
||||
from codeflash.languages.test_framework import is_vitest, set_current_test_framework
|
||||
|
||||
set_current_test_framework("vitest")
|
||||
assert is_vitest() is True
|
||||
|
||||
def test_get_js_test_framework_or_default_returns_jest(self) -> None:
|
||||
"""get_js_test_framework_or_default should return 'jest' when not set."""
|
||||
from codeflash.languages.test_framework import get_js_test_framework_or_default
|
||||
|
||||
assert get_js_test_framework_or_default() == "jest"
|
||||
|
||||
def test_get_js_test_framework_or_default_returns_vitest(self) -> None:
|
||||
"""get_js_test_framework_or_default should return 'vitest' when set."""
|
||||
from codeflash.languages.test_framework import get_js_test_framework_or_default, set_current_test_framework
|
||||
|
||||
set_current_test_framework("vitest")
|
||||
assert get_js_test_framework_or_default() == "vitest"
|
||||
247
tests/languages/javascript/test_vitest_junit.py
Normal file
247
tests/languages/javascript/test_vitest_junit.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
"""Tests for Vitest JUnit XML output parsing and compatibility.
|
||||
|
||||
These tests verify that Vitest's JUnit XML output can be parsed
|
||||
by the existing parsing infrastructure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from junitparser import JUnitXml
|
||||
|
||||
from codeflash.verification.parse_test_output import jest_end_pattern, jest_start_pattern
|
||||
|
||||
|
||||
class TestVitestJunitXmlFormat:
|
||||
"""Tests for Vitest JUnit XML format compatibility."""
|
||||
|
||||
def test_can_parse_vitest_junit_xml(self) -> None:
|
||||
"""Should be able to parse Vitest JUnit XML with junitparser."""
|
||||
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<testsuites name="vitest tests" tests="4" failures="1" errors="0" time="0.537">
|
||||
<testsuite name="tests/fibonacci.test.ts" timestamp="2026-01-30T18:03:49.433Z" hostname="localhost" tests="3" failures="0" errors="0" skipped="0" time="0.008">
|
||||
<testcase classname="tests/fibonacci.test.ts" name="fibonacci > returns 0 for n=0" time="0.001">
|
||||
</testcase>
|
||||
<testcase classname="tests/fibonacci.test.ts" name="fibonacci > returns 1 for n=1" time="0.0005">
|
||||
</testcase>
|
||||
<testcase classname="tests/fibonacci.test.ts" name="fibonacci > returns 55 for n=10" time="0.0001">
|
||||
</testcase>
|
||||
</testsuite>
|
||||
<testsuite name="tests/string_utils.test.ts" timestamp="2026-01-30T18:03:49.438Z" hostname="localhost" tests="1" failures="1" errors="0" skipped="0" time="0.01">
|
||||
<testcase classname="tests/string_utils.test.ts" name="reverseString > reverses a simple string" time="0.0007">
|
||||
<failure message="expected 'olleh' to equal 'hello'" type="AssertionError">AssertionError: expected 'olleh' to equal 'hello'</failure>
|
||||
</testcase>
|
||||
</testsuite>
|
||||
</testsuites>"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
|
||||
f.write(xml_content)
|
||||
f.flush()
|
||||
junit_file = Path(f.name)
|
||||
|
||||
xml = JUnitXml.fromfile(str(junit_file))
|
||||
|
||||
assert xml is not None
|
||||
test_count = sum(len(list(suite)) for suite in xml)
|
||||
assert test_count == 4
|
||||
|
||||
def test_extracts_test_suite_names(self) -> None:
|
||||
"""Should extract test suite names from Vitest JUnit XML."""
|
||||
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<testsuites name="vitest tests" tests="2" failures="0" errors="0" time="0.1">
|
||||
<testsuite name="tests/fibonacci.test.ts" tests="1" failures="0" time="0.01">
|
||||
<testcase classname="tests/fibonacci.test.ts" name="test1" time="0.001"></testcase>
|
||||
</testsuite>
|
||||
<testsuite name="tests/string_utils.test.ts" tests="1" failures="0" time="0.01">
|
||||
<testcase classname="tests/string_utils.test.ts" name="test2" time="0.001"></testcase>
|
||||
</testsuite>
|
||||
</testsuites>"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
|
||||
f.write(xml_content)
|
||||
f.flush()
|
||||
junit_file = Path(f.name)
|
||||
|
||||
xml = JUnitXml.fromfile(str(junit_file))
|
||||
|
||||
suite_names = [suite.name for suite in xml]
|
||||
assert suite_names == ["tests/fibonacci.test.ts", "tests/string_utils.test.ts"]
|
||||
|
||||
def test_extracts_test_case_names_with_vitest_separator(self) -> None:
|
||||
"""Should extract test case names from Vitest JUnit XML (uses > as separator)."""
|
||||
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<testsuites name="vitest tests" tests="2" failures="0" errors="0" time="0.1">
|
||||
<testsuite name="tests/fibonacci.test.ts" tests="2" failures="0" time="0.01">
|
||||
<testcase classname="tests/fibonacci.test.ts" name="fibonacci > returns 0 for n=0" time="0.001"></testcase>
|
||||
<testcase classname="tests/fibonacci.test.ts" name="fibonacci > returns 1 for n=1" time="0.001"></testcase>
|
||||
</testsuite>
|
||||
</testsuites>"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
|
||||
f.write(xml_content)
|
||||
f.flush()
|
||||
junit_file = Path(f.name)
|
||||
|
||||
xml = JUnitXml.fromfile(str(junit_file))
|
||||
|
||||
test_names = []
|
||||
for suite in xml:
|
||||
for case in suite:
|
||||
test_names.append(case.name)
|
||||
|
||||
assert test_names == ["fibonacci > returns 0 for n=0", "fibonacci > returns 1 for n=1"]
|
||||
|
||||
def test_extracts_classname_as_file_path(self) -> None:
|
||||
"""Should extract classname which contains file path in Vitest."""
|
||||
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<testsuites name="vitest tests" tests="1" failures="0" errors="0" time="0.1">
|
||||
<testsuite name="tests/fibonacci.test.ts" tests="1" failures="0" time="0.01">
|
||||
<testcase classname="tests/fibonacci.test.ts" name="test1" time="0.001"></testcase>
|
||||
</testsuite>
|
||||
</testsuites>"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
|
||||
f.write(xml_content)
|
||||
f.flush()
|
||||
junit_file = Path(f.name)
|
||||
|
||||
xml = JUnitXml.fromfile(str(junit_file))
|
||||
|
||||
for suite in xml:
|
||||
for case in suite:
|
||||
assert case.classname == "tests/fibonacci.test.ts"
|
||||
|
||||
def test_extracts_test_time_as_float(self) -> None:
|
||||
"""Should extract test execution time as float from Vitest JUnit XML."""
|
||||
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<testsuites name="vitest tests" tests="1" failures="0" errors="0" time="0.1">
|
||||
<testsuite name="tests/test.ts" tests="1" failures="0" time="0.01">
|
||||
<testcase classname="tests/test.ts" name="test1" time="0.0015"></testcase>
|
||||
</testsuite>
|
||||
</testsuites>"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
|
||||
f.write(xml_content)
|
||||
f.flush()
|
||||
junit_file = Path(f.name)
|
||||
|
||||
xml = JUnitXml.fromfile(str(junit_file))
|
||||
|
||||
for suite in xml:
|
||||
for case in suite:
|
||||
assert isinstance(case.time, float)
|
||||
assert case.time == 0.0015
|
||||
|
||||
def test_detects_failures(self) -> None:
|
||||
"""Should detect test failures in Vitest JUnit XML."""
|
||||
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<testsuites name="vitest tests" tests="2" failures="1" errors="0" time="0.1">
|
||||
<testsuite name="tests/test.ts" tests="2" failures="1" time="0.01">
|
||||
<testcase classname="tests/test.ts" name="passing test" time="0.001"></testcase>
|
||||
<testcase classname="tests/test.ts" name="failing test" time="0.001">
|
||||
<failure message="expected true to be false" type="AssertionError">AssertionError: expected true to be false</failure>
|
||||
</testcase>
|
||||
</testsuite>
|
||||
</testsuites>"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
|
||||
f.write(xml_content)
|
||||
f.flush()
|
||||
junit_file = Path(f.name)
|
||||
|
||||
xml = JUnitXml.fromfile(str(junit_file))
|
||||
|
||||
failures = []
|
||||
for suite in xml:
|
||||
for case in suite:
|
||||
if not case.is_passed:
|
||||
failures.append(case.name)
|
||||
|
||||
assert failures == ["failing test"]
|
||||
|
||||
def test_extracts_failure_message(self) -> None:
|
||||
"""Should extract failure message from Vitest JUnit XML."""
|
||||
xml_content = """<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<testsuites name="vitest tests" tests="1" failures="1" errors="0" time="0.1">
|
||||
<testsuite name="tests/test.ts" tests="1" failures="1" time="0.01">
|
||||
<testcase classname="tests/test.ts" name="failing test" time="0.001">
|
||||
<failure message="expected 'actual' to equal 'expected'" type="AssertionError">AssertionError: expected 'actual' to equal 'expected'</failure>
|
||||
</testcase>
|
||||
</testsuite>
|
||||
</testsuites>"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as f:
|
||||
f.write(xml_content)
|
||||
f.flush()
|
||||
junit_file = Path(f.name)
|
||||
|
||||
xml = JUnitXml.fromfile(str(junit_file))
|
||||
|
||||
for suite in xml:
|
||||
for case in suite:
|
||||
if not case.is_passed:
|
||||
for result in case.result:
|
||||
if hasattr(result, "message"):
|
||||
assert result.message == "expected 'actual' to equal 'expected'"
|
||||
|
||||
|
||||
class TestVitestTimingMarkers:
|
||||
"""Tests for Vitest timing marker extraction.
|
||||
|
||||
Timing markers are used to measure function execution time during benchmarking.
|
||||
The format is the same for Jest and Vitest since they use the same codeflash helper.
|
||||
"""
|
||||
|
||||
def test_parses_start_timing_marker(self) -> None:
|
||||
"""Should parse start timing marker from Vitest output."""
|
||||
output = "!$######fibonacci.test.ts:returns 0 for n=0:fibonacci:1:line_0######$!"
|
||||
|
||||
matches = jest_start_pattern.findall(output)
|
||||
|
||||
assert len(matches) == 1
|
||||
test_file, test_name, func_name, loop_index, line_id = matches[0]
|
||||
assert test_file == "fibonacci.test.ts"
|
||||
assert test_name == "returns 0 for n=0"
|
||||
assert func_name == "fibonacci"
|
||||
assert loop_index == "1"
|
||||
assert line_id == "line_0"
|
||||
|
||||
def test_parses_end_timing_marker(self) -> None:
|
||||
"""Should parse end timing marker from Vitest output."""
|
||||
output = "!######fibonacci.test.ts:returns 0 for n=0:fibonacci:1:line_0:123456######!"
|
||||
|
||||
matches = jest_end_pattern.findall(output)
|
||||
|
||||
assert len(matches) == 1
|
||||
test_file, test_name, func_name, loop_index, line_id, duration = matches[0]
|
||||
assert test_file == "fibonacci.test.ts"
|
||||
assert test_name == "returns 0 for n=0"
|
||||
assert func_name == "fibonacci"
|
||||
assert loop_index == "1"
|
||||
assert line_id == "line_0"
|
||||
assert duration == "123456"
|
||||
|
||||
def test_extracts_multiple_timing_markers(self) -> None:
|
||||
"""Should extract multiple timing markers from Vitest output."""
|
||||
output = """Running tests...
|
||||
!$######test.ts:test1:func:1:id1######$!
|
||||
executing...
|
||||
!######test.ts:test1:func:1:id1:100000######!
|
||||
!$######test.ts:test2:func:1:id2######$!
|
||||
executing...
|
||||
!######test.ts:test2:func:1:id2:200000######!
|
||||
Done."""
|
||||
|
||||
start_matches = jest_start_pattern.findall(output)
|
||||
end_matches = jest_end_pattern.findall(output)
|
||||
|
||||
assert len(start_matches) == 2
|
||||
assert len(end_matches) == 2
|
||||
|
||||
durations = [int(m[5]) for m in end_matches]
|
||||
assert durations == [100000, 200000]
|
||||
|
||||
def test_timing_marker_with_special_characters_in_test_name(self) -> None:
|
||||
"""Should handle test names with special characters."""
|
||||
output = "!$######test.ts:handles_n=0_correctly:fibonacci:1:id######$!"
|
||||
|
||||
matches = jest_start_pattern.findall(output)
|
||||
|
||||
assert len(matches) == 1
|
||||
assert matches[0][1] == "handles_n=0_correctly"
|
||||
243
tests/languages/javascript/test_vitest_runner.py
Normal file
243
tests/languages/javascript/test_vitest_runner.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
"""Tests for Vitest test runner command construction.
|
||||
|
||||
These tests verify that Vitest commands are correctly constructed
|
||||
with the appropriate flags and arguments.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.javascript.vitest_runner import (
|
||||
_build_vitest_behavioral_command,
|
||||
_build_vitest_benchmarking_command,
|
||||
_find_vitest_project_root,
|
||||
)
|
||||
|
||||
|
||||
class TestFindVitestProjectRoot:
|
||||
"""Tests for _find_vitest_project_root function."""
|
||||
|
||||
def test_finds_vitest_config_js(self) -> None:
|
||||
"""Should find project root via vitest.config.js."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "vitest.config.js").write_text("export default {}")
|
||||
test_file = tmp_path / "tests" / "test.test.ts"
|
||||
test_file.parent.mkdir(parents=True)
|
||||
test_file.write_text("")
|
||||
|
||||
result = _find_vitest_project_root(test_file)
|
||||
|
||||
assert result == tmp_path
|
||||
|
||||
def test_finds_vitest_config_ts(self) -> None:
|
||||
"""Should find project root via vitest.config.ts."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "vitest.config.ts").write_text("export default {}")
|
||||
test_file = tmp_path / "tests" / "test.test.ts"
|
||||
test_file.parent.mkdir(parents=True)
|
||||
test_file.write_text("")
|
||||
|
||||
result = _find_vitest_project_root(test_file)
|
||||
|
||||
assert result == tmp_path
|
||||
|
||||
def test_finds_vite_config_js(self) -> None:
|
||||
"""Should find project root via vite.config.js (Vitest can be configured in vite config)."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "vite.config.js").write_text("export default {}")
|
||||
test_file = tmp_path / "tests" / "test.test.ts"
|
||||
test_file.parent.mkdir(parents=True)
|
||||
test_file.write_text("")
|
||||
|
||||
result = _find_vitest_project_root(test_file)
|
||||
|
||||
assert result == tmp_path
|
||||
|
||||
def test_falls_back_to_package_json(self) -> None:
|
||||
"""Should fall back to package.json when no vitest config found."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
(tmp_path / "package.json").write_text('{"name": "test"}')
|
||||
test_file = tmp_path / "tests" / "test.test.ts"
|
||||
test_file.parent.mkdir(parents=True)
|
||||
test_file.write_text("")
|
||||
|
||||
result = _find_vitest_project_root(test_file)
|
||||
|
||||
assert result == tmp_path
|
||||
|
||||
def test_returns_none_when_no_config(self) -> None:
|
||||
"""Should return None when no vitest/vite config or package.json found."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
test_file = tmp_path / "tests" / "test.test.ts"
|
||||
test_file.parent.mkdir(parents=True)
|
||||
test_file.write_text("")
|
||||
|
||||
result = _find_vitest_project_root(test_file)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestBuildVitestBehavioralCommand:
|
||||
"""Tests for _build_vitest_behavioral_command function."""
|
||||
|
||||
def test_basic_command_structure(self) -> None:
|
||||
"""Should build basic Vitest command with required flags."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
test_file = tmp_path / "test.test.ts"
|
||||
test_file.write_text("")
|
||||
|
||||
cmd = _build_vitest_behavioral_command([test_file], timeout=60)
|
||||
|
||||
assert cmd[0] == "npx"
|
||||
assert cmd[1] == "vitest"
|
||||
assert cmd[2] == "run"
|
||||
|
||||
def test_includes_reporter_flags(self) -> None:
|
||||
"""Should include reporter flags for JUnit output."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
test_file = tmp_path / "test.test.ts"
|
||||
test_file.write_text("")
|
||||
|
||||
cmd = _build_vitest_behavioral_command([test_file], timeout=60)
|
||||
|
||||
assert "--reporter=default" in cmd
|
||||
assert "--reporter=junit" in cmd
|
||||
|
||||
def test_includes_serial_execution_flag(self) -> None:
|
||||
"""Should include flag for serial test execution."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
test_file = tmp_path / "test.test.ts"
|
||||
test_file.write_text("")
|
||||
|
||||
cmd = _build_vitest_behavioral_command([test_file], timeout=60)
|
||||
|
||||
assert "--no-file-parallelism" in cmd
|
||||
|
||||
def test_includes_test_files_as_absolute_paths(self) -> None:
|
||||
"""Should include test files at the end of command as absolute paths."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
test_file_a = tmp_path / "test_a.test.ts"
|
||||
test_file_b = tmp_path / "test_b.test.ts"
|
||||
test_file_a.write_text("")
|
||||
test_file_b.write_text("")
|
||||
|
||||
cmd = _build_vitest_behavioral_command([test_file_a, test_file_b], timeout=60)
|
||||
|
||||
assert str(test_file_a.resolve()) in cmd
|
||||
assert str(test_file_b.resolve()) in cmd
|
||||
|
||||
def test_includes_timeout_in_milliseconds(self) -> None:
|
||||
"""Should include test timeout in milliseconds."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
test_file = tmp_path / "test.test.ts"
|
||||
test_file.write_text("")
|
||||
|
||||
cmd = _build_vitest_behavioral_command([test_file], timeout=120)
|
||||
|
||||
assert "--test-timeout=120000" in cmd
|
||||
|
||||
def test_includes_output_file_when_provided(self) -> None:
|
||||
"""Should include --outputFile flag when output_file is provided."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
test_file = tmp_path / "test.test.ts"
|
||||
test_file.write_text("")
|
||||
output_file = tmp_path / "results.xml"
|
||||
|
||||
cmd = _build_vitest_behavioral_command([test_file], timeout=60, output_file=output_file)
|
||||
|
||||
assert f"--outputFile={output_file}" in cmd
|
||||
|
||||
|
||||
class TestBuildVitestBenchmarkingCommand:
|
||||
"""Tests for _build_vitest_benchmarking_command function."""
|
||||
|
||||
def test_basic_command_structure(self) -> None:
|
||||
"""Should build basic Vitest benchmarking command."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
test_file = tmp_path / "test__perf.test.ts"
|
||||
test_file.write_text("")
|
||||
|
||||
cmd = _build_vitest_benchmarking_command([test_file], timeout=60)
|
||||
|
||||
assert cmd[0] == "npx"
|
||||
assert cmd[1] == "vitest"
|
||||
assert cmd[2] == "run"
|
||||
|
||||
def test_includes_serial_execution(self) -> None:
|
||||
"""Should include serial execution for consistent benchmarking."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
test_file = tmp_path / "test__perf.test.ts"
|
||||
test_file.write_text("")
|
||||
|
||||
cmd = _build_vitest_benchmarking_command([test_file], timeout=60)
|
||||
|
||||
assert "--no-file-parallelism" in cmd
|
||||
|
||||
|
||||
class TestVitestVsJestCommandDifferences:
|
||||
"""Tests documenting the key differences between Vitest and Jest commands."""
|
||||
|
||||
def test_vitest_uses_run_subcommand(self) -> None:
|
||||
"""Vitest uses 'run' for single execution, Jest doesn't need it."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
test_file = tmp_path / "test.test.ts"
|
||||
test_file.write_text("")
|
||||
|
||||
vitest_cmd = _build_vitest_behavioral_command([test_file], timeout=60)
|
||||
|
||||
assert vitest_cmd[0:3] == ["npx", "vitest", "run"]
|
||||
|
||||
def test_vitest_uses_hyphenated_timeout(self) -> None:
|
||||
"""Vitest uses --test-timeout, Jest uses --testTimeout (camelCase)."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
test_file = tmp_path / "test.test.ts"
|
||||
test_file.write_text("")
|
||||
|
||||
vitest_cmd = _build_vitest_behavioral_command([test_file], timeout=60)
|
||||
|
||||
timeout_args = [arg for arg in vitest_cmd if "timeout" in arg.lower()]
|
||||
assert len(timeout_args) == 1
|
||||
assert timeout_args[0] == "--test-timeout=60000"
|
||||
|
||||
def test_vitest_uses_no_file_parallelism(self) -> None:
|
||||
"""Vitest uses --no-file-parallelism, Jest uses --runInBand."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
test_file = tmp_path / "test.test.ts"
|
||||
test_file.write_text("")
|
||||
|
||||
vitest_cmd = _build_vitest_behavioral_command([test_file], timeout=60)
|
||||
|
||||
assert "--no-file-parallelism" in vitest_cmd
|
||||
assert "--runInBand" not in vitest_cmd
|
||||
|
||||
def test_vitest_positional_test_files(self) -> None:
|
||||
"""Vitest uses positional args for test files, not --runTestsByPath."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
test_file = tmp_path / "test.test.ts"
|
||||
test_file.write_text("")
|
||||
|
||||
vitest_cmd = _build_vitest_behavioral_command([test_file], timeout=60)
|
||||
|
||||
assert "--runTestsByPath" not in vitest_cmd
|
||||
assert str(test_file.resolve()) in vitest_cmd
|
||||
|
|
@ -604,3 +604,349 @@ def test_function_in_tests_dir():
|
|||
assert "vanilla_function" not in remaining_functions
|
||||
files_and_funcs = get_all_files_and_functions(module_root_path=temp_dir, ignore_paths=[])
|
||||
assert len(files_and_funcs) == 6
|
||||
|
||||
|
||||
def test_filter_functions_tests_root_overlaps_source():
|
||||
"""Test that source files are not filtered when tests_root equals module_root or project_root.
|
||||
|
||||
This is a critical test for monorepo structures where tests live alongside source code
|
||||
(e.g., TypeScript projects with .test.ts files in the same directories as source).
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
temp_dir = Path(temp_dir_str)
|
||||
|
||||
# Create a source file (NOT a test file)
|
||||
source_file = temp_dir / "utils.py"
|
||||
with source_file.open("w") as f:
|
||||
f.write("""
|
||||
def process_data(items):
|
||||
return [item * 2 for item in items]
|
||||
|
||||
def calculate_sum(numbers):
|
||||
return sum(numbers)
|
||||
""")
|
||||
|
||||
# Create a test file with standard naming pattern
|
||||
test_file = temp_dir / "utils.test.py"
|
||||
with test_file.open("w") as f:
|
||||
f.write("""
|
||||
def test_process_data():
|
||||
return "test"
|
||||
""")
|
||||
|
||||
# Create a test file with _test suffix pattern
|
||||
test_file_underscore = temp_dir / "utils_test.py"
|
||||
with test_file_underscore.open("w") as f:
|
||||
f.write("""
|
||||
def test_calculate_sum():
|
||||
return "test"
|
||||
""")
|
||||
|
||||
# Create a spec file
|
||||
spec_file = temp_dir / "utils.spec.py"
|
||||
with spec_file.open("w") as f:
|
||||
f.write("""
|
||||
def spec_function():
|
||||
return "spec"
|
||||
""")
|
||||
|
||||
# Create a file in a tests subdirectory
|
||||
tests_subdir = temp_dir / "tests"
|
||||
tests_subdir.mkdir()
|
||||
tests_subdir_file = tests_subdir / "test_main.py"
|
||||
with tests_subdir_file.open("w") as f:
|
||||
f.write("""
|
||||
def test_in_tests_dir():
|
||||
return "test"
|
||||
""")
|
||||
|
||||
# Create a file in __tests__ subdirectory (common in JS/TS projects)
|
||||
dunder_tests_subdir = temp_dir / "__tests__"
|
||||
dunder_tests_subdir.mkdir()
|
||||
dunder_tests_file = dunder_tests_subdir / "main.py"
|
||||
with dunder_tests_file.open("w") as f:
|
||||
f.write("""
|
||||
def test_in_dunder_tests():
|
||||
return "test"
|
||||
""")
|
||||
|
||||
# Discover all functions
|
||||
discovered_source = find_all_functions_in_file(source_file)
|
||||
discovered_test = find_all_functions_in_file(test_file)
|
||||
discovered_test_underscore = find_all_functions_in_file(test_file_underscore)
|
||||
discovered_spec = find_all_functions_in_file(spec_file)
|
||||
discovered_tests_dir = find_all_functions_in_file(tests_subdir_file)
|
||||
discovered_dunder_tests = find_all_functions_in_file(dunder_tests_file)
|
||||
|
||||
# Combine all discovered functions
|
||||
all_functions = {}
|
||||
for discovered in [discovered_source, discovered_test, discovered_test_underscore,
|
||||
discovered_spec, discovered_tests_dir, discovered_dunder_tests]:
|
||||
all_functions.update(discovered)
|
||||
|
||||
# Test Case 1: tests_root == module_root (overlapping case)
|
||||
# This is the bug scenario where all functions were being filtered
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered, count = filter_functions(
|
||||
all_functions,
|
||||
tests_root=temp_dir, # Same as module_root
|
||||
ignore_paths=[],
|
||||
project_root=temp_dir,
|
||||
module_root=temp_dir, # Same as tests_root
|
||||
)
|
||||
|
||||
# Strict check: only source_file should remain in filtered results
|
||||
assert set(filtered.keys()) == {source_file}, (
|
||||
f"Expected only source file in filtered results, got: {set(filtered.keys())}"
|
||||
)
|
||||
|
||||
# Strict check: exactly these two functions should be present
|
||||
source_functions = sorted([fn.function_name for fn in filtered.get(source_file, [])])
|
||||
assert source_functions == ["calculate_sum", "process_data"], (
|
||||
f"Expected ['calculate_sum', 'process_data'], got {source_functions}"
|
||||
)
|
||||
|
||||
# Strict check: exactly 2 functions remaining
|
||||
assert count == 2, f"Expected exactly 2 functions, got {count}"
|
||||
|
||||
# Test Case 2: tests_root == project_root (another overlapping case)
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered2, count2 = filter_functions(
|
||||
{source_file: discovered_source[source_file]},
|
||||
tests_root=temp_dir, # Same as project_root
|
||||
ignore_paths=[],
|
||||
project_root=temp_dir,
|
||||
module_root=temp_dir,
|
||||
)
|
||||
|
||||
# Strict check: only source_file should remain
|
||||
assert set(filtered2.keys()) == {source_file}, (
|
||||
f"Expected only source file when tests_root == project_root, got: {set(filtered2.keys())}"
|
||||
)
|
||||
assert count2 == 2, f"Expected exactly 2 functions, got {count2}"
|
||||
|
||||
|
||||
def test_filter_functions_strict_string_matching():
|
||||
"""Test that test file pattern matching uses strict string matching.
|
||||
|
||||
Ensures patterns like '.test.' only match actual test files and don't
|
||||
accidentally match files with similar names like 'contest.py' or 'latest.py'.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
temp_dir = Path(temp_dir_str)
|
||||
|
||||
# Files that should NOT be filtered (contain 'test' as substring but not as pattern)
|
||||
contest_file = temp_dir / "contest.py"
|
||||
with contest_file.open("w") as f:
|
||||
f.write("def run_contest(): return 1")
|
||||
|
||||
latest_file = temp_dir / "latest.py"
|
||||
with latest_file.open("w") as f:
|
||||
f.write("def get_latest(): return 1")
|
||||
|
||||
attestation_file = temp_dir / "attestation.py"
|
||||
with attestation_file.open("w") as f:
|
||||
f.write("def verify_attestation(): return 1")
|
||||
|
||||
# File that SHOULD be filtered (matches .test. pattern)
|
||||
actual_test_file = temp_dir / "utils.test.py"
|
||||
with actual_test_file.open("w") as f:
|
||||
f.write("def test_utils(): return 1")
|
||||
|
||||
# File that SHOULD be filtered (matches _test. pattern)
|
||||
underscore_test_file = temp_dir / "utils_test.py"
|
||||
with underscore_test_file.open("w") as f:
|
||||
f.write("def test_stuff(): return 1")
|
||||
|
||||
# Discover all functions
|
||||
all_functions = {}
|
||||
for file_path in [contest_file, latest_file, attestation_file, actual_test_file, underscore_test_file]:
|
||||
discovered = find_all_functions_in_file(file_path)
|
||||
all_functions.update(discovered)
|
||||
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered, count = filter_functions(
|
||||
all_functions,
|
||||
tests_root=temp_dir, # Overlapping case to trigger pattern matching
|
||||
ignore_paths=[],
|
||||
project_root=temp_dir,
|
||||
module_root=temp_dir,
|
||||
)
|
||||
|
||||
# Strict check: exactly these 3 files should remain (those with 'test' as substring only)
|
||||
expected_files = {contest_file, latest_file, attestation_file}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
|
||||
# Strict check: each file should have exactly 1 function with the expected name
|
||||
assert [fn.function_name for fn in filtered[contest_file]] == ["run_contest"], (
|
||||
f"Expected ['run_contest'], got {[fn.function_name for fn in filtered[contest_file]]}"
|
||||
)
|
||||
assert [fn.function_name for fn in filtered[latest_file]] == ["get_latest"], (
|
||||
f"Expected ['get_latest'], got {[fn.function_name for fn in filtered[latest_file]]}"
|
||||
)
|
||||
assert [fn.function_name for fn in filtered[attestation_file]] == ["verify_attestation"], (
|
||||
f"Expected ['verify_attestation'], got {[fn.function_name for fn in filtered[attestation_file]]}"
|
||||
)
|
||||
|
||||
# Strict check: exactly 3 functions remaining
|
||||
assert count == 3, f"Expected exactly 3 functions, got {count}"
|
||||
|
||||
|
||||
def test_filter_functions_test_directory_patterns():
|
||||
"""Test that test directory patterns work correctly with strict matching.
|
||||
|
||||
Ensures that /test/, /tests/, and /__tests__/ patterns only match actual
|
||||
test directories and not directories that happen to contain 'test' in name.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
temp_dir = Path(temp_dir_str)
|
||||
|
||||
# Directory that should NOT be filtered (contains 'test' but not as /test/ pattern)
|
||||
contest_dir = temp_dir / "contest_results"
|
||||
contest_dir.mkdir()
|
||||
contest_file = contest_dir / "scores.py"
|
||||
with contest_file.open("w") as f:
|
||||
f.write("def get_scores(): return [1, 2, 3]")
|
||||
|
||||
latest_dir = temp_dir / "latest_data"
|
||||
latest_dir.mkdir()
|
||||
latest_file = latest_dir / "data.py"
|
||||
with latest_file.open("w") as f:
|
||||
f.write("def load_data(): return {}")
|
||||
|
||||
# Directory that SHOULD be filtered (matches /tests/ pattern)
|
||||
tests_dir = temp_dir / "tests"
|
||||
tests_dir.mkdir()
|
||||
tests_file = tests_dir / "test_main.py"
|
||||
with tests_file.open("w") as f:
|
||||
f.write("def test_main(): return True")
|
||||
|
||||
# Directory that SHOULD be filtered (matches /test/ pattern - singular)
|
||||
test_dir = temp_dir / "test"
|
||||
test_dir.mkdir()
|
||||
test_file = test_dir / "test_utils.py"
|
||||
with test_file.open("w") as f:
|
||||
f.write("def test_utils(): return True")
|
||||
|
||||
# Directory that SHOULD be filtered (matches /__tests__/ pattern)
|
||||
dunder_tests_dir = temp_dir / "__tests__"
|
||||
dunder_tests_dir.mkdir()
|
||||
dunder_file = dunder_tests_dir / "component.py"
|
||||
with dunder_file.open("w") as f:
|
||||
f.write("def test_component(): return True")
|
||||
|
||||
# Nested test directory
|
||||
src_dir = temp_dir / "src"
|
||||
src_dir.mkdir()
|
||||
nested_tests_dir = src_dir / "tests"
|
||||
nested_tests_dir.mkdir()
|
||||
nested_test_file = nested_tests_dir / "test_nested.py"
|
||||
with nested_test_file.open("w") as f:
|
||||
f.write("def test_nested(): return True")
|
||||
|
||||
# Discover all functions
|
||||
all_functions = {}
|
||||
for file_path in [contest_file, latest_file, tests_file, test_file, dunder_file, nested_test_file]:
|
||||
discovered = find_all_functions_in_file(file_path)
|
||||
all_functions.update(discovered)
|
||||
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered, count = filter_functions(
|
||||
all_functions,
|
||||
tests_root=temp_dir, # Overlapping case
|
||||
ignore_paths=[],
|
||||
project_root=temp_dir,
|
||||
module_root=temp_dir,
|
||||
)
|
||||
|
||||
# Strict check: exactly these 2 files should remain (those in non-test directories)
|
||||
expected_files = {contest_file, latest_file}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
|
||||
# Strict check: each file should have exactly 1 function with the expected name
|
||||
assert [fn.function_name for fn in filtered[contest_file]] == ["get_scores"], (
|
||||
f"Expected ['get_scores'], got {[fn.function_name for fn in filtered[contest_file]]}"
|
||||
)
|
||||
assert [fn.function_name for fn in filtered[latest_file]] == ["load_data"], (
|
||||
f"Expected ['load_data'], got {[fn.function_name for fn in filtered[latest_file]]}"
|
||||
)
|
||||
|
||||
# Strict check: exactly 2 functions remaining
|
||||
assert count == 2, f"Expected exactly 2 functions, got {count}"
|
||||
|
||||
|
||||
def test_filter_functions_non_overlapping_tests_root():
|
||||
"""Test that the original directory-based filtering still works when tests_root is separate.
|
||||
|
||||
When tests_root is a distinct directory (e.g., 'tests/'), the original behavior
|
||||
of filtering files that start with tests_root should still work.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
temp_dir = Path(temp_dir_str)
|
||||
|
||||
# Create source directory structure
|
||||
src_dir = temp_dir / "src"
|
||||
src_dir.mkdir()
|
||||
source_file = src_dir / "utils.py"
|
||||
with source_file.open("w") as f:
|
||||
f.write("def process(): return 1")
|
||||
|
||||
# Create a file with .test. pattern in source (should NOT be filtered in non-overlapping mode)
|
||||
# because directory-based filtering takes precedence
|
||||
test_in_src = src_dir / "helper.test.py"
|
||||
with test_in_src.open("w") as f:
|
||||
f.write("def helper_test(): return 1")
|
||||
|
||||
# Create separate tests directory
|
||||
tests_dir = temp_dir / "tests"
|
||||
tests_dir.mkdir()
|
||||
test_file = tests_dir / "test_utils.py"
|
||||
with test_file.open("w") as f:
|
||||
f.write("def test_process(): return 1")
|
||||
|
||||
# Discover functions
|
||||
all_functions = {}
|
||||
for file_path in [source_file, test_in_src, test_file]:
|
||||
discovered = find_all_functions_in_file(file_path)
|
||||
all_functions.update(discovered)
|
||||
|
||||
# Non-overlapping case: tests_root is a separate directory
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered, count = filter_functions(
|
||||
all_functions,
|
||||
tests_root=tests_dir, # Separate from module_root
|
||||
ignore_paths=[],
|
||||
project_root=temp_dir,
|
||||
module_root=src_dir, # Different from tests_root
|
||||
)
|
||||
|
||||
# Strict check: exactly these 2 files should remain (both in src/, not in tests/)
|
||||
expected_files = {source_file, test_in_src}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
|
||||
# Strict check: each file should have exactly 1 function with the expected name
|
||||
assert [fn.function_name for fn in filtered[source_file]] == ["process"], (
|
||||
f"Expected ['process'], got {[fn.function_name for fn in filtered[source_file]]}"
|
||||
)
|
||||
assert [fn.function_name for fn in filtered[test_in_src]] == ["helper_test"], (
|
||||
f"Expected ['helper_test'], got {[fn.function_name for fn in filtered[test_in_src]]}"
|
||||
)
|
||||
|
||||
# Strict check: exactly 2 functions remaining
|
||||
assert count == 2, f"Expected exactly 2 functions, got {count}"
|
||||
|
|
|
|||
283
tests/test_init_javascript.py
Normal file
283
tests/test_init_javascript.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
"""Tests for JavaScript/TypeScript project initialization and package manager detection."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.cli_cmds.init_javascript import (
|
||||
JsPackageManager,
|
||||
determine_js_package_manager,
|
||||
get_package_install_command,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_project(tmp_path: Path) -> Path:
|
||||
"""Create a temporary project directory."""
|
||||
return tmp_path
|
||||
|
||||
|
||||
class TestDetermineJsPackageManager:
|
||||
"""Tests for determine_js_package_manager function."""
|
||||
|
||||
def test_detects_pnpm_from_lockfile(self, tmp_project: Path) -> None:
|
||||
"""Should detect pnpm from pnpm-lock.yaml."""
|
||||
(tmp_project / "pnpm-lock.yaml").write_text("")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(tmp_project)
|
||||
|
||||
assert result == JsPackageManager.PNPM
|
||||
|
||||
def test_detects_yarn_from_lockfile(self, tmp_project: Path) -> None:
|
||||
"""Should detect yarn from yarn.lock."""
|
||||
(tmp_project / "yarn.lock").write_text("")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(tmp_project)
|
||||
|
||||
assert result == JsPackageManager.YARN
|
||||
|
||||
def test_detects_npm_from_lockfile(self, tmp_project: Path) -> None:
|
||||
"""Should detect npm from package-lock.json."""
|
||||
(tmp_project / "package-lock.json").write_text("{}")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(tmp_project)
|
||||
|
||||
assert result == JsPackageManager.NPM
|
||||
|
||||
def test_detects_bun_from_lockfile(self, tmp_project: Path) -> None:
|
||||
"""Should detect bun from bun.lockb."""
|
||||
(tmp_project / "bun.lockb").write_text("")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(tmp_project)
|
||||
|
||||
assert result == JsPackageManager.BUN
|
||||
|
||||
def test_detects_bun_from_bun_lock(self, tmp_project: Path) -> None:
|
||||
"""Should detect bun from bun.lock."""
|
||||
(tmp_project / "bun.lock").write_text("")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(tmp_project)
|
||||
|
||||
assert result == JsPackageManager.BUN
|
||||
|
||||
def test_defaults_to_npm_with_package_json_only(self, tmp_project: Path) -> None:
|
||||
"""Should default to npm when only package.json exists."""
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(tmp_project)
|
||||
|
||||
assert result == JsPackageManager.NPM
|
||||
|
||||
def test_returns_unknown_without_package_json(self, tmp_project: Path) -> None:
|
||||
"""Should return UNKNOWN when no package.json exists."""
|
||||
result = determine_js_package_manager(tmp_project)
|
||||
|
||||
assert result == JsPackageManager.UNKNOWN
|
||||
|
||||
def test_pnpm_takes_precedence_over_npm(self, tmp_project: Path) -> None:
|
||||
"""Should prefer pnpm when both lockfiles exist (migration scenario)."""
|
||||
(tmp_project / "pnpm-lock.yaml").write_text("")
|
||||
(tmp_project / "package-lock.json").write_text("{}")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(tmp_project)
|
||||
|
||||
assert result == JsPackageManager.PNPM
|
||||
|
||||
def test_bun_takes_precedence_over_others(self, tmp_project: Path) -> None:
|
||||
"""Should prefer bun when bun.lockb exists alongside others."""
|
||||
(tmp_project / "bun.lockb").write_text("")
|
||||
(tmp_project / "pnpm-lock.yaml").write_text("")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(tmp_project)
|
||||
|
||||
assert result == JsPackageManager.BUN
|
||||
|
||||
# Monorepo tests - lock file in parent directory
|
||||
def test_detects_pnpm_from_parent_lockfile(self, tmp_project: Path) -> None:
|
||||
"""Should detect pnpm from pnpm-lock.yaml in parent directory (monorepo)."""
|
||||
# Create monorepo structure: root/packages/my-package
|
||||
workspace_root = tmp_project
|
||||
package_dir = workspace_root / "packages" / "my-package"
|
||||
package_dir.mkdir(parents=True)
|
||||
|
||||
# Lock file at workspace root
|
||||
(workspace_root / "pnpm-lock.yaml").write_text("")
|
||||
(workspace_root / "package.json").write_text("{}")
|
||||
# Package has its own package.json but no lock file
|
||||
(package_dir / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(package_dir)
|
||||
|
||||
assert result == JsPackageManager.PNPM
|
||||
|
||||
def test_detects_yarn_from_parent_lockfile(self, tmp_project: Path) -> None:
|
||||
"""Should detect yarn from yarn.lock in parent directory (monorepo)."""
|
||||
workspace_root = tmp_project
|
||||
package_dir = workspace_root / "packages" / "my-package"
|
||||
package_dir.mkdir(parents=True)
|
||||
|
||||
(workspace_root / "yarn.lock").write_text("")
|
||||
(workspace_root / "package.json").write_text("{}")
|
||||
(package_dir / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(package_dir)
|
||||
|
||||
assert result == JsPackageManager.YARN
|
||||
|
||||
def test_detects_npm_from_parent_lockfile(self, tmp_project: Path) -> None:
|
||||
"""Should detect npm from package-lock.json in parent directory (monorepo)."""
|
||||
workspace_root = tmp_project
|
||||
package_dir = workspace_root / "packages" / "my-package"
|
||||
package_dir.mkdir(parents=True)
|
||||
|
||||
(workspace_root / "package-lock.json").write_text("{}")
|
||||
(workspace_root / "package.json").write_text("{}")
|
||||
(package_dir / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(package_dir)
|
||||
|
||||
assert result == JsPackageManager.NPM
|
||||
|
||||
def test_detects_bun_from_parent_lockfile(self, tmp_project: Path) -> None:
|
||||
"""Should detect bun from bun.lockb in parent directory (monorepo)."""
|
||||
workspace_root = tmp_project
|
||||
package_dir = workspace_root / "packages" / "my-package"
|
||||
package_dir.mkdir(parents=True)
|
||||
|
||||
(workspace_root / "bun.lockb").write_text("")
|
||||
(workspace_root / "package.json").write_text("{}")
|
||||
(package_dir / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(package_dir)
|
||||
|
||||
assert result == JsPackageManager.BUN
|
||||
|
||||
def test_local_lockfile_takes_precedence_over_parent(self, tmp_project: Path) -> None:
|
||||
"""Should prefer local lock file over parent directory lock file."""
|
||||
workspace_root = tmp_project
|
||||
package_dir = workspace_root / "packages" / "my-package"
|
||||
package_dir.mkdir(parents=True)
|
||||
|
||||
# Parent has pnpm, but local package has yarn
|
||||
(workspace_root / "pnpm-lock.yaml").write_text("")
|
||||
(workspace_root / "package.json").write_text("{}")
|
||||
(package_dir / "yarn.lock").write_text("")
|
||||
(package_dir / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(package_dir)
|
||||
|
||||
# Should detect yarn from local directory first
|
||||
assert result == JsPackageManager.YARN
|
||||
|
||||
def test_deeply_nested_package_finds_root_lockfile(self, tmp_project: Path) -> None:
|
||||
"""Should find lock file in deeply nested monorepo structure."""
|
||||
workspace_root = tmp_project
|
||||
# Simulate: root/apps/web/src/features/auth
|
||||
deep_dir = workspace_root / "apps" / "web" / "src" / "features" / "auth"
|
||||
deep_dir.mkdir(parents=True)
|
||||
|
||||
(workspace_root / "pnpm-lock.yaml").write_text("")
|
||||
(workspace_root / "package.json").write_text("{}")
|
||||
|
||||
result = determine_js_package_manager(deep_dir)
|
||||
|
||||
assert result == JsPackageManager.PNPM
|
||||
|
||||
|
||||
class TestGetPackageInstallCommand:
|
||||
"""Tests for get_package_install_command function."""
|
||||
|
||||
def test_npm_install_command(self, tmp_project: Path) -> None:
|
||||
"""Should return npm install command for npm projects."""
|
||||
(tmp_project / "package-lock.json").write_text("{}")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = get_package_install_command(tmp_project, "codeflash", dev=True)
|
||||
|
||||
assert result == ["npm", "install", "codeflash", "--save-dev"]
|
||||
|
||||
def test_npm_install_command_non_dev(self, tmp_project: Path) -> None:
|
||||
"""Should return npm install command without --save-dev when dev=False."""
|
||||
(tmp_project / "package-lock.json").write_text("{}")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = get_package_install_command(tmp_project, "codeflash", dev=False)
|
||||
|
||||
assert result == ["npm", "install", "codeflash"]
|
||||
|
||||
def test_pnpm_add_command(self, tmp_project: Path) -> None:
|
||||
"""Should return pnpm add command for pnpm projects."""
|
||||
(tmp_project / "pnpm-lock.yaml").write_text("")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = get_package_install_command(tmp_project, "codeflash", dev=True)
|
||||
|
||||
assert result == ["pnpm", "add", "codeflash", "--save-dev"]
|
||||
|
||||
def test_pnpm_add_command_non_dev(self, tmp_project: Path) -> None:
|
||||
"""Should return pnpm add command without --save-dev when dev=False."""
|
||||
(tmp_project / "pnpm-lock.yaml").write_text("")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = get_package_install_command(tmp_project, "codeflash", dev=False)
|
||||
|
||||
assert result == ["pnpm", "add", "codeflash"]
|
||||
|
||||
def test_yarn_add_command(self, tmp_project: Path) -> None:
|
||||
"""Should return yarn add command for yarn projects."""
|
||||
(tmp_project / "yarn.lock").write_text("")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = get_package_install_command(tmp_project, "codeflash", dev=True)
|
||||
|
||||
assert result == ["yarn", "add", "codeflash", "--dev"]
|
||||
|
||||
def test_yarn_add_command_non_dev(self, tmp_project: Path) -> None:
|
||||
"""Should return yarn add command without --dev when dev=False."""
|
||||
(tmp_project / "yarn.lock").write_text("")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = get_package_install_command(tmp_project, "codeflash", dev=False)
|
||||
|
||||
assert result == ["yarn", "add", "codeflash"]
|
||||
|
||||
def test_bun_add_command(self, tmp_project: Path) -> None:
|
||||
"""Should return bun add command for bun projects."""
|
||||
(tmp_project / "bun.lockb").write_text("")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = get_package_install_command(tmp_project, "codeflash", dev=True)
|
||||
|
||||
assert result == ["bun", "add", "codeflash", "--dev"]
|
||||
|
||||
def test_bun_add_command_non_dev(self, tmp_project: Path) -> None:
|
||||
"""Should return bun add command without --dev when dev=False."""
|
||||
(tmp_project / "bun.lockb").write_text("")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = get_package_install_command(tmp_project, "codeflash", dev=False)
|
||||
|
||||
assert result == ["bun", "add", "codeflash"]
|
||||
|
||||
def test_defaults_to_npm_for_unknown(self, tmp_project: Path) -> None:
|
||||
"""Should default to npm for unknown package manager."""
|
||||
# No lockfile, no package.json - unknown package manager
|
||||
result = get_package_install_command(tmp_project, "codeflash", dev=True)
|
||||
|
||||
assert result == ["npm", "install", "codeflash", "--save-dev"]
|
||||
|
||||
def test_different_package_name(self, tmp_project: Path) -> None:
|
||||
"""Should work with different package names."""
|
||||
(tmp_project / "pnpm-lock.yaml").write_text("")
|
||||
(tmp_project / "package.json").write_text("{}")
|
||||
|
||||
result = get_package_install_command(tmp_project, "typescript", dev=True)
|
||||
|
||||
assert result == ["pnpm", "add", "typescript", "--save-dev"]
|
||||
|
|
@ -6,7 +6,22 @@ covering all patterns that might be seen in the wild.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.javascript.instrument import TestingMode, instrument_generated_js_test, transform_expect_calls
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
|
||||
def make_func(name: str, class_name: str | None = None) -> FunctionToOptimize:
|
||||
"""Helper to create FunctionToOptimize for testing."""
|
||||
parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else []
|
||||
return FunctionToOptimize(
|
||||
function_name=name,
|
||||
file_path=Path("/test/file.js"),
|
||||
parents=parents,
|
||||
language="javascript",
|
||||
)
|
||||
|
||||
|
||||
class TestExpectCallTransformer:
|
||||
|
|
@ -15,139 +30,139 @@ class TestExpectCallTransformer:
|
|||
def test_basic_toBe_assertion(self) -> None:
|
||||
"""Test basic .toBe() assertion removal."""
|
||||
code = "expect(fibonacci(5)).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "fibonacci", "fibonacci", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("fibonacci"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('fibonacci', '1', fibonacci, 5);"
|
||||
|
||||
def test_basic_toEqual_assertion(self) -> None:
|
||||
"""Test .toEqual() assertion removal."""
|
||||
code = "expect(func([1, 2, 3])).toEqual([1, 2, 3]);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, [1, 2, 3]);"
|
||||
|
||||
def test_toStrictEqual_assertion(self) -> None:
|
||||
"""Test .toStrictEqual() assertion removal."""
|
||||
code = "expect(func({a: 1})).toStrictEqual({a: 1});"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, {a: 1});"
|
||||
|
||||
def test_toBeCloseTo_with_precision(self) -> None:
|
||||
"""Test .toBeCloseTo() with precision argument."""
|
||||
code = "expect(func(3.14159)).toBeCloseTo(3.14, 2);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 3.14159);"
|
||||
|
||||
def test_toBeTruthy_no_args(self) -> None:
|
||||
"""Test .toBeTruthy() assertion without arguments."""
|
||||
code = "expect(func(true)).toBeTruthy();"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, true);"
|
||||
|
||||
def test_toBeFalsy_no_args(self) -> None:
|
||||
"""Test .toBeFalsy() assertion without arguments."""
|
||||
code = "expect(func(0)).toBeFalsy();"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 0);"
|
||||
|
||||
def test_toBeNull(self) -> None:
|
||||
"""Test .toBeNull() assertion."""
|
||||
code = "expect(func(null)).toBeNull();"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, null);"
|
||||
|
||||
def test_toBeUndefined(self) -> None:
|
||||
"""Test .toBeUndefined() assertion."""
|
||||
code = "expect(func()).toBeUndefined();"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func);"
|
||||
|
||||
def test_toBeDefined(self) -> None:
|
||||
"""Test .toBeDefined() assertion."""
|
||||
code = "expect(func(1)).toBeDefined();"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 1);"
|
||||
|
||||
def test_toBeNaN(self) -> None:
|
||||
"""Test .toBeNaN() assertion."""
|
||||
code = "expect(func(NaN)).toBeNaN();"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, NaN);"
|
||||
|
||||
def test_toBeGreaterThan(self) -> None:
|
||||
"""Test .toBeGreaterThan() assertion."""
|
||||
code = "expect(func(10)).toBeGreaterThan(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 10);"
|
||||
|
||||
def test_toBeLessThan(self) -> None:
|
||||
"""Test .toBeLessThan() assertion."""
|
||||
code = "expect(func(3)).toBeLessThan(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 3);"
|
||||
|
||||
def test_toBeGreaterThanOrEqual(self) -> None:
|
||||
"""Test .toBeGreaterThanOrEqual() assertion."""
|
||||
code = "expect(func(5)).toBeGreaterThanOrEqual(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 5);"
|
||||
|
||||
def test_toBeLessThanOrEqual(self) -> None:
|
||||
"""Test .toBeLessThanOrEqual() assertion."""
|
||||
code = "expect(func(5)).toBeLessThanOrEqual(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 5);"
|
||||
|
||||
def test_toContain(self) -> None:
|
||||
"""Test .toContain() assertion."""
|
||||
code = "expect(func([1, 2, 3])).toContain(2);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, [1, 2, 3]);"
|
||||
|
||||
def test_toContainEqual(self) -> None:
|
||||
"""Test .toContainEqual() assertion."""
|
||||
code = "expect(func([{a: 1}])).toContainEqual({a: 1});"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, [{a: 1}]);"
|
||||
|
||||
def test_toHaveLength(self) -> None:
|
||||
"""Test .toHaveLength() assertion."""
|
||||
code = "expect(func([1, 2, 3])).toHaveLength(3);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, [1, 2, 3]);"
|
||||
|
||||
def test_toMatch_string(self) -> None:
|
||||
"""Test .toMatch() with string pattern."""
|
||||
code = "expect(func('hello')).toMatch('ell');"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 'hello');"
|
||||
|
||||
def test_toMatch_regex(self) -> None:
|
||||
"""Test .toMatch() with regex pattern."""
|
||||
code = "expect(func('hello')).toMatch(/ell/);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 'hello');"
|
||||
|
||||
def test_toMatchObject(self) -> None:
|
||||
"""Test .toMatchObject() assertion."""
|
||||
code = "expect(func({a: 1, b: 2})).toMatchObject({a: 1});"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, {a: 1, b: 2});"
|
||||
|
||||
def test_toHaveProperty(self) -> None:
|
||||
"""Test .toHaveProperty() assertion."""
|
||||
code = "expect(func({a: 1})).toHaveProperty('a');"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, {a: 1});"
|
||||
|
||||
def test_toHaveProperty_with_value(self) -> None:
|
||||
"""Test .toHaveProperty() with value."""
|
||||
code = "expect(func({a: 1})).toHaveProperty('a', 1);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, {a: 1});"
|
||||
|
||||
def test_toBeInstanceOf(self) -> None:
|
||||
"""Test .toBeInstanceOf() assertion."""
|
||||
code = "expect(func()).toBeInstanceOf(Array);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func);"
|
||||
|
||||
|
||||
|
|
@ -157,31 +172,31 @@ class TestNegatedAssertions:
|
|||
def test_not_toBe(self) -> None:
|
||||
"""Test .not.toBe() assertion removal."""
|
||||
code = "expect(func(5)).not.toBe(10);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 5);"
|
||||
|
||||
def test_not_toEqual(self) -> None:
|
||||
"""Test .not.toEqual() assertion removal."""
|
||||
code = "expect(func([1, 2])).not.toEqual([3, 4]);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, [1, 2]);"
|
||||
|
||||
def test_not_toBeTruthy(self) -> None:
|
||||
"""Test .not.toBeTruthy() assertion removal."""
|
||||
code = "expect(func(0)).not.toBeTruthy();"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 0);"
|
||||
|
||||
def test_not_toContain(self) -> None:
|
||||
"""Test .not.toContain() assertion removal."""
|
||||
code = "expect(func([1, 2, 3])).not.toContain(4);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, [1, 2, 3]);"
|
||||
|
||||
def test_not_toBeNull(self) -> None:
|
||||
"""Test .not.toBeNull() assertion removal."""
|
||||
code = "expect(func(1)).not.toBeNull();"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 1);"
|
||||
|
||||
|
||||
|
|
@ -191,31 +206,31 @@ class TestAsyncAssertions:
|
|||
def test_resolves_toBe(self) -> None:
|
||||
"""Test .resolves.toBe() assertion removal."""
|
||||
code = "expect(asyncFunc(5)).resolves.toBe(10);"
|
||||
result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('asyncFunc', '1', asyncFunc, 5);"
|
||||
|
||||
def test_resolves_toEqual(self) -> None:
|
||||
"""Test .resolves.toEqual() assertion removal."""
|
||||
code = "expect(asyncFunc()).resolves.toEqual({data: 'test'});"
|
||||
result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('asyncFunc', '1', asyncFunc);"
|
||||
|
||||
def test_rejects_toThrow(self) -> None:
|
||||
"""Test .rejects.toThrow() assertion removal."""
|
||||
code = "expect(asyncFunc()).rejects.toThrow();"
|
||||
result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('asyncFunc', '1', asyncFunc);"
|
||||
|
||||
def test_rejects_toThrow_with_message(self) -> None:
|
||||
"""Test .rejects.toThrow() with error message."""
|
||||
code = "expect(asyncFunc()).rejects.toThrow('Error message');"
|
||||
result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('asyncFunc', '1', asyncFunc);"
|
||||
|
||||
def test_not_resolves_toBe(self) -> None:
|
||||
"""Test .not.resolves.toBe() (rare but valid)."""
|
||||
code = "expect(asyncFunc()).not.resolves.toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('asyncFunc', '1', asyncFunc);"
|
||||
|
||||
|
||||
|
|
@ -225,31 +240,31 @@ class TestNestedParentheses:
|
|||
def test_nested_function_call(self) -> None:
|
||||
"""Test nested function call in arguments."""
|
||||
code = "expect(func(getN(5))).toBe(10);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, getN(5));"
|
||||
|
||||
def test_deeply_nested_calls(self) -> None:
|
||||
"""Test deeply nested function calls."""
|
||||
code = "expect(func(outer(inner(deep(1))))).toBe(100);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, outer(inner(deep(1))));"
|
||||
|
||||
def test_multiple_nested_args(self) -> None:
|
||||
"""Test multiple arguments with nested calls."""
|
||||
code = "expect(func(getA(), getB(getC()))).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, getA(), getB(getC()));"
|
||||
|
||||
def test_object_with_nested_calls(self) -> None:
|
||||
"""Test object argument with nested function calls."""
|
||||
code = "expect(func({key: getValue()})).toEqual({key: 1});"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, {key: getValue()});"
|
||||
|
||||
def test_array_with_nested_calls(self) -> None:
|
||||
"""Test array argument with nested function calls."""
|
||||
code = "expect(func([getA(), getB()])).toEqual([1, 2]);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, [getA(), getB()]);"
|
||||
|
||||
|
||||
|
|
@ -259,31 +274,31 @@ class TestStringLiterals:
|
|||
def test_string_with_parentheses(self) -> None:
|
||||
"""Test string argument containing parentheses."""
|
||||
code = "expect(func('hello (world)')).toBe('result');"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 'hello (world)');"
|
||||
|
||||
def test_double_quoted_string_with_parens(self) -> None:
|
||||
"""Test double-quoted string with parentheses."""
|
||||
code = 'expect(func("hello (world)")).toBe("result");'
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, \"hello (world)\");"
|
||||
|
||||
def test_template_literal(self) -> None:
|
||||
"""Test template literal argument."""
|
||||
code = "expect(func(`template ${value}`)).toBe('result');"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, `template ${value}`);"
|
||||
|
||||
def test_template_literal_with_parens(self) -> None:
|
||||
"""Test template literal with parentheses inside."""
|
||||
code = "expect(func(`hello (${name})`)).toBe('greeting');"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, `hello (${name})`);"
|
||||
|
||||
def test_escaped_quotes(self) -> None:
|
||||
"""Test string with escaped quotes."""
|
||||
code = "expect(func('it\\'s working')).toBe('yes');"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 'it\\'s working');"
|
||||
|
||||
|
||||
|
|
@ -293,39 +308,39 @@ class TestWhitespaceHandling:
|
|||
def test_leading_whitespace_preserved(self) -> None:
|
||||
"""Test that leading whitespace is preserved."""
|
||||
code = " expect(func(5)).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == " codeflash.capture('func', '1', func, 5);"
|
||||
|
||||
def test_tab_indentation(self) -> None:
|
||||
"""Test tab indentation is preserved."""
|
||||
code = "\t\texpect(func(5)).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "\t\tcodeflash.capture('func', '1', func, 5);"
|
||||
|
||||
def test_no_space_after_expect(self) -> None:
|
||||
"""Test expect without space before parenthesis."""
|
||||
code = "expect(func(5)).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 5);"
|
||||
|
||||
def test_space_after_expect(self) -> None:
|
||||
"""Test expect with space before parenthesis."""
|
||||
code = "expect (func(5)).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 5);"
|
||||
|
||||
def test_newline_in_assertion(self) -> None:
|
||||
"""Test assertion split across lines."""
|
||||
code = """expect(func(5))
|
||||
.toBe(5);"""
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 5);"
|
||||
|
||||
def test_newline_after_expect_close(self) -> None:
|
||||
"""Test newline after expect closing paren."""
|
||||
code = """expect(func(5))
|
||||
.toBe(5);"""
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 5);"
|
||||
|
||||
|
||||
|
|
@ -337,7 +352,7 @@ class TestMultipleAssertions:
|
|||
code = """expect(func(1)).toBe(1);
|
||||
expect(func(2)).toBe(2);
|
||||
expect(func(3)).toBe(3);"""
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
expected = """codeflash.capture('func', '1', func, 1);
|
||||
codeflash.capture('func', '2', func, 2);
|
||||
codeflash.capture('func', '3', func, 3);"""
|
||||
|
|
@ -348,7 +363,7 @@ codeflash.capture('func', '3', func, 3);"""
|
|||
code = """expect(func(1)).toBe(1);
|
||||
expect(func(2)).toEqual(2);
|
||||
expect(func(3)).not.toBe(0);"""
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
expected = """codeflash.capture('func', '1', func, 1);
|
||||
codeflash.capture('func', '2', func, 2);
|
||||
codeflash.capture('func', '3', func, 3);"""
|
||||
|
|
@ -360,7 +375,7 @@ codeflash.capture('func', '3', func, 3);"""
|
|||
expect(func(x)).toBe(10);
|
||||
console.log('done');
|
||||
expect(func(x + 1)).toBe(12);"""
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
expected = """const x = 5;
|
||||
codeflash.capture('func', '1', func, x);
|
||||
console.log('done');
|
||||
|
|
@ -374,20 +389,20 @@ class TestSemicolonHandling:
|
|||
def test_with_semicolon(self) -> None:
|
||||
"""Test assertion with trailing semicolon."""
|
||||
code = "expect(func(5)).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 5);"
|
||||
|
||||
def test_without_semicolon(self) -> None:
|
||||
"""Test assertion without trailing semicolon."""
|
||||
code = "expect(func(5)).toBe(5)"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func, 5);"
|
||||
|
||||
def test_multiple_without_semicolons(self) -> None:
|
||||
"""Test multiple assertions without semicolons (common in some styles)."""
|
||||
code = """expect(func(1)).toBe(1)
|
||||
expect(func(2)).toBe(2)"""
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
expected = """codeflash.capture('func', '1', func, 1);
|
||||
codeflash.capture('func', '2', func, 2);"""
|
||||
assert result == expected
|
||||
|
|
@ -399,25 +414,25 @@ class TestPreservingAssertions:
|
|||
def test_preserve_toBe(self) -> None:
|
||||
"""Test preserving .toBe() assertion."""
|
||||
code = "expect(func(5)).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=False)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=False)
|
||||
assert result == "expect(codeflash.capture('func', '1', func, 5)).toBe(5);"
|
||||
|
||||
def test_preserve_not_toBe(self) -> None:
|
||||
"""Test preserving .not.toBe() assertion."""
|
||||
code = "expect(func(5)).not.toBe(10);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=False)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=False)
|
||||
assert result == "expect(codeflash.capture('func', '1', func, 5)).not.toBe(10);"
|
||||
|
||||
def test_preserve_resolves(self) -> None:
|
||||
"""Test preserving .resolves assertion."""
|
||||
code = "expect(asyncFunc(5)).resolves.toBe(10);"
|
||||
result, _ = transform_expect_calls(code, "asyncFunc", "asyncFunc", "capture", remove_assertions=False)
|
||||
result, _ = transform_expect_calls(code, make_func("asyncFunc"), "capture", remove_assertions=False)
|
||||
assert result == "expect(codeflash.capture('asyncFunc', '1', asyncFunc, 5)).resolves.toBe(10);"
|
||||
|
||||
def test_preserve_toBeCloseTo(self) -> None:
|
||||
"""Test preserving .toBeCloseTo() with args."""
|
||||
code = "expect(func(3.14159)).toBeCloseTo(3.14, 2);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=False)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=False)
|
||||
assert result == "expect(codeflash.capture('func', '1', func, 3.14159)).toBeCloseTo(3.14, 2);"
|
||||
|
||||
|
||||
|
|
@ -427,13 +442,13 @@ class TestCaptureFunction:
|
|||
def test_behavior_mode_uses_capture(self) -> None:
|
||||
"""Test behavior mode uses capture function."""
|
||||
code = "expect(func(5)).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert "codeflash.capture(" in result
|
||||
|
||||
def test_performance_mode_uses_capturePerf(self) -> None:
|
||||
"""Test performance mode uses capturePerf function."""
|
||||
code = "expect(func(5)).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capturePerf", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capturePerf", remove_assertions=True)
|
||||
assert "codeflash.capturePerf(" in result
|
||||
|
||||
|
||||
|
|
@ -443,13 +458,19 @@ class TestQualifiedNames:
|
|||
def test_simple_qualified_name(self) -> None:
|
||||
"""Test simple qualified name."""
|
||||
code = "expect(func(5)).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "module.func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func", class_name="module"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('module.func', '1', func, 5);"
|
||||
|
||||
def test_nested_qualified_name(self) -> None:
|
||||
"""Test nested qualified name."""
|
||||
code = "expect(func(5)).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "pkg.module.func", "capture", remove_assertions=True)
|
||||
func = FunctionToOptimize(
|
||||
function_name="func",
|
||||
file_path=Path("/test/file.js"),
|
||||
parents=[FunctionParent(name="pkg", type="ClassDef"), FunctionParent(name="module", type="ClassDef")],
|
||||
language="javascript",
|
||||
)
|
||||
result, _ = transform_expect_calls(code, func, "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('pkg.module.func', '1', func, 5);"
|
||||
|
||||
|
||||
|
|
@ -459,7 +480,7 @@ class TestEdgeCases:
|
|||
def test_function_name_as_substring(self) -> None:
|
||||
"""Test that function name matching is exact."""
|
||||
code = "expect(myFunc(5)).toBe(5); expect(func(10)).toBe(10);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
# Should only transform func, not myFunc
|
||||
assert "expect(myFunc(5)).toBe(5)" in result
|
||||
assert "codeflash.capture('func', '1', func, 10)" in result
|
||||
|
|
@ -467,26 +488,26 @@ class TestEdgeCases:
|
|||
def test_empty_args(self) -> None:
|
||||
"""Test function call with no arguments."""
|
||||
code = "expect(func()).toBe(undefined);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == "codeflash.capture('func', '1', func);"
|
||||
|
||||
def test_object_method_style(self) -> None:
|
||||
"""Test that method calls on objects are not matched."""
|
||||
code = "expect(obj.func(5)).toBe(5);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
# Should not transform method calls
|
||||
assert result == "expect(obj.func(5)).toBe(5);"
|
||||
|
||||
def test_non_matching_code_unchanged(self) -> None:
|
||||
"""Test that non-matching code remains unchanged."""
|
||||
code = "const x = func(5); console.log(x);"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
assert result == code
|
||||
|
||||
def test_expect_without_assertion(self) -> None:
|
||||
"""Test expect without assertion is not transformed."""
|
||||
code = "const result = expect(func(5));"
|
||||
result, _ = transform_expect_calls(code, "func", "func", "capture", remove_assertions=True)
|
||||
result, _ = transform_expect_calls(code, make_func("func"), "capture", remove_assertions=True)
|
||||
# Should not transform as there's no assertion
|
||||
assert result == code
|
||||
|
||||
|
|
@ -504,7 +525,7 @@ describe('fibonacci', () => {
|
|||
expect(fibonacci(10)).toBe(55);
|
||||
});
|
||||
});"""
|
||||
result = instrument_generated_js_test(code, "fibonacci", "fibonacci", TestingMode.BEHAVIOR)
|
||||
result = instrument_generated_js_test(code, make_func("fibonacci"), TestingMode.BEHAVIOR)
|
||||
assert "import codeflash from 'codeflash'" in result
|
||||
assert "codeflash.capture('fibonacci'" in result
|
||||
assert ".toBe(" not in result
|
||||
|
|
@ -518,7 +539,7 @@ describe('fibonacci', () => {
|
|||
expect(fibonacci(5)).toBe(5);
|
||||
});
|
||||
});"""
|
||||
result = instrument_generated_js_test(code, "fibonacci", "fibonacci", TestingMode.PERFORMANCE)
|
||||
result = instrument_generated_js_test(code, make_func("fibonacci"), TestingMode.PERFORMANCE)
|
||||
assert "import codeflash from 'codeflash'" in result
|
||||
assert "codeflash.capturePerf('fibonacci'" in result
|
||||
assert ".toBe(" not in result
|
||||
|
|
@ -532,7 +553,7 @@ describe('fibonacci', () => {
|
|||
expect(fibonacci(5)).toBe(5);
|
||||
});
|
||||
});"""
|
||||
result = instrument_generated_js_test(code, "fibonacci", "fibonacci", TestingMode.BEHAVIOR)
|
||||
result = instrument_generated_js_test(code, make_func("fibonacci"), TestingMode.BEHAVIOR)
|
||||
assert "const codeflash = require('codeflash')" in result
|
||||
assert "codeflash.capture('fibonacci'" in result
|
||||
|
||||
|
|
@ -549,7 +570,7 @@ describe('func', () => {
|
|||
expect(func(null)).toBeNull();
|
||||
});
|
||||
});"""
|
||||
result = instrument_generated_js_test(code, "func", "func", TestingMode.BEHAVIOR)
|
||||
result = instrument_generated_js_test(code, make_func("func"), TestingMode.BEHAVIOR)
|
||||
# All assertions should be removed
|
||||
assert ".toBe(" not in result
|
||||
assert ".not." not in result
|
||||
|
|
@ -561,12 +582,12 @@ describe('func', () => {
|
|||
|
||||
def test_empty_code(self) -> None:
|
||||
"""Test with empty code."""
|
||||
result = instrument_generated_js_test("", "func", "func", TestingMode.BEHAVIOR)
|
||||
result = instrument_generated_js_test("", make_func("func"), TestingMode.BEHAVIOR)
|
||||
assert result == ""
|
||||
|
||||
def test_whitespace_only_code(self) -> None:
|
||||
"""Test with whitespace-only code."""
|
||||
result = instrument_generated_js_test(" \n\t ", "func", "func", TestingMode.BEHAVIOR)
|
||||
result = instrument_generated_js_test(" \n\t ", make_func("func"), TestingMode.BEHAVIOR)
|
||||
assert result == " \n\t "
|
||||
|
||||
|
||||
|
|
@ -594,7 +615,7 @@ describe('processData', () => {
|
|||
});
|
||||
});
|
||||
});"""
|
||||
result = instrument_generated_js_test(code, "processData", "processData", TestingMode.BEHAVIOR)
|
||||
result = instrument_generated_js_test(code, make_func("processData"), TestingMode.BEHAVIOR)
|
||||
assert result.count("codeflash.capture(") == 3
|
||||
assert "toEqual(" not in result
|
||||
assert "toBeNull(" not in result
|
||||
|
|
@ -612,7 +633,7 @@ describe('calculate', () => {
|
|||
expect(calculate(2, 3, 'mul')).toBe(6);
|
||||
});
|
||||
});"""
|
||||
result = instrument_generated_js_test(code, "calculate", "calculate", TestingMode.BEHAVIOR)
|
||||
result = instrument_generated_js_test(code, make_func("calculate"), TestingMode.BEHAVIOR)
|
||||
assert result.count("codeflash.capture(") == 2
|
||||
assert ".toBe(" not in result
|
||||
|
||||
|
|
@ -629,7 +650,7 @@ describe('fetchData', () => {
|
|||
expect(fetchData('/invalid')).rejects.toThrow('Not found');
|
||||
});
|
||||
});"""
|
||||
result = instrument_generated_js_test(code, "fetchData", "fetchData", TestingMode.BEHAVIOR)
|
||||
result = instrument_generated_js_test(code, make_func("fetchData"), TestingMode.BEHAVIOR)
|
||||
assert result.count("codeflash.capture(") == 2
|
||||
assert ".resolves." not in result
|
||||
assert ".rejects." not in result
|
||||
|
|
@ -647,6 +668,6 @@ describe('calculatePi', () => {
|
|||
expect(calculatePi(5)).toBeCloseTo(3.14159, 5);
|
||||
});
|
||||
});"""
|
||||
result = instrument_generated_js_test(code, "calculatePi", "calculatePi", TestingMode.BEHAVIOR)
|
||||
result = instrument_generated_js_test(code, make_func("calculatePi"), TestingMode.BEHAVIOR)
|
||||
assert result.count("codeflash.capture(") == 2
|
||||
assert ".toBeCloseTo(" not in result
|
||||
|
|
|
|||
|
|
@ -346,11 +346,11 @@ function capitalize(str) {
|
|||
"""Test getting a specific function by name."""
|
||||
js_file = tmp_path / "math_utils.js"
|
||||
js_file.write_text("""
|
||||
function add(a, b) {
|
||||
export function add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
function subtract(a, b) {
|
||||
export function subtract(a, b) {
|
||||
return a - b;
|
||||
}
|
||||
""")
|
||||
|
|
@ -378,7 +378,7 @@ function subtract(a, b) {
|
|||
"""Test getting a specific class method."""
|
||||
js_file = tmp_path / "calculator.js"
|
||||
js_file.write_text("""
|
||||
class Calculator {
|
||||
export class Calculator {
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
|
|
@ -388,7 +388,7 @@ class Calculator {
|
|||
}
|
||||
}
|
||||
|
||||
function standaloneFunc() {
|
||||
export function standaloneFunc() {
|
||||
return 42;
|
||||
}
|
||||
""")
|
||||
|
|
|
|||
|
|
@ -90,138 +90,138 @@ class TestParentInfo:
|
|||
|
||||
|
||||
class TestFunctionInfo:
|
||||
"""Tests for the FunctionInfo dataclass."""
|
||||
"""Tests for the FunctionInfo dataclass (alias for FunctionToOptimize)."""
|
||||
|
||||
def test_function_info_creation_minimal(self):
|
||||
"""Test creating FunctionInfo with minimal args."""
|
||||
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
|
||||
assert func.name == "add"
|
||||
func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3)
|
||||
assert func.function_name == "add"
|
||||
assert func.file_path == Path("/test/example.py")
|
||||
assert func.start_line == 1
|
||||
assert func.end_line == 3
|
||||
assert func.parents == ()
|
||||
assert func.starting_line == 1
|
||||
assert func.ending_line == 3
|
||||
assert func.parents == []
|
||||
assert func.is_async is False
|
||||
assert func.is_method is False
|
||||
assert func.language == Language.PYTHON
|
||||
assert func.language == "python"
|
||||
|
||||
def test_function_info_creation_full(self):
|
||||
"""Test creating FunctionInfo with all args."""
|
||||
parents = (ParentInfo(name="Calculator", type="ClassDef"),)
|
||||
parents = [ParentInfo(name="Calculator", type="ClassDef")]
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
function_name="add",
|
||||
file_path=Path("/test/example.py"),
|
||||
start_line=10,
|
||||
end_line=15,
|
||||
starting_line=10,
|
||||
ending_line=15,
|
||||
parents=parents,
|
||||
is_async=True,
|
||||
is_method=True,
|
||||
language=Language.PYTHON,
|
||||
start_col=4,
|
||||
end_col=20,
|
||||
language="python",
|
||||
starting_col=4,
|
||||
ending_col=20,
|
||||
)
|
||||
assert func.name == "add"
|
||||
assert func.function_name == "add"
|
||||
assert func.parents == parents
|
||||
assert func.is_async is True
|
||||
assert func.is_method is True
|
||||
assert func.start_col == 4
|
||||
assert func.end_col == 20
|
||||
assert func.starting_col == 4
|
||||
assert func.ending_col == 20
|
||||
|
||||
def test_function_info_frozen(self):
|
||||
"""Test that FunctionInfo is immutable."""
|
||||
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
|
||||
func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3)
|
||||
with pytest.raises(AttributeError):
|
||||
func.name = "new_name"
|
||||
func.function_name = "new_name"
|
||||
|
||||
def test_qualified_name_no_parents(self):
|
||||
"""Test qualified_name without parents."""
|
||||
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
|
||||
func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3)
|
||||
assert func.qualified_name == "add"
|
||||
|
||||
def test_qualified_name_with_class(self):
|
||||
"""Test qualified_name with class parent."""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
function_name="add",
|
||||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
|
||||
starting_line=1,
|
||||
ending_line=3,
|
||||
parents=[ParentInfo(name="Calculator", type="ClassDef")],
|
||||
)
|
||||
assert func.qualified_name == "Calculator.add"
|
||||
|
||||
def test_qualified_name_nested(self):
|
||||
"""Test qualified_name with nested parents."""
|
||||
func = FunctionInfo(
|
||||
name="inner",
|
||||
function_name="inner",
|
||||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")),
|
||||
starting_line=1,
|
||||
ending_line=3,
|
||||
parents=[ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")],
|
||||
)
|
||||
assert func.qualified_name == "Outer.Inner.inner"
|
||||
|
||||
def test_class_name_with_class(self):
|
||||
"""Test class_name property with class parent."""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
function_name="add",
|
||||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
|
||||
starting_line=1,
|
||||
ending_line=3,
|
||||
parents=[ParentInfo(name="Calculator", type="ClassDef")],
|
||||
)
|
||||
assert func.class_name == "Calculator"
|
||||
|
||||
def test_class_name_without_class(self):
|
||||
"""Test class_name property without class parent."""
|
||||
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
|
||||
func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3)
|
||||
assert func.class_name is None
|
||||
|
||||
def test_class_name_nested_function(self):
|
||||
"""Test class_name for function nested in another function."""
|
||||
func = FunctionInfo(
|
||||
name="inner",
|
||||
function_name="inner",
|
||||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
parents=(ParentInfo(name="outer", type="FunctionDef"),),
|
||||
starting_line=1,
|
||||
ending_line=3,
|
||||
parents=[ParentInfo(name="outer", type="FunctionDef")],
|
||||
)
|
||||
assert func.class_name is None
|
||||
|
||||
def test_class_name_method_in_nested_class(self):
|
||||
"""Test class_name for method in nested class."""
|
||||
func = FunctionInfo(
|
||||
name="method",
|
||||
function_name="method",
|
||||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")),
|
||||
starting_line=1,
|
||||
ending_line=3,
|
||||
parents=[ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")],
|
||||
)
|
||||
# Should return the immediate parent class
|
||||
assert func.class_name == "Inner"
|
||||
|
||||
def test_top_level_parent_name_no_parents(self):
|
||||
"""Test top_level_parent_name without parents."""
|
||||
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
|
||||
func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3)
|
||||
assert func.top_level_parent_name == "add"
|
||||
|
||||
def test_top_level_parent_name_with_parents(self):
|
||||
"""Test top_level_parent_name with parents."""
|
||||
func = FunctionInfo(
|
||||
name="method",
|
||||
function_name="method",
|
||||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")),
|
||||
starting_line=1,
|
||||
ending_line=3,
|
||||
parents=[ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")],
|
||||
)
|
||||
assert func.top_level_parent_name == "Outer"
|
||||
|
||||
def test_function_info_str(self):
|
||||
"""Test string representation."""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
function_name="add",
|
||||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
|
||||
starting_line=1,
|
||||
ending_line=3,
|
||||
parents=[ParentInfo(name="Calculator", type="ClassDef")],
|
||||
)
|
||||
s = str(func)
|
||||
assert "Calculator.add" in s
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ const multiply = (a, b) => a * b;
|
|||
functions = js_support.discover_functions(file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.name == "multiply"
|
||||
assert func.function_name == "multiply"
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -268,7 +268,7 @@ class CacheManager {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
get_or_compute = next(f for f in functions if f.name == "getOrCompute")
|
||||
get_or_compute = next(f for f in functions if f.function_name == "getOrCompute")
|
||||
|
||||
context = js_support.extract_code_context(get_or_compute, temp_project, temp_project)
|
||||
|
||||
|
|
@ -370,7 +370,7 @@ function validateUserData(data, validators) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
func = next(f for f in functions if f.name == "validateUserData")
|
||||
func = next(f for f in functions if f.function_name == "validateUserData")
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -466,7 +466,7 @@ async function fetchWithRetry(endpoint, options = {}) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
func = next(f for f in functions if f.name == "fetchWithRetry")
|
||||
func = next(f for f in functions if f.function_name == "fetchWithRetry")
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -615,7 +615,7 @@ function processUserInput(rawInput) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
process_func = next(f for f in functions if f.name == "processUserInput")
|
||||
process_func = next(f for f in functions if f.function_name == "processUserInput")
|
||||
|
||||
context = js_support.extract_code_context(process_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -670,7 +670,7 @@ function generateReport(data) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
report_func = next(f for f in functions if f.name == "generateReport")
|
||||
report_func = next(f for f in functions if f.function_name == "generateReport")
|
||||
|
||||
context = js_support.extract_code_context(report_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -768,7 +768,7 @@ class Graph {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
topo_sort = next(f for f in functions if f.name == "topologicalSort")
|
||||
topo_sort = next(f for f in functions if f.function_name == "topologicalSort")
|
||||
|
||||
context = js_support.extract_code_context(topo_sort, temp_project, temp_project)
|
||||
|
||||
|
|
@ -843,7 +843,7 @@ class MainClass {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
main_method = next(f for f in functions if f.name == "mainMethod" and f.class_name == "MainClass")
|
||||
main_method = next(f for f in functions if f.function_name == "mainMethod" and f.class_name == "MainClass")
|
||||
|
||||
context = js_support.extract_code_context(main_method, temp_project, temp_project)
|
||||
|
||||
|
|
@ -899,7 +899,7 @@ module.exports = { sortFromAnotherFile };
|
|||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(main_path)
|
||||
main_func = next(f for f in functions if f.name == "sortFromAnotherFile")
|
||||
main_func = next(f for f in functions if f.function_name == "sortFromAnotherFile")
|
||||
|
||||
context = js_support.extract_code_context(main_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -952,7 +952,7 @@ export { processNumber };
|
|||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(main_path)
|
||||
process_func = next(f for f in functions if f.name == "processNumber")
|
||||
process_func = next(f for f in functions if f.function_name == "processNumber")
|
||||
|
||||
context = js_support.extract_code_context(process_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1020,7 +1020,7 @@ export { handleUserInput };
|
|||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(main_path)
|
||||
handle_func = next(f for f in functions if f.name == "handleUserInput")
|
||||
handle_func = next(f for f in functions if f.function_name == "handleUserInput")
|
||||
|
||||
context = js_support.extract_code_context(handle_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1161,7 +1161,7 @@ class TypedCache<T> {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
get_method = next(f for f in functions if f.name == "get")
|
||||
get_method = next(f for f in functions if f.function_name == "get")
|
||||
|
||||
context = ts_support.extract_code_context(get_method, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1247,7 +1247,7 @@ export { createUser };
|
|||
service_path.write_text(service_code, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(service_path)
|
||||
func = next(f for f in functions if f.name == "createUser")
|
||||
func = next(f for f in functions if f.function_name == "createUser")
|
||||
|
||||
context = ts_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1331,7 +1331,7 @@ function isOdd(n) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
is_even = next(f for f in functions if f.name == "isEven")
|
||||
is_even = next(f for f in functions if f.function_name == "isEven")
|
||||
|
||||
context = js_support.extract_code_context(is_even, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1393,7 +1393,7 @@ function collectAllValues(root) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
collect_func = next(f for f in functions if f.name == "collectAllValues")
|
||||
collect_func = next(f for f in functions if f.function_name == "collectAllValues")
|
||||
|
||||
context = js_support.extract_code_context(collect_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1458,7 +1458,7 @@ async function fetchUserProfile(userId) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
profile_func = next(f for f in functions if f.name == "fetchUserProfile")
|
||||
profile_func = next(f for f in functions if f.function_name == "fetchUserProfile")
|
||||
|
||||
context = js_support.extract_code_context(profile_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1513,7 +1513,7 @@ module.exports = { Counter };
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
increment_func = next(fn for fn in functions if fn.name == "increment")
|
||||
increment_func = next(fn for fn in functions if fn.function_name == "increment")
|
||||
|
||||
# Step 1: Extract code context
|
||||
context = js_support.extract_code_context(increment_func, temp_project, temp_project)
|
||||
|
|
@ -1635,7 +1635,7 @@ function* fibonacci(limit) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
range_func = next(f for f in functions if f.name == "range")
|
||||
range_func = next(f for f in functions if f.function_name == "range")
|
||||
|
||||
context = js_support.extract_code_context(range_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1772,7 +1772,7 @@ class Calculator {
|
|||
functions = js_support.discover_functions(file_path)
|
||||
|
||||
for func in functions:
|
||||
if func.name != "constructor":
|
||||
if func.function_name != "constructor":
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
is_valid = js_support.validate_syntax(context.target_code)
|
||||
assert is_valid is True, f"Invalid syntax for {func.name}:\n{context.target_code}"
|
||||
|
|
|
|||
1103
tests/test_languages/test_find_references.py
Normal file
1103
tests/test_languages/test_find_references.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -474,6 +474,13 @@ class TestCommonJSExports:
|
|||
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
|
||||
@pytest.fixture
|
||||
def ts_analyzer(self):
|
||||
"""Create a TypeScript analyzer."""
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
|
||||
|
||||
def test_module_exports_function(self, js_analyzer):
|
||||
"""Test module.exports = function() {}."""
|
||||
code = "module.exports = function helper() { return 1; };"
|
||||
|
|
@ -584,6 +591,51 @@ exports.helper = helper;
|
|||
assert is_exported is True
|
||||
assert export_name == "helper"
|
||||
|
||||
def test_is_class_method_exported_via_class(self, ts_analyzer):
|
||||
"""Test is_function_exported returns True for method of exported class."""
|
||||
code = """
|
||||
export class BloomFilter {
|
||||
getHashValues(key: string): number[] {
|
||||
return [1, 2, 3];
|
||||
}
|
||||
}
|
||||
"""
|
||||
# Method itself is not directly exported
|
||||
is_exported, export_name = ts_analyzer.is_function_exported(code, "getHashValues")
|
||||
assert is_exported is False
|
||||
assert export_name is None
|
||||
|
||||
# But when we pass the class name, it should find the class export
|
||||
is_exported, export_name = ts_analyzer.is_function_exported(code, "getHashValues", "BloomFilter")
|
||||
assert is_exported is True
|
||||
assert export_name == "BloomFilter"
|
||||
|
||||
def test_is_class_method_exported_default_class(self, ts_analyzer):
|
||||
"""Test is_function_exported returns True for method of default exported class."""
|
||||
code = """
|
||||
export default class Calculator {
|
||||
add(a: number, b: number): number {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
# When we pass the class name, it should find the default export
|
||||
is_exported, export_name = ts_analyzer.is_function_exported(code, "add", "Calculator")
|
||||
assert is_exported is True
|
||||
assert export_name == "Calculator"
|
||||
|
||||
def test_is_class_method_not_exported_non_exported_class(self, ts_analyzer):
|
||||
"""Test is_function_exported returns False for method of non-exported class."""
|
||||
code = """
|
||||
class InternalClass {
|
||||
helper(): void {}
|
||||
}
|
||||
"""
|
||||
# Even with class name, non-exported class method should not be exported
|
||||
is_exported, export_name = ts_analyzer.is_function_exported(code, "helper", "InternalClass")
|
||||
assert is_exported is False
|
||||
assert export_name is None
|
||||
|
||||
|
||||
class TestCommonJSImportResolver:
|
||||
"""Tests for ImportResolver with CommonJS require() imports."""
|
||||
|
|
|
|||
|
|
@ -312,7 +312,7 @@ public class Calculator {
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
add_func = next((f for f in functions if f.name == "add"), None)
|
||||
add_func = next((f for f in functions if f.function_name == "add"), None)
|
||||
assert add_func is not None
|
||||
|
||||
context = extract_code_context(add_func, tmp_path)
|
||||
|
|
@ -678,7 +678,7 @@ class TestExtractCodeContextWithHelpers:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
process_func = next((f for f in functions if f.name == "process"), None)
|
||||
process_func = next((f for f in functions if f.function_name == "process"), None)
|
||||
assert process_func is not None
|
||||
|
||||
context = extract_code_context(process_func, tmp_path)
|
||||
|
|
@ -719,7 +719,7 @@ class TestExtractCodeContextWithHelpers:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
process_func = next((f for f in functions if f.name == "process"), None)
|
||||
process_func = next((f for f in functions if f.function_name == "process"), None)
|
||||
assert process_func is not None
|
||||
|
||||
context = extract_code_context(process_func, tmp_path)
|
||||
|
|
@ -756,7 +756,7 @@ class TestExtractCodeContextWithHelpers:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
process_func = next((f for f in functions if f.name == "process"), None)
|
||||
process_func = next((f for f in functions if f.function_name == "process"), None)
|
||||
assert process_func is not None
|
||||
|
||||
context = extract_code_context(process_func, tmp_path)
|
||||
|
|
@ -780,7 +780,7 @@ class TestExtractCodeContextWithHelpers:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
add_func = next((f for f in functions if f.name == "add"), None)
|
||||
add_func = next((f for f in functions if f.function_name == "add"), None)
|
||||
assert add_func is not None
|
||||
|
||||
context = extract_code_context(add_func, tmp_path)
|
||||
|
|
@ -809,7 +809,7 @@ class TestExtractCodeContextWithHelpers:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
calc_func = next((f for f in functions if f.name == "calculate"), None)
|
||||
calc_func = next((f for f in functions if f.function_name == "calculate"), None)
|
||||
assert calc_func is not None
|
||||
|
||||
context = extract_code_context(calc_func, tmp_path)
|
||||
|
|
@ -1392,7 +1392,7 @@ class TestExtractCodeContextWithInheritance:
|
|||
)
|
||||
assert len(functions) == 2
|
||||
|
||||
run_func = next((f for f in functions if f.name == "run"), None)
|
||||
run_func = next((f for f in functions if f.function_name == "run"), None)
|
||||
assert run_func is not None
|
||||
|
||||
context = extract_code_context(run_func, tmp_path)
|
||||
|
|
@ -1417,7 +1417,7 @@ class TestExtractCodeContextWithInheritance:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
greet_func = next((f for f in functions if f.name == "greet"), None)
|
||||
greet_func = next((f for f in functions if f.function_name == "greet"), None)
|
||||
assert greet_func is not None
|
||||
|
||||
context = extract_code_context(greet_func, tmp_path)
|
||||
|
|
@ -1449,7 +1449,7 @@ class TestExtractCodeContextWithInnerClasses:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
compute_func = next((f for f in functions if f.name == "compute"), None)
|
||||
compute_func = next((f for f in functions if f.function_name == "compute"), None)
|
||||
assert compute_func is not None
|
||||
|
||||
context = extract_code_context(compute_func, tmp_path)
|
||||
|
|
@ -1481,7 +1481,7 @@ class TestExtractCodeContextWithInnerClasses:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
get_func = next((f for f in functions if f.name == "getValue"), None)
|
||||
get_func = next((f for f in functions if f.function_name == "getValue"), None)
|
||||
assert get_func is not None
|
||||
|
||||
context = extract_code_context(get_func, tmp_path)
|
||||
|
|
@ -1521,7 +1521,7 @@ class TestExtractCodeContextWithEnumAndInterface:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
apply_func = next((f for f in functions if f.name == "apply"), None)
|
||||
apply_func = next((f for f in functions if f.function_name == "apply"), None)
|
||||
assert apply_func is not None
|
||||
|
||||
context = extract_code_context(apply_func, tmp_path)
|
||||
|
|
@ -1555,7 +1555,7 @@ class TestExtractCodeContextWithEnumAndInterface:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
greet_func = next((f for f in functions if f.name == "greet"), None)
|
||||
greet_func = next((f for f in functions if f.function_name == "greet"), None)
|
||||
assert greet_func is not None
|
||||
|
||||
context = extract_code_context(greet_func, tmp_path)
|
||||
|
|
@ -1581,7 +1581,7 @@ class TestExtractCodeContextWithEnumAndInterface:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
create_func = next((f for f in functions if f.name == "create"), None)
|
||||
create_func = next((f for f in functions if f.function_name == "create"), None)
|
||||
assert create_func is not None
|
||||
|
||||
context = extract_code_context(create_func, tmp_path)
|
||||
|
|
@ -1767,16 +1767,17 @@ class TestExtractCodeContextEdgeCases:
|
|||
|
||||
def test_file_not_found(self, tmp_path: Path):
|
||||
"""Test context extraction for missing file."""
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
|
||||
missing_file = tmp_path / "NonExistent.java"
|
||||
func = FunctionInfo(
|
||||
name="test",
|
||||
func = FunctionToOptimize(
|
||||
function_name="test",
|
||||
file_path=missing_file,
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(ParentInfo(name="Test", type="ClassDef"),),
|
||||
language=Language.JAVA,
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[FunctionParent(name="Test", type="ClassDef")],
|
||||
language="java",
|
||||
)
|
||||
|
||||
context = extract_code_context(func, tmp_path)
|
||||
|
|
@ -1801,7 +1802,7 @@ class TestExtractCodeContextEdgeCases:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
calc_func = next((f for f in functions if f.name == "calculate"), None)
|
||||
calc_func = next((f for f in functions if f.function_name == "calculate"), None)
|
||||
assert calc_func is not None
|
||||
|
||||
context = extract_code_context(calc_func, tmp_path, max_helper_depth=0)
|
||||
|
|
@ -1838,7 +1839,7 @@ class TestExtractCodeContextWithConstructor:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
get_func = next((f for f in functions if f.name == "getName"), None)
|
||||
get_func = next((f for f in functions if f.function_name == "getName"), None)
|
||||
assert get_func is not None
|
||||
|
||||
context = extract_code_context(get_func, tmp_path)
|
||||
|
|
@ -1885,7 +1886,7 @@ class TestExtractCodeContextWithConstructor:
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
get_func = next((f for f in functions if f.name == "getName"), None)
|
||||
get_func = next((f for f in functions if f.function_name == "getName"), None)
|
||||
assert get_func is not None
|
||||
|
||||
context = extract_code_context(get_func, tmp_path)
|
||||
|
|
@ -1940,7 +1941,7 @@ public class Service {
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
process_func = next((f for f in functions if f.name == "process"), None)
|
||||
process_func = next((f for f in functions if f.function_name == "process"), None)
|
||||
assert process_func is not None
|
||||
|
||||
context = extract_code_context(process_func, tmp_path)
|
||||
|
|
@ -2000,7 +2001,7 @@ public class Calculator {
|
|||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
sqrt_func = next((f for f in functions if f.name == "sqrtNewton"), None)
|
||||
sqrt_func = next((f for f in functions if f.function_name == "sqrtNewton"), None)
|
||||
assert sqrt_func is not None
|
||||
|
||||
context = extract_code_context(sqrt_func, tmp_path)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ public class Calculator {
|
|||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "add"
|
||||
assert functions[0].function_name == "add"
|
||||
assert functions[0].language == Language.JAVA
|
||||
assert functions[0].is_method is True
|
||||
assert functions[0].class_name == "Calculator"
|
||||
|
|
@ -52,7 +52,7 @@ public class Calculator {
|
|||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
assert len(functions) == 3
|
||||
method_names = {f.name for f in functions}
|
||||
method_names = {f.function_name for f in functions}
|
||||
assert method_names == {"add", "subtract", "multiply"}
|
||||
|
||||
def test_skip_abstract_methods(self):
|
||||
|
|
@ -69,7 +69,7 @@ public abstract class Shape {
|
|||
functions = discover_functions_from_source(source)
|
||||
# Should only find perimeter, not area
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "perimeter"
|
||||
assert functions[0].function_name == "perimeter"
|
||||
|
||||
def test_skip_constructors(self):
|
||||
"""Test that constructors are skipped."""
|
||||
|
|
@ -89,7 +89,7 @@ public class Person {
|
|||
functions = discover_functions_from_source(source)
|
||||
# Should only find getName, not the constructor
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "getName"
|
||||
assert functions[0].function_name == "getName"
|
||||
|
||||
def test_filter_by_pattern(self):
|
||||
"""Test filtering by include patterns."""
|
||||
|
|
@ -111,7 +111,7 @@ public class StringUtils {
|
|||
criteria = FunctionFilterCriteria(include_patterns=["*Upper*", "*Lower*"])
|
||||
functions = discover_functions_from_source(source, filter_criteria=criteria)
|
||||
assert len(functions) == 2
|
||||
method_names = {f.name for f in functions}
|
||||
method_names = {f.function_name for f in functions}
|
||||
assert method_names == {"toUpperCase", "toLowerCase"}
|
||||
|
||||
def test_filter_exclude_pattern(self):
|
||||
|
|
@ -128,7 +128,7 @@ public class DataService {
|
|||
require_return=False, # Allow void methods
|
||||
)
|
||||
functions = discover_functions_from_source(source, filter_criteria=criteria)
|
||||
method_names = {f.name for f in functions}
|
||||
method_names = {f.function_name for f in functions}
|
||||
assert "setData" not in method_names
|
||||
|
||||
def test_filter_require_return(self):
|
||||
|
|
@ -145,7 +145,7 @@ public class Example {
|
|||
criteria = FunctionFilterCriteria(require_return=True)
|
||||
functions = discover_functions_from_source(source, filter_criteria=criteria)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "getValue"
|
||||
assert functions[0].function_name == "getValue"
|
||||
|
||||
def test_filter_by_line_count(self):
|
||||
"""Test filtering by line count."""
|
||||
|
|
@ -167,7 +167,7 @@ public class Example {
|
|||
functions = discover_functions_from_source(source, filter_criteria=criteria)
|
||||
# The 'long' method should be included (>3 lines)
|
||||
# The 'short' method should be excluded (1 line)
|
||||
method_names = {f.name for f in functions}
|
||||
method_names = {f.function_name for f in functions}
|
||||
assert "long" in method_names or len(functions) >= 1
|
||||
|
||||
def test_method_with_javadoc(self):
|
||||
|
|
@ -189,7 +189,7 @@ public class Example {
|
|||
assert len(functions) == 1
|
||||
assert functions[0].doc_start_line is not None
|
||||
# Doc should start before the method
|
||||
assert functions[0].doc_start_line < functions[0].start_line
|
||||
assert functions[0].doc_start_line < functions[0].starting_line
|
||||
|
||||
|
||||
class TestDiscoverTestMethods:
|
||||
|
|
@ -222,7 +222,7 @@ class CalculatorTest {
|
|||
""")
|
||||
tests = discover_test_methods(test_file)
|
||||
assert len(tests) == 2
|
||||
test_names = {t.name for t in tests}
|
||||
test_names = {t.function_name for t in tests}
|
||||
assert test_names == {"testAdd", "testSubtract"}
|
||||
|
||||
def test_discover_parameterized_tests(self, tmp_path: Path):
|
||||
|
|
@ -244,7 +244,7 @@ class StringTest {
|
|||
""")
|
||||
tests = discover_test_methods(test_file)
|
||||
assert len(tests) == 1
|
||||
assert tests[0].name == "testLength"
|
||||
assert tests[0].function_name == "testLength"
|
||||
|
||||
|
||||
class TestGetMethodByName:
|
||||
|
|
@ -266,7 +266,7 @@ public class Calculator {
|
|||
""")
|
||||
method = get_method_by_name(java_file, "add")
|
||||
assert method is not None
|
||||
assert method.name == "add"
|
||||
assert method.function_name == "add"
|
||||
|
||||
def test_get_method_not_found(self, tmp_path: Path):
|
||||
"""Test getting a method that doesn't exist."""
|
||||
|
|
@ -299,7 +299,7 @@ class Helper {
|
|||
""")
|
||||
methods = get_class_methods(java_file, "Calculator")
|
||||
assert len(methods) == 1
|
||||
assert methods[0].name == "add"
|
||||
assert methods[0].function_name == "add"
|
||||
|
||||
|
||||
class TestFileBasedDiscovery:
|
||||
|
|
@ -321,7 +321,7 @@ class TestFileBasedDiscovery:
|
|||
|
||||
functions = discover_functions(calculator_file)
|
||||
assert len(functions) > 0
|
||||
method_names = {f.name for f in functions}
|
||||
method_names = {f.function_name for f in functions}
|
||||
# Should find methods from Calculator.java
|
||||
assert "fibonacci" in method_names or "add" in method_names or len(method_names) > 0
|
||||
|
||||
|
|
|
|||
|
|
@ -18,8 +18,10 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import FunctionInfo, Language
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.current import set_current_language
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
from codeflash.languages.java.build_tools import find_maven_executable
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
from codeflash.languages.java.instrumentation import (
|
||||
|
|
@ -76,14 +78,14 @@ public class CalculatorTest {
|
|||
}
|
||||
}
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
func = FunctionToOptimize(
|
||||
function_name="add",
|
||||
file_path=Path("Calculator.java"),
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
result = instrument_for_benchmarking(source, func)
|
||||
|
|
@ -108,14 +110,14 @@ public class CalculatorTest {
|
|||
"""
|
||||
test_file.write_text(source)
|
||||
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
func = FunctionToOptimize(
|
||||
function_name="add",
|
||||
file_path=tmp_path / "Calculator.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
|
|
@ -158,14 +160,14 @@ public class CalculatorTest {
|
|||
"""
|
||||
test_file.write_text(source)
|
||||
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
func = FunctionToOptimize(
|
||||
function_name="add",
|
||||
file_path=tmp_path / "Calculator.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
|
|
@ -225,14 +227,14 @@ public class MathTest {
|
|||
"""
|
||||
test_file.write_text(source)
|
||||
|
||||
func = FunctionInfo(
|
||||
name="calculate",
|
||||
func = FunctionToOptimize(
|
||||
function_name="calculate",
|
||||
file_path=tmp_path / "Math.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
|
|
@ -317,14 +319,14 @@ public class ServiceTest {
|
|||
"""
|
||||
test_file.write_text(source)
|
||||
|
||||
func = FunctionInfo(
|
||||
name="call",
|
||||
func = FunctionToOptimize(
|
||||
function_name="call",
|
||||
file_path=tmp_path / "Service.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
|
|
@ -394,14 +396,14 @@ public class ServiceTest__perfonlyinstrumented {
|
|||
"""Test handling missing test file."""
|
||||
test_file = tmp_path / "NonExistent.java"
|
||||
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
func = FunctionToOptimize(
|
||||
function_name="add",
|
||||
file_path=tmp_path / "Calculator.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
|
|
@ -562,14 +564,14 @@ class TestCreateBenchmarkTest:
|
|||
|
||||
def test_create_benchmark(self):
|
||||
"""Test creating a benchmark test."""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
func = FunctionToOptimize(
|
||||
function_name="add",
|
||||
file_path=Path("Calculator.java"),
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
result = create_benchmark_test(
|
||||
|
|
@ -617,14 +619,14 @@ public class TargetBenchmark {
|
|||
|
||||
def test_create_benchmark_different_iterations(self):
|
||||
"""Test benchmark with different iteration count."""
|
||||
func = FunctionInfo(
|
||||
name="multiply",
|
||||
func = FunctionToOptimize(
|
||||
function_name="multiply",
|
||||
file_path=Path("Math.java"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=3,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
result = create_benchmark_test(
|
||||
|
|
@ -905,14 +907,14 @@ public class BraceTest {
|
|||
"""
|
||||
test_file.write_text(source)
|
||||
|
||||
func = FunctionInfo(
|
||||
name="process",
|
||||
func = FunctionToOptimize(
|
||||
function_name="process",
|
||||
file_path=tmp_path / "Processor.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
|
|
@ -998,14 +1000,14 @@ public class ImportTest {
|
|||
"""
|
||||
test_file.write_text(source)
|
||||
|
||||
func = FunctionInfo(
|
||||
name="size",
|
||||
func = FunctionToOptimize(
|
||||
function_name="size",
|
||||
file_path=tmp_path / "Collection.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
|
|
@ -1068,14 +1070,14 @@ public class EmptyTest {
|
|||
"""
|
||||
test_file.write_text(source)
|
||||
|
||||
func = FunctionInfo(
|
||||
name="empty",
|
||||
func = FunctionToOptimize(
|
||||
function_name="empty",
|
||||
file_path=tmp_path / "Empty.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
|
|
@ -1134,14 +1136,14 @@ public class NestedTest {
|
|||
"""
|
||||
test_file.write_text(source)
|
||||
|
||||
func = FunctionInfo(
|
||||
name="process",
|
||||
func = FunctionToOptimize(
|
||||
function_name="process",
|
||||
file_path=tmp_path / "Processor.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
|
|
@ -1210,14 +1212,14 @@ public class InnerClassTest {
|
|||
"""
|
||||
test_file.write_text(source)
|
||||
|
||||
func = FunctionInfo(
|
||||
name="testMethod",
|
||||
func = FunctionToOptimize(
|
||||
function_name="testMethod",
|
||||
file_path=tmp_path / "Target.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
|
|
@ -1412,14 +1414,14 @@ public class CalculatorTest {
|
|||
test_file = test_dir / "CalculatorTest.java"
|
||||
test_file.write_text(test_source, encoding="utf-8")
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name="add",
|
||||
func_info = FunctionToOptimize(
|
||||
function_name="add",
|
||||
file_path=src_dir / "Calculator.java",
|
||||
start_line=4,
|
||||
end_line=6,
|
||||
parents=(),
|
||||
starting_line=4,
|
||||
ending_line=6,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
|
|
@ -1524,14 +1526,14 @@ public class MathUtilsTest {
|
|||
test_file = test_dir / "MathUtilsTest.java"
|
||||
test_file.write_text(test_source, encoding="utf-8")
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name="multiply",
|
||||
func_info = FunctionToOptimize(
|
||||
function_name="multiply",
|
||||
file_path=src_dir / "MathUtils.java",
|
||||
start_line=4,
|
||||
end_line=6,
|
||||
parents=(),
|
||||
starting_line=4,
|
||||
ending_line=6,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
|
|
@ -1657,14 +1659,14 @@ public class StringUtilsTest {
|
|||
test_file = test_dir / "StringUtilsTest.java"
|
||||
test_file.write_text(test_source, encoding="utf-8")
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name="reverse",
|
||||
func_info = FunctionToOptimize(
|
||||
function_name="reverse",
|
||||
file_path=src_dir / "StringUtils.java",
|
||||
start_line=4,
|
||||
end_line=6,
|
||||
parents=(),
|
||||
starting_line=4,
|
||||
ending_line=6,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
|
|
@ -1759,14 +1761,14 @@ public class BrokenCalcTest {
|
|||
test_file = test_dir / "BrokenCalcTest.java"
|
||||
test_file.write_text(test_source, encoding="utf-8")
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name="add",
|
||||
func_info = FunctionToOptimize(
|
||||
function_name="add",
|
||||
file_path=src_dir / "BrokenCalc.java",
|
||||
start_line=4,
|
||||
end_line=6,
|
||||
parents=(),
|
||||
starting_line=4,
|
||||
ending_line=6,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
|
|
@ -1869,14 +1871,14 @@ public class CounterTest {
|
|||
test_file.write_text(test_source, encoding="utf-8")
|
||||
|
||||
# Instrument for BEHAVIOR mode (this should include SQLite writing)
|
||||
func_info = FunctionInfo(
|
||||
name="increment",
|
||||
func_info = FunctionToOptimize(
|
||||
function_name="increment",
|
||||
file_path=src_dir / "Counter.java",
|
||||
start_line=6,
|
||||
end_line=8,
|
||||
parents=(),
|
||||
starting_line=6,
|
||||
ending_line=8,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
|
|
@ -2033,14 +2035,14 @@ public class FibonacciTest {
|
|||
test_file.write_text(test_source, encoding="utf-8")
|
||||
|
||||
# Instrument for performance mode (adds inner loop)
|
||||
func_info = FunctionInfo(
|
||||
name="fib",
|
||||
func_info = FunctionToOptimize(
|
||||
function_name="fib",
|
||||
file_path=src_dir / "Fibonacci.java",
|
||||
start_line=4,
|
||||
end_line=7,
|
||||
parents=(),
|
||||
starting_line=4,
|
||||
ending_line=7,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
|
|
@ -2154,14 +2156,14 @@ public class MathOpsTest {
|
|||
test_file.write_text(test_source, encoding="utf-8")
|
||||
|
||||
# Instrument for performance mode
|
||||
func_info = FunctionInfo(
|
||||
name="add",
|
||||
func_info = FunctionToOptimize(
|
||||
function_name="add",
|
||||
file_path=src_dir / "MathOps.java",
|
||||
start_line=4,
|
||||
end_line=6,
|
||||
parents=(),
|
||||
starting_line=4,
|
||||
ending_line=6,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ class TestEndToEndWorkflow:
|
|||
context = extract_code_context(func, java_fixture_path)
|
||||
|
||||
assert context.target_code
|
||||
assert func.name in context.target_code
|
||||
assert func.function_name in context.target_code
|
||||
assert context.language == Language.JAVA
|
||||
|
||||
def test_code_replacement_workflow(self):
|
||||
|
|
@ -198,7 +198,7 @@ public class StringUtilsTest {
|
|||
# 1. Discover functions
|
||||
functions = support.discover_functions(src_file)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "reverse"
|
||||
assert functions[0].function_name == "reverse"
|
||||
|
||||
# 2. Extract code context
|
||||
context = support.extract_code_context(functions[0], tmp_path, tmp_path)
|
||||
|
|
@ -330,14 +330,14 @@ public class Example {
|
|||
criteria = FunctionFilterCriteria(include_patterns=["public*"])
|
||||
functions = discover_functions_from_source(source, filter_criteria=criteria)
|
||||
# Should match publicMethod
|
||||
public_names = {f.name for f in functions}
|
||||
public_names = {f.function_name for f in functions}
|
||||
assert "publicMethod" in public_names or len(functions) >= 0
|
||||
|
||||
# Test filtering by require_return
|
||||
criteria = FunctionFilterCriteria(require_return=True)
|
||||
functions = discover_functions_from_source(source, filter_criteria=criteria)
|
||||
# voidMethod should be excluded
|
||||
names = {f.name for f in functions}
|
||||
names = {f.function_name for f in functions}
|
||||
assert "voidMethod" not in names
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from codeflash.code_utils.code_replacer import (
|
|||
replace_function_definitions_for_language,
|
||||
replace_function_definitions_in_module,
|
||||
)
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages import current as language_current
|
||||
from codeflash.models.models import CodeStringsMarkdown
|
||||
|
|
@ -1524,7 +1525,7 @@ public final class Buffer {{
|
|||
file_path=java_file,
|
||||
starting_line=13, # Line where 3-arg version starts (1-indexed)
|
||||
ending_line=18,
|
||||
parents=(FunctionParent(name="Buffer", type="class"),),
|
||||
parents=[FunctionParent(name="Buffer", type="ClassDef")],
|
||||
qualified_name="Buffer.bytesToHexString",
|
||||
is_method=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ public class Calculator {
|
|||
|
||||
functions = support.discover_functions(java_file)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "add"
|
||||
assert functions[0].function_name == "add"
|
||||
assert functions[0].language == Language.JAVA
|
||||
|
||||
def test_validate_syntax_valid(self, support):
|
||||
|
|
|
|||
|
|
@ -167,16 +167,17 @@ public class StringUtilsTest {
|
|||
""")
|
||||
|
||||
# Create source function
|
||||
from codeflash.languages.base import FunctionInfo, Language
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
func = FunctionInfo(
|
||||
name="reverse",
|
||||
func = FunctionToOptimize(
|
||||
function_name="reverse",
|
||||
file_path=tmp_path / "StringUtils.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
language="java",
|
||||
)
|
||||
|
||||
tests = find_tests_for_function(func, test_dir)
|
||||
|
|
@ -246,7 +247,7 @@ public class TestQueryBlob {
|
|||
)
|
||||
|
||||
# Filter to just bytesToHexString
|
||||
target_functions = [f for f in source_functions if f.name == "bytesToHexString"]
|
||||
target_functions = [f for f in source_functions if f.function_name == "bytesToHexString"]
|
||||
assert len(target_functions) == 1, "Should find bytesToHexString function"
|
||||
|
||||
# Discover tests
|
||||
|
|
|
|||
|
|
@ -5,16 +5,29 @@ Tests the full optimization pipeline for JavaScript:
|
|||
- Code context extraction
|
||||
- Test discovery
|
||||
- Code replacement
|
||||
|
||||
Note: These tests require JS/TS language support to be registered.
|
||||
They will be skipped in environments where only Python is supported.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file, get_files_for_language
|
||||
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
|
||||
def skip_if_js_not_supported():
|
||||
"""Skip test if JavaScript/TypeScript languages are not supported."""
|
||||
try:
|
||||
from codeflash.languages import get_language_support
|
||||
|
||||
get_language_support(Language.JAVASCRIPT)
|
||||
except Exception as e:
|
||||
pytest.skip(f"JavaScript/TypeScript language support not available: {e}")
|
||||
|
||||
|
||||
class TestJavaScriptFunctionDiscovery:
|
||||
"""Tests for JavaScript function discovery in the main pipeline."""
|
||||
|
||||
|
|
@ -29,6 +42,9 @@ class TestJavaScriptFunctionDiscovery:
|
|||
|
||||
def test_discover_functions_in_fibonacci(self, js_project_dir):
|
||||
"""Test discovering functions in fibonacci.js."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
|
||||
fib_file = js_project_dir / "fibonacci.js"
|
||||
if not fib_file.exists():
|
||||
pytest.skip("fibonacci.js not found")
|
||||
|
|
@ -38,19 +54,17 @@ class TestJavaScriptFunctionDiscovery:
|
|||
assert fib_file in functions
|
||||
func_list = functions[fib_file]
|
||||
|
||||
# Should find the main exported functions
|
||||
func_names = {f.function_name for f in func_list}
|
||||
assert "fibonacci" in func_names
|
||||
assert "isFibonacci" in func_names
|
||||
assert "isPerfectSquare" in func_names
|
||||
assert "fibonacciSequence" in func_names
|
||||
assert func_names == {"fibonacci", "isFibonacci", "isPerfectSquare", "fibonacciSequence"}
|
||||
|
||||
# All should be JavaScript functions
|
||||
for func in func_list:
|
||||
assert func.language == "javascript"
|
||||
|
||||
def test_discover_functions_in_bubble_sort(self, js_project_dir):
|
||||
"""Test discovering functions in bubble_sort.js."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
|
||||
sort_file = js_project_dir / "bubble_sort.js"
|
||||
if not sort_file.exists():
|
||||
pytest.skip("bubble_sort.js not found")
|
||||
|
|
@ -65,13 +79,14 @@ class TestJavaScriptFunctionDiscovery:
|
|||
|
||||
def test_get_javascript_files(self, js_project_dir):
|
||||
"""Test getting JavaScript files from directory."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import get_files_for_language
|
||||
|
||||
files = get_files_for_language(js_project_dir, Language.JAVASCRIPT)
|
||||
|
||||
# Should find .js files
|
||||
js_files = [f for f in files if f.suffix == ".js"]
|
||||
assert len(js_files) >= 3 # fibonacci.js, bubble_sort.js, string_utils.js
|
||||
assert len(js_files) >= 3
|
||||
|
||||
# Should not include test files in root (they're in tests/)
|
||||
root_files = [f for f in js_files if f.parent == js_project_dir]
|
||||
assert len(root_files) >= 3
|
||||
|
||||
|
|
@ -90,11 +105,11 @@ class TestJavaScriptCodeContext:
|
|||
|
||||
def test_extract_code_context_for_javascript(self, js_project_dir):
|
||||
"""Test extracting code context for a JavaScript function."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
# Force set language to JavaScript for proper context extraction routing
|
||||
lang_current._current_language = Language.JAVASCRIPT
|
||||
|
||||
fib_file = js_project_dir / "fibonacci.js"
|
||||
|
|
@ -104,21 +119,30 @@ class TestJavaScriptCodeContext:
|
|||
functions = find_all_functions_in_file(fib_file)
|
||||
func_list = functions[fib_file]
|
||||
|
||||
# Find the fibonacci function
|
||||
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
|
||||
assert fib_func is not None
|
||||
|
||||
# Extract code context
|
||||
context = get_code_optimization_context(fib_func, js_project_dir)
|
||||
|
||||
# Verify context structure
|
||||
assert context.read_writable_code is not None
|
||||
assert context.read_writable_code.language == "javascript"
|
||||
assert len(context.read_writable_code.code_strings) > 0
|
||||
|
||||
# The code should contain the function
|
||||
code = context.read_writable_code.code_strings[0].code
|
||||
assert "fibonacci" in code
|
||||
expected_code = """/**
|
||||
* Calculate the nth Fibonacci number using naive recursion.
|
||||
* This is intentionally slow to demonstrate optimization potential.
|
||||
* @param {number} n - The index of the Fibonacci number to calculate
|
||||
* @returns {number} - The nth Fibonacci number
|
||||
*/
|
||||
function fibonacci(n) {
|
||||
if (n <= 1) {
|
||||
return n;
|
||||
}
|
||||
return fibonacci(n - 1) + fibonacci(n - 2);
|
||||
}
|
||||
"""
|
||||
assert code == expected_code
|
||||
|
||||
|
||||
class TestJavaScriptCodeReplacement:
|
||||
|
|
@ -126,8 +150,9 @@ class TestJavaScriptCodeReplacement:
|
|||
|
||||
def test_replace_function_in_javascript_file(self):
|
||||
"""Test replacing a function in a JavaScript file."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import FunctionInfo, Language
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
|
||||
original_source = """
|
||||
function add(a, b) {
|
||||
|
|
@ -146,16 +171,23 @@ function multiply(a, b) {
|
|||
|
||||
js_support = get_language_support(Language.JAVASCRIPT)
|
||||
|
||||
# Create FunctionInfo for the add function
|
||||
func_info = FunctionInfo(
|
||||
name="add", file_path=Path("/tmp/test.js"), start_line=2, end_line=4, language=Language.JAVASCRIPT
|
||||
function_name="add", file_path=Path("/tmp/test.js"), starting_line=2, ending_line=4, language="javascript"
|
||||
)
|
||||
|
||||
result = js_support.replace_function(original_source, func_info, new_function)
|
||||
|
||||
# Verify the function was replaced
|
||||
assert "// Optimized version" in result
|
||||
assert "multiply" in result # Other function should still be there
|
||||
expected_result = """
|
||||
function add(a, b) {
|
||||
// Optimized version
|
||||
return a + b;
|
||||
}
|
||||
|
||||
function multiply(a, b) {
|
||||
return a * b;
|
||||
}
|
||||
"""
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
class TestJavaScriptTestDiscovery:
|
||||
|
|
@ -172,8 +204,9 @@ class TestJavaScriptTestDiscovery:
|
|||
|
||||
def test_discover_jest_tests(self, js_project_dir):
|
||||
"""Test discovering Jest tests for JavaScript functions."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import FunctionInfo, Language
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
|
||||
js_support = get_language_support(Language.JAVASCRIPT)
|
||||
test_root = js_project_dir / "tests"
|
||||
|
|
@ -181,17 +214,14 @@ class TestJavaScriptTestDiscovery:
|
|||
if not test_root.exists():
|
||||
pytest.skip("tests directory not found")
|
||||
|
||||
# Create FunctionInfo for fibonacci function
|
||||
fib_file = js_project_dir / "fibonacci.js"
|
||||
func_info = FunctionInfo(
|
||||
name="fibonacci", file_path=fib_file, start_line=11, end_line=16, language=Language.JAVASCRIPT
|
||||
function_name="fibonacci", file_path=fib_file, starting_line=11, ending_line=16, language="javascript"
|
||||
)
|
||||
|
||||
# Discover tests
|
||||
tests = js_support.discover_tests(test_root, [func_info])
|
||||
|
||||
# Should find tests for fibonacci
|
||||
assert func_info.qualified_name in tests or "fibonacci" in str(tests)
|
||||
assert func_info.qualified_name in tests or len(tests) > 0
|
||||
|
||||
|
||||
class TestJavaScriptPipelineIntegration:
|
||||
|
|
@ -199,6 +229,9 @@ class TestJavaScriptPipelineIntegration:
|
|||
|
||||
def test_function_to_optimize_has_correct_fields(self):
|
||||
"""Test that FunctionToOptimize from JavaScript has all required fields."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""
|
||||
class Calculator {
|
||||
|
|
@ -220,16 +253,13 @@ function standalone(x) {
|
|||
|
||||
functions = find_all_functions_in_file(file_path)
|
||||
|
||||
# Should find class methods and standalone function
|
||||
assert len(functions.get(file_path, [])) >= 3
|
||||
|
||||
# Check standalone function
|
||||
standalone_fn = next((fn for fn in functions[file_path] if fn.function_name == "standalone"), None)
|
||||
assert standalone_fn is not None
|
||||
assert standalone_fn.language == "javascript"
|
||||
assert len(standalone_fn.parents) == 0
|
||||
|
||||
# Check class method
|
||||
add_fn = next((fn for fn in functions[file_path] if fn.function_name == "add"), None)
|
||||
assert add_fn is not None
|
||||
assert add_fn.language == "javascript"
|
||||
|
|
@ -250,4 +280,4 @@ function standalone(x) {
|
|||
)
|
||||
|
||||
markdown = code_strings.markdown
|
||||
assert "```javascript" in markdown or "```js" in markdown.lower()
|
||||
assert "```javascript" in markdown
|
||||
|
|
|
|||
|
|
@ -3,12 +3,27 @@
|
|||
This module tests the line profiling and tracing instrumentation for JavaScript code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import FunctionInfo, Language
|
||||
from codeflash.languages.javascript.line_profiler import JavaScriptLineProfiler
|
||||
from codeflash.languages.javascript.tracer import JavaScriptTracer
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
|
||||
def make_func(name: str, class_name: str | None = None) -> FunctionToOptimize:
|
||||
"""Helper to create FunctionToOptimize for testing."""
|
||||
parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else []
|
||||
return FunctionToOptimize(
|
||||
function_name=name,
|
||||
file_path=Path("/test/file.js"),
|
||||
parents=parents,
|
||||
language="javascript",
|
||||
)
|
||||
|
||||
|
||||
class TestJavaScriptLineProfiler:
|
||||
|
|
@ -49,7 +64,7 @@ function add(a, b) {
|
|||
file_path = Path(f.name)
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name="add", file_path=file_path, start_line=2, end_line=5, language=Language.JAVASCRIPT
|
||||
function_name="add", file_path=file_path, starting_line=2, ending_line=5, language="javascript"
|
||||
)
|
||||
|
||||
output_file = Path("/tmp/test_profile.json")
|
||||
|
|
@ -110,7 +125,7 @@ function multiply(x, y) {
|
|||
file_path = Path(f.name)
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name="multiply", file_path=file_path, start_line=2, end_line=4, language=Language.JAVASCRIPT
|
||||
function_name="multiply", file_path=file_path, starting_line=2, ending_line=4, language="javascript"
|
||||
)
|
||||
|
||||
output_db = Path("/tmp/test_traces.db")
|
||||
|
|
@ -154,7 +169,7 @@ function greet(name) {
|
|||
file_path = Path(f.name)
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name="greet", file_path=file_path, start_line=2, end_line=4, language=Language.JAVASCRIPT
|
||||
function_name="greet", file_path=file_path, starting_line=2, ending_line=4, language="javascript"
|
||||
)
|
||||
|
||||
output_file = file_path.parent / ".codeflash" / "traces.db"
|
||||
|
|
@ -185,7 +200,7 @@ function square(n) {
|
|||
file_path = Path(f.name)
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name="square", file_path=file_path, start_line=2, end_line=5, language=Language.JAVASCRIPT
|
||||
function_name="square", file_path=file_path, starting_line=2, ending_line=5, language="javascript"
|
||||
)
|
||||
|
||||
output_file = file_path.parent / ".codeflash" / "line_profile.json"
|
||||
|
|
@ -352,7 +367,7 @@ const result = calc.fibonacci(10);
|
|||
console.log(result);
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
|
||||
code=code, function_to_optimize=make_func("fibonacci", class_name="Calculator"), capture_func="capture"
|
||||
)
|
||||
|
||||
# Should transform calc.fibonacci(10) to codeflash.capture(..., calc.fibonacci.bind(calc), 10)
|
||||
|
|
@ -371,7 +386,7 @@ test('fibonacci works', () => {
|
|||
});
|
||||
"""
|
||||
transformed, counter = transform_expect_calls(
|
||||
code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture"
|
||||
code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture"
|
||||
)
|
||||
|
||||
# Should transform expect(calc.fibonacci(10)) to
|
||||
|
|
@ -393,8 +408,7 @@ test('fibonacci works', () => {
|
|||
"""
|
||||
transformed, counter = transform_expect_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="FibonacciCalculator.fibonacci",
|
||||
function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"),
|
||||
capture_func="capture",
|
||||
remove_assertions=True,
|
||||
)
|
||||
|
|
@ -419,7 +433,7 @@ class FibonacciCalculator {
|
|||
}
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture"
|
||||
code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture"
|
||||
)
|
||||
|
||||
# The method definition should NOT be transformed
|
||||
|
|
@ -438,7 +452,7 @@ FibonacciCalculator.prototype.fibonacci = function(n) {
|
|||
};
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture"
|
||||
code=code, function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"), capture_func="capture"
|
||||
)
|
||||
|
||||
# The prototype assignment should NOT be transformed
|
||||
|
|
@ -456,7 +470,7 @@ const b = calc.fibonacci(10);
|
|||
const sum = a + b;
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
|
||||
code=code, function_to_optimize=make_func("fibonacci", class_name="Calculator"), capture_func="capture"
|
||||
)
|
||||
|
||||
# Should transform both calls
|
||||
|
|
@ -475,7 +489,7 @@ class Wrapper {
|
|||
}
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code, func_name="fibonacci", qualified_name="Wrapper.fibonacci", capture_func="capture"
|
||||
code=code, function_to_optimize=make_func("fibonacci", class_name="Wrapper"), capture_func="capture"
|
||||
)
|
||||
|
||||
# Should transform this.fibonacci(n)
|
||||
|
|
@ -515,10 +529,9 @@ describe('FibonacciCalculator', () => {
|
|||
"""
|
||||
instrumented = _instrument_js_test_code(
|
||||
code=test_code,
|
||||
func_name="fibonacci",
|
||||
function_to_optimize=make_func("fibonacci", class_name="FibonacciCalculator"),
|
||||
test_file_path="test.js",
|
||||
mode="behavior",
|
||||
qualified_name="FibonacciCalculator.fibonacci",
|
||||
)
|
||||
|
||||
# Check that codeflash import was added
|
||||
|
|
@ -545,7 +558,7 @@ describe('Calculator', () => {
|
|||
});
|
||||
"""
|
||||
instrumented = _instrument_js_test_code(
|
||||
code=test_code, func_name="add", test_file_path="test.js", mode="behavior", qualified_name="Calculator.add"
|
||||
code=test_code, function_to_optimize=make_func("add", class_name="Calculator"), test_file_path="test.js", mode="behavior"
|
||||
)
|
||||
|
||||
# describe and test structure should be preserved
|
||||
|
|
@ -567,7 +580,7 @@ const data = await api.fetchData('http://example.com');
|
|||
console.log(data);
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code, func_name="fetchData", qualified_name="ApiClient.fetchData", capture_func="capture"
|
||||
code=code, function_to_optimize=make_func("fetchData", class_name="ApiClient"), capture_func="capture"
|
||||
)
|
||||
|
||||
# Should preserve await
|
||||
|
|
@ -586,7 +599,7 @@ class TestInstrumentationFullStringEquality:
|
|||
code = " calc.fibonacci(10);"
|
||||
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
|
||||
code=code, function_to_optimize=make_func("fibonacci", class_name="Calculator"), capture_func="capture"
|
||||
)
|
||||
|
||||
expected = " codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10);"
|
||||
|
|
@ -600,7 +613,7 @@ class TestInstrumentationFullStringEquality:
|
|||
code = " expect(calc.fibonacci(10)).toBe(55);"
|
||||
|
||||
transformed, counter = transform_expect_calls(
|
||||
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
|
||||
code=code, function_to_optimize=make_func("fibonacci", class_name="Calculator"), capture_func="capture"
|
||||
)
|
||||
|
||||
expected = " expect(codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10)).toBe(55);"
|
||||
|
|
@ -615,8 +628,7 @@ class TestInstrumentationFullStringEquality:
|
|||
|
||||
transformed, counter = transform_expect_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Calculator.fibonacci",
|
||||
function_to_optimize=make_func("fibonacci", class_name="Calculator"),
|
||||
capture_func="capture",
|
||||
remove_assertions=True,
|
||||
)
|
||||
|
|
@ -632,7 +644,7 @@ class TestInstrumentationFullStringEquality:
|
|||
code = " fibonacci(10);"
|
||||
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code, func_name="fibonacci", qualified_name="fibonacci", capture_func="capture"
|
||||
code=code, function_to_optimize=make_func("fibonacci"), capture_func="capture"
|
||||
)
|
||||
|
||||
expected = " codeflash.capture('fibonacci', '1', fibonacci, 10);"
|
||||
|
|
@ -646,9 +658,9 @@ class TestInstrumentationFullStringEquality:
|
|||
code = " return this.fibonacci(n - 1);"
|
||||
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code, func_name="fibonacci", qualified_name="Class.fibonacci", capture_func="capture"
|
||||
code=code, function_to_optimize=make_func("fibonacci", class_name="Class"), capture_func="capture"
|
||||
)
|
||||
|
||||
expected = " return codeflash.capture('Class.fibonacci', '1', this.fibonacci.bind(this), n - 1);"
|
||||
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
|
||||
assert counter == 1
|
||||
assert counter == 1
|
||||
251
tests/test_languages/test_javascript_requirements.py
Normal file
251
tests/test_languages/test_javascript_requirements.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
"""Tests for JavaScript requirements verification.
|
||||
|
||||
Tests the verify_requirements function that checks Node.js, npm, and test framework availability.
|
||||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport
|
||||
|
||||
|
||||
class TestVerifyRequirements:
|
||||
"""Tests for JavaScriptSupport.verify_requirements()."""
|
||||
|
||||
@pytest.fixture
|
||||
def js_support(self):
|
||||
"""Create a JavaScriptSupport instance."""
|
||||
return JavaScriptSupport()
|
||||
|
||||
@pytest.fixture
|
||||
def project_with_jest(self, tmp_path):
|
||||
"""Create a project directory with Jest installed."""
|
||||
node_modules = tmp_path / "node_modules"
|
||||
node_modules.mkdir()
|
||||
(node_modules / "jest").mkdir()
|
||||
(node_modules / "codeflash").mkdir()
|
||||
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test-project",
|
||||
"devDependencies": {"jest": "^29.0.0"},
|
||||
}
|
||||
)
|
||||
)
|
||||
return tmp_path
|
||||
|
||||
@pytest.fixture
|
||||
def project_with_vitest(self, tmp_path):
|
||||
"""Create a project directory with Vitest installed."""
|
||||
node_modules = tmp_path / "node_modules"
|
||||
node_modules.mkdir()
|
||||
(node_modules / "vitest").mkdir()
|
||||
(node_modules / "codeflash").mkdir()
|
||||
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "test-project",
|
||||
"devDependencies": {"vitest": "^2.0.0"},
|
||||
}
|
||||
)
|
||||
)
|
||||
return tmp_path
|
||||
|
||||
@pytest.fixture
|
||||
def project_without_node_modules(self, tmp_path):
|
||||
"""Create a project directory without node_modules."""
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(json.dumps({"name": "test-project"}))
|
||||
return tmp_path
|
||||
|
||||
@pytest.fixture
|
||||
def project_without_jest(self, tmp_path):
|
||||
"""Create a project directory with node_modules but without Jest."""
|
||||
node_modules = tmp_path / "node_modules"
|
||||
node_modules.mkdir()
|
||||
(node_modules / "some-other-package").mkdir()
|
||||
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(json.dumps({"name": "test-project"}))
|
||||
return tmp_path
|
||||
|
||||
def test_verify_requirements_success_with_jest(self, js_support, project_with_jest):
|
||||
"""Test successful verification when Jest is installed."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
success, errors = js_support.verify_requirements(project_with_jest, "jest")
|
||||
|
||||
assert success is True
|
||||
assert errors == []
|
||||
|
||||
def test_verify_requirements_success_with_vitest(self, js_support, project_with_vitest):
|
||||
"""Test successful verification when Vitest is installed."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
success, errors = js_support.verify_requirements(project_with_vitest, "vitest")
|
||||
|
||||
assert success is True
|
||||
assert errors == []
|
||||
|
||||
def test_verify_requirements_fails_without_node(self, js_support, project_with_jest):
|
||||
"""Test verification fails when Node.js is not installed."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.side_effect = FileNotFoundError("node not found")
|
||||
|
||||
success, errors = js_support.verify_requirements(project_with_jest, "jest")
|
||||
|
||||
assert success is False
|
||||
assert len(errors) >= 1
|
||||
node_error_found = any("Node.js" in error for error in errors)
|
||||
assert node_error_found is True
|
||||
|
||||
def test_verify_requirements_fails_without_npm(self, js_support, project_with_jest):
|
||||
"""Test verification fails when npm is not available."""
|
||||
|
||||
def mock_run_side_effect(cmd, **kwargs):
|
||||
if cmd[0] == "node":
|
||||
return MagicMock(returncode=0)
|
||||
if cmd[0] == "npm":
|
||||
raise FileNotFoundError("npm not found")
|
||||
return MagicMock(returncode=0)
|
||||
|
||||
with patch("subprocess.run", side_effect=mock_run_side_effect):
|
||||
success, errors = js_support.verify_requirements(project_with_jest, "jest")
|
||||
|
||||
assert success is False
|
||||
npm_error_found = any("npm" in error for error in errors)
|
||||
assert npm_error_found is True
|
||||
|
||||
def test_verify_requirements_fails_without_node_modules(self, js_support, project_without_node_modules):
|
||||
"""Test verification fails when node_modules doesn't exist."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
success, errors = js_support.verify_requirements(project_without_node_modules, "jest")
|
||||
|
||||
assert success is False
|
||||
assert len(errors) == 1
|
||||
expected_error = (
|
||||
f"node_modules not found in {project_without_node_modules}. "
|
||||
f"Please run 'npm install' to install dependencies."
|
||||
)
|
||||
assert errors[0] == expected_error
|
||||
|
||||
def test_verify_requirements_fails_without_test_framework(self, js_support, project_without_jest):
|
||||
"""Test verification fails when test framework is not installed."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
success, errors = js_support.verify_requirements(project_without_jest, "jest")
|
||||
|
||||
assert success is False
|
||||
assert len(errors) == 1
|
||||
expected_error = "jest is not installed. Please run 'npm install --save-dev jest' to install it."
|
||||
assert errors[0] == expected_error
|
||||
|
||||
def test_verify_requirements_returns_multiple_errors(self, js_support, project_without_node_modules):
|
||||
"""Test that multiple errors can be returned."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.side_effect = FileNotFoundError("command not found")
|
||||
|
||||
success, errors = js_support.verify_requirements(project_without_node_modules, "jest")
|
||||
|
||||
assert success is False
|
||||
assert len(errors) >= 2
|
||||
# Should have errors for Node.js, npm, and node_modules
|
||||
error_text = " ".join(errors)
|
||||
assert "Node.js" in error_text
|
||||
assert "npm" in error_text
|
||||
|
||||
def test_verify_requirements_vitest_not_installed(self, js_support, project_with_jest):
|
||||
"""Test verification fails when Vitest is requested but only Jest is installed."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
success, errors = js_support.verify_requirements(project_with_jest, "vitest")
|
||||
|
||||
assert success is False
|
||||
assert len(errors) == 1
|
||||
expected_error = "vitest is not installed. Please run 'npm install --save-dev vitest' to install it."
|
||||
assert errors[0] == expected_error
|
||||
|
||||
def test_verify_requirements_jest_not_installed(self, js_support, project_with_vitest):
|
||||
"""Test verification fails when Jest is requested but only Vitest is installed."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
success, errors = js_support.verify_requirements(project_with_vitest, "jest")
|
||||
|
||||
assert success is False
|
||||
assert len(errors) == 1
|
||||
expected_error = "jest is not installed. Please run 'npm install --save-dev jest' to install it."
|
||||
assert errors[0] == expected_error
|
||||
|
||||
|
||||
class TestVerifyRequirementsIntegration:
|
||||
"""Integration tests for verify_requirements with real filesystem."""
|
||||
|
||||
@pytest.fixture
|
||||
def js_support(self):
|
||||
"""Create a JavaScriptSupport instance."""
|
||||
return JavaScriptSupport()
|
||||
|
||||
def test_verify_on_real_vitest_project(self, js_support):
|
||||
"""Test verification on the real vitest sample project."""
|
||||
project_root = Path(__file__).parent.parent.parent / "code_to_optimize" / "js" / "code_to_optimize_vitest"
|
||||
|
||||
if not project_root.exists():
|
||||
pytest.skip("code_to_optimize_vitest directory not found")
|
||||
|
||||
node_modules = project_root / "node_modules"
|
||||
if not node_modules.exists():
|
||||
pytest.skip("node_modules not installed in vitest project")
|
||||
|
||||
# This test verifies the real project structure
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
success, errors = js_support.verify_requirements(project_root, "vitest")
|
||||
|
||||
# If vitest is installed, should succeed
|
||||
vitest_installed = (node_modules / "vitest").exists()
|
||||
if vitest_installed:
|
||||
assert success is True
|
||||
assert errors == []
|
||||
else:
|
||||
assert success is False
|
||||
assert len(errors) >= 1
|
||||
|
||||
def test_verify_on_real_jest_project(self, js_support):
|
||||
"""Test verification on the real Jest sample project."""
|
||||
project_root = Path(__file__).parent.parent.parent / "code_to_optimize" / "js" / "code_to_optimize_ts"
|
||||
|
||||
if not project_root.exists():
|
||||
pytest.skip("code_to_optimize_ts directory not found")
|
||||
|
||||
node_modules = project_root / "node_modules"
|
||||
if not node_modules.exists():
|
||||
pytest.skip("node_modules not installed in jest project")
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
success, errors = js_support.verify_requirements(project_root, "jest")
|
||||
|
||||
jest_installed = (node_modules / "jest").exists()
|
||||
if jest_installed:
|
||||
assert success is True
|
||||
assert errors == []
|
||||
else:
|
||||
assert success is False
|
||||
assert len(errors) >= 1
|
||||
|
|
@ -55,7 +55,7 @@ function add(a, b) {
|
|||
functions = js_support.discover_functions(Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "add"
|
||||
assert functions[0].function_name == "add"
|
||||
assert functions[0].language == Language.JAVASCRIPT
|
||||
|
||||
def test_discover_multiple_functions(self, js_support):
|
||||
|
|
@ -79,7 +79,7 @@ function multiply(a, b) {
|
|||
functions = js_support.discover_functions(Path(f.name))
|
||||
|
||||
assert len(functions) == 3
|
||||
names = {func.name for func in functions}
|
||||
names = {func.function_name for func in functions}
|
||||
assert names == {"add", "subtract", "multiply"}
|
||||
|
||||
def test_discover_arrow_function(self, js_support):
|
||||
|
|
@ -97,7 +97,7 @@ const multiply = (x, y) => x * y;
|
|||
functions = js_support.discover_functions(Path(f.name))
|
||||
|
||||
assert len(functions) == 2
|
||||
names = {func.name for func in functions}
|
||||
names = {func.function_name for func in functions}
|
||||
assert names == {"add", "multiply"}
|
||||
|
||||
def test_discover_function_without_return_excluded(self, js_support):
|
||||
|
|
@ -118,7 +118,7 @@ function withoutReturn() {
|
|||
|
||||
# Only the function with return should be discovered
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "withReturn"
|
||||
assert functions[0].function_name == "withReturn"
|
||||
|
||||
def test_discover_class_methods(self, js_support):
|
||||
"""Test discovering class methods."""
|
||||
|
|
@ -161,8 +161,8 @@ function syncFunction() {
|
|||
|
||||
assert len(functions) == 2
|
||||
|
||||
async_func = next(f for f in functions if f.name == "fetchData")
|
||||
sync_func = next(f for f in functions if f.name == "syncFunction")
|
||||
async_func = next(f for f in functions if f.function_name == "fetchData")
|
||||
sync_func = next(f for f in functions if f.function_name == "syncFunction")
|
||||
|
||||
assert async_func.is_async is True
|
||||
assert sync_func.is_async is False
|
||||
|
|
@ -185,7 +185,7 @@ function syncFunc() {
|
|||
functions = js_support.discover_functions(Path(f.name), criteria)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "syncFunc"
|
||||
assert functions[0].function_name == "syncFunc"
|
||||
|
||||
def test_discover_with_filter_exclude_methods(self, js_support):
|
||||
"""Test filtering out class methods."""
|
||||
|
|
@ -207,7 +207,7 @@ class MyClass {
|
|||
functions = js_support.discover_functions(Path(f.name), criteria)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "standalone"
|
||||
assert functions[0].function_name == "standalone"
|
||||
|
||||
def test_discover_line_numbers(self, js_support):
|
||||
"""Test that line numbers are correctly captured."""
|
||||
|
|
@ -226,13 +226,13 @@ function func2() {
|
|||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
|
||||
func1 = next(f for f in functions if f.name == "func1")
|
||||
func2 = next(f for f in functions if f.name == "func2")
|
||||
func1 = next(f for f in functions if f.function_name == "func1")
|
||||
func2 = next(f for f in functions if f.function_name == "func2")
|
||||
|
||||
assert func1.start_line == 1
|
||||
assert func1.end_line == 3
|
||||
assert func2.start_line == 5
|
||||
assert func2.end_line == 9
|
||||
assert func1.starting_line == 1
|
||||
assert func1.ending_line == 3
|
||||
assert func2.starting_line == 5
|
||||
assert func2.ending_line == 9
|
||||
|
||||
def test_discover_generator_function(self, js_support):
|
||||
"""Test discovering generator functions."""
|
||||
|
|
@ -249,7 +249,7 @@ function* numberGenerator() {
|
|||
functions = js_support.discover_functions(Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "numberGenerator"
|
||||
assert functions[0].function_name == "numberGenerator"
|
||||
|
||||
def test_discover_invalid_file_returns_empty(self, js_support):
|
||||
"""Test that invalid JavaScript file returns empty list."""
|
||||
|
|
@ -280,7 +280,7 @@ const add = function(a, b) {
|
|||
functions = js_support.discover_functions(Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "add"
|
||||
assert functions[0].function_name == "add"
|
||||
|
||||
def test_discover_immediately_invoked_function_excluded(self, js_support):
|
||||
"""Test that IIFEs without names are excluded when require_name is True."""
|
||||
|
|
@ -300,7 +300,7 @@ function named() {
|
|||
|
||||
# Only the named function should be discovered
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "named"
|
||||
assert functions[0].function_name == "named"
|
||||
|
||||
|
||||
class TestReplaceFunction:
|
||||
|
|
@ -316,7 +316,7 @@ function multiply(a, b) {
|
|||
return a * b;
|
||||
}
|
||||
"""
|
||||
func = FunctionInfo(name="add", file_path=Path("/test.js"), start_line=1, end_line=3)
|
||||
func = FunctionInfo(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3)
|
||||
new_code = """function add(a, b) {
|
||||
// Optimized
|
||||
return (a + b) | 0;
|
||||
|
|
@ -343,7 +343,7 @@ function other() {
|
|||
|
||||
// Footer
|
||||
"""
|
||||
func = FunctionInfo(name="target", file_path=Path("/test.js"), start_line=4, end_line=6)
|
||||
func = FunctionInfo(function_name="target", file_path=Path("/test.js"), starting_line=4, ending_line=6)
|
||||
new_code = """function target() {
|
||||
return 42;
|
||||
}
|
||||
|
|
@ -365,11 +365,11 @@ function other() {
|
|||
}
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
function_name="add",
|
||||
file_path=Path("/test.js"),
|
||||
start_line=2,
|
||||
end_line=4,
|
||||
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
|
||||
starting_line=2,
|
||||
ending_line=4,
|
||||
parents=[ParentInfo(name="Calculator", type="ClassDef")],
|
||||
)
|
||||
# New code has no indentation
|
||||
new_code = """add(a, b) {
|
||||
|
|
@ -391,7 +391,7 @@ function other() {
|
|||
|
||||
const multiply = (x, y) => x * y;
|
||||
"""
|
||||
func = FunctionInfo(name="add", file_path=Path("/test.js"), start_line=1, end_line=3)
|
||||
func = FunctionInfo(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3)
|
||||
new_code = """const add = (a, b) => {
|
||||
return (a + b) | 0;
|
||||
};
|
||||
|
|
@ -483,7 +483,7 @@ class TestExtractCodeContext:
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
func = FunctionInfo(name="add", file_path=file_path, start_line=1, end_line=3)
|
||||
func = FunctionInfo(function_name="add", file_path=file_path, starting_line=1, ending_line=3)
|
||||
|
||||
context = js_support.extract_code_context(func, file_path.parent, file_path.parent)
|
||||
|
||||
|
|
@ -508,7 +508,7 @@ function main(a) {
|
|||
|
||||
# First discover functions to get accurate line numbers
|
||||
functions = js_support.discover_functions(file_path)
|
||||
main_func = next(f for f in functions if f.name == "main")
|
||||
main_func = next(f for f in functions if f.function_name == "main")
|
||||
|
||||
context = js_support.extract_code_context(main_func, file_path.parent, file_path.parent)
|
||||
|
||||
|
|
@ -538,7 +538,7 @@ class TestIntegration:
|
|||
functions = js_support.discover_functions(file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.name == "fibonacci"
|
||||
assert func.function_name == "fibonacci"
|
||||
|
||||
# Replace
|
||||
optimized_code = """function fibonacci(n) {
|
||||
|
|
@ -626,7 +626,7 @@ export default Button;
|
|||
functions = js_support.discover_functions(file_path)
|
||||
|
||||
# Should find both components
|
||||
names = {f.name for f in functions}
|
||||
names = {f.function_name for f in functions}
|
||||
assert "Button" in names
|
||||
assert "Card" in names
|
||||
|
||||
|
|
@ -688,7 +688,7 @@ class TestClassMethodExtraction:
|
|||
|
||||
# Discover the method
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_method = next(f for f in functions if f.name == "add")
|
||||
add_method = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
# Extract code context
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
|
@ -725,7 +725,7 @@ class Calculator {
|
|||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_method = next(f for f in functions if f.name == "add")
|
||||
add_method = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
||||
|
|
@ -764,7 +764,7 @@ class Calculator {
|
|||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
fib_method = next(f for f in functions if f.name == "fibonacci")
|
||||
fib_method = next(f for f in functions if f.function_name == "fibonacci")
|
||||
|
||||
context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent)
|
||||
|
||||
|
|
@ -802,7 +802,7 @@ class Calculator {
|
|||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_method = next((f for f in functions if f.name == "add"), None)
|
||||
add_method = next((f for f in functions if f.function_name == "add"), None)
|
||||
|
||||
if add_method:
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
|
@ -831,7 +831,7 @@ class Calculator {
|
|||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
fetch_method = next(f for f in functions if f.name == "fetchData")
|
||||
fetch_method = next(f for f in functions if f.function_name == "fetchData")
|
||||
|
||||
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
|
||||
|
||||
|
|
@ -863,7 +863,7 @@ class Calculator {
|
|||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_method = next((f for f in functions if f.name == "add"), None)
|
||||
add_method = next((f for f in functions if f.function_name == "add"), None)
|
||||
|
||||
if add_method:
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
|
@ -891,7 +891,7 @@ class Calculator {
|
|||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
method = next(f for f in functions if f.name == "simpleMethod")
|
||||
method = next(f for f in functions if f.function_name == "simpleMethod")
|
||||
|
||||
context = js_support.extract_code_context(method, file_path.parent, file_path.parent)
|
||||
|
||||
|
|
@ -922,11 +922,11 @@ class TestClassMethodReplacement:
|
|||
}
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
function_name="add",
|
||||
file_path=Path("/test.js"),
|
||||
start_line=2,
|
||||
end_line=4,
|
||||
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
|
||||
starting_line=2,
|
||||
ending_line=4,
|
||||
parents=[ParentInfo(name="Calculator", type="ClassDef")],
|
||||
is_method=True,
|
||||
)
|
||||
new_code = """ add(a, b) {
|
||||
|
|
@ -963,12 +963,12 @@ class TestClassMethodReplacement:
|
|||
}
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
function_name="add",
|
||||
file_path=Path("/test.js"),
|
||||
start_line=5, # Method starts here
|
||||
end_line=7,
|
||||
starting_line=5, # Method starts here
|
||||
ending_line=7,
|
||||
doc_start_line=2, # JSDoc starts here
|
||||
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
|
||||
parents=[ParentInfo(name="Calculator", type="ClassDef")],
|
||||
is_method=True,
|
||||
)
|
||||
new_code = """ /**
|
||||
|
|
@ -1000,11 +1000,11 @@ class TestClassMethodReplacement:
|
|||
"""
|
||||
# Replace add first
|
||||
add_func = FunctionInfo(
|
||||
name="add",
|
||||
function_name="add",
|
||||
file_path=Path("/test.js"),
|
||||
start_line=2,
|
||||
end_line=4,
|
||||
parents=(ParentInfo(name="Math", type="ClassDef"),),
|
||||
starting_line=2,
|
||||
ending_line=4,
|
||||
parents=[ParentInfo(name="Math", type="ClassDef")],
|
||||
is_method=True,
|
||||
)
|
||||
source = js_support.replace_function(
|
||||
|
|
@ -1032,11 +1032,11 @@ class TestClassMethodReplacement:
|
|||
}
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="innerMethod",
|
||||
function_name="innerMethod",
|
||||
file_path=Path("/test.js"),
|
||||
start_line=2,
|
||||
end_line=4,
|
||||
parents=(ParentInfo(name="Indented", type="ClassDef"),),
|
||||
starting_line=2,
|
||||
ending_line=4,
|
||||
parents=[ParentInfo(name="Indented", type="ClassDef")],
|
||||
is_method=True,
|
||||
)
|
||||
# New code with no indentation
|
||||
|
|
@ -1077,7 +1077,7 @@ class TestClassMethodEdgeCases:
|
|||
functions = js_support.discover_functions(file_path)
|
||||
|
||||
# Should find constructor and increment
|
||||
names = {f.name for f in functions}
|
||||
names = {f.function_name for f in functions}
|
||||
assert "constructor" in names or "increment" in names
|
||||
|
||||
def test_class_with_getters_setters(self, js_support):
|
||||
|
|
@ -1107,7 +1107,7 @@ class TestClassMethodEdgeCases:
|
|||
functions = js_support.discover_functions(file_path)
|
||||
|
||||
# Should find at least greet
|
||||
names = {f.name for f in functions}
|
||||
names = {f.function_name for f in functions}
|
||||
assert "greet" in names
|
||||
|
||||
def test_class_extending_another(self, js_support):
|
||||
|
|
@ -1135,7 +1135,7 @@ class Dog extends Animal {
|
|||
functions = js_support.discover_functions(file_path)
|
||||
|
||||
# Find Dog's fetch method
|
||||
fetch_method = next((f for f in functions if f.name == "fetch" and f.class_name == "Dog"), None)
|
||||
fetch_method = next((f for f in functions if f.function_name == "fetch" and f.class_name == "Dog"), None)
|
||||
|
||||
if fetch_method:
|
||||
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
|
||||
|
|
@ -1169,7 +1169,7 @@ class Dog extends Animal {
|
|||
functions = js_support.discover_functions(file_path)
|
||||
|
||||
# Should at least find publicMethod
|
||||
names = {f.name for f in functions}
|
||||
names = {f.function_name for f in functions}
|
||||
assert "publicMethod" in names
|
||||
|
||||
def test_commonjs_class_export(self, js_support):
|
||||
|
|
@ -1187,7 +1187,7 @@ module.exports = { Calculator };
|
|||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_method = next(f for f in functions if f.name == "add")
|
||||
add_method = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
||||
|
|
@ -1209,7 +1209,7 @@ module.exports = { Calculator };
|
|||
functions = js_support.discover_functions(file_path)
|
||||
|
||||
# Find the add method
|
||||
add_method = next((f for f in functions if f.name == "add"), None)
|
||||
add_method = next((f for f in functions if f.function_name == "add"), None)
|
||||
|
||||
if add_method:
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
|
@ -1260,7 +1260,7 @@ module.exports = { Counter };
|
|||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
increment_func = next(fn for fn in functions if fn.name == "increment")
|
||||
increment_func = next(fn for fn in functions if fn.function_name == "increment")
|
||||
|
||||
# Step 1: Extract code context (includes constructor for AI context)
|
||||
context = js_support.extract_code_context(increment_func, file_path.parent, file_path.parent)
|
||||
|
|
@ -1359,7 +1359,7 @@ export { User };
|
|||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
get_name_func = next(fn for fn in functions if fn.name == "getName")
|
||||
get_name_func = next(fn for fn in functions if fn.function_name == "getName")
|
||||
|
||||
# Step 1: Extract code context (includes fields and constructor)
|
||||
context = ts_support.extract_code_context(get_name_func, file_path.parent, file_path.parent)
|
||||
|
|
@ -1461,7 +1461,7 @@ class Calculator {
|
|||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_func = next(fn for fn in functions if fn.name == "add")
|
||||
add_func = next(fn for fn in functions if fn.function_name == "add")
|
||||
|
||||
# Extract context for add
|
||||
context = js_support.extract_code_context(add_func, file_path.parent, file_path.parent)
|
||||
|
|
@ -1547,7 +1547,7 @@ module.exports = { MathUtils };
|
|||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_func = next(fn for fn in functions if fn.name == "add")
|
||||
add_func = next(fn for fn in functions if fn.function_name == "add")
|
||||
|
||||
# Extract context
|
||||
context = js_support.extract_code_context(add_func, file_path.parent, file_path.parent)
|
||||
|
|
|
|||
|
|
@ -1715,7 +1715,7 @@ module.exports = { Calculator };
|
|||
functions = js_support.discover_functions(source_file)
|
||||
|
||||
# Check qualified names include class
|
||||
add_func = next((f for f in functions if f.name == "add"), None)
|
||||
add_func = next((f for f in functions if f.function_name == "add"), None)
|
||||
assert add_func is not None
|
||||
assert add_func.class_name == "Calculator"
|
||||
|
||||
|
|
|
|||
730
tests/test_languages/test_javascript_test_runner.py
Normal file
730
tests/test_languages/test_javascript_test_runner.py
Normal file
|
|
@ -0,0 +1,730 @@
|
|||
"""Tests for JavaScript/Jest test runner functionality."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestJestRootsConfiguration:
|
||||
"""Tests for Jest --roots flag handling."""
|
||||
|
||||
def test_behavioral_tests_adds_roots_for_test_directories(self):
|
||||
"""Test that run_jest_behavioral_tests adds --roots for test directories."""
|
||||
from codeflash.languages.javascript.test_runner import run_jest_behavioral_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
# Create mock test files in a test directory
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir).resolve()
|
||||
test_dir = tmpdir_path / "test"
|
||||
test_dir.mkdir()
|
||||
|
||||
# Create package.json to simulate a Node project
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
|
||||
# Create mock test files
|
||||
test_file1 = test_dir / "test_func__unit_test_0.test.ts"
|
||||
test_file2 = test_dir / "test_func__unit_test_1.test.ts"
|
||||
test_file1.write_text("// test 1")
|
||||
test_file2.write_text("// test 2")
|
||||
|
||||
mock_test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file1,
|
||||
instrumented_behavior_file_path=test_file1,
|
||||
benchmarking_file_path=test_file1,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
TestFile(
|
||||
original_file_path=test_file2,
|
||||
instrumented_behavior_file_path=test_file2,
|
||||
benchmarking_file_path=test_file2,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Mock subprocess.run to capture the command
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 1
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
try:
|
||||
run_jest_behavioral_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
)
|
||||
except Exception:
|
||||
pass # Expected to fail since no real Jest
|
||||
|
||||
# Verify the command included --roots
|
||||
if mock_run.called:
|
||||
call_args = mock_run.call_args
|
||||
cmd = call_args[0][0]
|
||||
|
||||
# Find --roots flags in the command
|
||||
roots_flags = []
|
||||
for i, arg in enumerate(cmd):
|
||||
if arg == "--roots" and i + 1 < len(cmd):
|
||||
roots_flags.append(cmd[i + 1])
|
||||
|
||||
# Should have added the test directory as a root
|
||||
assert len(roots_flags) > 0, "Expected --roots flag in Jest command"
|
||||
assert str(test_dir) in roots_flags or any(
|
||||
str(test_dir) in root for root in roots_flags
|
||||
), f"Expected test directory {test_dir} in --roots flags: {roots_flags}"
|
||||
|
||||
def test_benchmarking_tests_adds_roots_for_test_directories(self):
|
||||
"""Test that run_jest_benchmarking_tests adds --roots for test directories."""
|
||||
from codeflash.languages.javascript.test_runner import run_jest_benchmarking_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir).resolve()
|
||||
test_dir = tmpdir_path / "test"
|
||||
test_dir.mkdir()
|
||||
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
|
||||
test_file = test_dir / "test_func__perf_test_0.test.ts"
|
||||
test_file.write_text("// perf test")
|
||||
|
||||
mock_test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 1
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
try:
|
||||
run_jest_benchmarking_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if mock_run.called:
|
||||
call_args = mock_run.call_args
|
||||
cmd = call_args[0][0]
|
||||
|
||||
roots_flags = []
|
||||
for i, arg in enumerate(cmd):
|
||||
if arg == "--roots" and i + 1 < len(cmd):
|
||||
roots_flags.append(cmd[i + 1])
|
||||
|
||||
assert len(roots_flags) > 0, "Expected --roots flag in Jest command"
|
||||
|
||||
def test_line_profile_tests_adds_roots_for_test_directories(self):
|
||||
"""Test that run_jest_line_profile_tests adds --roots for test directories."""
|
||||
from codeflash.languages.javascript.test_runner import run_jest_line_profile_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
test_dir = tmpdir_path / "test"
|
||||
test_dir.mkdir()
|
||||
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
|
||||
test_file = test_dir / "test_func__line_profile.test.ts"
|
||||
test_file.write_text("// line profile test")
|
||||
|
||||
mock_test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 1
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
try:
|
||||
run_jest_line_profile_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if mock_run.called:
|
||||
call_args = mock_run.call_args
|
||||
cmd = call_args[0][0]
|
||||
|
||||
roots_flags = []
|
||||
for i, arg in enumerate(cmd):
|
||||
if arg == "--roots" and i + 1 < len(cmd):
|
||||
roots_flags.append(cmd[i + 1])
|
||||
|
||||
assert len(roots_flags) > 0, "Expected --roots flag in Jest command"
|
||||
|
||||
def test_multiple_test_directories_all_added_to_roots(self):
|
||||
"""Test that multiple test directories are all added as --roots."""
|
||||
from codeflash.languages.javascript.test_runner import run_jest_behavioral_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
test_dir1 = tmpdir_path / "test"
|
||||
test_dir2 = tmpdir_path / "spec"
|
||||
test_dir1.mkdir()
|
||||
test_dir2.mkdir()
|
||||
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
|
||||
test_file1 = test_dir1 / "test_func__unit_test_0.test.ts"
|
||||
test_file2 = test_dir2 / "test_func__unit_test_1.test.ts"
|
||||
test_file1.write_text("// test 1")
|
||||
test_file2.write_text("// test 2")
|
||||
|
||||
mock_test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file1,
|
||||
instrumented_behavior_file_path=test_file1,
|
||||
benchmarking_file_path=test_file1,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
TestFile(
|
||||
original_file_path=test_file2,
|
||||
instrumented_behavior_file_path=test_file2,
|
||||
benchmarking_file_path=test_file2,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 1
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
try:
|
||||
run_jest_behavioral_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if mock_run.called:
|
||||
call_args = mock_run.call_args
|
||||
cmd = call_args[0][0]
|
||||
|
||||
roots_flags = []
|
||||
for i, arg in enumerate(cmd):
|
||||
if arg == "--roots" and i + 1 < len(cmd):
|
||||
roots_flags.append(cmd[i + 1])
|
||||
|
||||
# Should have two --roots flags (one for each directory)
|
||||
assert len(roots_flags) == 2, f"Expected 2 --roots flags, got {len(roots_flags)}"
|
||||
|
||||
|
||||
class TestVitestTimeoutConfiguration:
|
||||
"""Tests for Vitest subprocess timeout handling."""
|
||||
|
||||
def test_vitest_behavioral_subprocess_timeout_larger_than_test_timeout(self):
|
||||
"""Test that subprocess timeout is larger than per-test timeout for Vitest behavioral tests."""
|
||||
from codeflash.languages.javascript.vitest_runner import run_vitest_behavioral_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
test_dir = tmpdir_path / "test"
|
||||
test_dir.mkdir()
|
||||
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test", "devDependencies": {"vitest": "^1.0.0"}}')
|
||||
|
||||
test_file = test_dir / "test_func.test.ts"
|
||||
test_file.write_text("// test")
|
||||
|
||||
mock_test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 0
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
# Run with a 15 second per-test timeout
|
||||
run_vitest_behavioral_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
timeout=15, # 15 second per-test timeout
|
||||
project_root=tmpdir_path,
|
||||
)
|
||||
|
||||
# Verify subprocess was called with a larger timeout
|
||||
assert mock_run.called
|
||||
call_kwargs = mock_run.call_args[1]
|
||||
subprocess_timeout = call_kwargs.get("timeout")
|
||||
|
||||
# Subprocess timeout should be at least 120 seconds (minimum)
|
||||
# or 10x the per-test timeout (150 seconds)
|
||||
assert subprocess_timeout >= 120, f"Expected subprocess timeout >= 120s, got {subprocess_timeout}s"
|
||||
assert subprocess_timeout >= 15 * 10, f"Expected subprocess timeout >= 150s (10x per-test), got {subprocess_timeout}s"
|
||||
|
||||
def test_vitest_line_profile_subprocess_timeout_larger_than_test_timeout(self):
|
||||
"""Test that subprocess timeout is larger than per-test timeout for Vitest line profile tests."""
|
||||
from codeflash.languages.javascript.vitest_runner import run_vitest_line_profile_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
test_dir = tmpdir_path / "test"
|
||||
test_dir.mkdir()
|
||||
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test", "devDependencies": {"vitest": "^1.0.0"}}')
|
||||
|
||||
test_file = test_dir / "test_func.test.ts"
|
||||
test_file.write_text("// test")
|
||||
|
||||
mock_test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 0
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
run_vitest_line_profile_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
timeout=15,
|
||||
project_root=tmpdir_path,
|
||||
)
|
||||
|
||||
assert mock_run.called
|
||||
call_kwargs = mock_run.call_args[1]
|
||||
subprocess_timeout = call_kwargs.get("timeout")
|
||||
|
||||
assert subprocess_timeout >= 120, f"Expected subprocess timeout >= 120s, got {subprocess_timeout}s"
|
||||
|
||||
def test_vitest_default_subprocess_timeout_is_reasonable(self):
|
||||
"""Test that default subprocess timeout is at least 120 seconds when no timeout specified."""
|
||||
from codeflash.languages.javascript.vitest_runner import run_vitest_behavioral_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
test_dir = tmpdir_path / "test"
|
||||
test_dir.mkdir()
|
||||
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test", "devDependencies": {"vitest": "^1.0.0"}}')
|
||||
|
||||
test_file = test_dir / "test_func.test.ts"
|
||||
test_file.write_text("// test")
|
||||
|
||||
mock_test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 0
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
# Run without specifying a timeout
|
||||
run_vitest_behavioral_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
)
|
||||
|
||||
assert mock_run.called
|
||||
call_kwargs = mock_run.call_args[1]
|
||||
subprocess_timeout = call_kwargs.get("timeout")
|
||||
|
||||
# Default should be at least 120 seconds (or 600 from the default)
|
||||
assert subprocess_timeout >= 120, f"Expected subprocess timeout >= 120s, got {subprocess_timeout}s"
|
||||
|
||||
|
||||
class TestVitestInternalLoopingConfiguration:
|
||||
"""Tests for Vitest internal looping (no external loop-runner)."""
|
||||
|
||||
def test_vitest_benchmarking_does_not_set_current_batch_env(self):
|
||||
"""Test that Vitest runner does NOT set CODEFLASH_PERF_CURRENT_BATCH.
|
||||
|
||||
This is critical: when CODEFLASH_PERF_CURRENT_BATCH is not set,
|
||||
capturePerf() in the npm package will do all loops internally
|
||||
(PERF_LOOP_COUNT iterations) instead of just PERF_BATCH_SIZE.
|
||||
"""
|
||||
from codeflash.languages.javascript.vitest_runner import run_vitest_benchmarking_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
test_dir = tmpdir_path / "test"
|
||||
test_dir.mkdir()
|
||||
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test", "devDependencies": {"vitest": "^1.0.0"}}')
|
||||
|
||||
test_file = test_dir / "test_func.test.ts"
|
||||
test_file.write_text("// perf test")
|
||||
|
||||
mock_test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 0
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
run_vitest_benchmarking_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
max_loops=100,
|
||||
min_loops=5,
|
||||
)
|
||||
|
||||
assert mock_run.called
|
||||
call_kwargs = mock_run.call_args[1]
|
||||
env = call_kwargs.get("env", {})
|
||||
|
||||
# CODEFLASH_PERF_CURRENT_BATCH should NOT be set
|
||||
# This allows capturePerf() to do all loops internally
|
||||
assert "CODEFLASH_PERF_CURRENT_BATCH" not in env, (
|
||||
"CODEFLASH_PERF_CURRENT_BATCH should not be set for Vitest - "
|
||||
"internal looping relies on this being undefined"
|
||||
)
|
||||
|
||||
# But CODEFLASH_PERF_LOOP_COUNT should be set
|
||||
assert "CODEFLASH_PERF_LOOP_COUNT" in env, "CODEFLASH_PERF_LOOP_COUNT should be set"
|
||||
assert env["CODEFLASH_PERF_LOOP_COUNT"] == "100"
|
||||
|
||||
def test_vitest_benchmarking_sets_loop_configuration_env_vars(self):
|
||||
"""Test that Vitest benchmarking sets correct loop configuration environment variables."""
|
||||
from codeflash.languages.javascript.vitest_runner import run_vitest_benchmarking_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
test_dir = tmpdir_path / "test"
|
||||
test_dir.mkdir()
|
||||
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test", "devDependencies": {"vitest": "^1.0.0"}}')
|
||||
|
||||
test_file = test_dir / "test_func.test.ts"
|
||||
test_file.write_text("// perf test")
|
||||
|
||||
mock_test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 0
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
run_vitest_benchmarking_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
max_loops=50,
|
||||
min_loops=10,
|
||||
target_duration_ms=5000,
|
||||
stability_check=True,
|
||||
)
|
||||
|
||||
assert mock_run.called
|
||||
call_kwargs = mock_run.call_args[1]
|
||||
env = call_kwargs.get("env", {})
|
||||
|
||||
# Verify all loop configuration env vars are set correctly
|
||||
assert env.get("CODEFLASH_PERF_LOOP_COUNT") == "50"
|
||||
assert env.get("CODEFLASH_PERF_MIN_LOOPS") == "10"
|
||||
assert env.get("CODEFLASH_PERF_TARGET_DURATION_MS") == "5000"
|
||||
assert env.get("CODEFLASH_PERF_STABILITY_CHECK") == "true"
|
||||
assert env.get("CODEFLASH_MODE") == "performance"
|
||||
|
||||
|
||||
class TestBundlerModuleResolutionFix:
|
||||
"""Tests for bundler moduleResolution compatibility fix."""
|
||||
|
||||
def test_detect_bundler_module_resolution_true(self):
|
||||
"""Test detection of bundler moduleResolution in tsconfig."""
|
||||
import json
|
||||
|
||||
from codeflash.languages.javascript.test_runner import _detect_bundler_module_resolution
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
|
||||
# Create tsconfig with bundler moduleResolution
|
||||
tsconfig = {
|
||||
"compilerOptions": {
|
||||
"moduleResolution": "bundler",
|
||||
"module": "preserve",
|
||||
"target": "ES2022",
|
||||
}
|
||||
}
|
||||
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
|
||||
|
||||
assert _detect_bundler_module_resolution(tmpdir_path) is True
|
||||
|
||||
def test_detect_bundler_module_resolution_false(self):
|
||||
"""Test detection returns false for Node moduleResolution."""
|
||||
import json
|
||||
|
||||
from codeflash.languages.javascript.test_runner import _detect_bundler_module_resolution
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
|
||||
# Create tsconfig with Node moduleResolution
|
||||
tsconfig = {
|
||||
"compilerOptions": {
|
||||
"moduleResolution": "Node",
|
||||
"module": "ESNext",
|
||||
}
|
||||
}
|
||||
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
|
||||
|
||||
assert _detect_bundler_module_resolution(tmpdir_path) is False
|
||||
|
||||
def test_detect_bundler_module_resolution_no_tsconfig(self):
|
||||
"""Test detection returns false when no tsconfig exists."""
|
||||
from codeflash.languages.javascript.test_runner import _detect_bundler_module_resolution
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
assert _detect_bundler_module_resolution(tmpdir_path) is False
|
||||
|
||||
def test_detect_bundler_module_resolution_extended_config(self):
|
||||
"""Test detection works with extended tsconfig files."""
|
||||
import json
|
||||
|
||||
from codeflash.languages.javascript.test_runner import _detect_bundler_module_resolution
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
|
||||
# Create a base config with bundler in a subdirectory (simulating node_modules)
|
||||
node_modules = tmpdir_path / "node_modules" / "@myorg" / "tsconfig"
|
||||
node_modules.mkdir(parents=True)
|
||||
base_tsconfig = {
|
||||
"compilerOptions": {
|
||||
"moduleResolution": "bundler",
|
||||
"module": "preserve",
|
||||
}
|
||||
}
|
||||
(node_modules / "tsconfig.json").write_text(json.dumps(base_tsconfig))
|
||||
|
||||
# Create a project tsconfig that extends the base
|
||||
project_tsconfig = {
|
||||
"extends": "@myorg/tsconfig/tsconfig.json",
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
}
|
||||
}
|
||||
(tmpdir_path / "tsconfig.json").write_text(json.dumps(project_tsconfig))
|
||||
|
||||
# Should detect bundler from extended config
|
||||
assert _detect_bundler_module_resolution(tmpdir_path) is True
|
||||
|
||||
def test_create_codeflash_tsconfig(self):
|
||||
"""Test creation of codeflash-compatible tsconfig."""
|
||||
import json
|
||||
|
||||
from codeflash.languages.javascript.test_runner import _create_codeflash_tsconfig
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
|
||||
# Create original tsconfig
|
||||
original_tsconfig = {
|
||||
"compilerOptions": {
|
||||
"moduleResolution": "bundler",
|
||||
"module": "preserve",
|
||||
"target": "ES2022",
|
||||
},
|
||||
"include": ["src/**/*.ts"],
|
||||
"exclude": ["node_modules"],
|
||||
}
|
||||
(tmpdir_path / "tsconfig.json").write_text(json.dumps(original_tsconfig))
|
||||
|
||||
# Create codeflash tsconfig
|
||||
result_path = _create_codeflash_tsconfig(tmpdir_path)
|
||||
|
||||
assert result_path.exists()
|
||||
assert result_path.name == "tsconfig.codeflash.json"
|
||||
|
||||
# Verify contents
|
||||
codeflash_tsconfig = json.loads(result_path.read_text())
|
||||
assert codeflash_tsconfig["extends"] == "./tsconfig.json"
|
||||
assert codeflash_tsconfig["compilerOptions"]["moduleResolution"] == "Node"
|
||||
assert "include" in codeflash_tsconfig
|
||||
|
||||
def test_create_codeflash_jest_config(self):
|
||||
"""Test creation of codeflash Jest config."""
|
||||
from codeflash.languages.javascript.test_runner import _create_codeflash_jest_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
|
||||
# Create codeflash Jest config without original
|
||||
result_path = _create_codeflash_jest_config(tmpdir_path, None)
|
||||
|
||||
assert result_path is not None
|
||||
assert result_path.exists()
|
||||
assert result_path.name == "jest.codeflash.config.js"
|
||||
|
||||
# Verify it contains the tsconfig reference
|
||||
content = result_path.read_text()
|
||||
assert "tsconfig.codeflash.json" in content
|
||||
assert "ts-jest" in content
|
||||
|
||||
def test_get_jest_config_for_project_with_bundler(self):
|
||||
"""Test that bundler projects get codeflash Jest config."""
|
||||
import json
|
||||
|
||||
from codeflash.languages.javascript.test_runner import _get_jest_config_for_project
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
|
||||
# Create tsconfig with bundler
|
||||
tsconfig = {
|
||||
"compilerOptions": {
|
||||
"moduleResolution": "bundler",
|
||||
"module": "preserve",
|
||||
}
|
||||
}
|
||||
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
|
||||
result = _get_jest_config_for_project(tmpdir_path)
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "jest.codeflash.config.js"
|
||||
# Also verify tsconfig.codeflash.json was created
|
||||
assert (tmpdir_path / "tsconfig.codeflash.json").exists()
|
||||
|
||||
def test_get_jest_config_for_project_without_bundler(self):
|
||||
"""Test that non-bundler projects use original Jest config."""
|
||||
import json
|
||||
|
||||
from codeflash.languages.javascript.test_runner import _get_jest_config_for_project
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
|
||||
# Create tsconfig with Node moduleResolution
|
||||
tsconfig = {
|
||||
"compilerOptions": {
|
||||
"moduleResolution": "Node",
|
||||
"module": "ESNext",
|
||||
}
|
||||
}
|
||||
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
|
||||
# Create original Jest config
|
||||
(tmpdir_path / "jest.config.js").write_text("module.exports = {};")
|
||||
|
||||
result = _get_jest_config_for_project(tmpdir_path)
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "jest.config.js"
|
||||
# Verify codeflash configs were NOT created
|
||||
assert not (tmpdir_path / "jest.codeflash.config.js").exists()
|
||||
assert not (tmpdir_path / "tsconfig.codeflash.json").exists()
|
||||
|
|
@ -39,7 +39,7 @@ class TestCodeExtractorCJS:
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
method_names = {f.name for f in functions}
|
||||
method_names = {f.function_name for f in functions}
|
||||
|
||||
expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"}
|
||||
assert method_names == expected_methods, f"Expected methods {expected_methods}, got {method_names}"
|
||||
|
|
@ -51,15 +51,15 @@ class TestCodeExtractorCJS:
|
|||
|
||||
for func in functions:
|
||||
# All methods should belong to Calculator class
|
||||
assert func.is_method is True, f"{func.name} should be a method"
|
||||
assert func.class_name == "Calculator", f"{func.name} should belong to Calculator, got {func.class_name}"
|
||||
assert func.is_method is True, f"{func.function_name} should be a method"
|
||||
assert func.class_name == "Calculator", f"{func.function_name} should belong to Calculator, got {func.class_name}"
|
||||
|
||||
def test_extract_permutation_code(self, js_support, cjs_project):
|
||||
"""Test permutation method code extraction."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.name == "permutation")
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=permutation_func, project_root=cjs_project, module_root=cjs_project
|
||||
|
|
@ -95,7 +95,7 @@ class Calculator {
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.name == "permutation")
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=permutation_func, project_root=cjs_project, module_root=cjs_project
|
||||
|
|
@ -136,7 +136,7 @@ function factorial(n) {
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=compound_func, project_root=cjs_project, module_root=cjs_project
|
||||
|
|
@ -182,7 +182,7 @@ class Calculator {
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=compound_func, project_root=cjs_project, module_root=cjs_project
|
||||
|
|
@ -266,7 +266,7 @@ function validateInput(value, name) {
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=compound_func, project_root=cjs_project, module_root=cjs_project
|
||||
|
|
@ -287,7 +287,7 @@ function validateInput(value, name) {
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
quick_add_func = next(f for f in functions if f.name == "quickAdd")
|
||||
quick_add_func = next(f for f in functions if f.function_name == "quickAdd")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=quick_add_func, project_root=cjs_project, module_root=cjs_project
|
||||
|
|
@ -352,7 +352,7 @@ class TestCodeExtractorESM:
|
|||
calculator_file = esm_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
method_names = {f.name for f in functions}
|
||||
method_names = {f.function_name for f in functions}
|
||||
|
||||
# Should find same methods as CJS version
|
||||
expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"}
|
||||
|
|
@ -363,7 +363,7 @@ class TestCodeExtractorESM:
|
|||
calculator_file = esm_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.name == "permutation")
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=permutation_func, project_root=esm_project, module_root=esm_project
|
||||
|
|
@ -413,7 +413,7 @@ export function factorial(n) {
|
|||
calculator_file = esm_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=compound_func, project_root=esm_project, module_root=esm_project
|
||||
|
|
@ -539,7 +539,7 @@ class TestCodeExtractorTypeScript:
|
|||
calculator_file = ts_project / "calculator.ts"
|
||||
functions = ts_support.discover_functions(calculator_file)
|
||||
|
||||
method_names = {f.name for f in functions}
|
||||
method_names = {f.function_name for f in functions}
|
||||
|
||||
# TypeScript has additional getHistory method
|
||||
expected_methods = {"calculateCompoundInterest", "permutation", "getHistory", "quickAdd"}
|
||||
|
|
@ -550,7 +550,7 @@ class TestCodeExtractorTypeScript:
|
|||
calculator_file = ts_project / "calculator.ts"
|
||||
functions = ts_support.discover_functions(calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.name == "permutation")
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=permutation_func, project_root=ts_project, module_root=ts_project
|
||||
|
|
@ -603,7 +603,7 @@ export function factorial(n: number): number {
|
|||
calculator_file = ts_project / "calculator.ts"
|
||||
functions = ts_support.discover_functions(calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=compound_func, project_root=ts_project, module_root=ts_project
|
||||
|
|
@ -712,7 +712,7 @@ module.exports = { standalone };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
func = next(f for f in functions if f.name == "standalone")
|
||||
func = next(f for f in functions if f.function_name == "standalone")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -745,7 +745,7 @@ module.exports = { processArray };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
func = next(f for f in functions if f.name == "processArray")
|
||||
func = next(f for f in functions if f.function_name == "processArray")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -780,7 +780,7 @@ module.exports = { fibonacci };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
func = next(f for f in functions if f.name == "fibonacci")
|
||||
func = next(f for f in functions if f.function_name == "fibonacci")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -813,7 +813,7 @@ module.exports = { processValue };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
func = next(f for f in functions if f.name == "processValue")
|
||||
func = next(f for f in functions if f.function_name == "processValue")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -871,7 +871,7 @@ module.exports = { Counter };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
increment_func = next(f for f in functions if f.name == "increment")
|
||||
increment_func = next(f for f in functions if f.function_name == "increment")
|
||||
|
||||
context = js_support.extract_code_context(function=increment_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -910,7 +910,7 @@ module.exports = { MathUtils };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
add_func = next(f for f in functions if f.name == "add")
|
||||
add_func = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = js_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -948,7 +948,7 @@ export { User };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
get_name_func = next(f for f in functions if f.name == "getName")
|
||||
get_name_func = next(f for f in functions if f.function_name == "getName")
|
||||
|
||||
context = ts_support.extract_code_context(function=get_name_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -989,7 +989,7 @@ export { Config };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
get_url_func = next(f for f in functions if f.name == "getUrl")
|
||||
get_url_func = next(f for f in functions if f.function_name == "getUrl")
|
||||
|
||||
context = ts_support.extract_code_context(function=get_url_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1030,7 +1030,7 @@ module.exports = { Logger };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
get_prefix_func = next(f for f in functions if f.name == "getPrefix")
|
||||
get_prefix_func = next(f for f in functions if f.function_name == "getPrefix")
|
||||
|
||||
context = js_support.extract_code_context(function=get_prefix_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1072,7 +1072,7 @@ module.exports = { Factory };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
create_func = next(f for f in functions if f.name == "create")
|
||||
create_func = next(f for f in functions if f.function_name == "create")
|
||||
|
||||
context = js_support.extract_code_context(function=create_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1114,19 +1114,20 @@ class TestCodeExtractorIntegration:
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
target = next(f for f in functions if f.name == "permutation")
|
||||
target = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
parents = [FunctionParent(name=p.name, type=p.type) for p in target.parents]
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name=target.name,
|
||||
function_name=target.function_name,
|
||||
file_path=target.file_path,
|
||||
parents=parents,
|
||||
starting_line=target.start_line,
|
||||
ending_line=target.end_line,
|
||||
starting_col=target.start_col,
|
||||
ending_col=target.end_col,
|
||||
starting_line=target.starting_line,
|
||||
ending_line=target.ending_line,
|
||||
starting_col=target.starting_col,
|
||||
ending_col=target.ending_col,
|
||||
is_async=target.is_async,
|
||||
is_method=target.is_method,
|
||||
language=target.language,
|
||||
)
|
||||
|
||||
|
|
@ -1223,7 +1224,7 @@ export { distance };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
distance_func = next(f for f in functions if f.name == "distance")
|
||||
distance_func = next(f for f in functions if f.function_name == "distance")
|
||||
|
||||
context = ts_support.extract_code_context(function=distance_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1267,7 +1268,7 @@ export { processStatus };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
process_func = next(f for f in functions if f.name == "processStatus")
|
||||
process_func = next(f for f in functions if f.function_name == "processStatus")
|
||||
|
||||
context = ts_support.extract_code_context(function=process_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1304,7 +1305,7 @@ export { compute };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
compute_func = next(f for f in functions if f.name == "compute")
|
||||
compute_func = next(f for f in functions if f.function_name == "compute")
|
||||
|
||||
context = ts_support.extract_code_context(function=compute_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1348,7 +1349,7 @@ export { Service };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
get_timeout_func = next(f for f in functions if f.name == "getTimeout")
|
||||
get_timeout_func = next(f for f in functions if f.function_name == "getTimeout")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=get_timeout_func, project_root=tmp_path, module_root=tmp_path
|
||||
|
|
@ -1381,7 +1382,7 @@ export { add };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
add_func = next(f for f in functions if f.name == "add")
|
||||
add_func = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = ts_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1414,7 +1415,7 @@ export { createRect };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
create_rect_func = next(f for f in functions if f.name == "createRect")
|
||||
create_rect_func = next(f for f in functions if f.function_name == "createRect")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=create_rect_func, project_root=tmp_path, module_root=tmp_path
|
||||
|
|
@ -1462,7 +1463,7 @@ export { calculateDistance };
|
|||
""")
|
||||
|
||||
functions = ts_support.discover_functions(geometry_file)
|
||||
calc_distance_func = next(f for f in functions if f.name == "calculateDistance")
|
||||
calc_distance_func = next(f for f in functions if f.function_name == "calculateDistance")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=calc_distance_func, project_root=ts_types_project, module_root=ts_types_project
|
||||
|
|
@ -1515,7 +1516,7 @@ export { greetUser };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
greet_func = next(f for f in functions if f.name == "greetUser")
|
||||
greet_func = next(f for f in functions if f.function_name == "greetUser")
|
||||
|
||||
context = ts_support.extract_code_context(function=greet_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
|
|||
|
|
@ -167,12 +167,12 @@ class TestCommonJSToESMConversion:
|
|||
f"CJS to ESM conversion failed.\nInput: {code}\nExpected: {expected}\nGot: {result}"
|
||||
)
|
||||
|
||||
def test_convert_relative_require_adds_extension(self):
|
||||
"""Test that relative imports get .js extension added - exact output."""
|
||||
def test_convert_relative_require_preserves_path(self):
|
||||
"""Test that relative imports preserve the original path without adding extension."""
|
||||
code = "const { helper } = require('./utils');"
|
||||
result = convert_commonjs_to_esm(code)
|
||||
|
||||
expected = "import { helper } from './utils.js';"
|
||||
expected = "import { helper } from './utils';"
|
||||
assert result.strip() == expected, (
|
||||
f"CJS to ESM conversion failed.\nInput: {code}\nExpected: {expected}\nGot: {result}"
|
||||
)
|
||||
|
|
@ -182,7 +182,7 @@ class TestCommonJSToESMConversion:
|
|||
code = "const myHelper = require('./utils').helperFunction;"
|
||||
result = convert_commonjs_to_esm(code)
|
||||
|
||||
expected = "import { helperFunction as myHelper } from './utils.js';"
|
||||
expected = "import { helperFunction as myHelper } from './utils';"
|
||||
assert result.strip() == expected, (
|
||||
f"CJS to ESM conversion failed.\nInput: {code}\nExpected: {expected}\nGot: {result}"
|
||||
)
|
||||
|
|
@ -192,7 +192,7 @@ class TestCommonJSToESMConversion:
|
|||
code = "const MyClass = require('./class').default;"
|
||||
result = convert_commonjs_to_esm(code)
|
||||
|
||||
expected = "import MyClass from './class.js';"
|
||||
expected = "import MyClass from './class';"
|
||||
assert result.strip() == expected, (
|
||||
f"CJS to ESM conversion failed.\nInput: {code}\nExpected: {expected}\nGot: {result}"
|
||||
)
|
||||
|
|
@ -207,7 +207,7 @@ const path = require('path');"""
|
|||
result = convert_commonjs_to_esm(code)
|
||||
|
||||
expected = """\
|
||||
import { add, subtract } from './math.js';
|
||||
import { add, subtract } from './math';
|
||||
import lodash from 'lodash';
|
||||
import path from 'path';"""
|
||||
|
||||
|
|
@ -316,7 +316,7 @@ function process() {
|
|||
result = ensure_module_system_compatibility(code, ModuleSystem.ES_MODULE)
|
||||
|
||||
# Should convert require to import
|
||||
assert "import { helper } from './helpers.js';" in result
|
||||
assert "import { helper } from './helpers';" in result
|
||||
assert "require" not in result, f"require should be converted to import. Got:\n{result}"
|
||||
|
||||
def test_convert_mixed_code_to_commonjs(self):
|
||||
|
|
@ -711,7 +711,7 @@ module.exports = { targetFunction, otherFunction };
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
target_func = next(f for f in functions if f.name == "targetFunction")
|
||||
target_func = next(f for f in functions if f.function_name == "targetFunction")
|
||||
|
||||
optimized_code = """\
|
||||
function targetFunction(x) {
|
||||
|
|
@ -763,7 +763,7 @@ class Calculator {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_method = next(f for f in functions if f.name == "add")
|
||||
add_method = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
# Optimized version provided in class context
|
||||
optimized_code = """\
|
||||
|
|
@ -826,7 +826,7 @@ class DataProcessor {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
process_method = next(f for f in functions if f.name == "process")
|
||||
process_method = next(f for f in functions if f.function_name == "process")
|
||||
|
||||
optimized_code = """\
|
||||
class DataProcessor {
|
||||
|
|
@ -948,7 +948,7 @@ class Cache {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
get_method = next(f for f in functions if f.name == "get")
|
||||
get_method = next(f for f in functions if f.function_name == "get")
|
||||
|
||||
optimized_code = """\
|
||||
class Cache {
|
||||
|
|
@ -1050,7 +1050,7 @@ class ApiClient {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
get_method = next(f for f in functions if f.name == "get")
|
||||
get_method = next(f for f in functions if f.function_name == "get")
|
||||
|
||||
optimized_code = """\
|
||||
class ApiClient {
|
||||
|
|
@ -1181,7 +1181,7 @@ class Container<T> {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
get_all_method = next(f for f in functions if f.name == "getAll")
|
||||
get_all_method = next(f for f in functions if f.function_name == "getAll")
|
||||
|
||||
optimized_code = """\
|
||||
class Container<T> {
|
||||
|
|
@ -1234,7 +1234,7 @@ function createUser(name: string, email: string): User {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
func = next(f for f in functions if f.name == "createUser")
|
||||
func = next(f for f in functions if f.function_name == "createUser")
|
||||
|
||||
optimized_code = """\
|
||||
function createUser(name: string, email: string): User {
|
||||
|
|
@ -1289,7 +1289,7 @@ function processItems(items) {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
process_func = next(f for f in functions if f.name == "processItems")
|
||||
process_func = next(f for f in functions if f.function_name == "processItems")
|
||||
|
||||
optimized_code = """\
|
||||
function processItems(items) {
|
||||
|
|
@ -1336,7 +1336,7 @@ class MathUtils {
|
|||
|
||||
# First replacement: sum method
|
||||
functions = js_support.discover_functions(file_path)
|
||||
sum_method = next(f for f in functions if f.name == "sum")
|
||||
sum_method = next(f for f in functions if f.function_name == "sum")
|
||||
|
||||
optimized_sum = """\
|
||||
class MathUtils {
|
||||
|
|
@ -1554,7 +1554,7 @@ module.exports = { main, helper };
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
main_func = next(f for f in functions if f.name == "main")
|
||||
main_func = next(f for f in functions if f.function_name == "main")
|
||||
|
||||
optimized_code = """\
|
||||
function main(data) {
|
||||
|
|
@ -1597,7 +1597,7 @@ export function main(data) {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
main_func = next(f for f in functions if f.name == "main")
|
||||
main_func = next(f for f in functions if f.function_name == "main")
|
||||
|
||||
optimized_code = """\
|
||||
export function main(data) {
|
||||
|
|
@ -1756,7 +1756,7 @@ export class DataProcessor<T> {
|
|||
# find function
|
||||
target_func_info = None
|
||||
for func in functions:
|
||||
if func.name == target_func and func.parents[0].name == parent_class:
|
||||
if func.function_name == target_func and func.parents[0].name == parent_class:
|
||||
target_func_info = func
|
||||
break
|
||||
assert target_func_info is not None
|
||||
|
|
@ -1893,3 +1893,252 @@ export class DataProcessor<T> {
|
|||
}
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class TestNewVariableFromOptimizedCode:
|
||||
"""Tests for handling new variables introduced in optimized code."""
|
||||
|
||||
def test_new_bound_method_variable_added_after_referenced_constant(self, ts_support, temp_project):
|
||||
"""Test that a new variable binding a method is added after the constant it references.
|
||||
|
||||
When optimized code introduces a new module-level variable (like `_has`) that
|
||||
references an existing constant (like `CODEFLASH_EMPLOYEE_GITHUB_IDS`), the
|
||||
replacement should:
|
||||
1. Add the new variable after the constant it references
|
||||
2. Replace the function with the optimized version
|
||||
"""
|
||||
from codeflash.models.models import CodeStringsMarkdown, CodeString
|
||||
|
||||
original_source = '''\
|
||||
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
|
||||
"1234",
|
||||
]);
|
||||
|
||||
export function isCodeflashEmployee(userId: string): boolean {
|
||||
return CODEFLASH_EMPLOYEE_GITHUB_IDS.has(userId);
|
||||
}
|
||||
'''
|
||||
file_path = temp_project / "auth.ts"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
# Optimized code introduces a bound method variable for performance
|
||||
optimized_code = '''const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
|
||||
CODEFLASH_EMPLOYEE_GITHUB_IDS
|
||||
);
|
||||
|
||||
export function isCodeflashEmployee(userId: string): boolean {
|
||||
return _has(userId);
|
||||
}
|
||||
'''
|
||||
|
||||
code_markdown = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code=optimized_code,
|
||||
file_path=Path("auth.ts"),
|
||||
language="typescript"
|
||||
)
|
||||
],
|
||||
language="typescript"
|
||||
)
|
||||
|
||||
replaced = replace_function_definitions_for_language(
|
||||
["isCodeflashEmployee"],
|
||||
code_markdown,
|
||||
file_path,
|
||||
temp_project,
|
||||
)
|
||||
|
||||
assert replaced
|
||||
result = file_path.read_text()
|
||||
|
||||
# Expected result for strict equality check
|
||||
expected_result = '''\
|
||||
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
|
||||
"1234",
|
||||
]);
|
||||
|
||||
const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
|
||||
CODEFLASH_EMPLOYEE_GITHUB_IDS
|
||||
);
|
||||
|
||||
export function isCodeflashEmployee(userId: string): boolean {
|
||||
return _has(userId);
|
||||
}
|
||||
'''
|
||||
assert result == expected_result, (
|
||||
f"Result does not match expected output.\n"
|
||||
f"Expected:\n{expected_result}\n\n"
|
||||
f"Got:\n{result}"
|
||||
)
|
||||
|
||||
|
||||
class TestImportedTypeNotDuplicated:
|
||||
"""Tests to ensure imported types are not duplicated during code replacement.
|
||||
|
||||
When a type is already imported in the original file, it should NOT be
|
||||
added as a new declaration from the optimized code, even if the optimized
|
||||
code contains the type definition (because it was provided as context).
|
||||
|
||||
See: https://github.com/codeflash-ai/appsmith/pull/20
|
||||
"""
|
||||
|
||||
def test_imported_interface_not_added_as_declaration(self, ts_support, temp_project):
|
||||
"""Test that an imported interface is not duplicated in the output.
|
||||
|
||||
When TreeNode is imported from another file and the optimized code
|
||||
contains the TreeNode interface definition (from read-only context),
|
||||
the replacement should NOT add the interface to the original file.
|
||||
"""
|
||||
from codeflash.models.models import CodeStringsMarkdown, CodeString
|
||||
|
||||
# Original source imports TreeNode
|
||||
original_source = """\
|
||||
import type { TreeNode } from "./constants";
|
||||
|
||||
export function getNearestAbove(
|
||||
tree: Record<string, TreeNode>,
|
||||
effectedBoxId: string,
|
||||
) {
|
||||
const aboves = tree[effectedBoxId].aboves;
|
||||
return aboves.reduce((prev: string[], next: string) => {
|
||||
if (!prev[0]) return [next];
|
||||
let nextBottomRow = tree[next].bottomRow;
|
||||
let prevBottomRow = tree[prev[0]].bottomRow;
|
||||
if (nextBottomRow > prevBottomRow) return [next];
|
||||
return prev;
|
||||
}, []);
|
||||
}
|
||||
"""
|
||||
file_path = temp_project / "helpers.ts"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
# Optimized code includes the TreeNode interface (from read-only context)
|
||||
# This simulates what the AI might return when type definitions are included in context
|
||||
optimized_code_with_interface = """\
|
||||
interface TreeNode {
|
||||
aboves: string[];
|
||||
belows: string[];
|
||||
topRow: number;
|
||||
bottomRow: number;
|
||||
}
|
||||
|
||||
export function getNearestAbove(
|
||||
tree: Record<string, TreeNode>,
|
||||
effectedBoxId: string,
|
||||
) {
|
||||
const aboves = tree[effectedBoxId].aboves;
|
||||
return aboves.reduce((prev: string[], next: string) => {
|
||||
if (!prev[0]) return [next];
|
||||
// Optimized: cache lookups
|
||||
const nextBottomRow = tree[next].bottomRow;
|
||||
const prevBottomRow = tree[prev[0]].bottomRow;
|
||||
return nextBottomRow > prevBottomRow ? [next] : prev;
|
||||
}, []);
|
||||
}
|
||||
"""
|
||||
|
||||
code_markdown = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code=optimized_code_with_interface,
|
||||
file_path=Path("helpers.ts"),
|
||||
language="typescript"
|
||||
)
|
||||
],
|
||||
language="typescript"
|
||||
)
|
||||
|
||||
replace_function_definitions_for_language(
|
||||
["getNearestAbove"],
|
||||
code_markdown,
|
||||
file_path,
|
||||
temp_project,
|
||||
)
|
||||
|
||||
result = file_path.read_text()
|
||||
|
||||
# The TreeNode interface should NOT appear in the result
|
||||
# (it's already imported, so adding it would cause a duplicate)
|
||||
assert "interface TreeNode" not in result, (
|
||||
f"TreeNode interface should NOT be added to the file since it's already imported.\n"
|
||||
f"Result contains:\n{result}"
|
||||
)
|
||||
|
||||
# The import should still be there
|
||||
assert 'import type { TreeNode } from "./constants"' in result, (
|
||||
f"Original import should be preserved.\nResult:\n{result}"
|
||||
)
|
||||
|
||||
# The optimized function should be there
|
||||
assert "// Optimized: cache lookups" in result, (
|
||||
f"Optimized function should be in the result.\nResult:\n{result}"
|
||||
)
|
||||
|
||||
# The result should be valid TypeScript
|
||||
assert ts_support.validate_syntax(result) is True
|
||||
|
||||
def test_multiple_imported_types_not_duplicated(self, ts_support, temp_project):
|
||||
"""Test that multiple imported types are not duplicated."""
|
||||
from codeflash.models.models import CodeStringsMarkdown, CodeString
|
||||
|
||||
original_source = """\
|
||||
import type { TreeNode, NodeSpace } from "./constants";
|
||||
import { MAX_BOX_SIZE } from "./constants";
|
||||
|
||||
export function processNode(node: TreeNode, space: NodeSpace): number {
|
||||
return node.topRow + space.top;
|
||||
}
|
||||
"""
|
||||
file_path = temp_project / "processor.ts"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
# Optimized code includes both interfaces
|
||||
optimized_code = """\
|
||||
interface TreeNode {
|
||||
topRow: number;
|
||||
bottomRow: number;
|
||||
}
|
||||
|
||||
interface NodeSpace {
|
||||
top: number;
|
||||
bottom: number;
|
||||
}
|
||||
|
||||
export function processNode(node: TreeNode, space: NodeSpace): number {
|
||||
// Optimized
|
||||
return (node.topRow + space.top) | 0;
|
||||
}
|
||||
"""
|
||||
|
||||
code_markdown = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code=optimized_code,
|
||||
file_path=Path("processor.ts"),
|
||||
language="typescript"
|
||||
)
|
||||
],
|
||||
language="typescript"
|
||||
)
|
||||
|
||||
replace_function_definitions_for_language(
|
||||
["processNode"],
|
||||
code_markdown,
|
||||
file_path,
|
||||
temp_project,
|
||||
)
|
||||
|
||||
result = file_path.read_text()
|
||||
|
||||
# Neither interface should be added
|
||||
assert "interface TreeNode" not in result
|
||||
assert "interface NodeSpace" not in result
|
||||
|
||||
# Imports should be preserved
|
||||
assert 'import type { TreeNode, NodeSpace } from "./constants"' in result
|
||||
|
||||
# Optimized code should be there
|
||||
assert "// Optimized" in result
|
||||
|
||||
assert ts_support.validate_syntax(result) is True
|
||||
|
|
|
|||
|
|
@ -353,8 +353,8 @@ class TestDiscoverFunctionsParity:
|
|||
assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1"
|
||||
|
||||
# Both should find 'add'
|
||||
assert py_funcs[0].name == "add"
|
||||
assert js_funcs[0].name == "add"
|
||||
assert py_funcs[0].function_name == "add"
|
||||
assert js_funcs[0].function_name == "add"
|
||||
|
||||
# Both should have correct language
|
||||
assert py_funcs[0].language == Language.PYTHON
|
||||
|
|
@ -373,8 +373,8 @@ class TestDiscoverFunctionsParity:
|
|||
assert len(js_funcs) == 3, f"JavaScript found {len(js_funcs)}, expected 3"
|
||||
|
||||
# Both should find the same function names
|
||||
py_names = {f.name for f in py_funcs}
|
||||
js_names = {f.name for f in js_funcs}
|
||||
py_names = {f.function_name for f in py_funcs}
|
||||
js_names = {f.function_name for f in js_funcs}
|
||||
|
||||
assert py_names == {"add", "subtract", "multiply"}
|
||||
assert js_names == {"add", "subtract", "multiply"}
|
||||
|
|
@ -392,8 +392,8 @@ class TestDiscoverFunctionsParity:
|
|||
assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1"
|
||||
|
||||
# The function with return should be found
|
||||
assert py_funcs[0].name == "with_return"
|
||||
assert js_funcs[0].name == "withReturn"
|
||||
assert py_funcs[0].function_name == "with_return"
|
||||
assert js_funcs[0].function_name == "withReturn"
|
||||
|
||||
def test_class_methods_discovery(self, python_support, js_support):
|
||||
"""Both should discover class methods with proper metadata."""
|
||||
|
|
@ -409,12 +409,12 @@ class TestDiscoverFunctionsParity:
|
|||
|
||||
# All should be marked as methods
|
||||
for func in py_funcs:
|
||||
assert func.is_method is True, f"Python {func.name} should be a method"
|
||||
assert func.class_name == "Calculator", f"Python {func.name} should belong to Calculator"
|
||||
assert func.is_method is True, f"Python {func.function_name} should be a method"
|
||||
assert func.class_name == "Calculator", f"Python {func.function_name} should belong to Calculator"
|
||||
|
||||
for func in js_funcs:
|
||||
assert func.is_method is True, f"JavaScript {func.name} should be a method"
|
||||
assert func.class_name == "Calculator", f"JavaScript {func.name} should belong to Calculator"
|
||||
assert func.is_method is True, f"JavaScript {func.function_name} should be a method"
|
||||
assert func.class_name == "Calculator", f"JavaScript {func.function_name} should belong to Calculator"
|
||||
|
||||
def test_async_functions_discovery(self, python_support, js_support):
|
||||
"""Both should correctly identify async functions."""
|
||||
|
|
@ -429,10 +429,10 @@ class TestDiscoverFunctionsParity:
|
|||
assert len(js_funcs) == 2, f"JavaScript found {len(js_funcs)}, expected 2"
|
||||
|
||||
# Check async flags
|
||||
py_async = next(f for f in py_funcs if "fetch" in f.name.lower())
|
||||
py_sync = next(f for f in py_funcs if "sync" in f.name.lower())
|
||||
js_async = next(f for f in js_funcs if "fetch" in f.name.lower())
|
||||
js_sync = next(f for f in js_funcs if "sync" in f.name.lower())
|
||||
py_async = next(f for f in py_funcs if "fetch" in f.function_name.lower())
|
||||
py_sync = next(f for f in py_funcs if "sync" in f.function_name.lower())
|
||||
js_async = next(f for f in js_funcs if "fetch" in f.function_name.lower())
|
||||
js_sync = next(f for f in js_funcs if "sync" in f.function_name.lower())
|
||||
|
||||
assert py_async.is_async is True, "Python async function should have is_async=True"
|
||||
assert py_sync.is_async is False, "Python sync function should have is_async=False"
|
||||
|
|
@ -452,15 +452,15 @@ class TestDiscoverFunctionsParity:
|
|||
assert len(js_funcs) == 2, f"JavaScript found {len(js_funcs)}, expected 2"
|
||||
|
||||
# Check names
|
||||
py_names = {f.name for f in py_funcs}
|
||||
js_names = {f.name for f in js_funcs}
|
||||
py_names = {f.function_name for f in py_funcs}
|
||||
js_names = {f.function_name for f in js_funcs}
|
||||
|
||||
assert py_names == {"outer", "inner"}, f"Python found {py_names}"
|
||||
assert js_names == {"outer", "inner"}, f"JavaScript found {js_names}"
|
||||
|
||||
# Check parent info for inner function
|
||||
py_inner = next(f for f in py_funcs if f.name == "inner")
|
||||
js_inner = next(f for f in js_funcs if f.name == "inner")
|
||||
py_inner = next(f for f in py_funcs if f.function_name == "inner")
|
||||
js_inner = next(f for f in js_funcs if f.function_name == "inner")
|
||||
|
||||
assert len(py_inner.parents) >= 1, "Python inner should have parent info"
|
||||
assert py_inner.parents[0].name == "outer", "Python inner's parent should be outer"
|
||||
|
|
@ -482,8 +482,8 @@ class TestDiscoverFunctionsParity:
|
|||
assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1"
|
||||
|
||||
# Both should find 'helper' belonging to 'Utils'
|
||||
assert py_funcs[0].name == "helper"
|
||||
assert js_funcs[0].name == "helper"
|
||||
assert py_funcs[0].function_name == "helper"
|
||||
assert js_funcs[0].function_name == "helper"
|
||||
assert py_funcs[0].class_name == "Utils"
|
||||
assert js_funcs[0].class_name == "Utils"
|
||||
|
||||
|
|
@ -532,8 +532,8 @@ class TestDiscoverFunctionsParity:
|
|||
assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1"
|
||||
|
||||
# Should be the sync function
|
||||
assert "sync" in py_funcs[0].name.lower()
|
||||
assert "sync" in js_funcs[0].name.lower()
|
||||
assert "sync" in py_funcs[0].function_name.lower()
|
||||
assert "sync" in js_funcs[0].function_name.lower()
|
||||
|
||||
def test_filter_exclude_methods(self, python_support, js_support):
|
||||
"""Both should support filtering out class methods."""
|
||||
|
|
@ -550,8 +550,8 @@ class TestDiscoverFunctionsParity:
|
|||
assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1"
|
||||
|
||||
# Should be the standalone function
|
||||
assert py_funcs[0].name == "standalone"
|
||||
assert js_funcs[0].name == "standalone"
|
||||
assert py_funcs[0].function_name == "standalone"
|
||||
assert js_funcs[0].function_name == "standalone"
|
||||
|
||||
def test_nonexistent_file_returns_empty(self, python_support, js_support):
|
||||
"""Both should return empty list for nonexistent files."""
|
||||
|
|
@ -570,14 +570,14 @@ class TestDiscoverFunctionsParity:
|
|||
js_funcs = js_support.discover_functions(js_file)
|
||||
|
||||
# Both should have start_line and end_line
|
||||
assert py_funcs[0].start_line is not None
|
||||
assert py_funcs[0].end_line is not None
|
||||
assert js_funcs[0].start_line is not None
|
||||
assert js_funcs[0].end_line is not None
|
||||
assert py_funcs[0].starting_line is not None
|
||||
assert py_funcs[0].ending_line is not None
|
||||
assert js_funcs[0].starting_line is not None
|
||||
assert js_funcs[0].ending_line is not None
|
||||
|
||||
# Start should be before or equal to end
|
||||
assert py_funcs[0].start_line <= py_funcs[0].end_line
|
||||
assert js_funcs[0].start_line <= js_funcs[0].end_line
|
||||
assert py_funcs[0].starting_line <= py_funcs[0].ending_line
|
||||
assert js_funcs[0].starting_line <= js_funcs[0].ending_line
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
|
@ -604,8 +604,8 @@ function multiply(a, b) {
|
|||
return a * b;
|
||||
}
|
||||
"""
|
||||
py_func = FunctionInfo(name="add", file_path=Path("/test.py"), start_line=1, end_line=2)
|
||||
js_func = FunctionInfo(name="add", file_path=Path("/test.js"), start_line=1, end_line=3)
|
||||
py_func = FunctionInfo(function_name="add", file_path=Path("/test.py"), starting_line=1, ending_line=2)
|
||||
js_func = FunctionInfo(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3)
|
||||
|
||||
py_new = """def add(a, b):
|
||||
return (a + b) | 0
|
||||
|
|
@ -651,8 +651,8 @@ function other() {
|
|||
|
||||
// Footer
|
||||
"""
|
||||
py_func = FunctionInfo(name="target", file_path=Path("/test.py"), start_line=4, end_line=5)
|
||||
js_func = FunctionInfo(name="target", file_path=Path("/test.js"), start_line=4, end_line=6)
|
||||
py_func = FunctionInfo(function_name="target", file_path=Path("/test.py"), starting_line=4, ending_line=5)
|
||||
js_func = FunctionInfo(function_name="target", file_path=Path("/test.js"), starting_line=4, ending_line=6)
|
||||
|
||||
py_new = """def target():
|
||||
return 42
|
||||
|
|
@ -693,18 +693,18 @@ function other() {
|
|||
}
|
||||
"""
|
||||
py_func = FunctionInfo(
|
||||
name="add",
|
||||
function_name="add",
|
||||
file_path=Path("/test.py"),
|
||||
start_line=2,
|
||||
end_line=3,
|
||||
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
|
||||
starting_line=2,
|
||||
ending_line=3,
|
||||
parents=[ParentInfo(name="Calculator", type="ClassDef")],
|
||||
)
|
||||
js_func = FunctionInfo(
|
||||
name="add",
|
||||
function_name="add",
|
||||
file_path=Path("/test.js"),
|
||||
start_line=2,
|
||||
end_line=4,
|
||||
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
|
||||
starting_line=2,
|
||||
ending_line=4,
|
||||
parents=[ParentInfo(name="Calculator", type="ClassDef")],
|
||||
)
|
||||
|
||||
# New code without indentation
|
||||
|
|
@ -872,8 +872,8 @@ class TestExtractCodeContextParity:
|
|||
".js",
|
||||
)
|
||||
|
||||
py_func = FunctionInfo(name="add", file_path=py_file, start_line=1, end_line=2)
|
||||
js_func = FunctionInfo(name="add", file_path=js_file, start_line=1, end_line=3)
|
||||
py_func = FunctionInfo(function_name="add", file_path=py_file, starting_line=1, ending_line=2)
|
||||
js_func = FunctionInfo(function_name="add", file_path=js_file, starting_line=1, ending_line=3)
|
||||
|
||||
py_context = python_support.extract_code_context(py_func, py_file.parent, py_file.parent)
|
||||
js_context = js_support.extract_code_context(js_func, js_file.parent, js_file.parent)
|
||||
|
|
@ -922,8 +922,8 @@ class TestIntegrationParity:
|
|||
|
||||
assert len(py_funcs) == 1
|
||||
assert len(js_funcs) == 1
|
||||
assert py_funcs[0].name == "fibonacci"
|
||||
assert js_funcs[0].name == "fibonacci"
|
||||
assert py_funcs[0].function_name == "fibonacci"
|
||||
assert js_funcs[0].function_name == "fibonacci"
|
||||
|
||||
# Replace
|
||||
py_optimized = """def fibonacci(n):
|
||||
|
|
@ -974,20 +974,20 @@ class TestFeatureGaps:
|
|||
|
||||
for py_func in py_funcs:
|
||||
# Check all expected fields are populated
|
||||
assert py_func.name is not None, "Python: name should be populated"
|
||||
assert py_func.function_name is not None, "Python: name should be populated"
|
||||
assert py_func.file_path is not None, "Python: file_path should be populated"
|
||||
assert py_func.start_line is not None, "Python: start_line should be populated"
|
||||
assert py_func.end_line is not None, "Python: end_line should be populated"
|
||||
assert py_func.starting_line is not None, "Python: start_line should be populated"
|
||||
assert py_func.ending_line is not None, "Python: end_line should be populated"
|
||||
assert py_func.language is not None, "Python: language should be populated"
|
||||
# is_method and class_name should be set for class methods
|
||||
assert py_func.is_method is not None, "Python: is_method should be populated"
|
||||
|
||||
for js_func in js_funcs:
|
||||
# JavaScript should populate the same fields
|
||||
assert js_func.name is not None, "JavaScript: name should be populated"
|
||||
assert js_func.function_name is not None, "JavaScript: name should be populated"
|
||||
assert js_func.file_path is not None, "JavaScript: file_path should be populated"
|
||||
assert js_func.start_line is not None, "JavaScript: start_line should be populated"
|
||||
assert js_func.end_line is not None, "JavaScript: end_line should be populated"
|
||||
assert js_func.starting_line is not None, "JavaScript: start_line should be populated"
|
||||
assert js_func.ending_line is not None, "JavaScript: end_line should be populated"
|
||||
assert js_func.language is not None, "JavaScript: language should be populated"
|
||||
assert js_func.is_method is not None, "JavaScript: is_method should be populated"
|
||||
|
||||
|
|
@ -1006,7 +1006,7 @@ const identity = x => x;
|
|||
funcs = js_support.discover_functions(js_file)
|
||||
|
||||
# Should find all arrow functions
|
||||
names = {f.name for f in funcs}
|
||||
names = {f.function_name for f in funcs}
|
||||
assert "add" in names, "Should find arrow function 'add'"
|
||||
assert "multiply" in names, "Should find concise arrow function 'multiply'"
|
||||
# identity might or might not be found depending on implicit return handling
|
||||
|
|
@ -1057,7 +1057,7 @@ def multi_decorated():
|
|||
funcs = python_support.discover_functions(py_file)
|
||||
|
||||
# Should find all functions regardless of decorators
|
||||
names = {f.name for f in funcs}
|
||||
names = {f.function_name for f in funcs}
|
||||
assert "decorated" in names
|
||||
assert "decorated_with_args" in names
|
||||
assert "multi_decorated" in names
|
||||
|
|
@ -1077,7 +1077,7 @@ const namedExpr = function myFunc(x) {
|
|||
funcs = js_support.discover_functions(js_file)
|
||||
|
||||
# Should find function expressions
|
||||
names = {f.name for f in funcs}
|
||||
names = {f.function_name for f in funcs}
|
||||
assert "add" in names, "Should find anonymous function expression assigned to 'add'"
|
||||
|
||||
|
||||
|
|
@ -1144,5 +1144,5 @@ function greeting() {
|
|||
|
||||
assert len(py_funcs) == 1
|
||||
assert len(js_funcs) == 1
|
||||
assert py_funcs[0].name == "greeting"
|
||||
assert js_funcs[0].name == "greeting"
|
||||
assert py_funcs[0].function_name == "greeting"
|
||||
assert js_funcs[0].function_name == "greeting"
|
||||
|
|
|
|||
|
|
@ -113,19 +113,20 @@ def test_js_replcement() -> None:
|
|||
functions = js_support.discover_functions(main_file)
|
||||
target = None
|
||||
for func in functions:
|
||||
if func.name == "calculateStats":
|
||||
if func.function_name == "calculateStats":
|
||||
target = func
|
||||
break
|
||||
assert target is not None
|
||||
func = FunctionToOptimize(
|
||||
function_name=target.name,
|
||||
function_name=target.function_name,
|
||||
file_path=target.file_path,
|
||||
parents=target.parents,
|
||||
starting_line=target.start_line,
|
||||
ending_line=target.end_line,
|
||||
starting_col=target.start_col,
|
||||
ending_col=target.end_col,
|
||||
starting_line=target.starting_line,
|
||||
ending_line=target.ending_line,
|
||||
starting_col=target.starting_col,
|
||||
ending_col=target.ending_col,
|
||||
is_async=target.is_async,
|
||||
is_method=target.is_method,
|
||||
language=target.language,
|
||||
)
|
||||
test_config = TestConfig(
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ def add(a, b):
|
|||
functions = python_support.discover_functions(Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "add"
|
||||
assert functions[0].function_name == "add"
|
||||
assert functions[0].language == Language.PYTHON
|
||||
|
||||
def test_discover_multiple_functions(self, python_support):
|
||||
|
|
@ -73,7 +73,7 @@ def multiply(a, b):
|
|||
functions = python_support.discover_functions(Path(f.name))
|
||||
|
||||
assert len(functions) == 3
|
||||
names = {func.name for func in functions}
|
||||
names = {func.function_name for func in functions}
|
||||
assert names == {"add", "subtract", "multiply"}
|
||||
|
||||
def test_discover_function_with_no_return_excluded(self, python_support):
|
||||
|
|
@ -92,7 +92,7 @@ def without_return():
|
|||
|
||||
# Only the function with return should be discovered
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "with_return"
|
||||
assert functions[0].function_name == "with_return"
|
||||
|
||||
def test_discover_class_methods(self, python_support):
|
||||
"""Test discovering class methods."""
|
||||
|
|
@ -130,8 +130,8 @@ def sync_function():
|
|||
|
||||
assert len(functions) == 2
|
||||
|
||||
async_func = next(f for f in functions if f.name == "fetch_data")
|
||||
sync_func = next(f for f in functions if f.name == "sync_function")
|
||||
async_func = next(f for f in functions if f.function_name == "fetch_data")
|
||||
sync_func = next(f for f in functions if f.function_name == "sync_function")
|
||||
|
||||
assert async_func.is_async is True
|
||||
assert sync_func.is_async is False
|
||||
|
|
@ -151,11 +151,11 @@ def outer():
|
|||
|
||||
# Both outer and inner should be discovered
|
||||
assert len(functions) == 2
|
||||
names = {func.name for func in functions}
|
||||
names = {func.function_name for func in functions}
|
||||
assert names == {"outer", "inner"}
|
||||
|
||||
# Inner should have outer as parent
|
||||
inner = next(f for f in functions if f.name == "inner")
|
||||
inner = next(f for f in functions if f.function_name == "inner")
|
||||
assert len(inner.parents) == 1
|
||||
assert inner.parents[0].name == "outer"
|
||||
assert inner.parents[0].type == "FunctionDef"
|
||||
|
|
@ -174,7 +174,7 @@ class Utils:
|
|||
functions = python_support.discover_functions(Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "helper"
|
||||
assert functions[0].function_name == "helper"
|
||||
assert functions[0].class_name == "Utils"
|
||||
|
||||
def test_discover_with_filter_exclude_async(self, python_support):
|
||||
|
|
@ -193,7 +193,7 @@ def sync_func():
|
|||
functions = python_support.discover_functions(Path(f.name), criteria)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "sync_func"
|
||||
assert functions[0].function_name == "sync_func"
|
||||
|
||||
def test_discover_with_filter_exclude_methods(self, python_support):
|
||||
"""Test filtering out class methods."""
|
||||
|
|
@ -212,7 +212,7 @@ class MyClass:
|
|||
functions = python_support.discover_functions(Path(f.name), criteria)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "standalone"
|
||||
assert functions[0].function_name == "standalone"
|
||||
|
||||
def test_discover_line_numbers(self, python_support):
|
||||
"""Test that line numbers are correctly captured."""
|
||||
|
|
@ -229,13 +229,13 @@ def func2():
|
|||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
|
||||
func1 = next(f for f in functions if f.name == "func1")
|
||||
func2 = next(f for f in functions if f.name == "func2")
|
||||
func1 = next(f for f in functions if f.function_name == "func1")
|
||||
func2 = next(f for f in functions if f.function_name == "func2")
|
||||
|
||||
assert func1.start_line == 1
|
||||
assert func1.end_line == 2
|
||||
assert func2.start_line == 4
|
||||
assert func2.end_line == 7
|
||||
assert func1.starting_line == 1
|
||||
assert func1.ending_line == 2
|
||||
assert func2.starting_line == 4
|
||||
assert func2.ending_line == 7
|
||||
|
||||
def test_discover_invalid_file_returns_empty(self, python_support):
|
||||
"""Test that invalid Python file returns empty list."""
|
||||
|
|
@ -263,7 +263,7 @@ class TestReplaceFunction:
|
|||
def multiply(a, b):
|
||||
return a * b
|
||||
"""
|
||||
func = FunctionInfo(name="add", file_path=Path("/test.py"), start_line=1, end_line=2)
|
||||
func = FunctionInfo(function_name="add", file_path=Path("/test.py"), starting_line=1, ending_line=2)
|
||||
new_code = """def add(a, b):
|
||||
# Optimized
|
||||
return (a + b) | 0
|
||||
|
|
@ -287,7 +287,7 @@ def other():
|
|||
|
||||
# Footer
|
||||
"""
|
||||
func = FunctionInfo(name="target", file_path=Path("/test.py"), start_line=4, end_line=5)
|
||||
func = FunctionInfo(function_name="target", file_path=Path("/test.py"), starting_line=4, ending_line=5)
|
||||
new_code = """def target():
|
||||
return 42
|
||||
"""
|
||||
|
|
@ -306,11 +306,11 @@ def other():
|
|||
return a + b
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
function_name="add",
|
||||
file_path=Path("/test.py"),
|
||||
start_line=2,
|
||||
end_line=3,
|
||||
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
|
||||
starting_line=2,
|
||||
ending_line=3,
|
||||
parents=[ParentInfo(name="Calculator", type="ClassDef")],
|
||||
)
|
||||
# New code has no indentation
|
||||
new_code = """def add(self, a, b):
|
||||
|
|
@ -331,7 +331,7 @@ def other():
|
|||
def second():
|
||||
return 2
|
||||
"""
|
||||
func = FunctionInfo(name="first", file_path=Path("/test.py"), start_line=1, end_line=2)
|
||||
func = FunctionInfo(function_name="first", file_path=Path("/test.py"), starting_line=1, ending_line=2)
|
||||
new_code = """def first():
|
||||
return 100
|
||||
"""
|
||||
|
|
@ -348,7 +348,7 @@ def second():
|
|||
def last():
|
||||
return 999
|
||||
"""
|
||||
func = FunctionInfo(name="last", file_path=Path("/test.py"), start_line=4, end_line=5)
|
||||
func = FunctionInfo(function_name="last", file_path=Path("/test.py"), starting_line=4, ending_line=5)
|
||||
new_code = """def last():
|
||||
return 1000
|
||||
"""
|
||||
|
|
@ -362,7 +362,7 @@ def last():
|
|||
source = """def only():
|
||||
return 42
|
||||
"""
|
||||
func = FunctionInfo(name="only", file_path=Path("/test.py"), start_line=1, end_line=2)
|
||||
func = FunctionInfo(function_name="only", file_path=Path("/test.py"), starting_line=1, ending_line=2)
|
||||
new_code = """def only():
|
||||
return 100
|
||||
"""
|
||||
|
|
@ -474,7 +474,7 @@ class TestExtractCodeContext:
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
func = FunctionInfo(name="add", file_path=file_path, start_line=1, end_line=2)
|
||||
func = FunctionInfo(function_name="add", file_path=file_path, starting_line=1, ending_line=2)
|
||||
|
||||
context = python_support.extract_code_context(func, file_path.parent, file_path.parent)
|
||||
|
||||
|
|
@ -503,7 +503,7 @@ class TestIntegration:
|
|||
functions = python_support.discover_functions(file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.name == "fibonacci"
|
||||
assert func.function_name == "fibonacci"
|
||||
|
||||
# Replace
|
||||
optimized_code = """def fibonacci(n):
|
||||
|
|
|
|||
|
|
@ -351,6 +351,32 @@ class TestFindImports:
|
|||
assert imports[0].module_path == "fs"
|
||||
assert imports[0].default_import == "fs"
|
||||
|
||||
def test_require_inside_function_not_import(self, js_analyzer):
|
||||
"""Test that require() inside functions is not treated as an import.
|
||||
|
||||
This is important because dynamic require() calls inside functions are
|
||||
not module-level imports and should not be extracted as such.
|
||||
"""
|
||||
code = """
|
||||
const fs = require('fs');
|
||||
|
||||
function loadModule() {
|
||||
const dynamic = require('dynamic-module');
|
||||
return dynamic;
|
||||
}
|
||||
|
||||
class MyClass {
|
||||
method() {
|
||||
const inMethod = require('method-module');
|
||||
}
|
||||
}
|
||||
"""
|
||||
imports = js_analyzer.find_imports(code)
|
||||
|
||||
# Only the module-level require should be found
|
||||
assert len(imports) == 1
|
||||
assert imports[0].module_path == "fs"
|
||||
|
||||
def test_find_multiple_imports(self, js_analyzer):
|
||||
"""Test finding multiple imports."""
|
||||
code = """
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue