mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge branch 'main' into add_vitest_support_to_js
This commit is contained in:
commit
82d9e435ef
16 changed files with 2900 additions and 47 deletions
20
.github/workflows/claude.yml
vendored
20
.github/workflows/claude.yml
vendored
|
|
@ -19,16 +19,30 @@ jobs:
|
|||
(github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for Claude to read CI results on PRs
|
||||
steps:
|
||||
- name: Get PR head ref
|
||||
id: pr-ref
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
# For issue_comment events, we need to fetch the PR info
|
||||
if [ "${{ github.event_name }}" = "issue_comment" ]; then
|
||||
PR_REF=$(gh api repos/${{ github.repository }}/pulls/${{ github.event.issue.number }} --jq '.head.ref')
|
||||
echo "ref=$PR_REF" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "ref=${{ github.event.pull_request.head.ref || github.head_ref }}" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
fetch-depth: 0
|
||||
ref: ${{ steps.pr-ref.outputs.ref }}
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from codeflash.code_utils.compat import LF
|
|||
from codeflash.code_utils.git_utils import get_git_remotes
|
||||
from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from rich.prompt import Confirm
|
||||
|
||||
|
||||
class ProjectLanguage(Enum):
|
||||
|
|
@ -208,9 +209,7 @@ def init_js_project(language: ProjectLanguage) -> None:
|
|||
|
||||
def should_modify_package_json_config() -> tuple[bool, dict[str, Any] | None]:
|
||||
"""Check if package.json has valid codeflash config for JS/TS projects."""
|
||||
from rich.prompt import Confirm
|
||||
|
||||
package_json_path = Path.cwd() / "package.json"
|
||||
package_json_path = Path("package.json")
|
||||
|
||||
if not package_json_path.exists():
|
||||
click.echo("❌ No package.json found. Please run 'npm init' first.")
|
||||
|
|
@ -230,6 +229,10 @@ def should_modify_package_json_config() -> tuple[bool, dict[str, Any] | None]:
|
|||
if not Path(module_root).is_dir():
|
||||
return True, None
|
||||
|
||||
tests_root = config.get("testsRoot", None)
|
||||
if tests_root and not Path(tests_root).is_dir():
|
||||
return True, None
|
||||
|
||||
# Config is valid - ask if user wants to reconfigure
|
||||
return Confirm.ask(
|
||||
"✅ A valid Codeflash config already exists in package.json. Do you want to re-configure it?",
|
||||
|
|
|
|||
|
|
@ -1563,23 +1563,228 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo
|
|||
def get_opt_review_metrics(
|
||||
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path, language: Language
|
||||
) -> str:
|
||||
if language != Language.PYTHON:
|
||||
# TODO: {Claude} handle function refrences for other languages
|
||||
return ""
|
||||
"""Get function reference metrics for optimization review.
|
||||
|
||||
Uses the LanguageSupport abstraction to find references, supporting both Python and JavaScript/TypeScript.
|
||||
|
||||
Args:
|
||||
source_code: Source code of the file containing the function.
|
||||
file_path: Path to the file.
|
||||
qualified_name: Qualified name of the function (e.g., "module.ClassName.method").
|
||||
project_root: Root of the project.
|
||||
tests_root: Root of the tests directory.
|
||||
language: The programming language.
|
||||
|
||||
Returns:
|
||||
Markdown-formatted string with code blocks showing calling functions.
|
||||
"""
|
||||
from codeflash.languages.base import FunctionInfo, ParentInfo, ReferenceInfo
|
||||
from codeflash.languages.registry import get_language_support
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
# Get the language support
|
||||
lang_support = get_language_support(language)
|
||||
if lang_support is None:
|
||||
return ""
|
||||
|
||||
# Parse qualified name to get function name and class name
|
||||
qualified_name_split = qualified_name.rsplit(".", maxsplit=1)
|
||||
if len(qualified_name_split) == 1:
|
||||
target_function, target_class = qualified_name_split[0], None
|
||||
function_name, class_name = qualified_name_split[0], None
|
||||
else:
|
||||
target_function, target_class = qualified_name_split[1], qualified_name_split[0]
|
||||
matches = get_fn_references_jedi(
|
||||
source_code, file_path, project_root, target_function, target_class
|
||||
) # jedi is not perfect, it doesn't capture aliased references
|
||||
calling_fns_details = find_occurances(qualified_name, str(file_path), matches, project_root, tests_root)
|
||||
function_name, class_name = qualified_name_split[1], qualified_name_split[0]
|
||||
|
||||
# Create a FunctionInfo for the function
|
||||
# We don't have full line info here, so we'll use defaults
|
||||
parents = ()
|
||||
if class_name:
|
||||
parents = (ParentInfo(name=class_name, type="ClassDef"),)
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name=function_name,
|
||||
file_path=file_path,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
parents=parents,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Find references using language support
|
||||
references = lang_support.find_references(func_info, project_root, tests_root, max_files=500)
|
||||
|
||||
if not references:
|
||||
return ""
|
||||
|
||||
# Format references as markdown code blocks
|
||||
calling_fns_details = _format_references_as_markdown(
|
||||
references, file_path, project_root, language
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting function references: {e}")
|
||||
calling_fns_details = ""
|
||||
logger.debug(f"Investigate {e}")
|
||||
|
||||
end_time = time.perf_counter()
|
||||
logger.debug(f"Got function references in {end_time - start_time:.2f} seconds")
|
||||
return calling_fns_details
|
||||
|
||||
|
||||
def _format_references_as_markdown(
|
||||
references: list, file_path: Path, project_root: Path, language: Language
|
||||
) -> str:
|
||||
"""Format references as markdown code blocks with calling function code.
|
||||
|
||||
Args:
|
||||
references: List of ReferenceInfo objects.
|
||||
file_path: Path to the source file (to exclude).
|
||||
project_root: Root of the project.
|
||||
language: The programming language.
|
||||
|
||||
Returns:
|
||||
Markdown-formatted string.
|
||||
"""
|
||||
# Group references by file
|
||||
refs_by_file: dict[Path, list] = {}
|
||||
for ref in references:
|
||||
# Exclude the source file's definition/import references
|
||||
if ref.file_path == file_path and ref.reference_type in ("import", "reexport"):
|
||||
continue
|
||||
|
||||
if ref.file_path not in refs_by_file:
|
||||
refs_by_file[ref.file_path] = []
|
||||
refs_by_file[ref.file_path].append(ref)
|
||||
|
||||
fn_call_context = ""
|
||||
context_len = 0
|
||||
|
||||
for ref_file, file_refs in refs_by_file.items():
|
||||
if context_len > MAX_CONTEXT_LEN_REVIEW:
|
||||
break
|
||||
|
||||
try:
|
||||
path_relative = ref_file.relative_to(project_root)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Get syntax highlighting language
|
||||
ext = ref_file.suffix.lstrip(".")
|
||||
if language == Language.PYTHON:
|
||||
lang_hint = "python"
|
||||
elif ext in ("ts", "tsx"):
|
||||
lang_hint = "typescript"
|
||||
else:
|
||||
lang_hint = "javascript"
|
||||
|
||||
# Read the file to extract calling function context
|
||||
try:
|
||||
file_content = ref_file.read_text(encoding="utf-8")
|
||||
lines = file_content.splitlines()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Get unique caller functions from this file
|
||||
callers_seen: set[str] = set()
|
||||
caller_contexts: list[str] = []
|
||||
|
||||
for ref in file_refs:
|
||||
caller = ref.caller_function or "<module>"
|
||||
if caller in callers_seen:
|
||||
continue
|
||||
callers_seen.add(caller)
|
||||
|
||||
# Extract context around the reference
|
||||
if ref.caller_function:
|
||||
# Try to extract the full calling function
|
||||
func_code = _extract_calling_function(file_content, ref.caller_function, ref.line, language)
|
||||
if func_code:
|
||||
caller_contexts.append(func_code)
|
||||
context_len += len(func_code)
|
||||
else:
|
||||
# Module-level call - show a few lines of context
|
||||
start_line = max(0, ref.line - 3)
|
||||
end_line = min(len(lines), ref.line + 2)
|
||||
context_code = "\n".join(lines[start_line:end_line])
|
||||
caller_contexts.append(context_code)
|
||||
context_len += len(context_code)
|
||||
|
||||
if caller_contexts:
|
||||
fn_call_context += f"```{lang_hint}:{path_relative}\n"
|
||||
fn_call_context += "\n".join(caller_contexts)
|
||||
fn_call_context += "\n```\n"
|
||||
|
||||
return fn_call_context
|
||||
|
||||
|
||||
def _extract_calling_function(source_code: str, function_name: str, ref_line: int, language: Language) -> str | None:
|
||||
"""Extract the source code of a calling function.
|
||||
|
||||
Args:
|
||||
source_code: Full source code of the file.
|
||||
function_name: Name of the function to extract.
|
||||
ref_line: Line number where the reference is.
|
||||
language: The programming language.
|
||||
|
||||
Returns:
|
||||
Source code of the function, or None if not found.
|
||||
"""
|
||||
if language == Language.PYTHON:
|
||||
return _extract_calling_function_python(source_code, function_name, ref_line)
|
||||
else:
|
||||
return _extract_calling_function_js(source_code, function_name, ref_line)
|
||||
|
||||
|
||||
def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None:
|
||||
"""Extract the source code of a calling function in Python."""
|
||||
try:
|
||||
import ast
|
||||
|
||||
tree = ast.parse(source_code)
|
||||
lines = source_code.splitlines()
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
if node.name == function_name:
|
||||
# Check if the reference line is within this function
|
||||
start_line = node.lineno
|
||||
end_line = node.end_lineno or start_line
|
||||
if start_line <= ref_line <= end_line:
|
||||
return "\n".join(lines[start_line - 1 : end_line])
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_calling_function_js(source_code: str, function_name: str, ref_line: int) -> str | None:
|
||||
"""Extract the source code of a calling function in JavaScript/TypeScript.
|
||||
|
||||
Args:
|
||||
source_code: Full source code of the file.
|
||||
function_name: Name of the function to extract.
|
||||
ref_line: Line number where the reference is (helps identify the right function).
|
||||
|
||||
Returns:
|
||||
Source code of the function, or None if not found.
|
||||
"""
|
||||
try:
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
# Try TypeScript first, fall back to JavaScript
|
||||
for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]:
|
||||
try:
|
||||
analyzer = TreeSitterAnalyzer(lang)
|
||||
functions = analyzer.find_functions(source_code, include_methods=True)
|
||||
|
||||
for func in functions:
|
||||
if func.name == function_name:
|
||||
# Check if the reference line is within this function
|
||||
if func.start_line <= ref_line <= func.end_line:
|
||||
return func.source_text
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -251,6 +251,9 @@ def parse_package_json_config(package_json_path: Path) -> tuple[dict[str, Any],
|
|||
detected_module_root = detect_module_root(project_root, package_data)
|
||||
config["module_root"] = str((project_root / Path(detected_module_root)).resolve())
|
||||
|
||||
if codeflash_config.get("testsRoot"):
|
||||
config["tests_root"] = str(project_root / Path(codeflash_config["testsRoot"]).resolve())
|
||||
|
||||
# Auto-detect test runner
|
||||
config["test_runner"] = detect_test_runner(project_root, package_data)
|
||||
# Keep pytest_cmd for backwards compatibility with existing code
|
||||
|
|
|
|||
|
|
@ -13,10 +13,14 @@ from codeflash.cli_cmds.console import logger
|
|||
from codeflash.code_utils.code_utils import exit_with_message
|
||||
from codeflash.code_utils.formatter import format_code
|
||||
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.registry import get_language_support_by_common_formatters
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
|
||||
|
||||
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool:
|
||||
def check_formatter_installed(
|
||||
formatter_cmds: list[str], exit_on_failure: bool = True, language: str = "python"
|
||||
) -> bool:
|
||||
if not formatter_cmds or formatter_cmds[0] == "disabled":
|
||||
return True
|
||||
first_cmd = formatter_cmds[0]
|
||||
|
|
@ -35,10 +39,21 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
|
|||
)
|
||||
return False
|
||||
|
||||
tmp_code = """print("hello world")"""
|
||||
lang_support = get_language_support_by_common_formatters(formatter_cmds)
|
||||
if not lang_support:
|
||||
logger.debug(f"Could not determine language for formatter: {formatter_cmds}")
|
||||
return True
|
||||
|
||||
if lang_support.language == Language.PYTHON:
|
||||
tmp_code = """print("hello world")"""
|
||||
elif lang_support.language in (Language.JAVASCRIPT, Language.TYPESCRIPT):
|
||||
tmp_code = "console.log('hello world');"
|
||||
else:
|
||||
return True
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp_file = Path(tmpdir) / "test_codeflash_formatter.py"
|
||||
tmp_file = Path(tmpdir) / ("test_codeflash_formatter" + lang_support.default_file_extension)
|
||||
tmp_file.write_text(tmp_code, encoding="utf-8")
|
||||
format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=False)
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from typing import Any, Optional, Union
|
|||
import isort
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.languages.registry import get_language_support
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
|
||||
|
||||
|
|
@ -47,8 +48,9 @@ def apply_formatter_cmds(
|
|||
raise FileNotFoundError(msg)
|
||||
|
||||
file_path = path
|
||||
lang_support = get_language_support(path)
|
||||
if test_dir_str:
|
||||
file_path = Path(test_dir_str) / "temp.py"
|
||||
file_path = Path(test_dir_str) / ("temp" + lang_support.default_file_extension)
|
||||
shutil.copy2(path, file_path)
|
||||
|
||||
file_token = "$file" # noqa: S105
|
||||
|
|
@ -87,13 +89,14 @@ def get_diff_lines_count(diff_output: str) -> int:
|
|||
return len(diff_lines)
|
||||
|
||||
|
||||
def format_generated_code(generated_test_source: str, formatter_cmds: list[str]) -> str:
|
||||
def format_generated_code(generated_test_source: str, formatter_cmds: list[str], language: str = "python") -> str:
|
||||
formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
|
||||
if formatter_name == "disabled": # nothing to do if no formatter provided
|
||||
return re.sub(r"\n{2,}", "\n\n", generated_test_source)
|
||||
with tempfile.TemporaryDirectory() as test_dir_str:
|
||||
# try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) return code with 2 or more newlines substituted with 2 newlines
|
||||
original_temp = Path(test_dir_str) / "original_temp.py"
|
||||
lang_support = get_language_support(language)
|
||||
original_temp = Path(test_dir_str) / ("original_temp" + lang_support.default_file_extension)
|
||||
original_temp.write_text(generated_test_source, encoding="utf8")
|
||||
_, formatted_code, changed = apply_formatter_cmds(
|
||||
formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=False
|
||||
|
|
@ -130,7 +133,8 @@ def format_code(
|
|||
# we don't count the formatting diff for the optimized function as it should be well-formatted
|
||||
original_code_without_opfunc = original_code.replace(optimized_code, "")
|
||||
|
||||
original_temp = Path(test_dir_str) / "original_temp.py"
|
||||
lang_support = get_language_support(path)
|
||||
original_temp = Path(test_dir_str) / ("original_temp" + lang_support.default_file_extension)
|
||||
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
|
||||
|
||||
formatted_temp, formatted_code, changed = apply_formatter_cmds(
|
||||
|
|
@ -160,6 +164,7 @@ def format_code(
|
|||
_, formatted_code, changed = apply_formatter_cmds(
|
||||
formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure
|
||||
)
|
||||
|
||||
if not changed:
|
||||
logger.warning(
|
||||
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"
|
||||
|
|
|
|||
|
|
@ -41,6 +41,8 @@ if TYPE_CHECKING:
|
|||
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
import contextlib
|
||||
|
||||
from rich.text import Text
|
||||
|
||||
_property_id = "property"
|
||||
|
|
@ -616,9 +618,10 @@ def get_all_replay_test_functions(
|
|||
except Exception as e:
|
||||
logger.warning(f"Error parsing replay test file {replay_test_file}: {e}")
|
||||
|
||||
if not trace_file_path:
|
||||
if trace_file_path is None:
|
||||
logger.error("Could not find trace_file_path in replay test files.")
|
||||
exit_with_message("Could not find trace_file_path in replay test files.")
|
||||
raise AssertionError("Unreachable") # exit_with_message never returns
|
||||
|
||||
if not trace_file_path.exists():
|
||||
logger.error(f"Trace file not found: {trace_file_path}")
|
||||
|
|
@ -673,7 +676,7 @@ def get_all_replay_test_functions(
|
|||
if filtered_list:
|
||||
filtered_valid_functions[file_path] = filtered_list
|
||||
|
||||
return filtered_valid_functions, trace_file_path
|
||||
return dict(filtered_valid_functions), trace_file_path
|
||||
|
||||
|
||||
def is_git_repo(file_path: str) -> bool:
|
||||
|
|
@ -685,11 +688,13 @@ def is_git_repo(file_path: str) -> bool:
|
|||
|
||||
|
||||
@cache
|
||||
def ignored_submodule_paths(module_root: str) -> list[str]:
|
||||
def ignored_submodule_paths(module_root: str) -> list[Path]:
|
||||
if is_git_repo(module_root):
|
||||
git_repo = git.Repo(module_root, search_parent_directories=True)
|
||||
try:
|
||||
return [Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules]
|
||||
working_dir = git_repo.working_tree_dir
|
||||
if working_dir is not None:
|
||||
return [Path(working_dir, submodule.path).resolve() for submodule in git_repo.submodules]
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting submodule paths: {e}")
|
||||
return []
|
||||
|
|
@ -703,7 +708,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|||
self.class_name = class_name
|
||||
self.function_name = function_or_method_name
|
||||
self.is_top_level = False
|
||||
self.function_has_args = None
|
||||
self.function_has_args: bool | None = None
|
||||
self.line_no = line_no
|
||||
self.is_staticmethod = False
|
||||
self.is_classmethod = False
|
||||
|
|
@ -817,31 +822,28 @@ def was_function_previously_optimized(
|
|||
|
||||
# Check optimization status if repository info is provided
|
||||
# already_optimized_count = 0
|
||||
try:
|
||||
|
||||
# Check optimization status if repository info is provided
|
||||
# already_optimized_count = 0
|
||||
owner = None
|
||||
repo = None
|
||||
with contextlib.suppress(git.exc.InvalidGitRepositoryError):
|
||||
owner, repo = get_repo_owner_and_name()
|
||||
except git.exc.InvalidGitRepositoryError:
|
||||
logger.warning("No git repository found")
|
||||
owner, repo = None, None
|
||||
|
||||
pr_number = get_pr_number()
|
||||
|
||||
if not owner or not repo or pr_number is None or getattr(args, "no_pr", False):
|
||||
return False
|
||||
|
||||
code_contexts = []
|
||||
|
||||
func_hash = code_context.hashing_code_context_hash
|
||||
# Use a unique path identifier that includes function info
|
||||
|
||||
code_contexts.append(
|
||||
code_contexts = [
|
||||
{
|
||||
"file_path": function_to_optimize.file_path,
|
||||
"file_path": str(function_to_optimize.file_path),
|
||||
"function_name": function_to_optimize.qualified_name,
|
||||
"code_hash": func_hash,
|
||||
}
|
||||
)
|
||||
|
||||
if not code_contexts:
|
||||
return False
|
||||
]
|
||||
|
||||
try:
|
||||
result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts)
|
||||
|
|
@ -860,7 +862,7 @@ def filter_functions(
|
|||
ignore_paths: list[Path],
|
||||
project_root: Path,
|
||||
module_root: Path,
|
||||
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
|
||||
previous_checkpoint_functions: dict[str, dict[str, Any]] | None = None,
|
||||
*,
|
||||
disable_logs: bool = False,
|
||||
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
|
||||
|
|
@ -885,21 +887,49 @@ def filter_functions(
|
|||
# Normalize paths for case-insensitive comparison on Windows
|
||||
tests_root_str = os.path.normcase(str(tests_root))
|
||||
module_root_str = os.path.normcase(str(module_root))
|
||||
project_root_str = os.path.normcase(str(project_root))
|
||||
|
||||
# Check if tests_root overlaps with module_root or project_root
|
||||
# In this case, we need to use file pattern matching instead of directory matching
|
||||
tests_root_overlaps_source = tests_root_str in (module_root_str, project_root_str) or module_root_str.startswith(
|
||||
tests_root_str + os.sep
|
||||
)
|
||||
|
||||
# Test file patterns for when tests_root overlaps with source
|
||||
test_file_name_patterns = (".test.", ".spec.", "_test.", "_spec.")
|
||||
test_dir_patterns = (os.sep + "test" + os.sep, os.sep + "tests" + os.sep, os.sep + "__tests__" + os.sep)
|
||||
|
||||
def is_test_file(file_path_normalized: str) -> bool:
|
||||
"""Check if a file is a test file based on patterns."""
|
||||
if tests_root_overlaps_source:
|
||||
# Use file pattern matching when tests_root overlaps with source
|
||||
file_lower = file_path_normalized.lower()
|
||||
# Check filename patterns (e.g., .test.ts, .spec.ts)
|
||||
if any(pattern in file_lower for pattern in test_file_name_patterns):
|
||||
return True
|
||||
# Check directory patterns, but only within the project root
|
||||
# to avoid false positives from parent directories
|
||||
relative_path = file_lower
|
||||
if project_root_str and file_lower.startswith(project_root_str.lower()):
|
||||
relative_path = file_lower[len(project_root_str) :]
|
||||
return any(pattern in relative_path for pattern in test_dir_patterns)
|
||||
# Use directory-based filtering when tests are in a separate directory
|
||||
return file_path_normalized.startswith(tests_root_str + os.sep)
|
||||
|
||||
# We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
|
||||
for file_path_path, functions in modified_functions.items():
|
||||
_functions = functions
|
||||
file_path = str(file_path_path)
|
||||
file_path_normalized = os.path.normcase(file_path)
|
||||
if file_path_normalized.startswith(tests_root_str + os.sep):
|
||||
if is_test_file(file_path_normalized):
|
||||
test_functions_removed_count += len(_functions)
|
||||
continue
|
||||
if file_path in ignore_paths or any(
|
||||
if file_path_path in ignore_paths or any(
|
||||
file_path_normalized.startswith(os.path.normcase(str(ignore_path)) + os.sep) for ignore_path in ignore_paths
|
||||
):
|
||||
ignore_paths_removed_count += 1
|
||||
continue
|
||||
if file_path in submodule_paths or any(
|
||||
if file_path_path in submodule_paths or any(
|
||||
file_path_normalized.startswith(os.path.normcase(str(submodule_path)) + os.sep)
|
||||
for submodule_path in submodule_paths
|
||||
):
|
||||
|
|
@ -991,7 +1021,7 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
|
|||
|
||||
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
|
||||
# Custom DFS, return True as soon as a Return node is found
|
||||
stack = [function_node]
|
||||
stack: list[ast.AST] = [function_node]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if isinstance(node, ast.Return):
|
||||
|
|
|
|||
|
|
@ -236,6 +236,37 @@ class FunctionFilterCriteria:
|
|||
max_lines: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceInfo:
|
||||
"""Information about a reference (call site) to a function.
|
||||
|
||||
This class captures information about where a function is called
|
||||
from, including the file, line number, context, and caller function.
|
||||
|
||||
Attributes:
|
||||
file_path: Path to the file containing the reference.
|
||||
line: Line number (1-indexed).
|
||||
column: Column number (0-indexed).
|
||||
end_line: End line number (1-indexed).
|
||||
end_column: End column number (0-indexed).
|
||||
context: The line of code containing the reference.
|
||||
reference_type: Type of reference ("call", "callback", "memoized", "import", "reexport").
|
||||
import_name: Name used to import the function (may differ from original).
|
||||
caller_function: Name of the function containing this reference (or None for module-level).
|
||||
|
||||
"""
|
||||
|
||||
file_path: Path
|
||||
line: int
|
||||
column: int
|
||||
end_line: int
|
||||
end_column: int
|
||||
context: str
|
||||
reference_type: str
|
||||
import_name: str | None
|
||||
caller_function: str | None = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LanguageSupport(Protocol):
|
||||
"""Protocol defining what a language implementation must provide.
|
||||
|
|
@ -278,6 +309,11 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def default_file_extension(self) -> str:
|
||||
"""Default file extension for this language."""
|
||||
...
|
||||
|
||||
@property
|
||||
def test_framework(self) -> str:
|
||||
"""Primary test framework name.
|
||||
|
|
@ -352,6 +388,29 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def find_references(
|
||||
self, function: FunctionInfo, project_root: Path, tests_root: Path | None = None, max_files: int = 500
|
||||
) -> list[ReferenceInfo]:
|
||||
"""Find all references (call sites) to a function across the codebase.
|
||||
|
||||
This method finds all places where a function is called, including:
|
||||
- Direct calls
|
||||
- Callbacks (passed to other functions)
|
||||
- Memoized versions
|
||||
- Re-exports
|
||||
|
||||
Args:
|
||||
function: The function to find references for.
|
||||
project_root: Root of the project to search.
|
||||
tests_root: Root of tests directory (references in tests are excluded).
|
||||
max_files: Maximum number of files to search.
|
||||
|
||||
Returns:
|
||||
List of ReferenceInfo objects describing each reference location.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
# === Code Transformation ===
|
||||
|
||||
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:
|
||||
|
|
|
|||
861
codeflash/languages/javascript/find_references.py
Normal file
861
codeflash/languages/javascript/find_references.py
Normal file
|
|
@ -0,0 +1,861 @@
|
|||
"""Find references for JavaScript/TypeScript functions.
|
||||
|
||||
This module provides functionality to find all references (call sites) of a function
|
||||
across a JavaScript/TypeScript codebase. Similar to Jedi's find_references for Python,
|
||||
this uses tree-sitter to parse and analyze code.
|
||||
|
||||
Key features:
|
||||
- Finds all call sites of a function across multiple files
|
||||
- Handles various import patterns (named, default, namespace, re-exports, aliases)
|
||||
- Supports both ES modules and CommonJS
|
||||
- Handles memoized functions, callbacks, and method calls
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Node
|
||||
|
||||
from codeflash.languages.treesitter_utils import ExportInfo, ImportInfo, TreeSitterAnalyzer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Reference:
|
||||
"""Represents a reference (call site) to a function."""
|
||||
|
||||
file_path: Path # File containing the reference
|
||||
line: int # 1-indexed line number
|
||||
column: int # 0-indexed column number
|
||||
end_line: int # 1-indexed end line
|
||||
end_column: int # 0-indexed end column
|
||||
context: str # The line of code containing the reference
|
||||
reference_type: str # Type: "call", "callback", "memoized", "import", "reexport"
|
||||
import_name: str | None # Name used to import the function (may differ from original)
|
||||
caller_function: str | None = None # Name of the function containing this reference
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExportedFunction:
|
||||
"""Represents how a function is exported from its source file."""
|
||||
|
||||
function_name: str # The local function name
|
||||
export_name: str | None # The name it's exported as (may differ)
|
||||
is_default: bool # Whether it's a default export
|
||||
file_path: Path # The source file
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceSearchContext:
|
||||
"""Context for tracking visited files during reference search."""
|
||||
|
||||
visited_files: set[Path] = field(default_factory=set)
|
||||
max_files: int = 1000 # Limit to prevent runaway searches
|
||||
|
||||
|
||||
class ReferenceFinder:
|
||||
"""Finds all references to a function across a JavaScript/TypeScript codebase.
|
||||
|
||||
This class provides functionality similar to Jedi's find_references for Python,
|
||||
but for JavaScript/TypeScript using tree-sitter.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from codeflash.languages.javascript.find_references import ReferenceFinder
|
||||
|
||||
finder = ReferenceFinder(project_root=Path("/my/project"))
|
||||
references = finder.find_references(
|
||||
function_name="myHelper",
|
||||
source_file=Path("/my/project/src/utils.ts")
|
||||
)
|
||||
for ref in references:
|
||||
print(f"{ref.file_path}:{ref.line} - {ref.context}")
|
||||
```
|
||||
"""
|
||||
|
||||
# File extensions to search
|
||||
EXTENSIONS = (".ts", ".tsx", ".js", ".jsx", ".mjs", ".cjs")
|
||||
|
||||
def __init__(self, project_root: Path, exclude_patterns: list[str] | None = None) -> None:
|
||||
"""Initialize the ReferenceFinder.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project to search.
|
||||
exclude_patterns: Glob patterns of directories/files to exclude.
|
||||
Defaults to ['node_modules', 'dist', 'build', '.git'].
|
||||
|
||||
"""
|
||||
self.project_root = project_root
|
||||
self.exclude_patterns = exclude_patterns or ["node_modules", "dist", "build", ".git", "coverage", "__pycache__"]
|
||||
self._file_cache: dict[Path, str] = {}
|
||||
|
||||
def find_references(
|
||||
self,
|
||||
function_name: str,
|
||||
source_file: Path,
|
||||
include_definition: bool = False,
|
||||
max_files: int = 1000,
|
||||
) -> list[Reference]:
|
||||
"""Find all references to a function across the project.
|
||||
|
||||
Args:
|
||||
function_name: Name of the function to find references for.
|
||||
source_file: Path to the file where the function is defined.
|
||||
include_definition: Whether to include the function definition itself.
|
||||
max_files: Maximum number of files to search (prevents runaway searches).
|
||||
|
||||
Returns:
|
||||
List of Reference objects describing each call site.
|
||||
|
||||
"""
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
references: list[Reference] = []
|
||||
context = ReferenceSearchContext(max_files=max_files)
|
||||
|
||||
# Step 1: Analyze how the function is exported from its source file
|
||||
source_code = self._read_file(source_file)
|
||||
if source_code is None:
|
||||
logger.warning("Could not read source file: %s", source_file)
|
||||
return references
|
||||
|
||||
analyzer = get_analyzer_for_file(source_file)
|
||||
exported = self._analyze_exports(function_name, source_file, source_code, analyzer)
|
||||
|
||||
if not exported:
|
||||
logger.debug("Function %s is not exported from %s", function_name, source_file)
|
||||
# Still search in same file for internal references
|
||||
same_file_refs = self._find_references_in_file(
|
||||
source_file, source_code, function_name, None, analyzer, include_self=not include_definition
|
||||
)
|
||||
references.extend(same_file_refs)
|
||||
return references
|
||||
|
||||
# Step 2: Find all files that might import from the source file
|
||||
context.visited_files.add(source_file)
|
||||
|
||||
# Track files that re-export our function (we'll search for imports to these too)
|
||||
reexport_files: list[tuple[Path, str]] = [] # (file_path, export_name)
|
||||
|
||||
# Step 3: Search all project files for imports and calls
|
||||
# We use a separate set for files checked for re-exports to avoid duplicate work
|
||||
checked_for_reexports: set[Path] = set()
|
||||
|
||||
for file_path in self._iter_project_files():
|
||||
if file_path in context.visited_files:
|
||||
continue
|
||||
if len(context.visited_files) >= context.max_files:
|
||||
logger.warning("Reached max file limit (%d), stopping search", max_files)
|
||||
break
|
||||
|
||||
file_code = self._read_file(file_path)
|
||||
if file_code is None:
|
||||
continue
|
||||
|
||||
file_analyzer = get_analyzer_for_file(file_path)
|
||||
|
||||
# Check if this file imports from the source file
|
||||
imports = file_analyzer.find_imports(file_code)
|
||||
import_info = self._find_matching_import(imports, source_file, file_path, exported)
|
||||
|
||||
if import_info:
|
||||
# Found an import - mark as visited and search for calls
|
||||
context.visited_files.add(file_path)
|
||||
import_name, original_import = import_info
|
||||
file_refs = self._find_references_in_file(
|
||||
file_path, file_code, function_name, import_name, file_analyzer, include_self=True
|
||||
)
|
||||
references.extend(file_refs)
|
||||
|
||||
# Always check for re-exports (even without direct import match)
|
||||
# This handles the case where a file re-exports from our source file
|
||||
if file_path not in checked_for_reexports:
|
||||
checked_for_reexports.add(file_path)
|
||||
reexport_refs = self._find_reexports_direct(
|
||||
file_path, file_code, source_file, exported, file_analyzer
|
||||
)
|
||||
references.extend(reexport_refs)
|
||||
|
||||
# Track re-export files for later searching
|
||||
for ref in reexport_refs:
|
||||
reexport_files.append((file_path, ref.import_name))
|
||||
|
||||
# Step 4: Follow re-export chains to find references through re-exports
|
||||
for reexport_file, reexport_name in reexport_files:
|
||||
# Create a new ExportedFunction for the re-exported function
|
||||
reexported = ExportedFunction(
|
||||
function_name=reexport_name,
|
||||
export_name=reexport_name,
|
||||
is_default=False,
|
||||
file_path=reexport_file,
|
||||
)
|
||||
|
||||
# Search for imports to the re-export file
|
||||
for file_path in self._iter_project_files():
|
||||
if file_path in context.visited_files:
|
||||
continue
|
||||
if file_path == reexport_file:
|
||||
continue
|
||||
if len(context.visited_files) >= context.max_files:
|
||||
break
|
||||
|
||||
file_code = self._read_file(file_path)
|
||||
if file_code is None:
|
||||
continue
|
||||
|
||||
file_analyzer = get_analyzer_for_file(file_path)
|
||||
imports = file_analyzer.find_imports(file_code)
|
||||
|
||||
# Check if this file imports from the re-export file
|
||||
import_info = self._find_matching_import(imports, reexport_file, file_path, reexported)
|
||||
|
||||
if import_info:
|
||||
context.visited_files.add(file_path)
|
||||
import_name, original_import = import_info
|
||||
file_refs = self._find_references_in_file(
|
||||
file_path, file_code, reexport_name, import_name, file_analyzer, include_self=True
|
||||
)
|
||||
# Avoid duplicates
|
||||
existing_locs = {(r.file_path, r.line, r.column) for r in references}
|
||||
for ref in file_refs:
|
||||
if (ref.file_path, ref.line, ref.column) not in existing_locs:
|
||||
references.append(ref)
|
||||
|
||||
# Step 5: Include references in the same file (internal calls)
|
||||
if include_definition or not exported:
|
||||
same_file_refs = self._find_references_in_file(
|
||||
source_file, source_code, function_name, None, analyzer, include_self=True
|
||||
)
|
||||
# Filter out duplicate references
|
||||
existing_locs = {(r.file_path, r.line, r.column) for r in references}
|
||||
for ref in same_file_refs:
|
||||
if (ref.file_path, ref.line, ref.column) not in existing_locs:
|
||||
references.append(ref)
|
||||
|
||||
# Step 6: Deduplicate references (same file, line, column)
|
||||
seen: set[tuple[Path, int, int]] = set()
|
||||
unique_refs: list[Reference] = []
|
||||
for ref in references:
|
||||
key = (ref.file_path, ref.line, ref.column)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_refs.append(ref)
|
||||
|
||||
return unique_refs
|
||||
|
||||
def _analyze_exports(
|
||||
self, function_name: str, file_path: Path, source_code: str, analyzer: TreeSitterAnalyzer
|
||||
) -> ExportedFunction | None:
|
||||
"""Analyze how a function is exported from its file.
|
||||
|
||||
Args:
|
||||
function_name: Name of the function to check.
|
||||
file_path: Path to the source file.
|
||||
source_code: Source code content.
|
||||
analyzer: TreeSitterAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
ExportedFunction if the function is exported, None otherwise.
|
||||
|
||||
"""
|
||||
is_exported, export_name = analyzer.is_function_exported(source_code, function_name)
|
||||
|
||||
if not is_exported:
|
||||
return None
|
||||
|
||||
return ExportedFunction(
|
||||
function_name=function_name,
|
||||
export_name=export_name,
|
||||
is_default=(export_name == "default"),
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
def _find_matching_import(
|
||||
self,
|
||||
imports: list[ImportInfo],
|
||||
source_file: Path,
|
||||
importing_file: Path,
|
||||
exported: ExportedFunction,
|
||||
) -> tuple[str, ImportInfo] | None:
|
||||
"""Find if any import in a file imports the target function.
|
||||
|
||||
Args:
|
||||
imports: List of imports in the file.
|
||||
source_file: Path to the file containing the function definition.
|
||||
importing_file: Path to the file being checked for imports.
|
||||
exported: Information about how the function is exported.
|
||||
|
||||
Returns:
|
||||
Tuple of (imported_name, ImportInfo) if found, None otherwise.
|
||||
|
||||
"""
|
||||
from codeflash.languages.javascript.import_resolver import ImportResolver
|
||||
|
||||
resolver = ImportResolver(self.project_root)
|
||||
|
||||
for imp in imports:
|
||||
# Resolve the import to see if it points to our source file
|
||||
resolved = resolver.resolve_import(imp, importing_file)
|
||||
if resolved is None:
|
||||
continue
|
||||
|
||||
if resolved.file_path != source_file:
|
||||
continue
|
||||
|
||||
# This import is from our source file - check if it imports our function
|
||||
if exported.is_default:
|
||||
# Default export - check default import
|
||||
if imp.default_import:
|
||||
return (imp.default_import, imp)
|
||||
# Also check namespace import
|
||||
if imp.namespace_import:
|
||||
return (f"{imp.namespace_import}.default", imp)
|
||||
else:
|
||||
# Named export - check named imports
|
||||
export_name = exported.export_name or exported.function_name
|
||||
for name, alias in imp.named_imports:
|
||||
if name == export_name:
|
||||
return (alias if alias else name, imp)
|
||||
|
||||
# Check namespace import
|
||||
if imp.namespace_import:
|
||||
return (f"{imp.namespace_import}.{export_name}", imp)
|
||||
|
||||
# Handle CommonJS default import used as namespace
|
||||
# e.g., const helpers = require('./helpers'); helpers.processConfig()
|
||||
# In this case, default_import acts like a namespace
|
||||
if imp.default_import and not imp.named_imports:
|
||||
return (f"{imp.default_import}.{export_name}", imp)
|
||||
|
||||
return None
|
||||
|
||||
def _find_references_in_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
source_code: str,
|
||||
function_name: str,
|
||||
import_name: str | None,
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
include_self: bool = True,
|
||||
) -> list[Reference]:
|
||||
"""Find all references to a function within a single file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to search.
|
||||
source_code: Source code content.
|
||||
function_name: Original function name.
|
||||
import_name: Name the function is imported as (may be different).
|
||||
analyzer: TreeSitterAnalyzer instance.
|
||||
include_self: Whether to include references in the file.
|
||||
|
||||
Returns:
|
||||
List of Reference objects.
|
||||
|
||||
"""
|
||||
references: list[Reference] = []
|
||||
source_bytes = source_code.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
lines = source_code.splitlines()
|
||||
|
||||
# The name to search for (either imported name or original)
|
||||
search_name = import_name if import_name else function_name
|
||||
|
||||
# Handle namespace imports (e.g., "utils.helper")
|
||||
if "." in search_name:
|
||||
namespace, member = search_name.split(".", 1)
|
||||
self._find_member_calls(
|
||||
tree.root_node, source_bytes, lines, file_path, namespace, member, references, None
|
||||
)
|
||||
else:
|
||||
# Find direct calls and other reference types
|
||||
self._find_identifier_references(
|
||||
tree.root_node, source_bytes, lines, file_path, search_name, function_name, references, None
|
||||
)
|
||||
|
||||
return references
|
||||
|
||||
def _find_identifier_references(
|
||||
self,
|
||||
node: Node,
|
||||
source_bytes: bytes,
|
||||
lines: list[str],
|
||||
file_path: Path,
|
||||
search_name: str,
|
||||
original_name: str,
|
||||
references: list[Reference],
|
||||
current_function: str | None,
|
||||
) -> None:
|
||||
"""Recursively find references to an identifier in the AST.
|
||||
|
||||
Args:
|
||||
node: Current tree-sitter node.
|
||||
source_bytes: Source code as bytes.
|
||||
lines: Source code split into lines.
|
||||
file_path: Path to the file.
|
||||
search_name: Name to search for.
|
||||
original_name: Original function name.
|
||||
references: List to append references to.
|
||||
current_function: Name of the containing function (for context).
|
||||
|
||||
"""
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer
|
||||
|
||||
# Track current function context
|
||||
new_current_function = current_function
|
||||
if node.type in ("function_declaration", "method_definition"):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
|
||||
elif node.type in ("variable_declarator",):
|
||||
# Arrow function or function expression assigned to variable
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if name_node and value_node and value_node.type in ("arrow_function", "function_expression"):
|
||||
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
|
||||
|
||||
# Check for call expressions
|
||||
if node.type == "call_expression":
|
||||
func_node = node.child_by_field_name("function")
|
||||
if func_node and func_node.type == "identifier":
|
||||
name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8")
|
||||
if name == search_name:
|
||||
ref = self._create_reference(
|
||||
file_path, func_node, lines, "call", search_name, current_function
|
||||
)
|
||||
references.append(ref)
|
||||
|
||||
# Check for identifiers used as callbacks or passed as arguments
|
||||
elif node.type == "identifier":
|
||||
name = source_bytes[node.start_byte : node.end_byte].decode("utf8")
|
||||
if name == search_name:
|
||||
parent = node.parent
|
||||
# Determine reference type based on context
|
||||
ref_type = self._determine_reference_type(node, parent, source_bytes)
|
||||
if ref_type:
|
||||
ref = self._create_reference(
|
||||
file_path, node, lines, ref_type, search_name, current_function
|
||||
)
|
||||
references.append(ref)
|
||||
|
||||
# Recurse into children
|
||||
for child in node.children:
|
||||
self._find_identifier_references(
|
||||
child, source_bytes, lines, file_path, search_name, original_name, references, new_current_function
|
||||
)
|
||||
|
||||
def _find_member_calls(
|
||||
self,
|
||||
node: Node,
|
||||
source_bytes: bytes,
|
||||
lines: list[str],
|
||||
file_path: Path,
|
||||
namespace: str,
|
||||
member: str,
|
||||
references: list[Reference],
|
||||
current_function: str | None,
|
||||
) -> None:
|
||||
"""Find calls to namespace.member (e.g., utils.helper()).
|
||||
|
||||
Args:
|
||||
node: Current tree-sitter node.
|
||||
source_bytes: Source code as bytes.
|
||||
lines: Source code split into lines.
|
||||
file_path: Path to the file.
|
||||
namespace: The namespace/object name.
|
||||
member: The member/property name.
|
||||
references: List to append references to.
|
||||
current_function: Name of the containing function.
|
||||
|
||||
"""
|
||||
# Track current function context
|
||||
new_current_function = current_function
|
||||
if node.type in ("function_declaration", "method_definition"):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
|
||||
|
||||
# Check for call expressions with member access
|
||||
if node.type == "call_expression":
|
||||
func_node = node.child_by_field_name("function")
|
||||
if func_node and func_node.type == "member_expression":
|
||||
obj_node = func_node.child_by_field_name("object")
|
||||
prop_node = func_node.child_by_field_name("property")
|
||||
|
||||
if obj_node and prop_node:
|
||||
obj_name = source_bytes[obj_node.start_byte : obj_node.end_byte].decode("utf8")
|
||||
prop_name = source_bytes[prop_node.start_byte : prop_node.end_byte].decode("utf8")
|
||||
|
||||
if obj_name == namespace and prop_name == member:
|
||||
ref = self._create_reference(
|
||||
file_path, func_node, lines, "call", f"{namespace}.{member}", current_function
|
||||
)
|
||||
references.append(ref)
|
||||
|
||||
# Also check for member expression used as callback
|
||||
elif node.type == "member_expression":
|
||||
obj_node = node.child_by_field_name("object")
|
||||
prop_node = node.child_by_field_name("property")
|
||||
|
||||
if obj_node and prop_node:
|
||||
obj_name = source_bytes[obj_node.start_byte : obj_node.end_byte].decode("utf8")
|
||||
prop_name = source_bytes[prop_node.start_byte : prop_node.end_byte].decode("utf8")
|
||||
|
||||
if obj_name == namespace and prop_name == member:
|
||||
parent = node.parent
|
||||
if parent and parent.type != "call_expression":
|
||||
ref_type = self._determine_reference_type(node, parent, source_bytes)
|
||||
if ref_type:
|
||||
ref = self._create_reference(
|
||||
file_path, node, lines, ref_type, f"{namespace}.{member}", current_function
|
||||
)
|
||||
references.append(ref)
|
||||
|
||||
# Recurse into children
|
||||
for child in node.children:
|
||||
self._find_member_calls(
|
||||
child, source_bytes, lines, file_path, namespace, member, references, new_current_function
|
||||
)
|
||||
|
||||
def _determine_reference_type(self, node: Node, parent: Node | None, source_bytes: bytes) -> str | None:
|
||||
"""Determine the type of reference based on AST context.
|
||||
|
||||
Args:
|
||||
node: The identifier node.
|
||||
parent: The parent node.
|
||||
source_bytes: Source code as bytes.
|
||||
|
||||
Returns:
|
||||
Reference type string or None if this isn't a valid reference.
|
||||
|
||||
"""
|
||||
if parent is None:
|
||||
return None
|
||||
|
||||
# Skip import statements
|
||||
if parent.type in ("import_specifier", "import_clause", "named_imports"):
|
||||
return None
|
||||
|
||||
# Skip function declarations (the function name itself)
|
||||
if parent.type in ("function_declaration", "method_definition"):
|
||||
name_node = parent.child_by_field_name("name")
|
||||
if name_node and name_node.id == node.id:
|
||||
return None
|
||||
|
||||
# Skip variable declarations where this is being defined
|
||||
if parent.type == "variable_declarator":
|
||||
name_node = parent.child_by_field_name("name")
|
||||
if name_node and name_node.id == node.id:
|
||||
return None
|
||||
|
||||
# Skip export specifiers
|
||||
if parent.type == "export_specifier":
|
||||
return None
|
||||
|
||||
# Check if passed as argument (callback or memoized)
|
||||
if parent.type == "arguments":
|
||||
# Check if grandparent is a memoize call
|
||||
grandparent = parent.parent
|
||||
if grandparent and grandparent.type == "call_expression":
|
||||
func_node = grandparent.child_by_field_name("function")
|
||||
if func_node:
|
||||
func_name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8")
|
||||
if any(m in func_name.lower() for m in ["memoize", "memo", "cache"]):
|
||||
return "memoized"
|
||||
return "callback"
|
||||
|
||||
# Check if used in array (often callback patterns)
|
||||
if parent.type == "array":
|
||||
return "callback"
|
||||
|
||||
# Check if passed to memoize/memoization functions (direct call check)
|
||||
if parent.type == "call_expression":
|
||||
func_node = parent.child_by_field_name("function")
|
||||
if func_node:
|
||||
func_name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8")
|
||||
if any(m in func_name.lower() for m in ["memoize", "memo", "cache"]):
|
||||
return "memoized"
|
||||
|
||||
# Check if used in a call expression as the function
|
||||
if parent.type == "call_expression":
|
||||
func_node = parent.child_by_field_name("function")
|
||||
if func_node and func_node.id == node.id:
|
||||
return "call"
|
||||
|
||||
# Check if assigned to a property
|
||||
if parent.type in ("pair", "property"):
|
||||
return "property"
|
||||
|
||||
# Check if part of member expression (method call setup)
|
||||
if parent.type == "member_expression":
|
||||
obj_node = parent.child_by_field_name("object")
|
||||
if obj_node and obj_node.id == node.id:
|
||||
# This is the object in obj.method
|
||||
return None # We'll catch the actual call elsewhere
|
||||
|
||||
# Generic reference
|
||||
return "reference"
|
||||
|
||||
def _create_reference(
|
||||
self,
|
||||
file_path: Path,
|
||||
node: Node,
|
||||
lines: list[str],
|
||||
ref_type: str,
|
||||
import_name: str,
|
||||
caller_function: str | None,
|
||||
) -> Reference:
|
||||
"""Create a Reference object from a node.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
node: The tree-sitter node.
|
||||
lines: Source code lines.
|
||||
ref_type: Type of reference.
|
||||
import_name: Name the function was imported as.
|
||||
caller_function: Name of the containing function.
|
||||
|
||||
Returns:
|
||||
A Reference object.
|
||||
|
||||
"""
|
||||
line_num = node.start_point[0] + 1 # 1-indexed
|
||||
context = lines[node.start_point[0]] if node.start_point[0] < len(lines) else ""
|
||||
|
||||
return Reference(
|
||||
file_path=file_path,
|
||||
line=line_num,
|
||||
column=node.start_point[1],
|
||||
end_line=node.end_point[0] + 1,
|
||||
end_column=node.end_point[1],
|
||||
context=context.strip(),
|
||||
reference_type=ref_type,
|
||||
import_name=import_name,
|
||||
caller_function=caller_function,
|
||||
)
|
||||
|
||||
def _find_reexports(
|
||||
self,
|
||||
file_path: Path,
|
||||
source_code: str,
|
||||
exported: ExportedFunction,
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
context: ReferenceSearchContext,
|
||||
) -> list[Reference]:
|
||||
"""Find re-exports of the function.
|
||||
|
||||
Re-exports look like: export { helper } from './utils'
|
||||
|
||||
Args:
|
||||
file_path: Path to the file being checked.
|
||||
source_code: Source code content.
|
||||
exported: Information about the original export.
|
||||
analyzer: TreeSitterAnalyzer instance.
|
||||
context: Search context.
|
||||
|
||||
Returns:
|
||||
List of Reference objects for re-exports.
|
||||
|
||||
"""
|
||||
references: list[Reference] = []
|
||||
exports = analyzer.find_exports(source_code)
|
||||
lines = source_code.splitlines()
|
||||
|
||||
for exp in exports:
|
||||
if not exp.is_reexport:
|
||||
continue
|
||||
|
||||
# Check if this re-exports our function
|
||||
export_name = exported.export_name or exported.function_name
|
||||
for name, alias in exp.exported_names:
|
||||
if name == export_name:
|
||||
# This is a re-export of our function
|
||||
# Create a reference with the line info from the export
|
||||
context_line = lines[exp.start_line - 1] if exp.start_line <= len(lines) else ""
|
||||
ref = Reference(
|
||||
file_path=file_path,
|
||||
line=exp.start_line,
|
||||
column=0,
|
||||
end_line=exp.end_line,
|
||||
end_column=0,
|
||||
context=context_line.strip(),
|
||||
reference_type="reexport",
|
||||
import_name=alias if alias else name,
|
||||
caller_function=None,
|
||||
)
|
||||
references.append(ref)
|
||||
|
||||
return references
|
||||
|
||||
def _find_reexports_direct(
|
||||
self,
|
||||
file_path: Path,
|
||||
source_code: str,
|
||||
source_file: Path,
|
||||
exported: ExportedFunction,
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
) -> list[Reference]:
|
||||
"""Find re-exports that directly reference our source file.
|
||||
|
||||
This method checks if a file has re-export statements that
|
||||
reference our source file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file being checked.
|
||||
source_code: Source code content.
|
||||
source_file: The original source file we're looking for references to.
|
||||
exported: Information about the original export.
|
||||
analyzer: TreeSitterAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of Reference objects for re-exports.
|
||||
|
||||
"""
|
||||
from codeflash.languages.javascript.import_resolver import ImportResolver
|
||||
|
||||
references: list[Reference] = []
|
||||
exports = analyzer.find_exports(source_code)
|
||||
lines = source_code.splitlines()
|
||||
resolver = ImportResolver(self.project_root)
|
||||
|
||||
for exp in exports:
|
||||
if not exp.is_reexport or not exp.reexport_source:
|
||||
continue
|
||||
|
||||
# Create a fake ImportInfo to resolve the re-export source
|
||||
from codeflash.languages.treesitter_utils import ImportInfo
|
||||
|
||||
fake_import = ImportInfo(
|
||||
module_path=exp.reexport_source,
|
||||
default_import=None,
|
||||
named_imports=[],
|
||||
namespace_import=None,
|
||||
is_type_only=False,
|
||||
start_line=exp.start_line,
|
||||
end_line=exp.end_line,
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(fake_import, file_path)
|
||||
if resolved is None or resolved.file_path != source_file:
|
||||
continue
|
||||
|
||||
# This file re-exports from our source file
|
||||
export_name = exported.export_name or exported.function_name
|
||||
for name, alias in exp.exported_names:
|
||||
if name == export_name:
|
||||
context_line = lines[exp.start_line - 1] if exp.start_line <= len(lines) else ""
|
||||
ref = Reference(
|
||||
file_path=file_path,
|
||||
line=exp.start_line,
|
||||
column=0,
|
||||
end_line=exp.end_line,
|
||||
end_column=0,
|
||||
context=context_line.strip(),
|
||||
reference_type="reexport",
|
||||
import_name=alias if alias else name,
|
||||
caller_function=None,
|
||||
)
|
||||
references.append(ref)
|
||||
|
||||
return references
|
||||
|
||||
def _iter_project_files(self) -> list[Path]:
|
||||
"""Iterate over all JavaScript/TypeScript files in the project.
|
||||
|
||||
Returns:
|
||||
List of file paths to search.
|
||||
|
||||
"""
|
||||
files: list[Path] = []
|
||||
|
||||
for ext in self.EXTENSIONS:
|
||||
for file_path in self.project_root.rglob(f"*{ext}"):
|
||||
# Check exclusion patterns
|
||||
if self._should_exclude(file_path):
|
||||
continue
|
||||
files.append(file_path)
|
||||
|
||||
return files
|
||||
|
||||
def _should_exclude(self, file_path: Path) -> bool:
|
||||
"""Check if a file should be excluded from search.
|
||||
|
||||
Args:
|
||||
file_path: Path to check.
|
||||
|
||||
Returns:
|
||||
True if the file should be excluded.
|
||||
|
||||
"""
|
||||
path_str = str(file_path)
|
||||
for pattern in self.exclude_patterns:
|
||||
if pattern in path_str:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _read_file(self, file_path: Path) -> str | None:
|
||||
"""Read a file's contents with caching.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
|
||||
Returns:
|
||||
File contents or None if unreadable.
|
||||
|
||||
"""
|
||||
if file_path in self._file_cache:
|
||||
return self._file_cache[file_path]
|
||||
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
self._file_cache[file_path] = content
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.debug("Could not read file %s: %s", file_path, e)
|
||||
return None
|
||||
|
||||
|
||||
def find_references(
|
||||
function_name: str,
|
||||
source_file: Path,
|
||||
project_root: Path | None = None,
|
||||
max_files: int = 1000,
|
||||
) -> list[Reference]:
|
||||
"""Convenience function to find all references to a function.
|
||||
|
||||
This is a simple wrapper around ReferenceFinder for common use cases.
|
||||
|
||||
Args:
|
||||
function_name: Name of the function to find references for.
|
||||
source_file: Path to the file where the function is defined.
|
||||
project_root: Root directory of the project. If None, uses source_file's parent.
|
||||
max_files: Maximum number of files to search.
|
||||
|
||||
Returns:
|
||||
List of Reference objects describing each call site.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from pathlib import Path
|
||||
from codeflash.languages.javascript.find_references import find_references
|
||||
|
||||
refs = find_references(
|
||||
function_name="myHelper",
|
||||
source_file=Path("/my/project/src/utils.ts"),
|
||||
project_root=Path("/my/project")
|
||||
)
|
||||
for ref in refs:
|
||||
print(f"{ref.file_path}:{ref.line}:{ref.column} - {ref.reference_type}")
|
||||
```
|
||||
|
||||
"""
|
||||
if project_root is None:
|
||||
project_root = source_file.parent
|
||||
|
||||
finder = ReferenceFinder(project_root)
|
||||
return finder.find_references(function_name, source_file, max_files=max_files)
|
||||
|
|
@ -19,6 +19,7 @@ from codeflash.languages.base import (
|
|||
HelperFunction,
|
||||
Language,
|
||||
ParentInfo,
|
||||
ReferenceInfo,
|
||||
TestInfo,
|
||||
TestResult,
|
||||
)
|
||||
|
|
@ -53,6 +54,11 @@ class JavaScriptSupport:
|
|||
"""File extensions supported by JavaScript."""
|
||||
return (".js", ".jsx", ".mjs", ".cjs")
|
||||
|
||||
@property
|
||||
def default_file_extension(self) -> str:
|
||||
"""Default file extension for JavaScript."""
|
||||
return ".js"
|
||||
|
||||
@property
|
||||
def test_framework(self) -> str:
|
||||
"""Primary test framework for JavaScript."""
|
||||
|
|
@ -959,6 +965,66 @@ class JavaScriptSupport:
|
|||
logger.warning("Failed to find helpers for %s: %s", function.name, e)
|
||||
return []
|
||||
|
||||
def find_references(
|
||||
self,
|
||||
function: FunctionInfo,
|
||||
project_root: Path,
|
||||
tests_root: Path | None = None,
|
||||
max_files: int = 500,
|
||||
) -> list[ReferenceInfo]:
|
||||
"""Find all references (call sites) to a function across the codebase.
|
||||
|
||||
Uses tree-sitter to find all places where a JavaScript/TypeScript function
|
||||
is called, including direct calls, callbacks, memoized versions, and re-exports.
|
||||
|
||||
Args:
|
||||
function: The function to find references for.
|
||||
project_root: Root of the project to search.
|
||||
tests_root: Root of tests directory (references in tests are excluded).
|
||||
max_files: Maximum number of files to search.
|
||||
|
||||
Returns:
|
||||
List of ReferenceInfo objects describing each reference location.
|
||||
|
||||
"""
|
||||
from codeflash.languages.base import ReferenceInfo
|
||||
from codeflash.languages.javascript.find_references import ReferenceFinder
|
||||
|
||||
try:
|
||||
finder = ReferenceFinder(project_root)
|
||||
refs = finder.find_references(function.name, function.file_path, max_files=max_files)
|
||||
|
||||
# Convert to ReferenceInfo and filter out tests
|
||||
result: list[ReferenceInfo] = []
|
||||
for ref in refs:
|
||||
# Exclude test files if tests_root is provided
|
||||
if tests_root:
|
||||
try:
|
||||
ref.file_path.relative_to(tests_root)
|
||||
continue # Skip if in tests_root
|
||||
except ValueError:
|
||||
pass # Not in tests_root, include it
|
||||
|
||||
result.append(
|
||||
ReferenceInfo(
|
||||
file_path=ref.file_path,
|
||||
line=ref.line,
|
||||
column=ref.column,
|
||||
end_line=ref.end_line,
|
||||
end_column=ref.end_column,
|
||||
context=ref.context,
|
||||
reference_type=ref.reference_type,
|
||||
import_name=ref.import_name,
|
||||
caller_function=ref.caller_function,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to find references for %s: %s", function.name, e)
|
||||
return []
|
||||
|
||||
# === Code Transformation ===
|
||||
|
||||
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from codeflash.languages.base import (
|
|||
HelperFunction,
|
||||
Language,
|
||||
ParentInfo,
|
||||
ReferenceInfo,
|
||||
TestInfo,
|
||||
TestResult,
|
||||
)
|
||||
|
|
@ -45,6 +46,11 @@ class PythonSupport:
|
|||
"""File extensions supported by Python."""
|
||||
return (".py", ".pyw")
|
||||
|
||||
@property
|
||||
def default_file_extension(self) -> str:
|
||||
"""Default file extension for Python."""
|
||||
return ".py"
|
||||
|
||||
@property
|
||||
def test_framework(self) -> str:
|
||||
"""Primary test framework for Python."""
|
||||
|
|
@ -289,6 +295,120 @@ class PythonSupport:
|
|||
|
||||
return helpers
|
||||
|
||||
def find_references(
|
||||
self,
|
||||
function: FunctionInfo,
|
||||
project_root: Path,
|
||||
tests_root: Path | None = None,
|
||||
max_files: int = 500,
|
||||
) -> list[ReferenceInfo]:
|
||||
"""Find all references (call sites) to a function across the codebase.
|
||||
|
||||
Uses jedi to find all places where a Python function is called.
|
||||
|
||||
Args:
|
||||
function: The function to find references for.
|
||||
project_root: Root of the project to search.
|
||||
tests_root: Root of tests directory (references in tests are excluded).
|
||||
max_files: Maximum number of files to search.
|
||||
|
||||
Returns:
|
||||
List of ReferenceInfo objects describing each reference location.
|
||||
|
||||
"""
|
||||
try:
|
||||
import jedi
|
||||
|
||||
source = function.file_path.read_text()
|
||||
|
||||
# Find the function position
|
||||
script = jedi.Script(code=source, path=function.file_path)
|
||||
names = script.get_names(all_scopes=True, definitions=True)
|
||||
|
||||
function_pos = None
|
||||
for name in names:
|
||||
if name.type == "function" and name.name == function.name:
|
||||
# Check for class parent if it's a method
|
||||
if function.class_name:
|
||||
parent = name.parent()
|
||||
if parent and parent.name == function.class_name and parent.type == "class":
|
||||
function_pos = (name.line, name.column)
|
||||
break
|
||||
else:
|
||||
function_pos = (name.line, name.column)
|
||||
break
|
||||
|
||||
if function_pos is None:
|
||||
return []
|
||||
|
||||
# Get references using jedi
|
||||
script = jedi.Script(code=source, path=function.file_path, project=jedi.Project(path=project_root))
|
||||
references = script.get_references(line=function_pos[0], column=function_pos[1])
|
||||
|
||||
result: list[ReferenceInfo] = []
|
||||
seen_locations: set[tuple[Path, int, int]] = set()
|
||||
|
||||
for ref in references:
|
||||
if not ref.module_path:
|
||||
continue
|
||||
|
||||
ref_path = Path(ref.module_path)
|
||||
|
||||
# Skip the definition itself
|
||||
if ref_path == function.file_path and ref.line == function_pos[0]:
|
||||
continue
|
||||
|
||||
# Skip test files
|
||||
if tests_root:
|
||||
try:
|
||||
ref_path.relative_to(tests_root)
|
||||
continue
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Avoid duplicates
|
||||
loc_key = (ref_path, ref.line, ref.column)
|
||||
if loc_key in seen_locations:
|
||||
continue
|
||||
seen_locations.add(loc_key)
|
||||
|
||||
# Get context line
|
||||
try:
|
||||
ref_source = ref_path.read_text()
|
||||
lines = ref_source.splitlines()
|
||||
context = lines[ref.line - 1] if ref.line <= len(lines) else ""
|
||||
except Exception:
|
||||
context = ""
|
||||
|
||||
# Determine caller function
|
||||
caller_function = None
|
||||
try:
|
||||
parent = ref.parent()
|
||||
if parent and parent.type == "function":
|
||||
caller_function = parent.name
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
result.append(
|
||||
ReferenceInfo(
|
||||
file_path=ref_path,
|
||||
line=ref.line,
|
||||
column=ref.column,
|
||||
end_line=ref.line,
|
||||
end_column=ref.column + len(function.name),
|
||||
context=context.strip(),
|
||||
reference_type="call",
|
||||
import_name=function.name,
|
||||
caller_function=caller_function,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to find references for %s: %s", function.name, e)
|
||||
return []
|
||||
|
||||
# === Code Transformation ===
|
||||
|
||||
def replace_function(self, source: str, function: FunctionInfo, new_source: str) -> str:
|
||||
|
|
|
|||
|
|
@ -178,6 +178,41 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport:
|
|||
_FRAMEWORK_CACHE: dict[str, LanguageSupport] = {}
|
||||
|
||||
|
||||
def get_language_support_by_common_formatters(formatter_cmd: str | list[str]) -> LanguageSupport | None:
|
||||
language: Language | None = None
|
||||
if isinstance(formatter_cmd, str):
|
||||
formatter_cmd = [formatter_cmd]
|
||||
|
||||
if len(formatter_cmd) == 1:
|
||||
formatter_cmd = formatter_cmd[0].split(" ")
|
||||
|
||||
# Try as extension first
|
||||
ext = None
|
||||
|
||||
py_formatters = ["black", "isort", "ruff", "autopep8", "yapf", "pyfmt"]
|
||||
js_ts_formatters = ["prettier", "eslint", "biome", "rome", "deno", "standard", "tslint"]
|
||||
|
||||
if any(cmd in py_formatters for cmd in formatter_cmd):
|
||||
ext = ".py"
|
||||
elif any(cmd in js_ts_formatters for cmd in formatter_cmd):
|
||||
ext = ".js"
|
||||
|
||||
if ext is None:
|
||||
# can't determine language
|
||||
return None
|
||||
|
||||
cls = _EXTENSION_REGISTRY[ext]
|
||||
language = cls().language
|
||||
|
||||
# Return cached instance or create new one
|
||||
if language not in _SUPPORT_CACHE:
|
||||
if language not in _LANGUAGE_REGISTRY:
|
||||
raise UnsupportedLanguageError(str(language), get_supported_languages())
|
||||
_SUPPORT_CACHE[language] = _LANGUAGE_REGISTRY[language]()
|
||||
|
||||
return _SUPPORT_CACHE[language]
|
||||
|
||||
|
||||
def get_language_support_by_framework(test_framework: str) -> LanguageSupport | None:
|
||||
"""Get language support for a test framework.
|
||||
|
||||
|
|
|
|||
|
|
@ -103,6 +103,9 @@ class CodeflashConfig(BaseModel):
|
|||
if self.module_root and self.module_root not in (".", "src"):
|
||||
config["moduleRoot"] = self.module_root
|
||||
|
||||
if self.tests_root:
|
||||
config["testsRoot"] = self.tests_root
|
||||
|
||||
# Formatter (only if explicitly set)
|
||||
if self.formatter_cmds:
|
||||
config["formatterCmds"] = self.formatter_cmds
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ function installCodeflash(uvBin) {
|
|||
try {
|
||||
// Use uv tool install to install codeflash in an isolated environment
|
||||
// This avoids conflicts with any existing Python environments
|
||||
execSync(`"${uvBin}" tool install codeflash --force`, {
|
||||
execSync(`"${uvBin}" tool install --force --python python3.12 codeflash`, {
|
||||
stdio: 'inherit',
|
||||
shell: true,
|
||||
});
|
||||
|
|
|
|||
|
|
@ -604,3 +604,349 @@ def test_function_in_tests_dir():
|
|||
assert "vanilla_function" not in remaining_functions
|
||||
files_and_funcs = get_all_files_and_functions(module_root_path=temp_dir, ignore_paths=[])
|
||||
assert len(files_and_funcs) == 6
|
||||
|
||||
|
||||
def test_filter_functions_tests_root_overlaps_source():
|
||||
"""Test that source files are not filtered when tests_root equals module_root or project_root.
|
||||
|
||||
This is a critical test for monorepo structures where tests live alongside source code
|
||||
(e.g., TypeScript projects with .test.ts files in the same directories as source).
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
temp_dir = Path(temp_dir_str)
|
||||
|
||||
# Create a source file (NOT a test file)
|
||||
source_file = temp_dir / "utils.py"
|
||||
with source_file.open("w") as f:
|
||||
f.write("""
|
||||
def process_data(items):
|
||||
return [item * 2 for item in items]
|
||||
|
||||
def calculate_sum(numbers):
|
||||
return sum(numbers)
|
||||
""")
|
||||
|
||||
# Create a test file with standard naming pattern
|
||||
test_file = temp_dir / "utils.test.py"
|
||||
with test_file.open("w") as f:
|
||||
f.write("""
|
||||
def test_process_data():
|
||||
return "test"
|
||||
""")
|
||||
|
||||
# Create a test file with _test suffix pattern
|
||||
test_file_underscore = temp_dir / "utils_test.py"
|
||||
with test_file_underscore.open("w") as f:
|
||||
f.write("""
|
||||
def test_calculate_sum():
|
||||
return "test"
|
||||
""")
|
||||
|
||||
# Create a spec file
|
||||
spec_file = temp_dir / "utils.spec.py"
|
||||
with spec_file.open("w") as f:
|
||||
f.write("""
|
||||
def spec_function():
|
||||
return "spec"
|
||||
""")
|
||||
|
||||
# Create a file in a tests subdirectory
|
||||
tests_subdir = temp_dir / "tests"
|
||||
tests_subdir.mkdir()
|
||||
tests_subdir_file = tests_subdir / "test_main.py"
|
||||
with tests_subdir_file.open("w") as f:
|
||||
f.write("""
|
||||
def test_in_tests_dir():
|
||||
return "test"
|
||||
""")
|
||||
|
||||
# Create a file in __tests__ subdirectory (common in JS/TS projects)
|
||||
dunder_tests_subdir = temp_dir / "__tests__"
|
||||
dunder_tests_subdir.mkdir()
|
||||
dunder_tests_file = dunder_tests_subdir / "main.py"
|
||||
with dunder_tests_file.open("w") as f:
|
||||
f.write("""
|
||||
def test_in_dunder_tests():
|
||||
return "test"
|
||||
""")
|
||||
|
||||
# Discover all functions
|
||||
discovered_source = find_all_functions_in_file(source_file)
|
||||
discovered_test = find_all_functions_in_file(test_file)
|
||||
discovered_test_underscore = find_all_functions_in_file(test_file_underscore)
|
||||
discovered_spec = find_all_functions_in_file(spec_file)
|
||||
discovered_tests_dir = find_all_functions_in_file(tests_subdir_file)
|
||||
discovered_dunder_tests = find_all_functions_in_file(dunder_tests_file)
|
||||
|
||||
# Combine all discovered functions
|
||||
all_functions = {}
|
||||
for discovered in [discovered_source, discovered_test, discovered_test_underscore,
|
||||
discovered_spec, discovered_tests_dir, discovered_dunder_tests]:
|
||||
all_functions.update(discovered)
|
||||
|
||||
# Test Case 1: tests_root == module_root (overlapping case)
|
||||
# This is the bug scenario where all functions were being filtered
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered, count = filter_functions(
|
||||
all_functions,
|
||||
tests_root=temp_dir, # Same as module_root
|
||||
ignore_paths=[],
|
||||
project_root=temp_dir,
|
||||
module_root=temp_dir, # Same as tests_root
|
||||
)
|
||||
|
||||
# Strict check: only source_file should remain in filtered results
|
||||
assert set(filtered.keys()) == {source_file}, (
|
||||
f"Expected only source file in filtered results, got: {set(filtered.keys())}"
|
||||
)
|
||||
|
||||
# Strict check: exactly these two functions should be present
|
||||
source_functions = sorted([fn.function_name for fn in filtered.get(source_file, [])])
|
||||
assert source_functions == ["calculate_sum", "process_data"], (
|
||||
f"Expected ['calculate_sum', 'process_data'], got {source_functions}"
|
||||
)
|
||||
|
||||
# Strict check: exactly 2 functions remaining
|
||||
assert count == 2, f"Expected exactly 2 functions, got {count}"
|
||||
|
||||
# Test Case 2: tests_root == project_root (another overlapping case)
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered2, count2 = filter_functions(
|
||||
{source_file: discovered_source[source_file]},
|
||||
tests_root=temp_dir, # Same as project_root
|
||||
ignore_paths=[],
|
||||
project_root=temp_dir,
|
||||
module_root=temp_dir,
|
||||
)
|
||||
|
||||
# Strict check: only source_file should remain
|
||||
assert set(filtered2.keys()) == {source_file}, (
|
||||
f"Expected only source file when tests_root == project_root, got: {set(filtered2.keys())}"
|
||||
)
|
||||
assert count2 == 2, f"Expected exactly 2 functions, got {count2}"
|
||||
|
||||
|
||||
def test_filter_functions_strict_string_matching():
|
||||
"""Test that test file pattern matching uses strict string matching.
|
||||
|
||||
Ensures patterns like '.test.' only match actual test files and don't
|
||||
accidentally match files with similar names like 'contest.py' or 'latest.py'.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
temp_dir = Path(temp_dir_str)
|
||||
|
||||
# Files that should NOT be filtered (contain 'test' as substring but not as pattern)
|
||||
contest_file = temp_dir / "contest.py"
|
||||
with contest_file.open("w") as f:
|
||||
f.write("def run_contest(): return 1")
|
||||
|
||||
latest_file = temp_dir / "latest.py"
|
||||
with latest_file.open("w") as f:
|
||||
f.write("def get_latest(): return 1")
|
||||
|
||||
attestation_file = temp_dir / "attestation.py"
|
||||
with attestation_file.open("w") as f:
|
||||
f.write("def verify_attestation(): return 1")
|
||||
|
||||
# File that SHOULD be filtered (matches .test. pattern)
|
||||
actual_test_file = temp_dir / "utils.test.py"
|
||||
with actual_test_file.open("w") as f:
|
||||
f.write("def test_utils(): return 1")
|
||||
|
||||
# File that SHOULD be filtered (matches _test. pattern)
|
||||
underscore_test_file = temp_dir / "utils_test.py"
|
||||
with underscore_test_file.open("w") as f:
|
||||
f.write("def test_stuff(): return 1")
|
||||
|
||||
# Discover all functions
|
||||
all_functions = {}
|
||||
for file_path in [contest_file, latest_file, attestation_file, actual_test_file, underscore_test_file]:
|
||||
discovered = find_all_functions_in_file(file_path)
|
||||
all_functions.update(discovered)
|
||||
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered, count = filter_functions(
|
||||
all_functions,
|
||||
tests_root=temp_dir, # Overlapping case to trigger pattern matching
|
||||
ignore_paths=[],
|
||||
project_root=temp_dir,
|
||||
module_root=temp_dir,
|
||||
)
|
||||
|
||||
# Strict check: exactly these 3 files should remain (those with 'test' as substring only)
|
||||
expected_files = {contest_file, latest_file, attestation_file}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
|
||||
# Strict check: each file should have exactly 1 function with the expected name
|
||||
assert [fn.function_name for fn in filtered[contest_file]] == ["run_contest"], (
|
||||
f"Expected ['run_contest'], got {[fn.function_name for fn in filtered[contest_file]]}"
|
||||
)
|
||||
assert [fn.function_name for fn in filtered[latest_file]] == ["get_latest"], (
|
||||
f"Expected ['get_latest'], got {[fn.function_name for fn in filtered[latest_file]]}"
|
||||
)
|
||||
assert [fn.function_name for fn in filtered[attestation_file]] == ["verify_attestation"], (
|
||||
f"Expected ['verify_attestation'], got {[fn.function_name for fn in filtered[attestation_file]]}"
|
||||
)
|
||||
|
||||
# Strict check: exactly 3 functions remaining
|
||||
assert count == 3, f"Expected exactly 3 functions, got {count}"
|
||||
|
||||
|
||||
def test_filter_functions_test_directory_patterns():
|
||||
"""Test that test directory patterns work correctly with strict matching.
|
||||
|
||||
Ensures that /test/, /tests/, and /__tests__/ patterns only match actual
|
||||
test directories and not directories that happen to contain 'test' in name.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
temp_dir = Path(temp_dir_str)
|
||||
|
||||
# Directory that should NOT be filtered (contains 'test' but not as /test/ pattern)
|
||||
contest_dir = temp_dir / "contest_results"
|
||||
contest_dir.mkdir()
|
||||
contest_file = contest_dir / "scores.py"
|
||||
with contest_file.open("w") as f:
|
||||
f.write("def get_scores(): return [1, 2, 3]")
|
||||
|
||||
latest_dir = temp_dir / "latest_data"
|
||||
latest_dir.mkdir()
|
||||
latest_file = latest_dir / "data.py"
|
||||
with latest_file.open("w") as f:
|
||||
f.write("def load_data(): return {}")
|
||||
|
||||
# Directory that SHOULD be filtered (matches /tests/ pattern)
|
||||
tests_dir = temp_dir / "tests"
|
||||
tests_dir.mkdir()
|
||||
tests_file = tests_dir / "test_main.py"
|
||||
with tests_file.open("w") as f:
|
||||
f.write("def test_main(): return True")
|
||||
|
||||
# Directory that SHOULD be filtered (matches /test/ pattern - singular)
|
||||
test_dir = temp_dir / "test"
|
||||
test_dir.mkdir()
|
||||
test_file = test_dir / "test_utils.py"
|
||||
with test_file.open("w") as f:
|
||||
f.write("def test_utils(): return True")
|
||||
|
||||
# Directory that SHOULD be filtered (matches /__tests__/ pattern)
|
||||
dunder_tests_dir = temp_dir / "__tests__"
|
||||
dunder_tests_dir.mkdir()
|
||||
dunder_file = dunder_tests_dir / "component.py"
|
||||
with dunder_file.open("w") as f:
|
||||
f.write("def test_component(): return True")
|
||||
|
||||
# Nested test directory
|
||||
src_dir = temp_dir / "src"
|
||||
src_dir.mkdir()
|
||||
nested_tests_dir = src_dir / "tests"
|
||||
nested_tests_dir.mkdir()
|
||||
nested_test_file = nested_tests_dir / "test_nested.py"
|
||||
with nested_test_file.open("w") as f:
|
||||
f.write("def test_nested(): return True")
|
||||
|
||||
# Discover all functions
|
||||
all_functions = {}
|
||||
for file_path in [contest_file, latest_file, tests_file, test_file, dunder_file, nested_test_file]:
|
||||
discovered = find_all_functions_in_file(file_path)
|
||||
all_functions.update(discovered)
|
||||
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered, count = filter_functions(
|
||||
all_functions,
|
||||
tests_root=temp_dir, # Overlapping case
|
||||
ignore_paths=[],
|
||||
project_root=temp_dir,
|
||||
module_root=temp_dir,
|
||||
)
|
||||
|
||||
# Strict check: exactly these 2 files should remain (those in non-test directories)
|
||||
expected_files = {contest_file, latest_file}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
|
||||
# Strict check: each file should have exactly 1 function with the expected name
|
||||
assert [fn.function_name for fn in filtered[contest_file]] == ["get_scores"], (
|
||||
f"Expected ['get_scores'], got {[fn.function_name for fn in filtered[contest_file]]}"
|
||||
)
|
||||
assert [fn.function_name for fn in filtered[latest_file]] == ["load_data"], (
|
||||
f"Expected ['load_data'], got {[fn.function_name for fn in filtered[latest_file]]}"
|
||||
)
|
||||
|
||||
# Strict check: exactly 2 functions remaining
|
||||
assert count == 2, f"Expected exactly 2 functions, got {count}"
|
||||
|
||||
|
||||
def test_filter_functions_non_overlapping_tests_root():
|
||||
"""Test that the original directory-based filtering still works when tests_root is separate.
|
||||
|
||||
When tests_root is a distinct directory (e.g., 'tests/'), the original behavior
|
||||
of filtering files that start with tests_root should still work.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
temp_dir = Path(temp_dir_str)
|
||||
|
||||
# Create source directory structure
|
||||
src_dir = temp_dir / "src"
|
||||
src_dir.mkdir()
|
||||
source_file = src_dir / "utils.py"
|
||||
with source_file.open("w") as f:
|
||||
f.write("def process(): return 1")
|
||||
|
||||
# Create a file with .test. pattern in source (should NOT be filtered in non-overlapping mode)
|
||||
# because directory-based filtering takes precedence
|
||||
test_in_src = src_dir / "helper.test.py"
|
||||
with test_in_src.open("w") as f:
|
||||
f.write("def helper_test(): return 1")
|
||||
|
||||
# Create separate tests directory
|
||||
tests_dir = temp_dir / "tests"
|
||||
tests_dir.mkdir()
|
||||
test_file = tests_dir / "test_utils.py"
|
||||
with test_file.open("w") as f:
|
||||
f.write("def test_process(): return 1")
|
||||
|
||||
# Discover functions
|
||||
all_functions = {}
|
||||
for file_path in [source_file, test_in_src, test_file]:
|
||||
discovered = find_all_functions_in_file(file_path)
|
||||
all_functions.update(discovered)
|
||||
|
||||
# Non-overlapping case: tests_root is a separate directory
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered, count = filter_functions(
|
||||
all_functions,
|
||||
tests_root=tests_dir, # Separate from module_root
|
||||
ignore_paths=[],
|
||||
project_root=temp_dir,
|
||||
module_root=src_dir, # Different from tests_root
|
||||
)
|
||||
|
||||
# Strict check: exactly these 2 files should remain (both in src/, not in tests/)
|
||||
expected_files = {source_file, test_in_src}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
|
||||
# Strict check: each file should have exactly 1 function with the expected name
|
||||
assert [fn.function_name for fn in filtered[source_file]] == ["process"], (
|
||||
f"Expected ['process'], got {[fn.function_name for fn in filtered[source_file]]}"
|
||||
)
|
||||
assert [fn.function_name for fn in filtered[test_in_src]] == ["helper_test"], (
|
||||
f"Expected ['helper_test'], got {[fn.function_name for fn in filtered[test_in_src]]}"
|
||||
)
|
||||
|
||||
# Strict check: exactly 2 functions remaining
|
||||
assert count == 2, f"Expected exactly 2 functions, got {count}"
|
||||
|
|
|
|||
1088
tests/test_languages/test_find_references.py
Normal file
1088
tests/test_languages/test_find_references.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue