format and lint all
This commit is contained in:
parent
0e5ad411ec
commit
198487bf81
109 changed files with 2704 additions and 3211 deletions
|
|
@ -129,7 +129,8 @@ class AiServiceClient:
|
|||
experiment_metadata: ExperimentMetadata | None = None,
|
||||
*,
|
||||
language: str = "python",
|
||||
language_version: str | None = None, # TODO:{claude} add language version to the language support and it should be cached
|
||||
language_version: str
|
||||
| None = None, # TODO:{claude} add language version to the language support and it should be cached
|
||||
module_system: str | None = None,
|
||||
is_async: bool = False,
|
||||
n_candidates: int = 5,
|
||||
|
|
@ -238,6 +239,7 @@ class AiServiceClient:
|
|||
is_async=is_async,
|
||||
n_candidates=n_candidates,
|
||||
)
|
||||
|
||||
def get_jit_rewritten_code( # noqa: D417
|
||||
self, source_code: str, trace_id: str
|
||||
) -> list[OptimizedCandidate]:
|
||||
|
|
@ -410,6 +412,7 @@ class AiServiceClient:
|
|||
|
||||
Returns:
|
||||
List of refined optimization candidates
|
||||
|
||||
"""
|
||||
payload = []
|
||||
for opt in request:
|
||||
|
|
@ -727,7 +730,7 @@ class AiServiceClient:
|
|||
language: str = "python",
|
||||
language_version: str | None = None,
|
||||
module_system: str | None = None,
|
||||
is_numerical_code: bool | None = None, # noqa: FBT001
|
||||
is_numerical_code: bool | None = None,
|
||||
) -> tuple[str, str, str] | None:
|
||||
"""Generate regression tests for the given function by making a request to the Django endpoint.
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from __future__ import annotations
|
|||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
|
|
@ -192,10 +191,7 @@ class TestGenRequest:
|
|||
"""Convert to API payload dict, maintaining backward compatibility."""
|
||||
payload = {
|
||||
"source_code_being_tested": self.source_code,
|
||||
"function_to_optimize": {
|
||||
"function_name": self.function_name,
|
||||
"is_async": self.is_async,
|
||||
},
|
||||
"function_to_optimize": {"function_name": self.function_name, "is_async": self.is_async},
|
||||
"helper_function_names": self.helper_function_names,
|
||||
"module_path": self.module_path,
|
||||
"test_module_path": self.test_module_path,
|
||||
|
|
@ -243,24 +239,16 @@ def python_language_info(version: str | None = None) -> LanguageInfo:
|
|||
|
||||
|
||||
def javascript_language_info(
|
||||
module_system: ModuleSystem = ModuleSystem.COMMONJS,
|
||||
version: str = "ES2022",
|
||||
module_system: ModuleSystem = ModuleSystem.COMMONJS, version: str = "ES2022"
|
||||
) -> LanguageInfo:
|
||||
"""Create LanguageInfo for JavaScript."""
|
||||
ext = ".mjs" if module_system == ModuleSystem.ESM else ".js"
|
||||
return LanguageInfo(
|
||||
name="javascript",
|
||||
version=version,
|
||||
module_system=module_system,
|
||||
file_extension=ext,
|
||||
has_type_annotations=False,
|
||||
name="javascript", version=version, module_system=module_system, file_extension=ext, has_type_annotations=False
|
||||
)
|
||||
|
||||
|
||||
def typescript_language_info(
|
||||
module_system: ModuleSystem = ModuleSystem.ESM,
|
||||
version: str = "ES2022",
|
||||
) -> LanguageInfo:
|
||||
def typescript_language_info(module_system: ModuleSystem = ModuleSystem.ESM, version: str = "ES2022") -> LanguageInfo:
|
||||
"""Create LanguageInfo for TypeScript."""
|
||||
return LanguageInfo(
|
||||
name="typescript",
|
||||
|
|
|
|||
|
|
@ -1516,6 +1516,7 @@ def _customize_python_workflow_content(
|
|||
codeflash_cmd += " --benchmark"
|
||||
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
|
||||
|
||||
|
||||
# TODO:{claude} Refactor and move to support for language specific
|
||||
def _customize_js_workflow_content(
|
||||
optimize_yml_content: str,
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ def code_print(
|
|||
function_name: Optional function name for LSP
|
||||
lsp_message_id: Optional LSP message ID
|
||||
language: Programming language for syntax highlighting ('python', 'javascript', 'typescript')
|
||||
|
||||
"""
|
||||
if is_LSP_enabled():
|
||||
lsp_log(
|
||||
|
|
@ -118,11 +119,7 @@ def code_print(
|
|||
from rich.syntax import Syntax
|
||||
|
||||
# Map codeflash language names to rich/pygments lexer names
|
||||
lexer_map = {
|
||||
"python": "python",
|
||||
"javascript": "javascript",
|
||||
"typescript": "typescript",
|
||||
}
|
||||
lexer_map = {"python": "python", "javascript": "javascript", "typescript": "typescript"}
|
||||
lexer = lexer_map.get(language, "python")
|
||||
|
||||
console.rule()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""JavaScript/TypeScript project initialization for Codeflash."""
|
||||
|
||||
# TODO:{claude} move to language support directory
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -125,8 +126,7 @@ def init_js_project(language: ProjectLanguage) -> None:
|
|||
|
||||
lang_panel = Panel(
|
||||
Text(
|
||||
f"📦 Detected {lang_name} project!\n\n"
|
||||
"I'll help you set up Codeflash for your project.",
|
||||
f"📦 Detected {lang_name} project!\n\nI'll help you set up Codeflash for your project.",
|
||||
style="cyan",
|
||||
justify="center",
|
||||
),
|
||||
|
|
@ -159,13 +159,13 @@ def init_js_project(language: ProjectLanguage) -> None:
|
|||
usage_table.add_column("Command", style="cyan")
|
||||
usage_table.add_column("Description", style="white")
|
||||
|
||||
usage_table.add_row(
|
||||
"codeflash --file <path-to-file> --function <function-name>", "Optimize a specific function"
|
||||
)
|
||||
usage_table.add_row("codeflash --file <path-to-file> --function <function-name>", "Optimize a specific function")
|
||||
usage_table.add_row("codeflash --all", "Optimize all functions in all files")
|
||||
usage_table.add_row("codeflash --help", "See all available options")
|
||||
|
||||
completion_message = f"⚡️ Codeflash is now set up for your {lang_name} project!\n\nYou can now run any of these commands:"
|
||||
completion_message = (
|
||||
f"⚡️ Codeflash is now set up for your {lang_name} project!\n\nYou can now run any of these commands:"
|
||||
)
|
||||
|
||||
if did_add_new_key:
|
||||
completion_message += (
|
||||
|
|
@ -265,10 +265,7 @@ def collect_js_setup_info(language: ProjectLanguage) -> JSSetupInfo:
|
|||
detection_table.add_row("Formatter", formatter_display)
|
||||
|
||||
detection_panel = Panel(
|
||||
Group(
|
||||
Text(f"Auto-detected settings for your {lang_name} project:\n", style="cyan"),
|
||||
detection_table,
|
||||
),
|
||||
Group(Text(f"Auto-detected settings for your {lang_name} project:\n", style="cyan"), detection_table),
|
||||
title="🔍 Auto-Detection Results",
|
||||
border_style="bright_blue",
|
||||
)
|
||||
|
|
@ -399,8 +396,7 @@ def _get_git_remote_for_setup() -> str:
|
|||
|
||||
git_panel = Panel(
|
||||
Text(
|
||||
"🔗 Configure Git Remote for Pull Requests.\n\n"
|
||||
"Codeflash will use this remote to create pull requests.",
|
||||
"🔗 Configure Git Remote for Pull Requests.\n\nCodeflash will use this remote to create pull requests.",
|
||||
style="blue",
|
||||
),
|
||||
title="🔗 Git Remote Setup",
|
||||
|
|
|
|||
|
|
@ -206,6 +206,7 @@ def find_package_json(config_file: Path | None = None) -> Path | None:
|
|||
|
||||
return None
|
||||
|
||||
|
||||
def parse_package_json_config(package_json_path: Path) -> tuple[dict[str, Any], Path] | None:
|
||||
"""Parse codeflash config from package.json with auto-detection.
|
||||
|
||||
|
|
|
|||
|
|
@ -8,14 +8,10 @@ from __future__ import annotations
|
|||
|
||||
import hashlib
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.code_utils.normalizers import get_normalizer
|
||||
from codeflash.languages import current_language, is_python
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def normalize_code(
|
||||
code: str,
|
||||
|
|
@ -35,6 +31,7 @@ def normalize_code(
|
|||
|
||||
Returns:
|
||||
Normalized code as string
|
||||
|
||||
"""
|
||||
if language is None:
|
||||
language = current_language().value
|
||||
|
|
@ -89,6 +86,7 @@ def get_code_fingerprint(code: str, language: str | None = None) -> str:
|
|||
|
||||
Returns:
|
||||
SHA-256 hash of normalized code
|
||||
|
||||
"""
|
||||
if language is None:
|
||||
language = current_language().value
|
||||
|
|
@ -112,6 +110,7 @@ def are_codes_duplicate(code1: str, code2: str, language: str | None = None) ->
|
|||
|
||||
Returns:
|
||||
True if codes are structurally identical (ignoring local variable names)
|
||||
|
||||
"""
|
||||
if language is None:
|
||||
language = current_language().value
|
||||
|
|
@ -127,8 +126,4 @@ def are_codes_duplicate(code1: str, code2: str, language: str | None = None) ->
|
|||
|
||||
|
||||
# Re-export for backward compatibility
|
||||
__all__ = [
|
||||
"normalize_code",
|
||||
"get_code_fingerprint",
|
||||
"are_codes_duplicate",
|
||||
]
|
||||
__all__ = ["are_codes_duplicate", "get_code_fingerprint", "normalize_code"]
|
||||
|
|
|
|||
|
|
@ -22,13 +22,10 @@ from codeflash.code_utils.normalizers.base import CodeNormalizer
|
|||
from codeflash.code_utils.normalizers.javascript import JavaScriptNormalizer, TypeScriptNormalizer
|
||||
from codeflash.code_utils.normalizers.python import PythonNormalizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
__all__ = [
|
||||
"CodeNormalizer",
|
||||
"PythonNormalizer",
|
||||
"JavaScriptNormalizer",
|
||||
"PythonNormalizer",
|
||||
"TypeScriptNormalizer",
|
||||
"get_normalizer",
|
||||
"get_normalizer_for_extension",
|
||||
|
|
@ -56,6 +53,7 @@ def get_normalizer(language: str) -> CodeNormalizer:
|
|||
|
||||
Raises:
|
||||
ValueError: If no normalizer exists for the language
|
||||
|
||||
"""
|
||||
language = language.lower()
|
||||
|
||||
|
|
@ -83,6 +81,7 @@ def get_normalizer_for_extension(extension: str) -> CodeNormalizer | None:
|
|||
|
||||
Returns:
|
||||
CodeNormalizer instance if found, None otherwise
|
||||
|
||||
"""
|
||||
extension = extension.lower()
|
||||
if not extension.startswith("."):
|
||||
|
|
@ -102,6 +101,7 @@ def register_normalizer(language: str, normalizer_class: type[CodeNormalizer]) -
|
|||
Args:
|
||||
language: Language name
|
||||
normalizer_class: CodeNormalizer subclass
|
||||
|
||||
"""
|
||||
_NORMALIZERS[language.lower()] = normalizer_class
|
||||
# Clear cached instance if it exists
|
||||
|
|
|
|||
|
|
@ -4,14 +4,11 @@ Code normalizers transform source code into a canonical form for duplicate detec
|
|||
They normalize variable names, remove comments/docstrings, and produce consistent output
|
||||
that can be compared across different implementations of the same algorithm.
|
||||
"""
|
||||
|
||||
# TODO:{claude} move to base.py in language folder
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class CodeNormalizer(ABC):
|
||||
|
|
@ -30,6 +27,7 @@ class CodeNormalizer(ABC):
|
|||
>>> code2 = "def foo(x): z = x + 1; return z"
|
||||
>>> normalizer.normalize(code1) == normalizer.normalize(code2)
|
||||
True
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
|
|
@ -52,6 +50,7 @@ class CodeNormalizer(ABC):
|
|||
|
||||
Returns:
|
||||
Normalized representation of the code
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
@ -66,6 +65,7 @@ class CodeNormalizer(ABC):
|
|||
|
||||
Returns:
|
||||
Normalized representation suitable for hashing
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
@ -78,6 +78,7 @@ class CodeNormalizer(ABC):
|
|||
|
||||
Returns:
|
||||
True if codes are structurally identical
|
||||
|
||||
"""
|
||||
try:
|
||||
normalized1 = self.normalize_for_hash(code1)
|
||||
|
|
@ -94,6 +95,7 @@ class CodeNormalizer(ABC):
|
|||
|
||||
Returns:
|
||||
SHA-256 hash of normalized code
|
||||
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ from codeflash.code_utils.normalizers.base import CodeNormalizer
|
|||
if TYPE_CHECKING:
|
||||
from tree_sitter import Node
|
||||
|
||||
# TODO:{claude} move to language support directory to keep the directory structure clean
|
||||
|
||||
# TODO:{claude} move to language support directory to keep the directory structure clean
|
||||
class JavaScriptVariableNormalizer:
|
||||
"""Normalizes JavaScript/TypeScript code for duplicate detection using tree-sitter.
|
||||
|
||||
|
|
@ -24,13 +25,49 @@ class JavaScriptVariableNormalizer:
|
|||
self.preserved_names: set[str] = set()
|
||||
# Common JavaScript builtins
|
||||
self.builtins = {
|
||||
"console", "window", "document", "Math", "JSON", "Object", "Array",
|
||||
"String", "Number", "Boolean", "Date", "RegExp", "Error", "Promise",
|
||||
"Map", "Set", "WeakMap", "WeakSet", "Symbol", "Proxy", "Reflect",
|
||||
"undefined", "null", "NaN", "Infinity", "globalThis", "parseInt",
|
||||
"parseFloat", "isNaN", "isFinite", "eval", "setTimeout", "setInterval",
|
||||
"clearTimeout", "clearInterval", "fetch", "require", "module", "exports",
|
||||
"process", "__dirname", "__filename", "Buffer",
|
||||
"console",
|
||||
"window",
|
||||
"document",
|
||||
"Math",
|
||||
"JSON",
|
||||
"Object",
|
||||
"Array",
|
||||
"String",
|
||||
"Number",
|
||||
"Boolean",
|
||||
"Date",
|
||||
"RegExp",
|
||||
"Error",
|
||||
"Promise",
|
||||
"Map",
|
||||
"Set",
|
||||
"WeakMap",
|
||||
"WeakSet",
|
||||
"Symbol",
|
||||
"Proxy",
|
||||
"Reflect",
|
||||
"undefined",
|
||||
"null",
|
||||
"NaN",
|
||||
"Infinity",
|
||||
"globalThis",
|
||||
"parseInt",
|
||||
"parseFloat",
|
||||
"isNaN",
|
||||
"isFinite",
|
||||
"eval",
|
||||
"setTimeout",
|
||||
"setInterval",
|
||||
"clearTimeout",
|
||||
"clearInterval",
|
||||
"fetch",
|
||||
"require",
|
||||
"module",
|
||||
"exports",
|
||||
"process",
|
||||
"__dirname",
|
||||
"__filename",
|
||||
"Buffer",
|
||||
}
|
||||
|
||||
def get_normalized_name(self, name: str) -> str:
|
||||
|
|
@ -48,7 +85,7 @@ class JavaScriptVariableNormalizer:
|
|||
if node.type in ("function_declaration", "function_expression", "method_definition", "arrow_function"):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
self.preserved_names.add(source_code[name_node.start_byte:name_node.end_byte].decode("utf-8"))
|
||||
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
|
||||
# Preserve parameters
|
||||
params_node = node.child_by_field_name("parameters") or node.child_by_field_name("parameter")
|
||||
if params_node:
|
||||
|
|
@ -58,7 +95,7 @@ class JavaScriptVariableNormalizer:
|
|||
elif node.type == "class_declaration":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
self.preserved_names.add(source_code[name_node.start_byte:name_node.end_byte].decode("utf-8"))
|
||||
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
|
||||
|
||||
# Import declarations
|
||||
elif node.type in ("import_statement", "import_declaration"):
|
||||
|
|
@ -66,7 +103,7 @@ class JavaScriptVariableNormalizer:
|
|||
if child.type == "import_clause":
|
||||
self._collect_import_names(child, source_code)
|
||||
elif child.type == "identifier":
|
||||
self.preserved_names.add(source_code[child.start_byte:child.end_byte].decode("utf-8"))
|
||||
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
|
||||
|
||||
# Recurse
|
||||
for child in node.children:
|
||||
|
|
@ -76,11 +113,13 @@ class JavaScriptVariableNormalizer:
|
|||
"""Collect parameter names from a parameters node."""
|
||||
for child in node.children:
|
||||
if child.type == "identifier":
|
||||
self.preserved_names.add(source_code[child.start_byte:child.end_byte].decode("utf-8"))
|
||||
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
|
||||
elif child.type in ("required_parameter", "optional_parameter", "rest_parameter"):
|
||||
pattern_node = child.child_by_field_name("pattern")
|
||||
if pattern_node and pattern_node.type == "identifier":
|
||||
self.preserved_names.add(source_code[pattern_node.start_byte:pattern_node.end_byte].decode("utf-8"))
|
||||
self.preserved_names.add(
|
||||
source_code[pattern_node.start_byte : pattern_node.end_byte].decode("utf-8")
|
||||
)
|
||||
# Recurse for nested patterns
|
||||
self._collect_parameter_names(child, source_code)
|
||||
|
||||
|
|
@ -88,15 +127,15 @@ class JavaScriptVariableNormalizer:
|
|||
"""Collect imported names from import clause."""
|
||||
for child in node.children:
|
||||
if child.type == "identifier":
|
||||
self.preserved_names.add(source_code[child.start_byte:child.end_byte].decode("utf-8"))
|
||||
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
|
||||
elif child.type == "import_specifier":
|
||||
# Get the local name (alias or original)
|
||||
alias_node = child.child_by_field_name("alias")
|
||||
name_node = child.child_by_field_name("name")
|
||||
if alias_node:
|
||||
self.preserved_names.add(source_code[alias_node.start_byte:alias_node.end_byte].decode("utf-8"))
|
||||
self.preserved_names.add(source_code[alias_node.start_byte : alias_node.end_byte].decode("utf-8"))
|
||||
elif name_node:
|
||||
self.preserved_names.add(source_code[name_node.start_byte:name_node.end_byte].decode("utf-8"))
|
||||
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
|
||||
self._collect_import_names(child, source_code)
|
||||
|
||||
def normalize_tree(self, node: Node, source_code: bytes) -> str:
|
||||
|
|
@ -113,14 +152,14 @@ class JavaScriptVariableNormalizer:
|
|||
|
||||
# Handle identifiers - normalize variable names
|
||||
if node.type == "identifier":
|
||||
name = source_code[node.start_byte:node.end_byte].decode("utf-8")
|
||||
name = source_code[node.start_byte : node.end_byte].decode("utf-8")
|
||||
normalized = self.get_normalized_name(name)
|
||||
parts.append(normalized)
|
||||
return
|
||||
|
||||
# Handle type identifiers (TypeScript) - preserve as-is
|
||||
if node.type == "type_identifier":
|
||||
parts.append(source_code[node.start_byte:node.end_byte].decode("utf-8"))
|
||||
parts.append(source_code[node.start_byte : node.end_byte].decode("utf-8"))
|
||||
return
|
||||
|
||||
# Handle string literals - normalize to placeholder
|
||||
|
|
@ -135,7 +174,7 @@ class JavaScriptVariableNormalizer:
|
|||
|
||||
# For leaf nodes, output the node type
|
||||
if len(node.children) == 0:
|
||||
text = source_code[node.start_byte:node.end_byte].decode("utf-8")
|
||||
text = source_code[node.start_byte : node.end_byte].decode("utf-8")
|
||||
parts.append(text)
|
||||
return
|
||||
|
||||
|
|
@ -191,14 +230,12 @@ class JavaScriptNormalizer(CodeNormalizer):
|
|||
|
||||
Returns:
|
||||
Normalized representation of the code
|
||||
|
||||
"""
|
||||
try:
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
lang_map = {
|
||||
"javascript": TreeSitterLanguage.JAVASCRIPT,
|
||||
"typescript": TreeSitterLanguage.TYPESCRIPT,
|
||||
}
|
||||
lang_map = {"javascript": TreeSitterLanguage.JAVASCRIPT, "typescript": TreeSitterLanguage.TYPESCRIPT}
|
||||
lang = lang_map.get(self._get_tree_sitter_language(), TreeSitterLanguage.JAVASCRIPT)
|
||||
analyzer = TreeSitterAnalyzer(lang)
|
||||
tree = analyzer.parse(code)
|
||||
|
|
@ -227,6 +264,7 @@ class JavaScriptNormalizer(CodeNormalizer):
|
|||
|
||||
Returns:
|
||||
Normalized representation suitable for hashing
|
||||
|
||||
"""
|
||||
return self.normalize(code)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,13 +3,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.code_utils.normalizers.base import CodeNormalizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class VariableNormalizer(ast.NodeTransformer):
|
||||
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
|
||||
|
|
@ -197,6 +193,7 @@ class PythonNormalizer(CodeNormalizer):
|
|||
|
||||
Returns:
|
||||
Normalized Python code as a string
|
||||
|
||||
"""
|
||||
tree = ast.parse(code)
|
||||
|
||||
|
|
@ -219,6 +216,7 @@ class PythonNormalizer(CodeNormalizer):
|
|||
|
||||
Returns:
|
||||
AST dump string suitable for hashing
|
||||
|
||||
"""
|
||||
tree = ast.parse(code)
|
||||
_remove_docstrings_from_ast(tree)
|
||||
|
|
|
|||
|
|
@ -207,6 +207,7 @@ def get_code_optimization_context(
|
|||
preexisting_objects=preexisting_objects,
|
||||
)
|
||||
|
||||
|
||||
def get_code_optimization_context_for_language(
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root_path: Path,
|
||||
|
|
@ -356,6 +357,7 @@ def get_code_optimization_context_for_language(
|
|||
preexisting_objects=set(), # Not implemented for non-Python yet
|
||||
)
|
||||
|
||||
|
||||
def extract_code_markdown_context_from_files(
|
||||
helpers_of_fto: dict[Path, set[FunctionSource]],
|
||||
helpers_of_helpers: dict[Path, set[FunctionSource]],
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@ from codeflash.code_utils.code_utils import (
|
|||
module_name_from_file_path,
|
||||
)
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
|
||||
from codeflash.languages import is_javascript, is_python
|
||||
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
|
||||
from codeflash.languages import is_javascript, is_python
|
||||
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""
|
||||
Multi-language support for Codeflash.
|
||||
"""Multi-language support for Codeflash.
|
||||
|
||||
This package provides the abstraction layer that allows Codeflash to support
|
||||
multiple programming languages while keeping the core optimization pipeline
|
||||
|
|
@ -37,6 +36,11 @@ from codeflash.languages.current import (
|
|||
reset_current_language,
|
||||
set_current_language,
|
||||
)
|
||||
from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401
|
||||
|
||||
# Import language support modules to trigger auto-registration
|
||||
# This ensures all supported languages are available when this package is imported
|
||||
from codeflash.languages.python import PythonSupport # noqa: F401
|
||||
from codeflash.languages.registry import (
|
||||
detect_project_language,
|
||||
get_language_support,
|
||||
|
|
@ -45,11 +49,6 @@ from codeflash.languages.registry import (
|
|||
register_language,
|
||||
)
|
||||
|
||||
# Import language support modules to trigger auto-registration
|
||||
# This ensures all supported languages are available when this package is imported
|
||||
from codeflash.languages.python import PythonSupport # noqa: F401
|
||||
from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
# Base types
|
||||
"CodeContext",
|
||||
|
|
|
|||
|
|
@ -507,10 +507,7 @@ class LanguageSupport(Protocol):
|
|||
# === Test Result Comparison ===
|
||||
|
||||
def compare_test_results(
|
||||
self,
|
||||
original_results_path: Path,
|
||||
candidate_results_path: Path,
|
||||
project_root: Path | None = None,
|
||||
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
|
||||
) -> tuple[bool, list]:
|
||||
"""Compare test results between original and candidate code.
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ def _get_compare_results_script(project_root: Path | None = None) -> Path | None
|
|||
|
||||
Returns:
|
||||
Path to compare-results.js if found, None otherwise.
|
||||
|
||||
"""
|
||||
search_dirs = []
|
||||
if project_root:
|
||||
|
|
@ -59,6 +60,7 @@ def compare_test_results(
|
|||
|
||||
Returns:
|
||||
Tuple of (all_equivalent, list of TestDiff objects).
|
||||
|
||||
"""
|
||||
script_path = comparator_script or _get_compare_results_script(project_root)
|
||||
|
||||
|
|
|
|||
|
|
@ -75,9 +75,7 @@ class StandaloneCallTransformer:
|
|||
# Pattern to match func_name( with optional leading await and optional object prefix
|
||||
# Captures: (whitespace)(await )?(object.)*func_name(
|
||||
# We'll filter out expect() and codeflash. cases in the transform loop
|
||||
self._call_pattern = re.compile(
|
||||
rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(func_name)}\s*\("
|
||||
)
|
||||
self._call_pattern = re.compile(rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(func_name)}\s*\(")
|
||||
|
||||
def transform(self, code: str) -> str:
|
||||
"""Transform all standalone calls in the code."""
|
||||
|
|
@ -353,9 +351,7 @@ class ExpectCallTransformer:
|
|||
self.invocation_counter = 0
|
||||
# Pattern to match start of expect((object.)*func_name(
|
||||
# Captures: (whitespace), (object prefix like calc. or this.)
|
||||
self._expect_pattern = re.compile(
|
||||
rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(func_name)}\s*\("
|
||||
)
|
||||
self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(func_name)}\s*\(")
|
||||
|
||||
def transform(self, code: str) -> str:
|
||||
"""Transform all expect calls in the code."""
|
||||
|
|
|
|||
|
|
@ -23,7 +23,12 @@ from codeflash.languages.base import (
|
|||
TestResult,
|
||||
)
|
||||
from codeflash.languages.registry import register_language
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, TypeDefinition, get_analyzer_for_file
|
||||
from codeflash.languages.treesitter_utils import (
|
||||
TreeSitterAnalyzer,
|
||||
TreeSitterLanguage,
|
||||
TypeDefinition,
|
||||
get_analyzer_for_file,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
|
@ -346,7 +351,9 @@ class JavaScriptSupport:
|
|||
|
||||
# Wrap the method in a class definition with context
|
||||
if class_jsdoc:
|
||||
target_code = f"{class_jsdoc}\n{class_indent}class {class_name} {{\n{class_body}{class_indent}}}\n"
|
||||
target_code = (
|
||||
f"{class_jsdoc}\n{class_indent}class {class_name} {{\n{class_body}{class_indent}}}\n"
|
||||
)
|
||||
else:
|
||||
target_code = f"{class_indent}class {class_name} {{\n{class_body}{class_indent}}}\n"
|
||||
else:
|
||||
|
|
@ -366,11 +373,7 @@ class JavaScriptSupport:
|
|||
|
||||
# Extract type definitions for function parameters and class fields
|
||||
type_definitions_context, type_definition_names = self._extract_type_definitions_context(
|
||||
function=function,
|
||||
source=source,
|
||||
analyzer=analyzer,
|
||||
imports=imports,
|
||||
module_root=module_root,
|
||||
function=function, source=source, analyzer=analyzer, imports=imports, module_root=module_root
|
||||
)
|
||||
|
||||
# Find module-level declarations (global variables/constants) referenced by the function
|
||||
|
|
@ -715,12 +718,7 @@ class JavaScriptSupport:
|
|||
return "\n".join(global_lines)
|
||||
|
||||
def _extract_type_definitions_context(
|
||||
self,
|
||||
function: FunctionInfo,
|
||||
source: str,
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
imports: list[Any],
|
||||
module_root: Path,
|
||||
self, function: FunctionInfo, source: str, analyzer: TreeSitterAnalyzer, imports: list[Any], module_root: Path
|
||||
) -> tuple[str, set[str]]:
|
||||
"""Extract type definitions used by the function for read-only context.
|
||||
|
||||
|
|
@ -805,11 +803,7 @@ class JavaScriptSupport:
|
|||
return "\n\n".join(type_def_parts), found_type_names
|
||||
|
||||
def _find_imported_type_definitions(
|
||||
self,
|
||||
type_names: set[str],
|
||||
imports: list[Any],
|
||||
module_root: Path,
|
||||
source_file_path: Path,
|
||||
self, type_names: set[str], imports: list[Any], module_root: Path, source_file_path: Path
|
||||
) -> list[TypeDefinition]:
|
||||
"""Find type definitions in imported files.
|
||||
|
||||
|
|
@ -1583,9 +1577,7 @@ class JavaScriptSupport:
|
|||
|
||||
return None
|
||||
|
||||
def get_module_path(
|
||||
self, source_file: Path, project_root: Path, tests_root: Path | None = None
|
||||
) -> str:
|
||||
def get_module_path(self, source_file: Path, project_root: Path, tests_root: Path | None = None) -> str:
|
||||
"""Get the module path for importing a JavaScript source file from tests.
|
||||
|
||||
For JavaScript, this returns a relative path from the tests directory to the source file
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""
|
||||
Python language support for Codeflash.
|
||||
"""Python language support for Codeflash.
|
||||
|
||||
This module provides the PythonSupport class which wraps the existing
|
||||
Python-specific implementations (LibCST, Jedi, pytest, etc.) to conform
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class PythonSupport:
|
|||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize, FunctionVisitor
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionVisitor
|
||||
|
||||
criteria = filter_criteria or FunctionFilterCriteria()
|
||||
|
||||
|
|
@ -339,12 +339,7 @@ class PythonSupport:
|
|||
# Try ruff first
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ruff", "format", "-"],
|
||||
check=False,
|
||||
input=source,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
["ruff", "format", "-"], check=False, input=source, capture_output=True, text=True, timeout=30
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
|
|
@ -356,12 +351,7 @@ class PythonSupport:
|
|||
# Try black as fallback
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["black", "-q", "-"],
|
||||
check=False,
|
||||
input=source,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
["black", "-q", "-"], check=False, input=source, capture_output=True, text=True, timeout=30
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
|
|
@ -397,19 +387,11 @@ class PythonSupport:
|
|||
junit_xml = output_dir / "pytest-results.xml"
|
||||
|
||||
# Build pytest command
|
||||
cmd = [
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
f"--junitxml={junit_xml}",
|
||||
"-v",
|
||||
]
|
||||
cmd = ["python", "-m", "pytest", f"--junitxml={junit_xml}", "-v"]
|
||||
cmd.extend(str(f) for f in test_files)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, check=False, cwd=cwd, env=env, capture_output=True, text=True, timeout=timeout
|
||||
)
|
||||
result = subprocess.run(cmd, check=False, cwd=cwd, env=env, capture_output=True, text=True, timeout=timeout)
|
||||
results = self.parse_test_results(junit_xml, result.stdout)
|
||||
return results, junit_xml
|
||||
|
||||
|
|
@ -653,11 +635,7 @@ class PythonSupport:
|
|||
|
||||
"""
|
||||
# Common test directory patterns for Python
|
||||
test_dirs = [
|
||||
project_root / "tests",
|
||||
project_root / "test",
|
||||
project_root / "spec",
|
||||
]
|
||||
test_dirs = [project_root / "tests", project_root / "test", project_root / "spec"]
|
||||
|
||||
for test_dir in test_dirs:
|
||||
if test_dir.exists() and test_dir.is_dir():
|
||||
|
|
@ -669,9 +647,7 @@ class PythonSupport:
|
|||
|
||||
return None
|
||||
|
||||
def get_module_path(
|
||||
self, source_file: Path, project_root: Path, tests_root: Path | None = None
|
||||
) -> str:
|
||||
def get_module_path(self, source_file: Path, project_root: Path, tests_root: Path | None = None) -> str:
|
||||
"""Get the module path for importing a Python source file.
|
||||
|
||||
For Python, this returns a dot-separated module path (e.g., 'mypackage.mymodule').
|
||||
|
|
@ -778,4 +754,4 @@ class PythonSupport:
|
|||
# Note: For Python, test execution is handled by the main test_runner.py
|
||||
# which has special Python-specific logic. These methods are not called
|
||||
# for Python as the test_runner checks is_python() and uses the existing path.
|
||||
# They are defined here only for protocol compliance.
|
||||
# They are defined here only for protocol compliance.
|
||||
|
|
|
|||
|
|
@ -1239,9 +1239,16 @@ class TreeSitterAnalyzer:
|
|||
|
||||
return type_names
|
||||
|
||||
def _find_function_node(self, node: Node, source_bytes: bytes, function_name: str, function_line: int) -> Node | None:
|
||||
def _find_function_node(
|
||||
self, node: Node, source_bytes: bytes, function_name: str, function_line: int
|
||||
) -> Node | None:
|
||||
"""Find a function/method node by name and line number."""
|
||||
if node.type in ("function_declaration", "method_definition", "function_expression", "generator_function_declaration"):
|
||||
if node.type in (
|
||||
"function_declaration",
|
||||
"method_definition",
|
||||
"function_expression",
|
||||
"generator_function_declaration",
|
||||
):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self.get_node_text(name_node, source_bytes)
|
||||
|
|
@ -1306,14 +1313,40 @@ class TreeSitterAnalyzer:
|
|||
if node.type == "type_identifier":
|
||||
type_name = self.get_node_text(node, source_bytes)
|
||||
# Skip primitive types
|
||||
if type_name not in ("number", "string", "boolean", "void", "null", "undefined", "any", "never", "unknown", "object", "symbol", "bigint"):
|
||||
if type_name not in (
|
||||
"number",
|
||||
"string",
|
||||
"boolean",
|
||||
"void",
|
||||
"null",
|
||||
"undefined",
|
||||
"any",
|
||||
"never",
|
||||
"unknown",
|
||||
"object",
|
||||
"symbol",
|
||||
"bigint",
|
||||
):
|
||||
type_names.add(type_name)
|
||||
return
|
||||
|
||||
# Handle regular identifiers in type position (can happen in some contexts)
|
||||
if node.type == "identifier" and node.parent and node.parent.type in ("type_annotation", "generic_type"):
|
||||
type_name = self.get_node_text(node, source_bytes)
|
||||
if type_name not in ("number", "string", "boolean", "void", "null", "undefined", "any", "never", "unknown", "object", "symbol", "bigint"):
|
||||
if type_name not in (
|
||||
"number",
|
||||
"string",
|
||||
"boolean",
|
||||
"void",
|
||||
"null",
|
||||
"undefined",
|
||||
"any",
|
||||
"never",
|
||||
"unknown",
|
||||
"object",
|
||||
"symbol",
|
||||
"bigint",
|
||||
):
|
||||
type_names.add(type_name)
|
||||
return
|
||||
|
||||
|
|
@ -1360,7 +1393,12 @@ class TreeSitterAnalyzer:
|
|||
# Handle export statements - unwrap to get the actual definition
|
||||
if node.type == "export_statement":
|
||||
for child in node.children:
|
||||
if child.type in ("interface_declaration", "type_alias_declaration", "class_declaration", "enum_declaration"):
|
||||
if child.type in (
|
||||
"interface_declaration",
|
||||
"type_alias_declaration",
|
||||
"class_declaration",
|
||||
"enum_declaration",
|
||||
):
|
||||
self._extract_type_definition(child, source_bytes, definitions, is_exported=True)
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ from codeflash.code_utils.code_utils import (
|
|||
extract_unique_errors,
|
||||
file_name_from_test_module_name,
|
||||
get_run_tmp_file,
|
||||
module_name_from_file_path,
|
||||
normalize_by_max,
|
||||
restore_conftest,
|
||||
unified_diff_strings,
|
||||
|
|
@ -81,7 +80,6 @@ from codeflash.languages import is_python
|
|||
from codeflash.languages.base import FunctionInfo, Language
|
||||
from codeflash.languages.current import current_language_support, is_typescript
|
||||
from codeflash.languages.javascript.module_system import detect_module_system
|
||||
from codeflash.languages.registry import get_language_support
|
||||
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown
|
||||
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
|
||||
from codeflash.models.ExperimentMetadata import ExperimentMetadata
|
||||
|
|
|
|||
|
|
@ -31,7 +31,9 @@ def existing_tests_source_for(
|
|||
optimized_runtimes_all: dict[InvocationId, list[int]],
|
||||
test_files_registry: TestFiles | None = None,
|
||||
) -> tuple[str, str, str]:
|
||||
logger.debug(f"[PR-DEBUG] existing_tests_source_for called with func={function_qualified_name_with_modules_from_root}")
|
||||
logger.debug(
|
||||
f"[PR-DEBUG] existing_tests_source_for called with func={function_qualified_name_with_modules_from_root}"
|
||||
)
|
||||
logger.debug(f"[PR-DEBUG] function_to_tests keys: {list(function_to_tests.keys())}")
|
||||
logger.debug(f"[PR-DEBUG] original_runtimes_all has {len(original_runtimes_all)} entries")
|
||||
logger.debug(f"[PR-DEBUG] optimized_runtimes_all has {len(optimized_runtimes_all)} entries")
|
||||
|
|
@ -60,11 +62,17 @@ def existing_tests_source_for(
|
|||
for tf in test_files_registry.test_files:
|
||||
if tf.original_file_path:
|
||||
if tf.instrumented_behavior_file_path:
|
||||
instrumented_to_original[tf.instrumented_behavior_file_path.resolve()] = tf.original_file_path.resolve()
|
||||
logger.debug(f"[PR-DEBUG] Mapping (behavior): {tf.instrumented_behavior_file_path.name} -> {tf.original_file_path.name}")
|
||||
instrumented_to_original[tf.instrumented_behavior_file_path.resolve()] = (
|
||||
tf.original_file_path.resolve()
|
||||
)
|
||||
logger.debug(
|
||||
f"[PR-DEBUG] Mapping (behavior): {tf.instrumented_behavior_file_path.name} -> {tf.original_file_path.name}"
|
||||
)
|
||||
if tf.benchmarking_file_path:
|
||||
instrumented_to_original[tf.benchmarking_file_path.resolve()] = tf.original_file_path.resolve()
|
||||
logger.debug(f"[PR-DEBUG] Mapping (perf): {tf.benchmarking_file_path.name} -> {tf.original_file_path.name}")
|
||||
logger.debug(
|
||||
f"[PR-DEBUG] Mapping (perf): {tf.benchmarking_file_path.name} -> {tf.original_file_path.name}"
|
||||
)
|
||||
|
||||
# Resolve all paths to absolute for consistent comparison
|
||||
non_generated_tests: set[Path] = set()
|
||||
|
|
@ -84,9 +92,22 @@ def existing_tests_source_for(
|
|||
# For Python, it's a module name (e.g., "tests.test_example") that needs conversion
|
||||
test_module_path = invocation_id.test_module_path
|
||||
# Jest test file extensions (including .test.ts, .spec.ts patterns)
|
||||
jest_test_extensions = (".test.ts", ".test.js", ".test.tsx", ".test.jsx",
|
||||
".spec.ts", ".spec.js", ".spec.tsx", ".spec.jsx",
|
||||
".ts", ".js", ".tsx", ".jsx", ".mjs", ".mts")
|
||||
jest_test_extensions = (
|
||||
".test.ts",
|
||||
".test.js",
|
||||
".test.tsx",
|
||||
".test.jsx",
|
||||
".spec.ts",
|
||||
".spec.js",
|
||||
".spec.tsx",
|
||||
".spec.jsx",
|
||||
".ts",
|
||||
".js",
|
||||
".tsx",
|
||||
".jsx",
|
||||
".mjs",
|
||||
".mts",
|
||||
)
|
||||
# Find the appropriate extension
|
||||
matched_ext = None
|
||||
for ext in jest_test_extensions:
|
||||
|
|
@ -96,7 +117,7 @@ def existing_tests_source_for(
|
|||
if matched_ext:
|
||||
# JavaScript/TypeScript: convert module-style path to file path
|
||||
# "tests.fibonacci__perfinstrumented.test.ts" -> "tests/fibonacci__perfinstrumented.test.ts"
|
||||
base_path = test_module_path[:-len(matched_ext)]
|
||||
base_path = test_module_path[: -len(matched_ext)]
|
||||
# Convert dots to path separators in the base path
|
||||
file_path = base_path.replace(".", os.sep) + matched_ext
|
||||
# Check if the module path includes the tests directory name
|
||||
|
|
|
|||
|
|
@ -25,10 +25,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
def generate_concolic_tests(
|
||||
test_cfg: TestConfig,
|
||||
args: Namespace,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
function_to_optimize_ast: ast.AST,
|
||||
test_cfg: TestConfig, args: Namespace, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.AST
|
||||
) -> tuple[dict[str, set[FunctionCalledInTest]], str]:
|
||||
"""Generate concolic tests using CrossHair (Python only).
|
||||
|
||||
|
|
@ -43,6 +40,7 @@ def generate_concolic_tests(
|
|||
|
||||
Returns:
|
||||
Tuple of (function_to_tests mapping, concolic test suite code)
|
||||
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
function_to_concolic_tests = {}
|
||||
|
|
|
|||
|
|
@ -20,16 +20,14 @@ if TYPE_CHECKING:
|
|||
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
|
||||
|
||||
# TODO:{self} Needs cleanup for jest logic check for coverage algorithm here and if we need to move it to /support
|
||||
class JestCoverageUtils:
|
||||
"""Coverage utils class for interfacing with Jest coverage output."""
|
||||
|
||||
@staticmethod
|
||||
def load_from_jest_json(
|
||||
coverage_json_path: Path,
|
||||
function_name: str,
|
||||
code_context: CodeOptimizationContext,
|
||||
source_code_path: Path,
|
||||
coverage_json_path: Path, function_name: str, code_context: CodeOptimizationContext, source_code_path: Path
|
||||
) -> CoverageData:
|
||||
"""Load coverage data from Jest's coverage-final.json file.
|
||||
|
||||
|
|
@ -41,6 +39,7 @@ class JestCoverageUtils:
|
|||
|
||||
Returns:
|
||||
CoverageData object with parsed coverage information
|
||||
|
||||
"""
|
||||
if not coverage_json_path or not coverage_json_path.exists():
|
||||
logger.debug(f"Jest coverage file not found: {coverage_json_path}")
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ reprlib_repr = reprlib.Repr()
|
|||
reprlib_repr.maxstring = 1500
|
||||
test_diff_repr = reprlib_repr.repr
|
||||
|
||||
|
||||
def safe_repr(obj: object) -> str:
|
||||
"""Safely get repr of an object, handling Mock objects with corrupted state."""
|
||||
try:
|
||||
|
|
@ -25,6 +26,7 @@ def safe_repr(obj: object) -> str:
|
|||
except (AttributeError, TypeError, RecursionError) as e:
|
||||
return f"<repr failed: {type(e).__name__}: {e}>"
|
||||
|
||||
|
||||
def compare_test_results(
|
||||
original_results: TestResults,
|
||||
candidate_results: TestResults,
|
||||
|
|
|
|||
|
|
@ -246,20 +246,30 @@ def parse_jest_json_results(
|
|||
# Check behavior path
|
||||
if test_file.instrumented_behavior_file_path:
|
||||
try:
|
||||
rel_path = str(test_file.instrumented_behavior_file_path.relative_to(test_config.tests_project_rootdir))
|
||||
rel_path = str(
|
||||
test_file.instrumented_behavior_file_path.relative_to(test_config.tests_project_rootdir)
|
||||
)
|
||||
except ValueError:
|
||||
rel_path = test_file.instrumented_behavior_file_path.name
|
||||
if rel_path == expected_path or rel_path.replace("/", ".").replace(".js", "") == result_module_path:
|
||||
if (
|
||||
rel_path == expected_path
|
||||
or rel_path.replace("/", ".").replace(".js", "") == result_module_path
|
||||
):
|
||||
test_file_path = test_file.instrumented_behavior_file_path
|
||||
test_type = test_file.test_type
|
||||
break
|
||||
# Check benchmarking path
|
||||
if test_file.benchmarking_file_path:
|
||||
try:
|
||||
rel_path = str(test_file.benchmarking_file_path.relative_to(test_config.tests_project_rootdir))
|
||||
rel_path = str(
|
||||
test_file.benchmarking_file_path.relative_to(test_config.tests_project_rootdir)
|
||||
)
|
||||
except ValueError:
|
||||
rel_path = test_file.benchmarking_file_path.name
|
||||
if rel_path == expected_path or rel_path.replace("/", ".").replace(".js", "") == result_module_path:
|
||||
if (
|
||||
rel_path == expected_path
|
||||
or rel_path.replace("/", ".").replace(".js", "") == result_module_path
|
||||
):
|
||||
test_file_path = test_file.benchmarking_file_path
|
||||
test_type = test_file.test_type
|
||||
break
|
||||
|
|
@ -416,9 +426,22 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
# For Python, it's a module path (e.g., "tests.test_foo") that needs conversion
|
||||
if is_jest:
|
||||
# Jest test file extensions (including .test.ts, .spec.ts patterns)
|
||||
jest_test_extensions = (".test.ts", ".test.js", ".test.tsx", ".test.jsx",
|
||||
".spec.ts", ".spec.js", ".spec.tsx", ".spec.jsx",
|
||||
".ts", ".js", ".tsx", ".jsx", ".mjs", ".mts")
|
||||
jest_test_extensions = (
|
||||
".test.ts",
|
||||
".test.js",
|
||||
".test.tsx",
|
||||
".test.jsx",
|
||||
".spec.ts",
|
||||
".spec.js",
|
||||
".spec.tsx",
|
||||
".spec.jsx",
|
||||
".ts",
|
||||
".js",
|
||||
".tsx",
|
||||
".jsx",
|
||||
".mjs",
|
||||
".mts",
|
||||
)
|
||||
# Check if it's a module-style path (no slashes, has dots beyond extension)
|
||||
if "/" not in test_module_path and "\\" not in test_module_path:
|
||||
# Find the appropriate extension to preserve
|
||||
|
|
@ -430,7 +453,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
if extension:
|
||||
# Convert module-style path to file path
|
||||
# "tests.fibonacci__perfinstrumented.test.ts" -> "tests/fibonacci__perfinstrumented.test.ts"
|
||||
base_path = test_module_path[:-len(extension)]
|
||||
base_path = test_module_path[: -len(extension)]
|
||||
file_path = base_path.replace(".", os.sep) + extension
|
||||
# Check if the module path includes the tests directory name
|
||||
tests_dir_name = test_config.tests_project_rootdir.name
|
||||
|
|
@ -560,6 +583,7 @@ def _extract_jest_console_output(suite_elem) -> str:
|
|||
|
||||
return raw_content
|
||||
|
||||
|
||||
# TODO: {Claude} we need to move to the support directory.
|
||||
def parse_jest_test_xml(
|
||||
test_xml_file_path: Path,
|
||||
|
|
@ -611,10 +635,7 @@ def parse_jest_test_xml(
|
|||
if test_file.instrumented_behavior_file_path:
|
||||
# Store both the absolute path and resolved path as keys
|
||||
abs_path = str(test_file.instrumented_behavior_file_path.resolve())
|
||||
instrumented_path_lookup[abs_path] = (
|
||||
test_file.instrumented_behavior_file_path,
|
||||
test_file.test_type,
|
||||
)
|
||||
instrumented_path_lookup[abs_path] = (test_file.instrumented_behavior_file_path, test_file.test_type)
|
||||
# Also store the string representation in case of minor path differences
|
||||
instrumented_path_lookup[str(test_file.instrumented_behavior_file_path)] = (
|
||||
test_file.instrumented_behavior_file_path,
|
||||
|
|
|
|||
|
|
@ -9,9 +9,7 @@ from pydantic.dataclasses import dataclass
|
|||
from codeflash.languages import current_language_support, is_javascript
|
||||
|
||||
|
||||
def get_test_file_path(
|
||||
test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit"
|
||||
) -> Path:
|
||||
def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path:
|
||||
assert test_type in {"unit", "inspired", "replay", "perf"}
|
||||
function_name = function_name.replace(".", "_")
|
||||
# Use appropriate file extension based on language
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from codeflash.models.models import FunctionParent
|
|||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
||||
def test_benchmark_extract(benchmark)->None:
|
||||
def test_benchmark_extract(benchmark) -> None:
|
||||
file_path = Path(__file__).parent.parent.parent.resolve() / "codeflash"
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
|
|
@ -28,4 +28,4 @@ def test_benchmark_extract(benchmark)->None:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
benchmark(get_code_optimization_context,function_to_optimize, opt.args.project_root)
|
||||
benchmark(get_code_optimization_context, function_to_optimize, opt.args.project_root)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ def test_benchmark_code_to_optimize_test_discovery(benchmark) -> None:
|
|||
tests_project_rootdir=tests_path.parent,
|
||||
)
|
||||
benchmark(discover_unit_tests, test_config)
|
||||
|
||||
|
||||
def test_benchmark_codeflash_test_discovery(benchmark) -> None:
|
||||
project_path = Path(__file__).parent.parent.parent.resolve() / "codeflash"
|
||||
tests_path = project_path / "tests"
|
||||
|
|
|
|||
|
|
@ -60,11 +60,7 @@ def run_merge_benchmark(count=100):
|
|||
test_results_xml, test_results_bin = generate_test_invocations(count)
|
||||
|
||||
# Perform the merge operation that will be benchmarked
|
||||
merge_test_results(
|
||||
xml_test_results=test_results_xml,
|
||||
bin_test_results=test_results_bin,
|
||||
test_framework="unittest"
|
||||
)
|
||||
merge_test_results(xml_test_results=test_results_xml, bin_test_results=test_results_bin, test_framework="unittest")
|
||||
|
||||
|
||||
def test_benchmark_merge_test_results(benchmark):
|
||||
|
|
|
|||
|
|
@ -2,15 +2,14 @@ from __future__ import annotations
|
|||
|
||||
import configparser
|
||||
import os
|
||||
import stat
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import tomlkit
|
||||
|
||||
from codeflash.code_utils.code_utils import custom_addopts
|
||||
|
||||
|
||||
def test_custom_addopts_modifies_and_restores_dotini_file(tmp_path: Path) -> None:
|
||||
"""Verify that custom_addopts correctly modifies and then restores a pytest.ini file."""
|
||||
# Create a dummy pytest.ini file
|
||||
|
|
@ -32,6 +31,7 @@ def test_custom_addopts_modifies_and_restores_dotini_file(tmp_path: Path) -> Non
|
|||
restored_content = config_file.read_text()
|
||||
assert restored_content.strip() == original_content.strip()
|
||||
|
||||
|
||||
def test_custom_addopts_modifies_and_restores_ini_file(tmp_path: Path) -> None:
|
||||
"""Verify that custom_addopts correctly modifies and then restores a pytest.ini file."""
|
||||
# Create a dummy pytest.ini file
|
||||
|
|
@ -60,9 +60,7 @@ def test_custom_addopts_modifies_and_restores_toml_file(tmp_path: Path) -> None:
|
|||
config_file = tmp_path / "pyproject.toml"
|
||||
os.chdir(tmp_path)
|
||||
original_addopts = "-v --cov=./src --junitxml=report.xml"
|
||||
original_content_dict = {
|
||||
"tool": {"pytest": {"ini_options": {"addopts": original_addopts}}}
|
||||
}
|
||||
original_content_dict = {"tool": {"pytest": {"ini_options": {"addopts": original_addopts}}}}
|
||||
original_content = tomlkit.dumps(original_content_dict)
|
||||
config_file.write_text(original_content)
|
||||
|
||||
|
|
@ -97,6 +95,7 @@ def test_custom_addopts_handles_no_addopts(tmp_path: Path) -> None:
|
|||
content_after_context = config_file.read_text()
|
||||
assert content_after_context == original_content
|
||||
|
||||
|
||||
def test_custom_addopts_handles_no_relevant_files(tmp_path: Path) -> None:
|
||||
"""Ensure custom_addopts runs without error when no config files are found."""
|
||||
# No config files created in tmp_path
|
||||
|
|
@ -151,9 +150,7 @@ def test_custom_addopts_with_multiple_config_files(tmp_path: Path) -> None:
|
|||
# Create pyproject.toml
|
||||
toml_file = tmp_path / "pyproject.toml"
|
||||
toml_original_addopts = "-s -n auto"
|
||||
toml_original_content_dict = {
|
||||
"tool": {"pytest": {"ini_options": {"addopts": toml_original_addopts}}}
|
||||
}
|
||||
toml_original_content_dict = {"tool": {"pytest": {"ini_options": {"addopts": toml_original_addopts}}}}
|
||||
toml_original_content = tomlkit.dumps(toml_original_content_dict)
|
||||
toml_file.write_text(toml_original_content)
|
||||
|
||||
|
|
@ -182,9 +179,8 @@ def test_custom_addopts_restores_on_exception(tmp_path: Path) -> None:
|
|||
config_file.write_text(original_content)
|
||||
|
||||
os.chdir(tmp_path)
|
||||
with pytest.raises(ValueError, match="Test exception"):
|
||||
with custom_addopts():
|
||||
raise ValueError("Test exception")
|
||||
with pytest.raises(ValueError, match="Test exception"), custom_addopts():
|
||||
raise ValueError("Test exception")
|
||||
|
||||
restored_content = config_file.read_text()
|
||||
assert restored_content.strip() == original_content.strip()
|
||||
|
|
|
|||
|
|
@ -106,10 +106,7 @@ class TestDetectLanguage:
|
|||
|
||||
def test_detects_typescript_with_complex_tsconfig(self, tmp_path: Path) -> None:
|
||||
"""Should detect TypeScript even with complex tsconfig."""
|
||||
tsconfig = {
|
||||
"compilerOptions": {"target": "ES2020", "module": "commonjs"},
|
||||
"include": ["src/**/*"],
|
||||
}
|
||||
tsconfig = {"compilerOptions": {"target": "ES2020", "module": "commonjs"}, "include": ["src/**/*"]}
|
||||
(tmp_path / "tsconfig.json").write_text(json.dumps(tsconfig))
|
||||
|
||||
result = detect_language(tmp_path)
|
||||
|
|
@ -784,11 +781,7 @@ class TestRealWorldPackageJsonExamples:
|
|||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "@myorg/core",
|
||||
"main": "./packages/core/src/index.js",
|
||||
"devDependencies": {"jest": "^29.0.0"},
|
||||
}
|
||||
{"name": "@myorg/core", "main": "./packages/core/src/index.js", "devDependencies": {"jest": "^29.0.0"}}
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,6 @@ from __future__ import annotations
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.cli_cmds.init_javascript import (
|
||||
JsPackageManager,
|
||||
get_js_codeflash_install_step,
|
||||
|
|
|
|||
|
|
@ -14,4 +14,4 @@ def set_python_language():
|
|||
reset_current_language()
|
||||
yield
|
||||
# Reset again after test to clean up any changes
|
||||
reset_current_language()
|
||||
reset_current_language()
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
MY_CONSTANT = 7
|
||||
MY_CONSTANT = 7
|
||||
|
|
|
|||
|
|
@ -5,9 +5,7 @@ from end_to_end_test_utilities import TestConfig, run_codeflash_command, run_wit
|
|||
|
||||
|
||||
def run_test(expected_improvement_pct: int) -> bool:
|
||||
config = TestConfig(
|
||||
file_path="bubble_sort.py", function_name="sorter", min_improvement_x=0.30, no_gen_tests=True
|
||||
)
|
||||
config = TestConfig(file_path="bubble_sort.py", function_name="sorter", min_improvement_x=0.30, no_gen_tests=True)
|
||||
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
|
||||
return run_codeflash_command(cwd, config, expected_improvement_pct)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,12 +18,7 @@ def run_test() -> bool:
|
|||
expected_test_files=1, # At least one test file should be instrumented
|
||||
)
|
||||
|
||||
cwd = (
|
||||
pathlib.Path(__file__).parent.parent.parent
|
||||
/ "code_to_optimize"
|
||||
/ "js"
|
||||
/ "code_to_optimize_js"
|
||||
).resolve()
|
||||
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "js" / "code_to_optimize_js").resolve()
|
||||
|
||||
return run_js_codeflash_command(cwd, config)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,10 +20,7 @@ def run_test() -> bool:
|
|||
)
|
||||
|
||||
cwd = (
|
||||
pathlib.Path(__file__).parent.parent.parent
|
||||
/ "code_to_optimize"
|
||||
/ "js"
|
||||
/ "code_to_optimize_js_esm"
|
||||
pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "js" / "code_to_optimize_js_esm"
|
||||
).resolve()
|
||||
|
||||
return run_js_codeflash_command(cwd, config)
|
||||
|
|
|
|||
|
|
@ -18,12 +18,7 @@ def run_test() -> bool:
|
|||
expected_test_files=1,
|
||||
)
|
||||
|
||||
cwd = (
|
||||
pathlib.Path(__file__).parent.parent.parent
|
||||
/ "code_to_optimize"
|
||||
/ "js"
|
||||
/ "code_to_optimize_ts"
|
||||
).resolve()
|
||||
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "js" / "code_to_optimize_ts").resolve()
|
||||
|
||||
return run_js_codeflash_command(cwd, config)
|
||||
|
||||
|
|
|
|||
|
|
@ -86,12 +86,14 @@ def validate_coverage(stdout: str, expectations: list[CoverageExpectation]) -> b
|
|||
|
||||
return True
|
||||
|
||||
|
||||
def validate_no_gen_tests(stdout: str) -> bool:
|
||||
if "Generated '0' tests for" not in stdout:
|
||||
logging.error("Tests generated even when flag was on")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_codeflash_command(
|
||||
cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int, expected_in_stdout: list[str] = None
|
||||
) -> bool:
|
||||
|
|
@ -106,9 +108,9 @@ def run_codeflash_command(
|
|||
|
||||
command = build_command(cwd, config, test_root, config.benchmarks_root if config.benchmarks_root else None)
|
||||
env = os.environ.copy()
|
||||
env['PYTHONIOENCODING'] = 'utf-8'
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
process = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding='utf-8'
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding="utf-8"
|
||||
)
|
||||
|
||||
output = []
|
||||
|
|
@ -131,7 +133,7 @@ def run_codeflash_command(
|
|||
if not stdout_validated:
|
||||
logging.error("Failed to find expected output in candidate output")
|
||||
validated = False
|
||||
logging.info(f"Success: Expected output found in candidate output")
|
||||
logging.info("Success: Expected output found in candidate output")
|
||||
|
||||
return validated
|
||||
|
||||
|
|
@ -150,10 +152,9 @@ def build_command(
|
|||
pyproject_path = cwd / "pyproject.toml"
|
||||
has_codeflash_config = False
|
||||
if pyproject_path.exists():
|
||||
with contextlib.suppress(Exception):
|
||||
with open(pyproject_path, "rb") as f:
|
||||
pyproject_data = tomllib.load(f)
|
||||
has_codeflash_config = "tool" in pyproject_data and "codeflash" in pyproject_data["tool"]
|
||||
with contextlib.suppress(Exception), open(pyproject_path, "rb") as f:
|
||||
pyproject_data = tomllib.load(f)
|
||||
has_codeflash_config = "tool" in pyproject_data and "codeflash" in pyproject_data["tool"]
|
||||
|
||||
# Only pass --tests-root and --module-root if they're not configured in pyproject.toml
|
||||
if not has_codeflash_config:
|
||||
|
|
@ -206,7 +207,9 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
|
|||
if config.expected_unit_tests_count is not None:
|
||||
# Match the global test discovery message from optimizer.py which counts test invocations
|
||||
# Format: "Discovered X existing unit tests and Y replay tests in Z.Zs at /path/to/tests"
|
||||
unit_test_match = re.search(r"Discovered (\d+) existing unit tests? and \d+ replay tests? in [\d.]+s at", stdout)
|
||||
unit_test_match = re.search(
|
||||
r"Discovered (\d+) existing unit tests? and \d+ replay tests? in [\d.]+s at", stdout
|
||||
)
|
||||
if not unit_test_match:
|
||||
logging.error("Could not find global unit test count")
|
||||
return False
|
||||
|
|
@ -250,9 +253,9 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p
|
|||
clear_directory(test_root)
|
||||
command = ["uv", "run", "--no-project", "-m", "codeflash.main", "optimize", "workload.py"]
|
||||
env = os.environ.copy()
|
||||
env['PYTHONIOENCODING'] = 'utf-8'
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
process = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding='utf-8'
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding="utf-8"
|
||||
)
|
||||
|
||||
output = []
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import re
|
|||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
|
|
@ -42,36 +42,20 @@ def install_npm_dependencies(cwd: pathlib.Path) -> bool:
|
|||
node_modules = cwd / "node_modules"
|
||||
if not node_modules.exists():
|
||||
logging.info(f"Installing npm dependencies in {cwd}")
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=str(cwd),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
result = subprocess.run(["npm", "install"], cwd=str(cwd), capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
logging.error(f"npm install failed: {result.stderr}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def build_js_command(
|
||||
cwd: pathlib.Path,
|
||||
config: JSTestConfig,
|
||||
) -> list[str]:
|
||||
def build_js_command(cwd: pathlib.Path, config: JSTestConfig) -> list[str]:
|
||||
"""Build the codeflash CLI command for JS/TS optimization."""
|
||||
# JS projects are at code_to_optimize/js/code_to_optimize_*, which is 3 levels deep
|
||||
# So we need ../../../codeflash/main.py to get to the root
|
||||
python_path = "../../../codeflash/main.py"
|
||||
|
||||
base_command = [
|
||||
"uv",
|
||||
"run",
|
||||
"--no-project",
|
||||
python_path,
|
||||
"--file",
|
||||
str(config.file_path),
|
||||
"--no-pr",
|
||||
]
|
||||
base_command = ["uv", "run", "--no-project", python_path, "--file", str(config.file_path), "--no-pr"]
|
||||
|
||||
if config.function_name:
|
||||
base_command.extend(["--function", config.function_name])
|
||||
|
|
@ -79,11 +63,7 @@ def build_js_command(
|
|||
return base_command
|
||||
|
||||
|
||||
def validate_js_output(
|
||||
stdout: str,
|
||||
return_code: int,
|
||||
config: JSTestConfig,
|
||||
) -> bool:
|
||||
def validate_js_output(stdout: str, return_code: int, config: JSTestConfig) -> bool:
|
||||
"""Validate the output of a JS/TS optimization run."""
|
||||
if return_code != 0:
|
||||
logging.error(f"Command returned exit code {return_code} instead of 0")
|
||||
|
|
@ -104,15 +84,11 @@ def validate_js_output(
|
|||
logging.info(f"Performance improvement: {improvement_pct}%; Rate: {improvement_x}x")
|
||||
|
||||
if improvement_pct <= config.expected_improvement_pct:
|
||||
logging.error(
|
||||
f"Performance improvement {improvement_pct}% not above {config.expected_improvement_pct}%"
|
||||
)
|
||||
logging.error(f"Performance improvement {improvement_pct}% not above {config.expected_improvement_pct}%")
|
||||
return False
|
||||
|
||||
if improvement_x <= config.min_improvement_x:
|
||||
logging.error(
|
||||
f"Performance improvement rate {improvement_x}x not above {config.min_improvement_x}x"
|
||||
)
|
||||
logging.error(f"Performance improvement rate {improvement_x}x not above {config.min_improvement_x}x")
|
||||
return False
|
||||
|
||||
if config.expected_test_files is not None:
|
||||
|
|
@ -124,19 +100,14 @@ def validate_js_output(
|
|||
|
||||
num_test_files = int(test_files_match.group(1))
|
||||
if num_test_files < config.expected_test_files:
|
||||
logging.error(
|
||||
f"Expected at least {config.expected_test_files} test files, found {num_test_files}"
|
||||
)
|
||||
logging.error(f"Expected at least {config.expected_test_files} test files, found {num_test_files}")
|
||||
return False
|
||||
|
||||
logging.info(f"Success: Performance improvement is {improvement_pct}%")
|
||||
return True
|
||||
|
||||
|
||||
def run_js_codeflash_command(
|
||||
cwd: pathlib.Path,
|
||||
config: JSTestConfig,
|
||||
) -> bool:
|
||||
def run_js_codeflash_command(cwd: pathlib.Path, config: JSTestConfig) -> bool:
|
||||
"""Run codeflash optimization on a JavaScript/TypeScript project."""
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
|
@ -159,13 +130,7 @@ def run_js_codeflash_command(
|
|||
logging.info(f"Running: {' '.join(command)}")
|
||||
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
cwd=str(cwd),
|
||||
env=env,
|
||||
encoding="utf-8",
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding="utf-8"
|
||||
)
|
||||
|
||||
output = []
|
||||
|
|
@ -210,4 +175,4 @@ def run_with_retries(test_func, *args, **kwargs) -> int:
|
|||
logging.error("Test failed after all retries")
|
||||
return 1
|
||||
|
||||
return 1
|
||||
return 1
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ Examples:
|
|||
python run_js_e2e_tests.py # Run all tests sequentially
|
||||
python run_js_e2e_tests.py --test fibonacci # Run only fibonacci tests
|
||||
python run_js_e2e_tests.py --parallel # Run tests in parallel
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -62,19 +63,14 @@ def run_single_test(test_file: str) -> TestResult:
|
|||
output = f"Error running test: {e}"
|
||||
|
||||
duration = time.time() - start_time
|
||||
return TestResult(
|
||||
name=test_file.replace(".py", ""),
|
||||
success=success,
|
||||
duration=duration,
|
||||
output=output,
|
||||
)
|
||||
return TestResult(name=test_file.replace(".py", ""), success=success, duration=duration, output=output)
|
||||
|
||||
|
||||
def run_tests_sequential(tests: list[str]) -> list[TestResult]:
|
||||
"""Run tests sequentially."""
|
||||
results = []
|
||||
for test in tests:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Running: {test}")
|
||||
print("=" * 60)
|
||||
result = run_single_test(test)
|
||||
|
|
@ -124,22 +120,9 @@ def print_summary(results: list[TestResult]) -> None:
|
|||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Run JS/TS e2e tests")
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
type=str,
|
||||
help="Run only tests matching this pattern",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--parallel",
|
||||
action="store_true",
|
||||
help="Run tests in parallel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of parallel workers (default: 4)",
|
||||
)
|
||||
parser.add_argument("--test", type=str, help="Run only tests matching this pattern")
|
||||
parser.add_argument("--parallel", action="store_true", help="Run tests in parallel")
|
||||
parser.add_argument("--workers", type=int, default=4, help="Number of parallel workers (default: 4)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Filter tests if pattern specified
|
||||
|
|
|
|||
|
|
@ -1,13 +1,18 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
|
||||
import tempfile
|
||||
from codeflash.code_utils.code_extractor import resolve_star_import, DottedImportCollector
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.code_utils.code_extractor import (
|
||||
DottedImportCollector,
|
||||
add_needed_imports_from_module,
|
||||
find_preexisting_objects,
|
||||
resolve_star_import,
|
||||
)
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
|
||||
def test_add_needed_imports_from_module0() -> None:
|
||||
src_module = '''import ast
|
||||
import logging
|
||||
|
|
@ -127,8 +132,9 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
|
|||
new_module = add_needed_imports_from_module(src_module, dst_module, src_path, dst_path, project_root)
|
||||
assert new_module == expected
|
||||
|
||||
|
||||
def test_duplicated_imports() -> None:
|
||||
optim_code = '''from dataclasses import dataclass
|
||||
optim_code = """from dataclasses import dataclass
|
||||
from recce.adapter.base import BaseAdapter
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
|
@ -151,9 +157,9 @@ class DbtAdapter(BaseAdapter):
|
|||
parent_map[k] = [parent for parent in parents if parent in node_ids]
|
||||
|
||||
return parent_map
|
||||
'''
|
||||
"""
|
||||
|
||||
original_code = '''import json
|
||||
original_code = """import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
|
@ -244,8 +250,8 @@ class DbtAdapter(BaseAdapter):
|
|||
parent_map[k] = [parent for parent in parents if parent in node_ids]
|
||||
|
||||
return parent_map
|
||||
'''
|
||||
expected = '''import json
|
||||
"""
|
||||
expected = """import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
|
@ -340,7 +346,7 @@ class DbtAdapter(BaseAdapter):
|
|||
parent_map[k] = [parent for parent in parents if parent in node_ids]
|
||||
|
||||
return parent_map
|
||||
'''
|
||||
"""
|
||||
|
||||
function_name: str = "DbtAdapter.build_parent_map"
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
|
|
@ -355,14 +361,12 @@ class DbtAdapter(BaseAdapter):
|
|||
assert new_code == expected
|
||||
|
||||
|
||||
|
||||
|
||||
def test_resolve_star_import_with_all_defined():
|
||||
"""Test resolve_star_import when __all__ is explicitly defined."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
test_module = project_root / 'test_module.py'
|
||||
|
||||
test_module = project_root / "test_module.py"
|
||||
|
||||
# Create a test module with __all__ definition
|
||||
test_module.write_text('''
|
||||
__all__ = ['public_function', 'PublicClass']
|
||||
|
|
@ -380,9 +384,9 @@ class AnotherPublicClass:
|
|||
"""Not in __all__ so should be excluded."""
|
||||
pass
|
||||
''')
|
||||
|
||||
symbols = resolve_star_import('test_module', project_root)
|
||||
expected_symbols = {'public_function', 'PublicClass'}
|
||||
|
||||
symbols = resolve_star_import("test_module", project_root)
|
||||
expected_symbols = {"public_function", "PublicClass"}
|
||||
assert symbols == expected_symbols
|
||||
|
||||
|
||||
|
|
@ -390,10 +394,10 @@ def test_resolve_star_import_without_all_defined():
|
|||
"""Test resolve_star_import when __all__ is not defined - should include all public symbols."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
test_module = project_root / 'test_module.py'
|
||||
|
||||
test_module = project_root / "test_module.py"
|
||||
|
||||
# Create a test module without __all__ definition
|
||||
test_module.write_text('''
|
||||
test_module.write_text("""
|
||||
def public_func():
|
||||
pass
|
||||
|
||||
|
|
@ -405,10 +409,10 @@ class PublicClass:
|
|||
|
||||
PUBLIC_VAR = 42
|
||||
_private_var = 'secret'
|
||||
''')
|
||||
|
||||
symbols = resolve_star_import('test_module', project_root)
|
||||
expected_symbols = {'public_func', 'PublicClass', 'PUBLIC_VAR'}
|
||||
""")
|
||||
|
||||
symbols = resolve_star_import("test_module", project_root)
|
||||
expected_symbols = {"public_func", "PublicClass", "PUBLIC_VAR"}
|
||||
assert symbols == expected_symbols
|
||||
|
||||
|
||||
|
|
@ -416,26 +420,26 @@ def test_resolve_star_import_nonexistent_module():
|
|||
"""Test resolve_star_import with non-existent module - should return empty set."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
|
||||
symbols = resolve_star_import('nonexistent_module', project_root)
|
||||
|
||||
symbols = resolve_star_import("nonexistent_module", project_root)
|
||||
assert symbols == set()
|
||||
|
||||
|
||||
def test_dotted_import_collector_skips_star_imports():
|
||||
"""Test that DottedImportCollector correctly skips star imports."""
|
||||
code_with_star_import = '''
|
||||
code_with_star_import = """
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import os
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
module = cst.parse_module(code_with_star_import)
|
||||
collector = DottedImportCollector()
|
||||
module.visit(collector)
|
||||
|
||||
|
||||
# Should collect regular imports but skip the star import
|
||||
expected_imports = {'collections.defaultdict', 'os', 'pathlib.Path'}
|
||||
expected_imports = {"collections.defaultdict", "os", "pathlib.Path"}
|
||||
assert collector.imports == expected_imports
|
||||
|
||||
|
||||
|
|
@ -443,10 +447,10 @@ def test_add_needed_imports_with_star_import_resolution():
|
|||
"""Test add_needed_imports_from_module correctly handles star imports by resolving them."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
|
||||
|
||||
# Create a source module that exports symbols
|
||||
src_module = project_root / 'source_module.py'
|
||||
src_module.write_text('''
|
||||
src_module = project_root / "source_module.py"
|
||||
src_module.write_text("""
|
||||
__all__ = ['UtilFunction', 'HelperClass']
|
||||
|
||||
def UtilFunction():
|
||||
|
|
@ -454,40 +458,38 @@ def UtilFunction():
|
|||
|
||||
class HelperClass:
|
||||
pass
|
||||
''')
|
||||
|
||||
""")
|
||||
|
||||
# Create source code that uses star import
|
||||
src_code = '''
|
||||
src_code = """
|
||||
from source_module import *
|
||||
|
||||
def my_function():
|
||||
helper = HelperClass()
|
||||
UtilFunction()
|
||||
return helper
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
# Destination code that needs the imports resolved
|
||||
dst_code = '''
|
||||
dst_code = """
|
||||
def my_function():
|
||||
helper = HelperClass()
|
||||
UtilFunction()
|
||||
return helper
|
||||
'''
|
||||
|
||||
src_path = project_root / 'src.py'
|
||||
dst_path = project_root / 'dst.py'
|
||||
"""
|
||||
|
||||
src_path = project_root / "src.py"
|
||||
dst_path = project_root / "dst.py"
|
||||
src_path.write_text(src_code)
|
||||
|
||||
result = add_needed_imports_from_module(
|
||||
src_code, dst_code, src_path, dst_path, project_root
|
||||
)
|
||||
|
||||
|
||||
result = add_needed_imports_from_module(src_code, dst_code, src_path, dst_path, project_root)
|
||||
|
||||
# The result should have individual imports instead of star import
|
||||
expected_result = '''from source_module import HelperClass, UtilFunction
|
||||
expected_result = """from source_module import HelperClass, UtilFunction
|
||||
|
||||
def my_function():
|
||||
helper = HelperClass()
|
||||
UtilFunction()
|
||||
return helper
|
||||
'''
|
||||
"""
|
||||
assert result == expected_result
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -157,10 +157,7 @@ class TestParseConcurrencyMetrics:
|
|||
!@######CONC:test_module:TestClass:test_func:my_async_func:1:50000000:10000000:5######@!
|
||||
More output here
|
||||
"""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout=perf_stdout,
|
||||
)
|
||||
test_results = TestResults(test_results=[], perf_stdout=perf_stdout)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "my_async_func")
|
||||
|
||||
|
|
@ -177,10 +174,7 @@ More output here
|
|||
!@######CONC:test_module:TestClass:test_func:target_func:2:60000000:10000000:5######@!
|
||||
!@######CONC:test_module:TestClass:test_func:other_func:1:30000000:15000000:5######@!
|
||||
"""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout=perf_stdout,
|
||||
)
|
||||
test_results = TestResults(test_results=[], perf_stdout=perf_stdout)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "target_func")
|
||||
|
||||
|
|
@ -195,10 +189,7 @@ More output here
|
|||
"""Test parsing when function name doesn't match."""
|
||||
perf_stdout = """!@######CONC:test_module:TestClass:test_func:other_func:1:50000000:10000000:5######@!
|
||||
"""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout=perf_stdout,
|
||||
)
|
||||
test_results = TestResults(test_results=[], perf_stdout=perf_stdout)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "nonexistent_func")
|
||||
|
||||
|
|
@ -206,10 +197,7 @@ More output here
|
|||
|
||||
def test_parse_concurrency_metrics_empty_stdout(self):
|
||||
"""Test parsing with empty stdout."""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout="",
|
||||
)
|
||||
test_results = TestResults(test_results=[], perf_stdout="")
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "any_func")
|
||||
|
||||
|
|
@ -217,10 +205,7 @@ More output here
|
|||
|
||||
def test_parse_concurrency_metrics_none_stdout(self):
|
||||
"""Test parsing with None stdout."""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout=None,
|
||||
)
|
||||
test_results = TestResults(test_results=[], perf_stdout=None)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "any_func")
|
||||
|
||||
|
|
@ -293,8 +278,7 @@ class TestConcurrencyRatioComparison:
|
|||
|
||||
# Non-blocking should have significantly higher concurrency ratio
|
||||
assert nonblocking_ratio > blocking_ratio, (
|
||||
f"Non-blocking ratio ({nonblocking_ratio:.2f}) should be greater than "
|
||||
f"blocking ratio ({blocking_ratio:.2f})"
|
||||
f"Non-blocking ratio ({nonblocking_ratio:.2f}) should be greater than blocking ratio ({blocking_ratio:.2f})"
|
||||
)
|
||||
|
||||
# The difference should be substantial (non-blocking should be at least 2x better)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import (
|
||||
|
|
@ -31,13 +32,13 @@ async def async_function_without_return():
|
|||
def regular_function():
|
||||
return 10
|
||||
"""
|
||||
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(async_function)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
|
||||
assert "async_function_with_return" in function_names
|
||||
assert "regular_function" in function_names
|
||||
assert "async_function_without_return" not in function_names
|
||||
|
|
@ -58,21 +59,21 @@ class AsyncClass:
|
|||
def sync_method(self):
|
||||
return "sync result"
|
||||
"""
|
||||
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(code_with_async_method)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
|
||||
found_functions = functions_found[file_path]
|
||||
function_names = [fn.function_name for fn in found_functions]
|
||||
qualified_names = [fn.qualified_name for fn in found_functions]
|
||||
|
||||
|
||||
assert "async_method" in function_names
|
||||
assert "AsyncClass.async_method" in qualified_names
|
||||
|
||||
|
||||
assert "sync_method" in function_names
|
||||
assert "AsyncClass.sync_method" in qualified_names
|
||||
|
||||
|
||||
assert "async_method_no_return" not in function_names
|
||||
|
||||
|
||||
|
|
@ -92,13 +93,13 @@ def outer_sync():
|
|||
|
||||
return inner_async
|
||||
"""
|
||||
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(nested_async)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
|
||||
assert "outer_async" in function_names
|
||||
assert "outer_sync" in function_names
|
||||
assert "inner_async" not in function_names
|
||||
|
|
@ -122,16 +123,16 @@ class MyClass:
|
|||
async def async_property(self):
|
||||
return await self.get_value()
|
||||
"""
|
||||
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(async_decorators)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
|
||||
assert "async_static_method" in function_names
|
||||
assert "async_class_method" in function_names
|
||||
|
||||
|
||||
assert "async_property" not in function_names
|
||||
|
||||
|
||||
|
|
@ -151,13 +152,13 @@ async def regular_async_with_return():
|
|||
result = await compute()
|
||||
return result
|
||||
"""
|
||||
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(async_generators)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
|
||||
assert "async_generator_with_return" in function_names
|
||||
assert "regular_async_with_return" in function_names
|
||||
assert "async_generator_no_return" not in function_names
|
||||
|
|
@ -183,23 +184,23 @@ class AsyncContainer:
|
|||
async def async_classmethod(cls):
|
||||
return "classmethod"
|
||||
"""
|
||||
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(code)
|
||||
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "top_level_async")
|
||||
assert result.is_top_level
|
||||
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "async_method", class_name="AsyncContainer")
|
||||
assert result.is_top_level
|
||||
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "nested_async", class_name="AsyncContainer")
|
||||
assert not result.is_top_level
|
||||
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "async_static", class_name="AsyncContainer")
|
||||
assert result.is_top_level
|
||||
assert result.is_staticmethod
|
||||
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "async_classmethod", class_name="AsyncContainer")
|
||||
assert result.is_top_level
|
||||
assert result.is_classmethod
|
||||
|
|
@ -224,17 +225,14 @@ class MixedClass:
|
|||
def sync_method(self):
|
||||
return self.operation()
|
||||
"""
|
||||
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(mixed_code)
|
||||
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests",
|
||||
project_root_path=".",
|
||||
test_framework="pytest",
|
||||
tests_project_rootdir=Path()
|
||||
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
|
||||
)
|
||||
|
||||
|
||||
functions, functions_count, _ = get_functions_to_optimize(
|
||||
optimize_all=None,
|
||||
replay_test=None,
|
||||
|
|
@ -245,15 +243,15 @@ class MixedClass:
|
|||
project_root=file_path.parent,
|
||||
module_root=file_path.parent,
|
||||
)
|
||||
|
||||
|
||||
assert functions_count == 4
|
||||
|
||||
|
||||
function_names = [fn.function_name for fn in functions[file_path]]
|
||||
assert "async_func_one" in function_names
|
||||
assert "sync_func_one" in function_names
|
||||
assert "async_method" in function_names
|
||||
assert "sync_method" in function_names
|
||||
|
||||
|
||||
assert "async_func_two" not in function_names
|
||||
|
||||
|
||||
|
|
@ -277,17 +275,14 @@ class MixedClass:
|
|||
def sync_method(self):
|
||||
return self.operation()
|
||||
"""
|
||||
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(mixed_code)
|
||||
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests",
|
||||
project_root_path=".",
|
||||
test_framework="pytest",
|
||||
tests_project_rootdir=Path()
|
||||
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
|
||||
)
|
||||
|
||||
|
||||
functions, functions_count, _ = get_functions_to_optimize(
|
||||
optimize_all=None,
|
||||
replay_test=None,
|
||||
|
|
@ -298,10 +293,10 @@ class MixedClass:
|
|||
project_root=file_path.parent,
|
||||
module_root=file_path.parent,
|
||||
)
|
||||
|
||||
|
||||
# Now async functions are always included, so we expect 4 functions (not 2)
|
||||
assert functions_count == 4
|
||||
|
||||
|
||||
function_names = [fn.function_name for fn in functions[file_path]]
|
||||
assert "sync_func_one" in function_names
|
||||
assert "sync_method" in function_names
|
||||
|
|
@ -327,13 +322,13 @@ async def module_level_async():
|
|||
return 3
|
||||
return LocalClass()
|
||||
"""
|
||||
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(complex_structure)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
|
||||
found_functions = functions_found[file_path]
|
||||
|
||||
|
||||
for fn in found_functions:
|
||||
if fn.function_name == "outer_method":
|
||||
assert len(fn.parents) == 1
|
||||
|
|
@ -345,4 +340,4 @@ async def module_level_async():
|
|||
assert fn.parents[1].name == "InnerClass"
|
||||
elif fn.function_name == "module_level_async":
|
||||
assert len(fn.parents) == 0
|
||||
assert fn.qualified_name == "module_level_async"
|
||||
assert fn.qualified_name == "module_level_async"
|
||||
|
|
|
|||
|
|
@ -7,11 +7,15 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
add_async_decorator_to_function,
|
||||
inject_profiling_into_existing_test,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode, TestType
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
|
||||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function, inject_profiling_into_existing_test
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_bubble_sort_behavior_results() -> None:
|
||||
|
|
@ -51,15 +55,16 @@ async def test_async_sort():
|
|||
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
|
||||
|
||||
# For async functions, instrument the source module directly with decorators
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert source_success
|
||||
|
||||
|
||||
# Verify the file was modified
|
||||
instrumented_source = fto_path.read_text("utf-8")
|
||||
assert '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' in instrumented_source
|
||||
assert (
|
||||
'''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n'''
|
||||
in instrumented_source
|
||||
)
|
||||
|
||||
# Add codeflash capture
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
|
@ -122,7 +127,6 @@ async def test_async_sort():
|
|||
expected_stdout = "codeflash stdout: Async sorting list\nresult: [0, 1, 2, 3, 4, 5]\n"
|
||||
assert expected_stdout == results_list[0].stdout
|
||||
|
||||
|
||||
assert results_list[1].id.function_getting_tested == "async_sorter"
|
||||
assert results_list[1].id.test_function_name == "test_async_sort"
|
||||
assert results_list[1].did_pass
|
||||
|
|
@ -178,12 +182,10 @@ async def test_async_class_sort():
|
|||
is_async=True,
|
||||
)
|
||||
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert source_success
|
||||
|
||||
|
||||
# Verify the file was modified
|
||||
instrumented_source = fto_path.read_text("utf-8")
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
|
|
@ -233,17 +235,17 @@ async def test_async_class_sort():
|
|||
testing_time=0.1,
|
||||
)
|
||||
|
||||
|
||||
assert test_results is not None
|
||||
assert test_results.test_results is not None
|
||||
|
||||
results_list = test_results.test_results
|
||||
assert len(results_list) == 2, f"Expected 2 results but got {len(results_list)}: {[r.id.function_getting_tested for r in results_list]}"
|
||||
assert len(results_list) == 2, (
|
||||
f"Expected 2 results but got {len(results_list)}: {[r.id.function_getting_tested for r in results_list]}"
|
||||
)
|
||||
|
||||
init_result = results_list[0]
|
||||
sorter_result = results_list[1]
|
||||
|
||||
|
||||
assert sorter_result.id.function_getting_tested == "sorter"
|
||||
assert sorter_result.id.test_class_name is None
|
||||
assert sorter_result.id.test_function_name == "test_async_class_sort"
|
||||
|
|
@ -292,15 +294,16 @@ async def test_async_perf():
|
|||
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
|
||||
|
||||
# Instrument the source module with async performance decorators
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.PERFORMANCE
|
||||
)
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.PERFORMANCE)
|
||||
|
||||
assert source_success
|
||||
|
||||
|
||||
# Verify the file was modified
|
||||
instrumented_source = fto_path.read_text("utf-8")
|
||||
assert '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' == instrumented_source
|
||||
assert (
|
||||
instrumented_source
|
||||
== '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n'''
|
||||
)
|
||||
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
||||
|
|
@ -358,7 +361,6 @@ async def test_async_perf():
|
|||
test_path.unlink()
|
||||
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_function_error_handling() -> None:
|
||||
test_code = """import asyncio
|
||||
|
|
@ -371,8 +373,12 @@ async def test_async_error():
|
|||
with pytest.raises(ValueError, match="Test error"):
|
||||
await async_error_function([1, 2, 3])"""
|
||||
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_error_temp.py").resolve()
|
||||
test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_error_perf_temp.py").resolve()
|
||||
test_path = (
|
||||
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_error_temp.py"
|
||||
).resolve()
|
||||
test_path_perf = (
|
||||
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_error_perf_temp.py"
|
||||
).resolve()
|
||||
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/async_bubble_sort.py").resolve()
|
||||
original_code = fto_path.read_text("utf-8")
|
||||
|
||||
|
|
@ -384,27 +390,27 @@ async def async_error_function(lst):
|
|||
await asyncio.sleep(0.001) # Small delay
|
||||
raise ValueError("Test error")
|
||||
"""
|
||||
|
||||
|
||||
modified_code = original_code + error_func_code
|
||||
fto_path.write_text(modified_code, "utf-8")
|
||||
|
||||
|
||||
with test_path.open("w") as f:
|
||||
f.write(test_code)
|
||||
|
||||
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
|
||||
project_root_path = (Path(__file__).parent / "..").resolve()
|
||||
|
||||
func = FunctionToOptimize(function_name="async_error_function", parents=[], file_path=Path(fto_path), is_async=True)
|
||||
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_error_function", parents=[], file_path=Path(fto_path), is_async=True
|
||||
)
|
||||
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert source_success
|
||||
|
||||
|
||||
# Verify the file was modified
|
||||
instrumented_source = fto_path.read_text("utf-8")
|
||||
|
||||
|
||||
expected_instrumented_source = """import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
|
|
@ -508,7 +514,7 @@ async def async_error_function(lst):
|
|||
assert test_results is not None
|
||||
assert test_results.test_results is not None
|
||||
assert len(test_results.test_results) >= 1
|
||||
|
||||
|
||||
result = test_results.test_results[0]
|
||||
assert result.id.function_getting_tested == "async_error_function"
|
||||
assert result.did_pass
|
||||
|
|
@ -539,8 +545,12 @@ async def test_async_multi():
|
|||
output2 = await async_sorter(input2)
|
||||
assert output2 == [7, 9]"""
|
||||
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_multi_temp.py").resolve()
|
||||
test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_multi_perf_temp.py").resolve()
|
||||
test_path = (
|
||||
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_multi_temp.py"
|
||||
).resolve()
|
||||
test_path_perf = (
|
||||
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_multi_perf_temp.py"
|
||||
).resolve()
|
||||
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/async_bubble_sort.py").resolve()
|
||||
original_code = fto_path.read_text("utf-8")
|
||||
|
||||
|
|
@ -553,9 +563,7 @@ async def test_async_multi():
|
|||
|
||||
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
|
||||
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert source_success
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
|
@ -606,17 +614,17 @@ async def test_async_multi():
|
|||
assert test_results is not None
|
||||
assert test_results.test_results is not None
|
||||
assert len(test_results.test_results) >= 2
|
||||
|
||||
|
||||
results_list = test_results.test_results
|
||||
function_calls = [r for r in results_list if r.id.function_getting_tested == "async_sorter"]
|
||||
assert len(function_calls) == 2
|
||||
|
||||
|
||||
first_call = function_calls[0]
|
||||
second_call = function_calls[1]
|
||||
|
||||
|
||||
assert first_call.stdout == "codeflash stdout: Async sorting list\nresult: [3, 4, 5]\n"
|
||||
assert second_call.stdout == "codeflash stdout: Async sorting list\nresult: [7, 9]\n"
|
||||
|
||||
|
||||
assert first_call.did_pass
|
||||
assert second_call.did_pass
|
||||
assert first_call.runtime is None or first_call.runtime >= 0
|
||||
|
|
@ -655,7 +663,9 @@ async def test_async_edge_cases():
|
|||
assert result_sorted == [1, 2, 3, 4]"""
|
||||
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_edge_temp.py").resolve()
|
||||
test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_edge_perf_temp.py").resolve()
|
||||
test_path_perf = (
|
||||
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_edge_perf_temp.py"
|
||||
).resolve()
|
||||
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/async_bubble_sort.py").resolve()
|
||||
original_code = fto_path.read_text("utf-8")
|
||||
|
||||
|
|
@ -668,9 +678,7 @@ async def test_async_edge_cases():
|
|||
|
||||
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
|
||||
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert source_success
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
|
@ -721,20 +729,20 @@ async def test_async_edge_cases():
|
|||
assert test_results is not None
|
||||
assert test_results.test_results is not None
|
||||
assert len(test_results.test_results) >= 3 # 3 function calls for edge cases
|
||||
|
||||
|
||||
results_list = test_results.test_results
|
||||
function_calls = [r for r in results_list if r.id.function_getting_tested == "async_sorter"]
|
||||
assert len(function_calls) == 3
|
||||
|
||||
|
||||
# Verify all calls passed
|
||||
for call in function_calls:
|
||||
assert call.did_pass
|
||||
assert call.runtime is None or call.runtime >= 0
|
||||
|
||||
|
||||
empty_call = function_calls[0]
|
||||
single_call = function_calls[1]
|
||||
sorted_call = function_calls[2]
|
||||
|
||||
|
||||
assert empty_call.stdout == "codeflash stdout: Async sorting list\nresult: []\n"
|
||||
assert single_call.stdout == "codeflash stdout: Async sorting list\nresult: [42]\n"
|
||||
assert sorted_call.stdout == "codeflash stdout: Async sorting list\nresult: [1, 2, 3, 4]\n"
|
||||
|
|
@ -761,7 +769,7 @@ def test_sync_function_behavior_in_async_test_environment() -> None:
|
|||
print(f"result: {result}")
|
||||
return result
|
||||
"""
|
||||
|
||||
|
||||
test_code = """from code_to_optimize.sync_bubble_sort import sync_sorter
|
||||
|
||||
|
||||
|
|
@ -774,26 +782,32 @@ def test_sync_sort():
|
|||
output = sync_sorter(input)
|
||||
assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]"""
|
||||
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_sync_in_async_temp.py").resolve()
|
||||
test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_sync_in_async_perf_temp.py").resolve()
|
||||
test_path = (
|
||||
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_sync_in_async_temp.py"
|
||||
).resolve()
|
||||
test_path_perf = (
|
||||
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_sync_in_async_perf_temp.py"
|
||||
).resolve()
|
||||
sync_fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/sync_bubble_sort.py").resolve()
|
||||
|
||||
|
||||
try:
|
||||
with sync_fto_path.open("w") as f:
|
||||
f.write(sync_sorter_code)
|
||||
|
||||
|
||||
with test_path.open("w") as f:
|
||||
f.write(test_code)
|
||||
|
||||
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
|
||||
project_root_path = (Path(__file__).parent / "..").resolve()
|
||||
|
||||
func = FunctionToOptimize(function_name="sync_sorter", parents=[], file_path=Path(sync_fto_path), is_async=False)
|
||||
func = FunctionToOptimize(
|
||||
function_name="sync_sorter", parents=[], file_path=Path(sync_fto_path), is_async=False
|
||||
)
|
||||
|
||||
original_cwd = os.getcwd()
|
||||
run_cwd = project_root_path
|
||||
os.chdir(run_cwd)
|
||||
|
||||
|
||||
success, instrumented_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(6, 13), CodePosition(10, 13)], # Lines where sync_sorter is called
|
||||
|
|
@ -802,10 +816,10 @@ def test_sync_sort():
|
|||
mode=TestingMode.BEHAVIOR,
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
assert success
|
||||
assert instrumented_test is not None
|
||||
|
||||
|
||||
with test_path.open("w") as f:
|
||||
f.write(instrumented_test)
|
||||
|
||||
|
|
@ -856,7 +870,7 @@ def test_sync_sort():
|
|||
|
||||
assert test_results is not None
|
||||
assert test_results.test_results is not None
|
||||
|
||||
|
||||
results_list = test_results.test_results
|
||||
assert results_list[0].id.function_getting_tested == "sync_sorter"
|
||||
assert results_list[0].id.iteration_id == "1_0"
|
||||
|
|
@ -935,7 +949,7 @@ async def async_merge_sort(lst: List[Union[int, float]]) -> List[Union[int, floa
|
|||
return result
|
||||
|
||||
"""
|
||||
|
||||
|
||||
test_code = """import asyncio
|
||||
import pytest
|
||||
from code_to_optimize.mixed_sort import sync_quick_sort, async_merge_sort
|
||||
|
|
@ -954,27 +968,29 @@ async def test_mixed_sorting():
|
|||
assert async_output == [2, 3, 5, 6, 9]"""
|
||||
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_mixed_sort_temp.py").resolve()
|
||||
test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_mixed_sort_perf_temp.py").resolve()
|
||||
test_path_perf = (
|
||||
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_mixed_sort_perf_temp.py"
|
||||
).resolve()
|
||||
mixed_fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/mixed_sort.py").resolve()
|
||||
|
||||
|
||||
try:
|
||||
with mixed_fto_path.open("w") as f:
|
||||
f.write(mixed_module_code)
|
||||
|
||||
|
||||
with test_path.open("w") as f:
|
||||
f.write(test_code)
|
||||
|
||||
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
|
||||
project_root_path = (Path(__file__).parent / "..").resolve()
|
||||
|
||||
async_func = FunctionToOptimize(function_name="async_merge_sort", parents=[], file_path=Path(mixed_fto_path), is_async=True)
|
||||
|
||||
source_success = add_async_decorator_to_function(
|
||||
mixed_fto_path, async_func, TestingMode.BEHAVIOR
|
||||
async_func = FunctionToOptimize(
|
||||
function_name="async_merge_sort", parents=[], file_path=Path(mixed_fto_path), is_async=True
|
||||
)
|
||||
|
||||
source_success = add_async_decorator_to_function(mixed_fto_path, async_func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert source_success
|
||||
|
||||
|
||||
# Verify the file was modified
|
||||
instrumented_source = mixed_fto_path.read_text("utf-8")
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
|
|
@ -1027,11 +1043,11 @@ async def test_mixed_sorting():
|
|||
|
||||
assert test_results is not None
|
||||
assert test_results.test_results is not None
|
||||
|
||||
|
||||
results_list = test_results.test_results
|
||||
async_calls = [r for r in results_list if r.id.function_getting_tested == "async_merge_sort"]
|
||||
assert len(async_calls) >= 1
|
||||
|
||||
|
||||
for call in async_calls:
|
||||
assert call.did_pass
|
||||
assert call.runtime is None or call.runtime >= 0
|
||||
|
|
@ -1043,4 +1059,4 @@ async def test_mixed_sorting():
|
|||
if test_path.exists():
|
||||
test_path.unlink()
|
||||
if test_path_perf.exists():
|
||||
test_path_perf.unlink()
|
||||
test_path_perf.unlink()
|
||||
|
|
|
|||
|
|
@ -4,22 +4,17 @@ import asyncio
|
|||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import dill as pickle
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import (
|
||||
codeflash_behavior_async,
|
||||
codeflash_performance_async,
|
||||
)
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async, codeflash_performance_async
|
||||
from codeflash.verification.codeflash_capture import VerificationType
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
class TestAsyncWrapperSQLiteValidation:
|
||||
|
||||
@pytest.fixture
|
||||
def test_env_setup(self, request):
|
||||
original_env = {}
|
||||
|
|
@ -31,13 +26,13 @@ class TestAsyncWrapperSQLiteValidation:
|
|||
"CODEFLASH_TEST_FUNCTION": request.node.name,
|
||||
"CODEFLASH_CURRENT_LINE_ID": "test_unit",
|
||||
}
|
||||
|
||||
|
||||
for key, value in test_env.items():
|
||||
original_env[key] = os.environ.get(key)
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
yield test_env
|
||||
|
||||
|
||||
for key, original_value in original_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
|
|
@ -48,45 +43,54 @@ class TestAsyncWrapperSQLiteValidation:
|
|||
def temp_db_path(self, test_env_setup):
|
||||
iteration = test_env_setup["CODEFLASH_TEST_ITERATION"]
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
|
||||
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
|
||||
|
||||
|
||||
yield db_path
|
||||
|
||||
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_behavior_async_basic_function(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def simple_async_add(a: int, b: int) -> int:
|
||||
await asyncio.sleep(0.001)
|
||||
return a + b
|
||||
|
||||
os.environ['CODEFLASH_CURRENT_LINE_ID'] = 'simple_async_add_59'
|
||||
os.environ["CODEFLASH_CURRENT_LINE_ID"] = "simple_async_add_59"
|
||||
result = await simple_async_add(5, 3)
|
||||
|
||||
|
||||
assert result == 8
|
||||
|
||||
|
||||
assert temp_db_path.exists()
|
||||
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_results'")
|
||||
assert cur.fetchone() is not None
|
||||
|
||||
|
||||
cur.execute("SELECT * FROM test_results")
|
||||
rows = cur.fetchall()
|
||||
|
||||
|
||||
assert len(rows) == 1
|
||||
row = rows[0]
|
||||
|
||||
(test_module_path, test_class_name, test_function_name, function_getting_tested,
|
||||
loop_index, iteration_id, runtime, return_value_blob, verification_type) = row
|
||||
|
||||
|
||||
(
|
||||
test_module_path,
|
||||
test_class_name,
|
||||
test_function_name,
|
||||
function_getting_tested,
|
||||
loop_index,
|
||||
iteration_id,
|
||||
runtime,
|
||||
return_value_blob,
|
||||
verification_type,
|
||||
) = row
|
||||
|
||||
assert test_module_path == __name__
|
||||
assert test_class_name == "TestAsyncWrapperSQLiteValidation"
|
||||
assert test_class_name == "TestAsyncWrapperSQLiteValidation"
|
||||
assert test_function_name == "test_behavior_async_basic_function"
|
||||
assert function_getting_tested == "simple_async_add"
|
||||
assert loop_index == 1
|
||||
|
|
@ -94,19 +98,18 @@ class TestAsyncWrapperSQLiteValidation:
|
|||
assert iteration_id.startswith("simple_async_add_") and iteration_id.endswith("_0")
|
||||
assert runtime > 0
|
||||
assert verification_type == VerificationType.FUNCTION_CALL.value
|
||||
|
||||
|
||||
unpickled_data = pickle.loads(return_value_blob)
|
||||
args, kwargs, return_val = unpickled_data
|
||||
|
||||
|
||||
assert args == (5, 3)
|
||||
assert kwargs == {}
|
||||
assert return_val == 8
|
||||
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_behavior_async_exception_handling(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_divide(a: int, b: int) -> float:
|
||||
await asyncio.sleep(0.001)
|
||||
|
|
@ -116,35 +119,35 @@ class TestAsyncWrapperSQLiteValidation:
|
|||
|
||||
result = await async_divide(10, 2)
|
||||
assert result == 5.0
|
||||
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot divide by zero"):
|
||||
await async_divide(10, 0)
|
||||
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT * FROM test_results ORDER BY iteration_id")
|
||||
rows = cur.fetchall()
|
||||
|
||||
|
||||
assert len(rows) == 2
|
||||
|
||||
|
||||
success_row = rows[0]
|
||||
success_data = pickle.loads(success_row[7]) # return_value_blob
|
||||
args, kwargs, return_val = success_data
|
||||
assert args == (10, 2)
|
||||
assert return_val == 5.0
|
||||
|
||||
|
||||
# Check exception record
|
||||
exception_row = rows[1]
|
||||
exception_data = pickle.loads(exception_row[7]) # return_value_blob
|
||||
assert isinstance(exception_data, ValueError)
|
||||
assert str(exception_data) == "Cannot divide by zero"
|
||||
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_async_no_database_storage(self, test_env_setup, temp_db_path, capsys):
|
||||
"""Test performance async decorator doesn't store to database."""
|
||||
|
||||
|
||||
@codeflash_performance_async
|
||||
async def async_multiply(a: int, b: int) -> int:
|
||||
"""Async function for performance testing."""
|
||||
|
|
@ -152,27 +155,26 @@ class TestAsyncWrapperSQLiteValidation:
|
|||
return a * b
|
||||
|
||||
result = await async_multiply(4, 7)
|
||||
|
||||
|
||||
assert result == 28
|
||||
|
||||
|
||||
assert not temp_db_path.exists()
|
||||
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output_lines = captured.out.strip().split('\n')
|
||||
|
||||
output_lines = captured.out.strip().split("\n")
|
||||
|
||||
assert len([line for line in output_lines if "!$######" in line]) == 1
|
||||
assert len([line for line in output_lines if "!######" in line and "######!" in line]) == 1
|
||||
|
||||
|
||||
closing_tag = [line for line in output_lines if "!######" in line and "######!" in line][0]
|
||||
assert "async_multiply" in closing_tag
|
||||
|
||||
|
||||
timing_part = closing_tag.split(":")[-1].replace("######!", "")
|
||||
timing_value = int(timing_part)
|
||||
assert timing_value > 0 # Should have positive timing
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_calls_indexing(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_increment(value: int) -> int:
|
||||
await asyncio.sleep(0.001)
|
||||
|
|
@ -183,105 +185,86 @@ class TestAsyncWrapperSQLiteValidation:
|
|||
for i in range(3):
|
||||
result = await async_increment(i)
|
||||
results.append(result)
|
||||
|
||||
|
||||
assert results == [1, 2, 3]
|
||||
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT iteration_id, return_value FROM test_results ORDER BY iteration_id")
|
||||
rows = cur.fetchall()
|
||||
|
||||
|
||||
assert len(rows) == 3
|
||||
|
||||
|
||||
actual_ids = [row[0] for row in rows]
|
||||
assert len(actual_ids) == 3
|
||||
|
||||
base_pattern = actual_ids[0].rsplit('_', 1)[0] # e.g., "async_increment_199"
|
||||
|
||||
base_pattern = actual_ids[0].rsplit("_", 1)[0] # e.g., "async_increment_199"
|
||||
expected_pattern = [f"{base_pattern}_{i}" for i in range(3)]
|
||||
assert actual_ids == expected_pattern
|
||||
|
||||
|
||||
for i, (_, return_value_blob) in enumerate(rows):
|
||||
args, kwargs, return_val = pickle.loads(return_value_blob)
|
||||
assert args == (i,)
|
||||
assert return_val == i + 1
|
||||
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_async_function_with_kwargs(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def complex_async_func(
|
||||
pos_arg: str,
|
||||
*args: int,
|
||||
keyword_arg: str = "default",
|
||||
**kwargs: str
|
||||
) -> dict:
|
||||
async def complex_async_func(pos_arg: str, *args: int, keyword_arg: str = "default", **kwargs: str) -> dict:
|
||||
await asyncio.sleep(0.001)
|
||||
return {
|
||||
"pos_arg": pos_arg,
|
||||
"args": args,
|
||||
"keyword_arg": keyword_arg,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
return {"pos_arg": pos_arg, "args": args, "keyword_arg": keyword_arg, "kwargs": kwargs}
|
||||
|
||||
result = await complex_async_func("hello", 1, 2, 3, keyword_arg="custom", extra1="value1", extra2="value2")
|
||||
|
||||
result = await complex_async_func(
|
||||
"hello",
|
||||
1, 2, 3,
|
||||
keyword_arg="custom",
|
||||
extra1="value1",
|
||||
extra2="value2"
|
||||
)
|
||||
|
||||
expected_result = {
|
||||
"pos_arg": "hello",
|
||||
"args": (1, 2, 3),
|
||||
"keyword_arg": "custom",
|
||||
"kwargs": {"extra1": "value1", "extra2": "value2"}
|
||||
"kwargs": {"extra1": "value1", "extra2": "value2"},
|
||||
}
|
||||
|
||||
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT return_value FROM test_results")
|
||||
row = cur.fetchone()
|
||||
|
||||
|
||||
stored_args, stored_kwargs, stored_result = pickle.loads(row[0])
|
||||
|
||||
|
||||
assert stored_args == ("hello", 1, 2, 3)
|
||||
assert stored_kwargs == {"keyword_arg": "custom", "extra1": "value1", "extra2": "value2"}
|
||||
assert stored_result == expected_result
|
||||
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_schema_validation(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def schema_test_func() -> str:
|
||||
return "schema_test"
|
||||
|
||||
|
||||
await schema_test_func()
|
||||
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
|
||||
cur.execute("PRAGMA table_info(test_results)")
|
||||
columns = cur.fetchall()
|
||||
|
||||
|
||||
expected_columns = [
|
||||
(0, 'test_module_path', 'TEXT', 0, None, 0),
|
||||
(1, 'test_class_name', 'TEXT', 0, None, 0),
|
||||
(2, 'test_function_name', 'TEXT', 0, None, 0),
|
||||
(3, 'function_getting_tested', 'TEXT', 0, None, 0),
|
||||
(4, 'loop_index', 'INTEGER', 0, None, 0),
|
||||
(5, 'iteration_id', 'TEXT', 0, None, 0),
|
||||
(6, 'runtime', 'INTEGER', 0, None, 0),
|
||||
(7, 'return_value', 'BLOB', 0, None, 0),
|
||||
(8, 'verification_type', 'TEXT', 0, None, 0)
|
||||
(0, "test_module_path", "TEXT", 0, None, 0),
|
||||
(1, "test_class_name", "TEXT", 0, None, 0),
|
||||
(2, "test_function_name", "TEXT", 0, None, 0),
|
||||
(3, "function_getting_tested", "TEXT", 0, None, 0),
|
||||
(4, "loop_index", "INTEGER", 0, None, 0),
|
||||
(5, "iteration_id", "TEXT", 0, None, 0),
|
||||
(6, "runtime", "INTEGER", 0, None, 0),
|
||||
(7, "return_value", "BLOB", 0, None, 0),
|
||||
(8, "verification_type", "TEXT", 0, None, 0),
|
||||
]
|
||||
|
||||
|
||||
assert columns == expected_columns
|
||||
con.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -3969,7 +3969,7 @@ def test_dependency_classes_kept_in_read_writable_context(tmp_path: Path) -> Non
|
|||
as types or in match statements, those classes are included in the optimization
|
||||
context, even though they don't contain any target functions.
|
||||
"""
|
||||
code = '''
|
||||
code = """
|
||||
import dataclasses
|
||||
import enum
|
||||
import typing as t
|
||||
|
|
@ -4013,20 +4013,13 @@ def reify_channel_message(data: dict) -> MessageIn:
|
|||
return MessageInBeginExfiltration()
|
||||
case _:
|
||||
raise ValueError(f"Unknown message kind: '{kind}'")
|
||||
'''
|
||||
"""
|
||||
code_path = tmp_path / "message.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="reify_channel_message",
|
||||
file_path=code_path,
|
||||
parents=[],
|
||||
)
|
||||
func_to_optimize = FunctionToOptimize(function_name="reify_channel_message", file_path=code_path, parents=[])
|
||||
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize=func_to_optimize, project_root_path=tmp_path)
|
||||
|
||||
expected_read_writable = """
|
||||
```python:message.py
|
||||
|
|
@ -4098,10 +4091,7 @@ class MyCustomDict(UserDict):
|
|||
parents=[FunctionParent(name="MyCustomDict", type="ClassDef")],
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize=func_to_optimize, project_root_path=tmp_path)
|
||||
|
||||
# The testgen context should include the UserDict __init__ method
|
||||
testgen_context = code_ctx.testgen_context.markdown
|
||||
|
|
@ -4146,9 +4136,7 @@ def second_helper():
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")]
|
||||
)
|
||||
|
||||
# Use a small optim_token_limit that allows read-writable but not read-only
|
||||
|
|
@ -4203,11 +4191,7 @@ def target_function(obj: TypeClass) -> int:
|
|||
main_path = package_dir / "main.py"
|
||||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_function",
|
||||
file_path=main_path,
|
||||
parents=[],
|
||||
)
|
||||
func_to_optimize = FunctionToOptimize(function_name="target_function", file_path=main_path, parents=[])
|
||||
|
||||
# Use a testgen_token_limit that:
|
||||
# - Is exceeded by full context with imported class (~1500 tokens)
|
||||
|
|
@ -4251,11 +4235,7 @@ def target_function():
|
|||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_function",
|
||||
file_path=file_path,
|
||||
parents=[],
|
||||
)
|
||||
func_to_optimize = FunctionToOptimize(function_name="target_function", file_path=file_path, parents=[])
|
||||
|
||||
# Use a very small testgen_token_limit that cannot fit even the base function
|
||||
with pytest.raises(ValueError, match="Testgen code context has exceeded token limit"):
|
||||
|
|
@ -4383,15 +4363,10 @@ class MyClass:
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")]
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize=func_to_optimize, project_root_path=tmp_path)
|
||||
|
||||
# CONFIG_VALUE should be in read-writable context since it's used by __init__
|
||||
read_writable = code_ctx.read_writable_code.markdown
|
||||
|
|
@ -4637,15 +4612,10 @@ class MyClass:
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")]
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize=func_to_optimize, project_root_path=tmp_path)
|
||||
|
||||
# counter should be in context since __init__ uses it
|
||||
read_writable = code_ctx.read_writable_code.markdown
|
||||
|
|
|
|||
|
|
@ -5,97 +5,97 @@ from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
|||
|
||||
|
||||
def test_add_needed_imports_with_none_aliases():
|
||||
source_code = '''
|
||||
source_code = """
|
||||
import json
|
||||
from typing import Dict as MyDict, Optional
|
||||
from collections import defaultdict
|
||||
'''
|
||||
|
||||
target_code = '''
|
||||
"""
|
||||
|
||||
target_code = """
|
||||
def target_function():
|
||||
pass
|
||||
'''
|
||||
|
||||
expected_output = '''
|
||||
"""
|
||||
|
||||
expected_output = """
|
||||
def target_function():
|
||||
pass
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
src_path = temp_path / "source.py"
|
||||
dst_path = temp_path / "target.py"
|
||||
|
||||
|
||||
src_path.write_text(source_code)
|
||||
dst_path.write_text(target_code)
|
||||
|
||||
|
||||
result = add_needed_imports_from_module(
|
||||
src_module_code=source_code,
|
||||
dst_module_code=target_code,
|
||||
src_path=src_path,
|
||||
dst_path=dst_path,
|
||||
project_root=temp_path
|
||||
project_root=temp_path,
|
||||
)
|
||||
|
||||
|
||||
assert result.strip() == expected_output.strip()
|
||||
|
||||
|
||||
def test_add_needed_imports_complex_aliases():
|
||||
source_code = '''
|
||||
source_code = """
|
||||
import os
|
||||
import sys as system
|
||||
from typing import Dict, List as MyList, Optional as Opt
|
||||
from collections import defaultdict as dd, Counter
|
||||
from pathlib import Path
|
||||
'''
|
||||
|
||||
target_code = '''
|
||||
"""
|
||||
|
||||
target_code = """
|
||||
def my_function():
|
||||
return "test"
|
||||
'''
|
||||
|
||||
expected_output = '''
|
||||
"""
|
||||
|
||||
expected_output = """
|
||||
def my_function():
|
||||
return "test"
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
src_path = temp_path / "source.py"
|
||||
dst_path = temp_path / "target.py"
|
||||
|
||||
|
||||
src_path.write_text(source_code)
|
||||
dst_path.write_text(target_code)
|
||||
|
||||
|
||||
result = add_needed_imports_from_module(
|
||||
src_module_code=source_code,
|
||||
dst_module_code=target_code,
|
||||
src_path=src_path,
|
||||
dst_path=dst_path,
|
||||
project_root=temp_path
|
||||
project_root=temp_path,
|
||||
)
|
||||
|
||||
|
||||
assert result.strip() == expected_output.strip()
|
||||
|
||||
|
||||
def test_add_needed_imports_with_usage():
|
||||
source_code = '''
|
||||
source_code = """
|
||||
import json
|
||||
from typing import Dict as MyDict, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
'''
|
||||
|
||||
target_code = '''
|
||||
"""
|
||||
|
||||
target_code = """
|
||||
def target_function():
|
||||
data = json.loads('{"key": "value"}')
|
||||
my_dict: MyDict[str, str] = {}
|
||||
opt_value: Optional[str] = None
|
||||
dd = defaultdict(list)
|
||||
return data, my_dict, opt_value, dd
|
||||
'''
|
||||
|
||||
expected_output = '''import json
|
||||
"""
|
||||
|
||||
expected_output = """import json
|
||||
from typing import Dict as MyDict, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
|
|
@ -105,30 +105,30 @@ def target_function():
|
|||
opt_value: Optional[str] = None
|
||||
dd = defaultdict(list)
|
||||
return data, my_dict, opt_value, dd
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
src_path = temp_path / "source.py"
|
||||
dst_path = temp_path / "target.py"
|
||||
|
||||
|
||||
src_path.write_text(source_code)
|
||||
dst_path.write_text(target_code)
|
||||
|
||||
|
||||
result = add_needed_imports_from_module(
|
||||
src_module_code=source_code,
|
||||
dst_module_code=target_code,
|
||||
src_path=src_path,
|
||||
dst_path=dst_path,
|
||||
project_root=temp_path
|
||||
project_root=temp_path,
|
||||
)
|
||||
|
||||
|
||||
# Assert exact expected output
|
||||
assert result.strip() == expected_output.strip()
|
||||
|
||||
|
||||
def test_litellm_router_style_imports():
|
||||
source_code = '''
|
||||
source_code = """
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
|
|
@ -136,92 +136,92 @@ from collections import defaultdict
|
|||
from typing import Dict, List, Optional, Union
|
||||
from litellm.types.utils import ModelInfo
|
||||
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
target_code = '''
|
||||
def target_function():
|
||||
"""Target function for testing."""
|
||||
pass
|
||||
'''
|
||||
|
||||
|
||||
expected_output = '''
|
||||
def target_function():
|
||||
"""Target function for testing."""
|
||||
pass
|
||||
'''
|
||||
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
src_path = temp_path / "complex_source.py"
|
||||
dst_path = temp_path / "target.py"
|
||||
|
||||
|
||||
src_path.write_text(source_code)
|
||||
dst_path.write_text(target_code)
|
||||
|
||||
|
||||
result = add_needed_imports_from_module(
|
||||
src_module_code=source_code,
|
||||
dst_module_code=target_code,
|
||||
src_path=src_path,
|
||||
dst_path=dst_path,
|
||||
project_root=temp_path
|
||||
project_root=temp_path,
|
||||
)
|
||||
|
||||
|
||||
assert result.strip() == expected_output.strip()
|
||||
|
||||
|
||||
def test_edge_case_none_values_in_alias_pairs():
|
||||
source_code = '''
|
||||
source_code = """
|
||||
from typing import Dict as MyDict, List, Optional as Opt
|
||||
from collections import defaultdict, Counter as cnt
|
||||
from pathlib import Path
|
||||
'''
|
||||
|
||||
target_code = '''
|
||||
"""
|
||||
|
||||
target_code = """
|
||||
def my_test_function():
|
||||
return "test"
|
||||
'''
|
||||
|
||||
expected_output = '''
|
||||
"""
|
||||
|
||||
expected_output = """
|
||||
def my_test_function():
|
||||
return "test"
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
src_path = temp_path / "edge_case_source.py"
|
||||
dst_path = temp_path / "target.py"
|
||||
|
||||
|
||||
src_path.write_text(source_code)
|
||||
dst_path.write_text(target_code)
|
||||
|
||||
|
||||
result = add_needed_imports_from_module(
|
||||
src_module_code=source_code,
|
||||
dst_module_code=target_code,
|
||||
src_path=src_path,
|
||||
dst_path=dst_path,
|
||||
project_root=temp_path
|
||||
project_root=temp_path,
|
||||
)
|
||||
|
||||
|
||||
assert result.strip() == expected_output.strip()
|
||||
|
||||
|
||||
def test_partial_import_usage():
|
||||
source_code = '''
|
||||
source_code = """
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List, Optional
|
||||
from collections import defaultdict, Counter
|
||||
'''
|
||||
|
||||
target_code = '''
|
||||
"""
|
||||
|
||||
target_code = """
|
||||
def use_some_imports():
|
||||
path = os.path.join("a", "b")
|
||||
my_dict: Dict[str, int] = {}
|
||||
counter = Counter([1, 2, 3])
|
||||
return path, my_dict, counter
|
||||
'''
|
||||
|
||||
expected_output = '''import os
|
||||
"""
|
||||
|
||||
expected_output = """import os
|
||||
from collections import Counter
|
||||
from typing import Dict
|
||||
|
||||
|
|
@ -230,42 +230,42 @@ def use_some_imports():
|
|||
my_dict: Dict[str, int] = {}
|
||||
counter = Counter([1, 2, 3])
|
||||
return path, my_dict, counter
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
src_path = temp_path / "source.py"
|
||||
dst_path = temp_path / "target.py"
|
||||
|
||||
|
||||
src_path.write_text(source_code)
|
||||
dst_path.write_text(target_code)
|
||||
|
||||
|
||||
result = add_needed_imports_from_module(
|
||||
src_module_code=source_code,
|
||||
dst_module_code=target_code,
|
||||
src_path=src_path,
|
||||
dst_path=dst_path,
|
||||
project_root=temp_path
|
||||
project_root=temp_path,
|
||||
)
|
||||
|
||||
|
||||
assert result.strip() == expected_output.strip()
|
||||
|
||||
|
||||
def test_alias_handling():
|
||||
source_code = '''
|
||||
source_code = """
|
||||
from typing import Dict as MyDict, List as MyList, Optional
|
||||
from collections import defaultdict as dd, Counter
|
||||
'''
|
||||
|
||||
target_code = '''
|
||||
"""
|
||||
|
||||
target_code = """
|
||||
def test_aliases():
|
||||
d: MyDict[str, int] = {}
|
||||
lst: MyList[str] = []
|
||||
dd_instance = dd(list)
|
||||
return d, lst, dd_instance
|
||||
'''
|
||||
|
||||
expected_output = '''from collections import defaultdict as dd
|
||||
"""
|
||||
|
||||
expected_output = """from collections import defaultdict as dd
|
||||
from typing import Dict as MyDict, List as MyList
|
||||
|
||||
def test_aliases():
|
||||
|
|
@ -273,59 +273,59 @@ def test_aliases():
|
|||
lst: MyList[str] = []
|
||||
dd_instance = dd(list)
|
||||
return d, lst, dd_instance
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
src_path = temp_path / "source.py"
|
||||
dst_path = temp_path / "target.py"
|
||||
|
||||
|
||||
src_path.write_text(source_code)
|
||||
dst_path.write_text(target_code)
|
||||
|
||||
|
||||
result = add_needed_imports_from_module(
|
||||
src_module_code=source_code,
|
||||
dst_module_code=target_code,
|
||||
src_path=src_path,
|
||||
dst_path=dst_path,
|
||||
project_root=temp_path
|
||||
project_root=temp_path,
|
||||
)
|
||||
|
||||
|
||||
assert result.strip() == expected_output.strip()
|
||||
|
||||
|
||||
def test_add_needed_imports_with_nonealiases():
|
||||
source_code = '''
|
||||
source_code = """
|
||||
import json
|
||||
from typing import Dict as MyDict, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
'''
|
||||
|
||||
target_code = '''
|
||||
"""
|
||||
|
||||
target_code = """
|
||||
def target_function():
|
||||
pass
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
src_path = temp_path / "source.py"
|
||||
dst_path = temp_path / "target.py"
|
||||
|
||||
|
||||
src_path.write_text(source_code)
|
||||
dst_path.write_text(target_code)
|
||||
|
||||
|
||||
# This should not raise a TypeError
|
||||
result = add_needed_imports_from_module(
|
||||
src_module_code=source_code,
|
||||
dst_module_code=target_code,
|
||||
src_path=src_path,
|
||||
dst_path=dst_path,
|
||||
project_root=temp_path
|
||||
project_root=temp_path,
|
||||
)
|
||||
|
||||
|
||||
expected_output = '''
|
||||
expected_output = """
|
||||
def target_function():
|
||||
pass
|
||||
'''
|
||||
assert result.strip() == expected_output.strip()
|
||||
"""
|
||||
assert result.strip() == expected_output.strip()
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -279,21 +279,23 @@ def test_path_belongs_to_site_packages_with_relative_path(monkeypatch: pytest.Mo
|
|||
assert path_belongs_to_site_packages(file_path) is False
|
||||
|
||||
|
||||
def test_path_belongs_to_site_packages_with_symlinked_site_packages(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
def test_path_belongs_to_site_packages_with_symlinked_site_packages(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
real_site_packages = tmp_path / "real_site_packages"
|
||||
real_site_packages.mkdir()
|
||||
|
||||
|
||||
symlinked_site_packages = tmp_path / "symlinked_site_packages"
|
||||
symlinked_site_packages.symlink_to(real_site_packages)
|
||||
|
||||
|
||||
package_file = real_site_packages / "some_package" / "__init__.py"
|
||||
package_file.parent.mkdir()
|
||||
package_file.write_text("# package file")
|
||||
|
||||
|
||||
monkeypatch.setattr(site, "getsitepackages", lambda: [str(symlinked_site_packages)])
|
||||
|
||||
|
||||
assert path_belongs_to_site_packages(package_file) is True
|
||||
|
||||
|
||||
symlinked_package_file = symlinked_site_packages / "some_package" / "__init__.py"
|
||||
assert path_belongs_to_site_packages(symlinked_package_file) is True
|
||||
|
||||
|
|
@ -301,40 +303,42 @@ def test_path_belongs_to_site_packages_with_symlinked_site_packages(monkeypatch:
|
|||
def test_path_belongs_to_site_packages_with_complex_symlinks(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
real_site_packages = tmp_path / "real" / "lib" / "python3.9" / "site-packages"
|
||||
real_site_packages.mkdir(parents=True)
|
||||
|
||||
|
||||
link1 = tmp_path / "link1"
|
||||
link1.symlink_to(real_site_packages.parent.parent.parent)
|
||||
|
||||
link2 = tmp_path / "link2"
|
||||
|
||||
link2 = tmp_path / "link2"
|
||||
link2.symlink_to(link1)
|
||||
|
||||
|
||||
package_file = real_site_packages / "test_package" / "module.py"
|
||||
package_file.parent.mkdir()
|
||||
package_file.write_text("# test module")
|
||||
|
||||
|
||||
site_packages_via_links = link2 / "lib" / "python3.9" / "site-packages"
|
||||
monkeypatch.setattr(site, "getsitepackages", lambda: [str(site_packages_via_links)])
|
||||
|
||||
|
||||
assert path_belongs_to_site_packages(package_file) is True
|
||||
|
||||
|
||||
file_via_links = site_packages_via_links / "test_package" / "module.py"
|
||||
assert path_belongs_to_site_packages(file_via_links) is True
|
||||
|
||||
|
||||
def test_path_belongs_to_site_packages_resolved_paths_normalization(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
def test_path_belongs_to_site_packages_resolved_paths_normalization(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
site_packages_dir = tmp_path / "lib" / "python3.9" / "site-packages"
|
||||
site_packages_dir.mkdir(parents=True)
|
||||
|
||||
|
||||
package_dir = site_packages_dir / "mypackage"
|
||||
package_dir.mkdir()
|
||||
package_file = package_dir / "module.py"
|
||||
package_file.write_text("# module")
|
||||
|
||||
|
||||
complex_site_packages_path = tmp_path / "lib" / "python3.9" / "other" / ".." / "site-packages" / "."
|
||||
monkeypatch.setattr(site, "getsitepackages", lambda: [str(complex_site_packages_path)])
|
||||
|
||||
|
||||
assert path_belongs_to_site_packages(package_file) is True
|
||||
|
||||
|
||||
complex_file_path = tmp_path / "lib" / "python3.9" / "site-packages" / "other" / ".." / "mypackage" / "module.py"
|
||||
assert path_belongs_to_site_packages(complex_file_path) is True
|
||||
|
||||
|
|
@ -374,8 +378,9 @@ def my_function():
|
|||
def mock_code_context():
|
||||
"""Mock CodeOptimizationContext for testing extract_dependent_function."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
|
||||
|
||||
context = MagicMock(spec=CodeOptimizationContext)
|
||||
context.preexisting_objects = []
|
||||
return context
|
||||
|
|
@ -393,7 +398,7 @@ def helper_function():
|
|||
```
|
||||
""")
|
||||
assert extract_dependent_function("main_function", mock_code_context) == "helper_function"
|
||||
|
||||
|
||||
# Test async function extraction
|
||||
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
|
||||
def main_function():
|
||||
|
|
@ -416,7 +421,7 @@ def main_function():
|
|||
```
|
||||
""")
|
||||
assert extract_dependent_function("main_function", mock_code_context) is False
|
||||
|
||||
|
||||
# Multiple dependent functions
|
||||
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
|
||||
def main_function():
|
||||
|
|
@ -443,7 +448,7 @@ def sync_helper():
|
|||
```
|
||||
""")
|
||||
assert extract_dependent_function("async_main", mock_code_context) == "sync_helper"
|
||||
|
||||
|
||||
# Only async functions
|
||||
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
|
||||
async def async_main():
|
||||
|
|
@ -500,7 +505,7 @@ def test_partial_module_name2(base_dir: Path) -> None:
|
|||
|
||||
def test_pytest_unittest_path_resolution_with_prefix(tmp_path: Path) -> None:
|
||||
"""Test path resolution when pytest includes parent directory in classname.
|
||||
|
||||
|
||||
This handles the case where pytest's base_dir is /path/to/tests but the
|
||||
classname includes the parent directory like "project.tests.unittest.test_file.TestClass".
|
||||
"""
|
||||
|
|
@ -509,34 +514,29 @@ def test_pytest_unittest_path_resolution_with_prefix(tmp_path: Path) -> None:
|
|||
tests_root = project_root / "tests"
|
||||
unittest_dir = tests_root / "unittest"
|
||||
unittest_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Create test files
|
||||
test_file = unittest_dir / "test_bubble_sort.py"
|
||||
test_file.touch()
|
||||
|
||||
|
||||
generated_test = unittest_dir / "test_sorter__unit_test_0.py"
|
||||
generated_test.touch()
|
||||
|
||||
|
||||
# Case 1: pytest reports classname with full path including "code_to_optimize.tests"
|
||||
# but base_dir is .../tests (not the project root)
|
||||
result = resolve_test_file_from_class_path(
|
||||
"code_to_optimize.tests.unittest.test_bubble_sort.TestPigLatin",
|
||||
tests_root
|
||||
"code_to_optimize.tests.unittest.test_bubble_sort.TestPigLatin", tests_root
|
||||
)
|
||||
assert result == test_file
|
||||
|
||||
|
||||
# Case 2: Generated test file with class name
|
||||
result = resolve_test_file_from_class_path(
|
||||
"code_to_optimize.tests.unittest.test_sorter__unit_test_0.TestSorter",
|
||||
tests_root
|
||||
"code_to_optimize.tests.unittest.test_sorter__unit_test_0.TestSorter", tests_root
|
||||
)
|
||||
assert result == generated_test
|
||||
|
||||
|
||||
# Case 3: Without the class name (just the module path)
|
||||
result = resolve_test_file_from_class_path(
|
||||
"code_to_optimize.tests.unittest.test_bubble_sort",
|
||||
tests_root
|
||||
)
|
||||
result = resolve_test_file_from_class_path("code_to_optimize.tests.unittest.test_bubble_sort", tests_root)
|
||||
assert result == test_file
|
||||
|
||||
|
||||
|
|
@ -546,23 +546,17 @@ def test_pytest_unittest_multiple_prefix_levels(tmp_path: Path) -> None:
|
|||
base = tmp_path / "org" / "project" / "src" / "tests"
|
||||
unit_dir = base / "unit"
|
||||
unit_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
test_file = unit_dir / "test_example.py"
|
||||
test_file.touch()
|
||||
|
||||
|
||||
# pytest might report: org.project.src.tests.unit.test_example.TestClass
|
||||
# with base_dir being .../src/tests or .../tests
|
||||
result = resolve_test_file_from_class_path(
|
||||
"org.project.src.tests.unit.test_example.TestClass",
|
||||
base
|
||||
)
|
||||
result = resolve_test_file_from_class_path("org.project.src.tests.unit.test_example.TestClass", base)
|
||||
assert result == test_file
|
||||
|
||||
|
||||
# Also test with base_dir at different level
|
||||
result = resolve_test_file_from_class_path(
|
||||
"project.src.tests.unit.test_example.TestClass",
|
||||
base
|
||||
)
|
||||
result = resolve_test_file_from_class_path("project.src.tests.unit.test_example.TestClass", base)
|
||||
assert result == test_file
|
||||
|
||||
|
||||
|
|
@ -570,15 +564,14 @@ def test_pytest_unittest_instrumented_files(tmp_path: Path) -> None:
|
|||
"""Test path resolution for instrumented test files."""
|
||||
tests_root = tmp_path / "tests" / "unittest"
|
||||
tests_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Create instrumented test file
|
||||
instrumented_file = tests_root / "test_bubble_sort__perfinstrumented.py"
|
||||
instrumented_file.touch()
|
||||
|
||||
|
||||
# pytest classname includes parent directories
|
||||
result = resolve_test_file_from_class_path(
|
||||
"code_to_optimize.tests.unittest.test_bubble_sort__perfinstrumented.TestPigLatin",
|
||||
tmp_path / "tests"
|
||||
"code_to_optimize.tests.unittest.test_bubble_sort__perfinstrumented.TestPigLatin", tmp_path / "tests"
|
||||
)
|
||||
assert result == instrumented_file
|
||||
|
||||
|
|
@ -587,15 +580,12 @@ def test_pytest_unittest_nested_classes(tmp_path: Path) -> None:
|
|||
"""Test path resolution with nested class names."""
|
||||
tests_root = tmp_path / "tests"
|
||||
tests_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
test_file = tests_root / "test_nested.py"
|
||||
test_file.touch()
|
||||
|
||||
|
||||
# Some unittest frameworks use nested classes
|
||||
result = resolve_test_file_from_class_path(
|
||||
"project.tests.test_nested.OuterClass.InnerClass",
|
||||
tests_root
|
||||
)
|
||||
result = resolve_test_file_from_class_path("project.tests.test_nested.OuterClass.InnerClass", tests_root)
|
||||
assert result == test_file
|
||||
|
||||
|
||||
|
|
@ -603,12 +593,9 @@ def test_pytest_unittest_no_match_returns_none(tmp_path: Path) -> None:
|
|||
"""Test that non-existent files return None even with prefix stripping."""
|
||||
tests_root = tmp_path / "tests"
|
||||
tests_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# File doesn't exist
|
||||
result = resolve_test_file_from_class_path(
|
||||
"code_to_optimize.tests.unittest.nonexistent_test.TestClass",
|
||||
tests_root
|
||||
)
|
||||
result = resolve_test_file_from_class_path("code_to_optimize.tests.unittest.nonexistent_test.TestClass", tests_root)
|
||||
assert result is None
|
||||
|
||||
|
||||
|
|
@ -617,10 +604,10 @@ def test_pytest_unittest_single_component(tmp_path: Path) -> None:
|
|||
base_dir = tmp_path
|
||||
test_file = base_dir / "test_simple.py"
|
||||
test_file.touch()
|
||||
|
||||
|
||||
result = file_name_from_test_module_name("test_simple", base_dir)
|
||||
assert result == test_file
|
||||
|
||||
|
||||
# With class name
|
||||
result = file_name_from_test_module_name("test_simple.TestClass", base_dir)
|
||||
assert result == test_file
|
||||
|
|
@ -644,7 +631,7 @@ def test_generate_candidates() -> None:
|
|||
"Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
|
||||
"krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
|
||||
"Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
|
||||
"/Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py"
|
||||
"/Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
|
||||
}
|
||||
assert generate_candidates(source_code_path) == expected_candidates
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,9 @@ class MyClass:
|
|||
with sample_code_path.open("w") as f:
|
||||
f.write(sample_code)
|
||||
result = execute_test_subprocess(
|
||||
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
|
||||
cwd=test_dir,
|
||||
cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"],
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
assert not result.stderr
|
||||
assert result.returncode == 0
|
||||
|
|
@ -129,7 +131,9 @@ class MyClass:
|
|||
with sample_code_path.open("w") as f:
|
||||
f.write(sample_code)
|
||||
result = execute_test_subprocess(
|
||||
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
|
||||
cwd=test_dir,
|
||||
cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"],
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
assert not result.stderr
|
||||
assert result.returncode == 0
|
||||
|
|
@ -194,7 +198,9 @@ class MyClass:
|
|||
with sample_code_path.open("w") as f:
|
||||
f.write(sample_code)
|
||||
result = execute_test_subprocess(
|
||||
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
|
||||
cwd=test_dir,
|
||||
cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"],
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
assert not result.stderr
|
||||
assert result.returncode == 0
|
||||
|
|
@ -279,7 +285,9 @@ class MyClass:
|
|||
|
||||
# Run pytest as a subprocess
|
||||
result = execute_test_subprocess(
|
||||
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
|
||||
cwd=test_dir,
|
||||
cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"],
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
|
||||
# Check for errors
|
||||
|
|
@ -356,7 +364,9 @@ class MyClass:
|
|||
with sample_code_path.open("w") as f:
|
||||
f.write(sample_code)
|
||||
result = execute_test_subprocess(
|
||||
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
|
||||
cwd=test_dir,
|
||||
cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"],
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
assert not result.stderr
|
||||
assert result.returncode == 0
|
||||
|
|
@ -1184,6 +1194,7 @@ class MyClass:
|
|||
helper_path_1.unlink(missing_ok=True)
|
||||
helper_path_2.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_get_stack_info_env_var_fallback() -> None:
|
||||
"""Test that get_test_info_from_stack falls back to environment variables when stack walking fails to find test_name.
|
||||
|
||||
|
|
@ -1421,8 +1432,7 @@ def calculate_portfolio_metrics(
|
|||
f.write(test_code)
|
||||
|
||||
fto = FunctionToOptimize("calculate_portfolio_metrics", fto_file_path, parents=[])
|
||||
file_path_to_helper_class = {
|
||||
}
|
||||
file_path_to_helper_class = {}
|
||||
instrument_codeflash_capture(fto, file_path_to_helper_class, tests_root)
|
||||
test_env = os.environ.copy()
|
||||
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
|
|
@ -1453,8 +1463,7 @@ def calculate_portfolio_metrics(
|
|||
candidate_helper_code = {}
|
||||
for file_path in file_path_to_helper_class:
|
||||
candidate_helper_code[file_path] = Path(file_path).read_text("utf-8")
|
||||
file_path_to_helper_classes = {
|
||||
}
|
||||
file_path_to_helper_classes = {}
|
||||
instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root)
|
||||
|
||||
test_results, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
|
|
@ -1692,4 +1701,4 @@ class SlotsClass:
|
|||
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
sample_code_path.unlink(missing_ok=True)
|
||||
sample_code_path.unlink(missing_ok=True)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import tempfile
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint, get_all_historical_functions
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
from pathlib import Path
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
|
||||
|
||||
@codeflash_trace
|
||||
def example_function(arr):
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -50,7 +50,9 @@ def test_speedup_critic() -> None:
|
|||
total_candidate_timing=12,
|
||||
)
|
||||
|
||||
assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True) # 20% improvement
|
||||
assert speedup_critic(
|
||||
candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True
|
||||
) # 20% improvement
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
|
|
@ -61,7 +63,9 @@ def test_speedup_critic() -> None:
|
|||
optimization_candidate_index=0,
|
||||
)
|
||||
|
||||
assert not speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True) # 6% improvement
|
||||
assert not speedup_critic(
|
||||
candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True
|
||||
) # 6% improvement
|
||||
|
||||
original_code_runtime = 100000
|
||||
best_runtime_until_now = 100000
|
||||
|
|
@ -75,7 +79,9 @@ def test_speedup_critic() -> None:
|
|||
optimization_candidate_index=0,
|
||||
)
|
||||
|
||||
assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True) # 6% improvement
|
||||
assert speedup_critic(
|
||||
candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True
|
||||
) # 6% improvement
|
||||
|
||||
|
||||
def test_generated_test_critic() -> None:
|
||||
|
|
@ -418,6 +424,7 @@ def test_coverage_critic() -> None:
|
|||
|
||||
assert coverage_critic(failing_coverage) is False
|
||||
|
||||
|
||||
def test_throughput_gain() -> None:
|
||||
"""Test throughput_gain calculation."""
|
||||
# Test basic throughput improvement
|
||||
|
|
@ -458,7 +465,7 @@ def test_speedup_critic_with_async_throughput() -> None:
|
|||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
# Test case 2: Runtime improves significantly, throughput doesn't meet threshold (should pass)
|
||||
|
|
@ -478,7 +485,7 @@ def test_speedup_critic_with_async_throughput() -> None:
|
|||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
# Test case 3: Throughput improves significantly, runtime doesn't meet threshold (should pass)
|
||||
|
|
@ -498,7 +505,7 @@ def test_speedup_critic_with_async_throughput() -> None:
|
|||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
# Test case 4: No throughput data - should fall back to runtime-only evaluation
|
||||
|
|
@ -518,7 +525,7 @@ def test_speedup_critic_with_async_throughput() -> None:
|
|||
best_runtime_until_now=None,
|
||||
original_async_throughput=None, # No original throughput data
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
# Test case 5: Test best_throughput_until_now comparison
|
||||
|
|
@ -539,7 +546,7 @@ def test_speedup_critic_with_async_throughput() -> None:
|
|||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
# Should fail when there's a better throughput already
|
||||
|
|
@ -549,7 +556,7 @@ def test_speedup_critic_with_async_throughput() -> None:
|
|||
best_runtime_until_now=7000, # Better runtime already exists
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=120, # Better throughput already exists
|
||||
disable_gh_action_noise=True
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
# Test case 6: Zero original throughput (edge case)
|
||||
|
|
@ -570,7 +577,7 @@ def test_speedup_critic_with_async_throughput() -> None:
|
|||
best_runtime_until_now=None,
|
||||
original_async_throughput=0, # Zero original throughput
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -594,29 +601,20 @@ def test_concurrency_gain() -> None:
|
|||
|
||||
# Test no improvement
|
||||
same = ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=10_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.0,
|
||||
sequential_time_ns=10_000_000, concurrent_time_ns=10_000_000, concurrency_factor=10, concurrency_ratio=1.0
|
||||
)
|
||||
assert concurrency_gain(original, same) == 0.0
|
||||
|
||||
# Test slight improvement
|
||||
slightly_better = ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=8_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.25,
|
||||
sequential_time_ns=10_000_000, concurrent_time_ns=8_000_000, concurrency_factor=10, concurrency_ratio=1.25
|
||||
)
|
||||
# 25% improvement: (1.25 - 1.0) / 1.0 = 0.25
|
||||
assert concurrency_gain(original, slightly_better) == 0.25
|
||||
|
||||
# Test zero original ratio (edge case)
|
||||
zero_ratio = ConcurrencyMetrics(
|
||||
sequential_time_ns=0,
|
||||
concurrent_time_ns=1_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=0.0,
|
||||
sequential_time_ns=0, concurrent_time_ns=1_000_000, concurrency_factor=10, concurrency_ratio=0.0
|
||||
)
|
||||
assert concurrency_gain(zero_ratio, optimized) == 0.0
|
||||
|
||||
|
|
@ -628,10 +626,7 @@ def test_speedup_critic_with_concurrency_metrics() -> None:
|
|||
|
||||
# Original concurrency metrics (blocking code - ratio ~= 1.0)
|
||||
original_concurrency = ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=10_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.0,
|
||||
sequential_time_ns=10_000_000, concurrent_time_ns=10_000_000, concurrency_factor=10, concurrency_ratio=1.0
|
||||
)
|
||||
|
||||
# Test case 1: Concurrency improves significantly (blocking -> non-blocking)
|
||||
|
|
@ -731,10 +726,7 @@ def test_speedup_critic_with_concurrency_metrics() -> None:
|
|||
total_candidate_timing=10000,
|
||||
async_throughput=100,
|
||||
concurrency_metrics=ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=2_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=5.0,
|
||||
sequential_time_ns=10_000_000, concurrent_time_ns=2_000_000, concurrency_factor=10, concurrency_ratio=5.0
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from unittest.mock import Mock
|
||||
import contextlib
|
||||
import os
|
||||
import shutil
|
||||
|
|
@ -6,6 +5,7 @@ import unittest
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest.mock import Mock
|
||||
|
||||
from codeflash.result.create_pr import existing_tests_source_for
|
||||
|
||||
|
|
@ -30,46 +30,33 @@ class TestExistingTestsSourceFor:
|
|||
self.mock_function_called_in_test = Mock()
|
||||
self.mock_function_called_in_test.tests_in_file = Mock()
|
||||
self.mock_function_called_in_test.tests_in_file.test_file = Path(__file__).resolve().parent / "test_module.py"
|
||||
#Path to pyproject.toml
|
||||
# Path to pyproject.toml
|
||||
os.chdir(self.test_cfg.project_root_path)
|
||||
|
||||
|
||||
def test_no_test_files_returns_empty_string(self):
|
||||
"""Test that function returns empty string when no test files exist."""
|
||||
|
||||
function_to_tests = {}
|
||||
original_runtimes = {}
|
||||
optimized_runtimes = {}
|
||||
|
||||
result, _, _ = existing_tests_source_for(
|
||||
"module.function",
|
||||
function_to_tests,
|
||||
self.test_cfg,
|
||||
original_runtimes,
|
||||
optimized_runtimes
|
||||
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
|
||||
)
|
||||
|
||||
assert result == ""
|
||||
|
||||
def test_single_test_with_improvement(self):
|
||||
"""Test single test showing performance improvement."""
|
||||
|
||||
function_to_tests = {
|
||||
"module.function": {self.mock_function_called_in_test}
|
||||
}
|
||||
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
|
||||
original_runtimes = {
|
||||
self.mock_invocation_id: [1000000] # 1ms in nanoseconds
|
||||
}
|
||||
optimized_runtimes = {
|
||||
self.mock_invocation_id: [500000] # 0.5ms in nanoseconds
|
||||
self.mock_invocation_id: [500000] # 0.5ms in nanoseconds
|
||||
}
|
||||
|
||||
result, _, _ = existing_tests_source_for(
|
||||
"module.function",
|
||||
function_to_tests,
|
||||
self.test_cfg,
|
||||
original_runtimes,
|
||||
optimized_runtimes
|
||||
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
|
||||
)
|
||||
|
||||
expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
|
||||
|
|
@ -81,23 +68,16 @@ class TestExistingTestsSourceFor:
|
|||
|
||||
def test_single_test_with_regression(self):
|
||||
"""Test single test showing performance regression."""
|
||||
|
||||
function_to_tests = {
|
||||
"module.function": {self.mock_function_called_in_test}
|
||||
}
|
||||
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
|
||||
original_runtimes = {
|
||||
self.mock_invocation_id: [500000] # 0.5ms in nanoseconds
|
||||
self.mock_invocation_id: [500000] # 0.5ms in nanoseconds
|
||||
}
|
||||
optimized_runtimes = {
|
||||
self.mock_invocation_id: [1000000] # 1ms in nanoseconds
|
||||
}
|
||||
|
||||
result, _, _ = existing_tests_source_for(
|
||||
"module.function",
|
||||
function_to_tests,
|
||||
self.test_cfg,
|
||||
original_runtimes,
|
||||
optimized_runtimes
|
||||
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
|
||||
)
|
||||
|
||||
expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
|
||||
|
|
@ -109,28 +89,17 @@ class TestExistingTestsSourceFor:
|
|||
|
||||
def test_test_without_class_name(self):
|
||||
"""Test function without class name (standalone test function)."""
|
||||
|
||||
mock_invocation_no_class = Mock()
|
||||
mock_invocation_no_class.test_module_path = "tests.test_module"
|
||||
mock_invocation_no_class.test_class_name = None
|
||||
mock_invocation_no_class.test_function_name = "test_standalone"
|
||||
|
||||
function_to_tests = {
|
||||
"module.function": {self.mock_function_called_in_test}
|
||||
}
|
||||
original_runtimes = {
|
||||
mock_invocation_no_class: [1000000]
|
||||
}
|
||||
optimized_runtimes = {
|
||||
mock_invocation_no_class: [800000]
|
||||
}
|
||||
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
|
||||
original_runtimes = {mock_invocation_no_class: [1000000]}
|
||||
optimized_runtimes = {mock_invocation_no_class: [800000]}
|
||||
|
||||
result, _, _ = existing_tests_source_for(
|
||||
"module.function",
|
||||
function_to_tests,
|
||||
self.test_cfg,
|
||||
original_runtimes,
|
||||
optimized_runtimes
|
||||
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
|
||||
)
|
||||
|
||||
expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
|
||||
|
|
@ -142,21 +111,12 @@ class TestExistingTestsSourceFor:
|
|||
|
||||
def test_missing_original_runtime(self):
|
||||
"""Test when original runtime is missing (shows NaN)."""
|
||||
|
||||
function_to_tests = {
|
||||
"module.function": {self.mock_function_called_in_test}
|
||||
}
|
||||
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
|
||||
original_runtimes = {}
|
||||
optimized_runtimes = {
|
||||
self.mock_invocation_id: [500000]
|
||||
}
|
||||
optimized_runtimes = {self.mock_invocation_id: [500000]}
|
||||
|
||||
result, _, _ = existing_tests_source_for(
|
||||
"module.function",
|
||||
function_to_tests,
|
||||
self.test_cfg,
|
||||
original_runtimes,
|
||||
optimized_runtimes
|
||||
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
|
||||
)
|
||||
|
||||
expected = ""
|
||||
|
|
@ -165,21 +125,12 @@ class TestExistingTestsSourceFor:
|
|||
|
||||
def test_missing_optimized_runtime(self):
|
||||
"""Test when optimized runtime is missing (shows NaN)."""
|
||||
|
||||
function_to_tests = {
|
||||
"module.function": {self.mock_function_called_in_test}
|
||||
}
|
||||
original_runtimes = {
|
||||
self.mock_invocation_id: [1000000]
|
||||
}
|
||||
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
|
||||
original_runtimes = {self.mock_invocation_id: [1000000]}
|
||||
optimized_runtimes = {}
|
||||
|
||||
result, _, _ = existing_tests_source_for(
|
||||
"module.function",
|
||||
function_to_tests,
|
||||
self.test_cfg,
|
||||
original_runtimes,
|
||||
optimized_runtimes
|
||||
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
|
||||
)
|
||||
|
||||
expected = ""
|
||||
|
|
@ -189,7 +140,7 @@ class TestExistingTestsSourceFor:
|
|||
def test_multiple_tests_sorted_output(self):
|
||||
"""Test multiple tests with sorted output by filename and function name."""
|
||||
# Create second test file
|
||||
|
||||
|
||||
mock_function_called_2 = Mock()
|
||||
mock_function_called_2.tests_in_file = Mock()
|
||||
mock_function_called_2.tests_in_file.test_file = Path(__file__).resolve().parent / "test_another.py"
|
||||
|
|
@ -199,24 +150,12 @@ class TestExistingTestsSourceFor:
|
|||
mock_invocation_2.test_class_name = "TestAnother"
|
||||
mock_invocation_2.test_function_name = "test_another_function"
|
||||
|
||||
function_to_tests = {
|
||||
"module.function": {self.mock_function_called_in_test, mock_function_called_2}
|
||||
}
|
||||
original_runtimes = {
|
||||
self.mock_invocation_id: [1000000],
|
||||
mock_invocation_2: [2000000]
|
||||
}
|
||||
optimized_runtimes = {
|
||||
self.mock_invocation_id: [800000],
|
||||
mock_invocation_2: [1500000]
|
||||
}
|
||||
function_to_tests = {"module.function": {self.mock_function_called_in_test, mock_function_called_2}}
|
||||
original_runtimes = {self.mock_invocation_id: [1000000], mock_invocation_2: [2000000]}
|
||||
optimized_runtimes = {self.mock_invocation_id: [800000], mock_invocation_2: [1500000]}
|
||||
|
||||
result, _, _ = existing_tests_source_for(
|
||||
"module.function",
|
||||
function_to_tests,
|
||||
self.test_cfg,
|
||||
original_runtimes,
|
||||
optimized_runtimes
|
||||
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
|
||||
)
|
||||
|
||||
expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
|
||||
|
|
@ -229,23 +168,16 @@ class TestExistingTestsSourceFor:
|
|||
|
||||
def test_multiple_runtimes_uses_minimum(self):
|
||||
"""Test that function uses minimum runtime when multiple measurements exist."""
|
||||
|
||||
function_to_tests = {
|
||||
"module.function": {self.mock_function_called_in_test}
|
||||
}
|
||||
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
|
||||
original_runtimes = {
|
||||
self.mock_invocation_id: [1000000, 1200000, 800000] # min: 800000
|
||||
}
|
||||
optimized_runtimes = {
|
||||
self.mock_invocation_id: [600000, 700000, 500000] # min: 500000
|
||||
self.mock_invocation_id: [600000, 700000, 500000] # min: 500000
|
||||
}
|
||||
|
||||
result, _, _ = existing_tests_source_for(
|
||||
"module.function",
|
||||
function_to_tests,
|
||||
self.test_cfg,
|
||||
original_runtimes,
|
||||
optimized_runtimes
|
||||
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
|
||||
)
|
||||
|
||||
expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
|
||||
|
|
@ -257,7 +189,6 @@ class TestExistingTestsSourceFor:
|
|||
|
||||
def test_complex_module_path_conversion(self):
|
||||
"""Test conversion of complex module paths to file paths."""
|
||||
|
||||
mock_invocation_complex = Mock()
|
||||
mock_invocation_complex.test_module_path = "tests.integration.test_complex_module"
|
||||
mock_invocation_complex.test_class_name = "TestComplex"
|
||||
|
|
@ -265,24 +196,16 @@ class TestExistingTestsSourceFor:
|
|||
|
||||
mock_function_complex = Mock()
|
||||
mock_function_complex.tests_in_file = Mock()
|
||||
mock_function_complex.tests_in_file.test_file = Path(__file__).resolve().parent / "integration/test_complex_module.py"
|
||||
mock_function_complex.tests_in_file.test_file = (
|
||||
Path(__file__).resolve().parent / "integration/test_complex_module.py"
|
||||
)
|
||||
|
||||
function_to_tests = {
|
||||
"module.function": {mock_function_complex}
|
||||
}
|
||||
original_runtimes = {
|
||||
mock_invocation_complex: [1000000]
|
||||
}
|
||||
optimized_runtimes = {
|
||||
mock_invocation_complex: [750000]
|
||||
}
|
||||
function_to_tests = {"module.function": {mock_function_complex}}
|
||||
original_runtimes = {mock_invocation_complex: [1000000]}
|
||||
optimized_runtimes = {mock_invocation_complex: [750000]}
|
||||
|
||||
result, _, _ = existing_tests_source_for(
|
||||
"module.function",
|
||||
function_to_tests,
|
||||
self.test_cfg,
|
||||
original_runtimes,
|
||||
optimized_runtimes
|
||||
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
|
||||
)
|
||||
|
||||
expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
|
||||
|
|
@ -294,23 +217,12 @@ class TestExistingTestsSourceFor:
|
|||
|
||||
def test_zero_runtime_values(self):
|
||||
"""Test handling of zero runtime values."""
|
||||
|
||||
function_to_tests = {
|
||||
"module.function": {self.mock_function_called_in_test}
|
||||
}
|
||||
original_runtimes = {
|
||||
self.mock_invocation_id: [0]
|
||||
}
|
||||
optimized_runtimes = {
|
||||
self.mock_invocation_id: [0]
|
||||
}
|
||||
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
|
||||
original_runtimes = {self.mock_invocation_id: [0]}
|
||||
optimized_runtimes = {self.mock_invocation_id: [0]}
|
||||
|
||||
result, _, _ = existing_tests_source_for(
|
||||
"module.function",
|
||||
function_to_tests,
|
||||
self.test_cfg,
|
||||
original_runtimes,
|
||||
optimized_runtimes
|
||||
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
|
||||
)
|
||||
|
||||
expected = ""
|
||||
|
|
@ -320,7 +232,7 @@ class TestExistingTestsSourceFor:
|
|||
def test_filters_out_generated_tests(self):
|
||||
"""Test that generated tests are filtered out and only non-generated tests are included."""
|
||||
# Create a test that would be filtered out (not in non_generated_tests)
|
||||
|
||||
|
||||
mock_generated_test = Mock()
|
||||
mock_generated_test.tests_in_file = Mock()
|
||||
mock_generated_test.tests_in_file.test_file = "/project/tests/generated_test.py"
|
||||
|
|
@ -330,24 +242,18 @@ class TestExistingTestsSourceFor:
|
|||
mock_generated_invocation.test_class_name = "TestGenerated"
|
||||
mock_generated_invocation.test_function_name = "test_generated"
|
||||
|
||||
function_to_tests = {
|
||||
"module.function": {self.mock_function_called_in_test}
|
||||
}
|
||||
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
|
||||
original_runtimes = {
|
||||
self.mock_invocation_id: [1000000],
|
||||
mock_generated_invocation: [500000] # This should be filtered out
|
||||
mock_generated_invocation: [500000], # This should be filtered out
|
||||
}
|
||||
optimized_runtimes = {
|
||||
self.mock_invocation_id: [800000],
|
||||
mock_generated_invocation: [400000] # This should be filtered out
|
||||
mock_generated_invocation: [400000], # This should be filtered out
|
||||
}
|
||||
|
||||
result, _, _ = existing_tests_source_for(
|
||||
"module.function",
|
||||
function_to_tests,
|
||||
self.test_cfg,
|
||||
original_runtimes,
|
||||
optimized_runtimes
|
||||
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
|
||||
)
|
||||
|
||||
# Should only include the non-generated test
|
||||
|
|
@ -358,9 +264,11 @@ class TestExistingTestsSourceFor:
|
|||
|
||||
assert result == expected
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MockInvocationId:
|
||||
"""Mocks codeflash.models.models.InvocationId"""
|
||||
|
||||
test_module_path: str
|
||||
test_function_name: str
|
||||
test_class_name: Optional[str] = None
|
||||
|
|
@ -369,6 +277,7 @@ class MockInvocationId:
|
|||
@dataclass(frozen=True)
|
||||
class MockTestsInFile:
|
||||
"""Mocks a part of codeflash.models.models.FunctionCalledInTest"""
|
||||
|
||||
test_file: Path
|
||||
test_type: str = "EXISTING_UNIT_TEST"
|
||||
|
||||
|
|
@ -376,12 +285,14 @@ class MockTestsInFile:
|
|||
@dataclass(frozen=True)
|
||||
class MockFunctionCalledInTest:
|
||||
"""Mocks codeflash.models.models.FunctionCalledInTest"""
|
||||
|
||||
tests_in_file: MockTestsInFile
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MockTestConfig:
|
||||
"""Mocks codeflash.verification.verification_utils.TestConfig"""
|
||||
|
||||
tests_root: Path
|
||||
|
||||
|
||||
|
|
@ -429,11 +340,7 @@ class ExistingTestsSourceForTests(unittest.TestCase):
|
|||
|
||||
test_cfg = MockTestConfig(tests_root=tests_dir.resolve())
|
||||
function_to_tests = {
|
||||
self.func_qual_name: {
|
||||
MockFunctionCalledInTest(
|
||||
tests_in_file=MockTestsInFile(test_file=test_file_path)
|
||||
)
|
||||
}
|
||||
self.func_qual_name: {MockFunctionCalledInTest(tests_in_file=MockTestsInFile(test_file=test_file_path))}
|
||||
}
|
||||
existing, replay, concolic = existing_tests_source_for(
|
||||
function_qualified_name_with_modules_from_root=self.func_qual_name,
|
||||
|
|
@ -456,17 +363,11 @@ class ExistingTestsSourceForTests(unittest.TestCase):
|
|||
|
||||
test_cfg = MockTestConfig(tests_root=tests_dir.resolve())
|
||||
function_to_tests = {
|
||||
self.func_qual_name: {
|
||||
MockFunctionCalledInTest(
|
||||
tests_in_file=MockTestsInFile(test_file=test_file_path)
|
||||
)
|
||||
}
|
||||
self.func_qual_name: {MockFunctionCalledInTest(tests_in_file=MockTestsInFile(test_file=test_file_path))}
|
||||
}
|
||||
|
||||
invocation_id = MockInvocationId(
|
||||
test_module_path="tests.test_existing",
|
||||
test_class_name="TestMyStuff",
|
||||
test_function_name="test_one",
|
||||
test_module_path="tests.test_existing", test_class_name="TestMyStuff", test_function_name="test_one"
|
||||
)
|
||||
|
||||
original_runtimes = {invocation_id: [200_000_000]}
|
||||
|
|
@ -501,32 +402,20 @@ class ExistingTestsSourceForTests(unittest.TestCase):
|
|||
test_cfg = MockTestConfig(tests_root=tests_dir.resolve())
|
||||
function_to_tests = {
|
||||
self.func_qual_name: {
|
||||
MockFunctionCalledInTest(
|
||||
tests_in_file=MockTestsInFile(test_file=replay_test_path)
|
||||
),
|
||||
MockFunctionCalledInTest(
|
||||
tests_in_file=MockTestsInFile(test_file=concolic_test_path)
|
||||
),
|
||||
MockFunctionCalledInTest(tests_in_file=MockTestsInFile(test_file=replay_test_path)),
|
||||
MockFunctionCalledInTest(tests_in_file=MockTestsInFile(test_file=concolic_test_path)),
|
||||
}
|
||||
}
|
||||
|
||||
replay_inv_id = MockInvocationId(
|
||||
test_module_path="tests.__replay_test_abc",
|
||||
test_function_name="test_replay_one",
|
||||
test_module_path="tests.__replay_test_abc", test_function_name="test_replay_one"
|
||||
)
|
||||
concolic_inv_id = MockInvocationId(
|
||||
test_module_path="tests.codeflash_concolic_xyz",
|
||||
test_function_name="test_concolic_one",
|
||||
test_module_path="tests.codeflash_concolic_xyz", test_function_name="test_concolic_one"
|
||||
)
|
||||
|
||||
original_runtimes = {
|
||||
replay_inv_id: [100_000_000],
|
||||
concolic_inv_id: [150_000_000],
|
||||
}
|
||||
optimized_runtimes = {
|
||||
replay_inv_id: [200_000_000],
|
||||
concolic_inv_id: [300_000_000],
|
||||
}
|
||||
original_runtimes = {replay_inv_id: [100_000_000], concolic_inv_id: [150_000_000]}
|
||||
optimized_runtimes = {replay_inv_id: [200_000_000], concolic_inv_id: [300_000_000]}
|
||||
|
||||
existing, replay, concolic = existing_tests_source_for(
|
||||
function_qualified_name_with_modules_from_root=self.func_qual_name,
|
||||
|
|
@ -555,21 +444,13 @@ class ExistingTestsSourceForTests(unittest.TestCase):
|
|||
test_cfg = MockTestConfig(tests_root=tests_dir.resolve())
|
||||
function_to_tests = {
|
||||
self.func_qual_name: {
|
||||
MockFunctionCalledInTest(
|
||||
tests_in_file=MockTestsInFile(test_file=existing_test_path)
|
||||
),
|
||||
MockFunctionCalledInTest(
|
||||
tests_in_file=MockTestsInFile(test_file=replay_test_path)
|
||||
),
|
||||
MockFunctionCalledInTest(tests_in_file=MockTestsInFile(test_file=existing_test_path)),
|
||||
MockFunctionCalledInTest(tests_in_file=MockTestsInFile(test_file=replay_test_path)),
|
||||
}
|
||||
}
|
||||
|
||||
existing_inv_id = MockInvocationId(
|
||||
"tests.test_existing", "test_speedup", "TestExisting"
|
||||
)
|
||||
replay_inv_id = MockInvocationId(
|
||||
"tests.__replay_test_mixed", "test_slowdown"
|
||||
)
|
||||
existing_inv_id = MockInvocationId("tests.test_existing", "test_speedup", "TestExisting")
|
||||
replay_inv_id = MockInvocationId("tests.__replay_test_mixed", "test_slowdown")
|
||||
|
||||
original_runtimes = {
|
||||
existing_inv_id: [400_000_000, 500_000_000], # min is 400ms
|
||||
|
|
@ -581,11 +462,7 @@ class ExistingTestsSourceForTests(unittest.TestCase):
|
|||
}
|
||||
|
||||
existing, replay, concolic = existing_tests_source_for(
|
||||
self.func_qual_name,
|
||||
function_to_tests,
|
||||
test_cfg,
|
||||
original_runtimes,
|
||||
optimized_runtimes,
|
||||
self.func_qual_name, function_to_tests, test_cfg, original_runtimes, optimized_runtimes
|
||||
)
|
||||
|
||||
self.assertIn("`test_existing.py::TestExisting.test_speedup`", existing)
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@
|
|||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType
|
||||
|
||||
|
||||
|
|
@ -340,11 +338,13 @@ class TestFileToNoOfTests:
|
|||
)
|
||||
|
||||
counter = test_results.file_to_no_of_tests(["test_remove"])
|
||||
expected = Counter({
|
||||
Path("/tmp/file1.py"): 1, # Only 1 GENERATED_REGRESSION test
|
||||
Path("/tmp/file2.py"): 1, # Only test_keep (test_remove is excluded)
|
||||
Path("/tmp/file3.py"): 3, # All 3 tests
|
||||
})
|
||||
expected = Counter(
|
||||
{
|
||||
Path("/tmp/file1.py"): 1, # Only 1 GENERATED_REGRESSION test
|
||||
Path("/tmp/file2.py"): 1, # Only test_keep (test_remove is excluded)
|
||||
Path("/tmp/file3.py"): 3, # All 3 tests
|
||||
}
|
||||
)
|
||||
assert counter == expected
|
||||
|
||||
def test_case_sensitivity(self):
|
||||
|
|
@ -438,7 +438,7 @@ class TestFileToNoOfTests:
|
|||
)
|
||||
|
||||
counter = test_results.file_to_no_of_tests([])
|
||||
expected = Counter({path: 1 for path in paths})
|
||||
expected = Counter(dict.fromkeys(paths, 1))
|
||||
assert counter == expected
|
||||
|
||||
def test_large_removal_list(self):
|
||||
|
|
@ -470,4 +470,4 @@ class TestFileToNoOfTests:
|
|||
)
|
||||
|
||||
counter = test_results.file_to_no_of_tests(removal_list)
|
||||
assert counter == Counter({Path("/tmp/test_file.py"): 50}) # 50 kept, 50 removed
|
||||
assert counter == Counter({Path("/tmp/test_file.py"): 50}) # 50 kept, 50 removed
|
||||
|
|
|
|||
|
|
@ -1,24 +1,24 @@
|
|||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import shutil
|
||||
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
yield Path(tmpdirname)
|
||||
|
||||
|
||||
def test_remove_duplicate_imports():
|
||||
"""Test that duplicate imports are removed when should_sort_imports is True."""
|
||||
original_code = "import os\nimport os\n"
|
||||
|
|
@ -187,9 +187,7 @@ def foo():
|
|||
temp_file = temp_dir / "test_file.py"
|
||||
temp_file.write_text(original_code)
|
||||
|
||||
actual = format_code(
|
||||
formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], path=temp_file
|
||||
)
|
||||
actual = format_code(formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], path=temp_file)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
|
|
@ -208,7 +206,7 @@ def foo():
|
|||
assert False, f"Shouldn't throw an exception even if the formatter is not found: {e}"
|
||||
|
||||
|
||||
def _run_formatting_test(source_code: str, should_content_change: bool, expected = None, optimized_function: str = ""):
|
||||
def _run_formatting_test(source_code: str, should_content_change: bool, expected=None, optimized_function: str = ""):
|
||||
try:
|
||||
import ruff # type: ignore
|
||||
except ImportError:
|
||||
|
|
@ -217,67 +215,50 @@ def _run_formatting_test(source_code: str, should_content_change: bool, expected
|
|||
with tempfile.TemporaryDirectory() as test_dir_str:
|
||||
test_dir = Path(test_dir_str)
|
||||
source_file = test_dir / "source.py"
|
||||
|
||||
|
||||
source_file.write_text(source_code)
|
||||
original = source_code
|
||||
target_path = test_dir / "target.py"
|
||||
|
||||
|
||||
shutil.copy2(source_file, target_path)
|
||||
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="process_data",
|
||||
parents=[],
|
||||
file_path=target_path
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(function_name="process_data", parents=[], file_path=target_path)
|
||||
|
||||
test_cfg = TestConfig(
|
||||
tests_root=test_dir,
|
||||
project_root_path=test_dir,
|
||||
test_framework="pytest",
|
||||
tests_project_rootdir=test_dir,
|
||||
tests_root=test_dir, project_root_path=test_dir, test_framework="pytest", tests_project_rootdir=test_dir
|
||||
)
|
||||
|
||||
args = argparse.Namespace(
|
||||
disable_imports_sorting=False,
|
||||
formatter_cmds=[
|
||||
"ruff check --exit-zero --fix $file",
|
||||
"ruff format $file"
|
||||
],
|
||||
disable_imports_sorting=False, formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"]
|
||||
)
|
||||
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
args=args,
|
||||
)
|
||||
|
||||
optimizer = FunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_cfg, args=args)
|
||||
|
||||
optimizer.reformat_code_and_helpers(
|
||||
helper_functions=[],
|
||||
path=target_path,
|
||||
original_code=optimizer.function_to_optimize_source_code,
|
||||
optimized_context=CodeStringsMarkdown(code_strings=[
|
||||
CodeString(
|
||||
code=optimized_function,
|
||||
file_path=target_path.relative_to(test_dir)
|
||||
)
|
||||
]),
|
||||
optimized_context=CodeStringsMarkdown(
|
||||
code_strings=[CodeString(code=optimized_function, file_path=target_path.relative_to(test_dir))]
|
||||
),
|
||||
)
|
||||
|
||||
content = target_path.read_text(encoding="utf8")
|
||||
|
||||
if expected is not None:
|
||||
assert content == expected, f"Expected content to be \n===========\n{expected}\n===========\nbut got\n===========\n{content}\n===========\n"
|
||||
assert content == expected, (
|
||||
f"Expected content to be \n===========\n{expected}\n===========\nbut got\n===========\n{content}\n===========\n"
|
||||
)
|
||||
|
||||
if should_content_change:
|
||||
assert content != original, f"Expected content to change for source.py"
|
||||
assert content != original, "Expected content to change for source.py"
|
||||
else:
|
||||
assert content == original, f"Expected content to remain unchanged for source.py"
|
||||
|
||||
assert content == original, "Expected content to remain unchanged for source.py"
|
||||
|
||||
|
||||
def test_formatting_file_with_many_diffs():
|
||||
"""Test that files with many formatting errors are skipped (content unchanged)."""
|
||||
source_code = '''import os,sys,json,datetime,re
|
||||
source_code = """import os,sys,json,datetime,re
|
||||
from collections import defaultdict,OrderedDict
|
||||
import numpy as np,pandas as pd
|
||||
|
||||
|
|
@ -354,7 +335,7 @@ def main():
|
|||
else:print("Pipeline failed")
|
||||
|
||||
if __name__=='__main__':main()
|
||||
'''
|
||||
"""
|
||||
_run_formatting_test(source_code, False)
|
||||
|
||||
|
||||
|
|
@ -423,7 +404,7 @@ def process_data(data, config=None):
|
|||
|
||||
def test_formatting_extremely_messy_file():
|
||||
"""Test that extremely messy files with 100+ potential changes are skipped."""
|
||||
source_code = '''import os,sys,json,datetime,re,collections,itertools,functools,operator
|
||||
source_code = """import os,sys,json,datetime,re,collections,itertools,functools,operator
|
||||
from pathlib import Path
|
||||
from typing import Dict,List,Optional,Union,Any,Tuple
|
||||
import numpy as np,pandas as pd,matplotlib.pyplot as plt
|
||||
|
|
@ -554,25 +535,28 @@ def main():
|
|||
for error in processor.errors:print(f" - {error}")
|
||||
|
||||
if __name__=='__main__':main()
|
||||
'''
|
||||
"""
|
||||
_run_formatting_test(source_code, False)
|
||||
|
||||
|
||||
def test_formatting_edge_case_exactly_100_diffs():
|
||||
"""Test behavior when exactly at the threshold of 100 changes."""
|
||||
# Create a file with exactly 100 minor formatting issues
|
||||
snippet = '''import json\n''' + '''
|
||||
snippet = (
|
||||
"""import json\n"""
|
||||
"""
|
||||
def func_{i}():
|
||||
x=1;y=2;z=3
|
||||
return x+y+z
|
||||
'''
|
||||
"""
|
||||
)
|
||||
source_code = "".join([snippet.format(i=i) for i in range(100)])
|
||||
_run_formatting_test(source_code, False)
|
||||
|
||||
|
||||
def test_formatting_with_syntax_errors():
|
||||
"""Test that files with syntax errors are handled gracefully."""
|
||||
source_code = '''import json
|
||||
source_code = """import json
|
||||
|
||||
def process_data(data):
|
||||
if not data:
|
||||
|
|
@ -585,7 +569,7 @@ def process_data(data):
|
|||
result.append(item)
|
||||
|
||||
return result
|
||||
'''
|
||||
"""
|
||||
_run_formatting_test(source_code, False)
|
||||
|
||||
|
||||
|
|
@ -641,7 +625,7 @@ def another_function_with_long_line():
|
|||
|
||||
def test_formatting_class_with_methods():
|
||||
"""Test formatting of classes with multiple methods and minor issues."""
|
||||
source_code = '''class DataProcessor:
|
||||
source_code = """class DataProcessor:
|
||||
def __init__(self, config):
|
||||
self.config=config
|
||||
self.data=[]
|
||||
|
|
@ -660,13 +644,13 @@ def test_formatting_class_with_methods():
|
|||
'processed':True
|
||||
})
|
||||
return result
|
||||
'''
|
||||
"""
|
||||
_run_formatting_test(source_code, True)
|
||||
|
||||
|
||||
def test_formatting_with_complex_comprehensions():
|
||||
"""Test files with complex list/dict comprehensions and formatting."""
|
||||
source_code = '''def complex_comprehensions(data):
|
||||
source_code = """def complex_comprehensions(data):
|
||||
# Various comprehension styles with formatting issues
|
||||
result1=[item['value'] for item in data if item.get('active',True) and 'value' in item]
|
||||
|
||||
|
|
@ -683,13 +667,13 @@ def test_formatting_with_complex_comprehensions():
|
|||
'complex':result3,
|
||||
'nested':nested
|
||||
}
|
||||
'''
|
||||
"""
|
||||
_run_formatting_test(source_code, True)
|
||||
|
||||
|
||||
def test_formatting_with_decorators_and_async():
|
||||
"""Test files with decorators and async functions."""
|
||||
source_code = '''import asyncio
|
||||
source_code = """import asyncio
|
||||
from functools import wraps
|
||||
|
||||
def timer_decorator(func):
|
||||
|
|
@ -715,26 +699,26 @@ class AsyncProcessor:
|
|||
@staticmethod
|
||||
async def process_batch(batch):
|
||||
return [{'id':item['id'],'status':'done'} for item in batch if 'id' in item]
|
||||
'''
|
||||
"""
|
||||
_run_formatting_test(source_code, True)
|
||||
|
||||
|
||||
def test_formatting_threshold_configuration():
|
||||
"""Test that the diff threshold can be configured (if supported)."""
|
||||
# This test assumes the threshold might be configurable
|
||||
source_code = '''import json,os,sys
|
||||
source_code = """import json,os,sys
|
||||
def func1():x=1;y=2;return x+y
|
||||
def func2():a=1;b=2;return a+b
|
||||
def func3():c=1;d=2;return c+d
|
||||
'''
|
||||
"""
|
||||
# Test with a file that has moderate formatting issues
|
||||
_run_formatting_test(source_code, True, optimized_function="def func2():a=1;b=2;return a+b")
|
||||
|
||||
|
||||
def test_formatting_empty_file():
|
||||
"""Test formatting of empty or minimal files."""
|
||||
source_code = '''# Just a comment pass
|
||||
'''
|
||||
source_code = """# Just a comment pass
|
||||
"""
|
||||
_run_formatting_test(source_code, False)
|
||||
|
||||
|
||||
|
|
@ -798,6 +782,7 @@ class ProcessorWithDocs:
|
|||
return{'result':[item for item in data if self._is_valid(item)]}"""
|
||||
_run_formatting_test(source_code, True, optimized_function=optimization_function, expected=expected)
|
||||
|
||||
|
||||
def test_sort_imports_skip_file():
|
||||
"""Test that isort skips files with # isort:skip_file."""
|
||||
code = """# isort:skip_file
|
||||
|
|
@ -809,6 +794,7 @@ import sys, os, json # isort will ignore this file completely"""
|
|||
|
||||
# ==================== Tests for format_generated_code ====================
|
||||
|
||||
|
||||
def test_format_generated_code_disabled():
|
||||
"""Test that format_generated_code returns code with normalized newlines when formatter is disabled."""
|
||||
test_code = """import os
|
||||
|
|
@ -889,6 +875,7 @@ def test_function(x, y, z):
|
|||
result = format_generated_code(test_code, ["black $file"])
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_format_generated_code_with_inference():
|
||||
"""Test format_generated_code with ruff formatter."""
|
||||
try:
|
||||
|
|
@ -1154,6 +1141,7 @@ from inference.core.models.base import Model
|
|||
result = format_generated_code(test_code, ["ruff format $file"])
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_format_generated_code_with_ruff():
|
||||
"""Test format_generated_code with ruff formatter."""
|
||||
try:
|
||||
|
|
@ -1205,8 +1193,11 @@ def test_format_generated_code_invalid_formatter():
|
|||
|
||||
# Should handle gracefully and return code with normalized newlines
|
||||
result = format_generated_code(test_code, ["nonexistent_formatter $file"])
|
||||
assert result == """def test():
|
||||
assert (
|
||||
result
|
||||
== """def test():
|
||||
pass"""
|
||||
)
|
||||
|
||||
|
||||
def test_format_generated_code_syntax_error():
|
||||
|
|
@ -1217,8 +1208,11 @@ def test_format_generated_code_syntax_error():
|
|||
# Formatter should fail but function should handle it gracefully
|
||||
result = format_generated_code(test_code, ["black $file"])
|
||||
# Should return code with normalized newlines when formatting fails
|
||||
assert result == """def test(: # syntax error
|
||||
assert (
|
||||
result
|
||||
== """def test(: # syntax error
|
||||
pass"""
|
||||
)
|
||||
|
||||
|
||||
def test_format_generated_code_already_formatted():
|
||||
|
|
@ -1272,9 +1266,9 @@ def test_format_generated_code_trailing_whitespace():
|
|||
"""
|
||||
|
||||
result = format_generated_code(test_code, ["black $file"])
|
||||
lines = result.split('\n')
|
||||
lines = result.split("\n")
|
||||
for line in lines:
|
||||
assert line == line.rstrip(), f"Line has trailing whitespace: {repr(line)}"
|
||||
assert line == line.rstrip(), f"Line has trailing whitespace: {line!r}"
|
||||
|
||||
|
||||
def test_format_generated_code_preserves_comments():
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import pathlib
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -90,6 +89,7 @@ def recursive_dependency_1(num):
|
|||
num_1 = calculate_something(num)
|
||||
return recursive_dependency_1(num) + num_1
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
|
|
@ -190,6 +190,7 @@ class Graph:
|
|||
return stack"""
|
||||
)
|
||||
|
||||
|
||||
def test_recursive_function_context() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
|
||||
|
|
@ -232,4 +233,4 @@ class C:
|
|||
return 0
|
||||
num_1 = self.calculate_something_3(num)
|
||||
return self.recursive(num) + num_1"""
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,16 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
import os
|
||||
import unittest.mock
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import (
|
||||
filter_files_optimized,
|
||||
filter_functions,
|
||||
find_all_functions_in_file,
|
||||
get_all_files_and_functions,
|
||||
get_functions_to_optimize,
|
||||
inspect_top_level_functions_or_methods,
|
||||
filter_functions,
|
||||
get_all_files_and_functions
|
||||
)
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
from codeflash.code_utils.compat import codeflash_temp_dir
|
||||
|
||||
|
||||
def test_function_eligible_for_optimization() -> None:
|
||||
|
|
@ -24,10 +22,10 @@ def test_function_eligible_for_optimization() -> None:
|
|||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(function)
|
||||
|
||||
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
assert functions_found[file_path][0].function_name == "test_function_eligible_for_optimization"
|
||||
|
||||
|
|
@ -40,34 +38,31 @@ def test_function_eligible_for_optimization() -> None:
|
|||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(function)
|
||||
|
||||
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
assert len(functions_found[file_path]) == 0
|
||||
|
||||
|
||||
# we want to trigger an error in the function discovery
|
||||
function = """def test_invalid_code():"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(function)
|
||||
|
||||
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
assert functions_found == {}
|
||||
|
||||
|
||||
|
||||
|
||||
def test_find_top_level_function_or_method():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(
|
||||
"""def functionA():
|
||||
|
|
@ -93,7 +88,7 @@ def non_classmethod_function(cls, name):
|
|||
return cls.name
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
assert inspect_top_level_functions_or_methods(file_path, "functionA").is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(file_path, "functionB").is_top_level
|
||||
assert inspect_top_level_functions_or_methods(file_path, "functionC", class_name="A").is_top_level
|
||||
|
|
@ -117,20 +112,21 @@ def non_classmethod_function(cls, name):
|
|||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(
|
||||
"""def functionA():
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
assert not inspect_top_level_functions_or_methods(file_path, "functionA")
|
||||
|
||||
|
||||
def test_class_method_discovery():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(
|
||||
"""class A:
|
||||
|
|
@ -146,7 +142,7 @@ class X:
|
|||
def functionA():
|
||||
return True"""
|
||||
)
|
||||
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
|
||||
)
|
||||
|
|
@ -202,10 +198,10 @@ def test_nested_function():
|
|||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(
|
||||
"""
|
||||
"""
|
||||
import copy
|
||||
|
||||
def propagate_attributes(
|
||||
|
|
@ -249,7 +245,7 @@ def propagate_attributes(
|
|||
return modified_nodes
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
|
||||
)
|
||||
|
|
@ -270,10 +266,10 @@ def propagate_attributes(
|
|||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(
|
||||
"""
|
||||
"""
|
||||
def outer_function():
|
||||
def inner_function():
|
||||
pass
|
||||
|
|
@ -281,7 +277,7 @@ def outer_function():
|
|||
return inner_function
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
|
||||
)
|
||||
|
|
@ -302,10 +298,10 @@ def outer_function():
|
|||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(
|
||||
"""
|
||||
"""
|
||||
def outer_function():
|
||||
def inner_function():
|
||||
pass
|
||||
|
|
@ -315,7 +311,7 @@ def outer_function():
|
|||
return inner_function, another_inner_function
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
|
||||
)
|
||||
|
|
@ -349,15 +345,16 @@ def test_filter_files_optimized():
|
|||
assert filter_files_optimized(file_path_different_level, tests_root, ignore_paths, module_root)
|
||||
assert not filter_files_optimized(file_path_above_level, tests_root, ignore_paths, module_root)
|
||||
|
||||
|
||||
def test_filter_functions():
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
temp_dir = Path(temp_dir_str)
|
||||
|
||||
|
||||
# Create a test file in the temporary directory
|
||||
test_file_path = temp_dir.joinpath("test_get_functions_to_optimize.py")
|
||||
with test_file_path.open("w") as f:
|
||||
f.write(
|
||||
"""
|
||||
"""
|
||||
import copy
|
||||
|
||||
def propagate_attributes(
|
||||
|
|
@ -407,14 +404,15 @@ def not_in_checkpoint_function():
|
|||
return "This function is not in the checkpoint."
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
discovered = find_all_functions_in_file(test_file_path)
|
||||
modified_functions = {test_file_path: discovered[test_file_path]}
|
||||
# Use an absolute path for tests_root that won't match the temp directory
|
||||
# This avoids path resolution issues in CI where the working directory might differ
|
||||
tests_root_absolute = (temp_dir.parent / "nonexistent_tests_dir").resolve()
|
||||
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}):
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered, count = filter_functions(
|
||||
modified_functions,
|
||||
tests_root=tests_root_absolute,
|
||||
|
|
@ -429,19 +427,19 @@ def not_in_checkpoint_function():
|
|||
# Create a tests directory inside our temp directory
|
||||
tests_root_dir = temp_dir.joinpath("tests")
|
||||
tests_root_dir.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
test_file_path = tests_root_dir.joinpath("test_functions.py")
|
||||
with test_file_path.open("w") as f:
|
||||
f.write(
|
||||
"""
|
||||
"""
|
||||
def test_function_in_tests_dir():
|
||||
return "This function is in a test directory and should be filtered out."
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
discovered_test_file = find_all_functions_in_file(test_file_path)
|
||||
modified_functions_test = {test_file_path: discovered_test_file.get(test_file_path, [])}
|
||||
|
||||
|
||||
filtered_test_file, count_test_file = filter_functions(
|
||||
modified_functions_test,
|
||||
tests_root=tests_root_dir,
|
||||
|
|
@ -449,7 +447,7 @@ def test_function_in_tests_dir():
|
|||
project_root=temp_dir,
|
||||
module_root=temp_dir,
|
||||
)
|
||||
|
||||
|
||||
assert not filtered_test_file
|
||||
assert count_test_file == 0
|
||||
|
||||
|
|
@ -459,7 +457,7 @@ def test_function_in_tests_dir():
|
|||
ignored_file_path = ignored_dir.joinpath("ignored_file.py")
|
||||
with ignored_file_path.open("w") as f:
|
||||
f.write("def ignored_func(): return 1")
|
||||
|
||||
|
||||
discovered_ignored = find_all_functions_in_file(ignored_file_path)
|
||||
modified_functions_ignored = {ignored_file_path: discovered_ignored.get(ignored_file_path, [])}
|
||||
|
||||
|
|
@ -474,17 +472,19 @@ def test_function_in_tests_dir():
|
|||
assert count_ignored == 0
|
||||
|
||||
# Test submodule paths
|
||||
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.ignored_submodule_paths",
|
||||
return_value=[str(temp_dir.joinpath("submodule_dir"))]):
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.ignored_submodule_paths",
|
||||
return_value=[str(temp_dir.joinpath("submodule_dir"))],
|
||||
):
|
||||
submodule_dir = temp_dir.joinpath("submodule_dir")
|
||||
submodule_dir.mkdir(exist_ok=True)
|
||||
submodule_file_path = submodule_dir.joinpath("submodule_file.py")
|
||||
with submodule_file_path.open("w") as f:
|
||||
f.write("def submodule_func(): return 1")
|
||||
|
||||
|
||||
discovered_submodule = find_all_functions_in_file(submodule_file_path)
|
||||
modified_functions_submodule = {submodule_file_path: discovered_submodule.get(submodule_file_path, [])}
|
||||
|
||||
|
||||
filtered_submodule, count_submodule = filter_functions(
|
||||
modified_functions_submodule,
|
||||
tests_root=Path("tests"),
|
||||
|
|
@ -496,14 +496,17 @@ def test_function_in_tests_dir():
|
|||
assert count_submodule == 0
|
||||
|
||||
# Test site packages
|
||||
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages",
|
||||
return_value=True):
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages", return_value=True
|
||||
):
|
||||
site_package_file_path = temp_dir.joinpath("site_package_file.py")
|
||||
with site_package_file_path.open("w") as f:
|
||||
f.write("def site_package_func(): return 1")
|
||||
|
||||
discovered_site_package = find_all_functions_in_file(site_package_file_path)
|
||||
modified_functions_site_package = {site_package_file_path: discovered_site_package.get(site_package_file_path, [])}
|
||||
modified_functions_site_package = {
|
||||
site_package_file_path: discovered_site_package.get(site_package_file_path, [])
|
||||
}
|
||||
|
||||
filtered_site_package, count_site_package = filter_functions(
|
||||
modified_functions_site_package,
|
||||
|
|
@ -514,16 +517,18 @@ def test_function_in_tests_dir():
|
|||
)
|
||||
assert not filtered_site_package
|
||||
assert count_site_package == 0
|
||||
|
||||
|
||||
# Test outside module root
|
||||
parent_dir = temp_dir.parent
|
||||
outside_module_root_path = parent_dir.joinpath("outside_module_root_file.py")
|
||||
try:
|
||||
with outside_module_root_path.open("w") as f:
|
||||
f.write("def func_outside_module_root(): return 1")
|
||||
|
||||
|
||||
discovered_outside_module = find_all_functions_in_file(outside_module_root_path)
|
||||
modified_functions_outside_module = {outside_module_root_path: discovered_outside_module.get(outside_module_root_path, [])}
|
||||
modified_functions_outside_module = {
|
||||
outside_module_root_path: discovered_outside_module.get(outside_module_root_path, [])
|
||||
}
|
||||
|
||||
filtered_outside_module, count_outside_module = filter_functions(
|
||||
modified_functions_outside_module,
|
||||
|
|
@ -543,8 +548,10 @@ def test_function_in_tests_dir():
|
|||
f.write("def func_in_invalid_module(): return 1")
|
||||
|
||||
discovered_invalid_module = find_all_functions_in_file(invalid_module_file_path)
|
||||
modified_functions_invalid_module = {invalid_module_file_path: discovered_invalid_module.get(invalid_module_file_path, [])}
|
||||
|
||||
modified_functions_invalid_module = {
|
||||
invalid_module_file_path: discovered_invalid_module.get(invalid_module_file_path, [])
|
||||
}
|
||||
|
||||
filtered_invalid_module, count_invalid_module = filter_functions(
|
||||
modified_functions_invalid_module,
|
||||
tests_root=Path("tests"),
|
||||
|
|
@ -556,8 +563,10 @@ def test_function_in_tests_dir():
|
|||
assert count_invalid_module == 0
|
||||
|
||||
original_file_path = temp_dir.joinpath("test_get_functions_to_optimize.py")
|
||||
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions",
|
||||
return_value={original_file_path.name: {"propagate_attributes", "other_blocklisted_function"}}):
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions",
|
||||
return_value={original_file_path.name: {"propagate_attributes", "other_blocklisted_function"}},
|
||||
):
|
||||
filtered_funcs, count = filter_functions(
|
||||
modified_functions,
|
||||
tests_root=Path("tests"),
|
||||
|
|
@ -571,15 +580,20 @@ def test_function_in_tests_dir():
|
|||
module_name = "test_get_functions_to_optimize"
|
||||
qualified_name_for_checkpoint = f"{module_name}.propagate_attributes"
|
||||
other_qualified_name_for_checkpoint = f"{module_name}.vanilla_function"
|
||||
|
||||
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}):
|
||||
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered_checkpoint, count_checkpoint = filter_functions(
|
||||
modified_functions,
|
||||
tests_root=Path("tests"),
|
||||
ignore_paths=[],
|
||||
project_root=temp_dir,
|
||||
module_root=temp_dir,
|
||||
previous_checkpoint_functions={qualified_name_for_checkpoint: {"status": "optimized"}, other_qualified_name_for_checkpoint: {}}
|
||||
previous_checkpoint_functions={
|
||||
qualified_name_for_checkpoint: {"status": "optimized"},
|
||||
other_qualified_name_for_checkpoint: {},
|
||||
},
|
||||
)
|
||||
assert filtered_checkpoint.get(original_file_path)
|
||||
assert count_checkpoint == 1
|
||||
|
|
@ -589,4 +603,4 @@ def test_function_in_tests_dir():
|
|||
assert "propagate_attributes" not in remaining_functions
|
||||
assert "vanilla_function" not in remaining_functions
|
||||
files_and_funcs = get_all_files_and_functions(module_root_path=temp_dir, ignore_paths=[])
|
||||
assert len(files_and_funcs) == 6
|
||||
assert len(files_and_funcs) == 6
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.benchmarking.function_ranker import FunctionRanker
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, find_all_functions_in_file
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -36,25 +35,32 @@ def test_function_ranker_initialization(trace_file):
|
|||
|
||||
def test_load_function_stats(function_ranker):
|
||||
assert len(function_ranker._function_stats) > 0
|
||||
|
||||
|
||||
# Check that funcA is loaded with expected structure
|
||||
func_a_key = None
|
||||
for key, stats in function_ranker._function_stats.items():
|
||||
if stats["function_name"] == "funcA":
|
||||
func_a_key = key
|
||||
break
|
||||
|
||||
|
||||
assert func_a_key is not None
|
||||
func_a_stats = function_ranker._function_stats[func_a_key]
|
||||
|
||||
|
||||
# Verify funcA stats structure
|
||||
expected_keys = {
|
||||
"filename", "function_name", "qualified_name", "class_name",
|
||||
"line_number", "call_count", "own_time_ns", "cumulative_time_ns",
|
||||
"time_in_callees_ns", "addressable_time_ns"
|
||||
"filename",
|
||||
"function_name",
|
||||
"qualified_name",
|
||||
"class_name",
|
||||
"line_number",
|
||||
"call_count",
|
||||
"own_time_ns",
|
||||
"cumulative_time_ns",
|
||||
"time_in_callees_ns",
|
||||
"addressable_time_ns",
|
||||
}
|
||||
assert set(func_a_stats.keys()) == expected_keys
|
||||
|
||||
|
||||
# Verify funcA specific values
|
||||
assert func_a_stats["function_name"] == "funcA"
|
||||
assert func_a_stats["call_count"] == 1
|
||||
|
|
@ -68,7 +74,7 @@ def test_get_function_addressable_time(function_ranker, workload_functions):
|
|||
if func.function_name == "funcA":
|
||||
func_a = func
|
||||
break
|
||||
|
||||
|
||||
assert func_a is not None
|
||||
addressable_time = function_ranker.get_function_addressable_time(func_a)
|
||||
|
||||
|
|
@ -79,15 +85,15 @@ def test_get_function_addressable_time(function_ranker, workload_functions):
|
|||
|
||||
def test_rank_functions(function_ranker, workload_functions):
|
||||
ranked_functions = function_ranker.rank_functions(workload_functions)
|
||||
|
||||
|
||||
# Should filter out functions below importance threshold and sort by addressable time
|
||||
assert len(ranked_functions) <= len(workload_functions)
|
||||
assert len(ranked_functions) > 0 # At least some functions should pass the threshold
|
||||
|
||||
|
||||
# funcA should pass the importance threshold
|
||||
func_a_in_results = any(f.function_name == "funcA" for f in ranked_functions)
|
||||
assert func_a_in_results
|
||||
|
||||
|
||||
# Verify functions are sorted by addressable time in descending order
|
||||
for i in range(len(ranked_functions) - 1):
|
||||
current_time = function_ranker.get_function_addressable_time(ranked_functions[i])
|
||||
|
|
@ -101,10 +107,10 @@ def test_get_function_stats_summary(function_ranker, workload_functions):
|
|||
if func.function_name == "funcA":
|
||||
func_a = func
|
||||
break
|
||||
|
||||
|
||||
assert func_a is not None
|
||||
stats = function_ranker.get_function_stats_summary(func_a)
|
||||
|
||||
|
||||
assert stats is not None
|
||||
assert stats["function_name"] == "funcA"
|
||||
assert stats["own_time_ns"] == 153000
|
||||
|
|
@ -112,24 +118,19 @@ def test_get_function_stats_summary(function_ranker, workload_functions):
|
|||
assert stats["addressable_time_ns"] == 1324000
|
||||
|
||||
|
||||
|
||||
|
||||
def test_importance_calculation(function_ranker):
|
||||
total_program_time = sum(
|
||||
s["own_time_ns"] for s in function_ranker._function_stats.values()
|
||||
if s.get("own_time_ns", 0) > 0
|
||||
s["own_time_ns"] for s in function_ranker._function_stats.values() if s.get("own_time_ns", 0) > 0
|
||||
)
|
||||
|
||||
|
||||
func_a_stats = None
|
||||
for stats in function_ranker._function_stats.values():
|
||||
if stats["function_name"] == "funcA":
|
||||
func_a_stats = stats
|
||||
break
|
||||
|
||||
|
||||
assert func_a_stats is not None
|
||||
importance = func_a_stats["own_time_ns"] / total_program_time
|
||||
|
||||
# funcA importance should be approximately 1.9% (153000/7958000)
|
||||
assert abs(importance - 0.019) < 0.01
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.code_extractor import get_code
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
|
|
@ -276,4 +278,4 @@ class CustomDataClass:
|
|||
[FunctionToOptimize("name", f.name, [FunctionParent("CustomDataClass", "ClassDef")])]
|
||||
)
|
||||
assert new_code is None
|
||||
assert contextual_dunder_methods == set()
|
||||
assert contextual_dunder_methods == set()
|
||||
|
|
|
|||
|
|
@ -242,8 +242,8 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
code_context = ctx_result.unwrap()
|
||||
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
|
||||
assert (
|
||||
code_context.testgen_context.flat
|
||||
== f'''# file: {file_path.relative_to(project_root_path)}
|
||||
code_context.testgen_context.flat
|
||||
== f'''# file: {file_path.relative_to(project_root_path)}
|
||||
_P = ParamSpec("_P")
|
||||
_KEY_T = TypeVar("_KEY_T")
|
||||
_STORE_T = TypeVar("_STORE_T")
|
||||
|
|
@ -412,8 +412,8 @@ def test_bubble_sort_deps() -> None:
|
|||
pytest.fail()
|
||||
code_context = ctx_result.unwrap()
|
||||
assert (
|
||||
code_context.testgen_context.flat
|
||||
== f"""{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep1_helper.py"))}
|
||||
code_context.testgen_context.flat
|
||||
== f"""{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep1_helper.py"))}
|
||||
def dep1_comparer(arr, j: int) -> bool:
|
||||
return arr[j] > arr[j + 1]
|
||||
|
||||
|
|
@ -438,7 +438,7 @@ def sorter_deps(arr):
|
|||
)
|
||||
assert len(code_context.helper_functions) == 2
|
||||
assert (
|
||||
code_context.helper_functions[0].fully_qualified_name
|
||||
== "code_to_optimize.bubble_sort_dep1_helper.dep1_comparer"
|
||||
code_context.helper_functions[0].fully_qualified_name
|
||||
== "code_to_optimize.bubble_sort_dep1_helper.dep1_comparer"
|
||||
)
|
||||
assert code_context.helper_functions[1].fully_qualified_name == "code_to_optimize.bubble_sort_dep2_swap.dep2_swap"
|
||||
assert code_context.helper_functions[1].fully_qualified_name == "code_to_optimize.bubble_sort_dep2_swap.dep2_swap"
|
||||
|
|
|
|||
|
|
@ -73,7 +73,9 @@ def test_dunder_methods_remove_docstring() -> None:
|
|||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -98,7 +100,9 @@ def test_class_remove_docstring() -> None:
|
|||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -125,7 +129,9 @@ def test_mixed_remove_docstring() -> None:
|
|||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -204,7 +210,9 @@ def test_multiple_top_level_targets() -> None:
|
|||
expected = """
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target1", "TestClass.target2"}, set())
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target1", "TestClass.target2"}, set()
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -665,7 +673,9 @@ def test_simplified_complete_implementation() -> None:
|
|||
pass
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -753,6 +763,10 @@ def test_simplified_complete_implementation_no_docstring() -> None:
|
|||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True
|
||||
dedent(code),
|
||||
CodeContextType.READ_ONLY,
|
||||
{"DataProcessor.target_method", "ResultHandler.target_method"},
|
||||
set(),
|
||||
remove_docstrings=True,
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
|
||||
|
||||
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.models.models import CodeContextType
|
||||
|
||||
|
||||
|
|
@ -12,7 +13,7 @@ def test_simple_function() -> None:
|
|||
y = 2
|
||||
return x + y
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
|
||||
expected = dedent("""
|
||||
def target_function():
|
||||
|
|
@ -55,7 +56,7 @@ def test_class_with_attributes() -> None:
|
|||
def other_method(self):
|
||||
print("this should be excluded")
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -79,7 +80,7 @@ def test_basic_class_structure() -> None:
|
|||
def not_findable(self):
|
||||
return 42
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"Outer.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"Outer.target_method"})
|
||||
|
||||
expected = dedent("""
|
||||
class Outer:
|
||||
|
|
@ -99,7 +100,7 @@ def test_top_level_targets() -> None:
|
|||
def target_function():
|
||||
return 42
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
|
||||
expected = dedent("""
|
||||
def target_function():
|
||||
|
|
@ -122,7 +123,7 @@ def test_multiple_top_level_classes() -> None:
|
|||
def process(self):
|
||||
return "C"
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"})
|
||||
|
||||
expected = dedent("""
|
||||
class ClassA:
|
||||
|
|
@ -147,7 +148,7 @@ def test_try_except_structure() -> None:
|
|||
def handle_error(self):
|
||||
print("error")
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"TargetClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"TargetClass.target_method"})
|
||||
|
||||
expected = dedent("""
|
||||
try:
|
||||
|
|
@ -174,7 +175,7 @@ def test_init_method() -> None:
|
|||
def target_method(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -186,6 +187,7 @@ def test_init_method() -> None:
|
|||
""")
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_dunder_method() -> None:
|
||||
code = """
|
||||
class MyClass:
|
||||
|
|
@ -198,7 +200,7 @@ def test_dunder_method() -> None:
|
|||
def target_method(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -208,6 +210,7 @@ def test_dunder_method() -> None:
|
|||
""")
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_no_targets_found() -> None:
|
||||
code = """
|
||||
class MyClass:
|
||||
|
|
@ -218,7 +221,7 @@ def test_no_targets_found() -> None:
|
|||
def target(self):
|
||||
pass
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
def method(self):
|
||||
|
|
@ -239,7 +242,7 @@ def test_no_targets_found_raises_for_nonexistent() -> None:
|
|||
pass
|
||||
"""
|
||||
with pytest.raises(ValueError, match="No target functions found in the provided code"):
|
||||
parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"NonExistent.target"})
|
||||
parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"NonExistent.target"})
|
||||
|
||||
|
||||
def test_module_var() -> None:
|
||||
|
|
@ -263,7 +266,5 @@ def test_module_var() -> None:
|
|||
var2 = "test"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@ from textwrap import dedent
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.models.models import CodeContextType
|
||||
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.models.models import CodeContextType
|
||||
|
||||
|
||||
def test_simple_function() -> None:
|
||||
code = """
|
||||
|
|
@ -22,6 +23,7 @@ def test_simple_function() -> None:
|
|||
"""
|
||||
assert dedent(expected).strip() == result.strip()
|
||||
|
||||
|
||||
def test_basic_class() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
|
|
@ -45,6 +47,7 @@ def test_basic_class() -> None:
|
|||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_dunder_methods() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
|
|
@ -102,7 +105,9 @@ def test_dunder_methods_remove_docstring() -> None:
|
|||
print("include me")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -132,7 +137,9 @@ def test_class_remove_docstring() -> None:
|
|||
print("include me")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -152,6 +159,7 @@ def test_target_in_nested_class() -> None:
|
|||
with pytest.raises(ValueError, match="No target functions found in the provided code"):
|
||||
parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"Outer.Inner.target_method"}, set())
|
||||
|
||||
|
||||
def test_method_signatures() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
|
|
@ -175,6 +183,8 @@ def test_method_signatures() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_multiple_top_level_targets() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
|
|
@ -203,7 +213,9 @@ def test_multiple_top_level_targets() -> None:
|
|||
self.x = 42
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target1", "TestClass.target2"}, set())
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"TestClass.target1", "TestClass.target2"}, set()
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -229,6 +241,7 @@ def test_class_annotations() -> None:
|
|||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_class_annotations_if() -> None:
|
||||
code = """
|
||||
if True:
|
||||
|
|
@ -345,6 +358,7 @@ def test_module_var() -> None:
|
|||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_module_var_if() -> None:
|
||||
code = """
|
||||
def target_function(self) -> None:
|
||||
|
|
@ -374,6 +388,7 @@ def test_module_var_if() -> None:
|
|||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_multiple_classes() -> None:
|
||||
code = """
|
||||
class ClassA:
|
||||
|
|
@ -399,7 +414,9 @@ def test_multiple_classes() -> None:
|
|||
return "C"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"ClassA.process", "ClassC.process"}, set())
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"ClassA.process", "ClassC.process"}, set()
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -518,6 +535,7 @@ def test_async_with_try_except() -> None:
|
|||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_simplified_complete_implementation() -> None:
|
||||
code = """
|
||||
class DataProcessor:
|
||||
|
|
@ -639,7 +657,9 @@ def test_simplified_complete_implementation() -> None:
|
|||
raise RuntimeError(f"Failed to initialize: {self.error}")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -740,6 +760,10 @@ def test_simplified_complete_implementation_no_docstring() -> None:
|
|||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True
|
||||
dedent(code),
|
||||
CodeContextType.TESTGEN,
|
||||
{"DataProcessor.target_method", "ResultHandler.target_method"},
|
||||
set(),
|
||||
remove_docstrings=True,
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from codeflash.code_utils.time_utils import humanize_runtime, format_time
|
||||
from codeflash.code_utils.time_utils import format_perf
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.time_utils import format_perf, format_time, humanize_runtime
|
||||
|
||||
|
||||
def test_humanize_runtime():
|
||||
assert humanize_runtime(0) == "0.00 nanoseconds"
|
||||
|
|
@ -140,19 +140,22 @@ class TestFormatTime:
|
|||
assert format_time(3_600_000_000_000) == "3600s" # 1 hour
|
||||
assert format_time(86_400_000_000_000) == "86400s" # 1 day
|
||||
|
||||
@pytest.mark.parametrize("nanoseconds,expected", [
|
||||
(0, "0ns"),
|
||||
(42, "42ns"),
|
||||
(1_500, "1.50μs"),
|
||||
(25_000, "25.0μs"),
|
||||
(150_000, "150μs"),
|
||||
(2_500_000, "2.50ms"),
|
||||
(45_000_000, "45.0ms"),
|
||||
(200_000_000, "200ms"),
|
||||
(3_500_000_000, "3.50s"),
|
||||
(75_000_000_000, "75.0s"),
|
||||
(300_000_000_000, "300s"),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"nanoseconds,expected",
|
||||
[
|
||||
(0, "0ns"),
|
||||
(42, "42ns"),
|
||||
(1_500, "1.50μs"),
|
||||
(25_000, "25.0μs"),
|
||||
(150_000, "150μs"),
|
||||
(2_500_000, "2.50ms"),
|
||||
(45_000_000, "45.0ms"),
|
||||
(200_000_000, "200ms"),
|
||||
(3_500_000_000, "3.50s"),
|
||||
(75_000_000_000, "75.0s"),
|
||||
(300_000_000_000, "300s"),
|
||||
],
|
||||
)
|
||||
def test_parametrized_examples(self, nanoseconds, expected):
|
||||
"""Parametrized test with various input/output combinations."""
|
||||
assert format_time(nanoseconds) == expected
|
||||
|
|
@ -272,4 +275,4 @@ class TestFormatPerf:
|
|||
assert format_perf(100.4) == "100"
|
||||
assert format_perf(10.54) == "10.5"
|
||||
assert format_perf(1.554) == "1.55"
|
||||
assert format_perf(0.1554) == "0.155"
|
||||
assert format_perf(0.1554) == "0.155"
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ from __future__ import annotations
|
|||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
detect_frameworks_from_code,
|
||||
inject_profiling_into_existing_test,
|
||||
|
|
@ -28,19 +26,11 @@ def normalize_instrumented_code(code: str) -> str:
|
|||
generates double-quoted f-strings for compatibility with older versions).
|
||||
"""
|
||||
# Normalize database path
|
||||
code = re.sub(
|
||||
r"sqlite3\.connect\(f'[^']+'",
|
||||
"sqlite3.connect(f'{CODEFLASH_DB_PATH}'",
|
||||
code
|
||||
)
|
||||
code = re.sub(r"sqlite3\.connect\(f'[^']+'", "sqlite3.connect(f'{CODEFLASH_DB_PATH}'", code)
|
||||
# Normalize f-string that contains the test_stdout_tag assignment
|
||||
# This specific f-string has internal single quotes, so libcst uses double quotes
|
||||
# on Python < 3.12, but single quotes on Python 3.12+
|
||||
code = re.sub(
|
||||
r'test_stdout_tag = f"([^"]+)"',
|
||||
r"test_stdout_tag = f'\1'",
|
||||
code
|
||||
)
|
||||
code = re.sub(r'test_stdout_tag = f"([^"]+)"', r"test_stdout_tag = f'\1'", code)
|
||||
return code
|
||||
|
||||
|
||||
|
|
@ -1112,11 +1102,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1142,11 +1128,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1172,11 +1154,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1202,11 +1180,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1232,11 +1206,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1262,11 +1232,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1292,11 +1258,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1322,11 +1284,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1353,11 +1311,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1385,11 +1339,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1423,11 +1373,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1453,11 +1399,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1483,11 +1425,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1513,11 +1451,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
@ -1545,11 +1479,7 @@ def test_my_function():
|
|||
test_file = tmp_path / "test_example.py"
|
||||
test_file.write_text(code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_function",
|
||||
parents=[],
|
||||
file_path=Path("mymodule.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
|
||||
|
||||
success, instrumented_code = inject_profiling_into_existing_test(
|
||||
test_path=test_file,
|
||||
|
|
|
|||
|
|
@ -116,11 +116,7 @@ def test_sort():
|
|||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path(fto_path))
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(6, 13), CodePosition(10, 13)],
|
||||
func,
|
||||
project_root_path,
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
test_path, [CodePosition(6, 13), CodePosition(10, 13)], func, project_root_path, mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -552,7 +548,9 @@ def test_sort():
|
|||
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_method.py").resolve()
|
||||
original_code = fto_path.read_text("utf-8")
|
||||
fto = FunctionToOptimize(
|
||||
function_name="sorter_classmethod", parents=[FunctionParent(name="BubbleSorter", type="ClassDef")], file_path=Path(fto_path)
|
||||
function_name="sorter_classmethod",
|
||||
parents=[FunctionParent(name="BubbleSorter", type="ClassDef")],
|
||||
file_path=Path(fto_path),
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmp_test_path = Path(tmpdirname) / "test_classmethod_behavior_results_temp.py"
|
||||
|
|
@ -646,8 +644,11 @@ def test_sort():
|
|||
)
|
||||
assert test_results[1].runtime > 0
|
||||
assert test_results[1].did_pass
|
||||
assert test_results[1].stdout == """codeflash stdout : BubbleSorter.sorter_classmethod() called
|
||||
assert (
|
||||
test_results[1].stdout
|
||||
== """codeflash stdout : BubbleSorter.sorter_classmethod() called
|
||||
"""
|
||||
)
|
||||
|
||||
results2, _ = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
|
|
@ -718,7 +719,9 @@ def test_sort():
|
|||
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_method.py").resolve()
|
||||
original_code = fto_path.read_text("utf-8")
|
||||
fto = FunctionToOptimize(
|
||||
function_name="sorter_staticmethod", parents=[FunctionParent(name="BubbleSorter", type="ClassDef")], file_path=Path(fto_path)
|
||||
function_name="sorter_staticmethod",
|
||||
parents=[FunctionParent(name="BubbleSorter", type="ClassDef")],
|
||||
file_path=Path(fto_path),
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmp_test_path = Path(tmpdirname) / "test_staticmethod_behavior_results_temp.py"
|
||||
|
|
@ -812,8 +815,11 @@ def test_sort():
|
|||
)
|
||||
assert test_results[1].runtime > 0
|
||||
assert test_results[1].did_pass
|
||||
assert test_results[1].stdout == """codeflash stdout : BubbleSorter.sorter_staticmethod() called
|
||||
assert (
|
||||
test_results[1].stdout
|
||||
== """codeflash stdout : BubbleSorter.sorter_staticmethod() called
|
||||
"""
|
||||
)
|
||||
|
||||
results2, _ = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
|
|
@ -831,4 +837,4 @@ def test_sort():
|
|||
finally:
|
||||
fto_path.write_text(original_code, "utf-8")
|
||||
test_path.unlink(missing_ok=True)
|
||||
test_path_perf.unlink(missing_ok=True)
|
||||
test_path_perf.unlink(missing_ok=True)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -81,9 +80,7 @@ async def async_function(x: int, y: int) -> int:
|
|||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(async_function_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=test_file, parents=[], is_async=True
|
||||
)
|
||||
func = FunctionToOptimize(function_name="async_function", file_path=test_file, parents=[], is_async=True)
|
||||
|
||||
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR)
|
||||
|
||||
|
|
@ -120,9 +117,7 @@ async def async_function(x: int, y: int) -> int:
|
|||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(async_function_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=test_file, parents=[], is_async=True
|
||||
)
|
||||
func = FunctionToOptimize(function_name="async_function", file_path=test_file, parents=[], is_async=True)
|
||||
|
||||
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.PERFORMANCE)
|
||||
|
||||
|
|
@ -160,9 +155,7 @@ async def async_function(x: int, y: int) -> int:
|
|||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(async_function_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=test_file, parents=[], is_async=True
|
||||
)
|
||||
func = FunctionToOptimize(function_name="async_function", file_path=test_file, parents=[], is_async=True)
|
||||
|
||||
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.CONCURRENCY)
|
||||
|
||||
|
|
@ -243,9 +236,7 @@ async def async_function(x: int, y: int) -> int:
|
|||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(already_decorated_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=test_file, parents=[], is_async=True
|
||||
)
|
||||
func = FunctionToOptimize(function_name="async_function", file_path=test_file, parents=[], is_async=True)
|
||||
|
||||
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR)
|
||||
|
||||
|
|
@ -290,12 +281,10 @@ async def test_async_function():
|
|||
# First instrument the source module
|
||||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
source_success = add_async_decorator_to_function(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
source_success = add_async_decorator_to_function(source_file, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert source_success is True
|
||||
|
||||
|
||||
# Verify the file was modified
|
||||
instrumented_source = source_file.read_text()
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
|
|
@ -347,12 +336,10 @@ async def test_async_function():
|
|||
# First instrument the source module
|
||||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
source_success = add_async_decorator_to_function(
|
||||
source_file, func, TestingMode.PERFORMANCE
|
||||
)
|
||||
source_success = add_async_decorator_to_function(source_file, func, TestingMode.PERFORMANCE)
|
||||
|
||||
assert source_success is True
|
||||
|
||||
|
||||
# Verify the file was modified
|
||||
instrumented_source = source_file.read_text()
|
||||
assert "@codeflash_performance_async" in instrumented_source
|
||||
|
|
@ -413,12 +400,10 @@ async def test_mixed_functions():
|
|||
|
||||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
source_success = add_async_decorator_to_function(
|
||||
source_file, async_func, TestingMode.BEHAVIOR
|
||||
)
|
||||
source_success = add_async_decorator_to_function(source_file, async_func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert source_success
|
||||
|
||||
|
||||
# Verify the file was modified
|
||||
instrumented_source = source_file.read_text()
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
|
|
@ -428,11 +413,7 @@ async def test_mixed_functions():
|
|||
assert "def sync_function(x: int, y: int) -> int:" in instrumented_source
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file,
|
||||
[CodePosition(8, 18), CodePosition(11, 19)],
|
||||
async_func,
|
||||
temp_dir,
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
test_file, [CodePosition(8, 18), CodePosition(11, 19)], async_func, temp_dir, mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
# Async functions should not be instrumented at the test level
|
||||
|
|
@ -465,8 +446,7 @@ class OuterClass:
|
|||
|
||||
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR)
|
||||
|
||||
expected_output = (
|
||||
"""import asyncio
|
||||
expected_output = """import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
|
@ -480,7 +460,6 @@ class OuterClass:
|
|||
await asyncio.sleep(0.001)
|
||||
return x * 2
|
||||
"""
|
||||
)
|
||||
|
||||
assert decorator_added
|
||||
modified_code = test_file.read_text()
|
||||
|
|
@ -510,9 +489,7 @@ async def async_function(x: int, y: int) -> int:
|
|||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(decorated_async_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=test_file, parents=[], is_async=True
|
||||
)
|
||||
func = FunctionToOptimize(function_name="async_function", file_path=test_file, parents=[], is_async=True)
|
||||
|
||||
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR)
|
||||
|
||||
|
|
@ -538,22 +515,16 @@ def sync_function(x: int, y: int) -> int:
|
|||
test_file = temp_dir / "test_sync.py"
|
||||
test_file.write_text(sync_function_code)
|
||||
|
||||
sync_func = FunctionToOptimize(
|
||||
function_name="sync_function",
|
||||
file_path=test_file,
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
sync_func = FunctionToOptimize(function_name="sync_function", file_path=test_file, parents=[], is_async=False)
|
||||
|
||||
decorator_added = add_async_decorator_to_function(
|
||||
test_file, sync_func, TestingMode.BEHAVIOR
|
||||
)
|
||||
decorator_added = add_async_decorator_to_function(test_file, sync_func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert not decorator_added
|
||||
# File should not be modified for sync functions
|
||||
modified_code = test_file.read_text()
|
||||
assert modified_code == sync_function_code
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_inject_profiling_async_multiple_calls_same_test(temp_dir):
|
||||
"""Test that multiple async function calls within the same test function get correctly numbered 0, 1, 2, etc."""
|
||||
|
|
@ -599,12 +570,10 @@ async def test_multiple_calls():
|
|||
# First instrument the source module with async decorators
|
||||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
source_success = add_async_decorator_to_function(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
source_success = add_async_decorator_to_function(source_file, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert source_success
|
||||
|
||||
|
||||
# Verify the file was modified
|
||||
instrumented_source = source_file.read_text()
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
|
|
@ -636,18 +605,15 @@ async def test_multiple_calls():
|
|||
line_id_1_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '1'")
|
||||
line_id_2_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '2'")
|
||||
|
||||
|
||||
assert line_id_0_count == 2, f"Expected 2 occurrences of line_id '0', got {line_id_0_count}"
|
||||
assert line_id_1_count == 1, f"Expected 1 occurrence of line_id '1', got {line_id_1_count}"
|
||||
assert line_id_2_count == 1, f"Expected 1 occurrence of line_id '2', got {line_id_2_count}"
|
||||
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_behavior_decorator_return_values_and_test_ids():
|
||||
"""Test that async behavior decorator correctly captures return values, test IDs, and stores data in database."""
|
||||
import asyncio
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -684,7 +650,7 @@ def test_async_behavior_decorator_return_values_and_test_ids():
|
|||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
|
||||
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_2.sqlite"))
|
||||
db_path = get_run_tmp_file(Path("test_return_values_2.sqlite"))
|
||||
|
||||
# Verify database exists and has data
|
||||
assert db_path.exists(), f"Database file not created at {db_path}"
|
||||
|
|
@ -745,7 +711,6 @@ def test_async_behavior_decorator_return_values_and_test_ids():
|
|||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_decorator_comprehensive_return_values_and_test_ids():
|
||||
import asyncio
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -793,7 +758,7 @@ def test_async_decorator_comprehensive_return_values_and_test_ids():
|
|||
f"Expected {test_case['expected']}, got {result} for args {test_case['args']}, kwargs {test_case['kwargs']}"
|
||||
)
|
||||
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_3.sqlite"))
|
||||
db_path = get_run_tmp_file(Path("test_return_values_3.sqlite"))
|
||||
assert db_path.exists(), f"Database not created at {db_path}"
|
||||
|
||||
con = sqlite3.connect(db_path)
|
||||
|
|
@ -837,7 +802,6 @@ def test_async_decorator_comprehensive_return_values_and_test_ids():
|
|||
f"Row {i}: Expected iteration_id '{expected_iteration_id}', got '{iteration_id}'"
|
||||
)
|
||||
|
||||
|
||||
args, kwargs, actual_return_value = pickle.loads(return_value_blob)
|
||||
expected_args = test_cases[i]["args"]
|
||||
expected_kwargs = test_cases[i]["kwargs"]
|
||||
|
|
|
|||
|
|
@ -3,8 +3,10 @@ from __future__ import annotations
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.benchmarking.instrument_codeflash_trace import add_codeflash_decorator_to_code, \
|
||||
instrument_codeflash_trace_decorator
|
||||
from codeflash.benchmarking.instrument_codeflash_trace import (
|
||||
add_codeflash_decorator_to_code,
|
||||
instrument_codeflash_trace_decorator,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
|
||||
|
||||
|
|
@ -15,16 +17,9 @@ def normal_function():
|
|||
return "Hello, World!"
|
||||
"""
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="normal_function",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[]
|
||||
)
|
||||
fto = FunctionToOptimize(function_name="normal_function", file_path=Path("dummy_path.py"), parents=[])
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
|
@ -47,13 +42,10 @@ class TestClass:
|
|||
fto = FunctionToOptimize(
|
||||
function_name="normal_method",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
|
@ -78,13 +70,10 @@ class TestClass:
|
|||
fto = FunctionToOptimize(
|
||||
function_name="class_method",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
|
@ -110,13 +99,10 @@ class TestClass:
|
|||
fto = FunctionToOptimize(
|
||||
function_name="static_method",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
|
@ -141,13 +127,10 @@ class TestClass:
|
|||
fto = FunctionToOptimize(
|
||||
function_name="__init__",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
|
@ -173,13 +156,10 @@ class TestClass:
|
|||
fto = FunctionToOptimize(
|
||||
function_name="property_method",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
|
@ -209,13 +189,10 @@ class OtherClass:
|
|||
fto = FunctionToOptimize(
|
||||
function_name="test_method",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
|
@ -239,16 +216,9 @@ def existing_function():
|
|||
return "This exists"
|
||||
"""
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="nonexistent_function",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[]
|
||||
)
|
||||
fto = FunctionToOptimize(function_name="nonexistent_function", file_path=Path("dummy_path.py"), parents=[])
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
|
||||
|
||||
# Code should remain unchanged
|
||||
assert modified_code.strip() == code.strip()
|
||||
|
|
@ -272,27 +242,16 @@ def function_two():
|
|||
"""
|
||||
|
||||
functions_to_optimize = [
|
||||
FunctionToOptimize(
|
||||
function_name="function_one",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[]
|
||||
),
|
||||
FunctionToOptimize(function_name="function_one", file_path=Path("dummy_path.py"), parents=[]),
|
||||
FunctionToOptimize(
|
||||
function_name="method_two",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
||||
),
|
||||
FunctionToOptimize(
|
||||
function_name="function_two",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[]
|
||||
)
|
||||
FunctionToOptimize(function_name="function_two", file_path=Path("dummy_path.py"), parents=[]),
|
||||
]
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=functions_to_optimize
|
||||
)
|
||||
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=functions_to_optimize)
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
|
@ -339,16 +298,12 @@ def function_two():
|
|||
|
||||
# Define functions to optimize
|
||||
functions_to_optimize = [
|
||||
FunctionToOptimize(
|
||||
function_name="function_one",
|
||||
file_path=test_file_path,
|
||||
parents=[]
|
||||
),
|
||||
FunctionToOptimize(function_name="function_one", file_path=test_file_path, parents=[]),
|
||||
FunctionToOptimize(
|
||||
function_name="method_two",
|
||||
file_path=test_file_path,
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
)
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
||||
),
|
||||
]
|
||||
|
||||
# Execute the function being tested
|
||||
|
|
@ -399,7 +354,7 @@ class ClassA:
|
|||
|
||||
# Create second test Python file
|
||||
test_file_2_path = Path(temp_dir) / "module_b.py"
|
||||
test_file_2_content ="""
|
||||
test_file_2_content = """
|
||||
def function_b():
|
||||
return "Function in module B"
|
||||
|
||||
|
|
@ -412,20 +367,14 @@ class ClassB:
|
|||
|
||||
# Define functions to optimize
|
||||
file_to_funcs_to_optimize = {
|
||||
test_file_1_path: [
|
||||
FunctionToOptimize(
|
||||
function_name="function_a",
|
||||
file_path=test_file_1_path,
|
||||
parents=[]
|
||||
)
|
||||
],
|
||||
test_file_1_path: [FunctionToOptimize(function_name="function_a", file_path=test_file_1_path, parents=[])],
|
||||
test_file_2_path: [
|
||||
FunctionToOptimize(
|
||||
function_name="static_method_b",
|
||||
file_path=test_file_2_path,
|
||||
parents=[FunctionParent(name="ClassB", type="ClassDef")]
|
||||
parents=[FunctionParent(name="ClassB", type="ClassDef")],
|
||||
)
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
# Execute the function being tested
|
||||
|
|
@ -484,13 +433,10 @@ class OuterClass:
|
|||
fto = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="OuterClass", type="ClassDef")]
|
||||
parents=[FunctionParent(name="OuterClass", type="ClassDef")],
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
|
@ -520,16 +466,9 @@ def target_function():
|
|||
return "Hello from target function after nested function"
|
||||
"""
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="target_function",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[]
|
||||
)
|
||||
fto = FunctionToOptimize(function_name="target_function", file_path=Path("dummy_path.py"), parents=[])
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
|
@ -561,11 +500,7 @@ def some_function():
|
|||
"""
|
||||
test_file_path.write_text(original_content, encoding="utf-8")
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="some_function",
|
||||
file_path=test_file_path,
|
||||
parents=[]
|
||||
)
|
||||
fto = FunctionToOptimize(function_name="some_function", file_path=test_file_path, parents=[])
|
||||
|
||||
instrument_codeflash_trace_decorator({test_file_path: [fto]})
|
||||
|
||||
|
|
@ -587,11 +522,7 @@ def patch_function():
|
|||
"""
|
||||
test_file_path.write_text(original_content, encoding="utf-8")
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="patch_function",
|
||||
file_path=test_file_path,
|
||||
parents=[]
|
||||
)
|
||||
fto = FunctionToOptimize(function_name="patch_function", file_path=test_file_path, parents=[])
|
||||
|
||||
instrument_codeflash_trace_decorator({test_file_path: [fto]})
|
||||
|
||||
|
|
@ -616,11 +547,7 @@ def trace_func():
|
|||
"""
|
||||
test_file_path.write_text(original_content, encoding="utf-8")
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="trace_func",
|
||||
file_path=test_file_path,
|
||||
parents=[]
|
||||
)
|
||||
fto = FunctionToOptimize(function_name="trace_func", file_path=test_file_path, parents=[])
|
||||
|
||||
instrument_codeflash_trace_decorator({test_file_path: [fto]})
|
||||
|
||||
|
|
@ -645,11 +572,7 @@ def util_func():
|
|||
"""
|
||||
test_file_path.write_text(original_content, encoding="utf-8")
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="util_func",
|
||||
file_path=test_file_path,
|
||||
parents=[]
|
||||
)
|
||||
fto = FunctionToOptimize(function_name="util_func", file_path=test_file_path, parents=[])
|
||||
|
||||
instrument_codeflash_trace_decorator({test_file_path: [fto]})
|
||||
|
||||
|
|
@ -673,15 +596,11 @@ def main_func():
|
|||
"""
|
||||
test_file_path.write_text(original_content, encoding="utf-8")
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="main_func",
|
||||
file_path=test_file_path,
|
||||
parents=[]
|
||||
)
|
||||
fto = FunctionToOptimize(function_name="main_func", file_path=test_file_path, parents=[])
|
||||
|
||||
instrument_codeflash_trace_decorator({test_file_path: [fto]})
|
||||
|
||||
# File SHOULD be modified
|
||||
modified_content = test_file_path.read_text(encoding="utf-8")
|
||||
assert "codeflash_trace" in modified_content
|
||||
assert "@codeflash_trace" in modified_content
|
||||
assert "@codeflash_trace" in modified_content
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ import os
|
|||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.line_profile_utils import add_decorator_imports, contains_jit_decorator
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
|
|
@ -26,7 +24,7 @@ def test_add_decorator_imports_helper_in_class():
|
|||
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
#func_optimizer = pass
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
ctx_result = func_optimizer.get_code_optimization_context()
|
||||
code_context: CodeOptimizationContext = ctx_result.unwrap()
|
||||
|
|
@ -36,8 +34,7 @@ def test_add_decorator_imports_helper_in_class():
|
|||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
line_profiler_output_file = add_decorator_imports(
|
||||
func_optimizer.function_to_optimize, code_context)
|
||||
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
|
||||
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
||||
|
||||
|
|
@ -77,11 +74,14 @@ class BubbleSortClass:
|
|||
assert code_context.helper_functions[0].file_path.read_text("utf-8") == expected_code_helper
|
||||
finally:
|
||||
func_optimizer.write_code_and_helpers(
|
||||
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
|
||||
func_optimizer.function_to_optimize_source_code,
|
||||
original_helper_code,
|
||||
func_optimizer.function_to_optimize.file_path,
|
||||
)
|
||||
|
||||
|
||||
def test_add_decorator_imports_helper_in_nested_class():
|
||||
#Need to invert the assert once the helper detection is fixed
|
||||
# Need to invert the assert once the helper detection is fixed
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_nested_classmethod.py").resolve()
|
||||
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
|
||||
project_root_path = (Path(__file__).parent / "..").resolve()
|
||||
|
|
@ -96,7 +96,7 @@ def test_add_decorator_imports_helper_in_nested_class():
|
|||
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
#func_optimizer = pass
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
ctx_result = func_optimizer.get_code_optimization_context()
|
||||
code_context: CodeOptimizationContext = ctx_result.unwrap()
|
||||
|
|
@ -106,8 +106,7 @@ def test_add_decorator_imports_helper_in_nested_class():
|
|||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
line_profiler_output_file = add_decorator_imports(
|
||||
func_optimizer.function_to_optimize, code_context)
|
||||
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
|
||||
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
||||
|
||||
|
|
@ -125,9 +124,12 @@ def sort_classmethod(x):
|
|||
assert code_context.helper_functions[0].qualified_name == "WrapperClass.__init__"
|
||||
finally:
|
||||
func_optimizer.write_code_and_helpers(
|
||||
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
|
||||
func_optimizer.function_to_optimize_source_code,
|
||||
original_helper_code,
|
||||
func_optimizer.function_to_optimize.file_path,
|
||||
)
|
||||
|
||||
|
||||
def test_add_decorator_imports_nodeps():
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
|
||||
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
|
||||
|
|
@ -143,7 +145,7 @@ def test_add_decorator_imports_nodeps():
|
|||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
#func_optimizer = pass
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
ctx_result = func_optimizer.get_code_optimization_context()
|
||||
code_context: CodeOptimizationContext = ctx_result.unwrap()
|
||||
|
|
@ -153,8 +155,7 @@ def test_add_decorator_imports_nodeps():
|
|||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
line_profiler_output_file = add_decorator_imports(
|
||||
func_optimizer.function_to_optimize, code_context)
|
||||
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
|
||||
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
||||
|
||||
|
|
@ -174,9 +175,12 @@ def sorter(arr):
|
|||
assert code_path.read_text("utf-8") == expected_code_main
|
||||
finally:
|
||||
func_optimizer.write_code_and_helpers(
|
||||
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
|
||||
func_optimizer.function_to_optimize_source_code,
|
||||
original_helper_code,
|
||||
func_optimizer.function_to_optimize.file_path,
|
||||
)
|
||||
|
||||
|
||||
def test_add_decorator_imports_helper_outside():
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_deps.py").resolve()
|
||||
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
|
||||
|
|
@ -192,7 +196,7 @@ def test_add_decorator_imports_helper_outside():
|
|||
func = FunctionToOptimize(function_name="sorter_deps", parents=[], file_path=code_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
#func_optimizer = pass
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
ctx_result = func_optimizer.get_code_optimization_context()
|
||||
code_context: CodeOptimizationContext = ctx_result.unwrap()
|
||||
|
|
@ -202,8 +206,7 @@ def test_add_decorator_imports_helper_outside():
|
|||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
line_profiler_output_file = add_decorator_imports(
|
||||
func_optimizer.function_to_optimize, code_context)
|
||||
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
|
||||
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
||||
|
||||
|
|
@ -227,7 +230,7 @@ def sorter_deps(arr):
|
|||
def dep1_comparer(arr, j: int) -> bool:
|
||||
return arr[j] > arr[j + 1]
|
||||
"""
|
||||
expected_code_helper2="""from line_profiler import profile as codeflash_line_profile
|
||||
expected_code_helper2 = """from line_profiler import profile as codeflash_line_profile
|
||||
|
||||
|
||||
@codeflash_line_profile
|
||||
|
|
@ -241,9 +244,12 @@ def dep2_swap(arr, j):
|
|||
assert code_context.helper_functions[1].file_path.read_text("utf-8") == expected_code_helper2
|
||||
finally:
|
||||
func_optimizer.write_code_and_helpers(
|
||||
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
|
||||
func_optimizer.function_to_optimize_source_code,
|
||||
original_helper_code,
|
||||
func_optimizer.function_to_optimize.file_path,
|
||||
)
|
||||
|
||||
|
||||
def test_add_decorator_imports_helper_in_dunder_class():
|
||||
code_str = """def sorter(arr):
|
||||
ans = helper(arr)
|
||||
|
|
@ -253,7 +259,7 @@ class helper:
|
|||
return arr.sort()"""
|
||||
code_path = TemporaryDirectory()
|
||||
code_write_path = Path(code_path.name) / "dunder_class.py"
|
||||
code_write_path.write_text(code_str,"utf-8")
|
||||
code_write_path.write_text(code_str, "utf-8")
|
||||
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
|
||||
project_root_path = Path(code_path.name)
|
||||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
|
|
@ -267,7 +273,7 @@ class helper:
|
|||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_write_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
#func_optimizer = pass
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
ctx_result = func_optimizer.get_code_optimization_context()
|
||||
code_context: CodeOptimizationContext = ctx_result.unwrap()
|
||||
|
|
@ -277,8 +283,7 @@ class helper:
|
|||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
line_profiler_output_file = add_decorator_imports(
|
||||
func_optimizer.function_to_optimize, code_context)
|
||||
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
|
||||
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
||||
|
||||
|
|
|
|||
|
|
@ -3,10 +3,13 @@ from __future__ import annotations
|
|||
import ast
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
FunctionImportedAsVisitor,
|
||||
|
|
@ -24,8 +27,6 @@ from codeflash.models.models import (
|
|||
TestsInFile,
|
||||
TestType,
|
||||
)
|
||||
import platform
|
||||
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
|
@ -114,12 +115,15 @@ import pytest"""
|
|||
if extra_imports:
|
||||
imports += "\n" + extra_imports
|
||||
return imports
|
||||
|
||||
|
||||
# create a temporary directory for the test results
|
||||
@pytest.fixture
|
||||
def tmp_dir():
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
yield Path(tmpdirname)
|
||||
|
||||
|
||||
def test_perfinjector_bubble_sort(tmp_dir) -> None:
|
||||
code = """import unittest
|
||||
|
||||
|
|
@ -150,12 +154,12 @@ import dill as pickle"""
|
|||
# timeout_decorator no longer used since pytest handles timeouts
|
||||
|
||||
imports += "\n\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
|
||||
wrapper_func = codeflash_wrap_string
|
||||
|
||||
|
||||
test_class_header = "class TestPigLatin(unittest.TestCase):"
|
||||
test_decorator = "" # pytest-timeout handles timeouts now, not timeout_decorator
|
||||
|
||||
|
||||
expected = imports + "\n\n\n" + wrapper_func + "\n" + test_class_header + "\n\n"
|
||||
if test_decorator:
|
||||
expected += test_decorator + "\n"
|
||||
|
|
@ -190,10 +194,7 @@ import dill as pickle"""
|
|||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
Path(f.name),
|
||||
[CodePosition(9, 17), CodePosition(13, 17), CodePosition(17, 17)],
|
||||
func,
|
||||
Path(f.name).parent,
|
||||
Path(f.name), [CodePosition(9, 17), CodePosition(13, 17), CodePosition(17, 17)], func, Path(f.name).parent
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -397,18 +398,14 @@ def test_sort():
|
|||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(8, 14), CodePosition(12, 14)],
|
||||
func,
|
||||
project_root_path,
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
test_path, [CodePosition(8, 14), CodePosition(12, 14)], func, project_root_path, mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
|
||||
success, new_perf_test = inject_profiling_into_existing_test(
|
||||
|
|
@ -422,7 +419,7 @@ tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
|||
assert new_perf_test is not None
|
||||
assert new_perf_test.replace('"', "'") == expected_perfonly.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
|
||||
with test_path.open("w") as f:
|
||||
|
|
@ -942,7 +939,7 @@ def test_sort_parametrized_loop(input, expected_output):
|
|||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
|
||||
# Overwrite old test with new instrumented test
|
||||
|
|
@ -951,7 +948,7 @@ tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
|||
|
||||
assert new_test_perf.replace('"', "'") == expected_perf.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
|
||||
# Overwrite old test with new instrumented test
|
||||
|
|
@ -1301,12 +1298,12 @@ def test_sort():
|
|||
assert new_test_behavior is not None
|
||||
assert new_test_behavior.replace('"', "'") == expected.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
|
||||
assert new_test_perf.replace('"', "'") == expected_perf.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
|
||||
# Overwrite old test with new instrumented test
|
||||
|
|
@ -1477,7 +1474,6 @@ result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2
|
|||
|
||||
|
||||
def test_perfinjector_bubble_sort_unittest_results() -> None:
|
||||
|
||||
code = """import unittest
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
|
@ -1499,7 +1495,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
"""
|
||||
|
||||
is_windows = platform.system() == "Windows"
|
||||
|
||||
|
||||
if is_windows:
|
||||
expected = (
|
||||
"""import gc
|
||||
|
|
@ -1685,11 +1681,11 @@ class TestPigLatin(unittest.TestCase):
|
|||
assert new_test_behavior is not None
|
||||
assert new_test_behavior.replace('"', "'") == expected.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
assert new_test_perf.replace('"', "'") == expected_perf.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
#
|
||||
# Overwrite old test with new instrumented test
|
||||
|
|
@ -1852,7 +1848,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
# Build expected behavior output with platform-aware imports
|
||||
imports_behavior = build_expected_unittest_imports("from parameterized import parameterized")
|
||||
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
|
||||
test_decorator_behavior = "" # pytest-timeout handles timeouts now
|
||||
test_class_behavior = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
|
|
@ -1872,7 +1868,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
self.assertEqual(output, expected_output)
|
||||
codeflash_con.close()
|
||||
"""
|
||||
|
||||
|
||||
expected_behavior = imports_behavior + "\n\n\n" + codeflash_wrap_string + "\n" + test_class_behavior
|
||||
# Build expected perf output with platform-aware imports
|
||||
imports_perf = """import gc
|
||||
|
|
@ -1882,7 +1878,7 @@ import unittest
|
|||
"""
|
||||
# pytest-timeout handles timeouts now, no timeout_decorator needed
|
||||
imports_perf += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
|
||||
test_decorator_perf = "" # pytest-timeout handles timeouts now
|
||||
test_class_perf = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
|
|
@ -1895,7 +1891,7 @@ import unittest
|
|||
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0', codeflash_loop_index, input)
|
||||
self.assertEqual(output, expected_output)
|
||||
"""
|
||||
|
||||
|
||||
expected_perf = imports_perf + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class_perf
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
|
||||
test_path = (
|
||||
|
|
@ -1933,13 +1929,13 @@ import unittest
|
|||
assert new_test_behavior is not None
|
||||
assert new_test_behavior.replace('"', "'") == expected_behavior.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
|
||||
assert new_test_perf is not None
|
||||
assert new_test_perf.replace('"', "'") == expected_perf.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
|
||||
#
|
||||
|
|
@ -2099,10 +2095,10 @@ class TestPigLatin(unittest.TestCase):
|
|||
output = sorter(input)
|
||||
self.assertEqual(output, expected_output)"""
|
||||
|
||||
# Build expected behavior output with platform-aware imports
|
||||
# Build expected behavior output with platform-aware imports
|
||||
imports_behavior = build_expected_unittest_imports()
|
||||
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
|
||||
test_decorator_behavior = "" # pytest-timeout handles timeouts now
|
||||
test_class_behavior = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
|
|
@ -2137,7 +2133,7 @@ import unittest
|
|||
"""
|
||||
# pytest-timeout handles timeouts now, no timeout_decorator needed
|
||||
imports_perf += "\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
|
||||
test_decorator_perf = "" # pytest-timeout handles timeouts now
|
||||
test_class_perf = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
|
|
@ -2154,7 +2150,7 @@ import unittest
|
|||
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '2_2', codeflash_loop_index, input)
|
||||
self.assertEqual(output, expected_output)
|
||||
"""
|
||||
|
||||
|
||||
expected_perf = imports_perf + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class_perf
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
|
||||
test_path = (
|
||||
|
|
@ -2192,11 +2188,11 @@ import unittest
|
|||
assert new_test_behavior is not None
|
||||
assert new_test_behavior.replace('"', "'") == expected_behavior.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
assert new_test_perf.replace('"', "'") == expected_perf.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
#
|
||||
# # Overwrite old test with new instrumented test
|
||||
|
|
@ -2361,7 +2357,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
# Build expected behavior output with platform-aware imports
|
||||
imports_behavior = build_expected_unittest_imports("from parameterized import parameterized")
|
||||
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
|
||||
test_decorator_behavior = "" # pytest-timeout handles timeouts now
|
||||
test_class_behavior = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
|
|
@ -2392,7 +2388,7 @@ import unittest
|
|||
"""
|
||||
# pytest-timeout handles timeouts now, no timeout_decorator needed
|
||||
imports_perf += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
|
||||
test_decorator_perf = "" # pytest-timeout handles timeouts now
|
||||
test_class_perf = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
|
|
@ -2406,7 +2402,7 @@ import unittest
|
|||
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0_0', codeflash_loop_index, input)
|
||||
self.assertEqual(output, expected_output)
|
||||
"""
|
||||
|
||||
|
||||
expected_perf = imports_perf + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class_perf
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
|
||||
test_path = (
|
||||
|
|
@ -2442,11 +2438,11 @@ import unittest
|
|||
assert new_test_behavior is not None
|
||||
assert new_test_behavior.replace('"', "'") == expected_behavior.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
assert new_test_perf.replace('"', "'") == expected_perf.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
#
|
||||
# Overwrite old test with new instrumented test
|
||||
|
|
@ -2888,7 +2884,7 @@ def test_sort():
|
|||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path="tests.pytest.test_conditional_instrumentation_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
|
@ -2970,7 +2966,7 @@ def test_sort():
|
|||
assert success
|
||||
formatted_expected = expected.format(
|
||||
module_path="tests.pytest.test_perfinjector_bubble_sort_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
)
|
||||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == formatted_expected.replace('"', "'")
|
||||
|
|
@ -3055,7 +3051,7 @@ def test_code_replacement10() -> None:
|
|||
|
||||
test_file_path = tmp_path / "test_class_method_instrumentation.py"
|
||||
test_file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="get_code_optimization_context",
|
||||
parents=[FunctionParent("Optimizer", "ClassDef")],
|
||||
|
|
@ -3210,7 +3206,7 @@ import unittest
|
|||
"""
|
||||
# pytest-timeout handles timeouts now, no timeout_decorator needed
|
||||
imports += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.sleeptime import accurate_sleepfunc"
|
||||
|
||||
|
||||
test_decorator = "" # pytest-timeout handles timeouts now
|
||||
test_class = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
|
|
@ -3222,7 +3218,7 @@ import unittest
|
|||
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
|
||||
output = codeflash_wrap(accurate_sleepfunc, '{module_path}', 'TestPigLatin', 'test_sleepfunc_sequence_short', 'accurate_sleepfunc', '0', codeflash_loop_index, n)
|
||||
"""
|
||||
|
||||
|
||||
expected = imports + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/sleeptime.py").resolve()
|
||||
test_path = (
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from argparse import Namespace
|
|||
from pathlib import Path
|
||||
|
||||
import isort
|
||||
|
||||
from code_to_optimize.bubble_sort_method import BubbleSorter
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
|
@ -403,7 +404,7 @@ class BubbleSorter:
|
|||
assert test_results_mutated_attr[0].return_value[0] == {"x": 1}
|
||||
assert test_results_mutated_attr[0].verification_type == VerificationType.INIT_STATE_FTO
|
||||
assert test_results_mutated_attr[0].stdout == ""
|
||||
match,_ = compare_test_results(
|
||||
match, _ = compare_test_results(
|
||||
test_results, test_results_mutated_attr
|
||||
) # The test should fail because the instance attribute was mutated
|
||||
assert not match
|
||||
|
|
@ -458,7 +459,7 @@ class BubbleSorter:
|
|||
assert test_results_new_attr[0].stdout == ""
|
||||
# assert test_results_new_attr[1].return_value[1]["self"].x == 0 TODO: add self as input
|
||||
# assert test_results_new_attr[1].return_value[1]["self"].y == 2 TODO: add self as input
|
||||
match,_ = compare_test_results(
|
||||
match, _ = compare_test_results(
|
||||
test_results, test_results_new_attr
|
||||
) # The test should pass because the instance attribute was not mutated, only a new one was added
|
||||
assert match
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@
|
|||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.code_extractor import is_numerical_code
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,14 +6,7 @@ covering all patterns that might be seen in the wild.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.javascript.instrument import (
|
||||
ExpectCallTransformer,
|
||||
TestingMode,
|
||||
instrument_generated_js_test,
|
||||
transform_expect_calls,
|
||||
)
|
||||
from codeflash.languages.javascript.instrument import TestingMode, instrument_generated_js_test, transform_expect_calls
|
||||
|
||||
|
||||
class TestExpectCallTransformer:
|
||||
|
|
@ -656,4 +649,4 @@ describe('calculatePi', () => {
|
|||
});"""
|
||||
result = instrument_generated_js_test(code, "calculatePi", "calculatePi", TestingMode.BEHAVIOR)
|
||||
assert result.count("codeflash.capture(") == 2
|
||||
assert ".toBeCloseTo(" not in result
|
||||
assert ".toBeCloseTo(" not in result
|
||||
|
|
|
|||
|
|
@ -1,18 +1,12 @@
|
|||
"""
|
||||
Tests for JavaScript function discovery in get_functions_to_optimize.
|
||||
"""Tests for JavaScript function discovery in get_functions_to_optimize.
|
||||
|
||||
These tests verify that JavaScript functions are correctly discovered,
|
||||
filtered, and returned from the function discovery pipeline.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import unittest.mock
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import (
|
||||
FunctionToOptimize,
|
||||
filter_functions,
|
||||
find_all_functions_in_file,
|
||||
get_all_files_and_functions,
|
||||
|
|
@ -233,11 +227,7 @@ function add(a, b) {
|
|||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered, count = filter_functions(
|
||||
functions,
|
||||
tests_root=tmp_path / "tests",
|
||||
ignore_paths=[],
|
||||
project_root=tmp_path,
|
||||
module_root=tmp_path,
|
||||
functions, tests_root=tmp_path / "tests", ignore_paths=[], project_root=tmp_path, module_root=tmp_path
|
||||
)
|
||||
|
||||
assert js_file in filtered
|
||||
|
|
@ -258,11 +248,7 @@ function testHelper() {
|
|||
modified_functions = {test_file: functions.get(test_file, [])}
|
||||
|
||||
filtered, count = filter_functions(
|
||||
modified_functions,
|
||||
tests_root=tests_dir,
|
||||
ignore_paths=[],
|
||||
project_root=tmp_path,
|
||||
module_root=tmp_path,
|
||||
modified_functions, tests_root=tests_dir, ignore_paths=[], project_root=tmp_path, module_root=tmp_path
|
||||
)
|
||||
|
||||
assert test_file not in filtered
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""
|
||||
Extensive tests for the language abstraction base types.
|
||||
"""Extensive tests for the language abstraction base types.
|
||||
|
||||
These tests verify that the core data structures work correctly
|
||||
and maintain their contracts.
|
||||
|
|
@ -92,12 +91,7 @@ class TestFunctionInfo:
|
|||
|
||||
def test_function_info_creation_minimal(self):
|
||||
"""Test creating FunctionInfo with minimal args."""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
)
|
||||
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
|
||||
assert func.name == "add"
|
||||
assert func.file_path == Path("/test/example.py")
|
||||
assert func.start_line == 1
|
||||
|
|
@ -131,23 +125,13 @@ class TestFunctionInfo:
|
|||
|
||||
def test_function_info_frozen(self):
|
||||
"""Test that FunctionInfo is immutable."""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
)
|
||||
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
|
||||
with pytest.raises(AttributeError):
|
||||
func.name = "new_name"
|
||||
|
||||
def test_qualified_name_no_parents(self):
|
||||
"""Test qualified_name without parents."""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
)
|
||||
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
|
||||
assert func.qualified_name == "add"
|
||||
|
||||
def test_qualified_name_with_class(self):
|
||||
|
|
@ -168,10 +152,7 @@ class TestFunctionInfo:
|
|||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
parents=(
|
||||
ParentInfo(name="Outer", type="ClassDef"),
|
||||
ParentInfo(name="Inner", type="ClassDef"),
|
||||
),
|
||||
parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")),
|
||||
)
|
||||
assert func.qualified_name == "Outer.Inner.inner"
|
||||
|
||||
|
|
@ -188,12 +169,7 @@ class TestFunctionInfo:
|
|||
|
||||
def test_class_name_without_class(self):
|
||||
"""Test class_name property without class parent."""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
)
|
||||
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
|
||||
assert func.class_name is None
|
||||
|
||||
def test_class_name_nested_function(self):
|
||||
|
|
@ -214,22 +190,14 @@ class TestFunctionInfo:
|
|||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
parents=(
|
||||
ParentInfo(name="Outer", type="ClassDef"),
|
||||
ParentInfo(name="Inner", type="ClassDef"),
|
||||
),
|
||||
parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")),
|
||||
)
|
||||
# Should return the immediate parent class
|
||||
assert func.class_name == "Inner"
|
||||
|
||||
def test_top_level_parent_name_no_parents(self):
|
||||
"""Test top_level_parent_name without parents."""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
)
|
||||
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
|
||||
assert func.top_level_parent_name == "add"
|
||||
|
||||
def test_top_level_parent_name_with_parents(self):
|
||||
|
|
@ -239,10 +207,7 @@ class TestFunctionInfo:
|
|||
file_path=Path("/test/example.py"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
parents=(
|
||||
ParentInfo(name="Outer", type="ClassDef"),
|
||||
ParentInfo(name="Inner", type="ClassDef"),
|
||||
),
|
||||
parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")),
|
||||
)
|
||||
assert func.top_level_parent_name == "Outer"
|
||||
|
||||
|
|
@ -285,10 +250,7 @@ class TestCodeContext:
|
|||
|
||||
def test_code_context_creation_minimal(self):
|
||||
"""Test creating CodeContext with minimal args."""
|
||||
ctx = CodeContext(
|
||||
target_code="def add(a, b): return a + b",
|
||||
target_file=Path("/test/example.py"),
|
||||
)
|
||||
ctx = CodeContext(target_code="def add(a, b): return a + b", target_file=Path("/test/example.py"))
|
||||
assert ctx.target_code == "def add(a, b): return a + b"
|
||||
assert ctx.target_file == Path("/test/example.py")
|
||||
assert ctx.helper_functions == []
|
||||
|
|
@ -325,38 +287,24 @@ class TestTestInfo:
|
|||
|
||||
def test_test_info_creation(self):
|
||||
"""Test creating TestInfo."""
|
||||
info = TestInfo(
|
||||
test_name="test_add",
|
||||
test_file=Path("/tests/test_calc.py"),
|
||||
test_class="TestCalculator",
|
||||
)
|
||||
info = TestInfo(test_name="test_add", test_file=Path("/tests/test_calc.py"), test_class="TestCalculator")
|
||||
assert info.test_name == "test_add"
|
||||
assert info.test_file == Path("/tests/test_calc.py")
|
||||
assert info.test_class == "TestCalculator"
|
||||
|
||||
def test_test_info_without_class(self):
|
||||
"""Test TestInfo without test class."""
|
||||
info = TestInfo(
|
||||
test_name="test_add",
|
||||
test_file=Path("/tests/test_calc.py"),
|
||||
)
|
||||
info = TestInfo(test_name="test_add", test_file=Path("/tests/test_calc.py"))
|
||||
assert info.test_class is None
|
||||
|
||||
def test_full_test_path_with_class(self):
|
||||
"""Test full_test_path with class."""
|
||||
info = TestInfo(
|
||||
test_name="test_add",
|
||||
test_file=Path("/tests/test_calc.py"),
|
||||
test_class="TestCalculator",
|
||||
)
|
||||
info = TestInfo(test_name="test_add", test_file=Path("/tests/test_calc.py"), test_class="TestCalculator")
|
||||
assert info.full_test_path == "/tests/test_calc.py::TestCalculator::test_add"
|
||||
|
||||
def test_full_test_path_without_class(self):
|
||||
"""Test full_test_path without class."""
|
||||
info = TestInfo(
|
||||
test_name="test_add",
|
||||
test_file=Path("/tests/test_calc.py"),
|
||||
)
|
||||
info = TestInfo(test_name="test_add", test_file=Path("/tests/test_calc.py"))
|
||||
assert info.full_test_path == "/tests/test_calc.py::test_add"
|
||||
|
||||
|
||||
|
|
@ -448,10 +396,7 @@ class TestConvertParentsToTuple:
|
|||
self.name = name
|
||||
self.type = type_
|
||||
|
||||
parents = [
|
||||
MockParent("Outer", "ClassDef"),
|
||||
MockParent("inner", "FunctionDef"),
|
||||
]
|
||||
parents = [MockParent("Outer", "ClassDef"), MockParent("inner", "FunctionDef")]
|
||||
result = convert_parents_to_tuple(parents)
|
||||
|
||||
assert len(result) == 2
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""
|
||||
Tests for the integrated multi-language function discovery.
|
||||
"""Tests for the integrated multi-language function discovery.
|
||||
|
||||
These tests verify that the function discovery in functions_to_optimize.py
|
||||
correctly routes to language-specific implementations.
|
||||
|
|
@ -8,10 +7,7 @@ correctly routes to language-specific implementations.
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import (
|
||||
FunctionToOptimize,
|
||||
find_all_functions_in_file,
|
||||
get_all_files_and_functions,
|
||||
get_files_for_language,
|
||||
|
|
|
|||
|
|
@ -4,16 +4,10 @@ These tests verify that the ImportResolver correctly resolves import paths
|
|||
to actual file paths, enabling multi-file context extraction.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from codeflash.languages.javascript.import_resolver import (
|
||||
ImportResolver,
|
||||
ResolvedImport,
|
||||
MultiFileHelperFinder,
|
||||
HelperSearchContext,
|
||||
)
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.javascript.import_resolver import HelperSearchContext, ImportResolver, MultiFileHelperFinder
|
||||
from codeflash.languages.treesitter_utils import ImportInfo
|
||||
|
||||
|
||||
|
|
@ -246,16 +240,16 @@ class TestMultiFileHelperFinder:
|
|||
src_dir.mkdir()
|
||||
|
||||
# Main file that imports helper
|
||||
(src_dir / "main.ts").write_text('''
|
||||
(src_dir / "main.ts").write_text("""
|
||||
import { helperFunc } from './helper';
|
||||
|
||||
export function mainFunc() {
|
||||
return helperFunc() + 1;
|
||||
}
|
||||
''')
|
||||
""")
|
||||
|
||||
# Helper file
|
||||
(src_dir / "helper.ts").write_text('''
|
||||
(src_dir / "helper.ts").write_text("""
|
||||
export function helperFunc() {
|
||||
return 42;
|
||||
}
|
||||
|
|
@ -263,7 +257,7 @@ export function helperFunc() {
|
|||
export function unusedHelper() {
|
||||
return 0;
|
||||
}
|
||||
''')
|
||||
""")
|
||||
|
||||
return tmp_path
|
||||
|
||||
|
|
@ -293,6 +287,7 @@ class TestExportInfo:
|
|||
def js_analyzer(self):
|
||||
"""Create a JavaScript analyzer."""
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
|
||||
def test_find_named_export_function(self, js_analyzer):
|
||||
|
|
@ -394,6 +389,7 @@ class TestCommonJSRequire:
|
|||
def js_analyzer(self):
|
||||
"""Create a JavaScript analyzer."""
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
|
||||
def test_require_default_import(self, js_analyzer):
|
||||
|
|
@ -475,6 +471,7 @@ class TestCommonJSExports:
|
|||
def js_analyzer(self):
|
||||
"""Create a JavaScript analyzer."""
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
|
||||
def test_module_exports_function(self, js_analyzer):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""
|
||||
End-to-end integration tests for JavaScript pipeline.
|
||||
"""End-to-end integration tests for JavaScript pipeline.
|
||||
|
||||
Tests the full optimization pipeline for JavaScript:
|
||||
- Function discovery
|
||||
|
|
@ -13,11 +12,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import (
|
||||
FunctionToOptimize,
|
||||
find_all_functions_in_file,
|
||||
get_files_for_language,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file, get_files_for_language
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
|
||||
|
|
@ -149,11 +144,7 @@ function multiply(a, b) {
|
|||
|
||||
# Create FunctionInfo for the add function
|
||||
func_info = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("/tmp/test.js"),
|
||||
start_line=2,
|
||||
end_line=4,
|
||||
language=Language.JAVASCRIPT,
|
||||
name="add", file_path=Path("/tmp/test.js"), start_line=2, end_line=4, language=Language.JAVASCRIPT
|
||||
)
|
||||
|
||||
result = js_support.replace_function(original_source, func_info, new_function)
|
||||
|
|
@ -189,11 +180,7 @@ class TestJavaScriptTestDiscovery:
|
|||
# Create FunctionInfo for fibonacci function
|
||||
fib_file = js_project_dir / "fibonacci.js"
|
||||
func_info = FunctionInfo(
|
||||
name="fibonacci",
|
||||
file_path=fib_file,
|
||||
start_line=11,
|
||||
end_line=16,
|
||||
language=Language.JAVASCRIPT,
|
||||
name="fibonacci", file_path=fib_file, start_line=11, end_line=16, language=Language.JAVASCRIPT
|
||||
)
|
||||
|
||||
# Discover tests
|
||||
|
|
@ -233,19 +220,13 @@ function standalone(x) {
|
|||
assert len(functions.get(file_path, [])) >= 3
|
||||
|
||||
# Check standalone function
|
||||
standalone_fn = next(
|
||||
(fn for fn in functions[file_path] if fn.function_name == "standalone"),
|
||||
None,
|
||||
)
|
||||
standalone_fn = next((fn for fn in functions[file_path] if fn.function_name == "standalone"), None)
|
||||
assert standalone_fn is not None
|
||||
assert standalone_fn.language == "javascript"
|
||||
assert len(standalone_fn.parents) == 0
|
||||
|
||||
# Check class method
|
||||
add_fn = next(
|
||||
(fn for fn in functions[file_path] if fn.function_name == "add"),
|
||||
None,
|
||||
)
|
||||
add_fn = next((fn for fn in functions[file_path] if fn.function_name == "add"), None)
|
||||
assert add_fn is not None
|
||||
assert add_fn.language == "javascript"
|
||||
assert len(add_fn.parents) == 1
|
||||
|
|
@ -258,9 +239,7 @@ function standalone(x) {
|
|||
code_strings = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code="function add(a, b) { return a + b; }",
|
||||
file_path=Path("test.js"),
|
||||
language="javascript",
|
||||
code="function add(a, b) { return a + b; }", file_path=Path("test.js"), language="javascript"
|
||||
)
|
||||
],
|
||||
language="javascript",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""
|
||||
Tests for JavaScript instrumentation (line profiling and tracing).
|
||||
"""Tests for JavaScript instrumentation (line profiling and tracing).
|
||||
|
||||
This module tests the line profiling and tracing instrumentation for JavaScript code.
|
||||
"""
|
||||
|
|
@ -7,8 +6,6 @@ This module tests the line profiling and tracing instrumentation for JavaScript
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import FunctionInfo, Language
|
||||
from codeflash.languages.javascript.line_profiler import JavaScriptLineProfiler
|
||||
from codeflash.languages.javascript.tracer import JavaScriptTracer
|
||||
|
|
@ -52,11 +49,7 @@ function add(a, b) {
|
|||
file_path = Path(f.name)
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name="add",
|
||||
file_path=file_path,
|
||||
start_line=2,
|
||||
end_line=5,
|
||||
language=Language.JAVASCRIPT,
|
||||
name="add", file_path=file_path, start_line=2, end_line=5, language=Language.JAVASCRIPT
|
||||
)
|
||||
|
||||
output_file = Path("/tmp/test_profile.json")
|
||||
|
|
@ -117,11 +110,7 @@ function multiply(x, y) {
|
|||
file_path = Path(f.name)
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name="multiply",
|
||||
file_path=file_path,
|
||||
start_line=2,
|
||||
end_line=4,
|
||||
language=Language.JAVASCRIPT,
|
||||
name="multiply", file_path=file_path, start_line=2, end_line=4, language=Language.JAVASCRIPT
|
||||
)
|
||||
|
||||
output_db = Path("/tmp/test_traces.db")
|
||||
|
|
@ -165,17 +154,11 @@ function greet(name) {
|
|||
file_path = Path(f.name)
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name="greet",
|
||||
file_path=file_path,
|
||||
start_line=2,
|
||||
end_line=4,
|
||||
language=Language.JAVASCRIPT,
|
||||
name="greet", file_path=file_path, start_line=2, end_line=4, language=Language.JAVASCRIPT
|
||||
)
|
||||
|
||||
output_file = file_path.parent / ".codeflash" / "traces.db"
|
||||
instrumented = js_support.instrument_for_behavior(
|
||||
source, [func_info], output_file=output_file
|
||||
)
|
||||
instrumented = js_support.instrument_for_behavior(source, [func_info], output_file=output_file)
|
||||
|
||||
assert "__codeflash_tracer__" in instrumented
|
||||
assert "wrap" in instrumented
|
||||
|
|
@ -202,11 +185,7 @@ function square(n) {
|
|||
file_path = Path(f.name)
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name="square",
|
||||
file_path=file_path,
|
||||
start_line=2,
|
||||
end_line=5,
|
||||
language=Language.JAVASCRIPT,
|
||||
name="square", file_path=file_path, start_line=2, end_line=5, language=Language.JAVASCRIPT
|
||||
)
|
||||
|
||||
output_file = file_path.parent / ".codeflash" / "line_profile.json"
|
||||
|
|
@ -373,10 +352,7 @@ const result = calc.fibonacci(10);
|
|||
console.log(result);
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Calculator.fibonacci",
|
||||
capture_func="capture",
|
||||
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
|
||||
)
|
||||
|
||||
# Should transform calc.fibonacci(10) to codeflash.capture(..., calc.fibonacci.bind(calc), 10)
|
||||
|
|
@ -395,10 +371,7 @@ test('fibonacci works', () => {
|
|||
});
|
||||
"""
|
||||
transformed, counter = transform_expect_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="FibonacciCalculator.fibonacci",
|
||||
capture_func="capture",
|
||||
code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture"
|
||||
)
|
||||
|
||||
# Should transform expect(calc.fibonacci(10)) to
|
||||
|
|
@ -446,10 +419,7 @@ class FibonacciCalculator {
|
|||
}
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="FibonacciCalculator.fibonacci",
|
||||
capture_func="capture",
|
||||
code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture"
|
||||
)
|
||||
|
||||
# The method definition should NOT be transformed
|
||||
|
|
@ -468,10 +438,7 @@ FibonacciCalculator.prototype.fibonacci = function(n) {
|
|||
};
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="FibonacciCalculator.fibonacci",
|
||||
capture_func="capture",
|
||||
code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture"
|
||||
)
|
||||
|
||||
# The prototype assignment should NOT be transformed
|
||||
|
|
@ -489,10 +456,7 @@ const b = calc.fibonacci(10);
|
|||
const sum = a + b;
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Calculator.fibonacci",
|
||||
capture_func="capture",
|
||||
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
|
||||
)
|
||||
|
||||
# Should transform both calls
|
||||
|
|
@ -511,10 +475,7 @@ class Wrapper {
|
|||
}
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Wrapper.fibonacci",
|
||||
capture_func="capture",
|
||||
code=code, func_name="fibonacci", qualified_name="Wrapper.fibonacci", capture_func="capture"
|
||||
)
|
||||
|
||||
# Should transform this.fibonacci(n)
|
||||
|
|
@ -524,9 +485,9 @@ class Wrapper {
|
|||
|
||||
def test_full_instrumentation_produces_valid_syntax(self):
|
||||
"""Test that full instrumentation produces syntactically valid JavaScript."""
|
||||
from codeflash.languages.javascript.instrument import _instrument_js_test_code
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.javascript.instrument import _instrument_js_test_code
|
||||
|
||||
js_support = get_language_support(Language.JAVASCRIPT)
|
||||
|
||||
|
|
@ -584,11 +545,7 @@ describe('Calculator', () => {
|
|||
});
|
||||
"""
|
||||
instrumented = _instrument_js_test_code(
|
||||
code=test_code,
|
||||
func_name="add",
|
||||
test_file_path="test.js",
|
||||
mode="behavior",
|
||||
qualified_name="Calculator.add",
|
||||
code=test_code, func_name="add", test_file_path="test.js", mode="behavior", qualified_name="Calculator.add"
|
||||
)
|
||||
|
||||
# describe and test structure should be preserved
|
||||
|
|
@ -610,10 +567,7 @@ const data = await api.fetchData('http://example.com');
|
|||
console.log(data);
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fetchData",
|
||||
qualified_name="ApiClient.fetchData",
|
||||
capture_func="capture",
|
||||
code=code, func_name="fetchData", qualified_name="ApiClient.fetchData", capture_func="capture"
|
||||
)
|
||||
|
||||
# Should preserve await
|
||||
|
|
@ -632,10 +586,7 @@ class TestInstrumentationFullStringEquality:
|
|||
code = " calc.fibonacci(10);"
|
||||
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Calculator.fibonacci",
|
||||
capture_func="capture",
|
||||
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
|
||||
)
|
||||
|
||||
expected = " codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10);"
|
||||
|
|
@ -649,10 +600,7 @@ class TestInstrumentationFullStringEquality:
|
|||
code = " expect(calc.fibonacci(10)).toBe(55);"
|
||||
|
||||
transformed, counter = transform_expect_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Calculator.fibonacci",
|
||||
capture_func="capture",
|
||||
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
|
||||
)
|
||||
|
||||
expected = " expect(codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10)).toBe(55);"
|
||||
|
|
@ -684,10 +632,7 @@ class TestInstrumentationFullStringEquality:
|
|||
code = " fibonacci(10);"
|
||||
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="fibonacci",
|
||||
capture_func="capture",
|
||||
code=code, func_name="fibonacci", qualified_name="fibonacci", capture_func="capture"
|
||||
)
|
||||
|
||||
expected = " codeflash.capture('fibonacci', '1', fibonacci, 10);"
|
||||
|
|
@ -701,12 +646,9 @@ class TestInstrumentationFullStringEquality:
|
|||
code = " return this.fibonacci(n - 1);"
|
||||
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Class.fibonacci",
|
||||
capture_func="capture",
|
||||
code=code, func_name="fibonacci", qualified_name="Class.fibonacci", capture_func="capture"
|
||||
)
|
||||
|
||||
expected = " return codeflash.capture('Class.fibonacci', '1', this.fibonacci.bind(this), n - 1);"
|
||||
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
|
||||
assert counter == 1
|
||||
assert counter == 1
|
||||
|
|
|
|||
|
|
@ -1,18 +1,11 @@
|
|||
"""
|
||||
Tests for JavaScript module system detection.
|
||||
"""Tests for JavaScript module system detection.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.javascript.module_system import (
|
||||
ModuleSystem,
|
||||
detect_module_system,
|
||||
get_import_statement,
|
||||
)
|
||||
from codeflash.languages.javascript.module_system import ModuleSystem, detect_module_system, get_import_statement
|
||||
|
||||
|
||||
class TestModuleSystemDetection:
|
||||
|
|
@ -73,9 +66,7 @@ class TestModuleSystemDetection:
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
file_path = project_root / "module.js"
|
||||
file_path.write_text(
|
||||
"const foo = require('./bar');\nmodule.exports = { baz: 1 };"
|
||||
)
|
||||
file_path.write_text("const foo = require('./bar');\nmodule.exports = { baz: 1 };")
|
||||
|
||||
result = detect_module_system(project_root, file_path)
|
||||
assert result == ModuleSystem.COMMONJS
|
||||
|
|
@ -99,9 +90,7 @@ class TestImportStatementGeneration:
|
|||
target = tmpdir / "lib" / "utils.js"
|
||||
source = tmpdir / "tests" / "utils.test.js"
|
||||
|
||||
result = get_import_statement(
|
||||
ModuleSystem.COMMONJS, target, source, ["foo", "bar"]
|
||||
)
|
||||
result = get_import_statement(ModuleSystem.COMMONJS, target, source, ["foo", "bar"])
|
||||
|
||||
assert result == "const { foo, bar } = require('../lib/utils');"
|
||||
|
||||
|
|
@ -112,9 +101,7 @@ class TestImportStatementGeneration:
|
|||
target = tmpdir / "lib" / "utils.js"
|
||||
source = tmpdir / "tests" / "utils.test.js"
|
||||
|
||||
result = get_import_statement(
|
||||
ModuleSystem.ES_MODULE, target, source, ["foo", "bar"]
|
||||
)
|
||||
result = get_import_statement(ModuleSystem.ES_MODULE, target, source, ["foo", "bar"])
|
||||
|
||||
assert result == "import { foo, bar } from '../lib/utils';"
|
||||
|
||||
|
|
@ -147,9 +134,7 @@ class TestImportStatementGeneration:
|
|||
target = tmpdir / "utils.js"
|
||||
source = tmpdir / "index.js"
|
||||
|
||||
result = get_import_statement(
|
||||
ModuleSystem.COMMONJS, target, source, ["foo"]
|
||||
)
|
||||
result = get_import_statement(ModuleSystem.COMMONJS, target, source, ["foo"])
|
||||
|
||||
assert result == "const { foo } = require('./utils');"
|
||||
|
||||
|
|
@ -160,9 +145,7 @@ class TestImportStatementGeneration:
|
|||
target = tmpdir / "lib" / "helpers" / "utils.js"
|
||||
source = tmpdir / "tests" / "test.js"
|
||||
|
||||
result = get_import_statement(
|
||||
ModuleSystem.COMMONJS, target, source, ["foo"]
|
||||
)
|
||||
result = get_import_statement(ModuleSystem.COMMONJS, target, source, ["foo"])
|
||||
|
||||
assert result == "const { foo } = require('../lib/helpers/utils');"
|
||||
|
||||
|
|
@ -173,8 +156,6 @@ class TestImportStatementGeneration:
|
|||
target = tmpdir / "utils.js"
|
||||
source = tmpdir / "tests" / "unit" / "test.js"
|
||||
|
||||
result = get_import_statement(
|
||||
ModuleSystem.COMMONJS, target, source, ["foo"]
|
||||
)
|
||||
result = get_import_statement(ModuleSystem.COMMONJS, target, source, ["foo"])
|
||||
|
||||
assert result == "const { foo } = require('../../utils');"
|
||||
assert result == "const { foo } = require('../../utils');"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""
|
||||
Extensive tests for the JavaScript language support implementation.
|
||||
"""Extensive tests for the JavaScript language support implementation.
|
||||
|
||||
These tests verify that JavaScriptSupport correctly discovers functions,
|
||||
replaces code, and integrates with the codeflash language abstraction.
|
||||
|
|
@ -10,12 +9,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import (
|
||||
FunctionFilterCriteria,
|
||||
FunctionInfo,
|
||||
Language,
|
||||
ParentInfo,
|
||||
)
|
||||
from codeflash.languages.base import FunctionFilterCriteria, FunctionInfo, Language, ParentInfo
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport
|
||||
|
||||
|
||||
|
|
@ -322,12 +316,7 @@ function multiply(a, b) {
|
|||
return a * b;
|
||||
}
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("/test.js"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
)
|
||||
func = FunctionInfo(name="add", file_path=Path("/test.js"), start_line=1, end_line=3)
|
||||
new_code = """function add(a, b) {
|
||||
// Optimized
|
||||
return (a + b) | 0;
|
||||
|
|
@ -354,12 +343,7 @@ function other() {
|
|||
|
||||
// Footer
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="target",
|
||||
file_path=Path("/test.js"),
|
||||
start_line=4,
|
||||
end_line=6,
|
||||
)
|
||||
func = FunctionInfo(name="target", file_path=Path("/test.js"), start_line=4, end_line=6)
|
||||
new_code = """function target() {
|
||||
return 42;
|
||||
}
|
||||
|
|
@ -407,12 +391,7 @@ function other() {
|
|||
|
||||
const multiply = (x, y) => x * y;
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("/test.js"),
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
)
|
||||
func = FunctionInfo(name="add", file_path=Path("/test.js"), start_line=1, end_line=3)
|
||||
new_code = """const add = (a, b) => {
|
||||
return (a + b) | 0;
|
||||
};
|
||||
|
|
@ -504,18 +483,9 @@ class TestExtractCodeContext:
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=file_path,
|
||||
start_line=1,
|
||||
end_line=3,
|
||||
)
|
||||
func = FunctionInfo(name="add", file_path=file_path, start_line=1, end_line=3)
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
func,
|
||||
file_path.parent,
|
||||
file_path.parent,
|
||||
)
|
||||
context = js_support.extract_code_context(func, file_path.parent, file_path.parent)
|
||||
|
||||
assert "function add" in context.target_code
|
||||
assert "return a + b" in context.target_code
|
||||
|
|
@ -540,11 +510,7 @@ function main(a) {
|
|||
functions = js_support.discover_functions(file_path)
|
||||
main_func = next(f for f in functions if f.name == "main")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
main_func,
|
||||
file_path.parent,
|
||||
file_path.parent,
|
||||
)
|
||||
context = js_support.extract_code_context(main_func, file_path.parent, file_path.parent)
|
||||
|
||||
assert "function main" in context.target_code
|
||||
# Helper should be found
|
||||
|
|
@ -689,6 +655,7 @@ describe('Math functions', () => {
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -1035,10 +1002,14 @@ class TestClassMethodReplacement:
|
|||
parents=(ParentInfo(name="Math", type="ClassDef"),),
|
||||
is_method=True,
|
||||
)
|
||||
source = js_support.replace_function(source, add_func, """ add(a, b) {
|
||||
source = js_support.replace_function(
|
||||
source,
|
||||
add_func,
|
||||
""" add(a, b) {
|
||||
return (a + b) | 0;
|
||||
}
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
assert js_support.validate_syntax(source) is True
|
||||
|
||||
|
|
@ -1346,8 +1317,7 @@ class Counter {
|
|||
module.exports = { Counter };
|
||||
"""
|
||||
assert result == expected_result, (
|
||||
f"Replacement result does not match expected.\n"
|
||||
f"Expected:\n{expected_result}\n\nGot:\n{result}"
|
||||
f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}"
|
||||
)
|
||||
assert js_support.validate_syntax(result) is True
|
||||
|
||||
|
|
@ -1455,8 +1425,7 @@ class User {
|
|||
export { User };
|
||||
"""
|
||||
assert result == expected_result, (
|
||||
f"Replacement result does not match expected.\n"
|
||||
f"Expected:\n{expected_result}\n\nGot:\n{result}"
|
||||
f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}"
|
||||
)
|
||||
assert ts_support.validate_syntax(result) is True
|
||||
|
||||
|
|
@ -1544,8 +1513,7 @@ class Calculator {
|
|||
}
|
||||
"""
|
||||
assert result == expected_result, (
|
||||
f"Replacement result does not match expected.\n"
|
||||
f"Expected:\n{expected_result}\n\nGot:\n{result}"
|
||||
f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}"
|
||||
)
|
||||
assert js_support.validate_syntax(result) is True
|
||||
|
||||
|
|
@ -1631,7 +1599,6 @@ class MathUtils {
|
|||
module.exports = { MathUtils };
|
||||
"""
|
||||
assert result == expected_result, (
|
||||
f"Replacement result does not match expected.\n"
|
||||
f"Expected:\n{expected_result}\n\nGot:\n{result}"
|
||||
f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}"
|
||||
)
|
||||
assert js_support.validate_syntax(result) is True
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""
|
||||
Comprehensive tests for JavaScript test discovery functionality.
|
||||
"""Comprehensive tests for JavaScript test discovery functionality.
|
||||
|
||||
These tests verify that the JavaScript language support correctly discovers
|
||||
Jest tests and maps them to source functions, similar to Python's test discovery tests.
|
||||
|
|
@ -10,7 +9,6 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import FunctionInfo, Language
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport
|
||||
|
||||
|
||||
|
|
@ -630,6 +628,7 @@ it('third test', () => {});
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -654,6 +653,7 @@ describe('Suite B', () => {
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -677,6 +677,7 @@ describe('Outer', () => {
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -700,6 +701,7 @@ describe.skip('skipped describe', () => {
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -720,6 +722,7 @@ describe.only('only describe', () => {
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -737,6 +740,7 @@ describe('describe single', () => {});
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -746,15 +750,16 @@ describe('describe single', () => {});
|
|||
def test_find_tests_with_double_quotes(self, js_support):
|
||||
"""Test finding tests with double-quoted names."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".test.js", mode="w", delete=False) as f:
|
||||
f.write('''
|
||||
f.write("""
|
||||
test("double quotes", () => {});
|
||||
describe("describe double", () => {});
|
||||
''')
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -770,6 +775,7 @@ describe("describe double", () => {});
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -1015,6 +1021,7 @@ describe('日本語テスト', () => {
|
|||
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -1042,6 +1049,7 @@ test.each([
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -1067,6 +1075,7 @@ describe.each([
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -1091,6 +1100,7 @@ describe('Math operations', () => {
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -1178,8 +1188,7 @@ describe('formatName', () => {
|
|||
for test_list in tests.values():
|
||||
all_test_names.extend([t.test_name for t in test_list])
|
||||
|
||||
assert any("validateEmail" in name or "accepts valid email" in name
|
||||
for name in all_test_names)
|
||||
assert any("validateEmail" in name or "accepts valid email" in name for name in all_test_names)
|
||||
|
||||
def test_discovery_with_fixtures(self, js_support):
|
||||
"""Test discovery when test file uses beforeEach/afterEach."""
|
||||
|
|
@ -1449,6 +1458,7 @@ testCases.forEach(name => {
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -1476,6 +1486,7 @@ describe('conditional tests', () => {
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -1499,6 +1510,7 @@ test('slow test', () => {
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -1521,6 +1533,7 @@ test.todo('also needs implementation');
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -1543,6 +1556,7 @@ test.concurrent('concurrent test 2', async () => {
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -1641,6 +1655,7 @@ describe('Array', function() {
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
@ -1671,6 +1686,7 @@ describe('User', () => {
|
|||
|
||||
source = file_path.read_text()
|
||||
from codeflash.languages.treesitter_utils import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
test_names = js_support._find_jest_tests(source, analyzer)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import shutil
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
|
|
@ -86,9 +87,7 @@ class Calculator {
|
|||
|
||||
assert context.target_code is not None, "target_code should not be None"
|
||||
assert context.target_code.strip() == expected_code.strip(), (
|
||||
f"Extracted code does not match expected.\n"
|
||||
f"Expected:\n{expected_code}\n\n"
|
||||
f"Got:\n{context.target_code}"
|
||||
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
||||
)
|
||||
|
||||
def test_extract_context_includes_direct_helpers(self, js_support, cjs_project):
|
||||
|
|
@ -175,9 +174,7 @@ class Calculator {
|
|||
}"""
|
||||
|
||||
assert context.target_code.strip() == expected_code.strip(), (
|
||||
f"Extracted code does not match expected.\n"
|
||||
f"Expected:\n{expected_code}\n\n"
|
||||
f"Got:\n{context.target_code}"
|
||||
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
||||
)
|
||||
|
||||
def test_extract_compound_interest_helpers(self, js_support, cjs_project):
|
||||
|
|
@ -282,9 +279,7 @@ function validateInput(value, name) {
|
|||
|
||||
assert len(context.imports) == 2, f"Expected 2 imports, got {len(context.imports)}: {context.imports}"
|
||||
assert context.imports == expected_imports, (
|
||||
f"Imports do not match expected.\n"
|
||||
f"Expected:\n{expected_imports}\n\n"
|
||||
f"Got:\n{context.imports}"
|
||||
f"Imports do not match expected.\nExpected:\n{expected_imports}\n\nGot:\n{context.imports}"
|
||||
)
|
||||
|
||||
def test_extract_static_method(self, js_support, cjs_project):
|
||||
|
|
@ -314,9 +309,7 @@ class Calculator {
|
|||
}"""
|
||||
|
||||
assert context.target_code.strip() == expected_code.strip(), (
|
||||
f"Extracted code does not match expected.\n"
|
||||
f"Expected:\n{expected_code}\n\n"
|
||||
f"Got:\n{context.target_code}"
|
||||
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
||||
)
|
||||
|
||||
# quickAdd uses add helper from math_utils
|
||||
|
|
@ -397,9 +390,7 @@ class Calculator {
|
|||
}"""
|
||||
|
||||
assert context.target_code.strip() == expected_code.strip(), (
|
||||
f"Extracted code does not match expected.\n"
|
||||
f"Expected:\n{expected_code}\n\n"
|
||||
f"Got:\n{context.target_code}"
|
||||
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
||||
)
|
||||
|
||||
# ESM permutation uses factorial helper
|
||||
|
|
@ -460,9 +451,7 @@ class Calculator {
|
|||
}"""
|
||||
|
||||
assert context.target_code.strip() == expected_code.strip(), (
|
||||
f"Extracted code does not match expected.\n"
|
||||
f"Expected:\n{expected_code}\n\n"
|
||||
f"Got:\n{context.target_code}"
|
||||
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
||||
)
|
||||
|
||||
expected_imports = [
|
||||
|
|
@ -472,9 +461,7 @@ class Calculator {
|
|||
|
||||
assert len(context.imports) == 2, f"Expected 2 imports, got {len(context.imports)}: {context.imports}"
|
||||
assert context.imports == expected_imports, (
|
||||
f"Imports do not match expected.\n"
|
||||
f"Expected:\n{expected_imports}\n\n"
|
||||
f"Got:\n{context.imports}"
|
||||
f"Imports do not match expected.\nExpected:\n{expected_imports}\n\nGot:\n{context.imports}"
|
||||
)
|
||||
|
||||
# ESM compound interest uses 4 helpers
|
||||
|
|
@ -593,9 +580,7 @@ class Calculator {
|
|||
}"""
|
||||
|
||||
assert context.target_code.strip() == expected_code.strip(), (
|
||||
f"Extracted code does not match expected.\n"
|
||||
f"Expected:\n{expected_code}\n\n"
|
||||
f"Got:\n{context.target_code}"
|
||||
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
||||
)
|
||||
|
||||
# TypeScript permutation uses factorial helper
|
||||
|
|
@ -659,9 +644,7 @@ class Calculator {
|
|||
}"""
|
||||
|
||||
assert context.target_code.strip() == expected_code.strip(), (
|
||||
f"Extracted code does not match expected.\n"
|
||||
f"Expected:\n{expected_code}\n\n"
|
||||
f"Got:\n{context.target_code}"
|
||||
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
||||
)
|
||||
|
||||
# TypeScript compound interest uses 4 helpers
|
||||
|
|
@ -890,9 +873,7 @@ module.exports = { Counter };
|
|||
functions = js_support.discover_functions(test_file)
|
||||
increment_func = next(f for f in functions if f.name == "increment")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=increment_func, project_root=tmp_path, module_root=tmp_path
|
||||
)
|
||||
context = js_support.extract_code_context(function=increment_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
expected_code = """\
|
||||
class Counter {
|
||||
|
|
@ -907,9 +888,7 @@ class Counter {
|
|||
}"""
|
||||
|
||||
assert context.target_code.strip() == expected_code.strip(), (
|
||||
f"Extracted code does not match expected.\n"
|
||||
f"Expected:\n{expected_code}\n\n"
|
||||
f"Got:\n{context.target_code}"
|
||||
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
||||
)
|
||||
|
||||
def test_method_extraction_class_without_constructor(self, js_support, tmp_path):
|
||||
|
|
@ -933,9 +912,7 @@ module.exports = { MathUtils };
|
|||
functions = js_support.discover_functions(test_file)
|
||||
add_func = next(f for f in functions if f.name == "add")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=add_func, project_root=tmp_path, module_root=tmp_path
|
||||
)
|
||||
context = js_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
expected_code = """\
|
||||
class MathUtils {
|
||||
|
|
@ -945,9 +922,7 @@ class MathUtils {
|
|||
}"""
|
||||
|
||||
assert context.target_code.strip() == expected_code.strip(), (
|
||||
f"Extracted code does not match expected.\n"
|
||||
f"Expected:\n{expected_code}\n\n"
|
||||
f"Got:\n{context.target_code}"
|
||||
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
||||
)
|
||||
|
||||
def test_typescript_method_extraction_includes_fields(self, ts_support, tmp_path):
|
||||
|
|
@ -975,9 +950,7 @@ export { User };
|
|||
functions = ts_support.discover_functions(test_file)
|
||||
get_name_func = next(f for f in functions if f.name == "getName")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=get_name_func, project_root=tmp_path, module_root=tmp_path
|
||||
)
|
||||
context = ts_support.extract_code_context(function=get_name_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
expected_code = """\
|
||||
class User {
|
||||
|
|
@ -995,9 +968,7 @@ class User {
|
|||
}"""
|
||||
|
||||
assert context.target_code.strip() == expected_code.strip(), (
|
||||
f"Extracted code does not match expected.\n"
|
||||
f"Expected:\n{expected_code}\n\n"
|
||||
f"Got:\n{context.target_code}"
|
||||
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
||||
)
|
||||
|
||||
def test_typescript_fields_only_no_constructor(self, ts_support, tmp_path):
|
||||
|
|
@ -1020,9 +991,7 @@ export { Config };
|
|||
functions = ts_support.discover_functions(test_file)
|
||||
get_url_func = next(f for f in functions if f.name == "getUrl")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=get_url_func, project_root=tmp_path, module_root=tmp_path
|
||||
)
|
||||
context = ts_support.extract_code_context(function=get_url_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
expected_code = """\
|
||||
class Config {
|
||||
|
|
@ -1035,9 +1004,7 @@ class Config {
|
|||
}"""
|
||||
|
||||
assert context.target_code.strip() == expected_code.strip(), (
|
||||
f"Extracted code does not match expected.\n"
|
||||
f"Expected:\n{expected_code}\n\n"
|
||||
f"Got:\n{context.target_code}"
|
||||
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
||||
)
|
||||
|
||||
def test_constructor_with_jsdoc(self, js_support, tmp_path):
|
||||
|
|
@ -1065,9 +1032,7 @@ module.exports = { Logger };
|
|||
functions = js_support.discover_functions(test_file)
|
||||
get_prefix_func = next(f for f in functions if f.name == "getPrefix")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=get_prefix_func, project_root=tmp_path, module_root=tmp_path
|
||||
)
|
||||
context = js_support.extract_code_context(function=get_prefix_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
expected_code = """\
|
||||
class Logger {
|
||||
|
|
@ -1085,9 +1050,7 @@ class Logger {
|
|||
}"""
|
||||
|
||||
assert context.target_code.strip() == expected_code.strip(), (
|
||||
f"Extracted code does not match expected.\n"
|
||||
f"Expected:\n{expected_code}\n\n"
|
||||
f"Got:\n{context.target_code}"
|
||||
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
||||
)
|
||||
|
||||
def test_static_method_includes_constructor(self, js_support, tmp_path):
|
||||
|
|
@ -1111,9 +1074,7 @@ module.exports = { Factory };
|
|||
functions = js_support.discover_functions(test_file)
|
||||
create_func = next(f for f in functions if f.name == "create")
|
||||
|
||||
context = js_support.extract_code_context(
|
||||
function=create_func, project_root=tmp_path, module_root=tmp_path
|
||||
)
|
||||
context = js_support.extract_code_context(function=create_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
expected_code = """\
|
||||
class Factory {
|
||||
|
|
@ -1127,9 +1088,7 @@ class Factory {
|
|||
}"""
|
||||
|
||||
assert context.target_code.strip() == expected_code.strip(), (
|
||||
f"Extracted code does not match expected.\n"
|
||||
f"Expected:\n{expected_code}\n\n"
|
||||
f"Got:\n{context.target_code}"
|
||||
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1264,9 +1223,7 @@ export { distance };
|
|||
functions = ts_support.discover_functions(test_file)
|
||||
distance_func = next(f for f in functions if f.name == "distance")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=distance_func, project_root=tmp_path, module_root=tmp_path
|
||||
)
|
||||
context = ts_support.extract_code_context(function=distance_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
# Type definition should be in read-only context with exact match
|
||||
expected_read_only = """\
|
||||
|
|
@ -1310,9 +1267,7 @@ export { processStatus };
|
|||
functions = ts_support.discover_functions(test_file)
|
||||
process_func = next(f for f in functions if f.name == "processStatus")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=process_func, project_root=tmp_path, module_root=tmp_path
|
||||
)
|
||||
context = ts_support.extract_code_context(function=process_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
# Enum should be in read-only context with exact match
|
||||
expected_read_only = """\
|
||||
|
|
@ -1349,9 +1304,7 @@ export { compute };
|
|||
functions = ts_support.discover_functions(test_file)
|
||||
compute_func = next(f for f in functions if f.name == "compute")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=compute_func, project_root=tmp_path, module_root=tmp_path
|
||||
)
|
||||
context = ts_support.extract_code_context(function=compute_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
# Type alias should be in read-only context with exact match
|
||||
expected_read_only = """\
|
||||
|
|
@ -1428,9 +1381,7 @@ export { add };
|
|||
functions = ts_support.discover_functions(test_file)
|
||||
add_func = next(f for f in functions if f.name == "add")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=add_func, project_root=tmp_path, module_root=tmp_path
|
||||
)
|
||||
context = ts_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
# No type definitions should be extracted for primitives - exact empty match
|
||||
assert context.read_only_context == "", (
|
||||
|
|
@ -1564,9 +1515,7 @@ export { greetUser };
|
|||
functions = ts_support.discover_functions(test_file)
|
||||
greet_func = next(f for f in functions if f.name == "greetUser")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
function=greet_func, project_root=tmp_path, module_root=tmp_path
|
||||
)
|
||||
context = ts_support.extract_code_context(function=greet_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
||||
# JSDoc should be included with the interface - exact match
|
||||
expected_read_only = """\
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import shutil
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.javascript.module_system import (
|
||||
ModuleSystem,
|
||||
convert_commonjs_to_esm,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""
|
||||
Regression tests for Python/JavaScript language support parity.
|
||||
"""Regression tests for Python/JavaScript language support parity.
|
||||
|
||||
These tests ensure that the JavaScript implementation maintains feature parity
|
||||
with the Python implementation. Each test class tests equivalent functionality
|
||||
|
|
@ -15,12 +14,7 @@ from typing import NamedTuple
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import (
|
||||
FunctionFilterCriteria,
|
||||
FunctionInfo,
|
||||
Language,
|
||||
ParentInfo,
|
||||
)
|
||||
from codeflash.languages.base import FunctionFilterCriteria, FunctionInfo, Language, ParentInfo
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport
|
||||
from codeflash.languages.python.support import PythonSupport
|
||||
|
||||
|
|
@ -732,8 +726,8 @@ function other() {
|
|||
js_method_line = next(l for l in js_lines if "add(a, b)" in l)
|
||||
|
||||
# Both should have indentation (4 spaces)
|
||||
assert py_method_line.startswith(" "), f"Python method should be indented: {repr(py_method_line)}"
|
||||
assert js_method_line.startswith(" "), f"JavaScript method should be indented: {repr(js_method_line)}"
|
||||
assert py_method_line.startswith(" "), f"Python method should be indented: {py_method_line!r}"
|
||||
assert js_method_line.startswith(" "), f"JavaScript method should be indented: {js_method_line!r}"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ function findMin(numbers) {
|
|||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.registry import get_language_support
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown
|
||||
|
|
@ -90,6 +91,7 @@ class Args:
|
|||
disable_imports_sorting = True
|
||||
formatter_cmds = ["disabled"]
|
||||
|
||||
|
||||
def test_js_replcement() -> None:
|
||||
try:
|
||||
root_dir = Path(__file__).parent.parent.parent.resolve()
|
||||
|
|
@ -128,11 +130,14 @@ def test_js_replcement() -> None:
|
|||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
result = func_optimizer.get_code_optimization_context()
|
||||
from codeflash.either import is_successful
|
||||
|
||||
if not is_successful(result):
|
||||
import pytest
|
||||
pytest.skip(f"Context extraction not fully implemented for JS: {result.failure() if hasattr(result, 'failure') else result}")
|
||||
code_context: CodeOptimizationContext = result.unwrap()
|
||||
|
||||
pytest.skip(
|
||||
f"Context extraction not fully implemented for JS: {result.failure() if hasattr(result, 'failure') else result}"
|
||||
)
|
||||
code_context: CodeOptimizationContext = result.unwrap()
|
||||
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -143,7 +148,9 @@ def test_js_replcement() -> None:
|
|||
|
||||
func_optimizer.args = Args()
|
||||
did_update = func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(new_code), original_helper_code=original_helper_code
|
||||
code_context=code_context,
|
||||
optimized_code=CodeStringsMarkdown.parse_markdown_code(new_code),
|
||||
original_helper_code=original_helper_code,
|
||||
)
|
||||
|
||||
assert did_update, "Expected code to be updated"
|
||||
|
|
@ -319,7 +326,9 @@ module.exports = {
|
|||
"""
|
||||
|
||||
assert main_code == expected_main, f"Main file mismatch.\n\nActual:\n{main_code}\n\nExpected:\n{expected_main}"
|
||||
assert helper_code == expected_helper, f"Helper file mismatch.\n\nActual:\n{helper_code}\n\nExpected:\n{expected_helper}"
|
||||
assert helper_code == expected_helper, (
|
||||
f"Helper file mismatch.\n\nActual:\n{helper_code}\n\nExpected:\n{expected_helper}"
|
||||
)
|
||||
|
||||
finally:
|
||||
main_file.write_text(original_main, encoding="utf-8")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""
|
||||
Extensive tests for the Python language support implementation.
|
||||
"""Extensive tests for the Python language support implementation.
|
||||
|
||||
These tests verify that PythonSupport correctly discovers functions,
|
||||
replaces code, and integrates with existing codeflash functionality.
|
||||
|
|
@ -10,12 +9,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import (
|
||||
FunctionFilterCriteria,
|
||||
FunctionInfo,
|
||||
Language,
|
||||
ParentInfo,
|
||||
)
|
||||
from codeflash.languages.base import FunctionFilterCriteria, FunctionInfo, Language, ParentInfo
|
||||
from codeflash.languages.python.support import PythonSupport
|
||||
|
||||
|
||||
|
|
@ -269,12 +263,7 @@ class TestReplaceFunction:
|
|||
def multiply(a, b):
|
||||
return a * b
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("/test.py"),
|
||||
start_line=1,
|
||||
end_line=2,
|
||||
)
|
||||
func = FunctionInfo(name="add", file_path=Path("/test.py"), start_line=1, end_line=2)
|
||||
new_code = """def add(a, b):
|
||||
# Optimized
|
||||
return (a + b) | 0
|
||||
|
|
@ -298,12 +287,7 @@ def other():
|
|||
|
||||
# Footer
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="target",
|
||||
file_path=Path("/test.py"),
|
||||
start_line=4,
|
||||
end_line=5,
|
||||
)
|
||||
func = FunctionInfo(name="target", file_path=Path("/test.py"), start_line=4, end_line=5)
|
||||
new_code = """def target():
|
||||
return 42
|
||||
"""
|
||||
|
|
@ -347,12 +331,7 @@ def other():
|
|||
def second():
|
||||
return 2
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="first",
|
||||
file_path=Path("/test.py"),
|
||||
start_line=1,
|
||||
end_line=2,
|
||||
)
|
||||
func = FunctionInfo(name="first", file_path=Path("/test.py"), start_line=1, end_line=2)
|
||||
new_code = """def first():
|
||||
return 100
|
||||
"""
|
||||
|
|
@ -369,12 +348,7 @@ def second():
|
|||
def last():
|
||||
return 999
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="last",
|
||||
file_path=Path("/test.py"),
|
||||
start_line=4,
|
||||
end_line=5,
|
||||
)
|
||||
func = FunctionInfo(name="last", file_path=Path("/test.py"), start_line=4, end_line=5)
|
||||
new_code = """def last():
|
||||
return 1000
|
||||
"""
|
||||
|
|
@ -388,12 +362,7 @@ def last():
|
|||
source = """def only():
|
||||
return 42
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="only",
|
||||
file_path=Path("/test.py"),
|
||||
start_line=1,
|
||||
end_line=2,
|
||||
)
|
||||
func = FunctionInfo(name="only", file_path=Path("/test.py"), start_line=1, end_line=2)
|
||||
new_code = """def only():
|
||||
return 100
|
||||
"""
|
||||
|
|
@ -505,18 +474,9 @@ class TestExtractCodeContext:
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=file_path,
|
||||
start_line=1,
|
||||
end_line=2,
|
||||
)
|
||||
func = FunctionInfo(name="add", file_path=file_path, start_line=1, end_line=2)
|
||||
|
||||
context = python_support.extract_code_context(
|
||||
func,
|
||||
file_path.parent,
|
||||
file_path.parent,
|
||||
)
|
||||
context = python_support.extract_code_context(func, file_path.parent, file_path.parent)
|
||||
|
||||
assert "def add" in context.target_code
|
||||
assert "return a + b" in context.target_code
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""
|
||||
Extensive tests for the language registry module.
|
||||
"""Extensive tests for the language registry module.
|
||||
|
||||
These tests verify that language registration, lookup, and detection
|
||||
work correctly.
|
||||
|
|
@ -10,7 +9,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import Language, LanguageSupport
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.registry import (
|
||||
UnsupportedLanguageError,
|
||||
clear_cache,
|
||||
|
|
@ -28,7 +27,6 @@ from codeflash.languages.registry import (
|
|||
def setup_registry():
|
||||
"""Ensure PythonSupport is registered before each test."""
|
||||
# Import to trigger registration
|
||||
from codeflash.languages.python import PythonSupport
|
||||
|
||||
yield
|
||||
# Clear cache after each test to avoid side effects
|
||||
|
|
@ -92,9 +90,7 @@ class TestGetLanguageSupport:
|
|||
"""Test that unsupported extension raises UnsupportedLanguageError."""
|
||||
with pytest.raises(UnsupportedLanguageError) as exc_info:
|
||||
get_language_support(Path("/test/example.xyz"))
|
||||
assert "xyz" in str(exc_info.value.identifier) or "example.xyz" in str(
|
||||
exc_info.value.identifier
|
||||
)
|
||||
assert "xyz" in str(exc_info.value.identifier) or "example.xyz" in str(exc_info.value.identifier)
|
||||
|
||||
def test_unsupported_language_raises(self):
|
||||
"""Test that unsupported language name raises UnsupportedLanguageError."""
|
||||
|
|
@ -175,9 +171,8 @@ class TestDetectProjectLanguage:
|
|||
|
||||
def test_detect_empty_project_raises(self):
|
||||
"""Test that empty project raises UnsupportedLanguageError."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with pytest.raises(UnsupportedLanguageError):
|
||||
detect_project_language(Path(tmpdir), Path(tmpdir))
|
||||
with tempfile.TemporaryDirectory() as tmpdir, pytest.raises(UnsupportedLanguageError):
|
||||
detect_project_language(Path(tmpdir), Path(tmpdir))
|
||||
|
||||
def test_detect_with_different_roots(self):
|
||||
"""Test detection with different project and module roots."""
|
||||
|
|
|
|||
|
|
@ -1,20 +1,14 @@
|
|||
"""
|
||||
Extensive tests for the tree-sitter utilities module.
|
||||
"""Extensive tests for the tree-sitter utilities module.
|
||||
|
||||
These tests verify that the TreeSitterAnalyzer correctly parses and
|
||||
analyzes JavaScript/TypeScript code.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.treesitter_utils import (
|
||||
FunctionNode,
|
||||
ImportInfo,
|
||||
TreeSitterAnalyzer,
|
||||
TreeSitterLanguage,
|
||||
get_analyzer_for_file,
|
||||
)
|
||||
from pathlib import Path
|
||||
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
|
||||
|
||||
|
||||
class TestTreeSitterLanguage:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
|
@ -9,11 +10,13 @@ class Args:
|
|||
disable_imports_sorting = True
|
||||
formatter_cmds = ["disabled"]
|
||||
|
||||
|
||||
def test_multi_file_replcement01() -> None:
|
||||
root_dir = Path(__file__).parent.parent.resolve()
|
||||
helper_file = (root_dir / "code_to_optimize/temp_helper.py").resolve()
|
||||
|
||||
helper_file.write_text("""import re
|
||||
|
||||
helper_file.write_text(
|
||||
"""import re
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
|
||||
|
|
@ -36,7 +39,9 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
|||
# TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl.
|
||||
|
||||
return tokens
|
||||
""", encoding="utf-8")
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
main_file = (root_dir / "code_to_optimize/temp_main.py").resolve()
|
||||
|
||||
|
|
@ -93,8 +98,6 @@ def _get_string_usage(text: str) -> Usage:
|
|||
```
|
||||
"""
|
||||
|
||||
|
||||
|
||||
func = FunctionToOptimize(function_name="_get_string_usage", parents=[], file_path=main_file)
|
||||
test_config = TestConfig(
|
||||
tests_root=root_dir / "tests/pytest",
|
||||
|
|
@ -106,8 +109,6 @@ def _get_string_usage(text: str) -> Usage:
|
|||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
|
||||
|
||||
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
for helper_function_path in helper_function_paths:
|
||||
|
|
@ -117,11 +118,13 @@ def _get_string_usage(text: str) -> Usage:
|
|||
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code
|
||||
code_context=code_context,
|
||||
optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code),
|
||||
original_helper_code=original_helper_code,
|
||||
)
|
||||
new_code = main_file.read_text(encoding="utf-8")
|
||||
new_helper_code = helper_file.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
helper_file.unlink(missing_ok=True)
|
||||
main_file.unlink(missing_ok=True)
|
||||
|
||||
|
|
@ -160,5 +163,5 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
|||
return tokens
|
||||
"""
|
||||
|
||||
assert new_code.rstrip() == original_main.rstrip() # No Change
|
||||
assert new_helper_code.rstrip() == expected_helper.rstrip()
|
||||
assert new_code.rstrip() == original_main.rstrip() # No Change
|
||||
assert new_helper_code.rstrip() == expected_helper.rstrip()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
|
||||
from codeflash.verification.parse_test_output import parse_test_failures_from_stdout
|
||||
|
||||
|
||||
def test_extracting_single_pytest_error_from_stdout():
|
||||
stdout = '''
|
||||
stdout = """
|
||||
F... [100%]
|
||||
=================================== FAILURES ===================================
|
||||
_______________________ test_calculate_portfolio_metrics _______________________
|
||||
|
|
@ -38,11 +37,13 @@ FAILED code_to_optimize/tests/pytest/test_multiple_helpers.py::test_calculate_po
|
|||
1 failed, 3 passed in 0.15s
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
errors = parse_test_failures_from_stdout(stdout)
|
||||
assert errors
|
||||
assert len(errors.keys()) == 1
|
||||
assert errors['test_calculate_portfolio_metrics'] == '''
|
||||
assert (
|
||||
errors["test_calculate_portfolio_metrics"]
|
||||
== """
|
||||
def test_calculate_portfolio_metrics():
|
||||
# Test case 1: Basic portfolio
|
||||
investments = [
|
||||
|
|
@ -68,13 +69,15 @@ E assert 4.109589046841222e-08 < 1e-10
|
|||
E + where 4.109589046841222e-08 = abs((0.890411 - 0.8904109589041095))
|
||||
|
||||
code_to_optimize/tests/pytest/test_multiple_helpers.py:26: AssertionError
|
||||
'''
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_extracting_no_pytest_failures():
|
||||
stdout = '''
|
||||
stdout = """
|
||||
.... [100%]
|
||||
4 passed in 0.12s
|
||||
'''
|
||||
"""
|
||||
errors = parse_test_failures_from_stdout(stdout)
|
||||
assert errors == {}
|
||||
|
||||
|
|
@ -82,7 +85,7 @@ def test_extracting_no_pytest_failures():
|
|||
def test_extracting_multiple_pytest_failures_with_class_method():
|
||||
print("hi")
|
||||
|
||||
stdout = '''
|
||||
stdout = """
|
||||
F.F [100%]
|
||||
=================================== FAILURES ===================================
|
||||
________________________ test_simple_failure ________________________
|
||||
|
|
@ -105,38 +108,42 @@ code_to_optimize/tests/test_calculator.py:22: ZeroDivisionError
|
|||
FAILED code_to_optimize/tests/test_simple.py::test_simple_failure
|
||||
FAILED code_to_optimize/tests/test_calculator.py::TestCalculator::test_divide_by_zero
|
||||
2 failed, 1 passed in 0.18s
|
||||
'''
|
||||
"""
|
||||
errors = parse_test_failures_from_stdout(stdout)
|
||||
print(errors)
|
||||
assert len(errors) == 2
|
||||
|
||||
assert 'test_simple_failure' in errors
|
||||
assert errors['test_simple_failure'] == '''
|
||||
assert "test_simple_failure" in errors
|
||||
assert (
|
||||
errors["test_simple_failure"]
|
||||
== """
|
||||
def test_simple_failure():
|
||||
x = 1 + 1
|
||||
> assert x == 3
|
||||
E assert 2 == 3
|
||||
|
||||
code_to_optimize/tests/test_simple.py:10: AssertionError
|
||||
'''
|
||||
"""
|
||||
)
|
||||
|
||||
assert 'TestCalculator.test_divide_by_zero' in errors
|
||||
assert '''
|
||||
assert "TestCalculator.test_divide_by_zero" in errors
|
||||
assert errors["TestCalculator.test_divide_by_zero"] == """
|
||||
class TestCalculator:
|
||||
def test_divide_by_zero(self):
|
||||
> Calculator().divide(10, 0)
|
||||
E ZeroDivisionError: division by zero
|
||||
|
||||
code_to_optimize/tests/test_calculator.py:22: ZeroDivisionError
|
||||
''' == errors['TestCalculator.test_divide_by_zero']
|
||||
"""
|
||||
|
||||
|
||||
def test_extracting_from_invalid_pytest_stdout():
|
||||
stdout = '''
|
||||
stdout = """
|
||||
Running tests...
|
||||
Everything seems fine
|
||||
No structured output here
|
||||
Just some random logs
|
||||
'''
|
||||
"""
|
||||
|
||||
errors = parse_test_failures_from_stdout(stdout)
|
||||
assert errors == {}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import pickle
|
|||
import shutil
|
||||
import socket
|
||||
import sqlite3
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -18,7 +19,6 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
|||
from codeflash.models.models import CodePosition, TestFile, TestFiles, TestingMode, TestsInFile, TestType
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
import time
|
||||
|
||||
try:
|
||||
import sqlalchemy
|
||||
|
|
@ -35,12 +35,8 @@ from codeflash.picklepatch.pickle_placeholder import PicklePlaceholder, PicklePl
|
|||
|
||||
|
||||
def test_picklepatch_simple_nested():
|
||||
"""Test that a simple nested data structure pickles and unpickles correctly.
|
||||
"""
|
||||
original_data = {
|
||||
"numbers": [1, 2, 3],
|
||||
"nested_dict": {"key": "value", "another": 42},
|
||||
}
|
||||
"""Test that a simple nested data structure pickles and unpickles correctly."""
|
||||
original_data = {"numbers": [1, 2, 3], "nested_dict": {"key": "value", "another": 42}}
|
||||
|
||||
dumped = PicklePatcher.dumps(original_data)
|
||||
reloaded = PicklePatcher.loads(dumped)
|
||||
|
|
@ -56,10 +52,7 @@ def test_picklepatch_with_socket():
|
|||
# Create a pair of connected sockets instead of a single socket
|
||||
sock1, sock2 = socket.socketpair()
|
||||
|
||||
data_with_socket = {
|
||||
"safe_value": 123,
|
||||
"raw_socket": sock1,
|
||||
}
|
||||
data_with_socket = {"safe_value": 123, "raw_socket": sock1}
|
||||
|
||||
# Send a message through sock1, which can be received by sock2
|
||||
sock1.send(b"Hello, world!")
|
||||
|
|
@ -85,17 +78,11 @@ def test_picklepatch_with_socket():
|
|||
|
||||
|
||||
def test_picklepatch_deeply_nested():
|
||||
"""Test that deep nesting with unpicklable objects works correctly.
|
||||
"""
|
||||
"""Test that deep nesting with unpicklable objects works correctly."""
|
||||
# Create a deeply nested structure with an unpicklable object
|
||||
deep_nested = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"normal": "value",
|
||||
"socket": socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
}
|
||||
}
|
||||
"level2": {"level3": {"normal": "value", "socket": socket.socket(socket.AF_INET, socket.SOCK_STREAM)}}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -108,9 +95,10 @@ def test_picklepatch_deeply_nested():
|
|||
# The socket should be replaced with a placeholder
|
||||
assert isinstance(reloaded["level1"]["level2"]["level3"]["socket"], PicklePlaceholder)
|
||||
|
||||
|
||||
def test_picklepatch_class_with_unpicklable_attr():
|
||||
"""Test that a class with an unpicklable attribute works correctly.
|
||||
"""
|
||||
"""Test that a class with an unpicklable attribute works correctly."""
|
||||
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.normal = "normal value"
|
||||
|
|
@ -128,8 +116,6 @@ def test_picklepatch_class_with_unpicklable_attr():
|
|||
assert isinstance(reloaded.unpicklable, PicklePlaceholder)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_picklepatch_with_database_connection():
|
||||
"""Test that a data structure containing a database connection is replaced
|
||||
by PicklePlaceholder rather than raising an error.
|
||||
|
|
@ -138,11 +124,7 @@ def test_picklepatch_with_database_connection():
|
|||
conn = sqlite3.connect(":memory:")
|
||||
cursor = conn.cursor()
|
||||
|
||||
data_with_db = {
|
||||
"description": "Database connection",
|
||||
"connection": conn,
|
||||
"cursor": cursor,
|
||||
}
|
||||
data_with_db = {"description": "Database connection", "connection": conn, "cursor": cursor}
|
||||
|
||||
dumped = PicklePatcher.dumps(data_with_db)
|
||||
reloaded = PicklePatcher.loads(dumped)
|
||||
|
|
@ -158,7 +140,7 @@ def test_picklepatch_with_database_connection():
|
|||
reloaded["connection"].execute("SELECT 1")
|
||||
|
||||
cursor.close()
|
||||
conn.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_picklepatch_with_generator():
|
||||
|
|
@ -175,11 +157,7 @@ def test_picklepatch_with_generator():
|
|||
gen = simple_generator()
|
||||
|
||||
# Put it in a data structure
|
||||
data_with_generator = {
|
||||
"description": "Contains a generator",
|
||||
"generator": gen,
|
||||
"normal_list": [1, 2, 3]
|
||||
}
|
||||
data_with_generator = {"description": "Contains a generator", "generator": gen, "normal_list": [1, 2, 3]}
|
||||
|
||||
dumped = PicklePatcher.dumps(data_with_generator)
|
||||
reloaded = PicklePatcher.loads(dumped)
|
||||
|
|
@ -204,11 +182,7 @@ def test_picklepatch_loads_standard_pickle():
|
|||
using the standard pickle module.
|
||||
"""
|
||||
# Create a simple data structure
|
||||
original_data = {
|
||||
"numbers": [1, 2, 3],
|
||||
"nested_dict": {"key": "value", "another": 42},
|
||||
"tuple": (1, "two", 3.0),
|
||||
}
|
||||
original_data = {"numbers": [1, 2, 3], "nested_dict": {"key": "value", "another": 42}, "tuple": (1, "two", 3.0)}
|
||||
|
||||
# Pickle it with standard pickle
|
||||
pickled_data = pickle.dumps(original_data)
|
||||
|
|
@ -231,13 +205,7 @@ def test_picklepatch_loads_dill_pickle():
|
|||
"""
|
||||
# Create a more complex data structure that includes a lambda function
|
||||
# which dill can handle but standard pickle cannot
|
||||
original_data = {
|
||||
"numbers": [1, 2, 3],
|
||||
"function": lambda x: x * 2,
|
||||
"nested": {
|
||||
"another_function": lambda y: y ** 2
|
||||
}
|
||||
}
|
||||
original_data = {"numbers": [1, 2, 3], "function": lambda x: x * 2, "nested": {"another_function": lambda y: y**2}}
|
||||
|
||||
# Pickle it with dill
|
||||
dilled_data = dill.dumps(original_data)
|
||||
|
|
@ -253,6 +221,7 @@ def test_picklepatch_loads_dill_pickle():
|
|||
assert reloaded["function"](5) == 10
|
||||
assert reloaded["nested"]["another_function"](4) == 16
|
||||
|
||||
|
||||
def test_run_and_parse_picklepatch() -> None:
|
||||
"""Test the end to end functionality of picklepatch, from tracing benchmarks to running the replay tests.
|
||||
|
||||
|
|
@ -269,7 +238,9 @@ def test_run_and_parse_picklepatch() -> None:
|
|||
benchmarks_root = project_root / "code_to_optimize" / "tests" / "pytest" / "benchmarks_socket_test"
|
||||
replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
|
||||
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
|
||||
fto_unused_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_unused_socket.py").resolve()
|
||||
fto_unused_socket_path = (
|
||||
project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_unused_socket.py"
|
||||
).resolve()
|
||||
fto_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").resolve()
|
||||
original_fto_unused_socket_code = fto_unused_socket_path.read_text("utf-8")
|
||||
original_fto_used_socket_code = fto_used_socket_path.read_text("utf-8")
|
||||
|
|
@ -282,7 +253,8 @@ def test_run_and_parse_picklepatch() -> None:
|
|||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
|
||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
||||
)
|
||||
function_calls = cursor.fetchall()
|
||||
|
||||
# Assert the length of function calls
|
||||
|
|
@ -290,38 +262,61 @@ def test_run_and_parse_picklepatch() -> None:
|
|||
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||
assert "code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket" in function_to_results
|
||||
|
||||
assert (
|
||||
"code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"
|
||||
in function_to_results
|
||||
)
|
||||
|
||||
# Close the connection to allow file cleanup on Windows
|
||||
conn.close()
|
||||
time.sleep(1)
|
||||
|
||||
# Handle the case where function runs too fast to be measured
|
||||
unused_socket_results = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"]
|
||||
unused_socket_results = function_to_results[
|
||||
"code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"
|
||||
]
|
||||
if unused_socket_results:
|
||||
test_name, total_time, function_time, percent = unused_socket_results[0]
|
||||
assert total_time >= 0.0
|
||||
# Function might be too fast, so we allow 0.0 function_time
|
||||
assert function_time >= 0.0
|
||||
assert percent >= 0.0
|
||||
used_socket_results = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_used_socket.bubble_sort_with_used_socket"]
|
||||
used_socket_results = function_to_results[
|
||||
"code_to_optimize.bubble_sort_picklepatch_test_used_socket.bubble_sort_with_used_socket"
|
||||
]
|
||||
# on windows , if the socket is not used we might not have resultssss
|
||||
if used_socket_results:
|
||||
test_name, total_time, function_time, percent = used_socket_results[0]
|
||||
assert total_time >= 0.0
|
||||
assert function_time >= 0.0
|
||||
assert function_time >= 0.0
|
||||
assert percent >= 0.0
|
||||
|
||||
bubble_sort_unused_socket_path = (project_root / "code_to_optimize"/ "bubble_sort_picklepatch_test_unused_socket.py").as_posix()
|
||||
bubble_sort_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").as_posix()
|
||||
bubble_sort_unused_socket_path = (
|
||||
project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_unused_socket.py"
|
||||
).as_posix()
|
||||
bubble_sort_used_socket_path = (
|
||||
project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py"
|
||||
).as_posix()
|
||||
# Expected function calls
|
||||
expected_calls = [
|
||||
("bubble_sort_with_unused_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_unused_socket",
|
||||
f"{bubble_sort_unused_socket_path}",
|
||||
"test_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 12),
|
||||
("bubble_sort_with_used_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_used_socket",
|
||||
f"{bubble_sort_used_socket_path}",
|
||||
"test_used_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 20)
|
||||
(
|
||||
"bubble_sort_with_unused_socket",
|
||||
"",
|
||||
"code_to_optimize.bubble_sort_picklepatch_test_unused_socket",
|
||||
f"{bubble_sort_unused_socket_path}",
|
||||
"test_socket_picklepatch",
|
||||
"code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket",
|
||||
12,
|
||||
),
|
||||
(
|
||||
"bubble_sort_with_used_socket",
|
||||
"",
|
||||
"code_to_optimize.bubble_sort_picklepatch_test_used_socket",
|
||||
f"{bubble_sort_used_socket_path}",
|
||||
"test_used_socket_picklepatch",
|
||||
"code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket",
|
||||
20,
|
||||
),
|
||||
]
|
||||
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
|
||||
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
|
||||
|
|
@ -332,29 +327,29 @@ def test_run_and_parse_picklepatch() -> None:
|
|||
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
|
||||
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
|
||||
conn.close()
|
||||
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
# Generate replay test
|
||||
generate_replay_test(output_file, replay_tests_dir)
|
||||
replay_test_path = replay_tests_dir / Path(
|
||||
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0.py")
|
||||
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0.py"
|
||||
)
|
||||
replay_test_perf_path = replay_tests_dir / Path(
|
||||
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0_perf.py")
|
||||
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0_perf.py"
|
||||
)
|
||||
assert replay_test_path.exists()
|
||||
original_replay_test_code = replay_test_path.read_text("utf-8")
|
||||
|
||||
# Instrument the replay test
|
||||
func = FunctionToOptimize(function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_unused_socket_path))
|
||||
func = FunctionToOptimize(
|
||||
function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_unused_socket_path)
|
||||
)
|
||||
original_cwd = Path.cwd()
|
||||
run_cwd = project_root
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
replay_test_path,
|
||||
[CodePosition(17, 15)],
|
||||
func,
|
||||
project_root,
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
replay_test_path, [CodePosition(17, 15)], func, project_root, mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -386,7 +381,14 @@ def test_run_and_parse_picklepatch() -> None:
|
|||
test_type=test_type,
|
||||
original_file_path=replay_test_path,
|
||||
benchmarking_file_path=replay_test_perf_path,
|
||||
tests_in_file=[TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function, test_type=test_type)],
|
||||
tests_in_file=[
|
||||
TestsInFile(
|
||||
test_file=replay_test_path,
|
||||
test_class=None,
|
||||
test_function=replay_test_function,
|
||||
test_type=test_type,
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -400,8 +402,14 @@ def test_run_and_parse_picklepatch() -> None:
|
|||
testing_time=1.0,
|
||||
)
|
||||
assert len(test_results_unused_socket) == 1
|
||||
assert test_results_unused_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
assert test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch"
|
||||
assert (
|
||||
test_results_unused_socket.test_results[0].id.test_module_path
|
||||
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
)
|
||||
assert (
|
||||
test_results_unused_socket.test_results[0].id.test_function_name
|
||||
== "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch"
|
||||
)
|
||||
assert test_results_unused_socket.test_results[0].did_pass == True
|
||||
|
||||
# Replace with optimized candidate
|
||||
|
|
@ -431,13 +439,11 @@ def bubble_sort_with_unused_socket(data_container):
|
|||
# Remove the previous instrumentation
|
||||
replay_test_path.write_text(original_replay_test_code)
|
||||
# Instrument the replay test
|
||||
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path))
|
||||
func = FunctionToOptimize(
|
||||
function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path)
|
||||
)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
replay_test_path,
|
||||
[CodePosition(23,15)],
|
||||
func,
|
||||
project_root,
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
replay_test_path, [CodePosition(23, 15)], func, project_root, mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -449,8 +455,9 @@ def bubble_sort_with_unused_socket(data_container):
|
|||
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
test_type = TestType.REPLAY_TEST
|
||||
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[],
|
||||
file_path=Path(fto_used_socket_path))
|
||||
func = FunctionToOptimize(
|
||||
function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path)
|
||||
)
|
||||
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch"
|
||||
func_optimizer = opt.create_function_optimizer(func)
|
||||
func_optimizer.test_files = TestFiles(
|
||||
|
|
@ -461,8 +468,13 @@ def bubble_sort_with_unused_socket(data_container):
|
|||
original_file_path=replay_test_path,
|
||||
benchmarking_file_path=replay_test_perf_path,
|
||||
tests_in_file=[
|
||||
TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function,
|
||||
test_type=test_type)],
|
||||
TestsInFile(
|
||||
test_file=replay_test_path,
|
||||
test_class=None,
|
||||
test_function=replay_test_function,
|
||||
test_type=test_type,
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -476,10 +488,14 @@ def bubble_sort_with_unused_socket(data_container):
|
|||
testing_time=1.0,
|
||||
)
|
||||
assert len(test_results_used_socket) == 1
|
||||
assert test_results_used_socket.test_results[
|
||||
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
assert test_results_used_socket.test_results[
|
||||
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch"
|
||||
assert (
|
||||
test_results_used_socket.test_results[0].id.test_module_path
|
||||
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
)
|
||||
assert (
|
||||
test_results_used_socket.test_results[0].id.test_function_name
|
||||
== "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch"
|
||||
)
|
||||
assert test_results_used_socket.test_results[0].did_pass is False
|
||||
print("test results used socket")
|
||||
print(test_results_used_socket)
|
||||
|
|
@ -507,10 +523,14 @@ def bubble_sort_with_used_socket(data_container):
|
|||
testing_time=1.0,
|
||||
)
|
||||
assert len(test_results_used_socket) == 1
|
||||
assert test_results_used_socket.test_results[
|
||||
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
assert test_results_used_socket.test_results[
|
||||
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch"
|
||||
assert (
|
||||
test_results_used_socket.test_results[0].id.test_module_path
|
||||
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
)
|
||||
assert (
|
||||
test_results_used_socket.test_results[0].id.test_function_name
|
||||
== "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch"
|
||||
)
|
||||
assert test_results_used_socket.test_results[0].did_pass is False
|
||||
|
||||
# Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests
|
||||
from codeflash.models.models import GeneratedTests, GeneratedTestsList
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,6 @@
|
|||
import tempfile
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
||||
def test_variable_removal_only() -> None:
|
||||
|
|
@ -337,6 +329,7 @@ def unused_function():
|
|||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_base_class_inheritance() -> None:
|
||||
"""Test that base classes used only for inheritance are preserved."""
|
||||
code = """
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue