mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
now enable instrumentation for helper classes
This commit is contained in:
parent
0aa7ca4ea4
commit
7a37e6e0eb
15 changed files with 722 additions and 450 deletions
|
|
@ -1,7 +1,6 @@
|
|||
|
||||
class BubbleSorter:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def __init__(self, x=0):
|
||||
self.x = x
|
||||
|
||||
def sorter(self, arr):
|
||||
for i in range(len(arr)):
|
||||
|
|
@ -11,5 +10,3 @@ class BubbleSorter:
|
|||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
return arr
|
||||
|
||||
|
||||
|
|
@ -119,7 +119,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
cst_to_code(self.modified_init_functions[self.current_class])
|
||||
):
|
||||
return original_node # Code was unchanged, so don't modify docstrings / comments
|
||||
return merge_init_functions(updated_node, self.modified_init_functions[self.current_class])
|
||||
return self.modified_init_functions[self.current_class]
|
||||
|
||||
return updated_node
|
||||
|
||||
|
|
@ -162,105 +162,105 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
return node
|
||||
|
||||
|
||||
class AttributeCollector(cst.CSTVisitor):
|
||||
"""Collects all self.attribute mentions in a CST."""
|
||||
# class AttributeCollector(cst.CSTVisitor):
|
||||
# """Collects all self.attribute mentions in a CST."""
|
||||
#
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
# self.attributes: set[str] = set()
|
||||
#
|
||||
# def visit_Attribute(self, node: cst.Attribute) -> bool:
|
||||
# """Record any self.attribute access."""
|
||||
# if isinstance(node.value, cst.Name) and node.value.value == "self":
|
||||
# self.attributes.add(node.attr.value)
|
||||
# return True
|
||||
#
|
||||
#
|
||||
# class AssignmentCollector(cst.CSTVisitor):
|
||||
# """Collects attributes being assigned to in a CST."""
|
||||
#
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
# self.assigned_attrs: set[str] = set()
|
||||
#
|
||||
# def visit_Assign(self, node: cst.Assign) -> bool:
|
||||
# """Check regular assignments like self.x = ..."""
|
||||
# for target in node.targets:
|
||||
# if (
|
||||
# isinstance(target.target, cst.Attribute)
|
||||
# and isinstance(target.target.value, cst.Name)
|
||||
# and target.target.value.value == "self"
|
||||
# ):
|
||||
# self.assigned_attrs.add(target.target.attr.value)
|
||||
# return True
|
||||
#
|
||||
# def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
|
||||
# """Check annotated assignments like self.x: str = ..."""
|
||||
# if (
|
||||
# isinstance(node.target, cst.Attribute)
|
||||
# and isinstance(node.target.value, cst.Name)
|
||||
# and node.target.value.value == "self"
|
||||
# ):
|
||||
# self.assigned_attrs.add(node.target.attr.value)
|
||||
# return True
|
||||
#
|
||||
# def visit_AugAssign(self, node: cst.AugAssign) -> bool:
|
||||
# """Check augmented assignments like self.x += ..."""
|
||||
# if (
|
||||
# isinstance(node.target, cst.Attribute)
|
||||
# and isinstance(node.target.value, cst.Name)
|
||||
# and node.target.value.value == "self"
|
||||
# ):
|
||||
# self.assigned_attrs.add(node.target.attr.value)
|
||||
# return True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attributes: set[str] = set()
|
||||
|
||||
def visit_Attribute(self, node: cst.Attribute) -> bool:
|
||||
"""Record any self.attribute access."""
|
||||
if isinstance(node.value, cst.Name) and node.value.value == "self":
|
||||
self.attributes.add(node.attr.value)
|
||||
return True
|
||||
|
||||
|
||||
class AssignmentCollector(cst.CSTVisitor):
|
||||
"""Collects attributes being assigned to in a CST."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.assigned_attrs: set[str] = set()
|
||||
|
||||
def visit_Assign(self, node: cst.Assign) -> bool:
|
||||
"""Check regular assignments like self.x = ..."""
|
||||
for target in node.targets:
|
||||
if (
|
||||
isinstance(target.target, cst.Attribute)
|
||||
and isinstance(target.target.value, cst.Name)
|
||||
and target.target.value.value == "self"
|
||||
):
|
||||
self.assigned_attrs.add(target.target.attr.value)
|
||||
return True
|
||||
|
||||
def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
|
||||
"""Check annotated assignments like self.x: str = ..."""
|
||||
if (
|
||||
isinstance(node.target, cst.Attribute)
|
||||
and isinstance(node.target.value, cst.Name)
|
||||
and node.target.value.value == "self"
|
||||
):
|
||||
self.assigned_attrs.add(node.target.attr.value)
|
||||
return True
|
||||
|
||||
def visit_AugAssign(self, node: cst.AugAssign) -> bool:
|
||||
"""Check augmented assignments like self.x += ..."""
|
||||
if (
|
||||
isinstance(node.target, cst.Attribute)
|
||||
and isinstance(node.target.value, cst.Name)
|
||||
and node.target.value.value == "self"
|
||||
):
|
||||
self.assigned_attrs.add(node.target.attr.value)
|
||||
return True
|
||||
|
||||
|
||||
def merge_init_functions(original_init: cst.FunctionDef, new_init: cst.FunctionDef) -> cst.FunctionDef:
|
||||
"""Merges two __init__ function definitions. Collects all self.attribute mentions
|
||||
from the original init, then filters out statements from the new init that
|
||||
assign to those attributes (but allows reading them).
|
||||
|
||||
Args:
|
||||
original_init: The original __init__ function to preserve
|
||||
new_init: The new __init__ function whose body will be filtered and appended
|
||||
|
||||
Returns:
|
||||
A merged FunctionDef
|
||||
|
||||
"""
|
||||
# Collect all self.attribute mentions from original init
|
||||
collector = AttributeCollector()
|
||||
original_init.visit(collector)
|
||||
existing_attrs = collector.attributes
|
||||
# Get set of existing statements, without comments
|
||||
original_stmts = {get_only_code_content(cst_to_code(stmt)) for stmt in original_init.body.body}
|
||||
# Filter new init body statements
|
||||
filtered_body = []
|
||||
|
||||
for stmt in new_init.body.body:
|
||||
# Filter out docstring of new init
|
||||
if (
|
||||
isinstance(stmt, cst.SimpleStatementLine)
|
||||
and len(stmt.body) == 1
|
||||
and isinstance(stmt.body[0], cst.Expr)
|
||||
and isinstance(stmt.body[0].value, cst.SimpleString)
|
||||
):
|
||||
continue
|
||||
# Filter out duplicate statements
|
||||
if get_only_code_content(cst_to_code(stmt)) in original_stmts:
|
||||
continue
|
||||
# Check for assignments to existing attributes
|
||||
assign_collector = AssignmentCollector()
|
||||
stmt.visit(assign_collector)
|
||||
|
||||
# Keep statement if it doesn't assign to any existing attributes
|
||||
if not assign_collector.assigned_attrs.intersection(existing_attrs):
|
||||
filtered_body.append(stmt)
|
||||
|
||||
# Merge bodies using with_changes
|
||||
return original_init.with_changes(
|
||||
body=original_init.body.with_changes(body=original_init.body.body + tuple(filtered_body))
|
||||
)
|
||||
#
|
||||
# def merge_init_functions(original_init: cst.FunctionDef, new_init: cst.FunctionDef) -> cst.FunctionDef:
|
||||
# """Merges two __init__ function definitions. Collects all self.attribute mentions
|
||||
# from the original init, then filters out statements from the new init that
|
||||
# assign to those attributes (but allows reading them).
|
||||
#
|
||||
# Args:
|
||||
# original_init: The original __init__ function to preserve
|
||||
# new_init: The new __init__ function whose body will be filtered and appended
|
||||
#
|
||||
# Returns:
|
||||
# A merged FunctionDef
|
||||
#
|
||||
# """
|
||||
# # Collect all self.attribute mentions from original init
|
||||
# collector = AttributeCollector()
|
||||
# original_init.visit(collector)
|
||||
# existing_attrs = collector.attributes
|
||||
# # Get set of existing statements, without comments
|
||||
# original_stmts = {get_only_code_content(cst_to_code(stmt)) for stmt in original_init.body.body}
|
||||
# # Filter new init body statements
|
||||
# filtered_body = []
|
||||
#
|
||||
# for stmt in new_init.body.body:
|
||||
# # Filter out docstring of new init
|
||||
# if (
|
||||
# isinstance(stmt, cst.SimpleStatementLine)
|
||||
# and len(stmt.body) == 1
|
||||
# and isinstance(stmt.body[0], cst.Expr)
|
||||
# and isinstance(stmt.body[0].value, cst.SimpleString)
|
||||
# ):
|
||||
# continue
|
||||
# # Filter out duplicate statements
|
||||
# if get_only_code_content(cst_to_code(stmt)) in original_stmts:
|
||||
# continue
|
||||
# # Check for assignments to existing attributes
|
||||
# assign_collector = AssignmentCollector()
|
||||
# stmt.visit(assign_collector)
|
||||
#
|
||||
# # Keep statement if it doesn't assign to any existing attributes
|
||||
# if not assign_collector.assigned_attrs.intersection(existing_attrs):
|
||||
# filtered_body.append(stmt)
|
||||
#
|
||||
# # Merge bodies using with_changes
|
||||
# return original_init.with_changes(
|
||||
# body=original_init.body.with_changes(body=original_init.body.body + tuple(filtered_body))
|
||||
# )
|
||||
|
||||
|
||||
def replace_functions_in_file(
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ import isort
|
|||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent, TestingMode, VerificationType
|
||||
from codeflash.models.models import FunctionParent, TestingMode
|
||||
from codeflash.verification.test_results import VerificationType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ def get_code_optimization_context(
|
|||
read_only_context_code=read_only_code_markdown.markdown,
|
||||
helper_functions=helpers_of_fto_obj_list,
|
||||
preexisting_objects=preexisting_objects,
|
||||
|
||||
)
|
||||
|
||||
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
|
||||
|
|
|
|||
|
|
@ -84,11 +84,6 @@ class CodeOptimizationContext(BaseModel):
|
|||
preexisting_objects: list[tuple[str, list[FunctionParent]]]
|
||||
|
||||
|
||||
class VerificationType(str, Enum):
|
||||
FUNCTION_TO_OPTIMIZE = "function_to_optimize"
|
||||
INSTANCE_STATE = "instance_state"
|
||||
|
||||
|
||||
class OptimizedCandidateResult(BaseModel):
|
||||
max_loop_count: int
|
||||
best_test_runtime: int
|
||||
|
|
|
|||
|
|
@ -337,8 +337,15 @@ class Optimizer:
|
|||
)
|
||||
|
||||
# Instrument code
|
||||
original_code = validated_original_code[function_to_optimize.file_path].source_code
|
||||
instrument_code(function_to_optimize)
|
||||
# Get a dict of file_path_to_classes of fto and helpers_of_fto
|
||||
file_path_to_helper_classes = defaultdict(set)
|
||||
for function_source in code_context.helper_functions:
|
||||
if (
|
||||
function_source.qualified_name != function_to_optimize.qualified_name
|
||||
and "." in function_source.qualified_name
|
||||
):
|
||||
file_path_to_helper_classes[function_source.file_path].add(function_source.qualified_name.split(".")[0])
|
||||
instrument_code(function_to_optimize, file_path_to_helper_classes)
|
||||
|
||||
baseline_result = self.establish_original_code_baseline( # this needs better typing
|
||||
function_name=function_to_optimize_qualified_name,
|
||||
|
|
@ -347,7 +354,11 @@ class Optimizer:
|
|||
)
|
||||
|
||||
# Remove instrumentation
|
||||
self.write_code_and_helpers(original_code, {}, function_to_optimize.file_path)
|
||||
self.write_code_and_helpers(
|
||||
validated_original_code[function_to_optimize.file_path].source_code,
|
||||
original_helper_code,
|
||||
function_to_optimize.file_path,
|
||||
)
|
||||
|
||||
console.rule()
|
||||
paths_to_cleanup = (
|
||||
|
|
@ -514,7 +525,6 @@ class Optimizer:
|
|||
optimized_code=candidate.source_code,
|
||||
qualified_function_name=function_to_optimize.qualified_name,
|
||||
)
|
||||
# If init was modified, instrument the code with codeflash capture
|
||||
|
||||
if not did_update:
|
||||
logger.warning(
|
||||
|
|
@ -527,6 +537,8 @@ class Optimizer:
|
|||
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
|
||||
continue
|
||||
|
||||
# Instrument codeflash capture
|
||||
|
||||
run_results = self.run_optimized_candidate(
|
||||
optimization_candidate_index=candidate_index, baseline_results=original_code_baseline
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import time
|
|||
|
||||
import dill as pickle
|
||||
|
||||
from codeflash.models.models import VerificationType
|
||||
from codeflash.verification.test_results import VerificationType
|
||||
|
||||
|
||||
def get_test_info_from_stack() -> tuple[str, str | None, str, str]:
|
||||
|
|
@ -37,7 +37,7 @@ def get_test_info_from_stack() -> tuple[str, str | None, str, str]:
|
|||
return test_module_name, test_class_name, test_name, line_id
|
||||
|
||||
|
||||
def codeflash_capture(function_name: str, tmp_dir_path: str):
|
||||
def codeflash_capture(function_name: str, tmp_dir_path: str, is_fto: bool = False):
|
||||
"""Defines decorator to be instrumented onto the init function in the code. Collects info of the test that called this, and captures the state of the instance."""
|
||||
|
||||
def decorator(wrapped):
|
||||
|
|
@ -120,7 +120,7 @@ def codeflash_capture(function_name: str, tmp_dir_path: str):
|
|||
invocation_id,
|
||||
codeflash_duration,
|
||||
pickled_return_value,
|
||||
VerificationType.INSTANCE_STATE,
|
||||
VerificationType.INSTANCE_STATE_FTO if is_fto else VerificationType.INSTANCE_STATE_HELPER,
|
||||
),
|
||||
)
|
||||
codeflash_con.commit()
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
|
|||
cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id)
|
||||
if cdd_test_result is not None and original_test_result is None:
|
||||
continue
|
||||
|
||||
# If helper function instance_state verification is not present, that's ok. continue
|
||||
if original_test_result is None or cdd_test_result is None:
|
||||
are_equal = False
|
||||
break
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -5,38 +7,56 @@ from codeflash.code_utils.code_utils import get_run_tmp_file
|
|||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
|
||||
def instrument_code(function_to_optimize: FunctionToOptimize) -> None:
|
||||
def instrument_code(function_to_optimize: FunctionToOptimize, file_path_to_helper_class: dict[Path, set[str]]) -> None:
|
||||
"""Instrument __init__ function with codeflash_capture decorator if it's in a class."""
|
||||
# Find the class parent
|
||||
if len(function_to_optimize.parents) == 1 and function_to_optimize.parents[0].type == "ClassDef":
|
||||
class_parent = function_to_optimize.parents[0]
|
||||
else:
|
||||
return
|
||||
|
||||
# Read code from file
|
||||
# Remove duplicate fto class from helper classes
|
||||
if function_to_optimize.file_path in file_path_to_helper_class:
|
||||
file_path_to_helper_class[function_to_optimize.file_path].remove(class_parent.name)
|
||||
# Instrument fto class
|
||||
with open(function_to_optimize.file_path) as f:
|
||||
original_code = f.read()
|
||||
|
||||
# Add decorator to init
|
||||
modified_code = add_codeflash_capture_to_init(
|
||||
class_name=class_parent.name,
|
||||
function_name=function_to_optimize.function_name,
|
||||
target_classes={class_parent.name},
|
||||
fto_name=function_to_optimize.function_name,
|
||||
tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))),
|
||||
code=original_code,
|
||||
is_fto=True,
|
||||
)
|
||||
|
||||
# Write modified code back to file
|
||||
with open(function_to_optimize.file_path, "w") as f:
|
||||
f.write(modified_code)
|
||||
|
||||
# Instrument helper classes
|
||||
for file_path, helper_classes in file_path_to_helper_class.items():
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
modified_code = add_codeflash_capture_to_init(
|
||||
target_classes=helper_classes,
|
||||
fto_name=function_to_optimize.function_name,
|
||||
tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))),
|
||||
code=original_code,
|
||||
is_fto=False,
|
||||
)
|
||||
with open(file_path, "w") as f:
|
||||
f.write(modified_code)
|
||||
|
||||
def add_codeflash_capture_to_init(class_name: str, function_name: str, tmp_dir_path: str, code: str) -> str:
|
||||
|
||||
def add_codeflash_capture_to_init(
|
||||
target_classes: set[str], fto_name: str, tmp_dir_path: str, code: str, is_fto: bool = False
|
||||
) -> str:
|
||||
"""Add codeflash_capture decorator to __init__ function in the specified class."""
|
||||
# Parse the code into an AST
|
||||
tree = ast.parse(code)
|
||||
|
||||
# Apply our transformation
|
||||
transformer = InitDecorator(class_name, function_name, tmp_dir_path)
|
||||
transformer = InitDecorator(target_classes, fto_name, tmp_dir_path, is_fto)
|
||||
modified_tree = transformer.visit(tree)
|
||||
if transformer.inserted_decorator:
|
||||
ast.fix_missing_locations(modified_tree)
|
||||
|
|
@ -48,10 +68,11 @@ def add_codeflash_capture_to_init(class_name: str, function_name: str, tmp_dir_p
|
|||
class InitDecorator(ast.NodeTransformer):
|
||||
"""AST transformer that adds codeflash_capture decorator to specific class's __init__."""
|
||||
|
||||
def __init__(self, target_class_name: str, function_name: str, tmp_dir_path: str):
|
||||
self.target_class_name = target_class_name
|
||||
self.function_name = function_name
|
||||
def __init__(self, target_classes: set[str], fto_name: str, tmp_dir_path: str, is_fto=False) -> None:
|
||||
self.target_classes = target_classes
|
||||
self.fto_name = fto_name
|
||||
self.tmp_dir_path = tmp_dir_path
|
||||
self.is_fto = is_fto
|
||||
self.has_import = False
|
||||
self.inserted_decorator = False
|
||||
|
||||
|
|
@ -74,7 +95,7 @@ class InitDecorator(ast.NodeTransformer):
|
|||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||
# Only modify the target class
|
||||
if node.name != self.target_class_name:
|
||||
if node.name not in self.target_classes:
|
||||
return node
|
||||
|
||||
# Look for __init__ method
|
||||
|
|
@ -85,8 +106,9 @@ class InitDecorator(ast.NodeTransformer):
|
|||
func=ast.Name(id="codeflash_capture", ctx=ast.Load()),
|
||||
args=[],
|
||||
keywords=[
|
||||
ast.keyword(arg="function_name", value=ast.Constant(value=self.function_name)),
|
||||
ast.keyword(arg="function_name", value=ast.Constant(value=self.fto_name)),
|
||||
ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)),
|
||||
ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)),
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from codeflash.code_utils.code_utils import (
|
|||
)
|
||||
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
|
||||
from codeflash.models.models import CoverageData, TestFiles
|
||||
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults
|
||||
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, VerificationType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import subprocess
|
||||
|
|
@ -92,6 +92,7 @@ def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, tes
|
|||
test_type=test_type,
|
||||
return_value=test_pickle,
|
||||
timed_out=False,
|
||||
verification_type=VerificationType.FUNCTION_TO_OPTIMIZE,
|
||||
)
|
||||
)
|
||||
return test_results
|
||||
|
|
@ -141,7 +142,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
test_type=test_type,
|
||||
return_value=ret_val,
|
||||
timed_out=False,
|
||||
verification_type=verification_type if verification_type else None,
|
||||
verification_type=VerificationType(verification_type) if verification_type else None,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
|
|
@ -376,6 +377,7 @@ def merge_test_results(
|
|||
test_type=xml_result.test_type,
|
||||
return_value=result_bin.return_value,
|
||||
timed_out=xml_result.timed_out,
|
||||
verification_type=VerificationType(result_bin.verification_type),
|
||||
)
|
||||
)
|
||||
elif xml_results.test_results[0].id.iteration_id is not None:
|
||||
|
|
@ -402,6 +404,7 @@ def merge_test_results(
|
|||
timed_out=xml_result.timed_out
|
||||
if bin_result.runtime is None
|
||||
else False, # If runtime was measured in the bin file, then the testcase did not time out
|
||||
verification_type=VerificationType(bin_result.verification_type),
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
@ -425,6 +428,7 @@ def merge_test_results(
|
|||
test_type=bin_result.test_type,
|
||||
return_value=bin_result.return_value,
|
||||
timed_out=xml_result.timed_out, # only the xml gets the timed_out flag
|
||||
verification_type=VerificationType(bin_result.verification_type),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,19 @@ from codeflash.cli_cmds.console import DEBUG_MODE, logger
|
|||
from codeflash.verification.comparator import comparator
|
||||
|
||||
|
||||
class VerificationType(str, Enum):
|
||||
FUNCTION_TO_OPTIMIZE = (
|
||||
"function_to_optimize" # Correctness verification for fto, checks input values and output values
|
||||
)
|
||||
INSTANCE_STATE_FTO = "instance_state_fto" # Correctness verification for instance state of fto, checks instance attributes right after __init__ is called
|
||||
INSTANCE_STATE_HELPER = "instance_state_helper" # Correctness verification for instance state of helper classes, checks instance attributes right after __init__ is called
|
||||
|
||||
def __new__(cls, value: str) -> VerificationType | None:
|
||||
obj = str.__new__(cls, value)
|
||||
obj._value_ = value if value != "" else None
|
||||
return obj
|
||||
|
||||
|
||||
class TestType(Enum):
|
||||
EXISTING_UNIT_TEST = 1
|
||||
INSPIRED_REGRESSION = 2
|
||||
|
|
@ -74,7 +87,7 @@ class FunctionTestInvocation:
|
|||
test_type: TestType
|
||||
return_value: Optional[object] # The return value of the function invocation
|
||||
timed_out: Optional[bool]
|
||||
verification_type: Optional[str]
|
||||
verification_type: str = VerificationType.FUNCTION_TO_OPTIMIZE
|
||||
|
||||
@property
|
||||
def unique_invocation_loop_id(self) -> str:
|
||||
|
|
|
|||
|
|
@ -744,14 +744,15 @@ def test_superset():
|
|||
def __init__(self):
|
||||
self.a = 1
|
||||
|
||||
class B(A):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.b = 2
|
||||
obj = A()
|
||||
obj.x = 3
|
||||
|
||||
assert comparator(A(), B(), superset_obj=True)
|
||||
assert not comparator(B(), A(), superset_obj=True)
|
||||
assert not comparator(A(), B())
|
||||
assert comparator(A(), obj, superset_obj=True)
|
||||
assert not comparator(obj, A(), superset_obj=True)
|
||||
assert not comparator(A(), obj)
|
||||
assert not comparator(obj, A())
|
||||
assert comparator(obj, obj, superset_obj=True)
|
||||
assert comparator(obj, obj)
|
||||
|
||||
|
||||
def test_compare_results_fn():
|
||||
|
|
|
|||
|
|
@ -1,231 +1,231 @@
|
|||
from textwrap import dedent
|
||||
|
||||
import libcst as cst
|
||||
from codeflash.code_utils.code_replacer import merge_init_functions, replace_functions_in_file
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
|
||||
def test_basic_merge() -> None:
|
||||
original = """
|
||||
class MyClass:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
"""
|
||||
new = """
|
||||
class MyClass:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
"""
|
||||
result = merge_init_functions(
|
||||
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
)
|
||||
|
||||
expected = """
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.x = x
|
||||
self.y = y
|
||||
"""
|
||||
assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
|
||||
|
||||
def test_prevent_duplication() -> None:
|
||||
original = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
print("init")
|
||||
self.setup()
|
||||
"""
|
||||
new = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
print("init")
|
||||
self.setup()
|
||||
self.y = 2
|
||||
"""
|
||||
result = merge_init_functions(
|
||||
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
)
|
||||
|
||||
expected = """
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
print("init")
|
||||
self.setup()
|
||||
self.y = 2
|
||||
"""
|
||||
assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
|
||||
|
||||
def test_prevent_overwrite() -> None:
|
||||
original = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
self.y = 2
|
||||
"""
|
||||
new = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 2
|
||||
"""
|
||||
result = merge_init_functions(
|
||||
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
)
|
||||
|
||||
expected = """
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
self.y = 2
|
||||
"""
|
||||
assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
|
||||
|
||||
def test_complex_control_flow() -> None:
|
||||
original = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
with self.lock:
|
||||
self.setup()
|
||||
if self.debug:
|
||||
self.enable_logging()
|
||||
"""
|
||||
new = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
try:
|
||||
self.connect()
|
||||
except ConnectionError:
|
||||
self.fallback()
|
||||
"""
|
||||
result = merge_init_functions(
|
||||
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
)
|
||||
|
||||
expected = """
|
||||
def __init__(self):
|
||||
with self.lock:
|
||||
self.setup()
|
||||
if self.debug:
|
||||
self.enable_logging()
|
||||
try:
|
||||
self.connect()
|
||||
except ConnectionError:
|
||||
self.fallback()
|
||||
"""
|
||||
assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
|
||||
|
||||
def test_docstrings_and_comments() -> None:
|
||||
original = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
# Setup configuration
|
||||
self.config = {} # Empty config
|
||||
"""
|
||||
new = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
\"\"\"New docstring.\"\"\"
|
||||
# Initialize database
|
||||
self.db = None # Database connection
|
||||
"""
|
||||
result = merge_init_functions(
|
||||
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
)
|
||||
expected = """
|
||||
def __init__(self):
|
||||
# Setup configuration
|
||||
self.config = {} # Empty config
|
||||
# Initialize database
|
||||
self.db = None # Database connection
|
||||
"""
|
||||
assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
|
||||
|
||||
def test_type_annotations() -> None:
|
||||
original = """
|
||||
class MyClass:
|
||||
def __init__(self) -> None:
|
||||
self.x: int = 1
|
||||
self.y: str = "hello"
|
||||
"""
|
||||
new = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.y: str = "new hello"
|
||||
self.z: float = 2.0
|
||||
"""
|
||||
result = merge_init_functions(
|
||||
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
)
|
||||
|
||||
expected = """
|
||||
def __init__(self) -> None:
|
||||
self.x: int = 1
|
||||
self.y: str = "hello"
|
||||
self.z: float = 2.0
|
||||
"""
|
||||
assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
|
||||
|
||||
# Tests for code replacement with init
|
||||
def test_merge_init_methods() -> None:
|
||||
optim_code = """class MyClass:
|
||||
def __init__(self):
|
||||
self.y = 2
|
||||
self.z = 3
|
||||
"""
|
||||
|
||||
original_code = """class MyClass:
|
||||
def __init__(self):
|
||||
self.y = 1
|
||||
self.setup()
|
||||
"""
|
||||
|
||||
expected = """class MyClass:
|
||||
def __init__(self):
|
||||
self.y = 1
|
||||
self.setup()
|
||||
self.z = 3
|
||||
"""
|
||||
|
||||
result = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
original_function_names=[],
|
||||
optimized_code=optim_code,
|
||||
preexisting_objects=[("__init__", [FunctionParent(name="MyClass", type="ClassDef")])],
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_init_is_function_to_optimize() -> None:
|
||||
optim_code = """class MyClass:
|
||||
def __init__(self):
|
||||
self.y = 2
|
||||
self.z = 3
|
||||
"""
|
||||
|
||||
original_code = """class MyClass:
|
||||
def __init__(self):
|
||||
self.y = 1
|
||||
self.setup()
|
||||
"""
|
||||
|
||||
expected = """class MyClass:
|
||||
def __init__(self):
|
||||
self.y = 2
|
||||
self.z = 3
|
||||
"""
|
||||
# In this scenario, we leave the mutation check to the usual FTO behaviour check.
|
||||
result = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
original_function_names=["MyClass.__init__"],
|
||||
optimized_code=optim_code,
|
||||
preexisting_objects=[("__init__", [FunctionParent(name="MyClass", type="ClassDef")])],
|
||||
)
|
||||
assert result == expected
|
||||
# from textwrap import dedent
|
||||
#
|
||||
# import libcst as cst
|
||||
# from codeflash.code_utils.code_replacer import merge_init_functions, replace_functions_in_file
|
||||
# from codeflash.models.models import FunctionParent
|
||||
#
|
||||
#
|
||||
# def test_basic_merge() -> None:
|
||||
# original = """
|
||||
# class MyClass:
|
||||
# def __init__(self, a, b):
|
||||
# self.a = a
|
||||
# self.b = b
|
||||
# """
|
||||
# new = """
|
||||
# class MyClass:
|
||||
# def __init__(self, x, y):
|
||||
# self.x = x
|
||||
# self.y = y
|
||||
# """
|
||||
# result = merge_init_functions(
|
||||
# cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
# )
|
||||
#
|
||||
# expected = """
|
||||
# def __init__(self, a, b):
|
||||
# self.a = a
|
||||
# self.b = b
|
||||
# self.x = x
|
||||
# self.y = y
|
||||
# """
|
||||
# assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
#
|
||||
#
|
||||
# def test_prevent_duplication() -> None:
|
||||
# original = """
|
||||
# class MyClass:
|
||||
# def __init__(self):
|
||||
# self.x = 1
|
||||
# print("init")
|
||||
# self.setup()
|
||||
# """
|
||||
# new = """
|
||||
# class MyClass:
|
||||
# def __init__(self):
|
||||
# print("init")
|
||||
# self.setup()
|
||||
# self.y = 2
|
||||
# """
|
||||
# result = merge_init_functions(
|
||||
# cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
# )
|
||||
#
|
||||
# expected = """
|
||||
# def __init__(self):
|
||||
# self.x = 1
|
||||
# print("init")
|
||||
# self.setup()
|
||||
# self.y = 2
|
||||
# """
|
||||
# assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
#
|
||||
#
|
||||
# def test_prevent_overwrite() -> None:
|
||||
# original = """
|
||||
# class MyClass:
|
||||
# def __init__(self):
|
||||
# self.x = 1
|
||||
# self.y = 2
|
||||
# """
|
||||
# new = """
|
||||
# class MyClass:
|
||||
# def __init__(self):
|
||||
# self.x = 2
|
||||
# """
|
||||
# result = merge_init_functions(
|
||||
# cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
# )
|
||||
#
|
||||
# expected = """
|
||||
# def __init__(self):
|
||||
# self.x = 1
|
||||
# self.y = 2
|
||||
# """
|
||||
# assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
#
|
||||
#
|
||||
# def test_complex_control_flow() -> None:
|
||||
# original = """
|
||||
# class MyClass:
|
||||
# def __init__(self):
|
||||
# with self.lock:
|
||||
# self.setup()
|
||||
# if self.debug:
|
||||
# self.enable_logging()
|
||||
# """
|
||||
# new = """
|
||||
# class MyClass:
|
||||
# def __init__(self):
|
||||
# try:
|
||||
# self.connect()
|
||||
# except ConnectionError:
|
||||
# self.fallback()
|
||||
# """
|
||||
# result = merge_init_functions(
|
||||
# cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
# )
|
||||
#
|
||||
# expected = """
|
||||
# def __init__(self):
|
||||
# with self.lock:
|
||||
# self.setup()
|
||||
# if self.debug:
|
||||
# self.enable_logging()
|
||||
# try:
|
||||
# self.connect()
|
||||
# except ConnectionError:
|
||||
# self.fallback()
|
||||
# """
|
||||
# assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
#
|
||||
#
|
||||
# def test_docstrings_and_comments() -> None:
|
||||
# original = """
|
||||
# class MyClass:
|
||||
# def __init__(self):
|
||||
# # Setup configuration
|
||||
# self.config = {} # Empty config
|
||||
# """
|
||||
# new = """
|
||||
# class MyClass:
|
||||
# def __init__(self):
|
||||
# \"\"\"New docstring.\"\"\"
|
||||
# # Initialize database
|
||||
# self.db = None # Database connection
|
||||
# """
|
||||
# result = merge_init_functions(
|
||||
# cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
# )
|
||||
# expected = """
|
||||
# def __init__(self):
|
||||
# # Setup configuration
|
||||
# self.config = {} # Empty config
|
||||
# # Initialize database
|
||||
# self.db = None # Database connection
|
||||
# """
|
||||
# assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
#
|
||||
#
|
||||
# def test_type_annotations() -> None:
|
||||
# original = """
|
||||
# class MyClass:
|
||||
# def __init__(self) -> None:
|
||||
# self.x: int = 1
|
||||
# self.y: str = "hello"
|
||||
# """
|
||||
# new = """
|
||||
# class MyClass:
|
||||
# def __init__(self):
|
||||
# self.y: str = "new hello"
|
||||
# self.z: float = 2.0
|
||||
# """
|
||||
# result = merge_init_functions(
|
||||
# cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
# )
|
||||
#
|
||||
# expected = """
|
||||
# def __init__(self) -> None:
|
||||
# self.x: int = 1
|
||||
# self.y: str = "hello"
|
||||
# self.z: float = 2.0
|
||||
# """
|
||||
# assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
#
|
||||
#
|
||||
# # Tests for code replacement with init
|
||||
# def test_merge_init_methods() -> None:
|
||||
# optim_code = """class MyClass:
|
||||
# def __init__(self):
|
||||
# self.y = 2
|
||||
# self.z = 3
|
||||
# """
|
||||
#
|
||||
# original_code = """class MyClass:
|
||||
# def __init__(self):
|
||||
# self.y = 1
|
||||
# self.setup()
|
||||
# """
|
||||
#
|
||||
# expected = """class MyClass:
|
||||
# def __init__(self):
|
||||
# self.y = 1
|
||||
# self.setup()
|
||||
# self.z = 3
|
||||
# """
|
||||
#
|
||||
# result = replace_functions_in_file(
|
||||
# source_code=original_code,
|
||||
# original_function_names=[],
|
||||
# optimized_code=optim_code,
|
||||
# preexisting_objects=[("__init__", [FunctionParent(name="MyClass", type="ClassDef")])],
|
||||
# )
|
||||
# assert result == expected
|
||||
#
|
||||
#
|
||||
# def test_init_is_function_to_optimize() -> None:
|
||||
# optim_code = """class MyClass:
|
||||
# def __init__(self):
|
||||
# self.y = 2
|
||||
# self.z = 3
|
||||
# """
|
||||
#
|
||||
# original_code = """class MyClass:
|
||||
# def __init__(self):
|
||||
# self.y = 1
|
||||
# self.setup()
|
||||
# """
|
||||
#
|
||||
# expected = """class MyClass:
|
||||
# def __init__(self):
|
||||
# self.y = 2
|
||||
# self.z = 3
|
||||
# """
|
||||
# # In this scenario, we leave the mutation check to the usual FTO behaviour check.
|
||||
# result = replace_functions_in_file(
|
||||
# source_code=original_code,
|
||||
# original_function_names=["MyClass.__init__"],
|
||||
# optimized_code=optim_code,
|
||||
# preexisting_objects=[("__init__", [FunctionParent(name="MyClass", type="ClassDef")])],
|
||||
# )
|
||||
# assert result == expected
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from codeflash.verification.instrument_code import instrument_code
|
|||
|
||||
|
||||
def test_add_codeflash_capture():
|
||||
# Test input code
|
||||
original_code = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
|
|
@ -22,39 +21,30 @@ from codeflash.verification.codeflash_capture import codeflash_capture
|
|||
|
||||
class MyClass:
|
||||
|
||||
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}')
|
||||
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=True)
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def target_function(self):
|
||||
return self.x + 1
|
||||
"""
|
||||
|
||||
# Create and modify test file
|
||||
test_file = Path("test_file.py")
|
||||
test_file.write_text(original_code)
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
test_path.write_text(original_code)
|
||||
|
||||
function = FunctionToOptimize(
|
||||
function_name="target_function",
|
||||
file_path=Path("test_file.py"),
|
||||
parents=[FunctionParent(type="ClassDef", name="MyClass")],
|
||||
function_name="target_function", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyClass")]
|
||||
)
|
||||
|
||||
try:
|
||||
# Run the instrumentation
|
||||
instrument_code(function)
|
||||
|
||||
# Check the result
|
||||
modified_code = test_file.read_text()
|
||||
instrument_code(function, {})
|
||||
modified_code = test_path.read_text()
|
||||
assert modified_code.strip() == expected.strip()
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_add_codeflash_capture_no_parent():
|
||||
# Test input code
|
||||
original_code = """
|
||||
class MyClass:
|
||||
|
||||
|
|
@ -68,24 +58,17 @@ class MyClass:
|
|||
def target_function(self):
|
||||
return self.x + 1
|
||||
"""
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
test_path.write_text(original_code)
|
||||
|
||||
# Create and modify test file
|
||||
test_file = Path("test_file.py")
|
||||
test_file.write_text(original_code)
|
||||
|
||||
function = FunctionToOptimize(function_name="target_function", file_path=Path("test_file.py"), parents=[])
|
||||
function = FunctionToOptimize(function_name="target_function", file_path=test_path, parents=[])
|
||||
|
||||
try:
|
||||
# Run the instrumentation
|
||||
instrument_code(function)
|
||||
|
||||
# Check the result
|
||||
modified_code = test_file.read_text()
|
||||
instrument_code(function, {})
|
||||
modified_code = test_path.read_text()
|
||||
assert modified_code.strip() == expected.strip()
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_add_codeflash_capture_no_init():
|
||||
|
|
@ -102,32 +85,263 @@ from codeflash.verification.codeflash_capture import codeflash_capture
|
|||
|
||||
class MyClass(ParentClass):
|
||||
|
||||
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}')
|
||||
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=True)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def target_function(self):
|
||||
return self.x + 1
|
||||
"""
|
||||
|
||||
# Create and modify test file
|
||||
test_file = Path("test_file.py")
|
||||
test_file.write_text(original_code)
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
test_path.write_text(original_code)
|
||||
|
||||
function = FunctionToOptimize(
|
||||
function_name="target_function",
|
||||
file_path=Path("test_file.py"),
|
||||
parents=[FunctionParent(type="ClassDef", name="MyClass")],
|
||||
function_name="target_function", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyClass")]
|
||||
)
|
||||
|
||||
try:
|
||||
# Run the instrumentation
|
||||
instrument_code(function)
|
||||
|
||||
# Check the result
|
||||
modified_code = test_file.read_text()
|
||||
instrument_code(function, {})
|
||||
modified_code = test_path.read_text()
|
||||
assert modified_code.strip() == expected.strip()
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_add_codeflash_capture_with_helpers():
|
||||
# Test input code
|
||||
original_code = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def target_function(self):
|
||||
return helper() + 1
|
||||
|
||||
def helper(self):
|
||||
return self.x
|
||||
"""
|
||||
|
||||
expected = f"""
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
|
||||
class MyClass:
|
||||
|
||||
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=True)
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def target_function(self):
|
||||
return helper() + 1
|
||||
|
||||
def helper(self):
|
||||
return self.x
|
||||
"""
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
test_path.write_text(original_code)
|
||||
|
||||
function = FunctionToOptimize(
|
||||
function_name="target_function", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyClass")]
|
||||
)
|
||||
|
||||
try:
|
||||
instrument_code(
|
||||
function, {test_path: {"MyClass"}}
|
||||
) # MyClass was removed from the file_path_to_helper_class as it shares class with FTO
|
||||
modified_code = test_path.read_text()
|
||||
assert modified_code.strip() == expected.strip()
|
||||
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_add_codeflash_capture_with_helpers_2():
|
||||
# Test input code
|
||||
original_code = """
|
||||
from test_helper_file import HelperClass
|
||||
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def target_function(self):
|
||||
return HelperClass().helper() + 1
|
||||
"""
|
||||
original_helper = """
|
||||
class HelperClass:
|
||||
def __init__(self):
|
||||
self.y = 1
|
||||
def helper(self):
|
||||
return 1
|
||||
"""
|
||||
expected = f"""
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
from test_helper_file import HelperClass
|
||||
|
||||
class MyClass:
|
||||
|
||||
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=True)
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def target_function(self):
|
||||
return HelperClass().helper() + 1
|
||||
"""
|
||||
expected_helper = f"""
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
|
||||
class HelperClass:
|
||||
@codeflash_capture(function_name='helper', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=False)
|
||||
def __init__(self):
|
||||
self.y = 1
|
||||
def helper(self):
|
||||
return 1
|
||||
"""
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
test_path.write_text(original_code)
|
||||
helper_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_helper_file.py").resolve()
|
||||
helper_path.write_text(original_helper)
|
||||
|
||||
function = FunctionToOptimize(
|
||||
function_name="target_function", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyClass")]
|
||||
)
|
||||
|
||||
try:
|
||||
instrument_code(function, {helper_path: {"HelperClass"}})
|
||||
modified_code = test_path.read_text()
|
||||
assert modified_code.strip() == expected.strip()
|
||||
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
helper_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_add_codeflash_capture_with_multiple_helpers():
|
||||
# Test input code with imports from two helper files
|
||||
original_code = """
|
||||
from helper_file_1 import HelperClass1
|
||||
from helper_file_2 import HelperClass2, AnotherHelperClass
|
||||
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def target_function(self):
|
||||
helper1 = HelperClass1().helper1()
|
||||
helper2 = HelperClass2().helper2()
|
||||
another = AnotherHelperClass().another_helper()
|
||||
return helper1 + helper2 + another
|
||||
"""
|
||||
|
||||
# First helper file content
|
||||
original_helper1 = """
|
||||
class HelperClass1:
|
||||
def __init__(self):
|
||||
self.y = 1
|
||||
def helper1(self):
|
||||
return 1
|
||||
"""
|
||||
|
||||
# Second helper file content
|
||||
original_helper2 = """
|
||||
class HelperClass2:
|
||||
def __init__(self):
|
||||
self.z = 2
|
||||
def helper2(self):
|
||||
return 2
|
||||
|
||||
class AnotherHelperClass:
|
||||
def another_helper(self):
|
||||
return 3
|
||||
"""
|
||||
|
||||
# Expected output code with decorators
|
||||
expected = f"""
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
from helper_file_1 import HelperClass1
|
||||
from helper_file_2 import HelperClass2, AnotherHelperClass
|
||||
|
||||
class MyClass:
|
||||
|
||||
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=True)
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def target_function(self):
|
||||
helper1 = HelperClass1().helper1()
|
||||
helper2 = HelperClass2().helper2()
|
||||
another = AnotherHelperClass().another_helper()
|
||||
return helper1 + helper2 + another
|
||||
"""
|
||||
|
||||
# Expected output for first helper file
|
||||
expected_helper1 = f"""
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
|
||||
class HelperClass1:
|
||||
|
||||
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=False)
|
||||
def __init__(self):
|
||||
self.y = 1
|
||||
|
||||
def helper1(self):
|
||||
return 1
|
||||
"""
|
||||
|
||||
# Expected output for second helper file
|
||||
expected_helper2 = f"""
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
|
||||
class HelperClass2:
|
||||
|
||||
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=False)
|
||||
def __init__(self):
|
||||
self.z = 2
|
||||
|
||||
def helper2(self):
|
||||
return 2
|
||||
|
||||
class AnotherHelperClass:
|
||||
|
||||
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=False)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def another_helper(self):
|
||||
return 3
|
||||
"""
|
||||
|
||||
# Set up test files
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
helper1_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/helper_file_1.py").resolve()
|
||||
helper2_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/helper_file_2.py").resolve()
|
||||
|
||||
# Write original content to files
|
||||
test_path.write_text(original_code)
|
||||
helper1_path.write_text(original_helper1)
|
||||
helper2_path.write_text(original_helper2)
|
||||
|
||||
# Create FunctionToOptimize instance
|
||||
function = FunctionToOptimize(
|
||||
function_name="target_function", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyClass")]
|
||||
)
|
||||
|
||||
try:
|
||||
# Instrument code with multiple helper files
|
||||
helper_classes = {helper1_path: {"HelperClass1"}, helper2_path: {"HelperClass2", "AnotherHelperClass"}}
|
||||
instrument_code(function, helper_classes)
|
||||
|
||||
# Verify the modifications
|
||||
modified_code = test_path.read_text()
|
||||
modified_helper1 = helper1_path.read_text()
|
||||
modified_helper2 = helper2_path.read_text()
|
||||
|
||||
assert modified_code.strip() == expected.strip()
|
||||
assert modified_helper1.strip() == expected_helper1.strip()
|
||||
assert modified_helper2.strip() == expected_helper2.strip()
|
||||
|
||||
finally:
|
||||
# Clean up test files
|
||||
test_path.unlink(missing_ok=True)
|
||||
helper1_path.unlink(missing_ok=True)
|
||||
helper2_path.unlink(missing_ok=True)
|
||||
|
|
|
|||
|
|
@ -7,10 +7,12 @@ from pathlib import Path
|
|||
import isort
|
||||
from code_to_optimize.bubble_sort_method import BubbleSorter
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.models.models import TestFile, TestFiles, TestingMode
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestingMode
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.verification.comparator import comparator
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
from codeflash.verification.instrument_code import instrument_code
|
||||
from codeflash.verification.test_results import TestType
|
||||
|
||||
# Used by aiservice instrumentation
|
||||
|
|
@ -85,7 +87,7 @@ def codeflash_wrap(
|
|||
"""
|
||||
|
||||
|
||||
def test_aiservice_class_method_behavior_results() -> None:
|
||||
def test_class_method_behavior_results() -> None:
|
||||
test_source = """import pytest
|
||||
from code_to_optimize.bubble_sort import BubbleSorter
|
||||
|
||||
|
|
@ -191,12 +193,10 @@ def test_single_element_list():
|
|||
# Replace with optimized code that mutated instance attribute
|
||||
optimized_code_mutated_attr = """
|
||||
class BubbleSorter:
|
||||
z = 1234
|
||||
|
||||
def __init__(self, x=1):
|
||||
self.x = x
|
||||
def new_sorter():
|
||||
pass
|
||||
|
||||
def sorter(self, arr):
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
|
|
@ -216,16 +216,13 @@ class BubbleSorter:
|
|||
pytest_max_loops=1,
|
||||
testing_time=0.1,
|
||||
)
|
||||
print(test_results_mutated_attr[0].return_value[1])
|
||||
assert test_results_mutated_attr[0].return_value[1]["self"].x == 1
|
||||
assert test_results_mutated_attr[0].return_value[1]["self"].z != 1234
|
||||
assert compare_test_results(
|
||||
assert not compare_test_results(
|
||||
test_results, test_results_mutated_attr
|
||||
) # The test should fail because the instance attribute was mutated
|
||||
# Replace with optimized code that did not mutate existing instance attribute, but added a new one
|
||||
optimized_code_new_attr = """
|
||||
class BubbleSorter:
|
||||
z = 0
|
||||
def __init__(self, x=0):
|
||||
self.x = x
|
||||
self.y = 2
|
||||
|
|
@ -249,34 +246,35 @@ class BubbleSorter:
|
|||
pytest_max_loops=1,
|
||||
testing_time=0.1,
|
||||
)
|
||||
# assert compare_test_results(
|
||||
# test_results, test_results_new_attr
|
||||
# ) # The test should pass because the instance attribute was not mutated, only a new one was added
|
||||
assert not compare_test_results(
|
||||
test_results, test_results_new_attr
|
||||
) # The test should pass because the instance attribute was not mutated, only a new one was added
|
||||
finally:
|
||||
fto_path.write_text(original_code, "utf-8")
|
||||
test_path.unlink(missing_ok=True)
|
||||
test_path_perf.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_aiservice_class_method_behavior_results_2() -> None:
|
||||
def test_class_method_behavior_results_with_codeflash_capture() -> None:
|
||||
test_source = """import pytest
|
||||
from code_to_optimize.bubble_sort import BubbleSorter
|
||||
obj = BubbleSorter()
|
||||
|
||||
def test_single_element_list():
|
||||
# Test that a single element list returns the same single element
|
||||
test_list = [42]
|
||||
codeflash_output = obj.sorter(test_list)
|
||||
obj = BubbleSorter()
|
||||
codeflash_output = obj.sorter([3,2,1])
|
||||
"""
|
||||
instrumented_behavior_test_source = (
|
||||
behavior_logging_code
|
||||
+ """
|
||||
import pytest
|
||||
from code_to_optimize.bubble_sort_method import BubbleSorter
|
||||
obj = BubbleSorter()
|
||||
|
||||
|
||||
def test_single_element_list():
|
||||
codeflash_loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
test_list = [42]
|
||||
_call__bound__arguments = inspect.signature(BubbleSorter.sorter).bind(obj,[42])
|
||||
obj = BubbleSorter()
|
||||
_call__bound__arguments = inspect.signature(BubbleSorter.sorter).bind(obj,[3,2,1])
|
||||
_call__bound__arguments.apply_defaults()
|
||||
|
||||
codeflash_return_value = codeflash_wrap(
|
||||
|
|
@ -309,6 +307,8 @@ def test_single_element_list():
|
|||
os.chdir(run_cwd)
|
||||
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_method.py").resolve()
|
||||
original_code = fto_path.read_text("utf-8")
|
||||
function_to_optimize = FunctionToOptimize("sorter", fto_path, [FunctionParent("BubbleSorter", "ClassDef")])
|
||||
|
||||
try:
|
||||
temp_run_dir = get_run_tmp_file(Path()).as_posix()
|
||||
instrumented_behavior_test_source = instrumented_behavior_test_source.replace(
|
||||
|
|
@ -316,7 +316,8 @@ def test_single_element_list():
|
|||
)
|
||||
with test_path.open("w") as f:
|
||||
f.write(instrumented_behavior_test_source)
|
||||
|
||||
# Add codeflash capture decorator
|
||||
instrument_code(function_to_optimize)
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=project_root_path,
|
||||
|
|
@ -328,7 +329,6 @@ def test_single_element_list():
|
|||
test_project_root=project_root_path,
|
||||
)
|
||||
)
|
||||
|
||||
test_env = os.environ.copy()
|
||||
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
|
|
@ -352,17 +352,28 @@ def test_single_element_list():
|
|||
pytest_max_loops=1,
|
||||
testing_time=0.1,
|
||||
)
|
||||
# Verify instance_state result, which checks instance state right after __init__, using codeflash_capture
|
||||
assert test_results[0].id.function_getting_tested == "sorter"
|
||||
assert test_results[0].id.test_function_name == "test_single_element_list"
|
||||
assert test_results[0].did_pass
|
||||
assert test_results[0].return_value[1]["arr"] == [42]
|
||||
assert comparator(test_results[0].return_value[1]["self"], BubbleSorter())
|
||||
assert not comparator(test_results[0].return_value[1]["self"], BubbleSorter(42))
|
||||
assert test_results[0].return_value[2] == [42]
|
||||
assert test_results[0].return_value[0] == {"x": 0}
|
||||
|
||||
# Verify function_to_optimize result
|
||||
assert test_results[1].id.function_getting_tested == "sorter"
|
||||
assert test_results[1].id.test_function_name == "test_single_element_list"
|
||||
assert test_results[1].did_pass
|
||||
|
||||
# Checks input values to the function to see if they have mutated
|
||||
assert comparator(test_results[1].return_value[1]["self"], BubbleSorter())
|
||||
assert test_results[1].return_value[1]["arr"] == [1, 2, 3]
|
||||
|
||||
# Check function return value
|
||||
assert test_results[1].return_value[2] == [1, 2, 3]
|
||||
|
||||
# Replace with optimized code that mutated instance attribute
|
||||
optimized_code_mutated_attr = """
|
||||
class BubbleSorter:
|
||||
z = 3
|
||||
|
||||
def __init__(self, x=1):
|
||||
self.x = x
|
||||
|
||||
|
|
@ -374,8 +385,9 @@ class BubbleSorter:
|
|||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
return arr
|
||||
"""
|
||||
"""
|
||||
fto_path.write_text(optimized_code_mutated_attr, "utf-8")
|
||||
instrument_code(function_to_optimize)
|
||||
test_results_mutated_attr, coverage_data = opt.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
|
|
@ -385,14 +397,13 @@ class BubbleSorter:
|
|||
pytest_max_loops=1,
|
||||
testing_time=0.1,
|
||||
)
|
||||
|
||||
assert test_results_mutated_attr[1].return_value[1]["self"].x == 1
|
||||
assert not compare_test_results(
|
||||
test_results, test_results_mutated_attr
|
||||
) # The test should fail because the instance attribute was mutated
|
||||
# Replace with optimized code that did not mutate existing instance attribute, but added a new one
|
||||
optimized_code_new_attr = """
|
||||
class BubbleSorter:
|
||||
z = 3
|
||||
def __init__(self, x=0):
|
||||
self.x = x
|
||||
self.y = 2
|
||||
|
|
@ -405,8 +416,9 @@ class BubbleSorter:
|
|||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
return arr
|
||||
"""
|
||||
"""
|
||||
fto_path.write_text(optimized_code_new_attr, "utf-8")
|
||||
instrument_code(function_to_optimize)
|
||||
test_results_new_attr, coverage_data = opt.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
|
|
@ -416,11 +428,11 @@ class BubbleSorter:
|
|||
pytest_max_loops=1,
|
||||
testing_time=0.1,
|
||||
)
|
||||
assert compare_test_results(
|
||||
assert test_results_new_attr[1].return_value[1]["self"].x == 0
|
||||
assert test_results_new_attr[1].return_value[1]["self"].y == 2
|
||||
assert not compare_test_results(
|
||||
test_results, test_results_new_attr
|
||||
) # The test should pass because the instance attribute was not mutated, only a new one was added
|
||||
|
||||
finally:
|
||||
fto_path.write_text(original_code, "utf-8")
|
||||
test_path.unlink(missing_ok=True)
|
||||
test_path_perf.unlink(missing_ok=True)
|
||||
|
|
|
|||
Loading…
Reference in a new issue