fix: Use FunctionToOptimize field names consistently across JS code
- Fix field name mismatches: .name → .function_name, .start_line → .starting_line, .end_line → .ending_line, .start_col → .starting_col, .end_col → .ending_col - Fix circular imports by creating function_types.py with FunctionParent - Add lazy language registration via _ensure_languages_registered() - Fix macOS symlink path resolution in ImportResolver - Update all affected code and tests to use correct FunctionToOptimize attributes Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
ac682b81dd
commit
a5edb73b13
23 changed files with 224 additions and 199 deletions
|
|
@ -155,22 +155,21 @@ def get_package_install_command(project_root: Path, package: str, dev: bool = Tr
|
|||
if dev:
|
||||
cmd.append("--save-dev")
|
||||
return cmd
|
||||
elif pkg_manager == JsPackageManager.YARN:
|
||||
if pkg_manager == JsPackageManager.YARN:
|
||||
cmd = ["yarn", "add", package]
|
||||
if dev:
|
||||
cmd.append("--dev")
|
||||
return cmd
|
||||
elif pkg_manager == JsPackageManager.BUN:
|
||||
if pkg_manager == JsPackageManager.BUN:
|
||||
cmd = ["bun", "add", package]
|
||||
if dev:
|
||||
cmd.append("--dev")
|
||||
return cmd
|
||||
else:
|
||||
# Default to npm
|
||||
cmd = ["npm", "install", package]
|
||||
if dev:
|
||||
cmd.append("--save-dev")
|
||||
return cmd
|
||||
# Default to npm
|
||||
cmd = ["npm", "install", package]
|
||||
if dev:
|
||||
cmd.append("--save-dev")
|
||||
return cmd
|
||||
|
||||
|
||||
def init_js_project(language: ProjectLanguage) -> None:
|
||||
|
|
|
|||
|
|
@ -549,7 +549,7 @@ def replace_function_definitions_for_language(
|
|||
# Find the function in current code
|
||||
func = None
|
||||
for f in current_functions:
|
||||
if func_name in (f.qualified_name, f.name):
|
||||
if func_name in (f.qualified_name, f.function_name):
|
||||
func = f
|
||||
break
|
||||
|
||||
|
|
@ -557,7 +557,9 @@ def replace_function_definitions_for_language(
|
|||
continue
|
||||
|
||||
# Extract just this function from the optimized code
|
||||
optimized_func = _extract_function_from_code(lang_support, code_to_apply, func.name, module_abspath)
|
||||
optimized_func = _extract_function_from_code(
|
||||
lang_support, code_to_apply, func.function_name, module_abspath
|
||||
)
|
||||
if optimized_func:
|
||||
new_code = lang_support.replace_function(new_code, func, optimized_func)
|
||||
modified = True
|
||||
|
|
@ -596,13 +598,13 @@ def _extract_function_from_code(
|
|||
# file_path is needed for JS/TS to determine correct analyzer (TypeScript vs JavaScript)
|
||||
functions = lang_support.discover_functions_from_source(source_code, file_path)
|
||||
for func in functions:
|
||||
if func.name == function_name:
|
||||
if func.function_name == function_name:
|
||||
# Extract the function's source using line numbers
|
||||
# Use doc_start_line if available to include JSDoc/docstring
|
||||
lines = source_code.splitlines(keepends=True)
|
||||
effective_start = func.doc_start_line or func.start_line
|
||||
if effective_start and func.end_line and effective_start <= len(lines):
|
||||
func_lines = lines[effective_start - 1 : func.end_line]
|
||||
effective_start = func.doc_start_line or func.starting_line
|
||||
if effective_start and func.ending_line and effective_start <= len(lines):
|
||||
func_lines = lines[effective_start - 1 : func.ending_line]
|
||||
return "".join(func_lines)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting function {function_name}: {e}")
|
||||
|
|
|
|||
|
|
@ -335,25 +335,8 @@ def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list
|
|||
try:
|
||||
lang_support = get_language_support(file_path)
|
||||
criteria = FunctionFilterCriteria(require_return=True)
|
||||
function_infos = lang_support.discover_functions(file_path, criteria)
|
||||
|
||||
ftos = []
|
||||
for func_info in function_infos:
|
||||
parents = [FunctionParent(p.name, p.type) for p in func_info.parents]
|
||||
ftos.append(
|
||||
FunctionToOptimize(
|
||||
function_name=func_info.name,
|
||||
file_path=func_info.file_path,
|
||||
parents=parents,
|
||||
starting_line=func_info.start_line,
|
||||
ending_line=func_info.end_line,
|
||||
starting_col=func_info.start_col,
|
||||
ending_col=func_info.end_col,
|
||||
is_async=func_info.is_async,
|
||||
language=func_info.language.value,
|
||||
)
|
||||
)
|
||||
functions[file_path] = ftos
|
||||
# discover_functions already returns FunctionToOptimize objects
|
||||
functions[file_path] = lang_support.discover_functions(file_path, criteria)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to discover functions in {file_path}: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -26,17 +26,6 @@ from codeflash.languages.base import (
|
|||
TestInfo,
|
||||
TestResult,
|
||||
)
|
||||
|
||||
|
||||
# Lazy import for FunctionInfo to avoid circular imports
|
||||
def __getattr__(name: str):
|
||||
if name == "FunctionInfo":
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
return FunctionToOptimize
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
from codeflash.languages.current import (
|
||||
current_language,
|
||||
current_language_support,
|
||||
|
|
@ -46,11 +35,9 @@ from codeflash.languages.current import (
|
|||
reset_current_language,
|
||||
set_current_language,
|
||||
)
|
||||
from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401
|
||||
|
||||
# Import language support modules to trigger auto-registration
|
||||
# This ensures all supported languages are available when this package is imported
|
||||
from codeflash.languages.python import PythonSupport # noqa: F401
|
||||
# Language support modules are imported lazily to avoid circular imports
|
||||
# They get registered when first accessed via get_language_support()
|
||||
from codeflash.languages.registry import (
|
||||
detect_project_language,
|
||||
get_language_support,
|
||||
|
|
@ -70,6 +57,29 @@ from codeflash.languages.test_framework import (
|
|||
set_current_test_framework,
|
||||
)
|
||||
|
||||
|
||||
# Lazy imports to avoid circular imports
|
||||
def __getattr__(name: str):
|
||||
if name == "FunctionInfo":
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
return FunctionToOptimize
|
||||
if name == "JavaScriptSupport":
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport
|
||||
|
||||
return JavaScriptSupport
|
||||
if name == "TypeScriptSupport":
|
||||
from codeflash.languages.javascript.support import TypeScriptSupport
|
||||
|
||||
return TypeScriptSupport
|
||||
if name == "PythonSupport":
|
||||
from codeflash.languages.python.support import PythonSupport
|
||||
|
||||
return PythonSupport
|
||||
msg = f"module {__name__!r} has no attribute {name!r}"
|
||||
raise AttributeError(msg)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CodeContext",
|
||||
"FunctionInfo",
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
|||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
from codeflash.languages.language_enum import Language
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
|
||||
# Backward compatibility aliases - ParentInfo is now FunctionParent
|
||||
ParentInfo = FunctionParent
|
||||
|
|
@ -30,7 +30,8 @@ def __getattr__(name: str) -> Any:
|
|||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
return FunctionToOptimize
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
msg = f"module {__name__!r} has no attribute {name!r}"
|
||||
raise AttributeError(msg)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -15,16 +15,16 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from tree_sitter import Node
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -781,10 +781,7 @@ class ReferenceFinder:
|
|||
|
||||
"""
|
||||
path_str = str(file_path)
|
||||
for pattern in self.exclude_patterns:
|
||||
if pattern in path_str:
|
||||
return True
|
||||
return False
|
||||
return any(pattern in path_str for pattern in self.exclude_patterns)
|
||||
|
||||
def _read_file(self, file_path: Path) -> str | None:
|
||||
"""Read a file's contents with caching.
|
||||
|
|
|
|||
|
|
@ -44,7 +44,8 @@ class ImportResolver:
|
|||
project_root: Root directory of the project.
|
||||
|
||||
"""
|
||||
self.project_root = project_root
|
||||
# Resolve to real path to handle macOS symlinks like /var -> /private/var
|
||||
self.project_root = project_root.resolve()
|
||||
self._resolution_cache: dict[tuple[Path, str], Path | None] = {}
|
||||
|
||||
def resolve_import(self, import_info: ImportInfo, source_file: Path) -> ResolvedImport | None:
|
||||
|
|
@ -329,7 +330,7 @@ class MultiFileHelperFinder:
|
|||
all_functions = analyzer.find_functions(source, include_methods=True)
|
||||
target_func = None
|
||||
for func in all_functions:
|
||||
if func.name == function.name and func.start_line == function.start_line:
|
||||
if func.name == function.function_name and func.start_line == function.starting_line:
|
||||
target_func = func
|
||||
break
|
||||
|
||||
|
|
@ -506,7 +507,7 @@ class MultiFileHelperFinder:
|
|||
Dictionary mapping file paths to lists of helper functions.
|
||||
|
||||
"""
|
||||
from codeflash.languages.base import FunctionToOptimize
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
if context.current_depth >= context.max_depth:
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ from codeflash.cli_cmds.console import logger
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.code_utils.code_position import CodePosition
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
|
||||
class TestingMode:
|
||||
|
|
|
|||
|
|
@ -65,10 +65,10 @@ class JavaScriptLineProfiler:
|
|||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Process functions in reverse order to preserve line numbers
|
||||
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
|
||||
for func in sorted(functions, key=lambda f: f.starting_line, reverse=True):
|
||||
func_lines = self._instrument_function(func, lines, file_path)
|
||||
start_idx = func.start_line - 1
|
||||
end_idx = func.end_line
|
||||
start_idx = func.starting_line - 1
|
||||
end_idx = func.ending_line
|
||||
lines = lines[:start_idx] + func_lines + lines[end_idx:]
|
||||
|
||||
instrumented_source = "".join(lines)
|
||||
|
|
@ -183,7 +183,7 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
|
|||
Instrumented function lines.
|
||||
|
||||
"""
|
||||
func_lines = lines[func.start_line - 1 : func.end_line]
|
||||
func_lines = lines[func.starting_line - 1 : func.ending_line]
|
||||
instrumented_lines = []
|
||||
|
||||
# Parse the function to find executable lines
|
||||
|
|
@ -194,7 +194,7 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
|
|||
tree = analyzer.parse(source.encode("utf8"))
|
||||
executable_lines = self._find_executable_lines(tree.root_node, source.encode("utf8"))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse function %s: %s", func.name, e)
|
||||
logger.warning("Failed to parse function %s: %s", func.function_name, e)
|
||||
return func_lines
|
||||
|
||||
# Add profiling to each executable line
|
||||
|
|
@ -203,7 +203,7 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
|
|||
|
||||
for local_idx, line in enumerate(func_lines):
|
||||
local_line_num = local_idx + 1 # 1-indexed within function
|
||||
global_line_num = func.start_line + local_idx # Global line number in original file
|
||||
global_line_num = func.starting_line + local_idx # Global line number in original file
|
||||
stripped = line.strip()
|
||||
|
||||
# Add enterFunction() call after the opening brace of the function
|
||||
|
|
|
|||
|
|
@ -241,7 +241,7 @@ class JavaScriptSupport:
|
|||
|
||||
# Match source functions to tests
|
||||
for func in source_functions:
|
||||
if func.name in imported_names or func.name in source:
|
||||
if func.function_name in imported_names or func.function_name in source:
|
||||
if func.qualified_name not in result:
|
||||
result[func.qualified_name] = []
|
||||
for test_name in test_functions:
|
||||
|
|
|
|||
|
|
@ -63,13 +63,7 @@ def _find_jest_config(project_root: Path) -> Path | None:
|
|||
|
||||
"""
|
||||
# Common Jest config file names, in order of preference
|
||||
config_names = [
|
||||
"jest.config.ts",
|
||||
"jest.config.js",
|
||||
"jest.config.mjs",
|
||||
"jest.config.cjs",
|
||||
"jest.config.json",
|
||||
]
|
||||
config_names = ["jest.config.ts", "jest.config.js", "jest.config.mjs", "jest.config.cjs", "jest.config.json"]
|
||||
|
||||
# First check the project root itself
|
||||
for config_name in config_names:
|
||||
|
|
@ -226,14 +220,7 @@ def _ensure_runtime_files(project_root: Path) -> None:
|
|||
|
||||
install_cmd = get_package_install_command(project_root, "codeflash", dev=True)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
install_cmd,
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
result = subprocess.run(install_cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=120)
|
||||
if result.returncode == 0:
|
||||
logger.debug(f"Installed codeflash using {install_cmd[0]}")
|
||||
return
|
||||
|
|
|
|||
|
|
@ -64,10 +64,10 @@ class JavaScriptTracer:
|
|||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Process functions in reverse order to preserve line numbers
|
||||
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
|
||||
for func in sorted(functions, key=lambda f: f.starting_line, reverse=True):
|
||||
instrumented = self._instrument_function(func, lines, file_path)
|
||||
start_idx = func.start_line - 1
|
||||
end_idx = func.end_line
|
||||
start_idx = func.starting_line - 1
|
||||
end_idx = func.ending_line
|
||||
lines = lines[:start_idx] + instrumented + lines[end_idx:]
|
||||
|
||||
instrumented_source = "".join(lines)
|
||||
|
|
@ -281,11 +281,11 @@ process.on('exit', () => {{
|
|||
Instrumented function lines.
|
||||
|
||||
"""
|
||||
func_lines = lines[func.start_line - 1 : func.end_line]
|
||||
func_lines = lines[func.starting_line - 1 : func.ending_line]
|
||||
func_text = "".join(func_lines)
|
||||
|
||||
# Detect function pattern
|
||||
func_name = func.name
|
||||
func_name = func.function_name
|
||||
is_arrow = "=>" in func_text.split("\n")[0]
|
||||
is_method = func.is_method
|
||||
is_async = func.is_async
|
||||
|
|
|
|||
|
|
@ -86,14 +86,7 @@ def _ensure_runtime_files(project_root: Path) -> None:
|
|||
|
||||
install_cmd = get_package_install_command(project_root, "codeflash", dev=True)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
install_cmd,
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
result = subprocess.run(install_cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=120)
|
||||
if result.returncode == 0:
|
||||
logger.debug(f"Installed codeflash using {install_cmd[0]}")
|
||||
return
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ class PythonSupport:
|
|||
try:
|
||||
source = test_file.read_text()
|
||||
# Check if function name appears in test file
|
||||
if func.name in source:
|
||||
if func.function_name in source:
|
||||
result[func.qualified_name].append(
|
||||
TestInfo(test_name=test_file.stem, test_file=test_file, test_class=None)
|
||||
)
|
||||
|
|
@ -289,7 +289,7 @@ class PythonSupport:
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to find helpers for %s: %s", function.name, e)
|
||||
logger.warning("Failed to find helpers for %s: %s", function.function_name, e)
|
||||
|
||||
return helpers
|
||||
|
||||
|
|
@ -389,10 +389,10 @@ class PythonSupport:
|
|||
line=ref.line,
|
||||
column=ref.column,
|
||||
end_line=ref.line,
|
||||
end_column=ref.column + len(function.name),
|
||||
end_column=ref.column + len(function.function_name),
|
||||
context=context.strip(),
|
||||
reference_type="call",
|
||||
import_name=function.name,
|
||||
import_name=function.function_name,
|
||||
caller_function=caller_function,
|
||||
)
|
||||
)
|
||||
|
|
@ -400,7 +400,7 @@ class PythonSupport:
|
|||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to find references for %s: %s", function.name, e)
|
||||
logger.warning("Failed to find references for %s: %s", function.function_name, e)
|
||||
return []
|
||||
|
||||
# === Code Transformation ===
|
||||
|
|
@ -433,7 +433,7 @@ class PythonSupport:
|
|||
preexisting_objects=set(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to replace function %s: %s", function.name, e)
|
||||
logger.warning("Failed to replace function %s: %s", function.function_name, e)
|
||||
return source
|
||||
|
||||
def format_code(self, source: str, file_path: Path | None = None) -> str:
|
||||
|
|
|
|||
|
|
@ -30,6 +30,33 @@ _LANGUAGE_REGISTRY: dict[Language, type[LanguageSupport]] = {}
|
|||
# Cache of instantiated language support objects
|
||||
_SUPPORT_CACHE: dict[Language, LanguageSupport] = {}
|
||||
|
||||
# Flag to track if language modules have been imported
|
||||
_languages_registered = False
|
||||
|
||||
|
||||
def _ensure_languages_registered() -> None:
|
||||
"""Ensure all language support modules are imported and registered.
|
||||
|
||||
This lazily imports the language support modules to avoid circular imports
|
||||
at module load time. The imports trigger the @register_language decorators
|
||||
which populate the registries.
|
||||
"""
|
||||
global _languages_registered
|
||||
if _languages_registered:
|
||||
return
|
||||
|
||||
# Import support modules to trigger registration
|
||||
# These imports are deferred to avoid circular imports
|
||||
import contextlib
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
from codeflash.languages.python import support as _
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
from codeflash.languages.javascript import support as _ # noqa: F401
|
||||
|
||||
_languages_registered = True
|
||||
|
||||
|
||||
class UnsupportedLanguageError(Exception):
|
||||
"""Raised when attempting to use an unsupported language."""
|
||||
|
|
@ -123,6 +150,10 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport:
|
|||
Raises:
|
||||
UnsupportedLanguageError: If the language is not supported.
|
||||
|
||||
Note:
|
||||
This function lazily imports language support modules on first call
|
||||
to avoid circular import issues at module load time.
|
||||
|
||||
Example:
|
||||
# By file path
|
||||
lang = get_language_support(Path("example.py"))
|
||||
|
|
@ -137,6 +168,7 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport:
|
|||
lang = get_language_support("python")
|
||||
|
||||
"""
|
||||
_ensure_languages_registered()
|
||||
language: Language | None = None
|
||||
|
||||
if isinstance(identifier, Language):
|
||||
|
|
@ -179,6 +211,7 @@ _FRAMEWORK_CACHE: dict[str, LanguageSupport] = {}
|
|||
|
||||
|
||||
def get_language_support_by_common_formatters(formatter_cmd: str | list[str]) -> LanguageSupport | None:
|
||||
_ensure_languages_registered()
|
||||
language: Language | None = None
|
||||
if isinstance(formatter_cmd, str):
|
||||
formatter_cmd = [formatter_cmd]
|
||||
|
|
@ -263,6 +296,7 @@ def detect_project_language(project_root: Path, module_root: Path) -> Language:
|
|||
UnsupportedLanguageError: If no supported language is detected.
|
||||
|
||||
"""
|
||||
_ensure_languages_registered()
|
||||
extension_counts: dict[str, int] = {}
|
||||
|
||||
# Count files by extension
|
||||
|
|
@ -290,6 +324,7 @@ def get_supported_languages() -> list[str]:
|
|||
List of language name strings.
|
||||
|
||||
"""
|
||||
_ensure_languages_registered()
|
||||
return [lang.value for lang in _LANGUAGE_REGISTRY]
|
||||
|
||||
|
||||
|
|
@ -300,6 +335,7 @@ def get_supported_extensions() -> list[str]:
|
|||
List of extension strings (with leading dots).
|
||||
|
||||
"""
|
||||
_ensure_languages_registered()
|
||||
return list(_EXTENSION_REGISTRY.keys())
|
||||
|
||||
|
||||
|
|
@ -325,10 +361,12 @@ def clear_registry() -> None:
|
|||
|
||||
Primarily useful for testing.
|
||||
"""
|
||||
global _languages_registered
|
||||
_EXTENSION_REGISTRY.clear()
|
||||
_LANGUAGE_REGISTRY.clear()
|
||||
_SUPPORT_CACHE.clear()
|
||||
_FRAMEWORK_CACHE.clear()
|
||||
_languages_registered = False
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
|
|
|
|||
18
codeflash/models/function_types.py
Normal file
18
codeflash/models/function_types.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Simple function-related types with no dependencies.
|
||||
|
||||
This module contains basic types used for function representation.
|
||||
It is intentionally kept dependency-free to avoid circular imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionParent:
|
||||
name: str
|
||||
type: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.type}:{self.name}"
|
||||
|
|
@ -598,13 +598,8 @@ class CodePosition:
|
|||
col_no: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionParent:
|
||||
name: str
|
||||
type: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.type}:{self.name}"
|
||||
# Re-export FunctionParent for backward compatibility
|
||||
from codeflash.models.function_types import FunctionParent # noqa: E402
|
||||
|
||||
|
||||
class OriginalCodeBaseline(BaseModel):
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ const multiply = (a, b) => a * b;
|
|||
functions = js_support.discover_functions(file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.name == "multiply"
|
||||
assert func.function_name == "multiply"
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -268,7 +268,7 @@ class CacheManager {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
get_or_compute = next(f for f in functions if f.name == "getOrCompute")
|
||||
get_or_compute = next(f for f in functions if f.function_name == "getOrCompute")
|
||||
|
||||
context = js_support.extract_code_context(get_or_compute, temp_project, temp_project)
|
||||
|
||||
|
|
@ -370,7 +370,7 @@ function validateUserData(data, validators) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
func = next(f for f in functions if f.name == "validateUserData")
|
||||
func = next(f for f in functions if f.function_name == "validateUserData")
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -466,7 +466,7 @@ async function fetchWithRetry(endpoint, options = {}) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
func = next(f for f in functions if f.name == "fetchWithRetry")
|
||||
func = next(f for f in functions if f.function_name == "fetchWithRetry")
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -615,7 +615,7 @@ function processUserInput(rawInput) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
process_func = next(f for f in functions if f.name == "processUserInput")
|
||||
process_func = next(f for f in functions if f.function_name == "processUserInput")
|
||||
|
||||
context = js_support.extract_code_context(process_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -670,7 +670,7 @@ function generateReport(data) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
report_func = next(f for f in functions if f.name == "generateReport")
|
||||
report_func = next(f for f in functions if f.function_name == "generateReport")
|
||||
|
||||
context = js_support.extract_code_context(report_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -768,7 +768,7 @@ class Graph {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
topo_sort = next(f for f in functions if f.name == "topologicalSort")
|
||||
topo_sort = next(f for f in functions if f.function_name == "topologicalSort")
|
||||
|
||||
context = js_support.extract_code_context(topo_sort, temp_project, temp_project)
|
||||
|
||||
|
|
@ -843,7 +843,7 @@ class MainClass {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
main_method = next(f for f in functions if f.name == "mainMethod" and f.class_name == "MainClass")
|
||||
main_method = next(f for f in functions if f.function_name == "mainMethod" and f.class_name == "MainClass")
|
||||
|
||||
context = js_support.extract_code_context(main_method, temp_project, temp_project)
|
||||
|
||||
|
|
@ -899,7 +899,7 @@ module.exports = { sortFromAnotherFile };
|
|||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(main_path)
|
||||
main_func = next(f for f in functions if f.name == "sortFromAnotherFile")
|
||||
main_func = next(f for f in functions if f.function_name == "sortFromAnotherFile")
|
||||
|
||||
context = js_support.extract_code_context(main_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -952,7 +952,7 @@ export { processNumber };
|
|||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(main_path)
|
||||
process_func = next(f for f in functions if f.name == "processNumber")
|
||||
process_func = next(f for f in functions if f.function_name == "processNumber")
|
||||
|
||||
context = js_support.extract_code_context(process_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1020,7 +1020,7 @@ export { handleUserInput };
|
|||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(main_path)
|
||||
handle_func = next(f for f in functions if f.name == "handleUserInput")
|
||||
handle_func = next(f for f in functions if f.function_name == "handleUserInput")
|
||||
|
||||
context = js_support.extract_code_context(handle_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1161,7 +1161,7 @@ class TypedCache<T> {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
get_method = next(f for f in functions if f.name == "get")
|
||||
get_method = next(f for f in functions if f.function_name == "get")
|
||||
|
||||
context = ts_support.extract_code_context(get_method, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1247,7 +1247,7 @@ export { createUser };
|
|||
service_path.write_text(service_code, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(service_path)
|
||||
func = next(f for f in functions if f.name == "createUser")
|
||||
func = next(f for f in functions if f.function_name == "createUser")
|
||||
|
||||
context = ts_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1331,7 +1331,7 @@ function isOdd(n) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
is_even = next(f for f in functions if f.name == "isEven")
|
||||
is_even = next(f for f in functions if f.function_name == "isEven")
|
||||
|
||||
context = js_support.extract_code_context(is_even, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1393,7 +1393,7 @@ function collectAllValues(root) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
collect_func = next(f for f in functions if f.name == "collectAllValues")
|
||||
collect_func = next(f for f in functions if f.function_name == "collectAllValues")
|
||||
|
||||
context = js_support.extract_code_context(collect_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1458,7 +1458,7 @@ async function fetchUserProfile(userId) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
profile_func = next(f for f in functions if f.name == "fetchUserProfile")
|
||||
profile_func = next(f for f in functions if f.function_name == "fetchUserProfile")
|
||||
|
||||
context = js_support.extract_code_context(profile_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1513,7 +1513,7 @@ module.exports = { Counter };
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
increment_func = next(fn for fn in functions if fn.name == "increment")
|
||||
increment_func = next(fn for fn in functions if fn.function_name == "increment")
|
||||
|
||||
# Step 1: Extract code context
|
||||
context = js_support.extract_code_context(increment_func, temp_project, temp_project)
|
||||
|
|
@ -1635,7 +1635,7 @@ function* fibonacci(limit) {
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
range_func = next(f for f in functions if f.name == "range")
|
||||
range_func = next(f for f in functions if f.function_name == "range")
|
||||
|
||||
context = js_support.extract_code_context(range_func, temp_project, temp_project)
|
||||
|
||||
|
|
@ -1772,7 +1772,7 @@ class Calculator {
|
|||
functions = js_support.discover_functions(file_path)
|
||||
|
||||
for func in functions:
|
||||
if func.name != "constructor":
|
||||
if func.function_name != "constructor":
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
is_valid = js_support.validate_syntax(context.target_code)
|
||||
assert is_valid is True, f"Invalid syntax for {func.name}:\n{context.target_code}"
|
||||
|
|
|
|||
|
|
@ -1715,7 +1715,7 @@ module.exports = { Calculator };
|
|||
functions = js_support.discover_functions(source_file)
|
||||
|
||||
# Check qualified names include class
|
||||
add_func = next((f for f in functions if f.name == "add"), None)
|
||||
add_func = next((f for f in functions if f.function_name == "add"), None)
|
||||
assert add_func is not None
|
||||
assert add_func.class_name == "Calculator"
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class TestCodeExtractorCJS:
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
method_names = {f.name for f in functions}
|
||||
method_names = {f.function_name for f in functions}
|
||||
|
||||
expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"}
|
||||
assert method_names == expected_methods, f"Expected methods {expected_methods}, got {method_names}"
|
||||
|
|
@ -51,15 +51,15 @@ class TestCodeExtractorCJS:
|
|||
|
||||
for func in functions:
|
||||
# All methods should belong to Calculator class
|
||||
assert func.is_method is True, f"{func.name} should be a method"
|
||||
assert func.class_name == "Calculator", f"{func.name} should belong to Calculator, got {func.class_name}"
|
||||
assert func.is_method is True, f"{func.function_name} should be a method"
|
||||
assert func.class_name == "Calculator", f"{func.function_name} should belong to Calculator, got {func.class_name}"
|
||||
|
||||
def test_extract_permutation_code(self, js_support, cjs_project):
|
||||
"""Test permutation method code extraction."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.name == "permutation")
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=permutation_func, project_root=cjs_project, module_root=cjs_project
|
||||
|
|
@ -95,7 +95,7 @@ class Calculator {
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.name == "permutation")
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=permutation_func, project_root=cjs_project, module_root=cjs_project
|
||||
|
|
@ -136,7 +136,7 @@ function factorial(n) {
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=compound_func, project_root=cjs_project, module_root=cjs_project
|
||||
|
|
@ -182,7 +182,7 @@ class Calculator {
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=compound_func, project_root=cjs_project, module_root=cjs_project
|
||||
|
|
@ -266,7 +266,7 @@ function validateInput(value, name) {
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=compound_func, project_root=cjs_project, module_root=cjs_project
|
||||
|
|
@ -287,7 +287,7 @@ function validateInput(value, name) {
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
quick_add_func = next(f for f in functions if f.name == "quickAdd")
|
||||
quick_add_func = next(f for f in functions if f.function_name == "quickAdd")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=quick_add_func, project_root=cjs_project, module_root=cjs_project
|
||||
|
|
@ -352,7 +352,7 @@ class TestCodeExtractorESM:
|
|||
calculator_file = esm_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
method_names = {f.name for f in functions}
|
||||
method_names = {f.function_name for f in functions}
|
||||
|
||||
# Should find same methods as CJS version
|
||||
expected_methods = {"calculateCompoundInterest", "permutation", "quickAdd"}
|
||||
|
|
@ -363,7 +363,7 @@ class TestCodeExtractorESM:
|
|||
calculator_file = esm_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.name == "permutation")
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=permutation_func, project_root=esm_project, module_root=esm_project
|
||||
|
|
@ -413,7 +413,7 @@ export function factorial(n) {
|
|||
calculator_file = esm_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=compound_func, project_root=esm_project, module_root=esm_project
|
||||
|
|
@ -539,7 +539,7 @@ class TestCodeExtractorTypeScript:
|
|||
calculator_file = ts_project / "calculator.ts"
|
||||
functions = ts_support.discover_functions(calculator_file)
|
||||
|
||||
method_names = {f.name for f in functions}
|
||||
method_names = {f.function_name for f in functions}
|
||||
|
||||
# TypeScript has additional getHistory method
|
||||
expected_methods = {"calculateCompoundInterest", "permutation", "getHistory", "quickAdd"}
|
||||
|
|
@ -550,7 +550,7 @@ class TestCodeExtractorTypeScript:
|
|||
calculator_file = ts_project / "calculator.ts"
|
||||
functions = ts_support.discover_functions(calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.name == "permutation")
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=permutation_func, project_root=ts_project, module_root=ts_project
|
||||
|
|
@ -603,7 +603,7 @@ export function factorial(n: number): number {
|
|||
calculator_file = ts_project / "calculator.ts"
|
||||
functions = ts_support.discover_functions(calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.name == "calculateCompoundInterest")
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=compound_func, project_root=ts_project, module_root=ts_project
|
||||
|
|
@ -712,7 +712,7 @@ module.exports = { standalone };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
func = next(f for f in functions if f.name == "standalone")
|
||||
func = next(f for f in functions if f.function_name == "standalone")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -745,7 +745,7 @@ module.exports = { processArray };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
func = next(f for f in functions if f.name == "processArray")
|
||||
func = next(f for f in functions if f.function_name == "processArray")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -780,7 +780,7 @@ module.exports = { fibonacci };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
func = next(f for f in functions if f.name == "fibonacci")
|
||||
func = next(f for f in functions if f.function_name == "fibonacci")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -813,7 +813,7 @@ module.exports = { processValue };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
func = next(f for f in functions if f.name == "processValue")
|
||||
func = next(f for f in functions if f.function_name == "processValue")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -871,7 +871,7 @@ module.exports = { Counter };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
increment_func = next(f for f in functions if f.name == "increment")
|
||||
increment_func = next(f for f in functions if f.function_name == "increment")
|
||||
|
||||
context = js_support.extract_code_context(function=increment_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -910,7 +910,7 @@ module.exports = { MathUtils };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
add_func = next(f for f in functions if f.name == "add")
|
||||
add_func = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = js_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -948,7 +948,7 @@ export { User };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
get_name_func = next(f for f in functions if f.name == "getName")
|
||||
get_name_func = next(f for f in functions if f.function_name == "getName")
|
||||
|
||||
context = ts_support.extract_code_context(function=get_name_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -989,7 +989,7 @@ export { Config };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
get_url_func = next(f for f in functions if f.name == "getUrl")
|
||||
get_url_func = next(f for f in functions if f.function_name == "getUrl")
|
||||
|
||||
context = ts_support.extract_code_context(function=get_url_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1030,7 +1030,7 @@ module.exports = { Logger };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
get_prefix_func = next(f for f in functions if f.name == "getPrefix")
|
||||
get_prefix_func = next(f for f in functions if f.function_name == "getPrefix")
|
||||
|
||||
context = js_support.extract_code_context(function=get_prefix_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1072,7 +1072,7 @@ module.exports = { Factory };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
create_func = next(f for f in functions if f.name == "create")
|
||||
create_func = next(f for f in functions if f.function_name == "create")
|
||||
|
||||
context = js_support.extract_code_context(function=create_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1114,19 +1114,20 @@ class TestCodeExtractorIntegration:
|
|||
calculator_file = cjs_project / "calculator.js"
|
||||
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
target = next(f for f in functions if f.name == "permutation")
|
||||
target = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
parents = [FunctionParent(name=p.name, type=p.type) for p in target.parents]
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name=target.name,
|
||||
function_name=target.function_name,
|
||||
file_path=target.file_path,
|
||||
parents=parents,
|
||||
starting_line=target.start_line,
|
||||
ending_line=target.end_line,
|
||||
starting_col=target.start_col,
|
||||
ending_col=target.end_col,
|
||||
starting_line=target.starting_line,
|
||||
ending_line=target.ending_line,
|
||||
starting_col=target.starting_col,
|
||||
ending_col=target.ending_col,
|
||||
is_async=target.is_async,
|
||||
is_method=target.is_method,
|
||||
language=target.language,
|
||||
)
|
||||
|
||||
|
|
@ -1223,7 +1224,7 @@ export { distance };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
distance_func = next(f for f in functions if f.name == "distance")
|
||||
distance_func = next(f for f in functions if f.function_name == "distance")
|
||||
|
||||
context = ts_support.extract_code_context(function=distance_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1267,7 +1268,7 @@ export { processStatus };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
process_func = next(f for f in functions if f.name == "processStatus")
|
||||
process_func = next(f for f in functions if f.function_name == "processStatus")
|
||||
|
||||
context = ts_support.extract_code_context(function=process_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1304,7 +1305,7 @@ export { compute };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
compute_func = next(f for f in functions if f.name == "compute")
|
||||
compute_func = next(f for f in functions if f.function_name == "compute")
|
||||
|
||||
context = ts_support.extract_code_context(function=compute_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1348,7 +1349,7 @@ export { Service };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
get_timeout_func = next(f for f in functions if f.name == "getTimeout")
|
||||
get_timeout_func = next(f for f in functions if f.function_name == "getTimeout")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=get_timeout_func, project_root=tmp_path, module_root=tmp_path
|
||||
|
|
@ -1381,7 +1382,7 @@ export { add };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
add_func = next(f for f in functions if f.name == "add")
|
||||
add_func = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = ts_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
@ -1414,7 +1415,7 @@ export { createRect };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
create_rect_func = next(f for f in functions if f.name == "createRect")
|
||||
create_rect_func = next(f for f in functions if f.function_name == "createRect")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=create_rect_func, project_root=tmp_path, module_root=tmp_path
|
||||
|
|
@ -1462,7 +1463,7 @@ export { calculateDistance };
|
|||
""")
|
||||
|
||||
functions = ts_support.discover_functions(geometry_file)
|
||||
calc_distance_func = next(f for f in functions if f.name == "calculateDistance")
|
||||
calc_distance_func = next(f for f in functions if f.function_name == "calculateDistance")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=calc_distance_func, project_root=ts_types_project, module_root=ts_types_project
|
||||
|
|
@ -1515,7 +1516,7 @@ export { greetUser };
|
|||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
greet_func = next(f for f in functions if f.name == "greetUser")
|
||||
greet_func = next(f for f in functions if f.function_name == "greetUser")
|
||||
|
||||
context = ts_support.extract_code_context(function=greet_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
|
|
|
|||
|
|
@ -711,7 +711,7 @@ module.exports = { targetFunction, otherFunction };
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
target_func = next(f for f in functions if f.name == "targetFunction")
|
||||
target_func = next(f for f in functions if f.function_name == "targetFunction")
|
||||
|
||||
optimized_code = """\
|
||||
function targetFunction(x) {
|
||||
|
|
@ -763,7 +763,7 @@ class Calculator {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_method = next(f for f in functions if f.name == "add")
|
||||
add_method = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
# Optimized version provided in class context
|
||||
optimized_code = """\
|
||||
|
|
@ -826,7 +826,7 @@ class DataProcessor {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
process_method = next(f for f in functions if f.name == "process")
|
||||
process_method = next(f for f in functions if f.function_name == "process")
|
||||
|
||||
optimized_code = """\
|
||||
class DataProcessor {
|
||||
|
|
@ -948,7 +948,7 @@ class Cache {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
get_method = next(f for f in functions if f.name == "get")
|
||||
get_method = next(f for f in functions if f.function_name == "get")
|
||||
|
||||
optimized_code = """\
|
||||
class Cache {
|
||||
|
|
@ -1050,7 +1050,7 @@ class ApiClient {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
get_method = next(f for f in functions if f.name == "get")
|
||||
get_method = next(f for f in functions if f.function_name == "get")
|
||||
|
||||
optimized_code = """\
|
||||
class ApiClient {
|
||||
|
|
@ -1181,7 +1181,7 @@ class Container<T> {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
get_all_method = next(f for f in functions if f.name == "getAll")
|
||||
get_all_method = next(f for f in functions if f.function_name == "getAll")
|
||||
|
||||
optimized_code = """\
|
||||
class Container<T> {
|
||||
|
|
@ -1234,7 +1234,7 @@ function createUser(name: string, email: string): User {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
func = next(f for f in functions if f.name == "createUser")
|
||||
func = next(f for f in functions if f.function_name == "createUser")
|
||||
|
||||
optimized_code = """\
|
||||
function createUser(name: string, email: string): User {
|
||||
|
|
@ -1289,7 +1289,7 @@ function processItems(items) {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
process_func = next(f for f in functions if f.name == "processItems")
|
||||
process_func = next(f for f in functions if f.function_name == "processItems")
|
||||
|
||||
optimized_code = """\
|
||||
function processItems(items) {
|
||||
|
|
@ -1336,7 +1336,7 @@ class MathUtils {
|
|||
|
||||
# First replacement: sum method
|
||||
functions = js_support.discover_functions(file_path)
|
||||
sum_method = next(f for f in functions if f.name == "sum")
|
||||
sum_method = next(f for f in functions if f.function_name == "sum")
|
||||
|
||||
optimized_sum = """\
|
||||
class MathUtils {
|
||||
|
|
@ -1554,7 +1554,7 @@ module.exports = { main, helper };
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
main_func = next(f for f in functions if f.name == "main")
|
||||
main_func = next(f for f in functions if f.function_name == "main")
|
||||
|
||||
optimized_code = """\
|
||||
function main(data) {
|
||||
|
|
@ -1597,7 +1597,7 @@ export function main(data) {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
main_func = next(f for f in functions if f.name == "main")
|
||||
main_func = next(f for f in functions if f.function_name == "main")
|
||||
|
||||
optimized_code = """\
|
||||
export function main(data) {
|
||||
|
|
@ -1756,7 +1756,7 @@ export class DataProcessor<T> {
|
|||
# find function
|
||||
target_func_info = None
|
||||
for func in functions:
|
||||
if func.name == target_func and func.parents[0].name == parent_class:
|
||||
if func.function_name == target_func and func.parents[0].name == parent_class:
|
||||
target_func_info = func
|
||||
break
|
||||
assert target_func_info is not None
|
||||
|
|
|
|||
|
|
@ -113,19 +113,20 @@ def test_js_replcement() -> None:
|
|||
functions = js_support.discover_functions(main_file)
|
||||
target = None
|
||||
for func in functions:
|
||||
if func.name == "calculateStats":
|
||||
if func.function_name == "calculateStats":
|
||||
target = func
|
||||
break
|
||||
assert target is not None
|
||||
func = FunctionToOptimize(
|
||||
function_name=target.name,
|
||||
function_name=target.function_name,
|
||||
file_path=target.file_path,
|
||||
parents=target.parents,
|
||||
starting_line=target.start_line,
|
||||
ending_line=target.end_line,
|
||||
starting_col=target.start_col,
|
||||
ending_col=target.end_col,
|
||||
starting_line=target.starting_line,
|
||||
ending_line=target.ending_line,
|
||||
starting_col=target.starting_col,
|
||||
ending_col=target.ending_col,
|
||||
is_async=target.is_async,
|
||||
is_method=target.is_method,
|
||||
language=target.language,
|
||||
)
|
||||
test_config = TestConfig(
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ function add(a, b) {
|
|||
functions = js_analyzer.find_functions(code)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "add"
|
||||
assert functions[0].name == "add"
|
||||
assert functions[0].is_arrow is False
|
||||
assert functions[0].is_async is False
|
||||
assert functions[0].is_method is False
|
||||
|
|
@ -151,7 +151,7 @@ const add = (a, b) => {
|
|||
functions = js_analyzer.find_functions(code)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "add"
|
||||
assert functions[0].name == "add"
|
||||
assert functions[0].is_arrow is True
|
||||
|
||||
def test_find_arrow_function_concise(self, js_analyzer):
|
||||
|
|
@ -160,7 +160,7 @@ const add = (a, b) => {
|
|||
functions = js_analyzer.find_functions(code)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "double"
|
||||
assert functions[0].name == "double"
|
||||
assert functions[0].is_arrow is True
|
||||
|
||||
def test_find_async_function(self, js_analyzer):
|
||||
|
|
@ -173,7 +173,7 @@ async function fetchData(url) {
|
|||
functions = js_analyzer.find_functions(code)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "fetchData"
|
||||
assert functions[0].name == "fetchData"
|
||||
assert functions[0].is_async is True
|
||||
|
||||
def test_find_class_methods(self, js_analyzer):
|
||||
|
|
@ -188,7 +188,7 @@ class Calculator {
|
|||
functions = js_analyzer.find_functions(code, include_methods=True)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "add"
|
||||
assert functions[0].name == "add"
|
||||
assert functions[0].is_method is True
|
||||
assert functions[0].class_name == "Calculator"
|
||||
|
||||
|
|
@ -208,7 +208,7 @@ function standalone() {
|
|||
functions = js_analyzer.find_functions(code, include_methods=False)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "standalone"
|
||||
assert functions[0].name == "standalone"
|
||||
|
||||
def test_exclude_arrow_functions(self, js_analyzer):
|
||||
"""Test excluding arrow functions."""
|
||||
|
|
@ -222,7 +222,7 @@ const arrow = () => 2;
|
|||
functions = js_analyzer.find_functions(code, include_arrow_functions=False)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "regular"
|
||||
assert functions[0].name == "regular"
|
||||
|
||||
def test_find_generator_function(self, js_analyzer):
|
||||
"""Test finding generator functions."""
|
||||
|
|
@ -235,7 +235,7 @@ function* numberGenerator() {
|
|||
functions = js_analyzer.find_functions(code)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "numberGenerator"
|
||||
assert functions[0].name == "numberGenerator"
|
||||
assert functions[0].is_generator is True
|
||||
|
||||
def test_function_line_numbers(self, js_analyzer):
|
||||
|
|
@ -291,7 +291,7 @@ function named() {
|
|||
functions = js_analyzer.find_functions(code, require_name=True)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "named"
|
||||
assert functions[0].name == "named"
|
||||
|
||||
def test_function_expression_in_variable(self, js_analyzer):
|
||||
"""Test function expression assigned to variable."""
|
||||
|
|
@ -303,7 +303,7 @@ const add = function(a, b) {
|
|||
functions = js_analyzer.find_functions(code)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "add"
|
||||
assert functions[0].name == "add"
|
||||
|
||||
|
||||
class TestFindImports:
|
||||
|
|
@ -515,7 +515,7 @@ function add(a: number, b: number): number {
|
|||
functions = ts_analyzer.find_functions(code)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "add"
|
||||
assert functions[0].name == "add"
|
||||
|
||||
def test_find_interface_method(self, ts_analyzer):
|
||||
"""Test that interface methods are not found (they're declarations)."""
|
||||
|
|
@ -544,4 +544,4 @@ function identity<T>(value: T): T {
|
|||
functions = ts_analyzer.find_functions(code)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "identity"
|
||||
assert functions[0].name == "identity"
|
||||
|
|
|
|||
Loading…
Reference in a new issue