Refactor verification_utils to add a function for parsing test return values from a binary file.

This commit is contained in:
afik.cohen 2023-10-22 18:27:58 -07:00
parent 78c6e494cc
commit 783f6b66a6
3 changed files with 163 additions and 6 deletions

View file

@ -11,8 +11,12 @@ from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
from codeflash.optimization.function_context import get_function_context_len_constrained
from codeflash.optimization.optimizer import optimize_python_code
from codeflash.verification.test_runner import run_tests
from codeflash.verification.verification_utils import get_test_file_path
from codeflash.verification.verification_utils import (
get_test_file_path,
parse_test_return_values_bin,
)
from codeflash.verification.verifier import generate_tests
from codeflash.verification.equivalence import compare_results
@dataclass
@ -75,12 +79,16 @@ def main() -> None:
function_name=function_name,
module_path=module_path,
function_dependencies=function_dependencies,
test_framework=args.test_framework,
test_timeout=INDIVIDUAL_TEST_TIMEOUT,
)
print(new_tests)
generated_tests_path = get_test_file_path(args.tests_root, function_name, 0)
test_files_created.add(generated_tests_path)
with open(generated_tests_path, "w") as file:
file.write(new_tests)
all_original_test_time = []
times_run = 0
generated_tests_elapsed_time = 0.0
for i in range(MAX_TEST_RUN_ITERATIONS):
@ -103,10 +111,35 @@ def main() -> None:
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
)
generated_tests_elapsed_time += time.time() - start_time
print(test_stdout, test_stderr)
# TODO: Implement the logic to disregard the timing info of the tests that ERRORed out. That is remove test cases that failed to run.
try:
original_results = parse_test_return_values_bin(
"/tmp/test_return_values_0.bin"
)
except AttributeError as e:
print(e)
original_results = None
if os.path.exists("/tmp/test_return_values_0.bin"):
os.remove("/tmp/test_return_values_0.bin")
all_original_test_time.append(
sum(
[val["runtime"][0] for val in list(original_results.values())]
if original_results is not None
else []
)
)
times_run += 1
if times_run == 0:
continue
original_runtime = best_runtime = min(all_original_test_time)
print(
f"ORIGINAL CODE RUNTIME OVER {times_run} RUN{'S' if times_run > 1 else ''} = {original_runtime}"
)
print("OPTIMIZING CODE....")
optimizations = optimize_python_code(code_to_optimize, n=N_CANDIDATES)
for i, (optimized_code, explanation) in enumerate(optimizations):
j = i + 1
if optimized_code is None:
continue
print("optimized_candidate", optimized_code)
@ -119,6 +152,37 @@ def main() -> None:
break
with open(path, "w") as f:
f.write(new_code)
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = str(j)
for test_index in range(MAX_TEST_RUN_ITERATIONS):
if os.path.exists(f"/tmp/test_return_values_{j}.bin"):
os.remove(f"/tmp/test_return_values_{j}.bin")
start_time = time.time()
test_stdout, test_stderr = run_tests(
test_path=generated_tests_path,
test_framework=args.test_framework,
cwd=args.root,
test_env=test_env,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
)
generated_tests_elapsed_time += time.time() - start_time
try:
test_results = parse_test_return_values_bin(
f"/tmp/test_return_values_{j}.bin"
)
except AttributeError as e:
print(e)
test_results = None
if test_index == 0 and test_results is not None:
if compare_results(original_results, test_results):
equal_results = True
print("RESULTS MATCHED!")
else:
print("RESULTS DID NOT MATCH")
equal_results = False
if not equal_results:
break
with open(path, "w") as f:
f.write(original_code)
print("----------------")

View file

@ -0,0 +1,72 @@
try:
import sqlalchemy
HAS_SQLALCHEMY = True
except ImportError:
HAS_SQLALCHEMY = False
from typing import Dict
import copy
def compare_results(original_result: Dict[str, Dict], test_result: Dict[str, Dict]):
_test_result = copy.deepcopy(test_result)
for key in original_result.keys():
if key not in _test_result:
_test_result[key] = {"runtime": [], "results": []}
_test_result[key]["runtime"].extend(original_result[key]["runtime"])
_test_result[key]["results"].extend(original_result[key]["results"])
pass
for key, value in _test_result.items():
if len(value["results"]) != 2:
return False
if not comparator(value["results"][0], value["results"][1]):
return False
return True
def comparator(orig, new):
if HAS_SQLALCHEMY:
try:
insp = sqlalchemy.inspection.inspect(orig)
insp = sqlalchemy.inspection.inspect(new)
orig_keys = orig.__dict__
new_keys = new.__dict__
for key in list(orig_keys.keys()):
if key.startswith("_"):
continue
if key not in new_keys or not comparator(orig_keys[key], new_keys[key]):
return False
return True
except sqlalchemy.exc.NoInspectionAvailable:
pass
if type(orig) != type(new):
return False
if isinstance(orig, list) and isinstance(new, list):
if len(orig) != len(new):
return False
for elem1, elem2 in zip(orig, new):
if not comparator(elem1, elem2):
return False
if isinstance(orig, str):
return orig == new
if isinstance(orig, int):
return orig == new
if isinstance(orig, bool):
return orig == new
if isinstance(orig, float):
return orig == new
if isinstance(orig, dict):
if len(orig) != len(new):
return False
for key in orig:
if key not in new:
return False
if not comparator(orig[key], new[key]):
return False
return True
# TODO : Add other types here
return True

View file

@ -1,12 +1,33 @@
import os
import pickle
def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"):
assert test_type in ["unit", "inspired", "replay"]
function_name = function_name.replace(".", "_")
path = os.path.join(
test_dir, f"test_{function_name}__{test_type}_test_{iteration}.py"
)
path = os.path.join(test_dir, f"test_{function_name}__{test_type}_test_{iteration}.py")
if os.path.exists(path):
return get_test_file_path(test_dir, function_name, iteration + 1)
return path
return path
def parse_test_return_values_bin(file_location):
test_results = {}
if not os.path.exists(file_location):
return None
with open(file_location, "rb") as file:
while file:
len_next = file.read(4)
if not len_next:
return test_results
len_next = int.from_bytes(len_next, byteorder="big")
test_name = file.read(len_next).decode("ascii")
len_next = file.read(8)
duration = int.from_bytes(len_next, byteorder="big")
len_next = file.read(4)
if not len_next:
return test_results
len_next = int.from_bytes(len_next, byteorder="big")
test_pickle = pickle.loads(file.read(len_next))
test_results[test_name] = {"runtime": [duration], "results": [test_pickle]}
return test_results