Shorthand method definitions inside object literals (e.g., `{ encrypt(data) {} }`)
were being discovered as optimizable functions. Since bare method syntax is only
valid inside class bodies or object literals, extracting them as standalone code
produced syntactically invalid JavaScript, failing context extraction.
Now skip method_definition nodes whose parent is an `object` node, matching
the existing skip for arrow functions in `pair` nodes.
Trace IDs: 04f07244, 01ac202f, 024a3d42, 04da127a (~274 affected logs)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1025 lines
33 KiB
Python
1025 lines
33 KiB
Python
"""Extensive tests for the tree-sitter utilities module.
|
|
|
|
These tests verify that the TreeSitterAnalyzer correctly parses and
|
|
analyzes JavaScript/TypeScript code.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
|
|
|
|
|
|
class TestTreeSitterLanguage:
|
|
"""Tests for TreeSitterLanguage enum."""
|
|
|
|
def test_language_values(self):
|
|
"""Test that language enum has expected values."""
|
|
assert TreeSitterLanguage.JAVASCRIPT.value == "javascript"
|
|
assert TreeSitterLanguage.TYPESCRIPT.value == "typescript"
|
|
assert TreeSitterLanguage.TSX.value == "tsx"
|
|
|
|
|
|
class TestTreeSitterAnalyzerCreation:
|
|
"""Tests for TreeSitterAnalyzer initialization."""
|
|
|
|
def test_create_javascript_analyzer(self):
|
|
"""Test creating JavaScript analyzer."""
|
|
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
|
assert analyzer.language == TreeSitterLanguage.JAVASCRIPT
|
|
|
|
def test_create_typescript_analyzer(self):
|
|
"""Test creating TypeScript analyzer."""
|
|
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
|
|
assert analyzer.language == TreeSitterLanguage.TYPESCRIPT
|
|
|
|
def test_create_with_string(self):
|
|
"""Test creating analyzer with string language name."""
|
|
analyzer = TreeSitterAnalyzer("javascript")
|
|
assert analyzer.language == TreeSitterLanguage.JAVASCRIPT
|
|
|
|
def test_lazy_parser_creation(self):
|
|
"""Test that parser is created lazily."""
|
|
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
|
assert analyzer._parser is None
|
|
# Access parser property
|
|
_ = analyzer.parser
|
|
assert analyzer._parser is not None
|
|
|
|
|
|
class TestGetAnalyzerForFile:
|
|
"""Tests for get_analyzer_for_file function."""
|
|
|
|
def test_js_file(self):
|
|
"""Test getting analyzer for .js file."""
|
|
analyzer = get_analyzer_for_file(Path("/test/file.js"))
|
|
assert analyzer.language == TreeSitterLanguage.JAVASCRIPT
|
|
|
|
def test_jsx_file(self):
|
|
"""Test getting analyzer for .jsx file."""
|
|
analyzer = get_analyzer_for_file(Path("/test/file.jsx"))
|
|
assert analyzer.language == TreeSitterLanguage.JAVASCRIPT
|
|
|
|
def test_ts_file(self):
|
|
"""Test getting analyzer for .ts file."""
|
|
analyzer = get_analyzer_for_file(Path("/test/file.ts"))
|
|
assert analyzer.language == TreeSitterLanguage.TYPESCRIPT
|
|
|
|
def test_tsx_file(self):
|
|
"""Test getting analyzer for .tsx file."""
|
|
analyzer = get_analyzer_for_file(Path("/test/file.tsx"))
|
|
assert analyzer.language == TreeSitterLanguage.TSX
|
|
|
|
def test_mjs_file(self):
|
|
"""Test getting analyzer for .mjs file."""
|
|
analyzer = get_analyzer_for_file(Path("/test/file.mjs"))
|
|
assert analyzer.language == TreeSitterLanguage.JAVASCRIPT
|
|
|
|
def test_cjs_file(self):
|
|
"""Test getting analyzer for .cjs file."""
|
|
analyzer = get_analyzer_for_file(Path("/test/file.cjs"))
|
|
assert analyzer.language == TreeSitterLanguage.JAVASCRIPT
|
|
|
|
|
|
class TestParsing:
|
|
"""Tests for parsing functionality."""
|
|
|
|
@pytest.fixture
|
|
def js_analyzer(self):
|
|
"""Create a JavaScript analyzer."""
|
|
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
|
|
|
def test_parse_simple_code(self, js_analyzer):
|
|
"""Test parsing simple JavaScript code."""
|
|
code = "const x = 1;"
|
|
tree = js_analyzer.parse(code)
|
|
assert tree.root_node is not None
|
|
assert not tree.root_node.has_error
|
|
|
|
def test_parse_bytes(self, js_analyzer):
|
|
"""Test parsing code as bytes."""
|
|
code = b"const x = 1;"
|
|
tree = js_analyzer.parse(code)
|
|
assert tree.root_node is not None
|
|
|
|
def test_parse_invalid_code(self, js_analyzer):
|
|
"""Test parsing invalid code marks errors."""
|
|
code = "function foo( {"
|
|
tree = js_analyzer.parse(code)
|
|
assert tree.root_node.has_error
|
|
|
|
def test_get_node_text(self, js_analyzer):
|
|
"""Test extracting text from a node."""
|
|
code = "const x = 1;"
|
|
code_bytes = code.encode("utf8")
|
|
tree = js_analyzer.parse(code_bytes)
|
|
text = js_analyzer.get_node_text(tree.root_node, code_bytes)
|
|
assert text == code
|
|
|
|
|
|
class TestFindFunctions:
|
|
"""Tests for find_functions method."""
|
|
|
|
@pytest.fixture
|
|
def js_analyzer(self):
|
|
"""Create a JavaScript analyzer."""
|
|
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
|
|
|
def test_find_function_declaration(self, js_analyzer):
|
|
"""Test finding function declarations."""
|
|
code = """
|
|
function add(a, b) {
|
|
return a + b;
|
|
}
|
|
"""
|
|
functions = js_analyzer.find_functions(code)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "add"
|
|
assert functions[0].is_arrow is False
|
|
assert functions[0].is_async is False
|
|
assert functions[0].is_method is False
|
|
|
|
def test_find_arrow_function(self, js_analyzer):
|
|
"""Test finding arrow functions."""
|
|
code = """
|
|
const add = (a, b) => {
|
|
return a + b;
|
|
};
|
|
"""
|
|
functions = js_analyzer.find_functions(code)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "add"
|
|
assert functions[0].is_arrow is True
|
|
|
|
def test_find_arrow_function_concise(self, js_analyzer):
|
|
"""Test finding concise arrow functions."""
|
|
code = "const double = x => x * 2;"
|
|
functions = js_analyzer.find_functions(code)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "double"
|
|
assert functions[0].is_arrow is True
|
|
|
|
def test_find_async_function(self, js_analyzer):
|
|
"""Test finding async functions."""
|
|
code = """
|
|
async function fetchData(url) {
|
|
return await fetch(url);
|
|
}
|
|
"""
|
|
functions = js_analyzer.find_functions(code)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "fetchData"
|
|
assert functions[0].is_async is True
|
|
|
|
def test_find_class_methods(self, js_analyzer):
|
|
"""Test finding class methods."""
|
|
code = """
|
|
class Calculator {
|
|
add(a, b) {
|
|
return a + b;
|
|
}
|
|
}
|
|
"""
|
|
functions = js_analyzer.find_functions(code, include_methods=True)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "add"
|
|
assert functions[0].is_method is True
|
|
assert functions[0].class_name == "Calculator"
|
|
|
|
def test_exclude_methods(self, js_analyzer):
|
|
"""Test excluding class methods."""
|
|
code = """
|
|
class Calculator {
|
|
add(a, b) {
|
|
return a + b;
|
|
}
|
|
}
|
|
|
|
function standalone() {
|
|
return 1;
|
|
}
|
|
"""
|
|
functions = js_analyzer.find_functions(code, include_methods=False)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "standalone"
|
|
|
|
def test_exclude_arrow_functions(self, js_analyzer):
|
|
"""Test excluding arrow functions."""
|
|
code = """
|
|
function regular() {
|
|
return 1;
|
|
}
|
|
|
|
const arrow = () => 2;
|
|
"""
|
|
functions = js_analyzer.find_functions(code, include_arrow_functions=False)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "regular"
|
|
|
|
def test_skip_object_literal_methods(self, js_analyzer):
|
|
"""Test that shorthand methods inside object literals are skipped."""
|
|
code = """
|
|
function getStrategies() {
|
|
return {
|
|
encrypt(data) {
|
|
return data;
|
|
},
|
|
decrypt(data) {
|
|
return data;
|
|
}
|
|
};
|
|
}
|
|
"""
|
|
functions = js_analyzer.find_functions(code, include_methods=True)
|
|
names = [f.name for f in functions]
|
|
assert "getStrategies" in names
|
|
assert "encrypt" not in names
|
|
assert "decrypt" not in names
|
|
|
|
def test_skip_module_exports_methods(self, js_analyzer):
|
|
"""Test that methods in module.exports object are skipped."""
|
|
code = """
|
|
module.exports = {
|
|
handler(req) {
|
|
return req;
|
|
}
|
|
};
|
|
"""
|
|
functions = js_analyzer.find_functions(code, include_methods=True)
|
|
assert len(functions) == 0
|
|
|
|
def test_class_methods_not_skipped(self, js_analyzer):
|
|
"""Test that class methods are NOT skipped (only object literal methods are)."""
|
|
code = """
|
|
class Encryptor {
|
|
encrypt(data) {
|
|
return data;
|
|
}
|
|
}
|
|
|
|
const strategies = {
|
|
decrypt(data) {
|
|
return data;
|
|
}
|
|
};
|
|
"""
|
|
functions = js_analyzer.find_functions(code, include_methods=True)
|
|
names = [f.name for f in functions]
|
|
assert "encrypt" in names
|
|
assert "decrypt" not in names
|
|
|
|
def test_find_generator_function(self, js_analyzer):
|
|
"""Test finding generator functions."""
|
|
code = """
|
|
function* numberGenerator() {
|
|
yield 1;
|
|
yield 2;
|
|
}
|
|
"""
|
|
functions = js_analyzer.find_functions(code)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "numberGenerator"
|
|
assert functions[0].is_generator is True
|
|
|
|
def test_function_line_numbers(self, js_analyzer):
|
|
"""Test that line numbers are correct."""
|
|
code = """function first() {
|
|
return 1;
|
|
}
|
|
|
|
function second() {
|
|
return 2;
|
|
}
|
|
"""
|
|
functions = js_analyzer.find_functions(code)
|
|
|
|
first = next(f for f in functions if f.name == "first")
|
|
second = next(f for f in functions if f.name == "second")
|
|
|
|
assert first.start_line == 1
|
|
assert first.end_line == 3
|
|
assert second.start_line == 5
|
|
assert second.end_line == 7
|
|
|
|
def test_nested_functions(self, js_analyzer):
|
|
"""Test finding nested functions."""
|
|
code = """
|
|
function outer() {
|
|
function inner() {
|
|
return 1;
|
|
}
|
|
return inner();
|
|
}
|
|
"""
|
|
functions = js_analyzer.find_functions(code)
|
|
|
|
assert len(functions) == 2
|
|
names = {f.name for f in functions}
|
|
assert names == {"outer", "inner"}
|
|
|
|
inner = next(f for f in functions if f.name == "inner")
|
|
assert inner.parent_function == "outer"
|
|
|
|
def test_require_name_filters_anonymous(self, js_analyzer):
|
|
"""Test that require_name filters anonymous functions."""
|
|
code = """
|
|
(function() {
|
|
return 1;
|
|
})();
|
|
|
|
function named() {
|
|
return 2;
|
|
}
|
|
"""
|
|
functions = js_analyzer.find_functions(code, require_name=True)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "named"
|
|
|
|
def test_function_expression_in_variable(self, js_analyzer):
|
|
"""Test function expression assigned to variable."""
|
|
code = """
|
|
const add = function(a, b) {
|
|
return a + b;
|
|
};
|
|
"""
|
|
functions = js_analyzer.find_functions(code)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "add"
|
|
|
|
|
|
class TestFindImports:
|
|
"""Tests for find_imports method."""
|
|
|
|
@pytest.fixture
|
|
def js_analyzer(self):
|
|
"""Create a JavaScript analyzer."""
|
|
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
|
|
|
def test_find_default_import(self, js_analyzer):
|
|
"""Test finding default import."""
|
|
code = "import React from 'react';"
|
|
imports = js_analyzer.find_imports(code)
|
|
|
|
assert len(imports) == 1
|
|
assert imports[0].module_path == "react"
|
|
assert imports[0].default_import == "React"
|
|
|
|
def test_find_named_imports(self, js_analyzer):
|
|
"""Test finding named imports."""
|
|
code = "import { useState, useEffect } from 'react';"
|
|
imports = js_analyzer.find_imports(code)
|
|
|
|
assert len(imports) == 1
|
|
assert imports[0].module_path == "react"
|
|
assert ("useState", None) in imports[0].named_imports
|
|
assert ("useEffect", None) in imports[0].named_imports
|
|
|
|
def test_find_namespace_import(self, js_analyzer):
|
|
"""Test finding namespace import."""
|
|
code = "import * as utils from './utils';"
|
|
imports = js_analyzer.find_imports(code)
|
|
|
|
assert len(imports) == 1
|
|
assert imports[0].module_path == "./utils"
|
|
assert imports[0].namespace_import == "utils"
|
|
|
|
def test_find_require(self, js_analyzer):
|
|
"""Test finding require() calls."""
|
|
code = "const fs = require('fs');"
|
|
imports = js_analyzer.find_imports(code)
|
|
|
|
assert len(imports) == 1
|
|
assert imports[0].module_path == "fs"
|
|
assert imports[0].default_import == "fs"
|
|
|
|
def test_require_inside_function_not_import(self, js_analyzer):
|
|
"""Test that require() inside functions is not treated as an import.
|
|
|
|
This is important because dynamic require() calls inside functions are
|
|
not module-level imports and should not be extracted as such.
|
|
"""
|
|
code = """
|
|
const fs = require('fs');
|
|
|
|
function loadModule() {
|
|
const dynamic = require('dynamic-module');
|
|
return dynamic;
|
|
}
|
|
|
|
class MyClass {
|
|
method() {
|
|
const inMethod = require('method-module');
|
|
}
|
|
}
|
|
"""
|
|
imports = js_analyzer.find_imports(code)
|
|
|
|
# Only the module-level require should be found
|
|
assert len(imports) == 1
|
|
assert imports[0].module_path == "fs"
|
|
|
|
def test_find_multiple_imports(self, js_analyzer):
|
|
"""Test finding multiple imports."""
|
|
code = """
|
|
import React from 'react';
|
|
import { useState } from 'react';
|
|
import * as utils from './utils';
|
|
const path = require('path');
|
|
"""
|
|
imports = js_analyzer.find_imports(code)
|
|
|
|
assert len(imports) == 4
|
|
modules = {imp.module_path for imp in imports}
|
|
assert modules == {"react", "./utils", "path"}
|
|
|
|
def test_import_with_alias(self, js_analyzer):
|
|
"""Test finding import with alias."""
|
|
code = "import { Component as Comp } from 'react';"
|
|
imports = js_analyzer.find_imports(code)
|
|
|
|
assert len(imports) == 1
|
|
assert ("Component", "Comp") in imports[0].named_imports
|
|
|
|
def test_relative_import(self, js_analyzer):
|
|
"""Test finding relative imports."""
|
|
code = "import { helper } from './helpers/utils';"
|
|
imports = js_analyzer.find_imports(code)
|
|
|
|
assert len(imports) == 1
|
|
assert imports[0].module_path == "./helpers/utils"
|
|
|
|
|
|
class TestFindFunctionCalls:
|
|
"""Tests for find_function_calls method."""
|
|
|
|
@pytest.fixture
|
|
def js_analyzer(self):
|
|
"""Create a JavaScript analyzer."""
|
|
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
|
|
|
def test_find_simple_calls(self, js_analyzer):
|
|
"""Test finding simple function calls."""
|
|
code = """
|
|
function helper() {
|
|
return 1;
|
|
}
|
|
|
|
function main() {
|
|
return helper() + 2;
|
|
}
|
|
"""
|
|
functions = js_analyzer.find_functions(code)
|
|
main_func = next(f for f in functions if f.name == "main")
|
|
|
|
calls = js_analyzer.find_function_calls(code, main_func)
|
|
|
|
assert "helper" in calls
|
|
|
|
def test_find_method_calls(self, js_analyzer):
|
|
"""Test finding method calls."""
|
|
code = """
|
|
function process(arr) {
|
|
return arr.map(x => x * 2).filter(x => x > 0);
|
|
}
|
|
"""
|
|
functions = js_analyzer.find_functions(code)
|
|
process_func = next(f for f in functions if f.name == "process")
|
|
|
|
calls = js_analyzer.find_function_calls(code, process_func)
|
|
|
|
assert "map" in calls
|
|
assert "filter" in calls
|
|
|
|
|
|
class TestHasReturnStatement:
|
|
"""Tests for has_return_statement method."""
|
|
|
|
@pytest.fixture
|
|
def js_analyzer(self):
|
|
"""Create a JavaScript analyzer."""
|
|
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
|
|
|
def test_function_with_return(self, js_analyzer):
|
|
"""Test function with return statement."""
|
|
code = """
|
|
function add(a, b) {
|
|
return a + b;
|
|
}
|
|
"""
|
|
functions = js_analyzer.find_functions(code)
|
|
assert js_analyzer.has_return_statement(functions[0], code) is True
|
|
|
|
def test_function_without_return(self, js_analyzer):
|
|
"""Test function without return statement."""
|
|
code = """
|
|
function log(msg) {
|
|
console.log(msg);
|
|
}
|
|
"""
|
|
functions = js_analyzer.find_functions(code, require_name=True)
|
|
func = next((f for f in functions if f.name == "log"), None)
|
|
if func:
|
|
assert js_analyzer.has_return_statement(func, code) is False
|
|
|
|
def test_arrow_function_implicit_return(self, js_analyzer):
|
|
"""Test arrow function with implicit return."""
|
|
code = "const double = x => x * 2;"
|
|
functions = js_analyzer.find_functions(code)
|
|
assert js_analyzer.has_return_statement(functions[0], code) is True
|
|
|
|
def test_arrow_function_explicit_return(self, js_analyzer):
|
|
"""Test arrow function with explicit return."""
|
|
code = """
|
|
const add = (a, b) => {
|
|
return a + b;
|
|
};
|
|
"""
|
|
functions = js_analyzer.find_functions(code)
|
|
assert js_analyzer.has_return_statement(functions[0], code) is True
|
|
|
|
|
|
class TestTypeScriptSupport:
|
|
"""Tests for TypeScript-specific features."""
|
|
|
|
@pytest.fixture
|
|
def ts_analyzer(self):
|
|
"""Create a TypeScript analyzer."""
|
|
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
|
|
|
|
def test_find_typed_function(self, ts_analyzer):
|
|
"""Test finding function with type annotations."""
|
|
code = """
|
|
function add(a: number, b: number): number {
|
|
return a + b;
|
|
}
|
|
"""
|
|
functions = ts_analyzer.find_functions(code)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "add"
|
|
|
|
def test_find_interface_method(self, ts_analyzer):
|
|
"""Test that interface methods are not found (they're declarations)."""
|
|
code = """
|
|
interface Calculator {
|
|
add(a: number, b: number): number;
|
|
}
|
|
|
|
function helper(): number {
|
|
return 1;
|
|
}
|
|
"""
|
|
functions = ts_analyzer.find_functions(code)
|
|
|
|
# Only the actual function should be found, not the interface method
|
|
names = {f.name for f in functions}
|
|
assert "helper" in names
|
|
|
|
def test_find_generic_function(self, ts_analyzer):
|
|
"""Test finding generic function."""
|
|
code = """
|
|
function identity<T>(value: T): T {
|
|
return value;
|
|
}
|
|
"""
|
|
functions = ts_analyzer.find_functions(code)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "identity"
|
|
|
|
|
|
class TestExportConstArrowFunctions:
|
|
"""Tests for export const arrow function pattern - Issue #10.
|
|
|
|
Modern TypeScript codebases commonly use:
|
|
- export const slugify = (str: string) => { return s; }
|
|
- export const uniqueBy = <T>(array: T[]) => { ... }
|
|
|
|
These must be correctly recognized as optimizable functions.
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def ts_analyzer(self):
|
|
"""Create a TypeScript analyzer."""
|
|
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
|
|
|
|
def test_export_const_arrow_function_basic(self, ts_analyzer):
|
|
"""Test finding export const arrow function (basic pattern)."""
|
|
code = """export const slugify = (str: string) => {
|
|
return str.toLowerCase();
|
|
};"""
|
|
functions = ts_analyzer.find_functions(code)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "slugify"
|
|
assert functions[0].is_arrow is True
|
|
assert ts_analyzer.has_return_statement(functions[0], code) is True
|
|
|
|
def test_export_const_arrow_function_optional_param(self, ts_analyzer):
|
|
"""Test finding export const arrow function with optional parameter."""
|
|
code = """export const slugify = (str: string, forDisplayingInput?: boolean) => {
|
|
if (!str) {
|
|
return "";
|
|
}
|
|
const s = str.toLowerCase();
|
|
return forDisplayingInput ? s : s.replace(/-+$/, "");
|
|
};"""
|
|
functions = ts_analyzer.find_functions(code)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "slugify"
|
|
assert functions[0].is_arrow is True
|
|
assert ts_analyzer.has_return_statement(functions[0], code) is True
|
|
|
|
def test_export_const_generic_arrow_function(self, ts_analyzer):
|
|
"""Test finding export const arrow function with generics."""
|
|
code = """export const uniqueBy = <T extends { [key: string]: unknown }>(array: T[], keys: (keyof T)[]) => {
|
|
return array.filter(
|
|
(item, index, self) => index === self.findIndex((t) => keys.every((key) => t[key] === item[key]))
|
|
);
|
|
};"""
|
|
functions = ts_analyzer.find_functions(code)
|
|
|
|
# Should find uniqueBy, and possibly the inner arrow functions
|
|
uniqueBy = next((f for f in functions if f.name == "uniqueBy"), None)
|
|
assert uniqueBy is not None
|
|
assert uniqueBy.is_arrow is True
|
|
assert ts_analyzer.has_return_statement(uniqueBy, code) is True
|
|
|
|
def test_export_const_arrow_function_is_exported(self, ts_analyzer):
|
|
"""Test that export const arrow functions are recognized as exported."""
|
|
code = """export const slugify = (str: string) => {
|
|
return str.toLowerCase();
|
|
};"""
|
|
|
|
# Check is_function_exported
|
|
is_exported, export_name = ts_analyzer.is_function_exported(code, "slugify")
|
|
assert is_exported is True
|
|
assert export_name == "slugify"
|
|
|
|
def test_export_const_with_default_export(self, ts_analyzer):
|
|
"""Test export const with separate default export."""
|
|
code = """export const slugify = (str: string) => {
|
|
return str.toLowerCase();
|
|
};
|
|
|
|
export default slugify;"""
|
|
|
|
functions = ts_analyzer.find_functions(code)
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "slugify"
|
|
|
|
# Should be exported both ways
|
|
is_named, named_name = ts_analyzer.is_function_exported(code, "slugify")
|
|
assert is_named is True
|
|
|
|
def test_multiple_export_const_functions(self, ts_analyzer):
|
|
"""Test multiple export const arrow functions in same file."""
|
|
code = """export const notUndefined = <T>(val: T | undefined): val is T => Boolean(val);
|
|
|
|
export const uniqueBy = <T extends { [key: string]: unknown }>(array: T[], keys: (keyof T)[]) => {
|
|
return array.filter(
|
|
(item, index, self) => index === self.findIndex((t) => keys.every((key) => t[key] === item[key]))
|
|
);
|
|
};"""
|
|
|
|
functions = ts_analyzer.find_functions(code)
|
|
|
|
# Find the top-level exported functions
|
|
names = {f.name for f in functions if f.parent_function is None}
|
|
assert "notUndefined" in names
|
|
assert "uniqueBy" in names
|
|
|
|
def test_export_const_arrow_with_implicit_return(self, ts_analyzer):
|
|
"""Test export const arrow function with implicit return."""
|
|
code = """export const double = (n: number) => n * 2;"""
|
|
|
|
functions = ts_analyzer.find_functions(code)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "double"
|
|
assert functions[0].is_arrow is True
|
|
assert ts_analyzer.has_return_statement(functions[0], code) is True
|
|
|
|
def test_export_const_async_arrow_function(self, ts_analyzer):
|
|
"""Test export const async arrow function."""
|
|
code = """export const fetchData = async (url: string) => {
|
|
const response = await fetch(url);
|
|
return response.json();
|
|
};"""
|
|
|
|
functions = ts_analyzer.find_functions(code)
|
|
|
|
assert len(functions) == 1
|
|
assert functions[0].name == "fetchData"
|
|
assert functions[0].is_arrow is True
|
|
assert functions[0].is_async is True
|
|
assert ts_analyzer.has_return_statement(functions[0], code) is True
|
|
|
|
def test_non_exported_const_not_exported(self, ts_analyzer):
|
|
"""Test that non-exported const functions are not marked as exported."""
|
|
code = """const privateFunc = (x: number) => {
|
|
return x * 2;
|
|
};
|
|
|
|
export const publicFunc = (x: number) => {
|
|
return privateFunc(x);
|
|
};"""
|
|
|
|
# privateFunc should not be exported
|
|
is_private_exported, _ = ts_analyzer.is_function_exported(code, "privateFunc")
|
|
assert is_private_exported is False
|
|
|
|
# publicFunc should be exported
|
|
is_public_exported, name = ts_analyzer.is_function_exported(code, "publicFunc")
|
|
assert is_public_exported is True
|
|
assert name == "publicFunc"
|
|
|
|
|
|
class TestWrappedDefaultExports:
|
|
"""Tests for wrapped default export pattern - Issue #9.
|
|
|
|
Handles patterns like:
|
|
- export default curry(traverseEntity)
|
|
- export default compose(fn1, fn2)
|
|
- export default wrapper(myFunc)
|
|
|
|
These must be correctly recognized so the wrapped function is exportable.
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def ts_analyzer(self):
|
|
"""Create a TypeScript analyzer."""
|
|
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
|
|
|
|
def test_curry_wrapped_export(self, ts_analyzer):
|
|
"""Test export default curry(fn) pattern."""
|
|
code = """import { curry } from 'lodash/fp';
|
|
|
|
const traverseEntity = async (visitor, options, entity) => {
|
|
return entity;
|
|
};
|
|
|
|
export default curry(traverseEntity);"""
|
|
|
|
# Check exports parsing
|
|
exports = ts_analyzer.find_exports(code)
|
|
assert len(exports) == 1
|
|
assert exports[0].default_export == "default"
|
|
assert exports[0].wrapped_default_args == ["traverseEntity"]
|
|
|
|
# Check is_function_exported
|
|
is_exported, export_name = ts_analyzer.is_function_exported(code, "traverseEntity")
|
|
assert is_exported is True
|
|
assert export_name == "default"
|
|
|
|
def test_compose_wrapped_export(self, ts_analyzer):
|
|
"""Test export default compose(fn1, fn2) pattern with multiple args."""
|
|
code = """import { compose } from 'lodash/fp';
|
|
|
|
function validateInput(data) { return data; }
|
|
function processData(data) { return data; }
|
|
|
|
export default compose(validateInput, processData);"""
|
|
|
|
exports = ts_analyzer.find_exports(code)
|
|
assert len(exports) == 1
|
|
assert exports[0].wrapped_default_args == ["validateInput", "processData"]
|
|
|
|
# Both functions should be recognized as exported
|
|
is_exported1, _ = ts_analyzer.is_function_exported(code, "validateInput")
|
|
is_exported2, _ = ts_analyzer.is_function_exported(code, "processData")
|
|
assert is_exported1 is True
|
|
assert is_exported2 is True
|
|
|
|
def test_nested_wrapper_export(self, ts_analyzer):
|
|
"""Test nested wrapper: export default compose(curry(fn))."""
|
|
code = """export default compose(curry(myFunc));"""
|
|
|
|
exports = ts_analyzer.find_exports(code)
|
|
assert len(exports) == 1
|
|
assert "myFunc" in exports[0].wrapped_default_args
|
|
|
|
is_exported, _ = ts_analyzer.is_function_exported(code, "myFunc")
|
|
assert is_exported is True
|
|
|
|
def test_generic_wrapper_export(self, ts_analyzer):
|
|
"""Test generic wrapper function."""
|
|
code = """const myFunction = (x: number) => x * 2;
|
|
|
|
export default someWrapper(myFunction);"""
|
|
|
|
is_exported, export_name = ts_analyzer.is_function_exported(code, "myFunction")
|
|
assert is_exported is True
|
|
assert export_name == "default"
|
|
|
|
def test_non_wrapped_function_not_exported(self, ts_analyzer):
|
|
"""Test that functions not in the wrapper call are not exported."""
|
|
code = """const helper = (x: number) => x + 1;
|
|
const main = (x: number) => helper(x) * 2;
|
|
|
|
export default curry(main);"""
|
|
|
|
# main is wrapped, so it's exported
|
|
is_main_exported, _ = ts_analyzer.is_function_exported(code, "main")
|
|
assert is_main_exported is True
|
|
|
|
# helper is NOT in the wrapper call, so not exported
|
|
is_helper_exported, _ = ts_analyzer.is_function_exported(code, "helper")
|
|
assert is_helper_exported is False
|
|
|
|
def test_direct_default_export_still_works(self, ts_analyzer):
|
|
"""Test that direct default exports still work."""
|
|
code = """function myFunc() { return 1; }
|
|
export default myFunc;"""
|
|
|
|
is_exported, export_name = ts_analyzer.is_function_exported(code, "myFunc")
|
|
assert is_exported is True
|
|
assert export_name == "default"
|
|
|
|
def test_strapi_traverse_entity_pattern(self, ts_analyzer):
|
|
"""Test the exact strapi pattern that was failing."""
|
|
code = """import { curry } from 'lodash/fp';
|
|
|
|
const traverseEntity = async (visitor: Visitor, options: TraverseOptions, entity: Data) => {
|
|
const { path = { raw: null }, schema, getModel } = options;
|
|
// ... implementation
|
|
return copy;
|
|
};
|
|
|
|
const createVisitorUtils = ({ data }: { data: Data }) => ({
|
|
remove(key: string) { delete data[key]; },
|
|
set(key: string, value: Data) { data[key] = value; },
|
|
});
|
|
|
|
export default curry(traverseEntity);"""
|
|
|
|
# traverseEntity should be recognized as exported
|
|
is_exported, export_name = ts_analyzer.is_function_exported(code, "traverseEntity")
|
|
assert is_exported is True
|
|
assert export_name == "default"
|
|
|
|
# createVisitorUtils is NOT wrapped, so not exported via default
|
|
is_utils_exported, _ = ts_analyzer.is_function_exported(code, "createVisitorUtils")
|
|
assert is_utils_exported is False
|
|
|
|
|
|
class TestNamedExportConstArrow:
|
|
"""Tests for const arrow functions exported via named export clause.
|
|
|
|
Pattern: const joinBy = () => {}; export { joinBy };
|
|
This is common in TypeScript codebases like Strapi.
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def ts_analyzer(self):
|
|
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
|
|
|
|
@pytest.fixture
|
|
def js_analyzer(self):
|
|
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
|
|
|
def test_named_export_const_arrow(self, ts_analyzer):
|
|
"""Const arrow function exported via separate export { } clause."""
|
|
code = """const joinBy = (arr: string[], separator: string) => {
|
|
return arr.join(separator);
|
|
};
|
|
|
|
export { joinBy };"""
|
|
|
|
functions = ts_analyzer.find_functions(code)
|
|
joinBy = next((f for f in functions if f.name == "joinBy"), None)
|
|
assert joinBy is not None
|
|
assert joinBy.is_exported is True
|
|
|
|
def test_named_export_alias(self, ts_analyzer):
|
|
"""Export { foo as bar } — foo should be marked as exported."""
|
|
code = """const foo = (x: number) => {
|
|
return x * 2;
|
|
};
|
|
|
|
export { foo as bar };"""
|
|
|
|
functions = ts_analyzer.find_functions(code)
|
|
foo = next((f for f in functions if f.name == "foo"), None)
|
|
assert foo is not None
|
|
assert foo.is_exported is True
|
|
|
|
def test_named_export_multiple(self, ts_analyzer):
|
|
"""Multiple functions in a single export clause."""
|
|
code = """const a = () => { return 1; };
|
|
const b = () => { return 2; };
|
|
const c = () => { return 3; };
|
|
|
|
export { a, b };"""
|
|
|
|
functions = ts_analyzer.find_functions(code)
|
|
a = next((f for f in functions if f.name == "a"), None)
|
|
b = next((f for f in functions if f.name == "b"), None)
|
|
c = next((f for f in functions if f.name == "c"), None)
|
|
assert a is not None and a.is_exported is True
|
|
assert b is not None and b.is_exported is True
|
|
assert c is not None and c.is_exported is False
|
|
|
|
def test_named_export_function_declaration(self, js_analyzer):
|
|
"""Regular function declarations exported via export { }."""
|
|
code = """function processData(data) {
|
|
return data;
|
|
}
|
|
|
|
export { processData };"""
|
|
|
|
functions = js_analyzer.find_functions(code)
|
|
f = next((f for f in functions if f.name == "processData"), None)
|
|
assert f is not None
|
|
assert f.is_exported is True
|
|
|
|
def test_is_function_exported_with_named_export(self, ts_analyzer):
|
|
"""is_function_exported should detect named export clause."""
|
|
code = """const joinBy = (arr: string[], separator: string) => {
|
|
return arr.join(separator);
|
|
};
|
|
|
|
export { joinBy };"""
|
|
|
|
is_exported, name = ts_analyzer.is_function_exported(code, "joinBy")
|
|
assert is_exported is True
|
|
|
|
|
|
class TestCjsReexportObjectMethods:
|
|
"""Tests for CJS re-export of object containing methods.
|
|
|
|
Pattern: const utils = { match() {} }; module.exports = utils;
|
|
This is common in Node.js libraries like Moleculer.
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def js_analyzer(self):
|
|
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
|
|
|
def test_cjs_reexport_object_methods(self, js_analyzer):
|
|
"""module.exports = varName where varName is object with methods."""
|
|
code = """const utils = {
|
|
match(text, pattern) {
|
|
return text.match(pattern);
|
|
},
|
|
slugify(str) {
|
|
return str.toLowerCase();
|
|
}
|
|
};
|
|
|
|
module.exports = utils;"""
|
|
|
|
is_exported, name = js_analyzer.is_function_exported(code, "match")
|
|
assert is_exported is True
|
|
|
|
is_exported2, _ = js_analyzer.is_function_exported(code, "slugify")
|
|
assert is_exported2 is True
|
|
|
|
def test_cjs_reexport_shorthand_props(self, js_analyzer):
|
|
"""module.exports = varName where object has shorthand properties."""
|
|
code = """function match(text, pattern) {
|
|
return text.match(pattern);
|
|
}
|
|
|
|
const utils = { match };
|
|
module.exports = utils;"""
|
|
|
|
is_exported, _ = js_analyzer.is_function_exported(code, "match")
|
|
assert is_exported is True
|
|
|
|
def test_cjs_reexport_pair_props(self, js_analyzer):
|
|
"""module.exports = varName where object has key: value pairs."""
|
|
code = """function myMatch(text, pattern) {
|
|
return text.match(pattern);
|
|
}
|
|
|
|
const utils = { match: myMatch };
|
|
module.exports = utils;"""
|
|
|
|
is_exported, _ = js_analyzer.is_function_exported(code, "match")
|
|
assert is_exported is True
|
|
|
|
def test_cjs_reexport_nonexistent_prop(self, js_analyzer):
|
|
"""A function not in the re-exported object should not be exported."""
|
|
code = """function helper() { return 1; }
|
|
|
|
const utils = {
|
|
match(text) { return text; }
|
|
};
|
|
|
|
module.exports = utils;"""
|
|
|
|
is_exported, _ = js_analyzer.is_function_exported(code, "helper")
|
|
assert is_exported is False
|