codeflash/tests/test_languages/test_js_code_extractor.py
Kevin Turcios 19bd6e4bad test: sync test files from main (safe, main-only changes)
34 test files updated with main's refactored tests for new language
support protocol, JS/TS improvements, and code context extraction.
2026-03-02 15:25:50 -05:00

1487 lines
52 KiB
Python

"""Tests for JavaScript/TypeScript code extractor.
Uses strict string equality to verify extraction results.
"""
import shutil
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.languages.registry import get_language_support
from codeflash.models.models import FunctionParent
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
FIXTURES_DIR = Path(__file__).parent / "fixtures"
class TestCodeExtractorCJS:
"""Tests for CommonJS module code extraction."""
@pytest.fixture
def cjs_project(self, tmp_path):
"""Create a temporary CJS project from fixtures."""
project_dir = tmp_path / "cjs_project"
shutil.copytree(FIXTURES_DIR / "js_cjs", project_dir)
return project_dir
@pytest.fixture
def js_support(self):
"""Create JavaScriptSupport instance."""
return JavaScriptSupport()
def test_discover_class_methods(self, js_support, cjs_project):
"""Test that class methods are discovered correctly."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
method_names = {f.function_name for f in functions}
expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"}
assert method_names == expected_methods, f"Expected methods {expected_methods}, got {method_names}"
def test_class_method_has_correct_parent(self, js_support, cjs_project):
"""Test parent class information for methods."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
for func in functions:
# All methods should belong to Calculator class
assert func.is_method is True, f"{func.function_name} should be a method"
assert func.class_name == "Calculator", (
f"{func.function_name} should belong to Calculator, got {func.class_name}"
)
def test_extract_permutation_code(self, js_support, cjs_project):
"""Test permutation method code extraction."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
permutation_func = next(f for f in functions if f.function_name == "permutation")
context = js_support.extract_code_context(
function=permutation_func, project_root=cjs_project, module_root=cjs_project
)
expected_code = """\
class Calculator {
constructor(precision = 2) {
this.precision = precision;
this.history = [];
}
/**
* Calculate permutation using factorial helper.
* @param n - Total items
* @param r - Items to choose
* @returns Permutation result
*/
permutation(n, r) {
if (n < r) return 0;
// Inefficient: calculates factorial(n) fully even when not needed
return factorial(n) / factorial(n - r);
}
}"""
assert context.target_code is not None, "target_code should not be None"
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_extract_context_includes_direct_helpers(self, js_support, cjs_project):
"""Test that direct helper functions are included in context."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
permutation_func = next(f for f in functions if f.function_name == "permutation")
context = js_support.extract_code_context(
function=permutation_func, project_root=cjs_project, module_root=cjs_project
)
# Find factorial helper
helper_dict = {h.name: h for h in context.helper_functions}
assert "factorial" in helper_dict, f"factorial helper not found. Found helpers: {list(helper_dict.keys())}"
factorial_helper = helper_dict["factorial"]
expected_factorial_code = """\
export function factorial(n) {
// Intentionally inefficient recursive implementation
if (n <= 1) return 1;
return n * factorial(n - 1);
}"""
assert factorial_helper.source_code.strip() == expected_factorial_code.strip(), (
f"Factorial helper code does not match expected.\n"
f"Expected:\n{expected_factorial_code}\n\n"
f"Got:\n{factorial_helper.source_code}"
)
# STRICT: Verify file path ends with expected filename
assert str(factorial_helper.file_path).endswith("math_utils.js"), (
f"Expected factorial to be from math_utils.js, got {factorial_helper.file_path}"
)
def test_extract_compound_interest_code(self, js_support, cjs_project):
"""Test calculateCompoundInterest code extraction."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = js_support.extract_code_context(
function=compound_func, project_root=cjs_project, module_root=cjs_project
)
expected_code = """\
class Calculator {
constructor(precision = 2) {
this.precision = precision;
this.history = [];
}
/**
* Calculate compound interest with multiple helper dependencies.
* @param principal - Initial amount
* @param rate - Interest rate (as decimal)
* @param time - Time in years
* @param n - Compounding frequency per year
* @returns Compound interest result
*/
calculateCompoundInterest(principal, rate, time, n) {
validateInput(principal, 'principal');
validateInput(rate, 'rate');
// Inefficient: recalculates power multiple times
let result = principal;
for (let i = 0; i < n * time; i++) {
result = multiply(result, add(1, rate / n));
}
const interest = result - principal;
this.history.push({ type: 'compound', result: interest });
return formatNumber(interest, this.precision);
}
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_extract_compound_interest_helpers(self, js_support, cjs_project):
"""Test helper extraction for calculateCompoundInterest."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = js_support.extract_code_context(
function=compound_func, project_root=cjs_project, module_root=cjs_project
)
helper_dict = {h.name: h for h in context.helper_functions}
expected_helpers = {"add", "multiply", "formatNumber", "validateInput"}
actual_helpers = set(helper_dict.keys())
assert actual_helpers == expected_helpers, f"Expected helpers {expected_helpers}, got {actual_helpers}"
# STRICT: Verify each helper's code exactly
expected_add_code = """\
export function add(a, b) {
return a + b;
}"""
expected_multiply_code = """\
export function multiply(a, b) {
return a * b;
}"""
expected_format_number_code = """\
export function formatNumber(num, decimals) {
return Number(num.toFixed(decimals));
}"""
expected_validate_input_code = """\
export function validateInput(value, name) {
if (typeof value !== 'number' || isNaN(value)) {
throw new Error(`Invalid ${name}: must be a number`);
}
}"""
helper_expectations = {
"add": (expected_add_code, "math_utils.js"),
"multiply": (expected_multiply_code, "math_utils.js"),
"formatNumber": (expected_format_number_code, "format.js"),
"validateInput": (expected_validate_input_code, "format.js"),
}
for helper_name, (expected_code, expected_file) in helper_expectations.items():
helper = helper_dict[helper_name]
assert helper.source_code.strip() == expected_code.strip(), (
f"{helper_name} helper code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{helper.source_code}"
)
assert str(helper.file_path).endswith(expected_file), (
f"Expected {helper_name} to be from {expected_file}, got {helper.file_path}"
)
def test_extract_context_includes_imports(self, js_support, cjs_project):
"""Test import statement extraction."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = js_support.extract_code_context(
function=compound_func, project_root=cjs_project, module_root=cjs_project
)
expected_imports = [
"const { add, multiply, factorial } = require('./math_utils');",
"const { formatNumber, validateInput } = require('./helpers/format');",
]
assert len(context.imports) == 2, f"Expected 2 imports, got {len(context.imports)}: {context.imports}"
assert context.imports == expected_imports, (
f"Imports do not match expected.\nExpected:\n{expected_imports}\n\nGot:\n{context.imports}"
)
def test_extract_static_method(self, js_support, cjs_project):
"""Test static method extraction (quickAdd)."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
quick_add_func = next(f for f in functions if f.function_name == "quickAdd")
context = js_support.extract_code_context(
function=quick_add_func, project_root=cjs_project, module_root=cjs_project
)
expected_code = """\
class Calculator {
constructor(precision = 2) {
this.precision = precision;
this.history = [];
}
/**
* Static method for quick calculations.
*/
static quickAdd(a, b) {
return add(a, b);
}
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
# quickAdd uses add helper from math_utils
helper_dict = {h.name: h for h in context.helper_functions}
assert set(helper_dict.keys()) == {"add"}, f"Expected 'add' helper, got: {list(helper_dict.keys())}"
expected_add_code = """\
export function add(a, b) {
return a + b;
}"""
assert helper_dict["add"].source_code.strip() == expected_add_code.strip(), (
f"add helper code does not match.\nExpected:\n{expected_add_code}\n\nGot:\n{helper_dict['add'].source_code}"
)
class TestCodeExtractorESM:
"""Tests for ES Module code extraction."""
@pytest.fixture
def esm_project(self, tmp_path):
"""Create a temporary ESM project from fixtures."""
project_dir = tmp_path / "esm_project"
shutil.copytree(FIXTURES_DIR / "js_esm", project_dir)
return project_dir
@pytest.fixture
def js_support(self):
"""Create JavaScriptSupport instance."""
return JavaScriptSupport()
def test_discover_esm_methods(self, js_support, esm_project):
"""Test method discovery in ESM project."""
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
method_names = {f.function_name for f in functions}
# Should find same methods as CJS version
expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"}
assert method_names == expected_methods, f"Expected methods {expected_methods}, got {method_names}"
def test_esm_permutation_extraction(self, js_support, esm_project):
"""Test permutation method extraction in ESM."""
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
permutation_func = next(f for f in functions if f.function_name == "permutation")
context = js_support.extract_code_context(
function=permutation_func, project_root=esm_project, module_root=esm_project
)
expected_code = """\
class Calculator {
constructor(precision = 2) {
this.precision = precision;
this.history = [];
}
/**
* Calculate permutation using factorial helper.
* @param n - Total items
* @param r - Items to choose
* @returns Permutation result
*/
permutation(n, r) {
if (n < r) return 0;
// Inefficient: calculates factorial(n) fully even when not needed
return factorial(n) / factorial(n - r);
}
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
# ESM permutation uses factorial helper
helper_dict = {h.name: h for h in context.helper_functions}
assert set(helper_dict.keys()) == {"factorial"}, f"Expected 'factorial' helper, got: {list(helper_dict.keys())}"
expected_factorial_code = """\
export function factorial(n) {
// Intentionally inefficient recursive implementation
if (n <= 1) return 1;
return n * factorial(n - 1);
}"""
assert helper_dict["factorial"].source_code.strip() == expected_factorial_code.strip(), (
f"factorial helper code does not match.\nExpected:\n{expected_factorial_code}\n\nGot:\n{helper_dict['factorial'].source_code}"
)
def test_esm_compound_interest_extraction(self, js_support, esm_project):
"""Test calculateCompoundInterest extraction in ESM with import syntax."""
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = js_support.extract_code_context(
function=compound_func, project_root=esm_project, module_root=esm_project
)
expected_code = """\
class Calculator {
constructor(precision = 2) {
this.precision = precision;
this.history = [];
}
/**
* Calculate compound interest with multiple helper dependencies.
* @param principal - Initial amount
* @param rate - Interest rate (as decimal)
* @param time - Time in years
* @param n - Compounding frequency per year
* @returns Compound interest result
*/
calculateCompoundInterest(principal, rate, time, n) {
validateInput(principal, 'principal');
validateInput(rate, 'rate');
// Inefficient: recalculates power multiple times
let result = principal;
for (let i = 0; i < n * time; i++) {
result = multiply(result, add(1, rate / n));
}
const interest = result - principal;
this.history.push({ type: 'compound', result: interest });
return formatNumber(interest, this.precision);
}
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
expected_imports = [
"import { add, multiply, factorial } from './math_utils.js';",
"import { formatNumber, validateInput } from './helpers/format.js';",
]
assert len(context.imports) == 2, f"Expected 2 imports, got {len(context.imports)}: {context.imports}"
assert context.imports == expected_imports, (
f"Imports do not match expected.\nExpected:\n{expected_imports}\n\nGot:\n{context.imports}"
)
# ESM compound interest uses 4 helpers
helper_dict = {h.name: h for h in context.helper_functions}
expected_helper_names = {"validateInput", "formatNumber", "add", "multiply"}
assert set(helper_dict.keys()) == expected_helper_names, (
f"Expected helpers {expected_helper_names}, got: {set(helper_dict.keys())}"
)
expected_validate_input_code = """\
export function validateInput(value, name) {
if (typeof value !== 'number' || isNaN(value)) {
throw new Error(`Invalid ${name}: must be a number`);
}
}"""
expected_format_number_code = """\
export function formatNumber(num, decimals) {
return Number(num.toFixed(decimals));
}"""
expected_add_code = """\
export function add(a, b) {
return a + b;
}"""
expected_multiply_code = """\
export function multiply(a, b) {
return a * b;
}"""
helper_expectations = {
"validateInput": expected_validate_input_code,
"formatNumber": expected_format_number_code,
"add": expected_add_code,
"multiply": expected_multiply_code,
}
for helper_name, expected_code in helper_expectations.items():
assert helper_dict[helper_name].source_code.strip() == expected_code.strip(), (
f"{helper_name} helper code does not match.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{helper_dict[helper_name].source_code}"
)
class TestCodeExtractorTypeScript:
"""Tests for TypeScript code extraction."""
@pytest.fixture
def ts_project(self, tmp_path):
"""Create a temporary TypeScript project from fixtures."""
project_dir = tmp_path / "ts_project"
shutil.copytree(FIXTURES_DIR / "ts", project_dir)
return project_dir
@pytest.fixture
def ts_support(self):
"""Create TypeScriptSupport instance."""
return TypeScriptSupport()
def test_typescript_support_properties(self, ts_support):
"""Test TypeScriptSupport properties."""
assert ts_support.language == Language.TYPESCRIPT
# STRICT: Verify exact file extensions
expected_extensions = {".ts", ".tsx"}
actual_extensions = set(ts_support.file_extensions)
assert expected_extensions.issubset(actual_extensions), (
f"Expected extensions {expected_extensions} to be subset of {actual_extensions}"
)
def test_discover_ts_methods(self, ts_support, ts_project):
"""Test method discovery in TypeScript."""
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
method_names = {f.function_name for f in functions}
# TypeScript has additional getHistory method
expected_methods = {"calculateCompoundInterest", "permutation", "getHistory", "quickAdd"}
assert method_names == expected_methods, f"Expected methods {expected_methods}, got {method_names}"
def test_ts_permutation_extraction(self, ts_support, ts_project):
"""Test permutation method extraction in TypeScript."""
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
permutation_func = next(f for f in functions if f.function_name == "permutation")
context = ts_support.extract_code_context(
function=permutation_func, project_root=ts_project, module_root=ts_project
)
expected_code = """\
class Calculator {
private precision: number;
private history: HistoryEntry[];
constructor(precision: number = 2) {
this.precision = precision;
this.history = [];
}
/**
* Calculate permutation using factorial helper.
* @param n - Total items
* @param r - Items to choose
* @returns Permutation result
*/
permutation(n: number, r: number): number {
if (n < r) return 0;
// Inefficient: calculates factorial(n) fully even when not needed
return factorial(n) / factorial(n - r);
}
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
# TypeScript permutation uses factorial helper
helper_dict = {h.name: h for h in context.helper_functions}
assert set(helper_dict.keys()) == {"factorial"}, f"Expected 'factorial' helper, got: {list(helper_dict.keys())}"
expected_factorial_code = """\
export function factorial(n: number): number {
// Intentionally inefficient recursive implementation
if (n <= 1) return 1;
return n * factorial(n - 1);
}"""
assert helper_dict["factorial"].source_code.strip() == expected_factorial_code.strip(), (
f"factorial helper code does not match.\nExpected:\n{expected_factorial_code}\n\nGot:\n{helper_dict['factorial'].source_code}"
)
def test_ts_compound_interest_extraction(self, ts_support, ts_project):
"""Test calculateCompoundInterest extraction in TypeScript."""
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = ts_support.extract_code_context(
function=compound_func, project_root=ts_project, module_root=ts_project
)
expected_code = """\
class Calculator {
private precision: number;
private history: HistoryEntry[];
constructor(precision: number = 2) {
this.precision = precision;
this.history = [];
}
/**
* Calculate compound interest with multiple helper dependencies.
* @param principal - Initial amount
* @param rate - Interest rate (as decimal)
* @param time - Time in years
* @param n - Compounding frequency per year
* @returns Compound interest result
*/
calculateCompoundInterest(principal: number, rate: number, time: number, n: number): number {
validateInput(principal, 'principal');
validateInput(rate, 'rate');
// Inefficient: recalculates power multiple times
let result = principal;
for (let i = 0; i < n * time; i++) {
result = multiply(result, add(1, rate / n));
}
const interest = result - principal;
this.history.push({ type: 'compound', result: interest });
return formatNumber(interest, this.precision);
}
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
# TypeScript compound interest uses 4 helpers
helper_dict = {h.name: h for h in context.helper_functions}
expected_helper_names = {"validateInput", "formatNumber", "add", "multiply"}
assert set(helper_dict.keys()) == expected_helper_names, (
f"Expected helpers {expected_helper_names}, got: {set(helper_dict.keys())}"
)
expected_validate_input_code = """\
export function validateInput(value: unknown, name: string): asserts value is number {
if (typeof value !== 'number' || isNaN(value)) {
throw new Error(`Invalid ${name}: must be a number`);
}
}"""
expected_format_number_code = """\
export function formatNumber(num: number, decimals: number): number {
return Number(num.toFixed(decimals));
}"""
expected_add_code = """\
export function add(a: number, b: number): number {
return a + b;
}"""
expected_multiply_code = """\
export function multiply(a: number, b: number): number {
return a * b;
}"""
helper_expectations = {
"validateInput": expected_validate_input_code,
"formatNumber": expected_format_number_code,
"add": expected_add_code,
"multiply": expected_multiply_code,
}
for helper_name, expected_code in helper_expectations.items():
assert helper_dict[helper_name].source_code.strip() == expected_code.strip(), (
f"{helper_name} helper code does not match.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{helper_dict[helper_name].source_code}"
)
class TestCodeExtractorEdgeCases:
"""Tests for edge cases."""
@pytest.fixture
def js_support(self):
"""Create JavaScriptSupport instance."""
return JavaScriptSupport()
def test_standalone_function(self, js_support, tmp_path):
"""Test standalone function with no helpers."""
source = """\
export function standalone(x) {
return x * 2;
}
module.exports = { standalone };
"""
test_file = tmp_path / "standalone.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
func = next(f for f in functions if f.function_name == "standalone")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
# STRICT: Exact code comparison
expected_code = """\
export function standalone(x) {
return x * 2;
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
# STRICT: Exactly zero helpers
assert len(context.helper_functions) == 0, (
f"Expected 0 helpers, got {len(context.helper_functions)}: {[h.name for h in context.helper_functions]}"
)
def test_external_package_excluded(self, js_support, tmp_path):
"""Test external packages are not resolved as helpers."""
source = """\
const _ = require('lodash');
export function processArray(arr) {
return _.map(arr, x => x * 2);
}
module.exports = { processArray };
"""
test_file = tmp_path / "processor.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
func = next(f for f in functions if f.function_name == "processArray")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
export function processArray(arr) {
return _.map(arr, x => x * 2);
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
expected_imports = ["const _ = require('lodash');"]
assert context.imports == expected_imports, (
f"Imports do not match expected.\nExpected:\n{expected_imports}\n\nGot:\n{context.imports}"
)
helper_names = {h.name for h in context.helper_functions}
assert helper_names == set(), f"Expected no helpers for external package usage, got: {helper_names}"
def test_recursive_function(self, js_support, tmp_path):
"""Test recursive function doesn't list itself as helper."""
source = """\
export function fibonacci(n) {
if (n <= 1) return n;
return fibonacci(n - 1) + fibonacci(n - 2);
}
module.exports = { fibonacci };
"""
test_file = tmp_path / "recursive.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
func = next(f for f in functions if f.function_name == "fibonacci")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
# STRICT: Exact code comparison
expected_code = """\
export function fibonacci(n) {
if (n <= 1) return n;
return fibonacci(n - 1) + fibonacci(n - 2);
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
# STRICT: Function should NOT be its own helper
helper_names = {h.name for h in context.helper_functions}
assert "fibonacci" not in helper_names, f"Recursive function listed itself as helper. Helpers: {helper_names}"
def test_arrow_function_helper(self, js_support, tmp_path):
"""Test arrow function helper extraction."""
source = """\
const helper = (x) => x * 2;
export const processValue = (value) => {
return helper(value) + 1;
};
module.exports = { processValue };
"""
test_file = tmp_path / "arrow.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
func = next(f for f in functions if f.function_name == "processValue")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
export const processValue = (value) => {
return helper(value) + 1;
};"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
assert context.imports == [], f"Expected no imports, got: {context.imports}"
helper_dict = {h.name: h for h in context.helper_functions}
assert set(helper_dict.keys()) == {"helper"}, f"Expected only 'helper', got: {list(helper_dict.keys())}"
expected_helper_code = "const helper = (x) => x * 2;"
actual_helper_code = helper_dict["helper"].source_code.strip()
assert actual_helper_code == expected_helper_code, (
f"Helper code does not match.\nExpected:\n{expected_helper_code}\n\nGot:\n{actual_helper_code}"
)
class TestClassContextExtraction:
"""Tests for class constructor and field extraction in code context."""
@pytest.fixture
def js_support(self):
"""Create JavaScriptSupport instance."""
return JavaScriptSupport()
@pytest.fixture
def ts_support(self):
"""Create TypeScriptSupport instance."""
return TypeScriptSupport()
def test_method_extraction_includes_constructor(self, js_support, tmp_path):
"""Test that extracting a class method includes the constructor."""
source = """\
export class Counter {
constructor(initial = 0) {
this.count = initial;
}
increment() {
this.count++;
return this.count;
}
}
module.exports = { Counter };
"""
test_file = tmp_path / "counter.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
increment_func = next(f for f in functions if f.function_name == "increment")
context = js_support.extract_code_context(function=increment_func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
class Counter {
constructor(initial = 0) {
this.count = initial;
}
increment() {
this.count++;
return this.count;
}
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_method_extraction_class_without_constructor(self, js_support, tmp_path):
"""Test extracting a method from a class that has no constructor."""
source = """\
export class MathUtils {
add(a, b) {
return a + b;
}
multiply(a, b) {
return a * b;
}
}
module.exports = { MathUtils };
"""
test_file = tmp_path / "math_utils.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
add_func = next(f for f in functions if f.function_name == "add")
context = js_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
class MathUtils {
add(a, b) {
return a + b;
}
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_typescript_method_extraction_includes_fields(self, ts_support, tmp_path):
"""Test that TypeScript method extraction includes class fields."""
source = """\
export class User {
private name: string;
public age: number;
constructor(name: string, age: number) {
this.name = name;
this.age = age;
}
getName(): string {
return this.name;
}
}
"""
test_file = tmp_path / "user.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
get_name_func = next(f for f in functions if f.function_name == "getName")
context = ts_support.extract_code_context(function=get_name_func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
class User {
private name: string;
public age: number;
constructor(name: string, age: number) {
this.name = name;
this.age = age;
}
getName(): string {
return this.name;
}
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_typescript_fields_only_no_constructor(self, ts_support, tmp_path):
"""Test TypeScript class with fields but no constructor."""
source = """\
export class Config {
readonly apiUrl: string = "https://api.example.com";
timeout: number = 5000;
getUrl(): string {
return this.apiUrl;
}
}
"""
test_file = tmp_path / "config.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
get_url_func = next(f for f in functions if f.function_name == "getUrl")
context = ts_support.extract_code_context(function=get_url_func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
class Config {
readonly apiUrl: string = "https://api.example.com";
timeout: number = 5000;
getUrl(): string {
return this.apiUrl;
}
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_constructor_with_jsdoc(self, js_support, tmp_path):
"""Test that constructor with JSDoc is fully extracted."""
source = """\
export class Logger {
/**
* Create a new Logger instance.
* @param {string} prefix - The prefix to use for log messages.
*/
constructor(prefix) {
this.prefix = prefix;
}
getPrefix() {
return this.prefix;
}
}
module.exports = { Logger };
"""
test_file = tmp_path / "logger.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
get_prefix_func = next(f for f in functions if f.function_name == "getPrefix")
context = js_support.extract_code_context(function=get_prefix_func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
class Logger {
/**
* Create a new Logger instance.
* @param {string} prefix - The prefix to use for log messages.
*/
constructor(prefix) {
this.prefix = prefix;
}
getPrefix() {
return this.prefix;
}
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_static_method_includes_constructor(self, js_support, tmp_path):
"""Test that static method extraction also includes constructor context."""
source = """\
export class Factory {
constructor(config) {
this.config = config;
}
static create(type) {
return new Factory({ type: type });
}
}
module.exports = { Factory };
"""
test_file = tmp_path / "factory.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
create_func = next(f for f in functions if f.function_name == "create")
context = js_support.extract_code_context(function=create_func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
class Factory {
constructor(config) {
this.config = config;
}
static create(type) {
return new Factory({ type: type });
}
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
class TestCodeExtractorIntegration:
"""Integration tests with FunctionOptimizer."""
@pytest.fixture
def cjs_project(self, tmp_path):
"""Create a temporary CJS project from fixtures."""
project_dir = tmp_path / "cjs_project"
shutil.copytree(FIXTURES_DIR / "js_cjs", project_dir)
return project_dir
def test_function_optimizer_workflow(self, cjs_project):
"""Test full FunctionOptimizer workflow."""
from codeflash.languages import current as lang_current
from codeflash.languages.base import Language
# Force set language to JavaScript for proper context extraction routing
lang_current._current_language = Language.JAVASCRIPT
js_support = get_language_support("javascript")
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
target = next(f for f in functions if f.function_name == "permutation")
parents = [FunctionParent(name=p.name, type=p.type) for p in target.parents]
func = FunctionToOptimize(
function_name=target.function_name,
file_path=target.file_path,
parents=parents,
starting_line=target.starting_line,
ending_line=target.ending_line,
starting_col=target.starting_col,
ending_col=target.ending_col,
is_async=target.is_async,
is_method=target.is_method,
language=target.language,
)
test_config = TestConfig(
tests_root=cjs_project / "tests",
tests_project_rootdir=cjs_project,
project_root_path=cjs_project,
pytest_cmd="jest",
)
func_optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
)
result = func_optimizer.get_code_optimization_context()
context = result.unwrap()
assert context.read_writable_code is not None, "read_writable_code should not be None"
# FunctionSource uses only_function_name, not name
helper_names = {h.only_function_name for h in context.helper_functions}
assert "factorial" in helper_names, f"factorial helper not found. Found: {helper_names}"
class TestTypeDefinitionExtraction:
"""Tests for TypeScript type definition extraction in read-only context."""
@pytest.fixture
def ts_support(self):
"""Create TypeScriptSupport instance."""
return TypeScriptSupport()
@pytest.fixture
def ts_types_project(self, tmp_path):
"""Create a temporary TypeScript project with type definitions."""
project_dir = tmp_path / "ts_types_project"
project_dir.mkdir()
# Create types.ts with type definitions
types_file = project_dir / "types.ts"
types_file.write_text("""\
/**
* Configuration options for calculations.
*/
export interface CalculationConfig {
precision: number;
enableCaching: boolean;
}
/**
* Point in 2D space.
*/
export interface Point {
x: number;
y: number;
}
/**
* Rounding mode enum.
*/
export enum RoundingMode {
FLOOR = 'floor',
CEIL = 'ceil',
ROUND = 'round',
}
/**
* Result type alias.
*/
export type Result<T> = {
value: T;
success: boolean;
};
""")
return project_dir
def test_extract_same_file_interface_from_parameter(self, ts_support, tmp_path):
"""Test extracting interface type definition when used in function parameter."""
source = """\
interface Point {
x: number;
y: number;
}
export function distance(p1: Point, p2: Point): number {
const dx = p2.x - p1.x;
const dy = p2.y - p1.y;
return Math.sqrt(dx * dx + dy * dy);
}
"""
test_file = tmp_path / "geometry.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
distance_func = next(f for f in functions if f.function_name == "distance")
context = ts_support.extract_code_context(function=distance_func, project_root=tmp_path, module_root=tmp_path)
# Type definition should be in read-only context with exact match
expected_read_only = """\
interface Point {
x: number;
y: number;
}"""
assert context.read_only_context is not None, "read_only_context should not be None"
assert context.read_only_context.strip() == expected_read_only.strip(), (
f"Read-only context does not match expected.\n"
f"Expected:\n{expected_read_only}\n\n"
f"Got:\n{context.read_only_context}"
)
def test_extract_same_file_enum_from_parameter(self, ts_support, tmp_path):
"""Test extracting enum type definition when used in function parameter."""
source = """\
enum Status {
PENDING = 'pending',
SUCCESS = 'success',
FAILURE = 'failure',
}
export function processStatus(status: Status): string {
switch (status) {
case Status.PENDING:
return 'Processing...';
case Status.SUCCESS:
return 'Done!';
case Status.FAILURE:
return 'Failed!';
}
}
"""
test_file = tmp_path / "status.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
process_func = next(f for f in functions if f.function_name == "processStatus")
context = ts_support.extract_code_context(function=process_func, project_root=tmp_path, module_root=tmp_path)
# Enum should be in read-only context with exact match
expected_read_only = """\
enum Status {
PENDING = 'pending',
SUCCESS = 'success',
FAILURE = 'failure',
}"""
assert context.read_only_context is not None, "read_only_context should not be None"
assert context.read_only_context.strip() == expected_read_only.strip(), (
f"Read-only context does not match expected.\n"
f"Expected:\n{expected_read_only}\n\n"
f"Got:\n{context.read_only_context}"
)
def test_extract_same_file_type_alias_from_return_type(self, ts_support, tmp_path):
"""Test extracting type alias when used in function return type."""
source = """\
type Result<T> = {
value: T;
success: boolean;
};
export function compute(x: number): Result<number> {
return { value: x * 2, success: true };
}
"""
test_file = tmp_path / "compute.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
compute_func = next(f for f in functions if f.function_name == "compute")
context = ts_support.extract_code_context(function=compute_func, project_root=tmp_path, module_root=tmp_path)
# Type alias should be in read-only context with exact match
expected_read_only = """\
type Result<T> = {
value: T;
success: boolean;
};"""
assert context.read_only_context is not None, "read_only_context should not be None"
assert context.read_only_context.strip() == expected_read_only.strip(), (
f"Read-only context does not match expected.\n"
f"Expected:\n{expected_read_only}\n\n"
f"Got:\n{context.read_only_context}"
)
def test_extract_class_field_types(self, ts_support, tmp_path):
"""Test extracting type definitions used in class fields."""
source = """\
interface Config {
timeout: number;
retries: number;
}
export class Service {
private config: Config;
constructor(config: Config) {
this.config = config;
}
getTimeout(): number {
return this.config.timeout;
}
}
"""
test_file = tmp_path / "service.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
get_timeout_func = next(f for f in functions if f.function_name == "getTimeout")
context = ts_support.extract_code_context(
function=get_timeout_func, project_root=tmp_path, module_root=tmp_path
)
# Config interface should be in read-only context with exact match
expected_read_only = """\
interface Config {
timeout: number;
retries: number;
}"""
assert context.read_only_context is not None, "read_only_context should not be None"
assert context.read_only_context.strip() == expected_read_only.strip(), (
f"Read-only context does not match expected.\n"
f"Expected:\n{expected_read_only}\n\n"
f"Got:\n{context.read_only_context}"
)
def test_primitive_types_not_included(self, ts_support, tmp_path):
"""Test that primitive types (number, string, etc.) are not extracted."""
source = """\
export function add(a: number, b: number): number {
return a + b;
}
"""
test_file = tmp_path / "add.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
add_func = next(f for f in functions if f.function_name == "add")
context = ts_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
# No type definitions should be extracted for primitives - exact empty match
assert context.read_only_context == "", (
f"Should not extract type definitions for primitive types.\n"
f"Expected empty string, got:\n{context.read_only_context}"
)
def test_extract_multiple_types(self, ts_support, tmp_path):
"""Test extracting multiple type definitions from same file."""
source = """\
interface Point {
x: number;
y: number;
}
interface Size {
width: number;
height: number;
}
export function createRect(origin: Point, size: Size): { origin: Point; size: Size } {
return { origin, size };
}
"""
test_file = tmp_path / "rect.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
create_rect_func = next(f for f in functions if f.function_name == "createRect")
context = ts_support.extract_code_context(
function=create_rect_func, project_root=tmp_path, module_root=tmp_path
)
# Both Point and Size should be in read-only context with exact match
expected_read_only = """\
interface Point {
x: number;
y: number;
}
interface Size {
width: number;
height: number;
}"""
assert context.read_only_context is not None, "read_only_context should not be None"
assert context.read_only_context.strip() == expected_read_only.strip(), (
f"Read-only context does not match expected.\n"
f"Expected:\n{expected_read_only}\n\n"
f"Got:\n{context.read_only_context}"
)
def test_extract_imported_type_definition(self, ts_support, ts_types_project):
"""Test extracting type definitions from imported files."""
# Create a file that imports types from types.ts
geometry_file = ts_types_project / "geometry.ts"
geometry_file.write_text("""\
import { Point, CalculationConfig } from './types';
export function calculateDistance(p1: Point, p2: Point, config: CalculationConfig): number {
const dx = p2.x - p1.x;
const dy = p2.y - p1.y;
const distance = Math.sqrt(dx * dx + dy * dy);
if (config.precision > 0) {
const factor = Math.pow(10, config.precision);
return Math.round(distance * factor) / factor;
}
return distance;
}
""")
functions = ts_support.discover_functions(geometry_file.read_text(encoding="utf-8"), geometry_file)
calc_distance_func = next(f for f in functions if f.function_name == "calculateDistance")
context = ts_support.extract_code_context(
function=calc_distance_func, project_root=ts_types_project, module_root=ts_types_project
)
# Imported type definitions should be in read-only context with exact match
# Types are sorted by file path and line number, with file comments
# Note: The extraction uses tree-sitter which doesn't capture JSDoc for interface
# definitions in separate files - this is a known limitation
expected_read_only = """\
// From types.ts
interface CalculationConfig {
precision: number;
enableCaching: boolean;
}
interface Point {
x: number;
y: number;
}"""
assert context.read_only_context is not None, "read_only_context should not be None"
assert context.read_only_context.strip() == expected_read_only.strip(), (
f"Read-only context does not match expected.\n"
f"Expected:\n{expected_read_only}\n\n"
f"Got:\n{context.read_only_context}"
)
def test_type_with_jsdoc_included(self, ts_support, tmp_path):
"""Test that JSDoc comments are included with type definitions."""
source = """\
/**
* Represents a user in the system.
* @property id - Unique identifier
* @property name - Display name
*/
interface User {
id: string;
name: string;
}
export function greetUser(user: User): string {
return `Hello, ${user.name}!`;
}
"""
test_file = tmp_path / "user.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
greet_func = next(f for f in functions if f.function_name == "greetUser")
context = ts_support.extract_code_context(function=greet_func, project_root=tmp_path, module_root=tmp_path)
# JSDoc should be included with the interface - exact match
expected_read_only = """\
/**
* Represents a user in the system.
* @property id - Unique identifier
* @property name - Display name
*/
interface User {
id: string;
name: string;
}"""
assert context.read_only_context is not None, "read_only_context should not be None"
assert context.read_only_context.strip() == expected_read_only.strip(), (
f"Read-only context does not match expected.\n"
f"Expected:\n{expected_read_only}\n\n"
f"Got:\n{context.read_only_context}"
)