mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
added verificationtype to sqlite in test results, initial work on comparator superset
This commit is contained in:
parent
79d75ea577
commit
0aa7ca4ea4
10 changed files with 60 additions and 50 deletions
15
code_to_optimize/bubble_sort_method.py
Normal file
15
code_to_optimize/bubble_sort_method.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
|
||||
class BubbleSorter:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def sorter(self, arr):
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
if arr[j] > arr[j + 1]:
|
||||
temp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
return arr
|
||||
|
||||
|
||||
|
|
@ -9,7 +9,7 @@ import isort
|
|||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent, TestingMode
|
||||
from codeflash.models.models import FunctionParent, TestingMode, VerificationType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
|
@ -251,7 +251,7 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
ast.Constant(
|
||||
value="CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT,"
|
||||
" test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT,"
|
||||
" loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)"
|
||||
" loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)"
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
|
|
@ -684,7 +684,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
|
|||
value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load()
|
||||
),
|
||||
args=[
|
||||
ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)"),
|
||||
ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"),
|
||||
ast.Tuple(
|
||||
elts=[
|
||||
ast.Name(id="test_module_name", ctx=ast.Load()),
|
||||
|
|
@ -695,6 +695,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
|
|||
ast.Name(id="invocation_id", ctx=ast.Load()),
|
||||
ast.Name(id="codeflash_duration", ctx=ast.Load()),
|
||||
ast.Name(id="pickled_return_value", ctx=ast.Load()),
|
||||
ast.Constant(value=VerificationType.FUNCTION_TO_OPTIMIZE),
|
||||
],
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
|
|
|
|||
|
|
@ -84,6 +84,11 @@ class CodeOptimizationContext(BaseModel):
|
|||
preexisting_objects: list[tuple[str, list[FunctionParent]]]
|
||||
|
||||
|
||||
class VerificationType(str, Enum):
|
||||
FUNCTION_TO_OPTIMIZE = "function_to_optimize"
|
||||
INSTANCE_STATE = "instance_state"
|
||||
|
||||
|
||||
class OptimizedCandidateResult(BaseModel):
|
||||
max_loop_count: int
|
||||
best_test_runtime: int
|
||||
|
|
|
|||
|
|
@ -514,6 +514,8 @@ class Optimizer:
|
|||
optimized_code=candidate.source_code,
|
||||
qualified_function_name=function_to_optimize.qualified_name,
|
||||
)
|
||||
# If init was modified, instrument the code with codeflash capture
|
||||
|
||||
if not did_update:
|
||||
logger.warning(
|
||||
"No functions were replaced in the optimized code. Skipping optimization candidate."
|
||||
|
|
@ -529,6 +531,9 @@ class Optimizer:
|
|||
optimization_candidate_index=candidate_index, baseline_results=original_code_baseline
|
||||
)
|
||||
console.rule()
|
||||
|
||||
# Remove codeflash capture
|
||||
|
||||
if not is_successful(run_results):
|
||||
optimized_runtimes[candidate.optimization_id] = None
|
||||
is_correct[candidate.optimization_id] = False
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ import time
|
|||
|
||||
import dill as pickle
|
||||
|
||||
from codeflash.models.models import VerificationType
|
||||
|
||||
|
||||
def get_test_info_from_stack() -> tuple[str, str | None, str, str]:
|
||||
"""Extract test information from the call stack."""
|
||||
|
|
@ -92,7 +94,7 @@ def codeflash_capture(function_name: str, tmp_dir_path: str):
|
|||
instance_state = args[0].__dict__ # self is always the first argument
|
||||
print(instance_state)
|
||||
codeflash_cur.execute(
|
||||
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)"
|
||||
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)"
|
||||
)
|
||||
|
||||
# Write to sqlite
|
||||
|
|
@ -108,7 +110,7 @@ def codeflash_capture(function_name: str, tmp_dir_path: str):
|
|||
pickled_return_value,
|
||||
)
|
||||
codeflash_cur.execute(
|
||||
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
|
|
@ -118,6 +120,7 @@ def codeflash_capture(function_name: str, tmp_dir_path: str):
|
|||
invocation_id,
|
||||
codeflash_duration,
|
||||
pickled_return_value,
|
||||
VerificationType.INSTANCE_STATE,
|
||||
),
|
||||
)
|
||||
codeflash_con.commit()
|
||||
|
|
|
|||
|
|
@ -44,43 +44,7 @@ except ImportError:
|
|||
HAS_PYRSISTENT = False
|
||||
|
||||
|
||||
def type_comparator(orig: Any, new: Any) -> bool:
|
||||
"""Custom type comparator for comparing two objects of the same type."""
|
||||
if isinstance(
|
||||
orig,
|
||||
(
|
||||
str,
|
||||
int,
|
||||
float,
|
||||
bool,
|
||||
complex,
|
||||
type(None),
|
||||
decimal.Decimal,
|
||||
set,
|
||||
list,
|
||||
tuple,
|
||||
bytes,
|
||||
bytearray,
|
||||
memoryview,
|
||||
frozenset,
|
||||
enum.Enum,
|
||||
type,
|
||||
),
|
||||
):
|
||||
return type(orig) == type(new)
|
||||
# Compare attributes of the type object, not the type itself. Reimported class objects have different memory addresses.
|
||||
type_obj = type(orig)
|
||||
new_type_obj = type(new)
|
||||
if (
|
||||
type_obj.__name__ != new_type_obj.__name__
|
||||
or type_obj.__qualname__ != new_type_obj.__qualname__
|
||||
or type_obj.__bases__ != new_type_obj.__bases__
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def comparator(orig: Any, new: Any) -> bool:
|
||||
def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
|
||||
try:
|
||||
# if not type_comparator(orig, new):
|
||||
if type(orig) is not type(new):
|
||||
|
|
@ -239,14 +203,11 @@ def comparator(orig: Any, new: Any) -> bool:
|
|||
new_keys = dict(new_keys)
|
||||
orig_keys = {k: v for k, v in orig_keys.items() if not k.startswith("__")}
|
||||
new_keys = {k: v for k, v in new_keys.items() if not k.startswith("__")}
|
||||
|
||||
if superset_obj:
|
||||
# allow new object to be a superset of the original object
|
||||
return all(k in new_keys and comparator(v, new_keys[k]) for k, v in orig_keys.items())
|
||||
return comparator(orig_keys, new_keys)
|
||||
# # Check that all original attributes exist and match in new object
|
||||
# for key, value in orig_keys.items():
|
||||
# if key not in new_keys:
|
||||
# return False
|
||||
# if not comparator(value, new_keys[key]):
|
||||
# return False
|
||||
# return True # Allow additional attributes in new_keys
|
||||
|
||||
if type(orig) in [types.BuiltinFunctionType, types.BuiltinMethodType]:
|
||||
return new == orig
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
|
|||
did_all_timeout = did_all_timeout and original_test_result.timed_out
|
||||
if original_test_result.timed_out:
|
||||
continue
|
||||
# if original_test_result.verification_type and original_test_result.verification_type == VerificationType.INSTANCE_STATE:
|
||||
# Do superset comparator
|
||||
if not comparator(original_test_result.return_value, cdd_test_result.return_value):
|
||||
are_equal = False
|
||||
break
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
cur = db.cursor()
|
||||
data = cur.execute(
|
||||
"SELECT test_module_path, test_class_name, test_function_name, "
|
||||
"function_getting_tested, loop_index, iteration_id, runtime, return_value FROM test_results"
|
||||
"function_getting_tested, loop_index, iteration_id, runtime, return_value,verification_type FROM test_results"
|
||||
).fetchall()
|
||||
finally:
|
||||
db.close()
|
||||
|
|
@ -119,6 +119,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
# TODO : this is because sqlite writes original file module path. Should make it consistent
|
||||
test_type = test_files.get_test_type_by_original_file_path(test_file_path)
|
||||
loop_index = val[4]
|
||||
verification_type = val[8]
|
||||
try:
|
||||
ret_val = (pickle.loads(val[7]) if loop_index == 1 else None,)
|
||||
except Exception:
|
||||
|
|
@ -140,6 +141,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
test_type=test_type,
|
||||
return_value=ret_val,
|
||||
timed_out=False,
|
||||
verification_type=verification_type if verification_type else None,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ class FunctionTestInvocation:
|
|||
test_type: TestType
|
||||
return_value: Optional[object] # The return value of the function invocation
|
||||
timed_out: Optional[bool]
|
||||
verification_type: Optional[str]
|
||||
|
||||
@property
|
||||
def unique_invocation_loop_id(self) -> str:
|
||||
|
|
|
|||
|
|
@ -739,6 +739,21 @@ class BubbleSorter:
|
|||
fto_path.write_text(original_code, "utf-8")
|
||||
|
||||
|
||||
def test_superset():
|
||||
class A:
|
||||
def __init__(self):
|
||||
self.a = 1
|
||||
|
||||
class B(A):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.b = 2
|
||||
|
||||
assert comparator(A(), B(), superset_obj=True)
|
||||
assert not comparator(B(), A(), superset_obj=True)
|
||||
assert not comparator(A(), B())
|
||||
|
||||
|
||||
def test_compare_results_fn():
|
||||
original_results = TestResults()
|
||||
original_results.add(
|
||||
|
|
|
|||
Loading…
Reference in a new issue