fix: include same-class helper methods inside class wrapper for TypeScript

When optimizing TypeScript class methods that call other methods from the
same class, the helper methods were being appended OUTSIDE the class
definition. This caused syntax errors because class-specific keywords like
`private` are only valid inside a class body.

Changes:
- Add _find_same_class_helpers() method to identify helper methods belonging
  to the same class as the target method
- Modify extract_code_context() to include same-class helpers inside the
  class wrapper and filter them from the helpers list
- Fix all JavaScript/TypeScript tests by adding export keywords to test code
  so functions can be discovered by discover_functions()
- Add comprehensive tests for same-class helper extraction

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
ali 2026-02-06 17:19:46 +02:00
parent 67ea0c9731
commit a6b936402d
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
25 changed files with 1017 additions and 508 deletions

View file

@ -7,7 +7,7 @@
* @param {number[]} arr - The array to sort
* @returns {number[]} - The sorted array
*/
function bubbleSort(arr) {
export function bubbleSort(arr) {
const result = arr.slice();
const n = result.length;
@ -29,7 +29,7 @@ function bubbleSort(arr) {
* @param {number[]} arr - The array to sort
* @returns {number[]} - The sorted array in descending order
*/
function bubbleSortDescending(arr) {
export function bubbleSortDescending(arr) {
const n = arr.length;
const result = [...arr];

View file

@ -11,7 +11,7 @@ const { sumArray, average, findMax, findMin } = require('./math_helpers');
* @param numbers - Array of numbers to analyze
* @returns Object containing sum, average, min, max, and range
*/
function calculateStats(numbers) {
export function calculateStats(numbers) {
if (numbers.length === 0) {
return {
sum: 0,
@ -42,7 +42,7 @@ function calculateStats(numbers) {
* @param numbers - Array of numbers to normalize
* @returns Normalized array
*/
function normalizeArray(numbers) {
export function normalizeArray(numbers) {
if (numbers.length === 0) return [];
const min = findMin(numbers);
@ -62,7 +62,7 @@ function normalizeArray(numbers) {
* @param weights - Array of weights (same length as values)
* @returns The weighted average
*/
function weightedAverage(values, weights) {
export function weightedAverage(values, weights) {
if (values.length === 0 || values.length !== weights.length) {
return 0;
}

View file

@ -8,7 +8,7 @@
* @param {number} n - The index of the Fibonacci number to calculate
* @returns {number} - The nth Fibonacci number
*/
function fibonacci(n) {
export function fibonacci(n) {
if (n <= 1) {
return n;
}
@ -20,7 +20,7 @@ function fibonacci(n) {
* @param {number} num - The number to check
* @returns {boolean} - True if num is a Fibonacci number
*/
function isFibonacci(num) {
export function isFibonacci(num) {
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
const check1 = 5 * num * num + 4;
const check2 = 5 * num * num - 4;
@ -33,7 +33,7 @@ function isFibonacci(num) {
* @param {number} n - The number to check
* @returns {boolean} - True if n is a perfect square
*/
function isPerfectSquare(n) {
export function isPerfectSquare(n) {
const sqrt = Math.sqrt(n);
return sqrt === Math.floor(sqrt);
}
@ -43,7 +43,7 @@ function isPerfectSquare(n) {
* @param {number} n - The number of Fibonacci numbers to generate
* @returns {number[]} - Array of Fibonacci numbers
*/
function fibonacciSequence(n) {
export function fibonacciSequence(n) {
const result = [];
for (let i = 0; i < n; i++) {
result.push(fibonacci(i));

View file

@ -8,7 +8,7 @@
* @param numbers - Array of numbers to sum
* @returns The sum of all numbers
*/
function sumArray(numbers) {
export function sumArray(numbers) {
// Intentionally inefficient - using reduce with spread operator
let result = 0;
for (let i = 0; i < numbers.length; i++) {
@ -22,7 +22,7 @@ function sumArray(numbers) {
* @param numbers - Array of numbers
* @returns The average value
*/
function average(numbers) {
export function average(numbers) {
if (numbers.length === 0) return 0;
return sumArray(numbers) / numbers.length;
}
@ -32,7 +32,7 @@ function average(numbers) {
* @param numbers - Array of numbers
* @returns The maximum value
*/
function findMax(numbers) {
export function findMax(numbers) {
if (numbers.length === 0) return -Infinity;
// Intentionally inefficient - sorting instead of linear scan
@ -45,7 +45,7 @@ function findMax(numbers) {
* @param numbers - Array of numbers
* @returns The minimum value
*/
function findMin(numbers) {
export function findMin(numbers) {
if (numbers.length === 0) return Infinity;
// Intentionally inefficient - sorting instead of linear scan

View file

@ -7,7 +7,7 @@
* @param {string} str - The string to reverse
* @returns {string} - The reversed string
*/
function reverseString(str) {
export function reverseString(str) {
// Intentionally inefficient O(n²) implementation for testing
let result = '';
for (let i = str.length - 1; i >= 0; i--) {
@ -27,7 +27,7 @@ function reverseString(str) {
* @param {string} str - The string to check
* @returns {boolean} - True if str is a palindrome
*/
function isPalindrome(str) {
export function isPalindrome(str) {
const cleaned = str.toLowerCase().replace(/[^a-z0-9]/g, '');
return cleaned === reverseString(cleaned);
}
@ -38,7 +38,7 @@ function isPalindrome(str) {
* @param {string} sub - The substring to count
* @returns {number} - Number of occurrences
*/
function countOccurrences(str, sub) {
export function countOccurrences(str, sub) {
let count = 0;
let pos = 0;
@ -57,7 +57,7 @@ function countOccurrences(str, sub) {
* @param {string[]} strs - Array of strings
* @returns {string} - The longest common prefix
*/
function longestCommonPrefix(strs) {
export function longestCommonPrefix(strs) {
if (strs.length === 0) return '';
if (strs.length === 1) return strs[0];
@ -78,7 +78,7 @@ function longestCommonPrefix(strs) {
* @param {string} str - The string to convert
* @returns {string} - The title-cased string
*/
function toTitleCase(str) {
export function toTitleCase(str) {
return str
.toLowerCase()
.split(' ')

View file

@ -9,7 +9,7 @@
* @param {number} n - The index of the Fibonacci number to calculate
* @returns {number} The nth Fibonacci number
*/
function fibonacci(n) {
export function fibonacci(n) {
if (n <= 1) {
return n;
}
@ -21,7 +21,7 @@ function fibonacci(n) {
* @param {number} num - The number to check
* @returns {boolean} True if num is a Fibonacci number
*/
function isFibonacci(num) {
export function isFibonacci(num) {
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
const check1 = 5 * num * num + 4;
const check2 = 5 * num * num - 4;
@ -33,7 +33,7 @@ function isFibonacci(num) {
* @param {number} n - The number to check
* @returns {boolean} True if n is a perfect square
*/
function isPerfectSquare(n) {
export function isPerfectSquare(n) {
const sqrt = Math.sqrt(n);
return sqrt === Math.floor(sqrt);
}
@ -43,7 +43,7 @@ function isPerfectSquare(n) {
* @param {number} n - The number of Fibonacci numbers to generate
* @returns {number[]} Array of Fibonacci numbers
*/
function fibonacciSequence(n) {
export function fibonacciSequence(n) {
const result = [];
for (let i = 0; i < n; i++) {
result.push(fibonacci(i));

View file

@ -3,7 +3,7 @@
* Intentionally inefficient for optimization testing.
*/
class FibonacciCalculator {
export class FibonacciCalculator {
constructor() {
// No initialization needed
}

View file

@ -962,3 +962,173 @@ def instrument_generated_js_test(
mode=mode,
remove_assertions=True,
)
def fix_imports_inside_test_blocks(test_code: str) -> str:
"""Fix import statements that appear inside test/it blocks.
JavaScript/TypeScript `import` statements must be at the top level of a module.
The AI sometimes generates imports inside test functions, which is invalid syntax.
This function detects such patterns and converts them to dynamic require() calls
which are valid inside functions.
Args:
test_code: The generated test code.
Returns:
Fixed test code with imports converted to require() inside functions.
"""
if not test_code or not test_code.strip():
return test_code
# Pattern to match import statements inside functions
# This captures imports that appear after function/test block openings
# We look for lines that:
# 1. Start with whitespace (indicating they're inside a block)
# 2. Have an import statement
lines = test_code.split("\n")
result_lines = []
brace_depth = 0
in_test_block = False
for line in lines:
stripped = line.strip()
# Track brace depth to know if we're inside a block
# Count braces, but ignore braces in strings (simplified check)
for char in stripped:
if char == "{":
brace_depth += 1
elif char == "}":
brace_depth -= 1
# Check if we're entering a test/it/describe block
if re.match(r"^(test|it|describe|beforeEach|afterEach|beforeAll|afterAll)\s*\(", stripped):
in_test_block = True
# Check for import statement inside a block (brace_depth > 0 means we're inside a function/block)
if brace_depth > 0 and stripped.startswith("import "):
# Convert ESM import to require
# Pattern: import { name } from 'module' -> const { name } = require('module')
# Pattern: import name from 'module' -> const name = require('module')
named_import = re.match(r"import\s+\{([^}]+)\}\s+from\s+['\"]([^'\"]+)['\"]", stripped)
default_import = re.match(r"import\s+(\w+)\s+from\s+['\"]([^'\"]+)['\"]", stripped)
namespace_import = re.match(r"import\s+\*\s+as\s+(\w+)\s+from\s+['\"]([^'\"]+)['\"]", stripped)
leading_whitespace = line[: len(line) - len(line.lstrip())]
if named_import:
names = named_import.group(1)
module = named_import.group(2)
new_line = f"{leading_whitespace}const {{{names}}} = require('{module}');"
result_lines.append(new_line)
logger.debug(f"Fixed import inside block: {stripped} -> {new_line.strip()}")
continue
if default_import:
name = default_import.group(1)
module = default_import.group(2)
new_line = f"{leading_whitespace}const {name} = require('{module}');"
result_lines.append(new_line)
logger.debug(f"Fixed import inside block: {stripped} -> {new_line.strip()}")
continue
if namespace_import:
name = namespace_import.group(1)
module = namespace_import.group(2)
new_line = f"{leading_whitespace}const {name} = require('{module}');"
result_lines.append(new_line)
logger.debug(f"Fixed import inside block: {stripped} -> {new_line.strip()}")
continue
result_lines.append(line)
return "\n".join(result_lines)
def fix_jest_mock_paths(test_code: str, test_file_path: Path, source_file_path: Path, tests_root: Path) -> str:
"""Fix relative paths in jest.mock() calls to be correct from the test file's location.
The AI sometimes generates jest.mock() calls with paths relative to the source file
instead of the test file. For example:
- Source at `src/queue/queue.ts` imports `../environment` (-> src/environment)
- Test at `tests/test.test.ts` generates `jest.mock('../environment')` (-> ./environment, wrong!)
- Should generate `jest.mock('../src/environment')`
This function detects relative mock paths and adjusts them based on the test file's
location relative to the source file's directory.
Args:
test_code: The generated test code.
test_file_path: Path to the test file being generated.
source_file_path: Path to the source file being tested.
tests_root: Root directory of the tests.
Returns:
Fixed test code with corrected mock paths.
"""
if not test_code or not test_code.strip():
return test_code
import os
# Get the directory containing the source file and the test file
source_dir = source_file_path.resolve().parent
test_dir = test_file_path.resolve().parent
project_root = tests_root.resolve().parent if tests_root.name == "tests" else tests_root.resolve()
# Pattern to match jest.mock() or jest.doMock() with relative paths
mock_pattern = re.compile(r"(jest\.(?:mock|doMock)\s*\(\s*['\"])(\.\./[^'\"]+|\.\/[^'\"]+)(['\"])")
def fix_mock_path(match: re.Match[str]) -> str:
original = match.group(0)
prefix = match.group(1)
rel_path = match.group(2)
suffix = match.group(3)
# Resolve the path as if it were relative to the source file's directory
# (which is how the AI often generates it)
source_relative_resolved = (source_dir / rel_path).resolve()
# Check if this resolved path exists or if adjusting it would make more sense
# Calculate what the correct relative path from the test file should be
try:
# First, try to find if the path makes sense from the test directory
test_relative_resolved = (test_dir / rel_path).resolve()
# If the path exists relative to test dir, keep it
if test_relative_resolved.exists() or (
test_relative_resolved.with_suffix(".ts").exists()
or test_relative_resolved.with_suffix(".js").exists()
or test_relative_resolved.with_suffix(".tsx").exists()
or test_relative_resolved.with_suffix(".jsx").exists()
):
return original # Keep original, it's valid
# If path exists relative to source dir, recalculate from test dir
if source_relative_resolved.exists() or (
source_relative_resolved.with_suffix(".ts").exists()
or source_relative_resolved.with_suffix(".js").exists()
or source_relative_resolved.with_suffix(".tsx").exists()
or source_relative_resolved.with_suffix(".jsx").exists()
):
# Calculate the correct relative path from test_dir to source_relative_resolved
new_rel_path = os.path.relpath(str(source_relative_resolved), str(test_dir))
# Ensure it starts with ./ or ../
if not new_rel_path.startswith("../") and not new_rel_path.startswith("./"):
new_rel_path = f"./{new_rel_path}"
# Use forward slashes
new_rel_path = new_rel_path.replace("\\", "/")
logger.debug(f"Fixed jest.mock path: {rel_path} -> {new_rel_path}")
return f"{prefix}{new_rel_path}{suffix}"
except (ValueError, OSError):
pass # Path resolution failed, keep original
return original # Keep original if we can't fix it
return mock_pattern.sub(fix_mock_path, test_code)

View file

@ -332,8 +332,14 @@ class JavaScriptSupport:
else:
target_code = ""
imports = analyzer.find_imports(source)
# Find helper functions called by target (needed before class wrapping to find same-class helpers)
helpers = self._find_helper_functions(function, source, analyzer, imports, module_root)
# For class methods, wrap the method in its class definition
# This is necessary because method definition syntax is only valid inside a class body
same_class_helper_names: set[str] = set()
if function.is_method and function.parents:
class_name = None
for parent in function.parents:
@ -342,17 +348,26 @@ class JavaScriptSupport:
break
if class_name:
# Find same-class helper methods that need to be included inside the class wrapper
same_class_helpers = self._find_same_class_helpers(
class_name, function.function_name, helpers, tree_functions, lines
)
same_class_helper_names = {h[0] for h in same_class_helpers} # method names
# Find the class definition in the source to get proper indentation, JSDoc, constructor, and fields
class_info = self._find_class_definition(source, class_name, analyzer, function.function_name)
if class_info:
class_jsdoc, class_indent, constructor_code, fields_code = class_info
# Build the class body with fields, constructor, and target method
# Build the class body with fields, constructor, target method, and same-class helpers
class_body_parts = []
if fields_code:
class_body_parts.append(fields_code)
if constructor_code:
class_body_parts.append(constructor_code)
class_body_parts.append(target_code)
# Add same-class helper methods inside the class body
for _helper_name, helper_source in same_class_helpers:
class_body_parts.append(helper_source)
class_body = "\n".join(class_body_parts)
# Wrap the method in a class definition with context
@ -363,13 +378,16 @@ class JavaScriptSupport:
else:
target_code = f"{class_indent}class {class_name} {{\n{class_body}{class_indent}}}\n"
else:
# Fallback: wrap with no indentation
target_code = f"class {class_name} {{\n{target_code}}}\n"
# Fallback: wrap with no indentation, including same-class helpers
helper_code = "\n".join(h[1] for h in same_class_helpers)
if helper_code:
target_code = f"class {class_name} {{\n{target_code}\n{helper_code}}}\n"
else:
target_code = f"class {class_name} {{\n{target_code}}}\n"
imports = analyzer.find_imports(source)
# Find helper functions called by target
helpers = self._find_helper_functions(function, source, analyzer, imports, module_root)
# Filter out same-class helpers from the helpers list (they're already inside the class wrapper)
if same_class_helper_names:
helpers = [h for h in helpers if h.name not in same_class_helper_names]
# Extract import statements as strings
import_lines = []
@ -552,6 +570,49 @@ class JavaScriptSupport:
return (constructor_code, fields_code)
def _find_same_class_helpers(
self,
class_name: str,
target_method_name: str,
helpers: list[HelperFunction],
tree_functions: list,
lines: list[str],
) -> list[tuple[str, str]]:
"""Find helper methods that belong to the same class as the target method.
These helpers need to be included inside the class wrapper rather than
appended outside, because they may use class-specific syntax like 'private'.
Args:
class_name: Name of the class containing the target method.
target_method_name: Name of the target method (to exclude).
helpers: List of all helper functions found.
tree_functions: List of FunctionNode from tree-sitter analysis.
lines: Source code split into lines.
Returns:
List of (method_name, source_code) tuples for same-class helpers.
"""
same_class_helpers: list[tuple[str, str]] = []
# Build a set of helper names for quick lookup
helper_names = {h.name for h in helpers}
# Names to exclude from same-class helpers (target method and constructor)
exclude_names = {target_method_name, "constructor"}
# Find methods in tree_functions that belong to the same class and are helpers
for func in tree_functions:
if func.class_name == class_name and func.name in helper_names and func.name not in exclude_names:
# Extract source including JSDoc if present
effective_start = func.doc_start_line or func.start_line
helper_lines = lines[effective_start - 1 : func.end_line]
helper_source = "".join(helper_lines)
same_class_helpers.append((func.name, helper_source))
return same_class_helpers
def _find_helper_functions(
self,
function: FunctionToOptimize,

View file

@ -79,6 +79,8 @@ def generate_tests(
if is_javascript():
from codeflash.languages.javascript.instrument import (
TestingMode,
fix_imports_inside_test_blocks,
fix_jest_mock_paths,
instrument_generated_js_test,
validate_and_fix_import_style,
)
@ -89,6 +91,14 @@ def generate_tests(
source_file = Path(function_to_optimize.file_path)
# Fix import statements that appear inside test blocks (invalid JS syntax)
generated_test_source = fix_imports_inside_test_blocks(generated_test_source)
# Fix relative paths in jest.mock() calls
generated_test_source = fix_jest_mock_paths(
generated_test_source, test_path, source_file, test_cfg.tests_project_rootdir
)
# Validate and fix import styles (default vs named exports)
generated_test_source = validate_and_fix_import_style(
generated_test_source, source_file, function_to_optimize.function_name

View file

@ -23,7 +23,7 @@ class TestJavaScriptFunctionDiscovery:
"""Test discovering a simple JavaScript function with return statement."""
js_file = tmp_path / "simple.js"
js_file.write_text("""
function add(a, b) {
export function add(a, b) {
return a + b;
}
""")
@ -39,15 +39,15 @@ function add(a, b) {
"""Test discovering multiple JavaScript functions."""
js_file = tmp_path / "multiple.js"
js_file.write_text("""
function add(a, b) {
export function add(a, b) {
return a + b;
}
function multiply(a, b) {
export function multiply(a, b) {
return a * b;
}
function divide(a, b) {
export function divide(a, b) {
return a / b;
}
""")
@ -61,11 +61,11 @@ function divide(a, b) {
"""Test that functions without return statements are excluded."""
js_file = tmp_path / "no_return.js"
js_file.write_text("""
function withReturn() {
export function withReturn() {
return 42;
}
function withoutReturn() {
export function withoutReturn() {
console.log("hello");
}
""")
@ -78,11 +78,11 @@ function withoutReturn() {
"""Test discovering arrow functions with explicit return."""
js_file = tmp_path / "arrow.js"
js_file.write_text("""
const add = (a, b) => {
export const add = (a, b) => {
return a + b;
};
const multiply = (a, b) => a * b;
export const multiply = (a, b) => a * b;
""")
functions = find_all_functions_in_file(js_file)
@ -95,7 +95,7 @@ const multiply = (a, b) => a * b;
"""Test discovering methods inside a JavaScript class."""
js_file = tmp_path / "class.js"
js_file.write_text("""
class Calculator {
export class Calculator {
add(a, b) {
return a + b;
}
@ -120,11 +120,11 @@ class Calculator {
"""Test discovering async JavaScript functions."""
js_file = tmp_path / "async.js"
js_file.write_text("""
async function fetchData(url) {
export async function fetchData(url) {
return await fetch(url);
}
function syncFunc() {
export function syncFunc() {
return 42;
}
""")
@ -141,7 +141,7 @@ function syncFunc() {
"""Test that nested functions are handled correctly."""
js_file = tmp_path / "nested.js"
js_file.write_text("""
function outer() {
export function outer() {
function inner() {
return 1;
}
@ -158,11 +158,11 @@ function outer() {
"""Test discovering functions in JSX files."""
jsx_file = tmp_path / "component.jsx"
jsx_file.write_text("""
function Button({ onClick }) {
export function Button({ onClick }) {
return <button onClick={onClick}>Click me</button>;
}
function formatText(text) {
export function formatText(text) {
return text.toUpperCase();
}
""")
@ -176,7 +176,7 @@ function formatText(text) {
"""Test that invalid JavaScript code returns empty results."""
js_file = tmp_path / "invalid.js"
js_file.write_text("""
function broken( {
export function broken( {
return 42;
}
""")
@ -189,11 +189,11 @@ function broken( {
"""Test that function line numbers are correctly detected."""
js_file = tmp_path / "lines.js"
js_file.write_text("""
function firstFunc() {
export function firstFunc() {
return 1;
}
function secondFunc() {
export function secondFunc() {
return 2;
}
""")
@ -217,7 +217,7 @@ class TestJavaScriptFunctionFiltering:
"""Test that filter_functions correctly includes JavaScript files."""
js_file = tmp_path / "module.js"
js_file.write_text("""
function add(a, b) {
export function add(a, b) {
return a + b;
}
""")
@ -240,7 +240,7 @@ function add(a, b) {
tests_dir.mkdir()
test_file = tests_dir / "test_module.test.js"
test_file.write_text("""
function testHelper() {
export function testHelper() {
return 42;
}
""")
@ -260,7 +260,7 @@ function testHelper() {
ignored_dir.mkdir()
js_file = ignored_dir / "ignored_module.js"
js_file.write_text("""
function ignoredFunc() {
export function ignoredFunc() {
return 42;
}
""")
@ -282,7 +282,7 @@ function ignoredFunc() {
"""Test that JavaScript files with dashes in name are included (unlike Python)."""
js_file = tmp_path / "my-module.js"
js_file.write_text("""
function myFunc() {
export function myFunc() {
return 42;
}
""")
@ -312,11 +312,11 @@ class TestGetFunctionsToOptimizeJavaScript:
"""Test getting functions to optimize from a JavaScript file."""
js_file = tmp_path / "string_utils.js"
js_file.write_text("""
function reverseString(str) {
export function reverseString(str) {
return str.split('').reverse().join('');
}
function capitalize(str) {
export function capitalize(str) {
return str.charAt(0).toUpperCase() + str.slice(1);
}
""")
@ -422,12 +422,12 @@ class TestGetAllFilesAndFunctionsJavaScript:
"""Test discovering all JavaScript functions in a directory."""
# Create multiple JS files
(tmp_path / "math.js").write_text("""
function add(a, b) {
export function add(a, b) {
return a + b;
}
""")
(tmp_path / "string.js").write_text("""
function reverse(str) {
export function reverse(str) {
return str.split('').reverse().join('');
}
""")
@ -451,7 +451,7 @@ def py_func():
return 1
""")
(tmp_path / "js_module.js").write_text("""
function jsFunc() {
export function jsFunc() {
return 1;
}
""")
@ -476,7 +476,7 @@ class TestFunctionToOptimizeJavaScript:
"""Test qualified name for top-level function."""
js_file = tmp_path / "module.js"
js_file.write_text("""
function topLevel() {
export function topLevel() {
return 42;
}
""")
@ -490,7 +490,7 @@ function topLevel() {
"""Test qualified name for class method."""
js_file = tmp_path / "module.js"
js_file.write_text("""
class MyClass {
export class MyClass {
myMethod() {
return 42;
}
@ -506,7 +506,7 @@ class MyClass {
"""Test that JavaScript functions have correct language attribute."""
js_file = tmp_path / "module.js"
js_file.write_text("""
function myFunc() {
export function myFunc() {
return 42;
}
""")

View file

@ -6,7 +6,7 @@
const { add, multiply, factorial } = require('./math_utils');
const { formatNumber, validateInput } = require('./helpers/format');
class Calculator {
export class Calculator {
constructor(precision = 2) {
this.precision = precision;
this.history = [];

View file

@ -8,7 +8,7 @@
* @param decimals - Number of decimal places
* @returns Formatted number
*/
function formatNumber(num, decimals) {
export function formatNumber(num, decimals) {
return Number(num.toFixed(decimals));
}
@ -18,7 +18,7 @@ function formatNumber(num, decimals) {
* @param name - Parameter name for error message
* @throws Error if value is not a valid number
*/
function validateInput(value, name) {
export function validateInput(value, name) {
if (typeof value !== 'number' || isNaN(value)) {
throw new Error(`Invalid ${name}: must be a number`);
}
@ -30,7 +30,7 @@ function validateInput(value, name) {
* @param symbol - Currency symbol
* @returns Formatted currency string
*/
function formatCurrency(amount, symbol = '$') {
export function formatCurrency(amount, symbol = '$') {
return `${symbol}${formatNumber(amount, 2)}`;
}

View file

@ -8,7 +8,7 @@
* @param b - Second number
* @returns Sum of a and b
*/
function add(a, b) {
export function add(a, b) {
return a + b;
}
@ -18,7 +18,7 @@ function add(a, b) {
* @param b - Second number
* @returns Product of a and b
*/
function multiply(a, b) {
export function multiply(a, b) {
return a * b;
}
@ -27,7 +27,7 @@ function multiply(a, b) {
* @param n - Non-negative integer
* @returns Factorial of n
*/
function factorial(n) {
export function factorial(n) {
// Intentionally inefficient recursive implementation
if (n <= 1) return 1;
return n * factorial(n - 1);
@ -39,7 +39,7 @@ function factorial(n) {
* @param exp - Exponent
* @returns base raised to exp
*/
function power(base, exp) {
export function power(base, exp) {
// Inefficient: linear time instead of log time
let result = 1;
for (let i = 0; i < exp; i++) {

View file

@ -56,7 +56,7 @@ class TestSimpleFunctionContext:
def test_simple_function_no_dependencies(self, js_support, temp_project):
"""Test extracting context for a simple standalone function without any dependencies."""
code = """\
function add(a, b) {
export function add(a, b) {
return a + b;
}
"""
@ -70,7 +70,7 @@ function add(a, b) {
context = js_support.extract_code_context(func, temp_project, temp_project)
expected_target_code = """\
function add(a, b) {
export function add(a, b) {
return a + b;
}
"""
@ -84,7 +84,7 @@ function add(a, b) {
def test_arrow_function_with_implicit_return(self, js_support, temp_project):
"""Test extracting an arrow function with implicit return."""
code = """\
const multiply = (a, b) => a * b;
export const multiply = (a, b) => a * b;
"""
file_path = temp_project / "math.js"
file_path.write_text(code, encoding="utf-8")
@ -97,7 +97,7 @@ const multiply = (a, b) => a * b;
context = js_support.extract_code_context(func, temp_project, temp_project)
expected_target_code = """\
const multiply = (a, b) => a * b;
export const multiply = (a, b) => a * b;
"""
assert context.target_code == expected_target_code
assert context.helper_functions == []
@ -116,7 +116,7 @@ class TestJSDocExtraction:
* @param {number} b - Second number
* @returns {number} The sum
*/
function add(a, b) {
export function add(a, b) {
return a + b;
}
"""
@ -129,13 +129,7 @@ function add(a, b) {
context = js_support.extract_code_context(func, temp_project, temp_project)
expected_target_code = """\
/**
* Adds two numbers together.
* @param {number} a - First number
* @param {number} b - Second number
* @returns {number} The sum
*/
function add(a, b) {
export function add(a, b) {
return a + b;
}
"""
@ -163,7 +157,7 @@ function add(a, b) {
* const doubled = await processItems([1, 2, 3], x => x * 2);
* // returns [2, 4, 6]
*/
async function processItems(items, callback, options = {}) {
export async function processItems(items, callback, options = {}) {
const { parallel = false, chunkSize = 100 } = options;
if (!Array.isArray(items)) {
@ -187,25 +181,7 @@ async function processItems(items, callback, options = {}) {
context = js_support.extract_code_context(func, temp_project, temp_project)
expected_target_code = """\
/**
* Processes an array of items with a callback function.
*
* This function iterates over each item and applies the transformation.
*
* @template T - The type of items in the input array
* @template U - The type of items in the output array
* @param {Array<T>} items - The input array to process
* @param {function(T, number): U} callback - Transformation function
* @param {Object} [options] - Optional configuration
* @param {boolean} [options.parallel=false] - Whether to process in parallel
* @param {number} [options.chunkSize=100] - Size of processing chunks
* @returns {Promise<Array<U>>} The transformed array
* @throws {TypeError} If items is not an array
* @example
* const doubled = await processItems([1, 2, 3], x => x * 2);
* // returns [2, 4, 6]
*/
async function processItems(items, callback, options = {}) {
export async function processItems(items, callback, options = {}) {
const { parallel = false, chunkSize = 100 } = options;
if (!Array.isArray(items)) {
@ -231,7 +207,7 @@ async function processItems(items, callback, options = {}) {
* @class CacheManager
* @description Provides in-memory caching with automatic expiration.
*/
class CacheManager {
export class CacheManager {
/**
* Creates a new cache manager.
* @param {number} defaultTTL - Default time-to-live in milliseconds
@ -275,12 +251,6 @@ class CacheManager {
context = js_support.extract_code_context(get_or_compute, temp_project, temp_project)
expected_target_code = """\
/**
* A cache implementation with TTL support.
*
* @class CacheManager
* @description Provides in-memory caching with automatic expiration.
*/
class CacheManager {
/**
* Creates a new cache manager.
@ -344,7 +314,7 @@ const EMAIL_REGEX = /^[^\\s@]+@[^\\s@]+\\.[^\\s@]+$/;
* @param {ValidatorFunction[]} validators - Array of validator functions
* @returns {ValidationResult} Combined validation result
*/
function validateUserData(data, validators) {
export function validateUserData(data, validators) {
const errors = [];
const fieldErrors = {};
@ -377,13 +347,7 @@ function validateUserData(data, validators) {
context = js_support.extract_code_context(func, temp_project, temp_project)
expected_target_code = """\
/**
* Validates user input data.
* @param {Object} data - The data to validate
* @param {ValidatorFunction[]} validators - Array of validator functions
* @returns {ValidationResult} Combined validation result
*/
function validateUserData(data, validators) {
export function validateUserData(data, validators) {
const errors = [];
const fieldErrors = {};
@ -433,7 +397,7 @@ const HTTP_STATUS = {
};
const UNUSED_CONFIG = { debug: false };
async function fetchWithRetry(endpoint, options = {}) {
export async function fetchWithRetry(endpoint, options = {}) {
const url = API_BASE_URL + endpoint;
let lastError;
@ -473,7 +437,7 @@ async function fetchWithRetry(endpoint, options = {}) {
context = js_support.extract_code_context(func, temp_project, temp_project)
expected_target_code = """\
async function fetchWithRetry(endpoint, options = {}) {
export async function fetchWithRetry(endpoint, options = {}) {
const url = API_BASE_URL + endpoint;
let lastError;
@ -537,7 +501,7 @@ const ERROR_MESSAGES = {
url: 'Please enter a valid URL'
};
function validateField(value, fieldType) {
export function validateField(value, fieldType) {
const pattern = PATTERNS[fieldType];
if (!pattern) {
return { valid: true, error: null };
@ -559,7 +523,7 @@ function validateField(value, fieldType) {
context = js_support.extract_code_context(func, temp_project, temp_project)
expected_target_code = """\
function validateField(value, fieldType) {
export function validateField(value, fieldType) {
const pattern = PATTERNS[fieldType];
if (!pattern) {
return { valid: true, error: null };
@ -595,16 +559,16 @@ class TestSameFileHelperFunctions:
def test_function_with_chain_of_helpers(self, js_support, temp_project):
"""Test function calling helper that calls another helper (transitive dependencies)."""
code = """\
function sanitizeString(str) {
export function sanitizeString(str) {
return str.trim().toLowerCase();
}
function normalizeInput(input) {
export function normalizeInput(input) {
const sanitized = sanitizeString(input);
return sanitized.replace(/\\s+/g, '-');
}
function processUserInput(rawInput) {
export function processUserInput(rawInput) {
const normalized = normalizeInput(rawInput);
return {
original: rawInput,
@ -622,7 +586,7 @@ function processUserInput(rawInput) {
context = js_support.extract_code_context(process_func, temp_project, temp_project)
expected_target_code = """\
function processUserInput(rawInput) {
export function processUserInput(rawInput) {
const normalized = normalizeInput(rawInput);
return {
original: rawInput,
@ -640,23 +604,23 @@ function processUserInput(rawInput) {
def test_function_with_multiple_unrelated_helpers(self, js_support, temp_project):
"""Test function calling multiple independent helper functions."""
code = """\
function formatDate(date) {
export function formatDate(date) {
return date.toISOString().split('T')[0];
}
function formatCurrency(amount) {
export function formatCurrency(amount) {
return '$' + amount.toFixed(2);
}
function formatPercentage(value) {
export function formatPercentage(value) {
return (value * 100).toFixed(1) + '%';
}
function unusedFormatter() {
export function unusedFormatter() {
return 'not used';
}
function generateReport(data) {
export function generateReport(data) {
const date = formatDate(new Date(data.timestamp));
const revenue = formatCurrency(data.revenue);
const growth = formatPercentage(data.growth);
@ -677,7 +641,7 @@ function generateReport(data) {
context = js_support.extract_code_context(report_func, temp_project, temp_project)
expected_target_code = """\
function generateReport(data) {
export function generateReport(data) {
const date = formatDate(new Date(data.timestamp));
const revenue = formatCurrency(data.revenue);
const growth = formatPercentage(data.growth);
@ -699,21 +663,21 @@ function generateReport(data) {
for helper in context.helper_functions:
if helper.name == "formatDate":
expected = """\
function formatDate(date) {
export function formatDate(date) {
return date.toISOString().split('T')[0];
}
"""
assert helper.source_code == expected
elif helper.name == "formatCurrency":
expected = """\
function formatCurrency(amount) {
export function formatCurrency(amount) {
return '$' + amount.toFixed(2);
}
"""
assert helper.source_code == expected
elif helper.name == "formatPercentage":
expected = """\
function formatPercentage(value) {
export function formatPercentage(value) {
return (value * 100).toFixed(1) + '%';
}
"""
@ -726,7 +690,7 @@ class TestClassMethodWithSiblingMethods:
def test_graph_topological_sort(self, js_support, temp_project):
"""Test graph class with topological sort - similar to Python test_class_method_dependencies."""
code = """\
class Graph {
export class Graph {
constructor(vertices) {
this.graph = new Map();
this.V = vertices;
@ -774,7 +738,7 @@ class Graph {
context = js_support.extract_code_context(topo_sort, temp_project, temp_project)
# The extracted code should include class wrapper with constructor
# The extracted code should include class wrapper with constructor and sibling methods used
expected_target_code = """\
class Graph {
constructor(vertices) {
@ -794,6 +758,19 @@ class Graph {
return stack;
}
topologicalSortUtil(v, visited, stack) {
visited[v] = true;
const neighbors = this.graph.get(v) || [];
for (const i of neighbors) {
if (visited[i] === false) {
this.topologicalSortUtil(i, visited, stack);
}
}
stack.unshift(v);
}
}
"""
assert context.target_code == expected_target_code
@ -802,7 +779,7 @@ class Graph {
def test_class_method_using_nested_helper_class(self, js_support, temp_project):
"""Test class method that uses another class as a helper - mirrors Python HelperClass test."""
code = """\
class HelperClass {
export class HelperClass {
constructor(name) {
this.name = name;
}
@ -816,7 +793,7 @@ class HelperClass {
}
}
class NestedHelper {
export class NestedHelper {
constructor(name) {
this.name = name;
}
@ -826,11 +803,11 @@ class NestedHelper {
}
}
function mainMethod() {
export function mainMethod() {
return 'hello';
}
class MainClass {
export class MainClass {
constructor(name) {
this.name = name;
}
@ -890,7 +867,7 @@ module.exports = { sorter };
main_code = """\
const { sorter } = require('./bubble_sort_with_math');
function sortFromAnotherFile(arr) {
export function sortFromAnotherFile(arr) {
const sortedArr = sorter(arr);
return sortedArr;
}
@ -906,7 +883,7 @@ module.exports = { sortFromAnotherFile };
context = js_support.extract_code_context(main_func, temp_project, temp_project)
expected_target_code = """\
function sortFromAnotherFile(arr) {
export function sortFromAnotherFile(arr) {
const sortedArr = sorter(arr);
return sortedArr;
}
@ -943,12 +920,10 @@ export default function identity(x) {
main_code = """\
import identity, { double, triple } from './utils';
function processNumber(n) {
export function processNumber(n) {
const base = identity(n);
return double(base) + triple(base);
}
export { processNumber };
"""
main_path = temp_project / "main.js"
main_path.write_text(main_code, encoding="utf-8")
@ -959,7 +934,7 @@ export { processNumber };
context = js_support.extract_code_context(process_func, temp_project, temp_project)
expected_target_code = """\
function processNumber(n) {
export function processNumber(n) {
const base = identity(n);
return double(base) + triple(base);
}
@ -1007,7 +982,7 @@ export function transformInput(input) {
main_code = """\
import { transformInput } from './middleware';
function handleUserInput(rawInput) {
export function handleUserInput(rawInput) {
try {
const result = transformInput(rawInput);
return { success: true, data: result };
@ -1015,8 +990,6 @@ function handleUserInput(rawInput) {
return { success: false, error: error.message };
}
}
export { handleUserInput };
"""
main_path = temp_project / "main.js"
main_path.write_text(main_code, encoding="utf-8")
@ -1027,7 +1000,7 @@ export { handleUserInput };
context = js_support.extract_code_context(handle_func, temp_project, temp_project)
expected_target_code = """\
function handleUserInput(rawInput) {
export function handleUserInput(rawInput) {
try {
const result = transformInput(rawInput);
return { success: true, data: result };
@ -1059,7 +1032,7 @@ interface Timestamped {
type Entity<T> = T & Identifiable & Timestamped;
function createEntity<T extends object>(data: T): Entity<T> {
export function createEntity<T extends object>(data: T): Entity<T> {
const now = new Date();
return {
...data,
@ -1078,7 +1051,7 @@ function createEntity<T extends object>(data: T): Entity<T> {
context = ts_support.extract_code_context(func, temp_project, temp_project)
expected_target_code = """\
function createEntity<T extends object>(data: T): Entity<T> {
export function createEntity<T extends object>(data: T): Entity<T> {
const now = new Date();
return {
...data,
@ -1117,7 +1090,7 @@ interface CacheConfig {
maxSize: number;
}
class TypedCache<T> {
export class TypedCache<T> {
private readonly cache: Map<string, CacheEntry<T>>;
private readonly config: CacheConfig;
@ -1235,15 +1208,13 @@ import type { User, CreateUserInput, UserRole } from './types';
const DEFAULT_ROLE: UserRole = 'user';
function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE): User {
export function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE): User {
return {
id: Math.random().toString(36).substring(2),
name: input.name,
email: input.email
};
}
export { createUser };
"""
service_path = temp_project / "service.ts"
service_path.write_text(service_code, encoding="utf-8")
@ -1254,7 +1225,7 @@ export { createUser };
context = ts_support.extract_code_context(func, temp_project, temp_project)
expected_target_code = """\
function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE): User {
export function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE): User {
return {
id: Math.random().toString(36).substring(2),
name: input.name,
@ -1294,7 +1265,7 @@ class TestRecursionAndCircularDependencies:
def test_self_recursive_factorial(self, js_support, temp_project):
"""Test self-recursive function does not list itself as helper."""
code = """\
function factorial(n) {
export function factorial(n) {
if (n <= 1) return 1;
return n * factorial(n - 1);
}
@ -1308,7 +1279,7 @@ function factorial(n) {
context = js_support.extract_code_context(func, temp_project, temp_project)
expected_target_code = """\
function factorial(n) {
export function factorial(n) {
if (n <= 1) return 1;
return n * factorial(n - 1);
}
@ -1319,12 +1290,12 @@ function factorial(n) {
def test_mutually_recursive_even_odd(self, js_support, temp_project):
"""Test mutually recursive functions."""
code = """\
function isEven(n) {
export function isEven(n) {
if (n === 0) return true;
return isOdd(n - 1);
}
function isOdd(n) {
export function isOdd(n) {
if (n === 0) return false;
return isEven(n - 1);
}
@ -1338,7 +1309,7 @@ function isOdd(n) {
context = js_support.extract_code_context(is_even, temp_project, temp_project)
expected_target_code = """\
function isEven(n) {
export function isEven(n) {
if (n === 0) return true;
return isOdd(n - 1);
}
@ -1351,7 +1322,7 @@ function isEven(n) {
# Verify helper source
assert context.helper_functions[0].source_code == """\
function isOdd(n) {
export function isOdd(n) {
if (n === 0) return false;
return isEven(n - 1);
}
@ -1360,28 +1331,28 @@ function isOdd(n) {
def test_complex_recursive_tree_traversal(self, js_support, temp_project):
"""Test complex recursive tree traversal with multiple recursive calls."""
code = """\
function traversePreOrder(node, visit) {
export function traversePreOrder(node, visit) {
if (!node) return;
visit(node.value);
traversePreOrder(node.left, visit);
traversePreOrder(node.right, visit);
}
function traverseInOrder(node, visit) {
export function traverseInOrder(node, visit) {
if (!node) return;
traverseInOrder(node.left, visit);
visit(node.value);
traverseInOrder(node.right, visit);
}
function traversePostOrder(node, visit) {
export function traversePostOrder(node, visit) {
if (!node) return;
traversePostOrder(node.left, visit);
traversePostOrder(node.right, visit);
visit(node.value);
}
function collectAllValues(root) {
export function collectAllValues(root) {
const values = { pre: [], in: [], post: [] };
traversePreOrder(root, v => values.pre.push(v));
@ -1400,7 +1371,7 @@ function collectAllValues(root) {
context = js_support.extract_code_context(collect_func, temp_project, temp_project)
expected_target_code = """\
function collectAllValues(root) {
export function collectAllValues(root) {
const values = { pre: [], in: [], post: [] };
traversePreOrder(root, v => values.pre.push(v));
@ -1423,7 +1394,7 @@ class TestAsyncPatternsAndPromises:
def test_async_function_chain(self, js_support, temp_project):
"""Test async function that calls other async functions."""
code = """\
async function fetchUserById(id) {
export async function fetchUserById(id) {
const response = await fetch(`/api/users/${id}`);
if (!response.ok) {
throw new Error(`User ${id} not found`);
@ -1431,17 +1402,17 @@ async function fetchUserById(id) {
return response.json();
}
async function fetchUserPosts(userId) {
export async function fetchUserPosts(userId) {
const response = await fetch(`/api/users/${userId}/posts`);
return response.json();
}
async function fetchUserComments(userId) {
export async function fetchUserComments(userId) {
const response = await fetch(`/api/users/${userId}/comments`);
return response.json();
}
async function fetchUserProfile(userId) {
export async function fetchUserProfile(userId) {
const user = await fetchUserById(userId);
const [posts, comments] = await Promise.all([
fetchUserPosts(userId),
@ -1465,7 +1436,7 @@ async function fetchUserProfile(userId) {
context = js_support.extract_code_context(profile_func, temp_project, temp_project)
expected_target_code = """\
async function fetchUserProfile(userId) {
export async function fetchUserProfile(userId) {
const user = await fetchUserById(userId);
const [posts, comments] = await Promise.all([
fetchUserPosts(userId),
@ -1493,7 +1464,7 @@ class TestExtractionReplacementRoundTrip:
def test_extract_and_replace_class_method(self, js_support, temp_project):
"""Test extracting code context and then replacing the method."""
original_source = """\
class Counter {
export class Counter {
constructor(initial = 0) {
this.count = initial;
}
@ -1536,7 +1507,7 @@ class Counter {
# Step 2: Simulate AI returning optimized code
optimized_code_from_ai = """\
class Counter {
export class Counter {
constructor(initial = 0) {
this.count = initial;
}
@ -1551,7 +1522,7 @@ class Counter {
result = js_support.replace_function(original_source, increment_func, optimized_code_from_ai)
expected_result = """\
class Counter {
export class Counter {
constructor(initial = 0) {
this.count = initial;
}
@ -1578,7 +1549,7 @@ class TestEdgeCases:
def test_function_with_complex_destructuring(self, js_support, temp_project):
"""Test function with complex nested destructuring parameters."""
code = """\
function processApiResponse({
export function processApiResponse({
data: { users = [], meta: { total, page } = {} } = {},
status,
headers: { 'content-type': contentType } = {}
@ -1600,7 +1571,7 @@ function processApiResponse({
context = js_support.extract_code_context(func, temp_project, temp_project)
expected_target_code = """\
function processApiResponse({
export function processApiResponse({
data: { users = [], meta: { total, page } = {} } = {},
status,
headers: { 'content-type': contentType } = {}
@ -1619,13 +1590,13 @@ function processApiResponse({
def test_generator_function(self, js_support, temp_project):
"""Test generator function extraction."""
code = """\
function* range(start, end, step = 1) {
export function* range(start, end, step = 1) {
for (let i = start; i < end; i += step) {
yield i;
}
}
function* fibonacci(limit) {
export function* fibonacci(limit) {
let [a, b] = [0, 1];
while (a < limit) {
yield a;
@ -1642,7 +1613,7 @@ function* fibonacci(limit) {
context = js_support.extract_code_context(range_func, temp_project, temp_project)
expected_target_code = """\
function* range(start, end, step = 1) {
export function* range(start, end, step = 1) {
for (let i = start; i < end; i += step) {
yield i;
}
@ -1660,7 +1631,7 @@ const FIELD_KEYS = {
AGE: 'user_age'
};
function createUserObject(name, email, age) {
export function createUserObject(name, email, age) {
return {
[FIELD_KEYS.NAME]: name,
[FIELD_KEYS.EMAIL]: email,
@ -1677,7 +1648,7 @@ function createUserObject(name, email, age) {
context = js_support.extract_code_context(func, temp_project, temp_project)
expected_target_code = """\
function createUserObject(name, email, age) {
export function createUserObject(name, email, age) {
return {
[FIELD_KEYS.NAME]: name,
[FIELD_KEYS.EMAIL]: email,
@ -1937,7 +1908,7 @@ class TestContextProperties:
def test_javascript_context_has_correct_language(self, js_support, temp_project):
"""Test that JavaScript context has correct language property."""
code = """\
function test() {
export function test() {
return 1;
}
"""
@ -1956,7 +1927,7 @@ function test() {
def test_typescript_context_has_javascript_language(self, ts_support, temp_project):
"""Test that TypeScript context uses JavaScript language enum."""
code = """\
function test(): number {
export function test(): number {
return 1;
}
"""
@ -1977,7 +1948,7 @@ class TestContextValidation:
def test_all_class_methods_produce_valid_syntax(self, js_support, temp_project):
"""Test that all extracted class methods are syntactically valid JavaScript."""
code = """\
class Calculator {
export class Calculator {
constructor(precision = 2) {
this.precision = precision;
}

View file

@ -89,11 +89,11 @@ def multiply(a, b):
"""Test that JavaScript files use the JavaScript handler."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
function add(a, b) {
export function add(a, b) {
return a + b;
}
function multiply(a, b) {
export function multiply(a, b) {
return a * b;
}
""")
@ -124,7 +124,7 @@ function multiply(a, b) {
"""Test that FunctionToOptimize has all required fields populated."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
class Calculator {
export class Calculator {
add(a, b) {
return a + b;
}
@ -162,7 +162,7 @@ def add(a, b):
def test_discovers_javascript_files_when_specified(self, tmp_path):
"""Test that JavaScript files are discovered when language is specified."""
(tmp_path / "module.js").write_text("""
function add(a, b) {
export function add(a, b) {
return a + b;
}
""")
@ -177,7 +177,7 @@ def py_func():
return 1
""")
(tmp_path / "js_module.js").write_text("""
function jsFunc() {
export function jsFunc() {
return 1;
}
""")

View file

@ -129,13 +129,7 @@ class TestJavaScriptCodeContext:
assert len(context.read_writable_code.code_strings) > 0
code = context.read_writable_code.code_strings[0].code
expected_code = """/**
* Calculate the nth Fibonacci number using naive recursion.
* This is intentionally slow to demonstrate optimization potential.
* @param {number} n - The index of the Fibonacci number to calculate
* @returns {number} - The nth Fibonacci number
*/
function fibonacci(n) {
expected_code = """export function fibonacci(n) {
if (n <= 1) {
return n;
}
@ -155,16 +149,16 @@ class TestJavaScriptCodeReplacement:
from codeflash.languages.base import FunctionInfo
original_source = """
function add(a, b) {
export function add(a, b) {
return a + b;
}
function multiply(a, b) {
export function multiply(a, b) {
return a * b;
}
"""
new_function = """function add(a, b) {
new_function = """export function add(a, b) {
// Optimized version
return a + b;
}"""
@ -178,12 +172,12 @@ function multiply(a, b) {
result = js_support.replace_function(original_source, func_info, new_function)
expected_result = """
function add(a, b) {
export function add(a, b) {
// Optimized version
return a + b;
}
function multiply(a, b) {
export function multiply(a, b) {
return a * b;
}
"""
@ -234,7 +228,7 @@ class TestJavaScriptPipelineIntegration:
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
class Calculator {
export class Calculator {
add(a, b) {
return a + b;
}
@ -244,7 +238,7 @@ class Calculator {
}
}
function standalone(x) {
export function standalone(x) {
return x * 2;
}
""")

View file

@ -663,4 +663,197 @@ class TestInstrumentationFullStringEquality:
expected = " return codeflash.capture('Class.fibonacci', '1', this.fibonacci.bind(this), n - 1);"
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
assert counter == 1
assert counter == 1
class TestFixImportsInsideTestBlocks:
"""Tests for fix_imports_inside_test_blocks function."""
def test_fix_named_import_inside_test_block(self):
"""Test fixing named import inside test function."""
from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks
code = """
test('should work', () => {
const mock = jest.fn();
import { foo } from '../src/module';
expect(foo()).toBe(true);
});
"""
fixed = fix_imports_inside_test_blocks(code)
assert "const { foo } = require('../src/module');" in fixed
assert "import { foo }" not in fixed
def test_fix_default_import_inside_test_block(self):
"""Test fixing default import inside test function."""
from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks
code = """
test('should work', () => {
env.isTest.mockReturnValue(false);
import queuesModule from '../src/queue/queue';
expect(queuesModule).toBeDefined();
});
"""
fixed = fix_imports_inside_test_blocks(code)
assert "const queuesModule = require('../src/queue/queue');" in fixed
assert "import queuesModule from" not in fixed
def test_fix_namespace_import_inside_test_block(self):
"""Test fixing namespace import inside test function."""
from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks
code = """
test('should work', () => {
import * as utils from '../src/utils';
expect(utils.foo()).toBe(true);
});
"""
fixed = fix_imports_inside_test_blocks(code)
assert "const utils = require('../src/utils');" in fixed
assert "import * as utils" not in fixed
def test_preserve_top_level_imports(self):
"""Test that top-level imports are not modified."""
from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks
code = """
import { jest, describe, test, expect } from '@jest/globals';
import { foo } from '../src/module';
describe('test suite', () => {
test('should work', () => {
expect(foo()).toBe(true);
});
});
"""
fixed = fix_imports_inside_test_blocks(code)
# Top-level imports should remain unchanged
assert "import { jest, describe, test, expect } from '@jest/globals';" in fixed
assert "import { foo } from '../src/module';" in fixed
def test_empty_code(self):
"""Test handling empty code."""
from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks
assert fix_imports_inside_test_blocks("") == ""
assert fix_imports_inside_test_blocks(" ") == " "
class TestFixJestMockPaths:
"""Tests for fix_jest_mock_paths function."""
def test_fix_mock_path_when_source_relative(self):
"""Test fixing mock path that's relative to source file."""
from codeflash.languages.javascript.instrument import fix_jest_mock_paths
with tempfile.TemporaryDirectory() as tmpdir:
# Create directory structure
src_dir = Path(tmpdir) / "src" / "queue"
tests_dir = Path(tmpdir) / "tests"
env_file = Path(tmpdir) / "src" / "environment.ts"
src_dir.mkdir(parents=True)
tests_dir.mkdir(parents=True)
env_file.parent.mkdir(parents=True, exist_ok=True)
env_file.write_text("export const env = {};")
source_file = src_dir / "queue.ts"
source_file.write_text("import env from '../environment';")
test_file = tests_dir / "test_queue.test.ts"
# Test code with incorrect mock path (relative to source, not test)
test_code = """
import { jest, describe, test, expect } from '@jest/globals';
jest.mock('../environment');
jest.mock('../redis/utils');
describe('queue', () => {
test('works', () => {});
});
"""
fixed = fix_jest_mock_paths(test_code, test_file, source_file, tests_dir)
# Should fix the path to be relative to the test file
assert "jest.mock('../src/environment')" in fixed
def test_preserve_valid_mock_path(self):
"""Test that valid mock paths are not modified."""
from codeflash.languages.javascript.instrument import fix_jest_mock_paths
with tempfile.TemporaryDirectory() as tmpdir:
# Create directory structure
src_dir = Path(tmpdir) / "src"
tests_dir = Path(tmpdir) / "tests"
src_dir.mkdir(parents=True)
tests_dir.mkdir(parents=True)
# Create the file being mocked at the correct location
mock_file = src_dir / "utils.ts"
mock_file.write_text("export const utils = {};")
source_file = src_dir / "main.ts"
source_file.write_text("")
test_file = tests_dir / "test_main.test.ts"
# Test code with correct mock path (valid from test location)
test_code = """
jest.mock('../src/utils');
describe('main', () => {
test('works', () => {});
});
"""
fixed = fix_jest_mock_paths(test_code, test_file, source_file, tests_dir)
# Should keep the path unchanged since it's valid
assert "jest.mock('../src/utils')" in fixed
def test_fix_doMock_path(self):
"""Test fixing jest.doMock path."""
from codeflash.languages.javascript.instrument import fix_jest_mock_paths
with tempfile.TemporaryDirectory() as tmpdir:
# Create directory structure: src/queue/queue.ts imports ../environment (-> src/environment.ts)
src_dir = Path(tmpdir) / "src"
queue_dir = src_dir / "queue"
tests_dir = Path(tmpdir) / "tests"
env_file = src_dir / "environment.ts"
queue_dir.mkdir(parents=True)
tests_dir.mkdir(parents=True)
env_file.write_text("export const env = {};")
source_file = queue_dir / "queue.ts"
source_file.write_text("")
test_file = tests_dir / "test_queue.test.ts"
# From src/queue/queue.ts, ../environment resolves to src/environment.ts
# Test file is at tests/test_queue.test.ts
# So the correct mock path from test should be ../src/environment
test_code = """
jest.doMock('../environment', () => ({ isTest: jest.fn() }));
"""
fixed = fix_jest_mock_paths(test_code, test_file, source_file, tests_dir)
# Should fix the doMock path
assert "jest.doMock('../src/environment'" in fixed
def test_empty_code(self):
"""Test handling empty code."""
from codeflash.languages.javascript.instrument import fix_jest_mock_paths
with tempfile.TemporaryDirectory() as tmpdir:
tests_dir = Path(tmpdir) / "tests"
tests_dir.mkdir()
source_file = Path(tmpdir) / "src" / "main.ts"
test_file = tests_dir / "test.ts"
assert fix_jest_mock_paths("", test_file, source_file, tests_dir) == ""
assert fix_jest_mock_paths(" ", test_file, source_file, tests_dir) == " "

View file

@ -46,7 +46,7 @@ class TestDiscoverFunctions:
"""Test discovering a simple function declaration."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
function add(a, b) {
export function add(a, b) {
return a + b;
}
""")
@ -62,15 +62,15 @@ function add(a, b) {
"""Test discovering multiple functions."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
function add(a, b) {
export function add(a, b) {
return a + b;
}
function subtract(a, b) {
export function subtract(a, b) {
return a - b;
}
function multiply(a, b) {
export function multiply(a, b) {
return a * b;
}
""")
@ -86,11 +86,11 @@ function multiply(a, b) {
"""Test discovering arrow functions assigned to variables."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
const add = (a, b) => {
export const add = (a, b) => {
return a + b;
};
const multiply = (x, y) => x * y;
export const multiply = (x, y) => x * y;
""")
f.flush()
@ -104,11 +104,11 @@ const multiply = (x, y) => x * y;
"""Test that functions without return are excluded by default."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
function withReturn() {
export function withReturn() {
return 1;
}
function withoutReturn() {
export function withoutReturn() {
console.log("hello");
}
""")
@ -124,7 +124,7 @@ function withoutReturn() {
"""Test discovering class methods."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
class Calculator {
export class Calculator {
add(a, b) {
return a + b;
}
@ -147,11 +147,11 @@ class Calculator {
"""Test discovering async functions."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
async function fetchData(url) {
export async function fetchData(url) {
return await fetch(url);
}
function syncFunction() {
export function syncFunction() {
return 1;
}
""")
@ -171,11 +171,11 @@ function syncFunction() {
"""Test filtering out async functions."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
async function asyncFunc() {
export async function asyncFunc() {
return 1;
}
function syncFunc() {
export function syncFunc() {
return 2;
}
""")
@ -191,11 +191,11 @@ function syncFunc() {
"""Test filtering out class methods."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
function standalone() {
export function standalone() {
return 1;
}
class MyClass {
export class MyClass {
method() {
return 2;
}
@ -212,11 +212,11 @@ class MyClass {
def test_discover_line_numbers(self, js_support):
"""Test that line numbers are correctly captured."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""function func1() {
f.write("""export function func1() {
return 1;
}
function func2() {
export function func2() {
const x = 1;
const y = 2;
return x + y;
@ -238,7 +238,7 @@ function func2() {
"""Test discovering generator functions."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
function* numberGenerator() {
export function* numberGenerator() {
yield 1;
yield 2;
return 3;
@ -271,7 +271,7 @@ function* numberGenerator() {
"""Test discovering function expressions."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
const add = function(a, b) {
export const add = function(a, b) {
return a + b;
};
""")
@ -290,7 +290,7 @@ const add = function(a, b) {
return 1;
})();
function named() {
export function named() {
return 2;
}
""")
@ -476,7 +476,7 @@ class TestExtractCodeContext:
def test_extract_simple_function(self, js_support):
"""Test extracting context for a simple function."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""function add(a, b) {
f.write("""export function add(a, b) {
return a + b;
}
""")
@ -495,11 +495,11 @@ class TestExtractCodeContext:
def test_extract_with_helper(self, js_support):
"""Test extracting context with helper functions."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""function helper(x) {
f.write("""export function helper(x) {
return x * 2;
}
function main(a) {
export function main(a) {
return helper(a) + 1;
}
""")
@ -523,7 +523,7 @@ class TestIntegration:
def test_discover_and_replace_workflow(self, js_support):
"""Test full discover -> replace workflow."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
original_code = """function fibonacci(n) {
original_code = """export function fibonacci(n) {
if (n <= 1) {
return n;
}
@ -541,7 +541,7 @@ class TestIntegration:
assert func.function_name == "fibonacci"
# Replace
optimized_code = """function fibonacci(n) {
optimized_code = """export function fibonacci(n) {
// Memoized version
const memo = {0: 0, 1: 1};
for (let i = 2; i <= n; i++) {
@ -561,7 +561,7 @@ class TestIntegration:
"""Test discovering and working with complex file."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""
class Calculator {
export class Calculator {
add(a, b) {
return a + b;
}
@ -571,13 +571,13 @@ class Calculator {
}
}
class StringUtils {
export class StringUtils {
reverse(s) {
return s.split('').reverse().join('');
}
}
function standalone() {
export function standalone() {
return 42;
}
""")
@ -605,11 +605,11 @@ function standalone() {
f.write("""
import React from 'react';
function Button({ onClick, children }) {
export function Button({ onClick, children }) {
return <button onClick={onClick}>{children}</button>;
}
const Card = ({ title, content }) => {
export const Card = ({ title, content }) => {
return (
<div className="card">
<h2>{title}</h2>
@ -673,7 +673,7 @@ class TestClassMethodExtraction:
def test_extract_class_method_wraps_in_class(self, js_support):
"""Test that extracting a class method wraps it in a class definition."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class Calculator {
f.write("""export class Calculator {
add(a, b) {
return a + b;
}
@ -694,6 +694,7 @@ class TestClassMethodExtraction:
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
# Full string equality check for exact extraction output
# Note: export keyword is not included in extracted class wrapper
expected_code = """class Calculator {
add(a, b) {
return a + b;
@ -709,7 +710,7 @@ class TestClassMethodExtraction:
f.write("""/**
* A simple calculator class.
*/
class Calculator {
export class Calculator {
/**
* Adds two numbers.
* @param {number} a - First number
@ -730,10 +731,9 @@ class Calculator {
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
# Full string equality check - includes class JSDoc, class definition, method JSDoc, and method
expected_code = """/**
* A simple calculator class.
*/
class Calculator {
# Note: export keyword is not included in extracted class wrapper
# Note: Class-level JSDoc is not included when extracting a method
expected_code = """class Calculator {
/**
* Adds two numbers.
* @param {number} a - First number
@ -751,7 +751,7 @@ class Calculator {
def test_extract_class_method_syntax_valid(self, js_support):
"""Test that extracted class method code is always syntactically valid."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class FibonacciCalculator {
f.write("""export class FibonacciCalculator {
fibonacci(n) {
if (n <= 1) {
return n;
@ -769,6 +769,7 @@ class Calculator {
context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent)
# Full string equality check
# Note: export keyword is not included in extracted class wrapper
expected_code = """class FibonacciCalculator {
fibonacci(n) {
if (n <= 1) {
@ -784,7 +785,7 @@ class Calculator {
def test_extract_nested_class_method(self, js_support):
"""Test extracting a method from a nested class structure."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class Outer {
f.write("""export class Outer {
createInner() {
return class Inner {
getValue() {
@ -808,6 +809,7 @@ class Calculator {
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
# Full string equality check
# Note: export keyword is not included in extracted class wrapper
expected_code = """class Outer {
add(a, b) {
return a + b;
@ -820,7 +822,7 @@ class Calculator {
def test_extract_async_class_method(self, js_support):
"""Test extracting an async class method."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class ApiClient {
f.write("""export class ApiClient {
async fetchData(url) {
const response = await fetch(url);
return response.json();
@ -836,6 +838,7 @@ class Calculator {
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
# Full string equality check
# Note: export keyword is not included in extracted class wrapper
expected_code = """class ApiClient {
async fetchData(url) {
const response = await fetch(url);
@ -849,7 +852,7 @@ class Calculator {
def test_extract_static_class_method(self, js_support):
"""Test extracting a static class method."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class MathUtils {
f.write("""export class MathUtils {
static add(a, b) {
return a + b;
}
@ -869,6 +872,7 @@ class Calculator {
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
# Full string equality check
# Note: export keyword is not included in extracted class wrapper
expected_code = """class MathUtils {
static add(a, b) {
return a + b;
@ -881,7 +885,7 @@ class Calculator {
def test_extract_class_method_without_class_jsdoc(self, js_support):
"""Test extracting a method from a class without JSDoc."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class SimpleClass {
f.write("""export class SimpleClass {
simpleMethod() {
return "hello";
}
@ -896,6 +900,7 @@ class Calculator {
context = js_support.extract_code_context(method, file_path.parent, file_path.parent)
# Full string equality check
# Note: export keyword is not included in extracted class wrapper
expected_code = """class SimpleClass {
simpleMethod() {
return "hello";
@ -1061,7 +1066,7 @@ class TestClassMethodEdgeCases:
def test_class_with_constructor(self, js_support):
"""Test handling classes with constructors."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class Counter {
f.write("""export class Counter {
constructor(start = 0) {
this.value = start;
}
@ -1083,7 +1088,7 @@ class TestClassMethodEdgeCases:
def test_class_with_getters_setters(self, js_support):
"""Test handling classes with getters and setters."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class Person {
f.write("""export class Person {
constructor(name) {
this._name = name;
}
@ -1113,13 +1118,13 @@ class TestClassMethodEdgeCases:
def test_class_extending_another(self, js_support):
"""Test handling classes that extend another class."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class Animal {
f.write("""export class Animal {
speak() {
return 'sound';
}
}
class Dog extends Animal {
export class Dog extends Animal {
speak() {
return 'bark';
}
@ -1141,6 +1146,7 @@ class Dog extends Animal {
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
# Full string equality check
# Note: export keyword is not included in extracted class wrapper
expected_code = """class Dog {
fetch() {
return 'ball';
@ -1153,7 +1159,7 @@ class Dog extends Animal {
def test_class_with_private_method(self, js_support):
"""Test handling classes with private methods (ES2022+)."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class SecureClass {
f.write("""export class SecureClass {
#privateMethod() {
return 'secret';
}
@ -1175,7 +1181,7 @@ class Dog extends Animal {
def test_commonjs_class_export(self, js_support):
"""Test handling CommonJS exported classes."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class Calculator {
f.write("""export class Calculator {
add(a, b) {
return a + b;
}
@ -1236,7 +1242,7 @@ class TestExtractionReplacementRoundTrip:
3. Replace extracts just the method body and replaces in original
"""
original_source = """\
class Counter {
export class Counter {
constructor(initial = 0) {
this.count = initial;
}
@ -1303,7 +1309,7 @@ class Counter {
# Verify result with exact string equality
expected_result = """\
class Counter {
export class Counter {
constructor(initial = 0) {
this.count = initial;
}
@ -1333,7 +1339,7 @@ module.exports = { Counter };
ts_support = TypeScriptSupport()
original_source = """\
class User {
export class User {
private name: string;
private age: number;
@ -1350,8 +1356,6 @@ class User {
return this.age;
}
}
export { User };
"""
with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f:
f.write(original_source)
@ -1408,7 +1412,7 @@ class User {
# Verify result with exact string equality
expected_result = """\
class User {
export class User {
private name: string;
private age: number;
@ -1426,8 +1430,6 @@ class User {
return this.age;
}
}
export { User };
"""
assert result == expected_result, (
f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}"
@ -1437,7 +1439,7 @@ export { User };
def test_extract_replace_preserves_other_methods(self, js_support):
"""Test that replacing one method doesn't affect others."""
original_source = """\
class Calculator {
export class Calculator {
constructor(precision = 2) {
this.precision = precision;
}
@ -1499,7 +1501,7 @@ class Calculator {
# Verify result with exact string equality
expected_result = """\
class Calculator {
export class Calculator {
constructor(precision = 2) {
this.precision = precision;
}
@ -1525,7 +1527,7 @@ class Calculator {
def test_extract_static_method_then_replace(self, js_support):
"""Test extracting and replacing a static method."""
original_source = """\
class MathUtils {
export class MathUtils {
constructor() {
this.cache = {};
}
@ -1538,8 +1540,6 @@ class MathUtils {
return a * b;
}
}
module.exports = { MathUtils };
"""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write(original_source)
@ -1586,7 +1586,7 @@ class MathUtils {
# Verify result with exact string equality
expected_result = """\
class MathUtils {
export class MathUtils {
constructor() {
this.cache = {};
}
@ -1600,8 +1600,6 @@ class MathUtils {
return a * b;
}
}
module.exports = { MathUtils };
"""
assert result == expected_result, (
f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}"

View file

@ -29,7 +29,7 @@ class TestDiscoverTests:
# Create source file
source_file = tmpdir / "math.js"
source_file.write_text("""
function add(a, b) {
export function add(a, b) {
return a + b;
}
@ -71,7 +71,7 @@ describe('add function', () => {
# Create source file
source_file = tmpdir / "calculator.js"
source_file.write_text("""
function multiply(a, b) {
export function multiply(a, b) {
return a * b;
}
@ -103,7 +103,7 @@ describe('multiply', () => {
# Create source file
source_file = tmpdir / "utils.js"
source_file.write_text("""
function formatDate(date) {
export function formatDate(date) {
return date.toISOString();
}
@ -136,11 +136,11 @@ test('formats date correctly', () => {
source_file = tmpdir / "string_utils.js"
source_file.write_text("""
function capitalize(str) {
export function capitalize(str) {
return str.charAt(0).toUpperCase() + str.slice(1);
}
function lowercase(str) {
export function lowercase(str) {
return str.toLowerCase();
}
@ -186,7 +186,7 @@ describe('String Utils', () => {
source_file = tmpdir / "array_utils.js"
source_file.write_text("""
function sum(arr) {
export function sum(arr) {
return arr.reduce((a, b) => a + b, 0);
}
@ -254,7 +254,7 @@ test('subtract two numbers', () => {
source_file = tmpdir / "greeter.js"
source_file.write_text("""
function greet(name) {
export function greet(name) {
return `Hello, ${name}!`;
}
@ -282,7 +282,7 @@ test('greets by name', () => {
source_file = tmpdir / "calculator_class.js"
source_file.write_text("""
class Calculator {
export class Calculator {
add(a, b) {
return a + b;
}
@ -333,7 +333,7 @@ describe('Calculator class', () => {
source_file = src_dir / "helpers.js"
source_file.write_text("""
function clamp(value, min, max) {
export function clamp(value, min, max) {
return Math.min(Math.max(value, min), max);
}
@ -375,11 +375,11 @@ describe('clamp', () => {
source_file = tmpdir / "async_utils.js"
source_file.write_text("""
async function fetchData(url) {
export async function fetchData(url) {
return await fetch(url).then(r => r.json());
}
async function delay(ms) {
export async function delay(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
@ -413,7 +413,7 @@ describe('async utilities', () => {
source_file.write_text("""
import React from 'react';
function Button({ onClick, children }) {
export function Button({ onClick, children }) {
return <button onClick={onClick}>{children}</button>;
}
@ -449,7 +449,7 @@ describe('Button component', () => {
source_file = tmpdir / "untested.js"
source_file.write_text("""
function untestedFunction() {
export function untestedFunction() {
return 42;
}
@ -479,11 +479,11 @@ test('other test', () => {
source_file = tmpdir / "validators.js"
source_file.write_text("""
function isEmail(str) {
export function isEmail(str) {
return str.includes('@');
}
function isUrl(str) {
export function isUrl(str) {
return str.startsWith('http');
}
@ -515,11 +515,11 @@ describe('validators', () => {
source_file = tmpdir / "shared_utils.js"
source_file.write_text("""
function helper1() {
export function helper1() {
return 1;
}
function helper2() {
export function helper2() {
return 2;
}
@ -558,7 +558,7 @@ test('helper2 returns 2', () => {
source_file = tmpdir / "format.js"
source_file.write_text("""
function formatNumber(n) {
export function formatNumber(n) {
return n.toFixed(2);
}
@ -587,7 +587,7 @@ test(`formatNumber with decimal`, () => {
source_file = tmpdir / "transform.js"
source_file.write_text("""
function transformData(data) {
export function transformData(data) {
return data.map(x => x * 2);
}
@ -792,8 +792,8 @@ class TestImportAnalysis:
source_file = tmpdir / "funcs.js"
source_file.write_text("""
function funcA() { return 1; }
function funcB() { return 2; }
export function funcA() { return 1; }
export function funcB() { return 2; }
module.exports = { funcA, funcB };
""")
@ -846,7 +846,7 @@ test('funcX works', () => {
source_file = tmpdir / "default_export.js"
source_file.write_text("""
function mainFunc() { return 'main'; }
export function mainFunc() { return 'main'; }
module.exports = mainFunc;
""")
@ -875,7 +875,7 @@ class TestEdgeCases:
source_file = tmpdir / "commented.js"
source_file.write_text("""
function compute() { return 42; }
export function compute() { return 42; }
module.exports = { compute };
""")
@ -908,7 +908,7 @@ test('block commented', () => {
source_file = tmpdir / "valid.js"
source_file.write_text("""
function validFunc() { return 1; }
export function validFunc() { return 1; }
module.exports = { validFunc };
""")
@ -933,8 +933,8 @@ test('broken test' { // Missing arrow function
source_file = tmpdir / "conflict.js"
source_file.write_text("""
function test(value) { return value > 0; }
function describe(obj) { return JSON.stringify(obj); }
export function test(value) { return value > 0; }
export function describe(obj) { return JSON.stringify(obj); }
module.exports = { test, describe };
""")
@ -962,7 +962,7 @@ describe('conflict tests', () => {
source_file = tmpdir / "lonely.js"
source_file.write_text("""
function lonelyFunc() { return 'alone'; }
export function lonelyFunc() { return 'alone'; }
module.exports = { lonelyFunc };
""")
@ -980,14 +980,14 @@ module.exports = { lonelyFunc };
file_a = tmpdir / "moduleA.js"
file_a.write_text("""
const { funcB } = require('./moduleB');
function funcA() { return 'A' + (funcB ? funcB() : ''); }
export function funcA() { return 'A' + (funcB ? funcB() : ''); }
module.exports = { funcA };
""")
file_b = tmpdir / "moduleB.js"
file_b.write_text("""
const { funcA } = require('./moduleA');
function funcB() { return 'B'; }
export function funcB() { return 'B'; }
module.exports = { funcB };
""")
@ -1126,17 +1126,17 @@ class TestTestDiscoveryIntegration:
# Source file
source_file = src_dir / "utils.js"
source_file.write_text(r"""
function validateEmail(email) {
export function validateEmail(email) {
const re = /^[^\s@]+@[^\s@]+\.[^\s@]+$/;
return re.test(email);
}
function validatePhone(phone) {
export function validatePhone(phone) {
const re = /^\d{10}$/;
return re.test(phone);
}
function formatName(first, last) {
export function formatName(first, last) {
return `${first} ${last}`.trim();
}
@ -1197,7 +1197,7 @@ describe('formatName', () => {
source_file = tmpdir / "database.js"
source_file.write_text("""
class Database {
export class Database {
constructor() {
this.data = [];
}
@ -1259,13 +1259,13 @@ class TestImportFilteringDetailed:
# Create two source files
source_a = tmpdir / "moduleA.js"
source_a.write_text("""
function funcA() { return 'A'; }
export function funcA() { return 'A'; }
module.exports = { funcA };
""")
source_b = tmpdir / "moduleB.js"
source_b.write_text("""
function funcB() { return 'B'; }
export function funcB() { return 'B'; }
module.exports = { funcB };
""")
@ -1296,9 +1296,9 @@ test('funcA works', () => {
source_file = tmpdir / "utils.js"
source_file.write_text("""
function funcOne() { return 1; }
function funcTwo() { return 2; }
function funcThree() { return 3; }
export function funcOne() { return 1; }
export function funcTwo() { return 2; }
export function funcThree() { return 3; }
module.exports = { funcOne, funcTwo, funcThree };
""")
@ -1325,7 +1325,7 @@ test('funcOne returns 1', () => {
source_file = tmpdir / "target.js"
source_file.write_text("""
function targetFunc() { return 'target'; }
export function targetFunc() { return 'target'; }
module.exports = { targetFunc };
""")
@ -1354,7 +1354,7 @@ test('mentions targetFunc in string', () => {
source_file = tmpdir / "math.js"
source_file.write_text("""
function calculate(x) { return x * 2; }
export function calculate(x) { return x * 2; }
module.exports = { calculate };
""")
@ -1380,7 +1380,7 @@ test('calculate doubles', () => {
source_file = tmpdir / "myclass.js"
source_file.write_text("""
class MyClass {
export class MyClass {
methodA() { return 'A'; }
methodB() { return 'B'; }
}
@ -1416,7 +1416,7 @@ describe('MyClass', () => {
source_file = src_dir / "helpers.js"
source_file.write_text("""
function deepHelper() { return 'deep'; }
export function deepHelper() { return 'deep'; }
module.exports = { deepHelper };
""")
@ -1574,9 +1574,9 @@ class TestFunctionToTestMapping:
source_file = tmpdir / "multiple.js"
source_file.write_text("""
function addNumbers(a, b) { return a + b; }
function subtractNumbers(a, b) { return a - b; }
function multiplyNumbers(a, b) { return a * b; }
export function addNumbers(a, b) { return a + b; }
export function subtractNumbers(a, b) { return a - b; }
export function multiplyNumbers(a, b) { return a * b; }
module.exports = { addNumbers, subtractNumbers, multiplyNumbers };
""")
@ -1613,7 +1613,7 @@ describe('subtractNumbers', () => {
source_file = tmpdir / "funcs.js"
source_file.write_text("""
function targetFunc() { return 'target'; }
export function targetFunc() { return 'target'; }
module.exports = { targetFunc };
""")
@ -1705,7 +1705,7 @@ class TestQualifiedNames:
source_file = tmpdir / "calculator.js"
source_file.write_text("""
class Calculator {
export class Calculator {
add(a, b) { return a + b; }
subtract(a, b) { return a - b; }
}
@ -1726,7 +1726,7 @@ module.exports = { Calculator };
source_file = tmpdir / "nested.js"
source_file.write_text("""
class Outer {
export class Outer {
innerMethod() {
class Inner {
deepMethod() { return 'deep'; }

View file

@ -109,12 +109,7 @@ class Calculator {
factorial_helper = helper_dict["factorial"]
expected_factorial_code = """\
/**
* Calculate factorial recursively.
* @param n - Non-negative integer
* @returns Factorial of n
*/
function factorial(n) {
export function factorial(n) {
// Intentionally inefficient recursive implementation
if (n <= 1) return 1;
return n * factorial(n - 1);
@ -196,46 +191,22 @@ class Calculator {
# STRICT: Verify each helper's code exactly
expected_add_code = """\
/**
* Add two numbers.
* @param a - First number
* @param b - Second number
* @returns Sum of a and b
*/
function add(a, b) {
export function add(a, b) {
return a + b;
}"""
expected_multiply_code = """\
/**
* Multiply two numbers.
* @param a - First number
* @param b - Second number
* @returns Product of a and b
*/
function multiply(a, b) {
export function multiply(a, b) {
return a * b;
}"""
expected_format_number_code = """\
/**
* Format a number to specified decimal places.
* @param num - Number to format
* @param decimals - Number of decimal places
* @returns Formatted number
*/
function formatNumber(num, decimals) {
export function formatNumber(num, decimals) {
return Number(num.toFixed(decimals));
}"""
expected_validate_input_code = """\
/**
* Validate that input is a valid number.
* @param value - Value to validate
* @param name - Parameter name for error message
* @throws Error if value is not a valid number
*/
function validateInput(value, name) {
export function validateInput(value, name) {
if (typeof value !== 'number' || isNaN(value)) {
throw new Error(`Invalid ${name}: must be a number`);
}
@ -317,13 +288,7 @@ class Calculator {
assert set(helper_dict.keys()) == {"add"}, f"Expected 'add' helper, got: {list(helper_dict.keys())}"
expected_add_code = """\
/**
* Add two numbers.
* @param a - First number
* @param b - Second number
* @returns Sum of a and b
*/
function add(a, b) {
export function add(a, b) {
return a + b;
}"""
@ -702,7 +667,7 @@ class TestCodeExtractorEdgeCases:
def test_standalone_function(self, js_support, tmp_path):
"""Test standalone function with no helpers."""
source = """\
function standalone(x) {
export function standalone(x) {
return x * 2;
}
@ -718,7 +683,7 @@ module.exports = { standalone };
# STRICT: Exact code comparison
expected_code = """\
function standalone(x) {
export function standalone(x) {
return x * 2;
}"""
assert context.target_code.strip() == expected_code.strip(), (
@ -735,7 +700,7 @@ function standalone(x) {
source = """\
const _ = require('lodash');
function processArray(arr) {
export function processArray(arr) {
return _.map(arr, x => x * 2);
}
@ -750,7 +715,7 @@ module.exports = { processArray };
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
function processArray(arr) {
export function processArray(arr) {
return _.map(arr, x => x * 2);
}"""
@ -769,7 +734,7 @@ function processArray(arr) {
def test_recursive_function(self, js_support, tmp_path):
"""Test recursive function doesn't list itself as helper."""
source = """\
function fibonacci(n) {
export function fibonacci(n) {
if (n <= 1) return n;
return fibonacci(n - 1) + fibonacci(n - 2);
}
@ -786,7 +751,7 @@ module.exports = { fibonacci };
# STRICT: Exact code comparison
expected_code = """\
function fibonacci(n) {
export function fibonacci(n) {
if (n <= 1) return n;
return fibonacci(n - 1) + fibonacci(n - 2);
}"""
@ -803,7 +768,7 @@ function fibonacci(n) {
source = """\
const helper = (x) => x * 2;
const processValue = (value) => {
export const processValue = (value) => {
return helper(value) + 1;
};
@ -818,7 +783,7 @@ module.exports = { processValue };
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
const processValue = (value) => {
export const processValue = (value) => {
return helper(value) + 1;
};"""
@ -854,7 +819,7 @@ class TestClassContextExtraction:
def test_method_extraction_includes_constructor(self, js_support, tmp_path):
"""Test that extracting a class method includes the constructor."""
source = """\
class Counter {
export class Counter {
constructor(initial = 0) {
this.count = initial;
}
@ -894,7 +859,7 @@ class Counter {
def test_method_extraction_class_without_constructor(self, js_support, tmp_path):
"""Test extracting a method from a class that has no constructor."""
source = """\
class MathUtils {
export class MathUtils {
add(a, b) {
return a + b;
}
@ -928,7 +893,7 @@ class MathUtils {
def test_typescript_method_extraction_includes_fields(self, ts_support, tmp_path):
"""Test that TypeScript method extraction includes class fields."""
source = """\
class User {
export class User {
private name: string;
public age: number;
@ -941,8 +906,6 @@ class User {
return this.name;
}
}
export { User };
"""
test_file = tmp_path / "user.ts"
test_file.write_text(source)
@ -974,7 +937,7 @@ class User {
def test_typescript_fields_only_no_constructor(self, ts_support, tmp_path):
"""Test TypeScript class with fields but no constructor."""
source = """\
class Config {
export class Config {
readonly apiUrl: string = "https://api.example.com";
timeout: number = 5000;
@ -982,8 +945,6 @@ class Config {
return this.apiUrl;
}
}
export { Config };
"""
test_file = tmp_path / "config.ts"
test_file.write_text(source)
@ -1010,7 +971,7 @@ class Config {
def test_constructor_with_jsdoc(self, js_support, tmp_path):
"""Test that constructor with JSDoc is fully extracted."""
source = """\
class Logger {
export class Logger {
/**
* Create a new Logger instance.
* @param {string} prefix - The prefix to use for log messages.
@ -1056,7 +1017,7 @@ class Logger {
def test_static_method_includes_constructor(self, js_support, tmp_path):
"""Test that static method extraction also includes constructor context."""
source = """\
class Factory {
export class Factory {
constructor(config) {
this.config = config;
}
@ -1212,13 +1173,11 @@ interface Point {
y: number;
}
function distance(p1: Point, p2: Point): number {
export function distance(p1: Point, p2: Point): number {
const dx = p2.x - p1.x;
const dy = p2.y - p1.y;
return Math.sqrt(dx * dx + dy * dy);
}
export { distance };
"""
test_file = tmp_path / "geometry.ts"
test_file.write_text(source)
@ -1251,7 +1210,7 @@ enum Status {
FAILURE = 'failure',
}
function processStatus(status: Status): string {
export function processStatus(status: Status): string {
switch (status) {
case Status.PENDING:
return 'Processing...';
@ -1261,8 +1220,6 @@ function processStatus(status: Status): string {
return 'Failed!';
}
}
export { processStatus };
"""
test_file = tmp_path / "status.ts"
test_file.write_text(source)
@ -1295,11 +1252,9 @@ type Result<T> = {
success: boolean;
};
function compute(x: number): Result<number> {
export function compute(x: number): Result<number> {
return { value: x * 2, success: true };
}
export { compute };
"""
test_file = tmp_path / "compute.ts"
test_file.write_text(source)
@ -1331,7 +1286,7 @@ interface Config {
retries: number;
}
class Service {
export class Service {
private config: Config;
constructor(config: Config) {
@ -1342,8 +1297,6 @@ class Service {
return this.config.timeout;
}
}
export { Service };
"""
test_file = tmp_path / "service.ts"
test_file.write_text(source)
@ -1372,11 +1325,9 @@ interface Config {
def test_primitive_types_not_included(self, ts_support, tmp_path):
"""Test that primitive types (number, string, etc.) are not extracted."""
source = """\
function add(a: number, b: number): number {
export function add(a: number, b: number): number {
return a + b;
}
export { add };
"""
test_file = tmp_path / "add.ts"
test_file.write_text(source)
@ -1405,11 +1356,9 @@ interface Size {
height: number;
}
function createRect(origin: Point, size: Size): { origin: Point; size: Size } {
export function createRect(origin: Point, size: Size): { origin: Point; size: Size } {
return { origin, size };
}
export { createRect };
"""
test_file = tmp_path / "rect.ts"
test_file.write_text(source)
@ -1447,7 +1396,7 @@ interface Size {
geometry_file.write_text("""\
import { Point, CalculationConfig } from './types';
function calculateDistance(p1: Point, p2: Point, config: CalculationConfig): number {
export function calculateDistance(p1: Point, p2: Point, config: CalculationConfig): number {
const dx = p2.x - p1.x;
const dy = p2.y - p1.y;
const distance = Math.sqrt(dx * dx + dy * dy);
@ -1458,8 +1407,6 @@ function calculateDistance(p1: Point, p2: Point, config: CalculationConfig): num
}
return distance;
}
export { calculateDistance };
""")
functions = ts_support.discover_functions(geometry_file)
@ -1506,11 +1453,9 @@ interface User {
name: string;
}
function greetUser(user: User): string {
export function greetUser(user: User): string {
return `Hello, ${user.name}!`;
}
export { greetUser };
"""
test_file = tmp_path / "user.ts"
test_file.write_text(source)

View file

@ -749,7 +749,7 @@ class TestSimpleFunctionReplacement:
def test_replace_simple_function_body(self, js_support, temp_project):
"""Test replacing a simple function body preserves structure exactly."""
original_source = """\
function add(a, b) {
export function add(a, b) {
return a + b;
}
"""
@ -761,7 +761,7 @@ function add(a, b) {
# Optimized version with different body
optimized_code = """\
function add(a, b) {
export function add(a, b) {
// Optimized: direct return
return a + b;
}
@ -770,7 +770,7 @@ function add(a, b) {
result = js_support.replace_function(original_source, func, optimized_code)
expected_result = """\
function add(a, b) {
export function add(a, b) {
// Optimized: direct return
return a + b;
}
@ -781,7 +781,7 @@ function add(a, b) {
def test_replace_function_with_multiple_statements(self, js_support, temp_project):
"""Test replacing function with complex multi-statement body."""
original_source = """\
function processData(data) {
export function processData(data) {
const result = [];
for (let i = 0; i < data.length; i++) {
result.push(data[i] * 2);
@ -797,7 +797,7 @@ function processData(data) {
# Optimized version using map
optimized_code = """\
function processData(data) {
export function processData(data) {
return data.map(x => x * 2);
}
"""
@ -805,7 +805,7 @@ function processData(data) {
result = js_support.replace_function(original_source, func, optimized_code)
expected_result = """\
function processData(data) {
export function processData(data) {
return data.map(x => x * 2);
}
"""
@ -817,12 +817,12 @@ function processData(data) {
original_source = """\
const CONFIG = { debug: true };
function targetFunction(x) {
export function targetFunction(x) {
console.log(x);
return x * 2;
}
function otherFunction(y) {
export function otherFunction(y) {
return y + 1;
}
@ -835,7 +835,7 @@ module.exports = { targetFunction, otherFunction };
target_func = next(f for f in functions if f.function_name == "targetFunction")
optimized_code = """\
function targetFunction(x) {
export function targetFunction(x) {
return x << 1;
}
"""
@ -845,11 +845,11 @@ function targetFunction(x) {
expected_result = """\
const CONFIG = { debug: true };
function targetFunction(x) {
export function targetFunction(x) {
return x << 1;
}
function otherFunction(y) {
export function otherFunction(y) {
return y + 1;
}
@ -865,7 +865,7 @@ class TestClassMethodReplacement:
def test_replace_class_method_body(self, js_support, temp_project):
"""Test replacing a class method body preserves class structure."""
original_source = """\
class Calculator {
export class Calculator {
constructor(precision = 2) {
this.precision = precision;
}
@ -888,7 +888,7 @@ class Calculator {
# Optimized version provided in class context
optimized_code = """\
class Calculator {
export class Calculator {
constructor(precision = 2) {
this.precision = precision;
}
@ -902,7 +902,7 @@ class Calculator {
result = js_support.replace_function(original_source, add_method, optimized_code)
expected_result = """\
class Calculator {
export class Calculator {
constructor(precision = 2) {
this.precision = precision;
}
@ -922,7 +922,7 @@ class Calculator {
def test_replace_method_calling_sibling_methods(self, js_support, temp_project):
"""Test replacing method that calls other methods in same class."""
original_source = """\
class DataProcessor {
export class DataProcessor {
constructor() {
this.cache = new Map();
}
@ -950,7 +950,7 @@ class DataProcessor {
process_method = next(f for f in functions if f.function_name == "process")
optimized_code = """\
class DataProcessor {
export class DataProcessor {
constructor() {
this.cache = new Map();
}
@ -967,7 +967,7 @@ class DataProcessor {
result = js_support.replace_function(original_source, process_method, optimized_code)
expected_result = """\
class DataProcessor {
export class DataProcessor {
constructor() {
this.cache = new Map();
}
@ -1000,7 +1000,7 @@ class TestJSDocPreservation:
* @param {number} b - Second number
* @returns {number} The sum
*/
function add(a, b) {
export function add(a, b) {
const sum = a + b;
return sum;
}
@ -1012,13 +1012,7 @@ function add(a, b) {
func = functions[0]
optimized_code = """\
/**
* Calculates the sum of two numbers.
* @param {number} a - First number
* @param {number} b - Second number
* @returns {number} The sum
*/
function add(a, b) {
export function add(a, b) {
return a + b;
}
"""
@ -1032,7 +1026,7 @@ function add(a, b) {
* @param {number} b - Second number
* @returns {number} The sum
*/
function add(a, b) {
export function add(a, b) {
return a + b;
}
"""
@ -1046,7 +1040,7 @@ function add(a, b) {
* A simple cache implementation.
* @class Cache
*/
class Cache {
export class Cache {
constructor() {
this.data = new Map();
}
@ -1095,7 +1089,7 @@ class Cache {
* A simple cache implementation.
* @class Cache
*/
class Cache {
export class Cache {
constructor() {
this.data = new Map();
}
@ -1120,7 +1114,7 @@ class TestAsyncFunctionReplacement:
def test_replace_async_function_body(self, js_support, temp_project):
"""Test replacing async function preserves async keyword."""
original_source = """\
async function fetchData(url) {
export async function fetchData(url) {
const response = await fetch(url);
const data = await response.json();
return data;
@ -1133,7 +1127,7 @@ async function fetchData(url) {
func = functions[0]
optimized_code = """\
async function fetchData(url) {
export async function fetchData(url) {
return (await fetch(url)).json();
}
"""
@ -1141,7 +1135,7 @@ async function fetchData(url) {
result = js_support.replace_function(original_source, func, optimized_code)
expected_result = """\
async function fetchData(url) {
export async function fetchData(url) {
return (await fetch(url)).json();
}
"""
@ -1151,7 +1145,7 @@ async function fetchData(url) {
def test_replace_async_class_method(self, js_support, temp_project):
"""Test replacing async class method."""
original_source = """\
class ApiClient {
export class ApiClient {
constructor(baseUrl) {
this.baseUrl = baseUrl;
}
@ -1190,7 +1184,7 @@ class ApiClient {
result = js_support.replace_function(original_source, get_method, optimized_code)
expected_result = """\
class ApiClient {
export class ApiClient {
constructor(baseUrl) {
this.baseUrl = baseUrl;
}
@ -1212,7 +1206,7 @@ class TestGeneratorFunctionReplacement:
def test_replace_generator_function_body(self, js_support, temp_project):
"""Test replacing generator function preserves generator syntax."""
original_source = """\
function* range(start, end) {
export function* range(start, end) {
for (let i = start; i < end; i++) {
yield i;
}
@ -1225,7 +1219,7 @@ function* range(start, end) {
func = functions[0]
optimized_code = """\
function* range(start, end) {
export function* range(start, end) {
let i = start;
while (i < end) yield i++;
}
@ -1234,7 +1228,7 @@ function* range(start, end) {
result = js_support.replace_function(original_source, func, optimized_code)
expected_result = """\
function* range(start, end) {
export function* range(start, end) {
let i = start;
while (i < end) yield i++;
}
@ -1249,7 +1243,7 @@ class TestTypeScriptReplacement:
def test_replace_typescript_function_with_types(self, ts_support, temp_project):
"""Test replacing TypeScript function preserves type annotations."""
original_source = """\
function processArray(items: number[]): number {
export function processArray(items: number[]): number {
let sum = 0;
for (let i = 0; i < items.length; i++) {
sum += items[i];
@ -1264,7 +1258,7 @@ function processArray(items: number[]): number {
func = functions[0]
optimized_code = """\
function processArray(items: number[]): number {
export function processArray(items: number[]): number {
return items.reduce((a, b) => a + b, 0);
}
"""
@ -1272,7 +1266,7 @@ function processArray(items: number[]): number {
result = ts_support.replace_function(original_source, func, optimized_code)
expected_result = """\
function processArray(items: number[]): number {
export function processArray(items: number[]): number {
return items.reduce((a, b) => a + b, 0);
}
"""
@ -1282,7 +1276,7 @@ function processArray(items: number[]): number {
def test_replace_typescript_class_method_with_generics(self, ts_support, temp_project):
"""Test replacing TypeScript generic class method."""
original_source = """\
class Container<T> {
export class Container<T> {
private items: T[] = [];
add(item: T): void {
@ -1317,7 +1311,7 @@ class Container<T> {
result = ts_support.replace_function(original_source, get_all_method, optimized_code)
expected_result = """\
class Container<T> {
export class Container<T> {
private items: T[] = [];
add(item: T): void {
@ -1341,7 +1335,7 @@ interface User {
email: string;
}
function createUser(name: string, email: string): User {
export function createUser(name: string, email: string): User {
const id = Math.random().toString(36).substring(2, 15);
const user: User = {
id: id,
@ -1358,7 +1352,7 @@ function createUser(name: string, email: string): User {
func = next(f for f in functions if f.function_name == "createUser")
optimized_code = """\
function createUser(name: string, email: string): User {
export function createUser(name: string, email: string): User {
return {
id: Math.random().toString(36).substring(2, 15),
name,
@ -1376,7 +1370,7 @@ interface User {
email: string;
}
function createUser(name: string, email: string): User {
export function createUser(name: string, email: string): User {
return {
id: Math.random().toString(36).substring(2, 15),
name,
@ -1394,7 +1388,7 @@ class TestComplexReplacements:
def test_replace_function_with_nested_functions(self, js_support, temp_project):
"""Test replacing function that contains nested function definitions."""
original_source = """\
function processItems(items) {
export function processItems(items) {
function helper(item) {
return item * 2;
}
@ -1413,7 +1407,7 @@ function processItems(items) {
process_func = next(f for f in functions if f.function_name == "processItems")
optimized_code = """\
function processItems(items) {
export function processItems(items) {
const helper = x => x * 2;
return items.map(helper);
}
@ -1422,7 +1416,7 @@ function processItems(items) {
result = js_support.replace_function(original_source, process_func, optimized_code)
expected_result = """\
function processItems(items) {
export function processItems(items) {
const helper = x => x * 2;
return items.map(helper);
}
@ -1433,7 +1427,7 @@ function processItems(items) {
def test_replace_multiple_methods_sequentially(self, js_support, temp_project):
"""Test replacing multiple methods in the same class sequentially."""
original_source = """\
class MathUtils {
export class MathUtils {
static sum(arr) {
let total = 0;
for (let i = 0; i < arr.length; i++) {
@ -1470,7 +1464,7 @@ class MathUtils {
result = js_support.replace_function(original_source, sum_method, optimized_sum)
expected_after_first = """\
class MathUtils {
export class MathUtils {
static sum(arr) {
return arr.reduce((a, b) => a + b, 0);
}
@ -1491,7 +1485,7 @@ class MathUtils {
def test_replace_function_with_complex_destructuring(self, js_support, temp_project):
"""Test replacing function with complex parameter destructuring."""
original_source = """\
function processConfig({ server: { host, port }, database: { url, poolSize } }) {
export function processConfig({ server: { host, port }, database: { url, poolSize } }) {
const serverUrl = host + ':' + port;
const dbConnection = url + '?poolSize=' + poolSize;
return {
@ -1507,7 +1501,7 @@ function processConfig({ server: { host, port }, database: { url, poolSize } })
func = functions[0]
optimized_code = """\
function processConfig({ server: { host, port }, database: { url, poolSize } }) {
export function processConfig({ server: { host, port }, database: { url, poolSize } }) {
return {
server: `${host}:${port}`,
db: `${url}?poolSize=${poolSize}`
@ -1518,7 +1512,7 @@ function processConfig({ server: { host, port }, database: { url, poolSize } })
result = js_support.replace_function(original_source, func, optimized_code)
expected_result = """\
function processConfig({ server: { host, port }, database: { url, poolSize } }) {
export function processConfig({ server: { host, port }, database: { url, poolSize } }) {
return {
server: `${host}:${port}`,
db: `${url}?poolSize=${poolSize}`
@ -1535,7 +1529,7 @@ class TestEdgeCases:
def test_replace_minimal_function_body(self, js_support, temp_project):
"""Test replacing function with minimal body."""
original_source = """\
function minimal() {
export function minimal() {
return null;
}
"""
@ -1546,7 +1540,7 @@ function minimal() {
func = functions[0]
optimized_code = """\
function minimal() {
export function minimal() {
return { initialized: true, timestamp: Date.now() };
}
"""
@ -1554,7 +1548,7 @@ function minimal() {
result = js_support.replace_function(original_source, func, optimized_code)
expected_result = """\
function minimal() {
export function minimal() {
return { initialized: true, timestamp: Date.now() };
}
"""
@ -1564,7 +1558,7 @@ function minimal() {
def test_replace_single_line_function(self, js_support, temp_project):
"""Test replacing single-line function."""
original_source = """\
function identity(x) { return x; }
export function identity(x) { return x; }
"""
file_path = temp_project / "utils.js"
file_path.write_text(original_source, encoding="utf-8")
@ -1573,13 +1567,13 @@ function identity(x) { return x; }
func = functions[0]
optimized_code = """\
function identity(x) { return x ?? null; }
export function identity(x) { return x ?? null; }
"""
result = js_support.replace_function(original_source, func, optimized_code)
expected_result = """\
function identity(x) { return x ?? null; }
export function identity(x) { return x ?? null; }
"""
assert result == expected_result
assert js_support.validate_syntax(result) is True
@ -1587,7 +1581,7 @@ function identity(x) { return x ?? null; }
def test_replace_function_with_special_characters_in_strings(self, js_support, temp_project):
"""Test replacing function containing special characters in strings."""
original_source = """\
function formatMessage(name) {
export function formatMessage(name) {
const greeting = 'Hello, ' + name + '!';
const special = "Contains \\"quotes\\" and \\n newlines";
return greeting + ' ' + special;
@ -1600,7 +1594,7 @@ function formatMessage(name) {
func = functions[0]
optimized_code = """\
function formatMessage(name) {
export function formatMessage(name) {
return `Hello, ${name}! Contains "quotes" and
newlines`;
}
@ -1609,7 +1603,7 @@ function formatMessage(name) {
result = js_support.replace_function(original_source, func, optimized_code)
expected_result = """\
function formatMessage(name) {
export function formatMessage(name) {
return `Hello, ${name}! Contains "quotes" and
newlines`;
}
@ -1620,7 +1614,7 @@ function formatMessage(name) {
def test_replace_function_with_regex(self, js_support, temp_project):
"""Test replacing function containing regex patterns."""
original_source = """\
function validateEmail(email) {
export function validateEmail(email) {
const pattern = /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$/;
if (pattern.test(email)) {
return true;
@ -1635,7 +1629,7 @@ function validateEmail(email) {
func = functions[0]
optimized_code = """\
function validateEmail(email) {
export function validateEmail(email) {
return /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$/.test(email);
}
"""
@ -1643,7 +1637,7 @@ function validateEmail(email) {
result = js_support.replace_function(original_source, func, optimized_code)
expected_result = """\
function validateEmail(email) {
export function validateEmail(email) {
return /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$/.test(email);
}
"""
@ -1657,11 +1651,11 @@ class TestModuleExportHandling:
def test_replace_exported_function_commonjs(self, js_support, temp_project):
"""Test replacing function in CommonJS module preserves exports."""
original_source = """\
function helper(x) {
export function helper(x) {
return x * 2;
}
function main(data) {
export function main(data) {
const results = [];
for (let i = 0; i < data.length; i++) {
results.push(helper(data[i]));
@ -1678,7 +1672,7 @@ module.exports = { main, helper };
main_func = next(f for f in functions if f.function_name == "main")
optimized_code = """\
function main(data) {
export function main(data) {
return data.map(helper);
}
"""
@ -1686,11 +1680,11 @@ function main(data) {
result = js_support.replace_function(original_source, main_func, optimized_code)
expected_result = """\
function helper(x) {
export function helper(x) {
return x * 2;
}
function main(data) {
export function main(data) {
return data.map(helper);
}
@ -1749,18 +1743,18 @@ class TestSyntaxValidation:
test_cases = [
# (original, optimized, description)
(
"function f(x) { return x + 1; }",
"function f(x) { return ++x; }",
"export function f(x) { return x + 1; }",
"export function f(x) { return ++x; }",
"increment replacement"
),
(
"function f(arr) { return arr.length > 0; }",
"function f(arr) { return !!arr.length; }",
"export function f(arr) { return arr.length > 0; }",
"export function f(arr) { return !!arr.length; }",
"boolean conversion"
),
(
"function f(a, b) { if (a) { return a; } return b; }",
"function f(a, b) { return a || b; }",
"export function f(a, b) { if (a) { return a; } return b; }",
"export function f(a, b) { return a || b; }",
"logical OR replacement"
),
]

View file

@ -38,7 +38,7 @@ def add(a, b):
return a + b
""",
javascript="""
function add(a, b) {
export function add(a, b) {
return a + b;
}
""",
@ -58,15 +58,15 @@ def multiply(a, b):
return a * b
""",
javascript="""
function add(a, b) {
export function add(a, b) {
return a + b;
}
function subtract(a, b) {
export function subtract(a, b) {
return a - b;
}
function multiply(a, b) {
export function multiply(a, b) {
return a * b;
}
""",
@ -83,11 +83,11 @@ def without_return():
print("hello")
""",
javascript="""
function withReturn() {
export function withReturn() {
return 1;
}
function withoutReturn() {
export function withoutReturn() {
console.log("hello");
}
""",
@ -105,7 +105,7 @@ class Calculator:
return a * b
""",
javascript="""
class Calculator {
export class Calculator {
add(a, b) {
return a + b;
}
@ -128,11 +128,11 @@ def sync_function():
return 1
""",
javascript="""
async function fetchData(url) {
export async function fetchData(url) {
return await fetch(url);
}
function syncFunction() {
export function syncFunction() {
return 1;
}
""",
@ -148,7 +148,7 @@ def outer():
return inner()
""",
javascript="""
function outer() {
export function outer() {
function inner() {
return 1;
}
@ -167,7 +167,7 @@ class Utils:
return x * 2
""",
javascript="""
class Utils {
export class Utils {
static helper(x) {
return x * 2;
}
@ -194,7 +194,7 @@ def standalone():
return 42
""",
javascript="""
class Calculator {
export class Calculator {
add(a, b) {
return a + b;
}
@ -204,13 +204,13 @@ class Calculator {
}
}
class StringUtils {
export class StringUtils {
reverse(s) {
return s.split('').reverse().join('');
}
}
function standalone() {
export function standalone() {
return 42;
}
""",
@ -227,11 +227,11 @@ def sync_func():
return 2
""",
javascript="""
async function asyncFunc() {
export async function asyncFunc() {
return 1;
}
function syncFunc() {
export function syncFunc() {
return 2;
}
""",
@ -249,11 +249,11 @@ class MyClass:
return 2
""",
javascript="""
function standalone() {
export function standalone() {
return 1;
}
class MyClass {
export class MyClass {
method() {
return 2;
}
@ -906,7 +906,7 @@ class TestIntegrationParity:
return n
return fibonacci(n - 1) + fibonacci(n - 2)
"""
js_original = """function fibonacci(n) {
js_original = """export function fibonacci(n) {
if (n <= 1) {
return n;
}
@ -933,7 +933,7 @@ class TestIntegrationParity:
memo[i] = memo[i-1] + memo[i-2]
return memo[n]
"""
js_optimized = """function fibonacci(n) {
js_optimized = """export function fibonacci(n) {
// Memoized version
const memo = {0: 0, 1: 1};
for (let i = 2; i <= n; i++) {
@ -994,13 +994,13 @@ class TestFeatureGaps:
def test_arrow_functions_unique_to_js(self, js_support):
"""JavaScript arrow functions should be discovered (no Python equivalent)."""
js_code = """
const add = (a, b) => {
export const add = (a, b) => {
return a + b;
};
const multiply = (x, y) => x * y;
export const multiply = (x, y) => x * y;
const identity = x => x;
export const identity = x => x;
"""
js_file = write_temp_file(js_code, ".js")
funcs = js_support.discover_functions(js_file)
@ -1021,7 +1021,7 @@ def number_generator():
return 3
"""
js_code = """
function* numberGenerator() {
export function* numberGenerator() {
yield 1;
yield 2;
return 3;
@ -1065,11 +1065,11 @@ def multi_decorated():
def test_function_expressions_js(self, js_support):
"""JavaScript function expressions should be discovered."""
js_code = """
const add = function(a, b) {
export const add = function(a, b) {
return a + b;
};
const namedExpr = function myFunc(x) {
export const namedExpr = function myFunc(x) {
return x * 2;
};
"""
@ -1132,7 +1132,7 @@ def greeting():
return "Hello, 世界! 🌍"
"""
js_code = """
function greeting() {
export function greeting() {
return "Hello, 世界! 🌍";
}
"""

View file

@ -168,6 +168,11 @@ def test_js_replcement() -> None:
const { sumArray, average, findMax, findMin } = require('./math_helpers');
/**
* Calculate statistics for an array of numbers.
* @param numbers - Array of numbers to analyze
* @returns Object containing sum, average, min, max, and range
*/
/**
* This is a modified comment
*/
@ -211,7 +216,7 @@ function calculateStats(numbers) {
* @param numbers - Array of numbers to normalize
* @returns Normalized array
*/
function normalizeArray(numbers) {
export function normalizeArray(numbers) {
if (numbers.length === 0) return [];
const min = findMin(numbers);
@ -231,7 +236,7 @@ function normalizeArray(numbers) {
* @param weights - Array of weights (same length as values)
* @returns The weighted average
*/
function weightedAverage(values, weights) {
export function weightedAverage(values, weights) {
if (values.length === 0 || values.length !== weights.length) {
return 0;
}
@ -264,7 +269,7 @@ module.exports = {
* @param numbers - Array of numbers to sum
* @returns The sum of all numbers
*/
function sumArray(numbers) {
export function sumArray(numbers) {
// Intentionally inefficient - using reduce with spread operator
let result = 0;
for (let i = 0; i < numbers.length; i++) {
@ -278,11 +283,16 @@ function sumArray(numbers) {
* @param numbers - Array of numbers
* @returns The average value
*/
function average(numbers) {
export function average(numbers) {
if (numbers.length === 0) return 0;
return sumArray(numbers) / numbers.length;
}
/**
* Find the maximum value in an array.
* @param numbers - Array of numbers
* @returns The maximum value
*/
/**
* Normalize an array of numbers to a 0-1 range.
* @param numbers - Array of numbers to normalize
@ -301,6 +311,11 @@ function findMax(numbers) {
return max;
}
/**
* Find the minimum value in an array.
* @param numbers - Array of numbers
* @returns The minimum value
*/
/**
* Find the minimum value in an array.
* @param numbers - Array of numbers

View file

@ -119,7 +119,7 @@ class TestTypeScriptCodeExtraction:
"""Test extracting code context for a simple function."""
with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f:
f.write("""
function add(a: number, b: number): number {
export function add(a: number, b: number): number {
return a + b;
}
""")
@ -147,7 +147,7 @@ import * as utils from "./utils";
const command_args = process.argv.slice(3);
async function execMongoEval(queryExpression, appsmithMongoURI) {
export async function execMongoEval(queryExpression, appsmithMongoURI) {
queryExpression = queryExpression.trim();
if (command_args.includes("--pretty")) {
@ -186,7 +186,7 @@ async function execMongoEval(queryExpression, appsmithMongoURI) {
import fsPromises from "fs/promises";
import path from "path";
async function figureOutContentsPath(root: string): Promise<string> {
export async function figureOutContentsPath(root: string): Promise<string> {
const subfolders = await fsPromises.readdir(root, { withFileTypes: true });
try {
@ -238,7 +238,7 @@ async function figureOutContentsPath(root: string): Promise<string> {
import fs from "fs";
import path from "path";
function readConfig(filename: string): string {
export function readConfig(filename: string): string {
const fullPath = path.join(__dirname, filename);
return fs.readFileSync(fullPath, "utf8");
}
@ -264,7 +264,7 @@ function readConfig(filename: string): string {
const CONFIG = { timeout: 5000 };
const MAX_RETRIES = 3;
async function fetchWithRetry(url: string): Promise<any> {
export async function fetchWithRetry(url: string): Promise<any> {
for (let i = 0; i < MAX_RETRIES; i++) {
try {
const response = await fetch(url, { signal: AbortSignal.timeout(CONFIG.timeout) });
@ -289,6 +289,164 @@ async function fetchWithRetry(url: string): Promise<any> {
assert ts_support.validate_syntax(code_context.target_code) is True
class TestSameClassHelperExtraction:
"""Tests for same-class helper method extraction.
When a class method calls other methods from the same class, those helper
methods should be included inside the class wrapper (not appended outside),
because they may use class-specific syntax like 'private'.
"""
def test_private_helper_method_inside_class_wrapper(self, ts_support):
"""Test that private helper methods are included inside the class wrapper."""
with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f:
# Export the class and add return statements so discover_functions finds the methods
f.write("""
export class EndpointGroup {
private endpoints: any[] = [];
constructor() {
this.endpoints = [];
}
post(path: string, handler: Function): EndpointGroup {
this.addEndpoint("POST", path, handler);
return this;
}
private addEndpoint(method: string, path: string, handler: Function): void {
this.endpoints.push({ method, path, handler });
return;
}
}
""")
f.flush()
file_path = Path(f.name)
# Discover the 'post' method
functions = ts_support.discover_functions(file_path)
post_method = None
for func in functions:
if func.function_name == "post":
post_method = func
break
assert post_method is not None, "post method should be discovered"
# Extract code context
code_context = ts_support.extract_code_context(
post_method, file_path.parent, file_path.parent
)
# The extracted code should be syntactically valid
assert ts_support.validate_syntax(code_context.target_code) is True, (
f"Extracted code should be valid TypeScript:\n{code_context.target_code}"
)
# Both post and addEndpoint should be inside the class
assert "class EndpointGroup" in code_context.target_code
assert "post(" in code_context.target_code
assert "private addEndpoint" in code_context.target_code
# The private method should be inside the class, not outside
# Check that addEndpoint appears BEFORE the closing brace of the class
class_end_index = code_context.target_code.rfind("}")
add_endpoint_index = code_context.target_code.find("addEndpoint")
assert add_endpoint_index < class_end_index, (
"addEndpoint should be inside the class wrapper"
)
def test_multiple_private_helpers_inside_class(self, ts_support):
"""Test that multiple private helpers are all included inside the class."""
with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f:
f.write("""
export class Router {
private routes: Map<string, any> = new Map();
addRoute(path: string, handler: Function): boolean {
const normalizedPath = this.normalizePath(path);
this.validatePath(normalizedPath);
this.routes.set(normalizedPath, handler);
return true;
}
private normalizePath(path: string): string {
return path.toLowerCase().trim();
}
private validatePath(path: string): boolean {
if (!path.startsWith("/")) {
throw new Error("Path must start with /");
}
return true;
}
}
""")
f.flush()
file_path = Path(f.name)
# Discover the 'addRoute' method
functions = ts_support.discover_functions(file_path)
add_route_method = None
for func in functions:
if func.function_name == "addRoute":
add_route_method = func
break
assert add_route_method is not None
code_context = ts_support.extract_code_context(
add_route_method, file_path.parent, file_path.parent
)
# Should be valid TypeScript
assert ts_support.validate_syntax(code_context.target_code) is True
# All methods should be inside the class
assert "private normalizePath" in code_context.target_code
assert "private validatePath" in code_context.target_code
def test_same_class_helpers_filtered_from_helper_list(self, ts_support):
"""Test that same-class helpers are not duplicated in the helpers list."""
with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f:
f.write("""
export class Calculator {
add(a: number, b: number): number {
return this.compute(a, b, "+");
}
private compute(a: number, b: number, op: string): number {
if (op === "+") return a + b;
return 0;
}
}
""")
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
add_method = None
for func in functions:
if func.function_name == "add":
add_method = func
break
assert add_method is not None
code_context = ts_support.extract_code_context(
add_method, file_path.parent, file_path.parent
)
# 'compute' should be in target_code (inside class)
assert "compute" in code_context.target_code
# 'compute' should NOT be in helper_functions (would be duplicate)
helper_names = [h.name for h in code_context.helper_functions]
assert "compute" not in helper_names, (
"Same-class helper 'compute' should not be in helper_functions list"
)
class TestTypeScriptLanguageProperties:
"""Tests for TypeScript language support properties."""