format and lint all

This commit is contained in:
misrasaurabh1 2026-01-29 01:39:48 -08:00
parent 0e5ad411ec
commit 198487bf81
109 changed files with 2704 additions and 3211 deletions

View file

@ -129,7 +129,8 @@ class AiServiceClient:
experiment_metadata: ExperimentMetadata | None = None,
*,
language: str = "python",
language_version: str | None = None, # TODO:{claude} add language version to the language support and it should be cached
language_version: str
| None = None, # TODO:{claude} add language version to the language support and it should be cached
module_system: str | None = None,
is_async: bool = False,
n_candidates: int = 5,
@ -238,6 +239,7 @@ class AiServiceClient:
is_async=is_async,
n_candidates=n_candidates,
)
def get_jit_rewritten_code( # noqa: D417
self, source_code: str, trace_id: str
) -> list[OptimizedCandidate]:
@ -410,6 +412,7 @@ class AiServiceClient:
Returns:
List of refined optimization candidates
"""
payload = []
for opt in request:
@ -727,7 +730,7 @@ class AiServiceClient:
language: str = "python",
language_version: str | None = None,
module_system: str | None = None,
is_numerical_code: bool | None = None, # noqa: FBT001
is_numerical_code: bool | None = None,
) -> tuple[str, str, str] | None:
"""Generate regression tests for the given function by making a request to the Django endpoint.

View file

@ -14,7 +14,6 @@ from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any
@ -192,10 +191,7 @@ class TestGenRequest:
"""Convert to API payload dict, maintaining backward compatibility."""
payload = {
"source_code_being_tested": self.source_code,
"function_to_optimize": {
"function_name": self.function_name,
"is_async": self.is_async,
},
"function_to_optimize": {"function_name": self.function_name, "is_async": self.is_async},
"helper_function_names": self.helper_function_names,
"module_path": self.module_path,
"test_module_path": self.test_module_path,
@ -243,24 +239,16 @@ def python_language_info(version: str | None = None) -> LanguageInfo:
def javascript_language_info(
module_system: ModuleSystem = ModuleSystem.COMMONJS,
version: str = "ES2022",
module_system: ModuleSystem = ModuleSystem.COMMONJS, version: str = "ES2022"
) -> LanguageInfo:
"""Create LanguageInfo for JavaScript."""
ext = ".mjs" if module_system == ModuleSystem.ESM else ".js"
return LanguageInfo(
name="javascript",
version=version,
module_system=module_system,
file_extension=ext,
has_type_annotations=False,
name="javascript", version=version, module_system=module_system, file_extension=ext, has_type_annotations=False
)
def typescript_language_info(
module_system: ModuleSystem = ModuleSystem.ESM,
version: str = "ES2022",
) -> LanguageInfo:
def typescript_language_info(module_system: ModuleSystem = ModuleSystem.ESM, version: str = "ES2022") -> LanguageInfo:
"""Create LanguageInfo for TypeScript."""
return LanguageInfo(
name="typescript",

View file

@ -1516,6 +1516,7 @@ def _customize_python_workflow_content(
codeflash_cmd += " --benchmark"
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
# TODO:{claude} Refactor and move to support for language specific
def _customize_js_workflow_content(
optimize_yml_content: str,

View file

@ -108,6 +108,7 @@ def code_print(
function_name: Optional function name for LSP
lsp_message_id: Optional LSP message ID
language: Programming language for syntax highlighting ('python', 'javascript', 'typescript')
"""
if is_LSP_enabled():
lsp_log(
@ -118,11 +119,7 @@ def code_print(
from rich.syntax import Syntax
# Map codeflash language names to rich/pygments lexer names
lexer_map = {
"python": "python",
"javascript": "javascript",
"typescript": "typescript",
}
lexer_map = {"python": "python", "javascript": "javascript", "typescript": "typescript"}
lexer = lexer_map.get(language, "python")
console.rule()

View file

@ -1,4 +1,5 @@
"""JavaScript/TypeScript project initialization for Codeflash."""
# TODO:{claude} move to language support directory
from __future__ import annotations
@ -125,8 +126,7 @@ def init_js_project(language: ProjectLanguage) -> None:
lang_panel = Panel(
Text(
f"📦 Detected {lang_name} project!\n\n"
"I'll help you set up Codeflash for your project.",
f"📦 Detected {lang_name} project!\n\nI'll help you set up Codeflash for your project.",
style="cyan",
justify="center",
),
@ -159,13 +159,13 @@ def init_js_project(language: ProjectLanguage) -> None:
usage_table.add_column("Command", style="cyan")
usage_table.add_column("Description", style="white")
usage_table.add_row(
"codeflash --file <path-to-file> --function <function-name>", "Optimize a specific function"
)
usage_table.add_row("codeflash --file <path-to-file> --function <function-name>", "Optimize a specific function")
usage_table.add_row("codeflash --all", "Optimize all functions in all files")
usage_table.add_row("codeflash --help", "See all available options")
completion_message = f"⚡️ Codeflash is now set up for your {lang_name} project!\n\nYou can now run any of these commands:"
completion_message = (
f"⚡️ Codeflash is now set up for your {lang_name} project!\n\nYou can now run any of these commands:"
)
if did_add_new_key:
completion_message += (
@ -265,10 +265,7 @@ def collect_js_setup_info(language: ProjectLanguage) -> JSSetupInfo:
detection_table.add_row("Formatter", formatter_display)
detection_panel = Panel(
Group(
Text(f"Auto-detected settings for your {lang_name} project:\n", style="cyan"),
detection_table,
),
Group(Text(f"Auto-detected settings for your {lang_name} project:\n", style="cyan"), detection_table),
title="🔍 Auto-Detection Results",
border_style="bright_blue",
)
@ -399,8 +396,7 @@ def _get_git_remote_for_setup() -> str:
git_panel = Panel(
Text(
"🔗 Configure Git Remote for Pull Requests.\n\n"
"Codeflash will use this remote to create pull requests.",
"🔗 Configure Git Remote for Pull Requests.\n\nCodeflash will use this remote to create pull requests.",
style="blue",
),
title="🔗 Git Remote Setup",

View file

@ -206,6 +206,7 @@ def find_package_json(config_file: Path | None = None) -> Path | None:
return None
def parse_package_json_config(package_json_path: Path) -> tuple[dict[str, Any], Path] | None:
"""Parse codeflash config from package.json with auto-detection.

View file

@ -8,14 +8,10 @@ from __future__ import annotations
import hashlib
import re
from typing import TYPE_CHECKING
from codeflash.code_utils.normalizers import get_normalizer
from codeflash.languages import current_language, is_python
if TYPE_CHECKING:
pass
def normalize_code(
code: str,
@ -35,6 +31,7 @@ def normalize_code(
Returns:
Normalized code as string
"""
if language is None:
language = current_language().value
@ -89,6 +86,7 @@ def get_code_fingerprint(code: str, language: str | None = None) -> str:
Returns:
SHA-256 hash of normalized code
"""
if language is None:
language = current_language().value
@ -112,6 +110,7 @@ def are_codes_duplicate(code1: str, code2: str, language: str | None = None) ->
Returns:
True if codes are structurally identical (ignoring local variable names)
"""
if language is None:
language = current_language().value
@ -127,8 +126,4 @@ def are_codes_duplicate(code1: str, code2: str, language: str | None = None) ->
# Re-export for backward compatibility
__all__ = [
"normalize_code",
"get_code_fingerprint",
"are_codes_duplicate",
]
__all__ = ["are_codes_duplicate", "get_code_fingerprint", "normalize_code"]

View file

@ -22,13 +22,10 @@ from codeflash.code_utils.normalizers.base import CodeNormalizer
from codeflash.code_utils.normalizers.javascript import JavaScriptNormalizer, TypeScriptNormalizer
from codeflash.code_utils.normalizers.python import PythonNormalizer
if TYPE_CHECKING:
pass
__all__ = [
"CodeNormalizer",
"PythonNormalizer",
"JavaScriptNormalizer",
"PythonNormalizer",
"TypeScriptNormalizer",
"get_normalizer",
"get_normalizer_for_extension",
@ -56,6 +53,7 @@ def get_normalizer(language: str) -> CodeNormalizer:
Raises:
ValueError: If no normalizer exists for the language
"""
language = language.lower()
@ -83,6 +81,7 @@ def get_normalizer_for_extension(extension: str) -> CodeNormalizer | None:
Returns:
CodeNormalizer instance if found, None otherwise
"""
extension = extension.lower()
if not extension.startswith("."):
@ -102,6 +101,7 @@ def register_normalizer(language: str, normalizer_class: type[CodeNormalizer]) -
Args:
language: Language name
normalizer_class: CodeNormalizer subclass
"""
_NORMALIZERS[language.lower()] = normalizer_class
# Clear cached instance if it exists

View file

@ -4,14 +4,11 @@ Code normalizers transform source code into a canonical form for duplicate detec
They normalize variable names, remove comments/docstrings, and produce consistent output
that can be compared across different implementations of the same algorithm.
"""
# TODO:{claude} move to base.py in language folder
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
pass
class CodeNormalizer(ABC):
@ -30,6 +27,7 @@ class CodeNormalizer(ABC):
>>> code2 = "def foo(x): z = x + 1; return z"
>>> normalizer.normalize(code1) == normalizer.normalize(code2)
True
"""
@property
@ -52,6 +50,7 @@ class CodeNormalizer(ABC):
Returns:
Normalized representation of the code
"""
...
@ -66,6 +65,7 @@ class CodeNormalizer(ABC):
Returns:
Normalized representation suitable for hashing
"""
...
@ -78,6 +78,7 @@ class CodeNormalizer(ABC):
Returns:
True if codes are structurally identical
"""
try:
normalized1 = self.normalize_for_hash(code1)
@ -94,6 +95,7 @@ class CodeNormalizer(ABC):
Returns:
SHA-256 hash of normalized code
"""
import hashlib

View file

@ -10,7 +10,8 @@ from codeflash.code_utils.normalizers.base import CodeNormalizer
if TYPE_CHECKING:
from tree_sitter import Node
# TODO:{claude} move to language support directory to keep the directory structure clean
# TODO:{claude} move to language support directory to keep the directory structure clean
class JavaScriptVariableNormalizer:
"""Normalizes JavaScript/TypeScript code for duplicate detection using tree-sitter.
@ -24,13 +25,49 @@ class JavaScriptVariableNormalizer:
self.preserved_names: set[str] = set()
# Common JavaScript builtins
self.builtins = {
"console", "window", "document", "Math", "JSON", "Object", "Array",
"String", "Number", "Boolean", "Date", "RegExp", "Error", "Promise",
"Map", "Set", "WeakMap", "WeakSet", "Symbol", "Proxy", "Reflect",
"undefined", "null", "NaN", "Infinity", "globalThis", "parseInt",
"parseFloat", "isNaN", "isFinite", "eval", "setTimeout", "setInterval",
"clearTimeout", "clearInterval", "fetch", "require", "module", "exports",
"process", "__dirname", "__filename", "Buffer",
"console",
"window",
"document",
"Math",
"JSON",
"Object",
"Array",
"String",
"Number",
"Boolean",
"Date",
"RegExp",
"Error",
"Promise",
"Map",
"Set",
"WeakMap",
"WeakSet",
"Symbol",
"Proxy",
"Reflect",
"undefined",
"null",
"NaN",
"Infinity",
"globalThis",
"parseInt",
"parseFloat",
"isNaN",
"isFinite",
"eval",
"setTimeout",
"setInterval",
"clearTimeout",
"clearInterval",
"fetch",
"require",
"module",
"exports",
"process",
"__dirname",
"__filename",
"Buffer",
}
def get_normalized_name(self, name: str) -> str:
@ -48,7 +85,7 @@ class JavaScriptVariableNormalizer:
if node.type in ("function_declaration", "function_expression", "method_definition", "arrow_function"):
name_node = node.child_by_field_name("name")
if name_node:
self.preserved_names.add(source_code[name_node.start_byte:name_node.end_byte].decode("utf-8"))
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
# Preserve parameters
params_node = node.child_by_field_name("parameters") or node.child_by_field_name("parameter")
if params_node:
@ -58,7 +95,7 @@ class JavaScriptVariableNormalizer:
elif node.type == "class_declaration":
name_node = node.child_by_field_name("name")
if name_node:
self.preserved_names.add(source_code[name_node.start_byte:name_node.end_byte].decode("utf-8"))
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
# Import declarations
elif node.type in ("import_statement", "import_declaration"):
@ -66,7 +103,7 @@ class JavaScriptVariableNormalizer:
if child.type == "import_clause":
self._collect_import_names(child, source_code)
elif child.type == "identifier":
self.preserved_names.add(source_code[child.start_byte:child.end_byte].decode("utf-8"))
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
# Recurse
for child in node.children:
@ -76,11 +113,13 @@ class JavaScriptVariableNormalizer:
"""Collect parameter names from a parameters node."""
for child in node.children:
if child.type == "identifier":
self.preserved_names.add(source_code[child.start_byte:child.end_byte].decode("utf-8"))
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
elif child.type in ("required_parameter", "optional_parameter", "rest_parameter"):
pattern_node = child.child_by_field_name("pattern")
if pattern_node and pattern_node.type == "identifier":
self.preserved_names.add(source_code[pattern_node.start_byte:pattern_node.end_byte].decode("utf-8"))
self.preserved_names.add(
source_code[pattern_node.start_byte : pattern_node.end_byte].decode("utf-8")
)
# Recurse for nested patterns
self._collect_parameter_names(child, source_code)
@ -88,15 +127,15 @@ class JavaScriptVariableNormalizer:
"""Collect imported names from import clause."""
for child in node.children:
if child.type == "identifier":
self.preserved_names.add(source_code[child.start_byte:child.end_byte].decode("utf-8"))
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
elif child.type == "import_specifier":
# Get the local name (alias or original)
alias_node = child.child_by_field_name("alias")
name_node = child.child_by_field_name("name")
if alias_node:
self.preserved_names.add(source_code[alias_node.start_byte:alias_node.end_byte].decode("utf-8"))
self.preserved_names.add(source_code[alias_node.start_byte : alias_node.end_byte].decode("utf-8"))
elif name_node:
self.preserved_names.add(source_code[name_node.start_byte:name_node.end_byte].decode("utf-8"))
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
self._collect_import_names(child, source_code)
def normalize_tree(self, node: Node, source_code: bytes) -> str:
@ -113,14 +152,14 @@ class JavaScriptVariableNormalizer:
# Handle identifiers - normalize variable names
if node.type == "identifier":
name = source_code[node.start_byte:node.end_byte].decode("utf-8")
name = source_code[node.start_byte : node.end_byte].decode("utf-8")
normalized = self.get_normalized_name(name)
parts.append(normalized)
return
# Handle type identifiers (TypeScript) - preserve as-is
if node.type == "type_identifier":
parts.append(source_code[node.start_byte:node.end_byte].decode("utf-8"))
parts.append(source_code[node.start_byte : node.end_byte].decode("utf-8"))
return
# Handle string literals - normalize to placeholder
@ -135,7 +174,7 @@ class JavaScriptVariableNormalizer:
# For leaf nodes, output the node type
if len(node.children) == 0:
text = source_code[node.start_byte:node.end_byte].decode("utf-8")
text = source_code[node.start_byte : node.end_byte].decode("utf-8")
parts.append(text)
return
@ -191,14 +230,12 @@ class JavaScriptNormalizer(CodeNormalizer):
Returns:
Normalized representation of the code
"""
try:
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
lang_map = {
"javascript": TreeSitterLanguage.JAVASCRIPT,
"typescript": TreeSitterLanguage.TYPESCRIPT,
}
lang_map = {"javascript": TreeSitterLanguage.JAVASCRIPT, "typescript": TreeSitterLanguage.TYPESCRIPT}
lang = lang_map.get(self._get_tree_sitter_language(), TreeSitterLanguage.JAVASCRIPT)
analyzer = TreeSitterAnalyzer(lang)
tree = analyzer.parse(code)
@ -227,6 +264,7 @@ class JavaScriptNormalizer(CodeNormalizer):
Returns:
Normalized representation suitable for hashing
"""
return self.normalize(code)

View file

@ -3,13 +3,9 @@
from __future__ import annotations
import ast
from typing import TYPE_CHECKING
from codeflash.code_utils.normalizers.base import CodeNormalizer
if TYPE_CHECKING:
pass
class VariableNormalizer(ast.NodeTransformer):
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
@ -197,6 +193,7 @@ class PythonNormalizer(CodeNormalizer):
Returns:
Normalized Python code as a string
"""
tree = ast.parse(code)
@ -219,6 +216,7 @@ class PythonNormalizer(CodeNormalizer):
Returns:
AST dump string suitable for hashing
"""
tree = ast.parse(code)
_remove_docstrings_from_ast(tree)

View file

@ -207,6 +207,7 @@ def get_code_optimization_context(
preexisting_objects=preexisting_objects,
)
def get_code_optimization_context_for_language(
function_to_optimize: FunctionToOptimize,
project_root_path: Path,
@ -356,6 +357,7 @@ def get_code_optimization_context_for_language(
preexisting_objects=set(), # Not implemented for non-Python yet
)
def extract_code_markdown_context_from_files(
helpers_of_fto: dict[Path, set[FunctionSource]],
helpers_of_helpers: dict[Path, set[FunctionSource]],

View file

@ -28,8 +28,8 @@ from codeflash.code_utils.code_utils import (
module_name_from_file_path,
)
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
from codeflash.languages import is_javascript, is_python
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
from codeflash.languages import is_javascript, is_python
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
if TYPE_CHECKING:

View file

@ -1,5 +1,4 @@
"""
Multi-language support for Codeflash.
"""Multi-language support for Codeflash.
This package provides the abstraction layer that allows Codeflash to support
multiple programming languages while keeping the core optimization pipeline
@ -37,6 +36,11 @@ from codeflash.languages.current import (
reset_current_language,
set_current_language,
)
from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401
# Import language support modules to trigger auto-registration
# This ensures all supported languages are available when this package is imported
from codeflash.languages.python import PythonSupport # noqa: F401
from codeflash.languages.registry import (
detect_project_language,
get_language_support,
@ -45,11 +49,6 @@ from codeflash.languages.registry import (
register_language,
)
# Import language support modules to trigger auto-registration
# This ensures all supported languages are available when this package is imported
from codeflash.languages.python import PythonSupport # noqa: F401
from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401
__all__ = [
# Base types
"CodeContext",

View file

@ -507,10 +507,7 @@ class LanguageSupport(Protocol):
# === Test Result Comparison ===
def compare_test_results(
self,
original_results_path: Path,
candidate_results_path: Path,
project_root: Path | None = None,
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
) -> tuple[bool, list]:
"""Compare test results between original and candidate code.

View file

@ -23,6 +23,7 @@ def _get_compare_results_script(project_root: Path | None = None) -> Path | None
Returns:
Path to compare-results.js if found, None otherwise.
"""
search_dirs = []
if project_root:
@ -59,6 +60,7 @@ def compare_test_results(
Returns:
Tuple of (all_equivalent, list of TestDiff objects).
"""
script_path = comparator_script or _get_compare_results_script(project_root)

View file

@ -75,9 +75,7 @@ class StandaloneCallTransformer:
# Pattern to match func_name( with optional leading await and optional object prefix
# Captures: (whitespace)(await )?(object.)*func_name(
# We'll filter out expect() and codeflash. cases in the transform loop
self._call_pattern = re.compile(
rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(func_name)}\s*\("
)
self._call_pattern = re.compile(rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(func_name)}\s*\(")
def transform(self, code: str) -> str:
"""Transform all standalone calls in the code."""
@ -353,9 +351,7 @@ class ExpectCallTransformer:
self.invocation_counter = 0
# Pattern to match start of expect((object.)*func_name(
# Captures: (whitespace), (object prefix like calc. or this.)
self._expect_pattern = re.compile(
rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(func_name)}\s*\("
)
self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(func_name)}\s*\(")
def transform(self, code: str) -> str:
"""Transform all expect calls in the code."""

View file

@ -23,7 +23,12 @@ from codeflash.languages.base import (
TestResult,
)
from codeflash.languages.registry import register_language
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, TypeDefinition, get_analyzer_for_file
from codeflash.languages.treesitter_utils import (
TreeSitterAnalyzer,
TreeSitterLanguage,
TypeDefinition,
get_analyzer_for_file,
)
if TYPE_CHECKING:
from collections.abc import Sequence
@ -346,7 +351,9 @@ class JavaScriptSupport:
# Wrap the method in a class definition with context
if class_jsdoc:
target_code = f"{class_jsdoc}\n{class_indent}class {class_name} {{\n{class_body}{class_indent}}}\n"
target_code = (
f"{class_jsdoc}\n{class_indent}class {class_name} {{\n{class_body}{class_indent}}}\n"
)
else:
target_code = f"{class_indent}class {class_name} {{\n{class_body}{class_indent}}}\n"
else:
@ -366,11 +373,7 @@ class JavaScriptSupport:
# Extract type definitions for function parameters and class fields
type_definitions_context, type_definition_names = self._extract_type_definitions_context(
function=function,
source=source,
analyzer=analyzer,
imports=imports,
module_root=module_root,
function=function, source=source, analyzer=analyzer, imports=imports, module_root=module_root
)
# Find module-level declarations (global variables/constants) referenced by the function
@ -715,12 +718,7 @@ class JavaScriptSupport:
return "\n".join(global_lines)
def _extract_type_definitions_context(
self,
function: FunctionInfo,
source: str,
analyzer: TreeSitterAnalyzer,
imports: list[Any],
module_root: Path,
self, function: FunctionInfo, source: str, analyzer: TreeSitterAnalyzer, imports: list[Any], module_root: Path
) -> tuple[str, set[str]]:
"""Extract type definitions used by the function for read-only context.
@ -805,11 +803,7 @@ class JavaScriptSupport:
return "\n\n".join(type_def_parts), found_type_names
def _find_imported_type_definitions(
self,
type_names: set[str],
imports: list[Any],
module_root: Path,
source_file_path: Path,
self, type_names: set[str], imports: list[Any], module_root: Path, source_file_path: Path
) -> list[TypeDefinition]:
"""Find type definitions in imported files.
@ -1583,9 +1577,7 @@ class JavaScriptSupport:
return None
def get_module_path(
self, source_file: Path, project_root: Path, tests_root: Path | None = None
) -> str:
def get_module_path(self, source_file: Path, project_root: Path, tests_root: Path | None = None) -> str:
"""Get the module path for importing a JavaScript source file from tests.
For JavaScript, this returns a relative path from the tests directory to the source file

View file

@ -1,5 +1,4 @@
"""
Python language support for Codeflash.
"""Python language support for Codeflash.
This module provides the PythonSupport class which wraps the existing
Python-specific implementations (LibCST, Jedi, pytest, etc.) to conform

View file

@ -73,7 +73,7 @@ class PythonSupport:
"""
import libcst as cst
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize, FunctionVisitor
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionVisitor
criteria = filter_criteria or FunctionFilterCriteria()
@ -339,12 +339,7 @@ class PythonSupport:
# Try ruff first
try:
result = subprocess.run(
["ruff", "format", "-"],
check=False,
input=source,
capture_output=True,
text=True,
timeout=30,
["ruff", "format", "-"], check=False, input=source, capture_output=True, text=True, timeout=30
)
if result.returncode == 0:
return result.stdout
@ -356,12 +351,7 @@ class PythonSupport:
# Try black as fallback
try:
result = subprocess.run(
["black", "-q", "-"],
check=False,
input=source,
capture_output=True,
text=True,
timeout=30,
["black", "-q", "-"], check=False, input=source, capture_output=True, text=True, timeout=30
)
if result.returncode == 0:
return result.stdout
@ -397,19 +387,11 @@ class PythonSupport:
junit_xml = output_dir / "pytest-results.xml"
# Build pytest command
cmd = [
"python",
"-m",
"pytest",
f"--junitxml={junit_xml}",
"-v",
]
cmd = ["python", "-m", "pytest", f"--junitxml={junit_xml}", "-v"]
cmd.extend(str(f) for f in test_files)
try:
result = subprocess.run(
cmd, check=False, cwd=cwd, env=env, capture_output=True, text=True, timeout=timeout
)
result = subprocess.run(cmd, check=False, cwd=cwd, env=env, capture_output=True, text=True, timeout=timeout)
results = self.parse_test_results(junit_xml, result.stdout)
return results, junit_xml
@ -653,11 +635,7 @@ class PythonSupport:
"""
# Common test directory patterns for Python
test_dirs = [
project_root / "tests",
project_root / "test",
project_root / "spec",
]
test_dirs = [project_root / "tests", project_root / "test", project_root / "spec"]
for test_dir in test_dirs:
if test_dir.exists() and test_dir.is_dir():
@ -669,9 +647,7 @@ class PythonSupport:
return None
def get_module_path(
self, source_file: Path, project_root: Path, tests_root: Path | None = None
) -> str:
def get_module_path(self, source_file: Path, project_root: Path, tests_root: Path | None = None) -> str:
"""Get the module path for importing a Python source file.
For Python, this returns a dot-separated module path (e.g., 'mypackage.mymodule').
@ -778,4 +754,4 @@ class PythonSupport:
# Note: For Python, test execution is handled by the main test_runner.py
# which has special Python-specific logic. These methods are not called
# for Python as the test_runner checks is_python() and uses the existing path.
# They are defined here only for protocol compliance.
# They are defined here only for protocol compliance.

View file

@ -1239,9 +1239,16 @@ class TreeSitterAnalyzer:
return type_names
def _find_function_node(self, node: Node, source_bytes: bytes, function_name: str, function_line: int) -> Node | None:
def _find_function_node(
self, node: Node, source_bytes: bytes, function_name: str, function_line: int
) -> Node | None:
"""Find a function/method node by name and line number."""
if node.type in ("function_declaration", "method_definition", "function_expression", "generator_function_declaration"):
if node.type in (
"function_declaration",
"method_definition",
"function_expression",
"generator_function_declaration",
):
name_node = node.child_by_field_name("name")
if name_node:
name = self.get_node_text(name_node, source_bytes)
@ -1306,14 +1313,40 @@ class TreeSitterAnalyzer:
if node.type == "type_identifier":
type_name = self.get_node_text(node, source_bytes)
# Skip primitive types
if type_name not in ("number", "string", "boolean", "void", "null", "undefined", "any", "never", "unknown", "object", "symbol", "bigint"):
if type_name not in (
"number",
"string",
"boolean",
"void",
"null",
"undefined",
"any",
"never",
"unknown",
"object",
"symbol",
"bigint",
):
type_names.add(type_name)
return
# Handle regular identifiers in type position (can happen in some contexts)
if node.type == "identifier" and node.parent and node.parent.type in ("type_annotation", "generic_type"):
type_name = self.get_node_text(node, source_bytes)
if type_name not in ("number", "string", "boolean", "void", "null", "undefined", "any", "never", "unknown", "object", "symbol", "bigint"):
if type_name not in (
"number",
"string",
"boolean",
"void",
"null",
"undefined",
"any",
"never",
"unknown",
"object",
"symbol",
"bigint",
):
type_names.add(type_name)
return
@ -1360,7 +1393,12 @@ class TreeSitterAnalyzer:
# Handle export statements - unwrap to get the actual definition
if node.type == "export_statement":
for child in node.children:
if child.type in ("interface_declaration", "type_alias_declaration", "class_declaration", "enum_declaration"):
if child.type in (
"interface_declaration",
"type_alias_declaration",
"class_declaration",
"enum_declaration",
):
self._extract_type_definition(child, source_bytes, definitions, is_exported=True)
return

View file

@ -41,7 +41,6 @@ from codeflash.code_utils.code_utils import (
extract_unique_errors,
file_name_from_test_module_name,
get_run_tmp_file,
module_name_from_file_path,
normalize_by_max,
restore_conftest,
unified_diff_strings,
@ -81,7 +80,6 @@ from codeflash.languages import is_python
from codeflash.languages.base import FunctionInfo, Language
from codeflash.languages.current import current_language_support, is_typescript
from codeflash.languages.javascript.module_system import detect_module_system
from codeflash.languages.registry import get_language_support
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
from codeflash.models.ExperimentMetadata import ExperimentMetadata

View file

@ -31,7 +31,9 @@ def existing_tests_source_for(
optimized_runtimes_all: dict[InvocationId, list[int]],
test_files_registry: TestFiles | None = None,
) -> tuple[str, str, str]:
logger.debug(f"[PR-DEBUG] existing_tests_source_for called with func={function_qualified_name_with_modules_from_root}")
logger.debug(
f"[PR-DEBUG] existing_tests_source_for called with func={function_qualified_name_with_modules_from_root}"
)
logger.debug(f"[PR-DEBUG] function_to_tests keys: {list(function_to_tests.keys())}")
logger.debug(f"[PR-DEBUG] original_runtimes_all has {len(original_runtimes_all)} entries")
logger.debug(f"[PR-DEBUG] optimized_runtimes_all has {len(optimized_runtimes_all)} entries")
@ -60,11 +62,17 @@ def existing_tests_source_for(
for tf in test_files_registry.test_files:
if tf.original_file_path:
if tf.instrumented_behavior_file_path:
instrumented_to_original[tf.instrumented_behavior_file_path.resolve()] = tf.original_file_path.resolve()
logger.debug(f"[PR-DEBUG] Mapping (behavior): {tf.instrumented_behavior_file_path.name} -> {tf.original_file_path.name}")
instrumented_to_original[tf.instrumented_behavior_file_path.resolve()] = (
tf.original_file_path.resolve()
)
logger.debug(
f"[PR-DEBUG] Mapping (behavior): {tf.instrumented_behavior_file_path.name} -> {tf.original_file_path.name}"
)
if tf.benchmarking_file_path:
instrumented_to_original[tf.benchmarking_file_path.resolve()] = tf.original_file_path.resolve()
logger.debug(f"[PR-DEBUG] Mapping (perf): {tf.benchmarking_file_path.name} -> {tf.original_file_path.name}")
logger.debug(
f"[PR-DEBUG] Mapping (perf): {tf.benchmarking_file_path.name} -> {tf.original_file_path.name}"
)
# Resolve all paths to absolute for consistent comparison
non_generated_tests: set[Path] = set()
@ -84,9 +92,22 @@ def existing_tests_source_for(
# For Python, it's a module name (e.g., "tests.test_example") that needs conversion
test_module_path = invocation_id.test_module_path
# Jest test file extensions (including .test.ts, .spec.ts patterns)
jest_test_extensions = (".test.ts", ".test.js", ".test.tsx", ".test.jsx",
".spec.ts", ".spec.js", ".spec.tsx", ".spec.jsx",
".ts", ".js", ".tsx", ".jsx", ".mjs", ".mts")
jest_test_extensions = (
".test.ts",
".test.js",
".test.tsx",
".test.jsx",
".spec.ts",
".spec.js",
".spec.tsx",
".spec.jsx",
".ts",
".js",
".tsx",
".jsx",
".mjs",
".mts",
)
# Find the appropriate extension
matched_ext = None
for ext in jest_test_extensions:
@ -96,7 +117,7 @@ def existing_tests_source_for(
if matched_ext:
# JavaScript/TypeScript: convert module-style path to file path
# "tests.fibonacci__perfinstrumented.test.ts" -> "tests/fibonacci__perfinstrumented.test.ts"
base_path = test_module_path[:-len(matched_ext)]
base_path = test_module_path[: -len(matched_ext)]
# Convert dots to path separators in the base path
file_path = base_path.replace(".", os.sep) + matched_ext
# Check if the module path includes the tests directory name

View file

@ -25,10 +25,7 @@ if TYPE_CHECKING:
def generate_concolic_tests(
test_cfg: TestConfig,
args: Namespace,
function_to_optimize: FunctionToOptimize,
function_to_optimize_ast: ast.AST,
test_cfg: TestConfig, args: Namespace, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.AST
) -> tuple[dict[str, set[FunctionCalledInTest]], str]:
"""Generate concolic tests using CrossHair (Python only).
@ -43,6 +40,7 @@ def generate_concolic_tests(
Returns:
Tuple of (function_to_tests mapping, concolic test suite code)
"""
start_time = time.perf_counter()
function_to_concolic_tests = {}

View file

@ -20,16 +20,14 @@ if TYPE_CHECKING:
from codeflash.models.models import CodeOptimizationContext
# TODO:{self} Needs cleanup for jest logic check for coverage algorithm here and if we need to move it to /support
class JestCoverageUtils:
"""Coverage utils class for interfacing with Jest coverage output."""
@staticmethod
def load_from_jest_json(
coverage_json_path: Path,
function_name: str,
code_context: CodeOptimizationContext,
source_code_path: Path,
coverage_json_path: Path, function_name: str, code_context: CodeOptimizationContext, source_code_path: Path
) -> CoverageData:
"""Load coverage data from Jest's coverage-final.json file.
@ -41,6 +39,7 @@ class JestCoverageUtils:
Returns:
CoverageData object with parsed coverage information
"""
if not coverage_json_path or not coverage_json_path.exists():
logger.debug(f"Jest coverage file not found: {coverage_json_path}")

View file

@ -18,6 +18,7 @@ reprlib_repr = reprlib.Repr()
reprlib_repr.maxstring = 1500
test_diff_repr = reprlib_repr.repr
def safe_repr(obj: object) -> str:
"""Safely get repr of an object, handling Mock objects with corrupted state."""
try:
@ -25,6 +26,7 @@ def safe_repr(obj: object) -> str:
except (AttributeError, TypeError, RecursionError) as e:
return f"<repr failed: {type(e).__name__}: {e}>"
def compare_test_results(
original_results: TestResults,
candidate_results: TestResults,

View file

@ -246,20 +246,30 @@ def parse_jest_json_results(
# Check behavior path
if test_file.instrumented_behavior_file_path:
try:
rel_path = str(test_file.instrumented_behavior_file_path.relative_to(test_config.tests_project_rootdir))
rel_path = str(
test_file.instrumented_behavior_file_path.relative_to(test_config.tests_project_rootdir)
)
except ValueError:
rel_path = test_file.instrumented_behavior_file_path.name
if rel_path == expected_path or rel_path.replace("/", ".").replace(".js", "") == result_module_path:
if (
rel_path == expected_path
or rel_path.replace("/", ".").replace(".js", "") == result_module_path
):
test_file_path = test_file.instrumented_behavior_file_path
test_type = test_file.test_type
break
# Check benchmarking path
if test_file.benchmarking_file_path:
try:
rel_path = str(test_file.benchmarking_file_path.relative_to(test_config.tests_project_rootdir))
rel_path = str(
test_file.benchmarking_file_path.relative_to(test_config.tests_project_rootdir)
)
except ValueError:
rel_path = test_file.benchmarking_file_path.name
if rel_path == expected_path or rel_path.replace("/", ".").replace(".js", "") == result_module_path:
if (
rel_path == expected_path
or rel_path.replace("/", ".").replace(".js", "") == result_module_path
):
test_file_path = test_file.benchmarking_file_path
test_type = test_file.test_type
break
@ -416,9 +426,22 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
# For Python, it's a module path (e.g., "tests.test_foo") that needs conversion
if is_jest:
# Jest test file extensions (including .test.ts, .spec.ts patterns)
jest_test_extensions = (".test.ts", ".test.js", ".test.tsx", ".test.jsx",
".spec.ts", ".spec.js", ".spec.tsx", ".spec.jsx",
".ts", ".js", ".tsx", ".jsx", ".mjs", ".mts")
jest_test_extensions = (
".test.ts",
".test.js",
".test.tsx",
".test.jsx",
".spec.ts",
".spec.js",
".spec.tsx",
".spec.jsx",
".ts",
".js",
".tsx",
".jsx",
".mjs",
".mts",
)
# Check if it's a module-style path (no slashes, has dots beyond extension)
if "/" not in test_module_path and "\\" not in test_module_path:
# Find the appropriate extension to preserve
@ -430,7 +453,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
if extension:
# Convert module-style path to file path
# "tests.fibonacci__perfinstrumented.test.ts" -> "tests/fibonacci__perfinstrumented.test.ts"
base_path = test_module_path[:-len(extension)]
base_path = test_module_path[: -len(extension)]
file_path = base_path.replace(".", os.sep) + extension
# Check if the module path includes the tests directory name
tests_dir_name = test_config.tests_project_rootdir.name
@ -560,6 +583,7 @@ def _extract_jest_console_output(suite_elem) -> str:
return raw_content
# TODO: {Claude} we need to move to the support directory.
def parse_jest_test_xml(
test_xml_file_path: Path,
@ -611,10 +635,7 @@ def parse_jest_test_xml(
if test_file.instrumented_behavior_file_path:
# Store both the absolute path and resolved path as keys
abs_path = str(test_file.instrumented_behavior_file_path.resolve())
instrumented_path_lookup[abs_path] = (
test_file.instrumented_behavior_file_path,
test_file.test_type,
)
instrumented_path_lookup[abs_path] = (test_file.instrumented_behavior_file_path, test_file.test_type)
# Also store the string representation in case of minor path differences
instrumented_path_lookup[str(test_file.instrumented_behavior_file_path)] = (
test_file.instrumented_behavior_file_path,

View file

@ -9,9 +9,7 @@ from pydantic.dataclasses import dataclass
from codeflash.languages import current_language_support, is_javascript
def get_test_file_path(
test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit"
) -> Path:
def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path:
assert test_type in {"unit", "inspired", "replay", "perf"}
function_name = function_name.replace(".", "_")
# Use appropriate file extension based on language

View file

@ -7,7 +7,7 @@ from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer
def test_benchmark_extract(benchmark)->None:
def test_benchmark_extract(benchmark) -> None:
file_path = Path(__file__).parent.parent.parent.resolve() / "codeflash"
opt = Optimizer(
Namespace(
@ -28,4 +28,4 @@ def test_benchmark_extract(benchmark)->None:
ending_line=None,
)
benchmark(get_code_optimization_context,function_to_optimize, opt.args.project_root)
benchmark(get_code_optimization_context, function_to_optimize, opt.args.project_root)

View file

@ -14,6 +14,8 @@ def test_benchmark_code_to_optimize_test_discovery(benchmark) -> None:
tests_project_rootdir=tests_path.parent,
)
benchmark(discover_unit_tests, test_config)
def test_benchmark_codeflash_test_discovery(benchmark) -> None:
project_path = Path(__file__).parent.parent.parent.resolve() / "codeflash"
tests_path = project_path / "tests"

View file

@ -60,11 +60,7 @@ def run_merge_benchmark(count=100):
test_results_xml, test_results_bin = generate_test_invocations(count)
# Perform the merge operation that will be benchmarked
merge_test_results(
xml_test_results=test_results_xml,
bin_test_results=test_results_bin,
test_framework="unittest"
)
merge_test_results(xml_test_results=test_results_xml, bin_test_results=test_results_bin, test_framework="unittest")
def test_benchmark_merge_test_results(benchmark):

View file

@ -2,15 +2,14 @@ from __future__ import annotations
import configparser
import os
import stat
from pathlib import Path
from unittest.mock import patch
import pytest
import tomlkit
from codeflash.code_utils.code_utils import custom_addopts
def test_custom_addopts_modifies_and_restores_dotini_file(tmp_path: Path) -> None:
"""Verify that custom_addopts correctly modifies and then restores a pytest.ini file."""
# Create a dummy pytest.ini file
@ -32,6 +31,7 @@ def test_custom_addopts_modifies_and_restores_dotini_file(tmp_path: Path) -> Non
restored_content = config_file.read_text()
assert restored_content.strip() == original_content.strip()
def test_custom_addopts_modifies_and_restores_ini_file(tmp_path: Path) -> None:
"""Verify that custom_addopts correctly modifies and then restores a pytest.ini file."""
# Create a dummy pytest.ini file
@ -60,9 +60,7 @@ def test_custom_addopts_modifies_and_restores_toml_file(tmp_path: Path) -> None:
config_file = tmp_path / "pyproject.toml"
os.chdir(tmp_path)
original_addopts = "-v --cov=./src --junitxml=report.xml"
original_content_dict = {
"tool": {"pytest": {"ini_options": {"addopts": original_addopts}}}
}
original_content_dict = {"tool": {"pytest": {"ini_options": {"addopts": original_addopts}}}}
original_content = tomlkit.dumps(original_content_dict)
config_file.write_text(original_content)
@ -97,6 +95,7 @@ def test_custom_addopts_handles_no_addopts(tmp_path: Path) -> None:
content_after_context = config_file.read_text()
assert content_after_context == original_content
def test_custom_addopts_handles_no_relevant_files(tmp_path: Path) -> None:
"""Ensure custom_addopts runs without error when no config files are found."""
# No config files created in tmp_path
@ -151,9 +150,7 @@ def test_custom_addopts_with_multiple_config_files(tmp_path: Path) -> None:
# Create pyproject.toml
toml_file = tmp_path / "pyproject.toml"
toml_original_addopts = "-s -n auto"
toml_original_content_dict = {
"tool": {"pytest": {"ini_options": {"addopts": toml_original_addopts}}}
}
toml_original_content_dict = {"tool": {"pytest": {"ini_options": {"addopts": toml_original_addopts}}}}
toml_original_content = tomlkit.dumps(toml_original_content_dict)
toml_file.write_text(toml_original_content)
@ -182,9 +179,8 @@ def test_custom_addopts_restores_on_exception(tmp_path: Path) -> None:
config_file.write_text(original_content)
os.chdir(tmp_path)
with pytest.raises(ValueError, match="Test exception"):
with custom_addopts():
raise ValueError("Test exception")
with pytest.raises(ValueError, match="Test exception"), custom_addopts():
raise ValueError("Test exception")
restored_content = config_file.read_text()
assert restored_content.strip() == original_content.strip()

View file

@ -106,10 +106,7 @@ class TestDetectLanguage:
def test_detects_typescript_with_complex_tsconfig(self, tmp_path: Path) -> None:
"""Should detect TypeScript even with complex tsconfig."""
tsconfig = {
"compilerOptions": {"target": "ES2020", "module": "commonjs"},
"include": ["src/**/*"],
}
tsconfig = {"compilerOptions": {"target": "ES2020", "module": "commonjs"}, "include": ["src/**/*"]}
(tmp_path / "tsconfig.json").write_text(json.dumps(tsconfig))
result = detect_language(tmp_path)
@ -784,11 +781,7 @@ class TestRealWorldPackageJsonExamples:
package_json = tmp_path / "package.json"
package_json.write_text(
json.dumps(
{
"name": "@myorg/core",
"main": "./packages/core/src/index.js",
"devDependencies": {"jest": "^29.0.0"},
}
{"name": "@myorg/core", "main": "./packages/core/src/index.js", "devDependencies": {"jest": "^29.0.0"}}
)
)

View file

@ -5,8 +5,6 @@ from __future__ import annotations
import json
from pathlib import Path
import pytest
from codeflash.cli_cmds.init_javascript import (
JsPackageManager,
get_js_codeflash_install_step,

View file

@ -14,4 +14,4 @@ def set_python_language():
reset_current_language()
yield
# Reset again after test to clean up any changes
reset_current_language()
reset_current_language()

View file

@ -1 +1 @@
MY_CONSTANT = 7
MY_CONSTANT = 7

View file

@ -5,9 +5,7 @@ from end_to_end_test_utilities import TestConfig, run_codeflash_command, run_wit
def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
file_path="bubble_sort.py", function_name="sorter", min_improvement_x=0.30, no_gen_tests=True
)
config = TestConfig(file_path="bubble_sort.py", function_name="sorter", min_improvement_x=0.30, no_gen_tests=True)
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
return run_codeflash_command(cwd, config, expected_improvement_pct)

View file

@ -18,12 +18,7 @@ def run_test() -> bool:
expected_test_files=1, # At least one test file should be instrumented
)
cwd = (
pathlib.Path(__file__).parent.parent.parent
/ "code_to_optimize"
/ "js"
/ "code_to_optimize_js"
).resolve()
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "js" / "code_to_optimize_js").resolve()
return run_js_codeflash_command(cwd, config)

View file

@ -20,10 +20,7 @@ def run_test() -> bool:
)
cwd = (
pathlib.Path(__file__).parent.parent.parent
/ "code_to_optimize"
/ "js"
/ "code_to_optimize_js_esm"
pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "js" / "code_to_optimize_js_esm"
).resolve()
return run_js_codeflash_command(cwd, config)

View file

@ -18,12 +18,7 @@ def run_test() -> bool:
expected_test_files=1,
)
cwd = (
pathlib.Path(__file__).parent.parent.parent
/ "code_to_optimize"
/ "js"
/ "code_to_optimize_ts"
).resolve()
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "js" / "code_to_optimize_ts").resolve()
return run_js_codeflash_command(cwd, config)

View file

@ -86,12 +86,14 @@ def validate_coverage(stdout: str, expectations: list[CoverageExpectation]) -> b
return True
def validate_no_gen_tests(stdout: str) -> bool:
if "Generated '0' tests for" not in stdout:
logging.error("Tests generated even when flag was on")
return False
return True
def run_codeflash_command(
cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int, expected_in_stdout: list[str] = None
) -> bool:
@ -106,9 +108,9 @@ def run_codeflash_command(
command = build_command(cwd, config, test_root, config.benchmarks_root if config.benchmarks_root else None)
env = os.environ.copy()
env['PYTHONIOENCODING'] = 'utf-8'
env["PYTHONIOENCODING"] = "utf-8"
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding='utf-8'
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding="utf-8"
)
output = []
@ -131,7 +133,7 @@ def run_codeflash_command(
if not stdout_validated:
logging.error("Failed to find expected output in candidate output")
validated = False
logging.info(f"Success: Expected output found in candidate output")
logging.info("Success: Expected output found in candidate output")
return validated
@ -150,10 +152,9 @@ def build_command(
pyproject_path = cwd / "pyproject.toml"
has_codeflash_config = False
if pyproject_path.exists():
with contextlib.suppress(Exception):
with open(pyproject_path, "rb") as f:
pyproject_data = tomllib.load(f)
has_codeflash_config = "tool" in pyproject_data and "codeflash" in pyproject_data["tool"]
with contextlib.suppress(Exception), open(pyproject_path, "rb") as f:
pyproject_data = tomllib.load(f)
has_codeflash_config = "tool" in pyproject_data and "codeflash" in pyproject_data["tool"]
# Only pass --tests-root and --module-root if they're not configured in pyproject.toml
if not has_codeflash_config:
@ -206,7 +207,9 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
if config.expected_unit_tests_count is not None:
# Match the global test discovery message from optimizer.py which counts test invocations
# Format: "Discovered X existing unit tests and Y replay tests in Z.Zs at /path/to/tests"
unit_test_match = re.search(r"Discovered (\d+) existing unit tests? and \d+ replay tests? in [\d.]+s at", stdout)
unit_test_match = re.search(
r"Discovered (\d+) existing unit tests? and \d+ replay tests? in [\d.]+s at", stdout
)
if not unit_test_match:
logging.error("Could not find global unit test count")
return False
@ -250,9 +253,9 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p
clear_directory(test_root)
command = ["uv", "run", "--no-project", "-m", "codeflash.main", "optimize", "workload.py"]
env = os.environ.copy()
env['PYTHONIOENCODING'] = 'utf-8'
env["PYTHONIOENCODING"] = "utf-8"
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding='utf-8'
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding="utf-8"
)
output = []

View file

@ -10,7 +10,7 @@ import re
import shutil
import subprocess
import time
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Optional
@ -42,36 +42,20 @@ def install_npm_dependencies(cwd: pathlib.Path) -> bool:
node_modules = cwd / "node_modules"
if not node_modules.exists():
logging.info(f"Installing npm dependencies in {cwd}")
result = subprocess.run(
["npm", "install"],
cwd=str(cwd),
capture_output=True,
text=True,
)
result = subprocess.run(["npm", "install"], cwd=str(cwd), capture_output=True, text=True)
if result.returncode != 0:
logging.error(f"npm install failed: {result.stderr}")
return False
return True
def build_js_command(
cwd: pathlib.Path,
config: JSTestConfig,
) -> list[str]:
def build_js_command(cwd: pathlib.Path, config: JSTestConfig) -> list[str]:
"""Build the codeflash CLI command for JS/TS optimization."""
# JS projects are at code_to_optimize/js/code_to_optimize_*, which is 3 levels deep
# So we need ../../../codeflash/main.py to get to the root
python_path = "../../../codeflash/main.py"
base_command = [
"uv",
"run",
"--no-project",
python_path,
"--file",
str(config.file_path),
"--no-pr",
]
base_command = ["uv", "run", "--no-project", python_path, "--file", str(config.file_path), "--no-pr"]
if config.function_name:
base_command.extend(["--function", config.function_name])
@ -79,11 +63,7 @@ def build_js_command(
return base_command
def validate_js_output(
stdout: str,
return_code: int,
config: JSTestConfig,
) -> bool:
def validate_js_output(stdout: str, return_code: int, config: JSTestConfig) -> bool:
"""Validate the output of a JS/TS optimization run."""
if return_code != 0:
logging.error(f"Command returned exit code {return_code} instead of 0")
@ -104,15 +84,11 @@ def validate_js_output(
logging.info(f"Performance improvement: {improvement_pct}%; Rate: {improvement_x}x")
if improvement_pct <= config.expected_improvement_pct:
logging.error(
f"Performance improvement {improvement_pct}% not above {config.expected_improvement_pct}%"
)
logging.error(f"Performance improvement {improvement_pct}% not above {config.expected_improvement_pct}%")
return False
if improvement_x <= config.min_improvement_x:
logging.error(
f"Performance improvement rate {improvement_x}x not above {config.min_improvement_x}x"
)
logging.error(f"Performance improvement rate {improvement_x}x not above {config.min_improvement_x}x")
return False
if config.expected_test_files is not None:
@ -124,19 +100,14 @@ def validate_js_output(
num_test_files = int(test_files_match.group(1))
if num_test_files < config.expected_test_files:
logging.error(
f"Expected at least {config.expected_test_files} test files, found {num_test_files}"
)
logging.error(f"Expected at least {config.expected_test_files} test files, found {num_test_files}")
return False
logging.info(f"Success: Performance improvement is {improvement_pct}%")
return True
def run_js_codeflash_command(
cwd: pathlib.Path,
config: JSTestConfig,
) -> bool:
def run_js_codeflash_command(cwd: pathlib.Path, config: JSTestConfig) -> bool:
"""Run codeflash optimization on a JavaScript/TypeScript project."""
logging.basicConfig(level=logging.INFO)
@ -159,13 +130,7 @@ def run_js_codeflash_command(
logging.info(f"Running: {' '.join(command)}")
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
cwd=str(cwd),
env=env,
encoding="utf-8",
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding="utf-8"
)
output = []
@ -210,4 +175,4 @@ def run_with_retries(test_func, *args, **kwargs) -> int:
logging.error("Test failed after all retries")
return 1
return 1
return 1

View file

@ -9,6 +9,7 @@ Examples:
python run_js_e2e_tests.py # Run all tests sequentially
python run_js_e2e_tests.py --test fibonacci # Run only fibonacci tests
python run_js_e2e_tests.py --parallel # Run tests in parallel
"""
import argparse
@ -62,19 +63,14 @@ def run_single_test(test_file: str) -> TestResult:
output = f"Error running test: {e}"
duration = time.time() - start_time
return TestResult(
name=test_file.replace(".py", ""),
success=success,
duration=duration,
output=output,
)
return TestResult(name=test_file.replace(".py", ""), success=success, duration=duration, output=output)
def run_tests_sequential(tests: list[str]) -> list[TestResult]:
"""Run tests sequentially."""
results = []
for test in tests:
print(f"\n{'='*60}")
print(f"\n{'=' * 60}")
print(f"Running: {test}")
print("=" * 60)
result = run_single_test(test)
@ -124,22 +120,9 @@ def print_summary(results: list[TestResult]) -> None:
def main() -> int:
parser = argparse.ArgumentParser(description="Run JS/TS e2e tests")
parser.add_argument(
"--test",
type=str,
help="Run only tests matching this pattern",
)
parser.add_argument(
"--parallel",
action="store_true",
help="Run tests in parallel",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="Number of parallel workers (default: 4)",
)
parser.add_argument("--test", type=str, help="Run only tests matching this pattern")
parser.add_argument("--parallel", action="store_true", help="Run tests in parallel")
parser.add_argument("--workers", type=int, default=4, help="Number of parallel workers (default: 4)")
args = parser.parse_args()
# Filter tests if pattern specified

View file

@ -1,13 +1,18 @@
import tempfile
from pathlib import Path
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
import tempfile
from codeflash.code_utils.code_extractor import resolve_star_import, DottedImportCollector
import libcst as cst
from codeflash.code_utils.code_extractor import (
DottedImportCollector,
add_needed_imports_from_module,
find_preexisting_objects,
resolve_star_import,
)
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.models.models import FunctionParent
def test_add_needed_imports_from_module0() -> None:
src_module = '''import ast
import logging
@ -127,8 +132,9 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
new_module = add_needed_imports_from_module(src_module, dst_module, src_path, dst_path, project_root)
assert new_module == expected
def test_duplicated_imports() -> None:
optim_code = '''from dataclasses import dataclass
optim_code = """from dataclasses import dataclass
from recce.adapter.base import BaseAdapter
from typing import Dict, List, Optional
@ -151,9 +157,9 @@ class DbtAdapter(BaseAdapter):
parent_map[k] = [parent for parent in parents if parent in node_ids]
return parent_map
'''
"""
original_code = '''import json
original_code = """import json
import logging
import os
import uuid
@ -244,8 +250,8 @@ class DbtAdapter(BaseAdapter):
parent_map[k] = [parent for parent in parents if parent in node_ids]
return parent_map
'''
expected = '''import json
"""
expected = """import json
import logging
import os
import uuid
@ -340,7 +346,7 @@ class DbtAdapter(BaseAdapter):
parent_map[k] = [parent for parent in parents if parent in node_ids]
return parent_map
'''
"""
function_name: str = "DbtAdapter.build_parent_map"
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
@ -355,14 +361,12 @@ class DbtAdapter(BaseAdapter):
assert new_code == expected
def test_resolve_star_import_with_all_defined():
"""Test resolve_star_import when __all__ is explicitly defined."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
test_module = project_root / 'test_module.py'
test_module = project_root / "test_module.py"
# Create a test module with __all__ definition
test_module.write_text('''
__all__ = ['public_function', 'PublicClass']
@ -380,9 +384,9 @@ class AnotherPublicClass:
"""Not in __all__ so should be excluded."""
pass
''')
symbols = resolve_star_import('test_module', project_root)
expected_symbols = {'public_function', 'PublicClass'}
symbols = resolve_star_import("test_module", project_root)
expected_symbols = {"public_function", "PublicClass"}
assert symbols == expected_symbols
@ -390,10 +394,10 @@ def test_resolve_star_import_without_all_defined():
"""Test resolve_star_import when __all__ is not defined - should include all public symbols."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
test_module = project_root / 'test_module.py'
test_module = project_root / "test_module.py"
# Create a test module without __all__ definition
test_module.write_text('''
test_module.write_text("""
def public_func():
pass
@ -405,10 +409,10 @@ class PublicClass:
PUBLIC_VAR = 42
_private_var = 'secret'
''')
symbols = resolve_star_import('test_module', project_root)
expected_symbols = {'public_func', 'PublicClass', 'PUBLIC_VAR'}
""")
symbols = resolve_star_import("test_module", project_root)
expected_symbols = {"public_func", "PublicClass", "PUBLIC_VAR"}
assert symbols == expected_symbols
@ -416,26 +420,26 @@ def test_resolve_star_import_nonexistent_module():
"""Test resolve_star_import with non-existent module - should return empty set."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
symbols = resolve_star_import('nonexistent_module', project_root)
symbols = resolve_star_import("nonexistent_module", project_root)
assert symbols == set()
def test_dotted_import_collector_skips_star_imports():
"""Test that DottedImportCollector correctly skips star imports."""
code_with_star_import = '''
code_with_star_import = """
from typing import *
from pathlib import Path
from collections import defaultdict
import os
'''
"""
module = cst.parse_module(code_with_star_import)
collector = DottedImportCollector()
module.visit(collector)
# Should collect regular imports but skip the star import
expected_imports = {'collections.defaultdict', 'os', 'pathlib.Path'}
expected_imports = {"collections.defaultdict", "os", "pathlib.Path"}
assert collector.imports == expected_imports
@ -443,10 +447,10 @@ def test_add_needed_imports_with_star_import_resolution():
"""Test add_needed_imports_from_module correctly handles star imports by resolving them."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
# Create a source module that exports symbols
src_module = project_root / 'source_module.py'
src_module.write_text('''
src_module = project_root / "source_module.py"
src_module.write_text("""
__all__ = ['UtilFunction', 'HelperClass']
def UtilFunction():
@ -454,40 +458,38 @@ def UtilFunction():
class HelperClass:
pass
''')
""")
# Create source code that uses star import
src_code = '''
src_code = """
from source_module import *
def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''
"""
# Destination code that needs the imports resolved
dst_code = '''
dst_code = """
def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''
src_path = project_root / 'src.py'
dst_path = project_root / 'dst.py'
"""
src_path = project_root / "src.py"
dst_path = project_root / "dst.py"
src_path.write_text(src_code)
result = add_needed_imports_from_module(
src_code, dst_code, src_path, dst_path, project_root
)
result = add_needed_imports_from_module(src_code, dst_code, src_path, dst_path, project_root)
# The result should have individual imports instead of star import
expected_result = '''from source_module import HelperClass, UtilFunction
expected_result = """from source_module import HelperClass, UtilFunction
def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''
"""
assert result == expected_result

File diff suppressed because it is too large Load diff

View file

@ -157,10 +157,7 @@ class TestParseConcurrencyMetrics:
!@######CONC:test_module:TestClass:test_func:my_async_func:1:50000000:10000000:5######@!
More output here
"""
test_results = TestResults(
test_results=[],
perf_stdout=perf_stdout,
)
test_results = TestResults(test_results=[], perf_stdout=perf_stdout)
metrics = parse_concurrency_metrics(test_results, "my_async_func")
@ -177,10 +174,7 @@ More output here
!@######CONC:test_module:TestClass:test_func:target_func:2:60000000:10000000:5######@!
!@######CONC:test_module:TestClass:test_func:other_func:1:30000000:15000000:5######@!
"""
test_results = TestResults(
test_results=[],
perf_stdout=perf_stdout,
)
test_results = TestResults(test_results=[], perf_stdout=perf_stdout)
metrics = parse_concurrency_metrics(test_results, "target_func")
@ -195,10 +189,7 @@ More output here
"""Test parsing when function name doesn't match."""
perf_stdout = """!@######CONC:test_module:TestClass:test_func:other_func:1:50000000:10000000:5######@!
"""
test_results = TestResults(
test_results=[],
perf_stdout=perf_stdout,
)
test_results = TestResults(test_results=[], perf_stdout=perf_stdout)
metrics = parse_concurrency_metrics(test_results, "nonexistent_func")
@ -206,10 +197,7 @@ More output here
def test_parse_concurrency_metrics_empty_stdout(self):
"""Test parsing with empty stdout."""
test_results = TestResults(
test_results=[],
perf_stdout="",
)
test_results = TestResults(test_results=[], perf_stdout="")
metrics = parse_concurrency_metrics(test_results, "any_func")
@ -217,10 +205,7 @@ More output here
def test_parse_concurrency_metrics_none_stdout(self):
"""Test parsing with None stdout."""
test_results = TestResults(
test_results=[],
perf_stdout=None,
)
test_results = TestResults(test_results=[], perf_stdout=None)
metrics = parse_concurrency_metrics(test_results, "any_func")
@ -293,8 +278,7 @@ class TestConcurrencyRatioComparison:
# Non-blocking should have significantly higher concurrency ratio
assert nonblocking_ratio > blocking_ratio, (
f"Non-blocking ratio ({nonblocking_ratio:.2f}) should be greater than "
f"blocking ratio ({blocking_ratio:.2f})"
f"Non-blocking ratio ({nonblocking_ratio:.2f}) should be greater than blocking ratio ({blocking_ratio:.2f})"
)
# The difference should be substantial (non-blocking should be at least 2x better)

View file

@ -1,6 +1,7 @@
import sys
import tempfile
from pathlib import Path
import sys
import pytest
from codeflash.discovery.functions_to_optimize import (
@ -31,13 +32,13 @@ async def async_function_without_return():
def regular_function():
return 10
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(async_function)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "async_function_with_return" in function_names
assert "regular_function" in function_names
assert "async_function_without_return" not in function_names
@ -58,21 +59,21 @@ class AsyncClass:
def sync_method(self):
return "sync result"
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(code_with_async_method)
functions_found = find_all_functions_in_file(file_path)
found_functions = functions_found[file_path]
function_names = [fn.function_name for fn in found_functions]
qualified_names = [fn.qualified_name for fn in found_functions]
assert "async_method" in function_names
assert "AsyncClass.async_method" in qualified_names
assert "sync_method" in function_names
assert "AsyncClass.sync_method" in qualified_names
assert "async_method_no_return" not in function_names
@ -92,13 +93,13 @@ def outer_sync():
return inner_async
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(nested_async)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "outer_async" in function_names
assert "outer_sync" in function_names
assert "inner_async" not in function_names
@ -122,16 +123,16 @@ class MyClass:
async def async_property(self):
return await self.get_value()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(async_decorators)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "async_static_method" in function_names
assert "async_class_method" in function_names
assert "async_property" not in function_names
@ -151,13 +152,13 @@ async def regular_async_with_return():
result = await compute()
return result
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(async_generators)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "async_generator_with_return" in function_names
assert "regular_async_with_return" in function_names
assert "async_generator_no_return" not in function_names
@ -183,23 +184,23 @@ class AsyncContainer:
async def async_classmethod(cls):
return "classmethod"
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(code)
result = inspect_top_level_functions_or_methods(file_path, "top_level_async")
assert result.is_top_level
result = inspect_top_level_functions_or_methods(file_path, "async_method", class_name="AsyncContainer")
assert result.is_top_level
result = inspect_top_level_functions_or_methods(file_path, "nested_async", class_name="AsyncContainer")
assert not result.is_top_level
result = inspect_top_level_functions_or_methods(file_path, "async_static", class_name="AsyncContainer")
assert result.is_top_level
assert result.is_staticmethod
result = inspect_top_level_functions_or_methods(file_path, "async_classmethod", class_name="AsyncContainer")
assert result.is_top_level
assert result.is_classmethod
@ -224,17 +225,14 @@ class MixedClass:
def sync_method(self):
return self.operation()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(mixed_code)
test_config = TestConfig(
tests_root="tests",
project_root_path=".",
test_framework="pytest",
tests_project_rootdir=Path()
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
)
functions, functions_count, _ = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
@ -245,15 +243,15 @@ class MixedClass:
project_root=file_path.parent,
module_root=file_path.parent,
)
assert functions_count == 4
function_names = [fn.function_name for fn in functions[file_path]]
assert "async_func_one" in function_names
assert "sync_func_one" in function_names
assert "async_method" in function_names
assert "sync_method" in function_names
assert "async_func_two" not in function_names
@ -277,17 +275,14 @@ class MixedClass:
def sync_method(self):
return self.operation()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(mixed_code)
test_config = TestConfig(
tests_root="tests",
project_root_path=".",
test_framework="pytest",
tests_project_rootdir=Path()
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
)
functions, functions_count, _ = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
@ -298,10 +293,10 @@ class MixedClass:
project_root=file_path.parent,
module_root=file_path.parent,
)
# Now async functions are always included, so we expect 4 functions (not 2)
assert functions_count == 4
function_names = [fn.function_name for fn in functions[file_path]]
assert "sync_func_one" in function_names
assert "sync_method" in function_names
@ -327,13 +322,13 @@ async def module_level_async():
return 3
return LocalClass()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(complex_structure)
functions_found = find_all_functions_in_file(file_path)
found_functions = functions_found[file_path]
for fn in found_functions:
if fn.function_name == "outer_method":
assert len(fn.parents) == 1
@ -345,4 +340,4 @@ async def module_level_async():
assert fn.parents[1].name == "InnerClass"
elif fn.function_name == "module_level_async":
assert len(fn.parents) == 0
assert fn.qualified_name == "module_level_async"
assert fn.qualified_name == "module_level_async"

View file

@ -7,11 +7,15 @@ from pathlib import Path
import pytest
from codeflash.code_utils.instrument_existing_tests import (
add_async_decorator_to_function,
inject_profiling_into_existing_test,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode, TestType
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function, inject_profiling_into_existing_test
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_bubble_sort_behavior_results() -> None:
@ -51,15 +55,16 @@ async def test_async_sort():
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
# For async functions, instrument the source module directly with decorators
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR
)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
assert source_success
# Verify the file was modified
instrumented_source = fto_path.read_text("utf-8")
assert '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' in instrumented_source
assert (
'''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n'''
in instrumented_source
)
# Add codeflash capture
instrument_codeflash_capture(func, {}, tests_root)
@ -122,7 +127,6 @@ async def test_async_sort():
expected_stdout = "codeflash stdout: Async sorting list\nresult: [0, 1, 2, 3, 4, 5]\n"
assert expected_stdout == results_list[0].stdout
assert results_list[1].id.function_getting_tested == "async_sorter"
assert results_list[1].id.test_function_name == "test_async_sort"
assert results_list[1].did_pass
@ -178,12 +182,10 @@ async def test_async_class_sort():
is_async=True,
)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR
)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
assert source_success
# Verify the file was modified
instrumented_source = fto_path.read_text("utf-8")
assert "@codeflash_behavior_async" in instrumented_source
@ -233,17 +235,17 @@ async def test_async_class_sort():
testing_time=0.1,
)
assert test_results is not None
assert test_results.test_results is not None
results_list = test_results.test_results
assert len(results_list) == 2, f"Expected 2 results but got {len(results_list)}: {[r.id.function_getting_tested for r in results_list]}"
assert len(results_list) == 2, (
f"Expected 2 results but got {len(results_list)}: {[r.id.function_getting_tested for r in results_list]}"
)
init_result = results_list[0]
sorter_result = results_list[1]
assert sorter_result.id.function_getting_tested == "sorter"
assert sorter_result.id.test_class_name is None
assert sorter_result.id.test_function_name == "test_async_class_sort"
@ -292,15 +294,16 @@ async def test_async_perf():
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
# Instrument the source module with async performance decorators
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.PERFORMANCE
)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.PERFORMANCE)
assert source_success
# Verify the file was modified
instrumented_source = fto_path.read_text("utf-8")
assert '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' == instrumented_source
assert (
instrumented_source
== '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n'''
)
instrument_codeflash_capture(func, {}, tests_root)
@ -358,7 +361,6 @@ async def test_async_perf():
test_path.unlink()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_function_error_handling() -> None:
test_code = """import asyncio
@ -371,8 +373,12 @@ async def test_async_error():
with pytest.raises(ValueError, match="Test error"):
await async_error_function([1, 2, 3])"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_error_temp.py").resolve()
test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_error_perf_temp.py").resolve()
test_path = (
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_error_temp.py"
).resolve()
test_path_perf = (
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_error_perf_temp.py"
).resolve()
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/async_bubble_sort.py").resolve()
original_code = fto_path.read_text("utf-8")
@ -384,27 +390,27 @@ async def async_error_function(lst):
await asyncio.sleep(0.001) # Small delay
raise ValueError("Test error")
"""
modified_code = original_code + error_func_code
fto_path.write_text(modified_code, "utf-8")
with test_path.open("w") as f:
f.write(test_code)
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
project_root_path = (Path(__file__).parent / "..").resolve()
func = FunctionToOptimize(function_name="async_error_function", parents=[], file_path=Path(fto_path), is_async=True)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR
func = FunctionToOptimize(
function_name="async_error_function", parents=[], file_path=Path(fto_path), is_async=True
)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
assert source_success
# Verify the file was modified
instrumented_source = fto_path.read_text("utf-8")
expected_instrumented_source = """import asyncio
from typing import List, Union
@ -508,7 +514,7 @@ async def async_error_function(lst):
assert test_results is not None
assert test_results.test_results is not None
assert len(test_results.test_results) >= 1
result = test_results.test_results[0]
assert result.id.function_getting_tested == "async_error_function"
assert result.did_pass
@ -539,8 +545,12 @@ async def test_async_multi():
output2 = await async_sorter(input2)
assert output2 == [7, 9]"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_multi_temp.py").resolve()
test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_multi_perf_temp.py").resolve()
test_path = (
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_multi_temp.py"
).resolve()
test_path_perf = (
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_multi_perf_temp.py"
).resolve()
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/async_bubble_sort.py").resolve()
original_code = fto_path.read_text("utf-8")
@ -553,9 +563,7 @@ async def test_async_multi():
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR
)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
assert source_success
instrument_codeflash_capture(func, {}, tests_root)
@ -606,17 +614,17 @@ async def test_async_multi():
assert test_results is not None
assert test_results.test_results is not None
assert len(test_results.test_results) >= 2
results_list = test_results.test_results
function_calls = [r for r in results_list if r.id.function_getting_tested == "async_sorter"]
assert len(function_calls) == 2
first_call = function_calls[0]
second_call = function_calls[1]
assert first_call.stdout == "codeflash stdout: Async sorting list\nresult: [3, 4, 5]\n"
assert second_call.stdout == "codeflash stdout: Async sorting list\nresult: [7, 9]\n"
assert first_call.did_pass
assert second_call.did_pass
assert first_call.runtime is None or first_call.runtime >= 0
@ -655,7 +663,9 @@ async def test_async_edge_cases():
assert result_sorted == [1, 2, 3, 4]"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_edge_temp.py").resolve()
test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_edge_perf_temp.py").resolve()
test_path_perf = (
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_edge_perf_temp.py"
).resolve()
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/async_bubble_sort.py").resolve()
original_code = fto_path.read_text("utf-8")
@ -668,9 +678,7 @@ async def test_async_edge_cases():
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR
)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
assert source_success
instrument_codeflash_capture(func, {}, tests_root)
@ -721,20 +729,20 @@ async def test_async_edge_cases():
assert test_results is not None
assert test_results.test_results is not None
assert len(test_results.test_results) >= 3 # 3 function calls for edge cases
results_list = test_results.test_results
function_calls = [r for r in results_list if r.id.function_getting_tested == "async_sorter"]
assert len(function_calls) == 3
# Verify all calls passed
for call in function_calls:
assert call.did_pass
assert call.runtime is None or call.runtime >= 0
empty_call = function_calls[0]
single_call = function_calls[1]
sorted_call = function_calls[2]
assert empty_call.stdout == "codeflash stdout: Async sorting list\nresult: []\n"
assert single_call.stdout == "codeflash stdout: Async sorting list\nresult: [42]\n"
assert sorted_call.stdout == "codeflash stdout: Async sorting list\nresult: [1, 2, 3, 4]\n"
@ -761,7 +769,7 @@ def test_sync_function_behavior_in_async_test_environment() -> None:
print(f"result: {result}")
return result
"""
test_code = """from code_to_optimize.sync_bubble_sort import sync_sorter
@ -774,26 +782,32 @@ def test_sync_sort():
output = sync_sorter(input)
assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_sync_in_async_temp.py").resolve()
test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_sync_in_async_perf_temp.py").resolve()
test_path = (
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_sync_in_async_temp.py"
).resolve()
test_path_perf = (
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_sync_in_async_perf_temp.py"
).resolve()
sync_fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/sync_bubble_sort.py").resolve()
try:
with sync_fto_path.open("w") as f:
f.write(sync_sorter_code)
with test_path.open("w") as f:
f.write(test_code)
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
project_root_path = (Path(__file__).parent / "..").resolve()
func = FunctionToOptimize(function_name="sync_sorter", parents=[], file_path=Path(sync_fto_path), is_async=False)
func = FunctionToOptimize(
function_name="sync_sorter", parents=[], file_path=Path(sync_fto_path), is_async=False
)
original_cwd = os.getcwd()
run_cwd = project_root_path
os.chdir(run_cwd)
success, instrumented_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(6, 13), CodePosition(10, 13)], # Lines where sync_sorter is called
@ -802,10 +816,10 @@ def test_sync_sort():
mode=TestingMode.BEHAVIOR,
)
os.chdir(original_cwd)
assert success
assert instrumented_test is not None
with test_path.open("w") as f:
f.write(instrumented_test)
@ -856,7 +870,7 @@ def test_sync_sort():
assert test_results is not None
assert test_results.test_results is not None
results_list = test_results.test_results
assert results_list[0].id.function_getting_tested == "sync_sorter"
assert results_list[0].id.iteration_id == "1_0"
@ -935,7 +949,7 @@ async def async_merge_sort(lst: List[Union[int, float]]) -> List[Union[int, floa
return result
"""
test_code = """import asyncio
import pytest
from code_to_optimize.mixed_sort import sync_quick_sort, async_merge_sort
@ -954,27 +968,29 @@ async def test_mixed_sorting():
assert async_output == [2, 3, 5, 6, 9]"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_mixed_sort_temp.py").resolve()
test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_mixed_sort_perf_temp.py").resolve()
test_path_perf = (
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_mixed_sort_perf_temp.py"
).resolve()
mixed_fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/mixed_sort.py").resolve()
try:
with mixed_fto_path.open("w") as f:
f.write(mixed_module_code)
with test_path.open("w") as f:
f.write(test_code)
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
project_root_path = (Path(__file__).parent / "..").resolve()
async_func = FunctionToOptimize(function_name="async_merge_sort", parents=[], file_path=Path(mixed_fto_path), is_async=True)
source_success = add_async_decorator_to_function(
mixed_fto_path, async_func, TestingMode.BEHAVIOR
async_func = FunctionToOptimize(
function_name="async_merge_sort", parents=[], file_path=Path(mixed_fto_path), is_async=True
)
source_success = add_async_decorator_to_function(mixed_fto_path, async_func, TestingMode.BEHAVIOR)
assert source_success
# Verify the file was modified
instrumented_source = mixed_fto_path.read_text("utf-8")
assert "@codeflash_behavior_async" in instrumented_source
@ -1027,11 +1043,11 @@ async def test_mixed_sorting():
assert test_results is not None
assert test_results.test_results is not None
results_list = test_results.test_results
async_calls = [r for r in results_list if r.id.function_getting_tested == "async_merge_sort"]
assert len(async_calls) >= 1
for call in async_calls:
assert call.did_pass
assert call.runtime is None or call.runtime >= 0
@ -1043,4 +1059,4 @@ async def test_mixed_sorting():
if test_path.exists():
test_path.unlink()
if test_path_perf.exists():
test_path_perf.unlink()
test_path_perf.unlink()

View file

@ -4,22 +4,17 @@ import asyncio
import os
import sqlite3
import sys
import tempfile
from pathlib import Path
import pytest
import dill as pickle
import pytest
from codeflash.code_utils.codeflash_wrap_decorator import (
codeflash_behavior_async,
codeflash_performance_async,
)
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async, codeflash_performance_async
from codeflash.verification.codeflash_capture import VerificationType
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
class TestAsyncWrapperSQLiteValidation:
@pytest.fixture
def test_env_setup(self, request):
original_env = {}
@ -31,13 +26,13 @@ class TestAsyncWrapperSQLiteValidation:
"CODEFLASH_TEST_FUNCTION": request.node.name,
"CODEFLASH_CURRENT_LINE_ID": "test_unit",
}
for key, value in test_env.items():
original_env[key] = os.environ.get(key)
os.environ[key] = value
yield test_env
for key, original_value in original_env.items():
if original_value is None:
os.environ.pop(key, None)
@ -48,45 +43,54 @@ class TestAsyncWrapperSQLiteValidation:
def temp_db_path(self, test_env_setup):
iteration = test_env_setup["CODEFLASH_TEST_ITERATION"]
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
yield db_path
if db_path.exists():
db_path.unlink()
@pytest.mark.asyncio
async def test_behavior_async_basic_function(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def simple_async_add(a: int, b: int) -> int:
await asyncio.sleep(0.001)
return a + b
os.environ['CODEFLASH_CURRENT_LINE_ID'] = 'simple_async_add_59'
os.environ["CODEFLASH_CURRENT_LINE_ID"] = "simple_async_add_59"
result = await simple_async_add(5, 3)
assert result == 8
assert temp_db_path.exists()
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_results'")
assert cur.fetchone() is not None
cur.execute("SELECT * FROM test_results")
rows = cur.fetchall()
assert len(rows) == 1
row = rows[0]
(test_module_path, test_class_name, test_function_name, function_getting_tested,
loop_index, iteration_id, runtime, return_value_blob, verification_type) = row
(
test_module_path,
test_class_name,
test_function_name,
function_getting_tested,
loop_index,
iteration_id,
runtime,
return_value_blob,
verification_type,
) = row
assert test_module_path == __name__
assert test_class_name == "TestAsyncWrapperSQLiteValidation"
assert test_class_name == "TestAsyncWrapperSQLiteValidation"
assert test_function_name == "test_behavior_async_basic_function"
assert function_getting_tested == "simple_async_add"
assert loop_index == 1
@ -94,19 +98,18 @@ class TestAsyncWrapperSQLiteValidation:
assert iteration_id.startswith("simple_async_add_") and iteration_id.endswith("_0")
assert runtime > 0
assert verification_type == VerificationType.FUNCTION_CALL.value
unpickled_data = pickle.loads(return_value_blob)
args, kwargs, return_val = unpickled_data
assert args == (5, 3)
assert kwargs == {}
assert return_val == 8
con.close()
@pytest.mark.asyncio
async def test_behavior_async_exception_handling(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def async_divide(a: int, b: int) -> float:
await asyncio.sleep(0.001)
@ -116,35 +119,35 @@ class TestAsyncWrapperSQLiteValidation:
result = await async_divide(10, 2)
assert result == 5.0
with pytest.raises(ValueError, match="Cannot divide by zero"):
await async_divide(10, 0)
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT * FROM test_results ORDER BY iteration_id")
rows = cur.fetchall()
assert len(rows) == 2
success_row = rows[0]
success_data = pickle.loads(success_row[7]) # return_value_blob
args, kwargs, return_val = success_data
assert args == (10, 2)
assert return_val == 5.0
# Check exception record
exception_row = rows[1]
exception_data = pickle.loads(exception_row[7]) # return_value_blob
assert isinstance(exception_data, ValueError)
assert str(exception_data) == "Cannot divide by zero"
con.close()
@pytest.mark.asyncio
async def test_performance_async_no_database_storage(self, test_env_setup, temp_db_path, capsys):
"""Test performance async decorator doesn't store to database."""
@codeflash_performance_async
async def async_multiply(a: int, b: int) -> int:
"""Async function for performance testing."""
@ -152,27 +155,26 @@ class TestAsyncWrapperSQLiteValidation:
return a * b
result = await async_multiply(4, 7)
assert result == 28
assert not temp_db_path.exists()
captured = capsys.readouterr()
output_lines = captured.out.strip().split('\n')
output_lines = captured.out.strip().split("\n")
assert len([line for line in output_lines if "!$######" in line]) == 1
assert len([line for line in output_lines if "!######" in line and "######!" in line]) == 1
closing_tag = [line for line in output_lines if "!######" in line and "######!" in line][0]
assert "async_multiply" in closing_tag
timing_part = closing_tag.split(":")[-1].replace("######!", "")
timing_value = int(timing_part)
assert timing_value > 0 # Should have positive timing
@pytest.mark.asyncio
async def test_multiple_calls_indexing(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def async_increment(value: int) -> int:
await asyncio.sleep(0.001)
@ -183,105 +185,86 @@ class TestAsyncWrapperSQLiteValidation:
for i in range(3):
result = await async_increment(i)
results.append(result)
assert results == [1, 2, 3]
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT iteration_id, return_value FROM test_results ORDER BY iteration_id")
rows = cur.fetchall()
assert len(rows) == 3
actual_ids = [row[0] for row in rows]
assert len(actual_ids) == 3
base_pattern = actual_ids[0].rsplit('_', 1)[0] # e.g., "async_increment_199"
base_pattern = actual_ids[0].rsplit("_", 1)[0] # e.g., "async_increment_199"
expected_pattern = [f"{base_pattern}_{i}" for i in range(3)]
assert actual_ids == expected_pattern
for i, (_, return_value_blob) in enumerate(rows):
args, kwargs, return_val = pickle.loads(return_value_blob)
assert args == (i,)
assert return_val == i + 1
con.close()
@pytest.mark.asyncio
async def test_complex_async_function_with_kwargs(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def complex_async_func(
pos_arg: str,
*args: int,
keyword_arg: str = "default",
**kwargs: str
) -> dict:
async def complex_async_func(pos_arg: str, *args: int, keyword_arg: str = "default", **kwargs: str) -> dict:
await asyncio.sleep(0.001)
return {
"pos_arg": pos_arg,
"args": args,
"keyword_arg": keyword_arg,
"kwargs": kwargs,
}
return {"pos_arg": pos_arg, "args": args, "keyword_arg": keyword_arg, "kwargs": kwargs}
result = await complex_async_func("hello", 1, 2, 3, keyword_arg="custom", extra1="value1", extra2="value2")
result = await complex_async_func(
"hello",
1, 2, 3,
keyword_arg="custom",
extra1="value1",
extra2="value2"
)
expected_result = {
"pos_arg": "hello",
"args": (1, 2, 3),
"keyword_arg": "custom",
"kwargs": {"extra1": "value1", "extra2": "value2"}
"kwargs": {"extra1": "value1", "extra2": "value2"},
}
assert result == expected_result
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT return_value FROM test_results")
row = cur.fetchone()
stored_args, stored_kwargs, stored_result = pickle.loads(row[0])
assert stored_args == ("hello", 1, 2, 3)
assert stored_kwargs == {"keyword_arg": "custom", "extra1": "value1", "extra2": "value2"}
assert stored_result == expected_result
con.close()
@pytest.mark.asyncio
async def test_database_schema_validation(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def schema_test_func() -> str:
return "schema_test"
await schema_test_func()
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("PRAGMA table_info(test_results)")
columns = cur.fetchall()
expected_columns = [
(0, 'test_module_path', 'TEXT', 0, None, 0),
(1, 'test_class_name', 'TEXT', 0, None, 0),
(2, 'test_function_name', 'TEXT', 0, None, 0),
(3, 'function_getting_tested', 'TEXT', 0, None, 0),
(4, 'loop_index', 'INTEGER', 0, None, 0),
(5, 'iteration_id', 'TEXT', 0, None, 0),
(6, 'runtime', 'INTEGER', 0, None, 0),
(7, 'return_value', 'BLOB', 0, None, 0),
(8, 'verification_type', 'TEXT', 0, None, 0)
(0, "test_module_path", "TEXT", 0, None, 0),
(1, "test_class_name", "TEXT", 0, None, 0),
(2, "test_function_name", "TEXT", 0, None, 0),
(3, "function_getting_tested", "TEXT", 0, None, 0),
(4, "loop_index", "INTEGER", 0, None, 0),
(5, "iteration_id", "TEXT", 0, None, 0),
(6, "runtime", "INTEGER", 0, None, 0),
(7, "return_value", "BLOB", 0, None, 0),
(8, "verification_type", "TEXT", 0, None, 0),
]
assert columns == expected_columns
con.close()

View file

@ -3969,7 +3969,7 @@ def test_dependency_classes_kept_in_read_writable_context(tmp_path: Path) -> Non
as types or in match statements, those classes are included in the optimization
context, even though they don't contain any target functions.
"""
code = '''
code = """
import dataclasses
import enum
import typing as t
@ -4013,20 +4013,13 @@ def reify_channel_message(data: dict) -> MessageIn:
return MessageInBeginExfiltration()
case _:
raise ValueError(f"Unknown message kind: '{kind}'")
'''
"""
code_path = tmp_path / "message.py"
code_path.write_text(code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="reify_channel_message",
file_path=code_path,
parents=[],
)
func_to_optimize = FunctionToOptimize(function_name="reify_channel_message", file_path=code_path, parents=[])
code_ctx = get_code_optimization_context(
function_to_optimize=func_to_optimize,
project_root_path=tmp_path,
)
code_ctx = get_code_optimization_context(function_to_optimize=func_to_optimize, project_root_path=tmp_path)
expected_read_writable = """
```python:message.py
@ -4098,10 +4091,7 @@ class MyCustomDict(UserDict):
parents=[FunctionParent(name="MyCustomDict", type="ClassDef")],
)
code_ctx = get_code_optimization_context(
function_to_optimize=func_to_optimize,
project_root_path=tmp_path,
)
code_ctx = get_code_optimization_context(function_to_optimize=func_to_optimize, project_root_path=tmp_path)
# The testgen context should include the UserDict __init__ method
testgen_context = code_ctx.testgen_context.markdown
@ -4146,9 +4136,7 @@ def second_helper():
file_path.write_text(code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")]
)
# Use a small optim_token_limit that allows read-writable but not read-only
@ -4203,11 +4191,7 @@ def target_function(obj: TypeClass) -> int:
main_path = package_dir / "main.py"
main_path.write_text(main_code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="target_function",
file_path=main_path,
parents=[],
)
func_to_optimize = FunctionToOptimize(function_name="target_function", file_path=main_path, parents=[])
# Use a testgen_token_limit that:
# - Is exceeded by full context with imported class (~1500 tokens)
@ -4251,11 +4235,7 @@ def target_function():
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="target_function",
file_path=file_path,
parents=[],
)
func_to_optimize = FunctionToOptimize(function_name="target_function", file_path=file_path, parents=[])
# Use a very small testgen_token_limit that cannot fit even the base function
with pytest.raises(ValueError, match="Testgen code context has exceeded token limit"):
@ -4383,15 +4363,10 @@ class MyClass:
file_path.write_text(code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")]
)
code_ctx = get_code_optimization_context(
function_to_optimize=func_to_optimize,
project_root_path=tmp_path,
)
code_ctx = get_code_optimization_context(function_to_optimize=func_to_optimize, project_root_path=tmp_path)
# CONFIG_VALUE should be in read-writable context since it's used by __init__
read_writable = code_ctx.read_writable_code.markdown
@ -4637,15 +4612,10 @@ class MyClass:
file_path.write_text(code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")]
)
code_ctx = get_code_optimization_context(
function_to_optimize=func_to_optimize,
project_root_path=tmp_path,
)
code_ctx = get_code_optimization_context(function_to_optimize=func_to_optimize, project_root_path=tmp_path)
# counter should be in context since __init__ uses it
read_writable = code_ctx.read_writable_code.markdown

View file

@ -5,97 +5,97 @@ from codeflash.code_utils.code_extractor import add_needed_imports_from_module
def test_add_needed_imports_with_none_aliases():
source_code = '''
source_code = """
import json
from typing import Dict as MyDict, Optional
from collections import defaultdict
'''
target_code = '''
"""
target_code = """
def target_function():
pass
'''
expected_output = '''
"""
expected_output = """
def target_function():
pass
'''
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path
project_root=temp_path,
)
assert result.strip() == expected_output.strip()
def test_add_needed_imports_complex_aliases():
source_code = '''
source_code = """
import os
import sys as system
from typing import Dict, List as MyList, Optional as Opt
from collections import defaultdict as dd, Counter
from pathlib import Path
'''
target_code = '''
"""
target_code = """
def my_function():
return "test"
'''
expected_output = '''
"""
expected_output = """
def my_function():
return "test"
'''
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path
project_root=temp_path,
)
assert result.strip() == expected_output.strip()
def test_add_needed_imports_with_usage():
source_code = '''
source_code = """
import json
from typing import Dict as MyDict, Optional
from collections import defaultdict
'''
target_code = '''
"""
target_code = """
def target_function():
data = json.loads('{"key": "value"}')
my_dict: MyDict[str, str] = {}
opt_value: Optional[str] = None
dd = defaultdict(list)
return data, my_dict, opt_value, dd
'''
expected_output = '''import json
"""
expected_output = """import json
from typing import Dict as MyDict, Optional
from collections import defaultdict
@ -105,30 +105,30 @@ def target_function():
opt_value: Optional[str] = None
dd = defaultdict(list)
return data, my_dict, opt_value, dd
'''
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path
project_root=temp_path,
)
# Assert exact expected output
assert result.strip() == expected_output.strip()
def test_litellm_router_style_imports():
source_code = '''
source_code = """
import asyncio
import copy
import json
@ -136,92 +136,92 @@ from collections import defaultdict
from typing import Dict, List, Optional, Union
from litellm.types.utils import ModelInfo
from litellm.types.utils import ModelInfo as ModelMapInfo
'''
"""
target_code = '''
def target_function():
"""Target function for testing."""
pass
'''
expected_output = '''
def target_function():
"""Target function for testing."""
pass
'''
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "complex_source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path
project_root=temp_path,
)
assert result.strip() == expected_output.strip()
def test_edge_case_none_values_in_alias_pairs():
source_code = '''
source_code = """
from typing import Dict as MyDict, List, Optional as Opt
from collections import defaultdict, Counter as cnt
from pathlib import Path
'''
target_code = '''
"""
target_code = """
def my_test_function():
return "test"
'''
expected_output = '''
"""
expected_output = """
def my_test_function():
return "test"
'''
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "edge_case_source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path
project_root=temp_path,
)
assert result.strip() == expected_output.strip()
def test_partial_import_usage():
source_code = '''
source_code = """
import os
import sys
from typing import Dict, List, Optional
from collections import defaultdict, Counter
'''
target_code = '''
"""
target_code = """
def use_some_imports():
path = os.path.join("a", "b")
my_dict: Dict[str, int] = {}
counter = Counter([1, 2, 3])
return path, my_dict, counter
'''
expected_output = '''import os
"""
expected_output = """import os
from collections import Counter
from typing import Dict
@ -230,42 +230,42 @@ def use_some_imports():
my_dict: Dict[str, int] = {}
counter = Counter([1, 2, 3])
return path, my_dict, counter
'''
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path
project_root=temp_path,
)
assert result.strip() == expected_output.strip()
def test_alias_handling():
source_code = '''
source_code = """
from typing import Dict as MyDict, List as MyList, Optional
from collections import defaultdict as dd, Counter
'''
target_code = '''
"""
target_code = """
def test_aliases():
d: MyDict[str, int] = {}
lst: MyList[str] = []
dd_instance = dd(list)
return d, lst, dd_instance
'''
expected_output = '''from collections import defaultdict as dd
"""
expected_output = """from collections import defaultdict as dd
from typing import Dict as MyDict, List as MyList
def test_aliases():
@ -273,59 +273,59 @@ def test_aliases():
lst: MyList[str] = []
dd_instance = dd(list)
return d, lst, dd_instance
'''
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path
project_root=temp_path,
)
assert result.strip() == expected_output.strip()
def test_add_needed_imports_with_nonealiases():
source_code = '''
source_code = """
import json
from typing import Dict as MyDict, Optional
from collections import defaultdict
'''
target_code = '''
"""
target_code = """
def target_function():
pass
'''
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
# This should not raise a TypeError
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path
project_root=temp_path,
)
expected_output = '''
expected_output = """
def target_function():
pass
'''
assert result.strip() == expected_output.strip()
"""
assert result.strip() == expected_output.strip()

File diff suppressed because it is too large Load diff

View file

@ -279,21 +279,23 @@ def test_path_belongs_to_site_packages_with_relative_path(monkeypatch: pytest.Mo
assert path_belongs_to_site_packages(file_path) is False
def test_path_belongs_to_site_packages_with_symlinked_site_packages(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
def test_path_belongs_to_site_packages_with_symlinked_site_packages(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
real_site_packages = tmp_path / "real_site_packages"
real_site_packages.mkdir()
symlinked_site_packages = tmp_path / "symlinked_site_packages"
symlinked_site_packages.symlink_to(real_site_packages)
package_file = real_site_packages / "some_package" / "__init__.py"
package_file.parent.mkdir()
package_file.write_text("# package file")
monkeypatch.setattr(site, "getsitepackages", lambda: [str(symlinked_site_packages)])
assert path_belongs_to_site_packages(package_file) is True
symlinked_package_file = symlinked_site_packages / "some_package" / "__init__.py"
assert path_belongs_to_site_packages(symlinked_package_file) is True
@ -301,40 +303,42 @@ def test_path_belongs_to_site_packages_with_symlinked_site_packages(monkeypatch:
def test_path_belongs_to_site_packages_with_complex_symlinks(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
real_site_packages = tmp_path / "real" / "lib" / "python3.9" / "site-packages"
real_site_packages.mkdir(parents=True)
link1 = tmp_path / "link1"
link1.symlink_to(real_site_packages.parent.parent.parent)
link2 = tmp_path / "link2"
link2 = tmp_path / "link2"
link2.symlink_to(link1)
package_file = real_site_packages / "test_package" / "module.py"
package_file.parent.mkdir()
package_file.write_text("# test module")
site_packages_via_links = link2 / "lib" / "python3.9" / "site-packages"
monkeypatch.setattr(site, "getsitepackages", lambda: [str(site_packages_via_links)])
assert path_belongs_to_site_packages(package_file) is True
file_via_links = site_packages_via_links / "test_package" / "module.py"
assert path_belongs_to_site_packages(file_via_links) is True
def test_path_belongs_to_site_packages_resolved_paths_normalization(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
def test_path_belongs_to_site_packages_resolved_paths_normalization(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
site_packages_dir = tmp_path / "lib" / "python3.9" / "site-packages"
site_packages_dir.mkdir(parents=True)
package_dir = site_packages_dir / "mypackage"
package_dir.mkdir()
package_file = package_dir / "module.py"
package_file.write_text("# module")
complex_site_packages_path = tmp_path / "lib" / "python3.9" / "other" / ".." / "site-packages" / "."
monkeypatch.setattr(site, "getsitepackages", lambda: [str(complex_site_packages_path)])
assert path_belongs_to_site_packages(package_file) is True
complex_file_path = tmp_path / "lib" / "python3.9" / "site-packages" / "other" / ".." / "mypackage" / "module.py"
assert path_belongs_to_site_packages(complex_file_path) is True
@ -374,8 +378,9 @@ def my_function():
def mock_code_context():
"""Mock CodeOptimizationContext for testing extract_dependent_function."""
from unittest.mock import MagicMock
from codeflash.models.models import CodeOptimizationContext
context = MagicMock(spec=CodeOptimizationContext)
context.preexisting_objects = []
return context
@ -393,7 +398,7 @@ def helper_function():
```
""")
assert extract_dependent_function("main_function", mock_code_context) == "helper_function"
# Test async function extraction
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
def main_function():
@ -416,7 +421,7 @@ def main_function():
```
""")
assert extract_dependent_function("main_function", mock_code_context) is False
# Multiple dependent functions
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
def main_function():
@ -443,7 +448,7 @@ def sync_helper():
```
""")
assert extract_dependent_function("async_main", mock_code_context) == "sync_helper"
# Only async functions
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
async def async_main():
@ -500,7 +505,7 @@ def test_partial_module_name2(base_dir: Path) -> None:
def test_pytest_unittest_path_resolution_with_prefix(tmp_path: Path) -> None:
"""Test path resolution when pytest includes parent directory in classname.
This handles the case where pytest's base_dir is /path/to/tests but the
classname includes the parent directory like "project.tests.unittest.test_file.TestClass".
"""
@ -509,34 +514,29 @@ def test_pytest_unittest_path_resolution_with_prefix(tmp_path: Path) -> None:
tests_root = project_root / "tests"
unittest_dir = tests_root / "unittest"
unittest_dir.mkdir(parents=True, exist_ok=True)
# Create test files
test_file = unittest_dir / "test_bubble_sort.py"
test_file.touch()
generated_test = unittest_dir / "test_sorter__unit_test_0.py"
generated_test.touch()
# Case 1: pytest reports classname with full path including "code_to_optimize.tests"
# but base_dir is .../tests (not the project root)
result = resolve_test_file_from_class_path(
"code_to_optimize.tests.unittest.test_bubble_sort.TestPigLatin",
tests_root
"code_to_optimize.tests.unittest.test_bubble_sort.TestPigLatin", tests_root
)
assert result == test_file
# Case 2: Generated test file with class name
result = resolve_test_file_from_class_path(
"code_to_optimize.tests.unittest.test_sorter__unit_test_0.TestSorter",
tests_root
"code_to_optimize.tests.unittest.test_sorter__unit_test_0.TestSorter", tests_root
)
assert result == generated_test
# Case 3: Without the class name (just the module path)
result = resolve_test_file_from_class_path(
"code_to_optimize.tests.unittest.test_bubble_sort",
tests_root
)
result = resolve_test_file_from_class_path("code_to_optimize.tests.unittest.test_bubble_sort", tests_root)
assert result == test_file
@ -546,23 +546,17 @@ def test_pytest_unittest_multiple_prefix_levels(tmp_path: Path) -> None:
base = tmp_path / "org" / "project" / "src" / "tests"
unit_dir = base / "unit"
unit_dir.mkdir(parents=True, exist_ok=True)
test_file = unit_dir / "test_example.py"
test_file.touch()
# pytest might report: org.project.src.tests.unit.test_example.TestClass
# with base_dir being .../src/tests or .../tests
result = resolve_test_file_from_class_path(
"org.project.src.tests.unit.test_example.TestClass",
base
)
result = resolve_test_file_from_class_path("org.project.src.tests.unit.test_example.TestClass", base)
assert result == test_file
# Also test with base_dir at different level
result = resolve_test_file_from_class_path(
"project.src.tests.unit.test_example.TestClass",
base
)
result = resolve_test_file_from_class_path("project.src.tests.unit.test_example.TestClass", base)
assert result == test_file
@ -570,15 +564,14 @@ def test_pytest_unittest_instrumented_files(tmp_path: Path) -> None:
"""Test path resolution for instrumented test files."""
tests_root = tmp_path / "tests" / "unittest"
tests_root.mkdir(parents=True, exist_ok=True)
# Create instrumented test file
instrumented_file = tests_root / "test_bubble_sort__perfinstrumented.py"
instrumented_file.touch()
# pytest classname includes parent directories
result = resolve_test_file_from_class_path(
"code_to_optimize.tests.unittest.test_bubble_sort__perfinstrumented.TestPigLatin",
tmp_path / "tests"
"code_to_optimize.tests.unittest.test_bubble_sort__perfinstrumented.TestPigLatin", tmp_path / "tests"
)
assert result == instrumented_file
@ -587,15 +580,12 @@ def test_pytest_unittest_nested_classes(tmp_path: Path) -> None:
"""Test path resolution with nested class names."""
tests_root = tmp_path / "tests"
tests_root.mkdir(parents=True, exist_ok=True)
test_file = tests_root / "test_nested.py"
test_file.touch()
# Some unittest frameworks use nested classes
result = resolve_test_file_from_class_path(
"project.tests.test_nested.OuterClass.InnerClass",
tests_root
)
result = resolve_test_file_from_class_path("project.tests.test_nested.OuterClass.InnerClass", tests_root)
assert result == test_file
@ -603,12 +593,9 @@ def test_pytest_unittest_no_match_returns_none(tmp_path: Path) -> None:
"""Test that non-existent files return None even with prefix stripping."""
tests_root = tmp_path / "tests"
tests_root.mkdir(parents=True, exist_ok=True)
# File doesn't exist
result = resolve_test_file_from_class_path(
"code_to_optimize.tests.unittest.nonexistent_test.TestClass",
tests_root
)
result = resolve_test_file_from_class_path("code_to_optimize.tests.unittest.nonexistent_test.TestClass", tests_root)
assert result is None
@ -617,10 +604,10 @@ def test_pytest_unittest_single_component(tmp_path: Path) -> None:
base_dir = tmp_path
test_file = base_dir / "test_simple.py"
test_file.touch()
result = file_name_from_test_module_name("test_simple", base_dir)
assert result == test_file
# With class name
result = file_name_from_test_module_name("test_simple.TestClass", base_dir)
assert result == test_file
@ -644,7 +631,7 @@ def test_generate_candidates() -> None:
"Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
"krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
"Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
"/Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py"
"/Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
}
assert generate_candidates(source_code_path) == expected_candidates

View file

@ -54,7 +54,9 @@ class MyClass:
with sample_code_path.open("w") as f:
f.write(sample_code)
result = execute_test_subprocess(
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
cwd=test_dir,
cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"],
env=os.environ.copy(),
)
assert not result.stderr
assert result.returncode == 0
@ -129,7 +131,9 @@ class MyClass:
with sample_code_path.open("w") as f:
f.write(sample_code)
result = execute_test_subprocess(
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
cwd=test_dir,
cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"],
env=os.environ.copy(),
)
assert not result.stderr
assert result.returncode == 0
@ -194,7 +198,9 @@ class MyClass:
with sample_code_path.open("w") as f:
f.write(sample_code)
result = execute_test_subprocess(
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
cwd=test_dir,
cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"],
env=os.environ.copy(),
)
assert not result.stderr
assert result.returncode == 0
@ -279,7 +285,9 @@ class MyClass:
# Run pytest as a subprocess
result = execute_test_subprocess(
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
cwd=test_dir,
cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"],
env=os.environ.copy(),
)
# Check for errors
@ -356,7 +364,9 @@ class MyClass:
with sample_code_path.open("w") as f:
f.write(sample_code)
result = execute_test_subprocess(
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
cwd=test_dir,
cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"],
env=os.environ.copy(),
)
assert not result.stderr
assert result.returncode == 0
@ -1184,6 +1194,7 @@ class MyClass:
helper_path_1.unlink(missing_ok=True)
helper_path_2.unlink(missing_ok=True)
def test_get_stack_info_env_var_fallback() -> None:
"""Test that get_test_info_from_stack falls back to environment variables when stack walking fails to find test_name.
@ -1421,8 +1432,7 @@ def calculate_portfolio_metrics(
f.write(test_code)
fto = FunctionToOptimize("calculate_portfolio_metrics", fto_file_path, parents=[])
file_path_to_helper_class = {
}
file_path_to_helper_class = {}
instrument_codeflash_capture(fto, file_path_to_helper_class, tests_root)
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
@ -1453,8 +1463,7 @@ def calculate_portfolio_metrics(
candidate_helper_code = {}
for file_path in file_path_to_helper_class:
candidate_helper_code[file_path] = Path(file_path).read_text("utf-8")
file_path_to_helper_classes = {
}
file_path_to_helper_classes = {}
instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root)
test_results, coverage_data = func_optimizer.run_and_parse_tests(
@ -1692,4 +1701,4 @@ class SlotsClass:
finally:
test_path.unlink(missing_ok=True)
sample_code_path.unlink(missing_ok=True)
sample_code_path.unlink(missing_ok=True)

View file

@ -3,6 +3,7 @@ import tempfile
from pathlib import Path
import pytest
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint, get_all_historical_functions

View file

@ -1,6 +1,6 @@
from codeflash.benchmarking.codeflash_trace import codeflash_trace
from pathlib import Path
from codeflash.code_utils.code_utils import get_run_tmp_file
@codeflash_trace
def example_function(arr):

File diff suppressed because it is too large Load diff

View file

@ -50,7 +50,9 @@ def test_speedup_critic() -> None:
total_candidate_timing=12,
)
assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True) # 20% improvement
assert speedup_critic(
candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True
) # 20% improvement
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
@ -61,7 +63,9 @@ def test_speedup_critic() -> None:
optimization_candidate_index=0,
)
assert not speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True) # 6% improvement
assert not speedup_critic(
candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True
) # 6% improvement
original_code_runtime = 100000
best_runtime_until_now = 100000
@ -75,7 +79,9 @@ def test_speedup_critic() -> None:
optimization_candidate_index=0,
)
assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True) # 6% improvement
assert speedup_critic(
candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True
) # 6% improvement
def test_generated_test_critic() -> None:
@ -418,6 +424,7 @@ def test_coverage_critic() -> None:
assert coverage_critic(failing_coverage) is False
def test_throughput_gain() -> None:
"""Test throughput_gain calculation."""
# Test basic throughput improvement
@ -458,7 +465,7 @@ def test_speedup_critic_with_async_throughput() -> None:
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
disable_gh_action_noise=True
disable_gh_action_noise=True,
)
# Test case 2: Runtime improves significantly, throughput doesn't meet threshold (should pass)
@ -478,7 +485,7 @@ def test_speedup_critic_with_async_throughput() -> None:
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
disable_gh_action_noise=True
disable_gh_action_noise=True,
)
# Test case 3: Throughput improves significantly, runtime doesn't meet threshold (should pass)
@ -498,7 +505,7 @@ def test_speedup_critic_with_async_throughput() -> None:
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
disable_gh_action_noise=True
disable_gh_action_noise=True,
)
# Test case 4: No throughput data - should fall back to runtime-only evaluation
@ -518,7 +525,7 @@ def test_speedup_critic_with_async_throughput() -> None:
best_runtime_until_now=None,
original_async_throughput=None, # No original throughput data
best_throughput_until_now=None,
disable_gh_action_noise=True
disable_gh_action_noise=True,
)
# Test case 5: Test best_throughput_until_now comparison
@ -539,7 +546,7 @@ def test_speedup_critic_with_async_throughput() -> None:
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
disable_gh_action_noise=True
disable_gh_action_noise=True,
)
# Should fail when there's a better throughput already
@ -549,7 +556,7 @@ def test_speedup_critic_with_async_throughput() -> None:
best_runtime_until_now=7000, # Better runtime already exists
original_async_throughput=original_async_throughput,
best_throughput_until_now=120, # Better throughput already exists
disable_gh_action_noise=True
disable_gh_action_noise=True,
)
# Test case 6: Zero original throughput (edge case)
@ -570,7 +577,7 @@ def test_speedup_critic_with_async_throughput() -> None:
best_runtime_until_now=None,
original_async_throughput=0, # Zero original throughput
best_throughput_until_now=None,
disable_gh_action_noise=True
disable_gh_action_noise=True,
)
@ -594,29 +601,20 @@ def test_concurrency_gain() -> None:
# Test no improvement
same = ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=10_000_000,
concurrency_factor=10,
concurrency_ratio=1.0,
sequential_time_ns=10_000_000, concurrent_time_ns=10_000_000, concurrency_factor=10, concurrency_ratio=1.0
)
assert concurrency_gain(original, same) == 0.0
# Test slight improvement
slightly_better = ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=8_000_000,
concurrency_factor=10,
concurrency_ratio=1.25,
sequential_time_ns=10_000_000, concurrent_time_ns=8_000_000, concurrency_factor=10, concurrency_ratio=1.25
)
# 25% improvement: (1.25 - 1.0) / 1.0 = 0.25
assert concurrency_gain(original, slightly_better) == 0.25
# Test zero original ratio (edge case)
zero_ratio = ConcurrencyMetrics(
sequential_time_ns=0,
concurrent_time_ns=1_000_000,
concurrency_factor=10,
concurrency_ratio=0.0,
sequential_time_ns=0, concurrent_time_ns=1_000_000, concurrency_factor=10, concurrency_ratio=0.0
)
assert concurrency_gain(zero_ratio, optimized) == 0.0
@ -628,10 +626,7 @@ def test_speedup_critic_with_concurrency_metrics() -> None:
# Original concurrency metrics (blocking code - ratio ~= 1.0)
original_concurrency = ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=10_000_000,
concurrency_factor=10,
concurrency_ratio=1.0,
sequential_time_ns=10_000_000, concurrent_time_ns=10_000_000, concurrency_factor=10, concurrency_ratio=1.0
)
# Test case 1: Concurrency improves significantly (blocking -> non-blocking)
@ -731,10 +726,7 @@ def test_speedup_critic_with_concurrency_metrics() -> None:
total_candidate_timing=10000,
async_throughput=100,
concurrency_metrics=ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=2_000_000,
concurrency_factor=10,
concurrency_ratio=5.0,
sequential_time_ns=10_000_000, concurrent_time_ns=2_000_000, concurrency_factor=10, concurrency_ratio=5.0
),
)

View file

@ -1,4 +1,3 @@
from unittest.mock import Mock
import contextlib
import os
import shutil
@ -6,6 +5,7 @@ import unittest
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from unittest.mock import Mock
from codeflash.result.create_pr import existing_tests_source_for
@ -30,46 +30,33 @@ class TestExistingTestsSourceFor:
self.mock_function_called_in_test = Mock()
self.mock_function_called_in_test.tests_in_file = Mock()
self.mock_function_called_in_test.tests_in_file.test_file = Path(__file__).resolve().parent / "test_module.py"
#Path to pyproject.toml
# Path to pyproject.toml
os.chdir(self.test_cfg.project_root_path)
def test_no_test_files_returns_empty_string(self):
"""Test that function returns empty string when no test files exist."""
function_to_tests = {}
original_runtimes = {}
optimized_runtimes = {}
result, _, _ = existing_tests_source_for(
"module.function",
function_to_tests,
self.test_cfg,
original_runtimes,
optimized_runtimes
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
)
assert result == ""
def test_single_test_with_improvement(self):
"""Test single test showing performance improvement."""
function_to_tests = {
"module.function": {self.mock_function_called_in_test}
}
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
original_runtimes = {
self.mock_invocation_id: [1000000] # 1ms in nanoseconds
}
optimized_runtimes = {
self.mock_invocation_id: [500000] # 0.5ms in nanoseconds
self.mock_invocation_id: [500000] # 0.5ms in nanoseconds
}
result, _, _ = existing_tests_source_for(
"module.function",
function_to_tests,
self.test_cfg,
original_runtimes,
optimized_runtimes
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
)
expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
@ -81,23 +68,16 @@ class TestExistingTestsSourceFor:
def test_single_test_with_regression(self):
"""Test single test showing performance regression."""
function_to_tests = {
"module.function": {self.mock_function_called_in_test}
}
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
original_runtimes = {
self.mock_invocation_id: [500000] # 0.5ms in nanoseconds
self.mock_invocation_id: [500000] # 0.5ms in nanoseconds
}
optimized_runtimes = {
self.mock_invocation_id: [1000000] # 1ms in nanoseconds
}
result, _, _ = existing_tests_source_for(
"module.function",
function_to_tests,
self.test_cfg,
original_runtimes,
optimized_runtimes
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
)
expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
@ -109,28 +89,17 @@ class TestExistingTestsSourceFor:
def test_test_without_class_name(self):
"""Test function without class name (standalone test function)."""
mock_invocation_no_class = Mock()
mock_invocation_no_class.test_module_path = "tests.test_module"
mock_invocation_no_class.test_class_name = None
mock_invocation_no_class.test_function_name = "test_standalone"
function_to_tests = {
"module.function": {self.mock_function_called_in_test}
}
original_runtimes = {
mock_invocation_no_class: [1000000]
}
optimized_runtimes = {
mock_invocation_no_class: [800000]
}
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
original_runtimes = {mock_invocation_no_class: [1000000]}
optimized_runtimes = {mock_invocation_no_class: [800000]}
result, _, _ = existing_tests_source_for(
"module.function",
function_to_tests,
self.test_cfg,
original_runtimes,
optimized_runtimes
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
)
expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
@ -142,21 +111,12 @@ class TestExistingTestsSourceFor:
def test_missing_original_runtime(self):
"""Test when original runtime is missing (shows NaN)."""
function_to_tests = {
"module.function": {self.mock_function_called_in_test}
}
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
original_runtimes = {}
optimized_runtimes = {
self.mock_invocation_id: [500000]
}
optimized_runtimes = {self.mock_invocation_id: [500000]}
result, _, _ = existing_tests_source_for(
"module.function",
function_to_tests,
self.test_cfg,
original_runtimes,
optimized_runtimes
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
)
expected = ""
@ -165,21 +125,12 @@ class TestExistingTestsSourceFor:
def test_missing_optimized_runtime(self):
"""Test when optimized runtime is missing (shows NaN)."""
function_to_tests = {
"module.function": {self.mock_function_called_in_test}
}
original_runtimes = {
self.mock_invocation_id: [1000000]
}
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
original_runtimes = {self.mock_invocation_id: [1000000]}
optimized_runtimes = {}
result, _, _ = existing_tests_source_for(
"module.function",
function_to_tests,
self.test_cfg,
original_runtimes,
optimized_runtimes
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
)
expected = ""
@ -189,7 +140,7 @@ class TestExistingTestsSourceFor:
def test_multiple_tests_sorted_output(self):
"""Test multiple tests with sorted output by filename and function name."""
# Create second test file
mock_function_called_2 = Mock()
mock_function_called_2.tests_in_file = Mock()
mock_function_called_2.tests_in_file.test_file = Path(__file__).resolve().parent / "test_another.py"
@ -199,24 +150,12 @@ class TestExistingTestsSourceFor:
mock_invocation_2.test_class_name = "TestAnother"
mock_invocation_2.test_function_name = "test_another_function"
function_to_tests = {
"module.function": {self.mock_function_called_in_test, mock_function_called_2}
}
original_runtimes = {
self.mock_invocation_id: [1000000],
mock_invocation_2: [2000000]
}
optimized_runtimes = {
self.mock_invocation_id: [800000],
mock_invocation_2: [1500000]
}
function_to_tests = {"module.function": {self.mock_function_called_in_test, mock_function_called_2}}
original_runtimes = {self.mock_invocation_id: [1000000], mock_invocation_2: [2000000]}
optimized_runtimes = {self.mock_invocation_id: [800000], mock_invocation_2: [1500000]}
result, _, _ = existing_tests_source_for(
"module.function",
function_to_tests,
self.test_cfg,
original_runtimes,
optimized_runtimes
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
)
expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
@ -229,23 +168,16 @@ class TestExistingTestsSourceFor:
def test_multiple_runtimes_uses_minimum(self):
"""Test that function uses minimum runtime when multiple measurements exist."""
function_to_tests = {
"module.function": {self.mock_function_called_in_test}
}
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
original_runtimes = {
self.mock_invocation_id: [1000000, 1200000, 800000] # min: 800000
}
optimized_runtimes = {
self.mock_invocation_id: [600000, 700000, 500000] # min: 500000
self.mock_invocation_id: [600000, 700000, 500000] # min: 500000
}
result, _, _ = existing_tests_source_for(
"module.function",
function_to_tests,
self.test_cfg,
original_runtimes,
optimized_runtimes
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
)
expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
@ -257,7 +189,6 @@ class TestExistingTestsSourceFor:
def test_complex_module_path_conversion(self):
"""Test conversion of complex module paths to file paths."""
mock_invocation_complex = Mock()
mock_invocation_complex.test_module_path = "tests.integration.test_complex_module"
mock_invocation_complex.test_class_name = "TestComplex"
@ -265,24 +196,16 @@ class TestExistingTestsSourceFor:
mock_function_complex = Mock()
mock_function_complex.tests_in_file = Mock()
mock_function_complex.tests_in_file.test_file = Path(__file__).resolve().parent / "integration/test_complex_module.py"
mock_function_complex.tests_in_file.test_file = (
Path(__file__).resolve().parent / "integration/test_complex_module.py"
)
function_to_tests = {
"module.function": {mock_function_complex}
}
original_runtimes = {
mock_invocation_complex: [1000000]
}
optimized_runtimes = {
mock_invocation_complex: [750000]
}
function_to_tests = {"module.function": {mock_function_complex}}
original_runtimes = {mock_invocation_complex: [1000000]}
optimized_runtimes = {mock_invocation_complex: [750000]}
result, _, _ = existing_tests_source_for(
"module.function",
function_to_tests,
self.test_cfg,
original_runtimes,
optimized_runtimes
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
)
expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
@ -294,23 +217,12 @@ class TestExistingTestsSourceFor:
def test_zero_runtime_values(self):
"""Test handling of zero runtime values."""
function_to_tests = {
"module.function": {self.mock_function_called_in_test}
}
original_runtimes = {
self.mock_invocation_id: [0]
}
optimized_runtimes = {
self.mock_invocation_id: [0]
}
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
original_runtimes = {self.mock_invocation_id: [0]}
optimized_runtimes = {self.mock_invocation_id: [0]}
result, _, _ = existing_tests_source_for(
"module.function",
function_to_tests,
self.test_cfg,
original_runtimes,
optimized_runtimes
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
)
expected = ""
@ -320,7 +232,7 @@ class TestExistingTestsSourceFor:
def test_filters_out_generated_tests(self):
"""Test that generated tests are filtered out and only non-generated tests are included."""
# Create a test that would be filtered out (not in non_generated_tests)
mock_generated_test = Mock()
mock_generated_test.tests_in_file = Mock()
mock_generated_test.tests_in_file.test_file = "/project/tests/generated_test.py"
@ -330,24 +242,18 @@ class TestExistingTestsSourceFor:
mock_generated_invocation.test_class_name = "TestGenerated"
mock_generated_invocation.test_function_name = "test_generated"
function_to_tests = {
"module.function": {self.mock_function_called_in_test}
}
function_to_tests = {"module.function": {self.mock_function_called_in_test}}
original_runtimes = {
self.mock_invocation_id: [1000000],
mock_generated_invocation: [500000] # This should be filtered out
mock_generated_invocation: [500000], # This should be filtered out
}
optimized_runtimes = {
self.mock_invocation_id: [800000],
mock_generated_invocation: [400000] # This should be filtered out
mock_generated_invocation: [400000], # This should be filtered out
}
result, _, _ = existing_tests_source_for(
"module.function",
function_to_tests,
self.test_cfg,
original_runtimes,
optimized_runtimes
"module.function", function_to_tests, self.test_cfg, original_runtimes, optimized_runtimes
)
# Should only include the non-generated test
@ -358,9 +264,11 @@ class TestExistingTestsSourceFor:
assert result == expected
@dataclass(frozen=True)
class MockInvocationId:
"""Mocks codeflash.models.models.InvocationId"""
test_module_path: str
test_function_name: str
test_class_name: Optional[str] = None
@ -369,6 +277,7 @@ class MockInvocationId:
@dataclass(frozen=True)
class MockTestsInFile:
"""Mocks a part of codeflash.models.models.FunctionCalledInTest"""
test_file: Path
test_type: str = "EXISTING_UNIT_TEST"
@ -376,12 +285,14 @@ class MockTestsInFile:
@dataclass(frozen=True)
class MockFunctionCalledInTest:
"""Mocks codeflash.models.models.FunctionCalledInTest"""
tests_in_file: MockTestsInFile
@dataclass(frozen=True)
class MockTestConfig:
"""Mocks codeflash.verification.verification_utils.TestConfig"""
tests_root: Path
@ -429,11 +340,7 @@ class ExistingTestsSourceForTests(unittest.TestCase):
test_cfg = MockTestConfig(tests_root=tests_dir.resolve())
function_to_tests = {
self.func_qual_name: {
MockFunctionCalledInTest(
tests_in_file=MockTestsInFile(test_file=test_file_path)
)
}
self.func_qual_name: {MockFunctionCalledInTest(tests_in_file=MockTestsInFile(test_file=test_file_path))}
}
existing, replay, concolic = existing_tests_source_for(
function_qualified_name_with_modules_from_root=self.func_qual_name,
@ -456,17 +363,11 @@ class ExistingTestsSourceForTests(unittest.TestCase):
test_cfg = MockTestConfig(tests_root=tests_dir.resolve())
function_to_tests = {
self.func_qual_name: {
MockFunctionCalledInTest(
tests_in_file=MockTestsInFile(test_file=test_file_path)
)
}
self.func_qual_name: {MockFunctionCalledInTest(tests_in_file=MockTestsInFile(test_file=test_file_path))}
}
invocation_id = MockInvocationId(
test_module_path="tests.test_existing",
test_class_name="TestMyStuff",
test_function_name="test_one",
test_module_path="tests.test_existing", test_class_name="TestMyStuff", test_function_name="test_one"
)
original_runtimes = {invocation_id: [200_000_000]}
@ -501,32 +402,20 @@ class ExistingTestsSourceForTests(unittest.TestCase):
test_cfg = MockTestConfig(tests_root=tests_dir.resolve())
function_to_tests = {
self.func_qual_name: {
MockFunctionCalledInTest(
tests_in_file=MockTestsInFile(test_file=replay_test_path)
),
MockFunctionCalledInTest(
tests_in_file=MockTestsInFile(test_file=concolic_test_path)
),
MockFunctionCalledInTest(tests_in_file=MockTestsInFile(test_file=replay_test_path)),
MockFunctionCalledInTest(tests_in_file=MockTestsInFile(test_file=concolic_test_path)),
}
}
replay_inv_id = MockInvocationId(
test_module_path="tests.__replay_test_abc",
test_function_name="test_replay_one",
test_module_path="tests.__replay_test_abc", test_function_name="test_replay_one"
)
concolic_inv_id = MockInvocationId(
test_module_path="tests.codeflash_concolic_xyz",
test_function_name="test_concolic_one",
test_module_path="tests.codeflash_concolic_xyz", test_function_name="test_concolic_one"
)
original_runtimes = {
replay_inv_id: [100_000_000],
concolic_inv_id: [150_000_000],
}
optimized_runtimes = {
replay_inv_id: [200_000_000],
concolic_inv_id: [300_000_000],
}
original_runtimes = {replay_inv_id: [100_000_000], concolic_inv_id: [150_000_000]}
optimized_runtimes = {replay_inv_id: [200_000_000], concolic_inv_id: [300_000_000]}
existing, replay, concolic = existing_tests_source_for(
function_qualified_name_with_modules_from_root=self.func_qual_name,
@ -555,21 +444,13 @@ class ExistingTestsSourceForTests(unittest.TestCase):
test_cfg = MockTestConfig(tests_root=tests_dir.resolve())
function_to_tests = {
self.func_qual_name: {
MockFunctionCalledInTest(
tests_in_file=MockTestsInFile(test_file=existing_test_path)
),
MockFunctionCalledInTest(
tests_in_file=MockTestsInFile(test_file=replay_test_path)
),
MockFunctionCalledInTest(tests_in_file=MockTestsInFile(test_file=existing_test_path)),
MockFunctionCalledInTest(tests_in_file=MockTestsInFile(test_file=replay_test_path)),
}
}
existing_inv_id = MockInvocationId(
"tests.test_existing", "test_speedup", "TestExisting"
)
replay_inv_id = MockInvocationId(
"tests.__replay_test_mixed", "test_slowdown"
)
existing_inv_id = MockInvocationId("tests.test_existing", "test_speedup", "TestExisting")
replay_inv_id = MockInvocationId("tests.__replay_test_mixed", "test_slowdown")
original_runtimes = {
existing_inv_id: [400_000_000, 500_000_000], # min is 400ms
@ -581,11 +462,7 @@ class ExistingTestsSourceForTests(unittest.TestCase):
}
existing, replay, concolic = existing_tests_source_for(
self.func_qual_name,
function_to_tests,
test_cfg,
original_runtimes,
optimized_runtimes,
self.func_qual_name, function_to_tests, test_cfg, original_runtimes, optimized_runtimes
)
self.assertIn("`test_existing.py::TestExisting.test_speedup`", existing)

View file

@ -3,8 +3,6 @@
from collections import Counter
from pathlib import Path
import pytest
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType
@ -340,11 +338,13 @@ class TestFileToNoOfTests:
)
counter = test_results.file_to_no_of_tests(["test_remove"])
expected = Counter({
Path("/tmp/file1.py"): 1, # Only 1 GENERATED_REGRESSION test
Path("/tmp/file2.py"): 1, # Only test_keep (test_remove is excluded)
Path("/tmp/file3.py"): 3, # All 3 tests
})
expected = Counter(
{
Path("/tmp/file1.py"): 1, # Only 1 GENERATED_REGRESSION test
Path("/tmp/file2.py"): 1, # Only test_keep (test_remove is excluded)
Path("/tmp/file3.py"): 3, # All 3 tests
}
)
assert counter == expected
def test_case_sensitivity(self):
@ -438,7 +438,7 @@ class TestFileToNoOfTests:
)
counter = test_results.file_to_no_of_tests([])
expected = Counter({path: 1 for path in paths})
expected = Counter(dict.fromkeys(paths, 1))
assert counter == expected
def test_large_removal_list(self):
@ -470,4 +470,4 @@ class TestFileToNoOfTests:
)
counter = test_results.file_to_no_of_tests(removal_list)
assert counter == Counter({Path("/tmp/test_file.py"): 50}) # 50 kept, 50 removed
assert counter == Counter({Path("/tmp/test_file.py"): 50}) # 50 kept, 50 removed

View file

@ -1,24 +1,24 @@
import argparse
import os
import shutil
import tempfile
from pathlib import Path
import pytest
import shutil
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeString, CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield Path(tmpdirname)
def test_remove_duplicate_imports():
"""Test that duplicate imports are removed when should_sort_imports is True."""
original_code = "import os\nimport os\n"
@ -187,9 +187,7 @@ def foo():
temp_file = temp_dir / "test_file.py"
temp_file.write_text(original_code)
actual = format_code(
formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], path=temp_file
)
actual = format_code(formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], path=temp_file)
assert actual == expected
@ -208,7 +206,7 @@ def foo():
assert False, f"Shouldn't throw an exception even if the formatter is not found: {e}"
def _run_formatting_test(source_code: str, should_content_change: bool, expected = None, optimized_function: str = ""):
def _run_formatting_test(source_code: str, should_content_change: bool, expected=None, optimized_function: str = ""):
try:
import ruff # type: ignore
except ImportError:
@ -217,67 +215,50 @@ def _run_formatting_test(source_code: str, should_content_change: bool, expected
with tempfile.TemporaryDirectory() as test_dir_str:
test_dir = Path(test_dir_str)
source_file = test_dir / "source.py"
source_file.write_text(source_code)
original = source_code
target_path = test_dir / "target.py"
shutil.copy2(source_file, target_path)
function_to_optimize = FunctionToOptimize(
function_name="process_data",
parents=[],
file_path=target_path
)
function_to_optimize = FunctionToOptimize(function_name="process_data", parents=[], file_path=target_path)
test_cfg = TestConfig(
tests_root=test_dir,
project_root_path=test_dir,
test_framework="pytest",
tests_project_rootdir=test_dir,
tests_root=test_dir, project_root_path=test_dir, test_framework="pytest", tests_project_rootdir=test_dir
)
args = argparse.Namespace(
disable_imports_sorting=False,
formatter_cmds=[
"ruff check --exit-zero --fix $file",
"ruff format $file"
],
disable_imports_sorting=False, formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"]
)
optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
args=args,
)
optimizer = FunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_cfg, args=args)
optimizer.reformat_code_and_helpers(
helper_functions=[],
path=target_path,
original_code=optimizer.function_to_optimize_source_code,
optimized_context=CodeStringsMarkdown(code_strings=[
CodeString(
code=optimized_function,
file_path=target_path.relative_to(test_dir)
)
]),
optimized_context=CodeStringsMarkdown(
code_strings=[CodeString(code=optimized_function, file_path=target_path.relative_to(test_dir))]
),
)
content = target_path.read_text(encoding="utf8")
if expected is not None:
assert content == expected, f"Expected content to be \n===========\n{expected}\n===========\nbut got\n===========\n{content}\n===========\n"
assert content == expected, (
f"Expected content to be \n===========\n{expected}\n===========\nbut got\n===========\n{content}\n===========\n"
)
if should_content_change:
assert content != original, f"Expected content to change for source.py"
assert content != original, "Expected content to change for source.py"
else:
assert content == original, f"Expected content to remain unchanged for source.py"
assert content == original, "Expected content to remain unchanged for source.py"
def test_formatting_file_with_many_diffs():
"""Test that files with many formatting errors are skipped (content unchanged)."""
source_code = '''import os,sys,json,datetime,re
source_code = """import os,sys,json,datetime,re
from collections import defaultdict,OrderedDict
import numpy as np,pandas as pd
@ -354,7 +335,7 @@ def main():
else:print("Pipeline failed")
if __name__=='__main__':main()
'''
"""
_run_formatting_test(source_code, False)
@ -423,7 +404,7 @@ def process_data(data, config=None):
def test_formatting_extremely_messy_file():
"""Test that extremely messy files with 100+ potential changes are skipped."""
source_code = '''import os,sys,json,datetime,re,collections,itertools,functools,operator
source_code = """import os,sys,json,datetime,re,collections,itertools,functools,operator
from pathlib import Path
from typing import Dict,List,Optional,Union,Any,Tuple
import numpy as np,pandas as pd,matplotlib.pyplot as plt
@ -554,25 +535,28 @@ def main():
for error in processor.errors:print(f" - {error}")
if __name__=='__main__':main()
'''
"""
_run_formatting_test(source_code, False)
def test_formatting_edge_case_exactly_100_diffs():
"""Test behavior when exactly at the threshold of 100 changes."""
# Create a file with exactly 100 minor formatting issues
snippet = '''import json\n''' + '''
snippet = (
"""import json\n"""
"""
def func_{i}():
x=1;y=2;z=3
return x+y+z
'''
"""
)
source_code = "".join([snippet.format(i=i) for i in range(100)])
_run_formatting_test(source_code, False)
def test_formatting_with_syntax_errors():
"""Test that files with syntax errors are handled gracefully."""
source_code = '''import json
source_code = """import json
def process_data(data):
if not data:
@ -585,7 +569,7 @@ def process_data(data):
result.append(item)
return result
'''
"""
_run_formatting_test(source_code, False)
@ -641,7 +625,7 @@ def another_function_with_long_line():
def test_formatting_class_with_methods():
"""Test formatting of classes with multiple methods and minor issues."""
source_code = '''class DataProcessor:
source_code = """class DataProcessor:
def __init__(self, config):
self.config=config
self.data=[]
@ -660,13 +644,13 @@ def test_formatting_class_with_methods():
'processed':True
})
return result
'''
"""
_run_formatting_test(source_code, True)
def test_formatting_with_complex_comprehensions():
"""Test files with complex list/dict comprehensions and formatting."""
source_code = '''def complex_comprehensions(data):
source_code = """def complex_comprehensions(data):
# Various comprehension styles with formatting issues
result1=[item['value'] for item in data if item.get('active',True) and 'value' in item]
@ -683,13 +667,13 @@ def test_formatting_with_complex_comprehensions():
'complex':result3,
'nested':nested
}
'''
"""
_run_formatting_test(source_code, True)
def test_formatting_with_decorators_and_async():
"""Test files with decorators and async functions."""
source_code = '''import asyncio
source_code = """import asyncio
from functools import wraps
def timer_decorator(func):
@ -715,26 +699,26 @@ class AsyncProcessor:
@staticmethod
async def process_batch(batch):
return [{'id':item['id'],'status':'done'} for item in batch if 'id' in item]
'''
"""
_run_formatting_test(source_code, True)
def test_formatting_threshold_configuration():
"""Test that the diff threshold can be configured (if supported)."""
# This test assumes the threshold might be configurable
source_code = '''import json,os,sys
source_code = """import json,os,sys
def func1():x=1;y=2;return x+y
def func2():a=1;b=2;return a+b
def func3():c=1;d=2;return c+d
'''
"""
# Test with a file that has moderate formatting issues
_run_formatting_test(source_code, True, optimized_function="def func2():a=1;b=2;return a+b")
def test_formatting_empty_file():
"""Test formatting of empty or minimal files."""
source_code = '''# Just a comment pass
'''
source_code = """# Just a comment pass
"""
_run_formatting_test(source_code, False)
@ -798,6 +782,7 @@ class ProcessorWithDocs:
return{'result':[item for item in data if self._is_valid(item)]}"""
_run_formatting_test(source_code, True, optimized_function=optimization_function, expected=expected)
def test_sort_imports_skip_file():
"""Test that isort skips files with # isort:skip_file."""
code = """# isort:skip_file
@ -809,6 +794,7 @@ import sys, os, json # isort will ignore this file completely"""
# ==================== Tests for format_generated_code ====================
def test_format_generated_code_disabled():
"""Test that format_generated_code returns code with normalized newlines when formatter is disabled."""
test_code = """import os
@ -889,6 +875,7 @@ def test_function(x, y, z):
result = format_generated_code(test_code, ["black $file"])
assert result == expected
def test_format_generated_code_with_inference():
"""Test format_generated_code with ruff formatter."""
try:
@ -1154,6 +1141,7 @@ from inference.core.models.base import Model
result = format_generated_code(test_code, ["ruff format $file"])
assert result == expected
def test_format_generated_code_with_ruff():
"""Test format_generated_code with ruff formatter."""
try:
@ -1205,8 +1193,11 @@ def test_format_generated_code_invalid_formatter():
# Should handle gracefully and return code with normalized newlines
result = format_generated_code(test_code, ["nonexistent_formatter $file"])
assert result == """def test():
assert (
result
== """def test():
pass"""
)
def test_format_generated_code_syntax_error():
@ -1217,8 +1208,11 @@ def test_format_generated_code_syntax_error():
# Formatter should fail but function should handle it gracefully
result = format_generated_code(test_code, ["black $file"])
# Should return code with normalized newlines when formatting fails
assert result == """def test(: # syntax error
assert (
result
== """def test(: # syntax error
pass"""
)
def test_format_generated_code_already_formatted():
@ -1272,9 +1266,9 @@ def test_format_generated_code_trailing_whitespace():
"""
result = format_generated_code(test_code, ["black $file"])
lines = result.split('\n')
lines = result.split("\n")
for line in lines:
assert line == line.rstrip(), f"Line has trailing whitespace: {repr(line)}"
assert line == line.rstrip(), f"Line has trailing whitespace: {line!r}"
def test_format_generated_code_preserves_comments():

View file

@ -1,5 +1,4 @@
import pathlib
from dataclasses import dataclass
import pytest
@ -90,6 +89,7 @@ def recursive_dependency_1(num):
num_1 = calculate_something(num)
return recursive_dependency_1(num) + num_1
from collections import defaultdict
@ -190,6 +190,7 @@ class Graph:
return stack"""
)
def test_recursive_function_context() -> None:
file_path = pathlib.Path(__file__).resolve()
@ -232,4 +233,4 @@ class C:
return 0
num_1 = self.calculate_something_3(num)
return self.recursive(num) + num_1"""
)
)

View file

@ -1,18 +1,16 @@
import tempfile
from pathlib import Path
import os
import unittest.mock
from pathlib import Path
from codeflash.discovery.functions_to_optimize import (
filter_files_optimized,
filter_functions,
find_all_functions_in_file,
get_all_files_and_functions,
get_functions_to_optimize,
inspect_top_level_functions_or_methods,
filter_functions,
get_all_files_and_functions
)
from codeflash.verification.verification_utils import TestConfig
from codeflash.code_utils.compat import codeflash_temp_dir
def test_function_eligible_for_optimization() -> None:
@ -24,10 +22,10 @@ def test_function_eligible_for_optimization() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(function)
functions_found = find_all_functions_in_file(file_path)
assert functions_found[file_path][0].function_name == "test_function_eligible_for_optimization"
@ -40,34 +38,31 @@ def test_function_eligible_for_optimization() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(function)
functions_found = find_all_functions_in_file(file_path)
assert len(functions_found[file_path]) == 0
# we want to trigger an error in the function discovery
function = """def test_invalid_code():"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(function)
functions_found = find_all_functions_in_file(file_path)
assert functions_found == {}
def test_find_top_level_function_or_method():
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(
"""def functionA():
@ -93,7 +88,7 @@ def non_classmethod_function(cls, name):
return cls.name
"""
)
assert inspect_top_level_functions_or_methods(file_path, "functionA").is_top_level
assert not inspect_top_level_functions_or_methods(file_path, "functionB").is_top_level
assert inspect_top_level_functions_or_methods(file_path, "functionC", class_name="A").is_top_level
@ -117,20 +112,21 @@ def non_classmethod_function(cls, name):
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(
"""def functionA():
"""
)
assert not inspect_top_level_functions_or_methods(file_path, "functionA")
def test_class_method_discovery():
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(
"""class A:
@ -146,7 +142,7 @@ class X:
def functionA():
return True"""
)
test_config = TestConfig(
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
)
@ -202,10 +198,10 @@ def test_nested_function():
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(
"""
"""
import copy
def propagate_attributes(
@ -249,7 +245,7 @@ def propagate_attributes(
return modified_nodes
"""
)
test_config = TestConfig(
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
)
@ -270,10 +266,10 @@ def propagate_attributes(
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(
"""
"""
def outer_function():
def inner_function():
pass
@ -281,7 +277,7 @@ def outer_function():
return inner_function
"""
)
test_config = TestConfig(
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
)
@ -302,10 +298,10 @@ def outer_function():
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(
"""
"""
def outer_function():
def inner_function():
pass
@ -315,7 +311,7 @@ def outer_function():
return inner_function, another_inner_function
"""
)
test_config = TestConfig(
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
)
@ -349,15 +345,16 @@ def test_filter_files_optimized():
assert filter_files_optimized(file_path_different_level, tests_root, ignore_paths, module_root)
assert not filter_files_optimized(file_path_above_level, tests_root, ignore_paths, module_root)
def test_filter_functions():
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
# Create a test file in the temporary directory
test_file_path = temp_dir.joinpath("test_get_functions_to_optimize.py")
with test_file_path.open("w") as f:
f.write(
"""
"""
import copy
def propagate_attributes(
@ -407,14 +404,15 @@ def not_in_checkpoint_function():
return "This function is not in the checkpoint."
"""
)
discovered = find_all_functions_in_file(test_file_path)
modified_functions = {test_file_path: discovered[test_file_path]}
# Use an absolute path for tests_root that won't match the temp directory
# This avoids path resolution issues in CI where the working directory might differ
tests_root_absolute = (temp_dir.parent / "nonexistent_tests_dir").resolve()
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}):
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered, count = filter_functions(
modified_functions,
tests_root=tests_root_absolute,
@ -429,19 +427,19 @@ def not_in_checkpoint_function():
# Create a tests directory inside our temp directory
tests_root_dir = temp_dir.joinpath("tests")
tests_root_dir.mkdir(exist_ok=True)
test_file_path = tests_root_dir.joinpath("test_functions.py")
with test_file_path.open("w") as f:
f.write(
"""
"""
def test_function_in_tests_dir():
return "This function is in a test directory and should be filtered out."
"""
)
discovered_test_file = find_all_functions_in_file(test_file_path)
modified_functions_test = {test_file_path: discovered_test_file.get(test_file_path, [])}
filtered_test_file, count_test_file = filter_functions(
modified_functions_test,
tests_root=tests_root_dir,
@ -449,7 +447,7 @@ def test_function_in_tests_dir():
project_root=temp_dir,
module_root=temp_dir,
)
assert not filtered_test_file
assert count_test_file == 0
@ -459,7 +457,7 @@ def test_function_in_tests_dir():
ignored_file_path = ignored_dir.joinpath("ignored_file.py")
with ignored_file_path.open("w") as f:
f.write("def ignored_func(): return 1")
discovered_ignored = find_all_functions_in_file(ignored_file_path)
modified_functions_ignored = {ignored_file_path: discovered_ignored.get(ignored_file_path, [])}
@ -474,17 +472,19 @@ def test_function_in_tests_dir():
assert count_ignored == 0
# Test submodule paths
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.ignored_submodule_paths",
return_value=[str(temp_dir.joinpath("submodule_dir"))]):
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.ignored_submodule_paths",
return_value=[str(temp_dir.joinpath("submodule_dir"))],
):
submodule_dir = temp_dir.joinpath("submodule_dir")
submodule_dir.mkdir(exist_ok=True)
submodule_file_path = submodule_dir.joinpath("submodule_file.py")
with submodule_file_path.open("w") as f:
f.write("def submodule_func(): return 1")
discovered_submodule = find_all_functions_in_file(submodule_file_path)
modified_functions_submodule = {submodule_file_path: discovered_submodule.get(submodule_file_path, [])}
filtered_submodule, count_submodule = filter_functions(
modified_functions_submodule,
tests_root=Path("tests"),
@ -496,14 +496,17 @@ def test_function_in_tests_dir():
assert count_submodule == 0
# Test site packages
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages",
return_value=True):
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages", return_value=True
):
site_package_file_path = temp_dir.joinpath("site_package_file.py")
with site_package_file_path.open("w") as f:
f.write("def site_package_func(): return 1")
discovered_site_package = find_all_functions_in_file(site_package_file_path)
modified_functions_site_package = {site_package_file_path: discovered_site_package.get(site_package_file_path, [])}
modified_functions_site_package = {
site_package_file_path: discovered_site_package.get(site_package_file_path, [])
}
filtered_site_package, count_site_package = filter_functions(
modified_functions_site_package,
@ -514,16 +517,18 @@ def test_function_in_tests_dir():
)
assert not filtered_site_package
assert count_site_package == 0
# Test outside module root
parent_dir = temp_dir.parent
outside_module_root_path = parent_dir.joinpath("outside_module_root_file.py")
try:
with outside_module_root_path.open("w") as f:
f.write("def func_outside_module_root(): return 1")
discovered_outside_module = find_all_functions_in_file(outside_module_root_path)
modified_functions_outside_module = {outside_module_root_path: discovered_outside_module.get(outside_module_root_path, [])}
modified_functions_outside_module = {
outside_module_root_path: discovered_outside_module.get(outside_module_root_path, [])
}
filtered_outside_module, count_outside_module = filter_functions(
modified_functions_outside_module,
@ -543,8 +548,10 @@ def test_function_in_tests_dir():
f.write("def func_in_invalid_module(): return 1")
discovered_invalid_module = find_all_functions_in_file(invalid_module_file_path)
modified_functions_invalid_module = {invalid_module_file_path: discovered_invalid_module.get(invalid_module_file_path, [])}
modified_functions_invalid_module = {
invalid_module_file_path: discovered_invalid_module.get(invalid_module_file_path, [])
}
filtered_invalid_module, count_invalid_module = filter_functions(
modified_functions_invalid_module,
tests_root=Path("tests"),
@ -556,8 +563,10 @@ def test_function_in_tests_dir():
assert count_invalid_module == 0
original_file_path = temp_dir.joinpath("test_get_functions_to_optimize.py")
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions",
return_value={original_file_path.name: {"propagate_attributes", "other_blocklisted_function"}}):
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions",
return_value={original_file_path.name: {"propagate_attributes", "other_blocklisted_function"}},
):
filtered_funcs, count = filter_functions(
modified_functions,
tests_root=Path("tests"),
@ -571,15 +580,20 @@ def test_function_in_tests_dir():
module_name = "test_get_functions_to_optimize"
qualified_name_for_checkpoint = f"{module_name}.propagate_attributes"
other_qualified_name_for_checkpoint = f"{module_name}.vanilla_function"
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}):
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered_checkpoint, count_checkpoint = filter_functions(
modified_functions,
tests_root=Path("tests"),
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
previous_checkpoint_functions={qualified_name_for_checkpoint: {"status": "optimized"}, other_qualified_name_for_checkpoint: {}}
previous_checkpoint_functions={
qualified_name_for_checkpoint: {"status": "optimized"},
other_qualified_name_for_checkpoint: {},
},
)
assert filtered_checkpoint.get(original_file_path)
assert count_checkpoint == 1
@ -589,4 +603,4 @@ def test_function_in_tests_dir():
assert "propagate_attributes" not in remaining_functions
assert "vanilla_function" not in remaining_functions
files_and_funcs = get_all_files_and_functions(module_root_path=temp_dir, ignore_paths=[])
assert len(files_and_funcs) == 6
assert len(files_and_funcs) == 6

View file

@ -1,10 +1,9 @@
import pytest
from pathlib import Path
from unittest.mock import patch
import pytest
from codeflash.benchmarking.function_ranker import FunctionRanker
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, find_all_functions_in_file
from codeflash.models.models import FunctionParent
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
@pytest.fixture
@ -36,25 +35,32 @@ def test_function_ranker_initialization(trace_file):
def test_load_function_stats(function_ranker):
assert len(function_ranker._function_stats) > 0
# Check that funcA is loaded with expected structure
func_a_key = None
for key, stats in function_ranker._function_stats.items():
if stats["function_name"] == "funcA":
func_a_key = key
break
assert func_a_key is not None
func_a_stats = function_ranker._function_stats[func_a_key]
# Verify funcA stats structure
expected_keys = {
"filename", "function_name", "qualified_name", "class_name",
"line_number", "call_count", "own_time_ns", "cumulative_time_ns",
"time_in_callees_ns", "addressable_time_ns"
"filename",
"function_name",
"qualified_name",
"class_name",
"line_number",
"call_count",
"own_time_ns",
"cumulative_time_ns",
"time_in_callees_ns",
"addressable_time_ns",
}
assert set(func_a_stats.keys()) == expected_keys
# Verify funcA specific values
assert func_a_stats["function_name"] == "funcA"
assert func_a_stats["call_count"] == 1
@ -68,7 +74,7 @@ def test_get_function_addressable_time(function_ranker, workload_functions):
if func.function_name == "funcA":
func_a = func
break
assert func_a is not None
addressable_time = function_ranker.get_function_addressable_time(func_a)
@ -79,15 +85,15 @@ def test_get_function_addressable_time(function_ranker, workload_functions):
def test_rank_functions(function_ranker, workload_functions):
ranked_functions = function_ranker.rank_functions(workload_functions)
# Should filter out functions below importance threshold and sort by addressable time
assert len(ranked_functions) <= len(workload_functions)
assert len(ranked_functions) > 0 # At least some functions should pass the threshold
# funcA should pass the importance threshold
func_a_in_results = any(f.function_name == "funcA" for f in ranked_functions)
assert func_a_in_results
# Verify functions are sorted by addressable time in descending order
for i in range(len(ranked_functions) - 1):
current_time = function_ranker.get_function_addressable_time(ranked_functions[i])
@ -101,10 +107,10 @@ def test_get_function_stats_summary(function_ranker, workload_functions):
if func.function_name == "funcA":
func_a = func
break
assert func_a is not None
stats = function_ranker.get_function_stats_summary(func_a)
assert stats is not None
assert stats["function_name"] == "funcA"
assert stats["own_time_ns"] == 153000
@ -112,24 +118,19 @@ def test_get_function_stats_summary(function_ranker, workload_functions):
assert stats["addressable_time_ns"] == 1324000
def test_importance_calculation(function_ranker):
total_program_time = sum(
s["own_time_ns"] for s in function_ranker._function_stats.values()
if s.get("own_time_ns", 0) > 0
s["own_time_ns"] for s in function_ranker._function_stats.values() if s.get("own_time_ns", 0) > 0
)
func_a_stats = None
for stats in function_ranker._function_stats.values():
if stats["function_name"] == "funcA":
func_a_stats = stats
break
assert func_a_stats is not None
importance = func_a_stats["own_time_ns"] / total_program_time
# funcA importance should be approximately 1.9% (153000/7958000)
assert abs(importance - 0.019) < 0.01

View file

@ -1,10 +1,12 @@
import tempfile
from pathlib import Path
import pytest
from codeflash.code_utils.code_extractor import get_code
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
import pytest
from pathlib import Path
@pytest.fixture
def temp_dir():
@ -276,4 +278,4 @@ class CustomDataClass:
[FunctionToOptimize("name", f.name, [FunctionParent("CustomDataClass", "ClassDef")])]
)
assert new_code is None
assert contextual_dunder_methods == set()
assert contextual_dunder_methods == set()

View file

@ -242,8 +242,8 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
code_context = ctx_result.unwrap()
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
assert (
code_context.testgen_context.flat
== f'''# file: {file_path.relative_to(project_root_path)}
code_context.testgen_context.flat
== f'''# file: {file_path.relative_to(project_root_path)}
_P = ParamSpec("_P")
_KEY_T = TypeVar("_KEY_T")
_STORE_T = TypeVar("_STORE_T")
@ -412,8 +412,8 @@ def test_bubble_sort_deps() -> None:
pytest.fail()
code_context = ctx_result.unwrap()
assert (
code_context.testgen_context.flat
== f"""{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep1_helper.py"))}
code_context.testgen_context.flat
== f"""{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep1_helper.py"))}
def dep1_comparer(arr, j: int) -> bool:
return arr[j] > arr[j + 1]
@ -438,7 +438,7 @@ def sorter_deps(arr):
)
assert len(code_context.helper_functions) == 2
assert (
code_context.helper_functions[0].fully_qualified_name
== "code_to_optimize.bubble_sort_dep1_helper.dep1_comparer"
code_context.helper_functions[0].fully_qualified_name
== "code_to_optimize.bubble_sort_dep1_helper.dep1_comparer"
)
assert code_context.helper_functions[1].fully_qualified_name == "code_to_optimize.bubble_sort_dep2_swap.dep2_swap"
assert code_context.helper_functions[1].fully_qualified_name == "code_to_optimize.bubble_sort_dep2_swap.dep2_swap"

View file

@ -73,7 +73,9 @@ def test_dunder_methods_remove_docstring() -> None:
return f"Value: {self.x}"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True)
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
)
assert dedent(expected).strip() == output.strip()
@ -98,7 +100,9 @@ def test_class_remove_docstring() -> None:
return f"Value: {self.x}"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True)
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
)
assert dedent(expected).strip() == output.strip()
@ -125,7 +129,9 @@ def test_mixed_remove_docstring() -> None:
return f"Value: {self.x}"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True)
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
)
assert dedent(expected).strip() == output.strip()
@ -204,7 +210,9 @@ def test_multiple_top_level_targets() -> None:
expected = """
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target1", "TestClass.target2"}, set())
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target1", "TestClass.target2"}, set()
)
assert dedent(expected).strip() == output.strip()
@ -665,7 +673,9 @@ def test_simplified_complete_implementation() -> None:
pass
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()
)
assert dedent(expected).strip() == output.strip()
@ -753,6 +763,10 @@ def test_simplified_complete_implementation_no_docstring() -> None:
"""
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True
dedent(code),
CodeContextType.READ_ONLY,
{"DataProcessor.target_method", "ResultHandler.target_method"},
set(),
remove_docstrings=True,
)
assert dedent(expected).strip() == output.strip()

View file

@ -1,7 +1,8 @@
from textwrap import dedent
import pytest
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
from codeflash.models.models import CodeContextType
@ -12,7 +13,7 @@ def test_simple_function() -> None:
y = 2
return x + y
"""
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
expected = dedent("""
def target_function():
@ -55,7 +56,7 @@ def test_class_with_attributes() -> None:
def other_method(self):
print("this should be excluded")
"""
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
expected = dedent("""
class MyClass:
@ -79,7 +80,7 @@ def test_basic_class_structure() -> None:
def not_findable(self):
return 42
"""
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"Outer.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"Outer.target_method"})
expected = dedent("""
class Outer:
@ -99,7 +100,7 @@ def test_top_level_targets() -> None:
def target_function():
return 42
"""
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
expected = dedent("""
def target_function():
@ -122,7 +123,7 @@ def test_multiple_top_level_classes() -> None:
def process(self):
return "C"
"""
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"})
expected = dedent("""
class ClassA:
@ -147,7 +148,7 @@ def test_try_except_structure() -> None:
def handle_error(self):
print("error")
"""
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"TargetClass.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"TargetClass.target_method"})
expected = dedent("""
try:
@ -174,7 +175,7 @@ def test_init_method() -> None:
def target_method(self):
return f"Value: {self.x}"
"""
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
expected = dedent("""
class MyClass:
@ -186,6 +187,7 @@ def test_init_method() -> None:
""")
assert result.strip() == expected.strip()
def test_dunder_method() -> None:
code = """
class MyClass:
@ -198,7 +200,7 @@ def test_dunder_method() -> None:
def target_method(self):
return f"Value: {self.x}"
"""
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
expected = dedent("""
class MyClass:
@ -208,6 +210,7 @@ def test_dunder_method() -> None:
""")
assert result.strip() == expected.strip()
def test_no_targets_found() -> None:
code = """
class MyClass:
@ -218,7 +221,7 @@ def test_no_targets_found() -> None:
def target(self):
pass
"""
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
expected = dedent("""
class MyClass:
def method(self):
@ -239,7 +242,7 @@ def test_no_targets_found_raises_for_nonexistent() -> None:
pass
"""
with pytest.raises(ValueError, match="No target functions found in the provided code"):
parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"NonExistent.target"})
parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"NonExistent.target"})
def test_module_var() -> None:
@ -263,7 +266,5 @@ def test_module_var() -> None:
var2 = "test"
"""
output = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"})
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
assert dedent(expected).strip() == output.strip()

View file

@ -2,8 +2,9 @@ from textwrap import dedent
import pytest
from codeflash.models.models import CodeContextType
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
from codeflash.models.models import CodeContextType
def test_simple_function() -> None:
code = """
@ -22,6 +23,7 @@ def test_simple_function() -> None:
"""
assert dedent(expected).strip() == result.strip()
def test_basic_class() -> None:
code = """
class TestClass:
@ -45,6 +47,7 @@ def test_basic_class() -> None:
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_dunder_methods() -> None:
code = """
class TestClass:
@ -102,7 +105,9 @@ def test_dunder_methods_remove_docstring() -> None:
print("include me")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True)
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True
)
assert dedent(expected).strip() == output.strip()
@ -132,7 +137,9 @@ def test_class_remove_docstring() -> None:
print("include me")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True)
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True
)
assert dedent(expected).strip() == output.strip()
@ -152,6 +159,7 @@ def test_target_in_nested_class() -> None:
with pytest.raises(ValueError, match="No target functions found in the provided code"):
parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"Outer.Inner.target_method"}, set())
def test_method_signatures() -> None:
code = """
class TestClass:
@ -175,6 +183,8 @@ def test_method_signatures() -> None:
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_multiple_top_level_targets() -> None:
code = """
class TestClass:
@ -203,7 +213,9 @@ def test_multiple_top_level_targets() -> None:
self.x = 42
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target1", "TestClass.target2"}, set())
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"TestClass.target1", "TestClass.target2"}, set()
)
assert dedent(expected).strip() == output.strip()
@ -229,6 +241,7 @@ def test_class_annotations() -> None:
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_class_annotations_if() -> None:
code = """
if True:
@ -345,6 +358,7 @@ def test_module_var() -> None:
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
assert dedent(expected).strip() == output.strip()
def test_module_var_if() -> None:
code = """
def target_function(self) -> None:
@ -374,6 +388,7 @@ def test_module_var_if() -> None:
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
assert dedent(expected).strip() == output.strip()
def test_multiple_classes() -> None:
code = """
class ClassA:
@ -399,7 +414,9 @@ def test_multiple_classes() -> None:
return "C"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"ClassA.process", "ClassC.process"}, set())
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"ClassA.process", "ClassC.process"}, set()
)
assert dedent(expected).strip() == output.strip()
@ -518,6 +535,7 @@ def test_async_with_try_except() -> None:
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_simplified_complete_implementation() -> None:
code = """
class DataProcessor:
@ -639,7 +657,9 @@ def test_simplified_complete_implementation() -> None:
raise RuntimeError(f"Failed to initialize: {self.error}")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()
)
assert dedent(expected).strip() == output.strip()
@ -740,6 +760,10 @@ def test_simplified_complete_implementation_no_docstring() -> None:
"""
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True
dedent(code),
CodeContextType.TESTGEN,
{"DataProcessor.target_method", "ResultHandler.target_method"},
set(),
remove_docstrings=True,
)
assert dedent(expected).strip() == output.strip()

View file

@ -1,7 +1,7 @@
from codeflash.code_utils.time_utils import humanize_runtime, format_time
from codeflash.code_utils.time_utils import format_perf
import pytest
from codeflash.code_utils.time_utils import format_perf, format_time, humanize_runtime
def test_humanize_runtime():
assert humanize_runtime(0) == "0.00 nanoseconds"
@ -140,19 +140,22 @@ class TestFormatTime:
assert format_time(3_600_000_000_000) == "3600s" # 1 hour
assert format_time(86_400_000_000_000) == "86400s" # 1 day
@pytest.mark.parametrize("nanoseconds,expected", [
(0, "0ns"),
(42, "42ns"),
(1_500, "1.50μs"),
(25_000, "25.0μs"),
(150_000, "150μs"),
(2_500_000, "2.50ms"),
(45_000_000, "45.0ms"),
(200_000_000, "200ms"),
(3_500_000_000, "3.50s"),
(75_000_000_000, "75.0s"),
(300_000_000_000, "300s"),
])
@pytest.mark.parametrize(
"nanoseconds,expected",
[
(0, "0ns"),
(42, "42ns"),
(1_500, "1.50μs"),
(25_000, "25.0μs"),
(150_000, "150μs"),
(2_500_000, "2.50ms"),
(45_000_000, "45.0ms"),
(200_000_000, "200ms"),
(3_500_000_000, "3.50s"),
(75_000_000_000, "75.0s"),
(300_000_000_000, "300s"),
],
)
def test_parametrized_examples(self, nanoseconds, expected):
"""Parametrized test with various input/output combinations."""
assert format_time(nanoseconds) == expected
@ -272,4 +275,4 @@ class TestFormatPerf:
assert format_perf(100.4) == "100"
assert format_perf(10.54) == "10.5"
assert format_perf(1.554) == "1.55"
assert format_perf(0.1554) == "0.155"
assert format_perf(0.1554) == "0.155"

View file

@ -9,8 +9,6 @@ from __future__ import annotations
import re
from pathlib import Path
import pytest
from codeflash.code_utils.instrument_existing_tests import (
detect_frameworks_from_code,
inject_profiling_into_existing_test,
@ -28,19 +26,11 @@ def normalize_instrumented_code(code: str) -> str:
generates double-quoted f-strings for compatibility with older versions).
"""
# Normalize database path
code = re.sub(
r"sqlite3\.connect\(f'[^']+'",
"sqlite3.connect(f'{CODEFLASH_DB_PATH}'",
code
)
code = re.sub(r"sqlite3\.connect\(f'[^']+'", "sqlite3.connect(f'{CODEFLASH_DB_PATH}'", code)
# Normalize f-string that contains the test_stdout_tag assignment
# This specific f-string has internal single quotes, so libcst uses double quotes
# on Python < 3.12, but single quotes on Python 3.12+
code = re.sub(
r'test_stdout_tag = f"([^"]+)"',
r"test_stdout_tag = f'\1'",
code
)
code = re.sub(r'test_stdout_tag = f"([^"]+)"', r"test_stdout_tag = f'\1'", code)
return code
@ -1112,11 +1102,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1142,11 +1128,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1172,11 +1154,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1202,11 +1180,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1232,11 +1206,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1262,11 +1232,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1292,11 +1258,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1322,11 +1284,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1353,11 +1311,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1385,11 +1339,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1423,11 +1373,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1453,11 +1399,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1483,11 +1425,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1513,11 +1451,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,
@ -1545,11 +1479,7 @@ def test_my_function():
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py"))
success, instrumented_code = inject_profiling_into_existing_test(
test_path=test_file,

View file

@ -116,11 +116,7 @@ def test_sort():
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path(fto_path))
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(6, 13), CodePosition(10, 13)],
func,
project_root_path,
mode=TestingMode.BEHAVIOR,
test_path, [CodePosition(6, 13), CodePosition(10, 13)], func, project_root_path, mode=TestingMode.BEHAVIOR
)
os.chdir(original_cwd)
assert success
@ -552,7 +548,9 @@ def test_sort():
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_method.py").resolve()
original_code = fto_path.read_text("utf-8")
fto = FunctionToOptimize(
function_name="sorter_classmethod", parents=[FunctionParent(name="BubbleSorter", type="ClassDef")], file_path=Path(fto_path)
function_name="sorter_classmethod",
parents=[FunctionParent(name="BubbleSorter", type="ClassDef")],
file_path=Path(fto_path),
)
with tempfile.TemporaryDirectory() as tmpdirname:
tmp_test_path = Path(tmpdirname) / "test_classmethod_behavior_results_temp.py"
@ -646,8 +644,11 @@ def test_sort():
)
assert test_results[1].runtime > 0
assert test_results[1].did_pass
assert test_results[1].stdout == """codeflash stdout : BubbleSorter.sorter_classmethod() called
assert (
test_results[1].stdout
== """codeflash stdout : BubbleSorter.sorter_classmethod() called
"""
)
results2, _ = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
@ -718,7 +719,9 @@ def test_sort():
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_method.py").resolve()
original_code = fto_path.read_text("utf-8")
fto = FunctionToOptimize(
function_name="sorter_staticmethod", parents=[FunctionParent(name="BubbleSorter", type="ClassDef")], file_path=Path(fto_path)
function_name="sorter_staticmethod",
parents=[FunctionParent(name="BubbleSorter", type="ClassDef")],
file_path=Path(fto_path),
)
with tempfile.TemporaryDirectory() as tmpdirname:
tmp_test_path = Path(tmpdirname) / "test_staticmethod_behavior_results_temp.py"
@ -812,8 +815,11 @@ def test_sort():
)
assert test_results[1].runtime > 0
assert test_results[1].did_pass
assert test_results[1].stdout == """codeflash stdout : BubbleSorter.sorter_staticmethod() called
assert (
test_results[1].stdout
== """codeflash stdout : BubbleSorter.sorter_staticmethod() called
"""
)
results2, _ = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
@ -831,4 +837,4 @@ def test_sort():
finally:
fto_path.write_text(original_code, "utf-8")
test_path.unlink(missing_ok=True)
test_path_perf.unlink(missing_ok=True)
test_path_perf.unlink(missing_ok=True)

View file

@ -1,8 +1,7 @@
import tempfile
from pathlib import Path
import uuid
import os
import sys
import tempfile
from pathlib import Path
import pytest
@ -81,9 +80,7 @@ async def async_function(x: int, y: int) -> int:
test_file = temp_dir / "test_async.py"
test_file.write_text(async_function_code)
func = FunctionToOptimize(
function_name="async_function", file_path=test_file, parents=[], is_async=True
)
func = FunctionToOptimize(function_name="async_function", file_path=test_file, parents=[], is_async=True)
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR)
@ -120,9 +117,7 @@ async def async_function(x: int, y: int) -> int:
test_file = temp_dir / "test_async.py"
test_file.write_text(async_function_code)
func = FunctionToOptimize(
function_name="async_function", file_path=test_file, parents=[], is_async=True
)
func = FunctionToOptimize(function_name="async_function", file_path=test_file, parents=[], is_async=True)
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.PERFORMANCE)
@ -160,9 +155,7 @@ async def async_function(x: int, y: int) -> int:
test_file = temp_dir / "test_async.py"
test_file.write_text(async_function_code)
func = FunctionToOptimize(
function_name="async_function", file_path=test_file, parents=[], is_async=True
)
func = FunctionToOptimize(function_name="async_function", file_path=test_file, parents=[], is_async=True)
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.CONCURRENCY)
@ -243,9 +236,7 @@ async def async_function(x: int, y: int) -> int:
test_file = temp_dir / "test_async.py"
test_file.write_text(already_decorated_code)
func = FunctionToOptimize(
function_name="async_function", file_path=test_file, parents=[], is_async=True
)
func = FunctionToOptimize(function_name="async_function", file_path=test_file, parents=[], is_async=True)
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR)
@ -290,12 +281,10 @@ async def test_async_function():
# First instrument the source module
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
source_success = add_async_decorator_to_function(
source_file, func, TestingMode.BEHAVIOR
)
source_success = add_async_decorator_to_function(source_file, func, TestingMode.BEHAVIOR)
assert source_success is True
# Verify the file was modified
instrumented_source = source_file.read_text()
assert "@codeflash_behavior_async" in instrumented_source
@ -347,12 +336,10 @@ async def test_async_function():
# First instrument the source module
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
source_success = add_async_decorator_to_function(
source_file, func, TestingMode.PERFORMANCE
)
source_success = add_async_decorator_to_function(source_file, func, TestingMode.PERFORMANCE)
assert source_success is True
# Verify the file was modified
instrumented_source = source_file.read_text()
assert "@codeflash_performance_async" in instrumented_source
@ -413,12 +400,10 @@ async def test_mixed_functions():
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
source_success = add_async_decorator_to_function(
source_file, async_func, TestingMode.BEHAVIOR
)
source_success = add_async_decorator_to_function(source_file, async_func, TestingMode.BEHAVIOR)
assert source_success
# Verify the file was modified
instrumented_source = source_file.read_text()
assert "@codeflash_behavior_async" in instrumented_source
@ -428,11 +413,7 @@ async def test_mixed_functions():
assert "def sync_function(x: int, y: int) -> int:" in instrumented_source
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file,
[CodePosition(8, 18), CodePosition(11, 19)],
async_func,
temp_dir,
mode=TestingMode.BEHAVIOR,
test_file, [CodePosition(8, 18), CodePosition(11, 19)], async_func, temp_dir, mode=TestingMode.BEHAVIOR
)
# Async functions should not be instrumented at the test level
@ -465,8 +446,7 @@ class OuterClass:
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR)
expected_output = (
"""import asyncio
expected_output = """import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
@ -480,7 +460,6 @@ class OuterClass:
await asyncio.sleep(0.001)
return x * 2
"""
)
assert decorator_added
modified_code = test_file.read_text()
@ -510,9 +489,7 @@ async def async_function(x: int, y: int) -> int:
test_file = temp_dir / "test_async.py"
test_file.write_text(decorated_async_code)
func = FunctionToOptimize(
function_name="async_function", file_path=test_file, parents=[], is_async=True
)
func = FunctionToOptimize(function_name="async_function", file_path=test_file, parents=[], is_async=True)
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR)
@ -538,22 +515,16 @@ def sync_function(x: int, y: int) -> int:
test_file = temp_dir / "test_sync.py"
test_file.write_text(sync_function_code)
sync_func = FunctionToOptimize(
function_name="sync_function",
file_path=test_file,
parents=[],
is_async=False,
)
sync_func = FunctionToOptimize(function_name="sync_function", file_path=test_file, parents=[], is_async=False)
decorator_added = add_async_decorator_to_function(
test_file, sync_func, TestingMode.BEHAVIOR
)
decorator_added = add_async_decorator_to_function(test_file, sync_func, TestingMode.BEHAVIOR)
assert not decorator_added
# File should not be modified for sync functions
modified_code = test_file.read_text()
assert modified_code == sync_function_code
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_inject_profiling_async_multiple_calls_same_test(temp_dir):
"""Test that multiple async function calls within the same test function get correctly numbered 0, 1, 2, etc."""
@ -599,12 +570,10 @@ async def test_multiple_calls():
# First instrument the source module with async decorators
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
source_success = add_async_decorator_to_function(
source_file, func, TestingMode.BEHAVIOR
)
source_success = add_async_decorator_to_function(source_file, func, TestingMode.BEHAVIOR)
assert source_success
# Verify the file was modified
instrumented_source = source_file.read_text()
assert "@codeflash_behavior_async" in instrumented_source
@ -636,18 +605,15 @@ async def test_multiple_calls():
line_id_1_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '1'")
line_id_2_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '2'")
assert line_id_0_count == 2, f"Expected 2 occurrences of line_id '0', got {line_id_0_count}"
assert line_id_1_count == 1, f"Expected 1 occurrence of line_id '1', got {line_id_1_count}"
assert line_id_2_count == 1, f"Expected 1 occurrence of line_id '2', got {line_id_2_count}"
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_behavior_decorator_return_values_and_test_ids():
"""Test that async behavior decorator correctly captures return values, test IDs, and stores data in database."""
import asyncio
import os
import sqlite3
from pathlib import Path
@ -684,7 +650,7 @@ def test_async_behavior_decorator_return_values_and_test_ids():
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
db_path = get_run_tmp_file(Path(f"test_return_values_2.sqlite"))
db_path = get_run_tmp_file(Path("test_return_values_2.sqlite"))
# Verify database exists and has data
assert db_path.exists(), f"Database file not created at {db_path}"
@ -745,7 +711,6 @@ def test_async_behavior_decorator_return_values_and_test_ids():
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_decorator_comprehensive_return_values_and_test_ids():
import asyncio
import os
import sqlite3
from pathlib import Path
@ -793,7 +758,7 @@ def test_async_decorator_comprehensive_return_values_and_test_ids():
f"Expected {test_case['expected']}, got {result} for args {test_case['args']}, kwargs {test_case['kwargs']}"
)
db_path = get_run_tmp_file(Path(f"test_return_values_3.sqlite"))
db_path = get_run_tmp_file(Path("test_return_values_3.sqlite"))
assert db_path.exists(), f"Database not created at {db_path}"
con = sqlite3.connect(db_path)
@ -837,7 +802,6 @@ def test_async_decorator_comprehensive_return_values_and_test_ids():
f"Row {i}: Expected iteration_id '{expected_iteration_id}', got '{iteration_id}'"
)
args, kwargs, actual_return_value = pickle.loads(return_value_blob)
expected_args = test_cases[i]["args"]
expected_kwargs = test_cases[i]["kwargs"]

View file

@ -3,8 +3,10 @@ from __future__ import annotations
import tempfile
from pathlib import Path
from codeflash.benchmarking.instrument_codeflash_trace import add_codeflash_decorator_to_code, \
instrument_codeflash_trace_decorator
from codeflash.benchmarking.instrument_codeflash_trace import (
add_codeflash_decorator_to_code,
instrument_codeflash_trace_decorator,
)
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
@ -15,16 +17,9 @@ def normal_function():
return "Hello, World!"
"""
fto = FunctionToOptimize(
function_name="normal_function",
file_path=Path("dummy_path.py"),
parents=[]
)
fto = FunctionToOptimize(function_name="normal_function", file_path=Path("dummy_path.py"), parents=[])
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@ -47,13 +42,10 @@ class TestClass:
fto = FunctionToOptimize(
function_name="normal_method",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
parents=[FunctionParent(name="TestClass", type="ClassDef")],
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@ -78,13 +70,10 @@ class TestClass:
fto = FunctionToOptimize(
function_name="class_method",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
parents=[FunctionParent(name="TestClass", type="ClassDef")],
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@ -110,13 +99,10 @@ class TestClass:
fto = FunctionToOptimize(
function_name="static_method",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
parents=[FunctionParent(name="TestClass", type="ClassDef")],
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@ -141,13 +127,10 @@ class TestClass:
fto = FunctionToOptimize(
function_name="__init__",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
parents=[FunctionParent(name="TestClass", type="ClassDef")],
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@ -173,13 +156,10 @@ class TestClass:
fto = FunctionToOptimize(
function_name="property_method",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
parents=[FunctionParent(name="TestClass", type="ClassDef")],
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@ -209,13 +189,10 @@ class OtherClass:
fto = FunctionToOptimize(
function_name="test_method",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
parents=[FunctionParent(name="TestClass", type="ClassDef")],
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@ -239,16 +216,9 @@ def existing_function():
return "This exists"
"""
fto = FunctionToOptimize(
function_name="nonexistent_function",
file_path=Path("dummy_path.py"),
parents=[]
)
fto = FunctionToOptimize(function_name="nonexistent_function", file_path=Path("dummy_path.py"), parents=[])
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
# Code should remain unchanged
assert modified_code.strip() == code.strip()
@ -272,27 +242,16 @@ def function_two():
"""
functions_to_optimize = [
FunctionToOptimize(
function_name="function_one",
file_path=Path("dummy_path.py"),
parents=[]
),
FunctionToOptimize(function_name="function_one", file_path=Path("dummy_path.py"), parents=[]),
FunctionToOptimize(
function_name="method_two",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
parents=[FunctionParent(name="TestClass", type="ClassDef")],
),
FunctionToOptimize(
function_name="function_two",
file_path=Path("dummy_path.py"),
parents=[]
)
FunctionToOptimize(function_name="function_two", file_path=Path("dummy_path.py"), parents=[]),
]
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=functions_to_optimize
)
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=functions_to_optimize)
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@ -339,16 +298,12 @@ def function_two():
# Define functions to optimize
functions_to_optimize = [
FunctionToOptimize(
function_name="function_one",
file_path=test_file_path,
parents=[]
),
FunctionToOptimize(function_name="function_one", file_path=test_file_path, parents=[]),
FunctionToOptimize(
function_name="method_two",
file_path=test_file_path,
parents=[FunctionParent(name="TestClass", type="ClassDef")]
)
parents=[FunctionParent(name="TestClass", type="ClassDef")],
),
]
# Execute the function being tested
@ -399,7 +354,7 @@ class ClassA:
# Create second test Python file
test_file_2_path = Path(temp_dir) / "module_b.py"
test_file_2_content ="""
test_file_2_content = """
def function_b():
return "Function in module B"
@ -412,20 +367,14 @@ class ClassB:
# Define functions to optimize
file_to_funcs_to_optimize = {
test_file_1_path: [
FunctionToOptimize(
function_name="function_a",
file_path=test_file_1_path,
parents=[]
)
],
test_file_1_path: [FunctionToOptimize(function_name="function_a", file_path=test_file_1_path, parents=[])],
test_file_2_path: [
FunctionToOptimize(
function_name="static_method_b",
file_path=test_file_2_path,
parents=[FunctionParent(name="ClassB", type="ClassDef")]
parents=[FunctionParent(name="ClassB", type="ClassDef")],
)
]
],
}
# Execute the function being tested
@ -484,13 +433,10 @@ class OuterClass:
fto = FunctionToOptimize(
function_name="target_method",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="OuterClass", type="ClassDef")]
parents=[FunctionParent(name="OuterClass", type="ClassDef")],
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@ -520,16 +466,9 @@ def target_function():
return "Hello from target function after nested function"
"""
fto = FunctionToOptimize(
function_name="target_function",
file_path=Path("dummy_path.py"),
parents=[]
)
fto = FunctionToOptimize(function_name="target_function", file_path=Path("dummy_path.py"), parents=[])
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto])
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@ -561,11 +500,7 @@ def some_function():
"""
test_file_path.write_text(original_content, encoding="utf-8")
fto = FunctionToOptimize(
function_name="some_function",
file_path=test_file_path,
parents=[]
)
fto = FunctionToOptimize(function_name="some_function", file_path=test_file_path, parents=[])
instrument_codeflash_trace_decorator({test_file_path: [fto]})
@ -587,11 +522,7 @@ def patch_function():
"""
test_file_path.write_text(original_content, encoding="utf-8")
fto = FunctionToOptimize(
function_name="patch_function",
file_path=test_file_path,
parents=[]
)
fto = FunctionToOptimize(function_name="patch_function", file_path=test_file_path, parents=[])
instrument_codeflash_trace_decorator({test_file_path: [fto]})
@ -616,11 +547,7 @@ def trace_func():
"""
test_file_path.write_text(original_content, encoding="utf-8")
fto = FunctionToOptimize(
function_name="trace_func",
file_path=test_file_path,
parents=[]
)
fto = FunctionToOptimize(function_name="trace_func", file_path=test_file_path, parents=[])
instrument_codeflash_trace_decorator({test_file_path: [fto]})
@ -645,11 +572,7 @@ def util_func():
"""
test_file_path.write_text(original_content, encoding="utf-8")
fto = FunctionToOptimize(
function_name="util_func",
file_path=test_file_path,
parents=[]
)
fto = FunctionToOptimize(function_name="util_func", file_path=test_file_path, parents=[])
instrument_codeflash_trace_decorator({test_file_path: [fto]})
@ -673,15 +596,11 @@ def main_func():
"""
test_file_path.write_text(original_content, encoding="utf-8")
fto = FunctionToOptimize(
function_name="main_func",
file_path=test_file_path,
parents=[]
)
fto = FunctionToOptimize(function_name="main_func", file_path=test_file_path, parents=[])
instrument_codeflash_trace_decorator({test_file_path: [fto]})
# File SHOULD be modified
modified_content = test_file_path.read_text(encoding="utf-8")
assert "codeflash_trace" in modified_content
assert "@codeflash_trace" in modified_content
assert "@codeflash_trace" in modified_content

View file

@ -2,8 +2,6 @@ import os
from pathlib import Path
from tempfile import TemporaryDirectory
import pytest
from codeflash.code_utils.line_profile_utils import add_decorator_imports, contains_jit_decorator
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext
@ -26,7 +24,7 @@ def test_add_decorator_imports_helper_in_class():
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
#func_optimizer = pass
# func_optimizer = pass
try:
ctx_result = func_optimizer.get_code_optimization_context()
code_context: CodeOptimizationContext = ctx_result.unwrap()
@ -36,8 +34,7 @@ def test_add_decorator_imports_helper_in_class():
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
line_profiler_output_file = add_decorator_imports(
func_optimizer.function_to_optimize, code_context)
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
@ -77,11 +74,14 @@ class BubbleSortClass:
assert code_context.helper_functions[0].file_path.read_text("utf-8") == expected_code_helper
finally:
func_optimizer.write_code_and_helpers(
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
func_optimizer.function_to_optimize_source_code,
original_helper_code,
func_optimizer.function_to_optimize.file_path,
)
def test_add_decorator_imports_helper_in_nested_class():
#Need to invert the assert once the helper detection is fixed
# Need to invert the assert once the helper detection is fixed
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_nested_classmethod.py").resolve()
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
project_root_path = (Path(__file__).parent / "..").resolve()
@ -96,7 +96,7 @@ def test_add_decorator_imports_helper_in_nested_class():
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
#func_optimizer = pass
# func_optimizer = pass
try:
ctx_result = func_optimizer.get_code_optimization_context()
code_context: CodeOptimizationContext = ctx_result.unwrap()
@ -106,8 +106,7 @@ def test_add_decorator_imports_helper_in_nested_class():
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
line_profiler_output_file = add_decorator_imports(
func_optimizer.function_to_optimize, code_context)
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
@ -125,9 +124,12 @@ def sort_classmethod(x):
assert code_context.helper_functions[0].qualified_name == "WrapperClass.__init__"
finally:
func_optimizer.write_code_and_helpers(
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
func_optimizer.function_to_optimize_source_code,
original_helper_code,
func_optimizer.function_to_optimize.file_path,
)
def test_add_decorator_imports_nodeps():
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
@ -143,7 +145,7 @@ def test_add_decorator_imports_nodeps():
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
#func_optimizer = pass
# func_optimizer = pass
try:
ctx_result = func_optimizer.get_code_optimization_context()
code_context: CodeOptimizationContext = ctx_result.unwrap()
@ -153,8 +155,7 @@ def test_add_decorator_imports_nodeps():
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
line_profiler_output_file = add_decorator_imports(
func_optimizer.function_to_optimize, code_context)
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
@ -174,9 +175,12 @@ def sorter(arr):
assert code_path.read_text("utf-8") == expected_code_main
finally:
func_optimizer.write_code_and_helpers(
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
func_optimizer.function_to_optimize_source_code,
original_helper_code,
func_optimizer.function_to_optimize.file_path,
)
def test_add_decorator_imports_helper_outside():
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_deps.py").resolve()
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
@ -192,7 +196,7 @@ def test_add_decorator_imports_helper_outside():
func = FunctionToOptimize(function_name="sorter_deps", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
#func_optimizer = pass
# func_optimizer = pass
try:
ctx_result = func_optimizer.get_code_optimization_context()
code_context: CodeOptimizationContext = ctx_result.unwrap()
@ -202,8 +206,7 @@ def test_add_decorator_imports_helper_outside():
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
line_profiler_output_file = add_decorator_imports(
func_optimizer.function_to_optimize, code_context)
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
@ -227,7 +230,7 @@ def sorter_deps(arr):
def dep1_comparer(arr, j: int) -> bool:
return arr[j] > arr[j + 1]
"""
expected_code_helper2="""from line_profiler import profile as codeflash_line_profile
expected_code_helper2 = """from line_profiler import profile as codeflash_line_profile
@codeflash_line_profile
@ -241,9 +244,12 @@ def dep2_swap(arr, j):
assert code_context.helper_functions[1].file_path.read_text("utf-8") == expected_code_helper2
finally:
func_optimizer.write_code_and_helpers(
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
func_optimizer.function_to_optimize_source_code,
original_helper_code,
func_optimizer.function_to_optimize.file_path,
)
def test_add_decorator_imports_helper_in_dunder_class():
code_str = """def sorter(arr):
ans = helper(arr)
@ -253,7 +259,7 @@ class helper:
return arr.sort()"""
code_path = TemporaryDirectory()
code_write_path = Path(code_path.name) / "dunder_class.py"
code_write_path.write_text(code_str,"utf-8")
code_write_path.write_text(code_str, "utf-8")
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
project_root_path = Path(code_path.name)
run_cwd = Path(__file__).parent.parent.resolve()
@ -267,7 +273,7 @@ class helper:
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_write_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
#func_optimizer = pass
# func_optimizer = pass
try:
ctx_result = func_optimizer.get_code_optimization_context()
code_context: CodeOptimizationContext = ctx_result.unwrap()
@ -277,8 +283,7 @@ class helper:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
line_profiler_output_file = add_decorator_imports(
func_optimizer.function_to_optimize, code_context)
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')

View file

@ -3,10 +3,13 @@ from __future__ import annotations
import ast
import math
import os
import platform
import sys
import tempfile
from pathlib import Path
import pytest
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.instrument_existing_tests import (
FunctionImportedAsVisitor,
@ -24,8 +27,6 @@ from codeflash.models.models import (
TestsInFile,
TestType,
)
import platform
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -114,12 +115,15 @@ import pytest"""
if extra_imports:
imports += "\n" + extra_imports
return imports
# create a temporary directory for the test results
@pytest.fixture
def tmp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield Path(tmpdirname)
def test_perfinjector_bubble_sort(tmp_dir) -> None:
code = """import unittest
@ -150,12 +154,12 @@ import dill as pickle"""
# timeout_decorator no longer used since pytest handles timeouts
imports += "\n\nfrom code_to_optimize.bubble_sort import sorter"
wrapper_func = codeflash_wrap_string
test_class_header = "class TestPigLatin(unittest.TestCase):"
test_decorator = "" # pytest-timeout handles timeouts now, not timeout_decorator
expected = imports + "\n\n\n" + wrapper_func + "\n" + test_class_header + "\n\n"
if test_decorator:
expected += test_decorator + "\n"
@ -190,10 +194,7 @@ import dill as pickle"""
run_cwd = Path(__file__).parent.parent.resolve()
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
Path(f.name),
[CodePosition(9, 17), CodePosition(13, 17), CodePosition(17, 17)],
func,
Path(f.name).parent,
Path(f.name), [CodePosition(9, 17), CodePosition(13, 17), CodePosition(17, 17)], func, Path(f.name).parent
)
os.chdir(original_cwd)
assert success
@ -397,18 +398,14 @@ def test_sort():
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(8, 14), CodePosition(12, 14)],
func,
project_root_path,
mode=TestingMode.BEHAVIOR,
test_path, [CodePosition(8, 14), CodePosition(12, 14)], func, project_root_path, mode=TestingMode.BEHAVIOR
)
os.chdir(original_cwd)
assert success
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
success, new_perf_test = inject_profiling_into_existing_test(
@ -422,7 +419,7 @@ tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
assert new_perf_test is not None
assert new_perf_test.replace('"', "'") == expected_perfonly.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
with test_path.open("w") as f:
@ -942,7 +939,7 @@ def test_sort_parametrized_loop(input, expected_output):
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
# Overwrite old test with new instrumented test
@ -951,7 +948,7 @@ tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
assert new_test_perf.replace('"', "'") == expected_perf.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
# Overwrite old test with new instrumented test
@ -1301,12 +1298,12 @@ def test_sort():
assert new_test_behavior is not None
assert new_test_behavior.replace('"', "'") == expected.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
assert new_test_perf.replace('"', "'") == expected_perf.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
# Overwrite old test with new instrumented test
@ -1477,7 +1474,6 @@ result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2
def test_perfinjector_bubble_sort_unittest_results() -> None:
code = """import unittest
from code_to_optimize.bubble_sort import sorter
@ -1499,7 +1495,7 @@ class TestPigLatin(unittest.TestCase):
"""
is_windows = platform.system() == "Windows"
if is_windows:
expected = (
"""import gc
@ -1685,11 +1681,11 @@ class TestPigLatin(unittest.TestCase):
assert new_test_behavior is not None
assert new_test_behavior.replace('"', "'") == expected.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
assert new_test_perf.replace('"', "'") == expected_perf.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
#
# Overwrite old test with new instrumented test
@ -1852,7 +1848,7 @@ class TestPigLatin(unittest.TestCase):
# Build expected behavior output with platform-aware imports
imports_behavior = build_expected_unittest_imports("from parameterized import parameterized")
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_behavior = "" # pytest-timeout handles timeouts now
test_class_behavior = """class TestPigLatin(unittest.TestCase):
@ -1872,7 +1868,7 @@ class TestPigLatin(unittest.TestCase):
self.assertEqual(output, expected_output)
codeflash_con.close()
"""
expected_behavior = imports_behavior + "\n\n\n" + codeflash_wrap_string + "\n" + test_class_behavior
# Build expected perf output with platform-aware imports
imports_perf = """import gc
@ -1882,7 +1878,7 @@ import unittest
"""
# pytest-timeout handles timeouts now, no timeout_decorator needed
imports_perf += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_perf = "" # pytest-timeout handles timeouts now
test_class_perf = """class TestPigLatin(unittest.TestCase):
@ -1895,7 +1891,7 @@ import unittest
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0', codeflash_loop_index, input)
self.assertEqual(output, expected_output)
"""
expected_perf = imports_perf + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class_perf
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
test_path = (
@ -1933,13 +1929,13 @@ import unittest
assert new_test_behavior is not None
assert new_test_behavior.replace('"', "'") == expected_behavior.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
assert new_test_perf is not None
assert new_test_perf.replace('"', "'") == expected_perf.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
#
@ -2099,10 +2095,10 @@ class TestPigLatin(unittest.TestCase):
output = sorter(input)
self.assertEqual(output, expected_output)"""
# Build expected behavior output with platform-aware imports
# Build expected behavior output with platform-aware imports
imports_behavior = build_expected_unittest_imports()
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_behavior = "" # pytest-timeout handles timeouts now
test_class_behavior = """class TestPigLatin(unittest.TestCase):
@ -2137,7 +2133,7 @@ import unittest
"""
# pytest-timeout handles timeouts now, no timeout_decorator needed
imports_perf += "\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_perf = "" # pytest-timeout handles timeouts now
test_class_perf = """class TestPigLatin(unittest.TestCase):
@ -2154,7 +2150,7 @@ import unittest
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '2_2', codeflash_loop_index, input)
self.assertEqual(output, expected_output)
"""
expected_perf = imports_perf + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class_perf
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
test_path = (
@ -2192,11 +2188,11 @@ import unittest
assert new_test_behavior is not None
assert new_test_behavior.replace('"', "'") == expected_behavior.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
assert new_test_perf.replace('"', "'") == expected_perf.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
#
# # Overwrite old test with new instrumented test
@ -2361,7 +2357,7 @@ class TestPigLatin(unittest.TestCase):
# Build expected behavior output with platform-aware imports
imports_behavior = build_expected_unittest_imports("from parameterized import parameterized")
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_behavior = "" # pytest-timeout handles timeouts now
test_class_behavior = """class TestPigLatin(unittest.TestCase):
@ -2392,7 +2388,7 @@ import unittest
"""
# pytest-timeout handles timeouts now, no timeout_decorator needed
imports_perf += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_perf = "" # pytest-timeout handles timeouts now
test_class_perf = """class TestPigLatin(unittest.TestCase):
@ -2406,7 +2402,7 @@ import unittest
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0_0', codeflash_loop_index, input)
self.assertEqual(output, expected_output)
"""
expected_perf = imports_perf + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class_perf
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
test_path = (
@ -2442,11 +2438,11 @@ import unittest
assert new_test_behavior is not None
assert new_test_behavior.replace('"', "'") == expected_behavior.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
assert new_test_perf.replace('"', "'") == expected_perf.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
#
# Overwrite old test with new instrumented test
@ -2888,7 +2884,7 @@ def test_sort():
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
module_path="tests.pytest.test_conditional_instrumentation_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
finally:
test_path.unlink(missing_ok=True)
@ -2970,7 +2966,7 @@ def test_sort():
assert success
formatted_expected = expected.format(
module_path="tests.pytest.test_perfinjector_bubble_sort_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
)
assert new_test is not None
assert new_test.replace('"', "'") == formatted_expected.replace('"', "'")
@ -3055,7 +3051,7 @@ def test_code_replacement10() -> None:
test_file_path = tmp_path / "test_class_method_instrumentation.py"
test_file_path.write_text(code, encoding="utf-8")
func = FunctionToOptimize(
function_name="get_code_optimization_context",
parents=[FunctionParent("Optimizer", "ClassDef")],
@ -3210,7 +3206,7 @@ import unittest
"""
# pytest-timeout handles timeouts now, no timeout_decorator needed
imports += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.sleeptime import accurate_sleepfunc"
test_decorator = "" # pytest-timeout handles timeouts now
test_class = """class TestPigLatin(unittest.TestCase):
@ -3222,7 +3218,7 @@ import unittest
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
output = codeflash_wrap(accurate_sleepfunc, '{module_path}', 'TestPigLatin', 'test_sleepfunc_sequence_short', 'accurate_sleepfunc', '0', codeflash_loop_index, n)
"""
expected = imports + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/sleeptime.py").resolve()
test_path = (

View file

@ -6,6 +6,7 @@ from argparse import Namespace
from pathlib import Path
import isort
from code_to_optimize.bubble_sort_method import BubbleSorter
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.formatter import sort_imports
@ -403,7 +404,7 @@ class BubbleSorter:
assert test_results_mutated_attr[0].return_value[0] == {"x": 1}
assert test_results_mutated_attr[0].verification_type == VerificationType.INIT_STATE_FTO
assert test_results_mutated_attr[0].stdout == ""
match,_ = compare_test_results(
match, _ = compare_test_results(
test_results, test_results_mutated_attr
) # The test should fail because the instance attribute was mutated
assert not match
@ -458,7 +459,7 @@ class BubbleSorter:
assert test_results_new_attr[0].stdout == ""
# assert test_results_new_attr[1].return_value[1]["self"].x == 0 TODO: add self as input
# assert test_results_new_attr[1].return_value[1]["self"].y == 2 TODO: add self as input
match,_ = compare_test_results(
match, _ = compare_test_results(
test_results, test_results_new_attr
) # The test should pass because the instance attribute was not mutated, only a new one was added
assert match

View file

@ -2,8 +2,6 @@
from unittest.mock import patch
import pytest
from codeflash.code_utils.code_extractor import is_numerical_code

View file

@ -6,14 +6,7 @@ covering all patterns that might be seen in the wild.
from __future__ import annotations
import pytest
from codeflash.languages.javascript.instrument import (
ExpectCallTransformer,
TestingMode,
instrument_generated_js_test,
transform_expect_calls,
)
from codeflash.languages.javascript.instrument import TestingMode, instrument_generated_js_test, transform_expect_calls
class TestExpectCallTransformer:
@ -656,4 +649,4 @@ describe('calculatePi', () => {
});"""
result = instrument_generated_js_test(code, "calculatePi", "calculatePi", TestingMode.BEHAVIOR)
assert result.count("codeflash.capture(") == 2
assert ".toBeCloseTo(" not in result
assert ".toBeCloseTo(" not in result

View file

@ -1,18 +1,12 @@
"""
Tests for JavaScript function discovery in get_functions_to_optimize.
"""Tests for JavaScript function discovery in get_functions_to_optimize.
These tests verify that JavaScript functions are correctly discovered,
filtered, and returned from the function discovery pipeline.
"""
import tempfile
import unittest.mock
from pathlib import Path
import pytest
from codeflash.discovery.functions_to_optimize import (
FunctionToOptimize,
filter_functions,
find_all_functions_in_file,
get_all_files_and_functions,
@ -233,11 +227,7 @@ function add(a, b) {
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered, count = filter_functions(
functions,
tests_root=tmp_path / "tests",
ignore_paths=[],
project_root=tmp_path,
module_root=tmp_path,
functions, tests_root=tmp_path / "tests", ignore_paths=[], project_root=tmp_path, module_root=tmp_path
)
assert js_file in filtered
@ -258,11 +248,7 @@ function testHelper() {
modified_functions = {test_file: functions.get(test_file, [])}
filtered, count = filter_functions(
modified_functions,
tests_root=tests_dir,
ignore_paths=[],
project_root=tmp_path,
module_root=tmp_path,
modified_functions, tests_root=tests_dir, ignore_paths=[], project_root=tmp_path, module_root=tmp_path
)
assert test_file not in filtered

View file

@ -1,5 +1,4 @@
"""
Extensive tests for the language abstraction base types.
"""Extensive tests for the language abstraction base types.
These tests verify that the core data structures work correctly
and maintain their contracts.
@ -92,12 +91,7 @@ class TestFunctionInfo:
def test_function_info_creation_minimal(self):
"""Test creating FunctionInfo with minimal args."""
func = FunctionInfo(
name="add",
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
)
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
assert func.name == "add"
assert func.file_path == Path("/test/example.py")
assert func.start_line == 1
@ -131,23 +125,13 @@ class TestFunctionInfo:
def test_function_info_frozen(self):
"""Test that FunctionInfo is immutable."""
func = FunctionInfo(
name="add",
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
)
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
with pytest.raises(AttributeError):
func.name = "new_name"
def test_qualified_name_no_parents(self):
"""Test qualified_name without parents."""
func = FunctionInfo(
name="add",
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
)
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
assert func.qualified_name == "add"
def test_qualified_name_with_class(self):
@ -168,10 +152,7 @@ class TestFunctionInfo:
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
parents=(
ParentInfo(name="Outer", type="ClassDef"),
ParentInfo(name="Inner", type="ClassDef"),
),
parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")),
)
assert func.qualified_name == "Outer.Inner.inner"
@ -188,12 +169,7 @@ class TestFunctionInfo:
def test_class_name_without_class(self):
"""Test class_name property without class parent."""
func = FunctionInfo(
name="add",
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
)
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
assert func.class_name is None
def test_class_name_nested_function(self):
@ -214,22 +190,14 @@ class TestFunctionInfo:
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
parents=(
ParentInfo(name="Outer", type="ClassDef"),
ParentInfo(name="Inner", type="ClassDef"),
),
parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")),
)
# Should return the immediate parent class
assert func.class_name == "Inner"
def test_top_level_parent_name_no_parents(self):
"""Test top_level_parent_name without parents."""
func = FunctionInfo(
name="add",
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
)
func = FunctionInfo(name="add", file_path=Path("/test/example.py"), start_line=1, end_line=3)
assert func.top_level_parent_name == "add"
def test_top_level_parent_name_with_parents(self):
@ -239,10 +207,7 @@ class TestFunctionInfo:
file_path=Path("/test/example.py"),
start_line=1,
end_line=3,
parents=(
ParentInfo(name="Outer", type="ClassDef"),
ParentInfo(name="Inner", type="ClassDef"),
),
parents=(ParentInfo(name="Outer", type="ClassDef"), ParentInfo(name="Inner", type="ClassDef")),
)
assert func.top_level_parent_name == "Outer"
@ -285,10 +250,7 @@ class TestCodeContext:
def test_code_context_creation_minimal(self):
"""Test creating CodeContext with minimal args."""
ctx = CodeContext(
target_code="def add(a, b): return a + b",
target_file=Path("/test/example.py"),
)
ctx = CodeContext(target_code="def add(a, b): return a + b", target_file=Path("/test/example.py"))
assert ctx.target_code == "def add(a, b): return a + b"
assert ctx.target_file == Path("/test/example.py")
assert ctx.helper_functions == []
@ -325,38 +287,24 @@ class TestTestInfo:
def test_test_info_creation(self):
"""Test creating TestInfo."""
info = TestInfo(
test_name="test_add",
test_file=Path("/tests/test_calc.py"),
test_class="TestCalculator",
)
info = TestInfo(test_name="test_add", test_file=Path("/tests/test_calc.py"), test_class="TestCalculator")
assert info.test_name == "test_add"
assert info.test_file == Path("/tests/test_calc.py")
assert info.test_class == "TestCalculator"
def test_test_info_without_class(self):
"""Test TestInfo without test class."""
info = TestInfo(
test_name="test_add",
test_file=Path("/tests/test_calc.py"),
)
info = TestInfo(test_name="test_add", test_file=Path("/tests/test_calc.py"))
assert info.test_class is None
def test_full_test_path_with_class(self):
"""Test full_test_path with class."""
info = TestInfo(
test_name="test_add",
test_file=Path("/tests/test_calc.py"),
test_class="TestCalculator",
)
info = TestInfo(test_name="test_add", test_file=Path("/tests/test_calc.py"), test_class="TestCalculator")
assert info.full_test_path == "/tests/test_calc.py::TestCalculator::test_add"
def test_full_test_path_without_class(self):
"""Test full_test_path without class."""
info = TestInfo(
test_name="test_add",
test_file=Path("/tests/test_calc.py"),
)
info = TestInfo(test_name="test_add", test_file=Path("/tests/test_calc.py"))
assert info.full_test_path == "/tests/test_calc.py::test_add"
@ -448,10 +396,7 @@ class TestConvertParentsToTuple:
self.name = name
self.type = type_
parents = [
MockParent("Outer", "ClassDef"),
MockParent("inner", "FunctionDef"),
]
parents = [MockParent("Outer", "ClassDef"), MockParent("inner", "FunctionDef")]
result = convert_parents_to_tuple(parents)
assert len(result) == 2

View file

@ -1,5 +1,4 @@
"""
Tests for the integrated multi-language function discovery.
"""Tests for the integrated multi-language function discovery.
These tests verify that the function discovery in functions_to_optimize.py
correctly routes to language-specific implementations.
@ -8,10 +7,7 @@ correctly routes to language-specific implementations.
import tempfile
from pathlib import Path
import pytest
from codeflash.discovery.functions_to_optimize import (
FunctionToOptimize,
find_all_functions_in_file,
get_all_files_and_functions,
get_files_for_language,

View file

@ -4,16 +4,10 @@ These tests verify that the ImportResolver correctly resolves import paths
to actual file paths, enabling multi-file context extraction.
"""
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from codeflash.languages.javascript.import_resolver import (
ImportResolver,
ResolvedImport,
MultiFileHelperFinder,
HelperSearchContext,
)
import pytest
from codeflash.languages.javascript.import_resolver import HelperSearchContext, ImportResolver, MultiFileHelperFinder
from codeflash.languages.treesitter_utils import ImportInfo
@ -246,16 +240,16 @@ class TestMultiFileHelperFinder:
src_dir.mkdir()
# Main file that imports helper
(src_dir / "main.ts").write_text('''
(src_dir / "main.ts").write_text("""
import { helperFunc } from './helper';
export function mainFunc() {
return helperFunc() + 1;
}
''')
""")
# Helper file
(src_dir / "helper.ts").write_text('''
(src_dir / "helper.ts").write_text("""
export function helperFunc() {
return 42;
}
@ -263,7 +257,7 @@ export function helperFunc() {
export function unusedHelper() {
return 0;
}
''')
""")
return tmp_path
@ -293,6 +287,7 @@ class TestExportInfo:
def js_analyzer(self):
"""Create a JavaScript analyzer."""
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
def test_find_named_export_function(self, js_analyzer):
@ -394,6 +389,7 @@ class TestCommonJSRequire:
def js_analyzer(self):
"""Create a JavaScript analyzer."""
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
def test_require_default_import(self, js_analyzer):
@ -475,6 +471,7 @@ class TestCommonJSExports:
def js_analyzer(self):
"""Create a JavaScript analyzer."""
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
def test_module_exports_function(self, js_analyzer):

View file

@ -1,5 +1,4 @@
"""
End-to-end integration tests for JavaScript pipeline.
"""End-to-end integration tests for JavaScript pipeline.
Tests the full optimization pipeline for JavaScript:
- Function discovery
@ -13,11 +12,7 @@ from pathlib import Path
import pytest
from codeflash.discovery.functions_to_optimize import (
FunctionToOptimize,
find_all_functions_in_file,
get_files_for_language,
)
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file, get_files_for_language
from codeflash.languages.base import Language
@ -149,11 +144,7 @@ function multiply(a, b) {
# Create FunctionInfo for the add function
func_info = FunctionInfo(
name="add",
file_path=Path("/tmp/test.js"),
start_line=2,
end_line=4,
language=Language.JAVASCRIPT,
name="add", file_path=Path("/tmp/test.js"), start_line=2, end_line=4, language=Language.JAVASCRIPT
)
result = js_support.replace_function(original_source, func_info, new_function)
@ -189,11 +180,7 @@ class TestJavaScriptTestDiscovery:
# Create FunctionInfo for fibonacci function
fib_file = js_project_dir / "fibonacci.js"
func_info = FunctionInfo(
name="fibonacci",
file_path=fib_file,
start_line=11,
end_line=16,
language=Language.JAVASCRIPT,
name="fibonacci", file_path=fib_file, start_line=11, end_line=16, language=Language.JAVASCRIPT
)
# Discover tests
@ -233,19 +220,13 @@ function standalone(x) {
assert len(functions.get(file_path, [])) >= 3
# Check standalone function
standalone_fn = next(
(fn for fn in functions[file_path] if fn.function_name == "standalone"),
None,
)
standalone_fn = next((fn for fn in functions[file_path] if fn.function_name == "standalone"), None)
assert standalone_fn is not None
assert standalone_fn.language == "javascript"
assert len(standalone_fn.parents) == 0
# Check class method
add_fn = next(
(fn for fn in functions[file_path] if fn.function_name == "add"),
None,
)
add_fn = next((fn for fn in functions[file_path] if fn.function_name == "add"), None)
assert add_fn is not None
assert add_fn.language == "javascript"
assert len(add_fn.parents) == 1
@ -258,9 +239,7 @@ function standalone(x) {
code_strings = CodeStringsMarkdown(
code_strings=[
CodeString(
code="function add(a, b) { return a + b; }",
file_path=Path("test.js"),
language="javascript",
code="function add(a, b) { return a + b; }", file_path=Path("test.js"), language="javascript"
)
],
language="javascript",

View file

@ -1,5 +1,4 @@
"""
Tests for JavaScript instrumentation (line profiling and tracing).
"""Tests for JavaScript instrumentation (line profiling and tracing).
This module tests the line profiling and tracing instrumentation for JavaScript code.
"""
@ -7,8 +6,6 @@ This module tests the line profiling and tracing instrumentation for JavaScript
import tempfile
from pathlib import Path
import pytest
from codeflash.languages.base import FunctionInfo, Language
from codeflash.languages.javascript.line_profiler import JavaScriptLineProfiler
from codeflash.languages.javascript.tracer import JavaScriptTracer
@ -52,11 +49,7 @@ function add(a, b) {
file_path = Path(f.name)
func_info = FunctionInfo(
name="add",
file_path=file_path,
start_line=2,
end_line=5,
language=Language.JAVASCRIPT,
name="add", file_path=file_path, start_line=2, end_line=5, language=Language.JAVASCRIPT
)
output_file = Path("/tmp/test_profile.json")
@ -117,11 +110,7 @@ function multiply(x, y) {
file_path = Path(f.name)
func_info = FunctionInfo(
name="multiply",
file_path=file_path,
start_line=2,
end_line=4,
language=Language.JAVASCRIPT,
name="multiply", file_path=file_path, start_line=2, end_line=4, language=Language.JAVASCRIPT
)
output_db = Path("/tmp/test_traces.db")
@ -165,17 +154,11 @@ function greet(name) {
file_path = Path(f.name)
func_info = FunctionInfo(
name="greet",
file_path=file_path,
start_line=2,
end_line=4,
language=Language.JAVASCRIPT,
name="greet", file_path=file_path, start_line=2, end_line=4, language=Language.JAVASCRIPT
)
output_file = file_path.parent / ".codeflash" / "traces.db"
instrumented = js_support.instrument_for_behavior(
source, [func_info], output_file=output_file
)
instrumented = js_support.instrument_for_behavior(source, [func_info], output_file=output_file)
assert "__codeflash_tracer__" in instrumented
assert "wrap" in instrumented
@ -202,11 +185,7 @@ function square(n) {
file_path = Path(f.name)
func_info = FunctionInfo(
name="square",
file_path=file_path,
start_line=2,
end_line=5,
language=Language.JAVASCRIPT,
name="square", file_path=file_path, start_line=2, end_line=5, language=Language.JAVASCRIPT
)
output_file = file_path.parent / ".codeflash" / "line_profile.json"
@ -373,10 +352,7 @@ const result = calc.fibonacci(10);
console.log(result);
"""
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="Calculator.fibonacci",
capture_func="capture",
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
)
# Should transform calc.fibonacci(10) to codeflash.capture(..., calc.fibonacci.bind(calc), 10)
@ -395,10 +371,7 @@ test('fibonacci works', () => {
});
"""
transformed, counter = transform_expect_calls(
code=code,
func_name="fibonacci",
qualified_name="FibonacciCalculator.fibonacci",
capture_func="capture",
code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture"
)
# Should transform expect(calc.fibonacci(10)) to
@ -446,10 +419,7 @@ class FibonacciCalculator {
}
"""
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="FibonacciCalculator.fibonacci",
capture_func="capture",
code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture"
)
# The method definition should NOT be transformed
@ -468,10 +438,7 @@ FibonacciCalculator.prototype.fibonacci = function(n) {
};
"""
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="FibonacciCalculator.fibonacci",
capture_func="capture",
code=code, func_name="fibonacci", qualified_name="FibonacciCalculator.fibonacci", capture_func="capture"
)
# The prototype assignment should NOT be transformed
@ -489,10 +456,7 @@ const b = calc.fibonacci(10);
const sum = a + b;
"""
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="Calculator.fibonacci",
capture_func="capture",
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
)
# Should transform both calls
@ -511,10 +475,7 @@ class Wrapper {
}
"""
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="Wrapper.fibonacci",
capture_func="capture",
code=code, func_name="fibonacci", qualified_name="Wrapper.fibonacci", capture_func="capture"
)
# Should transform this.fibonacci(n)
@ -524,9 +485,9 @@ class Wrapper {
def test_full_instrumentation_produces_valid_syntax(self):
"""Test that full instrumentation produces syntactically valid JavaScript."""
from codeflash.languages.javascript.instrument import _instrument_js_test_code
from codeflash.languages import get_language_support
from codeflash.languages.base import Language
from codeflash.languages.javascript.instrument import _instrument_js_test_code
js_support = get_language_support(Language.JAVASCRIPT)
@ -584,11 +545,7 @@ describe('Calculator', () => {
});
"""
instrumented = _instrument_js_test_code(
code=test_code,
func_name="add",
test_file_path="test.js",
mode="behavior",
qualified_name="Calculator.add",
code=test_code, func_name="add", test_file_path="test.js", mode="behavior", qualified_name="Calculator.add"
)
# describe and test structure should be preserved
@ -610,10 +567,7 @@ const data = await api.fetchData('http://example.com');
console.log(data);
"""
transformed, counter = transform_standalone_calls(
code=code,
func_name="fetchData",
qualified_name="ApiClient.fetchData",
capture_func="capture",
code=code, func_name="fetchData", qualified_name="ApiClient.fetchData", capture_func="capture"
)
# Should preserve await
@ -632,10 +586,7 @@ class TestInstrumentationFullStringEquality:
code = " calc.fibonacci(10);"
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="Calculator.fibonacci",
capture_func="capture",
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
)
expected = " codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10);"
@ -649,10 +600,7 @@ class TestInstrumentationFullStringEquality:
code = " expect(calc.fibonacci(10)).toBe(55);"
transformed, counter = transform_expect_calls(
code=code,
func_name="fibonacci",
qualified_name="Calculator.fibonacci",
capture_func="capture",
code=code, func_name="fibonacci", qualified_name="Calculator.fibonacci", capture_func="capture"
)
expected = " expect(codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10)).toBe(55);"
@ -684,10 +632,7 @@ class TestInstrumentationFullStringEquality:
code = " fibonacci(10);"
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="fibonacci",
capture_func="capture",
code=code, func_name="fibonacci", qualified_name="fibonacci", capture_func="capture"
)
expected = " codeflash.capture('fibonacci', '1', fibonacci, 10);"
@ -701,12 +646,9 @@ class TestInstrumentationFullStringEquality:
code = " return this.fibonacci(n - 1);"
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="Class.fibonacci",
capture_func="capture",
code=code, func_name="fibonacci", qualified_name="Class.fibonacci", capture_func="capture"
)
expected = " return codeflash.capture('Class.fibonacci', '1', this.fibonacci.bind(this), n - 1);"
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
assert counter == 1
assert counter == 1

View file

@ -1,18 +1,11 @@
"""
Tests for JavaScript module system detection.
"""Tests for JavaScript module system detection.
"""
import json
import tempfile
from pathlib import Path
import pytest
from codeflash.languages.javascript.module_system import (
ModuleSystem,
detect_module_system,
get_import_statement,
)
from codeflash.languages.javascript.module_system import ModuleSystem, detect_module_system, get_import_statement
class TestModuleSystemDetection:
@ -73,9 +66,7 @@ class TestModuleSystemDetection:
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
file_path = project_root / "module.js"
file_path.write_text(
"const foo = require('./bar');\nmodule.exports = { baz: 1 };"
)
file_path.write_text("const foo = require('./bar');\nmodule.exports = { baz: 1 };")
result = detect_module_system(project_root, file_path)
assert result == ModuleSystem.COMMONJS
@ -99,9 +90,7 @@ class TestImportStatementGeneration:
target = tmpdir / "lib" / "utils.js"
source = tmpdir / "tests" / "utils.test.js"
result = get_import_statement(
ModuleSystem.COMMONJS, target, source, ["foo", "bar"]
)
result = get_import_statement(ModuleSystem.COMMONJS, target, source, ["foo", "bar"])
assert result == "const { foo, bar } = require('../lib/utils');"
@ -112,9 +101,7 @@ class TestImportStatementGeneration:
target = tmpdir / "lib" / "utils.js"
source = tmpdir / "tests" / "utils.test.js"
result = get_import_statement(
ModuleSystem.ES_MODULE, target, source, ["foo", "bar"]
)
result = get_import_statement(ModuleSystem.ES_MODULE, target, source, ["foo", "bar"])
assert result == "import { foo, bar } from '../lib/utils';"
@ -147,9 +134,7 @@ class TestImportStatementGeneration:
target = tmpdir / "utils.js"
source = tmpdir / "index.js"
result = get_import_statement(
ModuleSystem.COMMONJS, target, source, ["foo"]
)
result = get_import_statement(ModuleSystem.COMMONJS, target, source, ["foo"])
assert result == "const { foo } = require('./utils');"
@ -160,9 +145,7 @@ class TestImportStatementGeneration:
target = tmpdir / "lib" / "helpers" / "utils.js"
source = tmpdir / "tests" / "test.js"
result = get_import_statement(
ModuleSystem.COMMONJS, target, source, ["foo"]
)
result = get_import_statement(ModuleSystem.COMMONJS, target, source, ["foo"])
assert result == "const { foo } = require('../lib/helpers/utils');"
@ -173,8 +156,6 @@ class TestImportStatementGeneration:
target = tmpdir / "utils.js"
source = tmpdir / "tests" / "unit" / "test.js"
result = get_import_statement(
ModuleSystem.COMMONJS, target, source, ["foo"]
)
result = get_import_statement(ModuleSystem.COMMONJS, target, source, ["foo"])
assert result == "const { foo } = require('../../utils');"
assert result == "const { foo } = require('../../utils');"

View file

@ -1,5 +1,4 @@
"""
Extensive tests for the JavaScript language support implementation.
"""Extensive tests for the JavaScript language support implementation.
These tests verify that JavaScriptSupport correctly discovers functions,
replaces code, and integrates with the codeflash language abstraction.
@ -10,12 +9,7 @@ from pathlib import Path
import pytest
from codeflash.languages.base import (
FunctionFilterCriteria,
FunctionInfo,
Language,
ParentInfo,
)
from codeflash.languages.base import FunctionFilterCriteria, FunctionInfo, Language, ParentInfo
from codeflash.languages.javascript.support import JavaScriptSupport
@ -322,12 +316,7 @@ function multiply(a, b) {
return a * b;
}
"""
func = FunctionInfo(
name="add",
file_path=Path("/test.js"),
start_line=1,
end_line=3,
)
func = FunctionInfo(name="add", file_path=Path("/test.js"), start_line=1, end_line=3)
new_code = """function add(a, b) {
// Optimized
return (a + b) | 0;
@ -354,12 +343,7 @@ function other() {
// Footer
"""
func = FunctionInfo(
name="target",
file_path=Path("/test.js"),
start_line=4,
end_line=6,
)
func = FunctionInfo(name="target", file_path=Path("/test.js"), start_line=4, end_line=6)
new_code = """function target() {
return 42;
}
@ -407,12 +391,7 @@ function other() {
const multiply = (x, y) => x * y;
"""
func = FunctionInfo(
name="add",
file_path=Path("/test.js"),
start_line=1,
end_line=3,
)
func = FunctionInfo(name="add", file_path=Path("/test.js"), start_line=1, end_line=3)
new_code = """const add = (a, b) => {
return (a + b) | 0;
};
@ -504,18 +483,9 @@ class TestExtractCodeContext:
f.flush()
file_path = Path(f.name)
func = FunctionInfo(
name="add",
file_path=file_path,
start_line=1,
end_line=3,
)
func = FunctionInfo(name="add", file_path=file_path, start_line=1, end_line=3)
context = js_support.extract_code_context(
func,
file_path.parent,
file_path.parent,
)
context = js_support.extract_code_context(func, file_path.parent, file_path.parent)
assert "function add" in context.target_code
assert "return a + b" in context.target_code
@ -540,11 +510,7 @@ function main(a) {
functions = js_support.discover_functions(file_path)
main_func = next(f for f in functions if f.name == "main")
context = js_support.extract_code_context(
main_func,
file_path.parent,
file_path.parent,
)
context = js_support.extract_code_context(main_func, file_path.parent, file_path.parent)
assert "function main" in context.target_code
# Helper should be found
@ -689,6 +655,7 @@ describe('Math functions', () => {
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1035,10 +1002,14 @@ class TestClassMethodReplacement:
parents=(ParentInfo(name="Math", type="ClassDef"),),
is_method=True,
)
source = js_support.replace_function(source, add_func, """ add(a, b) {
source = js_support.replace_function(
source,
add_func,
""" add(a, b) {
return (a + b) | 0;
}
""")
""",
)
assert js_support.validate_syntax(source) is True
@ -1346,8 +1317,7 @@ class Counter {
module.exports = { Counter };
"""
assert result == expected_result, (
f"Replacement result does not match expected.\n"
f"Expected:\n{expected_result}\n\nGot:\n{result}"
f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}"
)
assert js_support.validate_syntax(result) is True
@ -1455,8 +1425,7 @@ class User {
export { User };
"""
assert result == expected_result, (
f"Replacement result does not match expected.\n"
f"Expected:\n{expected_result}\n\nGot:\n{result}"
f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}"
)
assert ts_support.validate_syntax(result) is True
@ -1544,8 +1513,7 @@ class Calculator {
}
"""
assert result == expected_result, (
f"Replacement result does not match expected.\n"
f"Expected:\n{expected_result}\n\nGot:\n{result}"
f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}"
)
assert js_support.validate_syntax(result) is True
@ -1631,7 +1599,6 @@ class MathUtils {
module.exports = { MathUtils };
"""
assert result == expected_result, (
f"Replacement result does not match expected.\n"
f"Expected:\n{expected_result}\n\nGot:\n{result}"
f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}"
)
assert js_support.validate_syntax(result) is True

View file

@ -1,5 +1,4 @@
"""
Comprehensive tests for JavaScript test discovery functionality.
"""Comprehensive tests for JavaScript test discovery functionality.
These tests verify that the JavaScript language support correctly discovers
Jest tests and maps them to source functions, similar to Python's test discovery tests.
@ -10,7 +9,6 @@ from pathlib import Path
import pytest
from codeflash.languages.base import FunctionInfo, Language
from codeflash.languages.javascript.support import JavaScriptSupport
@ -630,6 +628,7 @@ it('third test', () => {});
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -654,6 +653,7 @@ describe('Suite B', () => {
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -677,6 +677,7 @@ describe('Outer', () => {
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -700,6 +701,7 @@ describe.skip('skipped describe', () => {
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -720,6 +722,7 @@ describe.only('only describe', () => {
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -737,6 +740,7 @@ describe('describe single', () => {});
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -746,15 +750,16 @@ describe('describe single', () => {});
def test_find_tests_with_double_quotes(self, js_support):
"""Test finding tests with double-quoted names."""
with tempfile.NamedTemporaryFile(suffix=".test.js", mode="w", delete=False) as f:
f.write('''
f.write("""
test("double quotes", () => {});
describe("describe double", () => {});
''')
""")
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -770,6 +775,7 @@ describe("describe double", () => {});
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1015,6 +1021,7 @@ describe('日本語テスト', () => {
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1042,6 +1049,7 @@ test.each([
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1067,6 +1075,7 @@ describe.each([
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1091,6 +1100,7 @@ describe('Math operations', () => {
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1178,8 +1188,7 @@ describe('formatName', () => {
for test_list in tests.values():
all_test_names.extend([t.test_name for t in test_list])
assert any("validateEmail" in name or "accepts valid email" in name
for name in all_test_names)
assert any("validateEmail" in name or "accepts valid email" in name for name in all_test_names)
def test_discovery_with_fixtures(self, js_support):
"""Test discovery when test file uses beforeEach/afterEach."""
@ -1449,6 +1458,7 @@ testCases.forEach(name => {
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1476,6 +1486,7 @@ describe('conditional tests', () => {
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1499,6 +1510,7 @@ test('slow test', () => {
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1521,6 +1533,7 @@ test.todo('also needs implementation');
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1543,6 +1556,7 @@ test.concurrent('concurrent test 2', async () => {
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1641,6 +1655,7 @@ describe('Array', function() {
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)
@ -1671,6 +1686,7 @@ describe('User', () => {
source = file_path.read_text()
from codeflash.languages.treesitter_utils import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
test_names = js_support._find_jest_tests(source, analyzer)

View file

@ -7,6 +7,7 @@ import shutil
from pathlib import Path
import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
@ -86,9 +87,7 @@ class Calculator {
assert context.target_code is not None, "target_code should not be None"
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{context.target_code}"
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_extract_context_includes_direct_helpers(self, js_support, cjs_project):
@ -175,9 +174,7 @@ class Calculator {
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{context.target_code}"
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_extract_compound_interest_helpers(self, js_support, cjs_project):
@ -282,9 +279,7 @@ function validateInput(value, name) {
assert len(context.imports) == 2, f"Expected 2 imports, got {len(context.imports)}: {context.imports}"
assert context.imports == expected_imports, (
f"Imports do not match expected.\n"
f"Expected:\n{expected_imports}\n\n"
f"Got:\n{context.imports}"
f"Imports do not match expected.\nExpected:\n{expected_imports}\n\nGot:\n{context.imports}"
)
def test_extract_static_method(self, js_support, cjs_project):
@ -314,9 +309,7 @@ class Calculator {
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{context.target_code}"
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
# quickAdd uses add helper from math_utils
@ -397,9 +390,7 @@ class Calculator {
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{context.target_code}"
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
# ESM permutation uses factorial helper
@ -460,9 +451,7 @@ class Calculator {
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{context.target_code}"
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
expected_imports = [
@ -472,9 +461,7 @@ class Calculator {
assert len(context.imports) == 2, f"Expected 2 imports, got {len(context.imports)}: {context.imports}"
assert context.imports == expected_imports, (
f"Imports do not match expected.\n"
f"Expected:\n{expected_imports}\n\n"
f"Got:\n{context.imports}"
f"Imports do not match expected.\nExpected:\n{expected_imports}\n\nGot:\n{context.imports}"
)
# ESM compound interest uses 4 helpers
@ -593,9 +580,7 @@ class Calculator {
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{context.target_code}"
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
# TypeScript permutation uses factorial helper
@ -659,9 +644,7 @@ class Calculator {
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{context.target_code}"
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
# TypeScript compound interest uses 4 helpers
@ -890,9 +873,7 @@ module.exports = { Counter };
functions = js_support.discover_functions(test_file)
increment_func = next(f for f in functions if f.name == "increment")
context = js_support.extract_code_context(
function=increment_func, project_root=tmp_path, module_root=tmp_path
)
context = js_support.extract_code_context(function=increment_func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
class Counter {
@ -907,9 +888,7 @@ class Counter {
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{context.target_code}"
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_method_extraction_class_without_constructor(self, js_support, tmp_path):
@ -933,9 +912,7 @@ module.exports = { MathUtils };
functions = js_support.discover_functions(test_file)
add_func = next(f for f in functions if f.name == "add")
context = js_support.extract_code_context(
function=add_func, project_root=tmp_path, module_root=tmp_path
)
context = js_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
class MathUtils {
@ -945,9 +922,7 @@ class MathUtils {
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{context.target_code}"
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_typescript_method_extraction_includes_fields(self, ts_support, tmp_path):
@ -975,9 +950,7 @@ export { User };
functions = ts_support.discover_functions(test_file)
get_name_func = next(f for f in functions if f.name == "getName")
context = ts_support.extract_code_context(
function=get_name_func, project_root=tmp_path, module_root=tmp_path
)
context = ts_support.extract_code_context(function=get_name_func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
class User {
@ -995,9 +968,7 @@ class User {
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{context.target_code}"
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_typescript_fields_only_no_constructor(self, ts_support, tmp_path):
@ -1020,9 +991,7 @@ export { Config };
functions = ts_support.discover_functions(test_file)
get_url_func = next(f for f in functions if f.name == "getUrl")
context = ts_support.extract_code_context(
function=get_url_func, project_root=tmp_path, module_root=tmp_path
)
context = ts_support.extract_code_context(function=get_url_func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
class Config {
@ -1035,9 +1004,7 @@ class Config {
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{context.target_code}"
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_constructor_with_jsdoc(self, js_support, tmp_path):
@ -1065,9 +1032,7 @@ module.exports = { Logger };
functions = js_support.discover_functions(test_file)
get_prefix_func = next(f for f in functions if f.name == "getPrefix")
context = js_support.extract_code_context(
function=get_prefix_func, project_root=tmp_path, module_root=tmp_path
)
context = js_support.extract_code_context(function=get_prefix_func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
class Logger {
@ -1085,9 +1050,7 @@ class Logger {
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{context.target_code}"
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
def test_static_method_includes_constructor(self, js_support, tmp_path):
@ -1111,9 +1074,7 @@ module.exports = { Factory };
functions = js_support.discover_functions(test_file)
create_func = next(f for f in functions if f.name == "create")
context = js_support.extract_code_context(
function=create_func, project_root=tmp_path, module_root=tmp_path
)
context = js_support.extract_code_context(function=create_func, project_root=tmp_path, module_root=tmp_path)
expected_code = """\
class Factory {
@ -1127,9 +1088,7 @@ class Factory {
}"""
assert context.target_code.strip() == expected_code.strip(), (
f"Extracted code does not match expected.\n"
f"Expected:\n{expected_code}\n\n"
f"Got:\n{context.target_code}"
f"Extracted code does not match expected.\nExpected:\n{expected_code}\n\nGot:\n{context.target_code}"
)
@ -1264,9 +1223,7 @@ export { distance };
functions = ts_support.discover_functions(test_file)
distance_func = next(f for f in functions if f.name == "distance")
context = ts_support.extract_code_context(
function=distance_func, project_root=tmp_path, module_root=tmp_path
)
context = ts_support.extract_code_context(function=distance_func, project_root=tmp_path, module_root=tmp_path)
# Type definition should be in read-only context with exact match
expected_read_only = """\
@ -1310,9 +1267,7 @@ export { processStatus };
functions = ts_support.discover_functions(test_file)
process_func = next(f for f in functions if f.name == "processStatus")
context = ts_support.extract_code_context(
function=process_func, project_root=tmp_path, module_root=tmp_path
)
context = ts_support.extract_code_context(function=process_func, project_root=tmp_path, module_root=tmp_path)
# Enum should be in read-only context with exact match
expected_read_only = """\
@ -1349,9 +1304,7 @@ export { compute };
functions = ts_support.discover_functions(test_file)
compute_func = next(f for f in functions if f.name == "compute")
context = ts_support.extract_code_context(
function=compute_func, project_root=tmp_path, module_root=tmp_path
)
context = ts_support.extract_code_context(function=compute_func, project_root=tmp_path, module_root=tmp_path)
# Type alias should be in read-only context with exact match
expected_read_only = """\
@ -1428,9 +1381,7 @@ export { add };
functions = ts_support.discover_functions(test_file)
add_func = next(f for f in functions if f.name == "add")
context = ts_support.extract_code_context(
function=add_func, project_root=tmp_path, module_root=tmp_path
)
context = ts_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
# No type definitions should be extracted for primitives - exact empty match
assert context.read_only_context == "", (
@ -1564,9 +1515,7 @@ export { greetUser };
functions = ts_support.discover_functions(test_file)
greet_func = next(f for f in functions if f.name == "greetUser")
context = ts_support.extract_code_context(
function=greet_func, project_root=tmp_path, module_root=tmp_path
)
context = ts_support.extract_code_context(function=greet_func, project_root=tmp_path, module_root=tmp_path)
# JSDoc should be included with the interface - exact match
expected_read_only = """\

View file

@ -12,6 +12,7 @@ import shutil
from pathlib import Path
import pytest
from codeflash.languages.javascript.module_system import (
ModuleSystem,
convert_commonjs_to_esm,

View file

@ -1,5 +1,4 @@
"""
Regression tests for Python/JavaScript language support parity.
"""Regression tests for Python/JavaScript language support parity.
These tests ensure that the JavaScript implementation maintains feature parity
with the Python implementation. Each test class tests equivalent functionality
@ -15,12 +14,7 @@ from typing import NamedTuple
import pytest
from codeflash.languages.base import (
FunctionFilterCriteria,
FunctionInfo,
Language,
ParentInfo,
)
from codeflash.languages.base import FunctionFilterCriteria, FunctionInfo, Language, ParentInfo
from codeflash.languages.javascript.support import JavaScriptSupport
from codeflash.languages.python.support import PythonSupport
@ -732,8 +726,8 @@ function other() {
js_method_line = next(l for l in js_lines if "add(a, b)" in l)
# Both should have indentation (4 spaces)
assert py_method_line.startswith(" "), f"Python method should be indented: {repr(py_method_line)}"
assert js_method_line.startswith(" "), f"JavaScript method should be indented: {repr(js_method_line)}"
assert py_method_line.startswith(" "), f"Python method should be indented: {py_method_line!r}"
assert js_method_line.startswith(" "), f"JavaScript method should be indented: {js_method_line!r}"
# ============================================================================

View file

@ -79,6 +79,7 @@ function findMin(numbers) {
"""
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.registry import get_language_support
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown
@ -90,6 +91,7 @@ class Args:
disable_imports_sorting = True
formatter_cmds = ["disabled"]
def test_js_replcement() -> None:
try:
root_dir = Path(__file__).parent.parent.parent.resolve()
@ -128,11 +130,14 @@ def test_js_replcement() -> None:
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
result = func_optimizer.get_code_optimization_context()
from codeflash.either import is_successful
if not is_successful(result):
import pytest
pytest.skip(f"Context extraction not fully implemented for JS: {result.failure() if hasattr(result, 'failure') else result}")
code_context: CodeOptimizationContext = result.unwrap()
pytest.skip(
f"Context extraction not fully implemented for JS: {result.failure() if hasattr(result, 'failure') else result}"
)
code_context: CodeOptimizationContext = result.unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -143,7 +148,9 @@ def test_js_replcement() -> None:
func_optimizer.args = Args()
did_update = func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(new_code), original_helper_code=original_helper_code
code_context=code_context,
optimized_code=CodeStringsMarkdown.parse_markdown_code(new_code),
original_helper_code=original_helper_code,
)
assert did_update, "Expected code to be updated"
@ -319,7 +326,9 @@ module.exports = {
"""
assert main_code == expected_main, f"Main file mismatch.\n\nActual:\n{main_code}\n\nExpected:\n{expected_main}"
assert helper_code == expected_helper, f"Helper file mismatch.\n\nActual:\n{helper_code}\n\nExpected:\n{expected_helper}"
assert helper_code == expected_helper, (
f"Helper file mismatch.\n\nActual:\n{helper_code}\n\nExpected:\n{expected_helper}"
)
finally:
main_file.write_text(original_main, encoding="utf-8")

View file

@ -1,5 +1,4 @@
"""
Extensive tests for the Python language support implementation.
"""Extensive tests for the Python language support implementation.
These tests verify that PythonSupport correctly discovers functions,
replaces code, and integrates with existing codeflash functionality.
@ -10,12 +9,7 @@ from pathlib import Path
import pytest
from codeflash.languages.base import (
FunctionFilterCriteria,
FunctionInfo,
Language,
ParentInfo,
)
from codeflash.languages.base import FunctionFilterCriteria, FunctionInfo, Language, ParentInfo
from codeflash.languages.python.support import PythonSupport
@ -269,12 +263,7 @@ class TestReplaceFunction:
def multiply(a, b):
return a * b
"""
func = FunctionInfo(
name="add",
file_path=Path("/test.py"),
start_line=1,
end_line=2,
)
func = FunctionInfo(name="add", file_path=Path("/test.py"), start_line=1, end_line=2)
new_code = """def add(a, b):
# Optimized
return (a + b) | 0
@ -298,12 +287,7 @@ def other():
# Footer
"""
func = FunctionInfo(
name="target",
file_path=Path("/test.py"),
start_line=4,
end_line=5,
)
func = FunctionInfo(name="target", file_path=Path("/test.py"), start_line=4, end_line=5)
new_code = """def target():
return 42
"""
@ -347,12 +331,7 @@ def other():
def second():
return 2
"""
func = FunctionInfo(
name="first",
file_path=Path("/test.py"),
start_line=1,
end_line=2,
)
func = FunctionInfo(name="first", file_path=Path("/test.py"), start_line=1, end_line=2)
new_code = """def first():
return 100
"""
@ -369,12 +348,7 @@ def second():
def last():
return 999
"""
func = FunctionInfo(
name="last",
file_path=Path("/test.py"),
start_line=4,
end_line=5,
)
func = FunctionInfo(name="last", file_path=Path("/test.py"), start_line=4, end_line=5)
new_code = """def last():
return 1000
"""
@ -388,12 +362,7 @@ def last():
source = """def only():
return 42
"""
func = FunctionInfo(
name="only",
file_path=Path("/test.py"),
start_line=1,
end_line=2,
)
func = FunctionInfo(name="only", file_path=Path("/test.py"), start_line=1, end_line=2)
new_code = """def only():
return 100
"""
@ -505,18 +474,9 @@ class TestExtractCodeContext:
f.flush()
file_path = Path(f.name)
func = FunctionInfo(
name="add",
file_path=file_path,
start_line=1,
end_line=2,
)
func = FunctionInfo(name="add", file_path=file_path, start_line=1, end_line=2)
context = python_support.extract_code_context(
func,
file_path.parent,
file_path.parent,
)
context = python_support.extract_code_context(func, file_path.parent, file_path.parent)
assert "def add" in context.target_code
assert "return a + b" in context.target_code

View file

@ -1,5 +1,4 @@
"""
Extensive tests for the language registry module.
"""Extensive tests for the language registry module.
These tests verify that language registration, lookup, and detection
work correctly.
@ -10,7 +9,7 @@ from pathlib import Path
import pytest
from codeflash.languages.base import Language, LanguageSupport
from codeflash.languages.base import Language
from codeflash.languages.registry import (
UnsupportedLanguageError,
clear_cache,
@ -28,7 +27,6 @@ from codeflash.languages.registry import (
def setup_registry():
"""Ensure PythonSupport is registered before each test."""
# Import to trigger registration
from codeflash.languages.python import PythonSupport
yield
# Clear cache after each test to avoid side effects
@ -92,9 +90,7 @@ class TestGetLanguageSupport:
"""Test that unsupported extension raises UnsupportedLanguageError."""
with pytest.raises(UnsupportedLanguageError) as exc_info:
get_language_support(Path("/test/example.xyz"))
assert "xyz" in str(exc_info.value.identifier) or "example.xyz" in str(
exc_info.value.identifier
)
assert "xyz" in str(exc_info.value.identifier) or "example.xyz" in str(exc_info.value.identifier)
def test_unsupported_language_raises(self):
"""Test that unsupported language name raises UnsupportedLanguageError."""
@ -175,9 +171,8 @@ class TestDetectProjectLanguage:
def test_detect_empty_project_raises(self):
"""Test that empty project raises UnsupportedLanguageError."""
with tempfile.TemporaryDirectory() as tmpdir:
with pytest.raises(UnsupportedLanguageError):
detect_project_language(Path(tmpdir), Path(tmpdir))
with tempfile.TemporaryDirectory() as tmpdir, pytest.raises(UnsupportedLanguageError):
detect_project_language(Path(tmpdir), Path(tmpdir))
def test_detect_with_different_roots(self):
"""Test detection with different project and module roots."""

View file

@ -1,20 +1,14 @@
"""
Extensive tests for the tree-sitter utilities module.
"""Extensive tests for the tree-sitter utilities module.
These tests verify that the TreeSitterAnalyzer correctly parses and
analyzes JavaScript/TypeScript code.
"""
from pathlib import Path
import pytest
from codeflash.languages.treesitter_utils import (
FunctionNode,
ImportInfo,
TreeSitterAnalyzer,
TreeSitterLanguage,
get_analyzer_for_file,
)
from pathlib import Path
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file
class TestTreeSitterLanguage:

View file

@ -1,4 +1,5 @@
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer
@ -9,11 +10,13 @@ class Args:
disable_imports_sorting = True
formatter_cmds = ["disabled"]
def test_multi_file_replcement01() -> None:
root_dir = Path(__file__).parent.parent.resolve()
helper_file = (root_dir / "code_to_optimize/temp_helper.py").resolve()
helper_file.write_text("""import re
helper_file.write_text(
"""import re
from collections.abc import Sequence
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
@ -36,7 +39,9 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
# TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl.
return tokens
""", encoding="utf-8")
""",
encoding="utf-8",
)
main_file = (root_dir / "code_to_optimize/temp_main.py").resolve()
@ -93,8 +98,6 @@ def _get_string_usage(text: str) -> Usage:
```
"""
func = FunctionToOptimize(function_name="_get_string_usage", parents=[], file_path=main_file)
test_config = TestConfig(
tests_root=root_dir / "tests/pytest",
@ -106,8 +109,6 @@ def _get_string_usage(text: str) -> Usage:
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
@ -117,11 +118,13 @@ def _get_string_usage(text: str) -> Usage:
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code
code_context=code_context,
optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code),
original_helper_code=original_helper_code,
)
new_code = main_file.read_text(encoding="utf-8")
new_helper_code = helper_file.read_text(encoding="utf-8")
helper_file.unlink(missing_ok=True)
main_file.unlink(missing_ok=True)
@ -160,5 +163,5 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
return tokens
"""
assert new_code.rstrip() == original_main.rstrip() # No Change
assert new_helper_code.rstrip() == expected_helper.rstrip()
assert new_code.rstrip() == original_main.rstrip() # No Change
assert new_helper_code.rstrip() == expected_helper.rstrip()

View file

@ -1,9 +1,8 @@
from codeflash.verification.parse_test_output import parse_test_failures_from_stdout
def test_extracting_single_pytest_error_from_stdout():
stdout = '''
stdout = """
F... [100%]
=================================== FAILURES ===================================
_______________________ test_calculate_portfolio_metrics _______________________
@ -38,11 +37,13 @@ FAILED code_to_optimize/tests/pytest/test_multiple_helpers.py::test_calculate_po
1 failed, 3 passed in 0.15s
'''
"""
errors = parse_test_failures_from_stdout(stdout)
assert errors
assert len(errors.keys()) == 1
assert errors['test_calculate_portfolio_metrics'] == '''
assert (
errors["test_calculate_portfolio_metrics"]
== """
def test_calculate_portfolio_metrics():
# Test case 1: Basic portfolio
investments = [
@ -68,13 +69,15 @@ E assert 4.109589046841222e-08 < 1e-10
E + where 4.109589046841222e-08 = abs((0.890411 - 0.8904109589041095))
code_to_optimize/tests/pytest/test_multiple_helpers.py:26: AssertionError
'''
"""
)
def test_extracting_no_pytest_failures():
stdout = '''
stdout = """
.... [100%]
4 passed in 0.12s
'''
"""
errors = parse_test_failures_from_stdout(stdout)
assert errors == {}
@ -82,7 +85,7 @@ def test_extracting_no_pytest_failures():
def test_extracting_multiple_pytest_failures_with_class_method():
print("hi")
stdout = '''
stdout = """
F.F [100%]
=================================== FAILURES ===================================
________________________ test_simple_failure ________________________
@ -105,38 +108,42 @@ code_to_optimize/tests/test_calculator.py:22: ZeroDivisionError
FAILED code_to_optimize/tests/test_simple.py::test_simple_failure
FAILED code_to_optimize/tests/test_calculator.py::TestCalculator::test_divide_by_zero
2 failed, 1 passed in 0.18s
'''
"""
errors = parse_test_failures_from_stdout(stdout)
print(errors)
assert len(errors) == 2
assert 'test_simple_failure' in errors
assert errors['test_simple_failure'] == '''
assert "test_simple_failure" in errors
assert (
errors["test_simple_failure"]
== """
def test_simple_failure():
x = 1 + 1
> assert x == 3
E assert 2 == 3
code_to_optimize/tests/test_simple.py:10: AssertionError
'''
"""
)
assert 'TestCalculator.test_divide_by_zero' in errors
assert '''
assert "TestCalculator.test_divide_by_zero" in errors
assert errors["TestCalculator.test_divide_by_zero"] == """
class TestCalculator:
def test_divide_by_zero(self):
> Calculator().divide(10, 0)
E ZeroDivisionError: division by zero
code_to_optimize/tests/test_calculator.py:22: ZeroDivisionError
''' == errors['TestCalculator.test_divide_by_zero']
"""
def test_extracting_from_invalid_pytest_stdout():
stdout = '''
stdout = """
Running tests...
Everything seems fine
No structured output here
Just some random logs
'''
"""
errors = parse_test_failures_from_stdout(stdout)
assert errors == {}

View file

@ -3,6 +3,7 @@ import pickle
import shutil
import socket
import sqlite3
import time
from argparse import Namespace
from pathlib import Path
@ -18,7 +19,6 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, TestFile, TestFiles, TestingMode, TestsInFile, TestType
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.equivalence import compare_test_results
import time
try:
import sqlalchemy
@ -35,12 +35,8 @@ from codeflash.picklepatch.pickle_placeholder import PicklePlaceholder, PicklePl
def test_picklepatch_simple_nested():
"""Test that a simple nested data structure pickles and unpickles correctly.
"""
original_data = {
"numbers": [1, 2, 3],
"nested_dict": {"key": "value", "another": 42},
}
"""Test that a simple nested data structure pickles and unpickles correctly."""
original_data = {"numbers": [1, 2, 3], "nested_dict": {"key": "value", "another": 42}}
dumped = PicklePatcher.dumps(original_data)
reloaded = PicklePatcher.loads(dumped)
@ -56,10 +52,7 @@ def test_picklepatch_with_socket():
# Create a pair of connected sockets instead of a single socket
sock1, sock2 = socket.socketpair()
data_with_socket = {
"safe_value": 123,
"raw_socket": sock1,
}
data_with_socket = {"safe_value": 123, "raw_socket": sock1}
# Send a message through sock1, which can be received by sock2
sock1.send(b"Hello, world!")
@ -85,17 +78,11 @@ def test_picklepatch_with_socket():
def test_picklepatch_deeply_nested():
"""Test that deep nesting with unpicklable objects works correctly.
"""
"""Test that deep nesting with unpicklable objects works correctly."""
# Create a deeply nested structure with an unpicklable object
deep_nested = {
"level1": {
"level2": {
"level3": {
"normal": "value",
"socket": socket.socket(socket.AF_INET, socket.SOCK_STREAM)
}
}
"level2": {"level3": {"normal": "value", "socket": socket.socket(socket.AF_INET, socket.SOCK_STREAM)}}
}
}
@ -108,9 +95,10 @@ def test_picklepatch_deeply_nested():
# The socket should be replaced with a placeholder
assert isinstance(reloaded["level1"]["level2"]["level3"]["socket"], PicklePlaceholder)
def test_picklepatch_class_with_unpicklable_attr():
"""Test that a class with an unpicklable attribute works correctly.
"""
"""Test that a class with an unpicklable attribute works correctly."""
class TestClass:
def __init__(self):
self.normal = "normal value"
@ -128,8 +116,6 @@ def test_picklepatch_class_with_unpicklable_attr():
assert isinstance(reloaded.unpicklable, PicklePlaceholder)
def test_picklepatch_with_database_connection():
"""Test that a data structure containing a database connection is replaced
by PicklePlaceholder rather than raising an error.
@ -138,11 +124,7 @@ def test_picklepatch_with_database_connection():
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()
data_with_db = {
"description": "Database connection",
"connection": conn,
"cursor": cursor,
}
data_with_db = {"description": "Database connection", "connection": conn, "cursor": cursor}
dumped = PicklePatcher.dumps(data_with_db)
reloaded = PicklePatcher.loads(dumped)
@ -158,7 +140,7 @@ def test_picklepatch_with_database_connection():
reloaded["connection"].execute("SELECT 1")
cursor.close()
conn.close()
conn.close()
def test_picklepatch_with_generator():
@ -175,11 +157,7 @@ def test_picklepatch_with_generator():
gen = simple_generator()
# Put it in a data structure
data_with_generator = {
"description": "Contains a generator",
"generator": gen,
"normal_list": [1, 2, 3]
}
data_with_generator = {"description": "Contains a generator", "generator": gen, "normal_list": [1, 2, 3]}
dumped = PicklePatcher.dumps(data_with_generator)
reloaded = PicklePatcher.loads(dumped)
@ -204,11 +182,7 @@ def test_picklepatch_loads_standard_pickle():
using the standard pickle module.
"""
# Create a simple data structure
original_data = {
"numbers": [1, 2, 3],
"nested_dict": {"key": "value", "another": 42},
"tuple": (1, "two", 3.0),
}
original_data = {"numbers": [1, 2, 3], "nested_dict": {"key": "value", "another": 42}, "tuple": (1, "two", 3.0)}
# Pickle it with standard pickle
pickled_data = pickle.dumps(original_data)
@ -231,13 +205,7 @@ def test_picklepatch_loads_dill_pickle():
"""
# Create a more complex data structure that includes a lambda function
# which dill can handle but standard pickle cannot
original_data = {
"numbers": [1, 2, 3],
"function": lambda x: x * 2,
"nested": {
"another_function": lambda y: y ** 2
}
}
original_data = {"numbers": [1, 2, 3], "function": lambda x: x * 2, "nested": {"another_function": lambda y: y**2}}
# Pickle it with dill
dilled_data = dill.dumps(original_data)
@ -253,6 +221,7 @@ def test_picklepatch_loads_dill_pickle():
assert reloaded["function"](5) == 10
assert reloaded["nested"]["another_function"](4) == 16
def test_run_and_parse_picklepatch() -> None:
"""Test the end to end functionality of picklepatch, from tracing benchmarks to running the replay tests.
@ -269,7 +238,9 @@ def test_run_and_parse_picklepatch() -> None:
benchmarks_root = project_root / "code_to_optimize" / "tests" / "pytest" / "benchmarks_socket_test"
replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
fto_unused_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_unused_socket.py").resolve()
fto_unused_socket_path = (
project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_unused_socket.py"
).resolve()
fto_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").resolve()
original_fto_unused_socket_code = fto_unused_socket_path.read_text("utf-8")
original_fto_used_socket_code = fto_used_socket_path.read_text("utf-8")
@ -282,7 +253,8 @@ def test_run_and_parse_picklepatch() -> None:
cursor = conn.cursor()
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
)
function_calls = cursor.fetchall()
# Assert the length of function calls
@ -290,38 +262,61 @@ def test_run_and_parse_picklepatch() -> None:
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket" in function_to_results
assert (
"code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"
in function_to_results
)
# Close the connection to allow file cleanup on Windows
conn.close()
time.sleep(1)
# Handle the case where function runs too fast to be measured
unused_socket_results = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"]
unused_socket_results = function_to_results[
"code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"
]
if unused_socket_results:
test_name, total_time, function_time, percent = unused_socket_results[0]
assert total_time >= 0.0
# Function might be too fast, so we allow 0.0 function_time
assert function_time >= 0.0
assert percent >= 0.0
used_socket_results = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_used_socket.bubble_sort_with_used_socket"]
used_socket_results = function_to_results[
"code_to_optimize.bubble_sort_picklepatch_test_used_socket.bubble_sort_with_used_socket"
]
# on windows , if the socket is not used we might not have resultssss
if used_socket_results:
test_name, total_time, function_time, percent = used_socket_results[0]
assert total_time >= 0.0
assert function_time >= 0.0
assert function_time >= 0.0
assert percent >= 0.0
bubble_sort_unused_socket_path = (project_root / "code_to_optimize"/ "bubble_sort_picklepatch_test_unused_socket.py").as_posix()
bubble_sort_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").as_posix()
bubble_sort_unused_socket_path = (
project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_unused_socket.py"
).as_posix()
bubble_sort_used_socket_path = (
project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py"
).as_posix()
# Expected function calls
expected_calls = [
("bubble_sort_with_unused_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_unused_socket",
f"{bubble_sort_unused_socket_path}",
"test_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 12),
("bubble_sort_with_used_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_used_socket",
f"{bubble_sort_used_socket_path}",
"test_used_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 20)
(
"bubble_sort_with_unused_socket",
"",
"code_to_optimize.bubble_sort_picklepatch_test_unused_socket",
f"{bubble_sort_unused_socket_path}",
"test_socket_picklepatch",
"code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket",
12,
),
(
"bubble_sort_with_used_socket",
"",
"code_to_optimize.bubble_sort_picklepatch_test_used_socket",
f"{bubble_sort_used_socket_path}",
"test_used_socket_picklepatch",
"code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket",
20,
),
]
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
@ -332,29 +327,29 @@ def test_run_and_parse_picklepatch() -> None:
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
conn.close()
time.sleep(1)
# Generate replay test
generate_replay_test(output_file, replay_tests_dir)
replay_test_path = replay_tests_dir / Path(
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0.py")
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0.py"
)
replay_test_perf_path = replay_tests_dir / Path(
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0_perf.py")
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0_perf.py"
)
assert replay_test_path.exists()
original_replay_test_code = replay_test_path.read_text("utf-8")
# Instrument the replay test
func = FunctionToOptimize(function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_unused_socket_path))
func = FunctionToOptimize(
function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_unused_socket_path)
)
original_cwd = Path.cwd()
run_cwd = project_root
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
replay_test_path,
[CodePosition(17, 15)],
func,
project_root,
mode=TestingMode.BEHAVIOR,
replay_test_path, [CodePosition(17, 15)], func, project_root, mode=TestingMode.BEHAVIOR
)
os.chdir(original_cwd)
assert success
@ -386,7 +381,14 @@ def test_run_and_parse_picklepatch() -> None:
test_type=test_type,
original_file_path=replay_test_path,
benchmarking_file_path=replay_test_perf_path,
tests_in_file=[TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function, test_type=test_type)],
tests_in_file=[
TestsInFile(
test_file=replay_test_path,
test_class=None,
test_function=replay_test_function,
test_type=test_type,
)
],
)
]
)
@ -400,8 +402,14 @@ def test_run_and_parse_picklepatch() -> None:
testing_time=1.0,
)
assert len(test_results_unused_socket) == 1
assert test_results_unused_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
assert test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch"
assert (
test_results_unused_socket.test_results[0].id.test_module_path
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
)
assert (
test_results_unused_socket.test_results[0].id.test_function_name
== "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch"
)
assert test_results_unused_socket.test_results[0].did_pass == True
# Replace with optimized candidate
@ -431,13 +439,11 @@ def bubble_sort_with_unused_socket(data_container):
# Remove the previous instrumentation
replay_test_path.write_text(original_replay_test_code)
# Instrument the replay test
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path))
func = FunctionToOptimize(
function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path)
)
success, new_test = inject_profiling_into_existing_test(
replay_test_path,
[CodePosition(23,15)],
func,
project_root,
mode=TestingMode.BEHAVIOR,
replay_test_path, [CodePosition(23, 15)], func, project_root, mode=TestingMode.BEHAVIOR
)
os.chdir(original_cwd)
assert success
@ -449,8 +455,9 @@ def bubble_sort_with_unused_socket(data_container):
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_type = TestType.REPLAY_TEST
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[],
file_path=Path(fto_used_socket_path))
func = FunctionToOptimize(
function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path)
)
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch"
func_optimizer = opt.create_function_optimizer(func)
func_optimizer.test_files = TestFiles(
@ -461,8 +468,13 @@ def bubble_sort_with_unused_socket(data_container):
original_file_path=replay_test_path,
benchmarking_file_path=replay_test_perf_path,
tests_in_file=[
TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function,
test_type=test_type)],
TestsInFile(
test_file=replay_test_path,
test_class=None,
test_function=replay_test_function,
test_type=test_type,
)
],
)
]
)
@ -476,10 +488,14 @@ def bubble_sort_with_unused_socket(data_container):
testing_time=1.0,
)
assert len(test_results_used_socket) == 1
assert test_results_used_socket.test_results[
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
assert test_results_used_socket.test_results[
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch"
assert (
test_results_used_socket.test_results[0].id.test_module_path
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
)
assert (
test_results_used_socket.test_results[0].id.test_function_name
== "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch"
)
assert test_results_used_socket.test_results[0].did_pass is False
print("test results used socket")
print(test_results_used_socket)
@ -507,10 +523,14 @@ def bubble_sort_with_used_socket(data_container):
testing_time=1.0,
)
assert len(test_results_used_socket) == 1
assert test_results_used_socket.test_results[
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
assert test_results_used_socket.test_results[
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch"
assert (
test_results_used_socket.test_results[0].id.test_module_path
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
)
assert (
test_results_used_socket.test_results[0].id.test_function_name
== "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch"
)
assert test_results_used_socket.test_results[0].did_pass is False
# Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined.

View file

@ -1,6 +1,7 @@
from pathlib import Path
import pytest
from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests
from codeflash.models.models import GeneratedTests, GeneratedTestsList

View file

@ -1,14 +1,6 @@
import tempfile
from argparse import Namespace
from pathlib import Path
import libcst as cst
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer
def test_variable_removal_only() -> None:
@ -337,6 +329,7 @@ def unused_function():
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.strip() == expected.strip()
def test_base_class_inheritance() -> None:
"""Test that base classes used only for inheritance are preserved."""
code = """

Some files were not shown because too many files have changed in this diff Show more