mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
check
This commit is contained in:
parent
c4ee0371f8
commit
c3b3f2db9c
24 changed files with 101 additions and 123 deletions
|
|
@ -110,7 +110,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
|
||||
# Process each row
|
||||
for row in cursor.fetchall():
|
||||
module_name, class_name, function_name, benchmark_file, benchmark_func, benchmark_line, time_ns = row
|
||||
module_name, class_name, function_name, benchmark_file, benchmark_func, _benchmark_line, time_ns = row
|
||||
|
||||
# Create the function key (module_name.class_name.function_name)
|
||||
if class_name:
|
||||
|
|
@ -184,7 +184,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
|
||||
# Process each row and subtract overhead
|
||||
for row in cursor.fetchall():
|
||||
benchmark_file, benchmark_func, benchmark_line, time_ns = row
|
||||
benchmark_file, benchmark_func, _benchmark_line, time_ns = row
|
||||
|
||||
# Create the benchmark key (file::function::line)
|
||||
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
|
||||
|
|
|
|||
|
|
@ -677,7 +677,7 @@ def _add_global_declarations_for_language(
|
|||
# Insert declarations
|
||||
before = lines[:insertion_line]
|
||||
after = lines[insertion_line:]
|
||||
result_lines = before + [new_decl_code] + after
|
||||
result_lines = [*before, new_decl_code, *after]
|
||||
|
||||
return "".join(result_lines)
|
||||
|
||||
|
|
@ -702,8 +702,7 @@ def _find_insertion_line_after_imports_js(lines: list[str], analyzer: TreeSitter
|
|||
imports = analyzer.find_imports(source)
|
||||
if imports:
|
||||
# Find the last import's end line
|
||||
last_import_end = max(imp.end_line for imp in imports)
|
||||
return last_import_end # 0-based index = end_line (since end_line is 1-indexed)
|
||||
return max(imp.end_line for imp in imports)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class AssertCleanup:
|
|||
|
||||
unittest_match = self.unittest_re.match(line)
|
||||
if unittest_match:
|
||||
indent, assert_method, args = unittest_match.groups()
|
||||
indent, _assert_method, args = unittest_match.groups()
|
||||
|
||||
if args:
|
||||
arg_parts = self._first_top_level_arg(args)
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_
|
|||
file_path = base_path.replace(".", os.sep) + matched_ext
|
||||
# Check if the module path includes the tests directory name
|
||||
tests_dir_name = tests_project_rootdir.name
|
||||
if file_path.startswith(tests_dir_name + os.sep) or file_path.startswith(tests_dir_name + "/"):
|
||||
if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")):
|
||||
# Module path includes "tests." - use parent directory
|
||||
abs_path = tests_project_rootdir.parent / Path(file_path)
|
||||
else:
|
||||
|
|
@ -358,8 +358,7 @@ def normalize_codeflash_imports(source: str) -> str:
|
|||
# Replace CommonJS require
|
||||
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
|
||||
# Replace ES module import
|
||||
source = _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
|
||||
return source
|
||||
return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
|
||||
|
||||
|
||||
def inject_test_globals(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ def get_normalizer_for_extension(extension: str) -> CodeNormalizer | None:
|
|||
if not extension.startswith("."):
|
||||
extension = f".{extension}"
|
||||
|
||||
for language, normalizer_class in _NORMALIZERS.items():
|
||||
for language in _NORMALIZERS:
|
||||
normalizer = get_normalizer(language)
|
||||
if extension in normalizer.supported_extensions:
|
||||
return normalizer
|
||||
|
|
|
|||
|
|
@ -116,9 +116,8 @@ class VariableNormalizer(ast.NodeTransformer):
|
|||
and node.id not in self.nonlocal_vars
|
||||
):
|
||||
node.id = self.get_normalized_name(node.id)
|
||||
elif isinstance(node.ctx, ast.Load):
|
||||
if node.id in self.var_mapping:
|
||||
node.id = self.var_mapping[node.id]
|
||||
elif isinstance(node.ctx, ast.Load) and node.id in self.var_mapping:
|
||||
node.id = self.var_mapping[node.id]
|
||||
return node
|
||||
|
||||
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> ast.ExceptHandler:
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ def get_code_optimization_context(
|
|||
qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if "." in qn})
|
||||
|
||||
# Get FunctionSource representation of helpers of helpers of FTO
|
||||
helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(
|
||||
helpers_of_helpers_dict, _helpers_of_helpers_list = get_function_sources_from_jedi(
|
||||
helpers_of_fto_qualified_names_dict, project_root_path
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -62,15 +62,15 @@ __all__ = [
|
|||
# Current language singleton
|
||||
"current_language",
|
||||
"current_language_support",
|
||||
"is_javascript",
|
||||
"is_python",
|
||||
"is_typescript",
|
||||
"reset_current_language",
|
||||
"set_current_language",
|
||||
# Registry functions
|
||||
"detect_project_language",
|
||||
"get_language_support",
|
||||
"get_supported_extensions",
|
||||
"get_supported_languages",
|
||||
"is_javascript",
|
||||
"is_python",
|
||||
"is_typescript",
|
||||
"register_language",
|
||||
"reset_current_language",
|
||||
"set_current_language",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ def compare_test_results(
|
|||
return False, []
|
||||
|
||||
# Check for unexpected exit codes (not 0 or 1)
|
||||
if result.returncode != 0 and result.returncode != 1:
|
||||
if result.returncode not in {0, 1}:
|
||||
logger.error(f"JavaScript comparator failed with exit code {result.returncode}")
|
||||
if result.stderr:
|
||||
logger.error(f"stderr: {result.stderr}")
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class ImportResolver:
|
|||
# Supported extensions in resolution order (prefer TS over JS)
|
||||
EXTENSIONS = (".ts", ".tsx", ".js", ".jsx", ".mjs", ".cjs")
|
||||
|
||||
def __init__(self, project_root: Path):
|
||||
def __init__(self, project_root: Path) -> None:
|
||||
"""Initialize the resolver.
|
||||
|
||||
Args:
|
||||
|
|
@ -289,7 +289,7 @@ class MultiFileHelperFinder:
|
|||
|
||||
DEFAULT_MAX_DEPTH = 2 # Target → helpers → helpers of helpers
|
||||
|
||||
def __init__(self, project_root: Path, import_resolver: ImportResolver):
|
||||
def __init__(self, project_root: Path, import_resolver: ImportResolver) -> None:
|
||||
"""Initialize the finder.
|
||||
|
||||
Args:
|
||||
|
|
@ -343,7 +343,7 @@ class MultiFileHelperFinder:
|
|||
# Find helpers from imported modules
|
||||
results: dict[Path, list[HelperFunction]] = {}
|
||||
|
||||
for call_name, (import_info, actual_name) in call_to_import.items():
|
||||
for (import_info, actual_name) in call_to_import.values():
|
||||
# Resolve the import to a file path
|
||||
resolved = self.import_resolver.resolve_import(import_info, function.file_path)
|
||||
if resolved is None:
|
||||
|
|
|
|||
|
|
@ -707,10 +707,7 @@ def _is_function_used_in_test(code: str, func_name: str) -> bool:
|
|||
# Check for method calls: obj.funcName( or this.funcName(
|
||||
# This handles class methods called on instances
|
||||
method_call_pattern = rf"\w+\.{re.escape(func_name)}\s*\("
|
||||
if re.search(method_call_pattern, code):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(re.search(method_call_pattern, code))
|
||||
|
||||
|
||||
def _instrument_js_test_code(
|
||||
|
|
@ -967,7 +964,7 @@ def instrument_generated_js_test(
|
|||
|
||||
# Use the internal instrumentation function with assertion removal enabled
|
||||
# Generated tests are treated as regression tests, so we remove LLM-generated assertions
|
||||
instrumented_code = _instrument_js_test_code(
|
||||
return _instrument_js_test_code(
|
||||
code=test_code,
|
||||
func_name=function_name,
|
||||
test_file_path="generated_test",
|
||||
|
|
@ -976,4 +973,3 @@ def instrument_generated_js_test(
|
|||
remove_assertions=True,
|
||||
)
|
||||
|
||||
return instrumented_code
|
||||
|
|
|
|||
|
|
@ -9,12 +9,13 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -29,7 +30,7 @@ class JavaScriptLineProfiler:
|
|||
- Total execution time per function
|
||||
"""
|
||||
|
||||
def __init__(self, output_file: Path):
|
||||
def __init__(self, output_file: Path) -> None:
|
||||
"""Initialize the line profiler.
|
||||
|
||||
Args:
|
||||
|
|
@ -64,7 +65,7 @@ class JavaScriptLineProfiler:
|
|||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Process functions in reverse order to preserve line numbers
|
||||
for func in reversed(sorted(functions, key=lambda f: f.start_line)):
|
||||
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
|
||||
func_lines = self._instrument_function(func, lines, file_path)
|
||||
start_idx = func.start_line - 1
|
||||
end_idx = func.end_line
|
||||
|
|
@ -271,7 +272,7 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
|
|||
"await_expression",
|
||||
}
|
||||
|
||||
def walk(n):
|
||||
def walk(n) -> None:
|
||||
if n.type in executable_types:
|
||||
# Add the starting line (1-indexed)
|
||||
executable_lines.add(n.start_point[0] + 1)
|
||||
|
|
@ -328,5 +329,5 @@ if (__codeflash_save_interval__.unref) __codeflash_save_interval__.unref(); // D
|
|||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse line profile results: {e}")
|
||||
logger.exception(f"Failed to parse line profile results: {e}")
|
||||
return {"timings": {}, "unit": 1e-9, "functions": {}}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,10 @@ from __future__ import annotations
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -96,15 +99,13 @@ def detect_module_system(project_root: Path, file_path: Path | None = None) -> s
|
|||
has_module_exports = "module.exports" in content or "exports." in content
|
||||
|
||||
# Determine based on what we found
|
||||
if has_import or has_export:
|
||||
if not (has_require or has_module_exports):
|
||||
logger.debug("Detected ES Module from import/export statements")
|
||||
return ModuleSystem.ES_MODULE
|
||||
if (has_import or has_export) and not (has_require or has_module_exports):
|
||||
logger.debug("Detected ES Module from import/export statements")
|
||||
return ModuleSystem.ES_MODULE
|
||||
|
||||
if has_require or has_module_exports:
|
||||
if not (has_import or has_export):
|
||||
logger.debug("Detected CommonJS from require/module.exports")
|
||||
return ModuleSystem.COMMONJS
|
||||
if (has_require or has_module_exports) and not (has_import or has_export):
|
||||
logger.debug("Detected CommonJS from require/module.exports")
|
||||
return ModuleSystem.COMMONJS
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze file {file_path}: {e}")
|
||||
|
|
@ -270,7 +271,7 @@ def convert_esm_to_commonjs(code: str) -> str:
|
|||
default_import = re.compile(r"import\s+(\w+)\s+from\s+['\"]([^'\"]+)['\"];?")
|
||||
|
||||
# Replace named imports with destructured requires
|
||||
def replace_named(match):
|
||||
def replace_named(match) -> str:
|
||||
names = match.group(1).strip()
|
||||
module_path = match.group(2)
|
||||
# Remove .js extension for CommonJS (optional but cleaner)
|
||||
|
|
@ -278,7 +279,7 @@ def convert_esm_to_commonjs(code: str) -> str:
|
|||
return f"const {{ {names} }} = require('{module_path}');"
|
||||
|
||||
# Replace default imports with simple requires
|
||||
def replace_default(match):
|
||||
def replace_default(match) -> str:
|
||||
name = match.group(1)
|
||||
module_path = match.group(2)
|
||||
# Remove .js extension for CommonJS
|
||||
|
|
@ -287,9 +288,8 @@ def convert_esm_to_commonjs(code: str) -> str:
|
|||
|
||||
# Apply conversions (named first as it's more specific)
|
||||
code = named_import.sub(replace_named, code)
|
||||
code = default_import.sub(replace_default, code)
|
||||
return default_import.sub(replace_default, code)
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def ensure_module_system_compatibility(code: str, target_module_system: str) -> str:
|
||||
|
|
|
|||
|
|
@ -23,16 +23,13 @@ from codeflash.languages.base import (
|
|||
TestResult,
|
||||
)
|
||||
from codeflash.languages.registry import register_language
|
||||
from codeflash.languages.treesitter_utils import (
|
||||
TreeSitterAnalyzer,
|
||||
TreeSitterLanguage,
|
||||
TypeDefinition,
|
||||
get_analyzer_for_file,
|
||||
)
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from codeflash.languages.treesitter_utils import TypeDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -302,7 +299,7 @@ class JavaScriptSupport:
|
|||
try:
|
||||
source = function.file_path.read_text()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read {function.file_path}: {e}")
|
||||
logger.exception(f"Failed to read {function.file_path}: {e}")
|
||||
return CodeContext(target_code="", target_file=function.file_path, language=Language.JAVASCRIPT)
|
||||
|
||||
# Find imports and helper functions
|
||||
|
|
@ -1306,7 +1303,7 @@ class JavaScriptSupport:
|
|||
logger.warning(f"Test execution timed out after {timeout}s")
|
||||
return [], junit_xml
|
||||
except Exception as e:
|
||||
logger.error(f"Test execution failed: {e}")
|
||||
logger.exception(f"Test execution failed: {e}")
|
||||
return [], junit_xml
|
||||
|
||||
def parse_test_results(self, junit_xml_path: Path, stdout: str) -> list[TestResult]:
|
||||
|
|
@ -1920,14 +1917,7 @@ class TypeScriptSupport(JavaScriptSupport):
|
|||
List of glob patterns for test files.
|
||||
|
||||
"""
|
||||
return [
|
||||
"*.test.ts",
|
||||
"*.test.tsx",
|
||||
"*.spec.ts",
|
||||
"*.spec.tsx",
|
||||
"__tests__/**/*.ts",
|
||||
"__tests__/**/*.tsx",
|
||||
] + super()._get_test_patterns()
|
||||
return ["*.test.ts", "*.test.tsx", "*.spec.ts", "*.spec.tsx", "__tests__/**/*.ts", "__tests__/**/*.tsx", *super()._get_test_patterns()]
|
||||
|
||||
def get_test_file_suffix(self) -> str:
|
||||
"""Get the test file suffix for TypeScript.
|
||||
|
|
@ -1970,10 +1960,7 @@ class TypeScriptSupport(JavaScriptSupport):
|
|||
"""
|
||||
try:
|
||||
# Determine file extension for prettier
|
||||
if file_path:
|
||||
stdin_filepath = str(file_path.name)
|
||||
else:
|
||||
stdin_filepath = "file.ts"
|
||||
stdin_filepath = str(file_path.name) if file_path else "file.ts"
|
||||
|
||||
# Try to use prettier via npx
|
||||
result = subprocess.run(
|
||||
|
|
|
|||
|
|
@ -10,10 +10,11 @@ from __future__ import annotations
|
|||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -29,7 +30,7 @@ class JavaScriptTracer:
|
|||
- Execution time
|
||||
"""
|
||||
|
||||
def __init__(self, output_db: Path):
|
||||
def __init__(self, output_db: Path) -> None:
|
||||
"""Initialize the tracer.
|
||||
|
||||
Args:
|
||||
|
|
@ -63,7 +64,7 @@ class JavaScriptTracer:
|
|||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Process functions in reverse order to preserve line numbers
|
||||
for func in reversed(sorted(functions, key=lambda f: f.start_line)):
|
||||
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
|
||||
instrumented = self._instrument_function(func, lines, file_path)
|
||||
start_idx = func.start_line - 1
|
||||
end_idx = func.end_line
|
||||
|
|
@ -365,7 +366,7 @@ process.on('exit', () => {{
|
|||
with json_file.open("r") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse trace JSON: {e}")
|
||||
logger.exception(f"Failed to parse trace JSON: {e}")
|
||||
return []
|
||||
|
||||
# Try SQLite database
|
||||
|
|
@ -397,5 +398,5 @@ process.on('exit', () => {{
|
|||
return traces
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse trace database: {e}")
|
||||
logger.exception(f"Failed to parse trace database: {e}")
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -183,7 +183,7 @@ class PythonSupport:
|
|||
try:
|
||||
source = function.file_path.read_text()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read {function.file_path}: {e}")
|
||||
logger.exception(f"Failed to read {function.file_path}: {e}")
|
||||
return CodeContext(target_code="", target_file=function.file_path, language=Language.PYTHON)
|
||||
|
||||
# Extract the function source
|
||||
|
|
@ -201,7 +201,7 @@ class PythonSupport:
|
|||
import_lines = []
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("import ") or stripped.startswith("from "):
|
||||
if stripped.startswith(("import ", "from ")):
|
||||
import_lines.append(stripped)
|
||||
elif stripped and not stripped.startswith("#"):
|
||||
# Stop at first non-import, non-comment line
|
||||
|
|
@ -312,13 +312,12 @@ class PythonSupport:
|
|||
original_function_names = [function.qualified_name]
|
||||
|
||||
# Use the existing replacer
|
||||
result = replace_functions_in_file(
|
||||
return replace_functions_in_file(
|
||||
source_code=source,
|
||||
original_function_names=original_function_names,
|
||||
optimized_code=new_source,
|
||||
preexisting_objects=set(),
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replace function {function.name}: {e}")
|
||||
return source
|
||||
|
|
@ -399,7 +398,7 @@ class PythonSupport:
|
|||
logger.warning(f"Test execution timed out after {timeout}s")
|
||||
return [], junit_xml
|
||||
except Exception as e:
|
||||
logger.error(f"Test execution failed: {e}")
|
||||
logger.exception(f"Test execution failed: {e}")
|
||||
return [], junit_xml
|
||||
|
||||
def parse_test_results(self, junit_xml_path: Path, stdout: str) -> list[TestResult]:
|
||||
|
|
@ -567,7 +566,7 @@ class PythonSupport:
|
|||
import libcst as cst
|
||||
|
||||
class TestFunctionRemover(cst.CSTTransformer):
|
||||
def __init__(self, names_to_remove: list[str]):
|
||||
def __init__(self, names_to_remove: list[str]) -> None:
|
||||
self.names_to_remove = set(names_to_remove)
|
||||
|
||||
def leave_FunctionDef(
|
||||
|
|
|
|||
|
|
@ -11,11 +11,13 @@ import logging
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import Language, LanguageSupport
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from codeflash.languages.base import LanguageSupport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -32,7 +34,7 @@ _SUPPORT_CACHE: dict[Language, LanguageSupport] = {}
|
|||
class UnsupportedLanguageError(Exception):
|
||||
"""Raised when attempting to use an unsupported language."""
|
||||
|
||||
def __init__(self, identifier: str | Path, supported: Iterable[str] | None = None):
|
||||
def __init__(self, identifier: str | Path, supported: Iterable[str] | None = None) -> None:
|
||||
self.identifier = identifier
|
||||
self.supported = list(supported) if supported else []
|
||||
msg = f"Unsupported language: {identifier}"
|
||||
|
|
@ -74,10 +76,13 @@ def register_language(cls: type[LanguageSupport]) -> type[LanguageSupport]:
|
|||
language = instance.language
|
||||
extensions = instance.file_extensions
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Failed to instantiate {cls.__name__} for registration. "
|
||||
f"Language support classes must be instantiable without arguments. "
|
||||
f"Error: {e}"
|
||||
)
|
||||
raise ValueError(
|
||||
msg
|
||||
) from e
|
||||
|
||||
# Register by extension
|
||||
|
|
@ -241,7 +246,8 @@ def detect_project_language(project_root: Path, module_root: Path) -> Language:
|
|||
logger.info(f"Detected language: {cls().language} (found {count} '{ext}' files)")
|
||||
return cls().language
|
||||
|
||||
raise UnsupportedLanguageError(f"No supported language detected in {module_root}", get_supported_languages())
|
||||
msg = f"No supported language detected in {module_root}"
|
||||
raise UnsupportedLanguageError(msg, get_supported_languages())
|
||||
|
||||
|
||||
def get_supported_languages() -> list[str]:
|
||||
|
|
|
|||
|
|
@ -9,13 +9,14 @@ from __future__ import annotations
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from tree_sitter import Language, Node, Parser
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Tree
|
||||
from pathlib import Path
|
||||
|
||||
from tree_sitter import Node, Tree
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -127,7 +128,7 @@ class TreeSitterAnalyzer:
|
|||
finding functions, imports, and other code structures.
|
||||
"""
|
||||
|
||||
def __init__(self, language: TreeSitterLanguage | str):
|
||||
def __init__(self, language: TreeSitterLanguage | str) -> None:
|
||||
"""Initialize the analyzer for a specific language.
|
||||
|
||||
Args:
|
||||
|
|
@ -235,7 +236,7 @@ class TreeSitterAnalyzer:
|
|||
new_class = current_class
|
||||
new_function = current_function
|
||||
|
||||
if node.type == "class_declaration" or node.type == "class":
|
||||
if node.type in {"class_declaration", "class"}:
|
||||
# Get class name
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
|
|
@ -683,12 +684,9 @@ class TreeSitterAnalyzer:
|
|||
if child.type == "default":
|
||||
# Find what's being exported as default
|
||||
for sibling in node.children:
|
||||
if sibling.type == "function_declaration" or sibling.type == "class_declaration":
|
||||
if sibling.type in {"function_declaration", "class_declaration"}:
|
||||
name_node = sibling.child_by_field_name("name")
|
||||
if name_node:
|
||||
default_export = self.get_node_text(name_node, source_bytes)
|
||||
else:
|
||||
default_export = "default"
|
||||
default_export = self.get_node_text(name_node, source_bytes) if name_node else "default"
|
||||
elif sibling.type == "identifier":
|
||||
default_export = self.get_node_text(sibling, source_bytes)
|
||||
elif sibling.type in ("arrow_function", "function_expression", "object", "array"):
|
||||
|
|
@ -771,13 +769,10 @@ class TreeSitterAnalyzer:
|
|||
|
||||
if left_text == "module.exports":
|
||||
# module.exports = something
|
||||
if right_node.type == "function_expression" or right_node.type == "arrow_function":
|
||||
if right_node.type in {"function_expression", "arrow_function"}:
|
||||
# module.exports = function foo() {} or module.exports = () => {}
|
||||
name_node = right_node.child_by_field_name("name")
|
||||
if name_node:
|
||||
default_export = self.get_node_text(name_node, source_bytes)
|
||||
else:
|
||||
default_export = "default"
|
||||
default_export = self.get_node_text(name_node, source_bytes) if name_node else "default"
|
||||
elif right_node.type == "identifier":
|
||||
# module.exports = someFunction
|
||||
default_export = self.get_node_text(right_node, source_bytes)
|
||||
|
|
@ -1045,7 +1040,7 @@ class TreeSitterAnalyzer:
|
|||
identifiers: list[str] = []
|
||||
|
||||
def walk(n: Node) -> None:
|
||||
if n.type == "identifier" or n.type == "shorthand_property_identifier_pattern":
|
||||
if n.type in {"identifier", "shorthand_property_identifier_pattern"}:
|
||||
identifiers.append(self.get_node_text(n, source_bytes))
|
||||
for child in n.children:
|
||||
walk(child)
|
||||
|
|
@ -1091,22 +1086,19 @@ class TreeSitterAnalyzer:
|
|||
return
|
||||
|
||||
# Skip variable declarator names (left side of declaration)
|
||||
if parent.type == "variable_declarator":
|
||||
if parent.child_by_field_name("name") == node:
|
||||
# Don't recurse - the value will be visited when we visit the declarator
|
||||
return
|
||||
if parent.type == "variable_declarator" and parent.child_by_field_name("name") == node:
|
||||
# Don't recurse - the value will be visited when we visit the declarator
|
||||
return
|
||||
|
||||
# Skip property names in object literals (keys)
|
||||
if parent.type == "pair":
|
||||
if parent.child_by_field_name("key") == node:
|
||||
# Don't recurse - the value will be visited when we visit the pair
|
||||
return
|
||||
if parent.type == "pair" and parent.child_by_field_name("key") == node:
|
||||
# Don't recurse - the value will be visited when we visit the pair
|
||||
return
|
||||
|
||||
# Skip property access property names (obj.property - skip 'property')
|
||||
if parent.type == "member_expression":
|
||||
if parent.child_by_field_name("property") == node:
|
||||
# Don't recurse - the object will be visited when we visit the member_expression
|
||||
return
|
||||
if parent.type == "member_expression" and parent.child_by_field_name("property") == node:
|
||||
# Don't recurse - the object will be visited when we visit the member_expression
|
||||
return
|
||||
|
||||
# Skip import specifier names
|
||||
if parent.type in ("import_specifier", "import_clause", "namespace_import"):
|
||||
|
|
@ -1117,7 +1109,7 @@ class TreeSitterAnalyzer:
|
|||
return
|
||||
|
||||
# Skip parameter names in function definitions
|
||||
if parent.type == "formal_parameters" or parent.type == "required_parameter":
|
||||
if parent.type in {"formal_parameters", "required_parameter"}:
|
||||
return
|
||||
|
||||
# This is a reference
|
||||
|
|
@ -1165,11 +1157,7 @@ class TreeSitterAnalyzer:
|
|||
return True
|
||||
return False
|
||||
|
||||
for child in node.children:
|
||||
if self._node_has_return(child):
|
||||
return True
|
||||
|
||||
return False
|
||||
return any(self._node_has_return(child) for child in node.children)
|
||||
|
||||
def extract_type_annotations(self, source: str, function_name: str, function_line: int) -> set[str]:
|
||||
"""Extract type annotation names from a function's parameters and return type.
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ def existing_tests_source_for(
|
|||
file_path = base_path.replace(".", os.sep) + matched_ext
|
||||
# Check if the module path includes the tests directory name
|
||||
tests_dir_name = test_cfg.tests_project_rootdir.name
|
||||
if file_path.startswith(tests_dir_name + os.sep) or file_path.startswith(tests_dir_name + "/"):
|
||||
if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")):
|
||||
# Module path includes "tests." - use project root parent
|
||||
instrumented_abs_path = (test_cfg.tests_project_rootdir.parent / file_path).resolve()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class ProfileStats(pstats.Stats):
|
|||
time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit]
|
||||
print(f"in {self.total_tt:.3f} {time_unit}", file=self.stream)
|
||||
print(file=self.stream)
|
||||
width, list_ = self.get_print_list(amount)
|
||||
_width, list_ = self.get_print_list(amount)
|
||||
if list_:
|
||||
self.print_title()
|
||||
for func in list_:
|
||||
|
|
|
|||
|
|
@ -470,7 +470,7 @@ class Tracer:
|
|||
# In multi-threaded contexts, we need to be more careful about frame comparisons
|
||||
if self.cur and frame.f_back is not self.cur[-2]:
|
||||
# This happens when we're in a different thread
|
||||
rpt, rit, ret, rfn, rframe, rcur = self.cur
|
||||
_rpt, _rit, _ret, _rfn, rframe, _rcur = self.cur
|
||||
|
||||
# Only attempt to handle the frame mismatch if we have a valid rframe
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
|
|
@ -457,7 +458,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
file_path = base_path.replace(".", os.sep) + extension
|
||||
# Check if the module path includes the tests directory name
|
||||
tests_dir_name = test_config.tests_project_rootdir.name
|
||||
if file_path.startswith(tests_dir_name + os.sep) or file_path.startswith(tests_dir_name + "/"):
|
||||
if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")):
|
||||
# Module path includes "tests." - use project root parent
|
||||
test_file_path = test_config.tests_project_rootdir.parent / file_path
|
||||
else:
|
||||
|
|
@ -845,10 +846,8 @@ def parse_jest_test_xml(
|
|||
runtime = None
|
||||
if end_match:
|
||||
# Duration is in the 6th group (index 5)
|
||||
try:
|
||||
with contextlib.suppress(ValueError, IndexError):
|
||||
runtime = int(end_match.group(6))
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
test_results.add(
|
||||
FunctionTestInvocation(
|
||||
loop_index=loop_index,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ This script tests all three approaches against the test cases and generates
|
|||
a comparison report.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ Each test case includes:
|
|||
- description: What edge case this tests
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue