mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge branch 'main' of github.com:codeflash-ai/codeflash into fix/js-jest30-loop-runner
This commit is contained in:
commit
e9b7154361
30 changed files with 2469 additions and 419 deletions
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
|
|
@ -48,7 +48,7 @@ jobs:
|
|||
with:
|
||||
use_foundry: "true"
|
||||
use_sticky_comment: true
|
||||
allowed_bots: "claude[bot]"
|
||||
allowed_bots: "claude[bot],codeflash-ai[bot]"
|
||||
prompt: |
|
||||
REPO: ${{ github.repository }}
|
||||
PR NUMBER: ${{ github.event.pull_request.number }}
|
||||
|
|
|
|||
17
CLAUDE.md
17
CLAUDE.md
|
|
@ -27,12 +27,29 @@ uv run ruff format codeflash/ # Format
|
|||
# Linting (run before committing)
|
||||
uv run prek run --from-ref origin/main
|
||||
|
||||
# Mypy type checking (run on changed files before committing)
|
||||
uv run mypy --non-interactive --config-file pyproject.toml <changed_files>
|
||||
|
||||
# Running the CLI
|
||||
uv run codeflash --help
|
||||
uv run codeflash init # Initialize in a project
|
||||
uv run codeflash --all # Optimize entire codebase
|
||||
```
|
||||
|
||||
## Mypy Type Checking
|
||||
|
||||
When modifying code, fix any mypy type errors in the files you changed. Run mypy on changed files:
|
||||
|
||||
```bash
|
||||
uv run mypy --non-interactive --config-file pyproject.toml <changed_files>
|
||||
```
|
||||
|
||||
Rules:
|
||||
- Fix type annotation issues: missing return types, incorrect types, Optional/None unions, import errors for type hints
|
||||
- Do NOT add `# type: ignore` comments — always fix the root cause
|
||||
- Do NOT fix type errors that require logic changes, complex generic type rework, or anything that could change runtime behavior
|
||||
- Files in `mypy_allowlist.txt` are checked in CI — ensure they remain error-free
|
||||
|
||||
<!-- Section below is auto-generated by `tessl install` - do not edit manually -->
|
||||
|
||||
# Agent Rules <!-- tessl-managed -->
|
||||
|
|
|
|||
|
|
@ -386,7 +386,7 @@ class JavaScriptTransformer:
|
|||
|
||||
from pathlib import Path
|
||||
from codeflash.languages.base import LanguageSupport, FunctionInfo, CodeContext
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer
|
||||
from codeflash.languages.javascript.transformer import JavaScriptTransformer
|
||||
|
||||
class JavaScriptSupport(LanguageSupport):
|
||||
|
|
@ -523,7 +523,7 @@ class JavaScriptSupport(LanguageSupport):
|
|||
# codeflash/languages/javascript/test_discovery.py
|
||||
|
||||
from pathlib import Path
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer
|
||||
|
||||
class JestTestDiscovery:
|
||||
"""Static analysis-based test discovery for Jest."""
|
||||
|
|
|
|||
|
|
@ -1772,7 +1772,7 @@ def _extract_calling_function_js(source_code: str, function_name: str, ref_line:
|
|||
|
||||
"""
|
||||
try:
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
# Try TypeScript first, fall back to JavaScript
|
||||
for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]:
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language, LanguageSupport
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode
|
||||
|
||||
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
|
||||
|
|
@ -640,7 +640,7 @@ def _add_global_declarations_for_language(
|
|||
return original_source
|
||||
|
||||
try:
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(module_abspath)
|
||||
|
||||
|
|
|
|||
|
|
@ -105,12 +105,12 @@ def clean_concolic_tests(test_suite_code: str) -> str:
|
|||
can_parse = False
|
||||
tree = None
|
||||
|
||||
if not can_parse:
|
||||
if not can_parse or tree is None:
|
||||
return AssertCleanup().transform_asserts(test_suite_code)
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"):
|
||||
new_body = []
|
||||
new_body: list[ast.stmt] = []
|
||||
for stmt in node.body:
|
||||
if isinstance(stmt, ast.Assert):
|
||||
if isinstance(stmt.test, ast.Compare) and isinstance(stmt.test.left, ast.Call):
|
||||
|
|
|
|||
|
|
@ -13,14 +13,28 @@ if TYPE_CHECKING:
|
|||
def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]:
|
||||
"""Extract the single dependent function from the code context excluding the main function."""
|
||||
dependent_functions = set()
|
||||
for code_string in code_context.testgen_context.code_strings:
|
||||
ast_tree = ast.parse(code_string.code)
|
||||
dependent_functions.update(
|
||||
{node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))}
|
||||
)
|
||||
|
||||
if main_function in dependent_functions:
|
||||
dependent_functions.discard(main_function)
|
||||
# Compare using bare name since AST extracts bare function names
|
||||
bare_main = main_function.rsplit(".", 1)[-1] if "." in main_function else main_function
|
||||
|
||||
for code_string in code_context.testgen_context.code_strings:
|
||||
# Quick heuristic: skip parsing entirely if there is no 'def' token,
|
||||
# since no function definitions can be present without it.
|
||||
if "def" not in code_string.code:
|
||||
continue
|
||||
|
||||
ast_tree = ast.parse(code_string.code)
|
||||
# Add function names directly, skipping the bare main name.
|
||||
for node in ast_tree.body:
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
name = node.name
|
||||
if name == bare_main:
|
||||
continue
|
||||
dependent_functions.add(name)
|
||||
# If more than one dependent function (other than the main) is found,
|
||||
# we can return False early since the final result cannot be a single name.
|
||||
if len(dependent_functions) > 1:
|
||||
return False
|
||||
|
||||
if not dependent_functions:
|
||||
return False
|
||||
|
|
@ -32,6 +46,9 @@ def extract_dependent_function(main_function: str, code_context: CodeOptimizatio
|
|||
|
||||
|
||||
def build_fully_qualified_name(function_name: str, code_context: CodeOptimizationContext) -> str:
|
||||
# If the name is already qualified (contains a dot), return as-is
|
||||
if "." in function_name:
|
||||
return function_name
|
||||
full_name = function_name
|
||||
for obj_name, parents in code_context.preexisting_objects:
|
||||
if obj_name == function_name:
|
||||
|
|
|
|||
|
|
@ -233,7 +233,7 @@ class JavaScriptNormalizer(CodeNormalizer):
|
|||
|
||||
"""
|
||||
try:
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
lang_map = {"javascript": TreeSitterLanguage.JAVASCRIPT, "typescript": TreeSitterLanguage.TYPESCRIPT}
|
||||
lang = lang_map.get(self._get_tree_sitter_language(), TreeSitterLanguage.JAVASCRIPT)
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bo
|
|||
Tuple of (is_exported, export_name). export_name may be 'default' for default exports.
|
||||
|
||||
"""
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
|
|
|
|||
|
|
@ -26,10 +26,10 @@ class PrComment:
|
|||
|
||||
def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]:
|
||||
report_table: dict[str, dict[str, int]] = {}
|
||||
for test_type, test_result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items():
|
||||
for test_type, counts in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items():
|
||||
name = test_type.to_name()
|
||||
if name:
|
||||
report_table[name] = test_result
|
||||
report_table[name] = counts
|
||||
|
||||
result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
|
||||
"optimization_explanation": self.optimization_explanation,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||
from tree_sitter import Node
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
|
||||
from codeflash.languages.javascript.treesitter import ImportInfo, TreeSitterAnalyzer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -112,7 +112,7 @@ class ReferenceFinder:
|
|||
List of Reference objects describing each call site.
|
||||
|
||||
"""
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
function_name = function_to_optimize.function_name
|
||||
source_file = function_to_optimize.file_path
|
||||
|
|
@ -168,7 +168,7 @@ class ReferenceFinder:
|
|||
if import_info:
|
||||
# Found an import - mark as visited and search for calls
|
||||
context.visited_files.add(file_path)
|
||||
import_name, original_import = import_info
|
||||
import_name, _original_import = import_info
|
||||
file_refs = self._find_references_in_file(
|
||||
file_path, file_code, function_name, import_name, file_analyzer, include_self=True
|
||||
)
|
||||
|
|
@ -213,7 +213,7 @@ class ReferenceFinder:
|
|||
trigger_check = True
|
||||
if import_info:
|
||||
context.visited_files.add(file_path)
|
||||
import_name, original_import = import_info
|
||||
import_name, _original_import = import_info
|
||||
file_refs = self._find_references_in_file(
|
||||
file_path, file_code, reexport_name, import_name, file_analyzer, include_self=True
|
||||
)
|
||||
|
|
@ -404,7 +404,7 @@ class ReferenceFinder:
|
|||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
|
||||
elif node.type in ("variable_declarator",):
|
||||
elif node.type == "variable_declarator":
|
||||
# Arrow function or function expression assigned to variable
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
|
|
@ -719,7 +719,7 @@ class ReferenceFinder:
|
|||
continue
|
||||
|
||||
# Create a fake ImportInfo to resolve the re-export source
|
||||
from codeflash.languages.treesitter_utils import ImportInfo
|
||||
from codeflash.languages.javascript.treesitter import ImportInfo
|
||||
|
||||
fake_import = ImportInfo(
|
||||
module_path=exp.reexport_source,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from typing import TYPE_CHECKING
|
|||
if TYPE_CHECKING:
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import HelperFunction
|
||||
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
|
||||
from codeflash.languages.javascript.treesitter import ImportInfo, TreeSitterAnalyzer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -486,7 +486,7 @@ class MultiFileHelperFinder:
|
|||
|
||||
"""
|
||||
from codeflash.languages.base import HelperFunction
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
|
|
@ -558,8 +558,8 @@ class MultiFileHelperFinder:
|
|||
|
||||
"""
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
from codeflash.languages.registry import get_language_support
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
if context.current_depth >= context.max_depth:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -861,7 +861,7 @@ def validate_and_fix_import_style(test_code: str, source_file_path: Path, functi
|
|||
Fixed test code with correct import style.
|
||||
|
||||
"""
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
# Read source file to determine export style
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import json
|
|||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
|
|
|||
|
|
@ -14,15 +14,15 @@ from typing import TYPE_CHECKING, Any
|
|||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, Language, TestInfo, TestResult
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
|
||||
from codeflash.languages.registry import register_language
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from codeflash.languages.base import ReferenceInfo
|
||||
from codeflash.languages.treesitter_utils import TypeDefinition
|
||||
from codeflash.languages.javascript.treesitter import TypeDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -95,6 +95,9 @@ class ExportInfo:
|
|||
reexport_source: str | None # Module path for re-exports
|
||||
start_line: int
|
||||
end_line: int
|
||||
# Functions passed as arguments to wrapper calls in default exports
|
||||
# e.g., export default curry(traverseEntity) -> ["traverseEntity"]
|
||||
wrapped_default_args: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -864,6 +867,7 @@ class TreeSitterAnalyzer:
|
|||
default_export: str | None = None
|
||||
is_reexport = False
|
||||
reexport_source: str | None = None
|
||||
wrapped_default_args: list[str] | None = None
|
||||
|
||||
# Check for re-export source (export { x } from './other')
|
||||
source_node = node.child_by_field_name("source")
|
||||
|
|
@ -883,6 +887,12 @@ class TreeSitterAnalyzer:
|
|||
default_export = self.get_node_text(sibling, source_bytes)
|
||||
elif sibling.type in ("arrow_function", "function_expression", "object", "array"):
|
||||
default_export = "default"
|
||||
elif sibling.type == "call_expression":
|
||||
# Handle wrapped exports: export default curry(traverseEntity)
|
||||
# The default export is the result of the call, but we track
|
||||
# the wrapped function names for export checking
|
||||
default_export = "default"
|
||||
wrapped_default_args = self._extract_call_expression_identifiers(sibling, source_bytes)
|
||||
break
|
||||
|
||||
# Handle named exports: export { a, b as c }
|
||||
|
|
@ -930,8 +940,37 @@ class TreeSitterAnalyzer:
|
|||
reexport_source=reexport_source,
|
||||
start_line=node.start_point[0] + 1,
|
||||
end_line=node.end_point[0] + 1,
|
||||
wrapped_default_args=wrapped_default_args,
|
||||
)
|
||||
|
||||
def _extract_call_expression_identifiers(self, node: Node, source_bytes: bytes) -> list[str]:
|
||||
"""Extract identifier names from arguments of a call expression.
|
||||
|
||||
For patterns like curry(traverseEntity) or compose(fn1, fn2), this extracts
|
||||
the function names passed as arguments: ["traverseEntity"] or ["fn1", "fn2"].
|
||||
|
||||
Args:
|
||||
node: A call_expression node.
|
||||
source_bytes: The source code as bytes.
|
||||
|
||||
Returns:
|
||||
List of identifier names found in the call arguments.
|
||||
|
||||
"""
|
||||
identifiers: list[str] = []
|
||||
|
||||
# Get the arguments node
|
||||
args_node = node.child_by_field_name("arguments")
|
||||
if args_node:
|
||||
for child in args_node.children:
|
||||
if child.type == "identifier":
|
||||
identifiers.append(self.get_node_text(child, source_bytes))
|
||||
# Also handle nested call expressions: compose(curry(fn))
|
||||
elif child.type == "call_expression":
|
||||
identifiers.extend(self._extract_call_expression_identifiers(child, source_bytes))
|
||||
|
||||
return identifiers
|
||||
|
||||
def _extract_commonjs_export(self, node: Node, source_bytes: bytes) -> ExportInfo | None:
|
||||
"""Extract export information from CommonJS module.exports or exports.* patterns.
|
||||
|
||||
|
|
@ -1033,6 +1072,7 @@ class TreeSitterAnalyzer:
|
|||
"""Check if a function is exported and get its export name.
|
||||
|
||||
For class methods, also checks if the containing class is exported.
|
||||
Also handles wrapped exports like: export default curry(traverseEntity)
|
||||
|
||||
Args:
|
||||
source: The source code to analyze.
|
||||
|
|
@ -1058,6 +1098,11 @@ class TreeSitterAnalyzer:
|
|||
if name == function_name:
|
||||
return (True, alias if alias else name)
|
||||
|
||||
# Check wrapped default exports: export default curry(traverseEntity)
|
||||
# The function is exported via wrapper, so it's accessible as "default"
|
||||
if export.wrapped_default_args and function_name in export.wrapped_default_args:
|
||||
return (True, "default")
|
||||
|
||||
# For class methods, check if the containing class is exported
|
||||
if class_name:
|
||||
for export in exports:
|
||||
|
|
@ -2793,7 +2793,7 @@ class FunctionOptimizer:
|
|||
test_config=self.test_cfg,
|
||||
optimization_iteration=optimization_iteration,
|
||||
run_result=run_result,
|
||||
function_name=self.function_to_optimize.function_name,
|
||||
function_name=self.function_to_optimize.qualified_name,
|
||||
source_file=self.function_to_optimize.file_path,
|
||||
code_context=code_context,
|
||||
coverage_database_file=coverage_database_file,
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def pytest_split(
|
|||
|
||||
except ImportError:
|
||||
return None, None
|
||||
test_files = set()
|
||||
test_files_set: set[str] = set()
|
||||
|
||||
# Find all test_*.py files recursively in the directory
|
||||
for test_path in test_paths:
|
||||
|
|
@ -42,12 +42,12 @@ def pytest_split(
|
|||
return None, None
|
||||
if _test_path.is_dir():
|
||||
# Find all test files matching the pattern test_*.py
|
||||
test_files.update(map(str, _test_path.rglob("test_*.py")))
|
||||
test_files.update(map(str, _test_path.rglob("*_test.py")))
|
||||
test_files_set.update(map(str, _test_path.rglob("test_*.py")))
|
||||
test_files_set.update(map(str, _test_path.rglob("*_test.py")))
|
||||
elif _test_path.is_file():
|
||||
test_files.add(str(_test_path))
|
||||
test_files_set.add(str(_test_path))
|
||||
|
||||
if not test_files:
|
||||
if not test_files_set:
|
||||
return [[]], None
|
||||
|
||||
# Determine number of splits
|
||||
|
|
@ -55,7 +55,7 @@ def pytest_split(
|
|||
num_splits = os.cpu_count() or 4
|
||||
|
||||
# randomize to increase chances of all splits being balanced
|
||||
test_files = list(test_files)
|
||||
test_files = list(test_files_set)
|
||||
shuffle(test_files)
|
||||
|
||||
# Apply limit if specified
|
||||
|
|
@ -75,7 +75,7 @@ def pytest_split(
|
|||
chunk_size = ceil(total_files / num_splits)
|
||||
|
||||
# Initialize result groups
|
||||
result_groups = [[] for _ in range(num_splits)]
|
||||
result_groups: list[list[str]] = [[] for _ in range(num_splits)]
|
||||
|
||||
# Distribute files across groups
|
||||
for i, test_file in enumerate(test_files):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import enum
|
|||
import math
|
||||
import re
|
||||
import types
|
||||
import weakref
|
||||
from collections import ChainMap, OrderedDict, deque
|
||||
from importlib.util import find_spec
|
||||
from typing import Any, Optional
|
||||
|
|
@ -25,6 +26,7 @@ HAS_JAX = find_spec("jax") is not None
|
|||
HAS_XARRAY = find_spec("xarray") is not None
|
||||
HAS_TENSORFLOW = find_spec("tensorflow") is not None
|
||||
HAS_NUMBA = find_spec("numba") is not None
|
||||
HAS_PYARROW = find_spec("pyarrow") is not None
|
||||
|
||||
# Pattern to match pytest temp directories: /tmp/pytest-of-<user>/pytest-<N>/
|
||||
# These paths vary between test runs but are logically equivalent
|
||||
|
|
@ -93,7 +95,7 @@ def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # no
|
|||
return _extract_exception_from_message(str(exc))
|
||||
|
||||
|
||||
def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
|
||||
def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
|
||||
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
|
||||
try:
|
||||
# Handle exceptions specially - before type check to allow wrapper comparison
|
||||
|
|
@ -171,6 +173,17 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
|
|||
return True
|
||||
return math.isclose(orig, new)
|
||||
|
||||
# Handle weak references (e.g., found in torch.nn.LSTM/GRU modules)
|
||||
if isinstance(orig, weakref.ref):
|
||||
orig_referent = orig()
|
||||
new_referent = new()
|
||||
# Both dead refs are equal, otherwise compare referents
|
||||
if orig_referent is None and new_referent is None:
|
||||
return True
|
||||
if orig_referent is None or new_referent is None:
|
||||
return False
|
||||
return comparator(orig_referent, new_referent, superset_obj)
|
||||
|
||||
if HAS_JAX:
|
||||
import jax # type: ignore # noqa: PGH003
|
||||
import jax.numpy as jnp # type: ignore # noqa: PGH003
|
||||
|
|
@ -342,13 +355,57 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
|
|||
return False
|
||||
return (orig != new).nnz == 0
|
||||
|
||||
if HAS_PYARROW:
|
||||
import pyarrow as pa # type: ignore # noqa: PGH003
|
||||
|
||||
if isinstance(orig, pa.Table):
|
||||
if orig.schema != new.schema:
|
||||
return False
|
||||
if orig.num_rows != new.num_rows:
|
||||
return False
|
||||
return bool(orig.equals(new))
|
||||
|
||||
if isinstance(orig, pa.RecordBatch):
|
||||
if orig.schema != new.schema:
|
||||
return False
|
||||
if orig.num_rows != new.num_rows:
|
||||
return False
|
||||
return bool(orig.equals(new))
|
||||
|
||||
if isinstance(orig, pa.ChunkedArray):
|
||||
if orig.type != new.type:
|
||||
return False
|
||||
if len(orig) != len(new):
|
||||
return False
|
||||
return bool(orig.equals(new))
|
||||
|
||||
if isinstance(orig, pa.Array):
|
||||
if orig.type != new.type:
|
||||
return False
|
||||
if len(orig) != len(new):
|
||||
return False
|
||||
return bool(orig.equals(new))
|
||||
|
||||
if isinstance(orig, pa.Scalar):
|
||||
if orig.type != new.type:
|
||||
return False
|
||||
# Handle null scalars
|
||||
if not orig.is_valid and not new.is_valid:
|
||||
return True
|
||||
if not orig.is_valid or not new.is_valid:
|
||||
return False
|
||||
return bool(orig.equals(new))
|
||||
|
||||
if isinstance(orig, (pa.Schema, pa.Field, pa.DataType)):
|
||||
return bool(orig.equals(new))
|
||||
|
||||
if HAS_PANDAS:
|
||||
import pandas # noqa: ICN001
|
||||
|
||||
if isinstance(
|
||||
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
|
||||
):
|
||||
return orig.equals(new)
|
||||
return bool(orig.equals(new))
|
||||
|
||||
if isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
|
||||
return orig == new
|
||||
|
|
@ -395,10 +452,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
|
|||
return orig == new
|
||||
|
||||
if HAS_NUMBA:
|
||||
import numba # type: ignore # noqa: PGH003
|
||||
from numba.core.dispatcher import Dispatcher # type: ignore # noqa: PGH003
|
||||
from numba.typed import Dict as NumbaDict # type: ignore # noqa: PGH003
|
||||
from numba.typed import List as NumbaList # type: ignore # noqa: PGH003
|
||||
import numba
|
||||
from numba.core.dispatcher import Dispatcher
|
||||
from numba.typed import Dict as NumbaDict
|
||||
from numba.typed import List as NumbaList
|
||||
|
||||
# Handle numba typed List
|
||||
if isinstance(orig, NumbaList):
|
||||
|
|
|
|||
|
|
@ -353,7 +353,9 @@ class CoverageUtils:
|
|||
for file in files:
|
||||
functions = files[file]["functions"]
|
||||
for function in functions:
|
||||
if dependent_function_name in function:
|
||||
if function == dependent_function_name or (
|
||||
"." in dependent_function_name and function.endswith(f".{dependent_function_name}")
|
||||
):
|
||||
return FunctionCoverage(
|
||||
name=dependent_function_name,
|
||||
coverage=functions[function]["summary"]["percent_covered"],
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
# These version placeholders will be replaced by uv-dynamic-versioning during build.
|
||||
__version__ = "0.20.0"
|
||||
__version__ = "0.20.0.post510.dev0+b8932209"
|
||||
|
|
|
|||
|
|
@ -106,40 +106,63 @@ function installUv() {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if git is available
|
||||
*/
|
||||
function hasGit() {
|
||||
try {
|
||||
const result = spawnSync('git', ['--version'], {
|
||||
stdio: 'ignore',
|
||||
shell: true,
|
||||
});
|
||||
return result.status === 0;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Install codeflash Python CLI using uv tool
|
||||
*
|
||||
* Installation priority:
|
||||
* 1. GitHub main branch (if git available) - gets latest features
|
||||
* 2. PyPI (fallback) - stable release
|
||||
*
|
||||
* We prefer GitHub because it has the latest JS/TS support that may not
|
||||
* be published to PyPI yet. uv handles cloning internally in its cache.
|
||||
*/
|
||||
function installCodeflash(uvBin) {
|
||||
logStep('2/3', 'Installing codeflash Python CLI...');
|
||||
|
||||
const GITHUB_REPO = 'git+https://github.com/codeflash-ai/codeflash.git';
|
||||
|
||||
// Priority 1: Install from GitHub (latest features, requires git)
|
||||
if (hasGit()) {
|
||||
try {
|
||||
execSync(`"${uvBin}" tool install --force --python python3.12 "${GITHUB_REPO}"`, {
|
||||
stdio: 'inherit',
|
||||
shell: true,
|
||||
});
|
||||
logSuccess('codeflash CLI installed from GitHub (latest)');
|
||||
return true;
|
||||
} catch (error) {
|
||||
logWarning(`GitHub installation failed: ${error.message}`);
|
||||
logWarning('Falling back to PyPI...');
|
||||
}
|
||||
} else {
|
||||
logWarning('Git not found, installing from PyPI...');
|
||||
}
|
||||
|
||||
// Priority 2: Install from PyPI (stable release fallback)
|
||||
try {
|
||||
// Use uv tool install to install codeflash in an isolated environment
|
||||
// This avoids conflicts with any existing Python environments
|
||||
execSync(`"${uvBin}" tool install --force --python python3.12 codeflash`, {
|
||||
stdio: 'inherit',
|
||||
shell: true,
|
||||
});
|
||||
logSuccess('codeflash CLI installed successfully');
|
||||
logSuccess('codeflash CLI installed from PyPI');
|
||||
return true;
|
||||
} catch (error) {
|
||||
// If codeflash is not on PyPI yet, try installing from the local package
|
||||
logWarning('codeflash not found on PyPI, trying local installation...');
|
||||
try {
|
||||
// Try installing from the current codeflash repo if we're in development
|
||||
const cliRoot = path.resolve(__dirname, '..', '..', '..');
|
||||
const pyprojectPath = path.join(cliRoot, 'pyproject.toml');
|
||||
|
||||
if (fs.existsSync(pyprojectPath)) {
|
||||
execSync(`"${uvBin}" tool install --force "${cliRoot}"`, {
|
||||
stdio: 'inherit',
|
||||
shell: true,
|
||||
});
|
||||
logSuccess('codeflash CLI installed from local source');
|
||||
return true;
|
||||
}
|
||||
} catch (localError) {
|
||||
logError(`Failed to install codeflash: ${localError.message}`);
|
||||
}
|
||||
logError(`Failed to install codeflash: ${error.message}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ dependencies = [
|
|||
"pygls>=2.0.0,<3.0.0",
|
||||
"codeflash-benchmark",
|
||||
"filelock",
|
||||
"pytest-asyncio>=1.2.0",
|
||||
"pytest-asyncio>=0.18.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
|
@ -87,6 +87,7 @@ tests = [
|
|||
"jax>=0.4.30",
|
||||
"numpy>=2.0.2",
|
||||
"pandas>=2.3.3",
|
||||
"pyarrow>=15.0.0",
|
||||
"pyrsistent>=0.20.0",
|
||||
"scipy>=1.13.1",
|
||||
"torch>=2.8.0",
|
||||
|
|
|
|||
228
tests/code_utils/test_coverage_utils.py
Normal file
228
tests/code_utils/test_coverage_utils.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from codeflash.code_utils.coverage_utils import build_fully_qualified_name, extract_dependent_function
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown
|
||||
from codeflash.verification.coverage_utils import CoverageUtils
|
||||
|
||||
|
||||
def _make_code_context(
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
|
||||
testgen_code_strings: list[CodeString] | None = None,
|
||||
) -> CodeOptimizationContext:
|
||||
"""Helper to create a minimal CodeOptimizationContext for testing."""
|
||||
return CodeOptimizationContext(
|
||||
testgen_context=CodeStringsMarkdown(code_strings=testgen_code_strings or []),
|
||||
read_writable_code=CodeStringsMarkdown(),
|
||||
helper_functions=[],
|
||||
preexisting_objects=preexisting_objects,
|
||||
)
|
||||
|
||||
|
||||
class TestBuildFullyQualifiedName:
|
||||
def test_bare_name_with_class_parent(self) -> None:
|
||||
ctx = _make_code_context({("__init__", (FunctionParent(name="HttpInterface", type="ClassDef"),))})
|
||||
assert build_fully_qualified_name("__init__", ctx) == "HttpInterface.__init__"
|
||||
|
||||
def test_bare_name_no_parent(self) -> None:
|
||||
ctx = _make_code_context({("helper_func", ())})
|
||||
assert build_fully_qualified_name("helper_func", ctx) == "helper_func"
|
||||
|
||||
def test_already_qualified_name_returned_as_is(self) -> None:
|
||||
"""If name already contains a dot, skip preexisting_objects lookup."""
|
||||
ctx = _make_code_context({("__init__", (FunctionParent(name="WrongClass", type="ClassDef"),))})
|
||||
result = build_fully_qualified_name("HttpInterface.__init__", ctx)
|
||||
assert result == "HttpInterface.__init__"
|
||||
|
||||
def test_bare_name_picks_first_match_from_set(self) -> None:
|
||||
"""With multiple __init__ entries, bare name picks an arbitrary one."""
|
||||
ctx = _make_code_context(
|
||||
{
|
||||
("__init__", (FunctionParent(name="ClassA", type="ClassDef"),)),
|
||||
("__init__", (FunctionParent(name="ClassB", type="ClassDef"),)),
|
||||
}
|
||||
)
|
||||
result = build_fully_qualified_name("__init__", ctx)
|
||||
assert result in {"ClassA.__init__", "ClassB.__init__"}
|
||||
|
||||
def test_qualified_name_avoids_ambiguity(self) -> None:
|
||||
"""Qualified name bypasses preexisting_objects entirely, avoiding ambiguity."""
|
||||
ctx = _make_code_context(
|
||||
{
|
||||
("__init__", (FunctionParent(name="ClassA", type="ClassDef"),)),
|
||||
("__init__", (FunctionParent(name="ClassB", type="ClassDef"),)),
|
||||
}
|
||||
)
|
||||
assert build_fully_qualified_name("ClassB.__init__", ctx) == "ClassB.__init__"
|
||||
|
||||
def test_bare_name_not_in_preexisting_objects(self) -> None:
|
||||
ctx = _make_code_context(set())
|
||||
assert build_fully_qualified_name("some_func", ctx) == "some_func"
|
||||
|
||||
def test_nested_class_parent(self) -> None:
|
||||
"""Bare name under nested class parents gets fully qualified."""
|
||||
ctx = _make_code_context(
|
||||
{("method", (FunctionParent(name="Outer", type="ClassDef"), FunctionParent(name="Inner", type="ClassDef")))}
|
||||
)
|
||||
assert build_fully_qualified_name("method", ctx) == "Inner.Outer.method"
|
||||
|
||||
def test_non_classdef_parent_ignored(self) -> None:
|
||||
"""Only ClassDef parents are prepended to the name."""
|
||||
ctx = _make_code_context({("helper", (FunctionParent(name="wrapper", type="FunctionDef"),))})
|
||||
assert build_fully_qualified_name("helper", ctx) == "helper"
|
||||
|
||||
|
||||
class TestExtractDependentFunction:
|
||||
def test_single_dependent_function(self) -> None:
|
||||
ctx = _make_code_context(
|
||||
preexisting_objects={("helper", ())},
|
||||
testgen_code_strings=[CodeString(code="def main_func(): pass\ndef helper(): pass")],
|
||||
)
|
||||
result = extract_dependent_function("main_func", ctx)
|
||||
assert result == "helper"
|
||||
|
||||
def test_qualified_main_function_discards_bare_match(self) -> None:
|
||||
"""Qualified main_function should still discard the matching bare name."""
|
||||
ctx = _make_code_context(
|
||||
preexisting_objects={("helper", ())},
|
||||
testgen_code_strings=[CodeString(code="def __init__(): pass\ndef helper(): pass")],
|
||||
)
|
||||
result = extract_dependent_function("HttpInterface.__init__", ctx)
|
||||
assert result == "helper"
|
||||
|
||||
def test_bare_main_function_discards_match(self) -> None:
|
||||
"""Bare main_function should still work for discarding."""
|
||||
ctx = _make_code_context(
|
||||
preexisting_objects={("helper", ())},
|
||||
testgen_code_strings=[CodeString(code="def main_func(): pass\ndef helper(): pass")],
|
||||
)
|
||||
result = extract_dependent_function("main_func", ctx)
|
||||
assert result == "helper"
|
||||
|
||||
def test_no_dependent_functions(self) -> None:
|
||||
ctx = _make_code_context(preexisting_objects=set(), testgen_code_strings=[CodeString(code="x = 1\n")])
|
||||
result = extract_dependent_function("main_func", ctx)
|
||||
assert result is False
|
||||
|
||||
def test_multiple_dependent_functions_returns_false(self) -> None:
|
||||
ctx = _make_code_context(
|
||||
preexisting_objects=set(),
|
||||
testgen_code_strings=[CodeString(code="def helper_a(): pass\ndef helper_b(): pass")],
|
||||
)
|
||||
result = extract_dependent_function("main_func", ctx)
|
||||
assert result is False
|
||||
|
||||
def test_dependent_function_gets_qualified(self) -> None:
|
||||
"""The dependent function returned should be qualified via build_fully_qualified_name."""
|
||||
ctx = _make_code_context(
|
||||
preexisting_objects={("helper", (FunctionParent(name="MyClass", type="ClassDef"),))},
|
||||
testgen_code_strings=[CodeString(code="def main_func(): pass\ndef helper(): pass")],
|
||||
)
|
||||
result = extract_dependent_function("main_func", ctx)
|
||||
assert result == "MyClass.helper"
|
||||
|
||||
def test_only_main_in_code_returns_false(self) -> None:
|
||||
"""When code only contains the main function, no dependent function exists."""
|
||||
ctx = _make_code_context(
|
||||
preexisting_objects=set(), testgen_code_strings=[CodeString(code="def __init__(): pass")]
|
||||
)
|
||||
result = extract_dependent_function("HttpInterface.__init__", ctx)
|
||||
assert result is False
|
||||
|
||||
def test_async_functions_extracted(self) -> None:
|
||||
"""Async function definitions are also extracted as dependent functions."""
|
||||
ctx = _make_code_context(
|
||||
preexisting_objects={("async_helper", ())},
|
||||
testgen_code_strings=[CodeString(code="def main(): pass\nasync def async_helper(): pass")],
|
||||
)
|
||||
result = extract_dependent_function("main", ctx)
|
||||
assert result == "async_helper"
|
||||
|
||||
|
||||
class TestGrabDependentFunctionFromCoverageData:
|
||||
def _make_func_data(self, coverage_pct: float = 80.0) -> dict[str, Any]:
|
||||
return {
|
||||
"summary": {"percent_covered": coverage_pct},
|
||||
"executed_lines": [1, 2, 3],
|
||||
"missing_lines": [4],
|
||||
"executed_branches": [[1, 0]],
|
||||
"missing_branches": [[2, 1]],
|
||||
}
|
||||
|
||||
def test_exact_match_in_coverage_data(self) -> None:
|
||||
coverage_data = {"HttpInterface.__init__": self._make_func_data(90.0)}
|
||||
result = CoverageUtils.grab_dependent_function_from_coverage_data("HttpInterface.__init__", coverage_data, {})
|
||||
assert result.name == "HttpInterface.__init__"
|
||||
assert result.coverage == 90.0
|
||||
|
||||
def test_fallback_exact_match_in_original_data(self) -> None:
|
||||
original_cov_data = {
|
||||
"files": {"http_api.py": {"functions": {"HttpInterface.__init__": self._make_func_data(75.0)}}}
|
||||
}
|
||||
result = CoverageUtils.grab_dependent_function_from_coverage_data(
|
||||
"HttpInterface.__init__", {}, original_cov_data
|
||||
)
|
||||
assert result.name == "HttpInterface.__init__"
|
||||
assert result.coverage == 75.0
|
||||
|
||||
def test_fallback_suffix_match_in_original_data(self) -> None:
|
||||
"""Qualified dependent name matches via suffix in original coverage data."""
|
||||
original_cov_data = {
|
||||
"files": {"http_api.py": {"functions": {"module.HttpInterface.__init__": self._make_func_data(60.0)}}}
|
||||
}
|
||||
result = CoverageUtils.grab_dependent_function_from_coverage_data(
|
||||
"HttpInterface.__init__", {}, original_cov_data
|
||||
)
|
||||
assert result.name == "HttpInterface.__init__"
|
||||
assert result.coverage == 60.0
|
||||
|
||||
def test_no_false_substring_match_bare_init(self) -> None:
|
||||
"""Bare __init__ should NOT match PathAwareCORSMiddleware.__init__ via substring."""
|
||||
original_cov_data = {
|
||||
"files": {"cors.py": {"functions": {"PathAwareCORSMiddleware.__init__": self._make_func_data(50.0)}}}
|
||||
}
|
||||
result = CoverageUtils.grab_dependent_function_from_coverage_data("__init__", {}, original_cov_data)
|
||||
assert result.coverage == 0
|
||||
|
||||
def test_no_false_substring_match_different_class(self) -> None:
|
||||
"""Qualified name for one class should not match another class's method."""
|
||||
original_cov_data = {
|
||||
"files": {
|
||||
"api.py": {
|
||||
"functions": {
|
||||
"PathAwareCORSMiddleware.__init__": self._make_func_data(50.0),
|
||||
"HttpInterface.__init__": self._make_func_data(85.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result = CoverageUtils.grab_dependent_function_from_coverage_data(
|
||||
"HttpInterface.__init__", {}, original_cov_data
|
||||
)
|
||||
assert result.name == "HttpInterface.__init__"
|
||||
assert result.coverage == 85.0
|
||||
|
||||
def test_no_match_returns_zero_coverage(self) -> None:
|
||||
result = CoverageUtils.grab_dependent_function_from_coverage_data("nonexistent_func", {}, {"files": {}})
|
||||
assert result.coverage == 0
|
||||
assert result.executed_lines == []
|
||||
|
||||
def test_qualified_suffix_no_match_for_partial_name(self) -> None:
|
||||
"""Ensure suffix match requires a dot boundary, not just string suffix."""
|
||||
original_cov_data = {
|
||||
"files": {"api.py": {"functions": {"XHttpInterface.__init__": self._make_func_data(40.0)}}}
|
||||
}
|
||||
# "HttpInterface.__init__" should NOT match "XHttpInterface.__init__" via suffix
|
||||
result = CoverageUtils.grab_dependent_function_from_coverage_data(
|
||||
"HttpInterface.__init__", {}, original_cov_data
|
||||
)
|
||||
assert result.coverage == 0
|
||||
|
||||
def test_bare_name_exact_match_in_fallback(self) -> None:
|
||||
"""Bare function name should still work with exact match in fallback."""
|
||||
original_cov_data = {"files": {"utils.py": {"functions": {"helper_func": self._make_func_data(95.0)}}}}
|
||||
result = CoverageUtils.grab_dependent_function_from_coverage_data("helper_func", {}, original_cov_data)
|
||||
assert result.name == "helper_func"
|
||||
assert result.coverage == 95.0
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -8,7 +8,7 @@ to actual file paths, enabling multi-file context extraction.
|
|||
import pytest
|
||||
|
||||
from codeflash.languages.javascript.import_resolver import HelperSearchContext, ImportResolver, MultiFileHelperFinder
|
||||
from codeflash.languages.treesitter_utils import ImportInfo
|
||||
from codeflash.languages.javascript.treesitter import ImportInfo
|
||||
|
||||
|
||||
class TestImportResolver:
|
||||
|
|
@ -286,7 +286,7 @@ class TestExportInfo:
|
|||
@pytest.fixture
|
||||
def js_analyzer(self):
|
||||
"""Create a JavaScript analyzer."""
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
|
||||
|
|
@ -388,7 +388,7 @@ class TestCommonJSRequire:
|
|||
@pytest.fixture
|
||||
def js_analyzer(self):
|
||||
"""Create a JavaScript analyzer."""
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
|
||||
|
|
@ -470,14 +470,14 @@ class TestCommonJSExports:
|
|||
@pytest.fixture
|
||||
def js_analyzer(self):
|
||||
"""Create a JavaScript analyzer."""
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
|
||||
@pytest.fixture
|
||||
def ts_analyzer(self):
|
||||
"""Create a TypeScript analyzer."""
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
|
||||
|
||||
|
|
|
|||
|
|
@ -654,7 +654,7 @@ describe('Math functions', () => {
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
|
|||
|
|
@ -627,7 +627,7 @@ it('third test', () => {});
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -652,7 +652,7 @@ describe('Suite B', () => {
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -676,7 +676,7 @@ describe('Outer', () => {
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -700,7 +700,7 @@ describe.skip('skipped describe', () => {
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -721,7 +721,7 @@ describe.only('only describe', () => {
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -739,7 +739,7 @@ describe('describe single', () => {});
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -758,7 +758,7 @@ describe("describe double", () => {});
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -774,7 +774,7 @@ describe("describe double", () => {});
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -1020,7 +1020,7 @@ describe('日本語テスト', () => {
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -1048,7 +1048,7 @@ test.each([
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -1074,7 +1074,7 @@ describe.each([
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -1099,7 +1099,7 @@ describe('Math operations', () => {
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -1457,7 +1457,7 @@ testCases.forEach(name => {
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -1485,7 +1485,7 @@ describe('conditional tests', () => {
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -1509,7 +1509,7 @@ test('slow test', () => {
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -1532,7 +1532,7 @@ test.todo('also needs implementation');
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -1555,7 +1555,7 @@ test.concurrent('concurrent test 2', async () => {
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -1654,7 +1654,7 @@ describe('Array', function() {
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
@ -1685,7 +1685,7 @@ describe('User', () => {
|
|||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
|
||||
|
||||
|
||||
class TestTreeSitterLanguage:
|
||||
|
|
@ -545,3 +545,279 @@ function identity<T>(value: T): T {
|
|||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "identity"
|
||||
|
||||
|
||||
class TestExportConstArrowFunctions:
|
||||
"""Tests for export const arrow function pattern - Issue #10.
|
||||
|
||||
Modern TypeScript codebases commonly use:
|
||||
- export const slugify = (str: string) => { return s; }
|
||||
- export const uniqueBy = <T>(array: T[]) => { ... }
|
||||
|
||||
These must be correctly recognized as optimizable functions.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def ts_analyzer(self):
|
||||
"""Create a TypeScript analyzer."""
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
|
||||
|
||||
def test_export_const_arrow_function_basic(self, ts_analyzer):
|
||||
"""Test finding export const arrow function (basic pattern)."""
|
||||
code = """export const slugify = (str: string) => {
|
||||
return str.toLowerCase();
|
||||
};"""
|
||||
functions = ts_analyzer.find_functions(code)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "slugify"
|
||||
assert functions[0].is_arrow is True
|
||||
assert ts_analyzer.has_return_statement(functions[0], code) is True
|
||||
|
||||
def test_export_const_arrow_function_optional_param(self, ts_analyzer):
|
||||
"""Test finding export const arrow function with optional parameter."""
|
||||
code = """export const slugify = (str: string, forDisplayingInput?: boolean) => {
|
||||
if (!str) {
|
||||
return "";
|
||||
}
|
||||
const s = str.toLowerCase();
|
||||
return forDisplayingInput ? s : s.replace(/-+$/, "");
|
||||
};"""
|
||||
functions = ts_analyzer.find_functions(code)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "slugify"
|
||||
assert functions[0].is_arrow is True
|
||||
assert ts_analyzer.has_return_statement(functions[0], code) is True
|
||||
|
||||
def test_export_const_generic_arrow_function(self, ts_analyzer):
|
||||
"""Test finding export const arrow function with generics."""
|
||||
code = """export const uniqueBy = <T extends { [key: string]: unknown }>(array: T[], keys: (keyof T)[]) => {
|
||||
return array.filter(
|
||||
(item, index, self) => index === self.findIndex((t) => keys.every((key) => t[key] === item[key]))
|
||||
);
|
||||
};"""
|
||||
functions = ts_analyzer.find_functions(code)
|
||||
|
||||
# Should find uniqueBy, and possibly the inner arrow functions
|
||||
uniqueBy = next((f for f in functions if f.name == "uniqueBy"), None)
|
||||
assert uniqueBy is not None
|
||||
assert uniqueBy.is_arrow is True
|
||||
assert ts_analyzer.has_return_statement(uniqueBy, code) is True
|
||||
|
||||
def test_export_const_arrow_function_is_exported(self, ts_analyzer):
|
||||
"""Test that export const arrow functions are recognized as exported."""
|
||||
code = """export const slugify = (str: string) => {
|
||||
return str.toLowerCase();
|
||||
};"""
|
||||
|
||||
# Check is_function_exported
|
||||
is_exported, export_name = ts_analyzer.is_function_exported(code, "slugify")
|
||||
assert is_exported is True
|
||||
assert export_name == "slugify"
|
||||
|
||||
def test_export_const_with_default_export(self, ts_analyzer):
|
||||
"""Test export const with separate default export."""
|
||||
code = """export const slugify = (str: string) => {
|
||||
return str.toLowerCase();
|
||||
};
|
||||
|
||||
export default slugify;"""
|
||||
|
||||
functions = ts_analyzer.find_functions(code)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "slugify"
|
||||
|
||||
# Should be exported both ways
|
||||
is_named, named_name = ts_analyzer.is_function_exported(code, "slugify")
|
||||
assert is_named is True
|
||||
|
||||
def test_multiple_export_const_functions(self, ts_analyzer):
|
||||
"""Test multiple export const arrow functions in same file."""
|
||||
code = """export const notUndefined = <T>(val: T | undefined): val is T => Boolean(val);
|
||||
|
||||
export const uniqueBy = <T extends { [key: string]: unknown }>(array: T[], keys: (keyof T)[]) => {
|
||||
return array.filter(
|
||||
(item, index, self) => index === self.findIndex((t) => keys.every((key) => t[key] === item[key]))
|
||||
);
|
||||
};"""
|
||||
|
||||
functions = ts_analyzer.find_functions(code)
|
||||
|
||||
# Find the top-level exported functions
|
||||
names = {f.name for f in functions if f.parent_function is None}
|
||||
assert "notUndefined" in names
|
||||
assert "uniqueBy" in names
|
||||
|
||||
def test_export_const_arrow_with_implicit_return(self, ts_analyzer):
|
||||
"""Test export const arrow function with implicit return."""
|
||||
code = """export const double = (n: number) => n * 2;"""
|
||||
|
||||
functions = ts_analyzer.find_functions(code)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "double"
|
||||
assert functions[0].is_arrow is True
|
||||
assert ts_analyzer.has_return_statement(functions[0], code) is True
|
||||
|
||||
def test_export_const_async_arrow_function(self, ts_analyzer):
|
||||
"""Test export const async arrow function."""
|
||||
code = """export const fetchData = async (url: string) => {
|
||||
const response = await fetch(url);
|
||||
return response.json();
|
||||
};"""
|
||||
|
||||
functions = ts_analyzer.find_functions(code)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "fetchData"
|
||||
assert functions[0].is_arrow is True
|
||||
assert functions[0].is_async is True
|
||||
assert ts_analyzer.has_return_statement(functions[0], code) is True
|
||||
|
||||
def test_non_exported_const_not_exported(self, ts_analyzer):
|
||||
"""Test that non-exported const functions are not marked as exported."""
|
||||
code = """const privateFunc = (x: number) => {
|
||||
return x * 2;
|
||||
};
|
||||
|
||||
export const publicFunc = (x: number) => {
|
||||
return privateFunc(x);
|
||||
};"""
|
||||
|
||||
# privateFunc should not be exported
|
||||
is_private_exported, _ = ts_analyzer.is_function_exported(code, "privateFunc")
|
||||
assert is_private_exported is False
|
||||
|
||||
# publicFunc should be exported
|
||||
is_public_exported, name = ts_analyzer.is_function_exported(code, "publicFunc")
|
||||
assert is_public_exported is True
|
||||
assert name == "publicFunc"
|
||||
|
||||
|
||||
class TestWrappedDefaultExports:
|
||||
"""Tests for wrapped default export pattern - Issue #9.
|
||||
|
||||
Handles patterns like:
|
||||
- export default curry(traverseEntity)
|
||||
- export default compose(fn1, fn2)
|
||||
- export default wrapper(myFunc)
|
||||
|
||||
These must be correctly recognized so the wrapped function is exportable.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def ts_analyzer(self):
|
||||
"""Create a TypeScript analyzer."""
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
|
||||
|
||||
def test_curry_wrapped_export(self, ts_analyzer):
|
||||
"""Test export default curry(fn) pattern."""
|
||||
code = """import { curry } from 'lodash/fp';
|
||||
|
||||
const traverseEntity = async (visitor, options, entity) => {
|
||||
return entity;
|
||||
};
|
||||
|
||||
export default curry(traverseEntity);"""
|
||||
|
||||
# Check exports parsing
|
||||
exports = ts_analyzer.find_exports(code)
|
||||
assert len(exports) == 1
|
||||
assert exports[0].default_export == "default"
|
||||
assert exports[0].wrapped_default_args == ["traverseEntity"]
|
||||
|
||||
# Check is_function_exported
|
||||
is_exported, export_name = ts_analyzer.is_function_exported(code, "traverseEntity")
|
||||
assert is_exported is True
|
||||
assert export_name == "default"
|
||||
|
||||
def test_compose_wrapped_export(self, ts_analyzer):
|
||||
"""Test export default compose(fn1, fn2) pattern with multiple args."""
|
||||
code = """import { compose } from 'lodash/fp';
|
||||
|
||||
function validateInput(data) { return data; }
|
||||
function processData(data) { return data; }
|
||||
|
||||
export default compose(validateInput, processData);"""
|
||||
|
||||
exports = ts_analyzer.find_exports(code)
|
||||
assert len(exports) == 1
|
||||
assert exports[0].wrapped_default_args == ["validateInput", "processData"]
|
||||
|
||||
# Both functions should be recognized as exported
|
||||
is_exported1, _ = ts_analyzer.is_function_exported(code, "validateInput")
|
||||
is_exported2, _ = ts_analyzer.is_function_exported(code, "processData")
|
||||
assert is_exported1 is True
|
||||
assert is_exported2 is True
|
||||
|
||||
def test_nested_wrapper_export(self, ts_analyzer):
|
||||
"""Test nested wrapper: export default compose(curry(fn))."""
|
||||
code = """export default compose(curry(myFunc));"""
|
||||
|
||||
exports = ts_analyzer.find_exports(code)
|
||||
assert len(exports) == 1
|
||||
assert "myFunc" in exports[0].wrapped_default_args
|
||||
|
||||
is_exported, _ = ts_analyzer.is_function_exported(code, "myFunc")
|
||||
assert is_exported is True
|
||||
|
||||
def test_generic_wrapper_export(self, ts_analyzer):
|
||||
"""Test generic wrapper function."""
|
||||
code = """const myFunction = (x: number) => x * 2;
|
||||
|
||||
export default someWrapper(myFunction);"""
|
||||
|
||||
is_exported, export_name = ts_analyzer.is_function_exported(code, "myFunction")
|
||||
assert is_exported is True
|
||||
assert export_name == "default"
|
||||
|
||||
def test_non_wrapped_function_not_exported(self, ts_analyzer):
|
||||
"""Test that functions not in the wrapper call are not exported."""
|
||||
code = """const helper = (x: number) => x + 1;
|
||||
const main = (x: number) => helper(x) * 2;
|
||||
|
||||
export default curry(main);"""
|
||||
|
||||
# main is wrapped, so it's exported
|
||||
is_main_exported, _ = ts_analyzer.is_function_exported(code, "main")
|
||||
assert is_main_exported is True
|
||||
|
||||
# helper is NOT in the wrapper call, so not exported
|
||||
is_helper_exported, _ = ts_analyzer.is_function_exported(code, "helper")
|
||||
assert is_helper_exported is False
|
||||
|
||||
def test_direct_default_export_still_works(self, ts_analyzer):
|
||||
"""Test that direct default exports still work."""
|
||||
code = """function myFunc() { return 1; }
|
||||
export default myFunc;"""
|
||||
|
||||
is_exported, export_name = ts_analyzer.is_function_exported(code, "myFunc")
|
||||
assert is_exported is True
|
||||
assert export_name == "default"
|
||||
|
||||
def test_strapi_traverse_entity_pattern(self, ts_analyzer):
|
||||
"""Test the exact strapi pattern that was failing."""
|
||||
code = """import { curry } from 'lodash/fp';
|
||||
|
||||
const traverseEntity = async (visitor: Visitor, options: TraverseOptions, entity: Data) => {
|
||||
const { path = { raw: null }, schema, getModel } = options;
|
||||
// ... implementation
|
||||
return copy;
|
||||
};
|
||||
|
||||
const createVisitorUtils = ({ data }: { data: Data }) => ({
|
||||
remove(key: string) { delete data[key]; },
|
||||
set(key: string, value: Data) { data[key] = value; },
|
||||
});
|
||||
|
||||
export default curry(traverseEntity);"""
|
||||
|
||||
# traverseEntity should be recognized as exported
|
||||
is_exported, export_name = ts_analyzer.is_function_exported(code, "traverseEntity")
|
||||
assert is_exported is True
|
||||
assert export_name == "default"
|
||||
|
||||
# createVisitorUtils is NOT wrapped, so not exported via default
|
||||
is_utils_exported, _ = ts_analyzer.is_function_exported(code, "createVisitorUtils")
|
||||
assert is_utils_exported is False
|
||||
|
|
|
|||
Loading…
Reference in a new issue