Merge branch 'main' into jit-docs

This commit is contained in:
Aseem Saxena 2026-01-29 12:06:54 -08:00 committed by GitHub
commit d020da8294
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
266 changed files with 61765 additions and 3145 deletions

41
.github/workflows/codeflash.yaml vendored Normal file
View file

@ -0,0 +1,41 @@
name: Codeflash
on:
pull_request:
paths:
# So that this workflow only runs when code within the target module is modified
- 'code_to_optimize_js_esm/**'
workflow_dispatch:
concurrency:
# Any new push to the PR will cancel the previous run, so that only the latest code is optimized
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
optimize:
name: Optimize new code
# Don't run codeflash on codeflash-ai[bot] commits, prevent duplicate optimizations
if: ${{ github.actor != 'codeflash-ai[bot]' }}
runs-on: ubuntu-latest
env:
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
defaults:
run:
working-directory: ./code_to_optimize_js_esm
steps:
- name: 🛎️ Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: 🟢 Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
cache: 'npm'
- name: 📦 Install Dependencies
run: npm ci
- name: ⚡️ Codeflash Optimization
run: npx codeflash

View file

@ -0,0 +1,88 @@
name: E2E - JS CommonJS Function
on:
pull_request:
paths:
- '**' # Trigger for all paths
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
jobs:
js-cjs-function-optimization:
# Dynamically determine if environment is needed only when workflow files change and contributor is external
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
runs-on: ubuntu-latest
env:
CODEFLASH_AIS_SERVER: prod
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
COLUMNS: 110
MAX_RETRIES: 3
RETRY_DELAY: 5
EXPECTED_IMPROVEMENT_PCT: 50
CODEFLASH_END_TO_END: 1
steps:
- name: 🛎️ Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Validate PR
run: |
# Check for any workflow changes
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
# Get the PR author
AUTHOR="${{ github.event.pull_request.user.login }}"
echo "PR Author: $AUTHOR"
# Allowlist check
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($AUTHOR). Proceeding."
elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
else
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
else
echo "✅ No workflow file changes detected. Proceeding."
fi
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Install codeflash npm package dependencies
run: |
cd packages/codeflash
npm install
- name: Install JS test project dependencies
run: |
cd code_to_optimize/js/code_to_optimize_js
npm install
- name: Set up Python 3.11 for CLI
uses: astral-sh/setup-uv@v6
with:
python-version: 3.11.6
- name: Install dependencies (CLI)
run: |
uv sync
- name: Run Codeflash to optimize JS CommonJS function
id: optimize_code
run: |
uv run python tests/scripts/end_to_end_test_js_cjs_function.py

88
.github/workflows/e2e-js-esm-async.yaml vendored Normal file
View file

@ -0,0 +1,88 @@
name: E2E - JS ESM Async
on:
pull_request:
paths:
- '**' # Trigger for all paths
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
jobs:
js-esm-async-optimization:
# Dynamically determine if environment is needed only when workflow files change and contributor is external
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
runs-on: ubuntu-latest
env:
CODEFLASH_AIS_SERVER: prod
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
COLUMNS: 110
MAX_RETRIES: 3
RETRY_DELAY: 5
EXPECTED_IMPROVEMENT_PCT: 10
CODEFLASH_END_TO_END: 1
steps:
- name: 🛎️ Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Validate PR
run: |
# Check for any workflow changes
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
# Get the PR author
AUTHOR="${{ github.event.pull_request.user.login }}"
echo "PR Author: $AUTHOR"
# Allowlist check
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($AUTHOR). Proceeding."
elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
else
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
else
echo "✅ No workflow file changes detected. Proceeding."
fi
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Install codeflash npm package dependencies
run: |
cd packages/codeflash
npm install
- name: Install JS test project dependencies
run: |
cd code_to_optimize/js/code_to_optimize_js_esm
npm install
- name: Set up Python 3.11 for CLI
uses: astral-sh/setup-uv@v6
with:
python-version: 3.11.6
- name: Install dependencies (CLI)
run: |
uv sync
- name: Run Codeflash to optimize ESM async function
id: optimize_code
run: |
uv run python tests/scripts/end_to_end_test_js_esm_async.py

88
.github/workflows/e2e-js-ts-class.yaml vendored Normal file
View file

@ -0,0 +1,88 @@
name: E2E - JS TypeScript Class
on:
pull_request:
paths:
- '**' # Trigger for all paths
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
jobs:
js-ts-class-optimization:
# Dynamically determine if environment is needed only when workflow files change and contributor is external
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
runs-on: ubuntu-latest
env:
CODEFLASH_AIS_SERVER: prod
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
COLUMNS: 110
MAX_RETRIES: 3
RETRY_DELAY: 5
EXPECTED_IMPROVEMENT_PCT: 30
CODEFLASH_END_TO_END: 1
steps:
- name: 🛎️ Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Validate PR
run: |
# Check for any workflow changes
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
# Get the PR author
AUTHOR="${{ github.event.pull_request.user.login }}"
echo "PR Author: $AUTHOR"
# Allowlist check
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($AUTHOR). Proceeding."
elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
else
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
else
echo "✅ No workflow file changes detected. Proceeding."
fi
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Install codeflash npm package dependencies
run: |
cd packages/codeflash
npm install
- name: Install JS test project dependencies
run: |
cd code_to_optimize/js/code_to_optimize_ts
npm install
- name: Set up Python 3.11 for CLI
uses: astral-sh/setup-uv@v6
with:
python-version: 3.11.6
- name: Install dependencies (CLI)
run: |
uv sync
- name: Run Codeflash to optimize TypeScript class method
id: optimize_code
run: |
uv run python tests/scripts/end_to_end_test_js_ts_class.py

4
.gitignore vendored
View file

@ -163,7 +163,6 @@ cython_debug/
#.idea/
.aider*
/js/common/node_modules/
/node_modules/
*.xml
*.pem
@ -259,6 +258,9 @@ WARP.MD
.mcp.json
.tessl/
tessl.json
**/node_modules/**
**/dist-nuitka/**
**/.npmrc
# Tessl auto-generates AGENTS.md on install; ignore to avoid cluttering git status
AGENTS.md

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,49 @@
/**
* Bubble sort implementation - intentionally inefficient for optimization testing.
*/
/**
* Sort an array using bubble sort algorithm.
* @param {number[]} arr - The array to sort
* @returns {number[]} - The sorted array
*/
function bubbleSort(arr) {
const result = arr.slice();
const n = result.length;
for (let i = 0; i < n; i++) {
for (let j = 0; j < n - 1; j++) {
if (result[j] > result[j + 1]) {
const temp = result[j];
result[j] = result[j + 1];
result[j + 1] = temp;
}
}
}
return result;
}
/**
* Sort an array in descending order.
* @param {number[]} arr - The array to sort
* @returns {number[]} - The sorted array in descending order
*/
function bubbleSortDescending(arr) {
const n = arr.length;
const result = [...arr];
for (let i = 0; i < n - 1; i++) {
for (let j = 0; j < n - i - 1; j++) {
if (result[j] < result[j + 1]) {
const temp = result[j];
result[j] = result[j + 1];
result[j + 1] = temp;
}
}
}
return result;
}
module.exports = { bubbleSort, bubbleSortDescending };

View file

@ -0,0 +1,85 @@
/**
* Calculator module - demonstrates cross-file function calls.
* Uses helper functions from math_helpers.js.
*/
const { sumArray, average, findMax, findMin } = require('./math_helpers');
/**
* Calculate statistics for an array of numbers.
* @param numbers - Array of numbers to analyze
* @returns Object containing sum, average, min, max, and range
*/
function calculateStats(numbers) {
if (numbers.length === 0) {
return {
sum: 0,
average: 0,
min: 0,
max: 0,
range: 0
};
}
const sum = sumArray(numbers);
const avg = average(numbers);
const min = findMin(numbers);
const max = findMax(numbers);
const range = max - min;
return {
sum,
average: avg,
min,
max,
range
};
}
/**
* Normalize an array of numbers to a 0-1 range.
* @param numbers - Array of numbers to normalize
* @returns Normalized array
*/
function normalizeArray(numbers) {
if (numbers.length === 0) return [];
const min = findMin(numbers);
const max = findMax(numbers);
const range = max - min;
if (range === 0) {
return numbers.map(() => 0.5);
}
return numbers.map(n => (n - min) / range);
}
/**
* Calculate the weighted average of values with corresponding weights.
* @param values - Array of values
* @param weights - Array of weights (same length as values)
* @returns The weighted average
*/
function weightedAverage(values, weights) {
if (values.length === 0 || values.length !== weights.length) {
return 0;
}
let weightedSum = 0;
for (let i = 0; i < values.length; i++) {
weightedSum += values[i] * weights[i];
}
const totalWeight = sumArray(weights);
if (totalWeight === 0) return 0;
return weightedSum / totalWeight;
}
module.exports = {
calculateStats,
normalizeArray,
weightedAverage
};

View file

@ -0,0 +1,54 @@
/**
* Fibonacci implementations - intentionally inefficient for optimization testing.
*/
/**
* Calculate the nth Fibonacci number using naive recursion.
* This is intentionally slow to demonstrate optimization potential.
* @param {number} n - The index of the Fibonacci number to calculate
* @returns {number} - The nth Fibonacci number
*/
function fibonacci(n) {
if (n <= 1) {
return n;
}
return fibonacci(n - 1) + fibonacci(n - 2);
}
/**
* Check if a number is a Fibonacci number.
* @param {number} num - The number to check
* @returns {boolean} - True if num is a Fibonacci number
*/
function isFibonacci(num) {
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
const check1 = 5 * num * num + 4;
const check2 = 5 * num * num - 4;
return isPerfectSquare(check1) || isPerfectSquare(check2);
}
/**
* Check if a number is a perfect square.
* @param {number} n - The number to check
* @returns {boolean} - True if n is a perfect square
*/
function isPerfectSquare(n) {
const sqrt = Math.sqrt(n);
return sqrt === Math.floor(sqrt);
}
/**
* Generate an array of Fibonacci numbers up to n.
* @param {number} n - The number of Fibonacci numbers to generate
* @returns {number[]} - Array of Fibonacci numbers
*/
function fibonacciSequence(n) {
const result = [];
for (let i = 0; i < n; i++) {
result.push(fibonacci(i));
}
return result;
}
module.exports = { fibonacci, isFibonacci, isPerfectSquare, fibonacciSequence };

View file

@ -0,0 +1,61 @@
/**
* Math helper functions - used by other modules.
* Some implementations are intentionally inefficient for optimization testing.
*/
/**
* Calculate the sum of an array of numbers.
* @param numbers - Array of numbers to sum
* @returns The sum of all numbers
*/
function sumArray(numbers) {
// Intentionally inefficient - using reduce with spread operator
let result = 0;
for (let i = 0; i < numbers.length; i++) {
result = result + numbers[i];
}
return result;
}
/**
* Calculate the average of an array of numbers.
* @param numbers - Array of numbers
* @returns The average value
*/
function average(numbers) {
if (numbers.length === 0) return 0;
return sumArray(numbers) / numbers.length;
}
/**
* Find the maximum value in an array.
* @param numbers - Array of numbers
* @returns The maximum value
*/
function findMax(numbers) {
if (numbers.length === 0) return -Infinity;
// Intentionally inefficient - sorting instead of linear scan
const sorted = [...numbers].sort((a, b) => b - a);
return sorted[0];
}
/**
* Find the minimum value in an array.
* @param numbers - Array of numbers
* @returns The minimum value
*/
function findMin(numbers) {
if (numbers.length === 0) return Infinity;
// Intentionally inefficient - sorting instead of linear scan
const sorted = [...numbers].sort((a, b) => a - b);
return sorted[0];
}
module.exports = {
sumArray,
average,
findMax,
findMin
};

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,35 @@
{
"name": "codeflash-js-test",
"version": "1.0.0",
"description": "Sample JavaScript project for codeflash optimization testing",
"main": "index.js",
"scripts": {
"test": "jest"
},
"codeflash": {
"moduleRoot": ".",
"testsRoot": "tests"
},
"devDependencies": {
"codeflash": "file:../../../packages/codeflash",
"jest": "^29.7.0",
"jest-junit": "^16.0.0"
},
"jest": {
"testEnvironment": "node",
"testMatch": [
"**/tests/**/*.test.js"
],
"reporters": [
"default",
[
"jest-junit",
{
"outputDirectory": ".codeflash",
"outputName": "jest-results.xml",
"includeConsoleOutput": true
}
]
]
}
}

View file

@ -0,0 +1,95 @@
/**
* String utility functions - some intentionally inefficient for optimization testing.
*/
/**
* Reverse a string character by character.
* @param {string} str - The string to reverse
* @returns {string} - The reversed string
*/
function reverseString(str) {
// Intentionally inefficient O(n²) implementation for testing
let result = '';
for (let i = str.length - 1; i >= 0; i--) {
// Rebuild the entire result string each iteration (very inefficient)
let temp = '';
for (let j = 0; j < result.length; j++) {
temp += result[j];
}
temp += str[i];
result = temp;
}
return result;
}
/**
* Check if a string is a palindrome.
* @param {string} str - The string to check
* @returns {boolean} - True if str is a palindrome
*/
function isPalindrome(str) {
const cleaned = str.toLowerCase().replace(/[^a-z0-9]/g, '');
return cleaned === reverseString(cleaned);
}
/**
* Count occurrences of a substring in a string.
* @param {string} str - The string to search in
* @param {string} sub - The substring to count
* @returns {number} - Number of occurrences
*/
function countOccurrences(str, sub) {
let count = 0;
let pos = 0;
while (true) {
pos = str.indexOf(sub, pos);
if (pos === -1) break;
count++;
pos += 1; // Move past current match
}
return count;
}
/**
* Find the longest common prefix of an array of strings.
* @param {string[]} strs - Array of strings
* @returns {string} - The longest common prefix
*/
function longestCommonPrefix(strs) {
if (strs.length === 0) return '';
if (strs.length === 1) return strs[0];
let prefix = strs[0];
for (let i = 1; i < strs.length; i++) {
while (strs[i].indexOf(prefix) !== 0) {
prefix = prefix.slice(0, -1);
if (prefix === '') return '';
}
}
return prefix;
}
/**
* Convert a string to title case.
* @param {string} str - The string to convert
* @returns {string} - The title-cased string
*/
function toTitleCase(str) {
return str
.toLowerCase()
.split(' ')
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
.join(' ');
}
module.exports = {
reverseString,
isPalindrome,
countOccurrences,
longestCommonPrefix,
toTitleCase
};

View file

@ -0,0 +1,70 @@
const { bubbleSort, bubbleSortDescending } = require('../bubble_sort');
describe('bubbleSort', () => {
test('sorts an empty array', () => {
expect(bubbleSort([])).toEqual([]);
});
test('sorts a single element array', () => {
expect(bubbleSort([1])).toEqual([1]);
});
test('sorts an already sorted array', () => {
expect(bubbleSort([1, 2, 3, 4, 5])).toEqual([1, 2, 3, 4, 5]);
});
test('sorts a reverse sorted array', () => {
expect(bubbleSort([5, 4, 3, 2, 1])).toEqual([1, 2, 3, 4, 5]);
});
test('sorts an array with duplicates', () => {
expect(bubbleSort([3, 1, 4, 1, 5, 9, 2, 6])).toEqual([1, 1, 2, 3, 4, 5, 6, 9]);
});
test('sorts negative numbers', () => {
expect(bubbleSort([-3, -1, -4, -1, -5])).toEqual([-5, -4, -3, -1, -1]);
});
test('does not mutate original array', () => {
const original = [3, 1, 2];
bubbleSort(original);
expect(original).toEqual([3, 1, 2]);
});
test('sorts a larger reverse sorted array for performance', () => {
const input = [];
for (let i = 500; i >= 0; i--) {
input.push(i);
}
const result = bubbleSort(input);
expect(result[0]).toBe(0);
expect(result[result.length - 1]).toBe(500);
});
test('sorts a larger random array for performance', () => {
const input = [
42, 17, 93, 8, 67, 31, 55, 22, 89, 4,
76, 12, 39, 58, 95, 26, 71, 48, 83, 19,
64, 3, 88, 37, 52, 11, 79, 46, 91, 28,
63, 7, 84, 33, 57, 14, 72, 41, 96, 24,
69, 6, 81, 36, 54, 16, 77, 44, 90, 29
];
const result = bubbleSort(input);
expect(result[0]).toBe(3);
expect(result[result.length - 1]).toBe(96);
});
});
describe('bubbleSortDescending', () => {
test('sorts in descending order', () => {
expect(bubbleSortDescending([1, 3, 2, 5, 4])).toEqual([5, 4, 3, 2, 1]);
});
test('handles empty array', () => {
expect(bubbleSortDescending([])).toEqual([]);
});
test('handles single element', () => {
expect(bubbleSortDescending([42])).toEqual([42]);
});
});

View file

@ -0,0 +1,470 @@
/**
* End-to-End Behavior Comparison Test
*
* This test verifies that:
* 1. The instrumentation correctly captures function behavior (args + return value)
* 2. Serialization/deserialization preserves all value types
* 3. The comparator correctly identifies equivalent behaviors
*
* It simulates what happens during optimization verification:
* - Run the same tests twice (original vs optimized) with different LOOP_INDEX
* - Store results to different locations
* - Compare the serialized values using the comparator
*/
const fs = require('fs');
const path = require('path');
const { execSync, spawn } = require('child_process');
// Import our modules from npm package
const { serialize, deserialize, getSerializerType, comparator } = require('codeflash');
// Test output directory
const TEST_OUTPUT_DIR = '/tmp/codeflash_e2e_test';
// Sample functions to test with various return types
const testFunctions = {
// Primitives
returnNumber: (x) => x * 2,
returnString: (s) => s.toUpperCase(),
returnBoolean: (x) => x > 0,
returnNull: () => null,
returnUndefined: () => undefined,
// Special numbers
returnNaN: () => NaN,
returnInfinity: () => Infinity,
returnNegInfinity: () => -Infinity,
// Complex types
returnArray: (arr) => arr.map(x => x * 2),
returnObject: (obj) => ({ ...obj, processed: true }),
returnMap: (entries) => new Map(entries),
returnSet: (values) => new Set(values),
returnDate: (ts) => new Date(ts),
returnRegExp: (pattern, flags) => new RegExp(pattern, flags),
// Nested structures
returnNested: (data) => ({
array: [1, 2, 3],
map: new Map([['key', data]]),
set: new Set([data]),
date: new Date('2024-01-15'),
}),
// TypedArrays
returnTypedArray: (data) => new Float64Array(data),
// Error handling
mayThrow: (shouldThrow) => {
if (shouldThrow) throw new Error('Test error');
return 'success';
},
};
describe('E2E Behavior Comparison', () => {
beforeAll(() => {
// Clean up and create test directory
if (fs.existsSync(TEST_OUTPUT_DIR)) {
fs.rmSync(TEST_OUTPUT_DIR, { recursive: true });
}
fs.mkdirSync(TEST_OUTPUT_DIR, { recursive: true });
console.log('Using serializer:', getSerializerType());
});
afterAll(() => {
// Cleanup
if (fs.existsSync(TEST_OUTPUT_DIR)) {
fs.rmSync(TEST_OUTPUT_DIR, { recursive: true });
}
});
describe('Direct Serialization Round-Trip', () => {
// Test that serialize -> deserialize -> compare works for all types
test('primitives round-trip correctly', () => {
const testCases = [
42,
-3.14159,
'hello world',
true,
false,
null,
undefined,
BigInt('9007199254740991'),
];
for (const original of testCases) {
const serialized = serialize(original);
const restored = deserialize(serialized);
expect(comparator(original, restored)).toBe(true);
}
});
test('special numbers round-trip correctly', () => {
const testCases = [NaN, Infinity, -Infinity, -0];
for (const original of testCases) {
const serialized = serialize(original);
const restored = deserialize(serialized);
expect(comparator(original, restored)).toBe(true);
}
});
test('complex objects round-trip correctly', () => {
const testCases = [
new Map([['a', 1], ['b', 2]]),
new Set([1, 2, 3]),
new Date('2024-01-15'),
/test\d+/gi,
new Error('test error'),
new Float64Array([1.1, 2.2, 3.3]),
];
for (const original of testCases) {
const serialized = serialize(original);
const restored = deserialize(serialized);
expect(comparator(original, restored)).toBe(true);
}
});
test('nested structures round-trip correctly', () => {
const original = {
array: [1, 'two', { three: 3 }],
map: new Map([['nested', new Set([1, 2, 3])]]),
date: new Date('2024-06-15'),
regex: /pattern/i,
typed: new Int32Array([10, 20, 30]),
};
const serialized = serialize(original);
const restored = deserialize(serialized);
expect(comparator(original, restored)).toBe(true);
});
});
describe('Function Behavior Format', () => {
// Test the [args, kwargs, return_value] format used by instrumentation
test('behavior tuple format serializes correctly', () => {
// Simulate what recordResult does: [args, {}, returnValue]
const args = [42, 'hello'];
const kwargs = {}; // JS doesn't have kwargs, always empty
const returnValue = { result: 84, message: 'HELLO' };
const behaviorTuple = [args, kwargs, returnValue];
const serialized = serialize(behaviorTuple);
const restored = deserialize(serialized);
expect(comparator(behaviorTuple, restored)).toBe(true);
expect(restored[0]).toEqual(args);
expect(restored[1]).toEqual(kwargs);
expect(comparator(restored[2], returnValue)).toBe(true);
});
test('behavior with Map return value', () => {
const args = [['a', 1], ['b', 2]];
const returnValue = new Map(args);
const behaviorTuple = [args, {}, returnValue];
const serialized = serialize(behaviorTuple);
const restored = deserialize(serialized);
expect(comparator(behaviorTuple, restored)).toBe(true);
expect(restored[2] instanceof Map).toBe(true);
expect(restored[2].get('a')).toBe(1);
});
test('behavior with Set return value', () => {
const args = [[1, 2, 3]];
const returnValue = new Set([1, 2, 3]);
const behaviorTuple = [args, {}, returnValue];
const serialized = serialize(behaviorTuple);
const restored = deserialize(serialized);
expect(comparator(behaviorTuple, restored)).toBe(true);
expect(restored[2] instanceof Set).toBe(true);
expect(restored[2].has(2)).toBe(true);
});
test('behavior with Date return value', () => {
const args = [1705276800000]; // 2024-01-15
const returnValue = new Date(1705276800000);
const behaviorTuple = [args, {}, returnValue];
const serialized = serialize(behaviorTuple);
const restored = deserialize(serialized);
expect(comparator(behaviorTuple, restored)).toBe(true);
expect(restored[2] instanceof Date).toBe(true);
expect(restored[2].getTime()).toBe(1705276800000);
});
test('behavior with TypedArray return value', () => {
const args = [[1.1, 2.2, 3.3]];
const returnValue = new Float64Array([1.1, 2.2, 3.3]);
const behaviorTuple = [args, {}, returnValue];
const serialized = serialize(behaviorTuple);
const restored = deserialize(serialized);
expect(comparator(behaviorTuple, restored)).toBe(true);
expect(restored[2] instanceof Float64Array).toBe(true);
});
test('behavior with Error (exception case)', () => {
const error = new TypeError('Invalid argument');
const serialized = serialize(error);
const restored = deserialize(serialized);
expect(comparator(error, restored)).toBe(true);
expect(restored.name).toBe('TypeError');
expect(restored.message).toBe('Invalid argument');
});
});
describe('Simulated Original vs Optimized Comparison', () => {
// Simulate running the same function twice and comparing results
function runAndCapture(fn, args) {
try {
const returnValue = fn(...args);
return { success: true, value: [args, {}, returnValue] };
} catch (error) {
return { success: false, error };
}
}
test('identical behaviors are equal - number function', () => {
const fn = testFunctions.returnNumber;
const args = [21];
// "Original" run
const original = runAndCapture(fn, args);
const originalSerialized = serialize(original.value);
// "Optimized" run (same function, simulating optimization)
const optimized = runAndCapture(fn, args);
const optimizedSerialized = serialize(optimized.value);
// Deserialize and compare (what verification does)
const originalRestored = deserialize(originalSerialized);
const optimizedRestored = deserialize(optimizedSerialized);
expect(comparator(originalRestored, optimizedRestored)).toBe(true);
});
test('identical behaviors are equal - Map function', () => {
const fn = testFunctions.returnMap;
const args = [[['x', 10], ['y', 20]]];
const original = runAndCapture(fn, args);
const originalSerialized = serialize(original.value);
const optimized = runAndCapture(fn, args);
const optimizedSerialized = serialize(optimized.value);
const originalRestored = deserialize(originalSerialized);
const optimizedRestored = deserialize(optimizedSerialized);
expect(comparator(originalRestored, optimizedRestored)).toBe(true);
});
test('identical behaviors are equal - nested structure', () => {
const fn = testFunctions.returnNested;
const args = ['test-data'];
const original = runAndCapture(fn, args);
const originalSerialized = serialize(original.value);
const optimized = runAndCapture(fn, args);
const optimizedSerialized = serialize(optimized.value);
const originalRestored = deserialize(originalSerialized);
const optimizedRestored = deserialize(optimizedSerialized);
expect(comparator(originalRestored, optimizedRestored)).toBe(true);
});
test('different behaviors are NOT equal', () => {
const fn1 = (x) => x * 2;
const fn2 = (x) => x * 3; // Different behavior!
const args = [10];
const original = runAndCapture(fn1, args);
const originalSerialized = serialize(original.value);
const optimized = runAndCapture(fn2, args);
const optimizedSerialized = serialize(optimized.value);
const originalRestored = deserialize(originalSerialized);
const optimizedRestored = deserialize(optimizedSerialized);
// Should be FALSE - behaviors differ (20 vs 30)
expect(comparator(originalRestored, optimizedRestored)).toBe(false);
});
test('floating point tolerance works', () => {
// Simulate slight floating point differences from optimization
const original = [[[1.0]], {}, 0.30000000000000004];
const optimized = [[[1.0]], {}, 0.3];
const originalSerialized = serialize(original);
const optimizedSerialized = serialize(optimized);
const originalRestored = deserialize(originalSerialized);
const optimizedRestored = deserialize(optimizedSerialized);
// Should be TRUE with default tolerance
expect(comparator(originalRestored, optimizedRestored)).toBe(true);
});
});
describe('Multiple Invocations Comparison', () => {
// Simulate multiple test invocations being stored and compared
test('batch of invocations can be compared', () => {
const testCases = [
{ fn: testFunctions.returnNumber, args: [1] },
{ fn: testFunctions.returnNumber, args: [100] },
{ fn: testFunctions.returnString, args: ['hello'] },
{ fn: testFunctions.returnArray, args: [[1, 2, 3]] },
{ fn: testFunctions.returnMap, args: [[['a', 1]]] },
{ fn: testFunctions.returnSet, args: [[1, 2, 3]] },
{ fn: testFunctions.returnDate, args: [1705276800000] },
{ fn: testFunctions.returnNested, args: ['data'] },
];
// Simulate original run
const originalResults = testCases.map(({ fn, args }) => {
const returnValue = fn(...args);
return serialize([args, {}, returnValue]);
});
// Simulate optimized run (same functions)
const optimizedResults = testCases.map(({ fn, args }) => {
const returnValue = fn(...args);
return serialize([args, {}, returnValue]);
});
// Compare all results
for (let i = 0; i < testCases.length; i++) {
const originalRestored = deserialize(originalResults[i]);
const optimizedRestored = deserialize(optimizedResults[i]);
expect(comparator(originalRestored, optimizedRestored)).toBe(true);
}
});
});
describe('File-Based Comparison (SQLite Simulation)', () => {
// Simulate writing to files and reading back for comparison
test('can write and read back serialized results', () => {
const originalPath = path.join(TEST_OUTPUT_DIR, 'original.bin');
const optimizedPath = path.join(TEST_OUTPUT_DIR, 'optimized.bin');
// Test data
const behaviorData = {
args: [42, 'test', { nested: true }],
kwargs: {},
returnValue: {
result: new Map([['answer', 42]]),
metadata: new Set(['processed', 'validated']),
timestamp: new Date('2024-01-15'),
},
};
const tuple = [behaviorData.args, behaviorData.kwargs, behaviorData.returnValue];
// Write "original" result
const originalBuffer = serialize(tuple);
fs.writeFileSync(originalPath, originalBuffer);
// Write "optimized" result (same data, simulating correct optimization)
const optimizedBuffer = serialize(tuple);
fs.writeFileSync(optimizedPath, optimizedBuffer);
// Read back and compare
const originalRead = fs.readFileSync(originalPath);
const optimizedRead = fs.readFileSync(optimizedPath);
const originalRestored = deserialize(originalRead);
const optimizedRestored = deserialize(optimizedRead);
expect(comparator(originalRestored, optimizedRestored)).toBe(true);
// Verify the complex types survived
expect(originalRestored[2].result instanceof Map).toBe(true);
expect(originalRestored[2].metadata instanceof Set).toBe(true);
expect(originalRestored[2].timestamp instanceof Date).toBe(true);
});
test('detects differences in file-based comparison', () => {
const originalPath = path.join(TEST_OUTPUT_DIR, 'original2.bin');
const optimizedPath = path.join(TEST_OUTPUT_DIR, 'optimized2.bin');
// Original behavior
const originalTuple = [[10], {}, 100];
fs.writeFileSync(originalPath, serialize(originalTuple));
// "Buggy" optimized behavior
const optimizedTuple = [[10], {}, 99]; // Wrong result!
fs.writeFileSync(optimizedPath, serialize(optimizedTuple));
// Read back and compare
const originalRestored = deserialize(fs.readFileSync(originalPath));
const optimizedRestored = deserialize(fs.readFileSync(optimizedPath));
// Should detect the difference
expect(comparator(originalRestored, optimizedRestored)).toBe(false);
});
});
describe('Edge Cases', () => {
test('handles special values in args', () => {
const tuple = [[NaN, Infinity, undefined, null], {}, 'processed'];
const serialized = serialize(tuple);
const restored = deserialize(serialized);
expect(comparator(tuple, restored)).toBe(true);
expect(Number.isNaN(restored[0][0])).toBe(true);
expect(restored[0][1]).toBe(Infinity);
expect(restored[0][2]).toBe(undefined);
expect(restored[0][3]).toBe(null);
});
test('handles circular references in return value', () => {
const obj = { value: 42 };
obj.self = obj; // Circular reference
const tuple = [[], {}, obj];
const serialized = serialize(tuple);
const restored = deserialize(serialized);
expect(comparator(tuple, restored)).toBe(true);
expect(restored[2].self).toBe(restored[2]);
});
test('handles empty results', () => {
const tuple = [[], {}, undefined];
const serialized = serialize(tuple);
const restored = deserialize(serialized);
expect(comparator(tuple, restored)).toBe(true);
});
test('handles large arrays', () => {
const largeArray = Array.from({ length: 1000 }, (_, i) => i);
const tuple = [[largeArray], {}, largeArray.reduce((a, b) => a + b, 0)];
const serialized = serialize(tuple);
const restored = deserialize(serialized);
expect(comparator(tuple, restored)).toBe(true);
});
});
});

View file

@ -0,0 +1,354 @@
#!/usr/bin/env node
/**
* End-to-End Comparison Test
*
* This test validates the full behavior comparison workflow:
* 1. Serialize test results to SQLite (simulating codeflash-jest-helper)
* 2. Run the comparison script
* 3. Verify results match expectations
*/
const fs = require('fs');
const path = require('path');
// Import our modules from npm package
const { serialize, readTestResults, compareResults } = require('codeflash');
// Try to load better-sqlite3
let Database;
try {
Database = require('better-sqlite3');
} catch (e) {
console.error('better-sqlite3 not installed, skipping E2E test');
process.exit(0);
}
const TEST_DIR = '/tmp/codeflash_e2e_comparison_test';
/**
* Create a SQLite database with test results.
*/
function createTestDatabase(dbPath, results) {
// Ensure directory exists
const dir = path.dirname(dbPath);
if (!fs.existsSync(dir)) {
fs.mkdirSync(dir, { recursive: true });
}
// Remove existing file
if (fs.existsSync(dbPath)) {
fs.unlinkSync(dbPath);
}
const db = new Database(dbPath);
// Create table
db.exec(`
CREATE TABLE test_results (
test_module_path TEXT,
test_class_name TEXT,
test_function_name TEXT,
function_getting_tested TEXT,
loop_index INTEGER,
iteration_id TEXT,
runtime INTEGER,
return_value BLOB,
verification_type TEXT
)
`);
// Insert results
const stmt = db.prepare(`
INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
`);
for (const result of results) {
stmt.run(
result.testModulePath,
result.testClassName || null,
result.testFunctionName,
result.functionGettingTested,
result.loopIndex,
result.iterationId,
result.runtime,
result.returnValue ? serialize(result.returnValue) : null,
result.verificationType || 'function_call'
);
}
db.close();
return dbPath;
}
/**
* Test 1: Identical results should be equivalent.
*/
function testIdenticalResults() {
console.log('\n=== Test 1: Identical Results ===');
const results = [
{
testModulePath: 'tests/math.test.js',
testFunctionName: 'test adds numbers',
functionGettingTested: 'add',
loopIndex: 1,
iterationId: '0_0',
runtime: 1000,
returnValue: [[1, 2], {}, 3], // [args, kwargs, returnValue]
},
{
testModulePath: 'tests/math.test.js',
testFunctionName: 'test multiplies numbers',
functionGettingTested: 'multiply',
loopIndex: 1,
iterationId: '0_1',
runtime: 1000,
returnValue: [[2, 3], {}, 6],
},
];
const originalDb = createTestDatabase(path.join(TEST_DIR, 'original1.sqlite'), results);
const candidateDb = createTestDatabase(path.join(TEST_DIR, 'candidate1.sqlite'), results);
const originalResults = readTestResults(originalDb);
const candidateResults = readTestResults(candidateDb);
const comparison = compareResults(originalResults, candidateResults);
console.log(` Original invocations: ${originalResults.size}`);
console.log(` Candidate invocations: ${candidateResults.size}`);
console.log(` Equivalent: ${comparison.equivalent}`);
console.log(` Diffs: ${comparison.diffs.length}`);
if (!comparison.equivalent || comparison.diffs.length > 0) {
console.log(' ❌ FAILED: Expected identical results to be equivalent');
return false;
}
console.log(' ✅ PASSED');
return true;
}
/**
* Test 2: Different return values should NOT be equivalent.
*/
function testDifferentReturnValues() {
console.log('\n=== Test 2: Different Return Values ===');
const originalResults = [
{
testModulePath: 'tests/math.test.js',
testFunctionName: 'test adds numbers',
functionGettingTested: 'add',
loopIndex: 1,
iterationId: '0_0',
runtime: 1000,
returnValue: [[1, 2], {}, 3], // Correct: 1 + 2 = 3
},
];
const candidateResults = [
{
testModulePath: 'tests/math.test.js',
testFunctionName: 'test adds numbers',
functionGettingTested: 'add',
loopIndex: 1,
iterationId: '0_0',
runtime: 1000,
returnValue: [[1, 2], {}, 4], // Wrong: should be 3, not 4
},
];
const originalDb = createTestDatabase(path.join(TEST_DIR, 'original2.sqlite'), originalResults);
const candidateDb = createTestDatabase(path.join(TEST_DIR, 'candidate2.sqlite'), candidateResults);
const original = readTestResults(originalDb);
const candidate = readTestResults(candidateDb);
const comparison = compareResults(original, candidate);
console.log(` Equivalent: ${comparison.equivalent}`);
console.log(` Diffs: ${comparison.diffs.length}`);
if (comparison.equivalent || comparison.diffs.length === 0) {
console.log(' ❌ FAILED: Expected different results to NOT be equivalent');
return false;
}
console.log(` Diff found: ${comparison.diffs[0].scope}`);
console.log(' ✅ PASSED');
return true;
}
/**
* Test 3: Complex JavaScript types (Map, Set, Date) should compare correctly.
*/
function testComplexTypes() {
console.log('\n=== Test 3: Complex JavaScript Types ===');
const complexValue = {
map: new Map([['a', 1], ['b', 2]]),
set: new Set([1, 2, 3]),
date: new Date('2024-01-15T00:00:00.000Z'),
nested: {
array: [1, 2, 3],
map: new Map([['nested', true]]),
},
};
const results = [
{
testModulePath: 'tests/complex.test.js',
testFunctionName: 'test complex return',
functionGettingTested: 'processData',
loopIndex: 1,
iterationId: '0_0',
runtime: 1000,
returnValue: [[], {}, complexValue],
},
];
const originalDb = createTestDatabase(path.join(TEST_DIR, 'original3.sqlite'), results);
const candidateDb = createTestDatabase(path.join(TEST_DIR, 'candidate3.sqlite'), results);
const original = readTestResults(originalDb);
const candidate = readTestResults(candidateDb);
const comparison = compareResults(original, candidate);
console.log(` Original invocations: ${original.size}`);
console.log(` Equivalent: ${comparison.equivalent}`);
console.log(` Diffs: ${comparison.diffs.length}`);
if (!comparison.equivalent) {
console.log(' ❌ FAILED: Expected complex types to be equivalent');
if (comparison.diffs.length > 0) {
console.log(` Diff: ${JSON.stringify(comparison.diffs[0])}`);
}
return false;
}
console.log(' ✅ PASSED');
return true;
}
/**
* Test 4: Floating point tolerance should allow small differences.
*/
function testFloatingPointTolerance() {
console.log('\n=== Test 4: Floating Point Tolerance ===');
const originalResults = [
{
testModulePath: 'tests/float.test.js',
testFunctionName: 'test float calculation',
functionGettingTested: 'calculate',
loopIndex: 1,
iterationId: '0_0',
runtime: 1000,
returnValue: [[], {}, 0.1 + 0.2], // 0.30000000000000004
},
];
const candidateResults = [
{
testModulePath: 'tests/float.test.js',
testFunctionName: 'test float calculation',
functionGettingTested: 'calculate',
loopIndex: 1,
iterationId: '0_0',
runtime: 1000,
returnValue: [[], {}, 0.3], // 0.3 (optimized calculation)
},
];
const originalDb = createTestDatabase(path.join(TEST_DIR, 'original4.sqlite'), originalResults);
const candidateDb = createTestDatabase(path.join(TEST_DIR, 'candidate4.sqlite'), candidateResults);
const original = readTestResults(originalDb);
const candidate = readTestResults(candidateDb);
const comparison = compareResults(original, candidate);
console.log(` Original value: ${0.1 + 0.2}`);
console.log(` Candidate value: ${0.3}`);
console.log(` Equivalent: ${comparison.equivalent}`);
if (!comparison.equivalent) {
console.log(' ❌ FAILED: Expected floating point values to be equivalent within tolerance');
return false;
}
console.log(' ✅ PASSED');
return true;
}
/**
* Test 5: NaN values should be equal to each other.
*/
function testNaNEquality() {
console.log('\n=== Test 5: NaN Equality ===');
const results = [
{
testModulePath: 'tests/nan.test.js',
testFunctionName: 'test NaN return',
functionGettingTested: 'divideByZero',
loopIndex: 1,
iterationId: '0_0',
runtime: 1000,
returnValue: [[], {}, NaN],
},
];
const originalDb = createTestDatabase(path.join(TEST_DIR, 'original5.sqlite'), results);
const candidateDb = createTestDatabase(path.join(TEST_DIR, 'candidate5.sqlite'), results);
const original = readTestResults(originalDb);
const candidate = readTestResults(candidateDb);
const comparison = compareResults(original, candidate);
console.log(` Equivalent: ${comparison.equivalent}`);
if (!comparison.equivalent) {
console.log(' ❌ FAILED: Expected NaN values to be equivalent');
return false;
}
console.log(' ✅ PASSED');
return true;
}
/**
* Main test runner.
*/
function main() {
console.log('='.repeat(60));
console.log('E2E Comparison Test Suite');
console.log('='.repeat(60));
// Setup
if (fs.existsSync(TEST_DIR)) {
fs.rmSync(TEST_DIR, { recursive: true });
}
fs.mkdirSync(TEST_DIR, { recursive: true });
const results = [];
results.push(testIdenticalResults());
results.push(testDifferentReturnValues());
results.push(testComplexTypes());
results.push(testFloatingPointTolerance());
results.push(testNaNEquality());
// Cleanup
fs.rmSync(TEST_DIR, { recursive: true });
// Summary
console.log('\n' + '='.repeat(60));
console.log('Summary');
console.log('='.repeat(60));
const passed = results.filter(r => r).length;
const total = results.length;
console.log(`Passed: ${passed}/${total}`);
if (passed === total) {
console.log('\n✅ ALL TESTS PASSED');
process.exit(0);
} else {
console.log('\n❌ SOME TESTS FAILED');
process.exit(1);
}
}
main();

View file

@ -0,0 +1,97 @@
const { fibonacci, isFibonacci, isPerfectSquare, fibonacciSequence } = require('../fibonacci');
describe('fibonacci', () => {
test('returns 0 for n=0', () => {
expect(fibonacci(0)).toBe(0);
});
test('returns 1 for n=1', () => {
expect(fibonacci(1)).toBe(1);
});
test('returns 1 for n=2', () => {
expect(fibonacci(2)).toBe(1);
});
test('returns 5 for n=5', () => {
expect(fibonacci(5)).toBe(5);
});
test('returns 55 for n=10', () => {
expect(fibonacci(10)).toBe(55);
});
test('returns 233 for n=13', () => {
expect(fibonacci(13)).toBe(233);
});
});
describe('isFibonacci', () => {
test('returns true for 0', () => {
expect(isFibonacci(0)).toBe(true);
});
test('returns true for 1', () => {
expect(isFibonacci(1)).toBe(true);
});
test('returns true for 8', () => {
expect(isFibonacci(8)).toBe(true);
});
test('returns true for 13', () => {
expect(isFibonacci(13)).toBe(true);
});
test('returns false for 4', () => {
expect(isFibonacci(4)).toBe(false);
});
test('returns false for 6', () => {
expect(isFibonacci(6)).toBe(false);
});
});
describe('isPerfectSquare', () => {
test('returns true for 0', () => {
expect(isPerfectSquare(0)).toBe(true);
});
test('returns true for 1', () => {
expect(isPerfectSquare(1)).toBe(true);
});
test('returns true for 4', () => {
expect(isPerfectSquare(4)).toBe(true);
});
test('returns true for 16', () => {
expect(isPerfectSquare(16)).toBe(true);
});
test('returns false for 2', () => {
expect(isPerfectSquare(2)).toBe(false);
});
test('returns false for 3', () => {
expect(isPerfectSquare(3)).toBe(false);
});
});
describe('fibonacciSequence', () => {
test('returns empty array for n=0', () => {
expect(fibonacciSequence(0)).toEqual([]);
});
test('returns [0] for n=1', () => {
expect(fibonacciSequence(1)).toEqual([0]);
});
test('returns first 5 Fibonacci numbers', () => {
expect(fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
});
test('returns first 10 Fibonacci numbers', () => {
expect(fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
});
});

View file

@ -0,0 +1,281 @@
#!/usr/bin/env node
/**
* Integration Test: Behavior Testing with Different Optimization Indices
*
* This script simulates the actual codeflash workflow:
* 1. Run tests with CODEFLASH_LOOP_INDEX=1 (original code)
* 2. Run tests with CODEFLASH_LOOP_INDEX=2 (optimized code)
* 3. Read back both result files
* 4. Compare using the comparator to verify equivalence
*
* Run directly: node tests/integration-behavior-test.js
*/
const fs = require('fs');
const path = require('path');
const { execSync } = require('child_process');
// Import our modules from npm package
const { serialize, deserialize, getSerializerType, comparator } = require('codeflash');
// Test configuration
const TEST_DIR = '/tmp/codeflash_integration_test';
const ORIGINAL_RESULTS = path.join(TEST_DIR, 'original_results.bin');
const OPTIMIZED_RESULTS = path.join(TEST_DIR, 'optimized_results.bin');
// Sample function to test - this simulates the "function being optimized"
function processData(input) {
// Original implementation
const result = {
numbers: input.numbers.map(n => n * 2),
sum: input.numbers.reduce((a, b) => a + b, 0),
metadata: new Map([
['processed', true],
['timestamp', new Date()],
]),
tags: new Set(input.tags || []),
};
return result;
}
// "Optimized" version - same behavior, different implementation
function processDataOptimized(input) {
// Optimized implementation (same behavior)
const doubled = [];
let sum = 0;
for (const n of input.numbers) {
doubled.push(n * 2);
sum += n;
}
return {
numbers: doubled,
sum,
metadata: new Map([
['processed', true],
['timestamp', new Date()],
]),
tags: new Set(input.tags || []),
};
}
// Test cases
const testCases = [
{ numbers: [1, 2, 3], tags: ['a', 'b'] },
{ numbers: [10, 20, 30, 40] },
{ numbers: [-5, 0, 5], tags: ['negative', 'zero', 'positive'] },
{ numbers: [1.5, 2.5, 3.5] },
{ numbers: [] },
];
// Helper to run a function and capture behavior
function captureAllBehaviors(fn, inputs) {
const results = [];
for (const input of inputs) {
try {
const returnValue = fn(input);
// Remove timestamp from metadata for comparison (it will differ)
if (returnValue.metadata) {
returnValue.metadata.delete('timestamp');
}
results.push({
success: true,
args: [input],
kwargs: {},
returnValue,
});
} catch (error) {
results.push({
success: false,
args: [input],
kwargs: {},
error: { name: error.name, message: error.message },
});
}
}
return results;
}
// Main test function
async function runIntegrationTest() {
console.log('='.repeat(60));
console.log('Integration Test: Behavior Comparison');
console.log('='.repeat(60));
console.log(`Serializer type: ${getSerializerType()}`);
console.log();
// Setup
if (fs.existsSync(TEST_DIR)) {
fs.rmSync(TEST_DIR, { recursive: true });
}
fs.mkdirSync(TEST_DIR, { recursive: true });
// Phase 1: Run "original" code (LOOP_INDEX=1)
console.log('Phase 1: Capturing original behavior...');
const originalBehaviors = captureAllBehaviors(processData, testCases);
const originalSerialized = serialize(originalBehaviors);
fs.writeFileSync(ORIGINAL_RESULTS, originalSerialized);
console.log(` - Captured ${originalBehaviors.length} invocations`);
console.log(` - Serialized size: ${originalSerialized.length} bytes`);
console.log(` - Saved to: ${ORIGINAL_RESULTS}`);
console.log();
// Phase 2: Run "optimized" code (LOOP_INDEX=2)
console.log('Phase 2: Capturing optimized behavior...');
const optimizedBehaviors = captureAllBehaviors(processDataOptimized, testCases);
const optimizedSerialized = serialize(optimizedBehaviors);
fs.writeFileSync(OPTIMIZED_RESULTS, optimizedSerialized);
console.log(` - Captured ${optimizedBehaviors.length} invocations`);
console.log(` - Serialized size: ${optimizedSerialized.length} bytes`);
console.log(` - Saved to: ${OPTIMIZED_RESULTS}`);
console.log();
// Phase 3: Read back and compare
console.log('Phase 3: Comparing behaviors...');
const originalRestored = deserialize(fs.readFileSync(ORIGINAL_RESULTS));
const optimizedRestored = deserialize(fs.readFileSync(OPTIMIZED_RESULTS));
console.log(` - Original results restored: ${originalRestored.length} invocations`);
console.log(` - Optimized results restored: ${optimizedRestored.length} invocations`);
console.log();
// Compare each invocation
let allEqual = true;
const comparisonResults = [];
for (let i = 0; i < originalRestored.length; i++) {
const orig = originalRestored[i];
const opt = optimizedRestored[i];
// Compare the behavior tuples
const isEqual = comparator(
[orig.args, orig.kwargs, orig.returnValue],
[opt.args, opt.kwargs, opt.returnValue]
);
comparisonResults.push({
invocation: i,
isEqual,
args: orig.args,
});
if (!isEqual) {
allEqual = false;
console.log(` ❌ Invocation ${i}: DIFFERENT`);
console.log(` Args: ${JSON.stringify(orig.args)}`);
} else {
console.log(` ✓ Invocation ${i}: EQUAL`);
}
}
console.log();
console.log('='.repeat(60));
if (allEqual) {
console.log('✅ SUCCESS: All behaviors are equivalent!');
console.log(' The optimization preserves correctness.');
} else {
console.log('❌ FAILURE: Some behaviors differ!');
console.log(' The optimization changed the behavior.');
}
console.log('='.repeat(60));
// Cleanup
fs.rmSync(TEST_DIR, { recursive: true });
// Return result for programmatic use
return { success: allEqual, results: comparisonResults };
}
// Also test with a "broken" optimization
async function runBrokenOptimizationTest() {
console.log();
console.log('='.repeat(60));
console.log('Testing detection of broken optimization...');
console.log('='.repeat(60));
// Setup
if (!fs.existsSync(TEST_DIR)) {
fs.mkdirSync(TEST_DIR, { recursive: true });
}
// Original function
const original = (x) => x * 2;
// "Broken" optimized function
const brokenOptimized = (x) => x * 2 + 1; // Bug: adds 1
const inputs = [1, 5, 10, 100];
// Capture original
const originalResults = inputs.map(x => ({
args: [x],
kwargs: {},
returnValue: original(x),
}));
// Capture broken optimized
const brokenResults = inputs.map(x => ({
args: [x],
kwargs: {},
returnValue: brokenOptimized(x),
}));
// Serialize
const originalSerialized = serialize(originalResults);
const brokenSerialized = serialize(brokenResults);
// Compare
const originalRestored = deserialize(originalSerialized);
const brokenRestored = deserialize(brokenSerialized);
let detectedBug = false;
for (let i = 0; i < originalRestored.length; i++) {
const isEqual = comparator(
[originalRestored[i].args, {}, originalRestored[i].returnValue],
[brokenRestored[i].args, {}, brokenRestored[i].returnValue]
);
if (!isEqual) {
detectedBug = true;
console.log(` ❌ Invocation ${i}: Difference detected`);
console.log(` Input: ${originalRestored[i].args[0]}`);
console.log(` Original: ${originalRestored[i].returnValue}`);
console.log(` Broken: ${brokenRestored[i].returnValue}`);
}
}
console.log();
if (detectedBug) {
console.log('✅ SUCCESS: Bug in optimization was detected!');
} else {
console.log('❌ FAILURE: Bug was not detected!');
}
console.log('='.repeat(60));
// Cleanup
if (fs.existsSync(TEST_DIR)) {
fs.rmSync(TEST_DIR, { recursive: true });
}
return { success: detectedBug };
}
// Run tests
async function main() {
try {
const result1 = await runIntegrationTest();
const result2 = await runBrokenOptimizationTest();
console.log();
console.log('='.repeat(60));
console.log('FINAL SUMMARY');
console.log('='.repeat(60));
console.log(`Correct optimization test: ${result1.success ? 'PASS' : 'FAIL'}`);
console.log(`Broken optimization detection: ${result2.success ? 'PASS' : 'FAIL'}`);
process.exit(result1.success && result2.success ? 0 : 1);
} catch (error) {
console.error('Test failed with error:', error);
process.exit(1);
}
}
main();

View file

@ -0,0 +1,294 @@
#!/usr/bin/env node
/**
* Codeflash Jest Loop Runner
*
* This script runs Jest tests multiple times to collect stable performance measurements.
* It mimics the Python pytest_plugin.py looping behavior.
*
* Usage:
* node loop-runner.js <test-file> [options]
*
* Options:
* --min-loops=N Minimum loops to run (default: 5)
* --max-loops=N Maximum loops to run (default: 100000)
* --duration=N Target duration in seconds (default: 10)
* --stability-check Enable stability-based early stopping
*/
const { spawn } = require('child_process');
const path = require('path');
// Configuration
const DEFAULT_MIN_LOOPS = 5;
const DEFAULT_MAX_LOOPS = 100000;
const DEFAULT_DURATION_SECONDS = 10;
const STABILITY_WINDOW_SIZE = 0.35;
const STABILITY_CENTER_TOLERANCE = 0.0025;
const STABILITY_SPREAD_TOLERANCE = 0.0025;
/**
* Parse timing data from Jest stdout.
* Looks for patterns like: !######test:func:1:lineId_0:123456######!
* where 123456 is the duration in nanoseconds.
*/
function parseTimingFromStdout(stdout) {
const timings = new Map(); // Map<testId, number[]>
const pattern = /!######([^:]+):([^:]*):([^:]+):([^:]+):(\d+_\d+):(\d+)######!/g;
let match;
while ((match = pattern.exec(stdout)) !== null) {
const [, testModule, testClass, testFunc, funcName, invocationId, durationNs] = match;
const testId = `${testModule}:${testClass}:${testFunc}:${funcName}:${invocationId}`;
if (!timings.has(testId)) {
timings.set(testId, []);
}
timings.get(testId).push(parseInt(durationNs, 10));
}
return timings;
}
/**
* Run Jest once and return timing data.
*/
async function runJestOnce(testFile, loopIndex, timeout, cwd) {
return new Promise((resolve, reject) => {
const env = {
...process.env,
CODEFLASH_LOOP_INDEX: String(loopIndex),
};
const jestArgs = [
'jest',
testFile,
'--runInBand',
'--forceExit',
`--testTimeout=${timeout * 1000}`,
];
const proc = spawn('npx', jestArgs, {
cwd,
env,
stdio: ['pipe', 'pipe', 'pipe'],
});
let stdout = '';
let stderr = '';
proc.stdout.on('data', (data) => {
stdout += data.toString();
});
proc.stderr.on('data', (data) => {
stderr += data.toString();
});
proc.on('close', (code) => {
resolve({
code,
stdout,
stderr,
timings: parseTimingFromStdout(stdout),
});
});
proc.on('error', reject);
});
}
/**
* Check if performance has stabilized.
* Implements the same stability check as Python's pytest_plugin.
*/
function shouldStopForStability(allTimings, windowSize) {
// Get total runtime for each loop
const loopTotals = [];
for (const [loopIndex, timings] of allTimings.entries()) {
let total = 0;
for (const durations of timings.values()) {
total += Math.min(...durations);
}
loopTotals.push(total);
}
if (loopTotals.length < windowSize) {
return false;
}
// Get recent window
const window = loopTotals.slice(-windowSize);
// Check center tolerance (all values within ±0.25% of median)
const sorted = [...window].sort((a, b) => a - b);
const median = sorted[Math.floor(sorted.length / 2)];
const centerTolerance = median * STABILITY_CENTER_TOLERANCE;
const withinCenter = window.every(v => Math.abs(v - median) <= centerTolerance);
// Check spread tolerance (max-min ≤ 0.25% of min)
const minVal = Math.min(...window);
const maxVal = Math.max(...window);
const spreadTolerance = minVal * STABILITY_SPREAD_TOLERANCE;
const withinSpread = (maxVal - minVal) <= spreadTolerance;
return withinCenter && withinSpread;
}
/**
* Main loop runner.
*/
async function runLoopedTests(testFile, options = {}) {
const minLoops = options.minLoops || DEFAULT_MIN_LOOPS;
const maxLoops = options.maxLoops || DEFAULT_MAX_LOOPS;
const durationSeconds = options.durationSeconds || DEFAULT_DURATION_SECONDS;
const stabilityCheck = options.stabilityCheck !== false;
const timeout = options.timeout || 15;
const cwd = options.cwd || process.cwd();
console.log(`[codeflash-loop-runner] Starting looped test execution`);
console.log(` Test file: ${testFile}`);
console.log(` Min loops: ${minLoops}`);
console.log(` Max loops: ${maxLoops}`);
console.log(` Duration: ${durationSeconds}s`);
console.log(` Stability check: ${stabilityCheck}`);
console.log('');
const startTime = Date.now();
const allTimings = new Map(); // Map<loopIndex, Map<testId, number[]>>
let loopCount = 0;
let lastExitCode = 0;
while (true) {
loopCount++;
const loopStart = Date.now();
console.log(`[loop ${loopCount}] Running...`);
const result = await runJestOnce(testFile, loopCount, timeout, cwd);
lastExitCode = result.code;
// Store timings for this loop
allTimings.set(loopCount, result.timings);
const loopDuration = Date.now() - loopStart;
const totalElapsed = (Date.now() - startTime) / 1000;
// Count timing entries
let timingCount = 0;
for (const durations of result.timings.values()) {
timingCount += durations.length;
}
console.log(`[loop ${loopCount}] Completed in ${loopDuration}ms, ${timingCount} timing entries`);
// Check stopping conditions
if (loopCount >= maxLoops) {
console.log(`[codeflash-loop-runner] Reached max loops (${maxLoops})`);
break;
}
if (loopCount >= minLoops && totalElapsed >= durationSeconds) {
console.log(`[codeflash-loop-runner] Reached duration limit (${durationSeconds}s)`);
break;
}
// Stability check
if (stabilityCheck && loopCount >= minLoops) {
const estimatedTotalLoops = Math.floor((durationSeconds / totalElapsed) * loopCount);
const windowSize = Math.max(3, Math.floor(STABILITY_WINDOW_SIZE * estimatedTotalLoops));
if (shouldStopForStability(allTimings, windowSize)) {
console.log(`[codeflash-loop-runner] Performance stabilized after ${loopCount} loops`);
break;
}
}
}
// Aggregate results
const aggregatedTimings = new Map(); // Map<testId, {min, max, avg, count}>
for (const [loopIndex, timings] of allTimings.entries()) {
for (const [testId, durations] of timings.entries()) {
if (!aggregatedTimings.has(testId)) {
aggregatedTimings.set(testId, { values: [], min: Infinity, max: 0, sum: 0, count: 0 });
}
const agg = aggregatedTimings.get(testId);
for (const d of durations) {
agg.values.push(d);
agg.min = Math.min(agg.min, d);
agg.max = Math.max(agg.max, d);
agg.sum += d;
agg.count++;
}
}
}
// Print summary
console.log('');
console.log('=== Performance Summary ===');
console.log(`Total loops: ${loopCount}`);
console.log(`Total time: ${((Date.now() - startTime) / 1000).toFixed(2)}s`);
console.log('');
for (const [testId, agg] of aggregatedTimings.entries()) {
const avg = agg.sum / agg.count;
console.log(`${testId}:`);
console.log(` Min: ${(agg.min / 1000).toFixed(2)} μs`);
console.log(` Max: ${(agg.max / 1000).toFixed(2)} μs`);
console.log(` Avg: ${(avg / 1000).toFixed(2)} μs`);
console.log(` Samples: ${agg.count}`);
}
return {
loopCount,
allTimings,
aggregatedTimings,
exitCode: lastExitCode,
};
}
// CLI interface
if (require.main === module) {
const args = process.argv.slice(2);
if (args.length === 0 || args[0] === '--help') {
console.log('Usage: node loop-runner.js <test-file> [options]');
console.log('');
console.log('Options:');
console.log(' --min-loops=N Minimum loops to run (default: 5)');
console.log(' --max-loops=N Maximum loops to run (default: 100000)');
console.log(' --duration=N Target duration in seconds (default: 10)');
console.log(' --stability-check Enable stability-based early stopping');
console.log(' --cwd=PATH Working directory for Jest');
process.exit(0);
}
const testFile = args[0];
const options = {};
for (const arg of args.slice(1)) {
if (arg.startsWith('--min-loops=')) {
options.minLoops = parseInt(arg.split('=')[1], 10);
} else if (arg.startsWith('--max-loops=')) {
options.maxLoops = parseInt(arg.split('=')[1], 10);
} else if (arg.startsWith('--duration=')) {
options.durationSeconds = parseFloat(arg.split('=')[1]);
} else if (arg === '--stability-check') {
options.stabilityCheck = true;
} else if (arg.startsWith('--cwd=')) {
options.cwd = arg.split('=')[1];
}
}
runLoopedTests(testFile, options)
.then((result) => {
process.exit(result.exitCode);
})
.catch((error) => {
console.error('Error:', error);
process.exit(1);
});
}
module.exports = { runLoopedTests, parseTimingFromStdout };

View file

@ -0,0 +1,35 @@
/**
* Test for session-level looping performance measurement.
*
* Note: Looping is now done at the session level by Python (test_runner.py)
* which runs Jest multiple times. Each Jest run executes the test once,
* and timing data is aggregated across runs for stability checking.
*/
// Load the codeflash helper from npm package
const codeflash = require('codeflash');
// Simple function to test
function fibonacci(n) {
if (n <= 1) return n;
let a = 0, b = 1;
for (let i = 2; i <= n; i++) {
const temp = a + b;
a = b;
b = temp;
}
return b;
}
describe('Session-Level Looping Performance Test', () => {
test('fibonacci(20) with session-level looping', () => {
// Looping is controlled by Python via CODEFLASH_LOOP_INDEX env var
const result = codeflash.capturePerf('fibonacci', '10', fibonacci, 20);
expect(result).toBe(6765);
});
test('fibonacci(30) with session-level looping', () => {
const result = codeflash.capturePerf('fibonacci', '16', fibonacci, 30);
expect(result).toBe(832040);
});
});

View file

@ -0,0 +1,41 @@
/**
* Sample performance test to verify looping mechanism.
*/
// Load the codeflash helper from npm package
const codeflash = require('codeflash');
// Simple function to test
function fibonacci(n) {
if (n <= 1) return n;
let a = 0, b = 1;
for (let i = 2; i <= n; i++) {
const temp = a + b;
a = b;
b = temp;
}
return b;
}
describe('Looping Performance Test', () => {
test('fibonacci(20) timing', () => {
const result = codeflash.capturePerf('fibonacci', '10', fibonacci, 20);
expect(result).toBe(6765);
});
test('fibonacci(30) timing', () => {
const result = codeflash.capturePerf('fibonacci', '16', fibonacci, 30);
expect(result).toBe(832040);
});
test('multiple calls in one test', () => {
// Same lineId, multiple calls - should increment invocation counter
const r1 = codeflash.capturePerf('fibonacci', '22', fibonacci, 5);
const r2 = codeflash.capturePerf('fibonacci', '22', fibonacci, 10);
const r3 = codeflash.capturePerf('fibonacci', '22', fibonacci, 15);
expect(r1).toBe(5);
expect(r2).toBe(55);
expect(r3).toBe(610);
});
});

View file

@ -0,0 +1,135 @@
const {
reverseString,
isPalindrome,
countOccurrences,
longestCommonPrefix,
toTitleCase
} = require('../string_utils');
describe('reverseString', () => {
test('reverses a simple string', () => {
expect(reverseString('hello')).toBe('olleh');
});
test('returns empty string for empty input', () => {
expect(reverseString('')).toBe('');
});
test('handles single character', () => {
expect(reverseString('a')).toBe('a');
});
test('handles palindrome', () => {
expect(reverseString('radar')).toBe('radar');
});
test('handles spaces', () => {
expect(reverseString('hello world')).toBe('dlrow olleh');
});
test('reverses a longer string for performance', () => {
const input = 'abcdefghijklmnopqrstuvwxyz'.repeat(20);
const result = reverseString(input);
expect(result.length).toBe(input.length);
expect(result[0]).toBe('z');
expect(result[result.length - 1]).toBe('a');
});
test('reverses a medium string', () => {
const input = 'The quick brown fox jumps over the lazy dog';
const expected = 'god yzal eht revo spmuj xof nworb kciuq ehT';
expect(reverseString(input)).toBe(expected);
});
});
describe('isPalindrome', () => {
test('returns true for simple palindrome', () => {
expect(isPalindrome('radar')).toBe(true);
});
test('returns true for palindrome with mixed case', () => {
expect(isPalindrome('RaceCar')).toBe(true);
});
test('returns true for palindrome with spaces and punctuation', () => {
expect(isPalindrome('A man, a plan, a canal: Panama')).toBe(true);
});
test('returns false for non-palindrome', () => {
expect(isPalindrome('hello')).toBe(false);
});
test('returns true for empty string', () => {
expect(isPalindrome('')).toBe(true);
});
test('returns true for single character', () => {
expect(isPalindrome('a')).toBe(true);
});
});
describe('countOccurrences', () => {
test('counts single occurrence', () => {
expect(countOccurrences('hello', 'ell')).toBe(1);
});
test('counts multiple occurrences', () => {
expect(countOccurrences('abababab', 'ab')).toBe(4);
});
test('returns 0 for no occurrences', () => {
expect(countOccurrences('hello', 'xyz')).toBe(0);
});
test('handles overlapping matches', () => {
expect(countOccurrences('aaa', 'aa')).toBe(2);
});
test('handles empty substring', () => {
expect(countOccurrences('hello', '')).toBe(6);
});
});
describe('longestCommonPrefix', () => {
test('finds common prefix', () => {
expect(longestCommonPrefix(['flower', 'flow', 'flight'])).toBe('fl');
});
test('returns empty for no common prefix', () => {
expect(longestCommonPrefix(['dog', 'racecar', 'car'])).toBe('');
});
test('returns empty for empty array', () => {
expect(longestCommonPrefix([])).toBe('');
});
test('returns the string for single element array', () => {
expect(longestCommonPrefix(['hello'])).toBe('hello');
});
test('handles identical strings', () => {
expect(longestCommonPrefix(['test', 'test', 'test'])).toBe('test');
});
});
describe('toTitleCase', () => {
test('converts simple string', () => {
expect(toTitleCase('hello world')).toBe('Hello World');
});
test('handles already title case', () => {
expect(toTitleCase('Hello World')).toBe('Hello World');
});
test('handles uppercase input', () => {
expect(toTitleCase('HELLO WORLD')).toBe('Hello World');
});
test('handles single word', () => {
expect(toTitleCase('hello')).toBe('Hello');
});
test('handles empty string', () => {
expect(toTitleCase('')).toBe('');
});
});

View file

@ -0,0 +1,5 @@
# Codeflash Configuration for CommonJS JavaScript Project
module_root: "."
tests_root: "tests"
test_framework: "jest"
formatter_cmds: []

View file

@ -0,0 +1,60 @@
/**
* Fibonacci implementations - CommonJS module
* Intentionally inefficient for optimization testing.
*/
/**
* Calculate the nth Fibonacci number using naive recursion.
* This is intentionally slow to demonstrate optimization potential.
* @param {number} n - The index of the Fibonacci number to calculate
* @returns {number} The nth Fibonacci number
*/
function fibonacci(n) {
if (n <= 1) {
return n;
}
return fibonacci(n - 1) + fibonacci(n - 2);
}
/**
* Check if a number is a Fibonacci number.
* @param {number} num - The number to check
* @returns {boolean} True if num is a Fibonacci number
*/
function isFibonacci(num) {
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
const check1 = 5 * num * num + 4;
const check2 = 5 * num * num - 4;
return isPerfectSquare(check1) || isPerfectSquare(check2);
}
/**
* Check if a number is a perfect square.
* @param {number} n - The number to check
* @returns {boolean} True if n is a perfect square
*/
function isPerfectSquare(n) {
const sqrt = Math.sqrt(n);
return sqrt === Math.floor(sqrt);
}
/**
* Generate an array of Fibonacci numbers up to n.
* @param {number} n - The number of Fibonacci numbers to generate
* @returns {number[]} Array of Fibonacci numbers
*/
function fibonacciSequence(n) {
const result = [];
for (let i = 0; i < n; i++) {
result.push(fibonacci(i));
}
return result;
}
// CommonJS exports
module.exports = {
fibonacci,
isFibonacci,
isPerfectSquare,
fibonacciSequence,
};

View file

@ -0,0 +1,61 @@
/**
* Fibonacci Calculator Class - CommonJS module
* Intentionally inefficient for optimization testing.
*/
class FibonacciCalculator {
constructor() {
// No initialization needed
}
/**
* Calculate the nth Fibonacci number using naive recursion.
* This is intentionally slow to demonstrate optimization potential.
* @param {number} n - The index of the Fibonacci number to calculate
* @returns {number} The nth Fibonacci number
*/
fibonacci(n) {
if (n <= 1) {
return n;
}
return this.fibonacci(n - 1) + this.fibonacci(n - 2);
}
/**
* Check if a number is a Fibonacci number.
* @param {number} num - The number to check
* @returns {boolean} True if num is a Fibonacci number
*/
isFibonacci(num) {
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
const check1 = 5 * num * num + 4;
const check2 = 5 * num * num - 4;
return this.isPerfectSquare(check1) || this.isPerfectSquare(check2);
}
/**
* Check if a number is a perfect square.
* @param {number} n - The number to check
* @returns {boolean} True if n is a perfect square
*/
isPerfectSquare(n) {
const sqrt = Math.sqrt(n);
return sqrt === Math.floor(sqrt);
}
/**
* Generate an array of Fibonacci numbers up to n.
* @param {number} n - The number of Fibonacci numbers to generate
* @returns {number[]} Array of Fibonacci numbers
*/
fibonacciSequence(n) {
const result = [];
for (let i = 0; i < n; i++) {
result.push(this.fibonacci(i));
}
return result;
}
}
// CommonJS exports
module.exports = { FibonacciCalculator };

View file

@ -0,0 +1,6 @@
module.exports = {
testEnvironment: 'node',
testMatch: ['**/tests/**/*.test.js'],
reporters: ['default', ['jest-junit', { outputDirectory: '.codeflash' }]],
verbose: true,
};

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,14 @@
{
"name": "code-to-optimize-js-cjs",
"version": "1.0.0",
"description": "CommonJS JavaScript test project for Codeflash E2E testing",
"main": "index.js",
"scripts": {
"test": "jest"
},
"devDependencies": {
"codeflash": "file:../../../packages/codeflash",
"jest": "^29.7.0",
"jest-junit": "^16.0.0"
}
}

View file

@ -0,0 +1,76 @@
/**
* Tests for Fibonacci functions - CommonJS module
*/
const { fibonacci, isFibonacci, isPerfectSquare, fibonacciSequence } = require('../fibonacci');
describe('fibonacci', () => {
test('returns 0 for n=0', () => {
expect(fibonacci(0)).toBe(0);
});
test('returns 1 for n=1', () => {
expect(fibonacci(1)).toBe(1);
});
test('returns 1 for n=2', () => {
expect(fibonacci(2)).toBe(1);
});
test('returns 5 for n=5', () => {
expect(fibonacci(5)).toBe(5);
});
test('returns 55 for n=10', () => {
expect(fibonacci(10)).toBe(55);
});
test('returns 233 for n=13', () => {
expect(fibonacci(13)).toBe(233);
});
});
describe('isFibonacci', () => {
test('returns true for Fibonacci numbers', () => {
expect(isFibonacci(0)).toBe(true);
expect(isFibonacci(1)).toBe(true);
expect(isFibonacci(5)).toBe(true);
expect(isFibonacci(8)).toBe(true);
expect(isFibonacci(13)).toBe(true);
});
test('returns false for non-Fibonacci numbers', () => {
expect(isFibonacci(4)).toBe(false);
expect(isFibonacci(6)).toBe(false);
expect(isFibonacci(7)).toBe(false);
});
});
describe('isPerfectSquare', () => {
test('returns true for perfect squares', () => {
expect(isPerfectSquare(0)).toBe(true);
expect(isPerfectSquare(1)).toBe(true);
expect(isPerfectSquare(4)).toBe(true);
expect(isPerfectSquare(9)).toBe(true);
expect(isPerfectSquare(16)).toBe(true);
});
test('returns false for non-perfect squares', () => {
expect(isPerfectSquare(2)).toBe(false);
expect(isPerfectSquare(3)).toBe(false);
expect(isPerfectSquare(5)).toBe(false);
});
});
describe('fibonacciSequence', () => {
test('returns empty array for n=0', () => {
expect(fibonacciSequence(0)).toEqual([]);
});
test('returns first 5 Fibonacci numbers', () => {
expect(fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
});
test('returns first 10 Fibonacci numbers', () => {
expect(fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
});
});

View file

@ -0,0 +1,105 @@
const { FibonacciCalculator } = require('../fibonacci_class');
describe('FibonacciCalculator', () => {
let calc;
beforeEach(() => {
calc = new FibonacciCalculator();
});
describe('fibonacci', () => {
test('returns 0 for n=0', () => {
expect(calc.fibonacci(0)).toBe(0);
});
test('returns 1 for n=1', () => {
expect(calc.fibonacci(1)).toBe(1);
});
test('returns 1 for n=2', () => {
expect(calc.fibonacci(2)).toBe(1);
});
test('returns 5 for n=5', () => {
expect(calc.fibonacci(5)).toBe(5);
});
test('returns 55 for n=10', () => {
expect(calc.fibonacci(10)).toBe(55);
});
test('returns 233 for n=13', () => {
expect(calc.fibonacci(13)).toBe(233);
});
});
describe('isFibonacci', () => {
test('returns true for 0', () => {
expect(calc.isFibonacci(0)).toBe(true);
});
test('returns true for 1', () => {
expect(calc.isFibonacci(1)).toBe(true);
});
test('returns true for 8', () => {
expect(calc.isFibonacci(8)).toBe(true);
});
test('returns true for 13', () => {
expect(calc.isFibonacci(13)).toBe(true);
});
test('returns false for 4', () => {
expect(calc.isFibonacci(4)).toBe(false);
});
test('returns false for 6', () => {
expect(calc.isFibonacci(6)).toBe(false);
});
});
describe('isPerfectSquare', () => {
test('returns true for 0', () => {
expect(calc.isPerfectSquare(0)).toBe(true);
});
test('returns true for 1', () => {
expect(calc.isPerfectSquare(1)).toBe(true);
});
test('returns true for 4', () => {
expect(calc.isPerfectSquare(4)).toBe(true);
});
test('returns true for 16', () => {
expect(calc.isPerfectSquare(16)).toBe(true);
});
test('returns false for 2', () => {
expect(calc.isPerfectSquare(2)).toBe(false);
});
test('returns false for 3', () => {
expect(calc.isPerfectSquare(3)).toBe(false);
});
});
describe('fibonacciSequence', () => {
test('returns empty array for n=0', () => {
expect(calc.fibonacciSequence(0)).toEqual([]);
});
test('returns [0] for n=1', () => {
expect(calc.fibonacciSequence(1)).toEqual([0]);
});
test('returns first 5 Fibonacci numbers', () => {
expect(calc.fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
});
test('returns first 10 Fibonacci numbers', () => {
expect(calc.fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
});
});
});

View file

@ -0,0 +1,64 @@
/**
* Async utility functions - ES Module version.
* Contains intentionally inefficient implementations for optimization testing.
*/
/**
* Simulate a delay (for testing purposes).
* @param {number} ms - Milliseconds to delay
* @returns {Promise<void>}
*/
export function delay(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
/**
* Process items sequentially when they could be parallel.
* Intentionally inefficient - processes items one at a time.
* @param {any[]} items - Items to process
* @param {function} processor - Async function to process each item
* @returns {Promise<any[]>} Processed results
*/
export async function processItemsSequential(items, processor) {
const results = [];
for (let i = 0; i < items.length; i++) {
const result = await processor(items[i]);
results.push(result);
}
return results;
}
/**
* Map over items asynchronously with a concurrency limit.
* Intentionally simple/inefficient implementation - ignores concurrency.
* @param {any[]} items - Items to process
* @param {function} mapper - Async mapper function
* @param {number} concurrency - Max concurrent operations (currently ignored)
* @returns {Promise<any[]>} Mapped results
*/
export async function asyncMap(items, mapper, concurrency = 1) {
// Inefficient: ignores concurrency, processes sequentially
const results = [];
for (const item of items) {
results.push(await mapper(item));
}
return results;
}
/**
* Filter items asynchronously.
* Inefficient implementation that processes items one by one.
* @param {any[]} items - Items to filter
* @param {function} predicate - Async predicate function
* @returns {Promise<any[]>} Filtered items
*/
export async function asyncFilter(items, predicate) {
const results = [];
for (const item of items) {
const shouldInclude = await predicate(item);
if (shouldInclude) {
results.push(item);
}
}
return results;
}

View file

@ -0,0 +1,5 @@
# Codeflash Configuration for ES Module JavaScript Project
module_root: "."
tests_root: "tests"
test_framework: "jest"
formatter_cmds: []

View file

@ -0,0 +1,52 @@
/**
* Fibonacci implementations - ES Module
* Intentionally inefficient for optimization testing.
*/
/**
* Calculate the nth Fibonacci number using naive recursion.
* This is intentionally slow to demonstrate optimization potential.
* @param {number} n - The index of the Fibonacci number to calculate
* @returns {number} The nth Fibonacci number
*/
export function fibonacci(n) {
if (n <= 1) {
return n;
}
return fibonacci(n - 1) + fibonacci(n - 2);
}
/**
* Check if a number is a Fibonacci number.
* @param {number} num - The number to check
* @returns {boolean} True if num is a Fibonacci number
*/
export function isFibonacci(num) {
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
const check1 = 5 * num * num + 4;
const check2 = 5 * num * num - 4;
return isPerfectSquare(check1) || isPerfectSquare(check2);
}
/**
* Check if a number is a perfect square.
* @param {number} n - The number to check
* @returns {boolean} True if n is a perfect square
*/
export function isPerfectSquare(n) {
const sqrt = Math.sqrt(n);
return sqrt === Math.floor(sqrt);
}
/**
* Generate an array of Fibonacci numbers up to n.
* @param {number} n - The number of Fibonacci numbers to generate
* @returns {number[]} Array of Fibonacci numbers
*/
export function fibonacciSequence(n) {
const result = [];
for (let i = 0; i < n; i++) {
result.push(fibonacci(i));
}
return result;
}

View file

@ -0,0 +1,11 @@
// Jest config for ES Module project (using .cjs since package is type: module)
module.exports = {
testEnvironment: 'node',
testMatch: ['**/tests/**/*.test.js'],
reporters: ['default', ['jest-junit', { outputDirectory: '.codeflash' }]],
verbose: true,
transform: {},
// Tell Jest to also look for modules in the project's node_modules when
// resolving modules from symlinked packages (like codeflash)
moduleDirectories: ['node_modules', '<rootDir>/node_modules'],
};

View file

@ -0,0 +1,23 @@
{
"name": "code-to-optimize-js-esm",
"version": "1.0.0",
"description": "ES Module JavaScript test project for Codeflash E2E testing",
"type": "module",
"main": "index.js",
"scripts": {
"test": "NODE_OPTIONS='--experimental-vm-modules' jest"
},
"devDependencies": {
"@eslint/js": "^9.39.2",
"codeflash": "file:../../../packages/codeflash",
"eslint": "^9.39.2",
"globals": "^17.1.0",
"jest": "^29.7.0",
"jest-junit": "^16.0.0"
},
"codeflash": {
"moduleRoot": ".",
"testsRoot": "tests",
"disableTelemetry": true
}
}

View file

@ -0,0 +1,85 @@
/**
* Tests for async utility functions - ES Module
*/
import { delay, processItemsSequential, asyncMap, asyncFilter } from '../async_utils.js';
describe('processItemsSequential', () => {
test('processes all items', async () => {
const items = [1, 2, 3, 4, 5];
const processor = async (x) => x * 2;
const results = await processItemsSequential(items, processor);
expect(results).toEqual([2, 4, 6, 8, 10]);
});
test('handles empty array', async () => {
const results = await processItemsSequential([], async (x) => x);
expect(results).toEqual([]);
});
test('handles async operations with delays', async () => {
const items = [1, 2, 3];
const processor = async (x) => {
await delay(1);
return x + 10;
};
const results = await processItemsSequential(items, processor);
expect(results).toEqual([11, 12, 13]);
});
test('preserves order', async () => {
const items = [5, 4, 3, 2, 1];
const processor = async (x) => x.toString();
const results = await processItemsSequential(items, processor);
expect(results).toEqual(['5', '4', '3', '2', '1']);
});
test('handles larger arrays', async () => {
const items = Array.from({ length: 20 }, (_, i) => i);
const processor = async (x) => x * 2;
const results = await processItemsSequential(items, processor);
expect(results.length).toBe(20);
expect(results[0]).toBe(0);
expect(results[19]).toBe(38);
});
});
describe('asyncMap', () => {
test('maps all items', async () => {
const items = [1, 2, 3];
const mapper = async (x) => x * 10;
const results = await asyncMap(items, mapper);
expect(results).toEqual([10, 20, 30]);
});
test('handles empty array', async () => {
const results = await asyncMap([], async (x) => x);
expect(results).toEqual([]);
});
test('handles objects', async () => {
const items = [{ a: 1 }, { a: 2 }];
const mapper = async (obj) => ({ ...obj, b: obj.a * 2 });
const results = await asyncMap(items, mapper);
expect(results).toEqual([{ a: 1, b: 2 }, { a: 2, b: 4 }]);
});
});
describe('asyncFilter', () => {
test('filters items based on predicate', async () => {
const items = [1, 2, 3, 4, 5, 6];
const predicate = async (x) => x % 2 === 0;
const results = await asyncFilter(items, predicate);
expect(results).toEqual([2, 4, 6]);
});
test('handles empty array', async () => {
const results = await asyncFilter([], async () => true);
expect(results).toEqual([]);
});
test('handles all items filtered out', async () => {
const items = [1, 2, 3];
const results = await asyncFilter(items, async () => false);
expect(results).toEqual([]);
});
});

View file

@ -0,0 +1,76 @@
/**
* Tests for Fibonacci functions - ES Module
*/
import { fibonacci, isFibonacci, isPerfectSquare, fibonacciSequence } from '../fibonacci.js';
describe('fibonacci', () => {
test('returns 0 for n=0', () => {
expect(fibonacci(0)).toBe(0);
});
test('returns 1 for n=1', () => {
expect(fibonacci(1)).toBe(1);
});
test('returns 1 for n=2', () => {
expect(fibonacci(2)).toBe(1);
});
test('returns 5 for n=5', () => {
expect(fibonacci(5)).toBe(5);
});
test('returns 55 for n=10', () => {
expect(fibonacci(10)).toBe(55);
});
test('returns 233 for n=13', () => {
expect(fibonacci(13)).toBe(233);
});
});
describe('isFibonacci', () => {
test('returns true for Fibonacci numbers', () => {
expect(isFibonacci(0)).toBe(true);
expect(isFibonacci(1)).toBe(true);
expect(isFibonacci(5)).toBe(true);
expect(isFibonacci(8)).toBe(true);
expect(isFibonacci(13)).toBe(true);
});
test('returns false for non-Fibonacci numbers', () => {
expect(isFibonacci(4)).toBe(false);
expect(isFibonacci(6)).toBe(false);
expect(isFibonacci(7)).toBe(false);
});
});
describe('isPerfectSquare', () => {
test('returns true for perfect squares', () => {
expect(isPerfectSquare(0)).toBe(true);
expect(isPerfectSquare(1)).toBe(true);
expect(isPerfectSquare(4)).toBe(true);
expect(isPerfectSquare(9)).toBe(true);
expect(isPerfectSquare(16)).toBe(true);
});
test('returns false for non-perfect squares', () => {
expect(isPerfectSquare(2)).toBe(false);
expect(isPerfectSquare(3)).toBe(false);
expect(isPerfectSquare(5)).toBe(false);
});
});
describe('fibonacciSequence', () => {
test('returns empty array for n=0', () => {
expect(fibonacciSequence(0)).toEqual([]);
});
test('returns first 5 Fibonacci numbers', () => {
expect(fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
});
test('returns first 10 Fibonacci numbers', () => {
expect(fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
});
});

View file

@ -0,0 +1,63 @@
/**
* Bubble sort implementation - intentionally inefficient for optimization testing.
*/
/**
* Sort an array using bubble sort algorithm.
* @param arr - The array to sort
* @returns A new sorted array
*/
export function bubbleSort<T>(arr: T[]): T[] {
const result = [...arr];
const n = result.length;
for (let i = 0; i < n - 1; i++) {
for (let j = 0; j < n - i - 1; j++) {
if (result[j] > result[j + 1]) {
// Swap elements
const temp = result[j];
result[j] = result[j + 1];
result[j + 1] = temp;
}
}
}
return result;
}
/**
* Sort an array in descending order using bubble sort.
* @param arr - The array to sort
* @returns A new sorted array (descending)
*/
export function bubbleSortDescending<T>(arr: T[]): T[] {
const result = [...arr];
const n = result.length;
for (let i = 0; i < n - 1; i++) {
for (let j = 0; j < n - i - 1; j++) {
if (result[j] < result[j + 1]) {
// Swap elements
const temp = result[j];
result[j] = result[j + 1];
result[j + 1] = temp;
}
}
}
return result;
}
/**
* Check if an array is sorted in ascending order.
* @param arr - The array to check
* @returns True if the array is sorted in ascending order
*/
export function isSorted<T>(arr: T[]): boolean {
for (let i = 0; i < arr.length - 1; i++) {
if (arr[i] > arr[i + 1]) {
return false;
}
}
return true;
}

View file

@ -0,0 +1,2 @@
module_root: .
tests_root: tests

View file

@ -0,0 +1,88 @@
/**
* DataProcessor class - demonstrates class method optimization in TypeScript.
* Contains intentionally inefficient implementations for optimization testing.
*/
/**
* A class for processing data arrays with various operations.
*/
export class DataProcessor<T> {
private data: T[];
/**
* Create a DataProcessor instance.
* @param data - Initial data array
*/
constructor(data: T[] = []) {
this.data = [...data];
}
/**
* Find duplicates in the data array.
* Intentionally inefficient O(n²) implementation.
* @returns Array of duplicate values
*/
findDuplicates(): T[] {
const duplicates: T[] = [];
for (let i = 0; i < this.data.length; i++) {
for (let j = i + 1; j < this.data.length; j++) {
if (this.data[i] === this.data[j]) {
if (!duplicates.includes(this.data[i])) {
duplicates.push(this.data[i]);
}
}
}
}
return duplicates;
}
/**
* Sort the data using bubble sort.
* Intentionally inefficient O(n²) implementation.
* @returns Sorted copy of the data
*/
sortData(): T[] {
const result = [...this.data];
const n = result.length;
for (let i = 0; i < n; i++) {
for (let j = 0; j < n - 1; j++) {
if (result[j] > result[j + 1]) {
const temp = result[j];
result[j] = result[j + 1];
result[j + 1] = temp;
}
}
}
return result;
}
/**
* Get unique values from the data.
* Intentionally inefficient O(n²) implementation.
* @returns Array of unique values
*/
getUnique(): T[] {
const unique: T[] = [];
for (let i = 0; i < this.data.length; i++) {
let found = false;
for (let j = 0; j < unique.length; j++) {
if (unique[j] === this.data[i]) {
found = true;
break;
}
}
if (!found) {
unique.push(this.data[i]);
}
}
return unique;
}
/**
* Get the data array.
* @returns The data array
*/
getData(): T[] {
return [...this.data];
}
}

View file

@ -0,0 +1,52 @@
/**
* Fibonacci implementations - intentionally inefficient for optimization testing.
*/
/**
* Calculate the nth Fibonacci number using naive recursion.
* This is intentionally slow to demonstrate optimization potential.
* @param n - The index of the Fibonacci number to calculate
* @returns The nth Fibonacci number
*/
export function fibonacci(n: number): number {
if (n <= 1) {
return n;
}
return fibonacci(n - 1) + fibonacci(n - 2);
}
/**
* Check if a number is a Fibonacci number.
* @param num - The number to check
* @returns True if num is a Fibonacci number
*/
export function isFibonacci(num: number): boolean {
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
const check1 = 5 * num * num + 4;
const check2 = 5 * num * num - 4;
return isPerfectSquare(check1) || isPerfectSquare(check2);
}
/**
* Check if a number is a perfect square.
* @param n - The number to check
* @returns True if n is a perfect square
*/
export function isPerfectSquare(n: number): boolean {
const sqrt = Math.sqrt(n);
return sqrt === Math.floor(sqrt);
}
/**
* Generate an array of Fibonacci numbers up to n.
* @param n - The number of Fibonacci numbers to generate
* @returns Array of Fibonacci numbers
*/
export function fibonacciSequence(n: number): number[] {
const result: number[] = [];
for (let i = 0; i < n; i++) {
result.push(fibonacci(i));
}
return result;
}

View file

@ -0,0 +1,35 @@
import type { Config } from 'jest';
const config: Config = {
preset: 'ts-jest',
testEnvironment: 'node',
testMatch: [
'**/tests/**/*.test.ts',
'**/tests/**/*.spec.ts'
],
moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'],
collectCoverageFrom: [
'**/*.ts',
'!**/node_modules/**',
'!**/dist/**',
'!jest.config.ts'
],
reporters: [
'default',
[
'jest-junit',
{
outputDirectory: '.codeflash',
outputName: 'jest-results.xml',
includeConsoleOutput: true
}
]
],
transform: {
'^.+\\.tsx?$': ['ts-jest', {
useESM: false
}]
}
};
export default config;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,33 @@
{
"name": "codeflash-ts-test",
"version": "1.0.0",
"description": "Sample TypeScript project for codeflash optimization testing",
"main": "dist/index.js",
"scripts": {
"test": "jest",
"test:coverage": "jest --coverage",
"build": "tsc"
},
"codeflash": {
"moduleRoot": ".",
"testsRoot": "tests"
},
"keywords": [
"codeflash",
"optimization",
"testing",
"typescript"
],
"author": "CodeFlash Inc.",
"license": "BSL 1.1",
"devDependencies": {
"@types/jest": "^29.5.0",
"@types/node": "^20.0.0",
"codeflash": "file:../../../packages/codeflash",
"jest": "^29.7.0",
"jest-junit": "^16.0.0",
"ts-jest": "^29.1.0",
"ts-node": "^10.9.2",
"typescript": "^5.0.0"
}
}

View file

@ -0,0 +1,84 @@
/**
* String utility functions - intentionally inefficient for optimization testing.
*/
/**
* Reverse a string character by character.
* This is intentionally inefficient O(n²) - rebuilds result string each iteration.
* @param str - The string to reverse
* @returns The reversed string
*/
export function reverseString(str: string): string {
// Intentionally inefficient O(n²) implementation for testing
let result = '';
for (let i = str.length - 1; i >= 0; i--) {
// Rebuild the entire result string each iteration (very inefficient)
let temp = '';
for (let j = 0; j < result.length; j++) {
temp += result[j];
}
temp += str[i];
result = temp;
}
return result;
}
/**
* Check if a string is a palindrome.
* @param str - The string to check
* @returns True if the string is a palindrome
*/
export function isPalindrome(str: string): boolean {
const cleaned = str.toLowerCase().replace(/[^a-z0-9]/g, '');
return cleaned === reverseString(cleaned);
}
/**
* Count occurrences of a substring in a string.
* @param str - The string to search in
* @param substr - The substring to count
* @returns The number of occurrences
*/
export function countOccurrences(str: string, substr: string): number {
let count = 0;
let pos = 0;
while ((pos = str.indexOf(substr, pos)) !== -1) {
count++;
pos += 1; // Move forward to find overlapping occurrences
}
return count;
}
/**
* Find the longest common prefix among an array of strings.
* @param strs - Array of strings
* @returns The longest common prefix
*/
export function longestCommonPrefix(strs: string[]): string {
if (strs.length === 0) return '';
if (strs.length === 1) return strs[0];
let prefix = strs[0];
for (let i = 1; i < strs.length; i++) {
while (strs[i].indexOf(prefix) !== 0) {
prefix = prefix.slice(0, -1);
if (prefix === '') return '';
}
}
return prefix;
}
/**
* Convert a string to title case.
* @param str - The string to convert
* @returns The title-cased string
*/
export function toTitleCase(str: string): string {
return str
.toLowerCase()
.split(' ')
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
.join(' ');
}

View file

@ -0,0 +1,92 @@
import { bubbleSort, bubbleSortDescending, isSorted } from '../bubble_sort';
describe('bubbleSort', () => {
test('sorts an empty array', () => {
expect(bubbleSort([])).toEqual([]);
});
test('sorts a single element array', () => {
expect(bubbleSort([1])).toEqual([1]);
});
test('sorts an already sorted array', () => {
expect(bubbleSort([1, 2, 3, 4, 5])).toEqual([1, 2, 3, 4, 5]);
});
test('sorts a reverse sorted array', () => {
expect(bubbleSort([5, 4, 3, 2, 1])).toEqual([1, 2, 3, 4, 5]);
});
test('sorts an unsorted array', () => {
expect(bubbleSort([3, 1, 4, 1, 5, 9, 2, 6])).toEqual([1, 1, 2, 3, 4, 5, 6, 9]);
});
test('handles negative numbers', () => {
expect(bubbleSort([-3, -1, -4, -1, -5])).toEqual([-5, -4, -3, -1, -1]);
});
test('handles mixed positive and negative', () => {
expect(bubbleSort([3, -1, 4, -1, 5])).toEqual([-1, -1, 3, 4, 5]);
});
test('does not mutate original array', () => {
const original = [3, 1, 2];
bubbleSort(original);
expect(original).toEqual([3, 1, 2]);
});
test('sorts a larger reverse sorted array for performance', () => {
const input: number[] = [];
for (let i = 500; i >= 0; i--) {
input.push(i);
}
const result = bubbleSort(input);
expect(result[0]).toBe(0);
expect(result[result.length - 1]).toBe(500);
});
test('sorts a larger random array for performance', () => {
const input = [
42, 17, 93, 8, 67, 31, 55, 22, 89, 4,
76, 12, 39, 58, 95, 26, 71, 48, 83, 19,
64, 3, 88, 37, 52, 11, 79, 46, 91, 28,
63, 7, 84, 33, 57, 14, 72, 41, 96, 24,
69, 6, 81, 36, 54, 16, 77, 44, 90, 29
];
const result = bubbleSort(input);
expect(result[0]).toBe(3);
expect(result[result.length - 1]).toBe(96);
});
});
describe('bubbleSortDescending', () => {
test('sorts in descending order', () => {
expect(bubbleSortDescending([1, 3, 2, 5, 4])).toEqual([5, 4, 3, 2, 1]);
});
test('handles empty array', () => {
expect(bubbleSortDescending([])).toEqual([]);
});
test('handles single element', () => {
expect(bubbleSortDescending([42])).toEqual([42]);
});
});
describe('isSorted', () => {
test('returns true for empty array', () => {
expect(isSorted([])).toBe(true);
});
test('returns true for single element', () => {
expect(isSorted([1])).toBe(true);
});
test('returns true for sorted array', () => {
expect(isSorted([1, 2, 3, 4, 5])).toBe(true);
});
test('returns false for unsorted array', () => {
expect(isSorted([1, 3, 2, 4, 5])).toBe(false);
});
});

View file

@ -0,0 +1,95 @@
import { DataProcessor } from '../data_processor';
describe('DataProcessor', () => {
describe('findDuplicates', () => {
test('finds duplicates in array with repeated values', () => {
const processor = new DataProcessor([1, 2, 3, 2, 4, 3, 5]);
expect(processor.findDuplicates().sort()).toEqual([2, 3]);
});
test('returns empty array when no duplicates', () => {
const processor = new DataProcessor([1, 2, 3, 4, 5]);
expect(processor.findDuplicates()).toEqual([]);
});
test('handles empty array', () => {
const processor = new DataProcessor<number>([]);
expect(processor.findDuplicates()).toEqual([]);
});
test('handles array with all same values', () => {
const processor = new DataProcessor([5, 5, 5, 5]);
expect(processor.findDuplicates()).toEqual([5]);
});
test('handles larger arrays with duplicates', () => {
const data: number[] = [];
for (let i = 0; i < 100; i++) {
data.push(i % 20);
}
const processor = new DataProcessor(data);
const duplicates = processor.findDuplicates();
expect(duplicates.length).toBe(20);
});
});
describe('sortData', () => {
test('sorts numbers in ascending order', () => {
const processor = new DataProcessor([5, 2, 8, 1, 9]);
expect(processor.sortData()).toEqual([1, 2, 5, 8, 9]);
});
test('handles already sorted array', () => {
const processor = new DataProcessor([1, 2, 3, 4, 5]);
expect(processor.sortData()).toEqual([1, 2, 3, 4, 5]);
});
test('handles reverse sorted array', () => {
const processor = new DataProcessor([5, 4, 3, 2, 1]);
expect(processor.sortData()).toEqual([1, 2, 3, 4, 5]);
});
test('handles array with duplicates', () => {
const processor = new DataProcessor([3, 1, 4, 1, 5, 9, 2, 6, 5]);
expect(processor.sortData()).toEqual([1, 1, 2, 3, 4, 5, 5, 6, 9]);
});
test('handles larger arrays', () => {
const data: number[] = [];
for (let i = 500; i >= 0; i--) {
data.push(i);
}
const processor = new DataProcessor(data);
const sorted = processor.sortData();
expect(sorted[0]).toBe(0);
expect(sorted[sorted.length - 1]).toBe(500);
});
});
describe('getUnique', () => {
test('returns unique values', () => {
const processor = new DataProcessor([1, 2, 2, 3, 3, 3, 4]);
expect(processor.getUnique()).toEqual([1, 2, 3, 4]);
});
test('preserves order of first occurrence', () => {
const processor = new DataProcessor([3, 1, 2, 1, 3, 2]);
expect(processor.getUnique()).toEqual([3, 1, 2]);
});
test('handles empty array', () => {
const processor = new DataProcessor<number>([]);
expect(processor.getUnique()).toEqual([]);
});
test('handles array with all unique values', () => {
const processor = new DataProcessor([1, 2, 3, 4, 5]);
expect(processor.getUnique()).toEqual([1, 2, 3, 4, 5]);
});
test('handles strings', () => {
const processor = new DataProcessor(['a', 'b', 'a', 'c', 'b']);
expect(processor.getUnique()).toEqual(['a', 'b', 'c']);
});
});
});

View file

@ -0,0 +1,97 @@
import { fibonacci, isFibonacci, isPerfectSquare, fibonacciSequence } from '../fibonacci';
describe('fibonacci', () => {
test('returns 0 for n=0', () => {
expect(fibonacci(0)).toBe(0);
});
test('returns 1 for n=1', () => {
expect(fibonacci(1)).toBe(1);
});
test('returns 1 for n=2', () => {
expect(fibonacci(2)).toBe(1);
});
test('returns 5 for n=5', () => {
expect(fibonacci(5)).toBe(5);
});
test('returns 55 for n=10', () => {
expect(fibonacci(10)).toBe(55);
});
test('returns 233 for n=13', () => {
expect(fibonacci(13)).toBe(233);
});
});
describe('isFibonacci', () => {
test('returns true for 0', () => {
expect(isFibonacci(0)).toBe(true);
});
test('returns true for 1', () => {
expect(isFibonacci(1)).toBe(true);
});
test('returns true for 8', () => {
expect(isFibonacci(8)).toBe(true);
});
test('returns true for 13', () => {
expect(isFibonacci(13)).toBe(true);
});
test('returns false for 4', () => {
expect(isFibonacci(4)).toBe(false);
});
test('returns false for 6', () => {
expect(isFibonacci(6)).toBe(false);
});
});
describe('isPerfectSquare', () => {
test('returns true for 0', () => {
expect(isPerfectSquare(0)).toBe(true);
});
test('returns true for 1', () => {
expect(isPerfectSquare(1)).toBe(true);
});
test('returns true for 4', () => {
expect(isPerfectSquare(4)).toBe(true);
});
test('returns true for 16', () => {
expect(isPerfectSquare(16)).toBe(true);
});
test('returns false for 2', () => {
expect(isPerfectSquare(2)).toBe(false);
});
test('returns false for 3', () => {
expect(isPerfectSquare(3)).toBe(false);
});
});
describe('fibonacciSequence', () => {
test('returns empty array for n=0', () => {
expect(fibonacciSequence(0)).toEqual([]);
});
test('returns [0] for n=1', () => {
expect(fibonacciSequence(1)).toEqual([0]);
});
test('returns first 5 Fibonacci numbers', () => {
expect(fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
});
test('returns first 10 Fibonacci numbers', () => {
expect(fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
});
});

View file

@ -0,0 +1,133 @@
import { reverseString, isPalindrome, countOccurrences, longestCommonPrefix, toTitleCase } from '../string_utils';
describe('reverseString', () => {
test('reverses an empty string', () => {
expect(reverseString('')).toBe('');
});
test('reverses a single character', () => {
expect(reverseString('a')).toBe('a');
});
test('reverses a word', () => {
expect(reverseString('hello')).toBe('olleh');
});
test('reverses a sentence', () => {
expect(reverseString('hello world')).toBe('dlrow olleh');
});
test('handles special characters', () => {
expect(reverseString('a!b@c#')).toBe('#c@b!a');
});
test('reverses a longer string for performance', () => {
const input = 'abcdefghijklmnopqrstuvwxyz'.repeat(20);
const result = reverseString(input);
expect(result.length).toBe(input.length);
expect(result[0]).toBe('z');
expect(result[result.length - 1]).toBe('a');
});
test('reverses a medium string', () => {
const input = 'The quick brown fox jumps over the lazy dog';
const expected = 'god yzal eht revo spmuj xof nworb kciuq ehT';
expect(reverseString(input)).toBe(expected);
});
});
describe('isPalindrome', () => {
test('returns true for empty string', () => {
expect(isPalindrome('')).toBe(true);
});
test('returns true for single character', () => {
expect(isPalindrome('a')).toBe(true);
});
test('returns true for palindrome word', () => {
expect(isPalindrome('racecar')).toBe(true);
});
test('returns true for palindrome with mixed case', () => {
expect(isPalindrome('RaceCar')).toBe(true);
});
test('returns true for palindrome with spaces', () => {
expect(isPalindrome('A man a plan a canal Panama')).toBe(true);
});
test('returns false for non-palindrome', () => {
expect(isPalindrome('hello')).toBe(false);
});
});
describe('countOccurrences', () => {
test('returns 0 for empty string', () => {
expect(countOccurrences('', 'a')).toBe(0);
});
test('returns 0 when substring not found', () => {
expect(countOccurrences('hello', 'x')).toBe(0);
});
test('counts single occurrence', () => {
expect(countOccurrences('hello', 'e')).toBe(1);
});
test('counts multiple occurrences', () => {
expect(countOccurrences('hello', 'l')).toBe(2);
});
test('counts overlapping occurrences', () => {
expect(countOccurrences('aaa', 'aa')).toBe(2);
});
test('counts multi-character substring', () => {
expect(countOccurrences('abcabc', 'abc')).toBe(2);
});
});
describe('longestCommonPrefix', () => {
test('returns empty for empty array', () => {
expect(longestCommonPrefix([])).toBe('');
});
test('returns the string for single element', () => {
expect(longestCommonPrefix(['hello'])).toBe('hello');
});
test('finds common prefix', () => {
expect(longestCommonPrefix(['flower', 'flow', 'flight'])).toBe('fl');
});
test('returns empty when no common prefix', () => {
expect(longestCommonPrefix(['dog', 'racecar', 'car'])).toBe('');
});
test('handles identical strings', () => {
expect(longestCommonPrefix(['test', 'test', 'test'])).toBe('test');
});
});
describe('toTitleCase', () => {
test('converts single word', () => {
expect(toTitleCase('hello')).toBe('Hello');
});
test('converts multiple words', () => {
expect(toTitleCase('hello world')).toBe('Hello World');
});
test('handles already title case', () => {
expect(toTitleCase('Hello World')).toBe('Hello World');
});
test('handles all uppercase', () => {
expect(toTitleCase('HELLO WORLD')).toBe('Hello World');
});
test('handles empty string', () => {
expect(toTitleCase('')).toBe('');
});
});

View file

@ -0,0 +1,20 @@
{
"compilerOptions": {
"target": "ES2020",
"module": "commonjs",
"lib": ["ES2020"],
"outDir": "./dist",
"rootDir": ".",
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"forceConsistentCasingInFileNames": true,
"declaration": true,
"declarationMap": true,
"sourceMap": true,
"resolveJsonModule": true,
"moduleResolution": "node"
},
"include": ["*.ts", "tests/**/*.ts"],
"exclude": ["node_modules", "dist"]
}

View file

@ -14,6 +14,7 @@ from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.env_utils import get_codeflash_api_key
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.languages import is_javascript, is_python
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import (
AIServiceRefinerRequest,
@ -101,11 +102,11 @@ class AiServiceClient:
return response
def _get_valid_candidates(
self, optimizations_json: list[dict[str, Any]], source: OptimizedCandidateSource
self, optimizations_json: list[dict[str, Any]], source: OptimizedCandidateSource, language: str = "python"
) -> list[OptimizedCandidate]:
candidates: list[OptimizedCandidate] = []
for opt in optimizations_json:
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"], expected_language=language)
if not code.code_strings:
continue
candidates.append(
@ -120,25 +121,32 @@ class AiServiceClient:
)
return candidates
def optimize_python_code( # noqa: D417
def optimize_code(
self,
source_code: str,
dependency_code: str,
trace_id: str,
experiment_metadata: ExperimentMetadata | None = None,
*,
language: str = "python",
language_version: str
| None = None, # TODO:{claude} add language version to the language support and it should be cached
module_system: str | None = None,
is_async: bool = False,
n_candidates: int = 5,
is_numerical_code: bool | None = None,
) -> list[OptimizedCandidate]:
"""Optimize the given python code for performance by making a request to the Django endpoint.
"""Optimize the given code for performance by making a request to the Django endpoint.
Parameters
----------
- source_code (str): The python code to optimize.
- source_code (str): The code to optimize.
- dependency_code (str): The dependency code used as read-only context for the optimization
- trace_id (str): Trace id of optimization run
- experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization
- language (str): Programming language ("python", "javascript", "typescript")
- language_version (str | None): Language version (e.g., "3.11.0" for Python, "ES2022" for JS)
- module_system (str | None): JS/TS module system ("esm", "commonjs", or None for Python)
- is_async (bool): Whether the function being optimized is async
- n_candidates (int): Number of candidates to generate
@ -152,11 +160,12 @@ class AiServiceClient:
start_time = time.perf_counter()
git_repo_owner, git_repo_name = safe_get_repo_owner_and_name()
payload = {
# Build payload with language-specific fields
payload: dict[str, Any] = {
"source_code": source_code,
"dependency_code": dependency_code,
"trace_id": trace_id,
"python_version": platform.python_version(),
"language": language,
"experiment_metadata": experiment_metadata,
"codeflash_version": codeflash_version,
"current_username": get_last_commit_author_if_pr_exists(None),
@ -167,6 +176,22 @@ class AiServiceClient:
"n_candidates": n_candidates,
"is_numerical_code": is_numerical_code,
}
# Add language-specific version fields
# Always include python_version for backward compatibility with older backend
payload["python_version"] = platform.python_version()
if is_python():
pass # python_version already set
else:
payload["language_version"] = language_version or "ES2022"
# Add module system for JavaScript/TypeScript (esm or commonjs)
if module_system:
payload["module_system"] = module_system
# DEBUG: Print payload language field
logger.debug(
f"Sending optimize request with language='{payload['language']}' (type: {type(payload['language'])})"
)
logger.debug(f"Sending optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}")
try:
@ -183,7 +208,7 @@ class AiServiceClient:
logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.")
logger.info(f"!lsp|Received {len(optimizations_json)} optimization candidates.")
console.rule()
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE)
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE, language)
try:
error = response.json()["error"]
except Exception:
@ -193,9 +218,29 @@ class AiServiceClient:
console.rule()
return []
def get_jit_rewritten_code( # noqa: D417
self, source_code: str, trace_id: str
# Backward-compatible alias
def optimize_python_code(
self,
source_code: str,
dependency_code: str,
trace_id: str,
experiment_metadata: ExperimentMetadata | None = None,
*,
is_async: bool = False,
n_candidates: int = 5,
) -> list[OptimizedCandidate]:
"""Backward-compatible alias for optimize_code() with language='python'."""
return self.optimize_code(
source_code=source_code,
dependency_code=dependency_code,
trace_id=trace_id,
experiment_metadata=experiment_metadata,
language="python",
is_async=is_async,
n_candidates=n_candidates,
)
def get_jit_rewritten_code(self, source_code: str, trace_id: str) -> list[OptimizedCandidate]:
"""Rewrite the given python code for performance via jit compilation by making a request to the Django endpoint.
Parameters
@ -245,7 +290,7 @@ class AiServiceClient:
console.rule()
return []
def optimize_python_code_line_profiler( # noqa: D417
def optimize_python_code_line_profiler(
self,
source_code: str,
dependency_code: str,
@ -253,18 +298,22 @@ class AiServiceClient:
line_profiler_results: str,
n_candidates: int,
experiment_metadata: ExperimentMetadata | None = None,
is_numerical_code: bool | None = None, # noqa: FBT001
is_numerical_code: bool | None = None,
language: str = "python",
language_version: str | None = None,
) -> list[OptimizedCandidate]:
"""Optimize the given python code for performance using line profiler results.
"""Optimize code for performance using line profiler results.
Parameters
----------
- source_code (str): The python code to optimize.
- source_code (str): The code to optimize.
- dependency_code (str): The dependency code used as read-only context for the optimization
- trace_id (str): Trace id of optimization run
- line_profiler_results (str): Line profiler output to guide optimization
- experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization
- n_candidates (int): Number of candidates to generate
- language (str): Programming language (python, javascript, typescript)
- language_version (str): Language version (e.g., "3.12.0" for Python, "ES2022" for JavaScript)
Returns
-------
@ -278,13 +327,18 @@ class AiServiceClient:
logger.info("Generating optimized candidates with line profiler…")
console.rule()
# Set python_version for backward compatibility with Python, or use language_version
python_version = language_version if language_version else platform.python_version()
payload = {
"source_code": source_code,
"dependency_code": dependency_code,
"n_candidates": n_candidates,
"line_profiler_results": line_profiler_results,
"trace_id": trace_id,
"python_version": platform.python_version(),
"python_version": python_version,
"language": language,
"language_version": language_version,
"experiment_metadata": experiment_metadata,
"codeflash_version": codeflash_version,
"call_sequence": self.get_next_sequence(),
@ -345,19 +399,22 @@ class AiServiceClient:
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
return None
def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
"""Optimize the given python code for performance by making a request to the Django endpoint.
def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
"""Refine optimization candidates for improved performance.
Supports Python, JavaScript, and TypeScript code refinement with optional
multi-file context for better understanding of imports and dependencies.
Args:
request: A list of optimization candidate details for refinement
request: A list of optimization candidate details for refinement
Returns:
-------
- List[OptimizationCandidate]: A list of Optimization Candidates.
List of refined optimization candidates
"""
payload = [
{
payload: list[dict[str, Any]] = []
for opt in request:
item: dict[str, Any] = {
"optimization_id": opt.optimization_id,
"original_source_code": opt.original_source_code,
"read_only_dependency_code": opt.read_only_dependency_code,
@ -370,11 +427,26 @@ class AiServiceClient:
"speedup": opt.speedup,
"trace_id": opt.trace_id,
"function_references": opt.function_references,
"python_version": platform.python_version(),
"call_sequence": self.get_next_sequence(),
# Multi-language support
"language": opt.language,
}
for opt in request
]
# Add language version - always include python_version for backward compatibility
item["python_version"] = platform.python_version()
if is_python():
pass # python_version already set
elif opt.language_version:
item["language_version"] = opt.language_version
else:
item["language_version"] = "ES2022" # Default for JS/TS
# Add multi-file context if provided
if opt.additional_context_files:
item["additional_context_files"] = opt.additional_context_files
payload.append(item)
try:
response = self.make_ai_service_request("/refinement", payload=payload, timeout=self.timeout)
except requests.exceptions.RequestException as e:
@ -396,6 +468,9 @@ class AiServiceClient:
console.rule()
return []
# Alias for backward compatibility
optimize_python_code_refinement = optimize_code_refinement
def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None:
"""Repair the optimization candidate that is not matching the test result of the original code.
@ -415,6 +490,7 @@ class AiServiceClient:
"modified_source_code": request.modified_source_code,
"trace_id": request.trace_id,
"test_diffs": request.test_diffs,
"language": request.language,
}
response = self.make_ai_service_request("/code_repair", payload=payload, timeout=self.timeout)
except (requests.exceptions.RequestException, TypeError) as e:
@ -426,7 +502,9 @@ class AiServiceClient:
fixed_optimization = response.json()
console.rule()
valid_candidates = self._get_valid_candidates([fixed_optimization], OptimizedCandidateSource.REPAIR)
valid_candidates = self._get_valid_candidates(
[fixed_optimization], OptimizedCandidateSource.REPAIR, request.language
)
if not valid_candidates:
logger.error("Code repair failed to generate a valid candidate.")
return None
@ -442,7 +520,7 @@ class AiServiceClient:
console.rule()
return None
def get_new_explanation( # noqa: D417
def get_new_explanation(
self,
source_code: str,
optimized_code: str,
@ -542,7 +620,7 @@ class AiServiceClient:
console.rule()
return ""
def generate_ranking( # noqa: D417
def generate_ranking(
self,
trace_id: str,
diffs: list[str],
@ -594,7 +672,7 @@ class AiServiceClient:
console.rule()
return None
def log_results( # noqa: D417
def log_results(
self,
function_trace_id: str,
speedup_ratio: dict[str, float | None] | None,
@ -635,7 +713,7 @@ class AiServiceClient:
except requests.exceptions.RequestException as e:
logger.exception(f"Error logging features: {e}")
def generate_regression_tests( # noqa: D417
def generate_regression_tests(
self,
source_code_being_tested: str,
function_to_optimize: FunctionToOptimize,
@ -646,7 +724,11 @@ class AiServiceClient:
test_timeout: int,
trace_id: str,
test_index: int,
is_numerical_code: bool | None = None, # noqa: FBT001
*,
language: str = "python",
language_version: str | None = None,
module_system: str | None = None,
is_numerical_code: bool | None = None,
) -> tuple[str, str, str] | None:
"""Generate regression tests for the given function by making a request to the Django endpoint.
@ -657,19 +739,31 @@ class AiServiceClient:
- helper_function_names (list[Source]): List of helper function names.
- module_path (Path): The module path where the function is located.
- test_module_path (Path): The module path for the test code.
- test_framework (str): The test framework to use, e.g., "pytest".
- test_framework (str): The test framework to use, e.g., "pytest", "jest".
- test_timeout (int): The timeout for each test in seconds.
- test_index (int): The index from 0-(n-1) if n tests are generated for a single trace_id
- language (str): Programming language ("python", "javascript", "typescript")
- language_version (str | None): Language version (e.g., "3.11.0" for Python, "ES2022" for JS)
- module_system (str | None): JS/TS module system ("esm", "commonjs", or None for Python)
Returns
-------
- Dict[str, str] | None: The generated regression tests and instrumented tests, or None if an error occurred.
"""
assert test_framework in ["pytest", "unittest"], (
f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
)
payload = {
# Validate test framework based on language
python_frameworks = ["pytest", "unittest"]
javascript_frameworks = ["jest", "mocha", "vitest"]
if is_python():
assert test_framework in python_frameworks, (
f"Invalid test framework for Python, got {test_framework} but expected one of {python_frameworks}"
)
elif is_javascript():
assert test_framework in javascript_frameworks, (
f"Invalid test framework for JavaScript, got {test_framework} but expected one of {javascript_frameworks}"
)
payload: dict[str, Any] = {
"source_code_being_tested": source_code_being_tested,
"function_to_optimize": function_to_optimize,
"helper_function_names": helper_function_names,
@ -679,12 +773,26 @@ class AiServiceClient:
"test_timeout": test_timeout,
"trace_id": trace_id,
"test_index": test_index,
"python_version": platform.python_version(),
"language": language,
"codeflash_version": codeflash_version,
"is_async": function_to_optimize.is_async,
"call_sequence": self.get_next_sequence(),
"is_numerical_code": is_numerical_code,
}
# Add language-specific version fields
# Always include python_version for backward compatibility with older backend
payload["python_version"] = platform.python_version()
if is_python():
pass # python_version already set
else:
payload["language_version"] = language_version or "ES2022"
# Add module system for JavaScript/TypeScript (esm or commonjs)
if module_system:
payload["module_system"] = module_system
# DEBUG: Print payload language field
logger.debug(f"Sending testgen request with language='{payload['language']}', framework='{test_framework}'")
try:
response = self.make_ai_service_request("/testgen", payload=payload, timeout=self.timeout)
except requests.exceptions.RequestException as e:
@ -706,7 +814,7 @@ class AiServiceClient:
error = response.json()["error"]
logger.error(f"Error generating tests: {response.status_code} - {error}")
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
return None # noqa: TRY300
return None
except Exception:
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})
@ -722,8 +830,8 @@ class AiServiceClient:
function_trace_id: str,
coverage_message: str,
replay_tests: str,
concolic_tests: str, # noqa: ARG002
calling_fn_details: str,
language: str = "python",
) -> OptimizationReviewResult:
"""Compute the optimization review of current Pull Request.
@ -765,7 +873,8 @@ class AiServiceClient:
"original_runtime": humanize_runtime(explanation.original_runtime_ns),
"codeflash_version": codeflash_version,
"calling_fn_details": calling_fn_details,
"python_version": platform.python_version(),
"language": language,
"python_version": platform.python_version() if is_python() else None,
"call_sequence": self.get_next_sequence(),
}
console.rule()

View file

@ -81,7 +81,7 @@ def make_cfapi_request(
else:
response = requests.get(url, headers=cfapi_headers, params=params, timeout=60)
response.raise_for_status()
return response # noqa: TRY300
return response
except requests.exceptions.HTTPError:
# response may be either a string or JSON, so we handle both cases
error_message = ""
@ -102,7 +102,7 @@ def make_cfapi_request(
@lru_cache(maxsize=1)
def get_user_id(api_key: Optional[str] = None) -> Optional[str]: # noqa: PLR0911
def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
:param api_key: The API key to use. If None, uses get_codeflash_api_key().
@ -396,7 +396,7 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
def is_function_being_optimized_again(
owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]]
) -> Any: # noqa: ANN401
) -> Any:
"""Check if the function being optimized is being optimized again."""
response = make_cfapi_request(
"/is-already-optimized",

260
codeflash/api/schemas.py Normal file
View file

@ -0,0 +1,260 @@
"""Language-agnostic schemas for AI service communication.
This module defines standardized payload schemas that work across all supported
languages (Python, JavaScript, TypeScript, and future languages).
Design principles:
1. General fields that apply to any language
2. Language-specific fields grouped in a nested object
3. Backward compatible with existing backend
4. Extensible for future languages without breaking changes
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
class ModuleSystem(str, Enum):
"""Module system used by the code."""
COMMONJS = "commonjs" # JavaScript/Node.js require/exports
ESM = "esm" # ES Modules import/export
PYTHON = "python" # Python import system
UNKNOWN = "unknown"
class TestFramework(str, Enum):
"""Supported test frameworks."""
# Python
PYTEST = "pytest"
UNITTEST = "unittest"
# JavaScript/TypeScript
JEST = "jest"
MOCHA = "mocha"
VITEST = "vitest"
@dataclass
class LanguageInfo:
"""Language-specific information.
General fields that describe the programming language and its environment.
This is designed to be extensible for future languages.
"""
# Core language identifier
name: str # "python", "javascript", "typescript", "rust", etc.
# Language version (format varies by language)
# - Python: "3.11.0"
# - JavaScript/TypeScript: "ES2022", "ES2023"
# - Rust: "1.70.0"
version: str | None = None
# Module system (primarily for JS/TS, but could apply to others)
module_system: ModuleSystem = ModuleSystem.UNKNOWN
# File extension (for generated files)
# - Python: ".py"
# - JavaScript: ".js", ".mjs", ".cjs"
# - TypeScript: ".ts", ".mts", ".cts"
file_extension: str = ""
# Type system info (for typed languages)
has_type_annotations: bool = False
type_checker: str | None = None # "mypy", "typescript", "pyright", etc.
@dataclass
class TestInfo:
"""Test-related information."""
# Test framework being used
framework: TestFramework
# Timeout for test execution (seconds)
timeout: int = 60
# Test file path patterns (for discovery)
test_patterns: list[str] = field(default_factory=list)
# Path to test files relative to project root
tests_root: str = "tests"
@dataclass
class OptimizeRequest:
"""Request payload for code optimization.
This schema is designed to be language-agnostic while supporting
language-specific fields through the `language_info` object.
"""
# === Core required fields ===
source_code: str # Code to optimize
trace_id: str # Unique identifier for this optimization run
# === Language information ===
language_info: LanguageInfo
# === Optional context ===
dependency_code: str = "" # Read-only context code
module_path: str = "" # Path to the module being optimized
# === Function metadata ===
is_async: bool = False # Whether function is async/await
is_numerical_code: bool | None = None # Whether code does numerical computation
# === Generation parameters ===
n_candidates: int = 5 # Number of optimization candidates
# === Metadata ===
codeflash_version: str = ""
experiment_metadata: dict[str, Any] | None = None
repo_owner: str | None = None
repo_name: str | None = None
current_username: str | None = None
def to_payload(self) -> dict[str, Any]:
"""Convert to API payload dict, maintaining backward compatibility."""
payload = {
"source_code": self.source_code,
"trace_id": self.trace_id,
"language": self.language_info.name,
"dependency_code": self.dependency_code,
"is_async": self.is_async,
"n_candidates": self.n_candidates,
"codeflash_version": self.codeflash_version,
"experiment_metadata": self.experiment_metadata,
"repo_owner": self.repo_owner,
"repo_name": self.repo_name,
"current_username": self.current_username,
"is_numerical_code": self.is_numerical_code,
}
# Add language-specific fields
if self.language_info.version:
payload["language_version"] = self.language_info.version
# Backward compat: always include python_version
import platform
payload["python_version"] = platform.python_version()
# Module system for JS/TS
if self.language_info.module_system != ModuleSystem.UNKNOWN:
payload["module_system"] = self.language_info.module_system.value
return payload
@dataclass
class TestGenRequest:
"""Request payload for test generation.
This schema is designed to be language-agnostic while supporting
language-specific fields through the `language_info` and `test_info` objects.
"""
# === Core required fields ===
source_code: str # Code being tested
function_name: str # Name of function to generate tests for
trace_id: str # Unique identifier
# === Language information ===
language_info: LanguageInfo
# === Test information ===
test_info: TestInfo
# === Path information ===
module_path: str = "" # Path to source module
test_module_path: str = "" # Path for generated test
# === Function metadata ===
helper_function_names: list[str] = field(default_factory=list)
is_async: bool = False
is_numerical_code: bool | None = None
# === Generation parameters ===
test_index: int = 0 # Index when generating multiple tests
# === Metadata ===
codeflash_version: str = ""
def to_payload(self) -> dict[str, Any]:
"""Convert to API payload dict, maintaining backward compatibility."""
payload = {
"source_code_being_tested": self.source_code,
"function_to_optimize": {"function_name": self.function_name, "is_async": self.is_async},
"helper_function_names": self.helper_function_names,
"module_path": self.module_path,
"test_module_path": self.test_module_path,
"test_framework": self.test_info.framework.value,
"test_timeout": self.test_info.timeout,
"trace_id": self.trace_id,
"test_index": self.test_index,
"language": self.language_info.name,
"codeflash_version": self.codeflash_version,
"is_async": self.is_async,
"is_numerical_code": self.is_numerical_code,
}
# Add language version
if self.language_info.version:
payload["language_version"] = self.language_info.version
# Backward compat: always include python_version
import platform
payload["python_version"] = platform.python_version()
# Module system for JS/TS
if self.language_info.module_system != ModuleSystem.UNKNOWN:
payload["module_system"] = self.language_info.module_system.value
return payload
# === Helper functions to create language info ===
def python_language_info(version: str | None = None) -> LanguageInfo:
"""Create LanguageInfo for Python."""
import platform
return LanguageInfo(
name="python",
version=version or platform.python_version(),
module_system=ModuleSystem.PYTHON,
file_extension=".py",
has_type_annotations=True,
type_checker="mypy",
)
def javascript_language_info(
module_system: ModuleSystem = ModuleSystem.COMMONJS, version: str = "ES2022"
) -> LanguageInfo:
"""Create LanguageInfo for JavaScript."""
ext = ".mjs" if module_system == ModuleSystem.ESM else ".js"
return LanguageInfo(
name="javascript", version=version, module_system=module_system, file_extension=ext, has_type_annotations=False
)
def typescript_language_info(module_system: ModuleSystem = ModuleSystem.ESM, version: str = "ES2022") -> LanguageInfo:
"""Create LanguageInfo for TypeScript."""
return LanguageInfo(
name="typescript",
version=version,
module_system=module_system,
file_extension=".ts",
has_type_annotations=True,
type_checker="typescript",
)

View file

@ -108,7 +108,7 @@ class CodeflashTrace:
func_id = (func.__module__, func.__name__)
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003
# Initialize thread-local active functions set if it doesn't exist
if not hasattr(self._thread_local, "active_functions"):
self._thread_local.active_functions = set()

View file

@ -53,7 +53,7 @@ class AddDecoratorTransformer(cst.CSTTransformer):
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# Create import statement for codeflash_trace
if not self.added_codeflash_trace:
return updated_node

View file

@ -110,7 +110,7 @@ class CodeFlashBenchmarkPlugin:
# Process each row
for row in cursor.fetchall():
module_name, class_name, function_name, benchmark_file, benchmark_func, benchmark_line, time_ns = row
module_name, class_name, function_name, benchmark_file, benchmark_func, _benchmark_line, time_ns = row
# Create the function key (module_name.class_name.function_name)
if class_name:
@ -172,7 +172,7 @@ class CodeFlashBenchmarkPlugin:
# Process overhead information
for row in cursor.fetchall():
benchmark_file, benchmark_func, benchmark_line, total_overhead_ns = row
benchmark_file, benchmark_func, _benchmark_line, total_overhead_ns = row
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
overhead_by_benchmark[benchmark_key] = total_overhead_ns or 0 # Handle NULL sum case
@ -184,7 +184,7 @@ class CodeFlashBenchmarkPlugin:
# Process each row and subtract overhead
for row in cursor.fetchall():
benchmark_file, benchmark_func, benchmark_line, time_ns = row
benchmark_file, benchmark_func, _benchmark_line, time_ns = row
# Create the benchmark key (file::function::line)
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
@ -200,7 +200,7 @@ class CodeFlashBenchmarkPlugin:
# Pytest hooks
@pytest.hookimpl
def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, ARG002
def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001
"""Execute after whole test run is completed."""
# Write any remaining benchmark timings to the database
codeflash_trace.close()
@ -218,7 +218,7 @@ class CodeFlashBenchmarkPlugin:
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
for item in items:
# Check for direct benchmark fixture usage
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames # ty:ignore[unsupported-operator]
# Check for @pytest.mark.benchmark marker
has_marker = False
@ -236,7 +236,7 @@ class CodeFlashBenchmarkPlugin:
def __init__(self, request: pytest.FixtureRequest) -> None:
self.request = request
def __call__(self, func, *args, **kwargs): # type: ignore # noqa: ANN001, ANN002, ANN003, ANN204, PGH003
def __call__(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN204
"""Handle both direct function calls and decorator usage."""
if args or kwargs:
# Used as benchmark(func, *args, **kwargs)
@ -249,7 +249,7 @@ class CodeFlashBenchmarkPlugin:
self._run_benchmark(func)
return wrapped_func
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN002, ANN003, ANN202
"""Actual benchmark implementation."""
node_path = getattr(self.request.node, "path", None) or getattr(self.request.node, "fspath", None)
if node_path is None:

View file

@ -20,7 +20,7 @@ def parse_args() -> Namespace:
parser = ArgumentParser()
subparsers = parser.add_subparsers(dest="command", help="Sub-commands")
init_parser = subparsers.add_parser("init", help="Initialize Codeflash for a Python project.")
init_parser = subparsers.add_parser("init", help="Initialize Codeflash for your project.")
init_parser.set_defaults(func=init_codeflash)
subparsers.add_parser("vscode-install", help="Install the Codeflash VSCode extension")
@ -28,7 +28,7 @@ def parse_args() -> Namespace:
init_actions_parser = subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
init_actions_parser.set_defaults(func=install_github_actions)
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize a Python project.")
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.")
from codeflash.tracer import main as tracer_main
@ -70,8 +70,8 @@ def parse_args() -> Namespace:
parser.add_argument(
"--module-root",
type=str,
help="Path to the project's Python module that you want to optimize."
" This is the top-level root directory where all the Python source code is located.",
help="Path to the project's module that you want to optimize."
" This is the top-level root directory where all the source code is located.",
)
parser.add_argument(
"--tests-root", type=str, help="Path to the test directory of the project, where all the tests are located."
@ -206,7 +206,21 @@ def process_pyproject_config(args: Namespace) -> Namespace:
setattr(args, key.replace("-", "_"), pyproject_config[key])
assert args.module_root is not None, "--module-root must be specified"
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
assert args.tests_root is not None, "--tests-root must be specified"
# For JS/TS projects, tests_root is optional (Jest auto-discovers tests)
# Default to module_root if not specified
is_js_ts_project = pyproject_config.get("language") in ("javascript", "typescript")
if args.tests_root is None:
if is_js_ts_project:
# Try common JS test directories, or default to module_root
for test_dir in ["test", "tests", "__tests__"]:
if Path(test_dir).is_dir():
args.tests_root = test_dir
break
if args.tests_root is None:
args.tests_root = args.module_root
else:
raise AssertionError("--tests-root must be specified")
assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"
if args.benchmark:
assert args.benchmarks_root is not None, "--benchmarks-root must be specified when running with --benchmark"

View file

@ -43,7 +43,7 @@ def inquirer_wrapper(func: Callable[..., str | bool], *args: str | bool, **kwarg
return func(*new_args, **new_kwargs)
def split_string_to_cli_width(string: str, is_confirm: bool = False) -> list[str]: # noqa: FBT001, FBT002
def split_string_to_cli_width(string: str, is_confirm: bool = False) -> list[str]:
cli_width, _ = shutil.get_terminal_size()
# split string to lines that accommodate "[?] " prefix
cli_width -= len("[?] ")

View file

@ -26,6 +26,15 @@ from codeflash.api.cfapi import get_user_id, is_github_app_installed_on_repo, se
from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.console import console, logger
from codeflash.cli_cmds.extension import install_vscode_extension
# Import JS/TS init module
from codeflash.cli_cmds.init_javascript import (
ProjectLanguage,
detect_project_language,
determine_js_package_manager,
get_js_dependency_installation_commands,
init_js_project,
)
from codeflash.code_utils.code_utils import validate_relative_directory_path
from codeflash.code_utils.compat import LF
from codeflash.code_utils.config_parser import parse_config_file
@ -57,6 +66,8 @@ CODEFLASH_LOGO: str = (
@dataclass(frozen=True)
class CLISetupInfo:
"""Setup info for Python projects."""
module_root: str
tests_root: str
benchmarks_root: Union[str, None]
@ -68,12 +79,16 @@ class CLISetupInfo:
@dataclass(frozen=True)
class VsCodeSetupInfo:
"""Setup info for VSCode extension initialization."""
module_root: str
tests_root: str
formatter: Union[str, list[str]]
class DependencyManager(Enum):
"""Python dependency managers."""
PIP = auto()
POETRY = auto()
UV = auto()
@ -95,6 +110,15 @@ def init_codeflash() -> None:
console.print(welcome_panel)
console.print()
# TODO:{claude} move the init_javascript to the support folder. Move any other language related specific implementation (other than python) to its support.
# Detect project language
project_language = detect_project_language()
if project_language in (ProjectLanguage.JAVASCRIPT, ProjectLanguage.TYPESCRIPT):
init_js_project(project_language)
return
# Python project flow
did_add_new_key = prompt_api_key()
should_modify, config = should_modify_pyproject_toml()
@ -663,7 +687,7 @@ def create_empty_pyproject_toml(pyproject_toml_path: Path) -> None:
apologize_and_exit()
def install_github_actions(override_formatter_check: bool = False) -> None: # noqa: FBT001, FBT002
def install_github_actions(override_formatter_check: bool = False) -> None:
try:
config, _config_file_path = parse_config_file(override_formatter_check=override_formatter_check)
@ -771,8 +795,16 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n
# Generate workflow content AFTER user confirmation
logger.info("[cmd_init.py:install_github_actions] User confirmed, generating workflow content...")
# Select the appropriate workflow template based on project language
project_language = detect_project_language_for_workflow(Path.cwd())
if project_language in ("javascript", "typescript"):
workflow_template = "codeflash-optimize-js.yaml"
else:
workflow_template = "codeflash-optimize.yaml"
optimize_yml_content = (
files("codeflash").joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml").read_text(encoding="utf-8")
files("codeflash").joinpath("cli_cmds", "workflows", workflow_template).read_text(encoding="utf-8")
)
materialized_optimize_yml_content = generate_dynamic_workflow_content(
optimize_yml_content, config, git_root, benchmark_mode
@ -1089,11 +1121,12 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n
apologize_and_exit()
def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyManager: # noqa: PLR0911
def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyManager:
"""Determine which dependency manager is being used based on pyproject.toml contents."""
if (Path.cwd() / "poetry.lock").exists():
cwd = Path.cwd()
if (cwd / "poetry.lock").exists():
return DependencyManager.POETRY
if (Path.cwd() / "uv.lock").exists():
if (cwd / "uv.lock").exists():
return DependencyManager.UV
if "tool" not in pyproject_data:
return DependencyManager.PIP
@ -1168,6 +1201,48 @@ def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str:
working-directory: ./{working_dir}"""
# ============================================================================
# JavaScript/TypeScript GitHub Actions Support
# ============================================================================
# Note: JS package manager and workflow helper functions are imported from init_javascript.py
def detect_project_language_for_workflow(project_root: Path) -> str:
"""Detect the primary language of the project for workflow generation.
Returns: 'python', 'javascript', or 'typescript'
"""
# Check for TypeScript config
if (project_root / "tsconfig.json").exists():
return "typescript"
# Check for JavaScript/TypeScript indicators
has_package_json = (project_root / "package.json").exists()
has_pyproject = (project_root / "pyproject.toml").exists()
if has_package_json and not has_pyproject:
# Pure JS/TS project
return "javascript"
if has_pyproject and not has_package_json:
# Pure Python project
return "python"
# Both exist - count files to determine primary language
js_count = 0
py_count = 0
for file in project_root.rglob("*"):
if file.is_file():
suffix = file.suffix.lower()
if suffix in {".js", ".jsx", ".ts", ".tsx", ".mjs", ".cjs"}:
js_count += 1
elif suffix == ".py":
py_count += 1
if js_count > py_count:
return "javascript"
return "python"
def collect_repo_files_for_workflow(git_root: Path) -> dict[str, Any]:
"""Collect important repository files and directory structure for workflow generation.
@ -1251,10 +1326,7 @@ def collect_repo_files_for_workflow(git_root: Path) -> dict[str, Any]:
def generate_dynamic_workflow_content(
optimize_yml_content: str,
config: tuple[dict[str, Any], Path],
git_root: Path,
benchmark_mode: bool = False, # noqa: FBT001, FBT002
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False
) -> str:
"""Generate workflow content with dynamic steps from AI service, falling back to static template.
@ -1268,7 +1340,15 @@ def generate_dynamic_workflow_content(
module_path = str(Path(config["module_root"]).relative_to(git_root) / "**")
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path)
# Get working directory
# Detect project language
project_language = detect_project_language_for_workflow(Path.cwd())
# For JavaScript/TypeScript projects, use static template customization
# (AI-generated steps are currently Python-only)
if project_language in ("javascript", "typescript"):
return customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode)
# Python project - try AI-generated steps
toml_path = Path.cwd() / "pyproject.toml"
try:
with toml_path.open(encoding="utf8") as pyproject_file:
@ -1378,14 +1458,28 @@ def generate_dynamic_workflow_content(
def customize_codeflash_yaml_content(
optimize_yml_content: str,
config: tuple[dict[str, Any], Path],
git_root: Path,
benchmark_mode: bool = False, # noqa: FBT001, FBT002
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False
) -> str:
module_path = str(Path(config["module_root"]).relative_to(git_root) / "**")
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path)
# Detect project language
project_language = detect_project_language_for_workflow(Path.cwd())
if project_language in ("javascript", "typescript"):
# JavaScript/TypeScript project
return _customize_js_workflow_content(optimize_yml_content, git_root, benchmark_mode)
# Python project (default)
return _customize_python_workflow_content(optimize_yml_content, git_root, benchmark_mode)
def _customize_python_workflow_content(
optimize_yml_content: str,
git_root: Path,
benchmark_mode: bool = False, # noqa: FBT001, FBT002
) -> str:
"""Customize workflow content for Python projects."""
# Get dependency installation commands
toml_path = Path.cwd() / "pyproject.toml"
try:
@ -1404,7 +1498,7 @@ def customize_codeflash_yaml_content(
python_depmanager_installation = get_dependency_manager_installation_string(dep_manager)
optimize_yml_content = optimize_yml_content.replace(
"{{ setup_python_dependency_manager }}", python_depmanager_installation
"{{ setup_runtime_environment }}", python_depmanager_installation
)
install_deps_cmd = get_dependency_installation_commands(dep_manager)
@ -1418,6 +1512,64 @@ def customize_codeflash_yaml_content(
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
# TODO:{claude} Refactor and move to support for language specific
def _customize_js_workflow_content(
optimize_yml_content: str,
git_root: Path,
benchmark_mode: bool = False, # noqa: FBT001, FBT002
) -> str:
"""Customize workflow content for JavaScript/TypeScript projects."""
from codeflash.cli_cmds.init_javascript import (
get_js_codeflash_install_step,
get_js_codeflash_run_command,
get_js_runtime_setup_steps,
is_codeflash_dependency,
)
project_root = Path.cwd()
package_json_path = project_root / "package.json"
if not package_json_path.exists():
click.echo(
f"I couldn't find a package.json in the current directory.{LF}"
f"Please run `npm init` or create a package.json file first."
)
apologize_and_exit()
# Determine working directory relative to git root
if project_root == git_root:
working_dir = ""
else:
rel_path = str(project_root.relative_to(git_root))
working_dir = f"""defaults:
run:
working-directory: ./{rel_path}"""
optimize_yml_content = optimize_yml_content.replace("{{ working_directory }}", working_dir)
# Determine package manager and codeflash dependency status
pkg_manager = determine_js_package_manager(project_root)
codeflash_is_dep = is_codeflash_dependency(project_root)
# Setup runtime environment (Node.js/Bun)
runtime_setup = get_js_runtime_setup_steps(pkg_manager)
optimize_yml_content = optimize_yml_content.replace("{{ setup_runtime_steps }}", runtime_setup)
# Install dependencies
install_deps_cmd = get_js_dependency_installation_commands(pkg_manager)
optimize_yml_content = optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)
# Install codeflash step (only if not a dependency)
install_codeflash = get_js_codeflash_install_step(pkg_manager, is_dependency=codeflash_is_dep)
optimize_yml_content = optimize_yml_content.replace("{{ install_codeflash_step }}", install_codeflash)
# Codeflash run command
codeflash_cmd = get_js_codeflash_run_command(pkg_manager, is_dependency=codeflash_is_dep)
if benchmark_mode:
codeflash_cmd += " --benchmark"
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
def get_formatter_cmds(formatter: str) -> list[str]:
if formatter == "black":
return ["black $file"]

View file

@ -98,17 +98,32 @@ def code_print(
file_name: Optional[str] = None,
function_name: Optional[str] = None,
lsp_message_id: Optional[str] = None,
language: str = "python",
) -> None:
"""Print code with syntax highlighting.
Args:
code_str: The code to print
file_name: Optional file name for LSP
function_name: Optional function name for LSP
lsp_message_id: Optional LSP message ID
language: Programming language for syntax highlighting ('python', 'javascript', 'typescript')
"""
if is_LSP_enabled():
lsp_log(
LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name, message_id=lsp_message_id)
)
return
"""Print code with syntax highlighting."""
from rich.syntax import Syntax
# Map codeflash language names to rich/pygments lexer names
lexer_map = {"python": "python", "javascript": "javascript", "typescript": "typescript"}
lexer = lexer_map.get(language, "python")
console.rule()
console.print(Syntax(code_str, "python", line_numbers=True, theme="github-dark"))
console.print(Syntax(code_str, lexer, line_numbers=True, theme="github-dark"))
console.rule()

View file

@ -0,0 +1,657 @@
"""JavaScript/TypeScript project initialization for Codeflash."""
# TODO:{claude} move to language support directory
from __future__ import annotations
import json
import os
import sys
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Any, Union
import click
import inquirer
from git import InvalidGitRepositoryError, Repo
from rich.console import Group
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.console import console
from codeflash.code_utils.code_utils import validate_relative_directory_path
from codeflash.code_utils.compat import LF
from codeflash.code_utils.git_utils import get_git_remotes
from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell
from codeflash.telemetry.posthog_cf import ph
class ProjectLanguage(Enum):
"""Supported project languages."""
PYTHON = auto()
JAVASCRIPT = auto()
TYPESCRIPT = auto()
class JsPackageManager(Enum):
"""JavaScript/TypeScript package managers."""
NPM = auto()
YARN = auto()
PNPM = auto()
BUN = auto()
UNKNOWN = auto()
@dataclass(frozen=True)
class JSSetupInfo:
"""Setup info for JavaScript/TypeScript projects.
Only stores values that override auto-detection or user preferences.
Most config is auto-detected from package.json and project structure.
"""
# Override values (None means use auto-detected value)
module_root_override: Union[str, None] = None
formatter_override: Union[list[str], None] = None
# User preferences (stored in config only if non-default)
git_remote: str = "origin"
disable_telemetry: bool = False
ignore_paths: list[str] | None = None
benchmarks_root: Union[str, None] = None
# Import theme from cmd_init to avoid duplication
def _get_theme(): # noqa: ANN202
"""Get the CodeflashTheme - imported lazily to avoid circular imports."""
from codeflash.cli_cmds.cmd_init import CodeflashTheme
return CodeflashTheme()
def detect_project_language(project_root: Path | None = None) -> ProjectLanguage:
"""Detect the primary language of the project.
Args:
project_root: Root directory to check. Defaults to current directory.
Returns:
ProjectLanguage enum value
"""
root = project_root or Path.cwd()
has_pyproject = (root / "pyproject.toml").exists()
has_setup_py = (root / "setup.py").exists()
has_package_json = (root / "package.json").exists()
has_tsconfig = (root / "tsconfig.json").exists()
# TypeScript project
if has_tsconfig:
return ProjectLanguage.TYPESCRIPT
# Pure JS project (has package.json but no Python files)
if has_package_json and not has_pyproject and not has_setup_py:
return ProjectLanguage.JAVASCRIPT
# Python project (default)
return ProjectLanguage.PYTHON
def determine_js_package_manager(project_root: Path) -> JsPackageManager:
"""Determine which JavaScript package manager is being used based on lock files."""
if (project_root / "bun.lockb").exists() or (project_root / "bun.lock").exists():
return JsPackageManager.BUN
if (project_root / "pnpm-lock.yaml").exists():
return JsPackageManager.PNPM
if (project_root / "yarn.lock").exists():
return JsPackageManager.YARN
if (project_root / "package-lock.json").exists():
return JsPackageManager.NPM
# Default to npm if package.json exists but no lock file
if (project_root / "package.json").exists():
return JsPackageManager.NPM
return JsPackageManager.UNKNOWN
def init_js_project(language: ProjectLanguage) -> None:
"""Initialize Codeflash for a JavaScript/TypeScript project."""
from codeflash.cli_cmds.cmd_init import install_github_actions, install_github_app, prompt_api_key
lang_name = "TypeScript" if language == ProjectLanguage.TYPESCRIPT else "JavaScript"
lang_panel = Panel(
Text(
f"📦 Detected {lang_name} project!\n\nI'll help you set up Codeflash for your project.",
style="cyan",
justify="center",
),
title=f"🟨 {lang_name} Setup",
border_style="bright_yellow",
)
console.print(lang_panel)
console.print()
did_add_new_key = prompt_api_key()
should_modify, _config = should_modify_package_json_config()
# Default git remote
git_remote = "origin"
if should_modify:
setup_info = collect_js_setup_info(language)
git_remote = setup_info.git_remote or "origin"
configured = configure_package_json(setup_info)
if not configured:
apologize_and_exit()
install_github_app(git_remote)
install_github_actions(override_formatter_check=True)
# Show completion message
usage_table = Table(show_header=False, show_lines=False, border_style="dim")
usage_table.add_column("Command", style="cyan")
usage_table.add_column("Description", style="white")
usage_table.add_row("codeflash --file <path-to-file> --function <function-name>", "Optimize a specific function")
usage_table.add_row("codeflash --all", "Optimize all functions in all files")
usage_table.add_row("codeflash --help", "See all available options")
completion_message = (
f"⚡️ Codeflash is now set up for your {lang_name} project!\n\nYou can now run any of these commands:"
)
if did_add_new_key:
completion_message += (
"\n\n🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
)
if os.name == "nt":
reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}"
else:
reload_cmd = f"source {get_shell_rc_path()}"
completion_message += f"\nOr run: {reload_cmd}"
completion_panel = Panel(
Group(Text(completion_message, style="bold green"), Text(""), usage_table),
title="🎉 Setup Complete!",
border_style="bright_green",
padding=(1, 2),
)
console.print(completion_panel)
ph("cli-js-installation-successful", {"language": lang_name, "did_add_new_key": did_add_new_key})
sys.exit(0)
def should_modify_package_json_config() -> tuple[bool, dict[str, Any] | None]:
"""Check if package.json has valid codeflash config for JS/TS projects."""
from rich.prompt import Confirm
package_json_path = Path.cwd() / "package.json"
if not package_json_path.exists():
click.echo("❌ No package.json found. Please run 'npm init' first.")
apologize_and_exit()
try:
with package_json_path.open(encoding="utf8") as f:
package_data = json.load(f)
config = package_data.get("codeflash", {})
if not config:
return True, None
# Check if module_root is valid (defaults to "." if not specified)
module_root = config.get("moduleRoot", ".")
if not Path(module_root).is_dir():
return True, None
# Config is valid - ask if user wants to reconfigure
return Confirm.ask(
"✅ A valid Codeflash config already exists in package.json. Do you want to re-configure it?",
default=False,
show_default=True,
), config
except Exception:
return True, None
def collect_js_setup_info(language: ProjectLanguage) -> JSSetupInfo:
"""Collect setup information for JavaScript/TypeScript projects.
Uses auto-detection for most settings and only asks for overrides if needed.
"""
from rich.prompt import Confirm
from codeflash.cli_cmds.cmd_init import ask_for_telemetry, get_valid_subdirs
from codeflash.code_utils.config_js import (
detect_formatter,
detect_module_root,
detect_test_runner,
get_package_json_data,
)
curdir = Path.cwd()
if not os.access(curdir, os.W_OK):
click.echo(f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}")
sys.exit(1)
lang_name = "TypeScript" if language == ProjectLanguage.TYPESCRIPT else "JavaScript"
# Load package.json data for detection
package_json_path = curdir / "package.json"
package_data = get_package_json_data(package_json_path) or {}
# Auto-detect values
detected_module_root = detect_module_root(curdir, package_data)
detected_test_runner = detect_test_runner(curdir, package_data)
detected_formatter = detect_formatter(curdir, package_data)
# Build detection summary
formatter_display = detected_formatter[0] if detected_formatter else "none detected"
detection_table = Table(show_header=False, box=None, padding=(0, 2))
detection_table.add_column("Setting", style="cyan")
detection_table.add_column("Value", style="green")
detection_table.add_row("Module root", detected_module_root)
detection_table.add_row("Test runner", detected_test_runner)
detection_table.add_row("Formatter", formatter_display)
detection_panel = Panel(
Group(Text(f"Auto-detected settings for your {lang_name} project:\n", style="cyan"), detection_table),
title="🔍 Auto-Detection Results",
border_style="bright_blue",
)
console.print(detection_panel)
console.print()
# Ask if user wants to change any settings
module_root_override = None
formatter_override = None
if Confirm.ask("Would you like to change any of these settings?", default=False):
# Module root override
valid_subdirs = get_valid_subdirs()
curdir_option = f"current directory ({curdir})"
custom_dir_option = "enter a custom directory…"
keep_detected_option = f"✓ keep detected ({detected_module_root})"
module_options = [
keep_detected_option,
*[d for d in valid_subdirs if d not in ("tests", "__tests__", "node_modules", detected_module_root)],
curdir_option,
custom_dir_option,
]
module_questions = [
inquirer.List(
"module_root",
message=f"Which directory contains your {lang_name} source code?",
choices=module_options,
default=keep_detected_option,
carousel=True,
)
]
module_answers = inquirer.prompt(module_questions, theme=_get_theme())
if not module_answers:
apologize_and_exit()
module_root_answer = module_answers["module_root"]
if module_root_answer == keep_detected_option:
pass # Keep auto-detected value
elif module_root_answer == curdir_option:
module_root_override = "."
elif module_root_answer == custom_dir_option:
module_root_override = _prompt_custom_directory("module")
else:
module_root_override = module_root_answer
ph("cli-js-module-root-provided", {"overridden": module_root_override is not None})
# Formatter override
formatter_questions = [
inquirer.List(
"formatter",
message="Which code formatter do you use?",
choices=[
(f"✓ keep detected ({formatter_display})", "keep"),
("💅 prettier", "prettier"),
("📐 eslint --fix", "eslint"),
("🔧 other", "other"),
("❌ don't use a formatter", "disabled"),
],
default="keep",
carousel=True,
)
]
formatter_answers = inquirer.prompt(formatter_questions, theme=_get_theme())
if not formatter_answers:
apologize_and_exit()
formatter_choice = formatter_answers["formatter"]
if formatter_choice != "keep":
formatter_override = get_js_formatter_cmd(formatter_choice)
ph("cli-js-formatter-provided", {"overridden": formatter_override is not None})
# Git remote
git_remote = _get_git_remote_for_setup()
# Telemetry
disable_telemetry = not ask_for_telemetry()
return JSSetupInfo(
module_root_override=module_root_override,
formatter_override=formatter_override,
git_remote=git_remote,
disable_telemetry=disable_telemetry,
)
def _prompt_custom_directory(dir_type: str) -> str:
"""Prompt for a custom directory path."""
while True:
custom_questions = [
inquirer.Path(
"custom_path",
message=f"Enter the path to your {dir_type} directory",
path_type=inquirer.Path.DIRECTORY,
exists=True,
)
]
custom_answers = inquirer.prompt(custom_questions, theme=_get_theme())
if not custom_answers:
apologize_and_exit()
custom_path_str = str(custom_answers["custom_path"])
is_valid, error_msg = validate_relative_directory_path(custom_path_str)
if is_valid:
return custom_path_str
click.echo(f"❌ Invalid path: {error_msg}")
click.echo("Please enter a valid relative directory path.")
console.print()
def _get_git_remote_for_setup() -> str:
"""Get git remote for project setup."""
try:
repo = Repo(Path.cwd(), search_parent_directories=True)
git_remotes = get_git_remotes(repo)
if not git_remotes:
return ""
if len(git_remotes) == 1:
return git_remotes[0]
git_panel = Panel(
Text(
"🔗 Configure Git Remote for Pull Requests.\n\nCodeflash will use this remote to create pull requests.",
style="blue",
),
title="🔗 Git Remote Setup",
border_style="bright_blue",
)
console.print(git_panel)
console.print()
git_questions = [
inquirer.List(
"git_remote",
message="Which git remote should Codeflash use?",
choices=git_remotes,
default="origin",
carousel=True,
)
]
git_answers = inquirer.prompt(git_questions, theme=_get_theme())
return git_answers["git_remote"] if git_answers else git_remotes[0]
except InvalidGitRepositoryError:
return ""
def get_js_formatter_cmd(formatter: str) -> list[str]:
"""Get formatter commands for JavaScript/TypeScript."""
if formatter == "prettier":
return ["npx prettier --write $file"]
if formatter == "eslint":
return ["npx eslint --fix $file"]
if formatter == "other":
click.echo("🔧 In package.json, please replace 'your-formatter' with your formatter command.")
return ["your-formatter $file"]
return ["disabled"]
def configure_package_json(setup_info: JSSetupInfo) -> bool:
"""Configure codeflash section in package.json for JavaScript/TypeScript projects.
Only writes minimal config - values that override auto-detection or user preferences.
Auto-detected values (language, moduleRoot, testRunner, formatter) are NOT stored
unless explicitly overridden by the user.
"""
package_json_path = Path.cwd() / "package.json"
try:
with package_json_path.open(encoding="utf8") as f:
package_data = json.load(f)
except FileNotFoundError:
click.echo("❌ No package.json found. Please run 'npm init' first.")
return False
except json.JSONDecodeError as e:
click.echo(f"❌ Invalid package.json: {e}")
return False
# Build minimal codeflash config using camelCase (JS convention)
# Only include values that override auto-detection or are user preferences
codeflash_config: dict[str, Any] = {}
# Module root override (only if user changed from auto-detected)
if setup_info.module_root_override is not None:
codeflash_config["moduleRoot"] = setup_info.module_root_override
# Formatter override (only if user changed from auto-detected)
if setup_info.formatter_override is not None:
if setup_info.formatter_override != ["disabled"]:
codeflash_config["formatterCmds"] = setup_info.formatter_override
else:
codeflash_config["formatterCmds"] = []
# Git remote (only if not default "origin")
if setup_info.git_remote and setup_info.git_remote not in ("", "origin"):
codeflash_config["gitRemote"] = setup_info.git_remote
# User preferences
if setup_info.disable_telemetry:
codeflash_config["disableTelemetry"] = True
if setup_info.ignore_paths:
codeflash_config["ignorePaths"] = setup_info.ignore_paths
if setup_info.benchmarks_root:
codeflash_config["benchmarksRoot"] = setup_info.benchmarks_root
# Only write codeflash section if there's something to write
if codeflash_config:
package_data["codeflash"] = codeflash_config
action = "Updated"
else:
# Remove codeflash section if empty (all auto-detected)
if "codeflash" in package_data:
del package_data["codeflash"]
action = "Configured"
try:
with package_json_path.open("w", encoding="utf8") as f:
json.dump(package_data, f, indent=2)
f.write("\n") # Trailing newline
except OSError as e:
click.echo(f"❌ Failed to update package.json: {e}")
return False
else:
if codeflash_config:
click.echo(f"{action} Codeflash configuration in {package_json_path}")
else:
click.echo("✅ Using auto-detected configuration (no overrides needed)")
click.echo()
return True
# ============================================================================
# GitHub Actions Workflow Helpers for JS/TS
# ============================================================================
def is_codeflash_dependency(project_root: Path) -> bool:
"""Check if codeflash is listed as a dependency in package.json."""
package_json_path = project_root / "package.json"
if not package_json_path.exists():
return False
try:
with package_json_path.open(encoding="utf8") as f:
package_data = json.load(f)
except (json.JSONDecodeError, OSError):
return False
deps = package_data.get("dependencies", {})
dev_deps = package_data.get("devDependencies", {})
return "codeflash" in deps or "codeflash" in dev_deps
def get_js_runtime_setup_steps(pkg_manager: JsPackageManager) -> str:
"""Generate the appropriate Node.js/Bun setup steps for GitHub Actions.
Returns properly indented YAML steps for the workflow template.
"""
if pkg_manager == JsPackageManager.BUN:
return """- name: 🥟 Setup Bun
uses: oven-sh/setup-bun@v2
with:
bun-version: latest"""
if pkg_manager == JsPackageManager.PNPM:
return """- name: 📦 Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 9
- name: 🟢 Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
cache: 'pnpm'"""
if pkg_manager == JsPackageManager.YARN:
return """- name: 🟢 Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
cache: 'yarn'"""
# NPM or UNKNOWN
return """- name: 🟢 Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
cache: 'npm'"""
def get_js_codeflash_install_step(pkg_manager: JsPackageManager, *, is_dependency: bool) -> str:
"""Generate the codeflash installation step if not already a dependency.
Args:
pkg_manager: The package manager being used.
is_dependency: Whether codeflash is already in package.json dependencies.
Returns:
YAML step string for installing codeflash, or empty string if not needed.
"""
if is_dependency:
# Codeflash will be installed with other dependencies
return ""
# Need to install codeflash separately
if pkg_manager == JsPackageManager.BUN:
return """- name: 📥 Install Codeflash
run: bun add -g codeflash"""
if pkg_manager == JsPackageManager.PNPM:
return """- name: 📥 Install Codeflash
run: pnpm add -g codeflash"""
if pkg_manager == JsPackageManager.YARN:
return """- name: 📥 Install Codeflash
run: yarn global add codeflash"""
# NPM or UNKNOWN
return """- name: 📥 Install Codeflash
run: npm install -g codeflash"""
def get_js_codeflash_run_command(pkg_manager: JsPackageManager, *, is_dependency: bool) -> str:
"""Generate the codeflash run command for GitHub Actions.
Args:
pkg_manager: The package manager being used.
is_dependency: Whether codeflash is in package.json dependencies.
Returns:
Command string to run codeflash.
"""
if is_dependency:
# Use package manager's run command for local dependency
if pkg_manager == JsPackageManager.BUN:
return "bun run codeflash"
if pkg_manager == JsPackageManager.PNPM:
return "pnpm exec codeflash"
if pkg_manager == JsPackageManager.YARN:
return "yarn codeflash"
# NPM
return "npx codeflash"
# Globally installed - just run directly
return "codeflash"
def get_js_runtime_setup_string(pkg_manager: JsPackageManager) -> str:
"""Generate the appropriate Node.js setup step for GitHub Actions.
Deprecated: Use get_js_runtime_setup_steps instead.
"""
return get_js_runtime_setup_steps(pkg_manager)
def get_js_dependency_installation_commands(pkg_manager: JsPackageManager) -> str:
"""Generate commands to install JavaScript/TypeScript dependencies."""
if pkg_manager == JsPackageManager.BUN:
return "bun install"
if pkg_manager == JsPackageManager.PNPM:
return "pnpm install"
if pkg_manager == JsPackageManager.YARN:
return "yarn install"
# NPM or UNKNOWN
return "npm ci"
def get_js_codeflash_command(pkg_manager: JsPackageManager) -> str:
"""Generate the appropriate codeflash command for JavaScript/TypeScript projects."""
if pkg_manager == JsPackageManager.BUN:
return "bunx codeflash"
if pkg_manager == JsPackageManager.PNPM:
return "pnpm dlx codeflash"
if pkg_manager == JsPackageManager.YARN:
return "yarn dlx codeflash"
# NPM or UNKNOWN
return "npx codeflash"

View file

@ -0,0 +1,35 @@
name: Codeflash
on:
pull_request:
paths:
# So that this workflow only runs when code within the target module is modified
- '{{ codeflash_module_path }}'
workflow_dispatch:
concurrency:
# Any new push to the PR will cancel the previous run, so that only the latest code is optimized
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
optimize:
name: Optimize new code
# Don't run codeflash on codeflash-ai[bot] commits, prevent duplicate optimizations
if: ${{ github.actor != 'codeflash-ai[bot]' }}
runs-on: ubuntu-latest
env:
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
{{ working_directory }}
steps:
- name: 🛎️ Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
{{ setup_runtime_steps }}
- name: 📦 Install Dependencies
run: {{ install_dependencies_command }}
{{ install_codeflash_step }}
- name: ⚡️ Codeflash Optimization
run: {{ codeflash_command }}

View file

@ -16,6 +16,7 @@ from libcst.helpers import calculate_module_and_package
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_REVIEW
from codeflash.languages.base import Language
from codeflash.models.models import CodePosition, FunctionParent
if TYPE_CHECKING:
@ -44,14 +45,14 @@ class GlobalFunctionCollector(cst.CSTVisitor):
self.scope_depth += 1
return True
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.scope_depth -= 1
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.scope_depth += 1
return True
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.scope_depth -= 1
@ -65,7 +66,7 @@ class GlobalFunctionTransformer(cst.CSTTransformer):
self.processed_functions: set[str] = set()
self.scope_depth = 0
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
self.scope_depth += 1
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
@ -80,14 +81,14 @@ class GlobalFunctionTransformer(cst.CSTTransformer):
return self.new_functions[name]
return updated_node
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.scope_depth += 1
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.scope_depth -= 1
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# Add any new functions that weren't in the original file
new_statements = list(updated_node.body)
@ -141,28 +142,28 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
self.scope_depth = 0
self.if_else_depth = 0
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: # noqa: ARG002
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.scope_depth += 1
return True
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.scope_depth -= 1
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.scope_depth += 1
return True
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.scope_depth -= 1
def visit_If(self, node: cst.If) -> Optional[bool]: # noqa: ARG002
def visit_If(self, node: cst.If) -> Optional[bool]:
self.if_else_depth += 1
return True
def leave_If(self, original_node: cst.If) -> None: # noqa: ARG002
def leave_If(self, original_node: cst.If) -> None:
self.if_else_depth -= 1
def visit_Else(self, node: cst.Else) -> Optional[bool]: # noqa: ARG002
def visit_Else(self, node: cst.Else) -> Optional[bool]:
# Else blocks are already counted as part of the if statement
return True
@ -231,24 +232,24 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
self.scope_depth = 0
self.if_else_depth = 0
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
self.scope_depth += 1
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
self.scope_depth -= 1
return updated_node
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.scope_depth += 1
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.scope_depth -= 1
return updated_node
def visit_If(self, node: cst.If) -> None: # noqa: ARG002
def visit_If(self, node: cst.If) -> None:
self.if_else_depth += 1
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: # noqa: ARG002
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
self.if_else_depth -= 1
return updated_node
@ -283,7 +284,7 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# Add any new assignments that weren't in the original file
new_statements = list(updated_node.body)
@ -370,7 +371,7 @@ class GlobalStatementTransformer(cst.CSTTransformer):
super().__init__()
self.global_statements = global_statements
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
if not self.global_statements:
return updated_node
@ -396,20 +397,20 @@ class GlobalStatementCollector(cst.CSTVisitor):
self.global_statements = []
self.in_function_or_class = False
def visit_ClassDef(self, node: cst.ClassDef) -> bool: # noqa: ARG002
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
# Don't visit inside classes
self.in_function_or_class = True
return False
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.in_function_or_class = False
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # noqa: ARG002
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
# Don't visit inside functions
self.in_function_or_class = True
return False
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.in_function_or_class = False
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
@ -490,16 +491,16 @@ class DottedImportCollector(cst.CSTVisitor):
self.depth = 0
self._collect_imports_from_block(node)
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
self.depth += 1
def leave_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
self.depth -= 1
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.depth += 1
def leave_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
def leave_ClassDef(self, node: cst.ClassDef) -> None:
self.depth -= 1
def visit_If(self, node: cst.If) -> None:
@ -529,9 +530,7 @@ def find_last_import_line(target_code: str) -> int:
class FutureAliasedImportTransformer(cst.CSTTransformer):
def leave_ImportFrom(
self,
original_node: cst.ImportFrom, # noqa: ARG002
updated_node: cst.ImportFrom,
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel:
import libcst.matchers as m
@ -676,7 +675,7 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
if not name.startswith("_"):
public_names.add(name)
return public_names # noqa: TRY300
return public_names
except Exception as e:
logger.warning(f"Error resolving star import for {module_name}: {e}")
@ -1165,7 +1164,7 @@ class FunctionCallFinder(ast.NodeVisitor):
return False
def _get_call_name(self, func_node) -> Optional[str]: # noqa: ANN001
def _get_call_name(self, func_node) -> Optional[str]:
"""Extract the name being called from a function node."""
# Fast path short-circuit for ast.Name nodes
if isinstance(func_node, ast.Name):
@ -1341,9 +1340,12 @@ def get_fn_references_jedi(
source_code: str, file_path: Path, project_root: Path, target_function: str, target_class: str | None
) -> list[Path]:
start_time = time.perf_counter()
function_position: CodePosition = find_specific_function_in_file(
function_position: CodePosition | None = find_specific_function_in_file(
source_code, file_path, target_function, target_class
)
if function_position is None:
# Function not found (may be non-Python code)
return []
try:
script = jedi.Script(code=source_code, path=file_path, project=jedi.Project(path=project_root))
# Get references to the function
@ -1555,15 +1557,15 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo
# If numba is not installed and all modules used require numba for optimization,
# return False since we can't optimize this code
if not has_numba and modules_used.issubset(NUMBA_REQUIRED_MODULES): # noqa : SIM103
return False
return True
return not (not has_numba and modules_used.issubset(NUMBA_REQUIRED_MODULES))
def get_opt_review_metrics(
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path, language: Language
) -> str:
if language != Language.PYTHON:
# TODO: {Claude} handle function refrences for other languages
return ""
start_time = time.perf_counter()
try:
qualified_name_split = qualified_name.rsplit(".", maxsplit=1)

View file

@ -18,12 +18,15 @@ from codeflash.code_utils.code_extractor import (
from codeflash.code_utils.config_parser import find_conftest_files
from codeflash.code_utils.formatter import sort_imports
from codeflash.code_utils.line_profile_utils import ImportAdder
from codeflash.languages import is_python
from codeflash.models.models import FunctionParent
if TYPE_CHECKING:
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language, LanguageSupport
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
@ -109,7 +112,7 @@ class PytestMarkAdder(cst.CSTTransformer):
if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest":
self.has_pytest_import = True
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
"""Add pytest import if not present."""
if not self.has_pytest_import:
# Create import statement
@ -118,7 +121,7 @@ class PytestMarkAdder(cst.CSTTransformer):
updated_node = updated_node.with_changes(body=[import_stmt, *updated_node.body])
return updated_node
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
"""Add pytest mark to test functions."""
# Check if the mark already exists
for decorator in updated_node.decorators:
@ -291,7 +294,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
return True
def leave_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
def leave_ClassDef(self, node: cst.ClassDef) -> None:
if self.current_class:
self.current_class = None
@ -315,7 +318,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
)
self.current_class = None
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # noqa: ARG002
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
return False
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
@ -344,7 +347,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
)
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
node = updated_node
max_function_index = None
max_class_index = None
@ -440,8 +443,15 @@ def replace_function_definitions_in_module(
module_abspath: Path,
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
project_root_path: Path,
should_add_global_assignments: bool = True, # noqa: FBT001, FBT002
should_add_global_assignments: bool = True,
function_to_optimize: Optional[FunctionToOptimize] = None,
) -> bool:
# Route to language-specific implementation for non-Python languages
if not is_python():
return replace_function_definitions_for_language(
function_names, optimized_code, module_abspath, project_root_path, function_to_optimize
)
source_code: str = module_abspath.read_text(encoding="utf8")
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
@ -463,16 +473,271 @@ def replace_function_definitions_in_module(
return True
def replace_function_definitions_for_language(
function_names: list[str],
optimized_code: CodeStringsMarkdown,
module_abspath: Path,
project_root_path: Path,
function_to_optimize: Optional[FunctionToOptimize] = None,
) -> bool:
"""Replace function definitions for non-Python languages.
Uses the language support abstraction to perform code replacement.
Args:
function_names: List of qualified function names to replace.
optimized_code: The optimized code to apply.
module_abspath: Path to the module file.
project_root_path: Root of the project.
function_to_optimize: The function being optimized (needed for line info).
Returns:
True if the code was modified, False if no changes.
"""
from codeflash.languages import get_language_support
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
original_source_code: str = module_abspath.read_text(encoding="utf8")
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
if not code_to_apply.strip():
return False
# Get language support
language = Language(optimized_code.language)
lang_support = get_language_support(language)
# Add any new global declarations from the optimized code to the original source
original_source_code = _add_global_declarations_for_language(
optimized_code=code_to_apply,
original_source=original_source_code,
module_abspath=module_abspath,
language=language,
)
# If we have function_to_optimize with line info and this is the main file, use it for precise replacement
if (
function_to_optimize
and function_to_optimize.starting_line
and function_to_optimize.ending_line
and function_to_optimize.file_path == module_abspath
):
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents)
func_info = FunctionInfo(
name=function_to_optimize.function_name,
file_path=module_abspath,
start_line=function_to_optimize.starting_line,
end_line=function_to_optimize.ending_line,
parents=parents,
is_async=function_to_optimize.is_async,
language=language,
)
# Extract just the target function from the optimized code
optimized_func = _extract_function_from_code(
lang_support, code_to_apply, function_to_optimize.function_name, module_abspath
)
if optimized_func:
new_code = lang_support.replace_function(original_source_code, func_info, optimized_func)
else:
# Fallback: use the entire optimized code (for simple single-function files)
new_code = lang_support.replace_function(original_source_code, func_info, code_to_apply)
else:
# For helper files or when we don't have precise line info:
# Find each function by name in both original and optimized code
# Then replace with the corresponding optimized version
new_code = original_source_code
modified = False
# Get the list of function names to replace
functions_to_replace = list(function_names)
for func_name in functions_to_replace:
# Re-discover functions from current code state to get correct line numbers
current_functions = lang_support.discover_functions_from_source(new_code, module_abspath)
# Find the function in current code
func = None
for f in current_functions:
if func_name in (f.qualified_name, f.name):
func = f
break
if func is None:
continue
# Extract just this function from the optimized code
optimized_func = _extract_function_from_code(lang_support, code_to_apply, func.name, module_abspath)
if optimized_func:
new_code = lang_support.replace_function(new_code, func, optimized_func)
modified = True
if not modified:
logger.warning(f"Could not find function {function_names} in {module_abspath}")
return False
# Check if there was actually a change
if original_source_code.strip() == new_code.strip():
return False
module_abspath.write_text(new_code, encoding="utf8")
return True
def _extract_function_from_code(
lang_support: LanguageSupport, source_code: str, function_name: str, file_path: Path | None = None
) -> str | None:
"""Extract a specific function's source code from a code string.
Includes JSDoc/docstring comments if present.
Args:
lang_support: Language support instance.
source_code: The full source code containing the function.
function_name: Name of the function to extract.
file_path: Path to the file (used to determine correct analyzer for JS/TS).
Returns:
The function's source code (including doc comments), or None if not found.
"""
try:
# Use the language support to find functions in the source
# file_path is needed for JS/TS to determine correct analyzer (TypeScript vs JavaScript)
functions = lang_support.discover_functions_from_source(source_code, file_path)
for func in functions:
if func.name == function_name:
# Extract the function's source using line numbers
# Use doc_start_line if available to include JSDoc/docstring
lines = source_code.splitlines(keepends=True)
effective_start = func.doc_start_line or func.start_line
if effective_start and func.end_line and effective_start <= len(lines):
func_lines = lines[effective_start - 1 : func.end_line]
return "".join(func_lines)
except Exception as e:
logger.debug(f"Error extracting function {function_name}: {e}")
return None
def _add_global_declarations_for_language(
optimized_code: str, original_source: str, module_abspath: Path, language: Language
) -> str:
"""Add new global declarations from optimized code to original source.
Finds module-level declarations (const, let, var, class, type, interface, enum)
in the optimized code that don't exist in the original source and adds them.
Args:
optimized_code: The optimized code that may contain new declarations.
original_source: The original source code.
module_abspath: Path to the module file (for parser selection).
language: The language of the code.
Returns:
Original source with new declarations added after imports.
"""
from codeflash.languages.base import Language
# Only process JavaScript/TypeScript
if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT):
return original_source
try:
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(module_abspath)
# Find declarations in both original and optimized code
original_declarations = analyzer.find_module_level_declarations(original_source)
optimized_declarations = analyzer.find_module_level_declarations(optimized_code)
if not optimized_declarations:
return original_source
# Get names of existing declarations
existing_names = {decl.name for decl in original_declarations}
# Find new declarations (names that don't exist in original)
new_declarations = []
seen_sources = set() # Track to avoid duplicates from destructuring
for decl in optimized_declarations:
if decl.name not in existing_names and decl.source_code not in seen_sources:
new_declarations.append(decl)
seen_sources.add(decl.source_code)
if not new_declarations:
return original_source
# Sort by line number to maintain order
new_declarations.sort(key=lambda d: d.start_line)
# Find insertion point (after imports)
lines = original_source.splitlines(keepends=True)
insertion_line = _find_insertion_line_after_imports_js(lines, analyzer, original_source)
# Build new declarations string
new_decl_code = "\n".join(decl.source_code for decl in new_declarations)
new_decl_code = new_decl_code + "\n\n"
# Insert declarations
before = lines[:insertion_line]
after = lines[insertion_line:]
result_lines = [*before, new_decl_code, *after]
return "".join(result_lines)
except Exception as e:
logger.debug(f"Error adding global declarations: {e}")
return original_source
def _find_insertion_line_after_imports_js(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
"""Find the line index where new declarations should be inserted (after imports).
Args:
lines: Source lines.
analyzer: TreeSitter analyzer for the file.
source: Full source code.
Returns:
Line index (0-based) for insertion.
"""
try:
imports = analyzer.find_imports(source)
if imports:
# Find the last import's end line
return max(imp.end_line for imp in imports)
except Exception as exc:
logger.debug(f"Exception occurred in _find_insertion_line_after_imports_js: {exc}")
# Default: insert at beginning (after any shebang/directive comments)
for i, line in enumerate(lines):
stripped = line.strip()
if stripped and not stripped.startswith("//") and not stripped.startswith("#!"):
return i
return 0
def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str:
file_to_code_context = optimized_code.file_to_path()
module_optimized_code = file_to_code_context.get(str(relative_path))
if module_optimized_code is None:
logger.warning(
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
"re-check your 'markdown code structure'"
f"existing files are {file_to_code_context.keys()}"
)
module_optimized_code = ""
# Fallback: if there's only one code block with None file path,
# use it regardless of the expected path (the AI server doesn't always include file paths)
if "None" in file_to_code_context and len(file_to_code_context) == 1:
module_optimized_code = file_to_code_context["None"]
logger.debug(f"Using code block with None file_path for {relative_path}")
else:
logger.warning(
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
"re-check your 'markdown code structure'"
f"existing files are {file_to_code_context.keys()}"
)
module_optimized_code = ""
return module_optimized_code
@ -518,7 +783,8 @@ def replace_optimized_code(
[
callee.qualified_name
for callee in code_context.helper_functions
if callee.file_path == module_path and callee.jedi_definition.type != "class"
if callee.file_path == module_path
and (callee.jedi_definition is None or callee.jedi_definition.type != "class")
]
),
candidate.source_code,

View file

@ -166,7 +166,7 @@ def filter_args(addopts_args: list[str]) -> list[str]:
return filtered_args
def modify_addopts(config_file: Path) -> tuple[str, bool]: # noqa : PLR0911
def modify_addopts(config_file: Path) -> tuple[str, bool]:
file_type = config_file.suffix.lower()
filename = config_file.name
config = None

View file

@ -46,7 +46,7 @@ def extract_test_context_from_env() -> tuple[str, str | None, str]:
def codeflash_behavior_async(func: F) -> F:
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
@ -122,7 +122,7 @@ def codeflash_behavior_async(func: F) -> F:
def codeflash_performance_async(func: F) -> F:
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
@ -172,7 +172,7 @@ def codeflash_concurrency_async(func: F) -> F:
"""Measures concurrent vs sequential execution performance for async functions."""
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
function_name = func.__name__
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))

View file

@ -71,7 +71,7 @@ class AssertCleanup:
unittest_match = self.unittest_re.match(line)
if unittest_match:
indent, assert_method, args = unittest_match.groups()
indent, _assert_method, args = unittest_match.groups()
if args:
arg_parts = self._first_top_level_arg(args)

View file

@ -91,7 +91,7 @@ EFFORT_VALUES: dict[str, dict[EffortLevel, Any]] = {
}
def get_effort_value(key: EffortKeys, effort: Union[EffortLevel, str]) -> Any: # noqa: ANN401
def get_effort_value(key: EffortKeys, effort: Union[EffortLevel, str]) -> Any:
key_str = key.value
if isinstance(effort, str):

View file

@ -0,0 +1,290 @@
"""JavaScript/TypeScript configuration parsing from package.json."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
PACKAGE_JSON_CACHE: dict[Path, Path] = {}
PACKAGE_JSON_DATA_CACHE: dict[Path, dict[str, Any]] = {}
def get_package_json_data(package_json_path: Path) -> dict[str, Any] | None:
"""Load and cache package.json data.
Args:
package_json_path: Path to package.json file.
Returns:
Parsed package.json data or None if invalid.
"""
if package_json_path in PACKAGE_JSON_DATA_CACHE:
return PACKAGE_JSON_DATA_CACHE[package_json_path]
try:
with package_json_path.open(encoding="utf8") as f:
data: dict[str, Any] = json.load(f)
PACKAGE_JSON_DATA_CACHE[package_json_path] = data
return data
except (json.JSONDecodeError, OSError):
return None
def detect_language(project_root: Path) -> str:
"""Detect project language from tsconfig.json presence.
Args:
project_root: Root directory of the project.
Returns:
"typescript" if tsconfig.json exists, "javascript" otherwise.
"""
tsconfig_path = project_root / "tsconfig.json"
return "typescript" if tsconfig_path.exists() else "javascript"
def detect_module_root(project_root: Path, package_data: dict[str, Any]) -> str:
"""Detect module root from package.json fields or directory conventions.
Detection order:
1. package.json "exports" field (extract directory from main export)
2. package.json "module" field (ESM entry point)
3. package.json "main" field (CJS entry point)
4. "src/" directory if it exists
5. Fall back to "." (project root)
Args:
project_root: Root directory of the project.
package_data: Parsed package.json data.
Returns:
Detected module root path (relative to project root).
"""
# Check exports field (modern Node.js)
exports = package_data.get("exports")
if exports:
entry_path = None
if isinstance(exports, str):
entry_path = exports
elif isinstance(exports, dict):
# Handle {"." : "./src/index.js"} or {".": {"import": "./src/index.js"}}
main_export = exports.get(".") or exports.get("import") or exports.get("default")
if isinstance(main_export, str):
entry_path = main_export
elif isinstance(main_export, dict):
entry_path = main_export.get("import") or main_export.get("default") or main_export.get("require")
if entry_path and isinstance(entry_path, str):
parent = Path(entry_path).parent
if parent != Path() and (project_root / parent).is_dir():
return parent.as_posix()
# Check module field (ESM)
module_field = package_data.get("module")
if module_field and isinstance(module_field, str):
parent = Path(module_field).parent
if parent != Path() and (project_root / parent).is_dir():
return parent.as_posix()
# Check main field (CJS)
main_field = package_data.get("main")
if main_field and isinstance(main_field, str):
parent = Path(main_field).parent
if parent != Path() and (project_root / parent).is_dir():
return parent.as_posix()
# Check for src/ directory convention
if (project_root / "src").is_dir():
return "src"
# Default to project root
return "."
def detect_test_runner(project_root: Path, package_data: dict[str, Any]) -> str: # noqa: ARG001
"""Detect test runner from devDependencies or scripts.test.
Detection order:
1. Check devDependencies for vitest, jest, mocha
2. Parse scripts.test for runner hints
3. Fall back to "jest" as default
Args:
project_root: Root directory of the project.
package_data: Parsed package.json data.
Returns:
Detected test runner command (e.g., "jest", "vitest", "mocha").
"""
runners = ["vitest", "jest", "mocha"]
dev_deps = package_data.get("devDependencies", {})
deps = package_data.get("dependencies", {})
all_deps = {**deps, **dev_deps}
# Check devDependencies (order matters - prefer more modern runners)
for runner in runners:
if runner in all_deps:
return runner
# Parse scripts.test for hints
scripts = package_data.get("scripts", {})
test_script = scripts.get("test", "")
if isinstance(test_script, str):
test_lower = test_script.lower()
for runner in runners:
if runner in test_lower:
return runner
# Default to jest
return "jest"
def detect_formatter(project_root: Path, package_data: dict[str, Any]) -> list[str] | None: # noqa: ARG001
"""Detect formatter from devDependencies.
Detection order:
1. Check devDependencies for prettier
2. Check devDependencies for eslint (with --fix)
3. Return None if no formatter detected
Args:
project_root: Root directory of the project.
package_data: Parsed package.json data.
Returns:
List of formatter commands or None if not detected.
"""
dev_deps = package_data.get("devDependencies", {})
deps = package_data.get("dependencies", {})
all_deps = {**deps, **dev_deps}
# Check for prettier (preferred)
if "prettier" in all_deps:
return ["npx prettier --write $file"]
# Check for eslint (can format with --fix)
if "eslint" in all_deps:
return ["npx eslint --fix $file"]
return None
def find_package_json(config_file: Path | None = None) -> Path | None:
"""Find package.json file for JavaScript/TypeScript projects.
Args:
config_file: Optional explicit config file path.
Returns:
Path to package.json if found, None otherwise.
"""
if config_file is not None:
config_file = Path(config_file)
if config_file.name == "package.json" and config_file.exists():
return config_file
return None
dir_path = Path.cwd()
cur_path = dir_path
if cur_path in PACKAGE_JSON_CACHE:
return PACKAGE_JSON_CACHE[cur_path]
while dir_path != dir_path.parent:
config_file = dir_path / "package.json"
if config_file.exists():
PACKAGE_JSON_CACHE[cur_path] = config_file
return config_file
dir_path = dir_path.parent
return None
def parse_package_json_config(package_json_path: Path) -> tuple[dict[str, Any], Path] | None:
"""Parse codeflash config from package.json with auto-detection.
Most configuration is auto-detected from package.json and project structure.
Only minimal config is stored in the "codeflash" key:
- benchmarksRoot: Where to store benchmark files (optional, defaults to __benchmarks__)
- ignorePaths: Paths to exclude from optimization (optional)
- disableTelemetry: Privacy preference (optional, defaults to false)
- formatterCmds: Override auto-detected formatter (optional)
Auto-detected values (not stored in config):
- language: Detected from tsconfig.json presence
- moduleRoot: Detected from package.json exports/module/main or src/ convention
- testRunner: Detected from devDependencies (vitest/jest/mocha)
- formatter: Detected from devDependencies (prettier/eslint)
Args:
package_json_path: Path to package.json file.
Returns:
Tuple of (config dict, path) if package.json exists, None otherwise.
"""
package_data = get_package_json_data(package_json_path)
if package_data is None:
return None
project_root = package_json_path.parent
codeflash_config = package_data.get("codeflash", {})
if not isinstance(codeflash_config, dict):
codeflash_config = {}
config: dict[str, Any] = {}
# Auto-detect language
config["language"] = detect_language(project_root)
# Auto-detect module root (can be overridden)
if codeflash_config.get("moduleRoot"):
config["module_root"] = str((project_root / Path(codeflash_config["moduleRoot"])).resolve())
else:
detected_module_root = detect_module_root(project_root, package_data)
config["module_root"] = str((project_root / Path(detected_module_root)).resolve())
# Auto-detect test runner
config["test_runner"] = detect_test_runner(project_root, package_data)
# Keep pytest_cmd for backwards compatibility with existing code
config["pytest_cmd"] = config["test_runner"]
# Auto-detect formatter (with optional override from config)
if "formatterCmds" in codeflash_config:
config["formatter_cmds"] = codeflash_config["formatterCmds"]
else:
detected_formatter = detect_formatter(project_root, package_data)
config["formatter_cmds"] = detected_formatter if detected_formatter else []
# Parse optional config values from codeflash section
if codeflash_config.get("benchmarksRoot"):
config["benchmarks_root"] = str((project_root / Path(codeflash_config["benchmarksRoot"])).resolve())
if codeflash_config.get("ignorePaths"):
config["ignore_paths"] = [str((project_root / path).resolve()) for path in codeflash_config["ignorePaths"]]
else:
config["ignore_paths"] = []
config["disable_telemetry"] = codeflash_config.get("disableTelemetry", False)
# Git remote (from config or default to "origin")
config["git_remote"] = codeflash_config.get("gitRemote", "origin")
# Set remaining defaults for backwards compatibility
config.setdefault("disable_imports_sorting", False)
config.setdefault("override_fixtures", False)
return config, package_json_path
def clear_cache() -> None:
"""Clear all package.json caches."""
PACKAGE_JSON_CACHE.clear()
PACKAGE_JSON_DATA_CACHE.clear()

View file

@ -5,10 +5,11 @@ from typing import Any
import tomlkit
from codeflash.code_utils.config_js import find_package_json, parse_package_json_config
from codeflash.lsp.helpers import is_LSP_enabled
PYPROJECT_TOML_CACHE = {}
ALL_CONFIG_FILES = {} # map path to closest config file
PYPROJECT_TOML_CACHE: dict[Path, Path] = {}
ALL_CONFIG_FILES: dict[Path, dict[str, Path]] = {}
def find_pyproject_toml(config_file: Path | None = None) -> Path:
@ -83,10 +84,27 @@ def find_conftest_files(test_paths: list[Path]) -> list[Path]:
return list(list_of_conftest_files)
# TODO for claude: There should be different functions to parse it per language, which should be chosen during runtime
def parse_config_file(
config_file_path: Path | None = None,
override_formatter_check: bool = False, # noqa: FBT001, FBT002
config_file_path: Path | None = None, override_formatter_check: bool = False
) -> tuple[dict[str, Any], Path]:
# First try package.json for JS/TS projects
package_json_path = find_package_json(config_file_path)
if package_json_path:
result = parse_package_json_config(package_json_path)
if result is not None:
config, path = result
# Validate formatter if needed
if not override_formatter_check and config.get("formatter_cmds"):
formatter_cmds = config.get("formatter_cmds", [])
if formatter_cmds and formatter_cmds[0] == "your-formatter $file":
raise ValueError(
"The formatter command is not set correctly in package.json. Please set the "
"formatter command in the 'formatterCmds' key."
)
return config, path
# Fall back to pyproject.toml
config_file_path = find_pyproject_toml(config_file_path)
try:
with config_file_path.open("rb") as f:

View file

@ -1,250 +1,129 @@
import ast
"""Code deduplication utilities using language-specific normalizers.
This module provides functions to normalize code, generate fingerprints,
and detect duplicate code segments across different programming languages.
"""
from __future__ import annotations
import hashlib
import re
from codeflash.code_utils.normalizers import get_normalizer
from codeflash.languages import current_language, is_python
class VariableNormalizer(ast.NodeTransformer):
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
Preserves function names, class names, parameters, built-ins, and imported names.
"""
def __init__(self) -> None:
self.var_counter = 0
self.var_mapping: dict[str, str] = {}
self.scope_stack = []
self.builtins = set(dir(__builtins__))
self.imports: set[str] = set()
self.global_vars: set[str] = set()
self.nonlocal_vars: set[str] = set()
self.parameters: set[str] = set() # Track function parameters
def enter_scope(self): # noqa : ANN201
"""Enter a new scope (function/class)."""
self.scope_stack.append(
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
)
def exit_scope(self): # noqa : ANN201
"""Exit current scope and restore parent scope."""
if self.scope_stack:
scope = self.scope_stack.pop()
self.var_mapping = scope["var_mapping"]
self.var_counter = scope["var_counter"]
self.parameters = scope["parameters"]
def get_normalized_name(self, name: str) -> str:
"""Get or create normalized name for a variable."""
# Don't normalize if it's a builtin, import, global, nonlocal, or parameter
if (
name in self.builtins
or name in self.imports
or name in self.global_vars
or name in self.nonlocal_vars
or name in self.parameters
):
return name
# Only normalize local variables
if name not in self.var_mapping:
self.var_mapping[name] = f"var_{self.var_counter}"
self.var_counter += 1
return self.var_mapping[name]
def visit_Import(self, node): # noqa : ANN001, ANN201
"""Track imported names."""
for alias in node.names:
name = alias.asname if alias.asname else alias.name
self.imports.add(name.split(".")[0])
return node
def visit_ImportFrom(self, node): # noqa : ANN001, ANN201
"""Track imported names from modules."""
for alias in node.names:
name = alias.asname if alias.asname else alias.name
self.imports.add(name)
return node
def visit_Global(self, node): # noqa : ANN001, ANN201
"""Track global variable declarations."""
# Avoid repeated .add calls by using set.update with list
self.global_vars.update(node.names)
return node
def visit_Nonlocal(self, node): # noqa : ANN001, ANN201
"""Track nonlocal variable declarations."""
# Using set.update for batch insertion (faster than add-in-loop)
self.nonlocal_vars.update(node.names)
return node
def visit_FunctionDef(self, node): # noqa : ANN001, ANN201
"""Process function but keep function name and parameters unchanged."""
self.enter_scope()
# Track all parameters (don't modify them)
for arg in node.args.args:
self.parameters.add(arg.arg)
if node.args.vararg:
self.parameters.add(node.args.vararg.arg)
if node.args.kwarg:
self.parameters.add(node.args.kwarg.arg)
for arg in node.args.kwonlyargs:
self.parameters.add(arg.arg)
# Visit function body
node = self.generic_visit(node)
self.exit_scope()
return node
def visit_AsyncFunctionDef(self, node): # noqa : ANN001, ANN201
"""Handle async functions same as regular functions."""
return self.visit_FunctionDef(node)
def visit_ClassDef(self, node): # noqa : ANN001, ANN201
"""Process class but keep class name unchanged."""
self.enter_scope()
node = self.generic_visit(node)
self.exit_scope()
return node
def visit_Name(self, node): # noqa : ANN001, ANN201
"""Normalize variable names in Name nodes."""
if isinstance(node.ctx, (ast.Store, ast.Del)):
# For assignments and deletions, check if we should normalize
if (
node.id not in self.builtins
and node.id not in self.imports
and node.id not in self.parameters
and node.id not in self.global_vars
and node.id not in self.nonlocal_vars
):
node.id = self.get_normalized_name(node.id)
elif isinstance(node.ctx, ast.Load): # noqa : SIM102
# For loading, use existing mapping if available
if node.id in self.var_mapping:
node.id = self.var_mapping[node.id]
return node
def visit_ExceptHandler(self, node): # noqa : ANN001, ANN201
"""Normalize exception variable names."""
if node.name:
node.name = self.get_normalized_name(node.name)
return self.generic_visit(node)
def visit_comprehension(self, node): # noqa : ANN001, ANN201
"""Normalize comprehension target variables."""
# Create new scope for comprehension
old_mapping = dict(self.var_mapping)
old_counter = self.var_counter
# Process the comprehension
node = self.generic_visit(node)
# Restore scope
self.var_mapping = old_mapping
self.var_counter = old_counter
return node
def visit_For(self, node): # noqa : ANN001, ANN201
"""Handle for loop target variables."""
# The target in a for loop is a local variable that should be normalized
return self.generic_visit(node)
def visit_With(self, node): # noqa : ANN001, ANN201
"""Handle with statement as variables."""
return self.generic_visit(node)
def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str: # noqa : FBT002, FBT001
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.
def normalize_code(
code: str,
remove_docstrings: bool = True,
return_ast_dump: bool = False,
language: str | None = None,
) -> str:
"""Normalize code by parsing, cleaning, and normalizing variable names.
Function names, class names, and parameters are preserved.
Args:
code: Python source code as string
remove_docstrings: Whether to remove docstrings
return_ast_dump: return_ast_dump
code: Source code as string
remove_docstrings: Whether to remove docstrings (Python only)
return_ast_dump: Return AST dump instead of unparsed code (Python only)
language: Language of the code. If None, uses the current session language.
Returns:
Normalized code as string
"""
if language is None:
language = current_language().value
try:
# Parse the code
tree = ast.parse(code)
normalizer = get_normalizer(language)
# Remove docstrings if requested
if remove_docstrings:
remove_docstrings_from_ast(tree)
# Python has additional options
if is_python():
if return_ast_dump:
return normalizer.normalize_for_hash(code)
return normalizer.normalize(code, remove_docstrings=remove_docstrings)
# Normalize variable names
normalizer = VariableNormalizer()
normalized_tree = normalizer.visit(tree)
if return_ast_dump:
# This is faster than unparsing etc
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)
# Fix missing locations in the AST
ast.fix_missing_locations(normalized_tree)
# Unparse back to code
return ast.unparse(normalized_tree)
except SyntaxError as e:
msg = f"Invalid Python syntax: {e}"
raise ValueError(msg) from e
# For other languages, use standard normalization
return normalizer.normalize(code)
except ValueError:
# Unknown language - fall back to basic normalization
return _basic_normalize(code)
except Exception:
# Parsing error - try other languages or fall back
if is_python():
# Try JavaScript as fallback
try:
js_normalizer = get_normalizer("javascript")
js_result = js_normalizer.normalize(code)
if js_result != _basic_normalize(code):
return js_result
except Exception:
pass
return _basic_normalize(code)
def remove_docstrings_from_ast(node): # noqa : ANN001, ANN201
"""Remove docstrings from AST nodes."""
# Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0]
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
# Use our own stack-based DFS instead of ast.walk for efficiency
stack = [node]
while stack:
current_node = stack.pop()
if isinstance(current_node, node_types):
# Remove docstring if it's the first stmt in body
body = current_node.body
if (
body
and isinstance(body[0], ast.Expr)
and isinstance(body[0].value, ast.Constant)
and isinstance(body[0].value.value, str)
):
current_node.body = body[1:]
# Only these nodes can nest more docstring-containing nodes
# Add their body elements to stack, avoiding unnecessary traversal
stack.extend([child for child in body if isinstance(child, node_types)])
def _basic_normalize(code: str) -> str:
"""Basic normalization: remove comments and normalize whitespace."""
# Remove single-line comments (// and #)
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
code = re.sub(r"#.*$", "", code, flags=re.MULTILINE)
# Remove multi-line comments
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
code = re.sub(r'""".*?"""', "", code, flags=re.DOTALL)
code = re.sub(r"'''.*?'''", "", code, flags=re.DOTALL)
# Normalize whitespace
return " ".join(code.split())
def get_code_fingerprint(code: str) -> str:
def get_code_fingerprint(code: str, language: str | None = None) -> str:
"""Generate a fingerprint for normalized code.
Args:
code: Python source code
code: Source code
language: Language of the code. If None, uses the current session language.
Returns:
SHA-256 hash of normalized code
"""
normalized = normalize_code(code)
return hashlib.sha256(normalized.encode()).hexdigest()
if language is None:
language = current_language().value
try:
normalizer = get_normalizer(language)
return normalizer.get_fingerprint(code)
except ValueError:
# Unknown language - use basic normalization
normalized = _basic_normalize(code)
return hashlib.sha256(normalized.encode()).hexdigest()
def are_codes_duplicate(code1: str, code2: str) -> bool:
def are_codes_duplicate(code1: str, code2: str, language: str | None = None) -> bool:
"""Check if two code segments are duplicates after normalization.
Args:
code1: First code segment
code2: Second code segment
language: Language of the code. If None, uses the current session language.
Returns:
True if codes are structurally identical (ignoring local variable names)
"""
if language is None:
language = current_language().value
try:
normalized1 = normalize_code(code1, return_ast_dump=True)
normalized2 = normalize_code(code2, return_ast_dump=True)
normalizer = get_normalizer(language)
return normalizer.are_duplicates(code1, code2)
except ValueError:
# Unknown language - use basic comparison
return _basic_normalize(code1) == _basic_normalize(code2)
except Exception:
return False
else:
return normalized1 == normalized2
# Re-export for backward compatibility
__all__ = ["are_codes_duplicate", "get_code_fingerprint", "normalize_code"]

View file

@ -12,6 +12,7 @@ from libcst.metadata import PositionProvider
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import format_perf, format_time
from codeflash.languages.registry import get_language_support
from codeflash.models.models import GeneratedTests, GeneratedTestsList
from codeflash.result.critic import performance_gain
@ -149,25 +150,85 @@ class CommentAdder(cst.CSTTransformer):
return updated_node
def _is_python_file(file_path: Path) -> bool:
"""Check if a file is a Python file."""
return file_path.suffix == ".py"
# TODO:{self} Needs cleanup for jest logic in else block
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path) -> dict[str, int]:
unique_inv_ids: dict[str, int] = {}
logger.debug(f"[unique_inv_id] Processing {len(inv_id_runtimes)} invocation IDs")
for inv_id, runtimes in inv_id_runtimes.items():
test_qualified_name = (
inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator]
if inv_id.test_class_name
else inv_id.test_function_name
)
abs_path = tests_project_rootdir / Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py")
# Detect if test_module_path is a file path (like in js tests) or a Python module name
# File paths contain slashes, module names use dots
test_module_path = inv_id.test_module_path
if "/" in test_module_path or "\\" in test_module_path:
# Already a file path - use directly
abs_path = tests_project_rootdir / Path(test_module_path)
else:
# Check for Jest test file extensions (e.g., tests.fibonacci.test.ts)
# These need special handling to avoid converting .test.ts -> /test/ts
jest_test_extensions = (
".test.ts",
".test.js",
".test.tsx",
".test.jsx",
".spec.ts",
".spec.js",
".spec.tsx",
".spec.jsx",
".ts",
".js",
".tsx",
".jsx",
".mjs",
".mts",
)
matched_ext = None
for ext in jest_test_extensions:
if test_module_path.endswith(ext):
matched_ext = ext
break
if matched_ext:
# JavaScript/TypeScript: convert module-style path to file path
# "tests.fibonacci__perfonlyinstrumented.test.ts" -> "tests/fibonacci__perfonlyinstrumented.test.ts"
base_path = test_module_path[: -len(matched_ext)]
file_path = base_path.replace(".", os.sep) + matched_ext
# Check if the module path includes the tests directory name
tests_dir_name = tests_project_rootdir.name
if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")):
# Module path includes "tests." - use parent directory
abs_path = tests_project_rootdir.parent / Path(file_path)
else:
# Module path doesn't include tests dir - use tests root directly
abs_path = tests_project_rootdir / Path(file_path)
else:
# Python module name - convert dots to path separators and add .py
abs_path = tests_project_rootdir / Path(test_module_path.replace(".", os.sep)).with_suffix(".py")
abs_path_str = str(abs_path.resolve().with_suffix(""))
if "__unit_test_" not in abs_path_str or not test_qualified_name:
# Include both unit test and perf test paths for runtime annotations
# (performance test runtimes are used for annotations)
if ("__unit_test_" not in abs_path_str and "__perf_test_" not in abs_path_str) or not test_qualified_name:
logger.debug(f"[unique_inv_id] Skipping: path={abs_path_str}, test_qualified_name={test_qualified_name}")
continue
key = test_qualified_name + "#" + abs_path_str
parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr]
cur_invid = inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1]) # type: ignore[union-attr]
match_key = key + "#" + cur_invid
logger.debug(f"[unique_inv_id] Adding key: {match_key} with runtime {min(runtimes)}")
if match_key not in unique_inv_ids:
unique_inv_ids[match_key] = 0
unique_inv_ids[match_key] += min(runtimes)
logger.debug(f"[unique_inv_id] Result has {len(unique_inv_ids)} entries")
return unique_inv_ids
@ -183,25 +244,46 @@ def add_runtime_comments_to_generated_tests(
# Process each generated test
modified_tests = []
for test in generated_tests.generated_tests:
try:
tree = cst.parse_module(test.generated_original_test_source)
wrapper = MetadataWrapper(tree)
line_to_comments = get_fn_call_linenos(test, original_runtimes_dict, optimized_runtimes_dict)
comment_adder = CommentAdder(line_to_comments)
modified_tree = wrapper.visit(comment_adder)
modified_source = modified_tree.code
modified_test = GeneratedTests(
generated_original_test_source=modified_source,
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
instrumented_perf_test_source=test.instrumented_perf_test_source,
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
modified_tests.append(modified_test)
except Exception as e:
# If parsing fails, keep the original test
logger.debug(f"Failed to add runtime comments to test: {e}")
modified_tests.append(test)
is_python = _is_python_file(test.behavior_file_path)
if is_python:
# Use Python libcst-based comment insertion
try:
tree = cst.parse_module(test.generated_original_test_source)
wrapper = MetadataWrapper(tree)
line_to_comments = get_fn_call_linenos(test, original_runtimes_dict, optimized_runtimes_dict)
comment_adder = CommentAdder(line_to_comments)
modified_tree = wrapper.visit(comment_adder)
modified_source = modified_tree.code
modified_test = GeneratedTests(
generated_original_test_source=modified_source,
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
instrumented_perf_test_source=test.instrumented_perf_test_source,
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
modified_tests.append(modified_test)
except Exception as e:
# If parsing fails, keep the original test
logger.debug(f"Failed to add runtime comments to test: {e}")
modified_tests.append(test)
else:
try:
language_support = get_language_support(test.behavior_file_path)
modified_source = language_support.add_runtime_comments(
test.generated_original_test_source, original_runtimes_dict, optimized_runtimes_dict
)
modified_test = GeneratedTests(
generated_original_test_source=modified_source,
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
instrumented_perf_test_source=test.instrumented_perf_test_source,
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
modified_tests.append(modified_test)
except Exception as e:
logger.debug(f"Failed to add runtime comments to test: {e}")
modified_tests.append(test)
return GeneratedTestsList(generated_tests=modified_tests)
@ -247,3 +329,103 @@ def _compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.P
)
for func in test_functions_to_remove
]
# Patterns for normalizing codeflash imports (legacy -> npm package)
_CODEFLASH_REQUIRE_PATTERN = re.compile(
r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]\.?/?codeflash-jest-helper['\"]\s*\)"
)
_CODEFLASH_IMPORT_PATTERN = re.compile(r"import\s+(?:\*\s+as\s+)?(\w+)\s+from\s+['\"]\.?/?codeflash-jest-helper['\"]")
def normalize_codeflash_imports(source: str) -> str:
"""Normalize codeflash imports to use the npm package.
Replaces legacy local file imports:
const codeflash = require('./codeflash-jest-helper')
import codeflash from './codeflash-jest-helper'
With npm package imports:
const codeflash = require('codeflash')
Args:
source: JavaScript/TypeScript source code.
Returns:
Source code with normalized imports.
"""
# Replace CommonJS require
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
# Replace ES module import
return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
def inject_test_globals(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
# TODO: inside the prompt tell the llm if it should import jest functions or it's already injected in the global window
"""Inject test globals into all generated tests.
Args:
generated_tests: List of generated tests.
Returns:
Generated tests with test globals injected.
"""
# we only inject test globals for esm modules
global_import = (
"import { jest, describe, it, expect, beforeEach, afterEach, beforeAll, test } from '@jest/globals'\n"
)
for test in generated_tests.generated_tests:
test.generated_original_test_source = global_import + test.generated_original_test_source
test.instrumented_behavior_test_source = global_import + test.instrumented_behavior_test_source
test.instrumented_perf_test_source = global_import + test.instrumented_perf_test_source
return generated_tests
def disable_ts_check(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
"""Disable TypeScript type checking in all generated tests.
Args:
generated_tests: List of generated tests.
Returns:
Generated tests with TypeScript type checking disabled.
"""
# we only inject test globals for esm modules
ts_nocheck = "// @ts-nocheck\n"
for test in generated_tests.generated_tests:
test.generated_original_test_source = ts_nocheck + test.generated_original_test_source
test.instrumented_behavior_test_source = ts_nocheck + test.instrumented_behavior_test_source
test.instrumented_perf_test_source = ts_nocheck + test.instrumented_perf_test_source
return generated_tests
def normalize_generated_tests_imports(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
"""Normalize codeflash imports in all generated tests.
Args:
generated_tests: List of generated tests.
Returns:
Generated tests with normalized imports.
"""
normalized_tests = []
for test in generated_tests.generated_tests:
# Only normalize JS/TS files
if test.behavior_file_path.suffix in (".js", ".ts", ".jsx", ".tsx", ".mjs", ".mts"):
normalized_test = GeneratedTests(
generated_original_test_source=normalize_codeflash_imports(test.generated_original_test_source),
instrumented_behavior_test_source=normalize_codeflash_imports(test.instrumented_behavior_test_source),
instrumented_perf_test_source=normalize_codeflash_imports(test.instrumented_perf_test_source),
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
normalized_tests.append(normalized_test)
else:
normalized_tests.append(test)
return GeneratedTestsList(generated_tests=normalized_tests)

View file

@ -16,7 +16,7 @@ from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, sav
from codeflash.lsp.helpers import is_LSP_enabled
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool:
if not formatter_cmds or formatter_cmds[0] == "disabled":
return True
first_cmd = formatter_cmds[0]
@ -155,7 +155,8 @@ def get_cached_gh_event_data() -> dict[str, Any]:
if not event_path:
return {}
with open(event_path, encoding="utf-8") as f: # noqa: PTH123
return json.load(f) # type: ignore # noqa
result: dict[str, Any] = json.load(f)
return result
def is_repo_a_fork() -> bool:

View file

@ -40,11 +40,7 @@ def generate_unified_diff(original: str, modified: str, from_file: str, to_file:
def apply_formatter_cmds(
cmds: list[str],
path: Path,
test_dir_str: Optional[str],
print_status: bool, # noqa
exit_on_failure: bool = True, # noqa
cmds: list[str], path: Path, test_dir_str: Optional[str], print_status: bool, exit_on_failure: bool = True
) -> tuple[Path, str, bool]:
if not path.exists():
msg = f"File {path} does not exist. Cannot apply formatter commands."
@ -111,9 +107,9 @@ def format_code(
formatter_cmds: list[str],
path: Union[str, Path],
optimized_code: str = "",
check_diff: bool = False, # noqa
print_status: bool = True, # noqa
exit_on_failure: bool = True, # noqa
check_diff: bool = False,
print_status: bool = True,
exit_on_failure: bool = True,
) -> str:
if is_LSP_enabled():
exit_on_failure = False
@ -174,7 +170,7 @@ def format_code(
return formatted_code
def sort_imports(code: str, **kwargs: Any) -> str: # noqa : ANN401
def sort_imports(code: str, **kwargs: Any) -> str:
try:
# Deduplicate and sort imports, modify the code in memory, not on disk
sorted_code = isort.code(code, **kwargs)

View file

@ -89,7 +89,7 @@ class InjectPerfOnly(ast.NodeTransformer):
# it's much more efficient to visit nodes manually. We'll only descend into expressions/statements.
# Helper for manual walk
def iter_ast_calls(node): # noqa: ANN202, ANN001
def iter_ast_calls(node): # noqa: ANN202
# Generator to yield each ast.Call in test_node, preserves node identity
stack = [node]
while stack:
@ -690,15 +690,14 @@ def detect_frameworks_from_code(code: str) -> dict[str, str]:
frameworks["tensorflow"] = alias.asname if alias.asname else module_name
elif module_name == "jax":
frameworks["jax"] = alias.asname if alias.asname else module_name
elif isinstance(node, ast.ImportFrom): # noqa: SIM102
if node.module:
module_name = node.module.split(".")[0]
if module_name == "torch" and "torch" not in frameworks:
frameworks["torch"] = module_name
elif module_name == "tensorflow" and "tensorflow" not in frameworks:
frameworks["tensorflow"] = module_name
elif module_name == "jax" and "jax" not in frameworks:
frameworks["jax"] = module_name
elif isinstance(node, ast.ImportFrom) and node.module:
module_name = node.module.split(".")[0]
if module_name == "torch" and "torch" not in frameworks:
frameworks["torch"] = module_name
elif module_name == "tensorflow" and "tensorflow" not in frameworks:
frameworks["tensorflow"] = module_name
elif module_name == "jax" and "jax" not in frameworks:
frameworks["jax"] = module_name
return frameworks
@ -910,8 +909,7 @@ def _create_device_sync_precompute_statements(used_frameworks: dict[str, str] |
def _create_device_sync_statements(
used_frameworks: dict[str, str] | None,
for_return_value: bool = False, # noqa: FBT001, FBT002
used_frameworks: dict[str, str] | None, for_return_value: bool = False
) -> list[ast.stmt]:
"""Create AST statements for device synchronization using pre-computed conditions.
@ -1450,7 +1448,7 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
# Track when we enter a class
self.context_stack.append(node.name.value)
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
# Pop the context when we leave a class
self.context_stack.pop()
return updated_node
@ -1530,7 +1528,7 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
if import_alias.name.value == decorator_name:
self.has_import = True
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# If the import is already there, don't add it again
if self.has_import:
return updated_node

View file

@ -204,7 +204,7 @@ class LineProfilerDecoratorAdder(cst.CSTTransformer):
# Track when we enter a class
self.context_stack.append(node.name.value)
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
# Pop the context when we leave a class
self.context_stack.pop()
return updated_node
@ -268,7 +268,7 @@ class ProfileEnableTransformer(cst.CSTTransformer):
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
if not self.found_import:
return updated_node
@ -332,11 +332,11 @@ def add_profile_enable(original_code: str, line_profile_output_file: str) -> str
class ImportAdder(cst.CSTTransformer):
def __init__(self, import_statement) -> None: # noqa: ANN001
def __init__(self, import_statement) -> None:
self.import_statement = import_statement
self.has_import = False
def leave_Module(self, original_node, updated_node): # noqa: ANN001, ANN201, ARG002
def leave_Module(self, original_node, updated_node): # noqa: ANN201
# If the import is already there, don't add it again
if self.has_import:
return updated_node
@ -347,7 +347,7 @@ class ImportAdder(cst.CSTTransformer):
# Add the import to the module's body
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
def visit_ImportFrom(self, node) -> None: # noqa: ANN001
def visit_ImportFrom(self, node) -> None:
# Check if the profile is already imported from line_profiler
if node.module and node.module.value == "line_profiler":
for import_alias in node.names:

View file

@ -0,0 +1,106 @@
"""Code normalizers for different programming languages.
This module provides language-specific code normalizers that transform source code
into canonical forms for duplicate detection. The normalizers:
- Replace local variable names with canonical forms (var_0, var_1, etc.)
- Preserve function names, class names, parameters, and imports
- Remove or normalize comments and docstrings
- Produce consistent output for structurally identical code
Usage:
>>> normalizer = get_normalizer("python")
>>> normalized = normalizer.normalize(code)
>>> fingerprint = normalizer.get_fingerprint(code)
>>> are_same = normalizer.are_duplicates(code1, code2)
"""
from __future__ import annotations
from codeflash.code_utils.normalizers.base import CodeNormalizer
from codeflash.code_utils.normalizers.javascript import JavaScriptNormalizer, TypeScriptNormalizer
from codeflash.code_utils.normalizers.python import PythonNormalizer
__all__ = [
"CodeNormalizer",
"JavaScriptNormalizer",
"PythonNormalizer",
"TypeScriptNormalizer",
"get_normalizer",
"get_normalizer_for_extension",
]
# Registry of normalizers by language
_NORMALIZERS: dict[str, type[CodeNormalizer]] = {
"python": PythonNormalizer,
"javascript": JavaScriptNormalizer,
"typescript": TypeScriptNormalizer,
}
# Singleton cache for normalizer instances
_normalizer_instances: dict[str, CodeNormalizer] = {}
def get_normalizer(language: str) -> CodeNormalizer:
"""Get a code normalizer for the specified language.
Args:
language: Language name ('python', 'javascript', 'typescript')
Returns:
CodeNormalizer instance for the language
Raises:
ValueError: If no normalizer exists for the language
"""
language = language.lower()
# Check cache first
if language in _normalizer_instances:
return _normalizer_instances[language]
# Get normalizer class
if language not in _NORMALIZERS:
supported = ", ".join(sorted(_NORMALIZERS.keys()))
msg = f"No normalizer available for language '{language}'. Supported: {supported}"
raise ValueError(msg)
# Create and cache instance
normalizer = _NORMALIZERS[language]()
_normalizer_instances[language] = normalizer
return normalizer
def get_normalizer_for_extension(extension: str) -> CodeNormalizer | None:
"""Get a code normalizer based on file extension.
Args:
extension: File extension including dot (e.g., '.py', '.js')
Returns:
CodeNormalizer instance if found, None otherwise
"""
extension = extension.lower()
if not extension.startswith("."):
extension = f".{extension}"
for language in _NORMALIZERS:
normalizer = get_normalizer(language)
if extension in normalizer.supported_extensions:
return normalizer
return None
def register_normalizer(language: str, normalizer_class: type[CodeNormalizer]) -> None:
"""Register a new normalizer for a language.
Args:
language: Language name
normalizer_class: CodeNormalizer subclass
"""
_NORMALIZERS[language.lower()] = normalizer_class
# Clear cached instance if it exists
_normalizer_instances.pop(language.lower(), None)

View file

@ -0,0 +1,104 @@
"""Abstract base class for code normalizers.
Code normalizers transform source code into a canonical form for duplicate detection.
They normalize variable names, remove comments/docstrings, and produce consistent output
that can be compared across different implementations of the same algorithm.
"""
# TODO:{claude} move to base.py in language folder
from __future__ import annotations
from abc import ABC, abstractmethod
class CodeNormalizer(ABC):
"""Abstract base class for language-specific code normalizers.
Subclasses must implement the normalize() method for their specific language.
The normalization should:
- Normalize local variable names to canonical forms (var_0, var_1, etc.)
- Preserve function names, class names, parameters, and imports
- Remove or normalize comments and docstrings
- Produce consistent output for structurally identical code
Example:
>>> normalizer = PythonNormalizer()
>>> code1 = "def foo(x): y = x + 1; return y"
>>> code2 = "def foo(x): z = x + 1; return z"
>>> normalizer.normalize(code1) == normalizer.normalize(code2)
True
"""
@property
@abstractmethod
def language(self) -> str:
"""Return the language this normalizer handles."""
...
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return ()
@abstractmethod
def normalize(self, code: str) -> str:
"""Normalize code to a canonical form for comparison.
Args:
code: Source code to normalize
Returns:
Normalized representation of the code
"""
...
@abstractmethod
def normalize_for_hash(self, code: str) -> str:
"""Normalize code optimized for hashing/fingerprinting.
This may return a more compact representation than normalize().
Args:
code: Source code to normalize
Returns:
Normalized representation suitable for hashing
"""
...
def are_duplicates(self, code1: str, code2: str) -> bool:
"""Check if two code segments are duplicates after normalization.
Args:
code1: First code segment
code2: Second code segment
Returns:
True if codes are structurally identical
"""
try:
normalized1 = self.normalize_for_hash(code1)
normalized2 = self.normalize_for_hash(code2)
except Exception:
return False
else:
return normalized1 == normalized2
def get_fingerprint(self, code: str) -> str:
"""Generate a fingerprint hash for normalized code.
Args:
code: Source code to fingerprint
Returns:
SHA-256 hash of normalized code
"""
import hashlib
normalized = self.normalize_for_hash(code)
return hashlib.sha256(normalized.encode()).hexdigest()

View file

@ -0,0 +1,290 @@
"""JavaScript/TypeScript code normalizer using tree-sitter."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from codeflash.code_utils.normalizers.base import CodeNormalizer
if TYPE_CHECKING:
from tree_sitter import Node
# TODO:{claude} move to language support directory to keep the directory structure clean
class JavaScriptVariableNormalizer:
"""Normalizes JavaScript/TypeScript code for duplicate detection using tree-sitter.
Normalizes local variable names while preserving function names, class names,
parameters, and imported names.
"""
def __init__(self) -> None:
self.var_counter = 0
self.var_mapping: dict[str, str] = {}
self.preserved_names: set[str] = set()
# Common JavaScript builtins
self.builtins = {
"console",
"window",
"document",
"Math",
"JSON",
"Object",
"Array",
"String",
"Number",
"Boolean",
"Date",
"RegExp",
"Error",
"Promise",
"Map",
"Set",
"WeakMap",
"WeakSet",
"Symbol",
"Proxy",
"Reflect",
"undefined",
"null",
"NaN",
"Infinity",
"globalThis",
"parseInt",
"parseFloat",
"isNaN",
"isFinite",
"eval",
"setTimeout",
"setInterval",
"clearTimeout",
"clearInterval",
"fetch",
"require",
"module",
"exports",
"process",
"__dirname",
"__filename",
"Buffer",
}
def get_normalized_name(self, name: str) -> str:
"""Get or create normalized name for a variable."""
if name in self.builtins or name in self.preserved_names:
return name
if name not in self.var_mapping:
self.var_mapping[name] = f"var_{self.var_counter}"
self.var_counter += 1
return self.var_mapping[name]
def collect_preserved_names(self, node: Node, source_code: bytes) -> None:
"""Collect names that should be preserved (function names, class names, imports, params)."""
# Function declarations and expressions - preserve the function name
if node.type in ("function_declaration", "function_expression", "method_definition", "arrow_function"):
name_node = node.child_by_field_name("name")
if name_node:
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
# Preserve parameters
params_node = node.child_by_field_name("parameters") or node.child_by_field_name("parameter")
if params_node:
self._collect_parameter_names(params_node, source_code)
# Class declarations
elif node.type == "class_declaration":
name_node = node.child_by_field_name("name")
if name_node:
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
# Import declarations
elif node.type in ("import_statement", "import_declaration"):
for child in node.children:
if child.type == "import_clause":
self._collect_import_names(child, source_code)
elif child.type == "identifier":
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
# Recurse
for child in node.children:
self.collect_preserved_names(child, source_code)
def _collect_parameter_names(self, node: Node, source_code: bytes) -> None:
"""Collect parameter names from a parameters node."""
for child in node.children:
if child.type == "identifier":
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
elif child.type in ("required_parameter", "optional_parameter", "rest_parameter"):
pattern_node = child.child_by_field_name("pattern")
if pattern_node and pattern_node.type == "identifier":
self.preserved_names.add(
source_code[pattern_node.start_byte : pattern_node.end_byte].decode("utf-8")
)
# Recurse for nested patterns
self._collect_parameter_names(child, source_code)
def _collect_import_names(self, node: Node, source_code: bytes) -> None:
"""Collect imported names from import clause."""
for child in node.children:
if child.type == "identifier":
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
elif child.type == "import_specifier":
# Get the local name (alias or original)
alias_node = child.child_by_field_name("alias")
name_node = child.child_by_field_name("name")
if alias_node:
self.preserved_names.add(source_code[alias_node.start_byte : alias_node.end_byte].decode("utf-8"))
elif name_node:
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
self._collect_import_names(child, source_code)
def normalize_tree(self, node: Node, source_code: bytes) -> str:
"""Normalize the AST tree to a string representation for comparison."""
parts: list[str] = []
self._normalize_node(node, source_code, parts)
return " ".join(parts)
def _normalize_node(self, node: Node, source_code: bytes, parts: list[str]) -> None:
"""Recursively normalize a node."""
# Skip comments
if node.type in ("comment", "line_comment", "block_comment"):
return
# Handle identifiers - normalize variable names
if node.type == "identifier":
name = source_code[node.start_byte : node.end_byte].decode("utf-8")
normalized = self.get_normalized_name(name)
parts.append(normalized)
return
# Handle type identifiers (TypeScript) - preserve as-is
if node.type == "type_identifier":
parts.append(source_code[node.start_byte : node.end_byte].decode("utf-8"))
return
# Handle string literals - normalize to placeholder
if node.type in ("string", "template_string", "string_fragment"):
parts.append('"STR"')
return
# Handle number literals - normalize to placeholder
if node.type == "number":
parts.append("NUM")
return
# For leaf nodes, output the node type
if len(node.children) == 0:
text = source_code[node.start_byte : node.end_byte].decode("utf-8")
parts.append(text)
return
# Output node type for structure
parts.append(f"({node.type}")
# Recurse into children
for child in node.children:
self._normalize_node(child, source_code, parts)
parts.append(")")
def _basic_normalize(code: str) -> str:
"""Basic normalization: remove comments and normalize whitespace."""
# Remove single-line comments
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
# Remove multi-line comments
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
# Normalize whitespace
return " ".join(code.split())
class JavaScriptNormalizer(CodeNormalizer):
"""JavaScript code normalizer using tree-sitter.
Normalizes JavaScript code by:
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
- Preserving function names, class names, parameters, and imports
- Removing comments
- Normalizing string and number literals
"""
@property
def language(self) -> str:
"""Return the language this normalizer handles."""
return "javascript"
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return (".js", ".jsx", ".mjs", ".cjs")
def _get_tree_sitter_language(self) -> str:
"""Get the tree-sitter language identifier."""
return "javascript"
def normalize(self, code: str) -> str:
"""Normalize JavaScript code to a canonical form.
Args:
code: JavaScript source code to normalize
Returns:
Normalized representation of the code
"""
try:
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
lang_map = {"javascript": TreeSitterLanguage.JAVASCRIPT, "typescript": TreeSitterLanguage.TYPESCRIPT}
lang = lang_map.get(self._get_tree_sitter_language(), TreeSitterLanguage.JAVASCRIPT)
analyzer = TreeSitterAnalyzer(lang)
tree = analyzer.parse(code)
if tree.root_node.has_error:
return _basic_normalize(code)
normalizer = JavaScriptVariableNormalizer()
source_bytes = code.encode("utf-8")
# First pass: collect preserved names
normalizer.collect_preserved_names(tree.root_node, source_bytes)
# Second pass: normalize and build representation
return normalizer.normalize_tree(tree.root_node, source_bytes)
except Exception:
return _basic_normalize(code)
def normalize_for_hash(self, code: str) -> str:
"""Normalize JavaScript code optimized for hashing.
For JavaScript, this is the same as normalize().
Args:
code: JavaScript source code to normalize
Returns:
Normalized representation suitable for hashing
"""
return self.normalize(code)
class TypeScriptNormalizer(JavaScriptNormalizer):
"""TypeScript code normalizer using tree-sitter.
Inherits from JavaScriptNormalizer and overrides language-specific settings.
"""
@property
def language(self) -> str:
"""Return the language this normalizer handles."""
return "typescript"
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return (".ts", ".tsx", ".mts", ".cts")
def _get_tree_sitter_language(self) -> str:
"""Get the tree-sitter language identifier."""
return "typescript"

View file

@ -0,0 +1,226 @@
"""Python code normalizer using AST transformation."""
from __future__ import annotations
import ast
from codeflash.code_utils.normalizers.base import CodeNormalizer
class VariableNormalizer(ast.NodeTransformer):
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
Preserves function names, class names, parameters, built-ins, and imported names.
"""
def __init__(self) -> None:
self.var_counter = 0
self.var_mapping: dict[str, str] = {}
self.scope_stack: list[dict] = []
self.builtins = set(dir(__builtins__))
self.imports: set[str] = set()
self.global_vars: set[str] = set()
self.nonlocal_vars: set[str] = set()
self.parameters: set[str] = set()
def enter_scope(self) -> None:
"""Enter a new scope (function/class)."""
self.scope_stack.append(
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
)
def exit_scope(self) -> None:
"""Exit current scope and restore parent scope."""
if self.scope_stack:
scope = self.scope_stack.pop()
self.var_mapping = scope["var_mapping"]
self.var_counter = scope["var_counter"]
self.parameters = scope["parameters"]
def get_normalized_name(self, name: str) -> str:
"""Get or create normalized name for a variable."""
if (
name in self.builtins
or name in self.imports
or name in self.global_vars
or name in self.nonlocal_vars
or name in self.parameters
):
return name
if name not in self.var_mapping:
self.var_mapping[name] = f"var_{self.var_counter}"
self.var_counter += 1
return self.var_mapping[name]
def visit_Import(self, node: ast.Import) -> ast.Import:
"""Track imported names."""
for alias in node.names:
name = alias.asname if alias.asname else alias.name
self.imports.add(name.split(".")[0])
return node
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
"""Track imported names from modules."""
for alias in node.names:
name = alias.asname if alias.asname else alias.name
self.imports.add(name)
return node
def visit_Global(self, node: ast.Global) -> ast.Global:
"""Track global variable declarations."""
self.global_vars.update(node.names)
return node
def visit_Nonlocal(self, node: ast.Nonlocal) -> ast.Nonlocal:
"""Track nonlocal variable declarations."""
self.nonlocal_vars.update(node.names)
return node
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
"""Process function but keep function name and parameters unchanged."""
self.enter_scope()
for arg in node.args.args:
self.parameters.add(arg.arg)
if node.args.vararg:
self.parameters.add(node.args.vararg.arg)
if node.args.kwarg:
self.parameters.add(node.args.kwarg.arg)
for arg in node.args.kwonlyargs:
self.parameters.add(arg.arg)
node = self.generic_visit(node)
self.exit_scope()
return node
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
"""Handle async functions same as regular functions."""
return self.visit_FunctionDef(node) # type: ignore[return-value]
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
"""Process class but keep class name unchanged."""
self.enter_scope()
node = self.generic_visit(node)
self.exit_scope()
return node
def visit_Name(self, node: ast.Name) -> ast.Name:
"""Normalize variable names in Name nodes."""
if isinstance(node.ctx, (ast.Store, ast.Del)):
if (
node.id not in self.builtins
and node.id not in self.imports
and node.id not in self.parameters
and node.id not in self.global_vars
and node.id not in self.nonlocal_vars
):
node.id = self.get_normalized_name(node.id)
elif isinstance(node.ctx, ast.Load) and node.id in self.var_mapping:
node.id = self.var_mapping[node.id]
return node
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> ast.ExceptHandler:
"""Normalize exception variable names."""
if node.name:
node.name = self.get_normalized_name(node.name)
return self.generic_visit(node)
def visit_comprehension(self, node: ast.comprehension) -> ast.comprehension:
"""Normalize comprehension target variables."""
old_mapping = dict(self.var_mapping)
old_counter = self.var_counter
node = self.generic_visit(node)
self.var_mapping = old_mapping
self.var_counter = old_counter
return node
def visit_For(self, node: ast.For) -> ast.For:
"""Handle for loop target variables."""
return self.generic_visit(node)
def visit_With(self, node: ast.With) -> ast.With:
"""Handle with statement as variables."""
return self.generic_visit(node)
def _remove_docstrings_from_ast(node: ast.AST) -> None:
"""Remove docstrings from AST nodes."""
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
stack = [node]
while stack:
current_node = stack.pop()
if isinstance(current_node, node_types):
body = current_node.body
if (
body
and isinstance(body[0], ast.Expr)
and isinstance(body[0].value, ast.Constant)
and isinstance(body[0].value.value, str)
):
current_node.body = body[1:]
stack.extend([child for child in body if isinstance(child, node_types)])
class PythonNormalizer(CodeNormalizer):
"""Python code normalizer using AST transformation.
Normalizes Python code by:
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
- Preserving function names, class names, parameters, and imports
- Optionally removing docstrings
"""
@property
def language(self) -> str:
"""Return the language this normalizer handles."""
return "python"
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return (".py", ".pyw", ".pyi")
def normalize(self, code: str, remove_docstrings: bool = True) -> str:
"""Normalize Python code to a canonical form.
Args:
code: Python source code to normalize
remove_docstrings: Whether to remove docstrings
Returns:
Normalized Python code as a string
"""
tree = ast.parse(code)
if remove_docstrings:
_remove_docstrings_from_ast(tree)
normalizer = VariableNormalizer()
normalized_tree = normalizer.visit(tree)
ast.fix_missing_locations(normalized_tree)
return ast.unparse(normalized_tree)
def normalize_for_hash(self, code: str) -> str:
"""Normalize Python code optimized for hashing.
Returns AST dump which is faster than unparsing.
Args:
code: Python source code to normalize
Returns:
AST dump string suitable for hashing
"""
tree = ast.parse(code)
_remove_docstrings_from_ast(tree)
normalizer = VariableNormalizer()
normalized_tree = normalizer.visit(tree)
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)

View file

@ -702,7 +702,7 @@ def _wait_for_manual_code_input(oauth: OAuthHandler) -> None:
if not oauth.is_complete:
oauth.manual_code = code.strip()
oauth.is_complete = True
except Exception: # noqa: S110
except Exception:
pass

View file

@ -242,9 +242,9 @@ def get_cross_platform_subprocess_run_args(
cwd: Path | str | None = None,
env: Mapping[str, str] | None = None,
timeout: Optional[float] = None,
check: bool = False, # noqa: FBT001, FBT002
text: bool = True, # noqa: FBT001, FBT002
capture_output: bool = True, # noqa: FBT001, FBT002 (only for non-Windows)
check: bool = False,
text: bool = True,
capture_output: bool = True,
) -> dict[str, str]:
run_args = {"cwd": cwd, "env": env, "text": text, "timeout": timeout, "check": check}
if sys.platform == "win32":

View file

@ -649,7 +649,7 @@ def tabulate(
headersalign=None,
rowalign=None,
maxheadercolwidths=None,
):
) -> str:
if tabular_data is None:
tabular_data = []

View file

@ -39,7 +39,7 @@ def get_latest_version_from_pypi() -> str | None:
return latest_version
logger.debug(f"Failed to fetch version from PyPI: {response.status_code}")
return None # noqa: TRY300
return None
except requests.RequestException as e:
logger.debug(f"Network error fetching version from PyPI: {e}")
return None

View file

@ -21,6 +21,10 @@ from codeflash.context.unused_definition_remover import (
remove_unused_definitions_by_function_names,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
# Language support imports for multi-language code context extraction
from codeflash.languages import is_python
from codeflash.languages.base import Language
from codeflash.models.models import (
CodeContextType,
CodeOptimizationContext,
@ -35,6 +39,7 @@ if TYPE_CHECKING:
from libcst import CSTNode
from codeflash.context.unused_definition_remover import UsageInfo
from codeflash.languages.base import HelperFunction
def build_testgen_context(
@ -75,6 +80,12 @@ def get_code_optimization_context(
optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
) -> CodeOptimizationContext:
# Route to language-specific implementation for non-Python languages
if not is_python():
return get_code_optimization_context_for_language(
function_to_optimize, project_root_path, optim_token_limit, testgen_token_limit
)
# Get FunctionSource representation of helpers of FTO
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(
{function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path
@ -95,7 +106,7 @@ def get_code_optimization_context(
qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if "." in qn})
# Get FunctionSource representation of helpers of helpers of FTO
helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(
helpers_of_helpers_dict, _helpers_of_helpers_list = get_function_sources_from_jedi(
helpers_of_fto_qualified_names_dict, project_root_path
)
@ -198,11 +209,161 @@ def get_code_optimization_context(
)
def get_code_optimization_context_for_language(
function_to_optimize: FunctionToOptimize,
project_root_path: Path,
optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
) -> CodeOptimizationContext:
"""Extract code optimization context for non-Python languages.
Uses the language support abstraction to extract code context and converts
it to the CodeOptimizationContext format expected by the pipeline.
This function supports multi-file context extraction, grouping helpers by file
and creating proper CodeStringsMarkdown with file paths for multi-file replacement.
Args:
function_to_optimize: The function to extract context for.
project_root_path: Root of the project.
optim_token_limit: Token limit for optimization context.
testgen_token_limit: Token limit for testgen context.
Returns:
CodeOptimizationContext with target code and dependencies.
"""
from codeflash.languages import get_language_support
from codeflash.languages.base import FunctionInfo, ParentInfo
# Get language support for this function
language = Language(function_to_optimize.language)
lang_support = get_language_support(language)
# Convert FunctionToOptimize to FunctionInfo for language support
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents)
func_info = FunctionInfo(
name=function_to_optimize.function_name,
file_path=function_to_optimize.file_path,
start_line=function_to_optimize.starting_line or 1,
end_line=function_to_optimize.ending_line or 1,
parents=parents,
is_async=function_to_optimize.is_async,
is_method=len(function_to_optimize.parents) > 0,
language=language,
)
# Extract code context using language support
code_context = lang_support.extract_code_context(func_info, project_root_path, project_root_path)
# Build imports string if available
imports_code = "\n".join(code_context.imports) if code_context.imports else ""
# Get relative path for target file
try:
target_relative_path = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve())
except ValueError:
target_relative_path = function_to_optimize.file_path
# Group helpers by file path
helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list)
helper_function_sources = []
for helper in code_context.helper_functions:
helpers_by_file[helper.file_path].append(helper)
# Convert to FunctionSource for pipeline compatibility
helper_function_sources.append(
FunctionSource(
file_path=helper.file_path,
qualified_name=helper.qualified_name,
fully_qualified_name=helper.qualified_name,
only_function_name=helper.name,
source_code=helper.source_code,
jedi_definition=None,
)
)
# Build read-writable code (target file + same-file helpers + global variables)
read_writable_code_strings = []
# Combine target code with same-file helpers
target_file_code = code_context.target_code
same_file_helpers = helpers_by_file.get(function_to_optimize.file_path, [])
if same_file_helpers:
helper_code = "\n\n".join(h.source_code for h in same_file_helpers)
target_file_code = target_file_code + "\n\n" + helper_code
# Add global variables (module-level declarations) referenced by the function and helpers
# These should be included in read-writable context so AI can modify them if needed
if code_context.read_only_context:
target_file_code = code_context.read_only_context + "\n\n" + target_file_code
# Add imports to target file code
if imports_code:
target_file_code = imports_code + "\n\n" + target_file_code
read_writable_code_strings.append(
CodeString(code=target_file_code, file_path=target_relative_path, language=function_to_optimize.language)
)
# Add helper files (cross-file helpers)
for file_path, file_helpers in helpers_by_file.items():
if file_path == function_to_optimize.file_path:
continue # Already included in target file
try:
helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve())
except ValueError:
helper_relative_path = file_path
# Combine all helpers from this file
combined_helper_code = "\n\n".join(h.source_code for h in file_helpers)
read_writable_code_strings.append(
CodeString(
code=combined_helper_code, file_path=helper_relative_path, language=function_to_optimize.language
)
)
read_writable_code = CodeStringsMarkdown(
code_strings=read_writable_code_strings, language=function_to_optimize.language
)
# Build testgen context (same as read_writable for non-Python)
testgen_context = CodeStringsMarkdown(
code_strings=read_writable_code_strings.copy(), language=function_to_optimize.language
)
# Check token limits
read_writable_tokens = encoded_tokens_len(read_writable_code.markdown)
if read_writable_tokens > optim_token_limit:
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
testgen_tokens = encoded_tokens_len(testgen_context.markdown)
if testgen_tokens > testgen_token_limit:
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
# Generate code hash from all read-writable code
code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest()
return CodeOptimizationContext(
testgen_context=testgen_context,
read_writable_code=read_writable_code,
# Global variables are now included in read-writable code, so don't duplicate in read-only
read_only_context_code="",
hashing_code_context=read_writable_code.flat,
hashing_code_context_hash=code_hash,
helper_functions=helper_function_sources,
preexisting_objects=set(), # Not implemented for non-Python yet
)
def extract_code_markdown_context_from_files(
helpers_of_fto: dict[Path, set[FunctionSource]],
helpers_of_helpers: dict[Path, set[FunctionSource]],
project_root_path: Path,
remove_docstrings: bool = False, # noqa: FBT001, FBT002
remove_docstrings: bool = False,
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
) -> CodeStringsMarkdown:
"""Extract code context from files containing target functions and their helpers, formatting them as markdown.
@ -833,7 +994,7 @@ def get_imported_names(import_node: cst.Import | cst.ImportFrom) -> set[str]:
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
"""Removes the docstring from an indented block if it exists.""" # noqa: D401
"""Removes the docstring from an indented block if it exists."""
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
return indented_block
first_stmt = indented_block.body[0].body[0]
@ -847,7 +1008,7 @@ def parse_code_and_prune_cst(
code_context_type: CodeContextType,
target_functions: set[str],
helpers_of_helper_functions: set[str] = set(), # noqa: B006
remove_docstrings: bool = False, # noqa: FBT001, FBT002
remove_docstrings: bool = False,
) -> str:
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
module = cst.parse_module(code)
@ -888,7 +1049,7 @@ def parse_code_and_prune_cst(
return ""
def prune_cst_for_read_writable_code( # noqa: PLR0911
def prune_cst_for_read_writable_code(
node: cst.CSTNode, target_functions: set[str], defs_with_usages: dict[str, UsageInfo], prefix: str = ""
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
@ -1006,7 +1167,7 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
return (node.with_changes(**updates) if updates else node), True
def prune_cst_for_code_hashing( # noqa: PLR0911
def prune_cst_for_code_hashing(
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
@ -1095,14 +1256,14 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
return (node.with_changes(**updates) if updates else node), True
def prune_cst_for_context( # noqa: PLR0911
def prune_cst_for_context(
node: cst.CSTNode,
target_functions: set[str],
helpers_of_helper_functions: set[str],
prefix: str = "",
remove_docstrings: bool = False, # noqa: FBT001, FBT002
include_target_in_output: bool = False, # noqa: FBT001, FBT002
include_init_dunder: bool = False, # noqa: FBT001, FBT002
remove_docstrings: bool = False,
include_target_in_output: bool = False,
include_init_dunder: bool = False,
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node for code context extraction.

View file

@ -11,6 +11,7 @@ import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.languages import is_javascript
from codeflash.models.models import CodeString, CodeStringsMarkdown
if TYPE_CHECKING:
@ -208,7 +209,7 @@ class DependencyCollector(cst.CSTVisitor):
self._extract_names_from_annotation(node.value)
# No need to check the attribute name itself as it's likely not a top-level definition
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.function_depth -= 1
if self.function_depth == 0 and self.class_depth == 0:
@ -237,7 +238,7 @@ class DependencyCollector(cst.CSTVisitor):
self.class_depth += 1
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.class_depth -= 1
if self.class_depth == 0:
@ -260,7 +261,7 @@ class DependencyCollector(cst.CSTVisitor):
# Use the first tracked name as the current top-level name (for dependency tracking)
self.current_top_level_name = tracked_names[0]
def leave_Assign(self, original_node: cst.Assign) -> None: # noqa: ARG002
def leave_Assign(self, original_node: cst.Assign) -> None:
if self.processing_variable:
self.processing_variable = False
self.current_variable_names.clear()
@ -370,7 +371,7 @@ class QualifiedFunctionUsageMarker:
self.mark_as_used_recursively(dep)
def remove_unused_definitions_recursively( # noqa: PLR0911
def remove_unused_definitions_recursively(
node: cst.CSTNode, definitions: dict[str, UsageInfo]
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node to remove unused definitions.
@ -553,7 +554,7 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
# Apply the recursive removal transformation
modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages)
return modified_module.code if modified_module else "" # noqa: TRY300
return modified_module.code if modified_module else ""
except Exception as e:
# If any other error occurs during processing, return the original code
logger.debug(f"Error processing code to remove unused definitions: {type(e).__name__}: {e}")
@ -629,8 +630,8 @@ def _analyze_imports_in_optimized_code(
helpers_by_file_and_func = defaultdict(dict)
helpers_by_file = defaultdict(list) # preserved for "import module"
for helper in code_context.helper_functions:
jedi_type = helper.jedi_definition.type
if jedi_type != "class":
jedi_type = helper.jedi_definition.type if helper.jedi_definition else None
if jedi_type != "class": # Include when jedi_definition is None (non-Python)
func_name = helper.only_function_name
module_name = helper.file_path.stem
# Cache function lookup for this (module, func)
@ -716,6 +717,11 @@ def detect_unused_helper_functions(
List of FunctionSource objects representing unused helper functions
"""
# Skip this analysis for non-Python languages since we use Python's ast module
if is_javascript():
logger.debug("Skipping unused helper function detection for JavaScript/TypeScript")
return []
if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0:
return list(
chain.from_iterable(
@ -783,7 +789,8 @@ def detect_unused_helper_functions(
unused_helpers = []
entrypoint_file_path = function_to_optimize.file_path
for helper_function in code_context.helper_functions:
if helper_function.jedi_definition.type != "class":
jedi_type = helper_function.jedi_definition.type if helper_function.jedi_definition else None
if jedi_type != "class": # Include when jedi_definition is None (non-Python)
# Check if the helper function is called using multiple name variants
helper_qualified_name = helper_function.qualified_name
helper_simple_name = helper_function.only_function_name

View file

@ -29,6 +29,7 @@ from codeflash.code_utils.code_utils import (
)
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
from codeflash.languages import is_javascript, is_python
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
if TYPE_CHECKING:
@ -554,11 +555,119 @@ def filter_test_files_by_imports(
return filtered_map
def _detect_language_from_functions(file_to_funcs: dict[Path, list[FunctionToOptimize]] | None) -> str | None:
"""Detect language from the functions to optimize.
Args:
file_to_funcs: Dictionary mapping file paths to functions.
Returns:
Language string (e.g., "python", "javascript") or None if not determinable.
"""
if not file_to_funcs:
return None
for funcs in file_to_funcs.values():
if funcs:
return funcs[0].language
return None
def discover_tests_for_language(
cfg: TestConfig, language: str, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
"""Discover tests using language-specific support.
Args:
cfg: Test configuration.
language: Language identifier (e.g., "javascript").
file_to_funcs_to_optimize: Dictionary mapping file paths to functions.
Returns:
Tuple of (function_to_tests_map, num_tests, num_replay_tests).
"""
from codeflash.languages import get_language_support
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
try:
lang_support = get_language_support(Language(language))
except Exception:
logger.warning(f"Unsupported language {language}, returning empty test map")
return {}, 0, 0
# Convert FunctionToOptimize to FunctionInfo for the language support API
# Also build a mapping from simple qualified_name to full qualified_name_with_modules
function_infos: list[FunctionInfo] = []
simple_to_full_name: dict[str, str] = {}
if file_to_funcs_to_optimize:
for funcs in file_to_funcs_to_optimize.values():
for func in funcs:
parents = tuple(ParentInfo(p.name, p.type) for p in func.parents)
func_info = FunctionInfo(
name=func.function_name,
file_path=func.file_path,
start_line=func.starting_line or 0,
end_line=func.ending_line or 0,
start_col=func.starting_col,
end_col=func.ending_col,
is_async=func.is_async,
is_method=bool(func.parents and any(p.type == "ClassDef" for p in func.parents)),
parents=parents,
language=Language(language),
)
function_infos.append(func_info)
# Map simple qualified_name to full qualified_name_with_modules_from_root
simple_to_full_name[func_info.qualified_name] = func.qualified_name_with_modules_from_root(
cfg.project_root_path
)
# Use language support to discover tests
test_map = lang_support.discover_tests(cfg.tests_root, function_infos)
# Convert TestInfo back to FunctionCalledInTest format
# Use the full qualified name (with modules) as the key for consistency with Python
function_to_tests: dict[str, set[FunctionCalledInTest]] = defaultdict(set)
num_tests = 0
for qualified_name, test_infos in test_map.items():
# Convert simple qualified_name to full qualified_name_with_modules
full_qualified_name = simple_to_full_name.get(qualified_name, qualified_name)
for test_info in test_infos:
function_to_tests[full_qualified_name].add(
FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=test_info.test_file,
test_class=test_info.test_class,
test_function=test_info.test_name,
test_type=TestType.EXISTING_UNIT_TEST,
),
position=CodePosition(line_no=0, col_no=0),
)
)
num_tests += 1
return dict(function_to_tests), num_tests, 0
def discover_unit_tests(
cfg: TestConfig,
discover_only_these_tests: list[Path] | None = None,
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None,
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
# Detect language from functions being optimized
language = _detect_language_from_functions(file_to_funcs_to_optimize)
# Route to language-specific test discovery for non-Python languages
if not is_python():
# For JavaScript/TypeScript, tests_project_rootdir should be tests_root itself
# The Jest helper will be configured to NOT include "tests." prefix to match
if is_javascript():
cfg.tests_project_rootdir = cfg.tests_root
return discover_tests_for_language(cfg, language, file_to_funcs_to_optimize)
# Existing Python logic
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
strategy = framework_strategies.get(cfg.test_framework, None)
if not strategy:

View file

@ -26,6 +26,9 @@ from codeflash.code_utils.code_utils import (
from codeflash.code_utils.env_utils import get_pr_number
from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.languages import get_language_support, get_supported_extensions
from codeflash.languages.base import Language
from codeflash.languages.registry import is_language_supported
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.models.models import FunctionParent
from codeflash.telemetry.posthog_cf import ph
@ -59,7 +62,7 @@ class ReturnStatementVisitor(cst.CSTVisitor):
super().__init__()
self.has_return_statement: bool = False
def visit_Return(self, node: cst.Return) -> None: # noqa: ARG002
def visit_Return(self, node: cst.Return) -> None:
self.has_return_statement = True
@ -135,7 +138,10 @@ class FunctionToOptimize:
parents: A list of parent scopes, which could be classes or functions.
starting_line: The starting line number of the function in the file.
ending_line: The ending line number of the function in the file.
starting_col: The starting column offset (for precise location in multi-line contexts).
ending_col: The ending column offset (for precise location in multi-line contexts).
is_async: Whether this function is defined as async.
language: The programming language of this function (default: "python").
The qualified_name property provides the full name of the function, including
any parent class or function names. The qualified_name_with_modules_from_root
@ -148,7 +154,10 @@ class FunctionToOptimize:
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
starting_line: Optional[int] = None
ending_line: Optional[int] = None
starting_col: Optional[int] = None # Column offset for precise location
ending_col: Optional[int] = None # Column offset for precise location
is_async: bool = False
language: str = "python" # Language identifier for multi-language support
@property
def top_level_parent_name(self) -> str:
@ -172,6 +181,98 @@ class FunctionToOptimize:
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
# =============================================================================
# Multi-language support helpers
# =============================================================================
def get_files_for_language(
module_root_path: Path, ignore_paths: list[Path], language: Language | None = None
) -> list[Path]:
"""Get all source files for supported languages.
Args:
module_root_path: Root path to search for source files.
ignore_paths: List of paths to ignore (can be files or directories).
language: Optional specific language to filter for. If None, includes all supported languages.
Returns:
List of file paths matching supported extensions.
"""
if language is not None:
support = get_language_support(language)
extensions = support.file_extensions
else:
extensions = tuple(get_supported_extensions())
files = []
for ext in extensions:
pattern = f"*{ext}"
for file_path in module_root_path.rglob(pattern):
if any(file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths):
continue
files.append(file_path)
return files
def _find_all_functions_in_python_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
"""Find all optimizable functions in a Python file using AST parsing.
This is the original Python implementation preserved for backward compatibility.
"""
functions: dict[Path, list[FunctionToOptimize]] = {}
with file_path.open(encoding="utf8") as f:
try:
ast_module = ast.parse(f.read())
except Exception as e:
if DEBUG_MODE:
logger.exception(e)
return functions
function_name_visitor = FunctionWithReturnStatement(file_path)
function_name_visitor.visit(ast_module)
functions[file_path] = function_name_visitor.functions
return functions
def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
"""Find all optimizable functions using the language support abstraction.
This function uses the registered language support for the file's language
to discover functions, then converts them to FunctionToOptimize instances.
"""
from codeflash.languages.base import FunctionFilterCriteria
functions: dict[Path, list[FunctionToOptimize]] = {}
try:
lang_support = get_language_support(file_path)
criteria = FunctionFilterCriteria(require_return=True)
function_infos = lang_support.discover_functions(file_path, criteria)
ftos = []
for func_info in function_infos:
parents = [FunctionParent(p.name, p.type) for p in func_info.parents]
ftos.append(
FunctionToOptimize(
function_name=func_info.name,
file_path=func_info.file_path,
parents=parents,
starting_line=func_info.start_line,
ending_line=func_info.end_line,
starting_col=func_info.start_col,
ending_col=func_info.end_col,
is_async=func_info.is_async,
language=func_info.language.value,
)
)
functions[file_path] = ftos
except Exception as e:
logger.debug(f"Failed to discover functions in {file_path}: {e}")
return functions
def get_functions_to_optimize(
optimize_all: str | None,
replay_test: list[Path] | None,
@ -194,7 +295,7 @@ def get_functions_to_optimize(
if optimize_all:
logger.info("!lsp|Finding all functions in the module '%s'", optimize_all)
console.rule()
functions = get_all_files_and_functions(Path(optimize_all))
functions = get_all_files_and_functions(Path(optimize_all), ignore_paths)
elif replay_test:
functions, trace_file_path = get_all_replay_test_functions(
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
@ -251,7 +352,7 @@ def get_functions_to_optimize(
return filtered_modified_functions, functions_count, trace_file_path
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]: # noqa: FBT001
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]:
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
return get_functions_within_lines(modified_lines)
@ -356,9 +457,22 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str
return functions
def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[FunctionToOptimize]]:
def get_all_files_and_functions(
module_root_path: Path, ignore_paths: list[Path], language: Language | None = None
) -> dict[str, list[FunctionToOptimize]]:
"""Get all optimizable functions from files in the module root.
Args:
module_root_path: Root path to search for source files.
ignore_paths: List of paths to ignore.
language: Optional specific language to filter for. If None, includes all supported languages.
Returns:
Dictionary mapping file paths to lists of FunctionToOptimize.
"""
functions: dict[str, list[FunctionToOptimize]] = {}
for file_path in module_root_path.rglob("*.py"):
for file_path in get_files_for_language(module_root_path, ignore_paths, language):
# Find all the functions in the file
functions.update(find_all_functions_in_file(file_path).items())
# Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time.
@ -369,18 +483,34 @@ def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[Functi
def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
functions: dict[Path, list[FunctionToOptimize]] = {}
with file_path.open(encoding="utf8") as f:
try:
ast_module = ast.parse(f.read())
except Exception as e:
if DEBUG_MODE:
logger.exception(e)
return functions
function_name_visitor = FunctionWithReturnStatement(file_path)
function_name_visitor.visit(ast_module)
functions[file_path] = function_name_visitor.functions
return functions
"""Find all optimizable functions in a file, routing to the appropriate language handler.
This function checks if the file extension is supported and routes to either
the Python-specific implementation (for backward compatibility) or the
language support abstraction for other languages.
Args:
file_path: Path to the source file.
Returns:
Dictionary mapping file path to list of FunctionToOptimize.
"""
# Check if the file extension is supported
if not is_language_supported(file_path):
return {}
try:
lang_support = get_language_support(file_path)
except Exception:
return {}
# Route to Python-specific implementation for backward compatibility
if lang_support.language == Language.PYTHON:
return _find_all_functions_in_python_file(file_path)
# Use language support abstraction for other languages
return _find_all_functions_via_language_support(file_path)
def get_all_replay_test_functions(
@ -472,7 +602,7 @@ def get_all_replay_test_functions(
def is_git_repo(file_path: str) -> bool:
try:
git.Repo(file_path, search_parent_directories=True)
return True # noqa: TRY300
return True
except git.InvalidGitRepositoryError:
return False
@ -704,11 +834,14 @@ def filter_functions(
if not file_path_normalized.startswith(module_root_str + os.sep):
non_modules_removed_count += len(_functions)
continue
try:
ast.parse(f"import {module_name_from_file_path(Path(file_path), project_root)}")
except SyntaxError:
malformed_paths_count += 1
continue
lang_support = get_language_support(Path(file_path))
if lang_support.language == Language.PYTHON:
try:
ast.parse(f"import {module_name_from_file_path(Path(file_path), project_root)}")
except SyntaxError:
malformed_paths_count += 1
continue
if blocklist_funcs:
functions_tmp = []

View file

@ -0,0 +1,76 @@
"""Multi-language support for Codeflash.
This package provides the abstraction layer that allows Codeflash to support
multiple programming languages while keeping the core optimization pipeline
language-agnostic.
Usage:
from codeflash.languages import get_language_support, Language
# Get language support for a file
lang = get_language_support(Path("example.py"))
# Discover functions
functions = lang.discover_functions(file_path)
# Replace a function
new_source = lang.replace_function(file_path, function, new_code)
"""
from codeflash.languages.base import (
CodeContext,
FunctionInfo,
HelperFunction,
Language,
LanguageSupport,
ParentInfo,
TestInfo,
TestResult,
)
from codeflash.languages.current import (
current_language,
current_language_support,
is_javascript,
is_python,
is_typescript,
reset_current_language,
set_current_language,
)
from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401
# Import language support modules to trigger auto-registration
# This ensures all supported languages are available when this package is imported
from codeflash.languages.python import PythonSupport # noqa: F401
from codeflash.languages.registry import (
detect_project_language,
get_language_support,
get_supported_extensions,
get_supported_languages,
register_language,
)
__all__ = [
# Base types
"CodeContext",
"FunctionInfo",
"HelperFunction",
"Language",
"LanguageSupport",
"ParentInfo",
"TestInfo",
"TestResult",
# Current language singleton
"current_language",
"current_language_support",
# Registry functions
"detect_project_language",
"get_language_support",
"get_supported_extensions",
"get_supported_languages",
"is_javascript",
"is_python",
"is_typescript",
"register_language",
"reset_current_language",
"set_current_language",
]

688
codeflash/languages/base.py Normal file
View file

@ -0,0 +1,688 @@
"""Base types and protocol for multi-language support in Codeflash.
This module defines the core abstractions that all language implementations must follow.
The LanguageSupport protocol defines the interface that each language must implement,
while the dataclasses define language-agnostic representations of code constructs.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
if TYPE_CHECKING:
from collections.abc import Sequence
from pathlib import Path
class Language(str, Enum):
"""Supported programming languages."""
PYTHON = "python"
JAVASCRIPT = "javascript"
TYPESCRIPT = "typescript"
def __str__(self) -> str:
return self.value
@dataclass(frozen=True)
class ParentInfo:
"""Parent scope information for nested functions/methods.
Represents the parent class or function that contains a nested function.
Used to construct the qualified name of a function.
Attributes:
name: The name of the parent scope (class name or function name).
type: The type of parent ("ClassDef", "FunctionDef", "AsyncFunctionDef", etc.).
"""
name: str
type: str # "ClassDef", "FunctionDef", "AsyncFunctionDef", etc.
def __str__(self) -> str:
return f"{self.type}:{self.name}"
@dataclass(frozen=True)
class FunctionInfo:
"""Language-agnostic representation of a function to optimize.
This class captures all the information needed to identify, locate, and
work with a function across different programming languages.
Attributes:
name: The simple function name (e.g., "add").
file_path: Absolute path to the file containing the function.
start_line: Starting line number (1-indexed).
end_line: Ending line number (1-indexed, inclusive).
parents: List of parent scopes (for nested functions/methods).
is_async: Whether this is an async function.
is_method: Whether this is a method (belongs to a class).
language: The programming language.
start_col: Starting column (0-indexed), optional for more precise location.
end_col: Ending column (0-indexed), optional.
"""
name: str
file_path: Path
start_line: int
end_line: int
parents: tuple[ParentInfo, ...] = ()
is_async: bool = False
is_method: bool = False
language: Language = Language.PYTHON
start_col: int | None = None
end_col: int | None = None
doc_start_line: int | None = None # Line where docstring/JSDoc starts (or None if no doc comment)
@property
def qualified_name(self) -> str:
"""Full qualified name including parent scopes.
For a method `add` in class `Calculator`, returns "Calculator.add".
For nested functions, includes all parent scopes.
"""
if not self.parents:
return self.name
parent_path = ".".join(parent.name for parent in self.parents)
return f"{parent_path}.{self.name}"
@property
def class_name(self) -> str | None:
"""Get the immediate parent class name, if any."""
for parent in reversed(self.parents):
if parent.type == "ClassDef":
return parent.name
return None
@property
def top_level_parent_name(self) -> str:
"""Get the top-level parent name, or function name if no parents."""
return self.parents[0].name if self.parents else self.name
def __str__(self) -> str:
return f"FunctionInfo({self.qualified_name} at {self.file_path}:{self.start_line}-{self.end_line})"
@dataclass
class HelperFunction:
"""A helper function that is a dependency of the target function.
Helper functions are functions called by the target function that are
within the same module/project (not external libraries).
Attributes:
name: The simple function name.
qualified_name: Full qualified name including parent scopes.
file_path: Path to the file containing the helper.
source_code: The source code of the helper function.
start_line: Starting line number.
end_line: Ending line number.
"""
name: str
qualified_name: str
file_path: Path
source_code: str
start_line: int
end_line: int
@dataclass
class CodeContext:
"""Code context extracted for optimization.
Contains the target function code and all relevant dependencies
needed for the AI to understand and optimize the function.
Attributes:
target_code: Source code of the function to optimize.
target_file: Path to the file containing the target function.
helper_functions: List of helper functions called by the target.
read_only_context: Additional context code (read-only dependencies).
imports: List of import statements needed.
language: The programming language.
"""
target_code: str
target_file: Path
helper_functions: list[HelperFunction] = field(default_factory=list)
read_only_context: str = ""
imports: list[str] = field(default_factory=list)
language: Language = Language.PYTHON
@dataclass
class TestInfo:
"""Information about a test that exercises a function.
Attributes:
test_name: Name of the test function.
test_file: Path to the test file.
test_class: Name of the test class, if any.
"""
test_name: str
test_file: Path
test_class: str | None = None
@property
def full_test_path(self) -> str:
"""Get full test path in pytest format (file::class::function)."""
file_path = self.test_file.as_posix()
if self.test_class:
return f"{file_path}::{self.test_class}::{self.test_name}"
return f"{file_path}::{self.test_name}"
@dataclass
class TestResult:
"""Language-agnostic test result.
Captures the outcome of running a single test, including timing
and behavioral data for equivalence checking.
Attributes:
test_name: Name of the test function.
test_file: Path to the test file.
passed: Whether the test passed.
runtime_ns: Execution time in nanoseconds.
return_value: The return value captured from the test.
stdout: Standard output captured during test execution.
stderr: Standard error captured during test execution.
error_message: Error message if the test failed.
"""
test_name: str
test_file: Path
passed: bool
runtime_ns: int | None = None
return_value: Any = None
stdout: str = ""
stderr: str = ""
error_message: str | None = None
@dataclass
class FunctionFilterCriteria:
"""Criteria for filtering which functions to discover.
Attributes:
include_patterns: Glob patterns for functions to include.
exclude_patterns: Glob patterns for functions to exclude.
require_return: Only include functions with return statements.
include_async: Include async functions.
include_methods: Include class methods.
min_lines: Minimum number of lines in the function.
max_lines: Maximum number of lines in the function.
"""
include_patterns: list[str] = field(default_factory=list)
exclude_patterns: list[str] = field(default_factory=list)
require_return: bool = True
include_async: bool = True
include_methods: bool = True
min_lines: int | None = None
max_lines: int | None = None
@runtime_checkable
class LanguageSupport(Protocol):
"""Protocol defining what a language implementation must provide.
All language-specific implementations (Python, JavaScript, etc.) must
implement this protocol. The protocol defines the interface for:
- Function discovery
- Code context extraction
- Code transformation (replacement)
- Test execution
- Test discovery
- Instrumentation for tracing
Example:
class PythonSupport(LanguageSupport):
@property
def language(self) -> Language:
return Language.PYTHON
def discover_functions(self, file_path: Path, ...) -> list[FunctionInfo]:
# Python-specific implementation using LibCST
...
"""
# === Properties ===
@property
def language(self) -> Language:
"""The language this implementation supports."""
...
@property
def file_extensions(self) -> tuple[str, ...]:
"""File extensions supported by this language.
Returns:
Tuple of extensions with leading dots (e.g., (".py",) for Python).
"""
...
@property
def test_framework(self) -> str:
"""Primary test framework name.
Returns:
Test framework identifier (e.g., "pytest", "jest").
"""
...
@property
def comment_prefix(self) -> str:
"""Like # or //."""
...
# === Discovery ===
def discover_functions(
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
) -> list[FunctionInfo]:
"""Find all optimizable functions in a file.
Args:
file_path: Path to the source file to analyze.
filter_criteria: Optional criteria to filter functions.
Returns:
List of FunctionInfo objects for discovered functions.
"""
...
def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionInfo]) -> dict[str, list[TestInfo]]:
"""Map source functions to their tests via static analysis.
Args:
test_root: Root directory containing tests.
source_functions: Functions to find tests for.
Returns:
Dict mapping qualified function names to lists of TestInfo.
"""
...
# === Code Analysis ===
def extract_code_context(self, function: FunctionInfo, project_root: Path, module_root: Path) -> CodeContext:
"""Extract function code and its dependencies.
Args:
function: The function to extract context for.
project_root: Root of the project.
module_root: Root of the module containing the function.
Returns:
CodeContext with target code and dependencies.
"""
...
def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> list[HelperFunction]:
"""Find helper functions called by the target function.
Args:
function: The target function to analyze.
project_root: Root of the project.
Returns:
List of HelperFunction objects.
"""
...
# === Code Transformation ===
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:
"""Replace a function in source code with new implementation.
Args:
source: Original source code.
function: FunctionInfo identifying the function to replace.
new_source: New function source code.
Returns:
Modified source code with function replaced.
"""
...
def format_code(self, source: str, file_path: Path | None = None) -> str:
"""Format code using language-specific formatter.
Args:
source: Source code to format.
file_path: Optional file path for context.
Returns:
Formatted source code.
"""
...
# === Test Execution ===
def run_tests(
self, test_files: Sequence[Path], cwd: Path, env: dict[str, str], timeout: int
) -> tuple[list[TestResult], Path]:
"""Run tests and return results.
Args:
test_files: Paths to test files to run.
cwd: Working directory for test execution.
env: Environment variables.
timeout: Maximum execution time in seconds.
Returns:
Tuple of (list of TestResults, path to JUnit XML).
"""
...
def parse_test_results(self, junit_xml_path: Path, stdout: str) -> list[TestResult]:
"""Parse test results from JUnit XML and stdout.
Args:
junit_xml_path: Path to JUnit XML results file.
stdout: Standard output from test execution.
Returns:
List of TestResult objects.
"""
...
# === Instrumentation ===
def instrument_for_behavior(self, source: str, functions: Sequence[FunctionInfo]) -> str:
"""Add behavior instrumentation to capture inputs/outputs.
Args:
source: Source code to instrument.
functions: Functions to add behavior capture.
Returns:
Instrumented source code.
"""
...
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionInfo) -> str:
"""Add timing instrumentation to test code.
Args:
test_source: Test source code to instrument.
target_function: Function being benchmarked.
Returns:
Instrumented test source code.
"""
...
# === Validation ===
def validate_syntax(self, source: str) -> bool:
"""Check if source code is syntactically valid.
Args:
source: Source code to validate.
Returns:
True if valid, False otherwise.
"""
...
def normalize_code(self, source: str) -> str:
"""Normalize code for deduplication.
Removes comments, normalizes whitespace, etc. to allow
comparison of semantically equivalent code.
Args:
source: Source code to normalize.
Returns:
Normalized source code.
"""
...
# === Test Editing ===
def add_runtime_comments(
self, test_source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]
) -> str:
"""Add runtime performance comments to test source code.
Adds comments showing the original vs optimized runtime for each
function call (e.g., "// 1.5ms -> 0.3ms (80% faster)").
Args:
test_source: Test source code to annotate.
original_runtimes: Map of invocation IDs to original runtimes (ns).
optimized_runtimes: Map of invocation IDs to optimized runtimes (ns).
Returns:
Test source code with runtime comments added.
"""
...
def remove_test_functions(self, test_source: str, functions_to_remove: list[str]) -> str:
"""Remove specific test functions from test source code.
Args:
test_source: Test source code.
functions_to_remove: List of function names to remove.
Returns:
Test source code with specified functions removed.
"""
...
# === Test Result Comparison ===
def compare_test_results(
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
) -> tuple[bool, list]:
"""Compare test results between original and candidate code.
Args:
original_results_path: Path to original test results (e.g., SQLite DB).
candidate_results_path: Path to candidate test results.
project_root: Project root directory (for finding node_modules, etc.).
Returns:
Tuple of (are_equivalent, list of TestDiff objects).
"""
...
# === Configuration ===
def get_test_file_suffix(self) -> str:
"""Get the test file suffix for this language.
Returns:
Test file suffix (e.g., ".test.js", "_test.py").
"""
...
def get_comment_prefix(self) -> str:
"""Get the comment prefix for this language.
Returns:
Comment prefix (e.g., "//" for JS, "#" for Python).
"""
...
def find_test_root(self, project_root: Path) -> Path | None:
"""Find the test root directory for a project.
Args:
project_root: Root directory of the project.
Returns:
Path to test root, or None if not found.
"""
...
def get_runtime_files(self) -> list[Path]:
"""Get paths to runtime files that need to be copied to user's project.
Returns:
List of paths to runtime files (e.g., codeflash-jest-helper.js).
"""
...
def ensure_runtime_environment(self, project_root: Path) -> bool:
"""Ensure the runtime environment is set up for the project.
This method handles language-specific runtime setup, such as installing
npm packages for JavaScript or pip packages for Python.
Args:
project_root: The project root directory.
Returns:
True if runtime environment is ready, False otherwise.
"""
# Default implementation: just copy runtime files
return False
def instrument_existing_test(
self,
test_path: Path,
call_positions: Sequence[Any],
function_to_optimize: Any,
tests_project_root: Path,
mode: str,
) -> tuple[bool, str | None]:
"""Inject profiling code into an existing test file.
Wraps function calls with capture/benchmark instrumentation for
behavioral verification and performance benchmarking.
Args:
test_path: Path to the test file.
call_positions: List of code positions where the function is called.
function_to_optimize: The function being optimized.
tests_project_root: Root directory of tests.
mode: Testing mode - "behavior" or "performance".
Returns:
Tuple of (success, instrumented_code).
"""
...
def instrument_source_for_line_profiler(self, func_info: FunctionInfo, line_profiler_output_file: Path) -> bool:
"""Instrument source code before line profiling."""
...
def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict:
"""Parse line profiler output."""
...
# === Test Execution ===
def run_behavioral_tests(
self,
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None = None,
project_root: Path | None = None,
enable_coverage: bool = False,
candidate_index: int = 0,
) -> tuple[Path, Any, Path | None, Path | None]:
"""Run behavioral tests for this language.
Args:
test_paths: TestFiles object containing test file information.
test_env: Environment variables for the test run.
cwd: Working directory for running tests.
timeout: Optional timeout in seconds.
project_root: Project root directory.
enable_coverage: Whether to collect coverage information.
candidate_index: Index of the candidate being tested.
Returns:
Tuple of (result_file_path, subprocess_result, coverage_path, config_path).
"""
...
def run_benchmarking_tests(
self,
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None = None,
project_root: Path | None = None,
min_loops: int = 5,
max_loops: int = 100_000,
target_duration_seconds: float = 10.0,
) -> tuple[Path, Any]:
"""Run benchmarking tests for this language.
Args:
test_paths: TestFiles object containing test file information.
test_env: Environment variables for the test run.
cwd: Working directory for running tests.
timeout: Optional timeout in seconds.
project_root: Project root directory.
min_loops: Minimum number of loops for benchmarking.
max_loops: Maximum number of loops for benchmarking.
target_duration_seconds: Target duration for benchmarking in seconds.
Returns:
Tuple of (result_file_path, subprocess_result).
"""
...
def convert_parents_to_tuple(parents: list | tuple) -> tuple[ParentInfo, ...]:
"""Convert a list of parent objects to a tuple of ParentInfo.
This helper handles conversion from the existing FunctionParent
dataclass to the new ParentInfo dataclass.
Args:
parents: List or tuple of parent objects with name and type attributes.
Returns:
Tuple of ParentInfo objects.
"""
return tuple(ParentInfo(name=p.name, type=p.type) for p in parents)

View file

@ -0,0 +1,118 @@
"""Singleton for the current language being used in the codeflash session.
This module provides a centralized way to access and set the current language
throughout the codeflash codebase, eliminating scattered language checks and
string comparisons.
Usage:
from codeflash.languages import current_language, set_current_language, is_python
# Set the language at the start of a session
set_current_language(Language.PYTHON)
# or
set_current_language("javascript")
# Check the current language anywhere in the codebase
if is_python():
# Python-specific code
...
# Get the current language
lang = current_language()
# Get language support for the current language
support = current_language_support()
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from codeflash.languages.base import Language
if TYPE_CHECKING:
from codeflash.languages.base import LanguageSupport
# Module-level singleton for the current language
_current_language: Language | None = None
def current_language() -> Language:
"""Get the current language being used in this codeflash session.
Returns:
The current Language enum value.
"""
return _current_language
def set_current_language(language: Language | str) -> None:
"""Set the current language for this codeflash session.
This should be called once at the start of an optimization run,
typically after reading the project configuration.
Args:
language: Either a Language enum value or a string like "python", "javascript", "typescript".
"""
global _current_language
if _current_language is not None:
return
_current_language = Language(language) if isinstance(language, str) else language
def reset_current_language() -> None:
"""Reset the current language to the default (Python).
Useful for testing or when starting a new session.
"""
global _current_language
_current_language = Language.PYTHON
def is_python() -> bool:
"""Check if the current language is Python.
Returns:
True if the current language is Python.
"""
return _current_language == Language.PYTHON
def is_javascript() -> bool:
"""Check if the current language is JavaScript or TypeScript.
This returns True for both JavaScript and TypeScript since they are
typically treated the same way in the optimization pipeline.
Returns:
True if the current language is JavaScript or TypeScript.
"""
return _current_language in (Language.JAVASCRIPT, Language.TYPESCRIPT)
def is_typescript() -> bool:
"""Check if the current language is TypeScript specifically.
Returns:
True if the current language is TypeScript.
"""
return _current_language == Language.TYPESCRIPT
def current_language_support() -> LanguageSupport:
"""Get the LanguageSupport instance for the current language.
Returns:
The LanguageSupport instance for the current language.
"""
from codeflash.languages.registry import get_language_support
return get_language_support(_current_language)

View file

@ -0,0 +1,5 @@
"""JavaScript/TypeScript language support for codeflash."""
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
__all__ = ["JavaScriptSupport", "TypeScriptSupport"]

View file

@ -0,0 +1,192 @@
"""JavaScript test result comparison.
This module provides functionality to compare test results between
original and optimized JavaScript code using a Node.js comparison script.
"""
from __future__ import annotations
import json
import os
import subprocess
from pathlib import Path
from codeflash.cli_cmds.console import logger
from codeflash.models.models import TestDiff, TestDiffScope
def _get_compare_results_script(project_root: Path | None = None) -> Path | None:
"""Find the compare-results.js script from the installed codeflash npm package.
Args:
project_root: Project root directory where node_modules is installed.
Returns:
Path to compare-results.js if found, None otherwise.
"""
search_dirs = []
if project_root:
search_dirs.append(project_root)
search_dirs.append(Path.cwd())
for base_dir in search_dirs:
script_path = base_dir / "node_modules" / "codeflash" / "runtime" / "compare-results.js"
if script_path.exists():
return script_path
return None
def compare_test_results(
original_sqlite_path: Path,
candidate_sqlite_path: Path,
comparator_script: Path | None = None,
project_root: Path | None = None,
) -> tuple[bool, list[TestDiff]]:
"""Compare JavaScript test results using the Node.js comparator.
This function calls a Node.js script that:
1. Reads serialized behavior data from both SQLite databases
2. Deserializes using the codeflash serializer module
3. Compares using the codeflash comparator module (handles Map, Set, Date, etc. natively)
4. Returns comparison results as JSON
Args:
original_sqlite_path: Path to SQLite database with original code results.
candidate_sqlite_path: Path to SQLite database with candidate code results.
comparator_script: Optional path to the comparison script.
project_root: Project root directory where node_modules is installed.
Returns:
Tuple of (all_equivalent, list of TestDiff objects).
"""
script_path = comparator_script or _get_compare_results_script(project_root)
if not script_path or not script_path.exists():
logger.error(
"JavaScript comparator script not found. "
"Please ensure the 'codeflash' npm package is installed in your project."
)
return False, []
if not original_sqlite_path.exists():
logger.error(f"Original SQLite database not found: {original_sqlite_path}")
return False, []
if not candidate_sqlite_path.exists():
logger.error(f"Candidate SQLite database not found: {candidate_sqlite_path}")
return False, []
# Determine working directory - should be where node_modules is installed
# The script needs better-sqlite3 which is installed in the project's node_modules
cwd = project_root or Path.cwd()
# Set NODE_PATH to include the project's node_modules
# This is needed because the script runs from the codeflash package directory,
# but needs to resolve modules from the project's node_modules
env = os.environ.copy()
node_modules_path = cwd / "node_modules"
if node_modules_path.exists():
existing_node_path = env.get("NODE_PATH", "")
if existing_node_path:
env["NODE_PATH"] = f"{node_modules_path}:{existing_node_path}"
else:
env["NODE_PATH"] = str(node_modules_path)
try:
result = subprocess.run(
["node", str(script_path), str(original_sqlite_path), str(candidate_sqlite_path)],
check=False,
capture_output=True,
text=True,
timeout=60,
cwd=str(cwd),
env=env,
)
# Parse the JSON output first - errors are reported in JSON too
try:
if not result.stdout or not result.stdout.strip():
logger.error("JavaScript comparator returned empty output")
if result.stderr:
logger.error(f"stderr: {result.stderr}")
return False, []
comparison = json.loads(result.stdout)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JavaScript comparator output: {e}")
logger.error(f"stdout: {result.stdout[:500] if result.stdout else '(empty)'}")
if result.stderr:
logger.error(f"stderr: {result.stderr[:500]}")
return False, []
# Check for errors in the JSON response
# Exit code 0 = equivalent, 1 = not equivalent, 2 = setup error
if comparison.get("error"):
logger.error(f"JavaScript comparator error: {comparison['error']}")
return False, []
# Check for unexpected exit codes (not 0 or 1)
if result.returncode not in {0, 1}:
logger.error(f"JavaScript comparator failed with exit code {result.returncode}")
if result.stderr:
logger.error(f"stderr: {result.stderr}")
return False, []
# Convert diffs to TestDiff objects
test_diffs: list[TestDiff] = []
for diff in comparison.get("diffs", []):
scope_str = diff.get("scope", "return_value")
scope = TestDiffScope.RETURN_VALUE
if scope_str == "stdout":
scope = TestDiffScope.STDOUT
elif scope_str == "did_pass":
scope = TestDiffScope.DID_PASS
test_info = diff.get("test_info", {})
# Build a test identifier string for JavaScript tests
test_function_name = test_info.get("test_function_name", "unknown")
function_getting_tested = test_info.get("function_getting_tested", "unknown")
test_src_code = f"// Test: {test_function_name}\n// Testing function: {function_getting_tested}"
test_diffs.append(
TestDiff(
scope=scope,
original_value=diff.get("original"),
candidate_value=diff.get("candidate"),
test_src_code=test_src_code,
candidate_pytest_error=diff.get("candidate_error"),
original_pass=True, # Assume passed if we got results
candidate_pass=diff.get("scope") != "missing",
original_pytest_error=None,
)
)
logger.debug(
f"JavaScript test diff:\n"
f" Test: {test_function_name}\n"
f" Function: {function_getting_tested}\n"
f" Scope: {scope_str}\n"
f" Original: {str(diff.get('original', 'N/A'))[:100]}\n"
f" Candidate: {str(diff.get('candidate', 'N/A'))[:100] if diff.get('candidate') else 'N/A'}"
)
equivalent = comparison.get("equivalent", False)
logger.info(
f"JavaScript comparison: {'equivalent' if equivalent else 'DIFFERENT'} "
f"({comparison.get('total_invocations', 0)} invocations, {len(test_diffs)} diffs)"
)
return equivalent, test_diffs
except subprocess.TimeoutExpired:
logger.error("JavaScript comparator timed out")
return False, []
except FileNotFoundError:
logger.error("Node.js not found. Please install Node.js to compare JavaScript test results.")
return False, []
except Exception as e:
logger.error(f"Error running JavaScript comparator: {e}")
return False, []

View file

@ -0,0 +1,230 @@
"""JavaScript test editing utilities.
This module provides functionality for editing JavaScript/TypeScript test files,
including adding runtime comments and removing test functions.
"""
from __future__ import annotations
import re
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import format_perf, format_time
from codeflash.result.critic import performance_gain
def format_runtime_comment(original_time: int, optimized_time: int) -> str:
"""Format a runtime comparison comment for JavaScript.
Args:
original_time: Original runtime in nanoseconds.
optimized_time: Optimized runtime in nanoseconds.
Returns:
Formatted comment string with // prefix.
"""
perf_gain = format_perf(
abs(performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100)
)
status = "slower" if optimized_time > original_time else "faster"
return f"// {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
def add_runtime_comments(source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]) -> str:
"""Add runtime comments to JavaScript test source code.
For JavaScript, we match timing data by test function name and add comments
to expect() or function call lines.
Args:
source: JavaScript test source code.
original_runtimes: Map of invocation keys to original runtimes (ns).
optimized_runtimes: Map of invocation keys to optimized runtimes (ns).
Returns:
Source code with runtime comments added.
"""
logger.debug(f"[js-annotations] original_runtimes has {len(original_runtimes)} entries")
logger.debug(f"[js-annotations] optimized_runtimes has {len(optimized_runtimes)} entries")
if not original_runtimes or not optimized_runtimes:
logger.debug("[js-annotations] No runtimes available, returning unchanged source")
return source
lines = source.split("\n")
modified_lines = []
# Build a lookup by FULL test name (including describe blocks) for suffix matching
# The keys in original_runtimes look like: "full_test_name#/path/to/test#invocation_id"
# where full_test_name includes describe blocks: "fibonacci Edge cases should return 0"
timing_by_full_name: dict[str, tuple[int, int]] = {}
for key in original_runtimes:
if key in optimized_runtimes:
# Extract test function name from the key (first part before #)
parts = key.split("#")
if parts:
full_test_name = parts[0]
logger.debug(f"[js-annotations] Found timing for full test name: '{full_test_name}'")
if full_test_name not in timing_by_full_name:
timing_by_full_name[full_test_name] = (original_runtimes[key], optimized_runtimes[key])
else:
# Sum up timings for same test
old_orig, old_opt = timing_by_full_name[full_test_name]
timing_by_full_name[full_test_name] = (
old_orig + original_runtimes[key],
old_opt + optimized_runtimes[key],
)
logger.debug(f"[js-annotations] Built timing_by_full_name with {len(timing_by_full_name)} entries")
def find_matching_test(test_description: str) -> str | None:
"""Find a timing key that ends with the given test description (suffix match).
Timing keys are like: "fibonacci Edge cases should return 0"
Source test names are like: "should return 0"
We need to match by suffix because timing includes all describe block names.
"""
# Try to match by finding a key that ends with the test description
for full_name in timing_by_full_name:
# Check if the full name ends with the test description (case-insensitive)
if full_name.lower().endswith(test_description.lower()):
logger.debug(f"[js-annotations] Suffix match: '{test_description}' matches '{full_name}'")
return full_name
return None
# Track current test context
current_test_name = None
current_matched_full_name = None
test_pattern = re.compile(r"(?:test|it)\s*\(\s*['\"]([^'\"]+)['\"]")
# Match function calls that look like: funcName(args) or expect(funcName(args))
func_call_pattern = re.compile(r"(?:expect\s*\(\s*)?(\w+)\s*\([^)]*\)")
for line in lines:
# Check if this line starts a new test
test_match = test_pattern.search(line)
if test_match:
current_test_name = test_match.group(1)
logger.debug(f"[js-annotations] Found test: '{current_test_name}'")
# Find the matching full name from timing data using suffix match
current_matched_full_name = find_matching_test(current_test_name)
if current_matched_full_name:
logger.debug(f"[js-annotations] Test '{current_test_name}' matched to '{current_matched_full_name}'")
# Check if this line has a function call and we have timing for current test
if current_matched_full_name and current_matched_full_name in timing_by_full_name:
# Only add comment if line has a function call and doesn't already have a comment
if func_call_pattern.search(line) and "//" not in line and "expect(" in line:
orig_time, opt_time = timing_by_full_name[current_matched_full_name]
comment = format_runtime_comment(orig_time, opt_time)
logger.debug(f"[js-annotations] Adding comment to test '{current_test_name}': {comment}")
# Add comment at end of line
line = f"{line.rstrip()} {comment}"
# Clear timing so we only annotate first call in each test
del timing_by_full_name[current_matched_full_name]
current_matched_full_name = None
modified_lines.append(line)
return "\n".join(modified_lines)
def remove_test_functions(source: str, functions_to_remove: list[str]) -> str:
"""Remove specific test functions from JavaScript test source code.
Handles Jest test patterns: test(), it(), and describe() blocks.
Args:
source: JavaScript test source code.
functions_to_remove: List of test function/describe names to remove.
Returns:
Source code with specified functions removed.
"""
if not functions_to_remove:
return source
for func_name in functions_to_remove:
# Pattern to match test('name', ...) or it('name', ...) blocks
# This handles nested callbacks and multi-line test bodies
test_pattern = re.compile(
r"(?:test|it)\s*\(\s*['\"]" + re.escape(func_name) + r"['\"].*?\)\s*;?\s*\n?", re.DOTALL
)
# Try to find and remove matching test blocks
# For more complex removal, we'd need to track brace matching
match = test_pattern.search(source)
if match:
# Find the full test block by tracking braces
start = match.start()
end = _find_block_end(source, match.end() - 1)
if end > start:
source = source[:start] + source[end:]
return source
def _find_block_end(source: str, start: int) -> int:
"""Find the end of a JavaScript block starting from a position.
Tracks brace matching to find where a function/block ends.
Args:
source: Source code.
start: Starting position (should be at or before opening brace).
Returns:
Position after the closing brace, or start if not found.
"""
# Find the opening brace
brace_pos = source.find("{", start)
if brace_pos == -1:
# No block found, try to find end of arrow function or simple statement
semicolon_pos = source.find(";", start)
newline_pos = source.find("\n", start)
if semicolon_pos != -1:
return semicolon_pos + 1
if newline_pos != -1:
return newline_pos + 1
return start
# Track brace depth
depth = 0
in_string = False
string_char = None
i = brace_pos
while i < len(source):
char = source[i]
# Handle string literals
if char in ('"', "'", "`") and (i == 0 or source[i - 1] != "\\"):
if not in_string:
in_string = True
string_char = char
elif char == string_char:
in_string = False
string_char = None
elif not in_string:
if char == "{":
depth += 1
elif char == "}":
depth -= 1
if depth == 0:
# Found the matching closing brace
# Skip any trailing semicolon or newline
end = i + 1
while end < len(source) and source[end] in " \t":
end += 1
if end < len(source) and source[end] == ";":
end += 1
while end < len(source) and source[end] in " \t\n":
end += 1
return end
i += 1
return start

View file

@ -0,0 +1,540 @@
"""Import resolution for JavaScript/TypeScript.
This module provides utilities to resolve JavaScript/TypeScript import paths
to actual file paths, enabling multi-file context extraction.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from codeflash.languages.base import FunctionInfo, HelperFunction
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
logger = logging.getLogger(__name__)
@dataclass
class ResolvedImport:
"""Result of resolving an import to a file path."""
file_path: Path # Resolved absolute file path
module_path: str # Original import path (e.g., './utils')
imported_names: list[str] # Names imported (for named imports)
is_default_import: bool # Whether it's a default import
is_namespace_import: bool # Whether it's import * as X
namespace_name: str | None # The namespace alias (X in import * as X)
class ImportResolver:
"""Resolves JavaScript/TypeScript import paths to file paths."""
# Supported extensions in resolution order (prefer TS over JS)
EXTENSIONS = (".ts", ".tsx", ".js", ".jsx", ".mjs", ".cjs")
def __init__(self, project_root: Path) -> None:
"""Initialize the resolver.
Args:
project_root: Root directory of the project.
"""
self.project_root = project_root
self._resolution_cache: dict[tuple[Path, str], Path | None] = {}
def resolve_import(self, import_info: ImportInfo, source_file: Path) -> ResolvedImport | None:
"""Resolve an import to its actual file path.
Args:
import_info: The import statement information.
source_file: The file containing the import statement.
Returns:
ResolvedImport if resolution successful, None otherwise.
"""
module_path = import_info.module_path
# Skip external packages (node_modules)
if self._is_external_package(module_path):
logger.debug("Skipping external package: %s", module_path)
return None
# Check cache
cache_key = (source_file, module_path)
if cache_key in self._resolution_cache:
cached_path = self._resolution_cache[cache_key]
if cached_path is None:
return None
return self._build_resolved_import(import_info, cached_path)
# Resolve the path
resolved_path = self._resolve_module_path(module_path, source_file.parent)
# Cache the result
self._resolution_cache[cache_key] = resolved_path
if resolved_path is None:
logger.debug("Could not resolve import: %s from %s", module_path, source_file)
return None
return self._build_resolved_import(import_info, resolved_path)
def _build_resolved_import(self, import_info: ImportInfo, resolved_path: Path) -> ResolvedImport:
"""Build a ResolvedImport from import info and resolved path."""
imported_names = []
# Collect named imports
for name, alias in import_info.named_imports:
imported_names.append(alias if alias else name)
# Add default import if present
if import_info.default_import:
imported_names.append(import_info.default_import)
return ResolvedImport(
file_path=resolved_path,
module_path=import_info.module_path,
imported_names=imported_names,
is_default_import=import_info.default_import is not None,
is_namespace_import=import_info.namespace_import is not None,
namespace_name=import_info.namespace_import,
)
def _resolve_module_path(self, module_path: str, source_dir: Path) -> Path | None:
"""Resolve a module path to an absolute file path.
Args:
module_path: The import path (e.g., './utils', '../lib/helper').
source_dir: Directory of the file containing the import.
Returns:
Resolved absolute path, or None if not found.
"""
# Handle relative imports
if module_path.startswith("."):
return self._resolve_relative_import(module_path, source_dir)
# Handle absolute imports (starting with /)
if module_path.startswith("/"):
return self._resolve_absolute_import(module_path)
# Bare imports (e.g., 'lodash') are external packages
return None
def _resolve_relative_import(self, module_path: str, source_dir: Path) -> Path | None:
"""Resolve relative imports like ./utils or ../lib/helper.
Args:
module_path: The relative import path.
source_dir: Directory to resolve from.
Returns:
Resolved absolute path, or None if not found.
"""
# Compute base path
base_path = (source_dir / module_path).resolve()
# Check if path is within project
try:
base_path.relative_to(self.project_root)
except ValueError:
logger.debug("Import path outside project root: %s", base_path)
return None
# If the path already has an extension, try it directly first
if base_path.suffix in self.EXTENSIONS:
if base_path.exists() and base_path.is_file():
return base_path
# TypeScript allows importing .ts files with .js extension
if base_path.suffix == ".js":
ts_path = base_path.with_suffix(".ts")
if ts_path.exists() and ts_path.is_file():
return ts_path
# Try adding extensions
resolved = self._try_extensions(base_path)
if resolved:
return resolved
# Try as directory with index file
resolved = self._try_index_file(base_path)
if resolved:
return resolved
return None
def _resolve_absolute_import(self, module_path: str) -> Path | None:
"""Resolve absolute imports starting with /.
Args:
module_path: The absolute import path.
Returns:
Resolved absolute path, or None if not found.
"""
# Treat as relative to project root
base_path = (self.project_root / module_path.lstrip("/")).resolve()
# Try adding extensions
resolved = self._try_extensions(base_path)
if resolved:
return resolved
# Try as directory with index file
resolved = self._try_index_file(base_path)
if resolved:
return resolved
return None
def _try_extensions(self, base_path: Path) -> Path | None:
"""Try adding various extensions to find the actual file.
Args:
base_path: The path without extension.
Returns:
Path if file found with an extension, None otherwise.
"""
# If base_path already exists as file
if base_path.exists() and base_path.is_file():
return base_path
# Try each extension in order
for ext in self.EXTENSIONS:
path_with_ext = base_path.with_suffix(ext)
if path_with_ext.exists() and path_with_ext.is_file():
return path_with_ext
# Also try adding extension to paths that already have one
# (e.g., './utils.js' might need to become './utils.js.ts' in some setups)
# This is rare but some bundlers support it
if base_path.suffix:
for ext in self.EXTENSIONS:
path_with_double_ext = Path(str(base_path) + ext)
if path_with_double_ext.exists() and path_with_double_ext.is_file():
return path_with_double_ext
return None
def _try_index_file(self, dir_path: Path) -> Path | None:
"""Try resolving to index file in a directory.
Args:
dir_path: The directory path to check.
Returns:
Path to index file if found, None otherwise.
"""
if not dir_path.exists() or not dir_path.is_dir():
return None
# Try index files with each extension
for ext in self.EXTENSIONS:
index_path = dir_path / f"index{ext}"
if index_path.exists() and index_path.is_file():
return index_path
return None
def _is_external_package(self, module_path: str) -> bool:
"""Check if import refers to an external package (node_modules).
Args:
module_path: The import module path.
Returns:
True if this is an external package import.
"""
# Relative imports are not external
if module_path.startswith("."):
return False
# Absolute imports (starting with /) are project-internal
if module_path.startswith("/"):
return False
# Bare imports without ./ or ../ are external packages
# This includes:
# - 'lodash'
# - '@company/utils'
# - 'react'
# - 'fs' (Node.js built-ins)
return True
@dataclass
class HelperSearchContext:
"""Context for recursive helper search."""
visited_files: set[Path] = field(default_factory=set)
visited_functions: set[tuple[Path, str]] = field(default_factory=set)
current_depth: int = 0
max_depth: int = 2
class MultiFileHelperFinder:
"""Finds helper functions across multiple files."""
DEFAULT_MAX_DEPTH = 2 # Target → helpers → helpers of helpers
def __init__(self, project_root: Path, import_resolver: ImportResolver) -> None:
"""Initialize the finder.
Args:
project_root: Root directory of the project.
import_resolver: ImportResolver instance for resolving imports.
"""
self.project_root = project_root
self.import_resolver = import_resolver
def find_helpers(
self,
function: FunctionInfo,
source: str,
analyzer: TreeSitterAnalyzer,
imports: list[ImportInfo],
max_depth: int = DEFAULT_MAX_DEPTH,
) -> dict[Path, list[HelperFunction]]:
"""Find all helper functions including cross-file dependencies.
Args:
function: The target function to find helpers for.
source: Source code of the file containing the function.
analyzer: TreeSitterAnalyzer for parsing.
imports: List of imports in the source file.
max_depth: Maximum recursion depth for finding helpers of helpers.
Returns:
Dictionary mapping file paths to lists of helper functions.
"""
context = HelperSearchContext(max_depth=max_depth)
context.visited_files.add(function.file_path)
# Find all function calls within the target function
all_functions = analyzer.find_functions(source, include_methods=True)
target_func = None
for func in all_functions:
if func.name == function.name and func.start_line == function.start_line:
target_func = func
break
if not target_func:
return {}
calls = analyzer.find_function_calls(source, target_func)
# Match calls to imports
call_to_import = self._match_calls_to_imports(calls, imports)
# Find helpers from imported modules
results: dict[Path, list[HelperFunction]] = {}
for import_info, actual_name in call_to_import.values():
# Resolve the import to a file path
resolved = self.import_resolver.resolve_import(import_info, function.file_path)
if resolved is None:
continue
# Skip if already visited
key = (resolved.file_path, actual_name)
if key in context.visited_functions:
continue
context.visited_functions.add(key)
# Extract the helper function from the resolved file
helper = self._extract_helper_from_file(resolved.file_path, actual_name, analyzer)
if helper:
if resolved.file_path not in results:
results[resolved.file_path] = []
results[resolved.file_path].append(helper)
# Recursively find helpers of this helper (if depth allows)
if context.current_depth < context.max_depth:
nested_results = self._find_helpers_recursive(
resolved.file_path,
helper,
HelperSearchContext(
visited_files=context.visited_files.copy(),
visited_functions=context.visited_functions.copy(),
current_depth=context.current_depth + 1,
max_depth=context.max_depth,
),
)
# Merge nested results
for path, helpers in nested_results.items():
if path not in results:
results[path] = []
results[path].extend(helpers)
return results
def _match_calls_to_imports(self, calls: set[str], imports: list[ImportInfo]) -> dict[str, tuple[ImportInfo, str]]:
"""Match function calls to their import sources.
Args:
calls: Set of function call names found in the code.
imports: List of import statements.
Returns:
Dictionary mapping call names to (ImportInfo, actual_function_name) tuples.
"""
matches: dict[str, tuple[ImportInfo, str]] = {}
for call in calls:
# Check for namespace calls (e.g., utils.helper)
if "." in call:
namespace, func_name = call.split(".", 1)
for imp in imports:
if imp.namespace_import == namespace:
matches[call] = (imp, func_name)
break
else:
# Check for direct imports
for imp in imports:
# Check default import
if imp.default_import == call:
matches[call] = (imp, "default")
break
# Check named imports
for name, alias in imp.named_imports:
if (alias and alias == call) or (not alias and name == call):
matches[call] = (imp, name)
break
return matches
def _extract_helper_from_file(
self, file_path: Path, function_name: str, analyzer: TreeSitterAnalyzer
) -> HelperFunction | None:
"""Extract a helper function from a resolved file.
Args:
file_path: Path to the file containing the function.
function_name: Name of the function to extract.
analyzer: TreeSitterAnalyzer for parsing.
Returns:
HelperFunction if found, None otherwise.
"""
from codeflash.languages.base import HelperFunction
from codeflash.languages.treesitter_utils import get_analyzer_for_file
try:
source = file_path.read_text(encoding="utf-8")
except Exception as e:
logger.warning("Failed to read %s: %s", file_path, e)
return None
# Get analyzer for this file type
file_analyzer = get_analyzer_for_file(file_path)
# Split source into lines for JSDoc extraction
lines = source.splitlines(keepends=True)
# Handle "default" export - look for default exported function
if function_name == "default":
# Find the default export
functions = file_analyzer.find_functions(source, include_methods=True)
# For now, return first function if looking for default
# TODO: Implement proper default export detection
for func in functions:
# Extract source including JSDoc if present
effective_start = func.doc_start_line or func.start_line
helper_lines = lines[effective_start - 1 : func.end_line]
helper_source = "".join(helper_lines)
return HelperFunction(
name=func.name,
qualified_name=func.name,
file_path=file_path,
source_code=helper_source,
start_line=effective_start,
end_line=func.end_line,
)
return None
# Find the function by name
functions = file_analyzer.find_functions(source, include_methods=True)
for func in functions:
if func.name == function_name:
# Extract source including JSDoc if present
effective_start = func.doc_start_line or func.start_line
helper_lines = lines[effective_start - 1 : func.end_line]
helper_source = "".join(helper_lines)
return HelperFunction(
name=func.name,
qualified_name=func.name,
file_path=file_path,
source_code=helper_source,
start_line=effective_start,
end_line=func.end_line,
)
logger.debug("Function %s not found in %s", function_name, file_path)
return None
def _find_helpers_recursive(
self, file_path: Path, helper: HelperFunction, context: HelperSearchContext
) -> dict[Path, list[HelperFunction]]:
"""Recursively find helpers of a helper function.
Args:
file_path: Path to the file containing the helper.
helper: The helper function to analyze.
context: Search context with visited tracking and depth limit.
Returns:
Dictionary mapping file paths to lists of helper functions.
"""
from codeflash.languages.base import FunctionInfo
from codeflash.languages.treesitter_utils import get_analyzer_for_file
if context.current_depth >= context.max_depth:
return {}
if file_path in context.visited_files:
return {}
context.visited_files.add(file_path)
try:
source = file_path.read_text(encoding="utf-8")
except Exception as e:
logger.warning("Failed to read %s: %s", file_path, e)
return {}
# Get analyzer and imports for this file
analyzer = get_analyzer_for_file(file_path)
imports = analyzer.find_imports(source)
# Create FunctionInfo for the helper
func_info = FunctionInfo(
name=helper.name, file_path=file_path, start_line=helper.start_line, end_line=helper.end_line, parents=()
)
# Recursively find helpers
return self.find_helpers(
function=func_info,
source=source,
analyzer=analyzer,
imports=imports,
max_depth=context.max_depth - context.current_depth,
)

View file

@ -0,0 +1,974 @@
"""JavaScript test instrumentation for existing tests.
This module provides functionality to inject profiling code into existing JavaScript
test files, similar to Python's inject_profiling_into_existing_test.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
if TYPE_CHECKING:
from codeflash.code_utils.code_position import CodePosition
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
class TestingMode:
"""Testing mode constants."""
BEHAVIOR = "behavior"
PERFORMANCE = "performance"
@dataclass
class ExpectCallMatch:
"""Represents a matched expect(func(...)).toXXX() call."""
start_pos: int
end_pos: int
leading_whitespace: str
func_args: str
assertion_chain: str
has_trailing_semicolon: bool
object_prefix: str = "" # Object prefix like "calc." or "this." or ""
@dataclass
class StandaloneCallMatch:
"""Represents a matched standalone func(...) call."""
start_pos: int
end_pos: int
leading_whitespace: str
func_args: str
prefix: str # "await " or ""
object_prefix: str # Object prefix like "calc." or "this." or ""
has_trailing_semicolon: bool
class StandaloneCallTransformer:
"""Transforms standalone func(...) calls in JavaScript test code.
This class handles the transformation of standalone function calls that are NOT
inside expect() wrappers. These calls need to be wrapped with codeflash.capture()
or codeflash.capturePerf() for instrumentation.
Examples:
- await func(args) -> await codeflash.capturePerf('name', 'id', func, args)
- func(args) -> codeflash.capturePerf('name', 'id', func, args)
- const result = func(args) -> const result = codeflash.capturePerf(...)
- arr.map(() => func(args)) -> arr.map(() => codeflash.capturePerf(..., func, args))
- calc.fibonacci(n) -> codeflash.capturePerf('...', 'id', calc.fibonacci.bind(calc), n)
"""
def __init__(self, func_name: str, qualified_name: str, capture_func: str) -> None:
self.func_name = func_name
self.qualified_name = qualified_name
self.capture_func = capture_func
self.invocation_counter = 0
# Pattern to match func_name( with optional leading await and optional object prefix
# Captures: (whitespace)(await )?(object.)*func_name(
# We'll filter out expect() and codeflash. cases in the transform loop
self._call_pattern = re.compile(rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(func_name)}\s*\(")
def transform(self, code: str) -> str:
"""Transform all standalone calls in the code."""
result: list[str] = []
pos = 0
while pos < len(code):
match = self._call_pattern.search(code, pos)
if not match:
result.append(code[pos:])
break
match_start = match.start()
# Check if this call is inside an expect() or already transformed
if self._should_skip_match(code, match_start, match):
result.append(code[pos : match.end()])
pos = match.end()
continue
# Add everything before the match
result.append(code[pos:match_start])
# Try to parse the full standalone call
standalone_match = self._parse_standalone_call(code, match)
if standalone_match is None:
# Couldn't parse, skip this match
result.append(code[match_start : match.end()])
pos = match.end()
continue
# Generate the transformed code
self.invocation_counter += 1
transformed = self._generate_transformed_call(standalone_match)
result.append(transformed)
pos = standalone_match.end_pos
return "".join(result)
def _should_skip_match(self, code: str, start: int, match: re.Match) -> bool:
"""Check if the match should be skipped (inside expect, already transformed, etc.)."""
# Look backwards to check context
lookback_start = max(0, start - 200)
lookback = code[lookback_start:start]
# Skip if already transformed with codeflash.capture
if f"codeflash.{self.capture_func}(" in lookback[-60:]:
return True
# Skip if this is a function/method definition, not a call
# Patterns to skip:
# - ClassName.prototype.funcName = function(
# - funcName = function(
# - funcName: function(
# - function funcName(
# - funcName() { (method definition in class)
near_context = lookback[-80:] if len(lookback) >= 80 else lookback
# Skip prototype assignment: ClassName.prototype.funcName = function(
if re.search(r"\.prototype\.\w+\s*=\s*function\s*$", near_context):
return True
# Skip function assignment: funcName = function(
if re.search(rf"{re.escape(self.func_name)}\s*=\s*function\s*$", near_context):
return True
# Skip function declaration: function funcName(
if re.search(rf"function\s+{re.escape(self.func_name)}\s*$", near_context):
return True
# Skip method definition in class body: funcName(params) { or async funcName(params) {
# Check by looking at what comes after the closing paren
# The match ends at the opening paren, so find the closing paren and check what follows
close_paren_pos = self._find_matching_paren(code, match.end() - 1)
if close_paren_pos != -1:
# Check if followed by { (method definition) after optional whitespace
after_close = code[close_paren_pos : close_paren_pos + 20].lstrip()
if after_close.startswith("{"):
# This is a method definition like "fibonacci(n) {"
# But we still want to capture certain patterns like arrow functions
# Check if there's no => before the {
between = code[close_paren_pos : close_paren_pos + 20].strip()
if not between.startswith("=>"):
return True
# Skip if inside expect() - look for 'expect(' with unmatched parens
# Find the last 'expect(' and check if it's still open
expect_search_start = max(0, start - 100)
expect_lookback = code[expect_search_start:start]
# Find all expect( positions
expect_pos = expect_lookback.rfind("expect(")
if expect_pos != -1:
# Count parens from expect( to our match position
between = expect_lookback[expect_pos:]
open_parens = between.count("(") - between.count(")")
if open_parens > 0:
# We're inside an unclosed expect()
return True
return False
def _find_matching_paren(self, code: str, open_paren_pos: int) -> int:
"""Find the position of the closing paren for the given opening paren."""
if open_paren_pos >= len(code) or code[open_paren_pos] != "(":
return -1
depth = 1
pos = open_paren_pos + 1
while pos < len(code) and depth > 0:
if code[pos] == "(":
depth += 1
elif code[pos] == ")":
depth -= 1
pos += 1
return pos if depth == 0 else -1
def _parse_standalone_call(self, code: str, match: re.Match) -> StandaloneCallMatch | None:
"""Parse a complete standalone func(...) call."""
leading_ws = match.group(1)
prefix = match.group(2) or "" # "await " or ""
object_prefix = match.group(3) or "" # Object prefix like "calc." or ""
# If qualified_name is a standalone function (no dot), don't match method calls
# e.g., if qualified_name="func", don't match "obj.func()" - only match "func()"
if "." not in self.qualified_name and object_prefix:
return None
# Find the opening paren position
match_text = match.group(0)
paren_offset = match_text.rfind("(")
open_paren_pos = match.start() + paren_offset
# Find the arguments (content inside parens)
func_args, close_pos = self._find_balanced_parens(code, open_paren_pos)
if func_args is None:
return None
# Check for trailing semicolon
end_pos = close_pos
# Skip whitespace
while end_pos < len(code) and code[end_pos] in " \t":
end_pos += 1
has_trailing_semicolon = end_pos < len(code) and code[end_pos] == ";"
if has_trailing_semicolon:
end_pos += 1
return StandaloneCallMatch(
start_pos=match.start(),
end_pos=end_pos,
leading_whitespace=leading_ws,
func_args=func_args,
prefix=prefix,
object_prefix=object_prefix,
has_trailing_semicolon=has_trailing_semicolon,
)
def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | None, int]:
"""Find content within balanced parentheses."""
if open_paren_pos >= len(code) or code[open_paren_pos] != "(":
return None, -1
depth = 1
pos = open_paren_pos + 1
in_string = False
string_char = None
while pos < len(code) and depth > 0:
char = code[pos]
# Handle string literals
if char in "\"'`" and (pos == 0 or code[pos - 1] != "\\"):
if not in_string:
in_string = True
string_char = char
elif char == string_char:
in_string = False
string_char = None
elif not in_string:
if char == "(":
depth += 1
elif char == ")":
depth -= 1
pos += 1
if depth != 0:
return None, -1
return code[open_paren_pos + 1 : pos - 1], pos
def _generate_transformed_call(self, match: StandaloneCallMatch) -> str:
"""Generate the transformed code for a standalone call."""
line_id = str(self.invocation_counter)
args_str = match.func_args.strip()
semicolon = ";" if match.has_trailing_semicolon else ""
# Handle method calls on objects (e.g., calc.fibonacci, this.method)
if match.object_prefix:
# Remove trailing dot from object prefix for the bind call
obj = match.object_prefix.rstrip(".")
full_method = f"{obj}.{self.func_name}"
if args_str:
return (
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {full_method}.bind({obj}), {args_str}){semicolon}"
)
return (
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {full_method}.bind({obj})){semicolon}"
)
# Handle standalone function calls
if args_str:
return (
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {self.func_name}, {args_str}){semicolon}"
)
return (
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {self.func_name}){semicolon}"
)
def transform_standalone_calls(
code: str, func_name: str, qualified_name: str, capture_func: str, start_counter: int = 0
) -> tuple[str, int]:
"""Transform standalone func(...) calls in JavaScript test code.
This transforms function calls that are NOT inside expect() wrappers.
Args:
code: The test code to transform.
func_name: Name of the function being tested.
qualified_name: Fully qualified function name.
capture_func: The capture function to use ('capture' or 'capturePerf').
start_counter: Starting value for the invocation counter.
Returns:
Tuple of (transformed code, final counter value).
"""
transformer = StandaloneCallTransformer(
func_name=func_name, qualified_name=qualified_name, capture_func=capture_func
)
transformer.invocation_counter = start_counter
result = transformer.transform(code)
return result, transformer.invocation_counter
class ExpectCallTransformer:
"""Transforms expect(func(...)).assertion() calls in JavaScript test code.
This class handles the parsing and transformation of Jest/Vitest expect calls,
supporting various assertion patterns including:
- Basic: expect(func(args)).toBe(value)
- Negated: expect(func(args)).not.toBe(value)
- Async: expect(func(args)).resolves.toBe(value)
- Chained: expect(func(args)).not.resolves.toBe(value)
- No-arg assertions: expect(func(args)).toBeTruthy()
- Multi-arg assertions: expect(func(args)).toBeCloseTo(0.5, 2)
"""
def __init__(self, func_name: str, qualified_name: str, capture_func: str, remove_assertions: bool = False) -> None:
self.func_name = func_name
self.qualified_name = qualified_name
self.capture_func = capture_func
self.remove_assertions = remove_assertions
self.invocation_counter = 0
# Pattern to match start of expect((object.)*func_name(
# Captures: (whitespace), (object prefix like calc. or this.)
self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(func_name)}\s*\(")
def transform(self, code: str) -> str:
"""Transform all expect calls in the code."""
result: list[str] = []
pos = 0
while pos < len(code):
match = self._expect_pattern.search(code, pos)
if not match:
result.append(code[pos:])
break
# Add everything before the match
result.append(code[pos : match.start()])
# Try to parse the full expect call
expect_match = self._parse_expect_call(code, match)
if expect_match is None:
# Couldn't parse, skip this match
result.append(code[match.start() : match.end()])
pos = match.end()
continue
# Generate the transformed code
self.invocation_counter += 1
transformed = self._generate_transformed_call(expect_match)
result.append(transformed)
pos = expect_match.end_pos
return "".join(result)
def _parse_expect_call(self, code: str, match: re.Match) -> ExpectCallMatch | None:
"""Parse a complete expect(func(...)).assertion() call.
Returns None if the pattern doesn't match expected structure.
"""
leading_ws = match.group(1)
object_prefix = match.group(2) or "" # Object prefix like "calc." or ""
# If qualified_name is a standalone function (no dot), don't match method calls
# e.g., if qualified_name="func", don't match "obj.func()" - only match "func()"
if "." not in self.qualified_name and object_prefix:
return None
# Find the arguments of the function call (handling nested parens)
args_start = match.end()
func_args, func_close_pos = self._find_balanced_parens(code, args_start - 1)
if func_args is None:
return None
# Skip whitespace and find closing ) of expect(
expect_close_pos = func_close_pos
while expect_close_pos < len(code) and code[expect_close_pos].isspace():
expect_close_pos += 1
if expect_close_pos >= len(code) or code[expect_close_pos] != ")":
return None
expect_close_pos += 1 # Move past )
# Parse the assertion chain (e.g., .not.resolves.toBe(value))
assertion_chain, chain_end_pos = self._parse_assertion_chain(code, expect_close_pos)
if assertion_chain is None:
return None
# Check for trailing semicolon
has_trailing_semicolon = chain_end_pos < len(code) and code[chain_end_pos] == ";"
if has_trailing_semicolon:
chain_end_pos += 1
return ExpectCallMatch(
start_pos=match.start(),
end_pos=chain_end_pos,
leading_whitespace=leading_ws,
func_args=func_args,
assertion_chain=assertion_chain,
has_trailing_semicolon=has_trailing_semicolon,
object_prefix=object_prefix,
)
def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | None, int]:
"""Find content within balanced parentheses.
Args:
code: The source code
open_paren_pos: Position of the opening parenthesis
Returns:
Tuple of (content inside parens, position after closing paren) or (None, -1)
"""
if open_paren_pos >= len(code) or code[open_paren_pos] != "(":
return None, -1
depth = 1
pos = open_paren_pos + 1
in_string = False
string_char = None
while pos < len(code) and depth > 0:
char = code[pos]
# Handle string literals
if char in "\"'`" and (pos == 0 or code[pos - 1] != "\\"):
if not in_string:
in_string = True
string_char = char
elif char == string_char:
in_string = False
string_char = None
elif not in_string:
if char == "(":
depth += 1
elif char == ")":
depth -= 1
pos += 1
if depth != 0:
return None, -1
# Return content (excluding parens) and position after closing paren
return code[open_paren_pos + 1 : pos - 1], pos
def _parse_assertion_chain(self, code: str, start_pos: int) -> tuple[str | None, int]:
"""Parse assertion chain like .not.resolves.toBe(value).
Handles:
- .toBe(value)
- .not.toBe(value)
- .resolves.toBe(value)
- .rejects.toThrow()
- .not.resolves.toBe(value)
- .toBeTruthy() (no args)
- .toBeCloseTo(0.5, 2) (multiple args with nested parens)
Returns:
Tuple of (assertion chain string, end position) or (None, -1)
"""
pos = start_pos
chain_parts: list[str] = []
# Skip any leading whitespace (for multi-line)
while pos < len(code) and code[pos] in " \t\n\r":
pos += 1
# Must start with a dot
if pos >= len(code) or code[pos] != ".":
return None, -1
while pos < len(code):
# Skip whitespace between chain elements
while pos < len(code) and code[pos] in " \t\n\r":
pos += 1
if pos >= len(code) or code[pos] != ".":
break
pos += 1 # Skip the dot
# Skip whitespace after dot
while pos < len(code) and code[pos] in " \t\n\r":
pos += 1
# Parse the method name
method_start = pos
while pos < len(code) and (code[pos].isalnum() or code[pos] == "_"):
pos += 1
if pos == method_start:
return None, -1
method_name = code[method_start:pos]
# Skip whitespace before potential parens
while pos < len(code) and code[pos] in " \t\n\r":
pos += 1
# Check for parentheses (method call)
if pos < len(code) and code[pos] == "(":
args_content, after_paren = self._find_balanced_parens(code, pos)
if args_content is None:
return None, -1
chain_parts.append(f".{method_name}({args_content})")
pos = after_paren
else:
# Method without parens (like .not, .resolves, .rejects)
# Or assertion without args like .toBeTruthy
chain_parts.append(f".{method_name}")
# If this is a terminal assertion (starts with 'to'), we're done
if method_name.startswith("to"):
break
if not chain_parts:
return None, -1
# Verify we have a terminal assertion (should end with .toXXX)
last_part = chain_parts[-1]
if not last_part.startswith(".to"):
return None, -1
return "".join(chain_parts), pos
def _generate_transformed_call(self, match: ExpectCallMatch) -> str:
"""Generate the transformed code for an expect call."""
line_id = str(self.invocation_counter)
args_str = match.func_args.strip()
# Determine the function reference to use
if match.object_prefix:
# Method call on object: calc.fibonacci -> calc.fibonacci.bind(calc)
obj = match.object_prefix.rstrip(".")
func_ref = f"{obj}.{self.func_name}.bind({obj})"
else:
func_ref = self.func_name
if self.remove_assertions:
# For generated/regression tests: remove expect wrapper and assertion
if args_str:
return (
f"{match.leading_whitespace}codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {func_ref}, {args_str});"
)
return (
f"{match.leading_whitespace}codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {func_ref});"
)
# For existing tests: keep the expect wrapper
semicolon = ";" if match.has_trailing_semicolon else ""
if args_str:
return (
f"{match.leading_whitespace}expect(codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {func_ref}, {args_str})){match.assertion_chain}{semicolon}"
)
return (
f"{match.leading_whitespace}expect(codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {func_ref})){match.assertion_chain}{semicolon}"
)
def transform_expect_calls(
code: str, func_name: str, qualified_name: str, capture_func: str, remove_assertions: bool = False
) -> tuple[str, int]:
"""Transform expect(func(...)).assertion() calls in JavaScript test code.
This is the main entry point for expect call transformation.
Args:
code: The test code to transform.
func_name: Name of the function being tested.
qualified_name: Fully qualified function name.
capture_func: The capture function to use ('capture' or 'capturePerf').
remove_assertions: If True, remove assertions entirely (for generated tests).
Returns:
Tuple of (transformed code, final invocation counter value).
"""
transformer = ExpectCallTransformer(
func_name=func_name,
qualified_name=qualified_name,
capture_func=capture_func,
remove_assertions=remove_assertions,
)
result = transformer.transform(code)
return result, transformer.invocation_counter
def inject_profiling_into_existing_js_test(
test_path: Path,
call_positions: list[CodePosition],
function_to_optimize: FunctionToOptimize,
tests_project_root: Path,
mode: str = TestingMode.BEHAVIOR,
) -> tuple[bool, str | None]:
"""Inject profiling code into an existing JavaScript test file.
This function wraps function calls with codeflash.capture() or codeflash.capturePerf()
to enable behavioral verification and performance benchmarking.
Args:
test_path: Path to the test file.
call_positions: List of code positions where the function is called.
function_to_optimize: The function being optimized.
tests_project_root: Root directory of tests.
mode: Testing mode - "behavior" or "performance".
Returns:
Tuple of (success, instrumented_code).
"""
try:
with test_path.open(encoding="utf8") as f:
test_code = f.read()
except Exception as e:
logger.error(f"Failed to read test file {test_path}: {e}")
return False, None
func_name = function_to_optimize.function_name
# Get the relative path for test identification
try:
rel_path = test_path.relative_to(tests_project_root)
except ValueError:
rel_path = test_path
# Check if the function is imported/required in this test file
if not _is_function_used_in_test(test_code, func_name):
logger.debug(f"Function '{func_name}' not found in test file {test_path}")
return False, None
# Instrument the test code
instrumented_code = _instrument_js_test_code(
test_code, func_name, str(rel_path), mode, function_to_optimize.qualified_name
)
if instrumented_code == test_code:
logger.debug(f"No changes made to test file {test_path}")
return False, None
return True, instrumented_code
def _is_function_used_in_test(code: str, func_name: str) -> bool:
"""Check if a function is imported or used in the test code.
This function handles both standalone functions and class methods.
For class methods, it checks if the method is called on any object
(e.g., calc.fibonacci, this.fibonacci).
"""
# Check for CommonJS require with named export
require_pattern = rf"(?:const|let|var)\s+\{{\s*[^}}]*\b{re.escape(func_name)}\b[^}}]*\}}\s*=\s*require\s*\("
if re.search(require_pattern, code):
return True
# Check for ES6 import with named export
import_pattern = rf"import\s+\{{\s*[^}}]*\b{re.escape(func_name)}\b[^}}]*\}}\s+from"
if re.search(import_pattern, code):
return True
# Check for default import (import func from or const func = require())
default_require = rf"(?:const|let|var)\s+{re.escape(func_name)}\s*=\s*require\s*\("
if re.search(default_require, code):
return True
default_import = rf"import\s+{re.escape(func_name)}\s+from"
if re.search(default_import, code):
return True
# Check for method calls: obj.funcName( or this.funcName(
# This handles class methods called on instances
method_call_pattern = rf"\w+\.{re.escape(func_name)}\s*\("
return bool(re.search(method_call_pattern, code))
def _instrument_js_test_code(
code: str, func_name: str, test_file_path: str, mode: str, qualified_name: str, remove_assertions: bool = False
) -> str:
"""Instrument JavaScript test code with profiling capture calls.
Args:
code: Original test code.
func_name: Name of the function to instrument.
test_file_path: Relative path to test file.
mode: Testing mode (behavior or performance).
qualified_name: Fully qualified function name.
remove_assertions: If True, remove expect assertions entirely (for generated/regression tests).
If False, keep the expect wrapper (for existing user-written tests).
Returns:
Instrumented code.
"""
# Add codeflash helper import if not already present
# Support both npm package (codeflash) and legacy local file (codeflash-jest-helper)
has_codeflash_import = "codeflash" in code
if not has_codeflash_import:
# Detect module system: ESM uses "import ... from", CommonJS uses "require()"
is_esm = bool(re.search(r"^\s*import\s+.+\s+from\s+['\"]", code, re.MULTILINE))
if is_esm:
# ESM: Use import statement at the top of the file (after any other imports)
helper_import = "import codeflash from 'codeflash';\n"
# Find the last import statement to add after
import_matches = list(re.finditer(r"^import\s+.+\s+from\s+['\"][^'\"]+['\"]\s*;?\s*\n", code, re.MULTILINE))
if import_matches:
# Add after the last import
last_import = import_matches[-1]
insert_pos = last_import.end()
code = code[:insert_pos] + helper_import + code[insert_pos:]
else:
# No imports found, add at beginning
code = helper_import + "\n" + code
else:
# CommonJS: Use require statement
helper_require = "const codeflash = require('codeflash');\n"
# Find the first require statement to add after
import_match = re.search(r"^((?:const|let|var)\s+.+?require\([^)]+\).*;?\s*\n)", code, re.MULTILINE)
if import_match:
insert_pos = import_match.end()
code = code[:insert_pos] + helper_require + code[insert_pos:]
else:
# Add at the beginning if no requires found
code = helper_require + "\n" + code
# Choose capture function based on mode
capture_func = "capturePerf" if mode == TestingMode.PERFORMANCE else "capture"
# Transform expect calls using the refactored transformer
code, expect_counter = transform_expect_calls(
code=code,
func_name=func_name,
qualified_name=qualified_name,
capture_func=capture_func,
remove_assertions=remove_assertions,
)
# Transform standalone calls (not inside expect wrappers)
# Continue counter from expect transformer to ensure unique IDs
code, _final_counter = transform_standalone_calls(
code=code,
func_name=func_name,
qualified_name=qualified_name,
capture_func=capture_func,
start_counter=expect_counter,
)
return code
def validate_and_fix_import_style(test_code: str, source_file_path: Path, function_name: str) -> str:
"""Validate and fix import style in generated test code to match source export.
The AI may generate tests with incorrect import styles (e.g., using named import
for a default export). This function detects such mismatches and fixes them.
Args:
test_code: The generated test code.
source_file_path: Path to the source file being tested.
function_name: Name of the function being tested.
Returns:
Fixed test code with correct import style.
"""
from codeflash.languages.treesitter_utils import get_analyzer_for_file
# Read source file to determine export style
try:
source_code = source_file_path.read_text(encoding="utf-8")
except Exception as e:
logger.warning(f"Could not read source file {source_file_path}: {e}")
return test_code
# Get analyzer for the source file
try:
analyzer = get_analyzer_for_file(source_file_path)
exports = analyzer.find_exports(source_code)
except Exception as e:
logger.warning(f"Could not analyze exports in {source_file_path}: {e}")
return test_code
if not exports:
return test_code
# Determine how the function is exported
is_default_export = False
is_named_export = False
for export in exports:
if export.default_export == function_name:
is_default_export = True
break
for name, _alias in export.exported_names:
if name == function_name:
is_named_export = True
break
if is_named_export:
break
# If we can't determine export style, don't modify
if not is_default_export and not is_named_export:
# Check if it might be a default export without name
for export in exports:
if export.default_export == "default":
is_default_export = True
break
if not is_default_export and not is_named_export:
return test_code
# Find import statements in test code that import from the source file
# Normalize path for matching
source_name = source_file_path.stem
source_patterns = [source_name, f"./{source_name}", f"../{source_name}", source_file_path.as_posix()]
# Pattern for named import: const { funcName } = require(...) or import { funcName } from ...
named_require_pattern = re.compile(
rf"(const|let|var)\s+\{{\s*{re.escape(function_name)}\s*\}}\s*=\s*require\s*\(\s*['\"]([^'\"]+)['\"]\s*\)"
)
named_import_pattern = re.compile(rf"import\s+\{{\s*{re.escape(function_name)}\s*\}}\s+from\s+['\"]([^'\"]+)['\"]")
# Pattern for default import: const funcName = require(...) or import funcName from ...
default_require_pattern = re.compile(
rf"(const|let|var)\s+{re.escape(function_name)}\s*=\s*require\s*\(\s*['\"]([^'\"]+)['\"]\s*\)"
)
default_import_pattern = re.compile(rf"import\s+{re.escape(function_name)}\s+from\s+['\"]([^'\"]+)['\"]")
def is_relevant_import(module_path: str) -> bool:
"""Check if the module path refers to our source file."""
# Normalize and compare
module_name = Path(module_path).stem
return any(p in module_path or module_name == source_name for p in source_patterns)
# Check for mismatch and fix
if is_default_export:
# Function is default exported, but test uses named import - need to fix
for match in named_require_pattern.finditer(test_code):
module_path = match.group(2)
if is_relevant_import(module_path):
logger.debug(f"Fixing named require to default for {function_name} from {module_path}")
old_import = match.group(0)
new_import = f"{match.group(1)} {function_name} = require('{module_path}')"
test_code = test_code.replace(old_import, new_import)
for match in named_import_pattern.finditer(test_code):
module_path = match.group(1)
if is_relevant_import(module_path):
logger.debug(f"Fixing named import to default for {function_name} from {module_path}")
old_import = match.group(0)
new_import = f"import {function_name} from '{module_path}'"
test_code = test_code.replace(old_import, new_import)
elif is_named_export:
# Function is named exported, but test uses default import - need to fix
for match in default_require_pattern.finditer(test_code):
module_path = match.group(2)
if is_relevant_import(module_path):
logger.debug(f"Fixing default require to named for {function_name} from {module_path}")
old_import = match.group(0)
new_import = f"{match.group(1)} {{ {function_name} }} = require('{module_path}')"
test_code = test_code.replace(old_import, new_import)
for match in default_import_pattern.finditer(test_code):
module_path = match.group(1)
if is_relevant_import(module_path):
logger.debug(f"Fixing default import to named for {function_name} from {module_path}")
old_import = match.group(0)
new_import = f"import {{ {function_name} }} from '{module_path}'"
test_code = test_code.replace(old_import, new_import)
return test_code
def get_instrumented_test_path(original_path: Path, mode: str) -> Path:
"""Generate path for instrumented test file.
Args:
original_path: Original test file path.
mode: Testing mode (behavior or performance).
Returns:
Path for instrumented file.
"""
suffix = "_codeflash_behavior" if mode == TestingMode.BEHAVIOR else "_codeflash_perf"
stem = original_path.stem
# Handle .test.js -> .test_codeflash_behavior.js
if ".test" in stem:
parts = stem.rsplit(".test", 1)
new_stem = f"{parts[0]}{suffix}.test"
elif ".spec" in stem:
parts = stem.rsplit(".spec", 1)
new_stem = f"{parts[0]}{suffix}.spec"
else:
new_stem = f"{stem}{suffix}"
return original_path.parent / f"{new_stem}{original_path.suffix}"
def instrument_generated_js_test(
test_code: str, function_name: str, qualified_name: str, mode: str = TestingMode.BEHAVIOR
) -> str:
"""Instrument generated JavaScript/TypeScript test code.
This function is used to instrument tests generated by the aiservice.
Unlike inject_profiling_into_existing_js_test, this takes the test code
as a string rather than reading from a file.
For generated tests, we remove the expect() assertions entirely because:
1. LLM-generated expected values may be incorrect
2. These are treated as regression tests where correctness is verified
by comparing outputs between original and optimized code
Args:
test_code: The generated test code to instrument.
function_name: Name of the function being tested.
qualified_name: Fully qualified function name (e.g., 'module.funcName').
mode: Testing mode - "behavior" or "performance".
Returns:
Instrumented test code with assertions removed.
"""
if not test_code or not test_code.strip():
return test_code
# Use the internal instrumentation function with assertion removal enabled
# Generated tests are treated as regression tests, so we remove LLM-generated assertions
return _instrument_js_test_code(
code=test_code,
func_name=function_name,
test_file_path="generated_test",
mode=mode,
qualified_name=qualified_name,
remove_assertions=True,
)

View file

@ -0,0 +1,333 @@
"""Line profiler instrumentation for JavaScript.
This module provides functionality to instrument JavaScript code with line-level
profiling similar to Python's line_profiler. It tracks execution counts and timing
for each line in instrumented functions.
"""
from __future__ import annotations
import json
import logging
from typing import TYPE_CHECKING
from codeflash.languages.treesitter_utils import get_analyzer_for_file
if TYPE_CHECKING:
from pathlib import Path
from codeflash.languages.base import FunctionInfo
logger = logging.getLogger(__name__)
class JavaScriptLineProfiler:
"""Instruments JavaScript code for line-level profiling.
This class adds profiling code to JavaScript functions to track:
- How many times each line executes
- How much time is spent on each line
- Total execution time per function
"""
def __init__(self, output_file: Path) -> None:
"""Initialize the line profiler.
Args:
output_file: Path where profiling results will be written.
"""
self.output_file = output_file
self.profiler_var = "__codeflash_line_profiler__"
def instrument_source(self, source: str, file_path: Path, functions: list[FunctionInfo]) -> str:
"""Instrument JavaScript source code with line profiling.
Adds profiling instrumentation to track line-level execution for the
specified functions.
Args:
source: Original JavaScript source code.
file_path: Path to the source file.
functions: List of functions to instrument.
Returns:
Instrumented source code with profiling.
"""
if not functions:
return source
# Initialize line contents map to collect source content during instrumentation
self.line_contents: dict[str, str] = {}
# Add instrumentation to each function
lines = source.splitlines(keepends=True)
# Process functions in reverse order to preserve line numbers
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
func_lines = self._instrument_function(func, lines, file_path)
start_idx = func.start_line - 1
end_idx = func.end_line
lines = lines[:start_idx] + func_lines + lines[end_idx:]
instrumented_source = "".join(lines)
# Add profiler initialization at the top (after collecting line contents)
profiler_init = self._generate_profiler_init()
# Add profiler save at the end
profiler_save = self._generate_profiler_save()
return profiler_init + "\n" + instrumented_source + "\n" + profiler_save
def _generate_profiler_init(self) -> str:
"""Generate JavaScript code for profiler initialization."""
# Serialize line contents map for embedding in JavaScript
line_contents_json = json.dumps(getattr(self, "line_contents", {}))
return f"""
// Codeflash line profiler initialization
// @ts-nocheck
const {self.profiler_var} = {{
stats: {{}},
lineContents: {line_contents_json},
lastLineTime: null,
lastKey: null,
totalHits: 0,
// Called at the start of each function to reset timing state
// This prevents "between function calls" time from being attributed to the last line
enterFunction: function() {{
this.lastKey = null;
this.lastLineTime = null;
}},
hit: function(file, line) {{
const now = performance.now(); // microsecond precision
// Attribute elapsed time to the PREVIOUS line (the one that was executing)
if (this.lastKey !== null && this.lastLineTime !== null) {{
this.stats[this.lastKey].time += (now - this.lastLineTime);
}}
const key = file + ':' + line;
if (!this.stats[key]) {{
this.stats[key] = {{ hits: 0, time: 0, file: file, line: line }};
}}
this.stats[key].hits++;
// Record current line as the one now executing
this.lastKey = key;
this.lastLineTime = now;
this.totalHits++;
// Save every 100 hits to ensure we capture results even with --forceExit
if (this.totalHits % 100 === 0) {{
this.save();
}}
}},
save: function() {{
const fs = require('fs');
const pathModule = require('path');
const outputDir = pathModule.dirname('{self.output_file.as_posix()}');
try {{
if (!fs.existsSync(outputDir)) {{
fs.mkdirSync(outputDir, {{ recursive: true }});
}}
// Merge line contents into stats before saving
const statsWithContent = {{}};
for (const key of Object.keys(this.stats)) {{
statsWithContent[key] = {{
...this.stats[key],
content: this.lineContents[key] || ''
}};
}}
fs.writeFileSync(
'{self.output_file.as_posix()}',
JSON.stringify(statsWithContent, null, 2)
);
}} catch (e) {{
console.error('Failed to save line profile results:', e);
}}
}}
}};
"""
def _generate_profiler_save(self) -> str:
"""Generate JavaScript code to save profiler results."""
return f"""
// Save profiler results on process exit and periodically
// Use beforeExit for graceful shutdowns
process.on('beforeExit', () => {self.profiler_var}.save());
process.on('exit', () => {self.profiler_var}.save());
process.on('SIGINT', () => {{ {self.profiler_var}.save(); process.exit(); }});
process.on('SIGTERM', () => {{ {self.profiler_var}.save(); process.exit(); }});
// For Jest --forceExit compatibility, save periodically (every 500ms)
const __codeflash_save_interval__ = setInterval(() => {self.profiler_var}.save(), 500);
if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // Don't keep process alive
"""
def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path) -> list[str]:
"""Instrument a single function with line profiling.
Args:
func: Function to instrument.
lines: Source lines.
file_path: Path to source file.
Returns:
Instrumented function lines.
"""
func_lines = lines[func.start_line - 1 : func.end_line]
instrumented_lines = []
# Parse the function to find executable lines
analyzer = get_analyzer_for_file(file_path)
source = "".join(func_lines)
try:
tree = analyzer.parse(source.encode("utf8"))
executable_lines = self._find_executable_lines(tree.root_node, source.encode("utf8"))
except Exception as e:
logger.warning("Failed to parse function %s: %s", func.name, e)
return func_lines
# Add profiling to each executable line
# executable_lines contains 1-indexed line numbers within the function snippet
function_entry_added = False
for local_idx, line in enumerate(func_lines):
local_line_num = local_idx + 1 # 1-indexed within function
global_line_num = func.start_line + local_idx # Global line number in original file
stripped = line.strip()
# Add enterFunction() call after the opening brace of the function
if not function_entry_added and "{" in line:
# Find indentation for the function body (use next line's indentation or default)
body_indent = " " # Default 4 spaces
if local_idx + 1 < len(func_lines):
next_line = func_lines[local_idx + 1]
if next_line.strip():
body_indent = " " * (len(next_line) - len(next_line.lstrip()))
# Add the line with enterFunction() call after it
instrumented_lines.append(line)
instrumented_lines.append(f"{body_indent}{self.profiler_var}.enterFunction();\n")
function_entry_added = True
continue
# Skip empty lines, comments, and closing braces
if local_line_num in executable_lines and stripped and not stripped.startswith("//") and stripped != "}":
# Get indentation
indent = len(line) - len(line.lstrip())
indent_str = " " * indent
# Store line content for the profiler output
content_key = f"{file_path.as_posix()}:{global_line_num}"
self.line_contents[content_key] = stripped
# Add hit() call before the line
profiled_line = (
f"{indent_str}{self.profiler_var}.hit('{file_path.as_posix()}', {global_line_num});\n{line}"
)
instrumented_lines.append(profiled_line)
else:
instrumented_lines.append(line)
return instrumented_lines
def _find_executable_lines(self, node, source_bytes: bytes) -> set[int]:
"""Find lines that contain executable statements.
Args:
node: Tree-sitter AST node.
source_bytes: Source code as bytes.
Returns:
Set of line numbers with executable statements.
"""
executable_lines = set()
# Node types that represent executable statements
executable_types = {
"expression_statement",
"return_statement",
"if_statement",
"for_statement",
"while_statement",
"do_statement",
"switch_statement",
"throw_statement",
"try_statement",
"variable_declaration",
"lexical_declaration",
"assignment_expression",
"call_expression",
"await_expression",
}
def walk(n) -> None:
if n.type in executable_types:
# Add the starting line (1-indexed)
executable_lines.add(n.start_point[0] + 1)
for child in n.children:
walk(child)
walk(node)
return executable_lines
@staticmethod
def parse_results(profile_file: Path) -> dict:
"""Parse line profiling results from output file.
Args:
profile_file: Path to profiling results JSON file.
Returns:
Dictionary with profiling statistics.
"""
if not profile_file.exists():
return {"timings": {}, "unit": 1e-9, "functions": {}}
try:
with profile_file.open("r") as f:
data = json.load(f)
# Group by file and function
timings = {}
for key, stats in data.items():
file_path, line_num = key.rsplit(":", 1)
line_num = int(line_num)
# performance.now() returns milliseconds, convert to nanoseconds
time_ms = float(stats["time"])
time_ns = int(time_ms * 1e6)
hits = stats["hits"]
if file_path not in timings:
timings[file_path] = {}
content = stats.get("content", "")
timings[file_path][line_num] = {
"hits": hits,
"time_ns": time_ns,
"time_ms": time_ms,
"content": content,
}
return {
"timings": timings,
"unit": 1e-9, # nanoseconds
"raw_data": data,
}
except Exception as e:
logger.exception("Failed to parse line profile results: %s", e)
return {"timings": {}, "unit": 1e-9, "functions": {}}

View file

@ -0,0 +1,324 @@
"""Module system detection for JavaScript/TypeScript projects.
Determines whether a project uses CommonJS (require/module.exports) or
ES Modules (import/export).
"""
from __future__ import annotations
import json
import logging
import re
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pathlib import Path
logger = logging.getLogger(__name__)
class ModuleSystem:
"""Enum-like class for module systems."""
COMMONJS = "commonjs"
ES_MODULE = "esm"
UNKNOWN = "unknown"
# Pattern for destructured require: const { a, b } = require('...')
destructured_require = re.compile(
r"(const|let|var)\s+\{\s*([^}]+)\s*\}\s*=\s*require\s*\(\s*['\"]([^'\"]+)['\"]\s*\)\s*;?"
)
# Pattern for require with property access: const foo = require('...').propertyName
# This must come before simple_require to match first
property_access_require = re.compile(
r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]([^'\"]+)['\"]\s*\)\.(\w+)\s*;?"
)
# Pattern for simple require: const foo = require('...')
simple_require = re.compile(r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]([^'\"]+)['\"]\s*\)\s*;?")
def detect_module_system(project_root: Path, file_path: Path | None = None) -> str:
"""Detect the module system used by a JavaScript/TypeScript project.
Detection strategy:
1. Check package.json for "type" field
2. If file_path provided, check file extension (.mjs = ESM, .cjs = CommonJS)
3. Analyze import statements in the file
4. Default to CommonJS if uncertain
Args:
project_root: Root directory of the project containing package.json.
file_path: Optional specific file to analyze.
Returns:
ModuleSystem constant (COMMONJS, ES_MODULE, or UNKNOWN).
"""
# Strategy 1: Check package.json
package_json = project_root / "package.json"
if package_json.exists():
try:
with package_json.open("r") as f:
pkg = json.load(f)
pkg_type = pkg.get("type", "commonjs")
if pkg_type == "module":
logger.debug("Detected ES Module from package.json type field")
return ModuleSystem.ES_MODULE
if pkg_type == "commonjs":
logger.debug("Detected CommonJS from package.json type field")
return ModuleSystem.COMMONJS
except Exception as e:
logger.warning("Failed to parse package.json: %s", e)
# Strategy 2: Check file extension
if file_path:
suffix = file_path.suffix
if suffix == ".mjs":
logger.debug("Detected ES Module from .mjs extension")
return ModuleSystem.ES_MODULE
if suffix == ".cjs":
logger.debug("Detected CommonJS from .cjs extension")
return ModuleSystem.COMMONJS
# Strategy 3: Analyze file content
if file_path.exists():
try:
content = file_path.read_text()
# Look for ES module syntax
has_import = "import " in content and "from " in content
has_export = "export " in content or "export default" in content or "export {" in content
# Look for CommonJS syntax
has_require = "require(" in content
has_module_exports = "module.exports" in content or "exports." in content
# Determine based on what we found
if (has_import or has_export) and not (has_require or has_module_exports):
logger.debug("Detected ES Module from import/export statements")
return ModuleSystem.ES_MODULE
if (has_require or has_module_exports) and not (has_import or has_export):
logger.debug("Detected CommonJS from require/module.exports")
return ModuleSystem.COMMONJS
except Exception as e:
logger.warning("Failed to analyze file %s: %s", file_path, e)
# Default to CommonJS (more common and backward compatible)
logger.debug("Defaulting to CommonJS")
return ModuleSystem.COMMONJS
def get_import_statement(
module_system: str, target_path: Path, source_path: Path, imported_names: list[str] | None = None
) -> str:
"""Generate the appropriate import statement for the module system.
Args:
module_system: ModuleSystem constant (COMMONJS or ES_MODULE).
target_path: Path to the module being imported.
source_path: Path to the file doing the importing.
imported_names: List of names to import (for named imports).
Returns:
Import statement string.
"""
# Calculate relative import path
rel_path = _get_relative_import_path(target_path, source_path)
if module_system == ModuleSystem.ES_MODULE:
if imported_names:
names = ", ".join(imported_names)
return f"import {{ {names} }} from '{rel_path}';"
# Default import
module_name = target_path.stem
return f"import {module_name} from '{rel_path}';"
if imported_names:
names = ", ".join(imported_names)
return f"const {{ {names} }} = require('{rel_path}');"
# Require entire module
module_name = target_path.stem
return f"const {module_name} = require('{rel_path}');"
def _get_relative_import_path(target_path: Path, source_path: Path) -> str:
"""Calculate relative import path from source to target.
For JavaScript imports, we calculate the path from the source file's directory
to the target file.
Args:
target_path: Absolute path to the file being imported.
source_path: Absolute path to the file doing the importing.
Returns:
Relative import path (without file extension for .js files).
"""
# Both paths should be absolute - get the directory containing source
source_dir = source_path.parent
# Try to use os.path.relpath for accuracy
import os
rel_path_str = os.path.relpath(str(target_path), str(source_dir))
# Normalize to forward slashes
rel_path_str = rel_path_str.replace("\\", "/")
# Remove .js extension (Node.js convention)
rel_path_str = rel_path_str.removesuffix(".js")
# Ensure it starts with ./ or ../ for relative imports
if not rel_path_str.startswith("./") and not rel_path_str.startswith("../"):
rel_path_str = "./" + rel_path_str
return rel_path_str
def add_js_extension(module_path: str) -> str:
"""Add .js extension to relative module paths for ESM compatibility."""
if module_path.startswith(("./", "../")): # noqa: SIM102
if not module_path.endswith(".js") and not module_path.endswith(".mjs"):
return module_path + ".js"
return module_path
# Replace destructured requires with named imports
def replace_destructured(match: re.Match) -> str:
names = match.group(2).strip()
module_path = add_js_extension(match.group(3))
return f"import {{ {names} }} from '{module_path}';"
# Replace property access requires with named imports with alias
# e.g., const foo = require('./module').bar -> import { bar as foo } from './module';
def replace_property_access(match: re.Match) -> str:
alias_name = match.group(2) # The variable name (e.g., missingAuthHeader)
module_path = add_js_extension(match.group(3))
property_name = match.group(4) # The property being accessed (e.g., missingAuthorizationHeader)
# Special case: .default means default export
if property_name == "default":
return f"import {alias_name} from '{module_path}';"
# Named export with alias
if alias_name == property_name:
return f"import {{ {property_name} }} from '{module_path}';"
return f"import {{ {property_name} as {alias_name} }} from '{module_path}';"
# Replace simple requires with default imports
def replace_simple(match: re.Match) -> str:
name = match.group(2)
module_path = add_js_extension(match.group(3))
return f"import {name} from '{module_path}';"
def convert_commonjs_to_esm(code: str) -> str:
"""Convert CommonJS require statements to ES Module imports.
Converts:
const { foo, bar } = require('./module'); -> import { foo, bar } from './module';
const foo = require('./module'); -> import foo from './module';
const foo = require('./module').default; -> import foo from './module';
const foo = require('./module').bar; -> import { bar as foo } from './module';
Special handling:
- Local codeflash helper (./codeflash-jest-helper) is converted to npm package codeflash
because the local helper uses CommonJS exports which don't work in ESM projects
Args:
code: JavaScript code with CommonJS require statements.
Returns:
Code with ES Module import statements.
"""
# Apply conversions (most specific patterns first)
code = destructured_require.sub(replace_destructured, code)
code = property_access_require.sub(replace_property_access, code)
return simple_require.sub(replace_simple, code)
def convert_esm_to_commonjs(code: str) -> str:
"""Convert ES Module imports to CommonJS require statements.
Converts:
import { foo, bar } from './module'; -> const { foo, bar } = require('./module');
import foo from './module'; -> const foo = require('./module');
Args:
code: JavaScript code with ES Module import statements.
Returns:
Code with CommonJS require statements.
"""
import re
# Pattern for named import: import { a, b } from '...'; (semicolon optional)
named_import = re.compile(r"import\s+\{\s*([^}]+)\s*\}\s+from\s+['\"]([^'\"]+)['\"];?")
# Pattern for default import: import foo from '...'; (semicolon optional)
default_import = re.compile(r"import\s+(\w+)\s+from\s+['\"]([^'\"]+)['\"];?")
# Replace named imports with destructured requires
def replace_named(match) -> str:
names = match.group(1).strip()
module_path = match.group(2)
# Remove .js extension for CommonJS (optional but cleaner)
module_path = module_path.removesuffix(".js")
return f"const {{ {names} }} = require('{module_path}');"
# Replace default imports with simple requires
def replace_default(match) -> str:
name = match.group(1)
module_path = match.group(2)
# Remove .js extension for CommonJS
module_path = module_path.removesuffix(".js")
return f"const {name} = require('{module_path}');"
# Apply conversions (named first as it's more specific)
code = named_import.sub(replace_named, code)
return default_import.sub(replace_default, code)
def ensure_module_system_compatibility(code: str, target_module_system: str) -> str:
"""Ensure code uses the correct module system syntax.
Detects the current module system in the code and converts if needed.
Handles mixed-style code (e.g., ESM imports with CommonJS require for npm packages).
Args:
code: JavaScript code to check and potentially convert.
target_module_system: Target ModuleSystem (COMMONJS or ES_MODULE).
Returns:
Code with correct module system syntax.
"""
# Detect current module system in code
has_require = "require(" in code
has_import = "import " in code and "from " in code
if target_module_system == ModuleSystem.ES_MODULE:
# Convert any require() statements to imports for ESM projects
# This handles mixed code (ESM imports + CommonJS requires for npm packages)
if has_require:
logger.debug("Converting CommonJS requires to ESM imports")
return convert_commonjs_to_esm(code)
elif target_module_system == ModuleSystem.COMMONJS:
# Convert any import statements to requires for CommonJS projects
if has_import:
logger.debug("Converting ESM imports to CommonJS requires")
return convert_esm_to_commonjs(code)
return code

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,660 @@
"""JavaScript test runner using Jest.
This module provides functions for running Jest tests for behavioral
verification and performance benchmarking.
"""
from __future__ import annotations
import json
import subprocess
import time
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.config_consts import STABILITY_CENTER_TOLERANCE, STABILITY_SPREAD_TOLERANCE
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
if TYPE_CHECKING:
from codeflash.models.models import TestFiles
def _find_node_project_root(file_path: Path) -> Path | None:
"""Find the Node.js project root by looking for package.json.
Traverses up from the given file path to find the nearest directory
containing package.json or jest.config.js.
Args:
file_path: A file path within the Node.js project.
Returns:
The project root directory, or None if not found.
"""
current = file_path.parent if file_path.is_file() else file_path
while current != current.parent: # Stop at filesystem root
if (
(current / "package.json").exists()
or (current / "jest.config.js").exists()
or (current / "jest.config.ts").exists()
or (current / "tsconfig.json").exists()
):
return current
current = current.parent
return None
def _is_esm_project(project_root: Path) -> bool:
"""Check if the project uses ES Modules.
Detects ESM by checking package.json for "type": "module".
Args:
project_root: The project root directory.
Returns:
True if the project uses ES Modules, False otherwise.
"""
package_json = project_root / "package.json"
if package_json.exists():
try:
with package_json.open("r") as f:
pkg = json.load(f)
return pkg.get("type") == "module"
except Exception as e:
logger.debug(f"Failed to read package.json: {e}")
return False
def _uses_ts_jest(project_root: Path) -> bool:
"""Check if the project uses ts-jest for TypeScript transformation.
ts-jest handles ESM transformation internally, so we don't need the
--experimental-vm-modules flag when it's being used. Adding that flag
can actually break Jest's module resolution for jest.mock() with relative paths.
Args:
project_root: The project root directory.
Returns:
True if ts-jest is being used, False otherwise.
"""
# Check for ts-jest in devDependencies
package_json = project_root / "package.json"
if package_json.exists():
try:
with package_json.open("r") as f:
pkg = json.load(f)
dev_deps = pkg.get("devDependencies", {})
deps = pkg.get("dependencies", {})
if "ts-jest" in dev_deps or "ts-jest" in deps:
return True
except Exception as e:
logger.debug(f"Failed to read package.json for ts-jest detection: {e}")
# Also check for jest.config with ts-jest preset
for config_file in ["jest.config.js", "jest.config.cjs", "jest.config.ts", "jest.config.mjs"]:
config_path = project_root / config_file
if config_path.exists():
try:
content = config_path.read_text()
if "ts-jest" in content:
return True
except Exception as e:
logger.debug(f"Failed to read {config_file}: {e}")
return False
def _configure_esm_environment(jest_env: dict[str, str], project_root: Path) -> None:
"""Configure environment variables for ES Module support in Jest.
Jest requires --experimental-vm-modules flag for ESM support.
This is passed via NODE_OPTIONS environment variable.
IMPORTANT: When ts-jest is being used, we skip adding --experimental-vm-modules
because ts-jest handles ESM transformation internally. Adding this flag can
break Jest's module resolution for jest.mock() calls with relative paths.
Args:
jest_env: Environment variables dictionary to modify.
project_root: The project root directory.
"""
if _is_esm_project(project_root):
# Skip if ts-jest is being used - it handles ESM internally and
# --experimental-vm-modules breaks module resolution for relative mocks
if _uses_ts_jest(project_root):
logger.debug("Skipping --experimental-vm-modules: ts-jest handles ESM transformation")
return
logger.debug("Configuring Jest for ES Module support")
existing_node_options = jest_env.get("NODE_OPTIONS", "")
esm_flag = "--experimental-vm-modules"
if esm_flag not in existing_node_options:
jest_env["NODE_OPTIONS"] = f"{existing_node_options} {esm_flag}".strip()
def _ensure_runtime_files(project_root: Path) -> None:
"""Ensure JavaScript runtime package is installed in the project.
Installs codeflash package if not already present.
The package provides all runtime files needed for test instrumentation.
Args:
project_root: The project root directory.
"""
# Check if package is already installed
node_modules_pkg = project_root / "node_modules" / "codeflash"
if node_modules_pkg.exists():
logger.debug("codeflash already installed")
return
# Try to install from local package first (for development)
local_package_path = Path(__file__).parent.parent.parent.parent / "packages" / "codeflash"
if local_package_path.exists():
try:
result = subprocess.run(
["npm", "install", "--save-dev", str(local_package_path)],
check=False,
cwd=project_root,
capture_output=True,
text=True,
timeout=120,
)
if result.returncode == 0:
logger.debug("Installed codeflash from local package")
return
logger.warning(f"Failed to install local package: {result.stderr}")
except Exception as e:
logger.warning(f"Error installing local package: {e}")
# Try to install from npm registry
try:
result = subprocess.run(
["npm", "install", "--save-dev", "codeflash"],
check=False,
cwd=project_root,
capture_output=True,
text=True,
timeout=120,
)
if result.returncode == 0:
logger.debug("Installed codeflash from npm registry")
return
logger.warning(f"Failed to install from npm: {result.stderr}")
except Exception as e:
logger.warning(f"Error installing from npm: {e}")
logger.error("Could not install codeflash. Please install it manually: npm install --save-dev codeflash")
def run_jest_behavioral_tests(
test_paths: TestFiles,
test_env: dict[str, str],
cwd: Path,
*,
timeout: int | None = None,
project_root: Path | None = None,
enable_coverage: bool = False,
candidate_index: int = 0,
) -> tuple[Path, subprocess.CompletedProcess, Path | None, Path | None]:
"""Run Jest tests and return results in a format compatible with pytest output.
Args:
test_paths: TestFiles object containing test file information.
test_env: Environment variables for the test run.
cwd: Working directory for running tests.
timeout: Optional timeout in seconds.
project_root: JavaScript project root (directory containing package.json).
enable_coverage: Whether to collect coverage information.
candidate_index: Index of the candidate being tested.
Returns:
Tuple of (result_file_path, subprocess_result, coverage_json_path, None).
"""
result_file_path = get_run_tmp_file(Path("jest_results.xml"))
# Get test files to run
test_files = [str(file.instrumented_behavior_file_path) for file in test_paths.test_files]
# Use provided project_root, or detect it as fallback
if project_root is None and test_files:
first_test_file = Path(test_files[0])
project_root = _find_node_project_root(first_test_file)
# Use the project root, or fall back to provided cwd
effective_cwd = project_root if project_root else cwd
logger.debug(f"Jest working directory: {effective_cwd}")
# Ensure the codeflash npm package is installed
_ensure_runtime_files(effective_cwd)
# Coverage output directory
coverage_dir = get_run_tmp_file(Path("jest_coverage"))
coverage_json_path = coverage_dir / "coverage-final.json" if enable_coverage else None
# Build Jest command
jest_cmd = [
"npx",
"jest",
"--reporters=default",
"--reporters=jest-junit",
"--runInBand", # Run tests serially for consistent timing
"--forceExit",
]
# Add coverage flags if enabled
if enable_coverage:
jest_cmd.extend(["--coverage", "--coverageReporters=json", f"--coverageDirectory={coverage_dir}"])
if test_files:
jest_cmd.append("--runTestsByPath")
jest_cmd.extend(str(Path(f).resolve()) for f in test_files)
if timeout:
jest_cmd.append(f"--testTimeout={timeout * 1000}") # Jest uses milliseconds
# Set up environment
jest_env = test_env.copy()
jest_env["JEST_JUNIT_OUTPUT_FILE"] = str(result_file_path)
jest_env["JEST_JUNIT_OUTPUT_DIR"] = str(result_file_path.parent)
jest_env["JEST_JUNIT_OUTPUT_NAME"] = result_file_path.name
# Configure jest-junit to use filepath-based classnames for proper parsing
jest_env["JEST_JUNIT_CLASSNAME"] = "{filepath}"
jest_env["JEST_JUNIT_SUITE_NAME"] = "{filepath}"
jest_env["JEST_JUNIT_ADD_FILE_ATTRIBUTE"] = "true"
# Include console.log output in JUnit XML for timing marker parsing
jest_env["JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT"] = "true"
# Set codeflash output file for the jest helper to write timing/behavior data (SQLite format)
# Use candidate_index to differentiate between baseline (0) and optimization candidates
codeflash_sqlite_file = get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite"))
jest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
jest_env["CODEFLASH_TEST_ITERATION"] = str(candidate_index)
jest_env["CODEFLASH_LOOP_INDEX"] = "1"
jest_env["CODEFLASH_MODE"] = "behavior"
# Seed random number generator for reproducible test runs across original and optimized code
jest_env["CODEFLASH_RANDOM_SEED"] = "42"
# Configure ESM support if project uses ES Modules
_configure_esm_environment(jest_env, effective_cwd)
logger.debug(f"Running Jest tests with command: {' '.join(jest_cmd)}")
start_time_ns = time.perf_counter_ns()
try:
run_args = get_cross_platform_subprocess_run_args(
cwd=effective_cwd, env=jest_env, timeout=timeout or 600, check=False, text=True, capture_output=True
)
result = subprocess.run(jest_cmd, **run_args) # noqa: PLW1510
# Jest sends console.log output to stderr by default - move it to stdout
# so our timing markers (printed via console.log) are in the expected place
if result.stderr and not result.stdout:
result = subprocess.CompletedProcess(
args=result.args, returncode=result.returncode, stdout=result.stderr, stderr=""
)
elif result.stderr:
# Combine stderr into stdout if both have content
result = subprocess.CompletedProcess(
args=result.args, returncode=result.returncode, stdout=result.stdout + "\n" + result.stderr, stderr=""
)
logger.debug(f"Jest result: returncode={result.returncode}")
except subprocess.TimeoutExpired:
logger.warning(f"Jest tests timed out after {timeout}s")
result = subprocess.CompletedProcess(args=jest_cmd, returncode=-1, stdout="", stderr="Test execution timed out")
except FileNotFoundError:
logger.error("Jest not found. Make sure Jest is installed (npm install jest)")
result = subprocess.CompletedProcess(
args=jest_cmd, returncode=-1, stdout="", stderr="Jest not found. Run: npm install jest jest-junit"
)
finally:
wall_clock_ns = time.perf_counter_ns() - start_time_ns
logger.debug(f"Jest behavioral tests completed in {wall_clock_ns / 1e9:.2f}s")
return result_file_path, result, coverage_json_path, None
def _parse_timing_from_jest_output(stdout: str) -> dict[str, int]:
"""Parse timing data from Jest stdout markers.
Extracts timing information from markers like:
!######testModule:testFunc:funcName:loopIndex:invocationId:durationNs######!
Args:
stdout: Jest stdout containing timing markers.
Returns:
Dictionary mapping test case IDs to duration in nanoseconds.
"""
import re
# Pattern: !######module:testFunc:funcName:loopIndex:invocationId:durationNs######!
pattern = re.compile(r"!######([^:]+):([^:]*):([^:]+):([^:]+):([^:]+):(\d+)######!")
timings: dict[str, int] = {}
for match in pattern.finditer(stdout):
module, test_class, func_name, _loop_index, invocation_id, duration_ns = match.groups()
# Create test case ID (same format as Python)
test_id = f"{module}:{test_class}:{func_name}:{invocation_id}"
timings[test_id] = int(duration_ns)
return timings
def _should_stop_stability(
runtimes: list[int],
window: int,
min_window_size: int,
center_rel_tol: float = STABILITY_CENTER_TOLERANCE,
spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE,
) -> bool:
"""Check if performance has stabilized (matches Python's pytest_plugin.should_stop exactly).
This function implements the same stability criteria as the Python pytest_plugin.py
to ensure consistent behavior between Python and JavaScript performance testing.
Args:
runtimes: List of aggregate runtimes (sum of min per test case).
window: Size of the window to check for stability.
min_window_size: Minimum number of data points required.
center_rel_tol: Center tolerance - all recent points must be within this fraction of median.
spread_rel_tol: Spread tolerance - (max-min)/min must be within this fraction.
Returns:
True if performance has stabilized, False otherwise.
"""
if len(runtimes) < window:
return False
if len(runtimes) < min_window_size:
return False
recent = runtimes[-window:]
# Use sorted array for faster median and min/max operations
recent_sorted = sorted(recent)
mid = window // 2
m = recent_sorted[mid] if window % 2 else (recent_sorted[mid - 1] + recent_sorted[mid]) / 2
# 1) All recent points close to the median
centered = True
for r in recent:
if abs(r - m) / m > center_rel_tol:
centered = False
break
# 2) Window spread is small
r_min, r_max = recent_sorted[0], recent_sorted[-1]
if r_min == 0:
return False
spread_ok = (r_max - r_min) / r_min <= spread_rel_tol
return centered and spread_ok
def run_jest_benchmarking_tests(
test_paths: TestFiles,
test_env: dict[str, str],
cwd: Path,
*,
timeout: int | None = None,
project_root: Path | None = None,
min_loops: int = 5,
max_loops: int = 100,
target_duration_ms: int = 10_000, # 10 seconds for benchmarking tests
stability_check: bool = True,
) -> tuple[Path, subprocess.CompletedProcess]:
"""Run Jest benchmarking tests with in-process session-level looping.
Uses a custom Jest runner (codeflash/loop-runner) to loop all tests
within a single Jest process, eliminating process startup overhead.
This matches Python's pytest_plugin behavior:
- All tests are run multiple times within a single Jest process
- Timing data is collected per iteration
- Stability is checked within the runner
Args:
test_paths: TestFiles object containing test file information.
test_env: Environment variables for the test run.
cwd: Working directory for running tests.
timeout: Optional timeout in seconds for the entire benchmark run.
project_root: JavaScript project root (directory containing package.json).
min_loops: Minimum number of loop iterations.
max_loops: Maximum number of loop iterations.
target_duration_ms: Target TOTAL duration in milliseconds for all loops.
stability_check: Whether to enable stability-based early stopping.
Returns:
Tuple of (result_file_path, subprocess_result with stdout from all iterations).
"""
result_file_path = get_run_tmp_file(Path("jest_perf_results.xml"))
# Get performance test files
test_files = [str(file.benchmarking_file_path) for file in test_paths.test_files if file.benchmarking_file_path]
# Use provided project_root, or detect it as fallback
if project_root is None and test_files:
first_test_file = Path(test_files[0])
project_root = _find_node_project_root(first_test_file)
effective_cwd = project_root if project_root else cwd
logger.debug(f"Jest benchmarking working directory: {effective_cwd}")
# Ensure the codeflash npm package is installed
_ensure_runtime_files(effective_cwd)
# Build Jest command for performance tests with custom loop runner
jest_cmd = [
"npx",
"jest",
"--reporters=default",
"--reporters=jest-junit",
"--runInBand", # Ensure serial execution even though runner enforces it
"--forceExit",
"--runner=codeflash/loop-runner", # Use custom loop runner for in-process looping
]
if test_files:
jest_cmd.append("--runTestsByPath")
jest_cmd.extend(str(Path(f).resolve()) for f in test_files)
if timeout:
jest_cmd.append(f"--testTimeout={timeout * 1000}")
# Base environment setup
jest_env = test_env.copy()
jest_env["JEST_JUNIT_OUTPUT_FILE"] = str(result_file_path)
jest_env["JEST_JUNIT_OUTPUT_DIR"] = str(result_file_path.parent)
jest_env["JEST_JUNIT_OUTPUT_NAME"] = result_file_path.name
jest_env["JEST_JUNIT_CLASSNAME"] = "{filepath}"
jest_env["JEST_JUNIT_SUITE_NAME"] = "{filepath}"
jest_env["JEST_JUNIT_ADD_FILE_ATTRIBUTE"] = "true"
jest_env["JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT"] = "true"
codeflash_sqlite_file = get_run_tmp_file(Path("test_return_values_0.sqlite"))
jest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
jest_env["CODEFLASH_TEST_ITERATION"] = "0"
jest_env["CODEFLASH_MODE"] = "performance"
jest_env["CODEFLASH_RANDOM_SEED"] = "42"
# Internal loop configuration for capturePerf (eliminates Jest environment overhead)
# Looping happens inside capturePerf() for maximum efficiency
jest_env["CODEFLASH_PERF_LOOP_COUNT"] = str(max_loops)
jest_env["CODEFLASH_PERF_MIN_LOOPS"] = str(min_loops)
jest_env["CODEFLASH_PERF_TARGET_DURATION_MS"] = str(target_duration_ms)
jest_env["CODEFLASH_PERF_STABILITY_CHECK"] = "true" if stability_check else "false"
jest_env["CODEFLASH_LOOP_INDEX"] = "1" # Initial value for compatibility
# Configure ESM support if project uses ES Modules
_configure_esm_environment(jest_env, effective_cwd)
# Total timeout for the entire benchmark run (longer than single-loop timeout)
# Account for startup overhead + target duration + buffer
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)
logger.debug(f"Running Jest benchmarking tests with in-process loop runner: {' '.join(jest_cmd)}")
logger.debug(
f"Jest benchmarking config: min_loops={min_loops}, max_loops={max_loops}, "
f"target_duration={target_duration_ms}ms, stability_check={stability_check}"
)
total_start_time = time.time()
try:
run_args = get_cross_platform_subprocess_run_args(
cwd=effective_cwd, env=jest_env, timeout=total_timeout, check=False, text=True, capture_output=True
)
result = subprocess.run(jest_cmd, **run_args) # noqa: PLW1510
# Combine stderr into stdout for timing markers
stdout = result.stdout or ""
if result.stderr:
stdout = stdout + "\n" + result.stderr if stdout else result.stderr
# Create result with combined stdout
result = subprocess.CompletedProcess(args=result.args, returncode=result.returncode, stdout=stdout, stderr="")
except subprocess.TimeoutExpired:
logger.warning(f"Jest benchmarking timed out after {total_timeout}s")
result = subprocess.CompletedProcess(args=jest_cmd, returncode=-1, stdout="", stderr="Benchmarking timed out")
except FileNotFoundError:
logger.error("Jest not found for benchmarking")
result = subprocess.CompletedProcess(args=jest_cmd, returncode=-1, stdout="", stderr="Jest not found")
wall_clock_seconds = time.time() - total_start_time
logger.debug(f"Jest benchmarking completed in {wall_clock_seconds:.2f}s")
return result_file_path, result
def run_jest_line_profile_tests(
test_paths: TestFiles,
test_env: dict[str, str],
cwd: Path,
*,
timeout: int | None = None,
project_root: Path | None = None,
line_profile_output_file: Path | None = None,
) -> tuple[Path, subprocess.CompletedProcess]:
"""Run Jest tests for line profiling.
This runs tests against source code that has been instrumented with line profiler.
The instrumentation collects execution counts and timing per line.
Args:
test_paths: TestFiles object containing test file information.
test_env: Environment variables for the test run.
cwd: Working directory for running tests.
timeout: Optional timeout in seconds for the subprocess.
project_root: JavaScript project root (directory containing package.json).
line_profile_output_file: Path where line profile results will be written.
Returns:
Tuple of (result_file_path, subprocess_result).
"""
result_file_path = get_run_tmp_file(Path("jest_line_profile_results.xml"))
# Get test files to run - use instrumented behavior files if available, otherwise benchmarking files
test_files = []
for file in test_paths.test_files:
if file.instrumented_behavior_file_path:
test_files.append(str(file.instrumented_behavior_file_path))
elif file.benchmarking_file_path:
test_files.append(str(file.benchmarking_file_path))
# Use provided project_root, or detect it as fallback
if project_root is None and test_files:
first_test_file = Path(test_files[0])
project_root = _find_node_project_root(first_test_file)
effective_cwd = project_root if project_root else cwd
logger.debug(f"Jest line profiling working directory: {effective_cwd}")
# Ensure the codeflash npm package is installed
_ensure_runtime_files(effective_cwd)
# Build Jest command for line profiling - simple run without benchmarking loops
jest_cmd = [
"npx",
"jest",
"--reporters=default",
"--reporters=jest-junit",
"--runInBand", # Run tests serially for consistent line profiling
"--forceExit",
]
if test_files:
jest_cmd.append("--runTestsByPath")
jest_cmd.extend(str(Path(f).resolve()) for f in test_files)
if timeout:
jest_cmd.append(f"--testTimeout={timeout * 1000}")
# Set up environment
jest_env = test_env.copy()
jest_env["JEST_JUNIT_OUTPUT_FILE"] = str(result_file_path)
jest_env["JEST_JUNIT_OUTPUT_DIR"] = str(result_file_path.parent)
jest_env["JEST_JUNIT_OUTPUT_NAME"] = result_file_path.name
jest_env["JEST_JUNIT_CLASSNAME"] = "{filepath}"
jest_env["JEST_JUNIT_SUITE_NAME"] = "{filepath}"
jest_env["JEST_JUNIT_ADD_FILE_ATTRIBUTE"] = "true"
jest_env["JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT"] = "true"
# Set codeflash output file for the jest helper
codeflash_sqlite_file = get_run_tmp_file(Path("test_return_values_line_profile.sqlite"))
jest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
jest_env["CODEFLASH_TEST_ITERATION"] = "0"
jest_env["CODEFLASH_LOOP_INDEX"] = "1"
jest_env["CODEFLASH_MODE"] = "line_profile"
# Seed random number generator for reproducibility
jest_env["CODEFLASH_RANDOM_SEED"] = "42"
# Pass the line profile output file path to the instrumented code
if line_profile_output_file:
jest_env["CODEFLASH_LINE_PROFILE_OUTPUT"] = str(line_profile_output_file)
# Configure ESM support if project uses ES Modules
_configure_esm_environment(jest_env, effective_cwd)
subprocess_timeout = timeout or 600
logger.debug(f"Running Jest line profile tests: {' '.join(jest_cmd)}")
start_time_ns = time.perf_counter_ns()
try:
run_args = get_cross_platform_subprocess_run_args(
cwd=effective_cwd, env=jest_env, timeout=subprocess_timeout, check=False, text=True, capture_output=True
)
result = subprocess.run(jest_cmd, **run_args) # noqa: PLW1510
# Jest sends console.log output to stderr by default - move it to stdout
if result.stderr and not result.stdout:
result = subprocess.CompletedProcess(
args=result.args, returncode=result.returncode, stdout=result.stderr, stderr=""
)
elif result.stderr:
result = subprocess.CompletedProcess(
args=result.args, returncode=result.returncode, stdout=result.stdout + "\n" + result.stderr, stderr=""
)
logger.debug(f"Jest line profile result: returncode={result.returncode}")
except subprocess.TimeoutExpired:
logger.warning(f"Jest line profile tests timed out after {subprocess_timeout}s")
result = subprocess.CompletedProcess(
args=jest_cmd, returncode=-1, stdout="", stderr="Line profile tests timed out"
)
except FileNotFoundError:
logger.error("Jest not found for line profiling")
result = subprocess.CompletedProcess(args=jest_cmd, returncode=-1, stdout="", stderr="Jest not found")
finally:
wall_clock_ns = time.perf_counter_ns() - start_time_ns
logger.debug(f"Jest line profile tests completed in {wall_clock_ns / 1e9:.2f}s")
return result_file_path, result

Some files were not shown because too many files have changed in this diff Show more