From 0aa7ca4ea474b7fafe59ccb2bac2bbf274147aa8 Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Tue, 21 Jan 2025 12:34:19 -0800 Subject: [PATCH] added verificationtype to sqlite in test results, initial work on comparator superset --- code_to_optimize/bubble_sort_method.py | 15 ++++++ .../code_utils/instrument_existing_tests.py | 7 +-- codeflash/models/models.py | 5 ++ codeflash/optimization/optimizer.py | 5 ++ codeflash/verification/codeflash_capture.py | 7 ++- codeflash/verification/comparator.py | 49 ++----------------- codeflash/verification/equivalence.py | 2 + codeflash/verification/parse_test_output.py | 4 +- codeflash/verification/test_results.py | 1 + tests/test_comparator.py | 15 ++++++ 10 files changed, 60 insertions(+), 50 deletions(-) create mode 100644 code_to_optimize/bubble_sort_method.py diff --git a/code_to_optimize/bubble_sort_method.py b/code_to_optimize/bubble_sort_method.py new file mode 100644 index 000000000..eec327a6b --- /dev/null +++ b/code_to_optimize/bubble_sort_method.py @@ -0,0 +1,15 @@ + +class BubbleSorter: + def __init__(self): + self.x = 1 + + def sorter(self, arr): + for i in range(len(arr)): + for j in range(len(arr) - 1): + if arr[j] > arr[j + 1]: + temp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = temp + return arr + + \ No newline at end of file diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index af3f5b40a..eac553e78 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -9,7 +9,7 @@ import isort from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent, TestingMode +from codeflash.models.models import FunctionParent, TestingMode, VerificationType if TYPE_CHECKING: from collections.abc import Iterable @@ -251,7 +251,7 @@ class InjectPerfOnly(ast.NodeTransformer): ast.Constant( value="CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT," " test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT," - " loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)" + " loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)" ) ], keywords=[], @@ -684,7 +684,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load() ), args=[ - ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)"), + ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"), ast.Tuple( elts=[ ast.Name(id="test_module_name", ctx=ast.Load()), @@ -695,6 +695,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun ast.Name(id="invocation_id", ctx=ast.Load()), ast.Name(id="codeflash_duration", ctx=ast.Load()), ast.Name(id="pickled_return_value", ctx=ast.Load()), + ast.Constant(value=VerificationType.FUNCTION_TO_OPTIMIZE), ], ctx=ast.Load(), ), diff --git a/codeflash/models/models.py b/codeflash/models/models.py index df7a03400..75eaca62d 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -84,6 +84,11 @@ class CodeOptimizationContext(BaseModel): preexisting_objects: list[tuple[str, list[FunctionParent]]] +class VerificationType(str, Enum): + FUNCTION_TO_OPTIMIZE = "function_to_optimize" + INSTANCE_STATE = "instance_state" + + class OptimizedCandidateResult(BaseModel): max_loop_count: int best_test_runtime: int diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 263579fee..cbea7539d 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -514,6 +514,8 @@ class Optimizer: optimized_code=candidate.source_code, qualified_function_name=function_to_optimize.qualified_name, ) + # If init was modified, instrument the code with codeflash capture + if not did_update: logger.warning( "No functions were replaced in the optimized code. Skipping optimization candidate." @@ -529,6 +531,9 @@ class Optimizer: optimization_candidate_index=candidate_index, baseline_results=original_code_baseline ) console.rule() + + # Remove codeflash capture + if not is_successful(run_results): optimized_runtimes[candidate.optimization_id] = None is_correct[candidate.optimization_id] = False diff --git a/codeflash/verification/codeflash_capture.py b/codeflash/verification/codeflash_capture.py index ee21c27c3..b0fbd6fa1 100644 --- a/codeflash/verification/codeflash_capture.py +++ b/codeflash/verification/codeflash_capture.py @@ -9,6 +9,8 @@ import time import dill as pickle +from codeflash.models.models import VerificationType + def get_test_info_from_stack() -> tuple[str, str | None, str, str]: """Extract test information from the call stack.""" @@ -92,7 +94,7 @@ def codeflash_capture(function_name: str, tmp_dir_path: str): instance_state = args[0].__dict__ # self is always the first argument print(instance_state) codeflash_cur.execute( - "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)" + "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)" ) # Write to sqlite @@ -108,7 +110,7 @@ def codeflash_capture(function_name: str, tmp_dir_path: str): pickled_return_value, ) codeflash_cur.execute( - "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", ( test_module_name, test_class_name, @@ -118,6 +120,7 @@ def codeflash_capture(function_name: str, tmp_dir_path: str): invocation_id, codeflash_duration, pickled_return_value, + VerificationType.INSTANCE_STATE, ), ) codeflash_con.commit() diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 5aa52a7ee..d7c40282b 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -44,43 +44,7 @@ except ImportError: HAS_PYRSISTENT = False -def type_comparator(orig: Any, new: Any) -> bool: - """Custom type comparator for comparing two objects of the same type.""" - if isinstance( - orig, - ( - str, - int, - float, - bool, - complex, - type(None), - decimal.Decimal, - set, - list, - tuple, - bytes, - bytearray, - memoryview, - frozenset, - enum.Enum, - type, - ), - ): - return type(orig) == type(new) - # Compare attributes of the type object, not the type itself. Reimported class objects have different memory addresses. - type_obj = type(orig) - new_type_obj = type(new) - if ( - type_obj.__name__ != new_type_obj.__name__ - or type_obj.__qualname__ != new_type_obj.__qualname__ - or type_obj.__bases__ != new_type_obj.__bases__ - ): - return False - return True - - -def comparator(orig: Any, new: Any) -> bool: +def comparator(orig: Any, new: Any, superset_obj=False) -> bool: try: # if not type_comparator(orig, new): if type(orig) is not type(new): @@ -239,14 +203,11 @@ def comparator(orig: Any, new: Any) -> bool: new_keys = dict(new_keys) orig_keys = {k: v for k, v in orig_keys.items() if not k.startswith("__")} new_keys = {k: v for k, v in new_keys.items() if not k.startswith("__")} + + if superset_obj: + # allow new object to be a superset of the original object + return all(k in new_keys and comparator(v, new_keys[k]) for k, v in orig_keys.items()) return comparator(orig_keys, new_keys) - # # Check that all original attributes exist and match in new object - # for key, value in orig_keys.items(): - # if key not in new_keys: - # return False - # if not comparator(value, new_keys[key]): - # return False - # return True # Allow additional attributes in new_keys if type(orig) in [types.BuiltinFunctionType, types.BuiltinMethodType]: return new == orig diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 1c4f54e7a..8ba4aeed8 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -30,6 +30,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR did_all_timeout = did_all_timeout and original_test_result.timed_out if original_test_result.timed_out: continue + # if original_test_result.verification_type and original_test_result.verification_type == VerificationType.INSTANCE_STATE: + # Do superset comparator if not comparator(original_test_result.return_value, cdd_test_result.return_value): are_equal = False break diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 72346e284..3ee0ca323 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -108,7 +108,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes cur = db.cursor() data = cur.execute( "SELECT test_module_path, test_class_name, test_function_name, " - "function_getting_tested, loop_index, iteration_id, runtime, return_value FROM test_results" + "function_getting_tested, loop_index, iteration_id, runtime, return_value,verification_type FROM test_results" ).fetchall() finally: db.close() @@ -119,6 +119,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes # TODO : this is because sqlite writes original file module path. Should make it consistent test_type = test_files.get_test_type_by_original_file_path(test_file_path) loop_index = val[4] + verification_type = val[8] try: ret_val = (pickle.loads(val[7]) if loop_index == 1 else None,) except Exception: @@ -140,6 +141,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes test_type=test_type, return_value=ret_val, timed_out=False, + verification_type=verification_type if verification_type else None, ) ) except Exception: diff --git a/codeflash/verification/test_results.py b/codeflash/verification/test_results.py index 3dfbca120..cf6b5264c 100644 --- a/codeflash/verification/test_results.py +++ b/codeflash/verification/test_results.py @@ -74,6 +74,7 @@ class FunctionTestInvocation: test_type: TestType return_value: Optional[object] # The return value of the function invocation timed_out: Optional[bool] + verification_type: Optional[str] @property def unique_invocation_loop_id(self) -> str: diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 90a12aa72..7dc2e62d4 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -739,6 +739,21 @@ class BubbleSorter: fto_path.write_text(original_code, "utf-8") +def test_superset(): + class A: + def __init__(self): + self.a = 1 + + class B(A): + def __init__(self): + super().__init__() + self.b = 2 + + assert comparator(A(), B(), superset_obj=True) + assert not comparator(B(), A(), superset_obj=True) + assert not comparator(A(), B()) + + def test_compare_results_fn(): original_results = TestResults() original_results.add(