Merge branch 'main' of github.com:codeflash-ai/codeflash into fix/js-jest30-loop-runner

This commit is contained in:
ali 2026-02-12 16:39:33 +02:00
commit e9b7154361
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
30 changed files with 2469 additions and 419 deletions

View file

@ -48,7 +48,7 @@ jobs:
with:
use_foundry: "true"
use_sticky_comment: true
allowed_bots: "claude[bot]"
allowed_bots: "claude[bot],codeflash-ai[bot]"
prompt: |
REPO: ${{ github.repository }}
PR NUMBER: ${{ github.event.pull_request.number }}

View file

@ -27,12 +27,29 @@ uv run ruff format codeflash/ # Format
# Linting (run before committing)
uv run prek run --from-ref origin/main
# Mypy type checking (run on changed files before committing)
uv run mypy --non-interactive --config-file pyproject.toml <changed_files>
# Running the CLI
uv run codeflash --help
uv run codeflash init # Initialize in a project
uv run codeflash --all # Optimize entire codebase
```
## Mypy Type Checking
When modifying code, fix any mypy type errors in the files you changed. Run mypy on changed files:
```bash
uv run mypy --non-interactive --config-file pyproject.toml <changed_files>
```
Rules:
- Fix type annotation issues: missing return types, incorrect types, Optional/None unions, import errors for type hints
- Do NOT add `# type: ignore` comments — always fix the root cause
- Do NOT fix type errors that require logic changes, complex generic type rework, or anything that could change runtime behavior
- Files in `mypy_allowlist.txt` are checked in CI — ensure they remain error-free
<!-- Section below is auto-generated by `tessl install` - do not edit manually -->
# Agent Rules <!-- tessl-managed -->

View file

@ -386,7 +386,7 @@ class JavaScriptTransformer:
from pathlib import Path
from codeflash.languages.base import LanguageSupport, FunctionInfo, CodeContext
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer
from codeflash.languages.javascript.transformer import JavaScriptTransformer
class JavaScriptSupport(LanguageSupport):
@ -523,7 +523,7 @@ class JavaScriptSupport(LanguageSupport):
# codeflash/languages/javascript/test_discovery.py
from pathlib import Path
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer
class JestTestDiscovery:
"""Static analysis-based test discovery for Jest."""

View file

@ -1772,7 +1772,7 @@ def _extract_calling_function_js(source_code: str, function_name: str, ref_line:
"""
try:
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
# Try TypeScript first, fall back to JavaScript
for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]:

View file

@ -26,7 +26,7 @@ if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language, LanguageSupport
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
@ -640,7 +640,7 @@ def _add_global_declarations_for_language(
return original_source
try:
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(module_abspath)

View file

@ -105,12 +105,12 @@ def clean_concolic_tests(test_suite_code: str) -> str:
can_parse = False
tree = None
if not can_parse:
if not can_parse or tree is None:
return AssertCleanup().transform_asserts(test_suite_code)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"):
new_body = []
new_body: list[ast.stmt] = []
for stmt in node.body:
if isinstance(stmt, ast.Assert):
if isinstance(stmt.test, ast.Compare) and isinstance(stmt.test.left, ast.Call):

View file

@ -13,14 +13,28 @@ if TYPE_CHECKING:
def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]:
"""Extract the single dependent function from the code context excluding the main function."""
dependent_functions = set()
for code_string in code_context.testgen_context.code_strings:
ast_tree = ast.parse(code_string.code)
dependent_functions.update(
{node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))}
)
if main_function in dependent_functions:
dependent_functions.discard(main_function)
# Compare using bare name since AST extracts bare function names
bare_main = main_function.rsplit(".", 1)[-1] if "." in main_function else main_function
for code_string in code_context.testgen_context.code_strings:
# Quick heuristic: skip parsing entirely if there is no 'def' token,
# since no function definitions can be present without it.
if "def" not in code_string.code:
continue
ast_tree = ast.parse(code_string.code)
# Add function names directly, skipping the bare main name.
for node in ast_tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
name = node.name
if name == bare_main:
continue
dependent_functions.add(name)
# If more than one dependent function (other than the main) is found,
# we can return False early since the final result cannot be a single name.
if len(dependent_functions) > 1:
return False
if not dependent_functions:
return False
@ -32,6 +46,9 @@ def extract_dependent_function(main_function: str, code_context: CodeOptimizatio
def build_fully_qualified_name(function_name: str, code_context: CodeOptimizationContext) -> str:
# If the name is already qualified (contains a dot), return as-is
if "." in function_name:
return function_name
full_name = function_name
for obj_name, parents in code_context.preexisting_objects:
if obj_name == function_name:

View file

@ -233,7 +233,7 @@ class JavaScriptNormalizer(CodeNormalizer):
"""
try:
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
lang_map = {"javascript": TreeSitterLanguage.JAVASCRIPT, "typescript": TreeSitterLanguage.TYPESCRIPT}
lang = lang_map.get(self._get_tree_sitter_language(), TreeSitterLanguage.JAVASCRIPT)

View file

@ -201,7 +201,7 @@ def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bo
Tuple of (is_exported, export_name). export_name may be 'default' for default exports.
"""
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
try:
source = file_path.read_text(encoding="utf-8")

View file

@ -26,10 +26,10 @@ class PrComment:
def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]:
report_table: dict[str, dict[str, int]] = {}
for test_type, test_result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items():
for test_type, counts in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items():
name = test_type.to_name()
if name:
report_table[name] = test_result
report_table[name] = counts
result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
"optimization_explanation": self.optimization_explanation,

View file

@ -23,7 +23,7 @@ if TYPE_CHECKING:
from tree_sitter import Node
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
from codeflash.languages.javascript.treesitter import ImportInfo, TreeSitterAnalyzer
logger = logging.getLogger(__name__)
@ -112,7 +112,7 @@ class ReferenceFinder:
List of Reference objects describing each call site.
"""
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
function_name = function_to_optimize.function_name
source_file = function_to_optimize.file_path
@ -168,7 +168,7 @@ class ReferenceFinder:
if import_info:
# Found an import - mark as visited and search for calls
context.visited_files.add(file_path)
import_name, original_import = import_info
import_name, _original_import = import_info
file_refs = self._find_references_in_file(
file_path, file_code, function_name, import_name, file_analyzer, include_self=True
)
@ -213,7 +213,7 @@ class ReferenceFinder:
trigger_check = True
if import_info:
context.visited_files.add(file_path)
import_name, original_import = import_info
import_name, _original_import = import_info
file_refs = self._find_references_in_file(
file_path, file_code, reexport_name, import_name, file_analyzer, include_self=True
)
@ -404,7 +404,7 @@ class ReferenceFinder:
name_node = node.child_by_field_name("name")
if name_node:
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
elif node.type in ("variable_declarator",):
elif node.type == "variable_declarator":
# Arrow function or function expression assigned to variable
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
@ -719,7 +719,7 @@ class ReferenceFinder:
continue
# Create a fake ImportInfo to resolve the re-export source
from codeflash.languages.treesitter_utils import ImportInfo
from codeflash.languages.javascript.treesitter import ImportInfo
fake_import = ImportInfo(
module_path=exp.reexport_source,

View file

@ -14,7 +14,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import HelperFunction
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
from codeflash.languages.javascript.treesitter import ImportInfo, TreeSitterAnalyzer
logger = logging.getLogger(__name__)
@ -486,7 +486,7 @@ class MultiFileHelperFinder:
"""
from codeflash.languages.base import HelperFunction
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
try:
source = file_path.read_text(encoding="utf-8")
@ -558,8 +558,8 @@ class MultiFileHelperFinder:
"""
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
from codeflash.languages.registry import get_language_support
from codeflash.languages.treesitter_utils import get_analyzer_for_file
if context.current_depth >= context.max_depth:
return {}

View file

@ -861,7 +861,7 @@ def validate_and_fix_import_style(test_code: str, source_file_path: Path, functi
Fixed test code with correct import style.
"""
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
# Read source file to determine export style
try:

View file

@ -11,7 +11,7 @@ import json
import logging
from typing import TYPE_CHECKING
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
if TYPE_CHECKING:
from pathlib import Path

View file

@ -14,15 +14,15 @@ from typing import TYPE_CHECKING, Any
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, Language, TestInfo, TestResult
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
from codeflash.languages.registry import register_language
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
from codeflash.models.models import FunctionParent
if TYPE_CHECKING:
from collections.abc import Sequence
from codeflash.languages.base import ReferenceInfo
from codeflash.languages.treesitter_utils import TypeDefinition
from codeflash.languages.javascript.treesitter import TypeDefinition
logger = logging.getLogger(__name__)

View file

@ -95,6 +95,9 @@ class ExportInfo:
reexport_source: str | None # Module path for re-exports
start_line: int
end_line: int
# Functions passed as arguments to wrapper calls in default exports
# e.g., export default curry(traverseEntity) -> ["traverseEntity"]
wrapped_default_args: list[str] | None = None
@dataclass
@ -864,6 +867,7 @@ class TreeSitterAnalyzer:
default_export: str | None = None
is_reexport = False
reexport_source: str | None = None
wrapped_default_args: list[str] | None = None
# Check for re-export source (export { x } from './other')
source_node = node.child_by_field_name("source")
@ -883,6 +887,12 @@ class TreeSitterAnalyzer:
default_export = self.get_node_text(sibling, source_bytes)
elif sibling.type in ("arrow_function", "function_expression", "object", "array"):
default_export = "default"
elif sibling.type == "call_expression":
# Handle wrapped exports: export default curry(traverseEntity)
# The default export is the result of the call, but we track
# the wrapped function names for export checking
default_export = "default"
wrapped_default_args = self._extract_call_expression_identifiers(sibling, source_bytes)
break
# Handle named exports: export { a, b as c }
@ -930,8 +940,37 @@ class TreeSitterAnalyzer:
reexport_source=reexport_source,
start_line=node.start_point[0] + 1,
end_line=node.end_point[0] + 1,
wrapped_default_args=wrapped_default_args,
)
def _extract_call_expression_identifiers(self, node: Node, source_bytes: bytes) -> list[str]:
"""Extract identifier names from arguments of a call expression.
For patterns like curry(traverseEntity) or compose(fn1, fn2), this extracts
the function names passed as arguments: ["traverseEntity"] or ["fn1", "fn2"].
Args:
node: A call_expression node.
source_bytes: The source code as bytes.
Returns:
List of identifier names found in the call arguments.
"""
identifiers: list[str] = []
# Get the arguments node
args_node = node.child_by_field_name("arguments")
if args_node:
for child in args_node.children:
if child.type == "identifier":
identifiers.append(self.get_node_text(child, source_bytes))
# Also handle nested call expressions: compose(curry(fn))
elif child.type == "call_expression":
identifiers.extend(self._extract_call_expression_identifiers(child, source_bytes))
return identifiers
def _extract_commonjs_export(self, node: Node, source_bytes: bytes) -> ExportInfo | None:
"""Extract export information from CommonJS module.exports or exports.* patterns.
@ -1033,6 +1072,7 @@ class TreeSitterAnalyzer:
"""Check if a function is exported and get its export name.
For class methods, also checks if the containing class is exported.
Also handles wrapped exports like: export default curry(traverseEntity)
Args:
source: The source code to analyze.
@ -1058,6 +1098,11 @@ class TreeSitterAnalyzer:
if name == function_name:
return (True, alias if alias else name)
# Check wrapped default exports: export default curry(traverseEntity)
# The function is exported via wrapper, so it's accessible as "default"
if export.wrapped_default_args and function_name in export.wrapped_default_args:
return (True, "default")
# For class methods, check if the containing class is exported
if class_name:
for export in exports:

View file

@ -2793,7 +2793,7 @@ class FunctionOptimizer:
test_config=self.test_cfg,
optimization_iteration=optimization_iteration,
run_result=run_result,
function_name=self.function_to_optimize.function_name,
function_name=self.function_to_optimize.qualified_name,
source_file=self.function_to_optimize.file_path,
code_context=code_context,
coverage_database_file=coverage_database_file,

View file

@ -33,7 +33,7 @@ def pytest_split(
except ImportError:
return None, None
test_files = set()
test_files_set: set[str] = set()
# Find all test_*.py files recursively in the directory
for test_path in test_paths:
@ -42,12 +42,12 @@ def pytest_split(
return None, None
if _test_path.is_dir():
# Find all test files matching the pattern test_*.py
test_files.update(map(str, _test_path.rglob("test_*.py")))
test_files.update(map(str, _test_path.rglob("*_test.py")))
test_files_set.update(map(str, _test_path.rglob("test_*.py")))
test_files_set.update(map(str, _test_path.rglob("*_test.py")))
elif _test_path.is_file():
test_files.add(str(_test_path))
test_files_set.add(str(_test_path))
if not test_files:
if not test_files_set:
return [[]], None
# Determine number of splits
@ -55,7 +55,7 @@ def pytest_split(
num_splits = os.cpu_count() or 4
# randomize to increase chances of all splits being balanced
test_files = list(test_files)
test_files = list(test_files_set)
shuffle(test_files)
# Apply limit if specified
@ -75,7 +75,7 @@ def pytest_split(
chunk_size = ceil(total_files / num_splits)
# Initialize result groups
result_groups = [[] for _ in range(num_splits)]
result_groups: list[list[str]] = [[] for _ in range(num_splits)]
# Distribute files across groups
for i, test_file in enumerate(test_files):

View file

@ -6,6 +6,7 @@ import enum
import math
import re
import types
import weakref
from collections import ChainMap, OrderedDict, deque
from importlib.util import find_spec
from typing import Any, Optional
@ -25,6 +26,7 @@ HAS_JAX = find_spec("jax") is not None
HAS_XARRAY = find_spec("xarray") is not None
HAS_TENSORFLOW = find_spec("tensorflow") is not None
HAS_NUMBA = find_spec("numba") is not None
HAS_PYARROW = find_spec("pyarrow") is not None
# Pattern to match pytest temp directories: /tmp/pytest-of-<user>/pytest-<N>/
# These paths vary between test runs but are logically equivalent
@ -93,7 +95,7 @@ def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # no
return _extract_exception_from_message(str(exc))
def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
try:
# Handle exceptions specially - before type check to allow wrapper comparison
@ -171,6 +173,17 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
return True
return math.isclose(orig, new)
# Handle weak references (e.g., found in torch.nn.LSTM/GRU modules)
if isinstance(orig, weakref.ref):
orig_referent = orig()
new_referent = new()
# Both dead refs are equal, otherwise compare referents
if orig_referent is None and new_referent is None:
return True
if orig_referent is None or new_referent is None:
return False
return comparator(orig_referent, new_referent, superset_obj)
if HAS_JAX:
import jax # type: ignore # noqa: PGH003
import jax.numpy as jnp # type: ignore # noqa: PGH003
@ -342,13 +355,57 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
return False
return (orig != new).nnz == 0
if HAS_PYARROW:
import pyarrow as pa # type: ignore # noqa: PGH003
if isinstance(orig, pa.Table):
if orig.schema != new.schema:
return False
if orig.num_rows != new.num_rows:
return False
return bool(orig.equals(new))
if isinstance(orig, pa.RecordBatch):
if orig.schema != new.schema:
return False
if orig.num_rows != new.num_rows:
return False
return bool(orig.equals(new))
if isinstance(orig, pa.ChunkedArray):
if orig.type != new.type:
return False
if len(orig) != len(new):
return False
return bool(orig.equals(new))
if isinstance(orig, pa.Array):
if orig.type != new.type:
return False
if len(orig) != len(new):
return False
return bool(orig.equals(new))
if isinstance(orig, pa.Scalar):
if orig.type != new.type:
return False
# Handle null scalars
if not orig.is_valid and not new.is_valid:
return True
if not orig.is_valid or not new.is_valid:
return False
return bool(orig.equals(new))
if isinstance(orig, (pa.Schema, pa.Field, pa.DataType)):
return bool(orig.equals(new))
if HAS_PANDAS:
import pandas # noqa: ICN001
if isinstance(
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
):
return orig.equals(new)
return bool(orig.equals(new))
if isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
return orig == new
@ -395,10 +452,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
return orig == new
if HAS_NUMBA:
import numba # type: ignore # noqa: PGH003
from numba.core.dispatcher import Dispatcher # type: ignore # noqa: PGH003
from numba.typed import Dict as NumbaDict # type: ignore # noqa: PGH003
from numba.typed import List as NumbaList # type: ignore # noqa: PGH003
import numba
from numba.core.dispatcher import Dispatcher
from numba.typed import Dict as NumbaDict
from numba.typed import List as NumbaList
# Handle numba typed List
if isinstance(orig, NumbaList):

View file

@ -353,7 +353,9 @@ class CoverageUtils:
for file in files:
functions = files[file]["functions"]
for function in functions:
if dependent_function_name in function:
if function == dependent_function_name or (
"." in dependent_function_name and function.endswith(f".{dependent_function_name}")
):
return FunctionCoverage(
name=dependent_function_name,
coverage=functions[function]["summary"]["percent_covered"],

View file

@ -1,2 +1,2 @@
# These version placeholders will be replaced by uv-dynamic-versioning during build.
__version__ = "0.20.0"
__version__ = "0.20.0.post510.dev0+b8932209"

View file

@ -106,40 +106,63 @@ function installUv() {
}
}
/**
* Check if git is available
*/
function hasGit() {
try {
const result = spawnSync('git', ['--version'], {
stdio: 'ignore',
shell: true,
});
return result.status === 0;
} catch {
return false;
}
}
/**
* Install codeflash Python CLI using uv tool
*
* Installation priority:
* 1. GitHub main branch (if git available) - gets latest features
* 2. PyPI (fallback) - stable release
*
* We prefer GitHub because it has the latest JS/TS support that may not
* be published to PyPI yet. uv handles cloning internally in its cache.
*/
function installCodeflash(uvBin) {
logStep('2/3', 'Installing codeflash Python CLI...');
const GITHUB_REPO = 'git+https://github.com/codeflash-ai/codeflash.git';
// Priority 1: Install from GitHub (latest features, requires git)
if (hasGit()) {
try {
execSync(`"${uvBin}" tool install --force --python python3.12 "${GITHUB_REPO}"`, {
stdio: 'inherit',
shell: true,
});
logSuccess('codeflash CLI installed from GitHub (latest)');
return true;
} catch (error) {
logWarning(`GitHub installation failed: ${error.message}`);
logWarning('Falling back to PyPI...');
}
} else {
logWarning('Git not found, installing from PyPI...');
}
// Priority 2: Install from PyPI (stable release fallback)
try {
// Use uv tool install to install codeflash in an isolated environment
// This avoids conflicts with any existing Python environments
execSync(`"${uvBin}" tool install --force --python python3.12 codeflash`, {
stdio: 'inherit',
shell: true,
});
logSuccess('codeflash CLI installed successfully');
logSuccess('codeflash CLI installed from PyPI');
return true;
} catch (error) {
// If codeflash is not on PyPI yet, try installing from the local package
logWarning('codeflash not found on PyPI, trying local installation...');
try {
// Try installing from the current codeflash repo if we're in development
const cliRoot = path.resolve(__dirname, '..', '..', '..');
const pyprojectPath = path.join(cliRoot, 'pyproject.toml');
if (fs.existsSync(pyprojectPath)) {
execSync(`"${uvBin}" tool install --force "${cliRoot}"`, {
stdio: 'inherit',
shell: true,
});
logSuccess('codeflash CLI installed from local source');
return true;
}
} catch (localError) {
logError(`Failed to install codeflash: ${localError.message}`);
}
logError(`Failed to install codeflash: ${error.message}`);
return false;
}
}

View file

@ -46,7 +46,7 @@ dependencies = [
"pygls>=2.0.0,<3.0.0",
"codeflash-benchmark",
"filelock",
"pytest-asyncio>=1.2.0",
"pytest-asyncio>=0.18.0",
]
[project.urls]
@ -87,6 +87,7 @@ tests = [
"jax>=0.4.30",
"numpy>=2.0.2",
"pandas>=2.3.3",
"pyarrow>=15.0.0",
"pyrsistent>=0.20.0",
"scipy>=1.13.1",
"torch>=2.8.0",

View file

@ -0,0 +1,228 @@
from __future__ import annotations
from typing import Any
from codeflash.code_utils.coverage_utils import build_fully_qualified_name, extract_dependent_function
from codeflash.models.function_types import FunctionParent
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown
from codeflash.verification.coverage_utils import CoverageUtils
def _make_code_context(
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
testgen_code_strings: list[CodeString] | None = None,
) -> CodeOptimizationContext:
"""Helper to create a minimal CodeOptimizationContext for testing."""
return CodeOptimizationContext(
testgen_context=CodeStringsMarkdown(code_strings=testgen_code_strings or []),
read_writable_code=CodeStringsMarkdown(),
helper_functions=[],
preexisting_objects=preexisting_objects,
)
class TestBuildFullyQualifiedName:
def test_bare_name_with_class_parent(self) -> None:
ctx = _make_code_context({("__init__", (FunctionParent(name="HttpInterface", type="ClassDef"),))})
assert build_fully_qualified_name("__init__", ctx) == "HttpInterface.__init__"
def test_bare_name_no_parent(self) -> None:
ctx = _make_code_context({("helper_func", ())})
assert build_fully_qualified_name("helper_func", ctx) == "helper_func"
def test_already_qualified_name_returned_as_is(self) -> None:
"""If name already contains a dot, skip preexisting_objects lookup."""
ctx = _make_code_context({("__init__", (FunctionParent(name="WrongClass", type="ClassDef"),))})
result = build_fully_qualified_name("HttpInterface.__init__", ctx)
assert result == "HttpInterface.__init__"
def test_bare_name_picks_first_match_from_set(self) -> None:
"""With multiple __init__ entries, bare name picks an arbitrary one."""
ctx = _make_code_context(
{
("__init__", (FunctionParent(name="ClassA", type="ClassDef"),)),
("__init__", (FunctionParent(name="ClassB", type="ClassDef"),)),
}
)
result = build_fully_qualified_name("__init__", ctx)
assert result in {"ClassA.__init__", "ClassB.__init__"}
def test_qualified_name_avoids_ambiguity(self) -> None:
"""Qualified name bypasses preexisting_objects entirely, avoiding ambiguity."""
ctx = _make_code_context(
{
("__init__", (FunctionParent(name="ClassA", type="ClassDef"),)),
("__init__", (FunctionParent(name="ClassB", type="ClassDef"),)),
}
)
assert build_fully_qualified_name("ClassB.__init__", ctx) == "ClassB.__init__"
def test_bare_name_not_in_preexisting_objects(self) -> None:
ctx = _make_code_context(set())
assert build_fully_qualified_name("some_func", ctx) == "some_func"
def test_nested_class_parent(self) -> None:
"""Bare name under nested class parents gets fully qualified."""
ctx = _make_code_context(
{("method", (FunctionParent(name="Outer", type="ClassDef"), FunctionParent(name="Inner", type="ClassDef")))}
)
assert build_fully_qualified_name("method", ctx) == "Inner.Outer.method"
def test_non_classdef_parent_ignored(self) -> None:
"""Only ClassDef parents are prepended to the name."""
ctx = _make_code_context({("helper", (FunctionParent(name="wrapper", type="FunctionDef"),))})
assert build_fully_qualified_name("helper", ctx) == "helper"
class TestExtractDependentFunction:
def test_single_dependent_function(self) -> None:
ctx = _make_code_context(
preexisting_objects={("helper", ())},
testgen_code_strings=[CodeString(code="def main_func(): pass\ndef helper(): pass")],
)
result = extract_dependent_function("main_func", ctx)
assert result == "helper"
def test_qualified_main_function_discards_bare_match(self) -> None:
"""Qualified main_function should still discard the matching bare name."""
ctx = _make_code_context(
preexisting_objects={("helper", ())},
testgen_code_strings=[CodeString(code="def __init__(): pass\ndef helper(): pass")],
)
result = extract_dependent_function("HttpInterface.__init__", ctx)
assert result == "helper"
def test_bare_main_function_discards_match(self) -> None:
"""Bare main_function should still work for discarding."""
ctx = _make_code_context(
preexisting_objects={("helper", ())},
testgen_code_strings=[CodeString(code="def main_func(): pass\ndef helper(): pass")],
)
result = extract_dependent_function("main_func", ctx)
assert result == "helper"
def test_no_dependent_functions(self) -> None:
ctx = _make_code_context(preexisting_objects=set(), testgen_code_strings=[CodeString(code="x = 1\n")])
result = extract_dependent_function("main_func", ctx)
assert result is False
def test_multiple_dependent_functions_returns_false(self) -> None:
ctx = _make_code_context(
preexisting_objects=set(),
testgen_code_strings=[CodeString(code="def helper_a(): pass\ndef helper_b(): pass")],
)
result = extract_dependent_function("main_func", ctx)
assert result is False
def test_dependent_function_gets_qualified(self) -> None:
"""The dependent function returned should be qualified via build_fully_qualified_name."""
ctx = _make_code_context(
preexisting_objects={("helper", (FunctionParent(name="MyClass", type="ClassDef"),))},
testgen_code_strings=[CodeString(code="def main_func(): pass\ndef helper(): pass")],
)
result = extract_dependent_function("main_func", ctx)
assert result == "MyClass.helper"
def test_only_main_in_code_returns_false(self) -> None:
"""When code only contains the main function, no dependent function exists."""
ctx = _make_code_context(
preexisting_objects=set(), testgen_code_strings=[CodeString(code="def __init__(): pass")]
)
result = extract_dependent_function("HttpInterface.__init__", ctx)
assert result is False
def test_async_functions_extracted(self) -> None:
"""Async function definitions are also extracted as dependent functions."""
ctx = _make_code_context(
preexisting_objects={("async_helper", ())},
testgen_code_strings=[CodeString(code="def main(): pass\nasync def async_helper(): pass")],
)
result = extract_dependent_function("main", ctx)
assert result == "async_helper"
class TestGrabDependentFunctionFromCoverageData:
def _make_func_data(self, coverage_pct: float = 80.0) -> dict[str, Any]:
return {
"summary": {"percent_covered": coverage_pct},
"executed_lines": [1, 2, 3],
"missing_lines": [4],
"executed_branches": [[1, 0]],
"missing_branches": [[2, 1]],
}
def test_exact_match_in_coverage_data(self) -> None:
coverage_data = {"HttpInterface.__init__": self._make_func_data(90.0)}
result = CoverageUtils.grab_dependent_function_from_coverage_data("HttpInterface.__init__", coverage_data, {})
assert result.name == "HttpInterface.__init__"
assert result.coverage == 90.0
def test_fallback_exact_match_in_original_data(self) -> None:
original_cov_data = {
"files": {"http_api.py": {"functions": {"HttpInterface.__init__": self._make_func_data(75.0)}}}
}
result = CoverageUtils.grab_dependent_function_from_coverage_data(
"HttpInterface.__init__", {}, original_cov_data
)
assert result.name == "HttpInterface.__init__"
assert result.coverage == 75.0
def test_fallback_suffix_match_in_original_data(self) -> None:
"""Qualified dependent name matches via suffix in original coverage data."""
original_cov_data = {
"files": {"http_api.py": {"functions": {"module.HttpInterface.__init__": self._make_func_data(60.0)}}}
}
result = CoverageUtils.grab_dependent_function_from_coverage_data(
"HttpInterface.__init__", {}, original_cov_data
)
assert result.name == "HttpInterface.__init__"
assert result.coverage == 60.0
def test_no_false_substring_match_bare_init(self) -> None:
"""Bare __init__ should NOT match PathAwareCORSMiddleware.__init__ via substring."""
original_cov_data = {
"files": {"cors.py": {"functions": {"PathAwareCORSMiddleware.__init__": self._make_func_data(50.0)}}}
}
result = CoverageUtils.grab_dependent_function_from_coverage_data("__init__", {}, original_cov_data)
assert result.coverage == 0
def test_no_false_substring_match_different_class(self) -> None:
"""Qualified name for one class should not match another class's method."""
original_cov_data = {
"files": {
"api.py": {
"functions": {
"PathAwareCORSMiddleware.__init__": self._make_func_data(50.0),
"HttpInterface.__init__": self._make_func_data(85.0),
}
}
}
}
result = CoverageUtils.grab_dependent_function_from_coverage_data(
"HttpInterface.__init__", {}, original_cov_data
)
assert result.name == "HttpInterface.__init__"
assert result.coverage == 85.0
def test_no_match_returns_zero_coverage(self) -> None:
result = CoverageUtils.grab_dependent_function_from_coverage_data("nonexistent_func", {}, {"files": {}})
assert result.coverage == 0
assert result.executed_lines == []
def test_qualified_suffix_no_match_for_partial_name(self) -> None:
"""Ensure suffix match requires a dot boundary, not just string suffix."""
original_cov_data = {
"files": {"api.py": {"functions": {"XHttpInterface.__init__": self._make_func_data(40.0)}}}
}
# "HttpInterface.__init__" should NOT match "XHttpInterface.__init__" via suffix
result = CoverageUtils.grab_dependent_function_from_coverage_data(
"HttpInterface.__init__", {}, original_cov_data
)
assert result.coverage == 0
def test_bare_name_exact_match_in_fallback(self) -> None:
"""Bare function name should still work with exact match in fallback."""
original_cov_data = {"files": {"utils.py": {"functions": {"helper_func": self._make_func_data(95.0)}}}}
result = CoverageUtils.grab_dependent_function_from_coverage_data("helper_func", {}, original_cov_data)
assert result.name == "helper_func"
assert result.coverage == 95.0

File diff suppressed because it is too large Load diff

View file

@ -8,7 +8,7 @@ to actual file paths, enabling multi-file context extraction.
import pytest
from codeflash.languages.javascript.import_resolver import HelperSearchContext, ImportResolver, MultiFileHelperFinder
from codeflash.languages.treesitter_utils import ImportInfo
from codeflash.languages.javascript.treesitter import ImportInfo
class TestImportResolver:
@ -286,7 +286,7 @@ class TestExportInfo:
@pytest.fixture
def js_analyzer(self):
"""Create a JavaScript analyzer."""
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
@ -388,7 +388,7 @@ class TestCommonJSRequire:
@pytest.fixture
def js_analyzer(self):
"""Create a JavaScript analyzer."""
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
@ -470,14 +470,14 @@ class TestCommonJSExports:
@pytest.fixture
def js_analyzer(self):
"""Create a JavaScript analyzer."""
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
@pytest.fixture
def ts_analyzer(self):
"""Create a TypeScript analyzer."""
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)

View file

@ -654,7 +654,7 @@ describe('Math functions', () => {
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)

View file

@ -627,7 +627,7 @@ it('third test', () => {});
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -652,7 +652,7 @@ describe('Suite B', () => {
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -676,7 +676,7 @@ describe('Outer', () => {
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -700,7 +700,7 @@ describe.skip('skipped describe', () => {
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -721,7 +721,7 @@ describe.only('only describe', () => {
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -739,7 +739,7 @@ describe('describe single', () => {});
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -758,7 +758,7 @@ describe("describe double", () => {});
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -774,7 +774,7 @@ describe("describe double", () => {});
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1020,7 +1020,7 @@ describe('日本語テスト', () => {
file_path = Path(f.name)
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1048,7 +1048,7 @@ test.each([
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1074,7 +1074,7 @@ describe.each([
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1099,7 +1099,7 @@ describe('Math operations', () => {
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1457,7 +1457,7 @@ testCases.forEach(name => {
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1485,7 +1485,7 @@ describe('conditional tests', () => {
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1509,7 +1509,7 @@ test('slow test', () => {
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1532,7 +1532,7 @@ test.todo('also needs implementation');
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1555,7 +1555,7 @@ test.concurrent('concurrent test 2', async () => {
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1654,7 +1654,7 @@ describe('Array', function() {
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1685,7 +1685,7 @@ describe('User', () => {
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)

View file

@ -8,7 +8,7 @@ from pathlib import Path
import pytest
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
class TestTreeSitterLanguage:
@ -545,3 +545,279 @@ function identity<T>(value: T): T {
assert len(functions) == 1
assert functions[0].name == "identity"
class TestExportConstArrowFunctions:
"""Tests for export const arrow function pattern - Issue #10.
Modern TypeScript codebases commonly use:
- export const slugify = (str: string) => { return s; }
- export const uniqueBy = <T>(array: T[]) => { ... }
These must be correctly recognized as optimizable functions.
"""
@pytest.fixture
def ts_analyzer(self):
"""Create a TypeScript analyzer."""
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
def test_export_const_arrow_function_basic(self, ts_analyzer):
"""Test finding export const arrow function (basic pattern)."""
code = """export const slugify = (str: string) => {
return str.toLowerCase();
};"""
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "slugify"
assert functions[0].is_arrow is True
assert ts_analyzer.has_return_statement(functions[0], code) is True
def test_export_const_arrow_function_optional_param(self, ts_analyzer):
"""Test finding export const arrow function with optional parameter."""
code = """export const slugify = (str: string, forDisplayingInput?: boolean) => {
if (!str) {
return "";
}
const s = str.toLowerCase();
return forDisplayingInput ? s : s.replace(/-+$/, "");
};"""
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "slugify"
assert functions[0].is_arrow is True
assert ts_analyzer.has_return_statement(functions[0], code) is True
def test_export_const_generic_arrow_function(self, ts_analyzer):
"""Test finding export const arrow function with generics."""
code = """export const uniqueBy = <T extends { [key: string]: unknown }>(array: T[], keys: (keyof T)[]) => {
return array.filter(
(item, index, self) => index === self.findIndex((t) => keys.every((key) => t[key] === item[key]))
);
};"""
functions = ts_analyzer.find_functions(code)
# Should find uniqueBy, and possibly the inner arrow functions
uniqueBy = next((f for f in functions if f.name == "uniqueBy"), None)
assert uniqueBy is not None
assert uniqueBy.is_arrow is True
assert ts_analyzer.has_return_statement(uniqueBy, code) is True
def test_export_const_arrow_function_is_exported(self, ts_analyzer):
"""Test that export const arrow functions are recognized as exported."""
code = """export const slugify = (str: string) => {
return str.toLowerCase();
};"""
# Check is_function_exported
is_exported, export_name = ts_analyzer.is_function_exported(code, "slugify")
assert is_exported is True
assert export_name == "slugify"
def test_export_const_with_default_export(self, ts_analyzer):
"""Test export const with separate default export."""
code = """export const slugify = (str: string) => {
return str.toLowerCase();
};
export default slugify;"""
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "slugify"
# Should be exported both ways
is_named, named_name = ts_analyzer.is_function_exported(code, "slugify")
assert is_named is True
def test_multiple_export_const_functions(self, ts_analyzer):
"""Test multiple export const arrow functions in same file."""
code = """export const notUndefined = <T>(val: T | undefined): val is T => Boolean(val);
export const uniqueBy = <T extends { [key: string]: unknown }>(array: T[], keys: (keyof T)[]) => {
return array.filter(
(item, index, self) => index === self.findIndex((t) => keys.every((key) => t[key] === item[key]))
);
};"""
functions = ts_analyzer.find_functions(code)
# Find the top-level exported functions
names = {f.name for f in functions if f.parent_function is None}
assert "notUndefined" in names
assert "uniqueBy" in names
def test_export_const_arrow_with_implicit_return(self, ts_analyzer):
"""Test export const arrow function with implicit return."""
code = """export const double = (n: number) => n * 2;"""
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "double"
assert functions[0].is_arrow is True
assert ts_analyzer.has_return_statement(functions[0], code) is True
def test_export_const_async_arrow_function(self, ts_analyzer):
"""Test export const async arrow function."""
code = """export const fetchData = async (url: string) => {
const response = await fetch(url);
return response.json();
};"""
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].name == "fetchData"
assert functions[0].is_arrow is True
assert functions[0].is_async is True
assert ts_analyzer.has_return_statement(functions[0], code) is True
def test_non_exported_const_not_exported(self, ts_analyzer):
"""Test that non-exported const functions are not marked as exported."""
code = """const privateFunc = (x: number) => {
return x * 2;
};
export const publicFunc = (x: number) => {
return privateFunc(x);
};"""
# privateFunc should not be exported
is_private_exported, _ = ts_analyzer.is_function_exported(code, "privateFunc")
assert is_private_exported is False
# publicFunc should be exported
is_public_exported, name = ts_analyzer.is_function_exported(code, "publicFunc")
assert is_public_exported is True
assert name == "publicFunc"
class TestWrappedDefaultExports:
"""Tests for wrapped default export pattern - Issue #9.
Handles patterns like:
- export default curry(traverseEntity)
- export default compose(fn1, fn2)
- export default wrapper(myFunc)
These must be correctly recognized so the wrapped function is exportable.
"""
@pytest.fixture
def ts_analyzer(self):
"""Create a TypeScript analyzer."""
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
def test_curry_wrapped_export(self, ts_analyzer):
"""Test export default curry(fn) pattern."""
code = """import { curry } from 'lodash/fp';
const traverseEntity = async (visitor, options, entity) => {
return entity;
};
export default curry(traverseEntity);"""
# Check exports parsing
exports = ts_analyzer.find_exports(code)
assert len(exports) == 1
assert exports[0].default_export == "default"
assert exports[0].wrapped_default_args == ["traverseEntity"]
# Check is_function_exported
is_exported, export_name = ts_analyzer.is_function_exported(code, "traverseEntity")
assert is_exported is True
assert export_name == "default"
def test_compose_wrapped_export(self, ts_analyzer):
"""Test export default compose(fn1, fn2) pattern with multiple args."""
code = """import { compose } from 'lodash/fp';
function validateInput(data) { return data; }
function processData(data) { return data; }
export default compose(validateInput, processData);"""
exports = ts_analyzer.find_exports(code)
assert len(exports) == 1
assert exports[0].wrapped_default_args == ["validateInput", "processData"]
# Both functions should be recognized as exported
is_exported1, _ = ts_analyzer.is_function_exported(code, "validateInput")
is_exported2, _ = ts_analyzer.is_function_exported(code, "processData")
assert is_exported1 is True
assert is_exported2 is True
def test_nested_wrapper_export(self, ts_analyzer):
"""Test nested wrapper: export default compose(curry(fn))."""
code = """export default compose(curry(myFunc));"""
exports = ts_analyzer.find_exports(code)
assert len(exports) == 1
assert "myFunc" in exports[0].wrapped_default_args
is_exported, _ = ts_analyzer.is_function_exported(code, "myFunc")
assert is_exported is True
def test_generic_wrapper_export(self, ts_analyzer):
"""Test generic wrapper function."""
code = """const myFunction = (x: number) => x * 2;
export default someWrapper(myFunction);"""
is_exported, export_name = ts_analyzer.is_function_exported(code, "myFunction")
assert is_exported is True
assert export_name == "default"
def test_non_wrapped_function_not_exported(self, ts_analyzer):
"""Test that functions not in the wrapper call are not exported."""
code = """const helper = (x: number) => x + 1;
const main = (x: number) => helper(x) * 2;
export default curry(main);"""
# main is wrapped, so it's exported
is_main_exported, _ = ts_analyzer.is_function_exported(code, "main")
assert is_main_exported is True
# helper is NOT in the wrapper call, so not exported
is_helper_exported, _ = ts_analyzer.is_function_exported(code, "helper")
assert is_helper_exported is False
def test_direct_default_export_still_works(self, ts_analyzer):
"""Test that direct default exports still work."""
code = """function myFunc() { return 1; }
export default myFunc;"""
is_exported, export_name = ts_analyzer.is_function_exported(code, "myFunc")
assert is_exported is True
assert export_name == "default"
def test_strapi_traverse_entity_pattern(self, ts_analyzer):
"""Test the exact strapi pattern that was failing."""
code = """import { curry } from 'lodash/fp';
const traverseEntity = async (visitor: Visitor, options: TraverseOptions, entity: Data) => {
const { path = { raw: null }, schema, getModel } = options;
// ... implementation
return copy;
};
const createVisitorUtils = ({ data }: { data: Data }) => ({
remove(key: string) { delete data[key]; },
set(key: string, value: Data) { data[key] = value; },
});
export default curry(traverseEntity);"""
# traverseEntity should be recognized as exported
is_exported, export_name = ts_analyzer.is_function_exported(code, "traverseEntity")
assert is_exported is True
assert export_name == "default"
# createVisitorUtils is NOT wrapped, so not exported via default
is_utils_exported, _ = ts_analyzer.is_function_exported(code, "createVisitorUtils")
assert is_utils_exported is False

790
uv.lock

File diff suppressed because it is too large Load diff