fix: Use FunctionToOptimize field names consistently across JS code

- Fix field name mismatches: .name → .function_name, .start_line → .starting_line,
  .end_line → .ending_line, .start_col → .starting_col, .end_col → .ending_col
- Fix circular imports by creating function_types.py with FunctionParent
- Add lazy language registration via _ensure_languages_registered()
- Fix macOS symlink path resolution in ImportResolver
- Update all affected code and tests to use correct FunctionToOptimize attributes

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
misrasaurabh1 2026-02-02 11:32:39 -08:00
parent ac682b81dd
commit a5edb73b13
23 changed files with 224 additions and 199 deletions

View file

@ -155,22 +155,21 @@ def get_package_install_command(project_root: Path, package: str, dev: bool = Tr
if dev:
cmd.append("--save-dev")
return cmd
elif pkg_manager == JsPackageManager.YARN:
if pkg_manager == JsPackageManager.YARN:
cmd = ["yarn", "add", package]
if dev:
cmd.append("--dev")
return cmd
elif pkg_manager == JsPackageManager.BUN:
if pkg_manager == JsPackageManager.BUN:
cmd = ["bun", "add", package]
if dev:
cmd.append("--dev")
return cmd
else:
# Default to npm
cmd = ["npm", "install", package]
if dev:
cmd.append("--save-dev")
return cmd
# Default to npm
cmd = ["npm", "install", package]
if dev:
cmd.append("--save-dev")
return cmd
def init_js_project(language: ProjectLanguage) -> None:

View file

@ -549,7 +549,7 @@ def replace_function_definitions_for_language(
# Find the function in current code
func = None
for f in current_functions:
if func_name in (f.qualified_name, f.name):
if func_name in (f.qualified_name, f.function_name):
func = f
break
@ -557,7 +557,9 @@ def replace_function_definitions_for_language(
continue
# Extract just this function from the optimized code
optimized_func = _extract_function_from_code(lang_support, code_to_apply, func.name, module_abspath)
optimized_func = _extract_function_from_code(
lang_support, code_to_apply, func.function_name, module_abspath
)
if optimized_func:
new_code = lang_support.replace_function(new_code, func, optimized_func)
modified = True
@ -596,13 +598,13 @@ def _extract_function_from_code(
# file_path is needed for JS/TS to determine correct analyzer (TypeScript vs JavaScript)
functions = lang_support.discover_functions_from_source(source_code, file_path)
for func in functions:
if func.name == function_name:
if func.function_name == function_name:
# Extract the function's source using line numbers
# Use doc_start_line if available to include JSDoc/docstring
lines = source_code.splitlines(keepends=True)
effective_start = func.doc_start_line or func.start_line
if effective_start and func.end_line and effective_start <= len(lines):
func_lines = lines[effective_start - 1 : func.end_line]
effective_start = func.doc_start_line or func.starting_line
if effective_start and func.ending_line and effective_start <= len(lines):
func_lines = lines[effective_start - 1 : func.ending_line]
return "".join(func_lines)
except Exception as e:
logger.debug(f"Error extracting function {function_name}: {e}")

View file

@ -335,25 +335,8 @@ def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list
try:
lang_support = get_language_support(file_path)
criteria = FunctionFilterCriteria(require_return=True)
function_infos = lang_support.discover_functions(file_path, criteria)
ftos = []
for func_info in function_infos:
parents = [FunctionParent(p.name, p.type) for p in func_info.parents]
ftos.append(
FunctionToOptimize(
function_name=func_info.name,
file_path=func_info.file_path,
parents=parents,
starting_line=func_info.start_line,
ending_line=func_info.end_line,
starting_col=func_info.start_col,
ending_col=func_info.end_col,
is_async=func_info.is_async,
language=func_info.language.value,
)
)
functions[file_path] = ftos
# discover_functions already returns FunctionToOptimize objects
functions[file_path] = lang_support.discover_functions(file_path, criteria)
except Exception as e:
logger.debug(f"Failed to discover functions in {file_path}: {e}")

View file

@ -26,17 +26,6 @@ from codeflash.languages.base import (
TestInfo,
TestResult,
)
# Lazy import for FunctionInfo to avoid circular imports
def __getattr__(name: str):
if name == "FunctionInfo":
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
return FunctionToOptimize
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
from codeflash.languages.current import (
current_language,
current_language_support,
@ -46,11 +35,9 @@ from codeflash.languages.current import (
reset_current_language,
set_current_language,
)
from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401
# Import language support modules to trigger auto-registration
# This ensures all supported languages are available when this package is imported
from codeflash.languages.python import PythonSupport # noqa: F401
# Language support modules are imported lazily to avoid circular imports
# They get registered when first accessed via get_language_support()
from codeflash.languages.registry import (
detect_project_language,
get_language_support,
@ -70,6 +57,29 @@ from codeflash.languages.test_framework import (
set_current_test_framework,
)
# Lazy imports to avoid circular imports
def __getattr__(name: str):
if name == "FunctionInfo":
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
return FunctionToOptimize
if name == "JavaScriptSupport":
from codeflash.languages.javascript.support import JavaScriptSupport
return JavaScriptSupport
if name == "TypeScriptSupport":
from codeflash.languages.javascript.support import TypeScriptSupport
return TypeScriptSupport
if name == "PythonSupport":
from codeflash.languages.python.support import PythonSupport
return PythonSupport
msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)
__all__ = [
"CodeContext",
"FunctionInfo",

View file

@ -17,7 +17,7 @@ if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.language_enum import Language
from codeflash.models.models import FunctionParent
from codeflash.models.function_types import FunctionParent
# Backward compatibility aliases - ParentInfo is now FunctionParent
ParentInfo = FunctionParent
@ -30,7 +30,8 @@ def __getattr__(name: str) -> Any:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
return FunctionToOptimize
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)
@dataclass

View file

@ -15,16 +15,16 @@ from __future__ import annotations
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pathlib import Path
from tree_sitter import Node
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
logger = logging.getLogger(__name__)
@ -781,10 +781,7 @@ class ReferenceFinder:
"""
path_str = str(file_path)
for pattern in self.exclude_patterns:
if pattern in path_str:
return True
return False
return any(pattern in path_str for pattern in self.exclude_patterns)
def _read_file(self, file_path: Path) -> str | None:
"""Read a file's contents with caching.

View file

@ -44,7 +44,8 @@ class ImportResolver:
project_root: Root directory of the project.
"""
self.project_root = project_root
# Resolve to real path to handle macOS symlinks like /var -> /private/var
self.project_root = project_root.resolve()
self._resolution_cache: dict[tuple[Path, str], Path | None] = {}
def resolve_import(self, import_info: ImportInfo, source_file: Path) -> ResolvedImport | None:
@ -329,7 +330,7 @@ class MultiFileHelperFinder:
all_functions = analyzer.find_functions(source, include_methods=True)
target_func = None
for func in all_functions:
if func.name == function.name and func.start_line == function.start_line:
if func.name == function.function_name and func.start_line == function.starting_line:
target_func = func
break
@ -506,7 +507,7 @@ class MultiFileHelperFinder:
Dictionary mapping file paths to lists of helper functions.
"""
from codeflash.languages.base import FunctionToOptimize
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.treesitter_utils import get_analyzer_for_file
if context.current_depth >= context.max_depth:

View file

@ -15,8 +15,7 @@ from codeflash.cli_cmds.console import logger
if TYPE_CHECKING:
from codeflash.code_utils.code_position import CodePosition
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
class TestingMode:

View file

@ -65,10 +65,10 @@ class JavaScriptLineProfiler:
lines = source.splitlines(keepends=True)
# Process functions in reverse order to preserve line numbers
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
for func in sorted(functions, key=lambda f: f.starting_line, reverse=True):
func_lines = self._instrument_function(func, lines, file_path)
start_idx = func.start_line - 1
end_idx = func.end_line
start_idx = func.starting_line - 1
end_idx = func.ending_line
lines = lines[:start_idx] + func_lines + lines[end_idx:]
instrumented_source = "".join(lines)
@ -183,7 +183,7 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
Instrumented function lines.
"""
func_lines = lines[func.start_line - 1 : func.end_line]
func_lines = lines[func.starting_line - 1 : func.ending_line]
instrumented_lines = []
# Parse the function to find executable lines
@ -194,7 +194,7 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
tree = analyzer.parse(source.encode("utf8"))
executable_lines = self._find_executable_lines(tree.root_node, source.encode("utf8"))
except Exception as e:
logger.warning("Failed to parse function %s: %s", func.name, e)
logger.warning("Failed to parse function %s: %s", func.function_name, e)
return func_lines
# Add profiling to each executable line
@ -203,7 +203,7 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
for local_idx, line in enumerate(func_lines):
local_line_num = local_idx + 1 # 1-indexed within function
global_line_num = func.start_line + local_idx # Global line number in original file
global_line_num = func.starting_line + local_idx # Global line number in original file
stripped = line.strip()
# Add enterFunction() call after the opening brace of the function

View file

@ -241,7 +241,7 @@ class JavaScriptSupport:
# Match source functions to tests
for func in source_functions:
if func.name in imported_names or func.name in source:
if func.function_name in imported_names or func.function_name in source:
if func.qualified_name not in result:
result[func.qualified_name] = []
for test_name in test_functions:

View file

@ -63,13 +63,7 @@ def _find_jest_config(project_root: Path) -> Path | None:
"""
# Common Jest config file names, in order of preference
config_names = [
"jest.config.ts",
"jest.config.js",
"jest.config.mjs",
"jest.config.cjs",
"jest.config.json",
]
config_names = ["jest.config.ts", "jest.config.js", "jest.config.mjs", "jest.config.cjs", "jest.config.json"]
# First check the project root itself
for config_name in config_names:
@ -226,14 +220,7 @@ def _ensure_runtime_files(project_root: Path) -> None:
install_cmd = get_package_install_command(project_root, "codeflash", dev=True)
try:
result = subprocess.run(
install_cmd,
check=False,
cwd=project_root,
capture_output=True,
text=True,
timeout=120,
)
result = subprocess.run(install_cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=120)
if result.returncode == 0:
logger.debug(f"Installed codeflash using {install_cmd[0]}")
return

View file

@ -64,10 +64,10 @@ class JavaScriptTracer:
lines = source.splitlines(keepends=True)
# Process functions in reverse order to preserve line numbers
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
for func in sorted(functions, key=lambda f: f.starting_line, reverse=True):
instrumented = self._instrument_function(func, lines, file_path)
start_idx = func.start_line - 1
end_idx = func.end_line
start_idx = func.starting_line - 1
end_idx = func.ending_line
lines = lines[:start_idx] + instrumented + lines[end_idx:]
instrumented_source = "".join(lines)
@ -281,11 +281,11 @@ process.on('exit', () => {{
Instrumented function lines.
"""
func_lines = lines[func.start_line - 1 : func.end_line]
func_lines = lines[func.starting_line - 1 : func.ending_line]
func_text = "".join(func_lines)
# Detect function pattern
func_name = func.name
func_name = func.function_name
is_arrow = "=>" in func_text.split("\n")[0]
is_method = func.is_method
is_async = func.is_async

View file

@ -86,14 +86,7 @@ def _ensure_runtime_files(project_root: Path) -> None:
install_cmd = get_package_install_command(project_root, "codeflash", dev=True)
try:
result = subprocess.run(
install_cmd,
check=False,
cwd=project_root,
capture_output=True,
text=True,
timeout=120,
)
result = subprocess.run(install_cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=120)
if result.returncode == 0:
logger.debug(f"Installed codeflash using {install_cmd[0]}")
return

View file

@ -159,7 +159,7 @@ class PythonSupport:
try:
source = test_file.read_text()
# Check if function name appears in test file
if func.name in source:
if func.function_name in source:
result[func.qualified_name].append(
TestInfo(test_name=test_file.stem, test_file=test_file, test_class=None)
)
@ -289,7 +289,7 @@ class PythonSupport:
)
except Exception as e:
logger.warning("Failed to find helpers for %s: %s", function.name, e)
logger.warning("Failed to find helpers for %s: %s", function.function_name, e)
return helpers
@ -389,10 +389,10 @@ class PythonSupport:
line=ref.line,
column=ref.column,
end_line=ref.line,
end_column=ref.column + len(function.name),
end_column=ref.column + len(function.function_name),
context=context.strip(),
reference_type="call",
import_name=function.name,
import_name=function.function_name,
caller_function=caller_function,
)
)
@ -400,7 +400,7 @@ class PythonSupport:
return result
except Exception as e:
logger.warning("Failed to find references for %s: %s", function.name, e)
logger.warning("Failed to find references for %s: %s", function.function_name, e)
return []
# === Code Transformation ===
@ -433,7 +433,7 @@ class PythonSupport:
preexisting_objects=set(),
)
except Exception as e:
logger.warning("Failed to replace function %s: %s", function.name, e)
logger.warning("Failed to replace function %s: %s", function.function_name, e)
return source
def format_code(self, source: str, file_path: Path | None = None) -> str:

View file

@ -30,6 +30,33 @@ _LANGUAGE_REGISTRY: dict[Language, type[LanguageSupport]] = {}
# Cache of instantiated language support objects
_SUPPORT_CACHE: dict[Language, LanguageSupport] = {}
# Flag to track if language modules have been imported
_languages_registered = False
def _ensure_languages_registered() -> None:
"""Ensure all language support modules are imported and registered.
This lazily imports the language support modules to avoid circular imports
at module load time. The imports trigger the @register_language decorators
which populate the registries.
"""
global _languages_registered
if _languages_registered:
return
# Import support modules to trigger registration
# These imports are deferred to avoid circular imports
import contextlib
with contextlib.suppress(ImportError):
from codeflash.languages.python import support as _
with contextlib.suppress(ImportError):
from codeflash.languages.javascript import support as _ # noqa: F401
_languages_registered = True
class UnsupportedLanguageError(Exception):
"""Raised when attempting to use an unsupported language."""
@ -123,6 +150,10 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport:
Raises:
UnsupportedLanguageError: If the language is not supported.
Note:
This function lazily imports language support modules on first call
to avoid circular import issues at module load time.
Example:
# By file path
lang = get_language_support(Path("example.py"))
@ -137,6 +168,7 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport:
lang = get_language_support("python")
"""
_ensure_languages_registered()
language: Language | None = None
if isinstance(identifier, Language):
@ -179,6 +211,7 @@ _FRAMEWORK_CACHE: dict[str, LanguageSupport] = {}
def get_language_support_by_common_formatters(formatter_cmd: str | list[str]) -> LanguageSupport | None:
_ensure_languages_registered()
language: Language | None = None
if isinstance(formatter_cmd, str):
formatter_cmd = [formatter_cmd]
@ -263,6 +296,7 @@ def detect_project_language(project_root: Path, module_root: Path) -> Language:
UnsupportedLanguageError: If no supported language is detected.
"""
_ensure_languages_registered()
extension_counts: dict[str, int] = {}
# Count files by extension
@ -290,6 +324,7 @@ def get_supported_languages() -> list[str]:
List of language name strings.
"""
_ensure_languages_registered()
return [lang.value for lang in _LANGUAGE_REGISTRY]
@ -300,6 +335,7 @@ def get_supported_extensions() -> list[str]:
List of extension strings (with leading dots).
"""
_ensure_languages_registered()
return list(_EXTENSION_REGISTRY.keys())
@ -325,10 +361,12 @@ def clear_registry() -> None:
Primarily useful for testing.
"""
global _languages_registered
_EXTENSION_REGISTRY.clear()
_LANGUAGE_REGISTRY.clear()
_SUPPORT_CACHE.clear()
_FRAMEWORK_CACHE.clear()
_languages_registered = False
def clear_cache() -> None:

View file

@ -0,0 +1,18 @@
"""Simple function-related types with no dependencies.
This module contains basic types used for function representation.
It is intentionally kept dependency-free to avoid circular imports.
"""
from __future__ import annotations
from pydantic.dataclasses import dataclass
@dataclass(frozen=True)
class FunctionParent:
name: str
type: str
def __str__(self) -> str:
return f"{self.type}:{self.name}"

View file

@ -598,13 +598,8 @@ class CodePosition:
col_no: int
@dataclass(frozen=True)
class FunctionParent:
name: str
type: str
def __str__(self) -> str:
return f"{self.type}:{self.name}"
# Re-export FunctionParent for backward compatibility
from codeflash.models.function_types import FunctionParent # noqa: E402
class OriginalCodeBaseline(BaseModel):

View file

@ -90,7 +90,7 @@ const multiply = (a, b) => a * b;
functions = js_support.discover_functions(file_path)
assert len(functions) == 1
func = functions[0]
assert func.name == "multiply"
assert func.function_name == "multiply"
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -268,7 +268,7 @@ class CacheManager {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
get_or_compute = next(f for f in functions if f.name == "getOrCompute")
get_or_compute = next(f for f in functions if f.function_name == "getOrCompute")
context = js_support.extract_code_context(get_or_compute, temp_project, temp_project)
@ -370,7 +370,7 @@ function validateUserData(data, validators) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
func = next(f for f in functions if f.name == "validateUserData")
func = next(f for f in functions if f.function_name == "validateUserData")
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -466,7 +466,7 @@ async function fetchWithRetry(endpoint, options = {}) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
func = next(f for f in functions if f.name == "fetchWithRetry")
func = next(f for f in functions if f.function_name == "fetchWithRetry")
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -615,7 +615,7 @@ function processUserInput(rawInput) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
process_func = next(f for f in functions if f.name == "processUserInput")
process_func = next(f for f in functions if f.function_name == "processUserInput")
context = js_support.extract_code_context(process_func, temp_project, temp_project)
@ -670,7 +670,7 @@ function generateReport(data) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
report_func = next(f for f in functions if f.name == "generateReport")
report_func = next(f for f in functions if f.function_name == "generateReport")
context = js_support.extract_code_context(report_func, temp_project, temp_project)
@ -768,7 +768,7 @@ class Graph {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
topo_sort = next(f for f in functions if f.name == "topologicalSort")
topo_sort = next(f for f in functions if f.function_name == "topologicalSort")
context = js_support.extract_code_context(topo_sort, temp_project, temp_project)
@ -843,7 +843,7 @@ class MainClass {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
main_method = next(f for f in functions if f.name == "mainMethod" and f.class_name == "MainClass")
main_method = next(f for f in functions if f.function_name == "mainMethod" and f.class_name == "MainClass")
context = js_support.extract_code_context(main_method, temp_project, temp_project)
@ -899,7 +899,7 @@ module.exports = { sortFromAnotherFile };
main_path.write_text(main_code, encoding="utf-8")
functions = js_support.discover_functions(main_path)
main_func = next(f for f in functions if f.name == "sortFromAnotherFile")
main_func = next(f for f in functions if f.function_name == "sortFromAnotherFile")
context = js_support.extract_code_context(main_func, temp_project, temp_project)
@ -952,7 +952,7 @@ export { processNumber };
main_path.write_text(main_code, encoding="utf-8")
functions = js_support.discover_functions(main_path)
process_func = next(f for f in functions if f.name == "processNumber")
process_func = next(f for f in functions if f.function_name == "processNumber")
context = js_support.extract_code_context(process_func, temp_project, temp_project)
@ -1020,7 +1020,7 @@ export { handleUserInput };
main_path.write_text(main_code, encoding="utf-8")
functions = js_support.discover_functions(main_path)
handle_func = next(f for f in functions if f.name == "handleUserInput")
handle_func = next(f for f in functions if f.function_name == "handleUserInput")
context = js_support.extract_code_context(handle_func, temp_project, temp_project)
@ -1161,7 +1161,7 @@ class TypedCache<T> {
file_path.write_text(code, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
get_method = next(f for f in functions if f.name == "get")
get_method = next(f for f in functions if f.function_name == "get")
context = ts_support.extract_code_context(get_method, temp_project, temp_project)
@ -1247,7 +1247,7 @@ export { createUser };
service_path.write_text(service_code, encoding="utf-8")
functions = ts_support.discover_functions(service_path)
func = next(f for f in functions if f.name == "createUser")
func = next(f for f in functions if f.function_name == "createUser")
context = ts_support.extract_code_context(func, temp_project, temp_project)
@ -1331,7 +1331,7 @@ function isOdd(n) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
is_even = next(f for f in functions if f.name == "isEven")
is_even = next(f for f in functions if f.function_name == "isEven")
context = js_support.extract_code_context(is_even, temp_project, temp_project)
@ -1393,7 +1393,7 @@ function collectAllValues(root) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
collect_func = next(f for f in functions if f.name == "collectAllValues")
collect_func = next(f for f in functions if f.function_name == "collectAllValues")
context = js_support.extract_code_context(collect_func, temp_project, temp_project)
@ -1458,7 +1458,7 @@ async function fetchUserProfile(userId) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
profile_func = next(f for f in functions if f.name == "fetchUserProfile")
profile_func = next(f for f in functions if f.function_name == "fetchUserProfile")
context = js_support.extract_code_context(profile_func, temp_project, temp_project)
@ -1513,7 +1513,7 @@ module.exports = { Counter };
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
increment_func = next(fn for fn in functions if fn.name == "increment")
increment_func = next(fn for fn in functions if fn.function_name == "increment")
# Step 1: Extract code context
context = js_support.extract_code_context(increment_func, temp_project, temp_project)
@ -1635,7 +1635,7 @@ function* fibonacci(limit) {
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
range_func = next(f for f in functions if f.name == "range")
range_func = next(f for f in functions if f.function_name == "range")
context = js_support.extract_code_context(range_func, temp_project, temp_project)
@ -1772,7 +1772,7 @@ class Calculator {
functions = js_support.discover_functions(file_path)
for func in functions:
if func.name != "constructor":
if func.function_name != "constructor":
context = js_support.extract_code_context(func, temp_project, temp_project)
is_valid = js_support.validate_syntax(context.target_code)
assert is_valid is True, f"Invalid syntax for {func.name}:\n{context.target_code}"

View file

@ -1715,7 +1715,7 @@ module.exports = { Calculator };
functions = js_support.discover_functions(source_file)
# Check qualified names include class
add_func = next((f for f in functions if f.name == "add"), None)
add_func = next((f for f in functions if f.function_name == "add"), None)
assert add_func is not None
assert add_func.class_name == "Calculator"

View file

@ -39,7 +39,7 @@ class TestCodeExtractorCJS:
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
method_names = {f.name for f in functions}
method_names = {f.function_name for f in functions}
expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"}
assert method_names == expected_methods, f"Expected methods {expected_methods}, got {method_names}"
@ -51,15 +51,15 @@ class TestCodeExtractorCJS:
for func in functions:
# All methods should belong to Calculator class
assert func.is_method is True, f"{func.name} should be a method"
assert func.class_name == "Calculator", f"{func.name} should belong to Calculator, got {func.class_name}"
assert func.is_method is True, f"{func.function_name} should be a method"
assert func.class_name == "Calculator", f"{func.function_name} should belong to Calculator, got {func.class_name}"
def test_extract_permutation_code(self, js_support, cjs_project):
"""Test permutation method code extraction."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
permutation_func = next(f for f in functions if f.name == "permutation")
permutation_func = next(f for f in functions if f.function_name == "permutation")
context = js_support.extract_code_context(
function=permutation_func, project_root=cjs_project, module_root=cjs_project
@ -95,7 +95,7 @@ class Calculator {
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
permutation_func = next(f for f in functions if f.name == "permutation")
permutation_func = next(f for f in functions if f.function_name == "permutation")
context = js_support.extract_code_context(
function=permutation_func, project_root=cjs_project, module_root=cjs_project
@ -136,7 +136,7 @@ function factorial(n) {
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = js_support.extract_code_context(
function=compound_func, project_root=cjs_project, module_root=cjs_project
@ -182,7 +182,7 @@ class Calculator {
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = js_support.extract_code_context(
function=compound_func, project_root=cjs_project, module_root=cjs_project
@ -266,7 +266,7 @@ function validateInput(value, name) {
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = js_support.extract_code_context(
function=compound_func, project_root=cjs_project, module_root=cjs_project
@ -287,7 +287,7 @@ function validateInput(value, name) {
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
quick_add_func = next(f for f in functions if f.name == "quickAdd")
quick_add_func = next(f for f in functions if f.function_name == "quickAdd")
context = js_support.extract_code_context(
function=quick_add_func, project_root=cjs_project, module_root=cjs_project
@ -352,7 +352,7 @@ class TestCodeExtractorESM:
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
method_names = {f.name for f in functions}
method_names = {f.function_name for f in functions}
# Should find same methods as CJS version
expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"}
@ -363,7 +363,7 @@ class TestCodeExtractorESM:
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
permutation_func = next(f for f in functions if f.name == "permutation")
permutation_func = next(f for f in functions if f.function_name == "permutation")
context = js_support.extract_code_context(
function=permutation_func, project_root=esm_project, module_root=esm_project
@ -413,7 +413,7 @@ export function factorial(n) {
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = js_support.extract_code_context(
function=compound_func, project_root=esm_project, module_root=esm_project
@ -539,7 +539,7 @@ class TestCodeExtractorTypeScript:
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file)
method_names = {f.name for f in functions}
method_names = {f.function_name for f in functions}
# TypeScript has additional getHistory method
expected_methods = {"calculateCompoundInterest", "permutation", "getHistory", "quickAdd"}
@ -550,7 +550,7 @@ class TestCodeExtractorTypeScript:
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file)
permutation_func = next(f for f in functions if f.name == "permutation")
permutation_func = next(f for f in functions if f.function_name == "permutation")
context = ts_support.extract_code_context(
function=permutation_func, project_root=ts_project, module_root=ts_project
@ -603,7 +603,7 @@ export function factorial(n: number): number {
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file)
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
context = ts_support.extract_code_context(
function=compound_func, project_root=ts_project, module_root=ts_project
@ -712,7 +712,7 @@ module.exports = { standalone };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
func = next(f for f in functions if f.name == "standalone")
func = next(f for f in functions if f.function_name == "standalone")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -745,7 +745,7 @@ module.exports = { processArray };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
func = next(f for f in functions if f.name == "processArray")
func = next(f for f in functions if f.function_name == "processArray")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -780,7 +780,7 @@ module.exports = { fibonacci };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
func = next(f for f in functions if f.name == "fibonacci")
func = next(f for f in functions if f.function_name == "fibonacci")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -813,7 +813,7 @@ module.exports = { processValue };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
func = next(f for f in functions if f.name == "processValue")
func = next(f for f in functions if f.function_name == "processValue")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -871,7 +871,7 @@ module.exports = { Counter };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
increment_func = next(f for f in functions if f.name == "increment")
increment_func = next(f for f in functions if f.function_name == "increment")
context = js_support.extract_code_context(function=increment_func, project_root=tmp_path, module_root=tmp_path)
@ -910,7 +910,7 @@ module.exports = { MathUtils };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
add_func = next(f for f in functions if f.name == "add")
add_func = next(f for f in functions if f.function_name == "add")
context = js_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
@ -948,7 +948,7 @@ export { User };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
get_name_func = next(f for f in functions if f.name == "getName")
get_name_func = next(f for f in functions if f.function_name == "getName")
context = ts_support.extract_code_context(function=get_name_func, project_root=tmp_path, module_root=tmp_path)
@ -989,7 +989,7 @@ export { Config };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
get_url_func = next(f for f in functions if f.name == "getUrl")
get_url_func = next(f for f in functions if f.function_name == "getUrl")
context = ts_support.extract_code_context(function=get_url_func, project_root=tmp_path, module_root=tmp_path)
@ -1030,7 +1030,7 @@ module.exports = { Logger };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
get_prefix_func = next(f for f in functions if f.name == "getPrefix")
get_prefix_func = next(f for f in functions if f.function_name == "getPrefix")
context = js_support.extract_code_context(function=get_prefix_func, project_root=tmp_path, module_root=tmp_path)
@ -1072,7 +1072,7 @@ module.exports = { Factory };
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
create_func = next(f for f in functions if f.name == "create")
create_func = next(f for f in functions if f.function_name == "create")
context = js_support.extract_code_context(function=create_func, project_root=tmp_path, module_root=tmp_path)
@ -1114,19 +1114,20 @@ class TestCodeExtractorIntegration:
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
target = next(f for f in functions if f.name == "permutation")
target = next(f for f in functions if f.function_name == "permutation")
parents = [FunctionParent(name=p.name, type=p.type) for p in target.parents]
func = FunctionToOptimize(
function_name=target.name,
function_name=target.function_name,
file_path=target.file_path,
parents=parents,
starting_line=target.start_line,
ending_line=target.end_line,
starting_col=target.start_col,
ending_col=target.end_col,
starting_line=target.starting_line,
ending_line=target.ending_line,
starting_col=target.starting_col,
ending_col=target.ending_col,
is_async=target.is_async,
is_method=target.is_method,
language=target.language,
)
@ -1223,7 +1224,7 @@ export { distance };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
distance_func = next(f for f in functions if f.name == "distance")
distance_func = next(f for f in functions if f.function_name == "distance")
context = ts_support.extract_code_context(function=distance_func, project_root=tmp_path, module_root=tmp_path)
@ -1267,7 +1268,7 @@ export { processStatus };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
process_func = next(f for f in functions if f.name == "processStatus")
process_func = next(f for f in functions if f.function_name == "processStatus")
context = ts_support.extract_code_context(function=process_func, project_root=tmp_path, module_root=tmp_path)
@ -1304,7 +1305,7 @@ export { compute };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
compute_func = next(f for f in functions if f.name == "compute")
compute_func = next(f for f in functions if f.function_name == "compute")
context = ts_support.extract_code_context(function=compute_func, project_root=tmp_path, module_root=tmp_path)
@ -1348,7 +1349,7 @@ export { Service };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
get_timeout_func = next(f for f in functions if f.name == "getTimeout")
get_timeout_func = next(f for f in functions if f.function_name == "getTimeout")
context = ts_support.extract_code_context(
function=get_timeout_func, project_root=tmp_path, module_root=tmp_path
@ -1381,7 +1382,7 @@ export { add };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
add_func = next(f for f in functions if f.name == "add")
add_func = next(f for f in functions if f.function_name == "add")
context = ts_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
@ -1414,7 +1415,7 @@ export { createRect };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
create_rect_func = next(f for f in functions if f.name == "createRect")
create_rect_func = next(f for f in functions if f.function_name == "createRect")
context = ts_support.extract_code_context(
function=create_rect_func, project_root=tmp_path, module_root=tmp_path
@ -1462,7 +1463,7 @@ export { calculateDistance };
""")
functions = ts_support.discover_functions(geometry_file)
calc_distance_func = next(f for f in functions if f.name == "calculateDistance")
calc_distance_func = next(f for f in functions if f.function_name == "calculateDistance")
context = ts_support.extract_code_context(
function=calc_distance_func, project_root=ts_types_project, module_root=ts_types_project
@ -1515,7 +1516,7 @@ export { greetUser };
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
greet_func = next(f for f in functions if f.name == "greetUser")
greet_func = next(f for f in functions if f.function_name == "greetUser")
context = ts_support.extract_code_context(function=greet_func, project_root=tmp_path, module_root=tmp_path)

View file

@ -711,7 +711,7 @@ module.exports = { targetFunction, otherFunction };
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
target_func = next(f for f in functions if f.name == "targetFunction")
target_func = next(f for f in functions if f.function_name == "targetFunction")
optimized_code = """\
function targetFunction(x) {
@ -763,7 +763,7 @@ class Calculator {
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
add_method = next(f for f in functions if f.name == "add")
add_method = next(f for f in functions if f.function_name == "add")
# Optimized version provided in class context
optimized_code = """\
@ -826,7 +826,7 @@ class DataProcessor {
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
process_method = next(f for f in functions if f.name == "process")
process_method = next(f for f in functions if f.function_name == "process")
optimized_code = """\
class DataProcessor {
@ -948,7 +948,7 @@ class Cache {
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
get_method = next(f for f in functions if f.name == "get")
get_method = next(f for f in functions if f.function_name == "get")
optimized_code = """\
class Cache {
@ -1050,7 +1050,7 @@ class ApiClient {
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
get_method = next(f for f in functions if f.name == "get")
get_method = next(f for f in functions if f.function_name == "get")
optimized_code = """\
class ApiClient {
@ -1181,7 +1181,7 @@ class Container<T> {
file_path.write_text(original_source, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
get_all_method = next(f for f in functions if f.name == "getAll")
get_all_method = next(f for f in functions if f.function_name == "getAll")
optimized_code = """\
class Container<T> {
@ -1234,7 +1234,7 @@ function createUser(name: string, email: string): User {
file_path.write_text(original_source, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
func = next(f for f in functions if f.name == "createUser")
func = next(f for f in functions if f.function_name == "createUser")
optimized_code = """\
function createUser(name: string, email: string): User {
@ -1289,7 +1289,7 @@ function processItems(items) {
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
process_func = next(f for f in functions if f.name == "processItems")
process_func = next(f for f in functions if f.function_name == "processItems")
optimized_code = """\
function processItems(items) {
@ -1336,7 +1336,7 @@ class MathUtils {
# First replacement: sum method
functions = js_support.discover_functions(file_path)
sum_method = next(f for f in functions if f.name == "sum")
sum_method = next(f for f in functions if f.function_name == "sum")
optimized_sum = """\
class MathUtils {
@ -1554,7 +1554,7 @@ module.exports = { main, helper };
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
main_func = next(f for f in functions if f.name == "main")
main_func = next(f for f in functions if f.function_name == "main")
optimized_code = """\
function main(data) {
@ -1597,7 +1597,7 @@ export function main(data) {
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
main_func = next(f for f in functions if f.name == "main")
main_func = next(f for f in functions if f.function_name == "main")
optimized_code = """\
export function main(data) {
@ -1756,7 +1756,7 @@ export class DataProcessor<T> {
# find function
target_func_info = None
for func in functions:
if func.name == target_func and func.parents[0].name == parent_class:
if func.function_name == target_func and func.parents[0].name == parent_class:
target_func_info = func
break
assert target_func_info is not None

View file

@ -113,19 +113,20 @@ def test_js_replcement() -> None:
functions = js_support.discover_functions(main_file)
target = None
for func in functions:
if func.name == "calculateStats":
if func.function_name == "calculateStats":
target = func
break
assert target is not None
func = FunctionToOptimize(
function_name=target.name,
function_name=target.function_name,
file_path=target.file_path,
parents=target.parents,
starting_line=target.start_line,
ending_line=target.end_line,
starting_col=target.start_col,
ending_col=target.end_col,
starting_line=target.starting_line,
ending_line=target.ending_line,
starting_col=target.starting_col,
ending_col=target.ending_col,
is_async=target.is_async,
is_method=target.is_method,
language=target.language,
)
test_config = TestConfig(

View file

@ -136,7 +136,7 @@ function add(a, b) {
functions = js_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].function_name == "add"
assert functions[0].name == "add"
assert functions[0].is_arrow is False
assert functions[0].is_async is False
assert functions[0].is_method is False
@ -151,7 +151,7 @@ const add = (a, b) => {
functions = js_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].function_name == "add"
assert functions[0].name == "add"
assert functions[0].is_arrow is True
def test_find_arrow_function_concise(self, js_analyzer):
@ -160,7 +160,7 @@ const add = (a, b) => {
functions = js_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].function_name == "double"
assert functions[0].name == "double"
assert functions[0].is_arrow is True
def test_find_async_function(self, js_analyzer):
@ -173,7 +173,7 @@ async function fetchData(url) {
functions = js_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].function_name == "fetchData"
assert functions[0].name == "fetchData"
assert functions[0].is_async is True
def test_find_class_methods(self, js_analyzer):
@ -188,7 +188,7 @@ class Calculator {
functions = js_analyzer.find_functions(code, include_methods=True)
assert len(functions) == 1
assert functions[0].function_name == "add"
assert functions[0].name == "add"
assert functions[0].is_method is True
assert functions[0].class_name == "Calculator"
@ -208,7 +208,7 @@ function standalone() {
functions = js_analyzer.find_functions(code, include_methods=False)
assert len(functions) == 1
assert functions[0].function_name == "standalone"
assert functions[0].name == "standalone"
def test_exclude_arrow_functions(self, js_analyzer):
"""Test excluding arrow functions."""
@ -222,7 +222,7 @@ const arrow = () => 2;
functions = js_analyzer.find_functions(code, include_arrow_functions=False)
assert len(functions) == 1
assert functions[0].function_name == "regular"
assert functions[0].name == "regular"
def test_find_generator_function(self, js_analyzer):
"""Test finding generator functions."""
@ -235,7 +235,7 @@ function* numberGenerator() {
functions = js_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].function_name == "numberGenerator"
assert functions[0].name == "numberGenerator"
assert functions[0].is_generator is True
def test_function_line_numbers(self, js_analyzer):
@ -291,7 +291,7 @@ function named() {
functions = js_analyzer.find_functions(code, require_name=True)
assert len(functions) == 1
assert functions[0].function_name == "named"
assert functions[0].name == "named"
def test_function_expression_in_variable(self, js_analyzer):
"""Test function expression assigned to variable."""
@ -303,7 +303,7 @@ const add = function(a, b) {
functions = js_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].function_name == "add"
assert functions[0].name == "add"
class TestFindImports:
@ -515,7 +515,7 @@ function add(a: number, b: number): number {
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].function_name == "add"
assert functions[0].name == "add"
def test_find_interface_method(self, ts_analyzer):
"""Test that interface methods are not found (they're declarations)."""
@ -544,4 +544,4 @@ function identity<T>(value: T): T {
functions = ts_analyzer.find_functions(code)
assert len(functions) == 1
assert functions[0].function_name == "identity"
assert functions[0].name == "identity"