added verificationtype to sqlite in test results, initial work on comparator superset

This commit is contained in:
Alvin Ryanputra 2025-01-21 12:34:19 -08:00
parent 79d75ea577
commit 0aa7ca4ea4
10 changed files with 60 additions and 50 deletions

View file

@ -0,0 +1,15 @@
class BubbleSorter:
def __init__(self):
self.x = 1
def sorter(self, arr):
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr

View file

@ -9,7 +9,7 @@ import isort
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent, TestingMode
from codeflash.models.models import FunctionParent, TestingMode, VerificationType
if TYPE_CHECKING:
from collections.abc import Iterable
@ -251,7 +251,7 @@ class InjectPerfOnly(ast.NodeTransformer):
ast.Constant(
value="CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT,"
" test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT,"
" loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)"
" loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)"
)
],
keywords=[],
@ -684,7 +684,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load()
),
args=[
ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)"),
ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"),
ast.Tuple(
elts=[
ast.Name(id="test_module_name", ctx=ast.Load()),
@ -695,6 +695,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
ast.Name(id="invocation_id", ctx=ast.Load()),
ast.Name(id="codeflash_duration", ctx=ast.Load()),
ast.Name(id="pickled_return_value", ctx=ast.Load()),
ast.Constant(value=VerificationType.FUNCTION_TO_OPTIMIZE),
],
ctx=ast.Load(),
),

View file

@ -84,6 +84,11 @@ class CodeOptimizationContext(BaseModel):
preexisting_objects: list[tuple[str, list[FunctionParent]]]
class VerificationType(str, Enum):
FUNCTION_TO_OPTIMIZE = "function_to_optimize"
INSTANCE_STATE = "instance_state"
class OptimizedCandidateResult(BaseModel):
max_loop_count: int
best_test_runtime: int

View file

@ -514,6 +514,8 @@ class Optimizer:
optimized_code=candidate.source_code,
qualified_function_name=function_to_optimize.qualified_name,
)
# If init was modified, instrument the code with codeflash capture
if not did_update:
logger.warning(
"No functions were replaced in the optimized code. Skipping optimization candidate."
@ -529,6 +531,9 @@ class Optimizer:
optimization_candidate_index=candidate_index, baseline_results=original_code_baseline
)
console.rule()
# Remove codeflash capture
if not is_successful(run_results):
optimized_runtimes[candidate.optimization_id] = None
is_correct[candidate.optimization_id] = False

View file

@ -9,6 +9,8 @@ import time
import dill as pickle
from codeflash.models.models import VerificationType
def get_test_info_from_stack() -> tuple[str, str | None, str, str]:
"""Extract test information from the call stack."""
@ -92,7 +94,7 @@ def codeflash_capture(function_name: str, tmp_dir_path: str):
instance_state = args[0].__dict__ # self is always the first argument
print(instance_state)
codeflash_cur.execute(
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)"
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)"
)
# Write to sqlite
@ -108,7 +110,7 @@ def codeflash_capture(function_name: str, tmp_dir_path: str):
pickled_return_value,
)
codeflash_cur.execute(
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
@ -118,6 +120,7 @@ def codeflash_capture(function_name: str, tmp_dir_path: str):
invocation_id,
codeflash_duration,
pickled_return_value,
VerificationType.INSTANCE_STATE,
),
)
codeflash_con.commit()

View file

@ -44,43 +44,7 @@ except ImportError:
HAS_PYRSISTENT = False
def type_comparator(orig: Any, new: Any) -> bool:
"""Custom type comparator for comparing two objects of the same type."""
if isinstance(
orig,
(
str,
int,
float,
bool,
complex,
type(None),
decimal.Decimal,
set,
list,
tuple,
bytes,
bytearray,
memoryview,
frozenset,
enum.Enum,
type,
),
):
return type(orig) == type(new)
# Compare attributes of the type object, not the type itself. Reimported class objects have different memory addresses.
type_obj = type(orig)
new_type_obj = type(new)
if (
type_obj.__name__ != new_type_obj.__name__
or type_obj.__qualname__ != new_type_obj.__qualname__
or type_obj.__bases__ != new_type_obj.__bases__
):
return False
return True
def comparator(orig: Any, new: Any) -> bool:
def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
try:
# if not type_comparator(orig, new):
if type(orig) is not type(new):
@ -239,14 +203,11 @@ def comparator(orig: Any, new: Any) -> bool:
new_keys = dict(new_keys)
orig_keys = {k: v for k, v in orig_keys.items() if not k.startswith("__")}
new_keys = {k: v for k, v in new_keys.items() if not k.startswith("__")}
if superset_obj:
# allow new object to be a superset of the original object
return all(k in new_keys and comparator(v, new_keys[k]) for k, v in orig_keys.items())
return comparator(orig_keys, new_keys)
# # Check that all original attributes exist and match in new object
# for key, value in orig_keys.items():
# if key not in new_keys:
# return False
# if not comparator(value, new_keys[key]):
# return False
# return True # Allow additional attributes in new_keys
if type(orig) in [types.BuiltinFunctionType, types.BuiltinMethodType]:
return new == orig

View file

@ -30,6 +30,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
did_all_timeout = did_all_timeout and original_test_result.timed_out
if original_test_result.timed_out:
continue
# if original_test_result.verification_type and original_test_result.verification_type == VerificationType.INSTANCE_STATE:
# Do superset comparator
if not comparator(original_test_result.return_value, cdd_test_result.return_value):
are_equal = False
break

View file

@ -108,7 +108,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
cur = db.cursor()
data = cur.execute(
"SELECT test_module_path, test_class_name, test_function_name, "
"function_getting_tested, loop_index, iteration_id, runtime, return_value FROM test_results"
"function_getting_tested, loop_index, iteration_id, runtime, return_value,verification_type FROM test_results"
).fetchall()
finally:
db.close()
@ -119,6 +119,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
# TODO : this is because sqlite writes original file module path. Should make it consistent
test_type = test_files.get_test_type_by_original_file_path(test_file_path)
loop_index = val[4]
verification_type = val[8]
try:
ret_val = (pickle.loads(val[7]) if loop_index == 1 else None,)
except Exception:
@ -140,6 +141,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
test_type=test_type,
return_value=ret_val,
timed_out=False,
verification_type=verification_type if verification_type else None,
)
)
except Exception:

View file

@ -74,6 +74,7 @@ class FunctionTestInvocation:
test_type: TestType
return_value: Optional[object] # The return value of the function invocation
timed_out: Optional[bool]
verification_type: Optional[str]
@property
def unique_invocation_loop_id(self) -> str:

View file

@ -739,6 +739,21 @@ class BubbleSorter:
fto_path.write_text(original_code, "utf-8")
def test_superset():
class A:
def __init__(self):
self.a = 1
class B(A):
def __init__(self):
super().__init__()
self.b = 2
assert comparator(A(), B(), superset_obj=True)
assert not comparator(B(), A(), superset_obj=True)
assert not comparator(A(), B())
def test_compare_results_fn():
original_results = TestResults()
original_results.add(