now enable instrumentation for helper classes

This commit is contained in:
Alvin Ryanputra 2025-01-21 16:22:38 -08:00
parent 0aa7ca4ea4
commit 7a37e6e0eb
15 changed files with 722 additions and 450 deletions

View file

@ -1,7 +1,6 @@
class BubbleSorter:
def __init__(self):
self.x = 1
def __init__(self, x=0):
self.x = x
def sorter(self, arr):
for i in range(len(arr)):
@ -11,5 +10,3 @@ class BubbleSorter:
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr

View file

@ -119,7 +119,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
cst_to_code(self.modified_init_functions[self.current_class])
):
return original_node # Code was unchanged, so don't modify docstrings / comments
return merge_init_functions(updated_node, self.modified_init_functions[self.current_class])
return self.modified_init_functions[self.current_class]
return updated_node
@ -162,105 +162,105 @@ class OptimFunctionReplacer(cst.CSTTransformer):
return node
class AttributeCollector(cst.CSTVisitor):
"""Collects all self.attribute mentions in a CST."""
# class AttributeCollector(cst.CSTVisitor):
# """Collects all self.attribute mentions in a CST."""
#
# def __init__(self):
# super().__init__()
# self.attributes: set[str] = set()
#
# def visit_Attribute(self, node: cst.Attribute) -> bool:
# """Record any self.attribute access."""
# if isinstance(node.value, cst.Name) and node.value.value == "self":
# self.attributes.add(node.attr.value)
# return True
#
#
# class AssignmentCollector(cst.CSTVisitor):
# """Collects attributes being assigned to in a CST."""
#
# def __init__(self):
# super().__init__()
# self.assigned_attrs: set[str] = set()
#
# def visit_Assign(self, node: cst.Assign) -> bool:
# """Check regular assignments like self.x = ..."""
# for target in node.targets:
# if (
# isinstance(target.target, cst.Attribute)
# and isinstance(target.target.value, cst.Name)
# and target.target.value.value == "self"
# ):
# self.assigned_attrs.add(target.target.attr.value)
# return True
#
# def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
# """Check annotated assignments like self.x: str = ..."""
# if (
# isinstance(node.target, cst.Attribute)
# and isinstance(node.target.value, cst.Name)
# and node.target.value.value == "self"
# ):
# self.assigned_attrs.add(node.target.attr.value)
# return True
#
# def visit_AugAssign(self, node: cst.AugAssign) -> bool:
# """Check augmented assignments like self.x += ..."""
# if (
# isinstance(node.target, cst.Attribute)
# and isinstance(node.target.value, cst.Name)
# and node.target.value.value == "self"
# ):
# self.assigned_attrs.add(node.target.attr.value)
# return True
def __init__(self):
super().__init__()
self.attributes: set[str] = set()
def visit_Attribute(self, node: cst.Attribute) -> bool:
"""Record any self.attribute access."""
if isinstance(node.value, cst.Name) and node.value.value == "self":
self.attributes.add(node.attr.value)
return True
class AssignmentCollector(cst.CSTVisitor):
"""Collects attributes being assigned to in a CST."""
def __init__(self):
super().__init__()
self.assigned_attrs: set[str] = set()
def visit_Assign(self, node: cst.Assign) -> bool:
"""Check regular assignments like self.x = ..."""
for target in node.targets:
if (
isinstance(target.target, cst.Attribute)
and isinstance(target.target.value, cst.Name)
and target.target.value.value == "self"
):
self.assigned_attrs.add(target.target.attr.value)
return True
def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
"""Check annotated assignments like self.x: str = ..."""
if (
isinstance(node.target, cst.Attribute)
and isinstance(node.target.value, cst.Name)
and node.target.value.value == "self"
):
self.assigned_attrs.add(node.target.attr.value)
return True
def visit_AugAssign(self, node: cst.AugAssign) -> bool:
"""Check augmented assignments like self.x += ..."""
if (
isinstance(node.target, cst.Attribute)
and isinstance(node.target.value, cst.Name)
and node.target.value.value == "self"
):
self.assigned_attrs.add(node.target.attr.value)
return True
def merge_init_functions(original_init: cst.FunctionDef, new_init: cst.FunctionDef) -> cst.FunctionDef:
"""Merges two __init__ function definitions. Collects all self.attribute mentions
from the original init, then filters out statements from the new init that
assign to those attributes (but allows reading them).
Args:
original_init: The original __init__ function to preserve
new_init: The new __init__ function whose body will be filtered and appended
Returns:
A merged FunctionDef
"""
# Collect all self.attribute mentions from original init
collector = AttributeCollector()
original_init.visit(collector)
existing_attrs = collector.attributes
# Get set of existing statements, without comments
original_stmts = {get_only_code_content(cst_to_code(stmt)) for stmt in original_init.body.body}
# Filter new init body statements
filtered_body = []
for stmt in new_init.body.body:
# Filter out docstring of new init
if (
isinstance(stmt, cst.SimpleStatementLine)
and len(stmt.body) == 1
and isinstance(stmt.body[0], cst.Expr)
and isinstance(stmt.body[0].value, cst.SimpleString)
):
continue
# Filter out duplicate statements
if get_only_code_content(cst_to_code(stmt)) in original_stmts:
continue
# Check for assignments to existing attributes
assign_collector = AssignmentCollector()
stmt.visit(assign_collector)
# Keep statement if it doesn't assign to any existing attributes
if not assign_collector.assigned_attrs.intersection(existing_attrs):
filtered_body.append(stmt)
# Merge bodies using with_changes
return original_init.with_changes(
body=original_init.body.with_changes(body=original_init.body.body + tuple(filtered_body))
)
#
# def merge_init_functions(original_init: cst.FunctionDef, new_init: cst.FunctionDef) -> cst.FunctionDef:
# """Merges two __init__ function definitions. Collects all self.attribute mentions
# from the original init, then filters out statements from the new init that
# assign to those attributes (but allows reading them).
#
# Args:
# original_init: The original __init__ function to preserve
# new_init: The new __init__ function whose body will be filtered and appended
#
# Returns:
# A merged FunctionDef
#
# """
# # Collect all self.attribute mentions from original init
# collector = AttributeCollector()
# original_init.visit(collector)
# existing_attrs = collector.attributes
# # Get set of existing statements, without comments
# original_stmts = {get_only_code_content(cst_to_code(stmt)) for stmt in original_init.body.body}
# # Filter new init body statements
# filtered_body = []
#
# for stmt in new_init.body.body:
# # Filter out docstring of new init
# if (
# isinstance(stmt, cst.SimpleStatementLine)
# and len(stmt.body) == 1
# and isinstance(stmt.body[0], cst.Expr)
# and isinstance(stmt.body[0].value, cst.SimpleString)
# ):
# continue
# # Filter out duplicate statements
# if get_only_code_content(cst_to_code(stmt)) in original_stmts:
# continue
# # Check for assignments to existing attributes
# assign_collector = AssignmentCollector()
# stmt.visit(assign_collector)
#
# # Keep statement if it doesn't assign to any existing attributes
# if not assign_collector.assigned_attrs.intersection(existing_attrs):
# filtered_body.append(stmt)
#
# # Merge bodies using with_changes
# return original_init.with_changes(
# body=original_init.body.with_changes(body=original_init.body.body + tuple(filtered_body))
# )
def replace_functions_in_file(

View file

@ -9,7 +9,8 @@ import isort
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent, TestingMode, VerificationType
from codeflash.models.models import FunctionParent, TestingMode
from codeflash.verification.test_results import VerificationType
if TYPE_CHECKING:
from collections.abc import Iterable

View file

@ -70,6 +70,7 @@ def get_code_optimization_context(
read_only_context_code=read_only_code_markdown.markdown,
helper_functions=helpers_of_fto_obj_list,
preexisting_objects=preexisting_objects,
)
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")

View file

@ -84,11 +84,6 @@ class CodeOptimizationContext(BaseModel):
preexisting_objects: list[tuple[str, list[FunctionParent]]]
class VerificationType(str, Enum):
FUNCTION_TO_OPTIMIZE = "function_to_optimize"
INSTANCE_STATE = "instance_state"
class OptimizedCandidateResult(BaseModel):
max_loop_count: int
best_test_runtime: int

View file

@ -337,8 +337,15 @@ class Optimizer:
)
# Instrument code
original_code = validated_original_code[function_to_optimize.file_path].source_code
instrument_code(function_to_optimize)
# Get a dict of file_path_to_classes of fto and helpers_of_fto
file_path_to_helper_classes = defaultdict(set)
for function_source in code_context.helper_functions:
if (
function_source.qualified_name != function_to_optimize.qualified_name
and "." in function_source.qualified_name
):
file_path_to_helper_classes[function_source.file_path].add(function_source.qualified_name.split(".")[0])
instrument_code(function_to_optimize, file_path_to_helper_classes)
baseline_result = self.establish_original_code_baseline( # this needs better typing
function_name=function_to_optimize_qualified_name,
@ -347,7 +354,11 @@ class Optimizer:
)
# Remove instrumentation
self.write_code_and_helpers(original_code, {}, function_to_optimize.file_path)
self.write_code_and_helpers(
validated_original_code[function_to_optimize.file_path].source_code,
original_helper_code,
function_to_optimize.file_path,
)
console.rule()
paths_to_cleanup = (
@ -514,7 +525,6 @@ class Optimizer:
optimized_code=candidate.source_code,
qualified_function_name=function_to_optimize.qualified_name,
)
# If init was modified, instrument the code with codeflash capture
if not did_update:
logger.warning(
@ -527,6 +537,8 @@ class Optimizer:
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
continue
# Instrument codeflash capture
run_results = self.run_optimized_candidate(
optimization_candidate_index=candidate_index, baseline_results=original_code_baseline
)

View file

@ -9,7 +9,7 @@ import time
import dill as pickle
from codeflash.models.models import VerificationType
from codeflash.verification.test_results import VerificationType
def get_test_info_from_stack() -> tuple[str, str | None, str, str]:
@ -37,7 +37,7 @@ def get_test_info_from_stack() -> tuple[str, str | None, str, str]:
return test_module_name, test_class_name, test_name, line_id
def codeflash_capture(function_name: str, tmp_dir_path: str):
def codeflash_capture(function_name: str, tmp_dir_path: str, is_fto: bool = False):
"""Defines decorator to be instrumented onto the init function in the code. Collects info of the test that called this, and captures the state of the instance."""
def decorator(wrapped):
@ -120,7 +120,7 @@ def codeflash_capture(function_name: str, tmp_dir_path: str):
invocation_id,
codeflash_duration,
pickled_return_value,
VerificationType.INSTANCE_STATE,
VerificationType.INSTANCE_STATE_FTO if is_fto else VerificationType.INSTANCE_STATE_HELPER,
),
)
codeflash_con.commit()

View file

@ -23,7 +23,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id)
if cdd_test_result is not None and original_test_result is None:
continue
# If helper function instance_state verification is not present, that's ok. continue
if original_test_result is None or cdd_test_result is None:
are_equal = False
break

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import ast
from pathlib import Path
@ -5,38 +7,56 @@ from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
def instrument_code(function_to_optimize: FunctionToOptimize) -> None:
def instrument_code(function_to_optimize: FunctionToOptimize, file_path_to_helper_class: dict[Path, set[str]]) -> None:
"""Instrument __init__ function with codeflash_capture decorator if it's in a class."""
# Find the class parent
if len(function_to_optimize.parents) == 1 and function_to_optimize.parents[0].type == "ClassDef":
class_parent = function_to_optimize.parents[0]
else:
return
# Read code from file
# Remove duplicate fto class from helper classes
if function_to_optimize.file_path in file_path_to_helper_class:
file_path_to_helper_class[function_to_optimize.file_path].remove(class_parent.name)
# Instrument fto class
with open(function_to_optimize.file_path) as f:
original_code = f.read()
# Add decorator to init
modified_code = add_codeflash_capture_to_init(
class_name=class_parent.name,
function_name=function_to_optimize.function_name,
target_classes={class_parent.name},
fto_name=function_to_optimize.function_name,
tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))),
code=original_code,
is_fto=True,
)
# Write modified code back to file
with open(function_to_optimize.file_path, "w") as f:
f.write(modified_code)
# Instrument helper classes
for file_path, helper_classes in file_path_to_helper_class.items():
with open(file_path) as f:
original_code = f.read()
modified_code = add_codeflash_capture_to_init(
target_classes=helper_classes,
fto_name=function_to_optimize.function_name,
tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))),
code=original_code,
is_fto=False,
)
with open(file_path, "w") as f:
f.write(modified_code)
def add_codeflash_capture_to_init(class_name: str, function_name: str, tmp_dir_path: str, code: str) -> str:
def add_codeflash_capture_to_init(
target_classes: set[str], fto_name: str, tmp_dir_path: str, code: str, is_fto: bool = False
) -> str:
"""Add codeflash_capture decorator to __init__ function in the specified class."""
# Parse the code into an AST
tree = ast.parse(code)
# Apply our transformation
transformer = InitDecorator(class_name, function_name, tmp_dir_path)
transformer = InitDecorator(target_classes, fto_name, tmp_dir_path, is_fto)
modified_tree = transformer.visit(tree)
if transformer.inserted_decorator:
ast.fix_missing_locations(modified_tree)
@ -48,10 +68,11 @@ def add_codeflash_capture_to_init(class_name: str, function_name: str, tmp_dir_p
class InitDecorator(ast.NodeTransformer):
"""AST transformer that adds codeflash_capture decorator to specific class's __init__."""
def __init__(self, target_class_name: str, function_name: str, tmp_dir_path: str):
self.target_class_name = target_class_name
self.function_name = function_name
def __init__(self, target_classes: set[str], fto_name: str, tmp_dir_path: str, is_fto=False) -> None:
self.target_classes = target_classes
self.fto_name = fto_name
self.tmp_dir_path = tmp_dir_path
self.is_fto = is_fto
self.has_import = False
self.inserted_decorator = False
@ -74,7 +95,7 @@ class InitDecorator(ast.NodeTransformer):
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
# Only modify the target class
if node.name != self.target_class_name:
if node.name not in self.target_classes:
return node
# Look for __init__ method
@ -85,8 +106,9 @@ class InitDecorator(ast.NodeTransformer):
func=ast.Name(id="codeflash_capture", ctx=ast.Load()),
args=[],
keywords=[
ast.keyword(arg="function_name", value=ast.Constant(value=self.function_name)),
ast.keyword(arg="function_name", value=ast.Constant(value=self.fto_name)),
ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)),
ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)),
],
)

View file

@ -21,7 +21,7 @@ from codeflash.code_utils.code_utils import (
)
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
from codeflash.models.models import CoverageData, TestFiles
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, VerificationType
if TYPE_CHECKING:
import subprocess
@ -92,6 +92,7 @@ def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, tes
test_type=test_type,
return_value=test_pickle,
timed_out=False,
verification_type=VerificationType.FUNCTION_TO_OPTIMIZE,
)
)
return test_results
@ -141,7 +142,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
test_type=test_type,
return_value=ret_val,
timed_out=False,
verification_type=verification_type if verification_type else None,
verification_type=VerificationType(verification_type) if verification_type else None,
)
)
except Exception:
@ -376,6 +377,7 @@ def merge_test_results(
test_type=xml_result.test_type,
return_value=result_bin.return_value,
timed_out=xml_result.timed_out,
verification_type=VerificationType(result_bin.verification_type),
)
)
elif xml_results.test_results[0].id.iteration_id is not None:
@ -402,6 +404,7 @@ def merge_test_results(
timed_out=xml_result.timed_out
if bin_result.runtime is None
else False, # If runtime was measured in the bin file, then the testcase did not time out
verification_type=VerificationType(bin_result.verification_type),
)
)
else:
@ -425,6 +428,7 @@ def merge_test_results(
test_type=bin_result.test_type,
return_value=bin_result.return_value,
timed_out=xml_result.timed_out, # only the xml gets the timed_out flag
verification_type=VerificationType(bin_result.verification_type),
)
)

View file

@ -13,6 +13,19 @@ from codeflash.cli_cmds.console import DEBUG_MODE, logger
from codeflash.verification.comparator import comparator
class VerificationType(str, Enum):
FUNCTION_TO_OPTIMIZE = (
"function_to_optimize" # Correctness verification for fto, checks input values and output values
)
INSTANCE_STATE_FTO = "instance_state_fto" # Correctness verification for instance state of fto, checks instance attributes right after __init__ is called
INSTANCE_STATE_HELPER = "instance_state_helper" # Correctness verification for instance state of helper classes, checks instance attributes right after __init__ is called
def __new__(cls, value: str) -> VerificationType | None:
obj = str.__new__(cls, value)
obj._value_ = value if value != "" else None
return obj
class TestType(Enum):
EXISTING_UNIT_TEST = 1
INSPIRED_REGRESSION = 2
@ -74,7 +87,7 @@ class FunctionTestInvocation:
test_type: TestType
return_value: Optional[object] # The return value of the function invocation
timed_out: Optional[bool]
verification_type: Optional[str]
verification_type: str = VerificationType.FUNCTION_TO_OPTIMIZE
@property
def unique_invocation_loop_id(self) -> str:

View file

@ -744,14 +744,15 @@ def test_superset():
def __init__(self):
self.a = 1
class B(A):
def __init__(self):
super().__init__()
self.b = 2
obj = A()
obj.x = 3
assert comparator(A(), B(), superset_obj=True)
assert not comparator(B(), A(), superset_obj=True)
assert not comparator(A(), B())
assert comparator(A(), obj, superset_obj=True)
assert not comparator(obj, A(), superset_obj=True)
assert not comparator(A(), obj)
assert not comparator(obj, A())
assert comparator(obj, obj, superset_obj=True)
assert comparator(obj, obj)
def test_compare_results_fn():

View file

@ -1,231 +1,231 @@
from textwrap import dedent
import libcst as cst
from codeflash.code_utils.code_replacer import merge_init_functions, replace_functions_in_file
from codeflash.models.models import FunctionParent
def test_basic_merge() -> None:
original = """
class MyClass:
def __init__(self, a, b):
self.a = a
self.b = b
"""
new = """
class MyClass:
def __init__(self, x, y):
self.x = x
self.y = y
"""
result = merge_init_functions(
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
)
expected = """
def __init__(self, a, b):
self.a = a
self.b = b
self.x = x
self.y = y
"""
assert cst.Module([result]).code.strip() == dedent(expected).strip()
def test_prevent_duplication() -> None:
original = """
class MyClass:
def __init__(self):
self.x = 1
print("init")
self.setup()
"""
new = """
class MyClass:
def __init__(self):
print("init")
self.setup()
self.y = 2
"""
result = merge_init_functions(
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
)
expected = """
def __init__(self):
self.x = 1
print("init")
self.setup()
self.y = 2
"""
assert cst.Module([result]).code.strip() == dedent(expected).strip()
def test_prevent_overwrite() -> None:
original = """
class MyClass:
def __init__(self):
self.x = 1
self.y = 2
"""
new = """
class MyClass:
def __init__(self):
self.x = 2
"""
result = merge_init_functions(
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
)
expected = """
def __init__(self):
self.x = 1
self.y = 2
"""
assert cst.Module([result]).code.strip() == dedent(expected).strip()
def test_complex_control_flow() -> None:
original = """
class MyClass:
def __init__(self):
with self.lock:
self.setup()
if self.debug:
self.enable_logging()
"""
new = """
class MyClass:
def __init__(self):
try:
self.connect()
except ConnectionError:
self.fallback()
"""
result = merge_init_functions(
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
)
expected = """
def __init__(self):
with self.lock:
self.setup()
if self.debug:
self.enable_logging()
try:
self.connect()
except ConnectionError:
self.fallback()
"""
assert cst.Module([result]).code.strip() == dedent(expected).strip()
def test_docstrings_and_comments() -> None:
original = """
class MyClass:
def __init__(self):
# Setup configuration
self.config = {} # Empty config
"""
new = """
class MyClass:
def __init__(self):
\"\"\"New docstring.\"\"\"
# Initialize database
self.db = None # Database connection
"""
result = merge_init_functions(
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
)
expected = """
def __init__(self):
# Setup configuration
self.config = {} # Empty config
# Initialize database
self.db = None # Database connection
"""
assert cst.Module([result]).code.strip() == dedent(expected).strip()
def test_type_annotations() -> None:
original = """
class MyClass:
def __init__(self) -> None:
self.x: int = 1
self.y: str = "hello"
"""
new = """
class MyClass:
def __init__(self):
self.y: str = "new hello"
self.z: float = 2.0
"""
result = merge_init_functions(
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
)
expected = """
def __init__(self) -> None:
self.x: int = 1
self.y: str = "hello"
self.z: float = 2.0
"""
assert cst.Module([result]).code.strip() == dedent(expected).strip()
# Tests for code replacement with init
def test_merge_init_methods() -> None:
optim_code = """class MyClass:
def __init__(self):
self.y = 2
self.z = 3
"""
original_code = """class MyClass:
def __init__(self):
self.y = 1
self.setup()
"""
expected = """class MyClass:
def __init__(self):
self.y = 1
self.setup()
self.z = 3
"""
result = replace_functions_in_file(
source_code=original_code,
original_function_names=[],
optimized_code=optim_code,
preexisting_objects=[("__init__", [FunctionParent(name="MyClass", type="ClassDef")])],
)
assert result == expected
def test_init_is_function_to_optimize() -> None:
optim_code = """class MyClass:
def __init__(self):
self.y = 2
self.z = 3
"""
original_code = """class MyClass:
def __init__(self):
self.y = 1
self.setup()
"""
expected = """class MyClass:
def __init__(self):
self.y = 2
self.z = 3
"""
# In this scenario, we leave the mutation check to the usual FTO behaviour check.
result = replace_functions_in_file(
source_code=original_code,
original_function_names=["MyClass.__init__"],
optimized_code=optim_code,
preexisting_objects=[("__init__", [FunctionParent(name="MyClass", type="ClassDef")])],
)
assert result == expected
# from textwrap import dedent
#
# import libcst as cst
# from codeflash.code_utils.code_replacer import merge_init_functions, replace_functions_in_file
# from codeflash.models.models import FunctionParent
#
#
# def test_basic_merge() -> None:
# original = """
# class MyClass:
# def __init__(self, a, b):
# self.a = a
# self.b = b
# """
# new = """
# class MyClass:
# def __init__(self, x, y):
# self.x = x
# self.y = y
# """
# result = merge_init_functions(
# cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
# )
#
# expected = """
# def __init__(self, a, b):
# self.a = a
# self.b = b
# self.x = x
# self.y = y
# """
# assert cst.Module([result]).code.strip() == dedent(expected).strip()
#
#
# def test_prevent_duplication() -> None:
# original = """
# class MyClass:
# def __init__(self):
# self.x = 1
# print("init")
# self.setup()
# """
# new = """
# class MyClass:
# def __init__(self):
# print("init")
# self.setup()
# self.y = 2
# """
# result = merge_init_functions(
# cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
# )
#
# expected = """
# def __init__(self):
# self.x = 1
# print("init")
# self.setup()
# self.y = 2
# """
# assert cst.Module([result]).code.strip() == dedent(expected).strip()
#
#
# def test_prevent_overwrite() -> None:
# original = """
# class MyClass:
# def __init__(self):
# self.x = 1
# self.y = 2
# """
# new = """
# class MyClass:
# def __init__(self):
# self.x = 2
# """
# result = merge_init_functions(
# cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
# )
#
# expected = """
# def __init__(self):
# self.x = 1
# self.y = 2
# """
# assert cst.Module([result]).code.strip() == dedent(expected).strip()
#
#
# def test_complex_control_flow() -> None:
# original = """
# class MyClass:
# def __init__(self):
# with self.lock:
# self.setup()
# if self.debug:
# self.enable_logging()
# """
# new = """
# class MyClass:
# def __init__(self):
# try:
# self.connect()
# except ConnectionError:
# self.fallback()
# """
# result = merge_init_functions(
# cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
# )
#
# expected = """
# def __init__(self):
# with self.lock:
# self.setup()
# if self.debug:
# self.enable_logging()
# try:
# self.connect()
# except ConnectionError:
# self.fallback()
# """
# assert cst.Module([result]).code.strip() == dedent(expected).strip()
#
#
# def test_docstrings_and_comments() -> None:
# original = """
# class MyClass:
# def __init__(self):
# # Setup configuration
# self.config = {} # Empty config
# """
# new = """
# class MyClass:
# def __init__(self):
# \"\"\"New docstring.\"\"\"
# # Initialize database
# self.db = None # Database connection
# """
# result = merge_init_functions(
# cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
# )
# expected = """
# def __init__(self):
# # Setup configuration
# self.config = {} # Empty config
# # Initialize database
# self.db = None # Database connection
# """
# assert cst.Module([result]).code.strip() == dedent(expected).strip()
#
#
# def test_type_annotations() -> None:
# original = """
# class MyClass:
# def __init__(self) -> None:
# self.x: int = 1
# self.y: str = "hello"
# """
# new = """
# class MyClass:
# def __init__(self):
# self.y: str = "new hello"
# self.z: float = 2.0
# """
# result = merge_init_functions(
# cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
# )
#
# expected = """
# def __init__(self) -> None:
# self.x: int = 1
# self.y: str = "hello"
# self.z: float = 2.0
# """
# assert cst.Module([result]).code.strip() == dedent(expected).strip()
#
#
# # Tests for code replacement with init
# def test_merge_init_methods() -> None:
# optim_code = """class MyClass:
# def __init__(self):
# self.y = 2
# self.z = 3
# """
#
# original_code = """class MyClass:
# def __init__(self):
# self.y = 1
# self.setup()
# """
#
# expected = """class MyClass:
# def __init__(self):
# self.y = 1
# self.setup()
# self.z = 3
# """
#
# result = replace_functions_in_file(
# source_code=original_code,
# original_function_names=[],
# optimized_code=optim_code,
# preexisting_objects=[("__init__", [FunctionParent(name="MyClass", type="ClassDef")])],
# )
# assert result == expected
#
#
# def test_init_is_function_to_optimize() -> None:
# optim_code = """class MyClass:
# def __init__(self):
# self.y = 2
# self.z = 3
# """
#
# original_code = """class MyClass:
# def __init__(self):
# self.y = 1
# self.setup()
# """
#
# expected = """class MyClass:
# def __init__(self):
# self.y = 2
# self.z = 3
# """
# # In this scenario, we leave the mutation check to the usual FTO behaviour check.
# result = replace_functions_in_file(
# source_code=original_code,
# original_function_names=["MyClass.__init__"],
# optimized_code=optim_code,
# preexisting_objects=[("__init__", [FunctionParent(name="MyClass", type="ClassDef")])],
# )
# assert result == expected

View file

@ -7,7 +7,6 @@ from codeflash.verification.instrument_code import instrument_code
def test_add_codeflash_capture():
# Test input code
original_code = """
class MyClass:
def __init__(self):
@ -22,39 +21,30 @@ from codeflash.verification.codeflash_capture import codeflash_capture
class MyClass:
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}')
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=True)
def __init__(self):
self.x = 1
def target_function(self):
return self.x + 1
"""
# Create and modify test file
test_file = Path("test_file.py")
test_file.write_text(original_code)
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="target_function",
file_path=Path("test_file.py"),
parents=[FunctionParent(type="ClassDef", name="MyClass")],
function_name="target_function", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyClass")]
)
try:
# Run the instrumentation
instrument_code(function)
# Check the result
modified_code = test_file.read_text()
instrument_code(function, {})
modified_code = test_path.read_text()
assert modified_code.strip() == expected.strip()
finally:
# Cleanup
test_file.unlink(missing_ok=True)
test_path.unlink(missing_ok=True)
def test_add_codeflash_capture_no_parent():
# Test input code
original_code = """
class MyClass:
@ -68,24 +58,17 @@ class MyClass:
def target_function(self):
return self.x + 1
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
# Create and modify test file
test_file = Path("test_file.py")
test_file.write_text(original_code)
function = FunctionToOptimize(function_name="target_function", file_path=Path("test_file.py"), parents=[])
function = FunctionToOptimize(function_name="target_function", file_path=test_path, parents=[])
try:
# Run the instrumentation
instrument_code(function)
# Check the result
modified_code = test_file.read_text()
instrument_code(function, {})
modified_code = test_path.read_text()
assert modified_code.strip() == expected.strip()
finally:
# Cleanup
test_file.unlink(missing_ok=True)
test_path.unlink(missing_ok=True)
def test_add_codeflash_capture_no_init():
@ -102,32 +85,263 @@ from codeflash.verification.codeflash_capture import codeflash_capture
class MyClass(ParentClass):
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}')
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=True)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def target_function(self):
return self.x + 1
"""
# Create and modify test file
test_file = Path("test_file.py")
test_file.write_text(original_code)
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="target_function",
file_path=Path("test_file.py"),
parents=[FunctionParent(type="ClassDef", name="MyClass")],
function_name="target_function", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyClass")]
)
try:
# Run the instrumentation
instrument_code(function)
# Check the result
modified_code = test_file.read_text()
instrument_code(function, {})
modified_code = test_path.read_text()
assert modified_code.strip() == expected.strip()
finally:
# Cleanup
test_file.unlink(missing_ok=True)
test_path.unlink(missing_ok=True)
def test_add_codeflash_capture_with_helpers():
# Test input code
original_code = """
class MyClass:
def __init__(self):
self.x = 1
def target_function(self):
return helper() + 1
def helper(self):
return self.x
"""
expected = f"""
from codeflash.verification.codeflash_capture import codeflash_capture
class MyClass:
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=True)
def __init__(self):
self.x = 1
def target_function(self):
return helper() + 1
def helper(self):
return self.x
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="target_function", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyClass")]
)
try:
instrument_code(
function, {test_path: {"MyClass"}}
) # MyClass was removed from the file_path_to_helper_class as it shares class with FTO
modified_code = test_path.read_text()
assert modified_code.strip() == expected.strip()
finally:
test_path.unlink(missing_ok=True)
def test_add_codeflash_capture_with_helpers_2():
# Test input code
original_code = """
from test_helper_file import HelperClass
class MyClass:
def __init__(self):
self.x = 1
def target_function(self):
return HelperClass().helper() + 1
"""
original_helper = """
class HelperClass:
def __init__(self):
self.y = 1
def helper(self):
return 1
"""
expected = f"""
from codeflash.verification.codeflash_capture import codeflash_capture
from test_helper_file import HelperClass
class MyClass:
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=True)
def __init__(self):
self.x = 1
def target_function(self):
return HelperClass().helper() + 1
"""
expected_helper = f"""
from codeflash.verification.codeflash_capture import codeflash_capture
class HelperClass:
@codeflash_capture(function_name='helper', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=False)
def __init__(self):
self.y = 1
def helper(self):
return 1
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
helper_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_helper_file.py").resolve()
helper_path.write_text(original_helper)
function = FunctionToOptimize(
function_name="target_function", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyClass")]
)
try:
instrument_code(function, {helper_path: {"HelperClass"}})
modified_code = test_path.read_text()
assert modified_code.strip() == expected.strip()
finally:
test_path.unlink(missing_ok=True)
helper_path.unlink(missing_ok=True)
def test_add_codeflash_capture_with_multiple_helpers():
# Test input code with imports from two helper files
original_code = """
from helper_file_1 import HelperClass1
from helper_file_2 import HelperClass2, AnotherHelperClass
class MyClass:
def __init__(self):
self.x = 1
def target_function(self):
helper1 = HelperClass1().helper1()
helper2 = HelperClass2().helper2()
another = AnotherHelperClass().another_helper()
return helper1 + helper2 + another
"""
# First helper file content
original_helper1 = """
class HelperClass1:
def __init__(self):
self.y = 1
def helper1(self):
return 1
"""
# Second helper file content
original_helper2 = """
class HelperClass2:
def __init__(self):
self.z = 2
def helper2(self):
return 2
class AnotherHelperClass:
def another_helper(self):
return 3
"""
# Expected output code with decorators
expected = f"""
from codeflash.verification.codeflash_capture import codeflash_capture
from helper_file_1 import HelperClass1
from helper_file_2 import HelperClass2, AnotherHelperClass
class MyClass:
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=True)
def __init__(self):
self.x = 1
def target_function(self):
helper1 = HelperClass1().helper1()
helper2 = HelperClass2().helper2()
another = AnotherHelperClass().another_helper()
return helper1 + helper2 + another
"""
# Expected output for first helper file
expected_helper1 = f"""
from codeflash.verification.codeflash_capture import codeflash_capture
class HelperClass1:
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=False)
def __init__(self):
self.y = 1
def helper1(self):
return 1
"""
# Expected output for second helper file
expected_helper2 = f"""
from codeflash.verification.codeflash_capture import codeflash_capture
class HelperClass2:
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=False)
def __init__(self):
self.z = 2
def helper2(self):
return 2
class AnotherHelperClass:
@codeflash_capture(function_name='target_function', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', is_fto=False)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def another_helper(self):
return 3
"""
# Set up test files
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
helper1_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/helper_file_1.py").resolve()
helper2_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/helper_file_2.py").resolve()
# Write original content to files
test_path.write_text(original_code)
helper1_path.write_text(original_helper1)
helper2_path.write_text(original_helper2)
# Create FunctionToOptimize instance
function = FunctionToOptimize(
function_name="target_function", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyClass")]
)
try:
# Instrument code with multiple helper files
helper_classes = {helper1_path: {"HelperClass1"}, helper2_path: {"HelperClass2", "AnotherHelperClass"}}
instrument_code(function, helper_classes)
# Verify the modifications
modified_code = test_path.read_text()
modified_helper1 = helper1_path.read_text()
modified_helper2 = helper2_path.read_text()
assert modified_code.strip() == expected.strip()
assert modified_helper1.strip() == expected_helper1.strip()
assert modified_helper2.strip() == expected_helper2.strip()
finally:
# Clean up test files
test_path.unlink(missing_ok=True)
helper1_path.unlink(missing_ok=True)
helper2_path.unlink(missing_ok=True)

View file

@ -7,10 +7,12 @@ from pathlib import Path
import isort
from code_to_optimize.bubble_sort_method import BubbleSorter
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.models.models import TestFile, TestFiles, TestingMode
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestingMode
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.comparator import comparator
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_code import instrument_code
from codeflash.verification.test_results import TestType
# Used by aiservice instrumentation
@ -85,7 +87,7 @@ def codeflash_wrap(
"""
def test_aiservice_class_method_behavior_results() -> None:
def test_class_method_behavior_results() -> None:
test_source = """import pytest
from code_to_optimize.bubble_sort import BubbleSorter
@ -191,12 +193,10 @@ def test_single_element_list():
# Replace with optimized code that mutated instance attribute
optimized_code_mutated_attr = """
class BubbleSorter:
z = 1234
def __init__(self, x=1):
self.x = x
def new_sorter():
pass
def sorter(self, arr):
for i in range(len(arr)):
for j in range(len(arr) - 1):
@ -216,16 +216,13 @@ class BubbleSorter:
pytest_max_loops=1,
testing_time=0.1,
)
print(test_results_mutated_attr[0].return_value[1])
assert test_results_mutated_attr[0].return_value[1]["self"].x == 1
assert test_results_mutated_attr[0].return_value[1]["self"].z != 1234
assert compare_test_results(
assert not compare_test_results(
test_results, test_results_mutated_attr
) # The test should fail because the instance attribute was mutated
# Replace with optimized code that did not mutate existing instance attribute, but added a new one
optimized_code_new_attr = """
class BubbleSorter:
z = 0
def __init__(self, x=0):
self.x = x
self.y = 2
@ -249,34 +246,35 @@ class BubbleSorter:
pytest_max_loops=1,
testing_time=0.1,
)
# assert compare_test_results(
# test_results, test_results_new_attr
# ) # The test should pass because the instance attribute was not mutated, only a new one was added
assert not compare_test_results(
test_results, test_results_new_attr
) # The test should pass because the instance attribute was not mutated, only a new one was added
finally:
fto_path.write_text(original_code, "utf-8")
test_path.unlink(missing_ok=True)
test_path_perf.unlink(missing_ok=True)
def test_aiservice_class_method_behavior_results_2() -> None:
def test_class_method_behavior_results_with_codeflash_capture() -> None:
test_source = """import pytest
from code_to_optimize.bubble_sort import BubbleSorter
obj = BubbleSorter()
def test_single_element_list():
# Test that a single element list returns the same single element
test_list = [42]
codeflash_output = obj.sorter(test_list)
obj = BubbleSorter()
codeflash_output = obj.sorter([3,2,1])
"""
instrumented_behavior_test_source = (
behavior_logging_code
+ """
import pytest
from code_to_optimize.bubble_sort_method import BubbleSorter
obj = BubbleSorter()
def test_single_element_list():
codeflash_loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_list = [42]
_call__bound__arguments = inspect.signature(BubbleSorter.sorter).bind(obj,[42])
obj = BubbleSorter()
_call__bound__arguments = inspect.signature(BubbleSorter.sorter).bind(obj,[3,2,1])
_call__bound__arguments.apply_defaults()
codeflash_return_value = codeflash_wrap(
@ -309,6 +307,8 @@ def test_single_element_list():
os.chdir(run_cwd)
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_method.py").resolve()
original_code = fto_path.read_text("utf-8")
function_to_optimize = FunctionToOptimize("sorter", fto_path, [FunctionParent("BubbleSorter", "ClassDef")])
try:
temp_run_dir = get_run_tmp_file(Path()).as_posix()
instrumented_behavior_test_source = instrumented_behavior_test_source.replace(
@ -316,7 +316,8 @@ def test_single_element_list():
)
with test_path.open("w") as f:
f.write(instrumented_behavior_test_source)
# Add codeflash capture decorator
instrument_code(function_to_optimize)
opt = Optimizer(
Namespace(
project_root=project_root_path,
@ -328,7 +329,6 @@ def test_single_element_list():
test_project_root=project_root_path,
)
)
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_env["CODEFLASH_LOOP_INDEX"] = "1"
@ -352,17 +352,28 @@ def test_single_element_list():
pytest_max_loops=1,
testing_time=0.1,
)
# Verify instance_state result, which checks instance state right after __init__, using codeflash_capture
assert test_results[0].id.function_getting_tested == "sorter"
assert test_results[0].id.test_function_name == "test_single_element_list"
assert test_results[0].did_pass
assert test_results[0].return_value[1]["arr"] == [42]
assert comparator(test_results[0].return_value[1]["self"], BubbleSorter())
assert not comparator(test_results[0].return_value[1]["self"], BubbleSorter(42))
assert test_results[0].return_value[2] == [42]
assert test_results[0].return_value[0] == {"x": 0}
# Verify function_to_optimize result
assert test_results[1].id.function_getting_tested == "sorter"
assert test_results[1].id.test_function_name == "test_single_element_list"
assert test_results[1].did_pass
# Checks input values to the function to see if they have mutated
assert comparator(test_results[1].return_value[1]["self"], BubbleSorter())
assert test_results[1].return_value[1]["arr"] == [1, 2, 3]
# Check function return value
assert test_results[1].return_value[2] == [1, 2, 3]
# Replace with optimized code that mutated instance attribute
optimized_code_mutated_attr = """
class BubbleSorter:
z = 3
def __init__(self, x=1):
self.x = x
@ -374,8 +385,9 @@ class BubbleSorter:
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr
"""
"""
fto_path.write_text(optimized_code_mutated_attr, "utf-8")
instrument_code(function_to_optimize)
test_results_mutated_attr, coverage_data = opt.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
@ -385,14 +397,13 @@ class BubbleSorter:
pytest_max_loops=1,
testing_time=0.1,
)
assert test_results_mutated_attr[1].return_value[1]["self"].x == 1
assert not compare_test_results(
test_results, test_results_mutated_attr
) # The test should fail because the instance attribute was mutated
# Replace with optimized code that did not mutate existing instance attribute, but added a new one
optimized_code_new_attr = """
class BubbleSorter:
z = 3
def __init__(self, x=0):
self.x = x
self.y = 2
@ -405,8 +416,9 @@ class BubbleSorter:
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr
"""
"""
fto_path.write_text(optimized_code_new_attr, "utf-8")
instrument_code(function_to_optimize)
test_results_new_attr, coverage_data = opt.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
@ -416,11 +428,11 @@ class BubbleSorter:
pytest_max_loops=1,
testing_time=0.1,
)
assert compare_test_results(
assert test_results_new_attr[1].return_value[1]["self"].x == 0
assert test_results_new_attr[1].return_value[1]["self"].y == 2
assert not compare_test_results(
test_results, test_results_new_attr
) # The test should pass because the instance attribute was not mutated, only a new one was added
finally:
fto_path.write_text(original_code, "utf-8")
test_path.unlink(missing_ok=True)
test_path_perf.unlink(missing_ok=True)