Merge branch 'main' into jit-docs
This commit is contained in:
commit
d020da8294
266 changed files with 61765 additions and 3145 deletions
41
.github/workflows/codeflash.yaml
vendored
Normal file
41
.github/workflows/codeflash.yaml
vendored
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
name: Codeflash
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
# So that this workflow only runs when code within the target module is modified
|
||||
- 'code_to_optimize_js_esm/**'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
# Any new push to the PR will cancel the previous run, so that only the latest code is optimized
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
jobs:
|
||||
optimize:
|
||||
name: Optimize new code
|
||||
# Don't run codeflash on codeflash-ai[bot] commits, prevent duplicate optimizations
|
||||
if: ${{ github.actor != 'codeflash-ai[bot]' }}
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./code_to_optimize_js_esm
|
||||
steps:
|
||||
- name: 🛎️ Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: 🟢 Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22'
|
||||
cache: 'npm'
|
||||
- name: 📦 Install Dependencies
|
||||
run: npm ci
|
||||
|
||||
- name: ⚡️ Codeflash Optimization
|
||||
run: npx codeflash
|
||||
88
.github/workflows/e2e-js-cjs-function.yaml
vendored
Normal file
88
.github/workflows/e2e-js-cjs-function.yaml
vendored
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
name: E2E - JS CommonJS Function
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**' # Trigger for all paths
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
js-cjs-function-optimization:
|
||||
# Dynamically determine if environment is needed only when workflow files change and contributor is external
|
||||
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODEFLASH_AIS_SERVER: prod
|
||||
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
|
||||
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
|
||||
COLUMNS: 110
|
||||
MAX_RETRIES: 3
|
||||
RETRY_DELAY: 5
|
||||
EXPECTED_IMPROVEMENT_PCT: 50
|
||||
CODEFLASH_END_TO_END: 1
|
||||
steps:
|
||||
- name: 🛎️ Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Validate PR
|
||||
run: |
|
||||
# Check for any workflow changes
|
||||
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
|
||||
# Get the PR author
|
||||
AUTHOR="${{ github.event.pull_request.user.login }}"
|
||||
echo "PR Author: $AUTHOR"
|
||||
|
||||
# Allowlist check
|
||||
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($AUTHOR). Proceeding."
|
||||
elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then
|
||||
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
|
||||
else
|
||||
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install codeflash npm package dependencies
|
||||
run: |
|
||||
cd packages/codeflash
|
||||
npm install
|
||||
|
||||
- name: Install JS test project dependencies
|
||||
run: |
|
||||
cd code_to_optimize/js/code_to_optimize_js
|
||||
npm install
|
||||
|
||||
- name: Set up Python 3.11 for CLI
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: 3.11.6
|
||||
|
||||
- name: Install dependencies (CLI)
|
||||
run: |
|
||||
uv sync
|
||||
|
||||
- name: Run Codeflash to optimize JS CommonJS function
|
||||
id: optimize_code
|
||||
run: |
|
||||
uv run python tests/scripts/end_to_end_test_js_cjs_function.py
|
||||
88
.github/workflows/e2e-js-esm-async.yaml
vendored
Normal file
88
.github/workflows/e2e-js-esm-async.yaml
vendored
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
name: E2E - JS ESM Async
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**' # Trigger for all paths
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
js-esm-async-optimization:
|
||||
# Dynamically determine if environment is needed only when workflow files change and contributor is external
|
||||
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODEFLASH_AIS_SERVER: prod
|
||||
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
|
||||
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
|
||||
COLUMNS: 110
|
||||
MAX_RETRIES: 3
|
||||
RETRY_DELAY: 5
|
||||
EXPECTED_IMPROVEMENT_PCT: 10
|
||||
CODEFLASH_END_TO_END: 1
|
||||
steps:
|
||||
- name: 🛎️ Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Validate PR
|
||||
run: |
|
||||
# Check for any workflow changes
|
||||
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
|
||||
# Get the PR author
|
||||
AUTHOR="${{ github.event.pull_request.user.login }}"
|
||||
echo "PR Author: $AUTHOR"
|
||||
|
||||
# Allowlist check
|
||||
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($AUTHOR). Proceeding."
|
||||
elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then
|
||||
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
|
||||
else
|
||||
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install codeflash npm package dependencies
|
||||
run: |
|
||||
cd packages/codeflash
|
||||
npm install
|
||||
|
||||
- name: Install JS test project dependencies
|
||||
run: |
|
||||
cd code_to_optimize/js/code_to_optimize_js_esm
|
||||
npm install
|
||||
|
||||
- name: Set up Python 3.11 for CLI
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: 3.11.6
|
||||
|
||||
- name: Install dependencies (CLI)
|
||||
run: |
|
||||
uv sync
|
||||
|
||||
- name: Run Codeflash to optimize ESM async function
|
||||
id: optimize_code
|
||||
run: |
|
||||
uv run python tests/scripts/end_to_end_test_js_esm_async.py
|
||||
88
.github/workflows/e2e-js-ts-class.yaml
vendored
Normal file
88
.github/workflows/e2e-js-ts-class.yaml
vendored
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
name: E2E - JS TypeScript Class
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**' # Trigger for all paths
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
js-ts-class-optimization:
|
||||
# Dynamically determine if environment is needed only when workflow files change and contributor is external
|
||||
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODEFLASH_AIS_SERVER: prod
|
||||
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
|
||||
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
|
||||
COLUMNS: 110
|
||||
MAX_RETRIES: 3
|
||||
RETRY_DELAY: 5
|
||||
EXPECTED_IMPROVEMENT_PCT: 30
|
||||
CODEFLASH_END_TO_END: 1
|
||||
steps:
|
||||
- name: 🛎️ Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Validate PR
|
||||
run: |
|
||||
# Check for any workflow changes
|
||||
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
|
||||
# Get the PR author
|
||||
AUTHOR="${{ github.event.pull_request.user.login }}"
|
||||
echo "PR Author: $AUTHOR"
|
||||
|
||||
# Allowlist check
|
||||
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($AUTHOR). Proceeding."
|
||||
elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then
|
||||
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
|
||||
else
|
||||
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install codeflash npm package dependencies
|
||||
run: |
|
||||
cd packages/codeflash
|
||||
npm install
|
||||
|
||||
- name: Install JS test project dependencies
|
||||
run: |
|
||||
cd code_to_optimize/js/code_to_optimize_ts
|
||||
npm install
|
||||
|
||||
- name: Set up Python 3.11 for CLI
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: 3.11.6
|
||||
|
||||
- name: Install dependencies (CLI)
|
||||
run: |
|
||||
uv sync
|
||||
|
||||
- name: Run Codeflash to optimize TypeScript class method
|
||||
id: optimize_code
|
||||
run: |
|
||||
uv run python tests/scripts/end_to_end_test_js_ts_class.py
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -163,7 +163,6 @@ cython_debug/
|
|||
#.idea/
|
||||
.aider*
|
||||
/js/common/node_modules/
|
||||
/node_modules/
|
||||
*.xml
|
||||
*.pem
|
||||
|
||||
|
|
@ -259,6 +258,9 @@ WARP.MD
|
|||
.mcp.json
|
||||
.tessl/
|
||||
tessl.json
|
||||
**/node_modules/**
|
||||
**/dist-nuitka/**
|
||||
**/.npmrc
|
||||
|
||||
# Tessl auto-generates AGENTS.md on install; ignore to avoid cluttering git status
|
||||
AGENTS.md
|
||||
|
|
|
|||
1116
MULTI_LANGUAGE_ARCHITECTURE.md
Normal file
1116
MULTI_LANGUAGE_ARCHITECTURE.md
Normal file
File diff suppressed because it is too large
Load diff
49
code_to_optimize/js/code_to_optimize_js/bubble_sort.js
Normal file
49
code_to_optimize/js/code_to_optimize_js/bubble_sort.js
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Bubble sort implementation - intentionally inefficient for optimization testing.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Sort an array using bubble sort algorithm.
|
||||
* @param {number[]} arr - The array to sort
|
||||
* @returns {number[]} - The sorted array
|
||||
*/
|
||||
function bubbleSort(arr) {
|
||||
const result = arr.slice();
|
||||
const n = result.length;
|
||||
|
||||
for (let i = 0; i < n; i++) {
|
||||
for (let j = 0; j < n - 1; j++) {
|
||||
if (result[j] > result[j + 1]) {
|
||||
const temp = result[j];
|
||||
result[j] = result[j + 1];
|
||||
result[j + 1] = temp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sort an array in descending order.
|
||||
* @param {number[]} arr - The array to sort
|
||||
* @returns {number[]} - The sorted array in descending order
|
||||
*/
|
||||
function bubbleSortDescending(arr) {
|
||||
const n = arr.length;
|
||||
const result = [...arr];
|
||||
|
||||
for (let i = 0; i < n - 1; i++) {
|
||||
for (let j = 0; j < n - i - 1; j++) {
|
||||
if (result[j] < result[j + 1]) {
|
||||
const temp = result[j];
|
||||
result[j] = result[j + 1];
|
||||
result[j + 1] = temp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
module.exports = { bubbleSort, bubbleSortDescending };
|
||||
85
code_to_optimize/js/code_to_optimize_js/calculator.js
Normal file
85
code_to_optimize/js/code_to_optimize_js/calculator.js
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Calculator module - demonstrates cross-file function calls.
|
||||
* Uses helper functions from math_helpers.js.
|
||||
*/
|
||||
|
||||
const { sumArray, average, findMax, findMin } = require('./math_helpers');
|
||||
|
||||
|
||||
/**
|
||||
* Calculate statistics for an array of numbers.
|
||||
* @param numbers - Array of numbers to analyze
|
||||
* @returns Object containing sum, average, min, max, and range
|
||||
*/
|
||||
function calculateStats(numbers) {
|
||||
if (numbers.length === 0) {
|
||||
return {
|
||||
sum: 0,
|
||||
average: 0,
|
||||
min: 0,
|
||||
max: 0,
|
||||
range: 0
|
||||
};
|
||||
}
|
||||
|
||||
const sum = sumArray(numbers);
|
||||
const avg = average(numbers);
|
||||
const min = findMin(numbers);
|
||||
const max = findMax(numbers);
|
||||
const range = max - min;
|
||||
|
||||
return {
|
||||
sum,
|
||||
average: avg,
|
||||
min,
|
||||
max,
|
||||
range
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize an array of numbers to a 0-1 range.
|
||||
* @param numbers - Array of numbers to normalize
|
||||
* @returns Normalized array
|
||||
*/
|
||||
function normalizeArray(numbers) {
|
||||
if (numbers.length === 0) return [];
|
||||
|
||||
const min = findMin(numbers);
|
||||
const max = findMax(numbers);
|
||||
const range = max - min;
|
||||
|
||||
if (range === 0) {
|
||||
return numbers.map(() => 0.5);
|
||||
}
|
||||
|
||||
return numbers.map(n => (n - min) / range);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the weighted average of values with corresponding weights.
|
||||
* @param values - Array of values
|
||||
* @param weights - Array of weights (same length as values)
|
||||
* @returns The weighted average
|
||||
*/
|
||||
function weightedAverage(values, weights) {
|
||||
if (values.length === 0 || values.length !== weights.length) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let weightedSum = 0;
|
||||
for (let i = 0; i < values.length; i++) {
|
||||
weightedSum += values[i] * weights[i];
|
||||
}
|
||||
|
||||
const totalWeight = sumArray(weights);
|
||||
if (totalWeight === 0) return 0;
|
||||
|
||||
return weightedSum / totalWeight;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
calculateStats,
|
||||
normalizeArray,
|
||||
weightedAverage
|
||||
};
|
||||
54
code_to_optimize/js/code_to_optimize_js/fibonacci.js
Normal file
54
code_to_optimize/js/code_to_optimize_js/fibonacci.js
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Fibonacci implementations - intentionally inefficient for optimization testing.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Calculate the nth Fibonacci number using naive recursion.
|
||||
* This is intentionally slow to demonstrate optimization potential.
|
||||
* @param {number} n - The index of the Fibonacci number to calculate
|
||||
* @returns {number} - The nth Fibonacci number
|
||||
*/
|
||||
function fibonacci(n) {
|
||||
if (n <= 1) {
|
||||
return n;
|
||||
}
|
||||
return fibonacci(n - 1) + fibonacci(n - 2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a Fibonacci number.
|
||||
* @param {number} num - The number to check
|
||||
* @returns {boolean} - True if num is a Fibonacci number
|
||||
*/
|
||||
function isFibonacci(num) {
|
||||
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
|
||||
const check1 = 5 * num * num + 4;
|
||||
const check2 = 5 * num * num - 4;
|
||||
|
||||
return isPerfectSquare(check1) || isPerfectSquare(check2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a perfect square.
|
||||
* @param {number} n - The number to check
|
||||
* @returns {boolean} - True if n is a perfect square
|
||||
*/
|
||||
function isPerfectSquare(n) {
|
||||
const sqrt = Math.sqrt(n);
|
||||
return sqrt === Math.floor(sqrt);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an array of Fibonacci numbers up to n.
|
||||
* @param {number} n - The number of Fibonacci numbers to generate
|
||||
* @returns {number[]} - Array of Fibonacci numbers
|
||||
*/
|
||||
function fibonacciSequence(n) {
|
||||
const result = [];
|
||||
for (let i = 0; i < n; i++) {
|
||||
result.push(fibonacci(i));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
module.exports = { fibonacci, isFibonacci, isPerfectSquare, fibonacciSequence };
|
||||
61
code_to_optimize/js/code_to_optimize_js/math_helpers.js
Normal file
61
code_to_optimize/js/code_to_optimize_js/math_helpers.js
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Math helper functions - used by other modules.
|
||||
* Some implementations are intentionally inefficient for optimization testing.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Calculate the sum of an array of numbers.
|
||||
* @param numbers - Array of numbers to sum
|
||||
* @returns The sum of all numbers
|
||||
*/
|
||||
function sumArray(numbers) {
|
||||
// Intentionally inefficient - using reduce with spread operator
|
||||
let result = 0;
|
||||
for (let i = 0; i < numbers.length; i++) {
|
||||
result = result + numbers[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the average of an array of numbers.
|
||||
* @param numbers - Array of numbers
|
||||
* @returns The average value
|
||||
*/
|
||||
function average(numbers) {
|
||||
if (numbers.length === 0) return 0;
|
||||
return sumArray(numbers) / numbers.length;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the maximum value in an array.
|
||||
* @param numbers - Array of numbers
|
||||
* @returns The maximum value
|
||||
*/
|
||||
function findMax(numbers) {
|
||||
if (numbers.length === 0) return -Infinity;
|
||||
|
||||
// Intentionally inefficient - sorting instead of linear scan
|
||||
const sorted = [...numbers].sort((a, b) => b - a);
|
||||
return sorted[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the minimum value in an array.
|
||||
* @param numbers - Array of numbers
|
||||
* @returns The minimum value
|
||||
*/
|
||||
function findMin(numbers) {
|
||||
if (numbers.length === 0) return Infinity;
|
||||
|
||||
// Intentionally inefficient - sorting instead of linear scan
|
||||
const sorted = [...numbers].sort((a, b) => a - b);
|
||||
return sorted[0];
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
sumArray,
|
||||
average,
|
||||
findMax,
|
||||
findMin
|
||||
};
|
||||
3731
code_to_optimize/js/code_to_optimize_js/package-lock.json
generated
Normal file
3731
code_to_optimize/js/code_to_optimize_js/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load diff
35
code_to_optimize/js/code_to_optimize_js/package.json
Normal file
35
code_to_optimize/js/code_to_optimize_js/package.json
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
{
|
||||
"name": "codeflash-js-test",
|
||||
"version": "1.0.0",
|
||||
"description": "Sample JavaScript project for codeflash optimization testing",
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"test": "jest"
|
||||
},
|
||||
"codeflash": {
|
||||
"moduleRoot": ".",
|
||||
"testsRoot": "tests"
|
||||
},
|
||||
"devDependencies": {
|
||||
"codeflash": "file:../../../packages/codeflash",
|
||||
"jest": "^29.7.0",
|
||||
"jest-junit": "^16.0.0"
|
||||
},
|
||||
"jest": {
|
||||
"testEnvironment": "node",
|
||||
"testMatch": [
|
||||
"**/tests/**/*.test.js"
|
||||
],
|
||||
"reporters": [
|
||||
"default",
|
||||
[
|
||||
"jest-junit",
|
||||
{
|
||||
"outputDirectory": ".codeflash",
|
||||
"outputName": "jest-results.xml",
|
||||
"includeConsoleOutput": true
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
}
|
||||
95
code_to_optimize/js/code_to_optimize_js/string_utils.js
Normal file
95
code_to_optimize/js/code_to_optimize_js/string_utils.js
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* String utility functions - some intentionally inefficient for optimization testing.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Reverse a string character by character.
|
||||
* @param {string} str - The string to reverse
|
||||
* @returns {string} - The reversed string
|
||||
*/
|
||||
function reverseString(str) {
|
||||
// Intentionally inefficient O(n²) implementation for testing
|
||||
let result = '';
|
||||
for (let i = str.length - 1; i >= 0; i--) {
|
||||
// Rebuild the entire result string each iteration (very inefficient)
|
||||
let temp = '';
|
||||
for (let j = 0; j < result.length; j++) {
|
||||
temp += result[j];
|
||||
}
|
||||
temp += str[i];
|
||||
result = temp;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a string is a palindrome.
|
||||
* @param {string} str - The string to check
|
||||
* @returns {boolean} - True if str is a palindrome
|
||||
*/
|
||||
function isPalindrome(str) {
|
||||
const cleaned = str.toLowerCase().replace(/[^a-z0-9]/g, '');
|
||||
return cleaned === reverseString(cleaned);
|
||||
}
|
||||
|
||||
/**
|
||||
* Count occurrences of a substring in a string.
|
||||
* @param {string} str - The string to search in
|
||||
* @param {string} sub - The substring to count
|
||||
* @returns {number} - Number of occurrences
|
||||
*/
|
||||
function countOccurrences(str, sub) {
|
||||
let count = 0;
|
||||
let pos = 0;
|
||||
|
||||
while (true) {
|
||||
pos = str.indexOf(sub, pos);
|
||||
if (pos === -1) break;
|
||||
count++;
|
||||
pos += 1; // Move past current match
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the longest common prefix of an array of strings.
|
||||
* @param {string[]} strs - Array of strings
|
||||
* @returns {string} - The longest common prefix
|
||||
*/
|
||||
function longestCommonPrefix(strs) {
|
||||
if (strs.length === 0) return '';
|
||||
if (strs.length === 1) return strs[0];
|
||||
|
||||
let prefix = strs[0];
|
||||
|
||||
for (let i = 1; i < strs.length; i++) {
|
||||
while (strs[i].indexOf(prefix) !== 0) {
|
||||
prefix = prefix.slice(0, -1);
|
||||
if (prefix === '') return '';
|
||||
}
|
||||
}
|
||||
|
||||
return prefix;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a string to title case.
|
||||
* @param {string} str - The string to convert
|
||||
* @returns {string} - The title-cased string
|
||||
*/
|
||||
function toTitleCase(str) {
|
||||
return str
|
||||
.toLowerCase()
|
||||
.split(' ')
|
||||
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(' ');
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
reverseString,
|
||||
isPalindrome,
|
||||
countOccurrences,
|
||||
longestCommonPrefix,
|
||||
toTitleCase
|
||||
};
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
const { bubbleSort, bubbleSortDescending } = require('../bubble_sort');
|
||||
|
||||
describe('bubbleSort', () => {
|
||||
test('sorts an empty array', () => {
|
||||
expect(bubbleSort([])).toEqual([]);
|
||||
});
|
||||
|
||||
test('sorts a single element array', () => {
|
||||
expect(bubbleSort([1])).toEqual([1]);
|
||||
});
|
||||
|
||||
test('sorts an already sorted array', () => {
|
||||
expect(bubbleSort([1, 2, 3, 4, 5])).toEqual([1, 2, 3, 4, 5]);
|
||||
});
|
||||
|
||||
test('sorts a reverse sorted array', () => {
|
||||
expect(bubbleSort([5, 4, 3, 2, 1])).toEqual([1, 2, 3, 4, 5]);
|
||||
});
|
||||
|
||||
test('sorts an array with duplicates', () => {
|
||||
expect(bubbleSort([3, 1, 4, 1, 5, 9, 2, 6])).toEqual([1, 1, 2, 3, 4, 5, 6, 9]);
|
||||
});
|
||||
|
||||
test('sorts negative numbers', () => {
|
||||
expect(bubbleSort([-3, -1, -4, -1, -5])).toEqual([-5, -4, -3, -1, -1]);
|
||||
});
|
||||
|
||||
test('does not mutate original array', () => {
|
||||
const original = [3, 1, 2];
|
||||
bubbleSort(original);
|
||||
expect(original).toEqual([3, 1, 2]);
|
||||
});
|
||||
|
||||
test('sorts a larger reverse sorted array for performance', () => {
|
||||
const input = [];
|
||||
for (let i = 500; i >= 0; i--) {
|
||||
input.push(i);
|
||||
}
|
||||
const result = bubbleSort(input);
|
||||
expect(result[0]).toBe(0);
|
||||
expect(result[result.length - 1]).toBe(500);
|
||||
});
|
||||
|
||||
test('sorts a larger random array for performance', () => {
|
||||
const input = [
|
||||
42, 17, 93, 8, 67, 31, 55, 22, 89, 4,
|
||||
76, 12, 39, 58, 95, 26, 71, 48, 83, 19,
|
||||
64, 3, 88, 37, 52, 11, 79, 46, 91, 28,
|
||||
63, 7, 84, 33, 57, 14, 72, 41, 96, 24,
|
||||
69, 6, 81, 36, 54, 16, 77, 44, 90, 29
|
||||
];
|
||||
const result = bubbleSort(input);
|
||||
expect(result[0]).toBe(3);
|
||||
expect(result[result.length - 1]).toBe(96);
|
||||
});
|
||||
});
|
||||
|
||||
describe('bubbleSortDescending', () => {
|
||||
test('sorts in descending order', () => {
|
||||
expect(bubbleSortDescending([1, 3, 2, 5, 4])).toEqual([5, 4, 3, 2, 1]);
|
||||
});
|
||||
|
||||
test('handles empty array', () => {
|
||||
expect(bubbleSortDescending([])).toEqual([]);
|
||||
});
|
||||
|
||||
test('handles single element', () => {
|
||||
expect(bubbleSortDescending([42])).toEqual([42]);
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,470 @@
|
|||
/**
|
||||
* End-to-End Behavior Comparison Test
|
||||
*
|
||||
* This test verifies that:
|
||||
* 1. The instrumentation correctly captures function behavior (args + return value)
|
||||
* 2. Serialization/deserialization preserves all value types
|
||||
* 3. The comparator correctly identifies equivalent behaviors
|
||||
*
|
||||
* It simulates what happens during optimization verification:
|
||||
* - Run the same tests twice (original vs optimized) with different LOOP_INDEX
|
||||
* - Store results to different locations
|
||||
* - Compare the serialized values using the comparator
|
||||
*/
|
||||
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { execSync, spawn } = require('child_process');
|
||||
|
||||
// Import our modules from npm package
|
||||
const { serialize, deserialize, getSerializerType, comparator } = require('codeflash');
|
||||
|
||||
// Test output directory
|
||||
const TEST_OUTPUT_DIR = '/tmp/codeflash_e2e_test';
|
||||
|
||||
// Sample functions to test with various return types
|
||||
const testFunctions = {
|
||||
// Primitives
|
||||
returnNumber: (x) => x * 2,
|
||||
returnString: (s) => s.toUpperCase(),
|
||||
returnBoolean: (x) => x > 0,
|
||||
returnNull: () => null,
|
||||
returnUndefined: () => undefined,
|
||||
|
||||
// Special numbers
|
||||
returnNaN: () => NaN,
|
||||
returnInfinity: () => Infinity,
|
||||
returnNegInfinity: () => -Infinity,
|
||||
|
||||
// Complex types
|
||||
returnArray: (arr) => arr.map(x => x * 2),
|
||||
returnObject: (obj) => ({ ...obj, processed: true }),
|
||||
returnMap: (entries) => new Map(entries),
|
||||
returnSet: (values) => new Set(values),
|
||||
returnDate: (ts) => new Date(ts),
|
||||
returnRegExp: (pattern, flags) => new RegExp(pattern, flags),
|
||||
|
||||
// Nested structures
|
||||
returnNested: (data) => ({
|
||||
array: [1, 2, 3],
|
||||
map: new Map([['key', data]]),
|
||||
set: new Set([data]),
|
||||
date: new Date('2024-01-15'),
|
||||
}),
|
||||
|
||||
// TypedArrays
|
||||
returnTypedArray: (data) => new Float64Array(data),
|
||||
|
||||
// Error handling
|
||||
mayThrow: (shouldThrow) => {
|
||||
if (shouldThrow) throw new Error('Test error');
|
||||
return 'success';
|
||||
},
|
||||
};
|
||||
|
||||
describe('E2E Behavior Comparison', () => {
|
||||
beforeAll(() => {
|
||||
// Clean up and create test directory
|
||||
if (fs.existsSync(TEST_OUTPUT_DIR)) {
|
||||
fs.rmSync(TEST_OUTPUT_DIR, { recursive: true });
|
||||
}
|
||||
fs.mkdirSync(TEST_OUTPUT_DIR, { recursive: true });
|
||||
console.log('Using serializer:', getSerializerType());
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
// Cleanup
|
||||
if (fs.existsSync(TEST_OUTPUT_DIR)) {
|
||||
fs.rmSync(TEST_OUTPUT_DIR, { recursive: true });
|
||||
}
|
||||
});
|
||||
|
||||
describe('Direct Serialization Round-Trip', () => {
|
||||
// Test that serialize -> deserialize -> compare works for all types
|
||||
|
||||
test('primitives round-trip correctly', () => {
|
||||
const testCases = [
|
||||
42,
|
||||
-3.14159,
|
||||
'hello world',
|
||||
true,
|
||||
false,
|
||||
null,
|
||||
undefined,
|
||||
BigInt('9007199254740991'),
|
||||
];
|
||||
|
||||
for (const original of testCases) {
|
||||
const serialized = serialize(original);
|
||||
const restored = deserialize(serialized);
|
||||
expect(comparator(original, restored)).toBe(true);
|
||||
}
|
||||
});
|
||||
|
||||
test('special numbers round-trip correctly', () => {
|
||||
const testCases = [NaN, Infinity, -Infinity, -0];
|
||||
|
||||
for (const original of testCases) {
|
||||
const serialized = serialize(original);
|
||||
const restored = deserialize(serialized);
|
||||
expect(comparator(original, restored)).toBe(true);
|
||||
}
|
||||
});
|
||||
|
||||
test('complex objects round-trip correctly', () => {
|
||||
const testCases = [
|
||||
new Map([['a', 1], ['b', 2]]),
|
||||
new Set([1, 2, 3]),
|
||||
new Date('2024-01-15'),
|
||||
/test\d+/gi,
|
||||
new Error('test error'),
|
||||
new Float64Array([1.1, 2.2, 3.3]),
|
||||
];
|
||||
|
||||
for (const original of testCases) {
|
||||
const serialized = serialize(original);
|
||||
const restored = deserialize(serialized);
|
||||
expect(comparator(original, restored)).toBe(true);
|
||||
}
|
||||
});
|
||||
|
||||
test('nested structures round-trip correctly', () => {
|
||||
const original = {
|
||||
array: [1, 'two', { three: 3 }],
|
||||
map: new Map([['nested', new Set([1, 2, 3])]]),
|
||||
date: new Date('2024-06-15'),
|
||||
regex: /pattern/i,
|
||||
typed: new Int32Array([10, 20, 30]),
|
||||
};
|
||||
|
||||
const serialized = serialize(original);
|
||||
const restored = deserialize(serialized);
|
||||
expect(comparator(original, restored)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Function Behavior Format', () => {
|
||||
// Test the [args, kwargs, return_value] format used by instrumentation
|
||||
|
||||
test('behavior tuple format serializes correctly', () => {
|
||||
// Simulate what recordResult does: [args, {}, returnValue]
|
||||
const args = [42, 'hello'];
|
||||
const kwargs = {}; // JS doesn't have kwargs, always empty
|
||||
const returnValue = { result: 84, message: 'HELLO' };
|
||||
|
||||
const behaviorTuple = [args, kwargs, returnValue];
|
||||
const serialized = serialize(behaviorTuple);
|
||||
const restored = deserialize(serialized);
|
||||
|
||||
expect(comparator(behaviorTuple, restored)).toBe(true);
|
||||
expect(restored[0]).toEqual(args);
|
||||
expect(restored[1]).toEqual(kwargs);
|
||||
expect(comparator(restored[2], returnValue)).toBe(true);
|
||||
});
|
||||
|
||||
test('behavior with Map return value', () => {
|
||||
const args = [['a', 1], ['b', 2]];
|
||||
const returnValue = new Map(args);
|
||||
const behaviorTuple = [args, {}, returnValue];
|
||||
|
||||
const serialized = serialize(behaviorTuple);
|
||||
const restored = deserialize(serialized);
|
||||
|
||||
expect(comparator(behaviorTuple, restored)).toBe(true);
|
||||
expect(restored[2] instanceof Map).toBe(true);
|
||||
expect(restored[2].get('a')).toBe(1);
|
||||
});
|
||||
|
||||
test('behavior with Set return value', () => {
|
||||
const args = [[1, 2, 3]];
|
||||
const returnValue = new Set([1, 2, 3]);
|
||||
const behaviorTuple = [args, {}, returnValue];
|
||||
|
||||
const serialized = serialize(behaviorTuple);
|
||||
const restored = deserialize(serialized);
|
||||
|
||||
expect(comparator(behaviorTuple, restored)).toBe(true);
|
||||
expect(restored[2] instanceof Set).toBe(true);
|
||||
expect(restored[2].has(2)).toBe(true);
|
||||
});
|
||||
|
||||
test('behavior with Date return value', () => {
|
||||
const args = [1705276800000]; // 2024-01-15
|
||||
const returnValue = new Date(1705276800000);
|
||||
const behaviorTuple = [args, {}, returnValue];
|
||||
|
||||
const serialized = serialize(behaviorTuple);
|
||||
const restored = deserialize(serialized);
|
||||
|
||||
expect(comparator(behaviorTuple, restored)).toBe(true);
|
||||
expect(restored[2] instanceof Date).toBe(true);
|
||||
expect(restored[2].getTime()).toBe(1705276800000);
|
||||
});
|
||||
|
||||
test('behavior with TypedArray return value', () => {
|
||||
const args = [[1.1, 2.2, 3.3]];
|
||||
const returnValue = new Float64Array([1.1, 2.2, 3.3]);
|
||||
const behaviorTuple = [args, {}, returnValue];
|
||||
|
||||
const serialized = serialize(behaviorTuple);
|
||||
const restored = deserialize(serialized);
|
||||
|
||||
expect(comparator(behaviorTuple, restored)).toBe(true);
|
||||
expect(restored[2] instanceof Float64Array).toBe(true);
|
||||
});
|
||||
|
||||
test('behavior with Error (exception case)', () => {
|
||||
const error = new TypeError('Invalid argument');
|
||||
const serialized = serialize(error);
|
||||
const restored = deserialize(serialized);
|
||||
|
||||
expect(comparator(error, restored)).toBe(true);
|
||||
expect(restored.name).toBe('TypeError');
|
||||
expect(restored.message).toBe('Invalid argument');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Simulated Original vs Optimized Comparison', () => {
|
||||
// Simulate running the same function twice and comparing results
|
||||
|
||||
function runAndCapture(fn, args) {
|
||||
try {
|
||||
const returnValue = fn(...args);
|
||||
return { success: true, value: [args, {}, returnValue] };
|
||||
} catch (error) {
|
||||
return { success: false, error };
|
||||
}
|
||||
}
|
||||
|
||||
test('identical behaviors are equal - number function', () => {
|
||||
const fn = testFunctions.returnNumber;
|
||||
const args = [21];
|
||||
|
||||
// "Original" run
|
||||
const original = runAndCapture(fn, args);
|
||||
const originalSerialized = serialize(original.value);
|
||||
|
||||
// "Optimized" run (same function, simulating optimization)
|
||||
const optimized = runAndCapture(fn, args);
|
||||
const optimizedSerialized = serialize(optimized.value);
|
||||
|
||||
// Deserialize and compare (what verification does)
|
||||
const originalRestored = deserialize(originalSerialized);
|
||||
const optimizedRestored = deserialize(optimizedSerialized);
|
||||
|
||||
expect(comparator(originalRestored, optimizedRestored)).toBe(true);
|
||||
});
|
||||
|
||||
test('identical behaviors are equal - Map function', () => {
|
||||
const fn = testFunctions.returnMap;
|
||||
const args = [[['x', 10], ['y', 20]]];
|
||||
|
||||
const original = runAndCapture(fn, args);
|
||||
const originalSerialized = serialize(original.value);
|
||||
|
||||
const optimized = runAndCapture(fn, args);
|
||||
const optimizedSerialized = serialize(optimized.value);
|
||||
|
||||
const originalRestored = deserialize(originalSerialized);
|
||||
const optimizedRestored = deserialize(optimizedSerialized);
|
||||
|
||||
expect(comparator(originalRestored, optimizedRestored)).toBe(true);
|
||||
});
|
||||
|
||||
test('identical behaviors are equal - nested structure', () => {
|
||||
const fn = testFunctions.returnNested;
|
||||
const args = ['test-data'];
|
||||
|
||||
const original = runAndCapture(fn, args);
|
||||
const originalSerialized = serialize(original.value);
|
||||
|
||||
const optimized = runAndCapture(fn, args);
|
||||
const optimizedSerialized = serialize(optimized.value);
|
||||
|
||||
const originalRestored = deserialize(originalSerialized);
|
||||
const optimizedRestored = deserialize(optimizedSerialized);
|
||||
|
||||
expect(comparator(originalRestored, optimizedRestored)).toBe(true);
|
||||
});
|
||||
|
||||
test('different behaviors are NOT equal', () => {
|
||||
const fn1 = (x) => x * 2;
|
||||
const fn2 = (x) => x * 3; // Different behavior!
|
||||
const args = [10];
|
||||
|
||||
const original = runAndCapture(fn1, args);
|
||||
const originalSerialized = serialize(original.value);
|
||||
|
||||
const optimized = runAndCapture(fn2, args);
|
||||
const optimizedSerialized = serialize(optimized.value);
|
||||
|
||||
const originalRestored = deserialize(originalSerialized);
|
||||
const optimizedRestored = deserialize(optimizedSerialized);
|
||||
|
||||
// Should be FALSE - behaviors differ (20 vs 30)
|
||||
expect(comparator(originalRestored, optimizedRestored)).toBe(false);
|
||||
});
|
||||
|
||||
test('floating point tolerance works', () => {
|
||||
// Simulate slight floating point differences from optimization
|
||||
const original = [[[1.0]], {}, 0.30000000000000004];
|
||||
const optimized = [[[1.0]], {}, 0.3];
|
||||
|
||||
const originalSerialized = serialize(original);
|
||||
const optimizedSerialized = serialize(optimized);
|
||||
|
||||
const originalRestored = deserialize(originalSerialized);
|
||||
const optimizedRestored = deserialize(optimizedSerialized);
|
||||
|
||||
// Should be TRUE with default tolerance
|
||||
expect(comparator(originalRestored, optimizedRestored)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Multiple Invocations Comparison', () => {
|
||||
// Simulate multiple test invocations being stored and compared
|
||||
|
||||
test('batch of invocations can be compared', () => {
|
||||
const testCases = [
|
||||
{ fn: testFunctions.returnNumber, args: [1] },
|
||||
{ fn: testFunctions.returnNumber, args: [100] },
|
||||
{ fn: testFunctions.returnString, args: ['hello'] },
|
||||
{ fn: testFunctions.returnArray, args: [[1, 2, 3]] },
|
||||
{ fn: testFunctions.returnMap, args: [[['a', 1]]] },
|
||||
{ fn: testFunctions.returnSet, args: [[1, 2, 3]] },
|
||||
{ fn: testFunctions.returnDate, args: [1705276800000] },
|
||||
{ fn: testFunctions.returnNested, args: ['data'] },
|
||||
];
|
||||
|
||||
// Simulate original run
|
||||
const originalResults = testCases.map(({ fn, args }) => {
|
||||
const returnValue = fn(...args);
|
||||
return serialize([args, {}, returnValue]);
|
||||
});
|
||||
|
||||
// Simulate optimized run (same functions)
|
||||
const optimizedResults = testCases.map(({ fn, args }) => {
|
||||
const returnValue = fn(...args);
|
||||
return serialize([args, {}, returnValue]);
|
||||
});
|
||||
|
||||
// Compare all results
|
||||
for (let i = 0; i < testCases.length; i++) {
|
||||
const originalRestored = deserialize(originalResults[i]);
|
||||
const optimizedRestored = deserialize(optimizedResults[i]);
|
||||
|
||||
expect(comparator(originalRestored, optimizedRestored)).toBe(true);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('File-Based Comparison (SQLite Simulation)', () => {
|
||||
// Simulate writing to files and reading back for comparison
|
||||
|
||||
test('can write and read back serialized results', () => {
|
||||
const originalPath = path.join(TEST_OUTPUT_DIR, 'original.bin');
|
||||
const optimizedPath = path.join(TEST_OUTPUT_DIR, 'optimized.bin');
|
||||
|
||||
// Test data
|
||||
const behaviorData = {
|
||||
args: [42, 'test', { nested: true }],
|
||||
kwargs: {},
|
||||
returnValue: {
|
||||
result: new Map([['answer', 42]]),
|
||||
metadata: new Set(['processed', 'validated']),
|
||||
timestamp: new Date('2024-01-15'),
|
||||
},
|
||||
};
|
||||
|
||||
const tuple = [behaviorData.args, behaviorData.kwargs, behaviorData.returnValue];
|
||||
|
||||
// Write "original" result
|
||||
const originalBuffer = serialize(tuple);
|
||||
fs.writeFileSync(originalPath, originalBuffer);
|
||||
|
||||
// Write "optimized" result (same data, simulating correct optimization)
|
||||
const optimizedBuffer = serialize(tuple);
|
||||
fs.writeFileSync(optimizedPath, optimizedBuffer);
|
||||
|
||||
// Read back and compare
|
||||
const originalRead = fs.readFileSync(originalPath);
|
||||
const optimizedRead = fs.readFileSync(optimizedPath);
|
||||
|
||||
const originalRestored = deserialize(originalRead);
|
||||
const optimizedRestored = deserialize(optimizedRead);
|
||||
|
||||
expect(comparator(originalRestored, optimizedRestored)).toBe(true);
|
||||
|
||||
// Verify the complex types survived
|
||||
expect(originalRestored[2].result instanceof Map).toBe(true);
|
||||
expect(originalRestored[2].metadata instanceof Set).toBe(true);
|
||||
expect(originalRestored[2].timestamp instanceof Date).toBe(true);
|
||||
});
|
||||
|
||||
test('detects differences in file-based comparison', () => {
|
||||
const originalPath = path.join(TEST_OUTPUT_DIR, 'original2.bin');
|
||||
const optimizedPath = path.join(TEST_OUTPUT_DIR, 'optimized2.bin');
|
||||
|
||||
// Original behavior
|
||||
const originalTuple = [[10], {}, 100];
|
||||
fs.writeFileSync(originalPath, serialize(originalTuple));
|
||||
|
||||
// "Buggy" optimized behavior
|
||||
const optimizedTuple = [[10], {}, 99]; // Wrong result!
|
||||
fs.writeFileSync(optimizedPath, serialize(optimizedTuple));
|
||||
|
||||
// Read back and compare
|
||||
const originalRestored = deserialize(fs.readFileSync(originalPath));
|
||||
const optimizedRestored = deserialize(fs.readFileSync(optimizedPath));
|
||||
|
||||
// Should detect the difference
|
||||
expect(comparator(originalRestored, optimizedRestored)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
test('handles special values in args', () => {
|
||||
const tuple = [[NaN, Infinity, undefined, null], {}, 'processed'];
|
||||
|
||||
const serialized = serialize(tuple);
|
||||
const restored = deserialize(serialized);
|
||||
|
||||
expect(comparator(tuple, restored)).toBe(true);
|
||||
expect(Number.isNaN(restored[0][0])).toBe(true);
|
||||
expect(restored[0][1]).toBe(Infinity);
|
||||
expect(restored[0][2]).toBe(undefined);
|
||||
expect(restored[0][3]).toBe(null);
|
||||
});
|
||||
|
||||
test('handles circular references in return value', () => {
|
||||
const obj = { value: 42 };
|
||||
obj.self = obj; // Circular reference
|
||||
|
||||
const tuple = [[], {}, obj];
|
||||
const serialized = serialize(tuple);
|
||||
const restored = deserialize(serialized);
|
||||
|
||||
expect(comparator(tuple, restored)).toBe(true);
|
||||
expect(restored[2].self).toBe(restored[2]);
|
||||
});
|
||||
|
||||
test('handles empty results', () => {
|
||||
const tuple = [[], {}, undefined];
|
||||
|
||||
const serialized = serialize(tuple);
|
||||
const restored = deserialize(serialized);
|
||||
|
||||
expect(comparator(tuple, restored)).toBe(true);
|
||||
});
|
||||
|
||||
test('handles large arrays', () => {
|
||||
const largeArray = Array.from({ length: 1000 }, (_, i) => i);
|
||||
const tuple = [[largeArray], {}, largeArray.reduce((a, b) => a + b, 0)];
|
||||
|
||||
const serialized = serialize(tuple);
|
||||
const restored = deserialize(serialized);
|
||||
|
||||
expect(comparator(tuple, restored)).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,354 @@
|
|||
#!/usr/bin/env node
|
||||
/**
|
||||
* End-to-End Comparison Test
|
||||
*
|
||||
* This test validates the full behavior comparison workflow:
|
||||
* 1. Serialize test results to SQLite (simulating codeflash-jest-helper)
|
||||
* 2. Run the comparison script
|
||||
* 3. Verify results match expectations
|
||||
*/
|
||||
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
// Import our modules from npm package
|
||||
const { serialize, readTestResults, compareResults } = require('codeflash');
|
||||
|
||||
// Try to load better-sqlite3
|
||||
let Database;
|
||||
try {
|
||||
Database = require('better-sqlite3');
|
||||
} catch (e) {
|
||||
console.error('better-sqlite3 not installed, skipping E2E test');
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
const TEST_DIR = '/tmp/codeflash_e2e_comparison_test';
|
||||
|
||||
/**
|
||||
* Create a SQLite database with test results.
|
||||
*/
|
||||
function createTestDatabase(dbPath, results) {
|
||||
// Ensure directory exists
|
||||
const dir = path.dirname(dbPath);
|
||||
if (!fs.existsSync(dir)) {
|
||||
fs.mkdirSync(dir, { recursive: true });
|
||||
}
|
||||
|
||||
// Remove existing file
|
||||
if (fs.existsSync(dbPath)) {
|
||||
fs.unlinkSync(dbPath);
|
||||
}
|
||||
|
||||
const db = new Database(dbPath);
|
||||
|
||||
// Create table
|
||||
db.exec(`
|
||||
CREATE TABLE test_results (
|
||||
test_module_path TEXT,
|
||||
test_class_name TEXT,
|
||||
test_function_name TEXT,
|
||||
function_getting_tested TEXT,
|
||||
loop_index INTEGER,
|
||||
iteration_id TEXT,
|
||||
runtime INTEGER,
|
||||
return_value BLOB,
|
||||
verification_type TEXT
|
||||
)
|
||||
`);
|
||||
|
||||
// Insert results
|
||||
const stmt = db.prepare(`
|
||||
INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`);
|
||||
|
||||
for (const result of results) {
|
||||
stmt.run(
|
||||
result.testModulePath,
|
||||
result.testClassName || null,
|
||||
result.testFunctionName,
|
||||
result.functionGettingTested,
|
||||
result.loopIndex,
|
||||
result.iterationId,
|
||||
result.runtime,
|
||||
result.returnValue ? serialize(result.returnValue) : null,
|
||||
result.verificationType || 'function_call'
|
||||
);
|
||||
}
|
||||
|
||||
db.close();
|
||||
return dbPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test 1: Identical results should be equivalent.
|
||||
*/
|
||||
function testIdenticalResults() {
|
||||
console.log('\n=== Test 1: Identical Results ===');
|
||||
|
||||
const results = [
|
||||
{
|
||||
testModulePath: 'tests/math.test.js',
|
||||
testFunctionName: 'test adds numbers',
|
||||
functionGettingTested: 'add',
|
||||
loopIndex: 1,
|
||||
iterationId: '0_0',
|
||||
runtime: 1000,
|
||||
returnValue: [[1, 2], {}, 3], // [args, kwargs, returnValue]
|
||||
},
|
||||
{
|
||||
testModulePath: 'tests/math.test.js',
|
||||
testFunctionName: 'test multiplies numbers',
|
||||
functionGettingTested: 'multiply',
|
||||
loopIndex: 1,
|
||||
iterationId: '0_1',
|
||||
runtime: 1000,
|
||||
returnValue: [[2, 3], {}, 6],
|
||||
},
|
||||
];
|
||||
|
||||
const originalDb = createTestDatabase(path.join(TEST_DIR, 'original1.sqlite'), results);
|
||||
const candidateDb = createTestDatabase(path.join(TEST_DIR, 'candidate1.sqlite'), results);
|
||||
|
||||
const originalResults = readTestResults(originalDb);
|
||||
const candidateResults = readTestResults(candidateDb);
|
||||
const comparison = compareResults(originalResults, candidateResults);
|
||||
|
||||
console.log(` Original invocations: ${originalResults.size}`);
|
||||
console.log(` Candidate invocations: ${candidateResults.size}`);
|
||||
console.log(` Equivalent: ${comparison.equivalent}`);
|
||||
console.log(` Diffs: ${comparison.diffs.length}`);
|
||||
|
||||
if (!comparison.equivalent || comparison.diffs.length > 0) {
|
||||
console.log(' ❌ FAILED: Expected identical results to be equivalent');
|
||||
return false;
|
||||
}
|
||||
console.log(' ✅ PASSED');
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test 2: Different return values should NOT be equivalent.
|
||||
*/
|
||||
function testDifferentReturnValues() {
|
||||
console.log('\n=== Test 2: Different Return Values ===');
|
||||
|
||||
const originalResults = [
|
||||
{
|
||||
testModulePath: 'tests/math.test.js',
|
||||
testFunctionName: 'test adds numbers',
|
||||
functionGettingTested: 'add',
|
||||
loopIndex: 1,
|
||||
iterationId: '0_0',
|
||||
runtime: 1000,
|
||||
returnValue: [[1, 2], {}, 3], // Correct: 1 + 2 = 3
|
||||
},
|
||||
];
|
||||
|
||||
const candidateResults = [
|
||||
{
|
||||
testModulePath: 'tests/math.test.js',
|
||||
testFunctionName: 'test adds numbers',
|
||||
functionGettingTested: 'add',
|
||||
loopIndex: 1,
|
||||
iterationId: '0_0',
|
||||
runtime: 1000,
|
||||
returnValue: [[1, 2], {}, 4], // Wrong: should be 3, not 4
|
||||
},
|
||||
];
|
||||
|
||||
const originalDb = createTestDatabase(path.join(TEST_DIR, 'original2.sqlite'), originalResults);
|
||||
const candidateDb = createTestDatabase(path.join(TEST_DIR, 'candidate2.sqlite'), candidateResults);
|
||||
|
||||
const original = readTestResults(originalDb);
|
||||
const candidate = readTestResults(candidateDb);
|
||||
const comparison = compareResults(original, candidate);
|
||||
|
||||
console.log(` Equivalent: ${comparison.equivalent}`);
|
||||
console.log(` Diffs: ${comparison.diffs.length}`);
|
||||
|
||||
if (comparison.equivalent || comparison.diffs.length === 0) {
|
||||
console.log(' ❌ FAILED: Expected different results to NOT be equivalent');
|
||||
return false;
|
||||
}
|
||||
console.log(` Diff found: ${comparison.diffs[0].scope}`);
|
||||
console.log(' ✅ PASSED');
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test 3: Complex JavaScript types (Map, Set, Date) should compare correctly.
|
||||
*/
|
||||
function testComplexTypes() {
|
||||
console.log('\n=== Test 3: Complex JavaScript Types ===');
|
||||
|
||||
const complexValue = {
|
||||
map: new Map([['a', 1], ['b', 2]]),
|
||||
set: new Set([1, 2, 3]),
|
||||
date: new Date('2024-01-15T00:00:00.000Z'),
|
||||
nested: {
|
||||
array: [1, 2, 3],
|
||||
map: new Map([['nested', true]]),
|
||||
},
|
||||
};
|
||||
|
||||
const results = [
|
||||
{
|
||||
testModulePath: 'tests/complex.test.js',
|
||||
testFunctionName: 'test complex return',
|
||||
functionGettingTested: 'processData',
|
||||
loopIndex: 1,
|
||||
iterationId: '0_0',
|
||||
runtime: 1000,
|
||||
returnValue: [[], {}, complexValue],
|
||||
},
|
||||
];
|
||||
|
||||
const originalDb = createTestDatabase(path.join(TEST_DIR, 'original3.sqlite'), results);
|
||||
const candidateDb = createTestDatabase(path.join(TEST_DIR, 'candidate3.sqlite'), results);
|
||||
|
||||
const original = readTestResults(originalDb);
|
||||
const candidate = readTestResults(candidateDb);
|
||||
const comparison = compareResults(original, candidate);
|
||||
|
||||
console.log(` Original invocations: ${original.size}`);
|
||||
console.log(` Equivalent: ${comparison.equivalent}`);
|
||||
console.log(` Diffs: ${comparison.diffs.length}`);
|
||||
|
||||
if (!comparison.equivalent) {
|
||||
console.log(' ❌ FAILED: Expected complex types to be equivalent');
|
||||
if (comparison.diffs.length > 0) {
|
||||
console.log(` Diff: ${JSON.stringify(comparison.diffs[0])}`);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
console.log(' ✅ PASSED');
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test 4: Floating point tolerance should allow small differences.
|
||||
*/
|
||||
function testFloatingPointTolerance() {
|
||||
console.log('\n=== Test 4: Floating Point Tolerance ===');
|
||||
|
||||
const originalResults = [
|
||||
{
|
||||
testModulePath: 'tests/float.test.js',
|
||||
testFunctionName: 'test float calculation',
|
||||
functionGettingTested: 'calculate',
|
||||
loopIndex: 1,
|
||||
iterationId: '0_0',
|
||||
runtime: 1000,
|
||||
returnValue: [[], {}, 0.1 + 0.2], // 0.30000000000000004
|
||||
},
|
||||
];
|
||||
|
||||
const candidateResults = [
|
||||
{
|
||||
testModulePath: 'tests/float.test.js',
|
||||
testFunctionName: 'test float calculation',
|
||||
functionGettingTested: 'calculate',
|
||||
loopIndex: 1,
|
||||
iterationId: '0_0',
|
||||
runtime: 1000,
|
||||
returnValue: [[], {}, 0.3], // 0.3 (optimized calculation)
|
||||
},
|
||||
];
|
||||
|
||||
const originalDb = createTestDatabase(path.join(TEST_DIR, 'original4.sqlite'), originalResults);
|
||||
const candidateDb = createTestDatabase(path.join(TEST_DIR, 'candidate4.sqlite'), candidateResults);
|
||||
|
||||
const original = readTestResults(originalDb);
|
||||
const candidate = readTestResults(candidateDb);
|
||||
const comparison = compareResults(original, candidate);
|
||||
|
||||
console.log(` Original value: ${0.1 + 0.2}`);
|
||||
console.log(` Candidate value: ${0.3}`);
|
||||
console.log(` Equivalent: ${comparison.equivalent}`);
|
||||
|
||||
if (!comparison.equivalent) {
|
||||
console.log(' ❌ FAILED: Expected floating point values to be equivalent within tolerance');
|
||||
return false;
|
||||
}
|
||||
console.log(' ✅ PASSED');
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test 5: NaN values should be equal to each other.
|
||||
*/
|
||||
function testNaNEquality() {
|
||||
console.log('\n=== Test 5: NaN Equality ===');
|
||||
|
||||
const results = [
|
||||
{
|
||||
testModulePath: 'tests/nan.test.js',
|
||||
testFunctionName: 'test NaN return',
|
||||
functionGettingTested: 'divideByZero',
|
||||
loopIndex: 1,
|
||||
iterationId: '0_0',
|
||||
runtime: 1000,
|
||||
returnValue: [[], {}, NaN],
|
||||
},
|
||||
];
|
||||
|
||||
const originalDb = createTestDatabase(path.join(TEST_DIR, 'original5.sqlite'), results);
|
||||
const candidateDb = createTestDatabase(path.join(TEST_DIR, 'candidate5.sqlite'), results);
|
||||
|
||||
const original = readTestResults(originalDb);
|
||||
const candidate = readTestResults(candidateDb);
|
||||
const comparison = compareResults(original, candidate);
|
||||
|
||||
console.log(` Equivalent: ${comparison.equivalent}`);
|
||||
|
||||
if (!comparison.equivalent) {
|
||||
console.log(' ❌ FAILED: Expected NaN values to be equivalent');
|
||||
return false;
|
||||
}
|
||||
console.log(' ✅ PASSED');
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Main test runner.
|
||||
*/
|
||||
function main() {
|
||||
console.log('='.repeat(60));
|
||||
console.log('E2E Comparison Test Suite');
|
||||
console.log('='.repeat(60));
|
||||
|
||||
// Setup
|
||||
if (fs.existsSync(TEST_DIR)) {
|
||||
fs.rmSync(TEST_DIR, { recursive: true });
|
||||
}
|
||||
fs.mkdirSync(TEST_DIR, { recursive: true });
|
||||
|
||||
const results = [];
|
||||
results.push(testIdenticalResults());
|
||||
results.push(testDifferentReturnValues());
|
||||
results.push(testComplexTypes());
|
||||
results.push(testFloatingPointTolerance());
|
||||
results.push(testNaNEquality());
|
||||
|
||||
// Cleanup
|
||||
fs.rmSync(TEST_DIR, { recursive: true });
|
||||
|
||||
// Summary
|
||||
console.log('\n' + '='.repeat(60));
|
||||
console.log('Summary');
|
||||
console.log('='.repeat(60));
|
||||
const passed = results.filter(r => r).length;
|
||||
const total = results.length;
|
||||
console.log(`Passed: ${passed}/${total}`);
|
||||
|
||||
if (passed === total) {
|
||||
console.log('\n✅ ALL TESTS PASSED');
|
||||
process.exit(0);
|
||||
} else {
|
||||
console.log('\n❌ SOME TESTS FAILED');
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
main();
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
const { fibonacci, isFibonacci, isPerfectSquare, fibonacciSequence } = require('../fibonacci');
|
||||
|
||||
describe('fibonacci', () => {
|
||||
test('returns 0 for n=0', () => {
|
||||
expect(fibonacci(0)).toBe(0);
|
||||
});
|
||||
|
||||
test('returns 1 for n=1', () => {
|
||||
expect(fibonacci(1)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 1 for n=2', () => {
|
||||
expect(fibonacci(2)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 5 for n=5', () => {
|
||||
expect(fibonacci(5)).toBe(5);
|
||||
});
|
||||
|
||||
test('returns 55 for n=10', () => {
|
||||
expect(fibonacci(10)).toBe(55);
|
||||
});
|
||||
|
||||
test('returns 233 for n=13', () => {
|
||||
expect(fibonacci(13)).toBe(233);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isFibonacci', () => {
|
||||
test('returns true for 0', () => {
|
||||
expect(isFibonacci(0)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 1', () => {
|
||||
expect(isFibonacci(1)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 8', () => {
|
||||
expect(isFibonacci(8)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 13', () => {
|
||||
expect(isFibonacci(13)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for 4', () => {
|
||||
expect(isFibonacci(4)).toBe(false);
|
||||
});
|
||||
|
||||
test('returns false for 6', () => {
|
||||
expect(isFibonacci(6)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPerfectSquare', () => {
|
||||
test('returns true for 0', () => {
|
||||
expect(isPerfectSquare(0)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 1', () => {
|
||||
expect(isPerfectSquare(1)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 4', () => {
|
||||
expect(isPerfectSquare(4)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 16', () => {
|
||||
expect(isPerfectSquare(16)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for 2', () => {
|
||||
expect(isPerfectSquare(2)).toBe(false);
|
||||
});
|
||||
|
||||
test('returns false for 3', () => {
|
||||
expect(isPerfectSquare(3)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fibonacciSequence', () => {
|
||||
test('returns empty array for n=0', () => {
|
||||
expect(fibonacciSequence(0)).toEqual([]);
|
||||
});
|
||||
|
||||
test('returns [0] for n=1', () => {
|
||||
expect(fibonacciSequence(1)).toEqual([0]);
|
||||
});
|
||||
|
||||
test('returns first 5 Fibonacci numbers', () => {
|
||||
expect(fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
|
||||
});
|
||||
|
||||
test('returns first 10 Fibonacci numbers', () => {
|
||||
expect(fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,281 @@
|
|||
#!/usr/bin/env node
|
||||
/**
|
||||
* Integration Test: Behavior Testing with Different Optimization Indices
|
||||
*
|
||||
* This script simulates the actual codeflash workflow:
|
||||
* 1. Run tests with CODEFLASH_LOOP_INDEX=1 (original code)
|
||||
* 2. Run tests with CODEFLASH_LOOP_INDEX=2 (optimized code)
|
||||
* 3. Read back both result files
|
||||
* 4. Compare using the comparator to verify equivalence
|
||||
*
|
||||
* Run directly: node tests/integration-behavior-test.js
|
||||
*/
|
||||
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { execSync } = require('child_process');
|
||||
|
||||
// Import our modules from npm package
|
||||
const { serialize, deserialize, getSerializerType, comparator } = require('codeflash');
|
||||
|
||||
// Test configuration
|
||||
const TEST_DIR = '/tmp/codeflash_integration_test';
|
||||
const ORIGINAL_RESULTS = path.join(TEST_DIR, 'original_results.bin');
|
||||
const OPTIMIZED_RESULTS = path.join(TEST_DIR, 'optimized_results.bin');
|
||||
|
||||
// Sample function to test - this simulates the "function being optimized"
|
||||
function processData(input) {
|
||||
// Original implementation
|
||||
const result = {
|
||||
numbers: input.numbers.map(n => n * 2),
|
||||
sum: input.numbers.reduce((a, b) => a + b, 0),
|
||||
metadata: new Map([
|
||||
['processed', true],
|
||||
['timestamp', new Date()],
|
||||
]),
|
||||
tags: new Set(input.tags || []),
|
||||
};
|
||||
return result;
|
||||
}
|
||||
|
||||
// "Optimized" version - same behavior, different implementation
|
||||
function processDataOptimized(input) {
|
||||
// Optimized implementation (same behavior)
|
||||
const doubled = [];
|
||||
let sum = 0;
|
||||
for (const n of input.numbers) {
|
||||
doubled.push(n * 2);
|
||||
sum += n;
|
||||
}
|
||||
return {
|
||||
numbers: doubled,
|
||||
sum,
|
||||
metadata: new Map([
|
||||
['processed', true],
|
||||
['timestamp', new Date()],
|
||||
]),
|
||||
tags: new Set(input.tags || []),
|
||||
};
|
||||
}
|
||||
|
||||
// Test cases
|
||||
const testCases = [
|
||||
{ numbers: [1, 2, 3], tags: ['a', 'b'] },
|
||||
{ numbers: [10, 20, 30, 40] },
|
||||
{ numbers: [-5, 0, 5], tags: ['negative', 'zero', 'positive'] },
|
||||
{ numbers: [1.5, 2.5, 3.5] },
|
||||
{ numbers: [] },
|
||||
];
|
||||
|
||||
// Helper to run a function and capture behavior
|
||||
function captureAllBehaviors(fn, inputs) {
|
||||
const results = [];
|
||||
for (const input of inputs) {
|
||||
try {
|
||||
const returnValue = fn(input);
|
||||
// Remove timestamp from metadata for comparison (it will differ)
|
||||
if (returnValue.metadata) {
|
||||
returnValue.metadata.delete('timestamp');
|
||||
}
|
||||
results.push({
|
||||
success: true,
|
||||
args: [input],
|
||||
kwargs: {},
|
||||
returnValue,
|
||||
});
|
||||
} catch (error) {
|
||||
results.push({
|
||||
success: false,
|
||||
args: [input],
|
||||
kwargs: {},
|
||||
error: { name: error.name, message: error.message },
|
||||
});
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
// Main test function
|
||||
async function runIntegrationTest() {
|
||||
console.log('='.repeat(60));
|
||||
console.log('Integration Test: Behavior Comparison');
|
||||
console.log('='.repeat(60));
|
||||
console.log(`Serializer type: ${getSerializerType()}`);
|
||||
console.log();
|
||||
|
||||
// Setup
|
||||
if (fs.existsSync(TEST_DIR)) {
|
||||
fs.rmSync(TEST_DIR, { recursive: true });
|
||||
}
|
||||
fs.mkdirSync(TEST_DIR, { recursive: true });
|
||||
|
||||
// Phase 1: Run "original" code (LOOP_INDEX=1)
|
||||
console.log('Phase 1: Capturing original behavior...');
|
||||
const originalBehaviors = captureAllBehaviors(processData, testCases);
|
||||
const originalSerialized = serialize(originalBehaviors);
|
||||
fs.writeFileSync(ORIGINAL_RESULTS, originalSerialized);
|
||||
console.log(` - Captured ${originalBehaviors.length} invocations`);
|
||||
console.log(` - Serialized size: ${originalSerialized.length} bytes`);
|
||||
console.log(` - Saved to: ${ORIGINAL_RESULTS}`);
|
||||
console.log();
|
||||
|
||||
// Phase 2: Run "optimized" code (LOOP_INDEX=2)
|
||||
console.log('Phase 2: Capturing optimized behavior...');
|
||||
const optimizedBehaviors = captureAllBehaviors(processDataOptimized, testCases);
|
||||
const optimizedSerialized = serialize(optimizedBehaviors);
|
||||
fs.writeFileSync(OPTIMIZED_RESULTS, optimizedSerialized);
|
||||
console.log(` - Captured ${optimizedBehaviors.length} invocations`);
|
||||
console.log(` - Serialized size: ${optimizedSerialized.length} bytes`);
|
||||
console.log(` - Saved to: ${OPTIMIZED_RESULTS}`);
|
||||
console.log();
|
||||
|
||||
// Phase 3: Read back and compare
|
||||
console.log('Phase 3: Comparing behaviors...');
|
||||
const originalRestored = deserialize(fs.readFileSync(ORIGINAL_RESULTS));
|
||||
const optimizedRestored = deserialize(fs.readFileSync(OPTIMIZED_RESULTS));
|
||||
|
||||
console.log(` - Original results restored: ${originalRestored.length} invocations`);
|
||||
console.log(` - Optimized results restored: ${optimizedRestored.length} invocations`);
|
||||
console.log();
|
||||
|
||||
// Compare each invocation
|
||||
let allEqual = true;
|
||||
const comparisonResults = [];
|
||||
|
||||
for (let i = 0; i < originalRestored.length; i++) {
|
||||
const orig = originalRestored[i];
|
||||
const opt = optimizedRestored[i];
|
||||
|
||||
// Compare the behavior tuples
|
||||
const isEqual = comparator(
|
||||
[orig.args, orig.kwargs, orig.returnValue],
|
||||
[opt.args, opt.kwargs, opt.returnValue]
|
||||
);
|
||||
|
||||
comparisonResults.push({
|
||||
invocation: i,
|
||||
isEqual,
|
||||
args: orig.args,
|
||||
});
|
||||
|
||||
if (!isEqual) {
|
||||
allEqual = false;
|
||||
console.log(` ❌ Invocation ${i}: DIFFERENT`);
|
||||
console.log(` Args: ${JSON.stringify(orig.args)}`);
|
||||
} else {
|
||||
console.log(` ✓ Invocation ${i}: EQUAL`);
|
||||
}
|
||||
}
|
||||
|
||||
console.log();
|
||||
console.log('='.repeat(60));
|
||||
if (allEqual) {
|
||||
console.log('✅ SUCCESS: All behaviors are equivalent!');
|
||||
console.log(' The optimization preserves correctness.');
|
||||
} else {
|
||||
console.log('❌ FAILURE: Some behaviors differ!');
|
||||
console.log(' The optimization changed the behavior.');
|
||||
}
|
||||
console.log('='.repeat(60));
|
||||
|
||||
// Cleanup
|
||||
fs.rmSync(TEST_DIR, { recursive: true });
|
||||
|
||||
// Return result for programmatic use
|
||||
return { success: allEqual, results: comparisonResults };
|
||||
}
|
||||
|
||||
// Also test with a "broken" optimization
|
||||
async function runBrokenOptimizationTest() {
|
||||
console.log();
|
||||
console.log('='.repeat(60));
|
||||
console.log('Testing detection of broken optimization...');
|
||||
console.log('='.repeat(60));
|
||||
|
||||
// Setup
|
||||
if (!fs.existsSync(TEST_DIR)) {
|
||||
fs.mkdirSync(TEST_DIR, { recursive: true });
|
||||
}
|
||||
|
||||
// Original function
|
||||
const original = (x) => x * 2;
|
||||
|
||||
// "Broken" optimized function
|
||||
const brokenOptimized = (x) => x * 2 + 1; // Bug: adds 1
|
||||
|
||||
const inputs = [1, 5, 10, 100];
|
||||
|
||||
// Capture original
|
||||
const originalResults = inputs.map(x => ({
|
||||
args: [x],
|
||||
kwargs: {},
|
||||
returnValue: original(x),
|
||||
}));
|
||||
|
||||
// Capture broken optimized
|
||||
const brokenResults = inputs.map(x => ({
|
||||
args: [x],
|
||||
kwargs: {},
|
||||
returnValue: brokenOptimized(x),
|
||||
}));
|
||||
|
||||
// Serialize
|
||||
const originalSerialized = serialize(originalResults);
|
||||
const brokenSerialized = serialize(brokenResults);
|
||||
|
||||
// Compare
|
||||
const originalRestored = deserialize(originalSerialized);
|
||||
const brokenRestored = deserialize(brokenSerialized);
|
||||
|
||||
let detectedBug = false;
|
||||
for (let i = 0; i < originalRestored.length; i++) {
|
||||
const isEqual = comparator(
|
||||
[originalRestored[i].args, {}, originalRestored[i].returnValue],
|
||||
[brokenRestored[i].args, {}, brokenRestored[i].returnValue]
|
||||
);
|
||||
if (!isEqual) {
|
||||
detectedBug = true;
|
||||
console.log(` ❌ Invocation ${i}: Difference detected`);
|
||||
console.log(` Input: ${originalRestored[i].args[0]}`);
|
||||
console.log(` Original: ${originalRestored[i].returnValue}`);
|
||||
console.log(` Broken: ${brokenRestored[i].returnValue}`);
|
||||
}
|
||||
}
|
||||
|
||||
console.log();
|
||||
if (detectedBug) {
|
||||
console.log('✅ SUCCESS: Bug in optimization was detected!');
|
||||
} else {
|
||||
console.log('❌ FAILURE: Bug was not detected!');
|
||||
}
|
||||
console.log('='.repeat(60));
|
||||
|
||||
// Cleanup
|
||||
if (fs.existsSync(TEST_DIR)) {
|
||||
fs.rmSync(TEST_DIR, { recursive: true });
|
||||
}
|
||||
|
||||
return { success: detectedBug };
|
||||
}
|
||||
|
||||
// Run tests
|
||||
async function main() {
|
||||
try {
|
||||
const result1 = await runIntegrationTest();
|
||||
const result2 = await runBrokenOptimizationTest();
|
||||
|
||||
console.log();
|
||||
console.log('='.repeat(60));
|
||||
console.log('FINAL SUMMARY');
|
||||
console.log('='.repeat(60));
|
||||
console.log(`Correct optimization test: ${result1.success ? 'PASS' : 'FAIL'}`);
|
||||
console.log(`Broken optimization detection: ${result2.success ? 'PASS' : 'FAIL'}`);
|
||||
|
||||
process.exit(result1.success && result2.success ? 0 : 1);
|
||||
} catch (error) {
|
||||
console.error('Test failed with error:', error);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
main();
|
||||
|
|
@ -0,0 +1,294 @@
|
|||
#!/usr/bin/env node
|
||||
/**
|
||||
* Codeflash Jest Loop Runner
|
||||
*
|
||||
* This script runs Jest tests multiple times to collect stable performance measurements.
|
||||
* It mimics the Python pytest_plugin.py looping behavior.
|
||||
*
|
||||
* Usage:
|
||||
* node loop-runner.js <test-file> [options]
|
||||
*
|
||||
* Options:
|
||||
* --min-loops=N Minimum loops to run (default: 5)
|
||||
* --max-loops=N Maximum loops to run (default: 100000)
|
||||
* --duration=N Target duration in seconds (default: 10)
|
||||
* --stability-check Enable stability-based early stopping
|
||||
*/
|
||||
|
||||
const { spawn } = require('child_process');
|
||||
const path = require('path');
|
||||
|
||||
// Configuration
|
||||
const DEFAULT_MIN_LOOPS = 5;
|
||||
const DEFAULT_MAX_LOOPS = 100000;
|
||||
const DEFAULT_DURATION_SECONDS = 10;
|
||||
const STABILITY_WINDOW_SIZE = 0.35;
|
||||
const STABILITY_CENTER_TOLERANCE = 0.0025;
|
||||
const STABILITY_SPREAD_TOLERANCE = 0.0025;
|
||||
|
||||
/**
|
||||
* Parse timing data from Jest stdout.
|
||||
* Looks for patterns like: !######test:func:1:lineId_0:123456######!
|
||||
* where 123456 is the duration in nanoseconds.
|
||||
*/
|
||||
function parseTimingFromStdout(stdout) {
|
||||
const timings = new Map(); // Map<testId, number[]>
|
||||
const pattern = /!######([^:]+):([^:]*):([^:]+):([^:]+):(\d+_\d+):(\d+)######!/g;
|
||||
|
||||
let match;
|
||||
while ((match = pattern.exec(stdout)) !== null) {
|
||||
const [, testModule, testClass, testFunc, funcName, invocationId, durationNs] = match;
|
||||
const testId = `${testModule}:${testClass}:${testFunc}:${funcName}:${invocationId}`;
|
||||
|
||||
if (!timings.has(testId)) {
|
||||
timings.set(testId, []);
|
||||
}
|
||||
timings.get(testId).push(parseInt(durationNs, 10));
|
||||
}
|
||||
|
||||
return timings;
|
||||
}
|
||||
|
||||
/**
|
||||
* Run Jest once and return timing data.
|
||||
*/
|
||||
async function runJestOnce(testFile, loopIndex, timeout, cwd) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const env = {
|
||||
...process.env,
|
||||
CODEFLASH_LOOP_INDEX: String(loopIndex),
|
||||
};
|
||||
|
||||
const jestArgs = [
|
||||
'jest',
|
||||
testFile,
|
||||
'--runInBand',
|
||||
'--forceExit',
|
||||
`--testTimeout=${timeout * 1000}`,
|
||||
];
|
||||
|
||||
const proc = spawn('npx', jestArgs, {
|
||||
cwd,
|
||||
env,
|
||||
stdio: ['pipe', 'pipe', 'pipe'],
|
||||
});
|
||||
|
||||
let stdout = '';
|
||||
let stderr = '';
|
||||
|
||||
proc.stdout.on('data', (data) => {
|
||||
stdout += data.toString();
|
||||
});
|
||||
|
||||
proc.stderr.on('data', (data) => {
|
||||
stderr += data.toString();
|
||||
});
|
||||
|
||||
proc.on('close', (code) => {
|
||||
resolve({
|
||||
code,
|
||||
stdout,
|
||||
stderr,
|
||||
timings: parseTimingFromStdout(stdout),
|
||||
});
|
||||
});
|
||||
|
||||
proc.on('error', reject);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if performance has stabilized.
|
||||
* Implements the same stability check as Python's pytest_plugin.
|
||||
*/
|
||||
function shouldStopForStability(allTimings, windowSize) {
|
||||
// Get total runtime for each loop
|
||||
const loopTotals = [];
|
||||
for (const [loopIndex, timings] of allTimings.entries()) {
|
||||
let total = 0;
|
||||
for (const durations of timings.values()) {
|
||||
total += Math.min(...durations);
|
||||
}
|
||||
loopTotals.push(total);
|
||||
}
|
||||
|
||||
if (loopTotals.length < windowSize) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get recent window
|
||||
const window = loopTotals.slice(-windowSize);
|
||||
|
||||
// Check center tolerance (all values within ±0.25% of median)
|
||||
const sorted = [...window].sort((a, b) => a - b);
|
||||
const median = sorted[Math.floor(sorted.length / 2)];
|
||||
const centerTolerance = median * STABILITY_CENTER_TOLERANCE;
|
||||
|
||||
const withinCenter = window.every(v => Math.abs(v - median) <= centerTolerance);
|
||||
|
||||
// Check spread tolerance (max-min ≤ 0.25% of min)
|
||||
const minVal = Math.min(...window);
|
||||
const maxVal = Math.max(...window);
|
||||
const spreadTolerance = minVal * STABILITY_SPREAD_TOLERANCE;
|
||||
const withinSpread = (maxVal - minVal) <= spreadTolerance;
|
||||
|
||||
return withinCenter && withinSpread;
|
||||
}
|
||||
|
||||
/**
|
||||
* Main loop runner.
|
||||
*/
|
||||
async function runLoopedTests(testFile, options = {}) {
|
||||
const minLoops = options.minLoops || DEFAULT_MIN_LOOPS;
|
||||
const maxLoops = options.maxLoops || DEFAULT_MAX_LOOPS;
|
||||
const durationSeconds = options.durationSeconds || DEFAULT_DURATION_SECONDS;
|
||||
const stabilityCheck = options.stabilityCheck !== false;
|
||||
const timeout = options.timeout || 15;
|
||||
const cwd = options.cwd || process.cwd();
|
||||
|
||||
console.log(`[codeflash-loop-runner] Starting looped test execution`);
|
||||
console.log(` Test file: ${testFile}`);
|
||||
console.log(` Min loops: ${minLoops}`);
|
||||
console.log(` Max loops: ${maxLoops}`);
|
||||
console.log(` Duration: ${durationSeconds}s`);
|
||||
console.log(` Stability check: ${stabilityCheck}`);
|
||||
console.log('');
|
||||
|
||||
const startTime = Date.now();
|
||||
const allTimings = new Map(); // Map<loopIndex, Map<testId, number[]>>
|
||||
let loopCount = 0;
|
||||
let lastExitCode = 0;
|
||||
|
||||
while (true) {
|
||||
loopCount++;
|
||||
const loopStart = Date.now();
|
||||
|
||||
console.log(`[loop ${loopCount}] Running...`);
|
||||
|
||||
const result = await runJestOnce(testFile, loopCount, timeout, cwd);
|
||||
lastExitCode = result.code;
|
||||
|
||||
// Store timings for this loop
|
||||
allTimings.set(loopCount, result.timings);
|
||||
|
||||
const loopDuration = Date.now() - loopStart;
|
||||
const totalElapsed = (Date.now() - startTime) / 1000;
|
||||
|
||||
// Count timing entries
|
||||
let timingCount = 0;
|
||||
for (const durations of result.timings.values()) {
|
||||
timingCount += durations.length;
|
||||
}
|
||||
|
||||
console.log(`[loop ${loopCount}] Completed in ${loopDuration}ms, ${timingCount} timing entries`);
|
||||
|
||||
// Check stopping conditions
|
||||
if (loopCount >= maxLoops) {
|
||||
console.log(`[codeflash-loop-runner] Reached max loops (${maxLoops})`);
|
||||
break;
|
||||
}
|
||||
|
||||
if (loopCount >= minLoops && totalElapsed >= durationSeconds) {
|
||||
console.log(`[codeflash-loop-runner] Reached duration limit (${durationSeconds}s)`);
|
||||
break;
|
||||
}
|
||||
|
||||
// Stability check
|
||||
if (stabilityCheck && loopCount >= minLoops) {
|
||||
const estimatedTotalLoops = Math.floor((durationSeconds / totalElapsed) * loopCount);
|
||||
const windowSize = Math.max(3, Math.floor(STABILITY_WINDOW_SIZE * estimatedTotalLoops));
|
||||
|
||||
if (shouldStopForStability(allTimings, windowSize)) {
|
||||
console.log(`[codeflash-loop-runner] Performance stabilized after ${loopCount} loops`);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Aggregate results
|
||||
const aggregatedTimings = new Map(); // Map<testId, {min, max, avg, count}>
|
||||
|
||||
for (const [loopIndex, timings] of allTimings.entries()) {
|
||||
for (const [testId, durations] of timings.entries()) {
|
||||
if (!aggregatedTimings.has(testId)) {
|
||||
aggregatedTimings.set(testId, { values: [], min: Infinity, max: 0, sum: 0, count: 0 });
|
||||
}
|
||||
const agg = aggregatedTimings.get(testId);
|
||||
for (const d of durations) {
|
||||
agg.values.push(d);
|
||||
agg.min = Math.min(agg.min, d);
|
||||
agg.max = Math.max(agg.max, d);
|
||||
agg.sum += d;
|
||||
agg.count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Print summary
|
||||
console.log('');
|
||||
console.log('=== Performance Summary ===');
|
||||
console.log(`Total loops: ${loopCount}`);
|
||||
console.log(`Total time: ${((Date.now() - startTime) / 1000).toFixed(2)}s`);
|
||||
console.log('');
|
||||
|
||||
for (const [testId, agg] of aggregatedTimings.entries()) {
|
||||
const avg = agg.sum / agg.count;
|
||||
console.log(`${testId}:`);
|
||||
console.log(` Min: ${(agg.min / 1000).toFixed(2)} μs`);
|
||||
console.log(` Max: ${(agg.max / 1000).toFixed(2)} μs`);
|
||||
console.log(` Avg: ${(avg / 1000).toFixed(2)} μs`);
|
||||
console.log(` Samples: ${agg.count}`);
|
||||
}
|
||||
|
||||
return {
|
||||
loopCount,
|
||||
allTimings,
|
||||
aggregatedTimings,
|
||||
exitCode: lastExitCode,
|
||||
};
|
||||
}
|
||||
|
||||
// CLI interface
|
||||
if (require.main === module) {
|
||||
const args = process.argv.slice(2);
|
||||
|
||||
if (args.length === 0 || args[0] === '--help') {
|
||||
console.log('Usage: node loop-runner.js <test-file> [options]');
|
||||
console.log('');
|
||||
console.log('Options:');
|
||||
console.log(' --min-loops=N Minimum loops to run (default: 5)');
|
||||
console.log(' --max-loops=N Maximum loops to run (default: 100000)');
|
||||
console.log(' --duration=N Target duration in seconds (default: 10)');
|
||||
console.log(' --stability-check Enable stability-based early stopping');
|
||||
console.log(' --cwd=PATH Working directory for Jest');
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
const testFile = args[0];
|
||||
const options = {};
|
||||
|
||||
for (const arg of args.slice(1)) {
|
||||
if (arg.startsWith('--min-loops=')) {
|
||||
options.minLoops = parseInt(arg.split('=')[1], 10);
|
||||
} else if (arg.startsWith('--max-loops=')) {
|
||||
options.maxLoops = parseInt(arg.split('=')[1], 10);
|
||||
} else if (arg.startsWith('--duration=')) {
|
||||
options.durationSeconds = parseFloat(arg.split('=')[1]);
|
||||
} else if (arg === '--stability-check') {
|
||||
options.stabilityCheck = true;
|
||||
} else if (arg.startsWith('--cwd=')) {
|
||||
options.cwd = arg.split('=')[1];
|
||||
}
|
||||
}
|
||||
|
||||
runLoopedTests(testFile, options)
|
||||
.then((result) => {
|
||||
process.exit(result.exitCode);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Error:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
}
|
||||
|
||||
module.exports = { runLoopedTests, parseTimingFromStdout };
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Test for session-level looping performance measurement.
|
||||
*
|
||||
* Note: Looping is now done at the session level by Python (test_runner.py)
|
||||
* which runs Jest multiple times. Each Jest run executes the test once,
|
||||
* and timing data is aggregated across runs for stability checking.
|
||||
*/
|
||||
|
||||
// Load the codeflash helper from npm package
|
||||
const codeflash = require('codeflash');
|
||||
|
||||
// Simple function to test
|
||||
function fibonacci(n) {
|
||||
if (n <= 1) return n;
|
||||
let a = 0, b = 1;
|
||||
for (let i = 2; i <= n; i++) {
|
||||
const temp = a + b;
|
||||
a = b;
|
||||
b = temp;
|
||||
}
|
||||
return b;
|
||||
}
|
||||
|
||||
describe('Session-Level Looping Performance Test', () => {
|
||||
test('fibonacci(20) with session-level looping', () => {
|
||||
// Looping is controlled by Python via CODEFLASH_LOOP_INDEX env var
|
||||
const result = codeflash.capturePerf('fibonacci', '10', fibonacci, 20);
|
||||
expect(result).toBe(6765);
|
||||
});
|
||||
|
||||
test('fibonacci(30) with session-level looping', () => {
|
||||
const result = codeflash.capturePerf('fibonacci', '16', fibonacci, 30);
|
||||
expect(result).toBe(832040);
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Sample performance test to verify looping mechanism.
|
||||
*/
|
||||
|
||||
// Load the codeflash helper from npm package
|
||||
const codeflash = require('codeflash');
|
||||
|
||||
// Simple function to test
|
||||
function fibonacci(n) {
|
||||
if (n <= 1) return n;
|
||||
let a = 0, b = 1;
|
||||
for (let i = 2; i <= n; i++) {
|
||||
const temp = a + b;
|
||||
a = b;
|
||||
b = temp;
|
||||
}
|
||||
return b;
|
||||
}
|
||||
|
||||
describe('Looping Performance Test', () => {
|
||||
test('fibonacci(20) timing', () => {
|
||||
const result = codeflash.capturePerf('fibonacci', '10', fibonacci, 20);
|
||||
expect(result).toBe(6765);
|
||||
});
|
||||
|
||||
test('fibonacci(30) timing', () => {
|
||||
const result = codeflash.capturePerf('fibonacci', '16', fibonacci, 30);
|
||||
expect(result).toBe(832040);
|
||||
});
|
||||
|
||||
test('multiple calls in one test', () => {
|
||||
// Same lineId, multiple calls - should increment invocation counter
|
||||
const r1 = codeflash.capturePerf('fibonacci', '22', fibonacci, 5);
|
||||
const r2 = codeflash.capturePerf('fibonacci', '22', fibonacci, 10);
|
||||
const r3 = codeflash.capturePerf('fibonacci', '22', fibonacci, 15);
|
||||
|
||||
expect(r1).toBe(5);
|
||||
expect(r2).toBe(55);
|
||||
expect(r3).toBe(610);
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
const {
|
||||
reverseString,
|
||||
isPalindrome,
|
||||
countOccurrences,
|
||||
longestCommonPrefix,
|
||||
toTitleCase
|
||||
} = require('../string_utils');
|
||||
|
||||
describe('reverseString', () => {
|
||||
test('reverses a simple string', () => {
|
||||
expect(reverseString('hello')).toBe('olleh');
|
||||
});
|
||||
|
||||
test('returns empty string for empty input', () => {
|
||||
expect(reverseString('')).toBe('');
|
||||
});
|
||||
|
||||
test('handles single character', () => {
|
||||
expect(reverseString('a')).toBe('a');
|
||||
});
|
||||
|
||||
test('handles palindrome', () => {
|
||||
expect(reverseString('radar')).toBe('radar');
|
||||
});
|
||||
|
||||
test('handles spaces', () => {
|
||||
expect(reverseString('hello world')).toBe('dlrow olleh');
|
||||
});
|
||||
|
||||
test('reverses a longer string for performance', () => {
|
||||
const input = 'abcdefghijklmnopqrstuvwxyz'.repeat(20);
|
||||
const result = reverseString(input);
|
||||
expect(result.length).toBe(input.length);
|
||||
expect(result[0]).toBe('z');
|
||||
expect(result[result.length - 1]).toBe('a');
|
||||
});
|
||||
|
||||
test('reverses a medium string', () => {
|
||||
const input = 'The quick brown fox jumps over the lazy dog';
|
||||
const expected = 'god yzal eht revo spmuj xof nworb kciuq ehT';
|
||||
expect(reverseString(input)).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPalindrome', () => {
|
||||
test('returns true for simple palindrome', () => {
|
||||
expect(isPalindrome('radar')).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for palindrome with mixed case', () => {
|
||||
expect(isPalindrome('RaceCar')).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for palindrome with spaces and punctuation', () => {
|
||||
expect(isPalindrome('A man, a plan, a canal: Panama')).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for non-palindrome', () => {
|
||||
expect(isPalindrome('hello')).toBe(false);
|
||||
});
|
||||
|
||||
test('returns true for empty string', () => {
|
||||
expect(isPalindrome('')).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for single character', () => {
|
||||
expect(isPalindrome('a')).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('countOccurrences', () => {
|
||||
test('counts single occurrence', () => {
|
||||
expect(countOccurrences('hello', 'ell')).toBe(1);
|
||||
});
|
||||
|
||||
test('counts multiple occurrences', () => {
|
||||
expect(countOccurrences('abababab', 'ab')).toBe(4);
|
||||
});
|
||||
|
||||
test('returns 0 for no occurrences', () => {
|
||||
expect(countOccurrences('hello', 'xyz')).toBe(0);
|
||||
});
|
||||
|
||||
test('handles overlapping matches', () => {
|
||||
expect(countOccurrences('aaa', 'aa')).toBe(2);
|
||||
});
|
||||
|
||||
test('handles empty substring', () => {
|
||||
expect(countOccurrences('hello', '')).toBe(6);
|
||||
});
|
||||
});
|
||||
|
||||
describe('longestCommonPrefix', () => {
|
||||
test('finds common prefix', () => {
|
||||
expect(longestCommonPrefix(['flower', 'flow', 'flight'])).toBe('fl');
|
||||
});
|
||||
|
||||
test('returns empty for no common prefix', () => {
|
||||
expect(longestCommonPrefix(['dog', 'racecar', 'car'])).toBe('');
|
||||
});
|
||||
|
||||
test('returns empty for empty array', () => {
|
||||
expect(longestCommonPrefix([])).toBe('');
|
||||
});
|
||||
|
||||
test('returns the string for single element array', () => {
|
||||
expect(longestCommonPrefix(['hello'])).toBe('hello');
|
||||
});
|
||||
|
||||
test('handles identical strings', () => {
|
||||
expect(longestCommonPrefix(['test', 'test', 'test'])).toBe('test');
|
||||
});
|
||||
});
|
||||
|
||||
describe('toTitleCase', () => {
|
||||
test('converts simple string', () => {
|
||||
expect(toTitleCase('hello world')).toBe('Hello World');
|
||||
});
|
||||
|
||||
test('handles already title case', () => {
|
||||
expect(toTitleCase('Hello World')).toBe('Hello World');
|
||||
});
|
||||
|
||||
test('handles uppercase input', () => {
|
||||
expect(toTitleCase('HELLO WORLD')).toBe('Hello World');
|
||||
});
|
||||
|
||||
test('handles single word', () => {
|
||||
expect(toTitleCase('hello')).toBe('Hello');
|
||||
});
|
||||
|
||||
test('handles empty string', () => {
|
||||
expect(toTitleCase('')).toBe('');
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Codeflash Configuration for CommonJS JavaScript Project
|
||||
module_root: "."
|
||||
tests_root: "tests"
|
||||
test_framework: "jest"
|
||||
formatter_cmds: []
|
||||
60
code_to_optimize/js/code_to_optimize_js_cjs/fibonacci.js
Normal file
60
code_to_optimize/js/code_to_optimize_js_cjs/fibonacci.js
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Fibonacci implementations - CommonJS module
|
||||
* Intentionally inefficient for optimization testing.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Calculate the nth Fibonacci number using naive recursion.
|
||||
* This is intentionally slow to demonstrate optimization potential.
|
||||
* @param {number} n - The index of the Fibonacci number to calculate
|
||||
* @returns {number} The nth Fibonacci number
|
||||
*/
|
||||
function fibonacci(n) {
|
||||
if (n <= 1) {
|
||||
return n;
|
||||
}
|
||||
return fibonacci(n - 1) + fibonacci(n - 2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a Fibonacci number.
|
||||
* @param {number} num - The number to check
|
||||
* @returns {boolean} True if num is a Fibonacci number
|
||||
*/
|
||||
function isFibonacci(num) {
|
||||
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
|
||||
const check1 = 5 * num * num + 4;
|
||||
const check2 = 5 * num * num - 4;
|
||||
return isPerfectSquare(check1) || isPerfectSquare(check2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a perfect square.
|
||||
* @param {number} n - The number to check
|
||||
* @returns {boolean} True if n is a perfect square
|
||||
*/
|
||||
function isPerfectSquare(n) {
|
||||
const sqrt = Math.sqrt(n);
|
||||
return sqrt === Math.floor(sqrt);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an array of Fibonacci numbers up to n.
|
||||
* @param {number} n - The number of Fibonacci numbers to generate
|
||||
* @returns {number[]} Array of Fibonacci numbers
|
||||
*/
|
||||
function fibonacciSequence(n) {
|
||||
const result = [];
|
||||
for (let i = 0; i < n; i++) {
|
||||
result.push(fibonacci(i));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// CommonJS exports
|
||||
module.exports = {
|
||||
fibonacci,
|
||||
isFibonacci,
|
||||
isPerfectSquare,
|
||||
fibonacciSequence,
|
||||
};
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Fibonacci Calculator Class - CommonJS module
|
||||
* Intentionally inefficient for optimization testing.
|
||||
*/
|
||||
|
||||
class FibonacciCalculator {
|
||||
constructor() {
|
||||
// No initialization needed
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the nth Fibonacci number using naive recursion.
|
||||
* This is intentionally slow to demonstrate optimization potential.
|
||||
* @param {number} n - The index of the Fibonacci number to calculate
|
||||
* @returns {number} The nth Fibonacci number
|
||||
*/
|
||||
fibonacci(n) {
|
||||
if (n <= 1) {
|
||||
return n;
|
||||
}
|
||||
return this.fibonacci(n - 1) + this.fibonacci(n - 2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a Fibonacci number.
|
||||
* @param {number} num - The number to check
|
||||
* @returns {boolean} True if num is a Fibonacci number
|
||||
*/
|
||||
isFibonacci(num) {
|
||||
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
|
||||
const check1 = 5 * num * num + 4;
|
||||
const check2 = 5 * num * num - 4;
|
||||
return this.isPerfectSquare(check1) || this.isPerfectSquare(check2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a perfect square.
|
||||
* @param {number} n - The number to check
|
||||
* @returns {boolean} True if n is a perfect square
|
||||
*/
|
||||
isPerfectSquare(n) {
|
||||
const sqrt = Math.sqrt(n);
|
||||
return sqrt === Math.floor(sqrt);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an array of Fibonacci numbers up to n.
|
||||
* @param {number} n - The number of Fibonacci numbers to generate
|
||||
* @returns {number[]} Array of Fibonacci numbers
|
||||
*/
|
||||
fibonacciSequence(n) {
|
||||
const result = [];
|
||||
for (let i = 0; i < n; i++) {
|
||||
result.push(this.fibonacci(i));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// CommonJS exports
|
||||
module.exports = { FibonacciCalculator };
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
module.exports = {
|
||||
testEnvironment: 'node',
|
||||
testMatch: ['**/tests/**/*.test.js'],
|
||||
reporters: ['default', ['jest-junit', { outputDirectory: '.codeflash' }]],
|
||||
verbose: true,
|
||||
};
|
||||
3731
code_to_optimize/js/code_to_optimize_js_cjs/package-lock.json
generated
Normal file
3731
code_to_optimize/js/code_to_optimize_js_cjs/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load diff
14
code_to_optimize/js/code_to_optimize_js_cjs/package.json
Normal file
14
code_to_optimize/js/code_to_optimize_js_cjs/package.json
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"name": "code-to-optimize-js-cjs",
|
||||
"version": "1.0.0",
|
||||
"description": "CommonJS JavaScript test project for Codeflash E2E testing",
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"test": "jest"
|
||||
},
|
||||
"devDependencies": {
|
||||
"codeflash": "file:../../../packages/codeflash",
|
||||
"jest": "^29.7.0",
|
||||
"jest-junit": "^16.0.0"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Tests for Fibonacci functions - CommonJS module
|
||||
*/
|
||||
const { fibonacci, isFibonacci, isPerfectSquare, fibonacciSequence } = require('../fibonacci');
|
||||
|
||||
describe('fibonacci', () => {
|
||||
test('returns 0 for n=0', () => {
|
||||
expect(fibonacci(0)).toBe(0);
|
||||
});
|
||||
|
||||
test('returns 1 for n=1', () => {
|
||||
expect(fibonacci(1)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 1 for n=2', () => {
|
||||
expect(fibonacci(2)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 5 for n=5', () => {
|
||||
expect(fibonacci(5)).toBe(5);
|
||||
});
|
||||
|
||||
test('returns 55 for n=10', () => {
|
||||
expect(fibonacci(10)).toBe(55);
|
||||
});
|
||||
|
||||
test('returns 233 for n=13', () => {
|
||||
expect(fibonacci(13)).toBe(233);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isFibonacci', () => {
|
||||
test('returns true for Fibonacci numbers', () => {
|
||||
expect(isFibonacci(0)).toBe(true);
|
||||
expect(isFibonacci(1)).toBe(true);
|
||||
expect(isFibonacci(5)).toBe(true);
|
||||
expect(isFibonacci(8)).toBe(true);
|
||||
expect(isFibonacci(13)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for non-Fibonacci numbers', () => {
|
||||
expect(isFibonacci(4)).toBe(false);
|
||||
expect(isFibonacci(6)).toBe(false);
|
||||
expect(isFibonacci(7)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPerfectSquare', () => {
|
||||
test('returns true for perfect squares', () => {
|
||||
expect(isPerfectSquare(0)).toBe(true);
|
||||
expect(isPerfectSquare(1)).toBe(true);
|
||||
expect(isPerfectSquare(4)).toBe(true);
|
||||
expect(isPerfectSquare(9)).toBe(true);
|
||||
expect(isPerfectSquare(16)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for non-perfect squares', () => {
|
||||
expect(isPerfectSquare(2)).toBe(false);
|
||||
expect(isPerfectSquare(3)).toBe(false);
|
||||
expect(isPerfectSquare(5)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fibonacciSequence', () => {
|
||||
test('returns empty array for n=0', () => {
|
||||
expect(fibonacciSequence(0)).toEqual([]);
|
||||
});
|
||||
|
||||
test('returns first 5 Fibonacci numbers', () => {
|
||||
expect(fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
|
||||
});
|
||||
|
||||
test('returns first 10 Fibonacci numbers', () => {
|
||||
expect(fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
const { FibonacciCalculator } = require('../fibonacci_class');
|
||||
|
||||
describe('FibonacciCalculator', () => {
|
||||
let calc;
|
||||
|
||||
beforeEach(() => {
|
||||
calc = new FibonacciCalculator();
|
||||
});
|
||||
|
||||
describe('fibonacci', () => {
|
||||
test('returns 0 for n=0', () => {
|
||||
expect(calc.fibonacci(0)).toBe(0);
|
||||
});
|
||||
|
||||
test('returns 1 for n=1', () => {
|
||||
expect(calc.fibonacci(1)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 1 for n=2', () => {
|
||||
expect(calc.fibonacci(2)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 5 for n=5', () => {
|
||||
expect(calc.fibonacci(5)).toBe(5);
|
||||
});
|
||||
|
||||
test('returns 55 for n=10', () => {
|
||||
expect(calc.fibonacci(10)).toBe(55);
|
||||
});
|
||||
|
||||
test('returns 233 for n=13', () => {
|
||||
expect(calc.fibonacci(13)).toBe(233);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isFibonacci', () => {
|
||||
test('returns true for 0', () => {
|
||||
expect(calc.isFibonacci(0)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 1', () => {
|
||||
expect(calc.isFibonacci(1)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 8', () => {
|
||||
expect(calc.isFibonacci(8)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 13', () => {
|
||||
expect(calc.isFibonacci(13)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for 4', () => {
|
||||
expect(calc.isFibonacci(4)).toBe(false);
|
||||
});
|
||||
|
||||
test('returns false for 6', () => {
|
||||
expect(calc.isFibonacci(6)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPerfectSquare', () => {
|
||||
test('returns true for 0', () => {
|
||||
expect(calc.isPerfectSquare(0)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 1', () => {
|
||||
expect(calc.isPerfectSquare(1)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 4', () => {
|
||||
expect(calc.isPerfectSquare(4)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 16', () => {
|
||||
expect(calc.isPerfectSquare(16)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for 2', () => {
|
||||
expect(calc.isPerfectSquare(2)).toBe(false);
|
||||
});
|
||||
|
||||
test('returns false for 3', () => {
|
||||
expect(calc.isPerfectSquare(3)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fibonacciSequence', () => {
|
||||
test('returns empty array for n=0', () => {
|
||||
expect(calc.fibonacciSequence(0)).toEqual([]);
|
||||
});
|
||||
|
||||
test('returns [0] for n=1', () => {
|
||||
expect(calc.fibonacciSequence(1)).toEqual([0]);
|
||||
});
|
||||
|
||||
test('returns first 5 Fibonacci numbers', () => {
|
||||
expect(calc.fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
|
||||
});
|
||||
|
||||
test('returns first 10 Fibonacci numbers', () => {
|
||||
expect(calc.fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
|
||||
});
|
||||
});
|
||||
});
|
||||
64
code_to_optimize/js/code_to_optimize_js_esm/async_utils.js
Normal file
64
code_to_optimize/js/code_to_optimize_js_esm/async_utils.js
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Async utility functions - ES Module version.
|
||||
* Contains intentionally inefficient implementations for optimization testing.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Simulate a delay (for testing purposes).
|
||||
* @param {number} ms - Milliseconds to delay
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
export function delay(ms) {
|
||||
return new Promise(resolve => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
/**
|
||||
* Process items sequentially when they could be parallel.
|
||||
* Intentionally inefficient - processes items one at a time.
|
||||
* @param {any[]} items - Items to process
|
||||
* @param {function} processor - Async function to process each item
|
||||
* @returns {Promise<any[]>} Processed results
|
||||
*/
|
||||
export async function processItemsSequential(items, processor) {
|
||||
const results = [];
|
||||
for (let i = 0; i < items.length; i++) {
|
||||
const result = await processor(items[i]);
|
||||
results.push(result);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Map over items asynchronously with a concurrency limit.
|
||||
* Intentionally simple/inefficient implementation - ignores concurrency.
|
||||
* @param {any[]} items - Items to process
|
||||
* @param {function} mapper - Async mapper function
|
||||
* @param {number} concurrency - Max concurrent operations (currently ignored)
|
||||
* @returns {Promise<any[]>} Mapped results
|
||||
*/
|
||||
export async function asyncMap(items, mapper, concurrency = 1) {
|
||||
// Inefficient: ignores concurrency, processes sequentially
|
||||
const results = [];
|
||||
for (const item of items) {
|
||||
results.push(await mapper(item));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter items asynchronously.
|
||||
* Inefficient implementation that processes items one by one.
|
||||
* @param {any[]} items - Items to filter
|
||||
* @param {function} predicate - Async predicate function
|
||||
* @returns {Promise<any[]>} Filtered items
|
||||
*/
|
||||
export async function asyncFilter(items, predicate) {
|
||||
const results = [];
|
||||
for (const item of items) {
|
||||
const shouldInclude = await predicate(item);
|
||||
if (shouldInclude) {
|
||||
results.push(item);
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Codeflash Configuration for ES Module JavaScript Project
|
||||
module_root: "."
|
||||
tests_root: "tests"
|
||||
test_framework: "jest"
|
||||
formatter_cmds: []
|
||||
52
code_to_optimize/js/code_to_optimize_js_esm/fibonacci.js
Normal file
52
code_to_optimize/js/code_to_optimize_js_esm/fibonacci.js
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Fibonacci implementations - ES Module
|
||||
* Intentionally inefficient for optimization testing.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Calculate the nth Fibonacci number using naive recursion.
|
||||
* This is intentionally slow to demonstrate optimization potential.
|
||||
* @param {number} n - The index of the Fibonacci number to calculate
|
||||
* @returns {number} The nth Fibonacci number
|
||||
*/
|
||||
export function fibonacci(n) {
|
||||
if (n <= 1) {
|
||||
return n;
|
||||
}
|
||||
return fibonacci(n - 1) + fibonacci(n - 2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a Fibonacci number.
|
||||
* @param {number} num - The number to check
|
||||
* @returns {boolean} True if num is a Fibonacci number
|
||||
*/
|
||||
export function isFibonacci(num) {
|
||||
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
|
||||
const check1 = 5 * num * num + 4;
|
||||
const check2 = 5 * num * num - 4;
|
||||
return isPerfectSquare(check1) || isPerfectSquare(check2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a perfect square.
|
||||
* @param {number} n - The number to check
|
||||
* @returns {boolean} True if n is a perfect square
|
||||
*/
|
||||
export function isPerfectSquare(n) {
|
||||
const sqrt = Math.sqrt(n);
|
||||
return sqrt === Math.floor(sqrt);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an array of Fibonacci numbers up to n.
|
||||
* @param {number} n - The number of Fibonacci numbers to generate
|
||||
* @returns {number[]} Array of Fibonacci numbers
|
||||
*/
|
||||
export function fibonacciSequence(n) {
|
||||
const result = [];
|
||||
for (let i = 0; i < n; i++) {
|
||||
result.push(fibonacci(i));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
11
code_to_optimize/js/code_to_optimize_js_esm/jest.config.cjs
Normal file
11
code_to_optimize/js/code_to_optimize_js_esm/jest.config.cjs
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
// Jest config for ES Module project (using .cjs since package is type: module)
|
||||
module.exports = {
|
||||
testEnvironment: 'node',
|
||||
testMatch: ['**/tests/**/*.test.js'],
|
||||
reporters: ['default', ['jest-junit', { outputDirectory: '.codeflash' }]],
|
||||
verbose: true,
|
||||
transform: {},
|
||||
// Tell Jest to also look for modules in the project's node_modules when
|
||||
// resolving modules from symlinked packages (like codeflash)
|
||||
moduleDirectories: ['node_modules', '<rootDir>/node_modules'],
|
||||
};
|
||||
23
code_to_optimize/js/code_to_optimize_js_esm/package.json
Normal file
23
code_to_optimize/js/code_to_optimize_js_esm/package.json
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"name": "code-to-optimize-js-esm",
|
||||
"version": "1.0.0",
|
||||
"description": "ES Module JavaScript test project for Codeflash E2E testing",
|
||||
"type": "module",
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"test": "NODE_OPTIONS='--experimental-vm-modules' jest"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.39.2",
|
||||
"codeflash": "file:../../../packages/codeflash",
|
||||
"eslint": "^9.39.2",
|
||||
"globals": "^17.1.0",
|
||||
"jest": "^29.7.0",
|
||||
"jest-junit": "^16.0.0"
|
||||
},
|
||||
"codeflash": {
|
||||
"moduleRoot": ".",
|
||||
"testsRoot": "tests",
|
||||
"disableTelemetry": true
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Tests for async utility functions - ES Module
|
||||
*/
|
||||
import { delay, processItemsSequential, asyncMap, asyncFilter } from '../async_utils.js';
|
||||
|
||||
describe('processItemsSequential', () => {
|
||||
test('processes all items', async () => {
|
||||
const items = [1, 2, 3, 4, 5];
|
||||
const processor = async (x) => x * 2;
|
||||
const results = await processItemsSequential(items, processor);
|
||||
expect(results).toEqual([2, 4, 6, 8, 10]);
|
||||
});
|
||||
|
||||
test('handles empty array', async () => {
|
||||
const results = await processItemsSequential([], async (x) => x);
|
||||
expect(results).toEqual([]);
|
||||
});
|
||||
|
||||
test('handles async operations with delays', async () => {
|
||||
const items = [1, 2, 3];
|
||||
const processor = async (x) => {
|
||||
await delay(1);
|
||||
return x + 10;
|
||||
};
|
||||
const results = await processItemsSequential(items, processor);
|
||||
expect(results).toEqual([11, 12, 13]);
|
||||
});
|
||||
|
||||
test('preserves order', async () => {
|
||||
const items = [5, 4, 3, 2, 1];
|
||||
const processor = async (x) => x.toString();
|
||||
const results = await processItemsSequential(items, processor);
|
||||
expect(results).toEqual(['5', '4', '3', '2', '1']);
|
||||
});
|
||||
|
||||
test('handles larger arrays', async () => {
|
||||
const items = Array.from({ length: 20 }, (_, i) => i);
|
||||
const processor = async (x) => x * 2;
|
||||
const results = await processItemsSequential(items, processor);
|
||||
expect(results.length).toBe(20);
|
||||
expect(results[0]).toBe(0);
|
||||
expect(results[19]).toBe(38);
|
||||
});
|
||||
});
|
||||
|
||||
describe('asyncMap', () => {
|
||||
test('maps all items', async () => {
|
||||
const items = [1, 2, 3];
|
||||
const mapper = async (x) => x * 10;
|
||||
const results = await asyncMap(items, mapper);
|
||||
expect(results).toEqual([10, 20, 30]);
|
||||
});
|
||||
|
||||
test('handles empty array', async () => {
|
||||
const results = await asyncMap([], async (x) => x);
|
||||
expect(results).toEqual([]);
|
||||
});
|
||||
|
||||
test('handles objects', async () => {
|
||||
const items = [{ a: 1 }, { a: 2 }];
|
||||
const mapper = async (obj) => ({ ...obj, b: obj.a * 2 });
|
||||
const results = await asyncMap(items, mapper);
|
||||
expect(results).toEqual([{ a: 1, b: 2 }, { a: 2, b: 4 }]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('asyncFilter', () => {
|
||||
test('filters items based on predicate', async () => {
|
||||
const items = [1, 2, 3, 4, 5, 6];
|
||||
const predicate = async (x) => x % 2 === 0;
|
||||
const results = await asyncFilter(items, predicate);
|
||||
expect(results).toEqual([2, 4, 6]);
|
||||
});
|
||||
|
||||
test('handles empty array', async () => {
|
||||
const results = await asyncFilter([], async () => true);
|
||||
expect(results).toEqual([]);
|
||||
});
|
||||
|
||||
test('handles all items filtered out', async () => {
|
||||
const items = [1, 2, 3];
|
||||
const results = await asyncFilter(items, async () => false);
|
||||
expect(results).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Tests for Fibonacci functions - ES Module
|
||||
*/
|
||||
import { fibonacci, isFibonacci, isPerfectSquare, fibonacciSequence } from '../fibonacci.js';
|
||||
|
||||
describe('fibonacci', () => {
|
||||
test('returns 0 for n=0', () => {
|
||||
expect(fibonacci(0)).toBe(0);
|
||||
});
|
||||
|
||||
test('returns 1 for n=1', () => {
|
||||
expect(fibonacci(1)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 1 for n=2', () => {
|
||||
expect(fibonacci(2)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 5 for n=5', () => {
|
||||
expect(fibonacci(5)).toBe(5);
|
||||
});
|
||||
|
||||
test('returns 55 for n=10', () => {
|
||||
expect(fibonacci(10)).toBe(55);
|
||||
});
|
||||
|
||||
test('returns 233 for n=13', () => {
|
||||
expect(fibonacci(13)).toBe(233);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isFibonacci', () => {
|
||||
test('returns true for Fibonacci numbers', () => {
|
||||
expect(isFibonacci(0)).toBe(true);
|
||||
expect(isFibonacci(1)).toBe(true);
|
||||
expect(isFibonacci(5)).toBe(true);
|
||||
expect(isFibonacci(8)).toBe(true);
|
||||
expect(isFibonacci(13)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for non-Fibonacci numbers', () => {
|
||||
expect(isFibonacci(4)).toBe(false);
|
||||
expect(isFibonacci(6)).toBe(false);
|
||||
expect(isFibonacci(7)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPerfectSquare', () => {
|
||||
test('returns true for perfect squares', () => {
|
||||
expect(isPerfectSquare(0)).toBe(true);
|
||||
expect(isPerfectSquare(1)).toBe(true);
|
||||
expect(isPerfectSquare(4)).toBe(true);
|
||||
expect(isPerfectSquare(9)).toBe(true);
|
||||
expect(isPerfectSquare(16)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for non-perfect squares', () => {
|
||||
expect(isPerfectSquare(2)).toBe(false);
|
||||
expect(isPerfectSquare(3)).toBe(false);
|
||||
expect(isPerfectSquare(5)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fibonacciSequence', () => {
|
||||
test('returns empty array for n=0', () => {
|
||||
expect(fibonacciSequence(0)).toEqual([]);
|
||||
});
|
||||
|
||||
test('returns first 5 Fibonacci numbers', () => {
|
||||
expect(fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
|
||||
});
|
||||
|
||||
test('returns first 10 Fibonacci numbers', () => {
|
||||
expect(fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
|
||||
});
|
||||
});
|
||||
63
code_to_optimize/js/code_to_optimize_ts/bubble_sort.ts
Normal file
63
code_to_optimize/js/code_to_optimize_ts/bubble_sort.ts
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Bubble sort implementation - intentionally inefficient for optimization testing.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Sort an array using bubble sort algorithm.
|
||||
* @param arr - The array to sort
|
||||
* @returns A new sorted array
|
||||
*/
|
||||
export function bubbleSort<T>(arr: T[]): T[] {
|
||||
const result = [...arr];
|
||||
const n = result.length;
|
||||
|
||||
for (let i = 0; i < n - 1; i++) {
|
||||
for (let j = 0; j < n - i - 1; j++) {
|
||||
if (result[j] > result[j + 1]) {
|
||||
// Swap elements
|
||||
const temp = result[j];
|
||||
result[j] = result[j + 1];
|
||||
result[j + 1] = temp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sort an array in descending order using bubble sort.
|
||||
* @param arr - The array to sort
|
||||
* @returns A new sorted array (descending)
|
||||
*/
|
||||
export function bubbleSortDescending<T>(arr: T[]): T[] {
|
||||
const result = [...arr];
|
||||
const n = result.length;
|
||||
|
||||
for (let i = 0; i < n - 1; i++) {
|
||||
for (let j = 0; j < n - i - 1; j++) {
|
||||
if (result[j] < result[j + 1]) {
|
||||
// Swap elements
|
||||
const temp = result[j];
|
||||
result[j] = result[j + 1];
|
||||
result[j + 1] = temp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if an array is sorted in ascending order.
|
||||
* @param arr - The array to check
|
||||
* @returns True if the array is sorted in ascending order
|
||||
*/
|
||||
export function isSorted<T>(arr: T[]): boolean {
|
||||
for (let i = 0; i < arr.length - 1; i++) {
|
||||
if (arr[i] > arr[i + 1]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
2
code_to_optimize/js/code_to_optimize_ts/codeflash.yaml
Normal file
2
code_to_optimize/js/code_to_optimize_ts/codeflash.yaml
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
module_root: .
|
||||
tests_root: tests
|
||||
88
code_to_optimize/js/code_to_optimize_ts/data_processor.ts
Normal file
88
code_to_optimize/js/code_to_optimize_ts/data_processor.ts
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* DataProcessor class - demonstrates class method optimization in TypeScript.
|
||||
* Contains intentionally inefficient implementations for optimization testing.
|
||||
*/
|
||||
|
||||
/**
|
||||
* A class for processing data arrays with various operations.
|
||||
*/
|
||||
export class DataProcessor<T> {
|
||||
private data: T[];
|
||||
|
||||
/**
|
||||
* Create a DataProcessor instance.
|
||||
* @param data - Initial data array
|
||||
*/
|
||||
constructor(data: T[] = []) {
|
||||
this.data = [...data];
|
||||
}
|
||||
|
||||
/**
|
||||
* Find duplicates in the data array.
|
||||
* Intentionally inefficient O(n²) implementation.
|
||||
* @returns Array of duplicate values
|
||||
*/
|
||||
findDuplicates(): T[] {
|
||||
const duplicates: T[] = [];
|
||||
for (let i = 0; i < this.data.length; i++) {
|
||||
for (let j = i + 1; j < this.data.length; j++) {
|
||||
if (this.data[i] === this.data[j]) {
|
||||
if (!duplicates.includes(this.data[i])) {
|
||||
duplicates.push(this.data[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return duplicates;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sort the data using bubble sort.
|
||||
* Intentionally inefficient O(n²) implementation.
|
||||
* @returns Sorted copy of the data
|
||||
*/
|
||||
sortData(): T[] {
|
||||
const result = [...this.data];
|
||||
const n = result.length;
|
||||
for (let i = 0; i < n; i++) {
|
||||
for (let j = 0; j < n - 1; j++) {
|
||||
if (result[j] > result[j + 1]) {
|
||||
const temp = result[j];
|
||||
result[j] = result[j + 1];
|
||||
result[j + 1] = temp;
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get unique values from the data.
|
||||
* Intentionally inefficient O(n²) implementation.
|
||||
* @returns Array of unique values
|
||||
*/
|
||||
getUnique(): T[] {
|
||||
const unique: T[] = [];
|
||||
for (let i = 0; i < this.data.length; i++) {
|
||||
let found = false;
|
||||
for (let j = 0; j < unique.length; j++) {
|
||||
if (unique[j] === this.data[i]) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
unique.push(this.data[i]);
|
||||
}
|
||||
}
|
||||
return unique;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the data array.
|
||||
* @returns The data array
|
||||
*/
|
||||
getData(): T[] {
|
||||
return [...this.data];
|
||||
}
|
||||
}
|
||||
52
code_to_optimize/js/code_to_optimize_ts/fibonacci.ts
Normal file
52
code_to_optimize/js/code_to_optimize_ts/fibonacci.ts
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Fibonacci implementations - intentionally inefficient for optimization testing.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Calculate the nth Fibonacci number using naive recursion.
|
||||
* This is intentionally slow to demonstrate optimization potential.
|
||||
* @param n - The index of the Fibonacci number to calculate
|
||||
* @returns The nth Fibonacci number
|
||||
*/
|
||||
export function fibonacci(n: number): number {
|
||||
if (n <= 1) {
|
||||
return n;
|
||||
}
|
||||
return fibonacci(n - 1) + fibonacci(n - 2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a Fibonacci number.
|
||||
* @param num - The number to check
|
||||
* @returns True if num is a Fibonacci number
|
||||
*/
|
||||
export function isFibonacci(num: number): boolean {
|
||||
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
|
||||
const check1 = 5 * num * num + 4;
|
||||
const check2 = 5 * num * num - 4;
|
||||
|
||||
return isPerfectSquare(check1) || isPerfectSquare(check2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a perfect square.
|
||||
* @param n - The number to check
|
||||
* @returns True if n is a perfect square
|
||||
*/
|
||||
export function isPerfectSquare(n: number): boolean {
|
||||
const sqrt = Math.sqrt(n);
|
||||
return sqrt === Math.floor(sqrt);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an array of Fibonacci numbers up to n.
|
||||
* @param n - The number of Fibonacci numbers to generate
|
||||
* @returns Array of Fibonacci numbers
|
||||
*/
|
||||
export function fibonacciSequence(n: number): number[] {
|
||||
const result: number[] = [];
|
||||
for (let i = 0; i < n; i++) {
|
||||
result.push(fibonacci(i));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
35
code_to_optimize/js/code_to_optimize_ts/jest.config.ts
Normal file
35
code_to_optimize/js/code_to_optimize_ts/jest.config.ts
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
import type { Config } from 'jest';
|
||||
|
||||
const config: Config = {
|
||||
preset: 'ts-jest',
|
||||
testEnvironment: 'node',
|
||||
testMatch: [
|
||||
'**/tests/**/*.test.ts',
|
||||
'**/tests/**/*.spec.ts'
|
||||
],
|
||||
moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'],
|
||||
collectCoverageFrom: [
|
||||
'**/*.ts',
|
||||
'!**/node_modules/**',
|
||||
'!**/dist/**',
|
||||
'!jest.config.ts'
|
||||
],
|
||||
reporters: [
|
||||
'default',
|
||||
[
|
||||
'jest-junit',
|
||||
{
|
||||
outputDirectory: '.codeflash',
|
||||
outputName: 'jest-results.xml',
|
||||
includeConsoleOutput: true
|
||||
}
|
||||
]
|
||||
],
|
||||
transform: {
|
||||
'^.+\\.tsx?$': ['ts-jest', {
|
||||
useESM: false
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
export default config;
|
||||
4085
code_to_optimize/js/code_to_optimize_ts/package-lock.json
generated
Normal file
4085
code_to_optimize/js/code_to_optimize_ts/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load diff
33
code_to_optimize/js/code_to_optimize_ts/package.json
Normal file
33
code_to_optimize/js/code_to_optimize_ts/package.json
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"name": "codeflash-ts-test",
|
||||
"version": "1.0.0",
|
||||
"description": "Sample TypeScript project for codeflash optimization testing",
|
||||
"main": "dist/index.js",
|
||||
"scripts": {
|
||||
"test": "jest",
|
||||
"test:coverage": "jest --coverage",
|
||||
"build": "tsc"
|
||||
},
|
||||
"codeflash": {
|
||||
"moduleRoot": ".",
|
||||
"testsRoot": "tests"
|
||||
},
|
||||
"keywords": [
|
||||
"codeflash",
|
||||
"optimization",
|
||||
"testing",
|
||||
"typescript"
|
||||
],
|
||||
"author": "CodeFlash Inc.",
|
||||
"license": "BSL 1.1",
|
||||
"devDependencies": {
|
||||
"@types/jest": "^29.5.0",
|
||||
"@types/node": "^20.0.0",
|
||||
"codeflash": "file:../../../packages/codeflash",
|
||||
"jest": "^29.7.0",
|
||||
"jest-junit": "^16.0.0",
|
||||
"ts-jest": "^29.1.0",
|
||||
"ts-node": "^10.9.2",
|
||||
"typescript": "^5.0.0"
|
||||
}
|
||||
}
|
||||
84
code_to_optimize/js/code_to_optimize_ts/string_utils.ts
Normal file
84
code_to_optimize/js/code_to_optimize_ts/string_utils.ts
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* String utility functions - intentionally inefficient for optimization testing.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Reverse a string character by character.
|
||||
* This is intentionally inefficient O(n²) - rebuilds result string each iteration.
|
||||
* @param str - The string to reverse
|
||||
* @returns The reversed string
|
||||
*/
|
||||
export function reverseString(str: string): string {
|
||||
// Intentionally inefficient O(n²) implementation for testing
|
||||
let result = '';
|
||||
for (let i = str.length - 1; i >= 0; i--) {
|
||||
// Rebuild the entire result string each iteration (very inefficient)
|
||||
let temp = '';
|
||||
for (let j = 0; j < result.length; j++) {
|
||||
temp += result[j];
|
||||
}
|
||||
temp += str[i];
|
||||
result = temp;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a string is a palindrome.
|
||||
* @param str - The string to check
|
||||
* @returns True if the string is a palindrome
|
||||
*/
|
||||
export function isPalindrome(str: string): boolean {
|
||||
const cleaned = str.toLowerCase().replace(/[^a-z0-9]/g, '');
|
||||
return cleaned === reverseString(cleaned);
|
||||
}
|
||||
|
||||
/**
|
||||
* Count occurrences of a substring in a string.
|
||||
* @param str - The string to search in
|
||||
* @param substr - The substring to count
|
||||
* @returns The number of occurrences
|
||||
*/
|
||||
export function countOccurrences(str: string, substr: string): number {
|
||||
let count = 0;
|
||||
let pos = 0;
|
||||
|
||||
while ((pos = str.indexOf(substr, pos)) !== -1) {
|
||||
count++;
|
||||
pos += 1; // Move forward to find overlapping occurrences
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the longest common prefix among an array of strings.
|
||||
* @param strs - Array of strings
|
||||
* @returns The longest common prefix
|
||||
*/
|
||||
export function longestCommonPrefix(strs: string[]): string {
|
||||
if (strs.length === 0) return '';
|
||||
if (strs.length === 1) return strs[0];
|
||||
|
||||
let prefix = strs[0];
|
||||
for (let i = 1; i < strs.length; i++) {
|
||||
while (strs[i].indexOf(prefix) !== 0) {
|
||||
prefix = prefix.slice(0, -1);
|
||||
if (prefix === '') return '';
|
||||
}
|
||||
}
|
||||
return prefix;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a string to title case.
|
||||
* @param str - The string to convert
|
||||
* @returns The title-cased string
|
||||
*/
|
||||
export function toTitleCase(str: string): string {
|
||||
return str
|
||||
.toLowerCase()
|
||||
.split(' ')
|
||||
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(' ');
|
||||
}
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
import { bubbleSort, bubbleSortDescending, isSorted } from '../bubble_sort';
|
||||
|
||||
describe('bubbleSort', () => {
|
||||
test('sorts an empty array', () => {
|
||||
expect(bubbleSort([])).toEqual([]);
|
||||
});
|
||||
|
||||
test('sorts a single element array', () => {
|
||||
expect(bubbleSort([1])).toEqual([1]);
|
||||
});
|
||||
|
||||
test('sorts an already sorted array', () => {
|
||||
expect(bubbleSort([1, 2, 3, 4, 5])).toEqual([1, 2, 3, 4, 5]);
|
||||
});
|
||||
|
||||
test('sorts a reverse sorted array', () => {
|
||||
expect(bubbleSort([5, 4, 3, 2, 1])).toEqual([1, 2, 3, 4, 5]);
|
||||
});
|
||||
|
||||
test('sorts an unsorted array', () => {
|
||||
expect(bubbleSort([3, 1, 4, 1, 5, 9, 2, 6])).toEqual([1, 1, 2, 3, 4, 5, 6, 9]);
|
||||
});
|
||||
|
||||
test('handles negative numbers', () => {
|
||||
expect(bubbleSort([-3, -1, -4, -1, -5])).toEqual([-5, -4, -3, -1, -1]);
|
||||
});
|
||||
|
||||
test('handles mixed positive and negative', () => {
|
||||
expect(bubbleSort([3, -1, 4, -1, 5])).toEqual([-1, -1, 3, 4, 5]);
|
||||
});
|
||||
|
||||
test('does not mutate original array', () => {
|
||||
const original = [3, 1, 2];
|
||||
bubbleSort(original);
|
||||
expect(original).toEqual([3, 1, 2]);
|
||||
});
|
||||
|
||||
test('sorts a larger reverse sorted array for performance', () => {
|
||||
const input: number[] = [];
|
||||
for (let i = 500; i >= 0; i--) {
|
||||
input.push(i);
|
||||
}
|
||||
const result = bubbleSort(input);
|
||||
expect(result[0]).toBe(0);
|
||||
expect(result[result.length - 1]).toBe(500);
|
||||
});
|
||||
|
||||
test('sorts a larger random array for performance', () => {
|
||||
const input = [
|
||||
42, 17, 93, 8, 67, 31, 55, 22, 89, 4,
|
||||
76, 12, 39, 58, 95, 26, 71, 48, 83, 19,
|
||||
64, 3, 88, 37, 52, 11, 79, 46, 91, 28,
|
||||
63, 7, 84, 33, 57, 14, 72, 41, 96, 24,
|
||||
69, 6, 81, 36, 54, 16, 77, 44, 90, 29
|
||||
];
|
||||
const result = bubbleSort(input);
|
||||
expect(result[0]).toBe(3);
|
||||
expect(result[result.length - 1]).toBe(96);
|
||||
});
|
||||
});
|
||||
|
||||
describe('bubbleSortDescending', () => {
|
||||
test('sorts in descending order', () => {
|
||||
expect(bubbleSortDescending([1, 3, 2, 5, 4])).toEqual([5, 4, 3, 2, 1]);
|
||||
});
|
||||
|
||||
test('handles empty array', () => {
|
||||
expect(bubbleSortDescending([])).toEqual([]);
|
||||
});
|
||||
|
||||
test('handles single element', () => {
|
||||
expect(bubbleSortDescending([42])).toEqual([42]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isSorted', () => {
|
||||
test('returns true for empty array', () => {
|
||||
expect(isSorted([])).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for single element', () => {
|
||||
expect(isSorted([1])).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for sorted array', () => {
|
||||
expect(isSorted([1, 2, 3, 4, 5])).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for unsorted array', () => {
|
||||
expect(isSorted([1, 3, 2, 4, 5])).toBe(false);
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
import { DataProcessor } from '../data_processor';
|
||||
|
||||
describe('DataProcessor', () => {
|
||||
describe('findDuplicates', () => {
|
||||
test('finds duplicates in array with repeated values', () => {
|
||||
const processor = new DataProcessor([1, 2, 3, 2, 4, 3, 5]);
|
||||
expect(processor.findDuplicates().sort()).toEqual([2, 3]);
|
||||
});
|
||||
|
||||
test('returns empty array when no duplicates', () => {
|
||||
const processor = new DataProcessor([1, 2, 3, 4, 5]);
|
||||
expect(processor.findDuplicates()).toEqual([]);
|
||||
});
|
||||
|
||||
test('handles empty array', () => {
|
||||
const processor = new DataProcessor<number>([]);
|
||||
expect(processor.findDuplicates()).toEqual([]);
|
||||
});
|
||||
|
||||
test('handles array with all same values', () => {
|
||||
const processor = new DataProcessor([5, 5, 5, 5]);
|
||||
expect(processor.findDuplicates()).toEqual([5]);
|
||||
});
|
||||
|
||||
test('handles larger arrays with duplicates', () => {
|
||||
const data: number[] = [];
|
||||
for (let i = 0; i < 100; i++) {
|
||||
data.push(i % 20);
|
||||
}
|
||||
const processor = new DataProcessor(data);
|
||||
const duplicates = processor.findDuplicates();
|
||||
expect(duplicates.length).toBe(20);
|
||||
});
|
||||
});
|
||||
|
||||
describe('sortData', () => {
|
||||
test('sorts numbers in ascending order', () => {
|
||||
const processor = new DataProcessor([5, 2, 8, 1, 9]);
|
||||
expect(processor.sortData()).toEqual([1, 2, 5, 8, 9]);
|
||||
});
|
||||
|
||||
test('handles already sorted array', () => {
|
||||
const processor = new DataProcessor([1, 2, 3, 4, 5]);
|
||||
expect(processor.sortData()).toEqual([1, 2, 3, 4, 5]);
|
||||
});
|
||||
|
||||
test('handles reverse sorted array', () => {
|
||||
const processor = new DataProcessor([5, 4, 3, 2, 1]);
|
||||
expect(processor.sortData()).toEqual([1, 2, 3, 4, 5]);
|
||||
});
|
||||
|
||||
test('handles array with duplicates', () => {
|
||||
const processor = new DataProcessor([3, 1, 4, 1, 5, 9, 2, 6, 5]);
|
||||
expect(processor.sortData()).toEqual([1, 1, 2, 3, 4, 5, 5, 6, 9]);
|
||||
});
|
||||
|
||||
test('handles larger arrays', () => {
|
||||
const data: number[] = [];
|
||||
for (let i = 500; i >= 0; i--) {
|
||||
data.push(i);
|
||||
}
|
||||
const processor = new DataProcessor(data);
|
||||
const sorted = processor.sortData();
|
||||
expect(sorted[0]).toBe(0);
|
||||
expect(sorted[sorted.length - 1]).toBe(500);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getUnique', () => {
|
||||
test('returns unique values', () => {
|
||||
const processor = new DataProcessor([1, 2, 2, 3, 3, 3, 4]);
|
||||
expect(processor.getUnique()).toEqual([1, 2, 3, 4]);
|
||||
});
|
||||
|
||||
test('preserves order of first occurrence', () => {
|
||||
const processor = new DataProcessor([3, 1, 2, 1, 3, 2]);
|
||||
expect(processor.getUnique()).toEqual([3, 1, 2]);
|
||||
});
|
||||
|
||||
test('handles empty array', () => {
|
||||
const processor = new DataProcessor<number>([]);
|
||||
expect(processor.getUnique()).toEqual([]);
|
||||
});
|
||||
|
||||
test('handles array with all unique values', () => {
|
||||
const processor = new DataProcessor([1, 2, 3, 4, 5]);
|
||||
expect(processor.getUnique()).toEqual([1, 2, 3, 4, 5]);
|
||||
});
|
||||
|
||||
test('handles strings', () => {
|
||||
const processor = new DataProcessor(['a', 'b', 'a', 'c', 'b']);
|
||||
expect(processor.getUnique()).toEqual(['a', 'b', 'c']);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
import { fibonacci, isFibonacci, isPerfectSquare, fibonacciSequence } from '../fibonacci';
|
||||
|
||||
describe('fibonacci', () => {
|
||||
test('returns 0 for n=0', () => {
|
||||
expect(fibonacci(0)).toBe(0);
|
||||
});
|
||||
|
||||
test('returns 1 for n=1', () => {
|
||||
expect(fibonacci(1)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 1 for n=2', () => {
|
||||
expect(fibonacci(2)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 5 for n=5', () => {
|
||||
expect(fibonacci(5)).toBe(5);
|
||||
});
|
||||
|
||||
test('returns 55 for n=10', () => {
|
||||
expect(fibonacci(10)).toBe(55);
|
||||
});
|
||||
|
||||
test('returns 233 for n=13', () => {
|
||||
expect(fibonacci(13)).toBe(233);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isFibonacci', () => {
|
||||
test('returns true for 0', () => {
|
||||
expect(isFibonacci(0)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 1', () => {
|
||||
expect(isFibonacci(1)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 8', () => {
|
||||
expect(isFibonacci(8)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 13', () => {
|
||||
expect(isFibonacci(13)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for 4', () => {
|
||||
expect(isFibonacci(4)).toBe(false);
|
||||
});
|
||||
|
||||
test('returns false for 6', () => {
|
||||
expect(isFibonacci(6)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPerfectSquare', () => {
|
||||
test('returns true for 0', () => {
|
||||
expect(isPerfectSquare(0)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 1', () => {
|
||||
expect(isPerfectSquare(1)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 4', () => {
|
||||
expect(isPerfectSquare(4)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 16', () => {
|
||||
expect(isPerfectSquare(16)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for 2', () => {
|
||||
expect(isPerfectSquare(2)).toBe(false);
|
||||
});
|
||||
|
||||
test('returns false for 3', () => {
|
||||
expect(isPerfectSquare(3)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fibonacciSequence', () => {
|
||||
test('returns empty array for n=0', () => {
|
||||
expect(fibonacciSequence(0)).toEqual([]);
|
||||
});
|
||||
|
||||
test('returns [0] for n=1', () => {
|
||||
expect(fibonacciSequence(1)).toEqual([0]);
|
||||
});
|
||||
|
||||
test('returns first 5 Fibonacci numbers', () => {
|
||||
expect(fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
|
||||
});
|
||||
|
||||
test('returns first 10 Fibonacci numbers', () => {
|
||||
expect(fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
import { reverseString, isPalindrome, countOccurrences, longestCommonPrefix, toTitleCase } from '../string_utils';
|
||||
|
||||
describe('reverseString', () => {
|
||||
test('reverses an empty string', () => {
|
||||
expect(reverseString('')).toBe('');
|
||||
});
|
||||
|
||||
test('reverses a single character', () => {
|
||||
expect(reverseString('a')).toBe('a');
|
||||
});
|
||||
|
||||
test('reverses a word', () => {
|
||||
expect(reverseString('hello')).toBe('olleh');
|
||||
});
|
||||
|
||||
test('reverses a sentence', () => {
|
||||
expect(reverseString('hello world')).toBe('dlrow olleh');
|
||||
});
|
||||
|
||||
test('handles special characters', () => {
|
||||
expect(reverseString('a!b@c#')).toBe('#c@b!a');
|
||||
});
|
||||
|
||||
test('reverses a longer string for performance', () => {
|
||||
const input = 'abcdefghijklmnopqrstuvwxyz'.repeat(20);
|
||||
const result = reverseString(input);
|
||||
expect(result.length).toBe(input.length);
|
||||
expect(result[0]).toBe('z');
|
||||
expect(result[result.length - 1]).toBe('a');
|
||||
});
|
||||
|
||||
test('reverses a medium string', () => {
|
||||
const input = 'The quick brown fox jumps over the lazy dog';
|
||||
const expected = 'god yzal eht revo spmuj xof nworb kciuq ehT';
|
||||
expect(reverseString(input)).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPalindrome', () => {
|
||||
test('returns true for empty string', () => {
|
||||
expect(isPalindrome('')).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for single character', () => {
|
||||
expect(isPalindrome('a')).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for palindrome word', () => {
|
||||
expect(isPalindrome('racecar')).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for palindrome with mixed case', () => {
|
||||
expect(isPalindrome('RaceCar')).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for palindrome with spaces', () => {
|
||||
expect(isPalindrome('A man a plan a canal Panama')).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for non-palindrome', () => {
|
||||
expect(isPalindrome('hello')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('countOccurrences', () => {
|
||||
test('returns 0 for empty string', () => {
|
||||
expect(countOccurrences('', 'a')).toBe(0);
|
||||
});
|
||||
|
||||
test('returns 0 when substring not found', () => {
|
||||
expect(countOccurrences('hello', 'x')).toBe(0);
|
||||
});
|
||||
|
||||
test('counts single occurrence', () => {
|
||||
expect(countOccurrences('hello', 'e')).toBe(1);
|
||||
});
|
||||
|
||||
test('counts multiple occurrences', () => {
|
||||
expect(countOccurrences('hello', 'l')).toBe(2);
|
||||
});
|
||||
|
||||
test('counts overlapping occurrences', () => {
|
||||
expect(countOccurrences('aaa', 'aa')).toBe(2);
|
||||
});
|
||||
|
||||
test('counts multi-character substring', () => {
|
||||
expect(countOccurrences('abcabc', 'abc')).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('longestCommonPrefix', () => {
|
||||
test('returns empty for empty array', () => {
|
||||
expect(longestCommonPrefix([])).toBe('');
|
||||
});
|
||||
|
||||
test('returns the string for single element', () => {
|
||||
expect(longestCommonPrefix(['hello'])).toBe('hello');
|
||||
});
|
||||
|
||||
test('finds common prefix', () => {
|
||||
expect(longestCommonPrefix(['flower', 'flow', 'flight'])).toBe('fl');
|
||||
});
|
||||
|
||||
test('returns empty when no common prefix', () => {
|
||||
expect(longestCommonPrefix(['dog', 'racecar', 'car'])).toBe('');
|
||||
});
|
||||
|
||||
test('handles identical strings', () => {
|
||||
expect(longestCommonPrefix(['test', 'test', 'test'])).toBe('test');
|
||||
});
|
||||
});
|
||||
|
||||
describe('toTitleCase', () => {
|
||||
test('converts single word', () => {
|
||||
expect(toTitleCase('hello')).toBe('Hello');
|
||||
});
|
||||
|
||||
test('converts multiple words', () => {
|
||||
expect(toTitleCase('hello world')).toBe('Hello World');
|
||||
});
|
||||
|
||||
test('handles already title case', () => {
|
||||
expect(toTitleCase('Hello World')).toBe('Hello World');
|
||||
});
|
||||
|
||||
test('handles all uppercase', () => {
|
||||
expect(toTitleCase('HELLO WORLD')).toBe('Hello World');
|
||||
});
|
||||
|
||||
test('handles empty string', () => {
|
||||
expect(toTitleCase('')).toBe('');
|
||||
});
|
||||
});
|
||||
20
code_to_optimize/js/code_to_optimize_ts/tsconfig.json
Normal file
20
code_to_optimize/js/code_to_optimize_ts/tsconfig.json
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"module": "commonjs",
|
||||
"lib": ["ES2020"],
|
||||
"outDir": "./dist",
|
||||
"rootDir": ".",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"declaration": true,
|
||||
"declarationMap": true,
|
||||
"sourceMap": true,
|
||||
"resolveJsonModule": true,
|
||||
"moduleResolution": "node"
|
||||
},
|
||||
"include": ["*.ts", "tests/**/*.ts"],
|
||||
"exclude": ["node_modules", "dist"]
|
||||
}
|
||||
|
|
@ -14,6 +14,7 @@ from codeflash.cli_cmds.console import console, logger
|
|||
from codeflash.code_utils.env_utils import get_codeflash_api_key
|
||||
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.languages import is_javascript, is_python
|
||||
from codeflash.models.ExperimentMetadata import ExperimentMetadata
|
||||
from codeflash.models.models import (
|
||||
AIServiceRefinerRequest,
|
||||
|
|
@ -101,11 +102,11 @@ class AiServiceClient:
|
|||
return response
|
||||
|
||||
def _get_valid_candidates(
|
||||
self, optimizations_json: list[dict[str, Any]], source: OptimizedCandidateSource
|
||||
self, optimizations_json: list[dict[str, Any]], source: OptimizedCandidateSource, language: str = "python"
|
||||
) -> list[OptimizedCandidate]:
|
||||
candidates: list[OptimizedCandidate] = []
|
||||
for opt in optimizations_json:
|
||||
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
|
||||
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"], expected_language=language)
|
||||
if not code.code_strings:
|
||||
continue
|
||||
candidates.append(
|
||||
|
|
@ -120,25 +121,32 @@ class AiServiceClient:
|
|||
)
|
||||
return candidates
|
||||
|
||||
def optimize_python_code( # noqa: D417
|
||||
def optimize_code(
|
||||
self,
|
||||
source_code: str,
|
||||
dependency_code: str,
|
||||
trace_id: str,
|
||||
experiment_metadata: ExperimentMetadata | None = None,
|
||||
*,
|
||||
language: str = "python",
|
||||
language_version: str
|
||||
| None = None, # TODO:{claude} add language version to the language support and it should be cached
|
||||
module_system: str | None = None,
|
||||
is_async: bool = False,
|
||||
n_candidates: int = 5,
|
||||
is_numerical_code: bool | None = None,
|
||||
) -> list[OptimizedCandidate]:
|
||||
"""Optimize the given python code for performance by making a request to the Django endpoint.
|
||||
"""Optimize the given code for performance by making a request to the Django endpoint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- source_code (str): The python code to optimize.
|
||||
- source_code (str): The code to optimize.
|
||||
- dependency_code (str): The dependency code used as read-only context for the optimization
|
||||
- trace_id (str): Trace id of optimization run
|
||||
- experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization
|
||||
- language (str): Programming language ("python", "javascript", "typescript")
|
||||
- language_version (str | None): Language version (e.g., "3.11.0" for Python, "ES2022" for JS)
|
||||
- module_system (str | None): JS/TS module system ("esm", "commonjs", or None for Python)
|
||||
- is_async (bool): Whether the function being optimized is async
|
||||
- n_candidates (int): Number of candidates to generate
|
||||
|
||||
|
|
@ -152,11 +160,12 @@ class AiServiceClient:
|
|||
start_time = time.perf_counter()
|
||||
git_repo_owner, git_repo_name = safe_get_repo_owner_and_name()
|
||||
|
||||
payload = {
|
||||
# Build payload with language-specific fields
|
||||
payload: dict[str, Any] = {
|
||||
"source_code": source_code,
|
||||
"dependency_code": dependency_code,
|
||||
"trace_id": trace_id,
|
||||
"python_version": platform.python_version(),
|
||||
"language": language,
|
||||
"experiment_metadata": experiment_metadata,
|
||||
"codeflash_version": codeflash_version,
|
||||
"current_username": get_last_commit_author_if_pr_exists(None),
|
||||
|
|
@ -167,6 +176,22 @@ class AiServiceClient:
|
|||
"n_candidates": n_candidates,
|
||||
"is_numerical_code": is_numerical_code,
|
||||
}
|
||||
|
||||
# Add language-specific version fields
|
||||
# Always include python_version for backward compatibility with older backend
|
||||
payload["python_version"] = platform.python_version()
|
||||
if is_python():
|
||||
pass # python_version already set
|
||||
else:
|
||||
payload["language_version"] = language_version or "ES2022"
|
||||
# Add module system for JavaScript/TypeScript (esm or commonjs)
|
||||
if module_system:
|
||||
payload["module_system"] = module_system
|
||||
|
||||
# DEBUG: Print payload language field
|
||||
logger.debug(
|
||||
f"Sending optimize request with language='{payload['language']}' (type: {type(payload['language'])})"
|
||||
)
|
||||
logger.debug(f"Sending optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}")
|
||||
|
||||
try:
|
||||
|
|
@ -183,7 +208,7 @@ class AiServiceClient:
|
|||
logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.")
|
||||
logger.info(f"!lsp|Received {len(optimizations_json)} optimization candidates.")
|
||||
console.rule()
|
||||
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE)
|
||||
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE, language)
|
||||
try:
|
||||
error = response.json()["error"]
|
||||
except Exception:
|
||||
|
|
@ -193,9 +218,29 @@ class AiServiceClient:
|
|||
console.rule()
|
||||
return []
|
||||
|
||||
def get_jit_rewritten_code( # noqa: D417
|
||||
self, source_code: str, trace_id: str
|
||||
# Backward-compatible alias
|
||||
def optimize_python_code(
|
||||
self,
|
||||
source_code: str,
|
||||
dependency_code: str,
|
||||
trace_id: str,
|
||||
experiment_metadata: ExperimentMetadata | None = None,
|
||||
*,
|
||||
is_async: bool = False,
|
||||
n_candidates: int = 5,
|
||||
) -> list[OptimizedCandidate]:
|
||||
"""Backward-compatible alias for optimize_code() with language='python'."""
|
||||
return self.optimize_code(
|
||||
source_code=source_code,
|
||||
dependency_code=dependency_code,
|
||||
trace_id=trace_id,
|
||||
experiment_metadata=experiment_metadata,
|
||||
language="python",
|
||||
is_async=is_async,
|
||||
n_candidates=n_candidates,
|
||||
)
|
||||
|
||||
def get_jit_rewritten_code(self, source_code: str, trace_id: str) -> list[OptimizedCandidate]:
|
||||
"""Rewrite the given python code for performance via jit compilation by making a request to the Django endpoint.
|
||||
|
||||
Parameters
|
||||
|
|
@ -245,7 +290,7 @@ class AiServiceClient:
|
|||
console.rule()
|
||||
return []
|
||||
|
||||
def optimize_python_code_line_profiler( # noqa: D417
|
||||
def optimize_python_code_line_profiler(
|
||||
self,
|
||||
source_code: str,
|
||||
dependency_code: str,
|
||||
|
|
@ -253,18 +298,22 @@ class AiServiceClient:
|
|||
line_profiler_results: str,
|
||||
n_candidates: int,
|
||||
experiment_metadata: ExperimentMetadata | None = None,
|
||||
is_numerical_code: bool | None = None, # noqa: FBT001
|
||||
is_numerical_code: bool | None = None,
|
||||
language: str = "python",
|
||||
language_version: str | None = None,
|
||||
) -> list[OptimizedCandidate]:
|
||||
"""Optimize the given python code for performance using line profiler results.
|
||||
"""Optimize code for performance using line profiler results.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- source_code (str): The python code to optimize.
|
||||
- source_code (str): The code to optimize.
|
||||
- dependency_code (str): The dependency code used as read-only context for the optimization
|
||||
- trace_id (str): Trace id of optimization run
|
||||
- line_profiler_results (str): Line profiler output to guide optimization
|
||||
- experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization
|
||||
- n_candidates (int): Number of candidates to generate
|
||||
- language (str): Programming language (python, javascript, typescript)
|
||||
- language_version (str): Language version (e.g., "3.12.0" for Python, "ES2022" for JavaScript)
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -278,13 +327,18 @@ class AiServiceClient:
|
|||
logger.info("Generating optimized candidates with line profiler…")
|
||||
console.rule()
|
||||
|
||||
# Set python_version for backward compatibility with Python, or use language_version
|
||||
python_version = language_version if language_version else platform.python_version()
|
||||
|
||||
payload = {
|
||||
"source_code": source_code,
|
||||
"dependency_code": dependency_code,
|
||||
"n_candidates": n_candidates,
|
||||
"line_profiler_results": line_profiler_results,
|
||||
"trace_id": trace_id,
|
||||
"python_version": platform.python_version(),
|
||||
"python_version": python_version,
|
||||
"language": language,
|
||||
"language_version": language_version,
|
||||
"experiment_metadata": experiment_metadata,
|
||||
"codeflash_version": codeflash_version,
|
||||
"call_sequence": self.get_next_sequence(),
|
||||
|
|
@ -345,19 +399,22 @@ class AiServiceClient:
|
|||
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
|
||||
return None
|
||||
|
||||
def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
|
||||
"""Optimize the given python code for performance by making a request to the Django endpoint.
|
||||
def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
|
||||
"""Refine optimization candidates for improved performance.
|
||||
|
||||
Supports Python, JavaScript, and TypeScript code refinement with optional
|
||||
multi-file context for better understanding of imports and dependencies.
|
||||
|
||||
Args:
|
||||
request: A list of optimization candidate details for refinement
|
||||
request: A list of optimization candidate details for refinement
|
||||
|
||||
Returns:
|
||||
-------
|
||||
- List[OptimizationCandidate]: A list of Optimization Candidates.
|
||||
List of refined optimization candidates
|
||||
|
||||
"""
|
||||
payload = [
|
||||
{
|
||||
payload: list[dict[str, Any]] = []
|
||||
for opt in request:
|
||||
item: dict[str, Any] = {
|
||||
"optimization_id": opt.optimization_id,
|
||||
"original_source_code": opt.original_source_code,
|
||||
"read_only_dependency_code": opt.read_only_dependency_code,
|
||||
|
|
@ -370,11 +427,26 @@ class AiServiceClient:
|
|||
"speedup": opt.speedup,
|
||||
"trace_id": opt.trace_id,
|
||||
"function_references": opt.function_references,
|
||||
"python_version": platform.python_version(),
|
||||
"call_sequence": self.get_next_sequence(),
|
||||
# Multi-language support
|
||||
"language": opt.language,
|
||||
}
|
||||
for opt in request
|
||||
]
|
||||
|
||||
# Add language version - always include python_version for backward compatibility
|
||||
item["python_version"] = platform.python_version()
|
||||
if is_python():
|
||||
pass # python_version already set
|
||||
elif opt.language_version:
|
||||
item["language_version"] = opt.language_version
|
||||
else:
|
||||
item["language_version"] = "ES2022" # Default for JS/TS
|
||||
|
||||
# Add multi-file context if provided
|
||||
if opt.additional_context_files:
|
||||
item["additional_context_files"] = opt.additional_context_files
|
||||
|
||||
payload.append(item)
|
||||
|
||||
try:
|
||||
response = self.make_ai_service_request("/refinement", payload=payload, timeout=self.timeout)
|
||||
except requests.exceptions.RequestException as e:
|
||||
|
|
@ -396,6 +468,9 @@ class AiServiceClient:
|
|||
console.rule()
|
||||
return []
|
||||
|
||||
# Alias for backward compatibility
|
||||
optimize_python_code_refinement = optimize_code_refinement
|
||||
|
||||
def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None:
|
||||
"""Repair the optimization candidate that is not matching the test result of the original code.
|
||||
|
||||
|
|
@ -415,6 +490,7 @@ class AiServiceClient:
|
|||
"modified_source_code": request.modified_source_code,
|
||||
"trace_id": request.trace_id,
|
||||
"test_diffs": request.test_diffs,
|
||||
"language": request.language,
|
||||
}
|
||||
response = self.make_ai_service_request("/code_repair", payload=payload, timeout=self.timeout)
|
||||
except (requests.exceptions.RequestException, TypeError) as e:
|
||||
|
|
@ -426,7 +502,9 @@ class AiServiceClient:
|
|||
fixed_optimization = response.json()
|
||||
console.rule()
|
||||
|
||||
valid_candidates = self._get_valid_candidates([fixed_optimization], OptimizedCandidateSource.REPAIR)
|
||||
valid_candidates = self._get_valid_candidates(
|
||||
[fixed_optimization], OptimizedCandidateSource.REPAIR, request.language
|
||||
)
|
||||
if not valid_candidates:
|
||||
logger.error("Code repair failed to generate a valid candidate.")
|
||||
return None
|
||||
|
|
@ -442,7 +520,7 @@ class AiServiceClient:
|
|||
console.rule()
|
||||
return None
|
||||
|
||||
def get_new_explanation( # noqa: D417
|
||||
def get_new_explanation(
|
||||
self,
|
||||
source_code: str,
|
||||
optimized_code: str,
|
||||
|
|
@ -542,7 +620,7 @@ class AiServiceClient:
|
|||
console.rule()
|
||||
return ""
|
||||
|
||||
def generate_ranking( # noqa: D417
|
||||
def generate_ranking(
|
||||
self,
|
||||
trace_id: str,
|
||||
diffs: list[str],
|
||||
|
|
@ -594,7 +672,7 @@ class AiServiceClient:
|
|||
console.rule()
|
||||
return None
|
||||
|
||||
def log_results( # noqa: D417
|
||||
def log_results(
|
||||
self,
|
||||
function_trace_id: str,
|
||||
speedup_ratio: dict[str, float | None] | None,
|
||||
|
|
@ -635,7 +713,7 @@ class AiServiceClient:
|
|||
except requests.exceptions.RequestException as e:
|
||||
logger.exception(f"Error logging features: {e}")
|
||||
|
||||
def generate_regression_tests( # noqa: D417
|
||||
def generate_regression_tests(
|
||||
self,
|
||||
source_code_being_tested: str,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
|
|
@ -646,7 +724,11 @@ class AiServiceClient:
|
|||
test_timeout: int,
|
||||
trace_id: str,
|
||||
test_index: int,
|
||||
is_numerical_code: bool | None = None, # noqa: FBT001
|
||||
*,
|
||||
language: str = "python",
|
||||
language_version: str | None = None,
|
||||
module_system: str | None = None,
|
||||
is_numerical_code: bool | None = None,
|
||||
) -> tuple[str, str, str] | None:
|
||||
"""Generate regression tests for the given function by making a request to the Django endpoint.
|
||||
|
||||
|
|
@ -657,19 +739,31 @@ class AiServiceClient:
|
|||
- helper_function_names (list[Source]): List of helper function names.
|
||||
- module_path (Path): The module path where the function is located.
|
||||
- test_module_path (Path): The module path for the test code.
|
||||
- test_framework (str): The test framework to use, e.g., "pytest".
|
||||
- test_framework (str): The test framework to use, e.g., "pytest", "jest".
|
||||
- test_timeout (int): The timeout for each test in seconds.
|
||||
- test_index (int): The index from 0-(n-1) if n tests are generated for a single trace_id
|
||||
- language (str): Programming language ("python", "javascript", "typescript")
|
||||
- language_version (str | None): Language version (e.g., "3.11.0" for Python, "ES2022" for JS)
|
||||
- module_system (str | None): JS/TS module system ("esm", "commonjs", or None for Python)
|
||||
|
||||
Returns
|
||||
-------
|
||||
- Dict[str, str] | None: The generated regression tests and instrumented tests, or None if an error occurred.
|
||||
|
||||
"""
|
||||
assert test_framework in ["pytest", "unittest"], (
|
||||
f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
|
||||
)
|
||||
payload = {
|
||||
# Validate test framework based on language
|
||||
python_frameworks = ["pytest", "unittest"]
|
||||
javascript_frameworks = ["jest", "mocha", "vitest"]
|
||||
if is_python():
|
||||
assert test_framework in python_frameworks, (
|
||||
f"Invalid test framework for Python, got {test_framework} but expected one of {python_frameworks}"
|
||||
)
|
||||
elif is_javascript():
|
||||
assert test_framework in javascript_frameworks, (
|
||||
f"Invalid test framework for JavaScript, got {test_framework} but expected one of {javascript_frameworks}"
|
||||
)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"source_code_being_tested": source_code_being_tested,
|
||||
"function_to_optimize": function_to_optimize,
|
||||
"helper_function_names": helper_function_names,
|
||||
|
|
@ -679,12 +773,26 @@ class AiServiceClient:
|
|||
"test_timeout": test_timeout,
|
||||
"trace_id": trace_id,
|
||||
"test_index": test_index,
|
||||
"python_version": platform.python_version(),
|
||||
"language": language,
|
||||
"codeflash_version": codeflash_version,
|
||||
"is_async": function_to_optimize.is_async,
|
||||
"call_sequence": self.get_next_sequence(),
|
||||
"is_numerical_code": is_numerical_code,
|
||||
}
|
||||
|
||||
# Add language-specific version fields
|
||||
# Always include python_version for backward compatibility with older backend
|
||||
payload["python_version"] = platform.python_version()
|
||||
if is_python():
|
||||
pass # python_version already set
|
||||
else:
|
||||
payload["language_version"] = language_version or "ES2022"
|
||||
# Add module system for JavaScript/TypeScript (esm or commonjs)
|
||||
if module_system:
|
||||
payload["module_system"] = module_system
|
||||
|
||||
# DEBUG: Print payload language field
|
||||
logger.debug(f"Sending testgen request with language='{payload['language']}', framework='{test_framework}'")
|
||||
try:
|
||||
response = self.make_ai_service_request("/testgen", payload=payload, timeout=self.timeout)
|
||||
except requests.exceptions.RequestException as e:
|
||||
|
|
@ -706,7 +814,7 @@ class AiServiceClient:
|
|||
error = response.json()["error"]
|
||||
logger.error(f"Error generating tests: {response.status_code} - {error}")
|
||||
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
|
||||
return None # noqa: TRY300
|
||||
return None
|
||||
except Exception:
|
||||
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
|
||||
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})
|
||||
|
|
@ -722,8 +830,8 @@ class AiServiceClient:
|
|||
function_trace_id: str,
|
||||
coverage_message: str,
|
||||
replay_tests: str,
|
||||
concolic_tests: str, # noqa: ARG002
|
||||
calling_fn_details: str,
|
||||
language: str = "python",
|
||||
) -> OptimizationReviewResult:
|
||||
"""Compute the optimization review of current Pull Request.
|
||||
|
||||
|
|
@ -765,7 +873,8 @@ class AiServiceClient:
|
|||
"original_runtime": humanize_runtime(explanation.original_runtime_ns),
|
||||
"codeflash_version": codeflash_version,
|
||||
"calling_fn_details": calling_fn_details,
|
||||
"python_version": platform.python_version(),
|
||||
"language": language,
|
||||
"python_version": platform.python_version() if is_python() else None,
|
||||
"call_sequence": self.get_next_sequence(),
|
||||
}
|
||||
console.rule()
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ def make_cfapi_request(
|
|||
else:
|
||||
response = requests.get(url, headers=cfapi_headers, params=params, timeout=60)
|
||||
response.raise_for_status()
|
||||
return response # noqa: TRY300
|
||||
return response
|
||||
except requests.exceptions.HTTPError:
|
||||
# response may be either a string or JSON, so we handle both cases
|
||||
error_message = ""
|
||||
|
|
@ -102,7 +102,7 @@ def make_cfapi_request(
|
|||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_user_id(api_key: Optional[str] = None) -> Optional[str]: # noqa: PLR0911
|
||||
def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
|
||||
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
|
||||
|
||||
:param api_key: The API key to use. If None, uses get_codeflash_api_key().
|
||||
|
|
@ -396,7 +396,7 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
|
|||
|
||||
def is_function_being_optimized_again(
|
||||
owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]]
|
||||
) -> Any: # noqa: ANN401
|
||||
) -> Any:
|
||||
"""Check if the function being optimized is being optimized again."""
|
||||
response = make_cfapi_request(
|
||||
"/is-already-optimized",
|
||||
|
|
|
|||
260
codeflash/api/schemas.py
Normal file
260
codeflash/api/schemas.py
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
"""Language-agnostic schemas for AI service communication.
|
||||
|
||||
This module defines standardized payload schemas that work across all supported
|
||||
languages (Python, JavaScript, TypeScript, and future languages).
|
||||
|
||||
Design principles:
|
||||
1. General fields that apply to any language
|
||||
2. Language-specific fields grouped in a nested object
|
||||
3. Backward compatible with existing backend
|
||||
4. Extensible for future languages without breaking changes
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ModuleSystem(str, Enum):
|
||||
"""Module system used by the code."""
|
||||
|
||||
COMMONJS = "commonjs" # JavaScript/Node.js require/exports
|
||||
ESM = "esm" # ES Modules import/export
|
||||
PYTHON = "python" # Python import system
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class TestFramework(str, Enum):
|
||||
"""Supported test frameworks."""
|
||||
|
||||
# Python
|
||||
PYTEST = "pytest"
|
||||
UNITTEST = "unittest"
|
||||
|
||||
# JavaScript/TypeScript
|
||||
JEST = "jest"
|
||||
MOCHA = "mocha"
|
||||
VITEST = "vitest"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageInfo:
|
||||
"""Language-specific information.
|
||||
|
||||
General fields that describe the programming language and its environment.
|
||||
This is designed to be extensible for future languages.
|
||||
"""
|
||||
|
||||
# Core language identifier
|
||||
name: str # "python", "javascript", "typescript", "rust", etc.
|
||||
|
||||
# Language version (format varies by language)
|
||||
# - Python: "3.11.0"
|
||||
# - JavaScript/TypeScript: "ES2022", "ES2023"
|
||||
# - Rust: "1.70.0"
|
||||
version: str | None = None
|
||||
|
||||
# Module system (primarily for JS/TS, but could apply to others)
|
||||
module_system: ModuleSystem = ModuleSystem.UNKNOWN
|
||||
|
||||
# File extension (for generated files)
|
||||
# - Python: ".py"
|
||||
# - JavaScript: ".js", ".mjs", ".cjs"
|
||||
# - TypeScript: ".ts", ".mts", ".cts"
|
||||
file_extension: str = ""
|
||||
|
||||
# Type system info (for typed languages)
|
||||
has_type_annotations: bool = False
|
||||
type_checker: str | None = None # "mypy", "typescript", "pyright", etc.
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestInfo:
|
||||
"""Test-related information."""
|
||||
|
||||
# Test framework being used
|
||||
framework: TestFramework
|
||||
|
||||
# Timeout for test execution (seconds)
|
||||
timeout: int = 60
|
||||
|
||||
# Test file path patterns (for discovery)
|
||||
test_patterns: list[str] = field(default_factory=list)
|
||||
|
||||
# Path to test files relative to project root
|
||||
tests_root: str = "tests"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizeRequest:
|
||||
"""Request payload for code optimization.
|
||||
|
||||
This schema is designed to be language-agnostic while supporting
|
||||
language-specific fields through the `language_info` object.
|
||||
"""
|
||||
|
||||
# === Core required fields ===
|
||||
source_code: str # Code to optimize
|
||||
trace_id: str # Unique identifier for this optimization run
|
||||
|
||||
# === Language information ===
|
||||
language_info: LanguageInfo
|
||||
|
||||
# === Optional context ===
|
||||
dependency_code: str = "" # Read-only context code
|
||||
module_path: str = "" # Path to the module being optimized
|
||||
|
||||
# === Function metadata ===
|
||||
is_async: bool = False # Whether function is async/await
|
||||
is_numerical_code: bool | None = None # Whether code does numerical computation
|
||||
|
||||
# === Generation parameters ===
|
||||
n_candidates: int = 5 # Number of optimization candidates
|
||||
|
||||
# === Metadata ===
|
||||
codeflash_version: str = ""
|
||||
experiment_metadata: dict[str, Any] | None = None
|
||||
repo_owner: str | None = None
|
||||
repo_name: str | None = None
|
||||
current_username: str | None = None
|
||||
|
||||
def to_payload(self) -> dict[str, Any]:
|
||||
"""Convert to API payload dict, maintaining backward compatibility."""
|
||||
payload = {
|
||||
"source_code": self.source_code,
|
||||
"trace_id": self.trace_id,
|
||||
"language": self.language_info.name,
|
||||
"dependency_code": self.dependency_code,
|
||||
"is_async": self.is_async,
|
||||
"n_candidates": self.n_candidates,
|
||||
"codeflash_version": self.codeflash_version,
|
||||
"experiment_metadata": self.experiment_metadata,
|
||||
"repo_owner": self.repo_owner,
|
||||
"repo_name": self.repo_name,
|
||||
"current_username": self.current_username,
|
||||
"is_numerical_code": self.is_numerical_code,
|
||||
}
|
||||
|
||||
# Add language-specific fields
|
||||
if self.language_info.version:
|
||||
payload["language_version"] = self.language_info.version
|
||||
|
||||
# Backward compat: always include python_version
|
||||
import platform
|
||||
|
||||
payload["python_version"] = platform.python_version()
|
||||
|
||||
# Module system for JS/TS
|
||||
if self.language_info.module_system != ModuleSystem.UNKNOWN:
|
||||
payload["module_system"] = self.language_info.module_system.value
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestGenRequest:
|
||||
"""Request payload for test generation.
|
||||
|
||||
This schema is designed to be language-agnostic while supporting
|
||||
language-specific fields through the `language_info` and `test_info` objects.
|
||||
"""
|
||||
|
||||
# === Core required fields ===
|
||||
source_code: str # Code being tested
|
||||
function_name: str # Name of function to generate tests for
|
||||
trace_id: str # Unique identifier
|
||||
|
||||
# === Language information ===
|
||||
language_info: LanguageInfo
|
||||
|
||||
# === Test information ===
|
||||
test_info: TestInfo
|
||||
|
||||
# === Path information ===
|
||||
module_path: str = "" # Path to source module
|
||||
test_module_path: str = "" # Path for generated test
|
||||
|
||||
# === Function metadata ===
|
||||
helper_function_names: list[str] = field(default_factory=list)
|
||||
is_async: bool = False
|
||||
is_numerical_code: bool | None = None
|
||||
|
||||
# === Generation parameters ===
|
||||
test_index: int = 0 # Index when generating multiple tests
|
||||
|
||||
# === Metadata ===
|
||||
codeflash_version: str = ""
|
||||
|
||||
def to_payload(self) -> dict[str, Any]:
|
||||
"""Convert to API payload dict, maintaining backward compatibility."""
|
||||
payload = {
|
||||
"source_code_being_tested": self.source_code,
|
||||
"function_to_optimize": {"function_name": self.function_name, "is_async": self.is_async},
|
||||
"helper_function_names": self.helper_function_names,
|
||||
"module_path": self.module_path,
|
||||
"test_module_path": self.test_module_path,
|
||||
"test_framework": self.test_info.framework.value,
|
||||
"test_timeout": self.test_info.timeout,
|
||||
"trace_id": self.trace_id,
|
||||
"test_index": self.test_index,
|
||||
"language": self.language_info.name,
|
||||
"codeflash_version": self.codeflash_version,
|
||||
"is_async": self.is_async,
|
||||
"is_numerical_code": self.is_numerical_code,
|
||||
}
|
||||
|
||||
# Add language version
|
||||
if self.language_info.version:
|
||||
payload["language_version"] = self.language_info.version
|
||||
|
||||
# Backward compat: always include python_version
|
||||
import platform
|
||||
|
||||
payload["python_version"] = platform.python_version()
|
||||
|
||||
# Module system for JS/TS
|
||||
if self.language_info.module_system != ModuleSystem.UNKNOWN:
|
||||
payload["module_system"] = self.language_info.module_system.value
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
# === Helper functions to create language info ===
|
||||
|
||||
|
||||
def python_language_info(version: str | None = None) -> LanguageInfo:
|
||||
"""Create LanguageInfo for Python."""
|
||||
import platform
|
||||
|
||||
return LanguageInfo(
|
||||
name="python",
|
||||
version=version or platform.python_version(),
|
||||
module_system=ModuleSystem.PYTHON,
|
||||
file_extension=".py",
|
||||
has_type_annotations=True,
|
||||
type_checker="mypy",
|
||||
)
|
||||
|
||||
|
||||
def javascript_language_info(
|
||||
module_system: ModuleSystem = ModuleSystem.COMMONJS, version: str = "ES2022"
|
||||
) -> LanguageInfo:
|
||||
"""Create LanguageInfo for JavaScript."""
|
||||
ext = ".mjs" if module_system == ModuleSystem.ESM else ".js"
|
||||
return LanguageInfo(
|
||||
name="javascript", version=version, module_system=module_system, file_extension=ext, has_type_annotations=False
|
||||
)
|
||||
|
||||
|
||||
def typescript_language_info(module_system: ModuleSystem = ModuleSystem.ESM, version: str = "ES2022") -> LanguageInfo:
|
||||
"""Create LanguageInfo for TypeScript."""
|
||||
return LanguageInfo(
|
||||
name="typescript",
|
||||
version=version,
|
||||
module_system=module_system,
|
||||
file_extension=".ts",
|
||||
has_type_annotations=True,
|
||||
type_checker="typescript",
|
||||
)
|
||||
|
|
@ -108,7 +108,7 @@ class CodeflashTrace:
|
|||
func_id = (func.__module__, func.__name__)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
|
||||
def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003
|
||||
# Initialize thread-local active functions set if it doesn't exist
|
||||
if not hasattr(self._thread_local, "active_functions"):
|
||||
self._thread_local.active_functions = set()
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class AddDecoratorTransformer(cst.CSTTransformer):
|
|||
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# Create import statement for codeflash_trace
|
||||
if not self.added_codeflash_trace:
|
||||
return updated_node
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
|
||||
# Process each row
|
||||
for row in cursor.fetchall():
|
||||
module_name, class_name, function_name, benchmark_file, benchmark_func, benchmark_line, time_ns = row
|
||||
module_name, class_name, function_name, benchmark_file, benchmark_func, _benchmark_line, time_ns = row
|
||||
|
||||
# Create the function key (module_name.class_name.function_name)
|
||||
if class_name:
|
||||
|
|
@ -172,7 +172,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
|
||||
# Process overhead information
|
||||
for row in cursor.fetchall():
|
||||
benchmark_file, benchmark_func, benchmark_line, total_overhead_ns = row
|
||||
benchmark_file, benchmark_func, _benchmark_line, total_overhead_ns = row
|
||||
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
|
||||
overhead_by_benchmark[benchmark_key] = total_overhead_ns or 0 # Handle NULL sum case
|
||||
|
||||
|
|
@ -184,7 +184,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
|
||||
# Process each row and subtract overhead
|
||||
for row in cursor.fetchall():
|
||||
benchmark_file, benchmark_func, benchmark_line, time_ns = row
|
||||
benchmark_file, benchmark_func, _benchmark_line, time_ns = row
|
||||
|
||||
# Create the benchmark key (file::function::line)
|
||||
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
|
||||
|
|
@ -200,7 +200,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
|
||||
# Pytest hooks
|
||||
@pytest.hookimpl
|
||||
def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, ARG002
|
||||
def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001
|
||||
"""Execute after whole test run is completed."""
|
||||
# Write any remaining benchmark timings to the database
|
||||
codeflash_trace.close()
|
||||
|
|
@ -218,7 +218,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
|
||||
for item in items:
|
||||
# Check for direct benchmark fixture usage
|
||||
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames
|
||||
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames # ty:ignore[unsupported-operator]
|
||||
|
||||
# Check for @pytest.mark.benchmark marker
|
||||
has_marker = False
|
||||
|
|
@ -236,7 +236,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
def __init__(self, request: pytest.FixtureRequest) -> None:
|
||||
self.request = request
|
||||
|
||||
def __call__(self, func, *args, **kwargs): # type: ignore # noqa: ANN001, ANN002, ANN003, ANN204, PGH003
|
||||
def __call__(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN204
|
||||
"""Handle both direct function calls and decorator usage."""
|
||||
if args or kwargs:
|
||||
# Used as benchmark(func, *args, **kwargs)
|
||||
|
|
@ -249,7 +249,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
self._run_benchmark(func)
|
||||
return wrapped_func
|
||||
|
||||
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
|
||||
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN002, ANN003, ANN202
|
||||
"""Actual benchmark implementation."""
|
||||
node_path = getattr(self.request.node, "path", None) or getattr(self.request.node, "fspath", None)
|
||||
if node_path is None:
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ def parse_args() -> Namespace:
|
|||
parser = ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="command", help="Sub-commands")
|
||||
|
||||
init_parser = subparsers.add_parser("init", help="Initialize Codeflash for a Python project.")
|
||||
init_parser = subparsers.add_parser("init", help="Initialize Codeflash for your project.")
|
||||
init_parser.set_defaults(func=init_codeflash)
|
||||
|
||||
subparsers.add_parser("vscode-install", help="Install the Codeflash VSCode extension")
|
||||
|
|
@ -28,7 +28,7 @@ def parse_args() -> Namespace:
|
|||
init_actions_parser = subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
|
||||
init_actions_parser.set_defaults(func=install_github_actions)
|
||||
|
||||
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize a Python project.")
|
||||
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.")
|
||||
|
||||
from codeflash.tracer import main as tracer_main
|
||||
|
||||
|
|
@ -70,8 +70,8 @@ def parse_args() -> Namespace:
|
|||
parser.add_argument(
|
||||
"--module-root",
|
||||
type=str,
|
||||
help="Path to the project's Python module that you want to optimize."
|
||||
" This is the top-level root directory where all the Python source code is located.",
|
||||
help="Path to the project's module that you want to optimize."
|
||||
" This is the top-level root directory where all the source code is located.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-root", type=str, help="Path to the test directory of the project, where all the tests are located."
|
||||
|
|
@ -206,7 +206,21 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
setattr(args, key.replace("-", "_"), pyproject_config[key])
|
||||
assert args.module_root is not None, "--module-root must be specified"
|
||||
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
|
||||
assert args.tests_root is not None, "--tests-root must be specified"
|
||||
|
||||
# For JS/TS projects, tests_root is optional (Jest auto-discovers tests)
|
||||
# Default to module_root if not specified
|
||||
is_js_ts_project = pyproject_config.get("language") in ("javascript", "typescript")
|
||||
if args.tests_root is None:
|
||||
if is_js_ts_project:
|
||||
# Try common JS test directories, or default to module_root
|
||||
for test_dir in ["test", "tests", "__tests__"]:
|
||||
if Path(test_dir).is_dir():
|
||||
args.tests_root = test_dir
|
||||
break
|
||||
if args.tests_root is None:
|
||||
args.tests_root = args.module_root
|
||||
else:
|
||||
raise AssertionError("--tests-root must be specified")
|
||||
assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"
|
||||
if args.benchmark:
|
||||
assert args.benchmarks_root is not None, "--benchmarks-root must be specified when running with --benchmark"
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def inquirer_wrapper(func: Callable[..., str | bool], *args: str | bool, **kwarg
|
|||
return func(*new_args, **new_kwargs)
|
||||
|
||||
|
||||
def split_string_to_cli_width(string: str, is_confirm: bool = False) -> list[str]: # noqa: FBT001, FBT002
|
||||
def split_string_to_cli_width(string: str, is_confirm: bool = False) -> list[str]:
|
||||
cli_width, _ = shutil.get_terminal_size()
|
||||
# split string to lines that accommodate "[?] " prefix
|
||||
cli_width -= len("[?] ")
|
||||
|
|
|
|||
|
|
@ -26,6 +26,15 @@ from codeflash.api.cfapi import get_user_id, is_github_app_installed_on_repo, se
|
|||
from codeflash.cli_cmds.cli_common import apologize_and_exit
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.cli_cmds.extension import install_vscode_extension
|
||||
|
||||
# Import JS/TS init module
|
||||
from codeflash.cli_cmds.init_javascript import (
|
||||
ProjectLanguage,
|
||||
detect_project_language,
|
||||
determine_js_package_manager,
|
||||
get_js_dependency_installation_commands,
|
||||
init_js_project,
|
||||
)
|
||||
from codeflash.code_utils.code_utils import validate_relative_directory_path
|
||||
from codeflash.code_utils.compat import LF
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
|
|
@ -57,6 +66,8 @@ CODEFLASH_LOGO: str = (
|
|||
|
||||
@dataclass(frozen=True)
|
||||
class CLISetupInfo:
|
||||
"""Setup info for Python projects."""
|
||||
|
||||
module_root: str
|
||||
tests_root: str
|
||||
benchmarks_root: Union[str, None]
|
||||
|
|
@ -68,12 +79,16 @@ class CLISetupInfo:
|
|||
|
||||
@dataclass(frozen=True)
|
||||
class VsCodeSetupInfo:
|
||||
"""Setup info for VSCode extension initialization."""
|
||||
|
||||
module_root: str
|
||||
tests_root: str
|
||||
formatter: Union[str, list[str]]
|
||||
|
||||
|
||||
class DependencyManager(Enum):
|
||||
"""Python dependency managers."""
|
||||
|
||||
PIP = auto()
|
||||
POETRY = auto()
|
||||
UV = auto()
|
||||
|
|
@ -95,6 +110,15 @@ def init_codeflash() -> None:
|
|||
console.print(welcome_panel)
|
||||
console.print()
|
||||
|
||||
# TODO:{claude} move the init_javascript to the support folder. Move any other language related specific implementation (other than python) to its support.
|
||||
# Detect project language
|
||||
project_language = detect_project_language()
|
||||
|
||||
if project_language in (ProjectLanguage.JAVASCRIPT, ProjectLanguage.TYPESCRIPT):
|
||||
init_js_project(project_language)
|
||||
return
|
||||
|
||||
# Python project flow
|
||||
did_add_new_key = prompt_api_key()
|
||||
|
||||
should_modify, config = should_modify_pyproject_toml()
|
||||
|
|
@ -663,7 +687,7 @@ def create_empty_pyproject_toml(pyproject_toml_path: Path) -> None:
|
|||
apologize_and_exit()
|
||||
|
||||
|
||||
def install_github_actions(override_formatter_check: bool = False) -> None: # noqa: FBT001, FBT002
|
||||
def install_github_actions(override_formatter_check: bool = False) -> None:
|
||||
try:
|
||||
config, _config_file_path = parse_config_file(override_formatter_check=override_formatter_check)
|
||||
|
||||
|
|
@ -771,8 +795,16 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n
|
|||
|
||||
# Generate workflow content AFTER user confirmation
|
||||
logger.info("[cmd_init.py:install_github_actions] User confirmed, generating workflow content...")
|
||||
|
||||
# Select the appropriate workflow template based on project language
|
||||
project_language = detect_project_language_for_workflow(Path.cwd())
|
||||
if project_language in ("javascript", "typescript"):
|
||||
workflow_template = "codeflash-optimize-js.yaml"
|
||||
else:
|
||||
workflow_template = "codeflash-optimize.yaml"
|
||||
|
||||
optimize_yml_content = (
|
||||
files("codeflash").joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml").read_text(encoding="utf-8")
|
||||
files("codeflash").joinpath("cli_cmds", "workflows", workflow_template).read_text(encoding="utf-8")
|
||||
)
|
||||
materialized_optimize_yml_content = generate_dynamic_workflow_content(
|
||||
optimize_yml_content, config, git_root, benchmark_mode
|
||||
|
|
@ -1089,11 +1121,12 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n
|
|||
apologize_and_exit()
|
||||
|
||||
|
||||
def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyManager: # noqa: PLR0911
|
||||
def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyManager:
|
||||
"""Determine which dependency manager is being used based on pyproject.toml contents."""
|
||||
if (Path.cwd() / "poetry.lock").exists():
|
||||
cwd = Path.cwd()
|
||||
if (cwd / "poetry.lock").exists():
|
||||
return DependencyManager.POETRY
|
||||
if (Path.cwd() / "uv.lock").exists():
|
||||
if (cwd / "uv.lock").exists():
|
||||
return DependencyManager.UV
|
||||
if "tool" not in pyproject_data:
|
||||
return DependencyManager.PIP
|
||||
|
|
@ -1168,6 +1201,48 @@ def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str:
|
|||
working-directory: ./{working_dir}"""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# JavaScript/TypeScript GitHub Actions Support
|
||||
# ============================================================================
|
||||
# Note: JS package manager and workflow helper functions are imported from init_javascript.py
|
||||
|
||||
|
||||
def detect_project_language_for_workflow(project_root: Path) -> str:
|
||||
"""Detect the primary language of the project for workflow generation.
|
||||
|
||||
Returns: 'python', 'javascript', or 'typescript'
|
||||
"""
|
||||
# Check for TypeScript config
|
||||
if (project_root / "tsconfig.json").exists():
|
||||
return "typescript"
|
||||
|
||||
# Check for JavaScript/TypeScript indicators
|
||||
has_package_json = (project_root / "package.json").exists()
|
||||
has_pyproject = (project_root / "pyproject.toml").exists()
|
||||
|
||||
if has_package_json and not has_pyproject:
|
||||
# Pure JS/TS project
|
||||
return "javascript"
|
||||
if has_pyproject and not has_package_json:
|
||||
# Pure Python project
|
||||
return "python"
|
||||
|
||||
# Both exist - count files to determine primary language
|
||||
js_count = 0
|
||||
py_count = 0
|
||||
for file in project_root.rglob("*"):
|
||||
if file.is_file():
|
||||
suffix = file.suffix.lower()
|
||||
if suffix in {".js", ".jsx", ".ts", ".tsx", ".mjs", ".cjs"}:
|
||||
js_count += 1
|
||||
elif suffix == ".py":
|
||||
py_count += 1
|
||||
|
||||
if js_count > py_count:
|
||||
return "javascript"
|
||||
return "python"
|
||||
|
||||
|
||||
def collect_repo_files_for_workflow(git_root: Path) -> dict[str, Any]:
|
||||
"""Collect important repository files and directory structure for workflow generation.
|
||||
|
||||
|
|
@ -1251,10 +1326,7 @@ def collect_repo_files_for_workflow(git_root: Path) -> dict[str, Any]:
|
|||
|
||||
|
||||
def generate_dynamic_workflow_content(
|
||||
optimize_yml_content: str,
|
||||
config: tuple[dict[str, Any], Path],
|
||||
git_root: Path,
|
||||
benchmark_mode: bool = False, # noqa: FBT001, FBT002
|
||||
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False
|
||||
) -> str:
|
||||
"""Generate workflow content with dynamic steps from AI service, falling back to static template.
|
||||
|
||||
|
|
@ -1268,7 +1340,15 @@ def generate_dynamic_workflow_content(
|
|||
module_path = str(Path(config["module_root"]).relative_to(git_root) / "**")
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path)
|
||||
|
||||
# Get working directory
|
||||
# Detect project language
|
||||
project_language = detect_project_language_for_workflow(Path.cwd())
|
||||
|
||||
# For JavaScript/TypeScript projects, use static template customization
|
||||
# (AI-generated steps are currently Python-only)
|
||||
if project_language in ("javascript", "typescript"):
|
||||
return customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode)
|
||||
|
||||
# Python project - try AI-generated steps
|
||||
toml_path = Path.cwd() / "pyproject.toml"
|
||||
try:
|
||||
with toml_path.open(encoding="utf8") as pyproject_file:
|
||||
|
|
@ -1378,14 +1458,28 @@ def generate_dynamic_workflow_content(
|
|||
|
||||
|
||||
def customize_codeflash_yaml_content(
|
||||
optimize_yml_content: str,
|
||||
config: tuple[dict[str, Any], Path],
|
||||
git_root: Path,
|
||||
benchmark_mode: bool = False, # noqa: FBT001, FBT002
|
||||
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False
|
||||
) -> str:
|
||||
module_path = str(Path(config["module_root"]).relative_to(git_root) / "**")
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path)
|
||||
|
||||
# Detect project language
|
||||
project_language = detect_project_language_for_workflow(Path.cwd())
|
||||
|
||||
if project_language in ("javascript", "typescript"):
|
||||
# JavaScript/TypeScript project
|
||||
return _customize_js_workflow_content(optimize_yml_content, git_root, benchmark_mode)
|
||||
|
||||
# Python project (default)
|
||||
return _customize_python_workflow_content(optimize_yml_content, git_root, benchmark_mode)
|
||||
|
||||
|
||||
def _customize_python_workflow_content(
|
||||
optimize_yml_content: str,
|
||||
git_root: Path,
|
||||
benchmark_mode: bool = False, # noqa: FBT001, FBT002
|
||||
) -> str:
|
||||
"""Customize workflow content for Python projects."""
|
||||
# Get dependency installation commands
|
||||
toml_path = Path.cwd() / "pyproject.toml"
|
||||
try:
|
||||
|
|
@ -1404,7 +1498,7 @@ def customize_codeflash_yaml_content(
|
|||
|
||||
python_depmanager_installation = get_dependency_manager_installation_string(dep_manager)
|
||||
optimize_yml_content = optimize_yml_content.replace(
|
||||
"{{ setup_python_dependency_manager }}", python_depmanager_installation
|
||||
"{{ setup_runtime_environment }}", python_depmanager_installation
|
||||
)
|
||||
install_deps_cmd = get_dependency_installation_commands(dep_manager)
|
||||
|
||||
|
|
@ -1418,6 +1512,64 @@ def customize_codeflash_yaml_content(
|
|||
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
|
||||
|
||||
|
||||
# TODO:{claude} Refactor and move to support for language specific
|
||||
def _customize_js_workflow_content(
|
||||
optimize_yml_content: str,
|
||||
git_root: Path,
|
||||
benchmark_mode: bool = False, # noqa: FBT001, FBT002
|
||||
) -> str:
|
||||
"""Customize workflow content for JavaScript/TypeScript projects."""
|
||||
from codeflash.cli_cmds.init_javascript import (
|
||||
get_js_codeflash_install_step,
|
||||
get_js_codeflash_run_command,
|
||||
get_js_runtime_setup_steps,
|
||||
is_codeflash_dependency,
|
||||
)
|
||||
|
||||
project_root = Path.cwd()
|
||||
package_json_path = project_root / "package.json"
|
||||
|
||||
if not package_json_path.exists():
|
||||
click.echo(
|
||||
f"I couldn't find a package.json in the current directory.{LF}"
|
||||
f"Please run `npm init` or create a package.json file first."
|
||||
)
|
||||
apologize_and_exit()
|
||||
|
||||
# Determine working directory relative to git root
|
||||
if project_root == git_root:
|
||||
working_dir = ""
|
||||
else:
|
||||
rel_path = str(project_root.relative_to(git_root))
|
||||
working_dir = f"""defaults:
|
||||
run:
|
||||
working-directory: ./{rel_path}"""
|
||||
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ working_directory }}", working_dir)
|
||||
|
||||
# Determine package manager and codeflash dependency status
|
||||
pkg_manager = determine_js_package_manager(project_root)
|
||||
codeflash_is_dep = is_codeflash_dependency(project_root)
|
||||
|
||||
# Setup runtime environment (Node.js/Bun)
|
||||
runtime_setup = get_js_runtime_setup_steps(pkg_manager)
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ setup_runtime_steps }}", runtime_setup)
|
||||
|
||||
# Install dependencies
|
||||
install_deps_cmd = get_js_dependency_installation_commands(pkg_manager)
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)
|
||||
|
||||
# Install codeflash step (only if not a dependency)
|
||||
install_codeflash = get_js_codeflash_install_step(pkg_manager, is_dependency=codeflash_is_dep)
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ install_codeflash_step }}", install_codeflash)
|
||||
|
||||
# Codeflash run command
|
||||
codeflash_cmd = get_js_codeflash_run_command(pkg_manager, is_dependency=codeflash_is_dep)
|
||||
if benchmark_mode:
|
||||
codeflash_cmd += " --benchmark"
|
||||
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
|
||||
|
||||
|
||||
def get_formatter_cmds(formatter: str) -> list[str]:
|
||||
if formatter == "black":
|
||||
return ["black $file"]
|
||||
|
|
|
|||
|
|
@ -98,17 +98,32 @@ def code_print(
|
|||
file_name: Optional[str] = None,
|
||||
function_name: Optional[str] = None,
|
||||
lsp_message_id: Optional[str] = None,
|
||||
language: str = "python",
|
||||
) -> None:
|
||||
"""Print code with syntax highlighting.
|
||||
|
||||
Args:
|
||||
code_str: The code to print
|
||||
file_name: Optional file name for LSP
|
||||
function_name: Optional function name for LSP
|
||||
lsp_message_id: Optional LSP message ID
|
||||
language: Programming language for syntax highlighting ('python', 'javascript', 'typescript')
|
||||
|
||||
"""
|
||||
if is_LSP_enabled():
|
||||
lsp_log(
|
||||
LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name, message_id=lsp_message_id)
|
||||
)
|
||||
return
|
||||
"""Print code with syntax highlighting."""
|
||||
|
||||
from rich.syntax import Syntax
|
||||
|
||||
# Map codeflash language names to rich/pygments lexer names
|
||||
lexer_map = {"python": "python", "javascript": "javascript", "typescript": "typescript"}
|
||||
lexer = lexer_map.get(language, "python")
|
||||
|
||||
console.rule()
|
||||
console.print(Syntax(code_str, "python", line_numbers=True, theme="github-dark"))
|
||||
console.print(Syntax(code_str, lexer, line_numbers=True, theme="github-dark"))
|
||||
console.rule()
|
||||
|
||||
|
||||
|
|
|
|||
657
codeflash/cli_cmds/init_javascript.py
Normal file
657
codeflash/cli_cmds/init_javascript.py
Normal file
|
|
@ -0,0 +1,657 @@
|
|||
"""JavaScript/TypeScript project initialization for Codeflash."""
|
||||
|
||||
# TODO:{claude} move to language support directory
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
import click
|
||||
import inquirer
|
||||
from git import InvalidGitRepositoryError, Repo
|
||||
from rich.console import Group
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from codeflash.cli_cmds.cli_common import apologize_and_exit
|
||||
from codeflash.cli_cmds.console import console
|
||||
from codeflash.code_utils.code_utils import validate_relative_directory_path
|
||||
from codeflash.code_utils.compat import LF
|
||||
from codeflash.code_utils.git_utils import get_git_remotes
|
||||
from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
|
||||
|
||||
class ProjectLanguage(Enum):
|
||||
"""Supported project languages."""
|
||||
|
||||
PYTHON = auto()
|
||||
JAVASCRIPT = auto()
|
||||
TYPESCRIPT = auto()
|
||||
|
||||
|
||||
class JsPackageManager(Enum):
|
||||
"""JavaScript/TypeScript package managers."""
|
||||
|
||||
NPM = auto()
|
||||
YARN = auto()
|
||||
PNPM = auto()
|
||||
BUN = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JSSetupInfo:
|
||||
"""Setup info for JavaScript/TypeScript projects.
|
||||
|
||||
Only stores values that override auto-detection or user preferences.
|
||||
Most config is auto-detected from package.json and project structure.
|
||||
"""
|
||||
|
||||
# Override values (None means use auto-detected value)
|
||||
module_root_override: Union[str, None] = None
|
||||
formatter_override: Union[list[str], None] = None
|
||||
|
||||
# User preferences (stored in config only if non-default)
|
||||
git_remote: str = "origin"
|
||||
disable_telemetry: bool = False
|
||||
ignore_paths: list[str] | None = None
|
||||
benchmarks_root: Union[str, None] = None
|
||||
|
||||
|
||||
# Import theme from cmd_init to avoid duplication
|
||||
def _get_theme(): # noqa: ANN202
|
||||
"""Get the CodeflashTheme - imported lazily to avoid circular imports."""
|
||||
from codeflash.cli_cmds.cmd_init import CodeflashTheme
|
||||
|
||||
return CodeflashTheme()
|
||||
|
||||
|
||||
def detect_project_language(project_root: Path | None = None) -> ProjectLanguage:
|
||||
"""Detect the primary language of the project.
|
||||
|
||||
Args:
|
||||
project_root: Root directory to check. Defaults to current directory.
|
||||
|
||||
Returns:
|
||||
ProjectLanguage enum value
|
||||
|
||||
"""
|
||||
root = project_root or Path.cwd()
|
||||
|
||||
has_pyproject = (root / "pyproject.toml").exists()
|
||||
has_setup_py = (root / "setup.py").exists()
|
||||
has_package_json = (root / "package.json").exists()
|
||||
has_tsconfig = (root / "tsconfig.json").exists()
|
||||
|
||||
# TypeScript project
|
||||
if has_tsconfig:
|
||||
return ProjectLanguage.TYPESCRIPT
|
||||
|
||||
# Pure JS project (has package.json but no Python files)
|
||||
if has_package_json and not has_pyproject and not has_setup_py:
|
||||
return ProjectLanguage.JAVASCRIPT
|
||||
|
||||
# Python project (default)
|
||||
return ProjectLanguage.PYTHON
|
||||
|
||||
|
||||
def determine_js_package_manager(project_root: Path) -> JsPackageManager:
|
||||
"""Determine which JavaScript package manager is being used based on lock files."""
|
||||
if (project_root / "bun.lockb").exists() or (project_root / "bun.lock").exists():
|
||||
return JsPackageManager.BUN
|
||||
if (project_root / "pnpm-lock.yaml").exists():
|
||||
return JsPackageManager.PNPM
|
||||
if (project_root / "yarn.lock").exists():
|
||||
return JsPackageManager.YARN
|
||||
if (project_root / "package-lock.json").exists():
|
||||
return JsPackageManager.NPM
|
||||
# Default to npm if package.json exists but no lock file
|
||||
if (project_root / "package.json").exists():
|
||||
return JsPackageManager.NPM
|
||||
return JsPackageManager.UNKNOWN
|
||||
|
||||
|
||||
def init_js_project(language: ProjectLanguage) -> None:
|
||||
"""Initialize Codeflash for a JavaScript/TypeScript project."""
|
||||
from codeflash.cli_cmds.cmd_init import install_github_actions, install_github_app, prompt_api_key
|
||||
|
||||
lang_name = "TypeScript" if language == ProjectLanguage.TYPESCRIPT else "JavaScript"
|
||||
|
||||
lang_panel = Panel(
|
||||
Text(
|
||||
f"📦 Detected {lang_name} project!\n\nI'll help you set up Codeflash for your project.",
|
||||
style="cyan",
|
||||
justify="center",
|
||||
),
|
||||
title=f"🟨 {lang_name} Setup",
|
||||
border_style="bright_yellow",
|
||||
)
|
||||
console.print(lang_panel)
|
||||
console.print()
|
||||
|
||||
did_add_new_key = prompt_api_key()
|
||||
|
||||
should_modify, _config = should_modify_package_json_config()
|
||||
|
||||
# Default git remote
|
||||
git_remote = "origin"
|
||||
|
||||
if should_modify:
|
||||
setup_info = collect_js_setup_info(language)
|
||||
git_remote = setup_info.git_remote or "origin"
|
||||
configured = configure_package_json(setup_info)
|
||||
if not configured:
|
||||
apologize_and_exit()
|
||||
|
||||
install_github_app(git_remote)
|
||||
|
||||
install_github_actions(override_formatter_check=True)
|
||||
|
||||
# Show completion message
|
||||
usage_table = Table(show_header=False, show_lines=False, border_style="dim")
|
||||
usage_table.add_column("Command", style="cyan")
|
||||
usage_table.add_column("Description", style="white")
|
||||
|
||||
usage_table.add_row("codeflash --file <path-to-file> --function <function-name>", "Optimize a specific function")
|
||||
usage_table.add_row("codeflash --all", "Optimize all functions in all files")
|
||||
usage_table.add_row("codeflash --help", "See all available options")
|
||||
|
||||
completion_message = (
|
||||
f"⚡️ Codeflash is now set up for your {lang_name} project!\n\nYou can now run any of these commands:"
|
||||
)
|
||||
|
||||
if did_add_new_key:
|
||||
completion_message += (
|
||||
"\n\n🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
|
||||
)
|
||||
if os.name == "nt":
|
||||
reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}"
|
||||
else:
|
||||
reload_cmd = f"source {get_shell_rc_path()}"
|
||||
completion_message += f"\nOr run: {reload_cmd}"
|
||||
|
||||
completion_panel = Panel(
|
||||
Group(Text(completion_message, style="bold green"), Text(""), usage_table),
|
||||
title="🎉 Setup Complete!",
|
||||
border_style="bright_green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print(completion_panel)
|
||||
|
||||
ph("cli-js-installation-successful", {"language": lang_name, "did_add_new_key": did_add_new_key})
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def should_modify_package_json_config() -> tuple[bool, dict[str, Any] | None]:
|
||||
"""Check if package.json has valid codeflash config for JS/TS projects."""
|
||||
from rich.prompt import Confirm
|
||||
|
||||
package_json_path = Path.cwd() / "package.json"
|
||||
|
||||
if not package_json_path.exists():
|
||||
click.echo("❌ No package.json found. Please run 'npm init' first.")
|
||||
apologize_and_exit()
|
||||
|
||||
try:
|
||||
with package_json_path.open(encoding="utf8") as f:
|
||||
package_data = json.load(f)
|
||||
|
||||
config = package_data.get("codeflash", {})
|
||||
|
||||
if not config:
|
||||
return True, None
|
||||
|
||||
# Check if module_root is valid (defaults to "." if not specified)
|
||||
module_root = config.get("moduleRoot", ".")
|
||||
if not Path(module_root).is_dir():
|
||||
return True, None
|
||||
|
||||
# Config is valid - ask if user wants to reconfigure
|
||||
return Confirm.ask(
|
||||
"✅ A valid Codeflash config already exists in package.json. Do you want to re-configure it?",
|
||||
default=False,
|
||||
show_default=True,
|
||||
), config
|
||||
except Exception:
|
||||
return True, None
|
||||
|
||||
|
||||
def collect_js_setup_info(language: ProjectLanguage) -> JSSetupInfo:
|
||||
"""Collect setup information for JavaScript/TypeScript projects.
|
||||
|
||||
Uses auto-detection for most settings and only asks for overrides if needed.
|
||||
"""
|
||||
from rich.prompt import Confirm
|
||||
|
||||
from codeflash.cli_cmds.cmd_init import ask_for_telemetry, get_valid_subdirs
|
||||
from codeflash.code_utils.config_js import (
|
||||
detect_formatter,
|
||||
detect_module_root,
|
||||
detect_test_runner,
|
||||
get_package_json_data,
|
||||
)
|
||||
|
||||
curdir = Path.cwd()
|
||||
|
||||
if not os.access(curdir, os.W_OK):
|
||||
click.echo(f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}")
|
||||
sys.exit(1)
|
||||
|
||||
lang_name = "TypeScript" if language == ProjectLanguage.TYPESCRIPT else "JavaScript"
|
||||
|
||||
# Load package.json data for detection
|
||||
package_json_path = curdir / "package.json"
|
||||
package_data = get_package_json_data(package_json_path) or {}
|
||||
|
||||
# Auto-detect values
|
||||
detected_module_root = detect_module_root(curdir, package_data)
|
||||
detected_test_runner = detect_test_runner(curdir, package_data)
|
||||
detected_formatter = detect_formatter(curdir, package_data)
|
||||
|
||||
# Build detection summary
|
||||
formatter_display = detected_formatter[0] if detected_formatter else "none detected"
|
||||
detection_table = Table(show_header=False, box=None, padding=(0, 2))
|
||||
detection_table.add_column("Setting", style="cyan")
|
||||
detection_table.add_column("Value", style="green")
|
||||
detection_table.add_row("Module root", detected_module_root)
|
||||
detection_table.add_row("Test runner", detected_test_runner)
|
||||
detection_table.add_row("Formatter", formatter_display)
|
||||
|
||||
detection_panel = Panel(
|
||||
Group(Text(f"Auto-detected settings for your {lang_name} project:\n", style="cyan"), detection_table),
|
||||
title="🔍 Auto-Detection Results",
|
||||
border_style="bright_blue",
|
||||
)
|
||||
console.print(detection_panel)
|
||||
console.print()
|
||||
|
||||
# Ask if user wants to change any settings
|
||||
module_root_override = None
|
||||
formatter_override = None
|
||||
|
||||
if Confirm.ask("Would you like to change any of these settings?", default=False):
|
||||
# Module root override
|
||||
valid_subdirs = get_valid_subdirs()
|
||||
curdir_option = f"current directory ({curdir})"
|
||||
custom_dir_option = "enter a custom directory…"
|
||||
keep_detected_option = f"✓ keep detected ({detected_module_root})"
|
||||
|
||||
module_options = [
|
||||
keep_detected_option,
|
||||
*[d for d in valid_subdirs if d not in ("tests", "__tests__", "node_modules", detected_module_root)],
|
||||
curdir_option,
|
||||
custom_dir_option,
|
||||
]
|
||||
|
||||
module_questions = [
|
||||
inquirer.List(
|
||||
"module_root",
|
||||
message=f"Which directory contains your {lang_name} source code?",
|
||||
choices=module_options,
|
||||
default=keep_detected_option,
|
||||
carousel=True,
|
||||
)
|
||||
]
|
||||
|
||||
module_answers = inquirer.prompt(module_questions, theme=_get_theme())
|
||||
if not module_answers:
|
||||
apologize_and_exit()
|
||||
|
||||
module_root_answer = module_answers["module_root"]
|
||||
if module_root_answer == keep_detected_option:
|
||||
pass # Keep auto-detected value
|
||||
elif module_root_answer == curdir_option:
|
||||
module_root_override = "."
|
||||
elif module_root_answer == custom_dir_option:
|
||||
module_root_override = _prompt_custom_directory("module")
|
||||
else:
|
||||
module_root_override = module_root_answer
|
||||
|
||||
ph("cli-js-module-root-provided", {"overridden": module_root_override is not None})
|
||||
|
||||
# Formatter override
|
||||
formatter_questions = [
|
||||
inquirer.List(
|
||||
"formatter",
|
||||
message="Which code formatter do you use?",
|
||||
choices=[
|
||||
(f"✓ keep detected ({formatter_display})", "keep"),
|
||||
("💅 prettier", "prettier"),
|
||||
("📐 eslint --fix", "eslint"),
|
||||
("🔧 other", "other"),
|
||||
("❌ don't use a formatter", "disabled"),
|
||||
],
|
||||
default="keep",
|
||||
carousel=True,
|
||||
)
|
||||
]
|
||||
|
||||
formatter_answers = inquirer.prompt(formatter_questions, theme=_get_theme())
|
||||
if not formatter_answers:
|
||||
apologize_and_exit()
|
||||
|
||||
formatter_choice = formatter_answers["formatter"]
|
||||
if formatter_choice != "keep":
|
||||
formatter_override = get_js_formatter_cmd(formatter_choice)
|
||||
|
||||
ph("cli-js-formatter-provided", {"overridden": formatter_override is not None})
|
||||
|
||||
# Git remote
|
||||
git_remote = _get_git_remote_for_setup()
|
||||
|
||||
# Telemetry
|
||||
disable_telemetry = not ask_for_telemetry()
|
||||
|
||||
return JSSetupInfo(
|
||||
module_root_override=module_root_override,
|
||||
formatter_override=formatter_override,
|
||||
git_remote=git_remote,
|
||||
disable_telemetry=disable_telemetry,
|
||||
)
|
||||
|
||||
|
||||
def _prompt_custom_directory(dir_type: str) -> str:
|
||||
"""Prompt for a custom directory path."""
|
||||
while True:
|
||||
custom_questions = [
|
||||
inquirer.Path(
|
||||
"custom_path",
|
||||
message=f"Enter the path to your {dir_type} directory",
|
||||
path_type=inquirer.Path.DIRECTORY,
|
||||
exists=True,
|
||||
)
|
||||
]
|
||||
|
||||
custom_answers = inquirer.prompt(custom_questions, theme=_get_theme())
|
||||
if not custom_answers:
|
||||
apologize_and_exit()
|
||||
|
||||
custom_path_str = str(custom_answers["custom_path"])
|
||||
is_valid, error_msg = validate_relative_directory_path(custom_path_str)
|
||||
if is_valid:
|
||||
return custom_path_str
|
||||
|
||||
click.echo(f"❌ Invalid path: {error_msg}")
|
||||
click.echo("Please enter a valid relative directory path.")
|
||||
console.print()
|
||||
|
||||
|
||||
def _get_git_remote_for_setup() -> str:
|
||||
"""Get git remote for project setup."""
|
||||
try:
|
||||
repo = Repo(Path.cwd(), search_parent_directories=True)
|
||||
git_remotes = get_git_remotes(repo)
|
||||
if not git_remotes:
|
||||
return ""
|
||||
|
||||
if len(git_remotes) == 1:
|
||||
return git_remotes[0]
|
||||
|
||||
git_panel = Panel(
|
||||
Text(
|
||||
"🔗 Configure Git Remote for Pull Requests.\n\nCodeflash will use this remote to create pull requests.",
|
||||
style="blue",
|
||||
),
|
||||
title="🔗 Git Remote Setup",
|
||||
border_style="bright_blue",
|
||||
)
|
||||
console.print(git_panel)
|
||||
console.print()
|
||||
|
||||
git_questions = [
|
||||
inquirer.List(
|
||||
"git_remote",
|
||||
message="Which git remote should Codeflash use?",
|
||||
choices=git_remotes,
|
||||
default="origin",
|
||||
carousel=True,
|
||||
)
|
||||
]
|
||||
|
||||
git_answers = inquirer.prompt(git_questions, theme=_get_theme())
|
||||
return git_answers["git_remote"] if git_answers else git_remotes[0]
|
||||
except InvalidGitRepositoryError:
|
||||
return ""
|
||||
|
||||
|
||||
def get_js_formatter_cmd(formatter: str) -> list[str]:
|
||||
"""Get formatter commands for JavaScript/TypeScript."""
|
||||
if formatter == "prettier":
|
||||
return ["npx prettier --write $file"]
|
||||
if formatter == "eslint":
|
||||
return ["npx eslint --fix $file"]
|
||||
if formatter == "other":
|
||||
click.echo("🔧 In package.json, please replace 'your-formatter' with your formatter command.")
|
||||
return ["your-formatter $file"]
|
||||
return ["disabled"]
|
||||
|
||||
|
||||
def configure_package_json(setup_info: JSSetupInfo) -> bool:
|
||||
"""Configure codeflash section in package.json for JavaScript/TypeScript projects.
|
||||
|
||||
Only writes minimal config - values that override auto-detection or user preferences.
|
||||
Auto-detected values (language, moduleRoot, testRunner, formatter) are NOT stored
|
||||
unless explicitly overridden by the user.
|
||||
"""
|
||||
package_json_path = Path.cwd() / "package.json"
|
||||
|
||||
try:
|
||||
with package_json_path.open(encoding="utf8") as f:
|
||||
package_data = json.load(f)
|
||||
except FileNotFoundError:
|
||||
click.echo("❌ No package.json found. Please run 'npm init' first.")
|
||||
return False
|
||||
except json.JSONDecodeError as e:
|
||||
click.echo(f"❌ Invalid package.json: {e}")
|
||||
return False
|
||||
|
||||
# Build minimal codeflash config using camelCase (JS convention)
|
||||
# Only include values that override auto-detection or are user preferences
|
||||
codeflash_config: dict[str, Any] = {}
|
||||
|
||||
# Module root override (only if user changed from auto-detected)
|
||||
if setup_info.module_root_override is not None:
|
||||
codeflash_config["moduleRoot"] = setup_info.module_root_override
|
||||
|
||||
# Formatter override (only if user changed from auto-detected)
|
||||
if setup_info.formatter_override is not None:
|
||||
if setup_info.formatter_override != ["disabled"]:
|
||||
codeflash_config["formatterCmds"] = setup_info.formatter_override
|
||||
else:
|
||||
codeflash_config["formatterCmds"] = []
|
||||
|
||||
# Git remote (only if not default "origin")
|
||||
if setup_info.git_remote and setup_info.git_remote not in ("", "origin"):
|
||||
codeflash_config["gitRemote"] = setup_info.git_remote
|
||||
|
||||
# User preferences
|
||||
if setup_info.disable_telemetry:
|
||||
codeflash_config["disableTelemetry"] = True
|
||||
|
||||
if setup_info.ignore_paths:
|
||||
codeflash_config["ignorePaths"] = setup_info.ignore_paths
|
||||
|
||||
if setup_info.benchmarks_root:
|
||||
codeflash_config["benchmarksRoot"] = setup_info.benchmarks_root
|
||||
|
||||
# Only write codeflash section if there's something to write
|
||||
if codeflash_config:
|
||||
package_data["codeflash"] = codeflash_config
|
||||
action = "Updated"
|
||||
else:
|
||||
# Remove codeflash section if empty (all auto-detected)
|
||||
if "codeflash" in package_data:
|
||||
del package_data["codeflash"]
|
||||
action = "Configured"
|
||||
|
||||
try:
|
||||
with package_json_path.open("w", encoding="utf8") as f:
|
||||
json.dump(package_data, f, indent=2)
|
||||
f.write("\n") # Trailing newline
|
||||
except OSError as e:
|
||||
click.echo(f"❌ Failed to update package.json: {e}")
|
||||
return False
|
||||
else:
|
||||
if codeflash_config:
|
||||
click.echo(f"✅ {action} Codeflash configuration in {package_json_path}")
|
||||
else:
|
||||
click.echo("✅ Using auto-detected configuration (no overrides needed)")
|
||||
click.echo()
|
||||
return True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# GitHub Actions Workflow Helpers for JS/TS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def is_codeflash_dependency(project_root: Path) -> bool:
|
||||
"""Check if codeflash is listed as a dependency in package.json."""
|
||||
package_json_path = project_root / "package.json"
|
||||
if not package_json_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
with package_json_path.open(encoding="utf8") as f:
|
||||
package_data = json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return False
|
||||
|
||||
deps = package_data.get("dependencies", {})
|
||||
dev_deps = package_data.get("devDependencies", {})
|
||||
return "codeflash" in deps or "codeflash" in dev_deps
|
||||
|
||||
|
||||
def get_js_runtime_setup_steps(pkg_manager: JsPackageManager) -> str:
|
||||
"""Generate the appropriate Node.js/Bun setup steps for GitHub Actions.
|
||||
|
||||
Returns properly indented YAML steps for the workflow template.
|
||||
"""
|
||||
if pkg_manager == JsPackageManager.BUN:
|
||||
return """- name: 🥟 Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest"""
|
||||
|
||||
if pkg_manager == JsPackageManager.PNPM:
|
||||
return """- name: 📦 Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
- name: 🟢 Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22'
|
||||
cache: 'pnpm'"""
|
||||
|
||||
if pkg_manager == JsPackageManager.YARN:
|
||||
return """- name: 🟢 Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22'
|
||||
cache: 'yarn'"""
|
||||
|
||||
# NPM or UNKNOWN
|
||||
return """- name: 🟢 Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22'
|
||||
cache: 'npm'"""
|
||||
|
||||
|
||||
def get_js_codeflash_install_step(pkg_manager: JsPackageManager, *, is_dependency: bool) -> str:
|
||||
"""Generate the codeflash installation step if not already a dependency.
|
||||
|
||||
Args:
|
||||
pkg_manager: The package manager being used.
|
||||
is_dependency: Whether codeflash is already in package.json dependencies.
|
||||
|
||||
Returns:
|
||||
YAML step string for installing codeflash, or empty string if not needed.
|
||||
|
||||
"""
|
||||
if is_dependency:
|
||||
# Codeflash will be installed with other dependencies
|
||||
return ""
|
||||
|
||||
# Need to install codeflash separately
|
||||
if pkg_manager == JsPackageManager.BUN:
|
||||
return """- name: 📥 Install Codeflash
|
||||
run: bun add -g codeflash"""
|
||||
|
||||
if pkg_manager == JsPackageManager.PNPM:
|
||||
return """- name: 📥 Install Codeflash
|
||||
run: pnpm add -g codeflash"""
|
||||
|
||||
if pkg_manager == JsPackageManager.YARN:
|
||||
return """- name: 📥 Install Codeflash
|
||||
run: yarn global add codeflash"""
|
||||
|
||||
# NPM or UNKNOWN
|
||||
return """- name: 📥 Install Codeflash
|
||||
run: npm install -g codeflash"""
|
||||
|
||||
|
||||
def get_js_codeflash_run_command(pkg_manager: JsPackageManager, *, is_dependency: bool) -> str:
|
||||
"""Generate the codeflash run command for GitHub Actions.
|
||||
|
||||
Args:
|
||||
pkg_manager: The package manager being used.
|
||||
is_dependency: Whether codeflash is in package.json dependencies.
|
||||
|
||||
Returns:
|
||||
Command string to run codeflash.
|
||||
|
||||
"""
|
||||
if is_dependency:
|
||||
# Use package manager's run command for local dependency
|
||||
if pkg_manager == JsPackageManager.BUN:
|
||||
return "bun run codeflash"
|
||||
if pkg_manager == JsPackageManager.PNPM:
|
||||
return "pnpm exec codeflash"
|
||||
if pkg_manager == JsPackageManager.YARN:
|
||||
return "yarn codeflash"
|
||||
# NPM
|
||||
return "npx codeflash"
|
||||
|
||||
# Globally installed - just run directly
|
||||
return "codeflash"
|
||||
|
||||
|
||||
def get_js_runtime_setup_string(pkg_manager: JsPackageManager) -> str:
|
||||
"""Generate the appropriate Node.js setup step for GitHub Actions.
|
||||
|
||||
Deprecated: Use get_js_runtime_setup_steps instead.
|
||||
"""
|
||||
return get_js_runtime_setup_steps(pkg_manager)
|
||||
|
||||
|
||||
def get_js_dependency_installation_commands(pkg_manager: JsPackageManager) -> str:
|
||||
"""Generate commands to install JavaScript/TypeScript dependencies."""
|
||||
if pkg_manager == JsPackageManager.BUN:
|
||||
return "bun install"
|
||||
if pkg_manager == JsPackageManager.PNPM:
|
||||
return "pnpm install"
|
||||
if pkg_manager == JsPackageManager.YARN:
|
||||
return "yarn install"
|
||||
# NPM or UNKNOWN
|
||||
return "npm ci"
|
||||
|
||||
|
||||
def get_js_codeflash_command(pkg_manager: JsPackageManager) -> str:
|
||||
"""Generate the appropriate codeflash command for JavaScript/TypeScript projects."""
|
||||
if pkg_manager == JsPackageManager.BUN:
|
||||
return "bunx codeflash"
|
||||
if pkg_manager == JsPackageManager.PNPM:
|
||||
return "pnpm dlx codeflash"
|
||||
if pkg_manager == JsPackageManager.YARN:
|
||||
return "yarn dlx codeflash"
|
||||
# NPM or UNKNOWN
|
||||
return "npx codeflash"
|
||||
35
codeflash/cli_cmds/workflows/codeflash-optimize-js.yaml
Normal file
35
codeflash/cli_cmds/workflows/codeflash-optimize-js.yaml
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
name: Codeflash
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
# So that this workflow only runs when code within the target module is modified
|
||||
- '{{ codeflash_module_path }}'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
# Any new push to the PR will cancel the previous run, so that only the latest code is optimized
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
jobs:
|
||||
optimize:
|
||||
name: Optimize new code
|
||||
# Don't run codeflash on codeflash-ai[bot] commits, prevent duplicate optimizations
|
||||
if: ${{ github.actor != 'codeflash-ai[bot]' }}
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
|
||||
{{ working_directory }}
|
||||
steps:
|
||||
- name: 🛎️ Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
{{ setup_runtime_steps }}
|
||||
- name: 📦 Install Dependencies
|
||||
run: {{ install_dependencies_command }}
|
||||
{{ install_codeflash_step }}
|
||||
- name: ⚡️ Codeflash Optimization
|
||||
run: {{ codeflash_command }}
|
||||
|
|
@ -16,6 +16,7 @@ from libcst.helpers import calculate_module_and_package
|
|||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_REVIEW
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.models.models import CodePosition, FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -44,14 +45,14 @@ class GlobalFunctionCollector(cst.CSTVisitor):
|
|||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
|
||||
self.scope_depth -= 1
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
|
||||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
self.scope_depth -= 1
|
||||
|
||||
|
||||
|
|
@ -65,7 +66,7 @@ class GlobalFunctionTransformer(cst.CSTTransformer):
|
|||
self.processed_functions: set[str] = set()
|
||||
self.scope_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
|
|
@ -80,14 +81,14 @@ class GlobalFunctionTransformer(cst.CSTTransformer):
|
|||
return self.new_functions[name]
|
||||
return updated_node
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
||||
self.scope_depth -= 1
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# Add any new functions that weren't in the original file
|
||||
new_statements = list(updated_node.body)
|
||||
|
||||
|
|
@ -141,28 +142,28 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
|
|||
self.scope_depth = 0
|
||||
self.if_else_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: # noqa: ARG002
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
|
||||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
|
||||
self.scope_depth -= 1
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
|
||||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
self.scope_depth -= 1
|
||||
|
||||
def visit_If(self, node: cst.If) -> Optional[bool]: # noqa: ARG002
|
||||
def visit_If(self, node: cst.If) -> Optional[bool]:
|
||||
self.if_else_depth += 1
|
||||
return True
|
||||
|
||||
def leave_If(self, original_node: cst.If) -> None: # noqa: ARG002
|
||||
def leave_If(self, original_node: cst.If) -> None:
|
||||
self.if_else_depth -= 1
|
||||
|
||||
def visit_Else(self, node: cst.Else) -> Optional[bool]: # noqa: ARG002
|
||||
def visit_Else(self, node: cst.Else) -> Optional[bool]:
|
||||
# Else blocks are already counted as part of the if statement
|
||||
return True
|
||||
|
||||
|
|
@ -231,24 +232,24 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
self.scope_depth = 0
|
||||
self.if_else_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
self.scope_depth -= 1
|
||||
return updated_node
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
||||
self.scope_depth -= 1
|
||||
return updated_node
|
||||
|
||||
def visit_If(self, node: cst.If) -> None: # noqa: ARG002
|
||||
def visit_If(self, node: cst.If) -> None:
|
||||
self.if_else_depth += 1
|
||||
|
||||
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: # noqa: ARG002
|
||||
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
|
||||
self.if_else_depth -= 1
|
||||
return updated_node
|
||||
|
||||
|
|
@ -283,7 +284,7 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# Add any new assignments that weren't in the original file
|
||||
new_statements = list(updated_node.body)
|
||||
|
||||
|
|
@ -370,7 +371,7 @@ class GlobalStatementTransformer(cst.CSTTransformer):
|
|||
super().__init__()
|
||||
self.global_statements = global_statements
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
if not self.global_statements:
|
||||
return updated_node
|
||||
|
||||
|
|
@ -396,20 +397,20 @@ class GlobalStatementCollector(cst.CSTVisitor):
|
|||
self.global_statements = []
|
||||
self.in_function_or_class = False
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
# Don't visit inside classes
|
||||
self.in_function_or_class = True
|
||||
return False
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
self.in_function_or_class = False
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # noqa: ARG002
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
||||
# Don't visit inside functions
|
||||
self.in_function_or_class = True
|
||||
return False
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
|
||||
self.in_function_or_class = False
|
||||
|
||||
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
|
||||
|
|
@ -490,16 +491,16 @@ class DottedImportCollector(cst.CSTVisitor):
|
|||
self.depth = 0
|
||||
self._collect_imports_from_block(node)
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.depth += 1
|
||||
|
||||
def leave_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.depth -= 1
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.depth += 1
|
||||
|
||||
def leave_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def leave_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.depth -= 1
|
||||
|
||||
def visit_If(self, node: cst.If) -> None:
|
||||
|
|
@ -529,9 +530,7 @@ def find_last_import_line(target_code: str) -> int:
|
|||
|
||||
class FutureAliasedImportTransformer(cst.CSTTransformer):
|
||||
def leave_ImportFrom(
|
||||
self,
|
||||
original_node: cst.ImportFrom, # noqa: ARG002
|
||||
updated_node: cst.ImportFrom,
|
||||
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
|
||||
) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel:
|
||||
import libcst.matchers as m
|
||||
|
||||
|
|
@ -676,7 +675,7 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
|
|||
if not name.startswith("_"):
|
||||
public_names.add(name)
|
||||
|
||||
return public_names # noqa: TRY300
|
||||
return public_names
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error resolving star import for {module_name}: {e}")
|
||||
|
|
@ -1165,7 +1164,7 @@ class FunctionCallFinder(ast.NodeVisitor):
|
|||
|
||||
return False
|
||||
|
||||
def _get_call_name(self, func_node) -> Optional[str]: # noqa: ANN001
|
||||
def _get_call_name(self, func_node) -> Optional[str]:
|
||||
"""Extract the name being called from a function node."""
|
||||
# Fast path short-circuit for ast.Name nodes
|
||||
if isinstance(func_node, ast.Name):
|
||||
|
|
@ -1341,9 +1340,12 @@ def get_fn_references_jedi(
|
|||
source_code: str, file_path: Path, project_root: Path, target_function: str, target_class: str | None
|
||||
) -> list[Path]:
|
||||
start_time = time.perf_counter()
|
||||
function_position: CodePosition = find_specific_function_in_file(
|
||||
function_position: CodePosition | None = find_specific_function_in_file(
|
||||
source_code, file_path, target_function, target_class
|
||||
)
|
||||
if function_position is None:
|
||||
# Function not found (may be non-Python code)
|
||||
return []
|
||||
try:
|
||||
script = jedi.Script(code=source_code, path=file_path, project=jedi.Project(path=project_root))
|
||||
# Get references to the function
|
||||
|
|
@ -1555,15 +1557,15 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo
|
|||
|
||||
# If numba is not installed and all modules used require numba for optimization,
|
||||
# return False since we can't optimize this code
|
||||
if not has_numba and modules_used.issubset(NUMBA_REQUIRED_MODULES): # noqa : SIM103
|
||||
return False
|
||||
|
||||
return True
|
||||
return not (not has_numba and modules_used.issubset(NUMBA_REQUIRED_MODULES))
|
||||
|
||||
|
||||
def get_opt_review_metrics(
|
||||
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path
|
||||
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path, language: Language
|
||||
) -> str:
|
||||
if language != Language.PYTHON:
|
||||
# TODO: {Claude} handle function refrences for other languages
|
||||
return ""
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
qualified_name_split = qualified_name.rsplit(".", maxsplit=1)
|
||||
|
|
|
|||
|
|
@ -18,12 +18,15 @@ from codeflash.code_utils.code_extractor import (
|
|||
from codeflash.code_utils.config_parser import find_conftest_files
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
from codeflash.code_utils.line_profile_utils import ImportAdder
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language, LanguageSupport
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode
|
||||
|
||||
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
|
||||
|
|
@ -109,7 +112,7 @@ class PytestMarkAdder(cst.CSTTransformer):
|
|||
if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest":
|
||||
self.has_pytest_import = True
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
"""Add pytest import if not present."""
|
||||
if not self.has_pytest_import:
|
||||
# Create import statement
|
||||
|
|
@ -118,7 +121,7 @@ class PytestMarkAdder(cst.CSTTransformer):
|
|||
updated_node = updated_node.with_changes(body=[import_stmt, *updated_node.body])
|
||||
return updated_node
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
"""Add pytest mark to test functions."""
|
||||
# Check if the mark already exists
|
||||
for decorator in updated_node.decorators:
|
||||
|
|
@ -291,7 +294,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def leave_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
if self.current_class:
|
||||
self.current_class = None
|
||||
|
||||
|
|
@ -315,7 +318,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
)
|
||||
self.current_class = None
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # noqa: ARG002
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
||||
return False
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
|
|
@ -344,7 +347,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
)
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
node = updated_node
|
||||
max_function_index = None
|
||||
max_class_index = None
|
||||
|
|
@ -440,8 +443,15 @@ def replace_function_definitions_in_module(
|
|||
module_abspath: Path,
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
|
||||
project_root_path: Path,
|
||||
should_add_global_assignments: bool = True, # noqa: FBT001, FBT002
|
||||
should_add_global_assignments: bool = True,
|
||||
function_to_optimize: Optional[FunctionToOptimize] = None,
|
||||
) -> bool:
|
||||
# Route to language-specific implementation for non-Python languages
|
||||
if not is_python():
|
||||
return replace_function_definitions_for_language(
|
||||
function_names, optimized_code, module_abspath, project_root_path, function_to_optimize
|
||||
)
|
||||
|
||||
source_code: str = module_abspath.read_text(encoding="utf8")
|
||||
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
|
||||
|
||||
|
|
@ -463,16 +473,271 @@ def replace_function_definitions_in_module(
|
|||
return True
|
||||
|
||||
|
||||
def replace_function_definitions_for_language(
|
||||
function_names: list[str],
|
||||
optimized_code: CodeStringsMarkdown,
|
||||
module_abspath: Path,
|
||||
project_root_path: Path,
|
||||
function_to_optimize: Optional[FunctionToOptimize] = None,
|
||||
) -> bool:
|
||||
"""Replace function definitions for non-Python languages.
|
||||
|
||||
Uses the language support abstraction to perform code replacement.
|
||||
|
||||
Args:
|
||||
function_names: List of qualified function names to replace.
|
||||
optimized_code: The optimized code to apply.
|
||||
module_abspath: Path to the module file.
|
||||
project_root_path: Root of the project.
|
||||
function_to_optimize: The function being optimized (needed for line info).
|
||||
|
||||
Returns:
|
||||
True if the code was modified, False if no changes.
|
||||
|
||||
"""
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
|
||||
|
||||
original_source_code: str = module_abspath.read_text(encoding="utf8")
|
||||
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
|
||||
|
||||
if not code_to_apply.strip():
|
||||
return False
|
||||
|
||||
# Get language support
|
||||
language = Language(optimized_code.language)
|
||||
lang_support = get_language_support(language)
|
||||
|
||||
# Add any new global declarations from the optimized code to the original source
|
||||
original_source_code = _add_global_declarations_for_language(
|
||||
optimized_code=code_to_apply,
|
||||
original_source=original_source_code,
|
||||
module_abspath=module_abspath,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# If we have function_to_optimize with line info and this is the main file, use it for precise replacement
|
||||
if (
|
||||
function_to_optimize
|
||||
and function_to_optimize.starting_line
|
||||
and function_to_optimize.ending_line
|
||||
and function_to_optimize.file_path == module_abspath
|
||||
):
|
||||
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents)
|
||||
func_info = FunctionInfo(
|
||||
name=function_to_optimize.function_name,
|
||||
file_path=module_abspath,
|
||||
start_line=function_to_optimize.starting_line,
|
||||
end_line=function_to_optimize.ending_line,
|
||||
parents=parents,
|
||||
is_async=function_to_optimize.is_async,
|
||||
language=language,
|
||||
)
|
||||
# Extract just the target function from the optimized code
|
||||
optimized_func = _extract_function_from_code(
|
||||
lang_support, code_to_apply, function_to_optimize.function_name, module_abspath
|
||||
)
|
||||
if optimized_func:
|
||||
new_code = lang_support.replace_function(original_source_code, func_info, optimized_func)
|
||||
else:
|
||||
# Fallback: use the entire optimized code (for simple single-function files)
|
||||
new_code = lang_support.replace_function(original_source_code, func_info, code_to_apply)
|
||||
else:
|
||||
# For helper files or when we don't have precise line info:
|
||||
# Find each function by name in both original and optimized code
|
||||
# Then replace with the corresponding optimized version
|
||||
new_code = original_source_code
|
||||
modified = False
|
||||
|
||||
# Get the list of function names to replace
|
||||
functions_to_replace = list(function_names)
|
||||
|
||||
for func_name in functions_to_replace:
|
||||
# Re-discover functions from current code state to get correct line numbers
|
||||
current_functions = lang_support.discover_functions_from_source(new_code, module_abspath)
|
||||
|
||||
# Find the function in current code
|
||||
func = None
|
||||
for f in current_functions:
|
||||
if func_name in (f.qualified_name, f.name):
|
||||
func = f
|
||||
break
|
||||
|
||||
if func is None:
|
||||
continue
|
||||
|
||||
# Extract just this function from the optimized code
|
||||
optimized_func = _extract_function_from_code(lang_support, code_to_apply, func.name, module_abspath)
|
||||
if optimized_func:
|
||||
new_code = lang_support.replace_function(new_code, func, optimized_func)
|
||||
modified = True
|
||||
|
||||
if not modified:
|
||||
logger.warning(f"Could not find function {function_names} in {module_abspath}")
|
||||
return False
|
||||
|
||||
# Check if there was actually a change
|
||||
if original_source_code.strip() == new_code.strip():
|
||||
return False
|
||||
|
||||
module_abspath.write_text(new_code, encoding="utf8")
|
||||
return True
|
||||
|
||||
|
||||
def _extract_function_from_code(
|
||||
lang_support: LanguageSupport, source_code: str, function_name: str, file_path: Path | None = None
|
||||
) -> str | None:
|
||||
"""Extract a specific function's source code from a code string.
|
||||
|
||||
Includes JSDoc/docstring comments if present.
|
||||
|
||||
Args:
|
||||
lang_support: Language support instance.
|
||||
source_code: The full source code containing the function.
|
||||
function_name: Name of the function to extract.
|
||||
file_path: Path to the file (used to determine correct analyzer for JS/TS).
|
||||
|
||||
Returns:
|
||||
The function's source code (including doc comments), or None if not found.
|
||||
|
||||
"""
|
||||
try:
|
||||
# Use the language support to find functions in the source
|
||||
# file_path is needed for JS/TS to determine correct analyzer (TypeScript vs JavaScript)
|
||||
functions = lang_support.discover_functions_from_source(source_code, file_path)
|
||||
for func in functions:
|
||||
if func.name == function_name:
|
||||
# Extract the function's source using line numbers
|
||||
# Use doc_start_line if available to include JSDoc/docstring
|
||||
lines = source_code.splitlines(keepends=True)
|
||||
effective_start = func.doc_start_line or func.start_line
|
||||
if effective_start and func.end_line and effective_start <= len(lines):
|
||||
func_lines = lines[effective_start - 1 : func.end_line]
|
||||
return "".join(func_lines)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting function {function_name}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _add_global_declarations_for_language(
|
||||
optimized_code: str, original_source: str, module_abspath: Path, language: Language
|
||||
) -> str:
|
||||
"""Add new global declarations from optimized code to original source.
|
||||
|
||||
Finds module-level declarations (const, let, var, class, type, interface, enum)
|
||||
in the optimized code that don't exist in the original source and adds them.
|
||||
|
||||
Args:
|
||||
optimized_code: The optimized code that may contain new declarations.
|
||||
original_source: The original source code.
|
||||
module_abspath: Path to the module file (for parser selection).
|
||||
language: The language of the code.
|
||||
|
||||
Returns:
|
||||
Original source with new declarations added after imports.
|
||||
|
||||
"""
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
# Only process JavaScript/TypeScript
|
||||
if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT):
|
||||
return original_source
|
||||
|
||||
try:
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(module_abspath)
|
||||
|
||||
# Find declarations in both original and optimized code
|
||||
original_declarations = analyzer.find_module_level_declarations(original_source)
|
||||
optimized_declarations = analyzer.find_module_level_declarations(optimized_code)
|
||||
|
||||
if not optimized_declarations:
|
||||
return original_source
|
||||
|
||||
# Get names of existing declarations
|
||||
existing_names = {decl.name for decl in original_declarations}
|
||||
|
||||
# Find new declarations (names that don't exist in original)
|
||||
new_declarations = []
|
||||
seen_sources = set() # Track to avoid duplicates from destructuring
|
||||
for decl in optimized_declarations:
|
||||
if decl.name not in existing_names and decl.source_code not in seen_sources:
|
||||
new_declarations.append(decl)
|
||||
seen_sources.add(decl.source_code)
|
||||
|
||||
if not new_declarations:
|
||||
return original_source
|
||||
|
||||
# Sort by line number to maintain order
|
||||
new_declarations.sort(key=lambda d: d.start_line)
|
||||
|
||||
# Find insertion point (after imports)
|
||||
lines = original_source.splitlines(keepends=True)
|
||||
insertion_line = _find_insertion_line_after_imports_js(lines, analyzer, original_source)
|
||||
|
||||
# Build new declarations string
|
||||
new_decl_code = "\n".join(decl.source_code for decl in new_declarations)
|
||||
new_decl_code = new_decl_code + "\n\n"
|
||||
|
||||
# Insert declarations
|
||||
before = lines[:insertion_line]
|
||||
after = lines[insertion_line:]
|
||||
result_lines = [*before, new_decl_code, *after]
|
||||
|
||||
return "".join(result_lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding global declarations: {e}")
|
||||
return original_source
|
||||
|
||||
|
||||
def _find_insertion_line_after_imports_js(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
|
||||
"""Find the line index where new declarations should be inserted (after imports).
|
||||
|
||||
Args:
|
||||
lines: Source lines.
|
||||
analyzer: TreeSitter analyzer for the file.
|
||||
source: Full source code.
|
||||
|
||||
Returns:
|
||||
Line index (0-based) for insertion.
|
||||
|
||||
"""
|
||||
try:
|
||||
imports = analyzer.find_imports(source)
|
||||
if imports:
|
||||
# Find the last import's end line
|
||||
return max(imp.end_line for imp in imports)
|
||||
except Exception as exc:
|
||||
logger.debug(f"Exception occurred in _find_insertion_line_after_imports_js: {exc}")
|
||||
|
||||
# Default: insert at beginning (after any shebang/directive comments)
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("//") and not stripped.startswith("#!"):
|
||||
return i
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str:
|
||||
file_to_code_context = optimized_code.file_to_path()
|
||||
module_optimized_code = file_to_code_context.get(str(relative_path))
|
||||
if module_optimized_code is None:
|
||||
logger.warning(
|
||||
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
|
||||
"re-check your 'markdown code structure'"
|
||||
f"existing files are {file_to_code_context.keys()}"
|
||||
)
|
||||
module_optimized_code = ""
|
||||
# Fallback: if there's only one code block with None file path,
|
||||
# use it regardless of the expected path (the AI server doesn't always include file paths)
|
||||
if "None" in file_to_code_context and len(file_to_code_context) == 1:
|
||||
module_optimized_code = file_to_code_context["None"]
|
||||
logger.debug(f"Using code block with None file_path for {relative_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
|
||||
"re-check your 'markdown code structure'"
|
||||
f"existing files are {file_to_code_context.keys()}"
|
||||
)
|
||||
module_optimized_code = ""
|
||||
return module_optimized_code
|
||||
|
||||
|
||||
|
|
@ -518,7 +783,8 @@ def replace_optimized_code(
|
|||
[
|
||||
callee.qualified_name
|
||||
for callee in code_context.helper_functions
|
||||
if callee.file_path == module_path and callee.jedi_definition.type != "class"
|
||||
if callee.file_path == module_path
|
||||
and (callee.jedi_definition is None or callee.jedi_definition.type != "class")
|
||||
]
|
||||
),
|
||||
candidate.source_code,
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ def filter_args(addopts_args: list[str]) -> list[str]:
|
|||
return filtered_args
|
||||
|
||||
|
||||
def modify_addopts(config_file: Path) -> tuple[str, bool]: # noqa : PLR0911
|
||||
def modify_addopts(config_file: Path) -> tuple[str, bool]:
|
||||
file_type = config_file.suffix.lower()
|
||||
filename = config_file.name
|
||||
config = None
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def extract_test_context_from_env() -> tuple[str, str | None, str]:
|
|||
|
||||
def codeflash_behavior_async(func: F) -> F:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
|
|
@ -122,7 +122,7 @@ def codeflash_behavior_async(func: F) -> F:
|
|||
|
||||
def codeflash_performance_async(func: F) -> F:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
|
|
@ -172,7 +172,7 @@ def codeflash_concurrency_async(func: F) -> F:
|
|||
"""Measures concurrent vs sequential execution performance for async functions."""
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
function_name = func.__name__
|
||||
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class AssertCleanup:
|
|||
|
||||
unittest_match = self.unittest_re.match(line)
|
||||
if unittest_match:
|
||||
indent, assert_method, args = unittest_match.groups()
|
||||
indent, _assert_method, args = unittest_match.groups()
|
||||
|
||||
if args:
|
||||
arg_parts = self._first_top_level_arg(args)
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ EFFORT_VALUES: dict[str, dict[EffortLevel, Any]] = {
|
|||
}
|
||||
|
||||
|
||||
def get_effort_value(key: EffortKeys, effort: Union[EffortLevel, str]) -> Any: # noqa: ANN401
|
||||
def get_effort_value(key: EffortKeys, effort: Union[EffortLevel, str]) -> Any:
|
||||
key_str = key.value
|
||||
|
||||
if isinstance(effort, str):
|
||||
|
|
|
|||
290
codeflash/code_utils/config_js.py
Normal file
290
codeflash/code_utils/config_js.py
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
"""JavaScript/TypeScript configuration parsing from package.json."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
PACKAGE_JSON_CACHE: dict[Path, Path] = {}
|
||||
PACKAGE_JSON_DATA_CACHE: dict[Path, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def get_package_json_data(package_json_path: Path) -> dict[str, Any] | None:
|
||||
"""Load and cache package.json data.
|
||||
|
||||
Args:
|
||||
package_json_path: Path to package.json file.
|
||||
|
||||
Returns:
|
||||
Parsed package.json data or None if invalid.
|
||||
|
||||
"""
|
||||
if package_json_path in PACKAGE_JSON_DATA_CACHE:
|
||||
return PACKAGE_JSON_DATA_CACHE[package_json_path]
|
||||
|
||||
try:
|
||||
with package_json_path.open(encoding="utf8") as f:
|
||||
data: dict[str, Any] = json.load(f)
|
||||
PACKAGE_JSON_DATA_CACHE[package_json_path] = data
|
||||
return data
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def detect_language(project_root: Path) -> str:
|
||||
"""Detect project language from tsconfig.json presence.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
|
||||
Returns:
|
||||
"typescript" if tsconfig.json exists, "javascript" otherwise.
|
||||
|
||||
"""
|
||||
tsconfig_path = project_root / "tsconfig.json"
|
||||
return "typescript" if tsconfig_path.exists() else "javascript"
|
||||
|
||||
|
||||
def detect_module_root(project_root: Path, package_data: dict[str, Any]) -> str:
|
||||
"""Detect module root from package.json fields or directory conventions.
|
||||
|
||||
Detection order:
|
||||
1. package.json "exports" field (extract directory from main export)
|
||||
2. package.json "module" field (ESM entry point)
|
||||
3. package.json "main" field (CJS entry point)
|
||||
4. "src/" directory if it exists
|
||||
5. Fall back to "." (project root)
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
package_data: Parsed package.json data.
|
||||
|
||||
Returns:
|
||||
Detected module root path (relative to project root).
|
||||
|
||||
"""
|
||||
# Check exports field (modern Node.js)
|
||||
exports = package_data.get("exports")
|
||||
if exports:
|
||||
entry_path = None
|
||||
if isinstance(exports, str):
|
||||
entry_path = exports
|
||||
elif isinstance(exports, dict):
|
||||
# Handle {"." : "./src/index.js"} or {".": {"import": "./src/index.js"}}
|
||||
main_export = exports.get(".") or exports.get("import") or exports.get("default")
|
||||
if isinstance(main_export, str):
|
||||
entry_path = main_export
|
||||
elif isinstance(main_export, dict):
|
||||
entry_path = main_export.get("import") or main_export.get("default") or main_export.get("require")
|
||||
|
||||
if entry_path and isinstance(entry_path, str):
|
||||
parent = Path(entry_path).parent
|
||||
if parent != Path() and (project_root / parent).is_dir():
|
||||
return parent.as_posix()
|
||||
|
||||
# Check module field (ESM)
|
||||
module_field = package_data.get("module")
|
||||
if module_field and isinstance(module_field, str):
|
||||
parent = Path(module_field).parent
|
||||
if parent != Path() and (project_root / parent).is_dir():
|
||||
return parent.as_posix()
|
||||
|
||||
# Check main field (CJS)
|
||||
main_field = package_data.get("main")
|
||||
if main_field and isinstance(main_field, str):
|
||||
parent = Path(main_field).parent
|
||||
if parent != Path() and (project_root / parent).is_dir():
|
||||
return parent.as_posix()
|
||||
|
||||
# Check for src/ directory convention
|
||||
if (project_root / "src").is_dir():
|
||||
return "src"
|
||||
|
||||
# Default to project root
|
||||
return "."
|
||||
|
||||
|
||||
def detect_test_runner(project_root: Path, package_data: dict[str, Any]) -> str: # noqa: ARG001
|
||||
"""Detect test runner from devDependencies or scripts.test.
|
||||
|
||||
Detection order:
|
||||
1. Check devDependencies for vitest, jest, mocha
|
||||
2. Parse scripts.test for runner hints
|
||||
3. Fall back to "jest" as default
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
package_data: Parsed package.json data.
|
||||
|
||||
Returns:
|
||||
Detected test runner command (e.g., "jest", "vitest", "mocha").
|
||||
|
||||
"""
|
||||
runners = ["vitest", "jest", "mocha"]
|
||||
dev_deps = package_data.get("devDependencies", {})
|
||||
deps = package_data.get("dependencies", {})
|
||||
all_deps = {**deps, **dev_deps}
|
||||
|
||||
# Check devDependencies (order matters - prefer more modern runners)
|
||||
for runner in runners:
|
||||
if runner in all_deps:
|
||||
return runner
|
||||
|
||||
# Parse scripts.test for hints
|
||||
scripts = package_data.get("scripts", {})
|
||||
test_script = scripts.get("test", "")
|
||||
if isinstance(test_script, str):
|
||||
test_lower = test_script.lower()
|
||||
for runner in runners:
|
||||
if runner in test_lower:
|
||||
return runner
|
||||
|
||||
# Default to jest
|
||||
return "jest"
|
||||
|
||||
|
||||
def detect_formatter(project_root: Path, package_data: dict[str, Any]) -> list[str] | None: # noqa: ARG001
|
||||
"""Detect formatter from devDependencies.
|
||||
|
||||
Detection order:
|
||||
1. Check devDependencies for prettier
|
||||
2. Check devDependencies for eslint (with --fix)
|
||||
3. Return None if no formatter detected
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
package_data: Parsed package.json data.
|
||||
|
||||
Returns:
|
||||
List of formatter commands or None if not detected.
|
||||
|
||||
"""
|
||||
dev_deps = package_data.get("devDependencies", {})
|
||||
deps = package_data.get("dependencies", {})
|
||||
all_deps = {**deps, **dev_deps}
|
||||
|
||||
# Check for prettier (preferred)
|
||||
if "prettier" in all_deps:
|
||||
return ["npx prettier --write $file"]
|
||||
|
||||
# Check for eslint (can format with --fix)
|
||||
if "eslint" in all_deps:
|
||||
return ["npx eslint --fix $file"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_package_json(config_file: Path | None = None) -> Path | None:
|
||||
"""Find package.json file for JavaScript/TypeScript projects.
|
||||
|
||||
Args:
|
||||
config_file: Optional explicit config file path.
|
||||
|
||||
Returns:
|
||||
Path to package.json if found, None otherwise.
|
||||
|
||||
"""
|
||||
if config_file is not None:
|
||||
config_file = Path(config_file)
|
||||
if config_file.name == "package.json" and config_file.exists():
|
||||
return config_file
|
||||
return None
|
||||
|
||||
dir_path = Path.cwd()
|
||||
cur_path = dir_path
|
||||
|
||||
if cur_path in PACKAGE_JSON_CACHE:
|
||||
return PACKAGE_JSON_CACHE[cur_path]
|
||||
|
||||
while dir_path != dir_path.parent:
|
||||
config_file = dir_path / "package.json"
|
||||
if config_file.exists():
|
||||
PACKAGE_JSON_CACHE[cur_path] = config_file
|
||||
return config_file
|
||||
dir_path = dir_path.parent
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_package_json_config(package_json_path: Path) -> tuple[dict[str, Any], Path] | None:
|
||||
"""Parse codeflash config from package.json with auto-detection.
|
||||
|
||||
Most configuration is auto-detected from package.json and project structure.
|
||||
Only minimal config is stored in the "codeflash" key:
|
||||
- benchmarksRoot: Where to store benchmark files (optional, defaults to __benchmarks__)
|
||||
- ignorePaths: Paths to exclude from optimization (optional)
|
||||
- disableTelemetry: Privacy preference (optional, defaults to false)
|
||||
- formatterCmds: Override auto-detected formatter (optional)
|
||||
|
||||
Auto-detected values (not stored in config):
|
||||
- language: Detected from tsconfig.json presence
|
||||
- moduleRoot: Detected from package.json exports/module/main or src/ convention
|
||||
- testRunner: Detected from devDependencies (vitest/jest/mocha)
|
||||
- formatter: Detected from devDependencies (prettier/eslint)
|
||||
|
||||
Args:
|
||||
package_json_path: Path to package.json file.
|
||||
|
||||
Returns:
|
||||
Tuple of (config dict, path) if package.json exists, None otherwise.
|
||||
|
||||
"""
|
||||
package_data = get_package_json_data(package_json_path)
|
||||
if package_data is None:
|
||||
return None
|
||||
|
||||
project_root = package_json_path.parent
|
||||
codeflash_config = package_data.get("codeflash", {})
|
||||
if not isinstance(codeflash_config, dict):
|
||||
codeflash_config = {}
|
||||
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
# Auto-detect language
|
||||
config["language"] = detect_language(project_root)
|
||||
|
||||
# Auto-detect module root (can be overridden)
|
||||
if codeflash_config.get("moduleRoot"):
|
||||
config["module_root"] = str((project_root / Path(codeflash_config["moduleRoot"])).resolve())
|
||||
else:
|
||||
detected_module_root = detect_module_root(project_root, package_data)
|
||||
config["module_root"] = str((project_root / Path(detected_module_root)).resolve())
|
||||
|
||||
# Auto-detect test runner
|
||||
config["test_runner"] = detect_test_runner(project_root, package_data)
|
||||
# Keep pytest_cmd for backwards compatibility with existing code
|
||||
config["pytest_cmd"] = config["test_runner"]
|
||||
|
||||
# Auto-detect formatter (with optional override from config)
|
||||
if "formatterCmds" in codeflash_config:
|
||||
config["formatter_cmds"] = codeflash_config["formatterCmds"]
|
||||
else:
|
||||
detected_formatter = detect_formatter(project_root, package_data)
|
||||
config["formatter_cmds"] = detected_formatter if detected_formatter else []
|
||||
|
||||
# Parse optional config values from codeflash section
|
||||
if codeflash_config.get("benchmarksRoot"):
|
||||
config["benchmarks_root"] = str((project_root / Path(codeflash_config["benchmarksRoot"])).resolve())
|
||||
|
||||
if codeflash_config.get("ignorePaths"):
|
||||
config["ignore_paths"] = [str((project_root / path).resolve()) for path in codeflash_config["ignorePaths"]]
|
||||
else:
|
||||
config["ignore_paths"] = []
|
||||
|
||||
config["disable_telemetry"] = codeflash_config.get("disableTelemetry", False)
|
||||
|
||||
# Git remote (from config or default to "origin")
|
||||
config["git_remote"] = codeflash_config.get("gitRemote", "origin")
|
||||
|
||||
# Set remaining defaults for backwards compatibility
|
||||
config.setdefault("disable_imports_sorting", False)
|
||||
config.setdefault("override_fixtures", False)
|
||||
|
||||
return config, package_json_path
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
"""Clear all package.json caches."""
|
||||
PACKAGE_JSON_CACHE.clear()
|
||||
PACKAGE_JSON_DATA_CACHE.clear()
|
||||
|
|
@ -5,10 +5,11 @@ from typing import Any
|
|||
|
||||
import tomlkit
|
||||
|
||||
from codeflash.code_utils.config_js import find_package_json, parse_package_json_config
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
|
||||
PYPROJECT_TOML_CACHE = {}
|
||||
ALL_CONFIG_FILES = {} # map path to closest config file
|
||||
PYPROJECT_TOML_CACHE: dict[Path, Path] = {}
|
||||
ALL_CONFIG_FILES: dict[Path, dict[str, Path]] = {}
|
||||
|
||||
|
||||
def find_pyproject_toml(config_file: Path | None = None) -> Path:
|
||||
|
|
@ -83,10 +84,27 @@ def find_conftest_files(test_paths: list[Path]) -> list[Path]:
|
|||
return list(list_of_conftest_files)
|
||||
|
||||
|
||||
# TODO for claude: There should be different functions to parse it per language, which should be chosen during runtime
|
||||
def parse_config_file(
|
||||
config_file_path: Path | None = None,
|
||||
override_formatter_check: bool = False, # noqa: FBT001, FBT002
|
||||
config_file_path: Path | None = None, override_formatter_check: bool = False
|
||||
) -> tuple[dict[str, Any], Path]:
|
||||
# First try package.json for JS/TS projects
|
||||
package_json_path = find_package_json(config_file_path)
|
||||
if package_json_path:
|
||||
result = parse_package_json_config(package_json_path)
|
||||
if result is not None:
|
||||
config, path = result
|
||||
# Validate formatter if needed
|
||||
if not override_formatter_check and config.get("formatter_cmds"):
|
||||
formatter_cmds = config.get("formatter_cmds", [])
|
||||
if formatter_cmds and formatter_cmds[0] == "your-formatter $file":
|
||||
raise ValueError(
|
||||
"The formatter command is not set correctly in package.json. Please set the "
|
||||
"formatter command in the 'formatterCmds' key."
|
||||
)
|
||||
return config, path
|
||||
|
||||
# Fall back to pyproject.toml
|
||||
config_file_path = find_pyproject_toml(config_file_path)
|
||||
try:
|
||||
with config_file_path.open("rb") as f:
|
||||
|
|
|
|||
|
|
@ -1,250 +1,129 @@
|
|||
import ast
|
||||
"""Code deduplication utilities using language-specific normalizers.
|
||||
|
||||
This module provides functions to normalize code, generate fingerprints,
|
||||
and detect duplicate code segments across different programming languages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
from codeflash.code_utils.normalizers import get_normalizer
|
||||
from codeflash.languages import current_language, is_python
|
||||
|
||||
|
||||
class VariableNormalizer(ast.NodeTransformer):
|
||||
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
|
||||
|
||||
Preserves function names, class names, parameters, built-ins, and imported names.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.var_counter = 0
|
||||
self.var_mapping: dict[str, str] = {}
|
||||
self.scope_stack = []
|
||||
self.builtins = set(dir(__builtins__))
|
||||
self.imports: set[str] = set()
|
||||
self.global_vars: set[str] = set()
|
||||
self.nonlocal_vars: set[str] = set()
|
||||
self.parameters: set[str] = set() # Track function parameters
|
||||
|
||||
def enter_scope(self): # noqa : ANN201
|
||||
"""Enter a new scope (function/class)."""
|
||||
self.scope_stack.append(
|
||||
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
|
||||
)
|
||||
|
||||
def exit_scope(self): # noqa : ANN201
|
||||
"""Exit current scope and restore parent scope."""
|
||||
if self.scope_stack:
|
||||
scope = self.scope_stack.pop()
|
||||
self.var_mapping = scope["var_mapping"]
|
||||
self.var_counter = scope["var_counter"]
|
||||
self.parameters = scope["parameters"]
|
||||
|
||||
def get_normalized_name(self, name: str) -> str:
|
||||
"""Get or create normalized name for a variable."""
|
||||
# Don't normalize if it's a builtin, import, global, nonlocal, or parameter
|
||||
if (
|
||||
name in self.builtins
|
||||
or name in self.imports
|
||||
or name in self.global_vars
|
||||
or name in self.nonlocal_vars
|
||||
or name in self.parameters
|
||||
):
|
||||
return name
|
||||
|
||||
# Only normalize local variables
|
||||
if name not in self.var_mapping:
|
||||
self.var_mapping[name] = f"var_{self.var_counter}"
|
||||
self.var_counter += 1
|
||||
return self.var_mapping[name]
|
||||
|
||||
def visit_Import(self, node): # noqa : ANN001, ANN201
|
||||
"""Track imported names."""
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
self.imports.add(name.split(".")[0])
|
||||
return node
|
||||
|
||||
def visit_ImportFrom(self, node): # noqa : ANN001, ANN201
|
||||
"""Track imported names from modules."""
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
self.imports.add(name)
|
||||
return node
|
||||
|
||||
def visit_Global(self, node): # noqa : ANN001, ANN201
|
||||
"""Track global variable declarations."""
|
||||
# Avoid repeated .add calls by using set.update with list
|
||||
self.global_vars.update(node.names)
|
||||
return node
|
||||
|
||||
def visit_Nonlocal(self, node): # noqa : ANN001, ANN201
|
||||
"""Track nonlocal variable declarations."""
|
||||
# Using set.update for batch insertion (faster than add-in-loop)
|
||||
self.nonlocal_vars.update(node.names)
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node): # noqa : ANN001, ANN201
|
||||
"""Process function but keep function name and parameters unchanged."""
|
||||
self.enter_scope()
|
||||
|
||||
# Track all parameters (don't modify them)
|
||||
for arg in node.args.args:
|
||||
self.parameters.add(arg.arg)
|
||||
if node.args.vararg:
|
||||
self.parameters.add(node.args.vararg.arg)
|
||||
if node.args.kwarg:
|
||||
self.parameters.add(node.args.kwarg.arg)
|
||||
for arg in node.args.kwonlyargs:
|
||||
self.parameters.add(arg.arg)
|
||||
|
||||
# Visit function body
|
||||
node = self.generic_visit(node)
|
||||
self.exit_scope()
|
||||
return node
|
||||
|
||||
def visit_AsyncFunctionDef(self, node): # noqa : ANN001, ANN201
|
||||
"""Handle async functions same as regular functions."""
|
||||
return self.visit_FunctionDef(node)
|
||||
|
||||
def visit_ClassDef(self, node): # noqa : ANN001, ANN201
|
||||
"""Process class but keep class name unchanged."""
|
||||
self.enter_scope()
|
||||
node = self.generic_visit(node)
|
||||
self.exit_scope()
|
||||
return node
|
||||
|
||||
def visit_Name(self, node): # noqa : ANN001, ANN201
|
||||
"""Normalize variable names in Name nodes."""
|
||||
if isinstance(node.ctx, (ast.Store, ast.Del)):
|
||||
# For assignments and deletions, check if we should normalize
|
||||
if (
|
||||
node.id not in self.builtins
|
||||
and node.id not in self.imports
|
||||
and node.id not in self.parameters
|
||||
and node.id not in self.global_vars
|
||||
and node.id not in self.nonlocal_vars
|
||||
):
|
||||
node.id = self.get_normalized_name(node.id)
|
||||
elif isinstance(node.ctx, ast.Load): # noqa : SIM102
|
||||
# For loading, use existing mapping if available
|
||||
if node.id in self.var_mapping:
|
||||
node.id = self.var_mapping[node.id]
|
||||
return node
|
||||
|
||||
def visit_ExceptHandler(self, node): # noqa : ANN001, ANN201
|
||||
"""Normalize exception variable names."""
|
||||
if node.name:
|
||||
node.name = self.get_normalized_name(node.name)
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_comprehension(self, node): # noqa : ANN001, ANN201
|
||||
"""Normalize comprehension target variables."""
|
||||
# Create new scope for comprehension
|
||||
old_mapping = dict(self.var_mapping)
|
||||
old_counter = self.var_counter
|
||||
|
||||
# Process the comprehension
|
||||
node = self.generic_visit(node)
|
||||
|
||||
# Restore scope
|
||||
self.var_mapping = old_mapping
|
||||
self.var_counter = old_counter
|
||||
return node
|
||||
|
||||
def visit_For(self, node): # noqa : ANN001, ANN201
|
||||
"""Handle for loop target variables."""
|
||||
# The target in a for loop is a local variable that should be normalized
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_With(self, node): # noqa : ANN001, ANN201
|
||||
"""Handle with statement as variables."""
|
||||
return self.generic_visit(node)
|
||||
|
||||
|
||||
def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str: # noqa : FBT002, FBT001
|
||||
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.
|
||||
def normalize_code(
|
||||
code: str,
|
||||
remove_docstrings: bool = True,
|
||||
return_ast_dump: bool = False,
|
||||
language: str | None = None,
|
||||
) -> str:
|
||||
"""Normalize code by parsing, cleaning, and normalizing variable names.
|
||||
|
||||
Function names, class names, and parameters are preserved.
|
||||
|
||||
Args:
|
||||
code: Python source code as string
|
||||
remove_docstrings: Whether to remove docstrings
|
||||
return_ast_dump: return_ast_dump
|
||||
code: Source code as string
|
||||
remove_docstrings: Whether to remove docstrings (Python only)
|
||||
return_ast_dump: Return AST dump instead of unparsed code (Python only)
|
||||
language: Language of the code. If None, uses the current session language.
|
||||
|
||||
Returns:
|
||||
Normalized code as string
|
||||
|
||||
"""
|
||||
if language is None:
|
||||
language = current_language().value
|
||||
|
||||
try:
|
||||
# Parse the code
|
||||
tree = ast.parse(code)
|
||||
normalizer = get_normalizer(language)
|
||||
|
||||
# Remove docstrings if requested
|
||||
if remove_docstrings:
|
||||
remove_docstrings_from_ast(tree)
|
||||
# Python has additional options
|
||||
if is_python():
|
||||
if return_ast_dump:
|
||||
return normalizer.normalize_for_hash(code)
|
||||
return normalizer.normalize(code, remove_docstrings=remove_docstrings)
|
||||
|
||||
# Normalize variable names
|
||||
normalizer = VariableNormalizer()
|
||||
normalized_tree = normalizer.visit(tree)
|
||||
if return_ast_dump:
|
||||
# This is faster than unparsing etc
|
||||
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)
|
||||
|
||||
# Fix missing locations in the AST
|
||||
ast.fix_missing_locations(normalized_tree)
|
||||
|
||||
# Unparse back to code
|
||||
return ast.unparse(normalized_tree)
|
||||
except SyntaxError as e:
|
||||
msg = f"Invalid Python syntax: {e}"
|
||||
raise ValueError(msg) from e
|
||||
# For other languages, use standard normalization
|
||||
return normalizer.normalize(code)
|
||||
except ValueError:
|
||||
# Unknown language - fall back to basic normalization
|
||||
return _basic_normalize(code)
|
||||
except Exception:
|
||||
# Parsing error - try other languages or fall back
|
||||
if is_python():
|
||||
# Try JavaScript as fallback
|
||||
try:
|
||||
js_normalizer = get_normalizer("javascript")
|
||||
js_result = js_normalizer.normalize(code)
|
||||
if js_result != _basic_normalize(code):
|
||||
return js_result
|
||||
except Exception:
|
||||
pass
|
||||
return _basic_normalize(code)
|
||||
|
||||
|
||||
def remove_docstrings_from_ast(node): # noqa : ANN001, ANN201
|
||||
"""Remove docstrings from AST nodes."""
|
||||
# Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0]
|
||||
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
|
||||
# Use our own stack-based DFS instead of ast.walk for efficiency
|
||||
stack = [node]
|
||||
while stack:
|
||||
current_node = stack.pop()
|
||||
if isinstance(current_node, node_types):
|
||||
# Remove docstring if it's the first stmt in body
|
||||
body = current_node.body
|
||||
if (
|
||||
body
|
||||
and isinstance(body[0], ast.Expr)
|
||||
and isinstance(body[0].value, ast.Constant)
|
||||
and isinstance(body[0].value.value, str)
|
||||
):
|
||||
current_node.body = body[1:]
|
||||
# Only these nodes can nest more docstring-containing nodes
|
||||
# Add their body elements to stack, avoiding unnecessary traversal
|
||||
stack.extend([child for child in body if isinstance(child, node_types)])
|
||||
def _basic_normalize(code: str) -> str:
|
||||
"""Basic normalization: remove comments and normalize whitespace."""
|
||||
# Remove single-line comments (// and #)
|
||||
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
|
||||
code = re.sub(r"#.*$", "", code, flags=re.MULTILINE)
|
||||
# Remove multi-line comments
|
||||
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
|
||||
code = re.sub(r'""".*?"""', "", code, flags=re.DOTALL)
|
||||
code = re.sub(r"'''.*?'''", "", code, flags=re.DOTALL)
|
||||
# Normalize whitespace
|
||||
return " ".join(code.split())
|
||||
|
||||
|
||||
def get_code_fingerprint(code: str) -> str:
|
||||
def get_code_fingerprint(code: str, language: str | None = None) -> str:
|
||||
"""Generate a fingerprint for normalized code.
|
||||
|
||||
Args:
|
||||
code: Python source code
|
||||
code: Source code
|
||||
language: Language of the code. If None, uses the current session language.
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of normalized code
|
||||
|
||||
"""
|
||||
normalized = normalize_code(code)
|
||||
return hashlib.sha256(normalized.encode()).hexdigest()
|
||||
if language is None:
|
||||
language = current_language().value
|
||||
|
||||
try:
|
||||
normalizer = get_normalizer(language)
|
||||
return normalizer.get_fingerprint(code)
|
||||
except ValueError:
|
||||
# Unknown language - use basic normalization
|
||||
normalized = _basic_normalize(code)
|
||||
return hashlib.sha256(normalized.encode()).hexdigest()
|
||||
|
||||
|
||||
def are_codes_duplicate(code1: str, code2: str) -> bool:
|
||||
def are_codes_duplicate(code1: str, code2: str, language: str | None = None) -> bool:
|
||||
"""Check if two code segments are duplicates after normalization.
|
||||
|
||||
Args:
|
||||
code1: First code segment
|
||||
code2: Second code segment
|
||||
language: Language of the code. If None, uses the current session language.
|
||||
|
||||
Returns:
|
||||
True if codes are structurally identical (ignoring local variable names)
|
||||
|
||||
"""
|
||||
if language is None:
|
||||
language = current_language().value
|
||||
|
||||
try:
|
||||
normalized1 = normalize_code(code1, return_ast_dump=True)
|
||||
normalized2 = normalize_code(code2, return_ast_dump=True)
|
||||
normalizer = get_normalizer(language)
|
||||
return normalizer.are_duplicates(code1, code2)
|
||||
except ValueError:
|
||||
# Unknown language - use basic comparison
|
||||
return _basic_normalize(code1) == _basic_normalize(code2)
|
||||
except Exception:
|
||||
return False
|
||||
else:
|
||||
return normalized1 == normalized2
|
||||
|
||||
|
||||
# Re-export for backward compatibility
|
||||
__all__ = ["are_codes_duplicate", "get_code_fingerprint", "normalize_code"]
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from libcst.metadata import PositionProvider
|
|||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.time_utils import format_perf, format_time
|
||||
from codeflash.languages.registry import get_language_support
|
||||
from codeflash.models.models import GeneratedTests, GeneratedTestsList
|
||||
from codeflash.result.critic import performance_gain
|
||||
|
||||
|
|
@ -149,25 +150,85 @@ class CommentAdder(cst.CSTTransformer):
|
|||
return updated_node
|
||||
|
||||
|
||||
def _is_python_file(file_path: Path) -> bool:
|
||||
"""Check if a file is a Python file."""
|
||||
return file_path.suffix == ".py"
|
||||
|
||||
|
||||
# TODO:{self} Needs cleanup for jest logic in else block
|
||||
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path) -> dict[str, int]:
|
||||
unique_inv_ids: dict[str, int] = {}
|
||||
logger.debug(f"[unique_inv_id] Processing {len(inv_id_runtimes)} invocation IDs")
|
||||
for inv_id, runtimes in inv_id_runtimes.items():
|
||||
test_qualified_name = (
|
||||
inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator]
|
||||
if inv_id.test_class_name
|
||||
else inv_id.test_function_name
|
||||
)
|
||||
abs_path = tests_project_rootdir / Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py")
|
||||
|
||||
# Detect if test_module_path is a file path (like in js tests) or a Python module name
|
||||
# File paths contain slashes, module names use dots
|
||||
test_module_path = inv_id.test_module_path
|
||||
if "/" in test_module_path or "\\" in test_module_path:
|
||||
# Already a file path - use directly
|
||||
abs_path = tests_project_rootdir / Path(test_module_path)
|
||||
else:
|
||||
# Check for Jest test file extensions (e.g., tests.fibonacci.test.ts)
|
||||
# These need special handling to avoid converting .test.ts -> /test/ts
|
||||
jest_test_extensions = (
|
||||
".test.ts",
|
||||
".test.js",
|
||||
".test.tsx",
|
||||
".test.jsx",
|
||||
".spec.ts",
|
||||
".spec.js",
|
||||
".spec.tsx",
|
||||
".spec.jsx",
|
||||
".ts",
|
||||
".js",
|
||||
".tsx",
|
||||
".jsx",
|
||||
".mjs",
|
||||
".mts",
|
||||
)
|
||||
matched_ext = None
|
||||
for ext in jest_test_extensions:
|
||||
if test_module_path.endswith(ext):
|
||||
matched_ext = ext
|
||||
break
|
||||
|
||||
if matched_ext:
|
||||
# JavaScript/TypeScript: convert module-style path to file path
|
||||
# "tests.fibonacci__perfonlyinstrumented.test.ts" -> "tests/fibonacci__perfonlyinstrumented.test.ts"
|
||||
base_path = test_module_path[: -len(matched_ext)]
|
||||
file_path = base_path.replace(".", os.sep) + matched_ext
|
||||
# Check if the module path includes the tests directory name
|
||||
tests_dir_name = tests_project_rootdir.name
|
||||
if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")):
|
||||
# Module path includes "tests." - use parent directory
|
||||
abs_path = tests_project_rootdir.parent / Path(file_path)
|
||||
else:
|
||||
# Module path doesn't include tests dir - use tests root directly
|
||||
abs_path = tests_project_rootdir / Path(file_path)
|
||||
else:
|
||||
# Python module name - convert dots to path separators and add .py
|
||||
abs_path = tests_project_rootdir / Path(test_module_path.replace(".", os.sep)).with_suffix(".py")
|
||||
|
||||
abs_path_str = str(abs_path.resolve().with_suffix(""))
|
||||
if "__unit_test_" not in abs_path_str or not test_qualified_name:
|
||||
# Include both unit test and perf test paths for runtime annotations
|
||||
# (performance test runtimes are used for annotations)
|
||||
if ("__unit_test_" not in abs_path_str and "__perf_test_" not in abs_path_str) or not test_qualified_name:
|
||||
logger.debug(f"[unique_inv_id] Skipping: path={abs_path_str}, test_qualified_name={test_qualified_name}")
|
||||
continue
|
||||
key = test_qualified_name + "#" + abs_path_str
|
||||
parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr]
|
||||
cur_invid = inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1]) # type: ignore[union-attr]
|
||||
match_key = key + "#" + cur_invid
|
||||
logger.debug(f"[unique_inv_id] Adding key: {match_key} with runtime {min(runtimes)}")
|
||||
if match_key not in unique_inv_ids:
|
||||
unique_inv_ids[match_key] = 0
|
||||
unique_inv_ids[match_key] += min(runtimes)
|
||||
logger.debug(f"[unique_inv_id] Result has {len(unique_inv_ids)} entries")
|
||||
return unique_inv_ids
|
||||
|
||||
|
||||
|
|
@ -183,25 +244,46 @@ def add_runtime_comments_to_generated_tests(
|
|||
# Process each generated test
|
||||
modified_tests = []
|
||||
for test in generated_tests.generated_tests:
|
||||
try:
|
||||
tree = cst.parse_module(test.generated_original_test_source)
|
||||
wrapper = MetadataWrapper(tree)
|
||||
line_to_comments = get_fn_call_linenos(test, original_runtimes_dict, optimized_runtimes_dict)
|
||||
comment_adder = CommentAdder(line_to_comments)
|
||||
modified_tree = wrapper.visit(comment_adder)
|
||||
modified_source = modified_tree.code
|
||||
modified_test = GeneratedTests(
|
||||
generated_original_test_source=modified_source,
|
||||
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
|
||||
instrumented_perf_test_source=test.instrumented_perf_test_source,
|
||||
behavior_file_path=test.behavior_file_path,
|
||||
perf_file_path=test.perf_file_path,
|
||||
)
|
||||
modified_tests.append(modified_test)
|
||||
except Exception as e:
|
||||
# If parsing fails, keep the original test
|
||||
logger.debug(f"Failed to add runtime comments to test: {e}")
|
||||
modified_tests.append(test)
|
||||
is_python = _is_python_file(test.behavior_file_path)
|
||||
|
||||
if is_python:
|
||||
# Use Python libcst-based comment insertion
|
||||
try:
|
||||
tree = cst.parse_module(test.generated_original_test_source)
|
||||
wrapper = MetadataWrapper(tree)
|
||||
line_to_comments = get_fn_call_linenos(test, original_runtimes_dict, optimized_runtimes_dict)
|
||||
comment_adder = CommentAdder(line_to_comments)
|
||||
modified_tree = wrapper.visit(comment_adder)
|
||||
modified_source = modified_tree.code
|
||||
modified_test = GeneratedTests(
|
||||
generated_original_test_source=modified_source,
|
||||
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
|
||||
instrumented_perf_test_source=test.instrumented_perf_test_source,
|
||||
behavior_file_path=test.behavior_file_path,
|
||||
perf_file_path=test.perf_file_path,
|
||||
)
|
||||
modified_tests.append(modified_test)
|
||||
except Exception as e:
|
||||
# If parsing fails, keep the original test
|
||||
logger.debug(f"Failed to add runtime comments to test: {e}")
|
||||
modified_tests.append(test)
|
||||
else:
|
||||
try:
|
||||
language_support = get_language_support(test.behavior_file_path)
|
||||
modified_source = language_support.add_runtime_comments(
|
||||
test.generated_original_test_source, original_runtimes_dict, optimized_runtimes_dict
|
||||
)
|
||||
modified_test = GeneratedTests(
|
||||
generated_original_test_source=modified_source,
|
||||
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
|
||||
instrumented_perf_test_source=test.instrumented_perf_test_source,
|
||||
behavior_file_path=test.behavior_file_path,
|
||||
perf_file_path=test.perf_file_path,
|
||||
)
|
||||
modified_tests.append(modified_test)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to add runtime comments to test: {e}")
|
||||
modified_tests.append(test)
|
||||
|
||||
return GeneratedTestsList(generated_tests=modified_tests)
|
||||
|
||||
|
|
@ -247,3 +329,103 @@ def _compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.P
|
|||
)
|
||||
for func in test_functions_to_remove
|
||||
]
|
||||
|
||||
|
||||
# Patterns for normalizing codeflash imports (legacy -> npm package)
|
||||
_CODEFLASH_REQUIRE_PATTERN = re.compile(
|
||||
r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]\.?/?codeflash-jest-helper['\"]\s*\)"
|
||||
)
|
||||
_CODEFLASH_IMPORT_PATTERN = re.compile(r"import\s+(?:\*\s+as\s+)?(\w+)\s+from\s+['\"]\.?/?codeflash-jest-helper['\"]")
|
||||
|
||||
|
||||
def normalize_codeflash_imports(source: str) -> str:
|
||||
"""Normalize codeflash imports to use the npm package.
|
||||
|
||||
Replaces legacy local file imports:
|
||||
const codeflash = require('./codeflash-jest-helper')
|
||||
import codeflash from './codeflash-jest-helper'
|
||||
|
||||
With npm package imports:
|
||||
const codeflash = require('codeflash')
|
||||
|
||||
Args:
|
||||
source: JavaScript/TypeScript source code.
|
||||
|
||||
Returns:
|
||||
Source code with normalized imports.
|
||||
|
||||
"""
|
||||
# Replace CommonJS require
|
||||
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
|
||||
# Replace ES module import
|
||||
return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
|
||||
|
||||
|
||||
def inject_test_globals(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
|
||||
# TODO: inside the prompt tell the llm if it should import jest functions or it's already injected in the global window
|
||||
"""Inject test globals into all generated tests.
|
||||
|
||||
Args:
|
||||
generated_tests: List of generated tests.
|
||||
|
||||
Returns:
|
||||
Generated tests with test globals injected.
|
||||
|
||||
"""
|
||||
# we only inject test globals for esm modules
|
||||
global_import = (
|
||||
"import { jest, describe, it, expect, beforeEach, afterEach, beforeAll, test } from '@jest/globals'\n"
|
||||
)
|
||||
|
||||
for test in generated_tests.generated_tests:
|
||||
test.generated_original_test_source = global_import + test.generated_original_test_source
|
||||
test.instrumented_behavior_test_source = global_import + test.instrumented_behavior_test_source
|
||||
test.instrumented_perf_test_source = global_import + test.instrumented_perf_test_source
|
||||
return generated_tests
|
||||
|
||||
|
||||
def disable_ts_check(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
|
||||
"""Disable TypeScript type checking in all generated tests.
|
||||
|
||||
Args:
|
||||
generated_tests: List of generated tests.
|
||||
|
||||
Returns:
|
||||
Generated tests with TypeScript type checking disabled.
|
||||
|
||||
"""
|
||||
# we only inject test globals for esm modules
|
||||
ts_nocheck = "// @ts-nocheck\n"
|
||||
|
||||
for test in generated_tests.generated_tests:
|
||||
test.generated_original_test_source = ts_nocheck + test.generated_original_test_source
|
||||
test.instrumented_behavior_test_source = ts_nocheck + test.instrumented_behavior_test_source
|
||||
test.instrumented_perf_test_source = ts_nocheck + test.instrumented_perf_test_source
|
||||
return generated_tests
|
||||
|
||||
|
||||
def normalize_generated_tests_imports(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
|
||||
"""Normalize codeflash imports in all generated tests.
|
||||
|
||||
Args:
|
||||
generated_tests: List of generated tests.
|
||||
|
||||
Returns:
|
||||
Generated tests with normalized imports.
|
||||
|
||||
"""
|
||||
normalized_tests = []
|
||||
for test in generated_tests.generated_tests:
|
||||
# Only normalize JS/TS files
|
||||
if test.behavior_file_path.suffix in (".js", ".ts", ".jsx", ".tsx", ".mjs", ".mts"):
|
||||
normalized_test = GeneratedTests(
|
||||
generated_original_test_source=normalize_codeflash_imports(test.generated_original_test_source),
|
||||
instrumented_behavior_test_source=normalize_codeflash_imports(test.instrumented_behavior_test_source),
|
||||
instrumented_perf_test_source=normalize_codeflash_imports(test.instrumented_perf_test_source),
|
||||
behavior_file_path=test.behavior_file_path,
|
||||
perf_file_path=test.perf_file_path,
|
||||
)
|
||||
normalized_tests.append(normalized_test)
|
||||
else:
|
||||
normalized_tests.append(test)
|
||||
return GeneratedTestsList(generated_tests=normalized_tests)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, sav
|
|||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
|
||||
|
||||
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
|
||||
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool:
|
||||
if not formatter_cmds or formatter_cmds[0] == "disabled":
|
||||
return True
|
||||
first_cmd = formatter_cmds[0]
|
||||
|
|
@ -155,7 +155,8 @@ def get_cached_gh_event_data() -> dict[str, Any]:
|
|||
if not event_path:
|
||||
return {}
|
||||
with open(event_path, encoding="utf-8") as f: # noqa: PTH123
|
||||
return json.load(f) # type: ignore # noqa
|
||||
result: dict[str, Any] = json.load(f)
|
||||
return result
|
||||
|
||||
|
||||
def is_repo_a_fork() -> bool:
|
||||
|
|
|
|||
|
|
@ -40,11 +40,7 @@ def generate_unified_diff(original: str, modified: str, from_file: str, to_file:
|
|||
|
||||
|
||||
def apply_formatter_cmds(
|
||||
cmds: list[str],
|
||||
path: Path,
|
||||
test_dir_str: Optional[str],
|
||||
print_status: bool, # noqa
|
||||
exit_on_failure: bool = True, # noqa
|
||||
cmds: list[str], path: Path, test_dir_str: Optional[str], print_status: bool, exit_on_failure: bool = True
|
||||
) -> tuple[Path, str, bool]:
|
||||
if not path.exists():
|
||||
msg = f"File {path} does not exist. Cannot apply formatter commands."
|
||||
|
|
@ -111,9 +107,9 @@ def format_code(
|
|||
formatter_cmds: list[str],
|
||||
path: Union[str, Path],
|
||||
optimized_code: str = "",
|
||||
check_diff: bool = False, # noqa
|
||||
print_status: bool = True, # noqa
|
||||
exit_on_failure: bool = True, # noqa
|
||||
check_diff: bool = False,
|
||||
print_status: bool = True,
|
||||
exit_on_failure: bool = True,
|
||||
) -> str:
|
||||
if is_LSP_enabled():
|
||||
exit_on_failure = False
|
||||
|
|
@ -174,7 +170,7 @@ def format_code(
|
|||
return formatted_code
|
||||
|
||||
|
||||
def sort_imports(code: str, **kwargs: Any) -> str: # noqa : ANN401
|
||||
def sort_imports(code: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
# Deduplicate and sort imports, modify the code in memory, not on disk
|
||||
sorted_code = isort.code(code, **kwargs)
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
# it's much more efficient to visit nodes manually. We'll only descend into expressions/statements.
|
||||
|
||||
# Helper for manual walk
|
||||
def iter_ast_calls(node): # noqa: ANN202, ANN001
|
||||
def iter_ast_calls(node): # noqa: ANN202
|
||||
# Generator to yield each ast.Call in test_node, preserves node identity
|
||||
stack = [node]
|
||||
while stack:
|
||||
|
|
@ -690,15 +690,14 @@ def detect_frameworks_from_code(code: str) -> dict[str, str]:
|
|||
frameworks["tensorflow"] = alias.asname if alias.asname else module_name
|
||||
elif module_name == "jax":
|
||||
frameworks["jax"] = alias.asname if alias.asname else module_name
|
||||
elif isinstance(node, ast.ImportFrom): # noqa: SIM102
|
||||
if node.module:
|
||||
module_name = node.module.split(".")[0]
|
||||
if module_name == "torch" and "torch" not in frameworks:
|
||||
frameworks["torch"] = module_name
|
||||
elif module_name == "tensorflow" and "tensorflow" not in frameworks:
|
||||
frameworks["tensorflow"] = module_name
|
||||
elif module_name == "jax" and "jax" not in frameworks:
|
||||
frameworks["jax"] = module_name
|
||||
elif isinstance(node, ast.ImportFrom) and node.module:
|
||||
module_name = node.module.split(".")[0]
|
||||
if module_name == "torch" and "torch" not in frameworks:
|
||||
frameworks["torch"] = module_name
|
||||
elif module_name == "tensorflow" and "tensorflow" not in frameworks:
|
||||
frameworks["tensorflow"] = module_name
|
||||
elif module_name == "jax" and "jax" not in frameworks:
|
||||
frameworks["jax"] = module_name
|
||||
|
||||
return frameworks
|
||||
|
||||
|
|
@ -910,8 +909,7 @@ def _create_device_sync_precompute_statements(used_frameworks: dict[str, str] |
|
|||
|
||||
|
||||
def _create_device_sync_statements(
|
||||
used_frameworks: dict[str, str] | None,
|
||||
for_return_value: bool = False, # noqa: FBT001, FBT002
|
||||
used_frameworks: dict[str, str] | None, for_return_value: bool = False
|
||||
) -> list[ast.stmt]:
|
||||
"""Create AST statements for device synchronization using pre-computed conditions.
|
||||
|
||||
|
|
@ -1450,7 +1448,7 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
|
|||
# Track when we enter a class
|
||||
self.context_stack.append(node.name.value)
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
||||
# Pop the context when we leave a class
|
||||
self.context_stack.pop()
|
||||
return updated_node
|
||||
|
|
@ -1530,7 +1528,7 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
|
|||
if import_alias.name.value == decorator_name:
|
||||
self.has_import = True
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# If the import is already there, don't add it again
|
||||
if self.has_import:
|
||||
return updated_node
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ class LineProfilerDecoratorAdder(cst.CSTTransformer):
|
|||
# Track when we enter a class
|
||||
self.context_stack.append(node.name.value)
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
||||
# Pop the context when we leave a class
|
||||
self.context_stack.pop()
|
||||
return updated_node
|
||||
|
|
@ -268,7 +268,7 @@ class ProfileEnableTransformer(cst.CSTTransformer):
|
|||
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
if not self.found_import:
|
||||
return updated_node
|
||||
|
||||
|
|
@ -332,11 +332,11 @@ def add_profile_enable(original_code: str, line_profile_output_file: str) -> str
|
|||
|
||||
|
||||
class ImportAdder(cst.CSTTransformer):
|
||||
def __init__(self, import_statement) -> None: # noqa: ANN001
|
||||
def __init__(self, import_statement) -> None:
|
||||
self.import_statement = import_statement
|
||||
self.has_import = False
|
||||
|
||||
def leave_Module(self, original_node, updated_node): # noqa: ANN001, ANN201, ARG002
|
||||
def leave_Module(self, original_node, updated_node): # noqa: ANN201
|
||||
# If the import is already there, don't add it again
|
||||
if self.has_import:
|
||||
return updated_node
|
||||
|
|
@ -347,7 +347,7 @@ class ImportAdder(cst.CSTTransformer):
|
|||
# Add the import to the module's body
|
||||
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
|
||||
|
||||
def visit_ImportFrom(self, node) -> None: # noqa: ANN001
|
||||
def visit_ImportFrom(self, node) -> None:
|
||||
# Check if the profile is already imported from line_profiler
|
||||
if node.module and node.module.value == "line_profiler":
|
||||
for import_alias in node.names:
|
||||
|
|
|
|||
106
codeflash/code_utils/normalizers/__init__.py
Normal file
106
codeflash/code_utils/normalizers/__init__.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
"""Code normalizers for different programming languages.
|
||||
|
||||
This module provides language-specific code normalizers that transform source code
|
||||
into canonical forms for duplicate detection. The normalizers:
|
||||
- Replace local variable names with canonical forms (var_0, var_1, etc.)
|
||||
- Preserve function names, class names, parameters, and imports
|
||||
- Remove or normalize comments and docstrings
|
||||
- Produce consistent output for structurally identical code
|
||||
|
||||
Usage:
|
||||
>>> normalizer = get_normalizer("python")
|
||||
>>> normalized = normalizer.normalize(code)
|
||||
>>> fingerprint = normalizer.get_fingerprint(code)
|
||||
>>> are_same = normalizer.are_duplicates(code1, code2)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from codeflash.code_utils.normalizers.base import CodeNormalizer
|
||||
from codeflash.code_utils.normalizers.javascript import JavaScriptNormalizer, TypeScriptNormalizer
|
||||
from codeflash.code_utils.normalizers.python import PythonNormalizer
|
||||
|
||||
__all__ = [
|
||||
"CodeNormalizer",
|
||||
"JavaScriptNormalizer",
|
||||
"PythonNormalizer",
|
||||
"TypeScriptNormalizer",
|
||||
"get_normalizer",
|
||||
"get_normalizer_for_extension",
|
||||
]
|
||||
|
||||
# Registry of normalizers by language
|
||||
_NORMALIZERS: dict[str, type[CodeNormalizer]] = {
|
||||
"python": PythonNormalizer,
|
||||
"javascript": JavaScriptNormalizer,
|
||||
"typescript": TypeScriptNormalizer,
|
||||
}
|
||||
|
||||
# Singleton cache for normalizer instances
|
||||
_normalizer_instances: dict[str, CodeNormalizer] = {}
|
||||
|
||||
|
||||
def get_normalizer(language: str) -> CodeNormalizer:
|
||||
"""Get a code normalizer for the specified language.
|
||||
|
||||
Args:
|
||||
language: Language name ('python', 'javascript', 'typescript')
|
||||
|
||||
Returns:
|
||||
CodeNormalizer instance for the language
|
||||
|
||||
Raises:
|
||||
ValueError: If no normalizer exists for the language
|
||||
|
||||
"""
|
||||
language = language.lower()
|
||||
|
||||
# Check cache first
|
||||
if language in _normalizer_instances:
|
||||
return _normalizer_instances[language]
|
||||
|
||||
# Get normalizer class
|
||||
if language not in _NORMALIZERS:
|
||||
supported = ", ".join(sorted(_NORMALIZERS.keys()))
|
||||
msg = f"No normalizer available for language '{language}'. Supported: {supported}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Create and cache instance
|
||||
normalizer = _NORMALIZERS[language]()
|
||||
_normalizer_instances[language] = normalizer
|
||||
return normalizer
|
||||
|
||||
|
||||
def get_normalizer_for_extension(extension: str) -> CodeNormalizer | None:
|
||||
"""Get a code normalizer based on file extension.
|
||||
|
||||
Args:
|
||||
extension: File extension including dot (e.g., '.py', '.js')
|
||||
|
||||
Returns:
|
||||
CodeNormalizer instance if found, None otherwise
|
||||
|
||||
"""
|
||||
extension = extension.lower()
|
||||
if not extension.startswith("."):
|
||||
extension = f".{extension}"
|
||||
|
||||
for language in _NORMALIZERS:
|
||||
normalizer = get_normalizer(language)
|
||||
if extension in normalizer.supported_extensions:
|
||||
return normalizer
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def register_normalizer(language: str, normalizer_class: type[CodeNormalizer]) -> None:
|
||||
"""Register a new normalizer for a language.
|
||||
|
||||
Args:
|
||||
language: Language name
|
||||
normalizer_class: CodeNormalizer subclass
|
||||
|
||||
"""
|
||||
_NORMALIZERS[language.lower()] = normalizer_class
|
||||
# Clear cached instance if it exists
|
||||
_normalizer_instances.pop(language.lower(), None)
|
||||
104
codeflash/code_utils/normalizers/base.py
Normal file
104
codeflash/code_utils/normalizers/base.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Abstract base class for code normalizers.
|
||||
|
||||
Code normalizers transform source code into a canonical form for duplicate detection.
|
||||
They normalize variable names, remove comments/docstrings, and produce consistent output
|
||||
that can be compared across different implementations of the same algorithm.
|
||||
"""
|
||||
|
||||
# TODO:{claude} move to base.py in language folder
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class CodeNormalizer(ABC):
|
||||
"""Abstract base class for language-specific code normalizers.
|
||||
|
||||
Subclasses must implement the normalize() method for their specific language.
|
||||
The normalization should:
|
||||
- Normalize local variable names to canonical forms (var_0, var_1, etc.)
|
||||
- Preserve function names, class names, parameters, and imports
|
||||
- Remove or normalize comments and docstrings
|
||||
- Produce consistent output for structurally identical code
|
||||
|
||||
Example:
|
||||
>>> normalizer = PythonNormalizer()
|
||||
>>> code1 = "def foo(x): y = x + 1; return y"
|
||||
>>> code2 = "def foo(x): z = x + 1; return z"
|
||||
>>> normalizer.normalize(code1) == normalizer.normalize(code2)
|
||||
True
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
...
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return ()
|
||||
|
||||
@abstractmethod
|
||||
def normalize(self, code: str) -> str:
|
||||
"""Normalize code to a canonical form for comparison.
|
||||
|
||||
Args:
|
||||
code: Source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation of the code
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def normalize_for_hash(self, code: str) -> str:
|
||||
"""Normalize code optimized for hashing/fingerprinting.
|
||||
|
||||
This may return a more compact representation than normalize().
|
||||
|
||||
Args:
|
||||
code: Source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation suitable for hashing
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def are_duplicates(self, code1: str, code2: str) -> bool:
|
||||
"""Check if two code segments are duplicates after normalization.
|
||||
|
||||
Args:
|
||||
code1: First code segment
|
||||
code2: Second code segment
|
||||
|
||||
Returns:
|
||||
True if codes are structurally identical
|
||||
|
||||
"""
|
||||
try:
|
||||
normalized1 = self.normalize_for_hash(code1)
|
||||
normalized2 = self.normalize_for_hash(code2)
|
||||
except Exception:
|
||||
return False
|
||||
else:
|
||||
return normalized1 == normalized2
|
||||
|
||||
def get_fingerprint(self, code: str) -> str:
|
||||
"""Generate a fingerprint hash for normalized code.
|
||||
|
||||
Args:
|
||||
code: Source code to fingerprint
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of normalized code
|
||||
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
normalized = self.normalize_for_hash(code)
|
||||
return hashlib.sha256(normalized.encode()).hexdigest()
|
||||
290
codeflash/code_utils/normalizers/javascript.py
Normal file
290
codeflash/code_utils/normalizers/javascript.py
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
"""JavaScript/TypeScript code normalizer using tree-sitter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.code_utils.normalizers.base import CodeNormalizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Node
|
||||
|
||||
|
||||
# TODO:{claude} move to language support directory to keep the directory structure clean
|
||||
class JavaScriptVariableNormalizer:
|
||||
"""Normalizes JavaScript/TypeScript code for duplicate detection using tree-sitter.
|
||||
|
||||
Normalizes local variable names while preserving function names, class names,
|
||||
parameters, and imported names.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.var_counter = 0
|
||||
self.var_mapping: dict[str, str] = {}
|
||||
self.preserved_names: set[str] = set()
|
||||
# Common JavaScript builtins
|
||||
self.builtins = {
|
||||
"console",
|
||||
"window",
|
||||
"document",
|
||||
"Math",
|
||||
"JSON",
|
||||
"Object",
|
||||
"Array",
|
||||
"String",
|
||||
"Number",
|
||||
"Boolean",
|
||||
"Date",
|
||||
"RegExp",
|
||||
"Error",
|
||||
"Promise",
|
||||
"Map",
|
||||
"Set",
|
||||
"WeakMap",
|
||||
"WeakSet",
|
||||
"Symbol",
|
||||
"Proxy",
|
||||
"Reflect",
|
||||
"undefined",
|
||||
"null",
|
||||
"NaN",
|
||||
"Infinity",
|
||||
"globalThis",
|
||||
"parseInt",
|
||||
"parseFloat",
|
||||
"isNaN",
|
||||
"isFinite",
|
||||
"eval",
|
||||
"setTimeout",
|
||||
"setInterval",
|
||||
"clearTimeout",
|
||||
"clearInterval",
|
||||
"fetch",
|
||||
"require",
|
||||
"module",
|
||||
"exports",
|
||||
"process",
|
||||
"__dirname",
|
||||
"__filename",
|
||||
"Buffer",
|
||||
}
|
||||
|
||||
def get_normalized_name(self, name: str) -> str:
|
||||
"""Get or create normalized name for a variable."""
|
||||
if name in self.builtins or name in self.preserved_names:
|
||||
return name
|
||||
if name not in self.var_mapping:
|
||||
self.var_mapping[name] = f"var_{self.var_counter}"
|
||||
self.var_counter += 1
|
||||
return self.var_mapping[name]
|
||||
|
||||
def collect_preserved_names(self, node: Node, source_code: bytes) -> None:
|
||||
"""Collect names that should be preserved (function names, class names, imports, params)."""
|
||||
# Function declarations and expressions - preserve the function name
|
||||
if node.type in ("function_declaration", "function_expression", "method_definition", "arrow_function"):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
|
||||
# Preserve parameters
|
||||
params_node = node.child_by_field_name("parameters") or node.child_by_field_name("parameter")
|
||||
if params_node:
|
||||
self._collect_parameter_names(params_node, source_code)
|
||||
|
||||
# Class declarations
|
||||
elif node.type == "class_declaration":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
|
||||
|
||||
# Import declarations
|
||||
elif node.type in ("import_statement", "import_declaration"):
|
||||
for child in node.children:
|
||||
if child.type == "import_clause":
|
||||
self._collect_import_names(child, source_code)
|
||||
elif child.type == "identifier":
|
||||
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
|
||||
|
||||
# Recurse
|
||||
for child in node.children:
|
||||
self.collect_preserved_names(child, source_code)
|
||||
|
||||
def _collect_parameter_names(self, node: Node, source_code: bytes) -> None:
|
||||
"""Collect parameter names from a parameters node."""
|
||||
for child in node.children:
|
||||
if child.type == "identifier":
|
||||
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
|
||||
elif child.type in ("required_parameter", "optional_parameter", "rest_parameter"):
|
||||
pattern_node = child.child_by_field_name("pattern")
|
||||
if pattern_node and pattern_node.type == "identifier":
|
||||
self.preserved_names.add(
|
||||
source_code[pattern_node.start_byte : pattern_node.end_byte].decode("utf-8")
|
||||
)
|
||||
# Recurse for nested patterns
|
||||
self._collect_parameter_names(child, source_code)
|
||||
|
||||
def _collect_import_names(self, node: Node, source_code: bytes) -> None:
|
||||
"""Collect imported names from import clause."""
|
||||
for child in node.children:
|
||||
if child.type == "identifier":
|
||||
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
|
||||
elif child.type == "import_specifier":
|
||||
# Get the local name (alias or original)
|
||||
alias_node = child.child_by_field_name("alias")
|
||||
name_node = child.child_by_field_name("name")
|
||||
if alias_node:
|
||||
self.preserved_names.add(source_code[alias_node.start_byte : alias_node.end_byte].decode("utf-8"))
|
||||
elif name_node:
|
||||
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
|
||||
self._collect_import_names(child, source_code)
|
||||
|
||||
def normalize_tree(self, node: Node, source_code: bytes) -> str:
|
||||
"""Normalize the AST tree to a string representation for comparison."""
|
||||
parts: list[str] = []
|
||||
self._normalize_node(node, source_code, parts)
|
||||
return " ".join(parts)
|
||||
|
||||
def _normalize_node(self, node: Node, source_code: bytes, parts: list[str]) -> None:
|
||||
"""Recursively normalize a node."""
|
||||
# Skip comments
|
||||
if node.type in ("comment", "line_comment", "block_comment"):
|
||||
return
|
||||
|
||||
# Handle identifiers - normalize variable names
|
||||
if node.type == "identifier":
|
||||
name = source_code[node.start_byte : node.end_byte].decode("utf-8")
|
||||
normalized = self.get_normalized_name(name)
|
||||
parts.append(normalized)
|
||||
return
|
||||
|
||||
# Handle type identifiers (TypeScript) - preserve as-is
|
||||
if node.type == "type_identifier":
|
||||
parts.append(source_code[node.start_byte : node.end_byte].decode("utf-8"))
|
||||
return
|
||||
|
||||
# Handle string literals - normalize to placeholder
|
||||
if node.type in ("string", "template_string", "string_fragment"):
|
||||
parts.append('"STR"')
|
||||
return
|
||||
|
||||
# Handle number literals - normalize to placeholder
|
||||
if node.type == "number":
|
||||
parts.append("NUM")
|
||||
return
|
||||
|
||||
# For leaf nodes, output the node type
|
||||
if len(node.children) == 0:
|
||||
text = source_code[node.start_byte : node.end_byte].decode("utf-8")
|
||||
parts.append(text)
|
||||
return
|
||||
|
||||
# Output node type for structure
|
||||
parts.append(f"({node.type}")
|
||||
|
||||
# Recurse into children
|
||||
for child in node.children:
|
||||
self._normalize_node(child, source_code, parts)
|
||||
|
||||
parts.append(")")
|
||||
|
||||
|
||||
def _basic_normalize(code: str) -> str:
|
||||
"""Basic normalization: remove comments and normalize whitespace."""
|
||||
# Remove single-line comments
|
||||
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
|
||||
# Remove multi-line comments
|
||||
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
|
||||
# Normalize whitespace
|
||||
return " ".join(code.split())
|
||||
|
||||
|
||||
class JavaScriptNormalizer(CodeNormalizer):
|
||||
"""JavaScript code normalizer using tree-sitter.
|
||||
|
||||
Normalizes JavaScript code by:
|
||||
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
|
||||
- Preserving function names, class names, parameters, and imports
|
||||
- Removing comments
|
||||
- Normalizing string and number literals
|
||||
"""
|
||||
|
||||
@property
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
return "javascript"
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return (".js", ".jsx", ".mjs", ".cjs")
|
||||
|
||||
def _get_tree_sitter_language(self) -> str:
|
||||
"""Get the tree-sitter language identifier."""
|
||||
return "javascript"
|
||||
|
||||
def normalize(self, code: str) -> str:
|
||||
"""Normalize JavaScript code to a canonical form.
|
||||
|
||||
Args:
|
||||
code: JavaScript source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation of the code
|
||||
|
||||
"""
|
||||
try:
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
lang_map = {"javascript": TreeSitterLanguage.JAVASCRIPT, "typescript": TreeSitterLanguage.TYPESCRIPT}
|
||||
lang = lang_map.get(self._get_tree_sitter_language(), TreeSitterLanguage.JAVASCRIPT)
|
||||
analyzer = TreeSitterAnalyzer(lang)
|
||||
tree = analyzer.parse(code)
|
||||
|
||||
if tree.root_node.has_error:
|
||||
return _basic_normalize(code)
|
||||
|
||||
normalizer = JavaScriptVariableNormalizer()
|
||||
source_bytes = code.encode("utf-8")
|
||||
|
||||
# First pass: collect preserved names
|
||||
normalizer.collect_preserved_names(tree.root_node, source_bytes)
|
||||
|
||||
# Second pass: normalize and build representation
|
||||
return normalizer.normalize_tree(tree.root_node, source_bytes)
|
||||
except Exception:
|
||||
return _basic_normalize(code)
|
||||
|
||||
def normalize_for_hash(self, code: str) -> str:
|
||||
"""Normalize JavaScript code optimized for hashing.
|
||||
|
||||
For JavaScript, this is the same as normalize().
|
||||
|
||||
Args:
|
||||
code: JavaScript source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation suitable for hashing
|
||||
|
||||
"""
|
||||
return self.normalize(code)
|
||||
|
||||
|
||||
class TypeScriptNormalizer(JavaScriptNormalizer):
|
||||
"""TypeScript code normalizer using tree-sitter.
|
||||
|
||||
Inherits from JavaScriptNormalizer and overrides language-specific settings.
|
||||
"""
|
||||
|
||||
@property
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
return "typescript"
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return (".ts", ".tsx", ".mts", ".cts")
|
||||
|
||||
def _get_tree_sitter_language(self) -> str:
|
||||
"""Get the tree-sitter language identifier."""
|
||||
return "typescript"
|
||||
226
codeflash/code_utils/normalizers/python.py
Normal file
226
codeflash/code_utils/normalizers/python.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""Python code normalizer using AST transformation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
|
||||
from codeflash.code_utils.normalizers.base import CodeNormalizer
|
||||
|
||||
|
||||
class VariableNormalizer(ast.NodeTransformer):
|
||||
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
|
||||
|
||||
Preserves function names, class names, parameters, built-ins, and imported names.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.var_counter = 0
|
||||
self.var_mapping: dict[str, str] = {}
|
||||
self.scope_stack: list[dict] = []
|
||||
self.builtins = set(dir(__builtins__))
|
||||
self.imports: set[str] = set()
|
||||
self.global_vars: set[str] = set()
|
||||
self.nonlocal_vars: set[str] = set()
|
||||
self.parameters: set[str] = set()
|
||||
|
||||
def enter_scope(self) -> None:
|
||||
"""Enter a new scope (function/class)."""
|
||||
self.scope_stack.append(
|
||||
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
|
||||
)
|
||||
|
||||
def exit_scope(self) -> None:
|
||||
"""Exit current scope and restore parent scope."""
|
||||
if self.scope_stack:
|
||||
scope = self.scope_stack.pop()
|
||||
self.var_mapping = scope["var_mapping"]
|
||||
self.var_counter = scope["var_counter"]
|
||||
self.parameters = scope["parameters"]
|
||||
|
||||
def get_normalized_name(self, name: str) -> str:
|
||||
"""Get or create normalized name for a variable."""
|
||||
if (
|
||||
name in self.builtins
|
||||
or name in self.imports
|
||||
or name in self.global_vars
|
||||
or name in self.nonlocal_vars
|
||||
or name in self.parameters
|
||||
):
|
||||
return name
|
||||
|
||||
if name not in self.var_mapping:
|
||||
self.var_mapping[name] = f"var_{self.var_counter}"
|
||||
self.var_counter += 1
|
||||
return self.var_mapping[name]
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> ast.Import:
|
||||
"""Track imported names."""
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
self.imports.add(name.split(".")[0])
|
||||
return node
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
|
||||
"""Track imported names from modules."""
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
self.imports.add(name)
|
||||
return node
|
||||
|
||||
def visit_Global(self, node: ast.Global) -> ast.Global:
|
||||
"""Track global variable declarations."""
|
||||
self.global_vars.update(node.names)
|
||||
return node
|
||||
|
||||
def visit_Nonlocal(self, node: ast.Nonlocal) -> ast.Nonlocal:
|
||||
"""Track nonlocal variable declarations."""
|
||||
self.nonlocal_vars.update(node.names)
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
|
||||
"""Process function but keep function name and parameters unchanged."""
|
||||
self.enter_scope()
|
||||
|
||||
for arg in node.args.args:
|
||||
self.parameters.add(arg.arg)
|
||||
if node.args.vararg:
|
||||
self.parameters.add(node.args.vararg.arg)
|
||||
if node.args.kwarg:
|
||||
self.parameters.add(node.args.kwarg.arg)
|
||||
for arg in node.args.kwonlyargs:
|
||||
self.parameters.add(arg.arg)
|
||||
|
||||
node = self.generic_visit(node)
|
||||
self.exit_scope()
|
||||
return node
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
|
||||
"""Handle async functions same as regular functions."""
|
||||
return self.visit_FunctionDef(node) # type: ignore[return-value]
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||
"""Process class but keep class name unchanged."""
|
||||
self.enter_scope()
|
||||
node = self.generic_visit(node)
|
||||
self.exit_scope()
|
||||
return node
|
||||
|
||||
def visit_Name(self, node: ast.Name) -> ast.Name:
|
||||
"""Normalize variable names in Name nodes."""
|
||||
if isinstance(node.ctx, (ast.Store, ast.Del)):
|
||||
if (
|
||||
node.id not in self.builtins
|
||||
and node.id not in self.imports
|
||||
and node.id not in self.parameters
|
||||
and node.id not in self.global_vars
|
||||
and node.id not in self.nonlocal_vars
|
||||
):
|
||||
node.id = self.get_normalized_name(node.id)
|
||||
elif isinstance(node.ctx, ast.Load) and node.id in self.var_mapping:
|
||||
node.id = self.var_mapping[node.id]
|
||||
return node
|
||||
|
||||
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> ast.ExceptHandler:
|
||||
"""Normalize exception variable names."""
|
||||
if node.name:
|
||||
node.name = self.get_normalized_name(node.name)
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_comprehension(self, node: ast.comprehension) -> ast.comprehension:
|
||||
"""Normalize comprehension target variables."""
|
||||
old_mapping = dict(self.var_mapping)
|
||||
old_counter = self.var_counter
|
||||
|
||||
node = self.generic_visit(node)
|
||||
|
||||
self.var_mapping = old_mapping
|
||||
self.var_counter = old_counter
|
||||
return node
|
||||
|
||||
def visit_For(self, node: ast.For) -> ast.For:
|
||||
"""Handle for loop target variables."""
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_With(self, node: ast.With) -> ast.With:
|
||||
"""Handle with statement as variables."""
|
||||
return self.generic_visit(node)
|
||||
|
||||
|
||||
def _remove_docstrings_from_ast(node: ast.AST) -> None:
|
||||
"""Remove docstrings from AST nodes."""
|
||||
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
|
||||
stack = [node]
|
||||
while stack:
|
||||
current_node = stack.pop()
|
||||
if isinstance(current_node, node_types):
|
||||
body = current_node.body
|
||||
if (
|
||||
body
|
||||
and isinstance(body[0], ast.Expr)
|
||||
and isinstance(body[0].value, ast.Constant)
|
||||
and isinstance(body[0].value.value, str)
|
||||
):
|
||||
current_node.body = body[1:]
|
||||
stack.extend([child for child in body if isinstance(child, node_types)])
|
||||
|
||||
|
||||
class PythonNormalizer(CodeNormalizer):
|
||||
"""Python code normalizer using AST transformation.
|
||||
|
||||
Normalizes Python code by:
|
||||
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
|
||||
- Preserving function names, class names, parameters, and imports
|
||||
- Optionally removing docstrings
|
||||
"""
|
||||
|
||||
@property
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
return "python"
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return (".py", ".pyw", ".pyi")
|
||||
|
||||
def normalize(self, code: str, remove_docstrings: bool = True) -> str:
|
||||
"""Normalize Python code to a canonical form.
|
||||
|
||||
Args:
|
||||
code: Python source code to normalize
|
||||
remove_docstrings: Whether to remove docstrings
|
||||
|
||||
Returns:
|
||||
Normalized Python code as a string
|
||||
|
||||
"""
|
||||
tree = ast.parse(code)
|
||||
|
||||
if remove_docstrings:
|
||||
_remove_docstrings_from_ast(tree)
|
||||
|
||||
normalizer = VariableNormalizer()
|
||||
normalized_tree = normalizer.visit(tree)
|
||||
ast.fix_missing_locations(normalized_tree)
|
||||
|
||||
return ast.unparse(normalized_tree)
|
||||
|
||||
def normalize_for_hash(self, code: str) -> str:
|
||||
"""Normalize Python code optimized for hashing.
|
||||
|
||||
Returns AST dump which is faster than unparsing.
|
||||
|
||||
Args:
|
||||
code: Python source code to normalize
|
||||
|
||||
Returns:
|
||||
AST dump string suitable for hashing
|
||||
|
||||
"""
|
||||
tree = ast.parse(code)
|
||||
_remove_docstrings_from_ast(tree)
|
||||
|
||||
normalizer = VariableNormalizer()
|
||||
normalized_tree = normalizer.visit(tree)
|
||||
|
||||
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)
|
||||
|
|
@ -702,7 +702,7 @@ def _wait_for_manual_code_input(oauth: OAuthHandler) -> None:
|
|||
if not oauth.is_complete:
|
||||
oauth.manual_code = code.strip()
|
||||
oauth.is_complete = True
|
||||
except Exception: # noqa: S110
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -242,9 +242,9 @@ def get_cross_platform_subprocess_run_args(
|
|||
cwd: Path | str | None = None,
|
||||
env: Mapping[str, str] | None = None,
|
||||
timeout: Optional[float] = None,
|
||||
check: bool = False, # noqa: FBT001, FBT002
|
||||
text: bool = True, # noqa: FBT001, FBT002
|
||||
capture_output: bool = True, # noqa: FBT001, FBT002 (only for non-Windows)
|
||||
check: bool = False,
|
||||
text: bool = True,
|
||||
capture_output: bool = True,
|
||||
) -> dict[str, str]:
|
||||
run_args = {"cwd": cwd, "env": env, "text": text, "timeout": timeout, "check": check}
|
||||
if sys.platform == "win32":
|
||||
|
|
|
|||
|
|
@ -649,7 +649,7 @@ def tabulate(
|
|||
headersalign=None,
|
||||
rowalign=None,
|
||||
maxheadercolwidths=None,
|
||||
):
|
||||
) -> str:
|
||||
if tabular_data is None:
|
||||
tabular_data = []
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def get_latest_version_from_pypi() -> str | None:
|
|||
|
||||
return latest_version
|
||||
logger.debug(f"Failed to fetch version from PyPI: {response.status_code}")
|
||||
return None # noqa: TRY300
|
||||
return None
|
||||
except requests.RequestException as e:
|
||||
logger.debug(f"Network error fetching version from PyPI: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ from codeflash.context.unused_definition_remover import (
|
|||
remove_unused_definitions_by_function_names,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
|
||||
|
||||
# Language support imports for multi-language code context extraction
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.models.models import (
|
||||
CodeContextType,
|
||||
CodeOptimizationContext,
|
||||
|
|
@ -35,6 +39,7 @@ if TYPE_CHECKING:
|
|||
from libcst import CSTNode
|
||||
|
||||
from codeflash.context.unused_definition_remover import UsageInfo
|
||||
from codeflash.languages.base import HelperFunction
|
||||
|
||||
|
||||
def build_testgen_context(
|
||||
|
|
@ -75,6 +80,12 @@ def get_code_optimization_context(
|
|||
optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
|
||||
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
|
||||
) -> CodeOptimizationContext:
|
||||
# Route to language-specific implementation for non-Python languages
|
||||
if not is_python():
|
||||
return get_code_optimization_context_for_language(
|
||||
function_to_optimize, project_root_path, optim_token_limit, testgen_token_limit
|
||||
)
|
||||
|
||||
# Get FunctionSource representation of helpers of FTO
|
||||
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(
|
||||
{function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path
|
||||
|
|
@ -95,7 +106,7 @@ def get_code_optimization_context(
|
|||
qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if "." in qn})
|
||||
|
||||
# Get FunctionSource representation of helpers of helpers of FTO
|
||||
helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(
|
||||
helpers_of_helpers_dict, _helpers_of_helpers_list = get_function_sources_from_jedi(
|
||||
helpers_of_fto_qualified_names_dict, project_root_path
|
||||
)
|
||||
|
||||
|
|
@ -198,11 +209,161 @@ def get_code_optimization_context(
|
|||
)
|
||||
|
||||
|
||||
def get_code_optimization_context_for_language(
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root_path: Path,
|
||||
optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
|
||||
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
|
||||
) -> CodeOptimizationContext:
|
||||
"""Extract code optimization context for non-Python languages.
|
||||
|
||||
Uses the language support abstraction to extract code context and converts
|
||||
it to the CodeOptimizationContext format expected by the pipeline.
|
||||
|
||||
This function supports multi-file context extraction, grouping helpers by file
|
||||
and creating proper CodeStringsMarkdown with file paths for multi-file replacement.
|
||||
|
||||
Args:
|
||||
function_to_optimize: The function to extract context for.
|
||||
project_root_path: Root of the project.
|
||||
optim_token_limit: Token limit for optimization context.
|
||||
testgen_token_limit: Token limit for testgen context.
|
||||
|
||||
Returns:
|
||||
CodeOptimizationContext with target code and dependencies.
|
||||
|
||||
"""
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import FunctionInfo, ParentInfo
|
||||
|
||||
# Get language support for this function
|
||||
language = Language(function_to_optimize.language)
|
||||
lang_support = get_language_support(language)
|
||||
|
||||
# Convert FunctionToOptimize to FunctionInfo for language support
|
||||
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents)
|
||||
func_info = FunctionInfo(
|
||||
name=function_to_optimize.function_name,
|
||||
file_path=function_to_optimize.file_path,
|
||||
start_line=function_to_optimize.starting_line or 1,
|
||||
end_line=function_to_optimize.ending_line or 1,
|
||||
parents=parents,
|
||||
is_async=function_to_optimize.is_async,
|
||||
is_method=len(function_to_optimize.parents) > 0,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Extract code context using language support
|
||||
code_context = lang_support.extract_code_context(func_info, project_root_path, project_root_path)
|
||||
|
||||
# Build imports string if available
|
||||
imports_code = "\n".join(code_context.imports) if code_context.imports else ""
|
||||
|
||||
# Get relative path for target file
|
||||
try:
|
||||
target_relative_path = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve())
|
||||
except ValueError:
|
||||
target_relative_path = function_to_optimize.file_path
|
||||
|
||||
# Group helpers by file path
|
||||
helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list)
|
||||
helper_function_sources = []
|
||||
|
||||
for helper in code_context.helper_functions:
|
||||
helpers_by_file[helper.file_path].append(helper)
|
||||
|
||||
# Convert to FunctionSource for pipeline compatibility
|
||||
helper_function_sources.append(
|
||||
FunctionSource(
|
||||
file_path=helper.file_path,
|
||||
qualified_name=helper.qualified_name,
|
||||
fully_qualified_name=helper.qualified_name,
|
||||
only_function_name=helper.name,
|
||||
source_code=helper.source_code,
|
||||
jedi_definition=None,
|
||||
)
|
||||
)
|
||||
|
||||
# Build read-writable code (target file + same-file helpers + global variables)
|
||||
read_writable_code_strings = []
|
||||
|
||||
# Combine target code with same-file helpers
|
||||
target_file_code = code_context.target_code
|
||||
same_file_helpers = helpers_by_file.get(function_to_optimize.file_path, [])
|
||||
if same_file_helpers:
|
||||
helper_code = "\n\n".join(h.source_code for h in same_file_helpers)
|
||||
target_file_code = target_file_code + "\n\n" + helper_code
|
||||
|
||||
# Add global variables (module-level declarations) referenced by the function and helpers
|
||||
# These should be included in read-writable context so AI can modify them if needed
|
||||
if code_context.read_only_context:
|
||||
target_file_code = code_context.read_only_context + "\n\n" + target_file_code
|
||||
|
||||
# Add imports to target file code
|
||||
if imports_code:
|
||||
target_file_code = imports_code + "\n\n" + target_file_code
|
||||
|
||||
read_writable_code_strings.append(
|
||||
CodeString(code=target_file_code, file_path=target_relative_path, language=function_to_optimize.language)
|
||||
)
|
||||
|
||||
# Add helper files (cross-file helpers)
|
||||
for file_path, file_helpers in helpers_by_file.items():
|
||||
if file_path == function_to_optimize.file_path:
|
||||
continue # Already included in target file
|
||||
|
||||
try:
|
||||
helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve())
|
||||
except ValueError:
|
||||
helper_relative_path = file_path
|
||||
|
||||
# Combine all helpers from this file
|
||||
combined_helper_code = "\n\n".join(h.source_code for h in file_helpers)
|
||||
|
||||
read_writable_code_strings.append(
|
||||
CodeString(
|
||||
code=combined_helper_code, file_path=helper_relative_path, language=function_to_optimize.language
|
||||
)
|
||||
)
|
||||
|
||||
read_writable_code = CodeStringsMarkdown(
|
||||
code_strings=read_writable_code_strings, language=function_to_optimize.language
|
||||
)
|
||||
|
||||
# Build testgen context (same as read_writable for non-Python)
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=read_writable_code_strings.copy(), language=function_to_optimize.language
|
||||
)
|
||||
|
||||
# Check token limits
|
||||
read_writable_tokens = encoded_tokens_len(read_writable_code.markdown)
|
||||
if read_writable_tokens > optim_token_limit:
|
||||
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
|
||||
|
||||
testgen_tokens = encoded_tokens_len(testgen_context.markdown)
|
||||
if testgen_tokens > testgen_token_limit:
|
||||
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
|
||||
|
||||
# Generate code hash from all read-writable code
|
||||
code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest()
|
||||
|
||||
return CodeOptimizationContext(
|
||||
testgen_context=testgen_context,
|
||||
read_writable_code=read_writable_code,
|
||||
# Global variables are now included in read-writable code, so don't duplicate in read-only
|
||||
read_only_context_code="",
|
||||
hashing_code_context=read_writable_code.flat,
|
||||
hashing_code_context_hash=code_hash,
|
||||
helper_functions=helper_function_sources,
|
||||
preexisting_objects=set(), # Not implemented for non-Python yet
|
||||
)
|
||||
|
||||
|
||||
def extract_code_markdown_context_from_files(
|
||||
helpers_of_fto: dict[Path, set[FunctionSource]],
|
||||
helpers_of_helpers: dict[Path, set[FunctionSource]],
|
||||
project_root_path: Path,
|
||||
remove_docstrings: bool = False, # noqa: FBT001, FBT002
|
||||
remove_docstrings: bool = False,
|
||||
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
|
||||
) -> CodeStringsMarkdown:
|
||||
"""Extract code context from files containing target functions and their helpers, formatting them as markdown.
|
||||
|
|
@ -833,7 +994,7 @@ def get_imported_names(import_node: cst.Import | cst.ImportFrom) -> set[str]:
|
|||
|
||||
|
||||
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
|
||||
"""Removes the docstring from an indented block if it exists.""" # noqa: D401
|
||||
"""Removes the docstring from an indented block if it exists."""
|
||||
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
|
||||
return indented_block
|
||||
first_stmt = indented_block.body[0].body[0]
|
||||
|
|
@ -847,7 +1008,7 @@ def parse_code_and_prune_cst(
|
|||
code_context_type: CodeContextType,
|
||||
target_functions: set[str],
|
||||
helpers_of_helper_functions: set[str] = set(), # noqa: B006
|
||||
remove_docstrings: bool = False, # noqa: FBT001, FBT002
|
||||
remove_docstrings: bool = False,
|
||||
) -> str:
|
||||
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
|
||||
module = cst.parse_module(code)
|
||||
|
|
@ -888,7 +1049,7 @@ def parse_code_and_prune_cst(
|
|||
return ""
|
||||
|
||||
|
||||
def prune_cst_for_read_writable_code( # noqa: PLR0911
|
||||
def prune_cst_for_read_writable_code(
|
||||
node: cst.CSTNode, target_functions: set[str], defs_with_usages: dict[str, UsageInfo], prefix: str = ""
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
|
||||
|
|
@ -1006,7 +1167,7 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
|
|||
return (node.with_changes(**updates) if updates else node), True
|
||||
|
||||
|
||||
def prune_cst_for_code_hashing( # noqa: PLR0911
|
||||
def prune_cst_for_code_hashing(
|
||||
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
|
||||
|
|
@ -1095,14 +1256,14 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
|
|||
return (node.with_changes(**updates) if updates else node), True
|
||||
|
||||
|
||||
def prune_cst_for_context( # noqa: PLR0911
|
||||
def prune_cst_for_context(
|
||||
node: cst.CSTNode,
|
||||
target_functions: set[str],
|
||||
helpers_of_helper_functions: set[str],
|
||||
prefix: str = "",
|
||||
remove_docstrings: bool = False, # noqa: FBT001, FBT002
|
||||
include_target_in_output: bool = False, # noqa: FBT001, FBT002
|
||||
include_init_dunder: bool = False, # noqa: FBT001, FBT002
|
||||
remove_docstrings: bool = False,
|
||||
include_target_in_output: bool = False,
|
||||
include_init_dunder: bool = False,
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node for code context extraction.
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import libcst as cst
|
|||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
||||
from codeflash.languages import is_javascript
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -208,7 +209,7 @@ class DependencyCollector(cst.CSTVisitor):
|
|||
self._extract_names_from_annotation(node.value)
|
||||
# No need to check the attribute name itself as it's likely not a top-level definition
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
|
||||
self.function_depth -= 1
|
||||
|
||||
if self.function_depth == 0 and self.class_depth == 0:
|
||||
|
|
@ -237,7 +238,7 @@ class DependencyCollector(cst.CSTVisitor):
|
|||
|
||||
self.class_depth += 1
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
self.class_depth -= 1
|
||||
|
||||
if self.class_depth == 0:
|
||||
|
|
@ -260,7 +261,7 @@ class DependencyCollector(cst.CSTVisitor):
|
|||
# Use the first tracked name as the current top-level name (for dependency tracking)
|
||||
self.current_top_level_name = tracked_names[0]
|
||||
|
||||
def leave_Assign(self, original_node: cst.Assign) -> None: # noqa: ARG002
|
||||
def leave_Assign(self, original_node: cst.Assign) -> None:
|
||||
if self.processing_variable:
|
||||
self.processing_variable = False
|
||||
self.current_variable_names.clear()
|
||||
|
|
@ -370,7 +371,7 @@ class QualifiedFunctionUsageMarker:
|
|||
self.mark_as_used_recursively(dep)
|
||||
|
||||
|
||||
def remove_unused_definitions_recursively( # noqa: PLR0911
|
||||
def remove_unused_definitions_recursively(
|
||||
node: cst.CSTNode, definitions: dict[str, UsageInfo]
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node to remove unused definitions.
|
||||
|
|
@ -553,7 +554,7 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
|
|||
# Apply the recursive removal transformation
|
||||
modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages)
|
||||
|
||||
return modified_module.code if modified_module else "" # noqa: TRY300
|
||||
return modified_module.code if modified_module else ""
|
||||
except Exception as e:
|
||||
# If any other error occurs during processing, return the original code
|
||||
logger.debug(f"Error processing code to remove unused definitions: {type(e).__name__}: {e}")
|
||||
|
|
@ -629,8 +630,8 @@ def _analyze_imports_in_optimized_code(
|
|||
helpers_by_file_and_func = defaultdict(dict)
|
||||
helpers_by_file = defaultdict(list) # preserved for "import module"
|
||||
for helper in code_context.helper_functions:
|
||||
jedi_type = helper.jedi_definition.type
|
||||
if jedi_type != "class":
|
||||
jedi_type = helper.jedi_definition.type if helper.jedi_definition else None
|
||||
if jedi_type != "class": # Include when jedi_definition is None (non-Python)
|
||||
func_name = helper.only_function_name
|
||||
module_name = helper.file_path.stem
|
||||
# Cache function lookup for this (module, func)
|
||||
|
|
@ -716,6 +717,11 @@ def detect_unused_helper_functions(
|
|||
List of FunctionSource objects representing unused helper functions
|
||||
|
||||
"""
|
||||
# Skip this analysis for non-Python languages since we use Python's ast module
|
||||
if is_javascript():
|
||||
logger.debug("Skipping unused helper function detection for JavaScript/TypeScript")
|
||||
return []
|
||||
|
||||
if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0:
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
|
|
@ -783,7 +789,8 @@ def detect_unused_helper_functions(
|
|||
unused_helpers = []
|
||||
entrypoint_file_path = function_to_optimize.file_path
|
||||
for helper_function in code_context.helper_functions:
|
||||
if helper_function.jedi_definition.type != "class":
|
||||
jedi_type = helper_function.jedi_definition.type if helper_function.jedi_definition else None
|
||||
if jedi_type != "class": # Include when jedi_definition is None (non-Python)
|
||||
# Check if the helper function is called using multiple name variants
|
||||
helper_qualified_name = helper_function.qualified_name
|
||||
helper_simple_name = helper_function.only_function_name
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from codeflash.code_utils.code_utils import (
|
|||
)
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
|
||||
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
|
||||
from codeflash.languages import is_javascript, is_python
|
||||
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -554,11 +555,119 @@ def filter_test_files_by_imports(
|
|||
return filtered_map
|
||||
|
||||
|
||||
def _detect_language_from_functions(file_to_funcs: dict[Path, list[FunctionToOptimize]] | None) -> str | None:
|
||||
"""Detect language from the functions to optimize.
|
||||
|
||||
Args:
|
||||
file_to_funcs: Dictionary mapping file paths to functions.
|
||||
|
||||
Returns:
|
||||
Language string (e.g., "python", "javascript") or None if not determinable.
|
||||
|
||||
"""
|
||||
if not file_to_funcs:
|
||||
return None
|
||||
|
||||
for funcs in file_to_funcs.values():
|
||||
if funcs:
|
||||
return funcs[0].language
|
||||
return None
|
||||
|
||||
|
||||
def discover_tests_for_language(
|
||||
cfg: TestConfig, language: str, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None
|
||||
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
|
||||
"""Discover tests using language-specific support.
|
||||
|
||||
Args:
|
||||
cfg: Test configuration.
|
||||
language: Language identifier (e.g., "javascript").
|
||||
file_to_funcs_to_optimize: Dictionary mapping file paths to functions.
|
||||
|
||||
Returns:
|
||||
Tuple of (function_to_tests_map, num_tests, num_replay_tests).
|
||||
|
||||
"""
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
|
||||
|
||||
try:
|
||||
lang_support = get_language_support(Language(language))
|
||||
except Exception:
|
||||
logger.warning(f"Unsupported language {language}, returning empty test map")
|
||||
return {}, 0, 0
|
||||
|
||||
# Convert FunctionToOptimize to FunctionInfo for the language support API
|
||||
# Also build a mapping from simple qualified_name to full qualified_name_with_modules
|
||||
function_infos: list[FunctionInfo] = []
|
||||
simple_to_full_name: dict[str, str] = {}
|
||||
if file_to_funcs_to_optimize:
|
||||
for funcs in file_to_funcs_to_optimize.values():
|
||||
for func in funcs:
|
||||
parents = tuple(ParentInfo(p.name, p.type) for p in func.parents)
|
||||
func_info = FunctionInfo(
|
||||
name=func.function_name,
|
||||
file_path=func.file_path,
|
||||
start_line=func.starting_line or 0,
|
||||
end_line=func.ending_line or 0,
|
||||
start_col=func.starting_col,
|
||||
end_col=func.ending_col,
|
||||
is_async=func.is_async,
|
||||
is_method=bool(func.parents and any(p.type == "ClassDef" for p in func.parents)),
|
||||
parents=parents,
|
||||
language=Language(language),
|
||||
)
|
||||
function_infos.append(func_info)
|
||||
# Map simple qualified_name to full qualified_name_with_modules_from_root
|
||||
simple_to_full_name[func_info.qualified_name] = func.qualified_name_with_modules_from_root(
|
||||
cfg.project_root_path
|
||||
)
|
||||
|
||||
# Use language support to discover tests
|
||||
test_map = lang_support.discover_tests(cfg.tests_root, function_infos)
|
||||
|
||||
# Convert TestInfo back to FunctionCalledInTest format
|
||||
# Use the full qualified name (with modules) as the key for consistency with Python
|
||||
function_to_tests: dict[str, set[FunctionCalledInTest]] = defaultdict(set)
|
||||
num_tests = 0
|
||||
|
||||
for qualified_name, test_infos in test_map.items():
|
||||
# Convert simple qualified_name to full qualified_name_with_modules
|
||||
full_qualified_name = simple_to_full_name.get(qualified_name, qualified_name)
|
||||
for test_info in test_infos:
|
||||
function_to_tests[full_qualified_name].add(
|
||||
FunctionCalledInTest(
|
||||
tests_in_file=TestsInFile(
|
||||
test_file=test_info.test_file,
|
||||
test_class=test_info.test_class,
|
||||
test_function=test_info.test_name,
|
||||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
),
|
||||
position=CodePosition(line_no=0, col_no=0),
|
||||
)
|
||||
)
|
||||
num_tests += 1
|
||||
|
||||
return dict(function_to_tests), num_tests, 0
|
||||
|
||||
|
||||
def discover_unit_tests(
|
||||
cfg: TestConfig,
|
||||
discover_only_these_tests: list[Path] | None = None,
|
||||
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None,
|
||||
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
|
||||
# Detect language from functions being optimized
|
||||
language = _detect_language_from_functions(file_to_funcs_to_optimize)
|
||||
|
||||
# Route to language-specific test discovery for non-Python languages
|
||||
if not is_python():
|
||||
# For JavaScript/TypeScript, tests_project_rootdir should be tests_root itself
|
||||
# The Jest helper will be configured to NOT include "tests." prefix to match
|
||||
if is_javascript():
|
||||
cfg.tests_project_rootdir = cfg.tests_root
|
||||
return discover_tests_for_language(cfg, language, file_to_funcs_to_optimize)
|
||||
|
||||
# Existing Python logic
|
||||
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
|
||||
strategy = framework_strategies.get(cfg.test_framework, None)
|
||||
if not strategy:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,9 @@ from codeflash.code_utils.code_utils import (
|
|||
from codeflash.code_utils.env_utils import get_pr_number
|
||||
from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.languages import get_language_support, get_supported_extensions
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.registry import is_language_supported
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
|
|
@ -59,7 +62,7 @@ class ReturnStatementVisitor(cst.CSTVisitor):
|
|||
super().__init__()
|
||||
self.has_return_statement: bool = False
|
||||
|
||||
def visit_Return(self, node: cst.Return) -> None: # noqa: ARG002
|
||||
def visit_Return(self, node: cst.Return) -> None:
|
||||
self.has_return_statement = True
|
||||
|
||||
|
||||
|
|
@ -135,7 +138,10 @@ class FunctionToOptimize:
|
|||
parents: A list of parent scopes, which could be classes or functions.
|
||||
starting_line: The starting line number of the function in the file.
|
||||
ending_line: The ending line number of the function in the file.
|
||||
starting_col: The starting column offset (for precise location in multi-line contexts).
|
||||
ending_col: The ending column offset (for precise location in multi-line contexts).
|
||||
is_async: Whether this function is defined as async.
|
||||
language: The programming language of this function (default: "python").
|
||||
|
||||
The qualified_name property provides the full name of the function, including
|
||||
any parent class or function names. The qualified_name_with_modules_from_root
|
||||
|
|
@ -148,7 +154,10 @@ class FunctionToOptimize:
|
|||
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
|
||||
starting_line: Optional[int] = None
|
||||
ending_line: Optional[int] = None
|
||||
starting_col: Optional[int] = None # Column offset for precise location
|
||||
ending_col: Optional[int] = None # Column offset for precise location
|
||||
is_async: bool = False
|
||||
language: str = "python" # Language identifier for multi-language support
|
||||
|
||||
@property
|
||||
def top_level_parent_name(self) -> str:
|
||||
|
|
@ -172,6 +181,98 @@ class FunctionToOptimize:
|
|||
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Multi-language support helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_files_for_language(
|
||||
module_root_path: Path, ignore_paths: list[Path], language: Language | None = None
|
||||
) -> list[Path]:
|
||||
"""Get all source files for supported languages.
|
||||
|
||||
Args:
|
||||
module_root_path: Root path to search for source files.
|
||||
ignore_paths: List of paths to ignore (can be files or directories).
|
||||
language: Optional specific language to filter for. If None, includes all supported languages.
|
||||
|
||||
Returns:
|
||||
List of file paths matching supported extensions.
|
||||
|
||||
"""
|
||||
if language is not None:
|
||||
support = get_language_support(language)
|
||||
extensions = support.file_extensions
|
||||
else:
|
||||
extensions = tuple(get_supported_extensions())
|
||||
|
||||
files = []
|
||||
for ext in extensions:
|
||||
pattern = f"*{ext}"
|
||||
for file_path in module_root_path.rglob(pattern):
|
||||
if any(file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths):
|
||||
continue
|
||||
files.append(file_path)
|
||||
return files
|
||||
|
||||
|
||||
def _find_all_functions_in_python_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
|
||||
"""Find all optimizable functions in a Python file using AST parsing.
|
||||
|
||||
This is the original Python implementation preserved for backward compatibility.
|
||||
"""
|
||||
functions: dict[Path, list[FunctionToOptimize]] = {}
|
||||
with file_path.open(encoding="utf8") as f:
|
||||
try:
|
||||
ast_module = ast.parse(f.read())
|
||||
except Exception as e:
|
||||
if DEBUG_MODE:
|
||||
logger.exception(e)
|
||||
return functions
|
||||
function_name_visitor = FunctionWithReturnStatement(file_path)
|
||||
function_name_visitor.visit(ast_module)
|
||||
functions[file_path] = function_name_visitor.functions
|
||||
return functions
|
||||
|
||||
|
||||
def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
|
||||
"""Find all optimizable functions using the language support abstraction.
|
||||
|
||||
This function uses the registered language support for the file's language
|
||||
to discover functions, then converts them to FunctionToOptimize instances.
|
||||
"""
|
||||
from codeflash.languages.base import FunctionFilterCriteria
|
||||
|
||||
functions: dict[Path, list[FunctionToOptimize]] = {}
|
||||
|
||||
try:
|
||||
lang_support = get_language_support(file_path)
|
||||
criteria = FunctionFilterCriteria(require_return=True)
|
||||
function_infos = lang_support.discover_functions(file_path, criteria)
|
||||
|
||||
ftos = []
|
||||
for func_info in function_infos:
|
||||
parents = [FunctionParent(p.name, p.type) for p in func_info.parents]
|
||||
ftos.append(
|
||||
FunctionToOptimize(
|
||||
function_name=func_info.name,
|
||||
file_path=func_info.file_path,
|
||||
parents=parents,
|
||||
starting_line=func_info.start_line,
|
||||
ending_line=func_info.end_line,
|
||||
starting_col=func_info.start_col,
|
||||
ending_col=func_info.end_col,
|
||||
is_async=func_info.is_async,
|
||||
language=func_info.language.value,
|
||||
)
|
||||
)
|
||||
functions[file_path] = ftos
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to discover functions in {file_path}: {e}")
|
||||
|
||||
return functions
|
||||
|
||||
|
||||
def get_functions_to_optimize(
|
||||
optimize_all: str | None,
|
||||
replay_test: list[Path] | None,
|
||||
|
|
@ -194,7 +295,7 @@ def get_functions_to_optimize(
|
|||
if optimize_all:
|
||||
logger.info("!lsp|Finding all functions in the module '%s'…", optimize_all)
|
||||
console.rule()
|
||||
functions = get_all_files_and_functions(Path(optimize_all))
|
||||
functions = get_all_files_and_functions(Path(optimize_all), ignore_paths)
|
||||
elif replay_test:
|
||||
functions, trace_file_path = get_all_replay_test_functions(
|
||||
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
|
||||
|
|
@ -251,7 +352,7 @@ def get_functions_to_optimize(
|
|||
return filtered_modified_functions, functions_count, trace_file_path
|
||||
|
||||
|
||||
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]: # noqa: FBT001
|
||||
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]:
|
||||
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
|
||||
return get_functions_within_lines(modified_lines)
|
||||
|
||||
|
|
@ -356,9 +457,22 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str
|
|||
return functions
|
||||
|
||||
|
||||
def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[FunctionToOptimize]]:
|
||||
def get_all_files_and_functions(
|
||||
module_root_path: Path, ignore_paths: list[Path], language: Language | None = None
|
||||
) -> dict[str, list[FunctionToOptimize]]:
|
||||
"""Get all optimizable functions from files in the module root.
|
||||
|
||||
Args:
|
||||
module_root_path: Root path to search for source files.
|
||||
ignore_paths: List of paths to ignore.
|
||||
language: Optional specific language to filter for. If None, includes all supported languages.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping file paths to lists of FunctionToOptimize.
|
||||
|
||||
"""
|
||||
functions: dict[str, list[FunctionToOptimize]] = {}
|
||||
for file_path in module_root_path.rglob("*.py"):
|
||||
for file_path in get_files_for_language(module_root_path, ignore_paths, language):
|
||||
# Find all the functions in the file
|
||||
functions.update(find_all_functions_in_file(file_path).items())
|
||||
# Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time.
|
||||
|
|
@ -369,18 +483,34 @@ def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[Functi
|
|||
|
||||
|
||||
def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
|
||||
functions: dict[Path, list[FunctionToOptimize]] = {}
|
||||
with file_path.open(encoding="utf8") as f:
|
||||
try:
|
||||
ast_module = ast.parse(f.read())
|
||||
except Exception as e:
|
||||
if DEBUG_MODE:
|
||||
logger.exception(e)
|
||||
return functions
|
||||
function_name_visitor = FunctionWithReturnStatement(file_path)
|
||||
function_name_visitor.visit(ast_module)
|
||||
functions[file_path] = function_name_visitor.functions
|
||||
return functions
|
||||
"""Find all optimizable functions in a file, routing to the appropriate language handler.
|
||||
|
||||
This function checks if the file extension is supported and routes to either
|
||||
the Python-specific implementation (for backward compatibility) or the
|
||||
language support abstraction for other languages.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping file path to list of FunctionToOptimize.
|
||||
|
||||
"""
|
||||
# Check if the file extension is supported
|
||||
if not is_language_supported(file_path):
|
||||
return {}
|
||||
|
||||
try:
|
||||
lang_support = get_language_support(file_path)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
# Route to Python-specific implementation for backward compatibility
|
||||
if lang_support.language == Language.PYTHON:
|
||||
return _find_all_functions_in_python_file(file_path)
|
||||
|
||||
# Use language support abstraction for other languages
|
||||
return _find_all_functions_via_language_support(file_path)
|
||||
|
||||
|
||||
def get_all_replay_test_functions(
|
||||
|
|
@ -472,7 +602,7 @@ def get_all_replay_test_functions(
|
|||
def is_git_repo(file_path: str) -> bool:
|
||||
try:
|
||||
git.Repo(file_path, search_parent_directories=True)
|
||||
return True # noqa: TRY300
|
||||
return True
|
||||
except git.InvalidGitRepositoryError:
|
||||
return False
|
||||
|
||||
|
|
@ -704,11 +834,14 @@ def filter_functions(
|
|||
if not file_path_normalized.startswith(module_root_str + os.sep):
|
||||
non_modules_removed_count += len(_functions)
|
||||
continue
|
||||
try:
|
||||
ast.parse(f"import {module_name_from_file_path(Path(file_path), project_root)}")
|
||||
except SyntaxError:
|
||||
malformed_paths_count += 1
|
||||
continue
|
||||
|
||||
lang_support = get_language_support(Path(file_path))
|
||||
if lang_support.language == Language.PYTHON:
|
||||
try:
|
||||
ast.parse(f"import {module_name_from_file_path(Path(file_path), project_root)}")
|
||||
except SyntaxError:
|
||||
malformed_paths_count += 1
|
||||
continue
|
||||
|
||||
if blocklist_funcs:
|
||||
functions_tmp = []
|
||||
|
|
|
|||
76
codeflash/languages/__init__.py
Normal file
76
codeflash/languages/__init__.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
"""Multi-language support for Codeflash.
|
||||
|
||||
This package provides the abstraction layer that allows Codeflash to support
|
||||
multiple programming languages while keeping the core optimization pipeline
|
||||
language-agnostic.
|
||||
|
||||
Usage:
|
||||
from codeflash.languages import get_language_support, Language
|
||||
|
||||
# Get language support for a file
|
||||
lang = get_language_support(Path("example.py"))
|
||||
|
||||
# Discover functions
|
||||
functions = lang.discover_functions(file_path)
|
||||
|
||||
# Replace a function
|
||||
new_source = lang.replace_function(file_path, function, new_code)
|
||||
"""
|
||||
|
||||
from codeflash.languages.base import (
|
||||
CodeContext,
|
||||
FunctionInfo,
|
||||
HelperFunction,
|
||||
Language,
|
||||
LanguageSupport,
|
||||
ParentInfo,
|
||||
TestInfo,
|
||||
TestResult,
|
||||
)
|
||||
from codeflash.languages.current import (
|
||||
current_language,
|
||||
current_language_support,
|
||||
is_javascript,
|
||||
is_python,
|
||||
is_typescript,
|
||||
reset_current_language,
|
||||
set_current_language,
|
||||
)
|
||||
from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401
|
||||
|
||||
# Import language support modules to trigger auto-registration
|
||||
# This ensures all supported languages are available when this package is imported
|
||||
from codeflash.languages.python import PythonSupport # noqa: F401
|
||||
from codeflash.languages.registry import (
|
||||
detect_project_language,
|
||||
get_language_support,
|
||||
get_supported_extensions,
|
||||
get_supported_languages,
|
||||
register_language,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base types
|
||||
"CodeContext",
|
||||
"FunctionInfo",
|
||||
"HelperFunction",
|
||||
"Language",
|
||||
"LanguageSupport",
|
||||
"ParentInfo",
|
||||
"TestInfo",
|
||||
"TestResult",
|
||||
# Current language singleton
|
||||
"current_language",
|
||||
"current_language_support",
|
||||
# Registry functions
|
||||
"detect_project_language",
|
||||
"get_language_support",
|
||||
"get_supported_extensions",
|
||||
"get_supported_languages",
|
||||
"is_javascript",
|
||||
"is_python",
|
||||
"is_typescript",
|
||||
"register_language",
|
||||
"reset_current_language",
|
||||
"set_current_language",
|
||||
]
|
||||
688
codeflash/languages/base.py
Normal file
688
codeflash/languages/base.py
Normal file
|
|
@ -0,0 +1,688 @@
|
|||
"""Base types and protocol for multi-language support in Codeflash.
|
||||
|
||||
This module defines the core abstractions that all language implementations must follow.
|
||||
The LanguageSupport protocol defines the interface that each language must implement,
|
||||
while the dataclasses define language-agnostic representations of code constructs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Language(str, Enum):
|
||||
"""Supported programming languages."""
|
||||
|
||||
PYTHON = "python"
|
||||
JAVASCRIPT = "javascript"
|
||||
TYPESCRIPT = "typescript"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParentInfo:
|
||||
"""Parent scope information for nested functions/methods.
|
||||
|
||||
Represents the parent class or function that contains a nested function.
|
||||
Used to construct the qualified name of a function.
|
||||
|
||||
Attributes:
|
||||
name: The name of the parent scope (class name or function name).
|
||||
type: The type of parent ("ClassDef", "FunctionDef", "AsyncFunctionDef", etc.).
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: str # "ClassDef", "FunctionDef", "AsyncFunctionDef", etc.
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.type}:{self.name}"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionInfo:
|
||||
"""Language-agnostic representation of a function to optimize.
|
||||
|
||||
This class captures all the information needed to identify, locate, and
|
||||
work with a function across different programming languages.
|
||||
|
||||
Attributes:
|
||||
name: The simple function name (e.g., "add").
|
||||
file_path: Absolute path to the file containing the function.
|
||||
start_line: Starting line number (1-indexed).
|
||||
end_line: Ending line number (1-indexed, inclusive).
|
||||
parents: List of parent scopes (for nested functions/methods).
|
||||
is_async: Whether this is an async function.
|
||||
is_method: Whether this is a method (belongs to a class).
|
||||
language: The programming language.
|
||||
start_col: Starting column (0-indexed), optional for more precise location.
|
||||
end_col: Ending column (0-indexed), optional.
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
file_path: Path
|
||||
start_line: int
|
||||
end_line: int
|
||||
parents: tuple[ParentInfo, ...] = ()
|
||||
is_async: bool = False
|
||||
is_method: bool = False
|
||||
language: Language = Language.PYTHON
|
||||
start_col: int | None = None
|
||||
end_col: int | None = None
|
||||
doc_start_line: int | None = None # Line where docstring/JSDoc starts (or None if no doc comment)
|
||||
|
||||
@property
|
||||
def qualified_name(self) -> str:
|
||||
"""Full qualified name including parent scopes.
|
||||
|
||||
For a method `add` in class `Calculator`, returns "Calculator.add".
|
||||
For nested functions, includes all parent scopes.
|
||||
"""
|
||||
if not self.parents:
|
||||
return self.name
|
||||
parent_path = ".".join(parent.name for parent in self.parents)
|
||||
return f"{parent_path}.{self.name}"
|
||||
|
||||
@property
|
||||
def class_name(self) -> str | None:
|
||||
"""Get the immediate parent class name, if any."""
|
||||
for parent in reversed(self.parents):
|
||||
if parent.type == "ClassDef":
|
||||
return parent.name
|
||||
return None
|
||||
|
||||
@property
|
||||
def top_level_parent_name(self) -> str:
|
||||
"""Get the top-level parent name, or function name if no parents."""
|
||||
return self.parents[0].name if self.parents else self.name
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FunctionInfo({self.qualified_name} at {self.file_path}:{self.start_line}-{self.end_line})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HelperFunction:
|
||||
"""A helper function that is a dependency of the target function.
|
||||
|
||||
Helper functions are functions called by the target function that are
|
||||
within the same module/project (not external libraries).
|
||||
|
||||
Attributes:
|
||||
name: The simple function name.
|
||||
qualified_name: Full qualified name including parent scopes.
|
||||
file_path: Path to the file containing the helper.
|
||||
source_code: The source code of the helper function.
|
||||
start_line: Starting line number.
|
||||
end_line: Ending line number.
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
qualified_name: str
|
||||
file_path: Path
|
||||
source_code: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeContext:
|
||||
"""Code context extracted for optimization.
|
||||
|
||||
Contains the target function code and all relevant dependencies
|
||||
needed for the AI to understand and optimize the function.
|
||||
|
||||
Attributes:
|
||||
target_code: Source code of the function to optimize.
|
||||
target_file: Path to the file containing the target function.
|
||||
helper_functions: List of helper functions called by the target.
|
||||
read_only_context: Additional context code (read-only dependencies).
|
||||
imports: List of import statements needed.
|
||||
language: The programming language.
|
||||
|
||||
"""
|
||||
|
||||
target_code: str
|
||||
target_file: Path
|
||||
helper_functions: list[HelperFunction] = field(default_factory=list)
|
||||
read_only_context: str = ""
|
||||
imports: list[str] = field(default_factory=list)
|
||||
language: Language = Language.PYTHON
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestInfo:
|
||||
"""Information about a test that exercises a function.
|
||||
|
||||
Attributes:
|
||||
test_name: Name of the test function.
|
||||
test_file: Path to the test file.
|
||||
test_class: Name of the test class, if any.
|
||||
|
||||
"""
|
||||
|
||||
test_name: str
|
||||
test_file: Path
|
||||
test_class: str | None = None
|
||||
|
||||
@property
|
||||
def full_test_path(self) -> str:
|
||||
"""Get full test path in pytest format (file::class::function)."""
|
||||
file_path = self.test_file.as_posix()
|
||||
if self.test_class:
|
||||
return f"{file_path}::{self.test_class}::{self.test_name}"
|
||||
return f"{file_path}::{self.test_name}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
"""Language-agnostic test result.
|
||||
|
||||
Captures the outcome of running a single test, including timing
|
||||
and behavioral data for equivalence checking.
|
||||
|
||||
Attributes:
|
||||
test_name: Name of the test function.
|
||||
test_file: Path to the test file.
|
||||
passed: Whether the test passed.
|
||||
runtime_ns: Execution time in nanoseconds.
|
||||
return_value: The return value captured from the test.
|
||||
stdout: Standard output captured during test execution.
|
||||
stderr: Standard error captured during test execution.
|
||||
error_message: Error message if the test failed.
|
||||
|
||||
"""
|
||||
|
||||
test_name: str
|
||||
test_file: Path
|
||||
passed: bool
|
||||
runtime_ns: int | None = None
|
||||
return_value: Any = None
|
||||
stdout: str = ""
|
||||
stderr: str = ""
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionFilterCriteria:
|
||||
"""Criteria for filtering which functions to discover.
|
||||
|
||||
Attributes:
|
||||
include_patterns: Glob patterns for functions to include.
|
||||
exclude_patterns: Glob patterns for functions to exclude.
|
||||
require_return: Only include functions with return statements.
|
||||
include_async: Include async functions.
|
||||
include_methods: Include class methods.
|
||||
min_lines: Minimum number of lines in the function.
|
||||
max_lines: Maximum number of lines in the function.
|
||||
|
||||
"""
|
||||
|
||||
include_patterns: list[str] = field(default_factory=list)
|
||||
exclude_patterns: list[str] = field(default_factory=list)
|
||||
require_return: bool = True
|
||||
include_async: bool = True
|
||||
include_methods: bool = True
|
||||
min_lines: int | None = None
|
||||
max_lines: int | None = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LanguageSupport(Protocol):
|
||||
"""Protocol defining what a language implementation must provide.
|
||||
|
||||
All language-specific implementations (Python, JavaScript, etc.) must
|
||||
implement this protocol. The protocol defines the interface for:
|
||||
- Function discovery
|
||||
- Code context extraction
|
||||
- Code transformation (replacement)
|
||||
- Test execution
|
||||
- Test discovery
|
||||
- Instrumentation for tracing
|
||||
|
||||
Example:
|
||||
class PythonSupport(LanguageSupport):
|
||||
@property
|
||||
def language(self) -> Language:
|
||||
return Language.PYTHON
|
||||
|
||||
def discover_functions(self, file_path: Path, ...) -> list[FunctionInfo]:
|
||||
# Python-specific implementation using LibCST
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
# === Properties ===
|
||||
|
||||
@property
|
||||
def language(self) -> Language:
|
||||
"""The language this implementation supports."""
|
||||
...
|
||||
|
||||
@property
|
||||
def file_extensions(self) -> tuple[str, ...]:
|
||||
"""File extensions supported by this language.
|
||||
|
||||
Returns:
|
||||
Tuple of extensions with leading dots (e.g., (".py",) for Python).
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def test_framework(self) -> str:
|
||||
"""Primary test framework name.
|
||||
|
||||
Returns:
|
||||
Test framework identifier (e.g., "pytest", "jest").
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def comment_prefix(self) -> str:
|
||||
"""Like # or //."""
|
||||
...
|
||||
|
||||
# === Discovery ===
|
||||
|
||||
def discover_functions(
|
||||
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
|
||||
) -> list[FunctionInfo]:
|
||||
"""Find all optimizable functions in a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file to analyze.
|
||||
filter_criteria: Optional criteria to filter functions.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo objects for discovered functions.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionInfo]) -> dict[str, list[TestInfo]]:
|
||||
"""Map source functions to their tests via static analysis.
|
||||
|
||||
Args:
|
||||
test_root: Root directory containing tests.
|
||||
source_functions: Functions to find tests for.
|
||||
|
||||
Returns:
|
||||
Dict mapping qualified function names to lists of TestInfo.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
# === Code Analysis ===
|
||||
|
||||
def extract_code_context(self, function: FunctionInfo, project_root: Path, module_root: Path) -> CodeContext:
|
||||
"""Extract function code and its dependencies.
|
||||
|
||||
Args:
|
||||
function: The function to extract context for.
|
||||
project_root: Root of the project.
|
||||
module_root: Root of the module containing the function.
|
||||
|
||||
Returns:
|
||||
CodeContext with target code and dependencies.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def find_helper_functions(self, function: FunctionInfo, project_root: Path) -> list[HelperFunction]:
|
||||
"""Find helper functions called by the target function.
|
||||
|
||||
Args:
|
||||
function: The target function to analyze.
|
||||
project_root: Root of the project.
|
||||
|
||||
Returns:
|
||||
List of HelperFunction objects.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
# === Code Transformation ===
|
||||
|
||||
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:
|
||||
"""Replace a function in source code with new implementation.
|
||||
|
||||
Args:
|
||||
source: Original source code.
|
||||
function: FunctionInfo identifying the function to replace.
|
||||
new_source: New function source code.
|
||||
|
||||
Returns:
|
||||
Modified source code with function replaced.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def format_code(self, source: str, file_path: Path | None = None) -> str:
|
||||
"""Format code using language-specific formatter.
|
||||
|
||||
Args:
|
||||
source: Source code to format.
|
||||
file_path: Optional file path for context.
|
||||
|
||||
Returns:
|
||||
Formatted source code.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
# === Test Execution ===
|
||||
|
||||
def run_tests(
|
||||
self, test_files: Sequence[Path], cwd: Path, env: dict[str, str], timeout: int
|
||||
) -> tuple[list[TestResult], Path]:
|
||||
"""Run tests and return results.
|
||||
|
||||
Args:
|
||||
test_files: Paths to test files to run.
|
||||
cwd: Working directory for test execution.
|
||||
env: Environment variables.
|
||||
timeout: Maximum execution time in seconds.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of TestResults, path to JUnit XML).
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def parse_test_results(self, junit_xml_path: Path, stdout: str) -> list[TestResult]:
|
||||
"""Parse test results from JUnit XML and stdout.
|
||||
|
||||
Args:
|
||||
junit_xml_path: Path to JUnit XML results file.
|
||||
stdout: Standard output from test execution.
|
||||
|
||||
Returns:
|
||||
List of TestResult objects.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
# === Instrumentation ===
|
||||
|
||||
def instrument_for_behavior(self, source: str, functions: Sequence[FunctionInfo]) -> str:
|
||||
"""Add behavior instrumentation to capture inputs/outputs.
|
||||
|
||||
Args:
|
||||
source: Source code to instrument.
|
||||
functions: Functions to add behavior capture.
|
||||
|
||||
Returns:
|
||||
Instrumented source code.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionInfo) -> str:
|
||||
"""Add timing instrumentation to test code.
|
||||
|
||||
Args:
|
||||
test_source: Test source code to instrument.
|
||||
target_function: Function being benchmarked.
|
||||
|
||||
Returns:
|
||||
Instrumented test source code.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
# === Validation ===
|
||||
|
||||
def validate_syntax(self, source: str) -> bool:
|
||||
"""Check if source code is syntactically valid.
|
||||
|
||||
Args:
|
||||
source: Source code to validate.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def normalize_code(self, source: str) -> str:
|
||||
"""Normalize code for deduplication.
|
||||
|
||||
Removes comments, normalizes whitespace, etc. to allow
|
||||
comparison of semantically equivalent code.
|
||||
|
||||
Args:
|
||||
source: Source code to normalize.
|
||||
|
||||
Returns:
|
||||
Normalized source code.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
# === Test Editing ===
|
||||
|
||||
def add_runtime_comments(
|
||||
self, test_source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]
|
||||
) -> str:
|
||||
"""Add runtime performance comments to test source code.
|
||||
|
||||
Adds comments showing the original vs optimized runtime for each
|
||||
function call (e.g., "// 1.5ms -> 0.3ms (80% faster)").
|
||||
|
||||
Args:
|
||||
test_source: Test source code to annotate.
|
||||
original_runtimes: Map of invocation IDs to original runtimes (ns).
|
||||
optimized_runtimes: Map of invocation IDs to optimized runtimes (ns).
|
||||
|
||||
Returns:
|
||||
Test source code with runtime comments added.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def remove_test_functions(self, test_source: str, functions_to_remove: list[str]) -> str:
|
||||
"""Remove specific test functions from test source code.
|
||||
|
||||
Args:
|
||||
test_source: Test source code.
|
||||
functions_to_remove: List of function names to remove.
|
||||
|
||||
Returns:
|
||||
Test source code with specified functions removed.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
# === Test Result Comparison ===
|
||||
|
||||
def compare_test_results(
|
||||
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
|
||||
) -> tuple[bool, list]:
|
||||
"""Compare test results between original and candidate code.
|
||||
|
||||
Args:
|
||||
original_results_path: Path to original test results (e.g., SQLite DB).
|
||||
candidate_results_path: Path to candidate test results.
|
||||
project_root: Project root directory (for finding node_modules, etc.).
|
||||
|
||||
Returns:
|
||||
Tuple of (are_equivalent, list of TestDiff objects).
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
# === Configuration ===
|
||||
|
||||
def get_test_file_suffix(self) -> str:
|
||||
"""Get the test file suffix for this language.
|
||||
|
||||
Returns:
|
||||
Test file suffix (e.g., ".test.js", "_test.py").
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def get_comment_prefix(self) -> str:
|
||||
"""Get the comment prefix for this language.
|
||||
|
||||
Returns:
|
||||
Comment prefix (e.g., "//" for JS, "#" for Python).
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def find_test_root(self, project_root: Path) -> Path | None:
|
||||
"""Find the test root directory for a project.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
|
||||
Returns:
|
||||
Path to test root, or None if not found.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def get_runtime_files(self) -> list[Path]:
|
||||
"""Get paths to runtime files that need to be copied to user's project.
|
||||
|
||||
Returns:
|
||||
List of paths to runtime files (e.g., codeflash-jest-helper.js).
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def ensure_runtime_environment(self, project_root: Path) -> bool:
|
||||
"""Ensure the runtime environment is set up for the project.
|
||||
|
||||
This method handles language-specific runtime setup, such as installing
|
||||
npm packages for JavaScript or pip packages for Python.
|
||||
|
||||
Args:
|
||||
project_root: The project root directory.
|
||||
|
||||
Returns:
|
||||
True if runtime environment is ready, False otherwise.
|
||||
|
||||
"""
|
||||
# Default implementation: just copy runtime files
|
||||
return False
|
||||
|
||||
def instrument_existing_test(
|
||||
self,
|
||||
test_path: Path,
|
||||
call_positions: Sequence[Any],
|
||||
function_to_optimize: Any,
|
||||
tests_project_root: Path,
|
||||
mode: str,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Inject profiling code into an existing test file.
|
||||
|
||||
Wraps function calls with capture/benchmark instrumentation for
|
||||
behavioral verification and performance benchmarking.
|
||||
|
||||
Args:
|
||||
test_path: Path to the test file.
|
||||
call_positions: List of code positions where the function is called.
|
||||
function_to_optimize: The function being optimized.
|
||||
tests_project_root: Root directory of tests.
|
||||
mode: Testing mode - "behavior" or "performance".
|
||||
|
||||
Returns:
|
||||
Tuple of (success, instrumented_code).
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def instrument_source_for_line_profiler(self, func_info: FunctionInfo, line_profiler_output_file: Path) -> bool:
|
||||
"""Instrument source code before line profiling."""
|
||||
...
|
||||
|
||||
def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict:
|
||||
"""Parse line profiler output."""
|
||||
...
|
||||
|
||||
# === Test Execution ===
|
||||
|
||||
def run_behavioral_tests(
|
||||
self,
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
enable_coverage: bool = False,
|
||||
candidate_index: int = 0,
|
||||
) -> tuple[Path, Any, Path | None, Path | None]:
|
||||
"""Run behavioral tests for this language.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
test_env: Environment variables for the test run.
|
||||
cwd: Working directory for running tests.
|
||||
timeout: Optional timeout in seconds.
|
||||
project_root: Project root directory.
|
||||
enable_coverage: Whether to collect coverage information.
|
||||
candidate_index: Index of the candidate being tested.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result, coverage_path, config_path).
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def run_benchmarking_tests(
|
||||
self,
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
min_loops: int = 5,
|
||||
max_loops: int = 100_000,
|
||||
target_duration_seconds: float = 10.0,
|
||||
) -> tuple[Path, Any]:
|
||||
"""Run benchmarking tests for this language.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
test_env: Environment variables for the test run.
|
||||
cwd: Working directory for running tests.
|
||||
timeout: Optional timeout in seconds.
|
||||
project_root: Project root directory.
|
||||
min_loops: Minimum number of loops for benchmarking.
|
||||
max_loops: Maximum number of loops for benchmarking.
|
||||
target_duration_seconds: Target duration for benchmarking in seconds.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result).
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def convert_parents_to_tuple(parents: list | tuple) -> tuple[ParentInfo, ...]:
|
||||
"""Convert a list of parent objects to a tuple of ParentInfo.
|
||||
|
||||
This helper handles conversion from the existing FunctionParent
|
||||
dataclass to the new ParentInfo dataclass.
|
||||
|
||||
Args:
|
||||
parents: List or tuple of parent objects with name and type attributes.
|
||||
|
||||
Returns:
|
||||
Tuple of ParentInfo objects.
|
||||
|
||||
"""
|
||||
return tuple(ParentInfo(name=p.name, type=p.type) for p in parents)
|
||||
118
codeflash/languages/current.py
Normal file
118
codeflash/languages/current.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""Singleton for the current language being used in the codeflash session.
|
||||
|
||||
This module provides a centralized way to access and set the current language
|
||||
throughout the codeflash codebase, eliminating scattered language checks and
|
||||
string comparisons.
|
||||
|
||||
Usage:
|
||||
from codeflash.languages import current_language, set_current_language, is_python
|
||||
|
||||
# Set the language at the start of a session
|
||||
set_current_language(Language.PYTHON)
|
||||
# or
|
||||
set_current_language("javascript")
|
||||
|
||||
# Check the current language anywhere in the codebase
|
||||
if is_python():
|
||||
# Python-specific code
|
||||
...
|
||||
|
||||
# Get the current language
|
||||
lang = current_language()
|
||||
|
||||
# Get language support for the current language
|
||||
support = current_language_support()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.languages.base import LanguageSupport
|
||||
|
||||
# Module-level singleton for the current language
|
||||
_current_language: Language | None = None
|
||||
|
||||
|
||||
def current_language() -> Language:
|
||||
"""Get the current language being used in this codeflash session.
|
||||
|
||||
Returns:
|
||||
The current Language enum value.
|
||||
|
||||
"""
|
||||
return _current_language
|
||||
|
||||
|
||||
def set_current_language(language: Language | str) -> None:
|
||||
"""Set the current language for this codeflash session.
|
||||
|
||||
This should be called once at the start of an optimization run,
|
||||
typically after reading the project configuration.
|
||||
|
||||
Args:
|
||||
language: Either a Language enum value or a string like "python", "javascript", "typescript".
|
||||
|
||||
"""
|
||||
global _current_language
|
||||
|
||||
if _current_language is not None:
|
||||
return
|
||||
_current_language = Language(language) if isinstance(language, str) else language
|
||||
|
||||
|
||||
def reset_current_language() -> None:
|
||||
"""Reset the current language to the default (Python).
|
||||
|
||||
Useful for testing or when starting a new session.
|
||||
"""
|
||||
global _current_language
|
||||
_current_language = Language.PYTHON
|
||||
|
||||
|
||||
def is_python() -> bool:
|
||||
"""Check if the current language is Python.
|
||||
|
||||
Returns:
|
||||
True if the current language is Python.
|
||||
|
||||
"""
|
||||
return _current_language == Language.PYTHON
|
||||
|
||||
|
||||
def is_javascript() -> bool:
|
||||
"""Check if the current language is JavaScript or TypeScript.
|
||||
|
||||
This returns True for both JavaScript and TypeScript since they are
|
||||
typically treated the same way in the optimization pipeline.
|
||||
|
||||
Returns:
|
||||
True if the current language is JavaScript or TypeScript.
|
||||
|
||||
"""
|
||||
return _current_language in (Language.JAVASCRIPT, Language.TYPESCRIPT)
|
||||
|
||||
|
||||
def is_typescript() -> bool:
|
||||
"""Check if the current language is TypeScript specifically.
|
||||
|
||||
Returns:
|
||||
True if the current language is TypeScript.
|
||||
|
||||
"""
|
||||
return _current_language == Language.TYPESCRIPT
|
||||
|
||||
|
||||
def current_language_support() -> LanguageSupport:
|
||||
"""Get the LanguageSupport instance for the current language.
|
||||
|
||||
Returns:
|
||||
The LanguageSupport instance for the current language.
|
||||
|
||||
"""
|
||||
from codeflash.languages.registry import get_language_support
|
||||
|
||||
return get_language_support(_current_language)
|
||||
5
codeflash/languages/javascript/__init__.py
Normal file
5
codeflash/languages/javascript/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""JavaScript/TypeScript language support for codeflash."""
|
||||
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
|
||||
__all__ = ["JavaScriptSupport", "TypeScriptSupport"]
|
||||
192
codeflash/languages/javascript/comparator.py
Normal file
192
codeflash/languages/javascript/comparator.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
"""JavaScript test result comparison.
|
||||
|
||||
This module provides functionality to compare test results between
|
||||
original and optimized JavaScript code using a Node.js comparison script.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.models.models import TestDiff, TestDiffScope
|
||||
|
||||
|
||||
def _get_compare_results_script(project_root: Path | None = None) -> Path | None:
|
||||
"""Find the compare-results.js script from the installed codeflash npm package.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory where node_modules is installed.
|
||||
|
||||
Returns:
|
||||
Path to compare-results.js if found, None otherwise.
|
||||
|
||||
"""
|
||||
search_dirs = []
|
||||
if project_root:
|
||||
search_dirs.append(project_root)
|
||||
search_dirs.append(Path.cwd())
|
||||
|
||||
for base_dir in search_dirs:
|
||||
script_path = base_dir / "node_modules" / "codeflash" / "runtime" / "compare-results.js"
|
||||
if script_path.exists():
|
||||
return script_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def compare_test_results(
|
||||
original_sqlite_path: Path,
|
||||
candidate_sqlite_path: Path,
|
||||
comparator_script: Path | None = None,
|
||||
project_root: Path | None = None,
|
||||
) -> tuple[bool, list[TestDiff]]:
|
||||
"""Compare JavaScript test results using the Node.js comparator.
|
||||
|
||||
This function calls a Node.js script that:
|
||||
1. Reads serialized behavior data from both SQLite databases
|
||||
2. Deserializes using the codeflash serializer module
|
||||
3. Compares using the codeflash comparator module (handles Map, Set, Date, etc. natively)
|
||||
4. Returns comparison results as JSON
|
||||
|
||||
Args:
|
||||
original_sqlite_path: Path to SQLite database with original code results.
|
||||
candidate_sqlite_path: Path to SQLite database with candidate code results.
|
||||
comparator_script: Optional path to the comparison script.
|
||||
project_root: Project root directory where node_modules is installed.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_equivalent, list of TestDiff objects).
|
||||
|
||||
"""
|
||||
script_path = comparator_script or _get_compare_results_script(project_root)
|
||||
|
||||
if not script_path or not script_path.exists():
|
||||
logger.error(
|
||||
"JavaScript comparator script not found. "
|
||||
"Please ensure the 'codeflash' npm package is installed in your project."
|
||||
)
|
||||
return False, []
|
||||
|
||||
if not original_sqlite_path.exists():
|
||||
logger.error(f"Original SQLite database not found: {original_sqlite_path}")
|
||||
return False, []
|
||||
|
||||
if not candidate_sqlite_path.exists():
|
||||
logger.error(f"Candidate SQLite database not found: {candidate_sqlite_path}")
|
||||
return False, []
|
||||
|
||||
# Determine working directory - should be where node_modules is installed
|
||||
# The script needs better-sqlite3 which is installed in the project's node_modules
|
||||
cwd = project_root or Path.cwd()
|
||||
|
||||
# Set NODE_PATH to include the project's node_modules
|
||||
# This is needed because the script runs from the codeflash package directory,
|
||||
# but needs to resolve modules from the project's node_modules
|
||||
env = os.environ.copy()
|
||||
node_modules_path = cwd / "node_modules"
|
||||
if node_modules_path.exists():
|
||||
existing_node_path = env.get("NODE_PATH", "")
|
||||
if existing_node_path:
|
||||
env["NODE_PATH"] = f"{node_modules_path}:{existing_node_path}"
|
||||
else:
|
||||
env["NODE_PATH"] = str(node_modules_path)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["node", str(script_path), str(original_sqlite_path), str(candidate_sqlite_path)],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
cwd=str(cwd),
|
||||
env=env,
|
||||
)
|
||||
|
||||
# Parse the JSON output first - errors are reported in JSON too
|
||||
try:
|
||||
if not result.stdout or not result.stdout.strip():
|
||||
logger.error("JavaScript comparator returned empty output")
|
||||
if result.stderr:
|
||||
logger.error(f"stderr: {result.stderr}")
|
||||
return False, []
|
||||
comparison = json.loads(result.stdout)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse JavaScript comparator output: {e}")
|
||||
logger.error(f"stdout: {result.stdout[:500] if result.stdout else '(empty)'}")
|
||||
if result.stderr:
|
||||
logger.error(f"stderr: {result.stderr[:500]}")
|
||||
return False, []
|
||||
|
||||
# Check for errors in the JSON response
|
||||
# Exit code 0 = equivalent, 1 = not equivalent, 2 = setup error
|
||||
if comparison.get("error"):
|
||||
logger.error(f"JavaScript comparator error: {comparison['error']}")
|
||||
return False, []
|
||||
|
||||
# Check for unexpected exit codes (not 0 or 1)
|
||||
if result.returncode not in {0, 1}:
|
||||
logger.error(f"JavaScript comparator failed with exit code {result.returncode}")
|
||||
if result.stderr:
|
||||
logger.error(f"stderr: {result.stderr}")
|
||||
return False, []
|
||||
|
||||
# Convert diffs to TestDiff objects
|
||||
test_diffs: list[TestDiff] = []
|
||||
for diff in comparison.get("diffs", []):
|
||||
scope_str = diff.get("scope", "return_value")
|
||||
scope = TestDiffScope.RETURN_VALUE
|
||||
if scope_str == "stdout":
|
||||
scope = TestDiffScope.STDOUT
|
||||
elif scope_str == "did_pass":
|
||||
scope = TestDiffScope.DID_PASS
|
||||
|
||||
test_info = diff.get("test_info", {})
|
||||
# Build a test identifier string for JavaScript tests
|
||||
test_function_name = test_info.get("test_function_name", "unknown")
|
||||
function_getting_tested = test_info.get("function_getting_tested", "unknown")
|
||||
test_src_code = f"// Test: {test_function_name}\n// Testing function: {function_getting_tested}"
|
||||
|
||||
test_diffs.append(
|
||||
TestDiff(
|
||||
scope=scope,
|
||||
original_value=diff.get("original"),
|
||||
candidate_value=diff.get("candidate"),
|
||||
test_src_code=test_src_code,
|
||||
candidate_pytest_error=diff.get("candidate_error"),
|
||||
original_pass=True, # Assume passed if we got results
|
||||
candidate_pass=diff.get("scope") != "missing",
|
||||
original_pytest_error=None,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"JavaScript test diff:\n"
|
||||
f" Test: {test_function_name}\n"
|
||||
f" Function: {function_getting_tested}\n"
|
||||
f" Scope: {scope_str}\n"
|
||||
f" Original: {str(diff.get('original', 'N/A'))[:100]}\n"
|
||||
f" Candidate: {str(diff.get('candidate', 'N/A'))[:100] if diff.get('candidate') else 'N/A'}"
|
||||
)
|
||||
|
||||
equivalent = comparison.get("equivalent", False)
|
||||
|
||||
logger.info(
|
||||
f"JavaScript comparison: {'equivalent' if equivalent else 'DIFFERENT'} "
|
||||
f"({comparison.get('total_invocations', 0)} invocations, {len(test_diffs)} diffs)"
|
||||
)
|
||||
|
||||
return equivalent, test_diffs
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("JavaScript comparator timed out")
|
||||
return False, []
|
||||
except FileNotFoundError:
|
||||
logger.error("Node.js not found. Please install Node.js to compare JavaScript test results.")
|
||||
return False, []
|
||||
except Exception as e:
|
||||
logger.error(f"Error running JavaScript comparator: {e}")
|
||||
return False, []
|
||||
230
codeflash/languages/javascript/edit_tests.py
Normal file
230
codeflash/languages/javascript/edit_tests.py
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
"""JavaScript test editing utilities.
|
||||
|
||||
This module provides functionality for editing JavaScript/TypeScript test files,
|
||||
including adding runtime comments and removing test functions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.time_utils import format_perf, format_time
|
||||
from codeflash.result.critic import performance_gain
|
||||
|
||||
|
||||
def format_runtime_comment(original_time: int, optimized_time: int) -> str:
|
||||
"""Format a runtime comparison comment for JavaScript.
|
||||
|
||||
Args:
|
||||
original_time: Original runtime in nanoseconds.
|
||||
optimized_time: Optimized runtime in nanoseconds.
|
||||
|
||||
Returns:
|
||||
Formatted comment string with // prefix.
|
||||
|
||||
"""
|
||||
perf_gain = format_perf(
|
||||
abs(performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100)
|
||||
)
|
||||
status = "slower" if optimized_time > original_time else "faster"
|
||||
return f"// {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
|
||||
|
||||
|
||||
def add_runtime_comments(source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]) -> str:
|
||||
"""Add runtime comments to JavaScript test source code.
|
||||
|
||||
For JavaScript, we match timing data by test function name and add comments
|
||||
to expect() or function call lines.
|
||||
|
||||
Args:
|
||||
source: JavaScript test source code.
|
||||
original_runtimes: Map of invocation keys to original runtimes (ns).
|
||||
optimized_runtimes: Map of invocation keys to optimized runtimes (ns).
|
||||
|
||||
Returns:
|
||||
Source code with runtime comments added.
|
||||
|
||||
"""
|
||||
logger.debug(f"[js-annotations] original_runtimes has {len(original_runtimes)} entries")
|
||||
logger.debug(f"[js-annotations] optimized_runtimes has {len(optimized_runtimes)} entries")
|
||||
|
||||
if not original_runtimes or not optimized_runtimes:
|
||||
logger.debug("[js-annotations] No runtimes available, returning unchanged source")
|
||||
return source
|
||||
|
||||
lines = source.split("\n")
|
||||
modified_lines = []
|
||||
|
||||
# Build a lookup by FULL test name (including describe blocks) for suffix matching
|
||||
# The keys in original_runtimes look like: "full_test_name#/path/to/test#invocation_id"
|
||||
# where full_test_name includes describe blocks: "fibonacci Edge cases should return 0"
|
||||
timing_by_full_name: dict[str, tuple[int, int]] = {}
|
||||
for key in original_runtimes:
|
||||
if key in optimized_runtimes:
|
||||
# Extract test function name from the key (first part before #)
|
||||
parts = key.split("#")
|
||||
if parts:
|
||||
full_test_name = parts[0]
|
||||
logger.debug(f"[js-annotations] Found timing for full test name: '{full_test_name}'")
|
||||
if full_test_name not in timing_by_full_name:
|
||||
timing_by_full_name[full_test_name] = (original_runtimes[key], optimized_runtimes[key])
|
||||
else:
|
||||
# Sum up timings for same test
|
||||
old_orig, old_opt = timing_by_full_name[full_test_name]
|
||||
timing_by_full_name[full_test_name] = (
|
||||
old_orig + original_runtimes[key],
|
||||
old_opt + optimized_runtimes[key],
|
||||
)
|
||||
|
||||
logger.debug(f"[js-annotations] Built timing_by_full_name with {len(timing_by_full_name)} entries")
|
||||
|
||||
def find_matching_test(test_description: str) -> str | None:
|
||||
"""Find a timing key that ends with the given test description (suffix match).
|
||||
|
||||
Timing keys are like: "fibonacci Edge cases should return 0"
|
||||
Source test names are like: "should return 0"
|
||||
We need to match by suffix because timing includes all describe block names.
|
||||
"""
|
||||
# Try to match by finding a key that ends with the test description
|
||||
for full_name in timing_by_full_name:
|
||||
# Check if the full name ends with the test description (case-insensitive)
|
||||
if full_name.lower().endswith(test_description.lower()):
|
||||
logger.debug(f"[js-annotations] Suffix match: '{test_description}' matches '{full_name}'")
|
||||
return full_name
|
||||
return None
|
||||
|
||||
# Track current test context
|
||||
current_test_name = None
|
||||
current_matched_full_name = None
|
||||
test_pattern = re.compile(r"(?:test|it)\s*\(\s*['\"]([^'\"]+)['\"]")
|
||||
# Match function calls that look like: funcName(args) or expect(funcName(args))
|
||||
func_call_pattern = re.compile(r"(?:expect\s*\(\s*)?(\w+)\s*\([^)]*\)")
|
||||
|
||||
for line in lines:
|
||||
# Check if this line starts a new test
|
||||
test_match = test_pattern.search(line)
|
||||
if test_match:
|
||||
current_test_name = test_match.group(1)
|
||||
logger.debug(f"[js-annotations] Found test: '{current_test_name}'")
|
||||
# Find the matching full name from timing data using suffix match
|
||||
current_matched_full_name = find_matching_test(current_test_name)
|
||||
if current_matched_full_name:
|
||||
logger.debug(f"[js-annotations] Test '{current_test_name}' matched to '{current_matched_full_name}'")
|
||||
|
||||
# Check if this line has a function call and we have timing for current test
|
||||
if current_matched_full_name and current_matched_full_name in timing_by_full_name:
|
||||
# Only add comment if line has a function call and doesn't already have a comment
|
||||
if func_call_pattern.search(line) and "//" not in line and "expect(" in line:
|
||||
orig_time, opt_time = timing_by_full_name[current_matched_full_name]
|
||||
comment = format_runtime_comment(orig_time, opt_time)
|
||||
logger.debug(f"[js-annotations] Adding comment to test '{current_test_name}': {comment}")
|
||||
# Add comment at end of line
|
||||
line = f"{line.rstrip()} {comment}"
|
||||
# Clear timing so we only annotate first call in each test
|
||||
del timing_by_full_name[current_matched_full_name]
|
||||
current_matched_full_name = None
|
||||
|
||||
modified_lines.append(line)
|
||||
|
||||
return "\n".join(modified_lines)
|
||||
|
||||
|
||||
def remove_test_functions(source: str, functions_to_remove: list[str]) -> str:
|
||||
"""Remove specific test functions from JavaScript test source code.
|
||||
|
||||
Handles Jest test patterns: test(), it(), and describe() blocks.
|
||||
|
||||
Args:
|
||||
source: JavaScript test source code.
|
||||
functions_to_remove: List of test function/describe names to remove.
|
||||
|
||||
Returns:
|
||||
Source code with specified functions removed.
|
||||
|
||||
"""
|
||||
if not functions_to_remove:
|
||||
return source
|
||||
|
||||
for func_name in functions_to_remove:
|
||||
# Pattern to match test('name', ...) or it('name', ...) blocks
|
||||
# This handles nested callbacks and multi-line test bodies
|
||||
test_pattern = re.compile(
|
||||
r"(?:test|it)\s*\(\s*['\"]" + re.escape(func_name) + r"['\"].*?\)\s*;?\s*\n?", re.DOTALL
|
||||
)
|
||||
|
||||
# Try to find and remove matching test blocks
|
||||
# For more complex removal, we'd need to track brace matching
|
||||
match = test_pattern.search(source)
|
||||
if match:
|
||||
# Find the full test block by tracking braces
|
||||
start = match.start()
|
||||
end = _find_block_end(source, match.end() - 1)
|
||||
if end > start:
|
||||
source = source[:start] + source[end:]
|
||||
|
||||
return source
|
||||
|
||||
|
||||
def _find_block_end(source: str, start: int) -> int:
|
||||
"""Find the end of a JavaScript block starting from a position.
|
||||
|
||||
Tracks brace matching to find where a function/block ends.
|
||||
|
||||
Args:
|
||||
source: Source code.
|
||||
start: Starting position (should be at or before opening brace).
|
||||
|
||||
Returns:
|
||||
Position after the closing brace, or start if not found.
|
||||
|
||||
"""
|
||||
# Find the opening brace
|
||||
brace_pos = source.find("{", start)
|
||||
if brace_pos == -1:
|
||||
# No block found, try to find end of arrow function or simple statement
|
||||
semicolon_pos = source.find(";", start)
|
||||
newline_pos = source.find("\n", start)
|
||||
if semicolon_pos != -1:
|
||||
return semicolon_pos + 1
|
||||
if newline_pos != -1:
|
||||
return newline_pos + 1
|
||||
return start
|
||||
|
||||
# Track brace depth
|
||||
depth = 0
|
||||
in_string = False
|
||||
string_char = None
|
||||
i = brace_pos
|
||||
|
||||
while i < len(source):
|
||||
char = source[i]
|
||||
|
||||
# Handle string literals
|
||||
if char in ('"', "'", "`") and (i == 0 or source[i - 1] != "\\"):
|
||||
if not in_string:
|
||||
in_string = True
|
||||
string_char = char
|
||||
elif char == string_char:
|
||||
in_string = False
|
||||
string_char = None
|
||||
elif not in_string:
|
||||
if char == "{":
|
||||
depth += 1
|
||||
elif char == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
# Found the matching closing brace
|
||||
# Skip any trailing semicolon or newline
|
||||
end = i + 1
|
||||
while end < len(source) and source[end] in " \t":
|
||||
end += 1
|
||||
if end < len(source) and source[end] == ";":
|
||||
end += 1
|
||||
while end < len(source) and source[end] in " \t\n":
|
||||
end += 1
|
||||
return end
|
||||
|
||||
i += 1
|
||||
|
||||
return start
|
||||
540
codeflash/languages/javascript/import_resolver.py
Normal file
540
codeflash/languages/javascript/import_resolver.py
Normal file
|
|
@ -0,0 +1,540 @@
|
|||
"""Import resolution for JavaScript/TypeScript.
|
||||
|
||||
This module provides utilities to resolve JavaScript/TypeScript import paths
|
||||
to actual file paths, enabling multi-file context extraction.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.languages.base import FunctionInfo, HelperFunction
|
||||
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedImport:
|
||||
"""Result of resolving an import to a file path."""
|
||||
|
||||
file_path: Path # Resolved absolute file path
|
||||
module_path: str # Original import path (e.g., './utils')
|
||||
imported_names: list[str] # Names imported (for named imports)
|
||||
is_default_import: bool # Whether it's a default import
|
||||
is_namespace_import: bool # Whether it's import * as X
|
||||
namespace_name: str | None # The namespace alias (X in import * as X)
|
||||
|
||||
|
||||
class ImportResolver:
|
||||
"""Resolves JavaScript/TypeScript import paths to file paths."""
|
||||
|
||||
# Supported extensions in resolution order (prefer TS over JS)
|
||||
EXTENSIONS = (".ts", ".tsx", ".js", ".jsx", ".mjs", ".cjs")
|
||||
|
||||
def __init__(self, project_root: Path) -> None:
|
||||
"""Initialize the resolver.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
|
||||
"""
|
||||
self.project_root = project_root
|
||||
self._resolution_cache: dict[tuple[Path, str], Path | None] = {}
|
||||
|
||||
def resolve_import(self, import_info: ImportInfo, source_file: Path) -> ResolvedImport | None:
|
||||
"""Resolve an import to its actual file path.
|
||||
|
||||
Args:
|
||||
import_info: The import statement information.
|
||||
source_file: The file containing the import statement.
|
||||
|
||||
Returns:
|
||||
ResolvedImport if resolution successful, None otherwise.
|
||||
|
||||
"""
|
||||
module_path = import_info.module_path
|
||||
|
||||
# Skip external packages (node_modules)
|
||||
if self._is_external_package(module_path):
|
||||
logger.debug("Skipping external package: %s", module_path)
|
||||
return None
|
||||
|
||||
# Check cache
|
||||
cache_key = (source_file, module_path)
|
||||
if cache_key in self._resolution_cache:
|
||||
cached_path = self._resolution_cache[cache_key]
|
||||
if cached_path is None:
|
||||
return None
|
||||
return self._build_resolved_import(import_info, cached_path)
|
||||
|
||||
# Resolve the path
|
||||
resolved_path = self._resolve_module_path(module_path, source_file.parent)
|
||||
|
||||
# Cache the result
|
||||
self._resolution_cache[cache_key] = resolved_path
|
||||
|
||||
if resolved_path is None:
|
||||
logger.debug("Could not resolve import: %s from %s", module_path, source_file)
|
||||
return None
|
||||
|
||||
return self._build_resolved_import(import_info, resolved_path)
|
||||
|
||||
def _build_resolved_import(self, import_info: ImportInfo, resolved_path: Path) -> ResolvedImport:
|
||||
"""Build a ResolvedImport from import info and resolved path."""
|
||||
imported_names = []
|
||||
|
||||
# Collect named imports
|
||||
for name, alias in import_info.named_imports:
|
||||
imported_names.append(alias if alias else name)
|
||||
|
||||
# Add default import if present
|
||||
if import_info.default_import:
|
||||
imported_names.append(import_info.default_import)
|
||||
|
||||
return ResolvedImport(
|
||||
file_path=resolved_path,
|
||||
module_path=import_info.module_path,
|
||||
imported_names=imported_names,
|
||||
is_default_import=import_info.default_import is not None,
|
||||
is_namespace_import=import_info.namespace_import is not None,
|
||||
namespace_name=import_info.namespace_import,
|
||||
)
|
||||
|
||||
def _resolve_module_path(self, module_path: str, source_dir: Path) -> Path | None:
|
||||
"""Resolve a module path to an absolute file path.
|
||||
|
||||
Args:
|
||||
module_path: The import path (e.g., './utils', '../lib/helper').
|
||||
source_dir: Directory of the file containing the import.
|
||||
|
||||
Returns:
|
||||
Resolved absolute path, or None if not found.
|
||||
|
||||
"""
|
||||
# Handle relative imports
|
||||
if module_path.startswith("."):
|
||||
return self._resolve_relative_import(module_path, source_dir)
|
||||
|
||||
# Handle absolute imports (starting with /)
|
||||
if module_path.startswith("/"):
|
||||
return self._resolve_absolute_import(module_path)
|
||||
|
||||
# Bare imports (e.g., 'lodash') are external packages
|
||||
return None
|
||||
|
||||
def _resolve_relative_import(self, module_path: str, source_dir: Path) -> Path | None:
|
||||
"""Resolve relative imports like ./utils or ../lib/helper.
|
||||
|
||||
Args:
|
||||
module_path: The relative import path.
|
||||
source_dir: Directory to resolve from.
|
||||
|
||||
Returns:
|
||||
Resolved absolute path, or None if not found.
|
||||
|
||||
"""
|
||||
# Compute base path
|
||||
base_path = (source_dir / module_path).resolve()
|
||||
|
||||
# Check if path is within project
|
||||
try:
|
||||
base_path.relative_to(self.project_root)
|
||||
except ValueError:
|
||||
logger.debug("Import path outside project root: %s", base_path)
|
||||
return None
|
||||
|
||||
# If the path already has an extension, try it directly first
|
||||
if base_path.suffix in self.EXTENSIONS:
|
||||
if base_path.exists() and base_path.is_file():
|
||||
return base_path
|
||||
# TypeScript allows importing .ts files with .js extension
|
||||
if base_path.suffix == ".js":
|
||||
ts_path = base_path.with_suffix(".ts")
|
||||
if ts_path.exists() and ts_path.is_file():
|
||||
return ts_path
|
||||
|
||||
# Try adding extensions
|
||||
resolved = self._try_extensions(base_path)
|
||||
if resolved:
|
||||
return resolved
|
||||
|
||||
# Try as directory with index file
|
||||
resolved = self._try_index_file(base_path)
|
||||
if resolved:
|
||||
return resolved
|
||||
|
||||
return None
|
||||
|
||||
def _resolve_absolute_import(self, module_path: str) -> Path | None:
|
||||
"""Resolve absolute imports starting with /.
|
||||
|
||||
Args:
|
||||
module_path: The absolute import path.
|
||||
|
||||
Returns:
|
||||
Resolved absolute path, or None if not found.
|
||||
|
||||
"""
|
||||
# Treat as relative to project root
|
||||
base_path = (self.project_root / module_path.lstrip("/")).resolve()
|
||||
|
||||
# Try adding extensions
|
||||
resolved = self._try_extensions(base_path)
|
||||
if resolved:
|
||||
return resolved
|
||||
|
||||
# Try as directory with index file
|
||||
resolved = self._try_index_file(base_path)
|
||||
if resolved:
|
||||
return resolved
|
||||
|
||||
return None
|
||||
|
||||
def _try_extensions(self, base_path: Path) -> Path | None:
|
||||
"""Try adding various extensions to find the actual file.
|
||||
|
||||
Args:
|
||||
base_path: The path without extension.
|
||||
|
||||
Returns:
|
||||
Path if file found with an extension, None otherwise.
|
||||
|
||||
"""
|
||||
# If base_path already exists as file
|
||||
if base_path.exists() and base_path.is_file():
|
||||
return base_path
|
||||
|
||||
# Try each extension in order
|
||||
for ext in self.EXTENSIONS:
|
||||
path_with_ext = base_path.with_suffix(ext)
|
||||
if path_with_ext.exists() and path_with_ext.is_file():
|
||||
return path_with_ext
|
||||
|
||||
# Also try adding extension to paths that already have one
|
||||
# (e.g., './utils.js' might need to become './utils.js.ts' in some setups)
|
||||
# This is rare but some bundlers support it
|
||||
if base_path.suffix:
|
||||
for ext in self.EXTENSIONS:
|
||||
path_with_double_ext = Path(str(base_path) + ext)
|
||||
if path_with_double_ext.exists() and path_with_double_ext.is_file():
|
||||
return path_with_double_ext
|
||||
|
||||
return None
|
||||
|
||||
def _try_index_file(self, dir_path: Path) -> Path | None:
|
||||
"""Try resolving to index file in a directory.
|
||||
|
||||
Args:
|
||||
dir_path: The directory path to check.
|
||||
|
||||
Returns:
|
||||
Path to index file if found, None otherwise.
|
||||
|
||||
"""
|
||||
if not dir_path.exists() or not dir_path.is_dir():
|
||||
return None
|
||||
|
||||
# Try index files with each extension
|
||||
for ext in self.EXTENSIONS:
|
||||
index_path = dir_path / f"index{ext}"
|
||||
if index_path.exists() and index_path.is_file():
|
||||
return index_path
|
||||
|
||||
return None
|
||||
|
||||
def _is_external_package(self, module_path: str) -> bool:
|
||||
"""Check if import refers to an external package (node_modules).
|
||||
|
||||
Args:
|
||||
module_path: The import module path.
|
||||
|
||||
Returns:
|
||||
True if this is an external package import.
|
||||
|
||||
"""
|
||||
# Relative imports are not external
|
||||
if module_path.startswith("."):
|
||||
return False
|
||||
|
||||
# Absolute imports (starting with /) are project-internal
|
||||
if module_path.startswith("/"):
|
||||
return False
|
||||
|
||||
# Bare imports without ./ or ../ are external packages
|
||||
# This includes:
|
||||
# - 'lodash'
|
||||
# - '@company/utils'
|
||||
# - 'react'
|
||||
# - 'fs' (Node.js built-ins)
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class HelperSearchContext:
|
||||
"""Context for recursive helper search."""
|
||||
|
||||
visited_files: set[Path] = field(default_factory=set)
|
||||
visited_functions: set[tuple[Path, str]] = field(default_factory=set)
|
||||
current_depth: int = 0
|
||||
max_depth: int = 2
|
||||
|
||||
|
||||
class MultiFileHelperFinder:
|
||||
"""Finds helper functions across multiple files."""
|
||||
|
||||
DEFAULT_MAX_DEPTH = 2 # Target → helpers → helpers of helpers
|
||||
|
||||
def __init__(self, project_root: Path, import_resolver: ImportResolver) -> None:
|
||||
"""Initialize the finder.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
import_resolver: ImportResolver instance for resolving imports.
|
||||
|
||||
"""
|
||||
self.project_root = project_root
|
||||
self.import_resolver = import_resolver
|
||||
|
||||
def find_helpers(
|
||||
self,
|
||||
function: FunctionInfo,
|
||||
source: str,
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
imports: list[ImportInfo],
|
||||
max_depth: int = DEFAULT_MAX_DEPTH,
|
||||
) -> dict[Path, list[HelperFunction]]:
|
||||
"""Find all helper functions including cross-file dependencies.
|
||||
|
||||
Args:
|
||||
function: The target function to find helpers for.
|
||||
source: Source code of the file containing the function.
|
||||
analyzer: TreeSitterAnalyzer for parsing.
|
||||
imports: List of imports in the source file.
|
||||
max_depth: Maximum recursion depth for finding helpers of helpers.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping file paths to lists of helper functions.
|
||||
|
||||
"""
|
||||
context = HelperSearchContext(max_depth=max_depth)
|
||||
context.visited_files.add(function.file_path)
|
||||
|
||||
# Find all function calls within the target function
|
||||
all_functions = analyzer.find_functions(source, include_methods=True)
|
||||
target_func = None
|
||||
for func in all_functions:
|
||||
if func.name == function.name and func.start_line == function.start_line:
|
||||
target_func = func
|
||||
break
|
||||
|
||||
if not target_func:
|
||||
return {}
|
||||
|
||||
calls = analyzer.find_function_calls(source, target_func)
|
||||
|
||||
# Match calls to imports
|
||||
call_to_import = self._match_calls_to_imports(calls, imports)
|
||||
|
||||
# Find helpers from imported modules
|
||||
results: dict[Path, list[HelperFunction]] = {}
|
||||
|
||||
for import_info, actual_name in call_to_import.values():
|
||||
# Resolve the import to a file path
|
||||
resolved = self.import_resolver.resolve_import(import_info, function.file_path)
|
||||
if resolved is None:
|
||||
continue
|
||||
|
||||
# Skip if already visited
|
||||
key = (resolved.file_path, actual_name)
|
||||
if key in context.visited_functions:
|
||||
continue
|
||||
context.visited_functions.add(key)
|
||||
|
||||
# Extract the helper function from the resolved file
|
||||
helper = self._extract_helper_from_file(resolved.file_path, actual_name, analyzer)
|
||||
if helper:
|
||||
if resolved.file_path not in results:
|
||||
results[resolved.file_path] = []
|
||||
results[resolved.file_path].append(helper)
|
||||
|
||||
# Recursively find helpers of this helper (if depth allows)
|
||||
if context.current_depth < context.max_depth:
|
||||
nested_results = self._find_helpers_recursive(
|
||||
resolved.file_path,
|
||||
helper,
|
||||
HelperSearchContext(
|
||||
visited_files=context.visited_files.copy(),
|
||||
visited_functions=context.visited_functions.copy(),
|
||||
current_depth=context.current_depth + 1,
|
||||
max_depth=context.max_depth,
|
||||
),
|
||||
)
|
||||
# Merge nested results
|
||||
for path, helpers in nested_results.items():
|
||||
if path not in results:
|
||||
results[path] = []
|
||||
results[path].extend(helpers)
|
||||
|
||||
return results
|
||||
|
||||
def _match_calls_to_imports(self, calls: set[str], imports: list[ImportInfo]) -> dict[str, tuple[ImportInfo, str]]:
|
||||
"""Match function calls to their import sources.
|
||||
|
||||
Args:
|
||||
calls: Set of function call names found in the code.
|
||||
imports: List of import statements.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping call names to (ImportInfo, actual_function_name) tuples.
|
||||
|
||||
"""
|
||||
matches: dict[str, tuple[ImportInfo, str]] = {}
|
||||
|
||||
for call in calls:
|
||||
# Check for namespace calls (e.g., utils.helper)
|
||||
if "." in call:
|
||||
namespace, func_name = call.split(".", 1)
|
||||
for imp in imports:
|
||||
if imp.namespace_import == namespace:
|
||||
matches[call] = (imp, func_name)
|
||||
break
|
||||
else:
|
||||
# Check for direct imports
|
||||
for imp in imports:
|
||||
# Check default import
|
||||
if imp.default_import == call:
|
||||
matches[call] = (imp, "default")
|
||||
break
|
||||
|
||||
# Check named imports
|
||||
for name, alias in imp.named_imports:
|
||||
if (alias and alias == call) or (not alias and name == call):
|
||||
matches[call] = (imp, name)
|
||||
break
|
||||
|
||||
return matches
|
||||
|
||||
def _extract_helper_from_file(
|
||||
self, file_path: Path, function_name: str, analyzer: TreeSitterAnalyzer
|
||||
) -> HelperFunction | None:
|
||||
"""Extract a helper function from a resolved file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file containing the function.
|
||||
function_name: Name of the function to extract.
|
||||
analyzer: TreeSitterAnalyzer for parsing.
|
||||
|
||||
Returns:
|
||||
HelperFunction if found, None otherwise.
|
||||
|
||||
"""
|
||||
from codeflash.languages.base import HelperFunction
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read %s: %s", file_path, e)
|
||||
return None
|
||||
|
||||
# Get analyzer for this file type
|
||||
file_analyzer = get_analyzer_for_file(file_path)
|
||||
|
||||
# Split source into lines for JSDoc extraction
|
||||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Handle "default" export - look for default exported function
|
||||
if function_name == "default":
|
||||
# Find the default export
|
||||
functions = file_analyzer.find_functions(source, include_methods=True)
|
||||
# For now, return first function if looking for default
|
||||
# TODO: Implement proper default export detection
|
||||
for func in functions:
|
||||
# Extract source including JSDoc if present
|
||||
effective_start = func.doc_start_line or func.start_line
|
||||
helper_lines = lines[effective_start - 1 : func.end_line]
|
||||
helper_source = "".join(helper_lines)
|
||||
|
||||
return HelperFunction(
|
||||
name=func.name,
|
||||
qualified_name=func.name,
|
||||
file_path=file_path,
|
||||
source_code=helper_source,
|
||||
start_line=effective_start,
|
||||
end_line=func.end_line,
|
||||
)
|
||||
return None
|
||||
|
||||
# Find the function by name
|
||||
functions = file_analyzer.find_functions(source, include_methods=True)
|
||||
for func in functions:
|
||||
if func.name == function_name:
|
||||
# Extract source including JSDoc if present
|
||||
effective_start = func.doc_start_line or func.start_line
|
||||
helper_lines = lines[effective_start - 1 : func.end_line]
|
||||
helper_source = "".join(helper_lines)
|
||||
|
||||
return HelperFunction(
|
||||
name=func.name,
|
||||
qualified_name=func.name,
|
||||
file_path=file_path,
|
||||
source_code=helper_source,
|
||||
start_line=effective_start,
|
||||
end_line=func.end_line,
|
||||
)
|
||||
|
||||
logger.debug("Function %s not found in %s", function_name, file_path)
|
||||
return None
|
||||
|
||||
def _find_helpers_recursive(
|
||||
self, file_path: Path, helper: HelperFunction, context: HelperSearchContext
|
||||
) -> dict[Path, list[HelperFunction]]:
|
||||
"""Recursively find helpers of a helper function.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file containing the helper.
|
||||
helper: The helper function to analyze.
|
||||
context: Search context with visited tracking and depth limit.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping file paths to lists of helper functions.
|
||||
|
||||
"""
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
if context.current_depth >= context.max_depth:
|
||||
return {}
|
||||
|
||||
if file_path in context.visited_files:
|
||||
return {}
|
||||
context.visited_files.add(file_path)
|
||||
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read %s: %s", file_path, e)
|
||||
return {}
|
||||
|
||||
# Get analyzer and imports for this file
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
imports = analyzer.find_imports(source)
|
||||
|
||||
# Create FunctionInfo for the helper
|
||||
func_info = FunctionInfo(
|
||||
name=helper.name, file_path=file_path, start_line=helper.start_line, end_line=helper.end_line, parents=()
|
||||
)
|
||||
|
||||
# Recursively find helpers
|
||||
return self.find_helpers(
|
||||
function=func_info,
|
||||
source=source,
|
||||
analyzer=analyzer,
|
||||
imports=imports,
|
||||
max_depth=context.max_depth - context.current_depth,
|
||||
)
|
||||
974
codeflash/languages/javascript/instrument.py
Normal file
974
codeflash/languages/javascript/instrument.py
Normal file
|
|
@ -0,0 +1,974 @@
|
|||
"""JavaScript test instrumentation for existing tests.
|
||||
|
||||
This module provides functionality to inject profiling code into existing JavaScript
|
||||
test files, similar to Python's inject_profiling_into_existing_test.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.code_utils.code_position import CodePosition
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
|
||||
class TestingMode:
|
||||
"""Testing mode constants."""
|
||||
|
||||
BEHAVIOR = "behavior"
|
||||
PERFORMANCE = "performance"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpectCallMatch:
|
||||
"""Represents a matched expect(func(...)).toXXX() call."""
|
||||
|
||||
start_pos: int
|
||||
end_pos: int
|
||||
leading_whitespace: str
|
||||
func_args: str
|
||||
assertion_chain: str
|
||||
has_trailing_semicolon: bool
|
||||
object_prefix: str = "" # Object prefix like "calc." or "this." or ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StandaloneCallMatch:
|
||||
"""Represents a matched standalone func(...) call."""
|
||||
|
||||
start_pos: int
|
||||
end_pos: int
|
||||
leading_whitespace: str
|
||||
func_args: str
|
||||
prefix: str # "await " or ""
|
||||
object_prefix: str # Object prefix like "calc." or "this." or ""
|
||||
has_trailing_semicolon: bool
|
||||
|
||||
|
||||
class StandaloneCallTransformer:
|
||||
"""Transforms standalone func(...) calls in JavaScript test code.
|
||||
|
||||
This class handles the transformation of standalone function calls that are NOT
|
||||
inside expect() wrappers. These calls need to be wrapped with codeflash.capture()
|
||||
or codeflash.capturePerf() for instrumentation.
|
||||
|
||||
Examples:
|
||||
- await func(args) -> await codeflash.capturePerf('name', 'id', func, args)
|
||||
- func(args) -> codeflash.capturePerf('name', 'id', func, args)
|
||||
- const result = func(args) -> const result = codeflash.capturePerf(...)
|
||||
- arr.map(() => func(args)) -> arr.map(() => codeflash.capturePerf(..., func, args))
|
||||
- calc.fibonacci(n) -> codeflash.capturePerf('...', 'id', calc.fibonacci.bind(calc), n)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, func_name: str, qualified_name: str, capture_func: str) -> None:
|
||||
self.func_name = func_name
|
||||
self.qualified_name = qualified_name
|
||||
self.capture_func = capture_func
|
||||
self.invocation_counter = 0
|
||||
# Pattern to match func_name( with optional leading await and optional object prefix
|
||||
# Captures: (whitespace)(await )?(object.)*func_name(
|
||||
# We'll filter out expect() and codeflash. cases in the transform loop
|
||||
self._call_pattern = re.compile(rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(func_name)}\s*\(")
|
||||
|
||||
def transform(self, code: str) -> str:
|
||||
"""Transform all standalone calls in the code."""
|
||||
result: list[str] = []
|
||||
pos = 0
|
||||
|
||||
while pos < len(code):
|
||||
match = self._call_pattern.search(code, pos)
|
||||
if not match:
|
||||
result.append(code[pos:])
|
||||
break
|
||||
|
||||
match_start = match.start()
|
||||
|
||||
# Check if this call is inside an expect() or already transformed
|
||||
if self._should_skip_match(code, match_start, match):
|
||||
result.append(code[pos : match.end()])
|
||||
pos = match.end()
|
||||
continue
|
||||
|
||||
# Add everything before the match
|
||||
result.append(code[pos:match_start])
|
||||
|
||||
# Try to parse the full standalone call
|
||||
standalone_match = self._parse_standalone_call(code, match)
|
||||
if standalone_match is None:
|
||||
# Couldn't parse, skip this match
|
||||
result.append(code[match_start : match.end()])
|
||||
pos = match.end()
|
||||
continue
|
||||
|
||||
# Generate the transformed code
|
||||
self.invocation_counter += 1
|
||||
transformed = self._generate_transformed_call(standalone_match)
|
||||
result.append(transformed)
|
||||
pos = standalone_match.end_pos
|
||||
|
||||
return "".join(result)
|
||||
|
||||
def _should_skip_match(self, code: str, start: int, match: re.Match) -> bool:
|
||||
"""Check if the match should be skipped (inside expect, already transformed, etc.)."""
|
||||
# Look backwards to check context
|
||||
lookback_start = max(0, start - 200)
|
||||
lookback = code[lookback_start:start]
|
||||
|
||||
# Skip if already transformed with codeflash.capture
|
||||
if f"codeflash.{self.capture_func}(" in lookback[-60:]:
|
||||
return True
|
||||
|
||||
# Skip if this is a function/method definition, not a call
|
||||
# Patterns to skip:
|
||||
# - ClassName.prototype.funcName = function(
|
||||
# - funcName = function(
|
||||
# - funcName: function(
|
||||
# - function funcName(
|
||||
# - funcName() { (method definition in class)
|
||||
near_context = lookback[-80:] if len(lookback) >= 80 else lookback
|
||||
|
||||
# Skip prototype assignment: ClassName.prototype.funcName = function(
|
||||
if re.search(r"\.prototype\.\w+\s*=\s*function\s*$", near_context):
|
||||
return True
|
||||
|
||||
# Skip function assignment: funcName = function(
|
||||
if re.search(rf"{re.escape(self.func_name)}\s*=\s*function\s*$", near_context):
|
||||
return True
|
||||
|
||||
# Skip function declaration: function funcName(
|
||||
if re.search(rf"function\s+{re.escape(self.func_name)}\s*$", near_context):
|
||||
return True
|
||||
|
||||
# Skip method definition in class body: funcName(params) { or async funcName(params) {
|
||||
# Check by looking at what comes after the closing paren
|
||||
# The match ends at the opening paren, so find the closing paren and check what follows
|
||||
close_paren_pos = self._find_matching_paren(code, match.end() - 1)
|
||||
if close_paren_pos != -1:
|
||||
# Check if followed by { (method definition) after optional whitespace
|
||||
after_close = code[close_paren_pos : close_paren_pos + 20].lstrip()
|
||||
if after_close.startswith("{"):
|
||||
# This is a method definition like "fibonacci(n) {"
|
||||
# But we still want to capture certain patterns like arrow functions
|
||||
# Check if there's no => before the {
|
||||
between = code[close_paren_pos : close_paren_pos + 20].strip()
|
||||
if not between.startswith("=>"):
|
||||
return True
|
||||
|
||||
# Skip if inside expect() - look for 'expect(' with unmatched parens
|
||||
# Find the last 'expect(' and check if it's still open
|
||||
expect_search_start = max(0, start - 100)
|
||||
expect_lookback = code[expect_search_start:start]
|
||||
|
||||
# Find all expect( positions
|
||||
expect_pos = expect_lookback.rfind("expect(")
|
||||
if expect_pos != -1:
|
||||
# Count parens from expect( to our match position
|
||||
between = expect_lookback[expect_pos:]
|
||||
open_parens = between.count("(") - between.count(")")
|
||||
if open_parens > 0:
|
||||
# We're inside an unclosed expect()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _find_matching_paren(self, code: str, open_paren_pos: int) -> int:
|
||||
"""Find the position of the closing paren for the given opening paren."""
|
||||
if open_paren_pos >= len(code) or code[open_paren_pos] != "(":
|
||||
return -1
|
||||
|
||||
depth = 1
|
||||
pos = open_paren_pos + 1
|
||||
|
||||
while pos < len(code) and depth > 0:
|
||||
if code[pos] == "(":
|
||||
depth += 1
|
||||
elif code[pos] == ")":
|
||||
depth -= 1
|
||||
pos += 1
|
||||
|
||||
return pos if depth == 0 else -1
|
||||
|
||||
def _parse_standalone_call(self, code: str, match: re.Match) -> StandaloneCallMatch | None:
|
||||
"""Parse a complete standalone func(...) call."""
|
||||
leading_ws = match.group(1)
|
||||
prefix = match.group(2) or "" # "await " or ""
|
||||
object_prefix = match.group(3) or "" # Object prefix like "calc." or ""
|
||||
|
||||
# If qualified_name is a standalone function (no dot), don't match method calls
|
||||
# e.g., if qualified_name="func", don't match "obj.func()" - only match "func()"
|
||||
if "." not in self.qualified_name and object_prefix:
|
||||
return None
|
||||
|
||||
# Find the opening paren position
|
||||
match_text = match.group(0)
|
||||
paren_offset = match_text.rfind("(")
|
||||
open_paren_pos = match.start() + paren_offset
|
||||
|
||||
# Find the arguments (content inside parens)
|
||||
func_args, close_pos = self._find_balanced_parens(code, open_paren_pos)
|
||||
if func_args is None:
|
||||
return None
|
||||
|
||||
# Check for trailing semicolon
|
||||
end_pos = close_pos
|
||||
# Skip whitespace
|
||||
while end_pos < len(code) and code[end_pos] in " \t":
|
||||
end_pos += 1
|
||||
|
||||
has_trailing_semicolon = end_pos < len(code) and code[end_pos] == ";"
|
||||
if has_trailing_semicolon:
|
||||
end_pos += 1
|
||||
|
||||
return StandaloneCallMatch(
|
||||
start_pos=match.start(),
|
||||
end_pos=end_pos,
|
||||
leading_whitespace=leading_ws,
|
||||
func_args=func_args,
|
||||
prefix=prefix,
|
||||
object_prefix=object_prefix,
|
||||
has_trailing_semicolon=has_trailing_semicolon,
|
||||
)
|
||||
|
||||
def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | None, int]:
|
||||
"""Find content within balanced parentheses."""
|
||||
if open_paren_pos >= len(code) or code[open_paren_pos] != "(":
|
||||
return None, -1
|
||||
|
||||
depth = 1
|
||||
pos = open_paren_pos + 1
|
||||
in_string = False
|
||||
string_char = None
|
||||
|
||||
while pos < len(code) and depth > 0:
|
||||
char = code[pos]
|
||||
|
||||
# Handle string literals
|
||||
if char in "\"'`" and (pos == 0 or code[pos - 1] != "\\"):
|
||||
if not in_string:
|
||||
in_string = True
|
||||
string_char = char
|
||||
elif char == string_char:
|
||||
in_string = False
|
||||
string_char = None
|
||||
elif not in_string:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
|
||||
pos += 1
|
||||
|
||||
if depth != 0:
|
||||
return None, -1
|
||||
|
||||
return code[open_paren_pos + 1 : pos - 1], pos
|
||||
|
||||
def _generate_transformed_call(self, match: StandaloneCallMatch) -> str:
|
||||
"""Generate the transformed code for a standalone call."""
|
||||
line_id = str(self.invocation_counter)
|
||||
args_str = match.func_args.strip()
|
||||
semicolon = ";" if match.has_trailing_semicolon else ""
|
||||
|
||||
# Handle method calls on objects (e.g., calc.fibonacci, this.method)
|
||||
if match.object_prefix:
|
||||
# Remove trailing dot from object prefix for the bind call
|
||||
obj = match.object_prefix.rstrip(".")
|
||||
full_method = f"{obj}.{self.func_name}"
|
||||
|
||||
if args_str:
|
||||
return (
|
||||
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {full_method}.bind({obj}), {args_str}){semicolon}"
|
||||
)
|
||||
return (
|
||||
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {full_method}.bind({obj})){semicolon}"
|
||||
)
|
||||
|
||||
# Handle standalone function calls
|
||||
if args_str:
|
||||
return (
|
||||
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {self.func_name}, {args_str}){semicolon}"
|
||||
)
|
||||
return (
|
||||
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {self.func_name}){semicolon}"
|
||||
)
|
||||
|
||||
|
||||
def transform_standalone_calls(
|
||||
code: str, func_name: str, qualified_name: str, capture_func: str, start_counter: int = 0
|
||||
) -> tuple[str, int]:
|
||||
"""Transform standalone func(...) calls in JavaScript test code.
|
||||
|
||||
This transforms function calls that are NOT inside expect() wrappers.
|
||||
|
||||
Args:
|
||||
code: The test code to transform.
|
||||
func_name: Name of the function being tested.
|
||||
qualified_name: Fully qualified function name.
|
||||
capture_func: The capture function to use ('capture' or 'capturePerf').
|
||||
start_counter: Starting value for the invocation counter.
|
||||
|
||||
Returns:
|
||||
Tuple of (transformed code, final counter value).
|
||||
|
||||
"""
|
||||
transformer = StandaloneCallTransformer(
|
||||
func_name=func_name, qualified_name=qualified_name, capture_func=capture_func
|
||||
)
|
||||
transformer.invocation_counter = start_counter
|
||||
result = transformer.transform(code)
|
||||
return result, transformer.invocation_counter
|
||||
|
||||
|
||||
class ExpectCallTransformer:
|
||||
"""Transforms expect(func(...)).assertion() calls in JavaScript test code.
|
||||
|
||||
This class handles the parsing and transformation of Jest/Vitest expect calls,
|
||||
supporting various assertion patterns including:
|
||||
- Basic: expect(func(args)).toBe(value)
|
||||
- Negated: expect(func(args)).not.toBe(value)
|
||||
- Async: expect(func(args)).resolves.toBe(value)
|
||||
- Chained: expect(func(args)).not.resolves.toBe(value)
|
||||
- No-arg assertions: expect(func(args)).toBeTruthy()
|
||||
- Multi-arg assertions: expect(func(args)).toBeCloseTo(0.5, 2)
|
||||
"""
|
||||
|
||||
def __init__(self, func_name: str, qualified_name: str, capture_func: str, remove_assertions: bool = False) -> None:
|
||||
self.func_name = func_name
|
||||
self.qualified_name = qualified_name
|
||||
self.capture_func = capture_func
|
||||
self.remove_assertions = remove_assertions
|
||||
self.invocation_counter = 0
|
||||
# Pattern to match start of expect((object.)*func_name(
|
||||
# Captures: (whitespace), (object prefix like calc. or this.)
|
||||
self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(func_name)}\s*\(")
|
||||
|
||||
def transform(self, code: str) -> str:
|
||||
"""Transform all expect calls in the code."""
|
||||
result: list[str] = []
|
||||
pos = 0
|
||||
|
||||
while pos < len(code):
|
||||
match = self._expect_pattern.search(code, pos)
|
||||
if not match:
|
||||
result.append(code[pos:])
|
||||
break
|
||||
|
||||
# Add everything before the match
|
||||
result.append(code[pos : match.start()])
|
||||
|
||||
# Try to parse the full expect call
|
||||
expect_match = self._parse_expect_call(code, match)
|
||||
if expect_match is None:
|
||||
# Couldn't parse, skip this match
|
||||
result.append(code[match.start() : match.end()])
|
||||
pos = match.end()
|
||||
continue
|
||||
|
||||
# Generate the transformed code
|
||||
self.invocation_counter += 1
|
||||
transformed = self._generate_transformed_call(expect_match)
|
||||
result.append(transformed)
|
||||
pos = expect_match.end_pos
|
||||
|
||||
return "".join(result)
|
||||
|
||||
def _parse_expect_call(self, code: str, match: re.Match) -> ExpectCallMatch | None:
|
||||
"""Parse a complete expect(func(...)).assertion() call.
|
||||
|
||||
Returns None if the pattern doesn't match expected structure.
|
||||
"""
|
||||
leading_ws = match.group(1)
|
||||
object_prefix = match.group(2) or "" # Object prefix like "calc." or ""
|
||||
|
||||
# If qualified_name is a standalone function (no dot), don't match method calls
|
||||
# e.g., if qualified_name="func", don't match "obj.func()" - only match "func()"
|
||||
if "." not in self.qualified_name and object_prefix:
|
||||
return None
|
||||
|
||||
# Find the arguments of the function call (handling nested parens)
|
||||
args_start = match.end()
|
||||
func_args, func_close_pos = self._find_balanced_parens(code, args_start - 1)
|
||||
if func_args is None:
|
||||
return None
|
||||
|
||||
# Skip whitespace and find closing ) of expect(
|
||||
expect_close_pos = func_close_pos
|
||||
while expect_close_pos < len(code) and code[expect_close_pos].isspace():
|
||||
expect_close_pos += 1
|
||||
|
||||
if expect_close_pos >= len(code) or code[expect_close_pos] != ")":
|
||||
return None
|
||||
|
||||
expect_close_pos += 1 # Move past )
|
||||
|
||||
# Parse the assertion chain (e.g., .not.resolves.toBe(value))
|
||||
assertion_chain, chain_end_pos = self._parse_assertion_chain(code, expect_close_pos)
|
||||
if assertion_chain is None:
|
||||
return None
|
||||
|
||||
# Check for trailing semicolon
|
||||
has_trailing_semicolon = chain_end_pos < len(code) and code[chain_end_pos] == ";"
|
||||
if has_trailing_semicolon:
|
||||
chain_end_pos += 1
|
||||
|
||||
return ExpectCallMatch(
|
||||
start_pos=match.start(),
|
||||
end_pos=chain_end_pos,
|
||||
leading_whitespace=leading_ws,
|
||||
func_args=func_args,
|
||||
assertion_chain=assertion_chain,
|
||||
has_trailing_semicolon=has_trailing_semicolon,
|
||||
object_prefix=object_prefix,
|
||||
)
|
||||
|
||||
def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | None, int]:
|
||||
"""Find content within balanced parentheses.
|
||||
|
||||
Args:
|
||||
code: The source code
|
||||
open_paren_pos: Position of the opening parenthesis
|
||||
|
||||
Returns:
|
||||
Tuple of (content inside parens, position after closing paren) or (None, -1)
|
||||
|
||||
"""
|
||||
if open_paren_pos >= len(code) or code[open_paren_pos] != "(":
|
||||
return None, -1
|
||||
|
||||
depth = 1
|
||||
pos = open_paren_pos + 1
|
||||
in_string = False
|
||||
string_char = None
|
||||
|
||||
while pos < len(code) and depth > 0:
|
||||
char = code[pos]
|
||||
|
||||
# Handle string literals
|
||||
if char in "\"'`" and (pos == 0 or code[pos - 1] != "\\"):
|
||||
if not in_string:
|
||||
in_string = True
|
||||
string_char = char
|
||||
elif char == string_char:
|
||||
in_string = False
|
||||
string_char = None
|
||||
elif not in_string:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
|
||||
pos += 1
|
||||
|
||||
if depth != 0:
|
||||
return None, -1
|
||||
|
||||
# Return content (excluding parens) and position after closing paren
|
||||
return code[open_paren_pos + 1 : pos - 1], pos
|
||||
|
||||
def _parse_assertion_chain(self, code: str, start_pos: int) -> tuple[str | None, int]:
|
||||
"""Parse assertion chain like .not.resolves.toBe(value).
|
||||
|
||||
Handles:
|
||||
- .toBe(value)
|
||||
- .not.toBe(value)
|
||||
- .resolves.toBe(value)
|
||||
- .rejects.toThrow()
|
||||
- .not.resolves.toBe(value)
|
||||
- .toBeTruthy() (no args)
|
||||
- .toBeCloseTo(0.5, 2) (multiple args with nested parens)
|
||||
|
||||
Returns:
|
||||
Tuple of (assertion chain string, end position) or (None, -1)
|
||||
|
||||
"""
|
||||
pos = start_pos
|
||||
chain_parts: list[str] = []
|
||||
|
||||
# Skip any leading whitespace (for multi-line)
|
||||
while pos < len(code) and code[pos] in " \t\n\r":
|
||||
pos += 1
|
||||
|
||||
# Must start with a dot
|
||||
if pos >= len(code) or code[pos] != ".":
|
||||
return None, -1
|
||||
|
||||
while pos < len(code):
|
||||
# Skip whitespace between chain elements
|
||||
while pos < len(code) and code[pos] in " \t\n\r":
|
||||
pos += 1
|
||||
|
||||
if pos >= len(code) or code[pos] != ".":
|
||||
break
|
||||
|
||||
pos += 1 # Skip the dot
|
||||
|
||||
# Skip whitespace after dot
|
||||
while pos < len(code) and code[pos] in " \t\n\r":
|
||||
pos += 1
|
||||
|
||||
# Parse the method name
|
||||
method_start = pos
|
||||
while pos < len(code) and (code[pos].isalnum() or code[pos] == "_"):
|
||||
pos += 1
|
||||
|
||||
if pos == method_start:
|
||||
return None, -1
|
||||
|
||||
method_name = code[method_start:pos]
|
||||
|
||||
# Skip whitespace before potential parens
|
||||
while pos < len(code) and code[pos] in " \t\n\r":
|
||||
pos += 1
|
||||
|
||||
# Check for parentheses (method call)
|
||||
if pos < len(code) and code[pos] == "(":
|
||||
args_content, after_paren = self._find_balanced_parens(code, pos)
|
||||
if args_content is None:
|
||||
return None, -1
|
||||
chain_parts.append(f".{method_name}({args_content})")
|
||||
pos = after_paren
|
||||
else:
|
||||
# Method without parens (like .not, .resolves, .rejects)
|
||||
# Or assertion without args like .toBeTruthy
|
||||
chain_parts.append(f".{method_name}")
|
||||
|
||||
# If this is a terminal assertion (starts with 'to'), we're done
|
||||
if method_name.startswith("to"):
|
||||
break
|
||||
|
||||
if not chain_parts:
|
||||
return None, -1
|
||||
|
||||
# Verify we have a terminal assertion (should end with .toXXX)
|
||||
last_part = chain_parts[-1]
|
||||
if not last_part.startswith(".to"):
|
||||
return None, -1
|
||||
|
||||
return "".join(chain_parts), pos
|
||||
|
||||
def _generate_transformed_call(self, match: ExpectCallMatch) -> str:
|
||||
"""Generate the transformed code for an expect call."""
|
||||
line_id = str(self.invocation_counter)
|
||||
args_str = match.func_args.strip()
|
||||
|
||||
# Determine the function reference to use
|
||||
if match.object_prefix:
|
||||
# Method call on object: calc.fibonacci -> calc.fibonacci.bind(calc)
|
||||
obj = match.object_prefix.rstrip(".")
|
||||
func_ref = f"{obj}.{self.func_name}.bind({obj})"
|
||||
else:
|
||||
func_ref = self.func_name
|
||||
|
||||
if self.remove_assertions:
|
||||
# For generated/regression tests: remove expect wrapper and assertion
|
||||
if args_str:
|
||||
return (
|
||||
f"{match.leading_whitespace}codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {func_ref}, {args_str});"
|
||||
)
|
||||
return (
|
||||
f"{match.leading_whitespace}codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {func_ref});"
|
||||
)
|
||||
|
||||
# For existing tests: keep the expect wrapper
|
||||
semicolon = ";" if match.has_trailing_semicolon else ""
|
||||
if args_str:
|
||||
return (
|
||||
f"{match.leading_whitespace}expect(codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {func_ref}, {args_str})){match.assertion_chain}{semicolon}"
|
||||
)
|
||||
return (
|
||||
f"{match.leading_whitespace}expect(codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {func_ref})){match.assertion_chain}{semicolon}"
|
||||
)
|
||||
|
||||
|
||||
def transform_expect_calls(
|
||||
code: str, func_name: str, qualified_name: str, capture_func: str, remove_assertions: bool = False
|
||||
) -> tuple[str, int]:
|
||||
"""Transform expect(func(...)).assertion() calls in JavaScript test code.
|
||||
|
||||
This is the main entry point for expect call transformation.
|
||||
|
||||
Args:
|
||||
code: The test code to transform.
|
||||
func_name: Name of the function being tested.
|
||||
qualified_name: Fully qualified function name.
|
||||
capture_func: The capture function to use ('capture' or 'capturePerf').
|
||||
remove_assertions: If True, remove assertions entirely (for generated tests).
|
||||
|
||||
Returns:
|
||||
Tuple of (transformed code, final invocation counter value).
|
||||
|
||||
"""
|
||||
transformer = ExpectCallTransformer(
|
||||
func_name=func_name,
|
||||
qualified_name=qualified_name,
|
||||
capture_func=capture_func,
|
||||
remove_assertions=remove_assertions,
|
||||
)
|
||||
result = transformer.transform(code)
|
||||
return result, transformer.invocation_counter
|
||||
|
||||
|
||||
def inject_profiling_into_existing_js_test(
|
||||
test_path: Path,
|
||||
call_positions: list[CodePosition],
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
tests_project_root: Path,
|
||||
mode: str = TestingMode.BEHAVIOR,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Inject profiling code into an existing JavaScript test file.
|
||||
|
||||
This function wraps function calls with codeflash.capture() or codeflash.capturePerf()
|
||||
to enable behavioral verification and performance benchmarking.
|
||||
|
||||
Args:
|
||||
test_path: Path to the test file.
|
||||
call_positions: List of code positions where the function is called.
|
||||
function_to_optimize: The function being optimized.
|
||||
tests_project_root: Root directory of tests.
|
||||
mode: Testing mode - "behavior" or "performance".
|
||||
|
||||
Returns:
|
||||
Tuple of (success, instrumented_code).
|
||||
|
||||
"""
|
||||
try:
|
||||
with test_path.open(encoding="utf8") as f:
|
||||
test_code = f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read test file {test_path}: {e}")
|
||||
return False, None
|
||||
|
||||
func_name = function_to_optimize.function_name
|
||||
|
||||
# Get the relative path for test identification
|
||||
try:
|
||||
rel_path = test_path.relative_to(tests_project_root)
|
||||
except ValueError:
|
||||
rel_path = test_path
|
||||
|
||||
# Check if the function is imported/required in this test file
|
||||
if not _is_function_used_in_test(test_code, func_name):
|
||||
logger.debug(f"Function '{func_name}' not found in test file {test_path}")
|
||||
return False, None
|
||||
|
||||
# Instrument the test code
|
||||
instrumented_code = _instrument_js_test_code(
|
||||
test_code, func_name, str(rel_path), mode, function_to_optimize.qualified_name
|
||||
)
|
||||
|
||||
if instrumented_code == test_code:
|
||||
logger.debug(f"No changes made to test file {test_path}")
|
||||
return False, None
|
||||
|
||||
return True, instrumented_code
|
||||
|
||||
|
||||
def _is_function_used_in_test(code: str, func_name: str) -> bool:
|
||||
"""Check if a function is imported or used in the test code.
|
||||
|
||||
This function handles both standalone functions and class methods.
|
||||
For class methods, it checks if the method is called on any object
|
||||
(e.g., calc.fibonacci, this.fibonacci).
|
||||
"""
|
||||
# Check for CommonJS require with named export
|
||||
require_pattern = rf"(?:const|let|var)\s+\{{\s*[^}}]*\b{re.escape(func_name)}\b[^}}]*\}}\s*=\s*require\s*\("
|
||||
if re.search(require_pattern, code):
|
||||
return True
|
||||
|
||||
# Check for ES6 import with named export
|
||||
import_pattern = rf"import\s+\{{\s*[^}}]*\b{re.escape(func_name)}\b[^}}]*\}}\s+from"
|
||||
if re.search(import_pattern, code):
|
||||
return True
|
||||
|
||||
# Check for default import (import func from or const func = require())
|
||||
default_require = rf"(?:const|let|var)\s+{re.escape(func_name)}\s*=\s*require\s*\("
|
||||
if re.search(default_require, code):
|
||||
return True
|
||||
|
||||
default_import = rf"import\s+{re.escape(func_name)}\s+from"
|
||||
if re.search(default_import, code):
|
||||
return True
|
||||
|
||||
# Check for method calls: obj.funcName( or this.funcName(
|
||||
# This handles class methods called on instances
|
||||
method_call_pattern = rf"\w+\.{re.escape(func_name)}\s*\("
|
||||
return bool(re.search(method_call_pattern, code))
|
||||
|
||||
|
||||
def _instrument_js_test_code(
|
||||
code: str, func_name: str, test_file_path: str, mode: str, qualified_name: str, remove_assertions: bool = False
|
||||
) -> str:
|
||||
"""Instrument JavaScript test code with profiling capture calls.
|
||||
|
||||
Args:
|
||||
code: Original test code.
|
||||
func_name: Name of the function to instrument.
|
||||
test_file_path: Relative path to test file.
|
||||
mode: Testing mode (behavior or performance).
|
||||
qualified_name: Fully qualified function name.
|
||||
remove_assertions: If True, remove expect assertions entirely (for generated/regression tests).
|
||||
If False, keep the expect wrapper (for existing user-written tests).
|
||||
|
||||
Returns:
|
||||
Instrumented code.
|
||||
|
||||
"""
|
||||
# Add codeflash helper import if not already present
|
||||
# Support both npm package (codeflash) and legacy local file (codeflash-jest-helper)
|
||||
has_codeflash_import = "codeflash" in code
|
||||
if not has_codeflash_import:
|
||||
# Detect module system: ESM uses "import ... from", CommonJS uses "require()"
|
||||
is_esm = bool(re.search(r"^\s*import\s+.+\s+from\s+['\"]", code, re.MULTILINE))
|
||||
|
||||
if is_esm:
|
||||
# ESM: Use import statement at the top of the file (after any other imports)
|
||||
helper_import = "import codeflash from 'codeflash';\n"
|
||||
# Find the last import statement to add after
|
||||
import_matches = list(re.finditer(r"^import\s+.+\s+from\s+['\"][^'\"]+['\"]\s*;?\s*\n", code, re.MULTILINE))
|
||||
if import_matches:
|
||||
# Add after the last import
|
||||
last_import = import_matches[-1]
|
||||
insert_pos = last_import.end()
|
||||
code = code[:insert_pos] + helper_import + code[insert_pos:]
|
||||
else:
|
||||
# No imports found, add at beginning
|
||||
code = helper_import + "\n" + code
|
||||
else:
|
||||
# CommonJS: Use require statement
|
||||
helper_require = "const codeflash = require('codeflash');\n"
|
||||
# Find the first require statement to add after
|
||||
import_match = re.search(r"^((?:const|let|var)\s+.+?require\([^)]+\).*;?\s*\n)", code, re.MULTILINE)
|
||||
if import_match:
|
||||
insert_pos = import_match.end()
|
||||
code = code[:insert_pos] + helper_require + code[insert_pos:]
|
||||
else:
|
||||
# Add at the beginning if no requires found
|
||||
code = helper_require + "\n" + code
|
||||
|
||||
# Choose capture function based on mode
|
||||
capture_func = "capturePerf" if mode == TestingMode.PERFORMANCE else "capture"
|
||||
|
||||
# Transform expect calls using the refactored transformer
|
||||
code, expect_counter = transform_expect_calls(
|
||||
code=code,
|
||||
func_name=func_name,
|
||||
qualified_name=qualified_name,
|
||||
capture_func=capture_func,
|
||||
remove_assertions=remove_assertions,
|
||||
)
|
||||
|
||||
# Transform standalone calls (not inside expect wrappers)
|
||||
# Continue counter from expect transformer to ensure unique IDs
|
||||
code, _final_counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name=func_name,
|
||||
qualified_name=qualified_name,
|
||||
capture_func=capture_func,
|
||||
start_counter=expect_counter,
|
||||
)
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def validate_and_fix_import_style(test_code: str, source_file_path: Path, function_name: str) -> str:
|
||||
"""Validate and fix import style in generated test code to match source export.
|
||||
|
||||
The AI may generate tests with incorrect import styles (e.g., using named import
|
||||
for a default export). This function detects such mismatches and fixes them.
|
||||
|
||||
Args:
|
||||
test_code: The generated test code.
|
||||
source_file_path: Path to the source file being tested.
|
||||
function_name: Name of the function being tested.
|
||||
|
||||
Returns:
|
||||
Fixed test code with correct import style.
|
||||
|
||||
"""
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
# Read source file to determine export style
|
||||
try:
|
||||
source_code = source_file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read source file {source_file_path}: {e}")
|
||||
return test_code
|
||||
|
||||
# Get analyzer for the source file
|
||||
try:
|
||||
analyzer = get_analyzer_for_file(source_file_path)
|
||||
exports = analyzer.find_exports(source_code)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not analyze exports in {source_file_path}: {e}")
|
||||
return test_code
|
||||
|
||||
if not exports:
|
||||
return test_code
|
||||
|
||||
# Determine how the function is exported
|
||||
is_default_export = False
|
||||
is_named_export = False
|
||||
|
||||
for export in exports:
|
||||
if export.default_export == function_name:
|
||||
is_default_export = True
|
||||
break
|
||||
for name, _alias in export.exported_names:
|
||||
if name == function_name:
|
||||
is_named_export = True
|
||||
break
|
||||
if is_named_export:
|
||||
break
|
||||
|
||||
# If we can't determine export style, don't modify
|
||||
if not is_default_export and not is_named_export:
|
||||
# Check if it might be a default export without name
|
||||
for export in exports:
|
||||
if export.default_export == "default":
|
||||
is_default_export = True
|
||||
break
|
||||
|
||||
if not is_default_export and not is_named_export:
|
||||
return test_code
|
||||
|
||||
# Find import statements in test code that import from the source file
|
||||
# Normalize path for matching
|
||||
source_name = source_file_path.stem
|
||||
source_patterns = [source_name, f"./{source_name}", f"../{source_name}", source_file_path.as_posix()]
|
||||
|
||||
# Pattern for named import: const { funcName } = require(...) or import { funcName } from ...
|
||||
named_require_pattern = re.compile(
|
||||
rf"(const|let|var)\s+\{{\s*{re.escape(function_name)}\s*\}}\s*=\s*require\s*\(\s*['\"]([^'\"]+)['\"]\s*\)"
|
||||
)
|
||||
named_import_pattern = re.compile(rf"import\s+\{{\s*{re.escape(function_name)}\s*\}}\s+from\s+['\"]([^'\"]+)['\"]")
|
||||
|
||||
# Pattern for default import: const funcName = require(...) or import funcName from ...
|
||||
default_require_pattern = re.compile(
|
||||
rf"(const|let|var)\s+{re.escape(function_name)}\s*=\s*require\s*\(\s*['\"]([^'\"]+)['\"]\s*\)"
|
||||
)
|
||||
default_import_pattern = re.compile(rf"import\s+{re.escape(function_name)}\s+from\s+['\"]([^'\"]+)['\"]")
|
||||
|
||||
def is_relevant_import(module_path: str) -> bool:
|
||||
"""Check if the module path refers to our source file."""
|
||||
# Normalize and compare
|
||||
module_name = Path(module_path).stem
|
||||
return any(p in module_path or module_name == source_name for p in source_patterns)
|
||||
|
||||
# Check for mismatch and fix
|
||||
if is_default_export:
|
||||
# Function is default exported, but test uses named import - need to fix
|
||||
for match in named_require_pattern.finditer(test_code):
|
||||
module_path = match.group(2)
|
||||
if is_relevant_import(module_path):
|
||||
logger.debug(f"Fixing named require to default for {function_name} from {module_path}")
|
||||
old_import = match.group(0)
|
||||
new_import = f"{match.group(1)} {function_name} = require('{module_path}')"
|
||||
test_code = test_code.replace(old_import, new_import)
|
||||
|
||||
for match in named_import_pattern.finditer(test_code):
|
||||
module_path = match.group(1)
|
||||
if is_relevant_import(module_path):
|
||||
logger.debug(f"Fixing named import to default for {function_name} from {module_path}")
|
||||
old_import = match.group(0)
|
||||
new_import = f"import {function_name} from '{module_path}'"
|
||||
test_code = test_code.replace(old_import, new_import)
|
||||
|
||||
elif is_named_export:
|
||||
# Function is named exported, but test uses default import - need to fix
|
||||
for match in default_require_pattern.finditer(test_code):
|
||||
module_path = match.group(2)
|
||||
if is_relevant_import(module_path):
|
||||
logger.debug(f"Fixing default require to named for {function_name} from {module_path}")
|
||||
old_import = match.group(0)
|
||||
new_import = f"{match.group(1)} {{ {function_name} }} = require('{module_path}')"
|
||||
test_code = test_code.replace(old_import, new_import)
|
||||
|
||||
for match in default_import_pattern.finditer(test_code):
|
||||
module_path = match.group(1)
|
||||
if is_relevant_import(module_path):
|
||||
logger.debug(f"Fixing default import to named for {function_name} from {module_path}")
|
||||
old_import = match.group(0)
|
||||
new_import = f"import {{ {function_name} }} from '{module_path}'"
|
||||
test_code = test_code.replace(old_import, new_import)
|
||||
|
||||
return test_code
|
||||
|
||||
|
||||
def get_instrumented_test_path(original_path: Path, mode: str) -> Path:
|
||||
"""Generate path for instrumented test file.
|
||||
|
||||
Args:
|
||||
original_path: Original test file path.
|
||||
mode: Testing mode (behavior or performance).
|
||||
|
||||
Returns:
|
||||
Path for instrumented file.
|
||||
|
||||
"""
|
||||
suffix = "_codeflash_behavior" if mode == TestingMode.BEHAVIOR else "_codeflash_perf"
|
||||
stem = original_path.stem
|
||||
# Handle .test.js -> .test_codeflash_behavior.js
|
||||
if ".test" in stem:
|
||||
parts = stem.rsplit(".test", 1)
|
||||
new_stem = f"{parts[0]}{suffix}.test"
|
||||
elif ".spec" in stem:
|
||||
parts = stem.rsplit(".spec", 1)
|
||||
new_stem = f"{parts[0]}{suffix}.spec"
|
||||
else:
|
||||
new_stem = f"{stem}{suffix}"
|
||||
|
||||
return original_path.parent / f"{new_stem}{original_path.suffix}"
|
||||
|
||||
|
||||
def instrument_generated_js_test(
|
||||
test_code: str, function_name: str, qualified_name: str, mode: str = TestingMode.BEHAVIOR
|
||||
) -> str:
|
||||
"""Instrument generated JavaScript/TypeScript test code.
|
||||
|
||||
This function is used to instrument tests generated by the aiservice.
|
||||
Unlike inject_profiling_into_existing_js_test, this takes the test code
|
||||
as a string rather than reading from a file.
|
||||
|
||||
For generated tests, we remove the expect() assertions entirely because:
|
||||
1. LLM-generated expected values may be incorrect
|
||||
2. These are treated as regression tests where correctness is verified
|
||||
by comparing outputs between original and optimized code
|
||||
|
||||
Args:
|
||||
test_code: The generated test code to instrument.
|
||||
function_name: Name of the function being tested.
|
||||
qualified_name: Fully qualified function name (e.g., 'module.funcName').
|
||||
mode: Testing mode - "behavior" or "performance".
|
||||
|
||||
Returns:
|
||||
Instrumented test code with assertions removed.
|
||||
|
||||
"""
|
||||
if not test_code or not test_code.strip():
|
||||
return test_code
|
||||
|
||||
# Use the internal instrumentation function with assertion removal enabled
|
||||
# Generated tests are treated as regression tests, so we remove LLM-generated assertions
|
||||
return _instrument_js_test_code(
|
||||
code=test_code,
|
||||
func_name=function_name,
|
||||
test_file_path="generated_test",
|
||||
mode=mode,
|
||||
qualified_name=qualified_name,
|
||||
remove_assertions=True,
|
||||
)
|
||||
333
codeflash/languages/javascript/line_profiler.py
Normal file
333
codeflash/languages/javascript/line_profiler.py
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
"""Line profiler instrumentation for JavaScript.
|
||||
|
||||
This module provides functionality to instrument JavaScript code with line-level
|
||||
profiling similar to Python's line_profiler. It tracks execution counts and timing
|
||||
for each line in instrumented functions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JavaScriptLineProfiler:
|
||||
"""Instruments JavaScript code for line-level profiling.
|
||||
|
||||
This class adds profiling code to JavaScript functions to track:
|
||||
- How many times each line executes
|
||||
- How much time is spent on each line
|
||||
- Total execution time per function
|
||||
"""
|
||||
|
||||
def __init__(self, output_file: Path) -> None:
|
||||
"""Initialize the line profiler.
|
||||
|
||||
Args:
|
||||
output_file: Path where profiling results will be written.
|
||||
|
||||
"""
|
||||
self.output_file = output_file
|
||||
self.profiler_var = "__codeflash_line_profiler__"
|
||||
|
||||
def instrument_source(self, source: str, file_path: Path, functions: list[FunctionInfo]) -> str:
|
||||
"""Instrument JavaScript source code with line profiling.
|
||||
|
||||
Adds profiling instrumentation to track line-level execution for the
|
||||
specified functions.
|
||||
|
||||
Args:
|
||||
source: Original JavaScript source code.
|
||||
file_path: Path to the source file.
|
||||
functions: List of functions to instrument.
|
||||
|
||||
Returns:
|
||||
Instrumented source code with profiling.
|
||||
|
||||
"""
|
||||
if not functions:
|
||||
return source
|
||||
|
||||
# Initialize line contents map to collect source content during instrumentation
|
||||
self.line_contents: dict[str, str] = {}
|
||||
|
||||
# Add instrumentation to each function
|
||||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Process functions in reverse order to preserve line numbers
|
||||
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
|
||||
func_lines = self._instrument_function(func, lines, file_path)
|
||||
start_idx = func.start_line - 1
|
||||
end_idx = func.end_line
|
||||
lines = lines[:start_idx] + func_lines + lines[end_idx:]
|
||||
|
||||
instrumented_source = "".join(lines)
|
||||
|
||||
# Add profiler initialization at the top (after collecting line contents)
|
||||
profiler_init = self._generate_profiler_init()
|
||||
|
||||
# Add profiler save at the end
|
||||
profiler_save = self._generate_profiler_save()
|
||||
|
||||
return profiler_init + "\n" + instrumented_source + "\n" + profiler_save
|
||||
|
||||
def _generate_profiler_init(self) -> str:
|
||||
"""Generate JavaScript code for profiler initialization."""
|
||||
# Serialize line contents map for embedding in JavaScript
|
||||
line_contents_json = json.dumps(getattr(self, "line_contents", {}))
|
||||
|
||||
return f"""
|
||||
// Codeflash line profiler initialization
|
||||
// @ts-nocheck
|
||||
const {self.profiler_var} = {{
|
||||
stats: {{}},
|
||||
lineContents: {line_contents_json},
|
||||
lastLineTime: null,
|
||||
lastKey: null,
|
||||
|
||||
totalHits: 0,
|
||||
|
||||
// Called at the start of each function to reset timing state
|
||||
// This prevents "between function calls" time from being attributed to the last line
|
||||
enterFunction: function() {{
|
||||
this.lastKey = null;
|
||||
this.lastLineTime = null;
|
||||
}},
|
||||
|
||||
hit: function(file, line) {{
|
||||
const now = performance.now(); // microsecond precision
|
||||
|
||||
// Attribute elapsed time to the PREVIOUS line (the one that was executing)
|
||||
if (this.lastKey !== null && this.lastLineTime !== null) {{
|
||||
this.stats[this.lastKey].time += (now - this.lastLineTime);
|
||||
}}
|
||||
|
||||
const key = file + ':' + line;
|
||||
if (!this.stats[key]) {{
|
||||
this.stats[key] = {{ hits: 0, time: 0, file: file, line: line }};
|
||||
}}
|
||||
this.stats[key].hits++;
|
||||
|
||||
// Record current line as the one now executing
|
||||
this.lastKey = key;
|
||||
this.lastLineTime = now;
|
||||
|
||||
this.totalHits++;
|
||||
// Save every 100 hits to ensure we capture results even with --forceExit
|
||||
if (this.totalHits % 100 === 0) {{
|
||||
this.save();
|
||||
}}
|
||||
}},
|
||||
|
||||
save: function() {{
|
||||
const fs = require('fs');
|
||||
const pathModule = require('path');
|
||||
const outputDir = pathModule.dirname('{self.output_file.as_posix()}');
|
||||
try {{
|
||||
if (!fs.existsSync(outputDir)) {{
|
||||
fs.mkdirSync(outputDir, {{ recursive: true }});
|
||||
}}
|
||||
// Merge line contents into stats before saving
|
||||
const statsWithContent = {{}};
|
||||
for (const key of Object.keys(this.stats)) {{
|
||||
statsWithContent[key] = {{
|
||||
...this.stats[key],
|
||||
content: this.lineContents[key] || ''
|
||||
}};
|
||||
}}
|
||||
fs.writeFileSync(
|
||||
'{self.output_file.as_posix()}',
|
||||
JSON.stringify(statsWithContent, null, 2)
|
||||
);
|
||||
}} catch (e) {{
|
||||
console.error('Failed to save line profile results:', e);
|
||||
}}
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
def _generate_profiler_save(self) -> str:
|
||||
"""Generate JavaScript code to save profiler results."""
|
||||
return f"""
|
||||
// Save profiler results on process exit and periodically
|
||||
// Use beforeExit for graceful shutdowns
|
||||
process.on('beforeExit', () => {self.profiler_var}.save());
|
||||
process.on('exit', () => {self.profiler_var}.save());
|
||||
process.on('SIGINT', () => {{ {self.profiler_var}.save(); process.exit(); }});
|
||||
process.on('SIGTERM', () => {{ {self.profiler_var}.save(); process.exit(); }});
|
||||
|
||||
// For Jest --forceExit compatibility, save periodically (every 500ms)
|
||||
const __codeflash_save_interval__ = setInterval(() => {self.profiler_var}.save(), 500);
|
||||
if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // Don't keep process alive
|
||||
"""
|
||||
|
||||
def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path) -> list[str]:
|
||||
"""Instrument a single function with line profiling.
|
||||
|
||||
Args:
|
||||
func: Function to instrument.
|
||||
lines: Source lines.
|
||||
file_path: Path to source file.
|
||||
|
||||
Returns:
|
||||
Instrumented function lines.
|
||||
|
||||
"""
|
||||
func_lines = lines[func.start_line - 1 : func.end_line]
|
||||
instrumented_lines = []
|
||||
|
||||
# Parse the function to find executable lines
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
source = "".join(func_lines)
|
||||
|
||||
try:
|
||||
tree = analyzer.parse(source.encode("utf8"))
|
||||
executable_lines = self._find_executable_lines(tree.root_node, source.encode("utf8"))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse function %s: %s", func.name, e)
|
||||
return func_lines
|
||||
|
||||
# Add profiling to each executable line
|
||||
# executable_lines contains 1-indexed line numbers within the function snippet
|
||||
function_entry_added = False
|
||||
|
||||
for local_idx, line in enumerate(func_lines):
|
||||
local_line_num = local_idx + 1 # 1-indexed within function
|
||||
global_line_num = func.start_line + local_idx # Global line number in original file
|
||||
stripped = line.strip()
|
||||
|
||||
# Add enterFunction() call after the opening brace of the function
|
||||
if not function_entry_added and "{" in line:
|
||||
# Find indentation for the function body (use next line's indentation or default)
|
||||
body_indent = " " # Default 4 spaces
|
||||
if local_idx + 1 < len(func_lines):
|
||||
next_line = func_lines[local_idx + 1]
|
||||
if next_line.strip():
|
||||
body_indent = " " * (len(next_line) - len(next_line.lstrip()))
|
||||
|
||||
# Add the line with enterFunction() call after it
|
||||
instrumented_lines.append(line)
|
||||
instrumented_lines.append(f"{body_indent}{self.profiler_var}.enterFunction();\n")
|
||||
function_entry_added = True
|
||||
continue
|
||||
|
||||
# Skip empty lines, comments, and closing braces
|
||||
if local_line_num in executable_lines and stripped and not stripped.startswith("//") and stripped != "}":
|
||||
# Get indentation
|
||||
indent = len(line) - len(line.lstrip())
|
||||
indent_str = " " * indent
|
||||
|
||||
# Store line content for the profiler output
|
||||
content_key = f"{file_path.as_posix()}:{global_line_num}"
|
||||
self.line_contents[content_key] = stripped
|
||||
|
||||
# Add hit() call before the line
|
||||
profiled_line = (
|
||||
f"{indent_str}{self.profiler_var}.hit('{file_path.as_posix()}', {global_line_num});\n{line}"
|
||||
)
|
||||
instrumented_lines.append(profiled_line)
|
||||
else:
|
||||
instrumented_lines.append(line)
|
||||
|
||||
return instrumented_lines
|
||||
|
||||
def _find_executable_lines(self, node, source_bytes: bytes) -> set[int]:
|
||||
"""Find lines that contain executable statements.
|
||||
|
||||
Args:
|
||||
node: Tree-sitter AST node.
|
||||
source_bytes: Source code as bytes.
|
||||
|
||||
Returns:
|
||||
Set of line numbers with executable statements.
|
||||
|
||||
"""
|
||||
executable_lines = set()
|
||||
|
||||
# Node types that represent executable statements
|
||||
executable_types = {
|
||||
"expression_statement",
|
||||
"return_statement",
|
||||
"if_statement",
|
||||
"for_statement",
|
||||
"while_statement",
|
||||
"do_statement",
|
||||
"switch_statement",
|
||||
"throw_statement",
|
||||
"try_statement",
|
||||
"variable_declaration",
|
||||
"lexical_declaration",
|
||||
"assignment_expression",
|
||||
"call_expression",
|
||||
"await_expression",
|
||||
}
|
||||
|
||||
def walk(n) -> None:
|
||||
if n.type in executable_types:
|
||||
# Add the starting line (1-indexed)
|
||||
executable_lines.add(n.start_point[0] + 1)
|
||||
|
||||
for child in n.children:
|
||||
walk(child)
|
||||
|
||||
walk(node)
|
||||
return executable_lines
|
||||
|
||||
@staticmethod
|
||||
def parse_results(profile_file: Path) -> dict:
|
||||
"""Parse line profiling results from output file.
|
||||
|
||||
Args:
|
||||
profile_file: Path to profiling results JSON file.
|
||||
|
||||
Returns:
|
||||
Dictionary with profiling statistics.
|
||||
|
||||
"""
|
||||
if not profile_file.exists():
|
||||
return {"timings": {}, "unit": 1e-9, "functions": {}}
|
||||
|
||||
try:
|
||||
with profile_file.open("r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Group by file and function
|
||||
timings = {}
|
||||
for key, stats in data.items():
|
||||
file_path, line_num = key.rsplit(":", 1)
|
||||
line_num = int(line_num)
|
||||
# performance.now() returns milliseconds, convert to nanoseconds
|
||||
time_ms = float(stats["time"])
|
||||
time_ns = int(time_ms * 1e6)
|
||||
hits = stats["hits"]
|
||||
|
||||
if file_path not in timings:
|
||||
timings[file_path] = {}
|
||||
|
||||
content = stats.get("content", "")
|
||||
timings[file_path][line_num] = {
|
||||
"hits": hits,
|
||||
"time_ns": time_ns,
|
||||
"time_ms": time_ms,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
return {
|
||||
"timings": timings,
|
||||
"unit": 1e-9, # nanoseconds
|
||||
"raw_data": data,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to parse line profile results: %s", e)
|
||||
return {"timings": {}, "unit": 1e-9, "functions": {}}
|
||||
324
codeflash/languages/javascript/module_system.py
Normal file
324
codeflash/languages/javascript/module_system.py
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
"""Module system detection for JavaScript/TypeScript projects.
|
||||
|
||||
Determines whether a project uses CommonJS (require/module.exports) or
|
||||
ES Modules (import/export).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModuleSystem:
|
||||
"""Enum-like class for module systems."""
|
||||
|
||||
COMMONJS = "commonjs"
|
||||
ES_MODULE = "esm"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
# Pattern for destructured require: const { a, b } = require('...')
|
||||
destructured_require = re.compile(
|
||||
r"(const|let|var)\s+\{\s*([^}]+)\s*\}\s*=\s*require\s*\(\s*['\"]([^'\"]+)['\"]\s*\)\s*;?"
|
||||
)
|
||||
|
||||
# Pattern for require with property access: const foo = require('...').propertyName
|
||||
# This must come before simple_require to match first
|
||||
property_access_require = re.compile(
|
||||
r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]([^'\"]+)['\"]\s*\)\.(\w+)\s*;?"
|
||||
)
|
||||
|
||||
# Pattern for simple require: const foo = require('...')
|
||||
simple_require = re.compile(r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]([^'\"]+)['\"]\s*\)\s*;?")
|
||||
|
||||
|
||||
def detect_module_system(project_root: Path, file_path: Path | None = None) -> str:
|
||||
"""Detect the module system used by a JavaScript/TypeScript project.
|
||||
|
||||
Detection strategy:
|
||||
1. Check package.json for "type" field
|
||||
2. If file_path provided, check file extension (.mjs = ESM, .cjs = CommonJS)
|
||||
3. Analyze import statements in the file
|
||||
4. Default to CommonJS if uncertain
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project containing package.json.
|
||||
file_path: Optional specific file to analyze.
|
||||
|
||||
Returns:
|
||||
ModuleSystem constant (COMMONJS, ES_MODULE, or UNKNOWN).
|
||||
|
||||
"""
|
||||
# Strategy 1: Check package.json
|
||||
package_json = project_root / "package.json"
|
||||
if package_json.exists():
|
||||
try:
|
||||
with package_json.open("r") as f:
|
||||
pkg = json.load(f)
|
||||
pkg_type = pkg.get("type", "commonjs")
|
||||
|
||||
if pkg_type == "module":
|
||||
logger.debug("Detected ES Module from package.json type field")
|
||||
return ModuleSystem.ES_MODULE
|
||||
if pkg_type == "commonjs":
|
||||
logger.debug("Detected CommonJS from package.json type field")
|
||||
return ModuleSystem.COMMONJS
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse package.json: %s", e)
|
||||
|
||||
# Strategy 2: Check file extension
|
||||
if file_path:
|
||||
suffix = file_path.suffix
|
||||
if suffix == ".mjs":
|
||||
logger.debug("Detected ES Module from .mjs extension")
|
||||
return ModuleSystem.ES_MODULE
|
||||
if suffix == ".cjs":
|
||||
logger.debug("Detected CommonJS from .cjs extension")
|
||||
return ModuleSystem.COMMONJS
|
||||
|
||||
# Strategy 3: Analyze file content
|
||||
if file_path.exists():
|
||||
try:
|
||||
content = file_path.read_text()
|
||||
|
||||
# Look for ES module syntax
|
||||
has_import = "import " in content and "from " in content
|
||||
has_export = "export " in content or "export default" in content or "export {" in content
|
||||
|
||||
# Look for CommonJS syntax
|
||||
has_require = "require(" in content
|
||||
has_module_exports = "module.exports" in content or "exports." in content
|
||||
|
||||
# Determine based on what we found
|
||||
if (has_import or has_export) and not (has_require or has_module_exports):
|
||||
logger.debug("Detected ES Module from import/export statements")
|
||||
return ModuleSystem.ES_MODULE
|
||||
|
||||
if (has_require or has_module_exports) and not (has_import or has_export):
|
||||
logger.debug("Detected CommonJS from require/module.exports")
|
||||
return ModuleSystem.COMMONJS
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to analyze file %s: %s", file_path, e)
|
||||
|
||||
# Default to CommonJS (more common and backward compatible)
|
||||
logger.debug("Defaulting to CommonJS")
|
||||
return ModuleSystem.COMMONJS
|
||||
|
||||
|
||||
def get_import_statement(
|
||||
module_system: str, target_path: Path, source_path: Path, imported_names: list[str] | None = None
|
||||
) -> str:
|
||||
"""Generate the appropriate import statement for the module system.
|
||||
|
||||
Args:
|
||||
module_system: ModuleSystem constant (COMMONJS or ES_MODULE).
|
||||
target_path: Path to the module being imported.
|
||||
source_path: Path to the file doing the importing.
|
||||
imported_names: List of names to import (for named imports).
|
||||
|
||||
Returns:
|
||||
Import statement string.
|
||||
|
||||
"""
|
||||
# Calculate relative import path
|
||||
rel_path = _get_relative_import_path(target_path, source_path)
|
||||
|
||||
if module_system == ModuleSystem.ES_MODULE:
|
||||
if imported_names:
|
||||
names = ", ".join(imported_names)
|
||||
return f"import {{ {names} }} from '{rel_path}';"
|
||||
# Default import
|
||||
module_name = target_path.stem
|
||||
return f"import {module_name} from '{rel_path}';"
|
||||
if imported_names:
|
||||
names = ", ".join(imported_names)
|
||||
return f"const {{ {names} }} = require('{rel_path}');"
|
||||
# Require entire module
|
||||
module_name = target_path.stem
|
||||
return f"const {module_name} = require('{rel_path}');"
|
||||
|
||||
|
||||
def _get_relative_import_path(target_path: Path, source_path: Path) -> str:
|
||||
"""Calculate relative import path from source to target.
|
||||
|
||||
For JavaScript imports, we calculate the path from the source file's directory
|
||||
to the target file.
|
||||
|
||||
Args:
|
||||
target_path: Absolute path to the file being imported.
|
||||
source_path: Absolute path to the file doing the importing.
|
||||
|
||||
Returns:
|
||||
Relative import path (without file extension for .js files).
|
||||
|
||||
"""
|
||||
# Both paths should be absolute - get the directory containing source
|
||||
source_dir = source_path.parent
|
||||
|
||||
# Try to use os.path.relpath for accuracy
|
||||
import os
|
||||
|
||||
rel_path_str = os.path.relpath(str(target_path), str(source_dir))
|
||||
|
||||
# Normalize to forward slashes
|
||||
rel_path_str = rel_path_str.replace("\\", "/")
|
||||
|
||||
# Remove .js extension (Node.js convention)
|
||||
rel_path_str = rel_path_str.removesuffix(".js")
|
||||
|
||||
# Ensure it starts with ./ or ../ for relative imports
|
||||
if not rel_path_str.startswith("./") and not rel_path_str.startswith("../"):
|
||||
rel_path_str = "./" + rel_path_str
|
||||
|
||||
return rel_path_str
|
||||
|
||||
|
||||
def add_js_extension(module_path: str) -> str:
|
||||
"""Add .js extension to relative module paths for ESM compatibility."""
|
||||
if module_path.startswith(("./", "../")): # noqa: SIM102
|
||||
if not module_path.endswith(".js") and not module_path.endswith(".mjs"):
|
||||
return module_path + ".js"
|
||||
return module_path
|
||||
|
||||
|
||||
# Replace destructured requires with named imports
|
||||
def replace_destructured(match: re.Match) -> str:
|
||||
names = match.group(2).strip()
|
||||
module_path = add_js_extension(match.group(3))
|
||||
return f"import {{ {names} }} from '{module_path}';"
|
||||
|
||||
|
||||
# Replace property access requires with named imports with alias
|
||||
# e.g., const foo = require('./module').bar -> import { bar as foo } from './module';
|
||||
def replace_property_access(match: re.Match) -> str:
|
||||
alias_name = match.group(2) # The variable name (e.g., missingAuthHeader)
|
||||
module_path = add_js_extension(match.group(3))
|
||||
property_name = match.group(4) # The property being accessed (e.g., missingAuthorizationHeader)
|
||||
|
||||
# Special case: .default means default export
|
||||
if property_name == "default":
|
||||
return f"import {alias_name} from '{module_path}';"
|
||||
|
||||
# Named export with alias
|
||||
if alias_name == property_name:
|
||||
return f"import {{ {property_name} }} from '{module_path}';"
|
||||
return f"import {{ {property_name} as {alias_name} }} from '{module_path}';"
|
||||
|
||||
|
||||
# Replace simple requires with default imports
|
||||
def replace_simple(match: re.Match) -> str:
|
||||
name = match.group(2)
|
||||
module_path = add_js_extension(match.group(3))
|
||||
return f"import {name} from '{module_path}';"
|
||||
|
||||
|
||||
def convert_commonjs_to_esm(code: str) -> str:
|
||||
"""Convert CommonJS require statements to ES Module imports.
|
||||
|
||||
Converts:
|
||||
const { foo, bar } = require('./module'); -> import { foo, bar } from './module';
|
||||
const foo = require('./module'); -> import foo from './module';
|
||||
const foo = require('./module').default; -> import foo from './module';
|
||||
const foo = require('./module').bar; -> import { bar as foo } from './module';
|
||||
|
||||
Special handling:
|
||||
- Local codeflash helper (./codeflash-jest-helper) is converted to npm package codeflash
|
||||
because the local helper uses CommonJS exports which don't work in ESM projects
|
||||
|
||||
Args:
|
||||
code: JavaScript code with CommonJS require statements.
|
||||
|
||||
Returns:
|
||||
Code with ES Module import statements.
|
||||
|
||||
"""
|
||||
# Apply conversions (most specific patterns first)
|
||||
code = destructured_require.sub(replace_destructured, code)
|
||||
code = property_access_require.sub(replace_property_access, code)
|
||||
return simple_require.sub(replace_simple, code)
|
||||
|
||||
|
||||
def convert_esm_to_commonjs(code: str) -> str:
|
||||
"""Convert ES Module imports to CommonJS require statements.
|
||||
|
||||
Converts:
|
||||
import { foo, bar } from './module'; -> const { foo, bar } = require('./module');
|
||||
import foo from './module'; -> const foo = require('./module');
|
||||
|
||||
Args:
|
||||
code: JavaScript code with ES Module import statements.
|
||||
|
||||
Returns:
|
||||
Code with CommonJS require statements.
|
||||
|
||||
"""
|
||||
import re
|
||||
|
||||
# Pattern for named import: import { a, b } from '...'; (semicolon optional)
|
||||
named_import = re.compile(r"import\s+\{\s*([^}]+)\s*\}\s+from\s+['\"]([^'\"]+)['\"];?")
|
||||
|
||||
# Pattern for default import: import foo from '...'; (semicolon optional)
|
||||
default_import = re.compile(r"import\s+(\w+)\s+from\s+['\"]([^'\"]+)['\"];?")
|
||||
|
||||
# Replace named imports with destructured requires
|
||||
def replace_named(match) -> str:
|
||||
names = match.group(1).strip()
|
||||
module_path = match.group(2)
|
||||
# Remove .js extension for CommonJS (optional but cleaner)
|
||||
module_path = module_path.removesuffix(".js")
|
||||
return f"const {{ {names} }} = require('{module_path}');"
|
||||
|
||||
# Replace default imports with simple requires
|
||||
def replace_default(match) -> str:
|
||||
name = match.group(1)
|
||||
module_path = match.group(2)
|
||||
# Remove .js extension for CommonJS
|
||||
module_path = module_path.removesuffix(".js")
|
||||
return f"const {name} = require('{module_path}');"
|
||||
|
||||
# Apply conversions (named first as it's more specific)
|
||||
code = named_import.sub(replace_named, code)
|
||||
return default_import.sub(replace_default, code)
|
||||
|
||||
|
||||
def ensure_module_system_compatibility(code: str, target_module_system: str) -> str:
|
||||
"""Ensure code uses the correct module system syntax.
|
||||
|
||||
Detects the current module system in the code and converts if needed.
|
||||
Handles mixed-style code (e.g., ESM imports with CommonJS require for npm packages).
|
||||
|
||||
Args:
|
||||
code: JavaScript code to check and potentially convert.
|
||||
target_module_system: Target ModuleSystem (COMMONJS or ES_MODULE).
|
||||
|
||||
Returns:
|
||||
Code with correct module system syntax.
|
||||
|
||||
"""
|
||||
# Detect current module system in code
|
||||
has_require = "require(" in code
|
||||
has_import = "import " in code and "from " in code
|
||||
|
||||
if target_module_system == ModuleSystem.ES_MODULE:
|
||||
# Convert any require() statements to imports for ESM projects
|
||||
# This handles mixed code (ESM imports + CommonJS requires for npm packages)
|
||||
if has_require:
|
||||
logger.debug("Converting CommonJS requires to ESM imports")
|
||||
return convert_commonjs_to_esm(code)
|
||||
elif target_module_system == ModuleSystem.COMMONJS:
|
||||
# Convert any import statements to requires for CommonJS projects
|
||||
if has_import:
|
||||
logger.debug("Converting ESM imports to CommonJS requires")
|
||||
return convert_esm_to_commonjs(code)
|
||||
|
||||
return code
|
||||
2129
codeflash/languages/javascript/support.py
Normal file
2129
codeflash/languages/javascript/support.py
Normal file
File diff suppressed because it is too large
Load diff
660
codeflash/languages/javascript/test_runner.py
Normal file
660
codeflash/languages/javascript/test_runner.py
Normal file
|
|
@ -0,0 +1,660 @@
|
|||
"""JavaScript test runner using Jest.
|
||||
|
||||
This module provides functions for running Jest tests for behavioral
|
||||
verification and performance benchmarking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.config_consts import STABILITY_CENTER_TOLERANCE, STABILITY_SPREAD_TOLERANCE
|
||||
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import TestFiles
|
||||
|
||||
|
||||
def _find_node_project_root(file_path: Path) -> Path | None:
|
||||
"""Find the Node.js project root by looking for package.json.
|
||||
|
||||
Traverses up from the given file path to find the nearest directory
|
||||
containing package.json or jest.config.js.
|
||||
|
||||
Args:
|
||||
file_path: A file path within the Node.js project.
|
||||
|
||||
Returns:
|
||||
The project root directory, or None if not found.
|
||||
|
||||
"""
|
||||
current = file_path.parent if file_path.is_file() else file_path
|
||||
while current != current.parent: # Stop at filesystem root
|
||||
if (
|
||||
(current / "package.json").exists()
|
||||
or (current / "jest.config.js").exists()
|
||||
or (current / "jest.config.ts").exists()
|
||||
or (current / "tsconfig.json").exists()
|
||||
):
|
||||
return current
|
||||
current = current.parent
|
||||
return None
|
||||
|
||||
|
||||
def _is_esm_project(project_root: Path) -> bool:
|
||||
"""Check if the project uses ES Modules.
|
||||
|
||||
Detects ESM by checking package.json for "type": "module".
|
||||
|
||||
Args:
|
||||
project_root: The project root directory.
|
||||
|
||||
Returns:
|
||||
True if the project uses ES Modules, False otherwise.
|
||||
|
||||
"""
|
||||
package_json = project_root / "package.json"
|
||||
if package_json.exists():
|
||||
try:
|
||||
with package_json.open("r") as f:
|
||||
pkg = json.load(f)
|
||||
return pkg.get("type") == "module"
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to read package.json: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _uses_ts_jest(project_root: Path) -> bool:
|
||||
"""Check if the project uses ts-jest for TypeScript transformation.
|
||||
|
||||
ts-jest handles ESM transformation internally, so we don't need the
|
||||
--experimental-vm-modules flag when it's being used. Adding that flag
|
||||
can actually break Jest's module resolution for jest.mock() with relative paths.
|
||||
|
||||
Args:
|
||||
project_root: The project root directory.
|
||||
|
||||
Returns:
|
||||
True if ts-jest is being used, False otherwise.
|
||||
|
||||
"""
|
||||
# Check for ts-jest in devDependencies
|
||||
package_json = project_root / "package.json"
|
||||
if package_json.exists():
|
||||
try:
|
||||
with package_json.open("r") as f:
|
||||
pkg = json.load(f)
|
||||
dev_deps = pkg.get("devDependencies", {})
|
||||
deps = pkg.get("dependencies", {})
|
||||
if "ts-jest" in dev_deps or "ts-jest" in deps:
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to read package.json for ts-jest detection: {e}")
|
||||
|
||||
# Also check for jest.config with ts-jest preset
|
||||
for config_file in ["jest.config.js", "jest.config.cjs", "jest.config.ts", "jest.config.mjs"]:
|
||||
config_path = project_root / config_file
|
||||
if config_path.exists():
|
||||
try:
|
||||
content = config_path.read_text()
|
||||
if "ts-jest" in content:
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to read {config_file}: {e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _configure_esm_environment(jest_env: dict[str, str], project_root: Path) -> None:
|
||||
"""Configure environment variables for ES Module support in Jest.
|
||||
|
||||
Jest requires --experimental-vm-modules flag for ESM support.
|
||||
This is passed via NODE_OPTIONS environment variable.
|
||||
|
||||
IMPORTANT: When ts-jest is being used, we skip adding --experimental-vm-modules
|
||||
because ts-jest handles ESM transformation internally. Adding this flag can
|
||||
break Jest's module resolution for jest.mock() calls with relative paths.
|
||||
|
||||
Args:
|
||||
jest_env: Environment variables dictionary to modify.
|
||||
project_root: The project root directory.
|
||||
|
||||
"""
|
||||
if _is_esm_project(project_root):
|
||||
# Skip if ts-jest is being used - it handles ESM internally and
|
||||
# --experimental-vm-modules breaks module resolution for relative mocks
|
||||
if _uses_ts_jest(project_root):
|
||||
logger.debug("Skipping --experimental-vm-modules: ts-jest handles ESM transformation")
|
||||
return
|
||||
|
||||
logger.debug("Configuring Jest for ES Module support")
|
||||
existing_node_options = jest_env.get("NODE_OPTIONS", "")
|
||||
esm_flag = "--experimental-vm-modules"
|
||||
if esm_flag not in existing_node_options:
|
||||
jest_env["NODE_OPTIONS"] = f"{existing_node_options} {esm_flag}".strip()
|
||||
|
||||
|
||||
def _ensure_runtime_files(project_root: Path) -> None:
|
||||
"""Ensure JavaScript runtime package is installed in the project.
|
||||
|
||||
Installs codeflash package if not already present.
|
||||
The package provides all runtime files needed for test instrumentation.
|
||||
|
||||
Args:
|
||||
project_root: The project root directory.
|
||||
|
||||
"""
|
||||
# Check if package is already installed
|
||||
node_modules_pkg = project_root / "node_modules" / "codeflash"
|
||||
if node_modules_pkg.exists():
|
||||
logger.debug("codeflash already installed")
|
||||
return
|
||||
|
||||
# Try to install from local package first (for development)
|
||||
local_package_path = Path(__file__).parent.parent.parent.parent / "packages" / "codeflash"
|
||||
if local_package_path.exists():
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["npm", "install", "--save-dev", str(local_package_path)],
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.debug("Installed codeflash from local package")
|
||||
return
|
||||
logger.warning(f"Failed to install local package: {result.stderr}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error installing local package: {e}")
|
||||
|
||||
# Try to install from npm registry
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["npm", "install", "--save-dev", "codeflash"],
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.debug("Installed codeflash from npm registry")
|
||||
return
|
||||
logger.warning(f"Failed to install from npm: {result.stderr}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error installing from npm: {e}")
|
||||
|
||||
logger.error("Could not install codeflash. Please install it manually: npm install --save-dev codeflash")
|
||||
|
||||
|
||||
def run_jest_behavioral_tests(
|
||||
test_paths: TestFiles,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
*,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
enable_coverage: bool = False,
|
||||
candidate_index: int = 0,
|
||||
) -> tuple[Path, subprocess.CompletedProcess, Path | None, Path | None]:
|
||||
"""Run Jest tests and return results in a format compatible with pytest output.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
test_env: Environment variables for the test run.
|
||||
cwd: Working directory for running tests.
|
||||
timeout: Optional timeout in seconds.
|
||||
project_root: JavaScript project root (directory containing package.json).
|
||||
enable_coverage: Whether to collect coverage information.
|
||||
candidate_index: Index of the candidate being tested.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result, coverage_json_path, None).
|
||||
|
||||
"""
|
||||
result_file_path = get_run_tmp_file(Path("jest_results.xml"))
|
||||
|
||||
# Get test files to run
|
||||
test_files = [str(file.instrumented_behavior_file_path) for file in test_paths.test_files]
|
||||
|
||||
# Use provided project_root, or detect it as fallback
|
||||
if project_root is None and test_files:
|
||||
first_test_file = Path(test_files[0])
|
||||
project_root = _find_node_project_root(first_test_file)
|
||||
|
||||
# Use the project root, or fall back to provided cwd
|
||||
effective_cwd = project_root if project_root else cwd
|
||||
logger.debug(f"Jest working directory: {effective_cwd}")
|
||||
|
||||
# Ensure the codeflash npm package is installed
|
||||
_ensure_runtime_files(effective_cwd)
|
||||
|
||||
# Coverage output directory
|
||||
coverage_dir = get_run_tmp_file(Path("jest_coverage"))
|
||||
coverage_json_path = coverage_dir / "coverage-final.json" if enable_coverage else None
|
||||
|
||||
# Build Jest command
|
||||
jest_cmd = [
|
||||
"npx",
|
||||
"jest",
|
||||
"--reporters=default",
|
||||
"--reporters=jest-junit",
|
||||
"--runInBand", # Run tests serially for consistent timing
|
||||
"--forceExit",
|
||||
]
|
||||
|
||||
# Add coverage flags if enabled
|
||||
if enable_coverage:
|
||||
jest_cmd.extend(["--coverage", "--coverageReporters=json", f"--coverageDirectory={coverage_dir}"])
|
||||
|
||||
if test_files:
|
||||
jest_cmd.append("--runTestsByPath")
|
||||
jest_cmd.extend(str(Path(f).resolve()) for f in test_files)
|
||||
|
||||
if timeout:
|
||||
jest_cmd.append(f"--testTimeout={timeout * 1000}") # Jest uses milliseconds
|
||||
|
||||
# Set up environment
|
||||
jest_env = test_env.copy()
|
||||
jest_env["JEST_JUNIT_OUTPUT_FILE"] = str(result_file_path)
|
||||
jest_env["JEST_JUNIT_OUTPUT_DIR"] = str(result_file_path.parent)
|
||||
jest_env["JEST_JUNIT_OUTPUT_NAME"] = result_file_path.name
|
||||
# Configure jest-junit to use filepath-based classnames for proper parsing
|
||||
jest_env["JEST_JUNIT_CLASSNAME"] = "{filepath}"
|
||||
jest_env["JEST_JUNIT_SUITE_NAME"] = "{filepath}"
|
||||
jest_env["JEST_JUNIT_ADD_FILE_ATTRIBUTE"] = "true"
|
||||
# Include console.log output in JUnit XML for timing marker parsing
|
||||
jest_env["JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT"] = "true"
|
||||
# Set codeflash output file for the jest helper to write timing/behavior data (SQLite format)
|
||||
# Use candidate_index to differentiate between baseline (0) and optimization candidates
|
||||
codeflash_sqlite_file = get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite"))
|
||||
jest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
|
||||
jest_env["CODEFLASH_TEST_ITERATION"] = str(candidate_index)
|
||||
jest_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
jest_env["CODEFLASH_MODE"] = "behavior"
|
||||
# Seed random number generator for reproducible test runs across original and optimized code
|
||||
jest_env["CODEFLASH_RANDOM_SEED"] = "42"
|
||||
|
||||
# Configure ESM support if project uses ES Modules
|
||||
_configure_esm_environment(jest_env, effective_cwd)
|
||||
|
||||
logger.debug(f"Running Jest tests with command: {' '.join(jest_cmd)}")
|
||||
|
||||
start_time_ns = time.perf_counter_ns()
|
||||
try:
|
||||
run_args = get_cross_platform_subprocess_run_args(
|
||||
cwd=effective_cwd, env=jest_env, timeout=timeout or 600, check=False, text=True, capture_output=True
|
||||
)
|
||||
result = subprocess.run(jest_cmd, **run_args) # noqa: PLW1510
|
||||
# Jest sends console.log output to stderr by default - move it to stdout
|
||||
# so our timing markers (printed via console.log) are in the expected place
|
||||
if result.stderr and not result.stdout:
|
||||
result = subprocess.CompletedProcess(
|
||||
args=result.args, returncode=result.returncode, stdout=result.stderr, stderr=""
|
||||
)
|
||||
elif result.stderr:
|
||||
# Combine stderr into stdout if both have content
|
||||
result = subprocess.CompletedProcess(
|
||||
args=result.args, returncode=result.returncode, stdout=result.stdout + "\n" + result.stderr, stderr=""
|
||||
)
|
||||
logger.debug(f"Jest result: returncode={result.returncode}")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"Jest tests timed out after {timeout}s")
|
||||
result = subprocess.CompletedProcess(args=jest_cmd, returncode=-1, stdout="", stderr="Test execution timed out")
|
||||
except FileNotFoundError:
|
||||
logger.error("Jest not found. Make sure Jest is installed (npm install jest)")
|
||||
result = subprocess.CompletedProcess(
|
||||
args=jest_cmd, returncode=-1, stdout="", stderr="Jest not found. Run: npm install jest jest-junit"
|
||||
)
|
||||
finally:
|
||||
wall_clock_ns = time.perf_counter_ns() - start_time_ns
|
||||
logger.debug(f"Jest behavioral tests completed in {wall_clock_ns / 1e9:.2f}s")
|
||||
|
||||
return result_file_path, result, coverage_json_path, None
|
||||
|
||||
|
||||
def _parse_timing_from_jest_output(stdout: str) -> dict[str, int]:
|
||||
"""Parse timing data from Jest stdout markers.
|
||||
|
||||
Extracts timing information from markers like:
|
||||
!######testModule:testFunc:funcName:loopIndex:invocationId:durationNs######!
|
||||
|
||||
Args:
|
||||
stdout: Jest stdout containing timing markers.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping test case IDs to duration in nanoseconds.
|
||||
|
||||
"""
|
||||
import re
|
||||
|
||||
# Pattern: !######module:testFunc:funcName:loopIndex:invocationId:durationNs######!
|
||||
pattern = re.compile(r"!######([^:]+):([^:]*):([^:]+):([^:]+):([^:]+):(\d+)######!")
|
||||
|
||||
timings: dict[str, int] = {}
|
||||
for match in pattern.finditer(stdout):
|
||||
module, test_class, func_name, _loop_index, invocation_id, duration_ns = match.groups()
|
||||
# Create test case ID (same format as Python)
|
||||
test_id = f"{module}:{test_class}:{func_name}:{invocation_id}"
|
||||
timings[test_id] = int(duration_ns)
|
||||
|
||||
return timings
|
||||
|
||||
|
||||
def _should_stop_stability(
|
||||
runtimes: list[int],
|
||||
window: int,
|
||||
min_window_size: int,
|
||||
center_rel_tol: float = STABILITY_CENTER_TOLERANCE,
|
||||
spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE,
|
||||
) -> bool:
|
||||
"""Check if performance has stabilized (matches Python's pytest_plugin.should_stop exactly).
|
||||
|
||||
This function implements the same stability criteria as the Python pytest_plugin.py
|
||||
to ensure consistent behavior between Python and JavaScript performance testing.
|
||||
|
||||
Args:
|
||||
runtimes: List of aggregate runtimes (sum of min per test case).
|
||||
window: Size of the window to check for stability.
|
||||
min_window_size: Minimum number of data points required.
|
||||
center_rel_tol: Center tolerance - all recent points must be within this fraction of median.
|
||||
spread_rel_tol: Spread tolerance - (max-min)/min must be within this fraction.
|
||||
|
||||
Returns:
|
||||
True if performance has stabilized, False otherwise.
|
||||
|
||||
"""
|
||||
if len(runtimes) < window:
|
||||
return False
|
||||
|
||||
if len(runtimes) < min_window_size:
|
||||
return False
|
||||
|
||||
recent = runtimes[-window:]
|
||||
|
||||
# Use sorted array for faster median and min/max operations
|
||||
recent_sorted = sorted(recent)
|
||||
mid = window // 2
|
||||
m = recent_sorted[mid] if window % 2 else (recent_sorted[mid - 1] + recent_sorted[mid]) / 2
|
||||
|
||||
# 1) All recent points close to the median
|
||||
centered = True
|
||||
for r in recent:
|
||||
if abs(r - m) / m > center_rel_tol:
|
||||
centered = False
|
||||
break
|
||||
|
||||
# 2) Window spread is small
|
||||
r_min, r_max = recent_sorted[0], recent_sorted[-1]
|
||||
if r_min == 0:
|
||||
return False
|
||||
spread_ok = (r_max - r_min) / r_min <= spread_rel_tol
|
||||
|
||||
return centered and spread_ok
|
||||
|
||||
|
||||
def run_jest_benchmarking_tests(
|
||||
test_paths: TestFiles,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
*,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
min_loops: int = 5,
|
||||
max_loops: int = 100,
|
||||
target_duration_ms: int = 10_000, # 10 seconds for benchmarking tests
|
||||
stability_check: bool = True,
|
||||
) -> tuple[Path, subprocess.CompletedProcess]:
|
||||
"""Run Jest benchmarking tests with in-process session-level looping.
|
||||
|
||||
Uses a custom Jest runner (codeflash/loop-runner) to loop all tests
|
||||
within a single Jest process, eliminating process startup overhead.
|
||||
|
||||
This matches Python's pytest_plugin behavior:
|
||||
- All tests are run multiple times within a single Jest process
|
||||
- Timing data is collected per iteration
|
||||
- Stability is checked within the runner
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
test_env: Environment variables for the test run.
|
||||
cwd: Working directory for running tests.
|
||||
timeout: Optional timeout in seconds for the entire benchmark run.
|
||||
project_root: JavaScript project root (directory containing package.json).
|
||||
min_loops: Minimum number of loop iterations.
|
||||
max_loops: Maximum number of loop iterations.
|
||||
target_duration_ms: Target TOTAL duration in milliseconds for all loops.
|
||||
stability_check: Whether to enable stability-based early stopping.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result with stdout from all iterations).
|
||||
|
||||
"""
|
||||
result_file_path = get_run_tmp_file(Path("jest_perf_results.xml"))
|
||||
|
||||
# Get performance test files
|
||||
test_files = [str(file.benchmarking_file_path) for file in test_paths.test_files if file.benchmarking_file_path]
|
||||
|
||||
# Use provided project_root, or detect it as fallback
|
||||
if project_root is None and test_files:
|
||||
first_test_file = Path(test_files[0])
|
||||
project_root = _find_node_project_root(first_test_file)
|
||||
|
||||
effective_cwd = project_root if project_root else cwd
|
||||
logger.debug(f"Jest benchmarking working directory: {effective_cwd}")
|
||||
|
||||
# Ensure the codeflash npm package is installed
|
||||
_ensure_runtime_files(effective_cwd)
|
||||
|
||||
# Build Jest command for performance tests with custom loop runner
|
||||
jest_cmd = [
|
||||
"npx",
|
||||
"jest",
|
||||
"--reporters=default",
|
||||
"--reporters=jest-junit",
|
||||
"--runInBand", # Ensure serial execution even though runner enforces it
|
||||
"--forceExit",
|
||||
"--runner=codeflash/loop-runner", # Use custom loop runner for in-process looping
|
||||
]
|
||||
|
||||
if test_files:
|
||||
jest_cmd.append("--runTestsByPath")
|
||||
jest_cmd.extend(str(Path(f).resolve()) for f in test_files)
|
||||
|
||||
if timeout:
|
||||
jest_cmd.append(f"--testTimeout={timeout * 1000}")
|
||||
|
||||
# Base environment setup
|
||||
jest_env = test_env.copy()
|
||||
jest_env["JEST_JUNIT_OUTPUT_FILE"] = str(result_file_path)
|
||||
jest_env["JEST_JUNIT_OUTPUT_DIR"] = str(result_file_path.parent)
|
||||
jest_env["JEST_JUNIT_OUTPUT_NAME"] = result_file_path.name
|
||||
jest_env["JEST_JUNIT_CLASSNAME"] = "{filepath}"
|
||||
jest_env["JEST_JUNIT_SUITE_NAME"] = "{filepath}"
|
||||
jest_env["JEST_JUNIT_ADD_FILE_ATTRIBUTE"] = "true"
|
||||
jest_env["JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT"] = "true"
|
||||
codeflash_sqlite_file = get_run_tmp_file(Path("test_return_values_0.sqlite"))
|
||||
jest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
|
||||
jest_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
jest_env["CODEFLASH_MODE"] = "performance"
|
||||
jest_env["CODEFLASH_RANDOM_SEED"] = "42"
|
||||
|
||||
# Internal loop configuration for capturePerf (eliminates Jest environment overhead)
|
||||
# Looping happens inside capturePerf() for maximum efficiency
|
||||
jest_env["CODEFLASH_PERF_LOOP_COUNT"] = str(max_loops)
|
||||
jest_env["CODEFLASH_PERF_MIN_LOOPS"] = str(min_loops)
|
||||
jest_env["CODEFLASH_PERF_TARGET_DURATION_MS"] = str(target_duration_ms)
|
||||
jest_env["CODEFLASH_PERF_STABILITY_CHECK"] = "true" if stability_check else "false"
|
||||
jest_env["CODEFLASH_LOOP_INDEX"] = "1" # Initial value for compatibility
|
||||
|
||||
# Configure ESM support if project uses ES Modules
|
||||
_configure_esm_environment(jest_env, effective_cwd)
|
||||
|
||||
# Total timeout for the entire benchmark run (longer than single-loop timeout)
|
||||
# Account for startup overhead + target duration + buffer
|
||||
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)
|
||||
|
||||
logger.debug(f"Running Jest benchmarking tests with in-process loop runner: {' '.join(jest_cmd)}")
|
||||
logger.debug(
|
||||
f"Jest benchmarking config: min_loops={min_loops}, max_loops={max_loops}, "
|
||||
f"target_duration={target_duration_ms}ms, stability_check={stability_check}"
|
||||
)
|
||||
|
||||
total_start_time = time.time()
|
||||
|
||||
try:
|
||||
run_args = get_cross_platform_subprocess_run_args(
|
||||
cwd=effective_cwd, env=jest_env, timeout=total_timeout, check=False, text=True, capture_output=True
|
||||
)
|
||||
result = subprocess.run(jest_cmd, **run_args) # noqa: PLW1510
|
||||
|
||||
# Combine stderr into stdout for timing markers
|
||||
stdout = result.stdout or ""
|
||||
if result.stderr:
|
||||
stdout = stdout + "\n" + result.stderr if stdout else result.stderr
|
||||
|
||||
# Create result with combined stdout
|
||||
result = subprocess.CompletedProcess(args=result.args, returncode=result.returncode, stdout=stdout, stderr="")
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"Jest benchmarking timed out after {total_timeout}s")
|
||||
result = subprocess.CompletedProcess(args=jest_cmd, returncode=-1, stdout="", stderr="Benchmarking timed out")
|
||||
except FileNotFoundError:
|
||||
logger.error("Jest not found for benchmarking")
|
||||
result = subprocess.CompletedProcess(args=jest_cmd, returncode=-1, stdout="", stderr="Jest not found")
|
||||
|
||||
wall_clock_seconds = time.time() - total_start_time
|
||||
logger.debug(f"Jest benchmarking completed in {wall_clock_seconds:.2f}s")
|
||||
|
||||
return result_file_path, result
|
||||
|
||||
|
||||
def run_jest_line_profile_tests(
|
||||
test_paths: TestFiles,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
*,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
line_profile_output_file: Path | None = None,
|
||||
) -> tuple[Path, subprocess.CompletedProcess]:
|
||||
"""Run Jest tests for line profiling.
|
||||
|
||||
This runs tests against source code that has been instrumented with line profiler.
|
||||
The instrumentation collects execution counts and timing per line.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
test_env: Environment variables for the test run.
|
||||
cwd: Working directory for running tests.
|
||||
timeout: Optional timeout in seconds for the subprocess.
|
||||
project_root: JavaScript project root (directory containing package.json).
|
||||
line_profile_output_file: Path where line profile results will be written.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result).
|
||||
|
||||
"""
|
||||
result_file_path = get_run_tmp_file(Path("jest_line_profile_results.xml"))
|
||||
|
||||
# Get test files to run - use instrumented behavior files if available, otherwise benchmarking files
|
||||
test_files = []
|
||||
for file in test_paths.test_files:
|
||||
if file.instrumented_behavior_file_path:
|
||||
test_files.append(str(file.instrumented_behavior_file_path))
|
||||
elif file.benchmarking_file_path:
|
||||
test_files.append(str(file.benchmarking_file_path))
|
||||
|
||||
# Use provided project_root, or detect it as fallback
|
||||
if project_root is None and test_files:
|
||||
first_test_file = Path(test_files[0])
|
||||
project_root = _find_node_project_root(first_test_file)
|
||||
|
||||
effective_cwd = project_root if project_root else cwd
|
||||
logger.debug(f"Jest line profiling working directory: {effective_cwd}")
|
||||
|
||||
# Ensure the codeflash npm package is installed
|
||||
_ensure_runtime_files(effective_cwd)
|
||||
|
||||
# Build Jest command for line profiling - simple run without benchmarking loops
|
||||
jest_cmd = [
|
||||
"npx",
|
||||
"jest",
|
||||
"--reporters=default",
|
||||
"--reporters=jest-junit",
|
||||
"--runInBand", # Run tests serially for consistent line profiling
|
||||
"--forceExit",
|
||||
]
|
||||
|
||||
if test_files:
|
||||
jest_cmd.append("--runTestsByPath")
|
||||
jest_cmd.extend(str(Path(f).resolve()) for f in test_files)
|
||||
|
||||
if timeout:
|
||||
jest_cmd.append(f"--testTimeout={timeout * 1000}")
|
||||
|
||||
# Set up environment
|
||||
jest_env = test_env.copy()
|
||||
jest_env["JEST_JUNIT_OUTPUT_FILE"] = str(result_file_path)
|
||||
jest_env["JEST_JUNIT_OUTPUT_DIR"] = str(result_file_path.parent)
|
||||
jest_env["JEST_JUNIT_OUTPUT_NAME"] = result_file_path.name
|
||||
jest_env["JEST_JUNIT_CLASSNAME"] = "{filepath}"
|
||||
jest_env["JEST_JUNIT_SUITE_NAME"] = "{filepath}"
|
||||
jest_env["JEST_JUNIT_ADD_FILE_ATTRIBUTE"] = "true"
|
||||
jest_env["JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT"] = "true"
|
||||
# Set codeflash output file for the jest helper
|
||||
codeflash_sqlite_file = get_run_tmp_file(Path("test_return_values_line_profile.sqlite"))
|
||||
jest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
|
||||
jest_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
jest_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
jest_env["CODEFLASH_MODE"] = "line_profile"
|
||||
# Seed random number generator for reproducibility
|
||||
jest_env["CODEFLASH_RANDOM_SEED"] = "42"
|
||||
# Pass the line profile output file path to the instrumented code
|
||||
if line_profile_output_file:
|
||||
jest_env["CODEFLASH_LINE_PROFILE_OUTPUT"] = str(line_profile_output_file)
|
||||
|
||||
# Configure ESM support if project uses ES Modules
|
||||
_configure_esm_environment(jest_env, effective_cwd)
|
||||
|
||||
subprocess_timeout = timeout or 600
|
||||
|
||||
logger.debug(f"Running Jest line profile tests: {' '.join(jest_cmd)}")
|
||||
|
||||
start_time_ns = time.perf_counter_ns()
|
||||
try:
|
||||
run_args = get_cross_platform_subprocess_run_args(
|
||||
cwd=effective_cwd, env=jest_env, timeout=subprocess_timeout, check=False, text=True, capture_output=True
|
||||
)
|
||||
result = subprocess.run(jest_cmd, **run_args) # noqa: PLW1510
|
||||
# Jest sends console.log output to stderr by default - move it to stdout
|
||||
if result.stderr and not result.stdout:
|
||||
result = subprocess.CompletedProcess(
|
||||
args=result.args, returncode=result.returncode, stdout=result.stderr, stderr=""
|
||||
)
|
||||
elif result.stderr:
|
||||
result = subprocess.CompletedProcess(
|
||||
args=result.args, returncode=result.returncode, stdout=result.stdout + "\n" + result.stderr, stderr=""
|
||||
)
|
||||
logger.debug(f"Jest line profile result: returncode={result.returncode}")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"Jest line profile tests timed out after {subprocess_timeout}s")
|
||||
result = subprocess.CompletedProcess(
|
||||
args=jest_cmd, returncode=-1, stdout="", stderr="Line profile tests timed out"
|
||||
)
|
||||
except FileNotFoundError:
|
||||
logger.error("Jest not found for line profiling")
|
||||
result = subprocess.CompletedProcess(args=jest_cmd, returncode=-1, stdout="", stderr="Jest not found")
|
||||
finally:
|
||||
wall_clock_ns = time.perf_counter_ns() - start_time_ns
|
||||
logger.debug(f"Jest line profile tests completed in {wall_clock_ns / 1e9:.2f}s")
|
||||
|
||||
return result_file_path, result
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue