Merge branch 'main' into add_vitest_support_to_js

This commit is contained in:
Saurabh Misra 2026-02-01 12:32:09 -08:00 committed by GitHub
commit 82d9e435ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 2900 additions and 47 deletions

View file

@ -19,16 +19,30 @@ jobs:
(github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: read
contents: write
pull-requests: write
issues: read
id-token: write
actions: read # Required for Claude to read CI results on PRs
steps:
- name: Get PR head ref
id: pr-ref
env:
GH_TOKEN: ${{ github.token }}
run: |
# For issue_comment events, we need to fetch the PR info
if [ "${{ github.event_name }}" = "issue_comment" ]; then
PR_REF=$(gh api repos/${{ github.repository }}/pulls/${{ github.event.issue.number }} --jq '.head.ref')
echo "ref=$PR_REF" >> $GITHUB_OUTPUT
else
echo "ref=${{ github.event.pull_request.head.ref || github.head_ref }}" >> $GITHUB_OUTPUT
fi
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
fetch-depth: 0
ref: ${{ steps.pr-ref.outputs.ref }}
- name: Run Claude Code
id: claude

View file

@ -26,6 +26,7 @@ from codeflash.code_utils.compat import LF
from codeflash.code_utils.git_utils import get_git_remotes
from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell
from codeflash.telemetry.posthog_cf import ph
from rich.prompt import Confirm
class ProjectLanguage(Enum):
@ -208,9 +209,7 @@ def init_js_project(language: ProjectLanguage) -> None:
def should_modify_package_json_config() -> tuple[bool, dict[str, Any] | None]:
"""Check if package.json has valid codeflash config for JS/TS projects."""
from rich.prompt import Confirm
package_json_path = Path.cwd() / "package.json"
package_json_path = Path("package.json")
if not package_json_path.exists():
click.echo("❌ No package.json found. Please run 'npm init' first.")
@ -230,6 +229,10 @@ def should_modify_package_json_config() -> tuple[bool, dict[str, Any] | None]:
if not Path(module_root).is_dir():
return True, None
tests_root = config.get("testsRoot", None)
if tests_root and not Path(tests_root).is_dir():
return True, None
# Config is valid - ask if user wants to reconfigure
return Confirm.ask(
"✅ A valid Codeflash config already exists in package.json. Do you want to re-configure it?",

View file

@ -1563,23 +1563,228 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo
def get_opt_review_metrics(
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path, language: Language
) -> str:
if language != Language.PYTHON:
# TODO: {Claude} handle function refrences for other languages
return ""
"""Get function reference metrics for optimization review.
Uses the LanguageSupport abstraction to find references, supporting both Python and JavaScript/TypeScript.
Args:
source_code: Source code of the file containing the function.
file_path: Path to the file.
qualified_name: Qualified name of the function (e.g., "module.ClassName.method").
project_root: Root of the project.
tests_root: Root of the tests directory.
language: The programming language.
Returns:
Markdown-formatted string with code blocks showing calling functions.
"""
from codeflash.languages.base import FunctionInfo, ParentInfo, ReferenceInfo
from codeflash.languages.registry import get_language_support
start_time = time.perf_counter()
try:
# Get the language support
lang_support = get_language_support(language)
if lang_support is None:
return ""
# Parse qualified name to get function name and class name
qualified_name_split = qualified_name.rsplit(".", maxsplit=1)
if len(qualified_name_split) == 1:
target_function, target_class = qualified_name_split[0], None
function_name, class_name = qualified_name_split[0], None
else:
target_function, target_class = qualified_name_split[1], qualified_name_split[0]
matches = get_fn_references_jedi(
source_code, file_path, project_root, target_function, target_class
) # jedi is not perfect, it doesn't capture aliased references
calling_fns_details = find_occurances(qualified_name, str(file_path), matches, project_root, tests_root)
function_name, class_name = qualified_name_split[1], qualified_name_split[0]
# Create a FunctionInfo for the function
# We don't have full line info here, so we'll use defaults
parents = ()
if class_name:
parents = (ParentInfo(name=class_name, type="ClassDef"),)
func_info = FunctionInfo(
name=function_name,
file_path=file_path,
start_line=1,
end_line=1,
parents=parents,
language=language,
)
# Find references using language support
references = lang_support.find_references(func_info, project_root, tests_root, max_files=500)
if not references:
return ""
# Format references as markdown code blocks
calling_fns_details = _format_references_as_markdown(
references, file_path, project_root, language
)
except Exception as e:
logger.debug(f"Error getting function references: {e}")
calling_fns_details = ""
logger.debug(f"Investigate {e}")
end_time = time.perf_counter()
logger.debug(f"Got function references in {end_time - start_time:.2f} seconds")
return calling_fns_details
def _format_references_as_markdown(
references: list, file_path: Path, project_root: Path, language: Language
) -> str:
"""Format references as markdown code blocks with calling function code.
Args:
references: List of ReferenceInfo objects.
file_path: Path to the source file (to exclude).
project_root: Root of the project.
language: The programming language.
Returns:
Markdown-formatted string.
"""
# Group references by file
refs_by_file: dict[Path, list] = {}
for ref in references:
# Exclude the source file's definition/import references
if ref.file_path == file_path and ref.reference_type in ("import", "reexport"):
continue
if ref.file_path not in refs_by_file:
refs_by_file[ref.file_path] = []
refs_by_file[ref.file_path].append(ref)
fn_call_context = ""
context_len = 0
for ref_file, file_refs in refs_by_file.items():
if context_len > MAX_CONTEXT_LEN_REVIEW:
break
try:
path_relative = ref_file.relative_to(project_root)
except ValueError:
continue
# Get syntax highlighting language
ext = ref_file.suffix.lstrip(".")
if language == Language.PYTHON:
lang_hint = "python"
elif ext in ("ts", "tsx"):
lang_hint = "typescript"
else:
lang_hint = "javascript"
# Read the file to extract calling function context
try:
file_content = ref_file.read_text(encoding="utf-8")
lines = file_content.splitlines()
except Exception:
continue
# Get unique caller functions from this file
callers_seen: set[str] = set()
caller_contexts: list[str] = []
for ref in file_refs:
caller = ref.caller_function or "<module>"
if caller in callers_seen:
continue
callers_seen.add(caller)
# Extract context around the reference
if ref.caller_function:
# Try to extract the full calling function
func_code = _extract_calling_function(file_content, ref.caller_function, ref.line, language)
if func_code:
caller_contexts.append(func_code)
context_len += len(func_code)
else:
# Module-level call - show a few lines of context
start_line = max(0, ref.line - 3)
end_line = min(len(lines), ref.line + 2)
context_code = "\n".join(lines[start_line:end_line])
caller_contexts.append(context_code)
context_len += len(context_code)
if caller_contexts:
fn_call_context += f"```{lang_hint}:{path_relative}\n"
fn_call_context += "\n".join(caller_contexts)
fn_call_context += "\n```\n"
return fn_call_context
def _extract_calling_function(source_code: str, function_name: str, ref_line: int, language: Language) -> str | None:
"""Extract the source code of a calling function.
Args:
source_code: Full source code of the file.
function_name: Name of the function to extract.
ref_line: Line number where the reference is.
language: The programming language.
Returns:
Source code of the function, or None if not found.
"""
if language == Language.PYTHON:
return _extract_calling_function_python(source_code, function_name, ref_line)
else:
return _extract_calling_function_js(source_code, function_name, ref_line)
def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None:
"""Extract the source code of a calling function in Python."""
try:
import ast
tree = ast.parse(source_code)
lines = source_code.splitlines()
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if node.name == function_name:
# Check if the reference line is within this function
start_line = node.lineno
end_line = node.end_lineno or start_line
if start_line <= ref_line <= end_line:
return "\n".join(lines[start_line - 1 : end_line])
return None
except Exception:
return None
def _extract_calling_function_js(source_code: str, function_name: str, ref_line: int) -> str | None:
"""Extract the source code of a calling function in JavaScript/TypeScript.
Args:
source_code: Full source code of the file.
function_name: Name of the function to extract.
ref_line: Line number where the reference is (helps identify the right function).
Returns:
Source code of the function, or None if not found.
"""
try:
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
# Try TypeScript first, fall back to JavaScript
for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]:
try:
analyzer = TreeSitterAnalyzer(lang)
functions = analyzer.find_functions(source_code, include_methods=True)
for func in functions:
if func.name == function_name:
# Check if the reference line is within this function
if func.start_line <= ref_line <= func.end_line:
return func.source_text
break
except Exception:
continue
return None
except Exception:
return None

View file

@ -251,6 +251,9 @@ def parse_package_json_config(package_json_path: Path) -> tuple[dict[str, Any],
detected_module_root = detect_module_root(project_root, package_data)
config["module_root"] = str((project_root / Path(detected_module_root)).resolve())
if codeflash_config.get("testsRoot"):
config["tests_root"] = str(project_root / Path(codeflash_config["testsRoot"]).resolve())
# Auto-detect test runner
config["test_runner"] = detect_test_runner(project_root, package_data)
# Keep pytest_cmd for backwards compatibility with existing code

View file

@ -13,10 +13,14 @@ from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.formatter import format_code
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
from codeflash.languages.base import Language
from codeflash.languages.registry import get_language_support_by_common_formatters
from codeflash.lsp.helpers import is_LSP_enabled
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool:
def check_formatter_installed(
formatter_cmds: list[str], exit_on_failure: bool = True, language: str = "python"
) -> bool:
if not formatter_cmds or formatter_cmds[0] == "disabled":
return True
first_cmd = formatter_cmds[0]
@ -35,10 +39,21 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
)
return False
tmp_code = """print("hello world")"""
lang_support = get_language_support_by_common_formatters(formatter_cmds)
if not lang_support:
logger.debug(f"Could not determine language for formatter: {formatter_cmds}")
return True
if lang_support.language == Language.PYTHON:
tmp_code = """print("hello world")"""
elif lang_support.language in (Language.JAVASCRIPT, Language.TYPESCRIPT):
tmp_code = "console.log('hello world');"
else:
return True
try:
with tempfile.TemporaryDirectory() as tmpdir:
tmp_file = Path(tmpdir) / "test_codeflash_formatter.py"
tmp_file = Path(tmpdir) / ("test_codeflash_formatter" + lang_support.default_file_extension)
tmp_file.write_text(tmp_code, encoding="utf-8")
format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=False)
return True

View file

@ -13,6 +13,7 @@ from typing import Any, Optional, Union
import isort
from codeflash.cli_cmds.console import console, logger
from codeflash.languages.registry import get_language_support
from codeflash.lsp.helpers import is_LSP_enabled
@ -47,8 +48,9 @@ def apply_formatter_cmds(
raise FileNotFoundError(msg)
file_path = path
lang_support = get_language_support(path)
if test_dir_str:
file_path = Path(test_dir_str) / "temp.py"
file_path = Path(test_dir_str) / ("temp" + lang_support.default_file_extension)
shutil.copy2(path, file_path)
file_token = "$file" # noqa: S105
@ -87,13 +89,14 @@ def get_diff_lines_count(diff_output: str) -> int:
return len(diff_lines)
def format_generated_code(generated_test_source: str, formatter_cmds: list[str]) -> str:
def format_generated_code(generated_test_source: str, formatter_cmds: list[str], language: str = "python") -> str:
formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
if formatter_name == "disabled": # nothing to do if no formatter provided
return re.sub(r"\n{2,}", "\n\n", generated_test_source)
with tempfile.TemporaryDirectory() as test_dir_str:
# try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) return code with 2 or more newlines substituted with 2 newlines
original_temp = Path(test_dir_str) / "original_temp.py"
lang_support = get_language_support(language)
original_temp = Path(test_dir_str) / ("original_temp" + lang_support.default_file_extension)
original_temp.write_text(generated_test_source, encoding="utf8")
_, formatted_code, changed = apply_formatter_cmds(
formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=False
@ -130,7 +133,8 @@ def format_code(
# we don't count the formatting diff for the optimized function as it should be well-formatted
original_code_without_opfunc = original_code.replace(optimized_code, "")
original_temp = Path(test_dir_str) / "original_temp.py"
lang_support = get_language_support(path)
original_temp = Path(test_dir_str) / ("original_temp" + lang_support.default_file_extension)
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
formatted_temp, formatted_code, changed = apply_formatter_cmds(
@ -160,6 +164,7 @@ def format_code(
_, formatted_code, changed = apply_formatter_cmds(
formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure
)
if not changed:
logger.warning(
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"

View file

@ -41,6 +41,8 @@ if TYPE_CHECKING:
from codeflash.models.models import CodeOptimizationContext
from codeflash.verification.verification_utils import TestConfig
import contextlib
from rich.text import Text
_property_id = "property"
@ -616,9 +618,10 @@ def get_all_replay_test_functions(
except Exception as e:
logger.warning(f"Error parsing replay test file {replay_test_file}: {e}")
if not trace_file_path:
if trace_file_path is None:
logger.error("Could not find trace_file_path in replay test files.")
exit_with_message("Could not find trace_file_path in replay test files.")
raise AssertionError("Unreachable") # exit_with_message never returns
if not trace_file_path.exists():
logger.error(f"Trace file not found: {trace_file_path}")
@ -673,7 +676,7 @@ def get_all_replay_test_functions(
if filtered_list:
filtered_valid_functions[file_path] = filtered_list
return filtered_valid_functions, trace_file_path
return dict(filtered_valid_functions), trace_file_path
def is_git_repo(file_path: str) -> bool:
@ -685,11 +688,13 @@ def is_git_repo(file_path: str) -> bool:
@cache
def ignored_submodule_paths(module_root: str) -> list[str]:
def ignored_submodule_paths(module_root: str) -> list[Path]:
if is_git_repo(module_root):
git_repo = git.Repo(module_root, search_parent_directories=True)
try:
return [Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules]
working_dir = git_repo.working_tree_dir
if working_dir is not None:
return [Path(working_dir, submodule.path).resolve() for submodule in git_repo.submodules]
except Exception as e:
logger.warning(f"Error getting submodule paths: {e}")
return []
@ -703,7 +708,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
self.class_name = class_name
self.function_name = function_or_method_name
self.is_top_level = False
self.function_has_args = None
self.function_has_args: bool | None = None
self.line_no = line_no
self.is_staticmethod = False
self.is_classmethod = False
@ -817,31 +822,28 @@ def was_function_previously_optimized(
# Check optimization status if repository info is provided
# already_optimized_count = 0
try:
# Check optimization status if repository info is provided
# already_optimized_count = 0
owner = None
repo = None
with contextlib.suppress(git.exc.InvalidGitRepositoryError):
owner, repo = get_repo_owner_and_name()
except git.exc.InvalidGitRepositoryError:
logger.warning("No git repository found")
owner, repo = None, None
pr_number = get_pr_number()
if not owner or not repo or pr_number is None or getattr(args, "no_pr", False):
return False
code_contexts = []
func_hash = code_context.hashing_code_context_hash
# Use a unique path identifier that includes function info
code_contexts.append(
code_contexts = [
{
"file_path": function_to_optimize.file_path,
"file_path": str(function_to_optimize.file_path),
"function_name": function_to_optimize.qualified_name,
"code_hash": func_hash,
}
)
if not code_contexts:
return False
]
try:
result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts)
@ -860,7 +862,7 @@ def filter_functions(
ignore_paths: list[Path],
project_root: Path,
module_root: Path,
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
previous_checkpoint_functions: dict[str, dict[str, Any]] | None = None,
*,
disable_logs: bool = False,
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
@ -885,21 +887,49 @@ def filter_functions(
# Normalize paths for case-insensitive comparison on Windows
tests_root_str = os.path.normcase(str(tests_root))
module_root_str = os.path.normcase(str(module_root))
project_root_str = os.path.normcase(str(project_root))
# Check if tests_root overlaps with module_root or project_root
# In this case, we need to use file pattern matching instead of directory matching
tests_root_overlaps_source = tests_root_str in (module_root_str, project_root_str) or module_root_str.startswith(
tests_root_str + os.sep
)
# Test file patterns for when tests_root overlaps with source
test_file_name_patterns = (".test.", ".spec.", "_test.", "_spec.")
test_dir_patterns = (os.sep + "test" + os.sep, os.sep + "tests" + os.sep, os.sep + "__tests__" + os.sep)
def is_test_file(file_path_normalized: str) -> bool:
"""Check if a file is a test file based on patterns."""
if tests_root_overlaps_source:
# Use file pattern matching when tests_root overlaps with source
file_lower = file_path_normalized.lower()
# Check filename patterns (e.g., .test.ts, .spec.ts)
if any(pattern in file_lower for pattern in test_file_name_patterns):
return True
# Check directory patterns, but only within the project root
# to avoid false positives from parent directories
relative_path = file_lower
if project_root_str and file_lower.startswith(project_root_str.lower()):
relative_path = file_lower[len(project_root_str) :]
return any(pattern in relative_path for pattern in test_dir_patterns)
# Use directory-based filtering when tests are in a separate directory
return file_path_normalized.startswith(tests_root_str + os.sep)
# We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
for file_path_path, functions in modified_functions.items():
_functions = functions
file_path = str(file_path_path)
file_path_normalized = os.path.normcase(file_path)
if file_path_normalized.startswith(tests_root_str + os.sep):
if is_test_file(file_path_normalized):
test_functions_removed_count += len(_functions)
continue
if file_path in ignore_paths or any(
if file_path_path in ignore_paths or any(
file_path_normalized.startswith(os.path.normcase(str(ignore_path)) + os.sep) for ignore_path in ignore_paths
):
ignore_paths_removed_count += 1
continue
if file_path in submodule_paths or any(
if file_path_path in submodule_paths or any(
file_path_normalized.startswith(os.path.normcase(str(submodule_path)) + os.sep)
for submodule_path in submodule_paths
):
@ -991,7 +1021,7 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
# Custom DFS, return True as soon as a Return node is found
stack = [function_node]
stack: list[ast.AST] = [function_node]
while stack:
node = stack.pop()
if isinstance(node, ast.Return):

View file

@ -236,6 +236,37 @@ class FunctionFilterCriteria:
max_lines: int | None = None
@dataclass
class ReferenceInfo:
"""Information about a reference (call site) to a function.
This class captures information about where a function is called
from, including the file, line number, context, and caller function.
Attributes:
file_path: Path to the file containing the reference.
line: Line number (1-indexed).
column: Column number (0-indexed).
end_line: End line number (1-indexed).
end_column: End column number (0-indexed).
context: The line of code containing the reference.
reference_type: Type of reference ("call", "callback", "memoized", "import", "reexport").
import_name: Name used to import the function (may differ from original).
caller_function: Name of the function containing this reference (or None for module-level).
"""
file_path: Path
line: int
column: int
end_line: int
end_column: int
context: str
reference_type: str
import_name: str | None
caller_function: str | None = None
@runtime_checkable
class LanguageSupport(Protocol):
"""Protocol defining what a language implementation must provide.
@ -278,6 +309,11 @@ class LanguageSupport(Protocol):
"""
...
@property
def default_file_extension(self) -> str:
"""Default file extension for this language."""
...
@property
def test_framework(self) -> str:
"""Primary test framework name.
@ -352,6 +388,29 @@ class LanguageSupport(Protocol):
"""
...
def find_references(
self, function: FunctionInfo, project_root: Path, tests_root: Path | None = None, max_files: int = 500
) -> list[ReferenceInfo]:
"""Find all references (call sites) to a function across the codebase.
This method finds all places where a function is called, including:
- Direct calls
- Callbacks (passed to other functions)
- Memoized versions
- Re-exports
Args:
function: The function to find references for.
project_root: Root of the project to search.
tests_root: Root of tests directory (references in tests are excluded).
max_files: Maximum number of files to search.
Returns:
List of ReferenceInfo objects describing each reference location.
"""
...
# === Code Transformation ===
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:

View file

@ -0,0 +1,861 @@
"""Find references for JavaScript/TypeScript functions.
This module provides functionality to find all references (call sites) of a function
across a JavaScript/TypeScript codebase. Similar to Jedi's find_references for Python,
this uses tree-sitter to parse and analyze code.
Key features:
- Finds all call sites of a function across multiple files
- Handles various import patterns (named, default, namespace, re-exports, aliases)
- Supports both ES modules and CommonJS
- Handles memoized functions, callbacks, and method calls
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from tree_sitter import Node
from codeflash.languages.treesitter_utils import ExportInfo, ImportInfo, TreeSitterAnalyzer
logger = logging.getLogger(__name__)
@dataclass
class Reference:
"""Represents a reference (call site) to a function."""
file_path: Path # File containing the reference
line: int # 1-indexed line number
column: int # 0-indexed column number
end_line: int # 1-indexed end line
end_column: int # 0-indexed end column
context: str # The line of code containing the reference
reference_type: str # Type: "call", "callback", "memoized", "import", "reexport"
import_name: str | None # Name used to import the function (may differ from original)
caller_function: str | None = None # Name of the function containing this reference
@dataclass
class ExportedFunction:
"""Represents how a function is exported from its source file."""
function_name: str # The local function name
export_name: str | None # The name it's exported as (may differ)
is_default: bool # Whether it's a default export
file_path: Path # The source file
@dataclass
class ReferenceSearchContext:
"""Context for tracking visited files during reference search."""
visited_files: set[Path] = field(default_factory=set)
max_files: int = 1000 # Limit to prevent runaway searches
class ReferenceFinder:
"""Finds all references to a function across a JavaScript/TypeScript codebase.
This class provides functionality similar to Jedi's find_references for Python,
but for JavaScript/TypeScript using tree-sitter.
Example usage:
```python
from codeflash.languages.javascript.find_references import ReferenceFinder
finder = ReferenceFinder(project_root=Path("/my/project"))
references = finder.find_references(
function_name="myHelper",
source_file=Path("/my/project/src/utils.ts")
)
for ref in references:
print(f"{ref.file_path}:{ref.line} - {ref.context}")
```
"""
# File extensions to search
EXTENSIONS = (".ts", ".tsx", ".js", ".jsx", ".mjs", ".cjs")
def __init__(self, project_root: Path, exclude_patterns: list[str] | None = None) -> None:
"""Initialize the ReferenceFinder.
Args:
project_root: Root directory of the project to search.
exclude_patterns: Glob patterns of directories/files to exclude.
Defaults to ['node_modules', 'dist', 'build', '.git'].
"""
self.project_root = project_root
self.exclude_patterns = exclude_patterns or ["node_modules", "dist", "build", ".git", "coverage", "__pycache__"]
self._file_cache: dict[Path, str] = {}
def find_references(
self,
function_name: str,
source_file: Path,
include_definition: bool = False,
max_files: int = 1000,
) -> list[Reference]:
"""Find all references to a function across the project.
Args:
function_name: Name of the function to find references for.
source_file: Path to the file where the function is defined.
include_definition: Whether to include the function definition itself.
max_files: Maximum number of files to search (prevents runaway searches).
Returns:
List of Reference objects describing each call site.
"""
from codeflash.languages.treesitter_utils import get_analyzer_for_file
references: list[Reference] = []
context = ReferenceSearchContext(max_files=max_files)
# Step 1: Analyze how the function is exported from its source file
source_code = self._read_file(source_file)
if source_code is None:
logger.warning("Could not read source file: %s", source_file)
return references
analyzer = get_analyzer_for_file(source_file)
exported = self._analyze_exports(function_name, source_file, source_code, analyzer)
if not exported:
logger.debug("Function %s is not exported from %s", function_name, source_file)
# Still search in same file for internal references
same_file_refs = self._find_references_in_file(
source_file, source_code, function_name, None, analyzer, include_self=not include_definition
)
references.extend(same_file_refs)
return references
# Step 2: Find all files that might import from the source file
context.visited_files.add(source_file)
# Track files that re-export our function (we'll search for imports to these too)
reexport_files: list[tuple[Path, str]] = [] # (file_path, export_name)
# Step 3: Search all project files for imports and calls
# We use a separate set for files checked for re-exports to avoid duplicate work
checked_for_reexports: set[Path] = set()
for file_path in self._iter_project_files():
if file_path in context.visited_files:
continue
if len(context.visited_files) >= context.max_files:
logger.warning("Reached max file limit (%d), stopping search", max_files)
break
file_code = self._read_file(file_path)
if file_code is None:
continue
file_analyzer = get_analyzer_for_file(file_path)
# Check if this file imports from the source file
imports = file_analyzer.find_imports(file_code)
import_info = self._find_matching_import(imports, source_file, file_path, exported)
if import_info:
# Found an import - mark as visited and search for calls
context.visited_files.add(file_path)
import_name, original_import = import_info
file_refs = self._find_references_in_file(
file_path, file_code, function_name, import_name, file_analyzer, include_self=True
)
references.extend(file_refs)
# Always check for re-exports (even without direct import match)
# This handles the case where a file re-exports from our source file
if file_path not in checked_for_reexports:
checked_for_reexports.add(file_path)
reexport_refs = self._find_reexports_direct(
file_path, file_code, source_file, exported, file_analyzer
)
references.extend(reexport_refs)
# Track re-export files for later searching
for ref in reexport_refs:
reexport_files.append((file_path, ref.import_name))
# Step 4: Follow re-export chains to find references through re-exports
for reexport_file, reexport_name in reexport_files:
# Create a new ExportedFunction for the re-exported function
reexported = ExportedFunction(
function_name=reexport_name,
export_name=reexport_name,
is_default=False,
file_path=reexport_file,
)
# Search for imports to the re-export file
for file_path in self._iter_project_files():
if file_path in context.visited_files:
continue
if file_path == reexport_file:
continue
if len(context.visited_files) >= context.max_files:
break
file_code = self._read_file(file_path)
if file_code is None:
continue
file_analyzer = get_analyzer_for_file(file_path)
imports = file_analyzer.find_imports(file_code)
# Check if this file imports from the re-export file
import_info = self._find_matching_import(imports, reexport_file, file_path, reexported)
if import_info:
context.visited_files.add(file_path)
import_name, original_import = import_info
file_refs = self._find_references_in_file(
file_path, file_code, reexport_name, import_name, file_analyzer, include_self=True
)
# Avoid duplicates
existing_locs = {(r.file_path, r.line, r.column) for r in references}
for ref in file_refs:
if (ref.file_path, ref.line, ref.column) not in existing_locs:
references.append(ref)
# Step 5: Include references in the same file (internal calls)
if include_definition or not exported:
same_file_refs = self._find_references_in_file(
source_file, source_code, function_name, None, analyzer, include_self=True
)
# Filter out duplicate references
existing_locs = {(r.file_path, r.line, r.column) for r in references}
for ref in same_file_refs:
if (ref.file_path, ref.line, ref.column) not in existing_locs:
references.append(ref)
# Step 6: Deduplicate references (same file, line, column)
seen: set[tuple[Path, int, int]] = set()
unique_refs: list[Reference] = []
for ref in references:
key = (ref.file_path, ref.line, ref.column)
if key not in seen:
seen.add(key)
unique_refs.append(ref)
return unique_refs
def _analyze_exports(
self, function_name: str, file_path: Path, source_code: str, analyzer: TreeSitterAnalyzer
) -> ExportedFunction | None:
"""Analyze how a function is exported from its file.
Args:
function_name: Name of the function to check.
file_path: Path to the source file.
source_code: Source code content.
analyzer: TreeSitterAnalyzer instance.
Returns:
ExportedFunction if the function is exported, None otherwise.
"""
is_exported, export_name = analyzer.is_function_exported(source_code, function_name)
if not is_exported:
return None
return ExportedFunction(
function_name=function_name,
export_name=export_name,
is_default=(export_name == "default"),
file_path=file_path,
)
def _find_matching_import(
self,
imports: list[ImportInfo],
source_file: Path,
importing_file: Path,
exported: ExportedFunction,
) -> tuple[str, ImportInfo] | None:
"""Find if any import in a file imports the target function.
Args:
imports: List of imports in the file.
source_file: Path to the file containing the function definition.
importing_file: Path to the file being checked for imports.
exported: Information about how the function is exported.
Returns:
Tuple of (imported_name, ImportInfo) if found, None otherwise.
"""
from codeflash.languages.javascript.import_resolver import ImportResolver
resolver = ImportResolver(self.project_root)
for imp in imports:
# Resolve the import to see if it points to our source file
resolved = resolver.resolve_import(imp, importing_file)
if resolved is None:
continue
if resolved.file_path != source_file:
continue
# This import is from our source file - check if it imports our function
if exported.is_default:
# Default export - check default import
if imp.default_import:
return (imp.default_import, imp)
# Also check namespace import
if imp.namespace_import:
return (f"{imp.namespace_import}.default", imp)
else:
# Named export - check named imports
export_name = exported.export_name or exported.function_name
for name, alias in imp.named_imports:
if name == export_name:
return (alias if alias else name, imp)
# Check namespace import
if imp.namespace_import:
return (f"{imp.namespace_import}.{export_name}", imp)
# Handle CommonJS default import used as namespace
# e.g., const helpers = require('./helpers'); helpers.processConfig()
# In this case, default_import acts like a namespace
if imp.default_import and not imp.named_imports:
return (f"{imp.default_import}.{export_name}", imp)
return None
def _find_references_in_file(
self,
file_path: Path,
source_code: str,
function_name: str,
import_name: str | None,
analyzer: TreeSitterAnalyzer,
include_self: bool = True,
) -> list[Reference]:
"""Find all references to a function within a single file.
Args:
file_path: Path to the file to search.
source_code: Source code content.
function_name: Original function name.
import_name: Name the function is imported as (may be different).
analyzer: TreeSitterAnalyzer instance.
include_self: Whether to include references in the file.
Returns:
List of Reference objects.
"""
references: list[Reference] = []
source_bytes = source_code.encode("utf8")
tree = analyzer.parse(source_bytes)
lines = source_code.splitlines()
# The name to search for (either imported name or original)
search_name = import_name if import_name else function_name
# Handle namespace imports (e.g., "utils.helper")
if "." in search_name:
namespace, member = search_name.split(".", 1)
self._find_member_calls(
tree.root_node, source_bytes, lines, file_path, namespace, member, references, None
)
else:
# Find direct calls and other reference types
self._find_identifier_references(
tree.root_node, source_bytes, lines, file_path, search_name, function_name, references, None
)
return references
def _find_identifier_references(
self,
node: Node,
source_bytes: bytes,
lines: list[str],
file_path: Path,
search_name: str,
original_name: str,
references: list[Reference],
current_function: str | None,
) -> None:
"""Recursively find references to an identifier in the AST.
Args:
node: Current tree-sitter node.
source_bytes: Source code as bytes.
lines: Source code split into lines.
file_path: Path to the file.
search_name: Name to search for.
original_name: Original function name.
references: List to append references to.
current_function: Name of the containing function (for context).
"""
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer
# Track current function context
new_current_function = current_function
if node.type in ("function_declaration", "method_definition"):
name_node = node.child_by_field_name("name")
if name_node:
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
elif node.type in ("variable_declarator",):
# Arrow function or function expression assigned to variable
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if name_node and value_node and value_node.type in ("arrow_function", "function_expression"):
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
# Check for call expressions
if node.type == "call_expression":
func_node = node.child_by_field_name("function")
if func_node and func_node.type == "identifier":
name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8")
if name == search_name:
ref = self._create_reference(
file_path, func_node, lines, "call", search_name, current_function
)
references.append(ref)
# Check for identifiers used as callbacks or passed as arguments
elif node.type == "identifier":
name = source_bytes[node.start_byte : node.end_byte].decode("utf8")
if name == search_name:
parent = node.parent
# Determine reference type based on context
ref_type = self._determine_reference_type(node, parent, source_bytes)
if ref_type:
ref = self._create_reference(
file_path, node, lines, ref_type, search_name, current_function
)
references.append(ref)
# Recurse into children
for child in node.children:
self._find_identifier_references(
child, source_bytes, lines, file_path, search_name, original_name, references, new_current_function
)
def _find_member_calls(
self,
node: Node,
source_bytes: bytes,
lines: list[str],
file_path: Path,
namespace: str,
member: str,
references: list[Reference],
current_function: str | None,
) -> None:
"""Find calls to namespace.member (e.g., utils.helper()).
Args:
node: Current tree-sitter node.
source_bytes: Source code as bytes.
lines: Source code split into lines.
file_path: Path to the file.
namespace: The namespace/object name.
member: The member/property name.
references: List to append references to.
current_function: Name of the containing function.
"""
# Track current function context
new_current_function = current_function
if node.type in ("function_declaration", "method_definition"):
name_node = node.child_by_field_name("name")
if name_node:
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
# Check for call expressions with member access
if node.type == "call_expression":
func_node = node.child_by_field_name("function")
if func_node and func_node.type == "member_expression":
obj_node = func_node.child_by_field_name("object")
prop_node = func_node.child_by_field_name("property")
if obj_node and prop_node:
obj_name = source_bytes[obj_node.start_byte : obj_node.end_byte].decode("utf8")
prop_name = source_bytes[prop_node.start_byte : prop_node.end_byte].decode("utf8")
if obj_name == namespace and prop_name == member:
ref = self._create_reference(
file_path, func_node, lines, "call", f"{namespace}.{member}", current_function
)
references.append(ref)
# Also check for member expression used as callback
elif node.type == "member_expression":
obj_node = node.child_by_field_name("object")
prop_node = node.child_by_field_name("property")
if obj_node and prop_node:
obj_name = source_bytes[obj_node.start_byte : obj_node.end_byte].decode("utf8")
prop_name = source_bytes[prop_node.start_byte : prop_node.end_byte].decode("utf8")
if obj_name == namespace and prop_name == member:
parent = node.parent
if parent and parent.type != "call_expression":
ref_type = self._determine_reference_type(node, parent, source_bytes)
if ref_type:
ref = self._create_reference(
file_path, node, lines, ref_type, f"{namespace}.{member}", current_function
)
references.append(ref)
# Recurse into children
for child in node.children:
self._find_member_calls(
child, source_bytes, lines, file_path, namespace, member, references, new_current_function
)
def _determine_reference_type(self, node: Node, parent: Node | None, source_bytes: bytes) -> str | None:
"""Determine the type of reference based on AST context.
Args:
node: The identifier node.
parent: The parent node.
source_bytes: Source code as bytes.
Returns:
Reference type string or None if this isn't a valid reference.
"""
if parent is None:
return None
# Skip import statements
if parent.type in ("import_specifier", "import_clause", "named_imports"):
return None
# Skip function declarations (the function name itself)
if parent.type in ("function_declaration", "method_definition"):
name_node = parent.child_by_field_name("name")
if name_node and name_node.id == node.id:
return None
# Skip variable declarations where this is being defined
if parent.type == "variable_declarator":
name_node = parent.child_by_field_name("name")
if name_node and name_node.id == node.id:
return None
# Skip export specifiers
if parent.type == "export_specifier":
return None
# Check if passed as argument (callback or memoized)
if parent.type == "arguments":
# Check if grandparent is a memoize call
grandparent = parent.parent
if grandparent and grandparent.type == "call_expression":
func_node = grandparent.child_by_field_name("function")
if func_node:
func_name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8")
if any(m in func_name.lower() for m in ["memoize", "memo", "cache"]):
return "memoized"
return "callback"
# Check if used in array (often callback patterns)
if parent.type == "array":
return "callback"
# Check if passed to memoize/memoization functions (direct call check)
if parent.type == "call_expression":
func_node = parent.child_by_field_name("function")
if func_node:
func_name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8")
if any(m in func_name.lower() for m in ["memoize", "memo", "cache"]):
return "memoized"
# Check if used in a call expression as the function
if parent.type == "call_expression":
func_node = parent.child_by_field_name("function")
if func_node and func_node.id == node.id:
return "call"
# Check if assigned to a property
if parent.type in ("pair", "property"):
return "property"
# Check if part of member expression (method call setup)
if parent.type == "member_expression":
obj_node = parent.child_by_field_name("object")
if obj_node and obj_node.id == node.id:
# This is the object in obj.method
return None # We'll catch the actual call elsewhere
# Generic reference
return "reference"
def _create_reference(
self,
file_path: Path,
node: Node,
lines: list[str],
ref_type: str,
import_name: str,
caller_function: str | None,
) -> Reference:
"""Create a Reference object from a node.
Args:
file_path: Path to the file.
node: The tree-sitter node.
lines: Source code lines.
ref_type: Type of reference.
import_name: Name the function was imported as.
caller_function: Name of the containing function.
Returns:
A Reference object.
"""
line_num = node.start_point[0] + 1 # 1-indexed
context = lines[node.start_point[0]] if node.start_point[0] < len(lines) else ""
return Reference(
file_path=file_path,
line=line_num,
column=node.start_point[1],
end_line=node.end_point[0] + 1,
end_column=node.end_point[1],
context=context.strip(),
reference_type=ref_type,
import_name=import_name,
caller_function=caller_function,
)
def _find_reexports(
self,
file_path: Path,
source_code: str,
exported: ExportedFunction,
analyzer: TreeSitterAnalyzer,
context: ReferenceSearchContext,
) -> list[Reference]:
"""Find re-exports of the function.
Re-exports look like: export { helper } from './utils'
Args:
file_path: Path to the file being checked.
source_code: Source code content.
exported: Information about the original export.
analyzer: TreeSitterAnalyzer instance.
context: Search context.
Returns:
List of Reference objects for re-exports.
"""
references: list[Reference] = []
exports = analyzer.find_exports(source_code)
lines = source_code.splitlines()
for exp in exports:
if not exp.is_reexport:
continue
# Check if this re-exports our function
export_name = exported.export_name or exported.function_name
for name, alias in exp.exported_names:
if name == export_name:
# This is a re-export of our function
# Create a reference with the line info from the export
context_line = lines[exp.start_line - 1] if exp.start_line <= len(lines) else ""
ref = Reference(
file_path=file_path,
line=exp.start_line,
column=0,
end_line=exp.end_line,
end_column=0,
context=context_line.strip(),
reference_type="reexport",
import_name=alias if alias else name,
caller_function=None,
)
references.append(ref)
return references
def _find_reexports_direct(
self,
file_path: Path,
source_code: str,
source_file: Path,
exported: ExportedFunction,
analyzer: TreeSitterAnalyzer,
) -> list[Reference]:
"""Find re-exports that directly reference our source file.
This method checks if a file has re-export statements that
reference our source file.
Args:
file_path: Path to the file being checked.
source_code: Source code content.
source_file: The original source file we're looking for references to.
exported: Information about the original export.
analyzer: TreeSitterAnalyzer instance.
Returns:
List of Reference objects for re-exports.
"""
from codeflash.languages.javascript.import_resolver import ImportResolver
references: list[Reference] = []
exports = analyzer.find_exports(source_code)
lines = source_code.splitlines()
resolver = ImportResolver(self.project_root)
for exp in exports:
if not exp.is_reexport or not exp.reexport_source:
continue
# Create a fake ImportInfo to resolve the re-export source
from codeflash.languages.treesitter_utils import ImportInfo
fake_import = ImportInfo(
module_path=exp.reexport_source,
default_import=None,
named_imports=[],
namespace_import=None,
is_type_only=False,
start_line=exp.start_line,
end_line=exp.end_line,
)
resolved = resolver.resolve_import(fake_import, file_path)
if resolved is None or resolved.file_path != source_file:
continue
# This file re-exports from our source file
export_name = exported.export_name or exported.function_name
for name, alias in exp.exported_names:
if name == export_name:
context_line = lines[exp.start_line - 1] if exp.start_line <= len(lines) else ""
ref = Reference(
file_path=file_path,
line=exp.start_line,
column=0,
end_line=exp.end_line,
end_column=0,
context=context_line.strip(),
reference_type="reexport",
import_name=alias if alias else name,
caller_function=None,
)
references.append(ref)
return references
def _iter_project_files(self) -> list[Path]:
"""Iterate over all JavaScript/TypeScript files in the project.
Returns:
List of file paths to search.
"""
files: list[Path] = []
for ext in self.EXTENSIONS:
for file_path in self.project_root.rglob(f"*{ext}"):
# Check exclusion patterns
if self._should_exclude(file_path):
continue
files.append(file_path)
return files
def _should_exclude(self, file_path: Path) -> bool:
"""Check if a file should be excluded from search.
Args:
file_path: Path to check.
Returns:
True if the file should be excluded.
"""
path_str = str(file_path)
for pattern in self.exclude_patterns:
if pattern in path_str:
return True
return False
def _read_file(self, file_path: Path) -> str | None:
"""Read a file's contents with caching.
Args:
file_path: Path to the file.
Returns:
File contents or None if unreadable.
"""
if file_path in self._file_cache:
return self._file_cache[file_path]
try:
content = file_path.read_text(encoding="utf-8")
self._file_cache[file_path] = content
return content
except Exception as e:
logger.debug("Could not read file %s: %s", file_path, e)
return None
def find_references(
function_name: str,
source_file: Path,
project_root: Path | None = None,
max_files: int = 1000,
) -> list[Reference]:
"""Convenience function to find all references to a function.
This is a simple wrapper around ReferenceFinder for common use cases.
Args:
function_name: Name of the function to find references for.
source_file: Path to the file where the function is defined.
project_root: Root directory of the project. If None, uses source_file's parent.
max_files: Maximum number of files to search.
Returns:
List of Reference objects describing each call site.
Example:
```python
from pathlib import Path
from codeflash.languages.javascript.find_references import find_references
refs = find_references(
function_name="myHelper",
source_file=Path("/my/project/src/utils.ts"),
project_root=Path("/my/project")
)
for ref in refs:
print(f"{ref.file_path}:{ref.line}:{ref.column} - {ref.reference_type}")
```
"""
if project_root is None:
project_root = source_file.parent
finder = ReferenceFinder(project_root)
return finder.find_references(function_name, source_file, max_files=max_files)

View file

@ -19,6 +19,7 @@ from codeflash.languages.base import (
HelperFunction,
Language,
ParentInfo,
ReferenceInfo,
TestInfo,
TestResult,
)
@ -53,6 +54,11 @@ class JavaScriptSupport:
"""File extensions supported by JavaScript."""
return (".js", ".jsx", ".mjs", ".cjs")
@property
def default_file_extension(self) -> str:
"""Default file extension for JavaScript."""
return ".js"
@property
def test_framework(self) -> str:
"""Primary test framework for JavaScript."""
@ -959,6 +965,66 @@ class JavaScriptSupport:
logger.warning("Failed to find helpers for %s: %s", function.name, e)
return []
def find_references(
self,
function: FunctionInfo,
project_root: Path,
tests_root: Path | None = None,
max_files: int = 500,
) -> list[ReferenceInfo]:
"""Find all references (call sites) to a function across the codebase.
Uses tree-sitter to find all places where a JavaScript/TypeScript function
is called, including direct calls, callbacks, memoized versions, and re-exports.
Args:
function: The function to find references for.
project_root: Root of the project to search.
tests_root: Root of tests directory (references in tests are excluded).
max_files: Maximum number of files to search.
Returns:
List of ReferenceInfo objects describing each reference location.
"""
from codeflash.languages.base import ReferenceInfo
from codeflash.languages.javascript.find_references import ReferenceFinder
try:
finder = ReferenceFinder(project_root)
refs = finder.find_references(function.name, function.file_path, max_files=max_files)
# Convert to ReferenceInfo and filter out tests
result: list[ReferenceInfo] = []
for ref in refs:
# Exclude test files if tests_root is provided
if tests_root:
try:
ref.file_path.relative_to(tests_root)
continue # Skip if in tests_root
except ValueError:
pass # Not in tests_root, include it
result.append(
ReferenceInfo(
file_path=ref.file_path,
line=ref.line,
column=ref.column,
end_line=ref.end_line,
end_column=ref.end_column,
context=ref.context,
reference_type=ref.reference_type,
import_name=ref.import_name,
caller_function=ref.caller_function,
)
)
return result
except Exception as e:
logger.warning("Failed to find references for %s: %s", function.name, e)
return []
# === Code Transformation ===
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:

View file

@ -13,6 +13,7 @@ from codeflash.languages.base import (
HelperFunction,
Language,
ParentInfo,
ReferenceInfo,
TestInfo,
TestResult,
)
@ -45,6 +46,11 @@ class PythonSupport:
"""File extensions supported by Python."""
return (".py", ".pyw")
@property
def default_file_extension(self) -> str:
"""Default file extension for Python."""
return ".py"
@property
def test_framework(self) -> str:
"""Primary test framework for Python."""
@ -289,6 +295,120 @@ class PythonSupport:
return helpers
def find_references(
self,
function: FunctionInfo,
project_root: Path,
tests_root: Path | None = None,
max_files: int = 500,
) -> list[ReferenceInfo]:
"""Find all references (call sites) to a function across the codebase.
Uses jedi to find all places where a Python function is called.
Args:
function: The function to find references for.
project_root: Root of the project to search.
tests_root: Root of tests directory (references in tests are excluded).
max_files: Maximum number of files to search.
Returns:
List of ReferenceInfo objects describing each reference location.
"""
try:
import jedi
source = function.file_path.read_text()
# Find the function position
script = jedi.Script(code=source, path=function.file_path)
names = script.get_names(all_scopes=True, definitions=True)
function_pos = None
for name in names:
if name.type == "function" and name.name == function.name:
# Check for class parent if it's a method
if function.class_name:
parent = name.parent()
if parent and parent.name == function.class_name and parent.type == "class":
function_pos = (name.line, name.column)
break
else:
function_pos = (name.line, name.column)
break
if function_pos is None:
return []
# Get references using jedi
script = jedi.Script(code=source, path=function.file_path, project=jedi.Project(path=project_root))
references = script.get_references(line=function_pos[0], column=function_pos[1])
result: list[ReferenceInfo] = []
seen_locations: set[tuple[Path, int, int]] = set()
for ref in references:
if not ref.module_path:
continue
ref_path = Path(ref.module_path)
# Skip the definition itself
if ref_path == function.file_path and ref.line == function_pos[0]:
continue
# Skip test files
if tests_root:
try:
ref_path.relative_to(tests_root)
continue
except ValueError:
pass
# Avoid duplicates
loc_key = (ref_path, ref.line, ref.column)
if loc_key in seen_locations:
continue
seen_locations.add(loc_key)
# Get context line
try:
ref_source = ref_path.read_text()
lines = ref_source.splitlines()
context = lines[ref.line - 1] if ref.line <= len(lines) else ""
except Exception:
context = ""
# Determine caller function
caller_function = None
try:
parent = ref.parent()
if parent and parent.type == "function":
caller_function = parent.name
except Exception:
pass
result.append(
ReferenceInfo(
file_path=ref_path,
line=ref.line,
column=ref.column,
end_line=ref.line,
end_column=ref.column + len(function.name),
context=context.strip(),
reference_type="call",
import_name=function.name,
caller_function=caller_function,
)
)
return result
except Exception as e:
logger.warning("Failed to find references for %s: %s", function.name, e)
return []
# === Code Transformation ===
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:

View file

@ -178,6 +178,41 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport:
_FRAMEWORK_CACHE: dict[str, LanguageSupport] = {}
def get_language_support_by_common_formatters(formatter_cmd: str | list[str]) -> LanguageSupport | None:
language: Language | None = None
if isinstance(formatter_cmd, str):
formatter_cmd = [formatter_cmd]
if len(formatter_cmd) == 1:
formatter_cmd = formatter_cmd[0].split(" ")
# Try as extension first
ext = None
py_formatters = ["black", "isort", "ruff", "autopep8", "yapf", "pyfmt"]
js_ts_formatters = ["prettier", "eslint", "biome", "rome", "deno", "standard", "tslint"]
if any(cmd in py_formatters for cmd in formatter_cmd):
ext = ".py"
elif any(cmd in js_ts_formatters for cmd in formatter_cmd):
ext = ".js"
if ext is None:
# can't determine language
return None
cls = _EXTENSION_REGISTRY[ext]
language = cls().language
# Return cached instance or create new one
if language not in _SUPPORT_CACHE:
if language not in _LANGUAGE_REGISTRY:
raise UnsupportedLanguageError(str(language), get_supported_languages())
_SUPPORT_CACHE[language] = _LANGUAGE_REGISTRY[language]()
return _SUPPORT_CACHE[language]
def get_language_support_by_framework(test_framework: str) -> LanguageSupport | None:
"""Get language support for a test framework.

View file

@ -103,6 +103,9 @@ class CodeflashConfig(BaseModel):
if self.module_root and self.module_root not in (".", "src"):
config["moduleRoot"] = self.module_root
if self.tests_root:
config["testsRoot"] = self.tests_root
# Formatter (only if explicitly set)
if self.formatter_cmds:
config["formatterCmds"] = self.formatter_cmds

View file

@ -115,7 +115,7 @@ function installCodeflash(uvBin) {
try {
// Use uv tool install to install codeflash in an isolated environment
// This avoids conflicts with any existing Python environments
execSync(`"${uvBin}" tool install codeflash --force`, {
execSync(`"${uvBin}" tool install --force --python python3.12 codeflash`, {
stdio: 'inherit',
shell: true,
});

View file

@ -604,3 +604,349 @@ def test_function_in_tests_dir():
assert "vanilla_function" not in remaining_functions
files_and_funcs = get_all_files_and_functions(module_root_path=temp_dir, ignore_paths=[])
assert len(files_and_funcs) == 6
def test_filter_functions_tests_root_overlaps_source():
"""Test that source files are not filtered when tests_root equals module_root or project_root.
This is a critical test for monorepo structures where tests live alongside source code
(e.g., TypeScript projects with .test.ts files in the same directories as source).
"""
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
# Create a source file (NOT a test file)
source_file = temp_dir / "utils.py"
with source_file.open("w") as f:
f.write("""
def process_data(items):
return [item * 2 for item in items]
def calculate_sum(numbers):
return sum(numbers)
""")
# Create a test file with standard naming pattern
test_file = temp_dir / "utils.test.py"
with test_file.open("w") as f:
f.write("""
def test_process_data():
return "test"
""")
# Create a test file with _test suffix pattern
test_file_underscore = temp_dir / "utils_test.py"
with test_file_underscore.open("w") as f:
f.write("""
def test_calculate_sum():
return "test"
""")
# Create a spec file
spec_file = temp_dir / "utils.spec.py"
with spec_file.open("w") as f:
f.write("""
def spec_function():
return "spec"
""")
# Create a file in a tests subdirectory
tests_subdir = temp_dir / "tests"
tests_subdir.mkdir()
tests_subdir_file = tests_subdir / "test_main.py"
with tests_subdir_file.open("w") as f:
f.write("""
def test_in_tests_dir():
return "test"
""")
# Create a file in __tests__ subdirectory (common in JS/TS projects)
dunder_tests_subdir = temp_dir / "__tests__"
dunder_tests_subdir.mkdir()
dunder_tests_file = dunder_tests_subdir / "main.py"
with dunder_tests_file.open("w") as f:
f.write("""
def test_in_dunder_tests():
return "test"
""")
# Discover all functions
discovered_source = find_all_functions_in_file(source_file)
discovered_test = find_all_functions_in_file(test_file)
discovered_test_underscore = find_all_functions_in_file(test_file_underscore)
discovered_spec = find_all_functions_in_file(spec_file)
discovered_tests_dir = find_all_functions_in_file(tests_subdir_file)
discovered_dunder_tests = find_all_functions_in_file(dunder_tests_file)
# Combine all discovered functions
all_functions = {}
for discovered in [discovered_source, discovered_test, discovered_test_underscore,
discovered_spec, discovered_tests_dir, discovered_dunder_tests]:
all_functions.update(discovered)
# Test Case 1: tests_root == module_root (overlapping case)
# This is the bug scenario where all functions were being filtered
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered, count = filter_functions(
all_functions,
tests_root=temp_dir, # Same as module_root
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir, # Same as tests_root
)
# Strict check: only source_file should remain in filtered results
assert set(filtered.keys()) == {source_file}, (
f"Expected only source file in filtered results, got: {set(filtered.keys())}"
)
# Strict check: exactly these two functions should be present
source_functions = sorted([fn.function_name for fn in filtered.get(source_file, [])])
assert source_functions == ["calculate_sum", "process_data"], (
f"Expected ['calculate_sum', 'process_data'], got {source_functions}"
)
# Strict check: exactly 2 functions remaining
assert count == 2, f"Expected exactly 2 functions, got {count}"
# Test Case 2: tests_root == project_root (another overlapping case)
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered2, count2 = filter_functions(
{source_file: discovered_source[source_file]},
tests_root=temp_dir, # Same as project_root
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
)
# Strict check: only source_file should remain
assert set(filtered2.keys()) == {source_file}, (
f"Expected only source file when tests_root == project_root, got: {set(filtered2.keys())}"
)
assert count2 == 2, f"Expected exactly 2 functions, got {count2}"
def test_filter_functions_strict_string_matching():
"""Test that test file pattern matching uses strict string matching.
Ensures patterns like '.test.' only match actual test files and don't
accidentally match files with similar names like 'contest.py' or 'latest.py'.
"""
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
# Files that should NOT be filtered (contain 'test' as substring but not as pattern)
contest_file = temp_dir / "contest.py"
with contest_file.open("w") as f:
f.write("def run_contest(): return 1")
latest_file = temp_dir / "latest.py"
with latest_file.open("w") as f:
f.write("def get_latest(): return 1")
attestation_file = temp_dir / "attestation.py"
with attestation_file.open("w") as f:
f.write("def verify_attestation(): return 1")
# File that SHOULD be filtered (matches .test. pattern)
actual_test_file = temp_dir / "utils.test.py"
with actual_test_file.open("w") as f:
f.write("def test_utils(): return 1")
# File that SHOULD be filtered (matches _test. pattern)
underscore_test_file = temp_dir / "utils_test.py"
with underscore_test_file.open("w") as f:
f.write("def test_stuff(): return 1")
# Discover all functions
all_functions = {}
for file_path in [contest_file, latest_file, attestation_file, actual_test_file, underscore_test_file]:
discovered = find_all_functions_in_file(file_path)
all_functions.update(discovered)
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered, count = filter_functions(
all_functions,
tests_root=temp_dir, # Overlapping case to trigger pattern matching
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
)
# Strict check: exactly these 3 files should remain (those with 'test' as substring only)
expected_files = {contest_file, latest_file, attestation_file}
assert set(filtered.keys()) == expected_files, (
f"Expected files {expected_files}, got {set(filtered.keys())}"
)
# Strict check: each file should have exactly 1 function with the expected name
assert [fn.function_name for fn in filtered[contest_file]] == ["run_contest"], (
f"Expected ['run_contest'], got {[fn.function_name for fn in filtered[contest_file]]}"
)
assert [fn.function_name for fn in filtered[latest_file]] == ["get_latest"], (
f"Expected ['get_latest'], got {[fn.function_name for fn in filtered[latest_file]]}"
)
assert [fn.function_name for fn in filtered[attestation_file]] == ["verify_attestation"], (
f"Expected ['verify_attestation'], got {[fn.function_name for fn in filtered[attestation_file]]}"
)
# Strict check: exactly 3 functions remaining
assert count == 3, f"Expected exactly 3 functions, got {count}"
def test_filter_functions_test_directory_patterns():
"""Test that test directory patterns work correctly with strict matching.
Ensures that /test/, /tests/, and /__tests__/ patterns only match actual
test directories and not directories that happen to contain 'test' in name.
"""
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
# Directory that should NOT be filtered (contains 'test' but not as /test/ pattern)
contest_dir = temp_dir / "contest_results"
contest_dir.mkdir()
contest_file = contest_dir / "scores.py"
with contest_file.open("w") as f:
f.write("def get_scores(): return [1, 2, 3]")
latest_dir = temp_dir / "latest_data"
latest_dir.mkdir()
latest_file = latest_dir / "data.py"
with latest_file.open("w") as f:
f.write("def load_data(): return {}")
# Directory that SHOULD be filtered (matches /tests/ pattern)
tests_dir = temp_dir / "tests"
tests_dir.mkdir()
tests_file = tests_dir / "test_main.py"
with tests_file.open("w") as f:
f.write("def test_main(): return True")
# Directory that SHOULD be filtered (matches /test/ pattern - singular)
test_dir = temp_dir / "test"
test_dir.mkdir()
test_file = test_dir / "test_utils.py"
with test_file.open("w") as f:
f.write("def test_utils(): return True")
# Directory that SHOULD be filtered (matches /__tests__/ pattern)
dunder_tests_dir = temp_dir / "__tests__"
dunder_tests_dir.mkdir()
dunder_file = dunder_tests_dir / "component.py"
with dunder_file.open("w") as f:
f.write("def test_component(): return True")
# Nested test directory
src_dir = temp_dir / "src"
src_dir.mkdir()
nested_tests_dir = src_dir / "tests"
nested_tests_dir.mkdir()
nested_test_file = nested_tests_dir / "test_nested.py"
with nested_test_file.open("w") as f:
f.write("def test_nested(): return True")
# Discover all functions
all_functions = {}
for file_path in [contest_file, latest_file, tests_file, test_file, dunder_file, nested_test_file]:
discovered = find_all_functions_in_file(file_path)
all_functions.update(discovered)
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered, count = filter_functions(
all_functions,
tests_root=temp_dir, # Overlapping case
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
)
# Strict check: exactly these 2 files should remain (those in non-test directories)
expected_files = {contest_file, latest_file}
assert set(filtered.keys()) == expected_files, (
f"Expected files {expected_files}, got {set(filtered.keys())}"
)
# Strict check: each file should have exactly 1 function with the expected name
assert [fn.function_name for fn in filtered[contest_file]] == ["get_scores"], (
f"Expected ['get_scores'], got {[fn.function_name for fn in filtered[contest_file]]}"
)
assert [fn.function_name for fn in filtered[latest_file]] == ["load_data"], (
f"Expected ['load_data'], got {[fn.function_name for fn in filtered[latest_file]]}"
)
# Strict check: exactly 2 functions remaining
assert count == 2, f"Expected exactly 2 functions, got {count}"
def test_filter_functions_non_overlapping_tests_root():
"""Test that the original directory-based filtering still works when tests_root is separate.
When tests_root is a distinct directory (e.g., 'tests/'), the original behavior
of filtering files that start with tests_root should still work.
"""
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
# Create source directory structure
src_dir = temp_dir / "src"
src_dir.mkdir()
source_file = src_dir / "utils.py"
with source_file.open("w") as f:
f.write("def process(): return 1")
# Create a file with .test. pattern in source (should NOT be filtered in non-overlapping mode)
# because directory-based filtering takes precedence
test_in_src = src_dir / "helper.test.py"
with test_in_src.open("w") as f:
f.write("def helper_test(): return 1")
# Create separate tests directory
tests_dir = temp_dir / "tests"
tests_dir.mkdir()
test_file = tests_dir / "test_utils.py"
with test_file.open("w") as f:
f.write("def test_process(): return 1")
# Discover functions
all_functions = {}
for file_path in [source_file, test_in_src, test_file]:
discovered = find_all_functions_in_file(file_path)
all_functions.update(discovered)
# Non-overlapping case: tests_root is a separate directory
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered, count = filter_functions(
all_functions,
tests_root=tests_dir, # Separate from module_root
ignore_paths=[],
project_root=temp_dir,
module_root=src_dir, # Different from tests_root
)
# Strict check: exactly these 2 files should remain (both in src/, not in tests/)
expected_files = {source_file, test_in_src}
assert set(filtered.keys()) == expected_files, (
f"Expected files {expected_files}, got {set(filtered.keys())}"
)
# Strict check: each file should have exactly 1 function with the expected name
assert [fn.function_name for fn in filtered[source_file]] == ["process"], (
f"Expected ['process'], got {[fn.function_name for fn in filtered[source_file]]}"
)
assert [fn.function_name for fn in filtered[test_in_src]] == ["helper_test"], (
f"Expected ['helper_test'], got {[fn.function_name for fn in filtered[test_in_src]]}"
)
# Strict check: exactly 2 functions remaining
assert count == 2, f"Expected exactly 2 functions, got {count}"

File diff suppressed because it is too large Load diff