codeflash/tests/test_languages/test_treesitter_utils.py
ali 8b7ebee5fa
fix: skip object literal methods during JS/TS function discovery
Shorthand method definitions inside object literals (e.g., `{ encrypt(data) {} }`)
were being discovered as optimizable functions. Since bare method syntax is only
valid inside class bodies or object literals, extracting them as standalone code
produced syntactically invalid JavaScript, failing context extraction.

Now skip method_definition nodes whose parent is an `object` node, matching
the existing skip for arrow functions in `pair` nodes.

Trace IDs: 04f07244, 01ac202f, 024a3d42, 04da127a (~274 affected logs)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-25 18:44:08 +02:00

1025 lines
33 KiB
Python

"""Extensive tests for the tree-sitter utilities module.
These tests verify that the TreeSitterAnalyzer correctly parses and
analyzes JavaScript/TypeScript code.
"""
from pathlib import Path
import pytest
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
class TestTreeSitterLanguage:
"""Tests for TreeSitterLanguage enum."""
def test_language_values(self):
"""Test that language enum has expected values."""
assert TreeSitterLanguage.JAVASCRIPT.value == "javascript"
assert TreeSitterLanguage.TYPESCRIPT.value == "typescript"
assert TreeSitterLanguage.TSX.value == "tsx"
class TestTreeSitterAnalyzerCreation:
"""Tests for TreeSitterAnalyzer initialization."""
def test_create_javascript_analyzer(self):
"""Test creating JavaScript analyzer."""
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
assert analyzer.language == TreeSitterLanguage.JAVASCRIPT
def test_create_typescript_analyzer(self):
"""Test creating TypeScript analyzer."""
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
assert analyzer.language == TreeSitterLanguage.TYPESCRIPT
def test_create_with_string(self):
"""Test creating analyzer with string language name."""
analyzer = TreeSitterAnalyzer("javascript")
assert analyzer.language == TreeSitterLanguage.JAVASCRIPT
def test_lazy_parser_creation(self):
"""Test that parser is created lazily."""
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
assert analyzer._parser is None
# Access parser property
_ = analyzer.parser
assert analyzer._parser is not None
class TestGetAnalyzerForFile:
"""Tests for get_analyzer_for_file function."""
def test_js_file(self):
"""Test getting analyzer for .js file."""
analyzer = get_analyzer_for_file(Path("/test/file.js"))
assert analyzer.language == TreeSitterLanguage.JAVASCRIPT
def test_jsx_file(self):
"""Test getting analyzer for .jsx file."""
analyzer = get_analyzer_for_file(Path("/test/file.jsx"))
assert analyzer.language == TreeSitterLanguage.JAVASCRIPT
def test_ts_file(self):
"""Test getting analyzer for .ts file."""
analyzer = get_analyzer_for_file(Path("/test/file.ts"))
assert analyzer.language == TreeSitterLanguage.TYPESCRIPT
def test_tsx_file(self):
"""Test getting analyzer for .tsx file."""
analyzer = get_analyzer_for_file(Path("/test/file.tsx"))
assert analyzer.language == TreeSitterLanguage.TSX
def test_mjs_file(self):
"""Test getting analyzer for .mjs file."""
analyzer = get_analyzer_for_file(Path("/test/file.mjs"))
assert analyzer.language == TreeSitterLanguage.JAVASCRIPT
def test_cjs_file(self):
"""Test getting analyzer for .cjs file."""
analyzer = get_analyzer_for_file(Path("/test/file.cjs"))
assert analyzer.language == TreeSitterLanguage.JAVASCRIPT
class TestParsing:
"""Tests for parsing functionality."""
@pytest.fixture
def js_analyzer(self):
"""Create a JavaScript analyzer."""
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
def test_parse_simple_code(self, js_analyzer):
"""Test parsing simple JavaScript code."""
code = "const x = 1;"
tree = js_analyzer.parse(code)
assert tree.root_node is not None
assert not tree.root_node.has_error
def test_parse_bytes(self, js_analyzer):
"""Test parsing code as bytes."""
code = b"const x = 1;"
tree = js_analyzer.parse(code)
assert tree.root_node is not None
def test_parse_invalid_code(self, js_analyzer):
"""Test parsing invalid code marks errors."""
code = "function foo( {"
tree = js_analyzer.parse(code)
assert tree.root_node.has_error
def test_get_node_text(self, js_analyzer):
"""Test extracting text from a node."""
code = "const x = 1;"
code_bytes = code.encode("utf8")
tree = js_analyzer.parse(code_bytes)
text = js_analyzer.get_node_text(tree.root_node, code_bytes)
assert text == code
class TestFindFunctions:
"""Tests for find_functions method."""
@pytest.fixture
def js_analyzer(self):
"""Create a JavaScript analyzer."""
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
def test_find_function_declaration(self, js_analyzer):
"""Test finding function declarations."""
code = """
function add(a, b) {
return a + b;
}
"""
functions = js_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "add"
assert functions[0].is_arrow is False
assert functions[0].is_async is False
assert functions[0].is_method is False
def test_find_arrow_function(self, js_analyzer):
"""Test finding arrow functions."""
code = """
const add = (a, b) => {
return a + b;
};
"""
functions = js_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "add"
assert functions[0].is_arrow is True
def test_find_arrow_function_concise(self, js_analyzer):
"""Test finding concise arrow functions."""
code = "const double = x => x * 2;"
functions = js_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "double"
assert functions[0].is_arrow is True
def test_find_async_function(self, js_analyzer):
"""Test finding async functions."""
code = """
async function fetchData(url) {
return await fetch(url);
}
"""
functions = js_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "fetchData"
assert functions[0].is_async is True
def test_find_class_methods(self, js_analyzer):
"""Test finding class methods."""
code = """
class Calculator {
add(a, b) {
return a + b;
}
}
"""
functions = js_analyzer.find_functions(code, include_methods=True)
assert len(functions) == 1
assert functions[0].name == "add"
assert functions[0].is_method is True
assert functions[0].class_name == "Calculator"
def test_exclude_methods(self, js_analyzer):
"""Test excluding class methods."""
code = """
class Calculator {
add(a, b) {
return a + b;
}
}
function standalone() {
return 1;
}
"""
functions = js_analyzer.find_functions(code, include_methods=False)
assert len(functions) == 1
assert functions[0].name == "standalone"
def test_exclude_arrow_functions(self, js_analyzer):
"""Test excluding arrow functions."""
code = """
function regular() {
return 1;
}
const arrow = () => 2;
"""
functions = js_analyzer.find_functions(code, include_arrow_functions=False)
assert len(functions) == 1
assert functions[0].name == "regular"
def test_skip_object_literal_methods(self, js_analyzer):
"""Test that shorthand methods inside object literals are skipped."""
code = """
function getStrategies() {
return {
encrypt(data) {
return data;
},
decrypt(data) {
return data;
}
};
}
"""
functions = js_analyzer.find_functions(code, include_methods=True)
names = [f.name for f in functions]
assert "getStrategies" in names
assert "encrypt" not in names
assert "decrypt" not in names
def test_skip_module_exports_methods(self, js_analyzer):
"""Test that methods in module.exports object are skipped."""
code = """
module.exports = {
handler(req) {
return req;
}
};
"""
functions = js_analyzer.find_functions(code, include_methods=True)
assert len(functions) == 0
def test_class_methods_not_skipped(self, js_analyzer):
"""Test that class methods are NOT skipped (only object literal methods are)."""
code = """
class Encryptor {
encrypt(data) {
return data;
}
}
const strategies = {
decrypt(data) {
return data;
}
};
"""
functions = js_analyzer.find_functions(code, include_methods=True)
names = [f.name for f in functions]
assert "encrypt" in names
assert "decrypt" not in names
def test_find_generator_function(self, js_analyzer):
"""Test finding generator functions."""
code = """
function* numberGenerator() {
yield 1;
yield 2;
}
"""
functions = js_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "numberGenerator"
assert functions[0].is_generator is True
def test_function_line_numbers(self, js_analyzer):
"""Test that line numbers are correct."""
code = """function first() {
return 1;
}
function second() {
return 2;
}
"""
functions = js_analyzer.find_functions(code)
first = next(f for f in functions if f.name == "first")
second = next(f for f in functions if f.name == "second")
assert first.start_line == 1
assert first.end_line == 3
assert second.start_line == 5
assert second.end_line == 7
def test_nested_functions(self, js_analyzer):
"""Test finding nested functions."""
code = """
function outer() {
function inner() {
return 1;
}
return inner();
}
"""
functions = js_analyzer.find_functions(code)
assert len(functions) == 2
names = {f.name for f in functions}
assert names == {"outer", "inner"}
inner = next(f for f in functions if f.name == "inner")
assert inner.parent_function == "outer"
def test_require_name_filters_anonymous(self, js_analyzer):
"""Test that require_name filters anonymous functions."""
code = """
(function() {
return 1;
})();
function named() {
return 2;
}
"""
functions = js_analyzer.find_functions(code, require_name=True)
assert len(functions) == 1
assert functions[0].name == "named"
def test_function_expression_in_variable(self, js_analyzer):
"""Test function expression assigned to variable."""
code = """
const add = function(a, b) {
return a + b;
};
"""
functions = js_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "add"
class TestFindImports:
"""Tests for find_imports method."""
@pytest.fixture
def js_analyzer(self):
"""Create a JavaScript analyzer."""
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
def test_find_default_import(self, js_analyzer):
"""Test finding default import."""
code = "import React from 'react';"
imports = js_analyzer.find_imports(code)
assert len(imports) == 1
assert imports[0].module_path == "react"
assert imports[0].default_import == "React"
def test_find_named_imports(self, js_analyzer):
"""Test finding named imports."""
code = "import { useState, useEffect } from 'react';"
imports = js_analyzer.find_imports(code)
assert len(imports) == 1
assert imports[0].module_path == "react"
assert ("useState", None) in imports[0].named_imports
assert ("useEffect", None) in imports[0].named_imports
def test_find_namespace_import(self, js_analyzer):
"""Test finding namespace import."""
code = "import * as utils from './utils';"
imports = js_analyzer.find_imports(code)
assert len(imports) == 1
assert imports[0].module_path == "./utils"
assert imports[0].namespace_import == "utils"
def test_find_require(self, js_analyzer):
"""Test finding require() calls."""
code = "const fs = require('fs');"
imports = js_analyzer.find_imports(code)
assert len(imports) == 1
assert imports[0].module_path == "fs"
assert imports[0].default_import == "fs"
def test_require_inside_function_not_import(self, js_analyzer):
"""Test that require() inside functions is not treated as an import.
This is important because dynamic require() calls inside functions are
not module-level imports and should not be extracted as such.
"""
code = """
const fs = require('fs');
function loadModule() {
const dynamic = require('dynamic-module');
return dynamic;
}
class MyClass {
method() {
const inMethod = require('method-module');
}
}
"""
imports = js_analyzer.find_imports(code)
# Only the module-level require should be found
assert len(imports) == 1
assert imports[0].module_path == "fs"
def test_find_multiple_imports(self, js_analyzer):
"""Test finding multiple imports."""
code = """
import React from 'react';
import { useState } from 'react';
import * as utils from './utils';
const path = require('path');
"""
imports = js_analyzer.find_imports(code)
assert len(imports) == 4
modules = {imp.module_path for imp in imports}
assert modules == {"react", "./utils", "path"}
def test_import_with_alias(self, js_analyzer):
"""Test finding import with alias."""
code = "import { Component as Comp } from 'react';"
imports = js_analyzer.find_imports(code)
assert len(imports) == 1
assert ("Component", "Comp") in imports[0].named_imports
def test_relative_import(self, js_analyzer):
"""Test finding relative imports."""
code = "import { helper } from './helpers/utils';"
imports = js_analyzer.find_imports(code)
assert len(imports) == 1
assert imports[0].module_path == "./helpers/utils"
class TestFindFunctionCalls:
"""Tests for find_function_calls method."""
@pytest.fixture
def js_analyzer(self):
"""Create a JavaScript analyzer."""
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
def test_find_simple_calls(self, js_analyzer):
"""Test finding simple function calls."""
code = """
function helper() {
return 1;
}
function main() {
return helper() + 2;
}
"""
functions = js_analyzer.find_functions(code)
main_func = next(f for f in functions if f.name == "main")
calls = js_analyzer.find_function_calls(code, main_func)
assert "helper" in calls
def test_find_method_calls(self, js_analyzer):
"""Test finding method calls."""
code = """
function process(arr) {
return arr.map(x => x * 2).filter(x => x > 0);
}
"""
functions = js_analyzer.find_functions(code)
process_func = next(f for f in functions if f.name == "process")
calls = js_analyzer.find_function_calls(code, process_func)
assert "map" in calls
assert "filter" in calls
class TestHasReturnStatement:
"""Tests for has_return_statement method."""
@pytest.fixture
def js_analyzer(self):
"""Create a JavaScript analyzer."""
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
def test_function_with_return(self, js_analyzer):
"""Test function with return statement."""
code = """
function add(a, b) {
return a + b;
}
"""
functions = js_analyzer.find_functions(code)
assert js_analyzer.has_return_statement(functions[0], code) is True
def test_function_without_return(self, js_analyzer):
"""Test function without return statement."""
code = """
function log(msg) {
console.log(msg);
}
"""
functions = js_analyzer.find_functions(code, require_name=True)
func = next((f for f in functions if f.name == "log"), None)
if func:
assert js_analyzer.has_return_statement(func, code) is False
def test_arrow_function_implicit_return(self, js_analyzer):
"""Test arrow function with implicit return."""
code = "const double = x => x * 2;"
functions = js_analyzer.find_functions(code)
assert js_analyzer.has_return_statement(functions[0], code) is True
def test_arrow_function_explicit_return(self, js_analyzer):
"""Test arrow function with explicit return."""
code = """
const add = (a, b) => {
return a + b;
};
"""
functions = js_analyzer.find_functions(code)
assert js_analyzer.has_return_statement(functions[0], code) is True
class TestTypeScriptSupport:
"""Tests for TypeScript-specific features."""
@pytest.fixture
def ts_analyzer(self):
"""Create a TypeScript analyzer."""
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
def test_find_typed_function(self, ts_analyzer):
"""Test finding function with type annotations."""
code = """
function add(a: number, b: number): number {
return a + b;
}
"""
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "add"
def test_find_interface_method(self, ts_analyzer):
"""Test that interface methods are not found (they're declarations)."""
code = """
interface Calculator {
add(a: number, b: number): number;
}
function helper(): number {
return 1;
}
"""
functions = ts_analyzer.find_functions(code)
# Only the actual function should be found, not the interface method
names = {f.name for f in functions}
assert "helper" in names
def test_find_generic_function(self, ts_analyzer):
"""Test finding generic function."""
code = """
function identity<T>(value: T): T {
return value;
}
"""
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "identity"
class TestExportConstArrowFunctions:
"""Tests for export const arrow function pattern - Issue #10.
Modern TypeScript codebases commonly use:
- export const slugify = (str: string) => { return s; }
- export const uniqueBy = <T>(array: T[]) => { ... }
These must be correctly recognized as optimizable functions.
"""
@pytest.fixture
def ts_analyzer(self):
"""Create a TypeScript analyzer."""
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
def test_export_const_arrow_function_basic(self, ts_analyzer):
"""Test finding export const arrow function (basic pattern)."""
code = """export const slugify = (str: string) => {
return str.toLowerCase();
};"""
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "slugify"
assert functions[0].is_arrow is True
assert ts_analyzer.has_return_statement(functions[0], code) is True
def test_export_const_arrow_function_optional_param(self, ts_analyzer):
"""Test finding export const arrow function with optional parameter."""
code = """export const slugify = (str: string, forDisplayingInput?: boolean) => {
if (!str) {
return "";
}
const s = str.toLowerCase();
return forDisplayingInput ? s : s.replace(/-+$/, "");
};"""
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "slugify"
assert functions[0].is_arrow is True
assert ts_analyzer.has_return_statement(functions[0], code) is True
def test_export_const_generic_arrow_function(self, ts_analyzer):
"""Test finding export const arrow function with generics."""
code = """export const uniqueBy = <T extends { [key: string]: unknown }>(array: T[], keys: (keyof T)[]) => {
return array.filter(
(item, index, self) => index === self.findIndex((t) => keys.every((key) => t[key] === item[key]))
);
};"""
functions = ts_analyzer.find_functions(code)
# Should find uniqueBy, and possibly the inner arrow functions
uniqueBy = next((f for f in functions if f.name == "uniqueBy"), None)
assert uniqueBy is not None
assert uniqueBy.is_arrow is True
assert ts_analyzer.has_return_statement(uniqueBy, code) is True
def test_export_const_arrow_function_is_exported(self, ts_analyzer):
"""Test that export const arrow functions are recognized as exported."""
code = """export const slugify = (str: string) => {
return str.toLowerCase();
};"""
# Check is_function_exported
is_exported, export_name = ts_analyzer.is_function_exported(code, "slugify")
assert is_exported is True
assert export_name == "slugify"
def test_export_const_with_default_export(self, ts_analyzer):
"""Test export const with separate default export."""
code = """export const slugify = (str: string) => {
return str.toLowerCase();
};
export default slugify;"""
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "slugify"
# Should be exported both ways
is_named, named_name = ts_analyzer.is_function_exported(code, "slugify")
assert is_named is True
def test_multiple_export_const_functions(self, ts_analyzer):
"""Test multiple export const arrow functions in same file."""
code = """export const notUndefined = <T>(val: T | undefined): val is T => Boolean(val);
export const uniqueBy = <T extends { [key: string]: unknown }>(array: T[], keys: (keyof T)[]) => {
return array.filter(
(item, index, self) => index === self.findIndex((t) => keys.every((key) => t[key] === item[key]))
);
};"""
functions = ts_analyzer.find_functions(code)
# Find the top-level exported functions
names = {f.name for f in functions if f.parent_function is None}
assert "notUndefined" in names
assert "uniqueBy" in names
def test_export_const_arrow_with_implicit_return(self, ts_analyzer):
"""Test export const arrow function with implicit return."""
code = """export const double = (n: number) => n * 2;"""
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "double"
assert functions[0].is_arrow is True
assert ts_analyzer.has_return_statement(functions[0], code) is True
def test_export_const_async_arrow_function(self, ts_analyzer):
"""Test export const async arrow function."""
code = """export const fetchData = async (url: string) => {
const response = await fetch(url);
return response.json();
};"""
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "fetchData"
assert functions[0].is_arrow is True
assert functions[0].is_async is True
assert ts_analyzer.has_return_statement(functions[0], code) is True
def test_non_exported_const_not_exported(self, ts_analyzer):
"""Test that non-exported const functions are not marked as exported."""
code = """const privateFunc = (x: number) => {
return x * 2;
};
export const publicFunc = (x: number) => {
return privateFunc(x);
};"""
# privateFunc should not be exported
is_private_exported, _ = ts_analyzer.is_function_exported(code, "privateFunc")
assert is_private_exported is False
# publicFunc should be exported
is_public_exported, name = ts_analyzer.is_function_exported(code, "publicFunc")
assert is_public_exported is True
assert name == "publicFunc"
class TestWrappedDefaultExports:
"""Tests for wrapped default export pattern - Issue #9.
Handles patterns like:
- export default curry(traverseEntity)
- export default compose(fn1, fn2)
- export default wrapper(myFunc)
These must be correctly recognized so the wrapped function is exportable.
"""
@pytest.fixture
def ts_analyzer(self):
"""Create a TypeScript analyzer."""
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
def test_curry_wrapped_export(self, ts_analyzer):
"""Test export default curry(fn) pattern."""
code = """import { curry } from 'lodash/fp';
const traverseEntity = async (visitor, options, entity) => {
return entity;
};
export default curry(traverseEntity);"""
# Check exports parsing
exports = ts_analyzer.find_exports(code)
assert len(exports) == 1
assert exports[0].default_export == "default"
assert exports[0].wrapped_default_args == ["traverseEntity"]
# Check is_function_exported
is_exported, export_name = ts_analyzer.is_function_exported(code, "traverseEntity")
assert is_exported is True
assert export_name == "default"
def test_compose_wrapped_export(self, ts_analyzer):
"""Test export default compose(fn1, fn2) pattern with multiple args."""
code = """import { compose } from 'lodash/fp';
function validateInput(data) { return data; }
function processData(data) { return data; }
export default compose(validateInput, processData);"""
exports = ts_analyzer.find_exports(code)
assert len(exports) == 1
assert exports[0].wrapped_default_args == ["validateInput", "processData"]
# Both functions should be recognized as exported
is_exported1, _ = ts_analyzer.is_function_exported(code, "validateInput")
is_exported2, _ = ts_analyzer.is_function_exported(code, "processData")
assert is_exported1 is True
assert is_exported2 is True
def test_nested_wrapper_export(self, ts_analyzer):
"""Test nested wrapper: export default compose(curry(fn))."""
code = """export default compose(curry(myFunc));"""
exports = ts_analyzer.find_exports(code)
assert len(exports) == 1
assert "myFunc" in exports[0].wrapped_default_args
is_exported, _ = ts_analyzer.is_function_exported(code, "myFunc")
assert is_exported is True
def test_generic_wrapper_export(self, ts_analyzer):
"""Test generic wrapper function."""
code = """const myFunction = (x: number) => x * 2;
export default someWrapper(myFunction);"""
is_exported, export_name = ts_analyzer.is_function_exported(code, "myFunction")
assert is_exported is True
assert export_name == "default"
def test_non_wrapped_function_not_exported(self, ts_analyzer):
"""Test that functions not in the wrapper call are not exported."""
code = """const helper = (x: number) => x + 1;
const main = (x: number) => helper(x) * 2;
export default curry(main);"""
# main is wrapped, so it's exported
is_main_exported, _ = ts_analyzer.is_function_exported(code, "main")
assert is_main_exported is True
# helper is NOT in the wrapper call, so not exported
is_helper_exported, _ = ts_analyzer.is_function_exported(code, "helper")
assert is_helper_exported is False
def test_direct_default_export_still_works(self, ts_analyzer):
"""Test that direct default exports still work."""
code = """function myFunc() { return 1; }
export default myFunc;"""
is_exported, export_name = ts_analyzer.is_function_exported(code, "myFunc")
assert is_exported is True
assert export_name == "default"
def test_strapi_traverse_entity_pattern(self, ts_analyzer):
"""Test the exact strapi pattern that was failing."""
code = """import { curry } from 'lodash/fp';
const traverseEntity = async (visitor: Visitor, options: TraverseOptions, entity: Data) => {
const { path = { raw: null }, schema, getModel } = options;
// ... implementation
return copy;
};
const createVisitorUtils = ({ data }: { data: Data }) => ({
remove(key: string) { delete data[key]; },
set(key: string, value: Data) { data[key] = value; },
});
export default curry(traverseEntity);"""
# traverseEntity should be recognized as exported
is_exported, export_name = ts_analyzer.is_function_exported(code, "traverseEntity")
assert is_exported is True
assert export_name == "default"
# createVisitorUtils is NOT wrapped, so not exported via default
is_utils_exported, _ = ts_analyzer.is_function_exported(code, "createVisitorUtils")
assert is_utils_exported is False
class TestNamedExportConstArrow:
"""Tests for const arrow functions exported via named export clause.
Pattern: const joinBy = () => {}; export { joinBy };
This is common in TypeScript codebases like Strapi.
"""
@pytest.fixture
def ts_analyzer(self):
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
@pytest.fixture
def js_analyzer(self):
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
def test_named_export_const_arrow(self, ts_analyzer):
"""Const arrow function exported via separate export { } clause."""
code = """const joinBy = (arr: string[], separator: string) => {
return arr.join(separator);
};
export { joinBy };"""
functions = ts_analyzer.find_functions(code)
joinBy = next((f for f in functions if f.name == "joinBy"), None)
assert joinBy is not None
assert joinBy.is_exported is True
def test_named_export_alias(self, ts_analyzer):
"""Export { foo as bar } — foo should be marked as exported."""
code = """const foo = (x: number) => {
return x * 2;
};
export { foo as bar };"""
functions = ts_analyzer.find_functions(code)
foo = next((f for f in functions if f.name == "foo"), None)
assert foo is not None
assert foo.is_exported is True
def test_named_export_multiple(self, ts_analyzer):
"""Multiple functions in a single export clause."""
code = """const a = () => { return 1; };
const b = () => { return 2; };
const c = () => { return 3; };
export { a, b };"""
functions = ts_analyzer.find_functions(code)
a = next((f for f in functions if f.name == "a"), None)
b = next((f for f in functions if f.name == "b"), None)
c = next((f for f in functions if f.name == "c"), None)
assert a is not None and a.is_exported is True
assert b is not None and b.is_exported is True
assert c is not None and c.is_exported is False
def test_named_export_function_declaration(self, js_analyzer):
"""Regular function declarations exported via export { }."""
code = """function processData(data) {
return data;
}
export { processData };"""
functions = js_analyzer.find_functions(code)
f = next((f for f in functions if f.name == "processData"), None)
assert f is not None
assert f.is_exported is True
def test_is_function_exported_with_named_export(self, ts_analyzer):
"""is_function_exported should detect named export clause."""
code = """const joinBy = (arr: string[], separator: string) => {
return arr.join(separator);
};
export { joinBy };"""
is_exported, name = ts_analyzer.is_function_exported(code, "joinBy")
assert is_exported is True
class TestCjsReexportObjectMethods:
"""Tests for CJS re-export of object containing methods.
Pattern: const utils = { match() {} }; module.exports = utils;
This is common in Node.js libraries like Moleculer.
"""
@pytest.fixture
def js_analyzer(self):
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
def test_cjs_reexport_object_methods(self, js_analyzer):
"""module.exports = varName where varName is object with methods."""
code = """const utils = {
match(text, pattern) {
return text.match(pattern);
},
slugify(str) {
return str.toLowerCase();
}
};
module.exports = utils;"""
is_exported, name = js_analyzer.is_function_exported(code, "match")
assert is_exported is True
is_exported2, _ = js_analyzer.is_function_exported(code, "slugify")
assert is_exported2 is True
def test_cjs_reexport_shorthand_props(self, js_analyzer):
"""module.exports = varName where object has shorthand properties."""
code = """function match(text, pattern) {
return text.match(pattern);
}
const utils = { match };
module.exports = utils;"""
is_exported, _ = js_analyzer.is_function_exported(code, "match")
assert is_exported is True
def test_cjs_reexport_pair_props(self, js_analyzer):
"""module.exports = varName where object has key: value pairs."""
code = """function myMatch(text, pattern) {
return text.match(pattern);
}
const utils = { match: myMatch };
module.exports = utils;"""
is_exported, _ = js_analyzer.is_function_exported(code, "match")
assert is_exported is True
def test_cjs_reexport_nonexistent_prop(self, js_analyzer):
"""A function not in the re-exported object should not be exported."""
code = """function helper() { return 1; }
const utils = {
match(text) { return text; }
};
module.exports = utils;"""
is_exported, _ = js_analyzer.is_function_exported(code, "helper")
assert is_exported is False