mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
34 test files updated with main's refactored tests for new language support protocol, JS/TS improvements, and code context extraction.
1487 lines
52 KiB
Python
1487 lines
52 KiB
Python
"""Tests for JavaScript/TypeScript code extractor.
|
|
|
|
Uses strict string equality to verify extraction results.
|
|
"""
|
|
|
|
import shutil
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
|
from codeflash.languages.base import Language
|
|
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
|
from codeflash.languages.registry import get_language_support
|
|
from codeflash.models.models import FunctionParent
|
|
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
|
from codeflash.verification.verification_utils import TestConfig
|
|
|
|
FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
|
|
|
|
|
class TestCodeExtractorCJS:
|
|
"""Tests for CommonJS module code extraction."""
|
|
|
|
@pytest.fixture
|
|
def cjs_project(self, tmp_path):
|
|
"""Create a temporary CJS project from fixtures."""
|
|
project_dir = tmp_path / "cjs_project"
|
|
shutil.copytree(FIXTURES_DIR / "js_cjs", project_dir)
|
|
return project_dir
|
|
|
|
@pytest.fixture
|
|
def js_support(self):
|
|
"""Create JavaScriptSupport instance."""
|
|
return JavaScriptSupport()
|
|
|
|
def test_discover_class_methods(self, js_support, cjs_project):
|
|
"""Test that class methods are discovered correctly."""
|
|
calculator_file = cjs_project / "calculator.js"
|
|
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
method_names = {f.function_name for f in functions}
|
|
|
|
expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"}
|
|
assert method_names == expected_methods, f"Expected methods {expected_methods}, got {method_names}"
|
|
|
|
def test_class_method_has_correct_parent(self, js_support, cjs_project):
|
|
"""Test parent class information for methods."""
|
|
calculator_file = cjs_project / "calculator.js"
|
|
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
for func in functions:
|
|
# All methods should belong to Calculator class
|
|
assert func.is_method is True, f"{func.function_name} should be a method"
|
|
assert func.class_name == "Calculator", (
|
|
f"{func.function_name} should belong to Calculator, got {func.class_name}"
|
|
)
|
|
|
|
def test_extract_permutation_code(self, js_support, cjs_project):
|
|
"""Test permutation method code extraction."""
|
|
calculator_file = cjs_project / "calculator.js"
|
|
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
|
|
|
context = js_support.extract_code_context(
|
|
function=permutation_func, project_root=cjs_project, module_root=cjs_project
|
|
)
|
|
|
|
expected_code = """\
|
|
class Calculator {
|
|
constructor(precision = 2) {
|
|
this.precision = precision;
|
|
this.history = [];
|
|
}
|
|
|
|
/**
|
|
* Calculate permutation using factorial helper.
|
|
* @param n - Total items
|
|
* @param r - Items to choose
|
|
* @returns Permutation result
|
|
*/
|
|
permutation(n, r) {
|
|
if (n < r) return 0;
|
|
// Inefficient: calculates factorial(n) fully even when not needed
|
|
return factorial(n) / factorial(n - r);
|
|
}
|
|
}"""
|
|
|
|
assert context.target_code is not None, "target_code should not be None"
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
def test_extract_context_includes_direct_helpers(self, js_support, cjs_project):
|
|
"""Test that direct helper functions are included in context."""
|
|
calculator_file = cjs_project / "calculator.js"
|
|
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
|
|
|
context = js_support.extract_code_context(
|
|
function=permutation_func, project_root=cjs_project, module_root=cjs_project
|
|
)
|
|
|
|
# Find factorial helper
|
|
helper_dict = {h.name: h for h in context.helper_functions}
|
|
|
|
assert "factorial" in helper_dict, f"factorial helper not found. Found helpers: {list(helper_dict.keys())}"
|
|
|
|
factorial_helper = helper_dict["factorial"]
|
|
|
|
expected_factorial_code = """\
|
|
export function factorial(n) {
|
|
// Intentionally inefficient recursive implementation
|
|
if (n <= 1) return 1;
|
|
return n * factorial(n - 1);
|
|
}"""
|
|
|
|
assert factorial_helper.source_code.strip() == expected_factorial_code.strip(), (
|
|
f"Factorial helper code does not match expected.\n"
|
|
f"Expected:\n{expected_factorial_code}\n\n"
|
|
f"Got:\n{factorial_helper.source_code}"
|
|
)
|
|
|
|
# STRICT: Verify file path ends with expected filename
|
|
assert str(factorial_helper.file_path).endswith("math_utils.js"), (
|
|
f"Expected factorial to be from math_utils.js, got {factorial_helper.file_path}"
|
|
)
|
|
|
|
def test_extract_compound_interest_code(self, js_support, cjs_project):
|
|
"""Test calculateCompoundInterest code extraction."""
|
|
calculator_file = cjs_project / "calculator.js"
|
|
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
|
|
|
context = js_support.extract_code_context(
|
|
function=compound_func, project_root=cjs_project, module_root=cjs_project
|
|
)
|
|
|
|
expected_code = """\
|
|
class Calculator {
|
|
constructor(precision = 2) {
|
|
this.precision = precision;
|
|
this.history = [];
|
|
}
|
|
|
|
/**
|
|
* Calculate compound interest with multiple helper dependencies.
|
|
* @param principal - Initial amount
|
|
* @param rate - Interest rate (as decimal)
|
|
* @param time - Time in years
|
|
* @param n - Compounding frequency per year
|
|
* @returns Compound interest result
|
|
*/
|
|
calculateCompoundInterest(principal, rate, time, n) {
|
|
validateInput(principal, 'principal');
|
|
validateInput(rate, 'rate');
|
|
|
|
// Inefficient: recalculates power multiple times
|
|
let result = principal;
|
|
for (let i = 0; i < n * time; i++) {
|
|
result = multiply(result, add(1, rate / n));
|
|
}
|
|
|
|
const interest = result - principal;
|
|
this.history.push({ type: 'compound', result: interest });
|
|
return formatNumber(interest, this.precision);
|
|
}
|
|
}"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
def test_extract_compound_interest_helpers(self, js_support, cjs_project):
|
|
"""Test helper extraction for calculateCompoundInterest."""
|
|
calculator_file = cjs_project / "calculator.js"
|
|
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
|
|
|
context = js_support.extract_code_context(
|
|
function=compound_func, project_root=cjs_project, module_root=cjs_project
|
|
)
|
|
|
|
helper_dict = {h.name: h for h in context.helper_functions}
|
|
|
|
expected_helpers = {"add", "multiply", "formatNumber", "validateInput"}
|
|
actual_helpers = set(helper_dict.keys())
|
|
assert actual_helpers == expected_helpers, f"Expected helpers {expected_helpers}, got {actual_helpers}"
|
|
|
|
# STRICT: Verify each helper's code exactly
|
|
expected_add_code = """\
|
|
export function add(a, b) {
|
|
return a + b;
|
|
}"""
|
|
|
|
expected_multiply_code = """\
|
|
export function multiply(a, b) {
|
|
return a * b;
|
|
}"""
|
|
|
|
expected_format_number_code = """\
|
|
export function formatNumber(num, decimals) {
|
|
return Number(num.toFixed(decimals));
|
|
}"""
|
|
|
|
expected_validate_input_code = """\
|
|
export function validateInput(value, name) {
|
|
if (typeof value !== 'number' || isNaN(value)) {
|
|
throw new Error(`Invalid ${name}: must be a number`);
|
|
}
|
|
}"""
|
|
|
|
helper_expectations = {
|
|
"add": (expected_add_code, "math_utils.js"),
|
|
"multiply": (expected_multiply_code, "math_utils.js"),
|
|
"formatNumber": (expected_format_number_code, "format.js"),
|
|
"validateInput": (expected_validate_input_code, "format.js"),
|
|
}
|
|
|
|
for helper_name, (expected_code, expected_file) in helper_expectations.items():
|
|
helper = helper_dict[helper_name]
|
|
|
|
assert helper.source_code.strip() == expected_code.strip(), (
|
|
f"{helper_name} helper code does not match expected.\n"
|
|
f"Expected:\n{expected_code}\n\n"
|
|
f"Got:\n{helper.source_code}"
|
|
)
|
|
|
|
assert str(helper.file_path).endswith(expected_file), (
|
|
f"Expected {helper_name} to be from {expected_file}, got {helper.file_path}"
|
|
)
|
|
|
|
def test_extract_context_includes_imports(self, js_support, cjs_project):
|
|
"""Test import statement extraction."""
|
|
calculator_file = cjs_project / "calculator.js"
|
|
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
|
|
|
context = js_support.extract_code_context(
|
|
function=compound_func, project_root=cjs_project, module_root=cjs_project
|
|
)
|
|
|
|
expected_imports = [
|
|
"const { add, multiply, factorial } = require('./math_utils');",
|
|
"const { formatNumber, validateInput } = require('./helpers/format');",
|
|
]
|
|
|
|
assert len(context.imports) == 2, f"Expected 2 imports, got {len(context.imports)}: {context.imports}"
|
|
assert context.imports == expected_imports, (
|
|
f"Imports do not match expected.\nExpected:\n{expected_imports}\n\nGot:\n{context.imports}"
|
|
)
|
|
|
|
def test_extract_static_method(self, js_support, cjs_project):
|
|
"""Test static method extraction (quickAdd)."""
|
|
calculator_file = cjs_project / "calculator.js"
|
|
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
quick_add_func = next(f for f in functions if f.function_name == "quickAdd")
|
|
|
|
context = js_support.extract_code_context(
|
|
function=quick_add_func, project_root=cjs_project, module_root=cjs_project
|
|
)
|
|
|
|
expected_code = """\
|
|
class Calculator {
|
|
constructor(precision = 2) {
|
|
this.precision = precision;
|
|
this.history = [];
|
|
}
|
|
|
|
/**
|
|
* Static method for quick calculations.
|
|
*/
|
|
static quickAdd(a, b) {
|
|
return add(a, b);
|
|
}
|
|
}"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
# quickAdd uses add helper from math_utils
|
|
helper_dict = {h.name: h for h in context.helper_functions}
|
|
assert set(helper_dict.keys()) == {"add"}, f"Expected 'add' helper, got: {list(helper_dict.keys())}"
|
|
|
|
expected_add_code = """\
|
|
export function add(a, b) {
|
|
return a + b;
|
|
}"""
|
|
|
|
assert helper_dict["add"].source_code.strip() == expected_add_code.strip(), (
|
|
f"add helper code does not match.\nExpected:\n{expected_add_code}\n\nGot:\n{helper_dict['add'].source_code}"
|
|
)
|
|
|
|
|
|
class TestCodeExtractorESM:
|
|
"""Tests for ES Module code extraction."""
|
|
|
|
@pytest.fixture
|
|
def esm_project(self, tmp_path):
|
|
"""Create a temporary ESM project from fixtures."""
|
|
project_dir = tmp_path / "esm_project"
|
|
shutil.copytree(FIXTURES_DIR / "js_esm", project_dir)
|
|
return project_dir
|
|
|
|
@pytest.fixture
|
|
def js_support(self):
|
|
"""Create JavaScriptSupport instance."""
|
|
return JavaScriptSupport()
|
|
|
|
def test_discover_esm_methods(self, js_support, esm_project):
|
|
"""Test method discovery in ESM project."""
|
|
calculator_file = esm_project / "calculator.js"
|
|
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
method_names = {f.function_name for f in functions}
|
|
|
|
# Should find same methods as CJS version
|
|
expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"}
|
|
assert method_names == expected_methods, f"Expected methods {expected_methods}, got {method_names}"
|
|
|
|
def test_esm_permutation_extraction(self, js_support, esm_project):
|
|
"""Test permutation method extraction in ESM."""
|
|
calculator_file = esm_project / "calculator.js"
|
|
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
|
|
|
context = js_support.extract_code_context(
|
|
function=permutation_func, project_root=esm_project, module_root=esm_project
|
|
)
|
|
|
|
expected_code = """\
|
|
class Calculator {
|
|
constructor(precision = 2) {
|
|
this.precision = precision;
|
|
this.history = [];
|
|
}
|
|
|
|
/**
|
|
* Calculate permutation using factorial helper.
|
|
* @param n - Total items
|
|
* @param r - Items to choose
|
|
* @returns Permutation result
|
|
*/
|
|
permutation(n, r) {
|
|
if (n < r) return 0;
|
|
// Inefficient: calculates factorial(n) fully even when not needed
|
|
return factorial(n) / factorial(n - r);
|
|
}
|
|
}"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
# ESM permutation uses factorial helper
|
|
helper_dict = {h.name: h for h in context.helper_functions}
|
|
assert set(helper_dict.keys()) == {"factorial"}, f"Expected 'factorial' helper, got: {list(helper_dict.keys())}"
|
|
|
|
expected_factorial_code = """\
|
|
export function factorial(n) {
|
|
// Intentionally inefficient recursive implementation
|
|
if (n <= 1) return 1;
|
|
return n * factorial(n - 1);
|
|
}"""
|
|
|
|
assert helper_dict["factorial"].source_code.strip() == expected_factorial_code.strip(), (
|
|
f"factorial helper code does not match.\nExpected:\n{expected_factorial_code}\n\nGot:\n{helper_dict['factorial'].source_code}"
|
|
)
|
|
|
|
def test_esm_compound_interest_extraction(self, js_support, esm_project):
|
|
"""Test calculateCompoundInterest extraction in ESM with import syntax."""
|
|
calculator_file = esm_project / "calculator.js"
|
|
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
|
|
|
context = js_support.extract_code_context(
|
|
function=compound_func, project_root=esm_project, module_root=esm_project
|
|
)
|
|
|
|
expected_code = """\
|
|
class Calculator {
|
|
constructor(precision = 2) {
|
|
this.precision = precision;
|
|
this.history = [];
|
|
}
|
|
|
|
/**
|
|
* Calculate compound interest with multiple helper dependencies.
|
|
* @param principal - Initial amount
|
|
* @param rate - Interest rate (as decimal)
|
|
* @param time - Time in years
|
|
* @param n - Compounding frequency per year
|
|
* @returns Compound interest result
|
|
*/
|
|
calculateCompoundInterest(principal, rate, time, n) {
|
|
validateInput(principal, 'principal');
|
|
validateInput(rate, 'rate');
|
|
|
|
// Inefficient: recalculates power multiple times
|
|
let result = principal;
|
|
for (let i = 0; i < n * time; i++) {
|
|
result = multiply(result, add(1, rate / n));
|
|
}
|
|
|
|
const interest = result - principal;
|
|
this.history.push({ type: 'compound', result: interest });
|
|
return formatNumber(interest, this.precision);
|
|
}
|
|
}"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
expected_imports = [
|
|
"import { add, multiply, factorial } from './math_utils.js';",
|
|
"import { formatNumber, validateInput } from './helpers/format.js';",
|
|
]
|
|
|
|
assert len(context.imports) == 2, f"Expected 2 imports, got {len(context.imports)}: {context.imports}"
|
|
assert context.imports == expected_imports, (
|
|
f"Imports do not match expected.\nExpected:\n{expected_imports}\n\nGot:\n{context.imports}"
|
|
)
|
|
|
|
# ESM compound interest uses 4 helpers
|
|
helper_dict = {h.name: h for h in context.helper_functions}
|
|
expected_helper_names = {"validateInput", "formatNumber", "add", "multiply"}
|
|
assert set(helper_dict.keys()) == expected_helper_names, (
|
|
f"Expected helpers {expected_helper_names}, got: {set(helper_dict.keys())}"
|
|
)
|
|
|
|
expected_validate_input_code = """\
|
|
export function validateInput(value, name) {
|
|
if (typeof value !== 'number' || isNaN(value)) {
|
|
throw new Error(`Invalid ${name}: must be a number`);
|
|
}
|
|
}"""
|
|
|
|
expected_format_number_code = """\
|
|
export function formatNumber(num, decimals) {
|
|
return Number(num.toFixed(decimals));
|
|
}"""
|
|
|
|
expected_add_code = """\
|
|
export function add(a, b) {
|
|
return a + b;
|
|
}"""
|
|
|
|
expected_multiply_code = """\
|
|
export function multiply(a, b) {
|
|
return a * b;
|
|
}"""
|
|
|
|
helper_expectations = {
|
|
"validateInput": expected_validate_input_code,
|
|
"formatNumber": expected_format_number_code,
|
|
"add": expected_add_code,
|
|
"multiply": expected_multiply_code,
|
|
}
|
|
|
|
for helper_name, expected_code in helper_expectations.items():
|
|
assert helper_dict[helper_name].source_code.strip() == expected_code.strip(), (
|
|
f"{helper_name} helper code does not match.\n"
|
|
f"Expected:\n{expected_code}\n\n"
|
|
f"Got:\n{helper_dict[helper_name].source_code}"
|
|
)
|
|
|
|
|
|
class TestCodeExtractorTypeScript:
|
|
"""Tests for TypeScript code extraction."""
|
|
|
|
@pytest.fixture
|
|
def ts_project(self, tmp_path):
|
|
"""Create a temporary TypeScript project from fixtures."""
|
|
project_dir = tmp_path / "ts_project"
|
|
shutil.copytree(FIXTURES_DIR / "ts", project_dir)
|
|
return project_dir
|
|
|
|
@pytest.fixture
|
|
def ts_support(self):
|
|
"""Create TypeScriptSupport instance."""
|
|
return TypeScriptSupport()
|
|
|
|
def test_typescript_support_properties(self, ts_support):
|
|
"""Test TypeScriptSupport properties."""
|
|
assert ts_support.language == Language.TYPESCRIPT
|
|
|
|
# STRICT: Verify exact file extensions
|
|
expected_extensions = {".ts", ".tsx"}
|
|
actual_extensions = set(ts_support.file_extensions)
|
|
assert expected_extensions.issubset(actual_extensions), (
|
|
f"Expected extensions {expected_extensions} to be subset of {actual_extensions}"
|
|
)
|
|
|
|
def test_discover_ts_methods(self, ts_support, ts_project):
|
|
"""Test method discovery in TypeScript."""
|
|
calculator_file = ts_project / "calculator.ts"
|
|
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
method_names = {f.function_name for f in functions}
|
|
|
|
# TypeScript has additional getHistory method
|
|
expected_methods = {"calculateCompoundInterest", "permutation", "getHistory", "quickAdd"}
|
|
assert method_names == expected_methods, f"Expected methods {expected_methods}, got {method_names}"
|
|
|
|
def test_ts_permutation_extraction(self, ts_support, ts_project):
|
|
"""Test permutation method extraction in TypeScript."""
|
|
calculator_file = ts_project / "calculator.ts"
|
|
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
|
|
|
context = ts_support.extract_code_context(
|
|
function=permutation_func, project_root=ts_project, module_root=ts_project
|
|
)
|
|
|
|
expected_code = """\
|
|
class Calculator {
|
|
private precision: number;
|
|
private history: HistoryEntry[];
|
|
|
|
constructor(precision: number = 2) {
|
|
this.precision = precision;
|
|
this.history = [];
|
|
}
|
|
|
|
/**
|
|
* Calculate permutation using factorial helper.
|
|
* @param n - Total items
|
|
* @param r - Items to choose
|
|
* @returns Permutation result
|
|
*/
|
|
permutation(n: number, r: number): number {
|
|
if (n < r) return 0;
|
|
// Inefficient: calculates factorial(n) fully even when not needed
|
|
return factorial(n) / factorial(n - r);
|
|
}
|
|
}"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
# TypeScript permutation uses factorial helper
|
|
helper_dict = {h.name: h for h in context.helper_functions}
|
|
assert set(helper_dict.keys()) == {"factorial"}, f"Expected 'factorial' helper, got: {list(helper_dict.keys())}"
|
|
|
|
expected_factorial_code = """\
|
|
export function factorial(n: number): number {
|
|
// Intentionally inefficient recursive implementation
|
|
if (n <= 1) return 1;
|
|
return n * factorial(n - 1);
|
|
}"""
|
|
|
|
assert helper_dict["factorial"].source_code.strip() == expected_factorial_code.strip(), (
|
|
f"factorial helper code does not match.\nExpected:\n{expected_factorial_code}\n\nGot:\n{helper_dict['factorial'].source_code}"
|
|
)
|
|
|
|
def test_ts_compound_interest_extraction(self, ts_support, ts_project):
|
|
"""Test calculateCompoundInterest extraction in TypeScript."""
|
|
calculator_file = ts_project / "calculator.ts"
|
|
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
|
|
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
|
|
|
context = ts_support.extract_code_context(
|
|
function=compound_func, project_root=ts_project, module_root=ts_project
|
|
)
|
|
|
|
expected_code = """\
|
|
class Calculator {
|
|
private precision: number;
|
|
private history: HistoryEntry[];
|
|
|
|
constructor(precision: number = 2) {
|
|
this.precision = precision;
|
|
this.history = [];
|
|
}
|
|
|
|
/**
|
|
* Calculate compound interest with multiple helper dependencies.
|
|
* @param principal - Initial amount
|
|
* @param rate - Interest rate (as decimal)
|
|
* @param time - Time in years
|
|
* @param n - Compounding frequency per year
|
|
* @returns Compound interest result
|
|
*/
|
|
calculateCompoundInterest(principal: number, rate: number, time: number, n: number): number {
|
|
validateInput(principal, 'principal');
|
|
validateInput(rate, 'rate');
|
|
|
|
// Inefficient: recalculates power multiple times
|
|
let result = principal;
|
|
for (let i = 0; i < n * time; i++) {
|
|
result = multiply(result, add(1, rate / n));
|
|
}
|
|
|
|
const interest = result - principal;
|
|
this.history.push({ type: 'compound', result: interest });
|
|
return formatNumber(interest, this.precision);
|
|
}
|
|
}"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
# TypeScript compound interest uses 4 helpers
|
|
helper_dict = {h.name: h for h in context.helper_functions}
|
|
expected_helper_names = {"validateInput", "formatNumber", "add", "multiply"}
|
|
assert set(helper_dict.keys()) == expected_helper_names, (
|
|
f"Expected helpers {expected_helper_names}, got: {set(helper_dict.keys())}"
|
|
)
|
|
|
|
expected_validate_input_code = """\
|
|
export function validateInput(value: unknown, name: string): asserts value is number {
|
|
if (typeof value !== 'number' || isNaN(value)) {
|
|
throw new Error(`Invalid ${name}: must be a number`);
|
|
}
|
|
}"""
|
|
|
|
expected_format_number_code = """\
|
|
export function formatNumber(num: number, decimals: number): number {
|
|
return Number(num.toFixed(decimals));
|
|
}"""
|
|
|
|
expected_add_code = """\
|
|
export function add(a: number, b: number): number {
|
|
return a + b;
|
|
}"""
|
|
|
|
expected_multiply_code = """\
|
|
export function multiply(a: number, b: number): number {
|
|
return a * b;
|
|
}"""
|
|
|
|
helper_expectations = {
|
|
"validateInput": expected_validate_input_code,
|
|
"formatNumber": expected_format_number_code,
|
|
"add": expected_add_code,
|
|
"multiply": expected_multiply_code,
|
|
}
|
|
|
|
for helper_name, expected_code in helper_expectations.items():
|
|
assert helper_dict[helper_name].source_code.strip() == expected_code.strip(), (
|
|
f"{helper_name} helper code does not match.\n"
|
|
f"Expected:\n{expected_code}\n\n"
|
|
f"Got:\n{helper_dict[helper_name].source_code}"
|
|
)
|
|
|
|
|
|
class TestCodeExtractorEdgeCases:
|
|
"""Tests for edge cases."""
|
|
|
|
@pytest.fixture
|
|
def js_support(self):
|
|
"""Create JavaScriptSupport instance."""
|
|
return JavaScriptSupport()
|
|
|
|
def test_standalone_function(self, js_support, tmp_path):
|
|
"""Test standalone function with no helpers."""
|
|
source = """\
|
|
export function standalone(x) {
|
|
return x * 2;
|
|
}
|
|
|
|
module.exports = { standalone };
|
|
"""
|
|
test_file = tmp_path / "standalone.js"
|
|
test_file.write_text(source)
|
|
|
|
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
func = next(f for f in functions if f.function_name == "standalone")
|
|
|
|
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
# STRICT: Exact code comparison
|
|
expected_code = """\
|
|
export function standalone(x) {
|
|
return x * 2;
|
|
}"""
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
# STRICT: Exactly zero helpers
|
|
assert len(context.helper_functions) == 0, (
|
|
f"Expected 0 helpers, got {len(context.helper_functions)}: {[h.name for h in context.helper_functions]}"
|
|
)
|
|
|
|
def test_external_package_excluded(self, js_support, tmp_path):
|
|
"""Test external packages are not resolved as helpers."""
|
|
source = """\
|
|
const _ = require('lodash');
|
|
|
|
export function processArray(arr) {
|
|
return _.map(arr, x => x * 2);
|
|
}
|
|
|
|
module.exports = { processArray };
|
|
"""
|
|
test_file = tmp_path / "processor.js"
|
|
test_file.write_text(source)
|
|
|
|
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
func = next(f for f in functions if f.function_name == "processArray")
|
|
|
|
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
expected_code = """\
|
|
export function processArray(arr) {
|
|
return _.map(arr, x => x * 2);
|
|
}"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
expected_imports = ["const _ = require('lodash');"]
|
|
assert context.imports == expected_imports, (
|
|
f"Imports do not match expected.\nExpected:\n{expected_imports}\n\nGot:\n{context.imports}"
|
|
)
|
|
|
|
helper_names = {h.name for h in context.helper_functions}
|
|
assert helper_names == set(), f"Expected no helpers for external package usage, got: {helper_names}"
|
|
|
|
def test_recursive_function(self, js_support, tmp_path):
|
|
"""Test recursive function doesn't list itself as helper."""
|
|
source = """\
|
|
export function fibonacci(n) {
|
|
if (n <= 1) return n;
|
|
return fibonacci(n - 1) + fibonacci(n - 2);
|
|
}
|
|
|
|
module.exports = { fibonacci };
|
|
"""
|
|
test_file = tmp_path / "recursive.js"
|
|
test_file.write_text(source)
|
|
|
|
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
func = next(f for f in functions if f.function_name == "fibonacci")
|
|
|
|
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
# STRICT: Exact code comparison
|
|
expected_code = """\
|
|
export function fibonacci(n) {
|
|
if (n <= 1) return n;
|
|
return fibonacci(n - 1) + fibonacci(n - 2);
|
|
}"""
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
# STRICT: Function should NOT be its own helper
|
|
helper_names = {h.name for h in context.helper_functions}
|
|
assert "fibonacci" not in helper_names, f"Recursive function listed itself as helper. Helpers: {helper_names}"
|
|
|
|
def test_arrow_function_helper(self, js_support, tmp_path):
|
|
"""Test arrow function helper extraction."""
|
|
source = """\
|
|
const helper = (x) => x * 2;
|
|
|
|
export const processValue = (value) => {
|
|
return helper(value) + 1;
|
|
};
|
|
|
|
module.exports = { processValue };
|
|
"""
|
|
test_file = tmp_path / "arrow.js"
|
|
test_file.write_text(source)
|
|
|
|
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
func = next(f for f in functions if f.function_name == "processValue")
|
|
|
|
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
expected_code = """\
|
|
export const processValue = (value) => {
|
|
return helper(value) + 1;
|
|
};"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
assert context.imports == [], f"Expected no imports, got: {context.imports}"
|
|
|
|
helper_dict = {h.name: h for h in context.helper_functions}
|
|
assert set(helper_dict.keys()) == {"helper"}, f"Expected only 'helper', got: {list(helper_dict.keys())}"
|
|
|
|
expected_helper_code = "const helper = (x) => x * 2;"
|
|
actual_helper_code = helper_dict["helper"].source_code.strip()
|
|
assert actual_helper_code == expected_helper_code, (
|
|
f"Helper code does not match.\nExpected:\n{expected_helper_code}\n\nGot:\n{actual_helper_code}"
|
|
)
|
|
|
|
|
|
class TestClassContextExtraction:
|
|
"""Tests for class constructor and field extraction in code context."""
|
|
|
|
@pytest.fixture
|
|
def js_support(self):
|
|
"""Create JavaScriptSupport instance."""
|
|
return JavaScriptSupport()
|
|
|
|
@pytest.fixture
|
|
def ts_support(self):
|
|
"""Create TypeScriptSupport instance."""
|
|
return TypeScriptSupport()
|
|
|
|
def test_method_extraction_includes_constructor(self, js_support, tmp_path):
|
|
"""Test that extracting a class method includes the constructor."""
|
|
source = """\
|
|
export class Counter {
|
|
constructor(initial = 0) {
|
|
this.count = initial;
|
|
}
|
|
|
|
increment() {
|
|
this.count++;
|
|
return this.count;
|
|
}
|
|
}
|
|
|
|
module.exports = { Counter };
|
|
"""
|
|
test_file = tmp_path / "counter.js"
|
|
test_file.write_text(source)
|
|
|
|
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
increment_func = next(f for f in functions if f.function_name == "increment")
|
|
|
|
context = js_support.extract_code_context(function=increment_func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
expected_code = """\
|
|
class Counter {
|
|
constructor(initial = 0) {
|
|
this.count = initial;
|
|
}
|
|
|
|
increment() {
|
|
this.count++;
|
|
return this.count;
|
|
}
|
|
}"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
def test_method_extraction_class_without_constructor(self, js_support, tmp_path):
|
|
"""Test extracting a method from a class that has no constructor."""
|
|
source = """\
|
|
export class MathUtils {
|
|
add(a, b) {
|
|
return a + b;
|
|
}
|
|
|
|
multiply(a, b) {
|
|
return a * b;
|
|
}
|
|
}
|
|
|
|
module.exports = { MathUtils };
|
|
"""
|
|
test_file = tmp_path / "math_utils.js"
|
|
test_file.write_text(source)
|
|
|
|
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
add_func = next(f for f in functions if f.function_name == "add")
|
|
|
|
context = js_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
expected_code = """\
|
|
class MathUtils {
|
|
add(a, b) {
|
|
return a + b;
|
|
}
|
|
}"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
def test_typescript_method_extraction_includes_fields(self, ts_support, tmp_path):
|
|
"""Test that TypeScript method extraction includes class fields."""
|
|
source = """\
|
|
export class User {
|
|
private name: string;
|
|
public age: number;
|
|
|
|
constructor(name: string, age: number) {
|
|
this.name = name;
|
|
this.age = age;
|
|
}
|
|
|
|
getName(): string {
|
|
return this.name;
|
|
}
|
|
}
|
|
"""
|
|
test_file = tmp_path / "user.ts"
|
|
test_file.write_text(source)
|
|
|
|
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
get_name_func = next(f for f in functions if f.function_name == "getName")
|
|
|
|
context = ts_support.extract_code_context(function=get_name_func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
expected_code = """\
|
|
class User {
|
|
private name: string;
|
|
public age: number;
|
|
|
|
constructor(name: string, age: number) {
|
|
this.name = name;
|
|
this.age = age;
|
|
}
|
|
|
|
getName(): string {
|
|
return this.name;
|
|
}
|
|
}"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
def test_typescript_fields_only_no_constructor(self, ts_support, tmp_path):
|
|
"""Test TypeScript class with fields but no constructor."""
|
|
source = """\
|
|
export class Config {
|
|
readonly apiUrl: string = "https://api.example.com";
|
|
timeout: number = 5000;
|
|
|
|
getUrl(): string {
|
|
return this.apiUrl;
|
|
}
|
|
}
|
|
"""
|
|
test_file = tmp_path / "config.ts"
|
|
test_file.write_text(source)
|
|
|
|
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
get_url_func = next(f for f in functions if f.function_name == "getUrl")
|
|
|
|
context = ts_support.extract_code_context(function=get_url_func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
expected_code = """\
|
|
class Config {
|
|
readonly apiUrl: string = "https://api.example.com";
|
|
timeout: number = 5000;
|
|
|
|
getUrl(): string {
|
|
return this.apiUrl;
|
|
}
|
|
}"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
def test_constructor_with_jsdoc(self, js_support, tmp_path):
|
|
"""Test that constructor with JSDoc is fully extracted."""
|
|
source = """\
|
|
export class Logger {
|
|
/**
|
|
* Create a new Logger instance.
|
|
* @param {string} prefix - The prefix to use for log messages.
|
|
*/
|
|
constructor(prefix) {
|
|
this.prefix = prefix;
|
|
}
|
|
|
|
getPrefix() {
|
|
return this.prefix;
|
|
}
|
|
}
|
|
|
|
module.exports = { Logger };
|
|
"""
|
|
test_file = tmp_path / "logger.js"
|
|
test_file.write_text(source)
|
|
|
|
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
get_prefix_func = next(f for f in functions if f.function_name == "getPrefix")
|
|
|
|
context = js_support.extract_code_context(function=get_prefix_func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
expected_code = """\
|
|
class Logger {
|
|
/**
|
|
* Create a new Logger instance.
|
|
* @param {string} prefix - The prefix to use for log messages.
|
|
*/
|
|
constructor(prefix) {
|
|
this.prefix = prefix;
|
|
}
|
|
|
|
getPrefix() {
|
|
return this.prefix;
|
|
}
|
|
}"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
def test_static_method_includes_constructor(self, js_support, tmp_path):
|
|
"""Test that static method extraction also includes constructor context."""
|
|
source = """\
|
|
export class Factory {
|
|
constructor(config) {
|
|
this.config = config;
|
|
}
|
|
|
|
static create(type) {
|
|
return new Factory({ type: type });
|
|
}
|
|
}
|
|
|
|
module.exports = { Factory };
|
|
"""
|
|
test_file = tmp_path / "factory.js"
|
|
test_file.write_text(source)
|
|
|
|
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
create_func = next(f for f in functions if f.function_name == "create")
|
|
|
|
context = js_support.extract_code_context(function=create_func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
expected_code = """\
|
|
class Factory {
|
|
constructor(config) {
|
|
this.config = config;
|
|
}
|
|
|
|
static create(type) {
|
|
return new Factory({ type: type });
|
|
}
|
|
}"""
|
|
|
|
assert context.target_code.strip() == expected_code.strip(), (
|
|
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
|
)
|
|
|
|
|
|
class TestCodeExtractorIntegration:
|
|
"""Integration tests with FunctionOptimizer."""
|
|
|
|
@pytest.fixture
|
|
def cjs_project(self, tmp_path):
|
|
"""Create a temporary CJS project from fixtures."""
|
|
project_dir = tmp_path / "cjs_project"
|
|
shutil.copytree(FIXTURES_DIR / "js_cjs", project_dir)
|
|
return project_dir
|
|
|
|
def test_function_optimizer_workflow(self, cjs_project):
|
|
"""Test full FunctionOptimizer workflow."""
|
|
from codeflash.languages import current as lang_current
|
|
from codeflash.languages.base import Language
|
|
|
|
# Force set language to JavaScript for proper context extraction routing
|
|
lang_current._current_language = Language.JAVASCRIPT
|
|
|
|
js_support = get_language_support("javascript")
|
|
calculator_file = cjs_project / "calculator.js"
|
|
|
|
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
|
target = next(f for f in functions if f.function_name == "permutation")
|
|
|
|
parents = [FunctionParent(name=p.name, type=p.type) for p in target.parents]
|
|
|
|
func = FunctionToOptimize(
|
|
function_name=target.function_name,
|
|
file_path=target.file_path,
|
|
parents=parents,
|
|
starting_line=target.starting_line,
|
|
ending_line=target.ending_line,
|
|
starting_col=target.starting_col,
|
|
ending_col=target.ending_col,
|
|
is_async=target.is_async,
|
|
is_method=target.is_method,
|
|
language=target.language,
|
|
)
|
|
|
|
test_config = TestConfig(
|
|
tests_root=cjs_project / "tests",
|
|
tests_project_rootdir=cjs_project,
|
|
project_root_path=cjs_project,
|
|
pytest_cmd="jest",
|
|
)
|
|
|
|
func_optimizer = JavaScriptFunctionOptimizer(
|
|
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
|
|
)
|
|
result = func_optimizer.get_code_optimization_context()
|
|
|
|
context = result.unwrap()
|
|
|
|
assert context.read_writable_code is not None, "read_writable_code should not be None"
|
|
|
|
# FunctionSource uses only_function_name, not name
|
|
helper_names = {h.only_function_name for h in context.helper_functions}
|
|
assert "factorial" in helper_names, f"factorial helper not found. Found: {helper_names}"
|
|
|
|
|
|
class TestTypeDefinitionExtraction:
|
|
"""Tests for TypeScript type definition extraction in read-only context."""
|
|
|
|
@pytest.fixture
|
|
def ts_support(self):
|
|
"""Create TypeScriptSupport instance."""
|
|
return TypeScriptSupport()
|
|
|
|
@pytest.fixture
|
|
def ts_types_project(self, tmp_path):
|
|
"""Create a temporary TypeScript project with type definitions."""
|
|
project_dir = tmp_path / "ts_types_project"
|
|
project_dir.mkdir()
|
|
|
|
# Create types.ts with type definitions
|
|
types_file = project_dir / "types.ts"
|
|
types_file.write_text("""\
|
|
/**
|
|
* Configuration options for calculations.
|
|
*/
|
|
export interface CalculationConfig {
|
|
precision: number;
|
|
enableCaching: boolean;
|
|
}
|
|
|
|
/**
|
|
* Point in 2D space.
|
|
*/
|
|
export interface Point {
|
|
x: number;
|
|
y: number;
|
|
}
|
|
|
|
/**
|
|
* Rounding mode enum.
|
|
*/
|
|
export enum RoundingMode {
|
|
FLOOR = 'floor',
|
|
CEIL = 'ceil',
|
|
ROUND = 'round',
|
|
}
|
|
|
|
/**
|
|
* Result type alias.
|
|
*/
|
|
export type Result<T> = {
|
|
value: T;
|
|
success: boolean;
|
|
};
|
|
""")
|
|
return project_dir
|
|
|
|
def test_extract_same_file_interface_from_parameter(self, ts_support, tmp_path):
|
|
"""Test extracting interface type definition when used in function parameter."""
|
|
source = """\
|
|
interface Point {
|
|
x: number;
|
|
y: number;
|
|
}
|
|
|
|
export function distance(p1: Point, p2: Point): number {
|
|
const dx = p2.x - p1.x;
|
|
const dy = p2.y - p1.y;
|
|
return Math.sqrt(dx * dx + dy * dy);
|
|
}
|
|
"""
|
|
test_file = tmp_path / "geometry.ts"
|
|
test_file.write_text(source)
|
|
|
|
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
distance_func = next(f for f in functions if f.function_name == "distance")
|
|
|
|
context = ts_support.extract_code_context(function=distance_func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
# Type definition should be in read-only context with exact match
|
|
expected_read_only = """\
|
|
interface Point {
|
|
x: number;
|
|
y: number;
|
|
}"""
|
|
|
|
assert context.read_only_context is not None, "read_only_context should not be None"
|
|
assert context.read_only_context.strip() == expected_read_only.strip(), (
|
|
f"Read-only context does not match expected.\n"
|
|
f"Expected:\n{expected_read_only}\n\n"
|
|
f"Got:\n{context.read_only_context}"
|
|
)
|
|
|
|
def test_extract_same_file_enum_from_parameter(self, ts_support, tmp_path):
|
|
"""Test extracting enum type definition when used in function parameter."""
|
|
source = """\
|
|
enum Status {
|
|
PENDING = 'pending',
|
|
SUCCESS = 'success',
|
|
FAILURE = 'failure',
|
|
}
|
|
|
|
export function processStatus(status: Status): string {
|
|
switch (status) {
|
|
case Status.PENDING:
|
|
return 'Processing...';
|
|
case Status.SUCCESS:
|
|
return 'Done!';
|
|
case Status.FAILURE:
|
|
return 'Failed!';
|
|
}
|
|
}
|
|
"""
|
|
test_file = tmp_path / "status.ts"
|
|
test_file.write_text(source)
|
|
|
|
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
process_func = next(f for f in functions if f.function_name == "processStatus")
|
|
|
|
context = ts_support.extract_code_context(function=process_func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
# Enum should be in read-only context with exact match
|
|
expected_read_only = """\
|
|
enum Status {
|
|
PENDING = 'pending',
|
|
SUCCESS = 'success',
|
|
FAILURE = 'failure',
|
|
}"""
|
|
|
|
assert context.read_only_context is not None, "read_only_context should not be None"
|
|
assert context.read_only_context.strip() == expected_read_only.strip(), (
|
|
f"Read-only context does not match expected.\n"
|
|
f"Expected:\n{expected_read_only}\n\n"
|
|
f"Got:\n{context.read_only_context}"
|
|
)
|
|
|
|
def test_extract_same_file_type_alias_from_return_type(self, ts_support, tmp_path):
|
|
"""Test extracting type alias when used in function return type."""
|
|
source = """\
|
|
type Result<T> = {
|
|
value: T;
|
|
success: boolean;
|
|
};
|
|
|
|
export function compute(x: number): Result<number> {
|
|
return { value: x * 2, success: true };
|
|
}
|
|
"""
|
|
test_file = tmp_path / "compute.ts"
|
|
test_file.write_text(source)
|
|
|
|
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
compute_func = next(f for f in functions if f.function_name == "compute")
|
|
|
|
context = ts_support.extract_code_context(function=compute_func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
# Type alias should be in read-only context with exact match
|
|
expected_read_only = """\
|
|
type Result<T> = {
|
|
value: T;
|
|
success: boolean;
|
|
};"""
|
|
|
|
assert context.read_only_context is not None, "read_only_context should not be None"
|
|
assert context.read_only_context.strip() == expected_read_only.strip(), (
|
|
f"Read-only context does not match expected.\n"
|
|
f"Expected:\n{expected_read_only}\n\n"
|
|
f"Got:\n{context.read_only_context}"
|
|
)
|
|
|
|
def test_extract_class_field_types(self, ts_support, tmp_path):
|
|
"""Test extracting type definitions used in class fields."""
|
|
source = """\
|
|
interface Config {
|
|
timeout: number;
|
|
retries: number;
|
|
}
|
|
|
|
export class Service {
|
|
private config: Config;
|
|
|
|
constructor(config: Config) {
|
|
this.config = config;
|
|
}
|
|
|
|
getTimeout(): number {
|
|
return this.config.timeout;
|
|
}
|
|
}
|
|
"""
|
|
test_file = tmp_path / "service.ts"
|
|
test_file.write_text(source)
|
|
|
|
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
get_timeout_func = next(f for f in functions if f.function_name == "getTimeout")
|
|
|
|
context = ts_support.extract_code_context(
|
|
function=get_timeout_func, project_root=tmp_path, module_root=tmp_path
|
|
)
|
|
|
|
# Config interface should be in read-only context with exact match
|
|
expected_read_only = """\
|
|
interface Config {
|
|
timeout: number;
|
|
retries: number;
|
|
}"""
|
|
|
|
assert context.read_only_context is not None, "read_only_context should not be None"
|
|
assert context.read_only_context.strip() == expected_read_only.strip(), (
|
|
f"Read-only context does not match expected.\n"
|
|
f"Expected:\n{expected_read_only}\n\n"
|
|
f"Got:\n{context.read_only_context}"
|
|
)
|
|
|
|
def test_primitive_types_not_included(self, ts_support, tmp_path):
|
|
"""Test that primitive types (number, string, etc.) are not extracted."""
|
|
source = """\
|
|
export function add(a: number, b: number): number {
|
|
return a + b;
|
|
}
|
|
"""
|
|
test_file = tmp_path / "add.ts"
|
|
test_file.write_text(source)
|
|
|
|
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
add_func = next(f for f in functions if f.function_name == "add")
|
|
|
|
context = ts_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
# No type definitions should be extracted for primitives - exact empty match
|
|
assert context.read_only_context == "", (
|
|
f"Should not extract type definitions for primitive types.\n"
|
|
f"Expected empty string, got:\n{context.read_only_context}"
|
|
)
|
|
|
|
def test_extract_multiple_types(self, ts_support, tmp_path):
|
|
"""Test extracting multiple type definitions from same file."""
|
|
source = """\
|
|
interface Point {
|
|
x: number;
|
|
y: number;
|
|
}
|
|
|
|
interface Size {
|
|
width: number;
|
|
height: number;
|
|
}
|
|
|
|
export function createRect(origin: Point, size: Size): { origin: Point; size: Size } {
|
|
return { origin, size };
|
|
}
|
|
"""
|
|
test_file = tmp_path / "rect.ts"
|
|
test_file.write_text(source)
|
|
|
|
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
create_rect_func = next(f for f in functions if f.function_name == "createRect")
|
|
|
|
context = ts_support.extract_code_context(
|
|
function=create_rect_func, project_root=tmp_path, module_root=tmp_path
|
|
)
|
|
|
|
# Both Point and Size should be in read-only context with exact match
|
|
expected_read_only = """\
|
|
interface Point {
|
|
x: number;
|
|
y: number;
|
|
}
|
|
|
|
interface Size {
|
|
width: number;
|
|
height: number;
|
|
}"""
|
|
|
|
assert context.read_only_context is not None, "read_only_context should not be None"
|
|
assert context.read_only_context.strip() == expected_read_only.strip(), (
|
|
f"Read-only context does not match expected.\n"
|
|
f"Expected:\n{expected_read_only}\n\n"
|
|
f"Got:\n{context.read_only_context}"
|
|
)
|
|
|
|
def test_extract_imported_type_definition(self, ts_support, ts_types_project):
|
|
"""Test extracting type definitions from imported files."""
|
|
# Create a file that imports types from types.ts
|
|
geometry_file = ts_types_project / "geometry.ts"
|
|
geometry_file.write_text("""\
|
|
import { Point, CalculationConfig } from './types';
|
|
|
|
export function calculateDistance(p1: Point, p2: Point, config: CalculationConfig): number {
|
|
const dx = p2.x - p1.x;
|
|
const dy = p2.y - p1.y;
|
|
const distance = Math.sqrt(dx * dx + dy * dy);
|
|
|
|
if (config.precision > 0) {
|
|
const factor = Math.pow(10, config.precision);
|
|
return Math.round(distance * factor) / factor;
|
|
}
|
|
return distance;
|
|
}
|
|
""")
|
|
|
|
functions = ts_support.discover_functions(geometry_file.read_text(encoding="utf-8"), geometry_file)
|
|
calc_distance_func = next(f for f in functions if f.function_name == "calculateDistance")
|
|
|
|
context = ts_support.extract_code_context(
|
|
function=calc_distance_func, project_root=ts_types_project, module_root=ts_types_project
|
|
)
|
|
|
|
# Imported type definitions should be in read-only context with exact match
|
|
# Types are sorted by file path and line number, with file comments
|
|
# Note: The extraction uses tree-sitter which doesn't capture JSDoc for interface
|
|
# definitions in separate files - this is a known limitation
|
|
expected_read_only = """\
|
|
// From types.ts
|
|
|
|
interface CalculationConfig {
|
|
precision: number;
|
|
enableCaching: boolean;
|
|
}
|
|
|
|
interface Point {
|
|
x: number;
|
|
y: number;
|
|
}"""
|
|
|
|
assert context.read_only_context is not None, "read_only_context should not be None"
|
|
assert context.read_only_context.strip() == expected_read_only.strip(), (
|
|
f"Read-only context does not match expected.\n"
|
|
f"Expected:\n{expected_read_only}\n\n"
|
|
f"Got:\n{context.read_only_context}"
|
|
)
|
|
|
|
def test_type_with_jsdoc_included(self, ts_support, tmp_path):
|
|
"""Test that JSDoc comments are included with type definitions."""
|
|
source = """\
|
|
/**
|
|
* Represents a user in the system.
|
|
* @property id - Unique identifier
|
|
* @property name - Display name
|
|
*/
|
|
interface User {
|
|
id: string;
|
|
name: string;
|
|
}
|
|
|
|
export function greetUser(user: User): string {
|
|
return `Hello, ${user.name}!`;
|
|
}
|
|
"""
|
|
test_file = tmp_path / "user.ts"
|
|
test_file.write_text(source)
|
|
|
|
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
|
greet_func = next(f for f in functions if f.function_name == "greetUser")
|
|
|
|
context = ts_support.extract_code_context(function=greet_func, project_root=tmp_path, module_root=tmp_path)
|
|
|
|
# JSDoc should be included with the interface - exact match
|
|
expected_read_only = """\
|
|
/**
|
|
* Represents a user in the system.
|
|
* @property id - Unique identifier
|
|
* @property name - Display name
|
|
*/
|
|
interface User {
|
|
id: string;
|
|
name: string;
|
|
}"""
|
|
|
|
assert context.read_only_context is not None, "read_only_context should not be None"
|
|
assert context.read_only_context.strip() == expected_read_only.strip(), (
|
|
f"Read-only context does not match expected.\n"
|
|
f"Expected:\n{expected_read_only}\n\n"
|
|
f"Got:\n{context.read_only_context}"
|
|
)
|