mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
Refactor verification_utils to add a function for parsing test return values from a binary file.
This commit is contained in:
parent
78c6e494cc
commit
783f6b66a6
3 changed files with 163 additions and 6 deletions
|
|
@ -11,8 +11,12 @@ from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
|
|||
from codeflash.optimization.function_context import get_function_context_len_constrained
|
||||
from codeflash.optimization.optimizer import optimize_python_code
|
||||
from codeflash.verification.test_runner import run_tests
|
||||
from codeflash.verification.verification_utils import get_test_file_path
|
||||
from codeflash.verification.verification_utils import (
|
||||
get_test_file_path,
|
||||
parse_test_return_values_bin,
|
||||
)
|
||||
from codeflash.verification.verifier import generate_tests
|
||||
from codeflash.verification.equivalence import compare_results
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -75,12 +79,16 @@ def main() -> None:
|
|||
function_name=function_name,
|
||||
module_path=module_path,
|
||||
function_dependencies=function_dependencies,
|
||||
test_framework=args.test_framework,
|
||||
test_timeout=INDIVIDUAL_TEST_TIMEOUT,
|
||||
)
|
||||
print(new_tests)
|
||||
generated_tests_path = get_test_file_path(args.tests_root, function_name, 0)
|
||||
test_files_created.add(generated_tests_path)
|
||||
with open(generated_tests_path, "w") as file:
|
||||
file.write(new_tests)
|
||||
all_original_test_time = []
|
||||
times_run = 0
|
||||
generated_tests_elapsed_time = 0.0
|
||||
|
||||
for i in range(MAX_TEST_RUN_ITERATIONS):
|
||||
|
|
@ -103,10 +111,35 @@ def main() -> None:
|
|||
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
|
||||
)
|
||||
generated_tests_elapsed_time += time.time() - start_time
|
||||
print(test_stdout, test_stderr)
|
||||
# TODO: Implement the logic to disregard the timing info of the tests that ERRORed out. That is remove test cases that failed to run.
|
||||
try:
|
||||
original_results = parse_test_return_values_bin(
|
||||
"/tmp/test_return_values_0.bin"
|
||||
)
|
||||
except AttributeError as e:
|
||||
print(e)
|
||||
original_results = None
|
||||
if os.path.exists("/tmp/test_return_values_0.bin"):
|
||||
os.remove("/tmp/test_return_values_0.bin")
|
||||
|
||||
all_original_test_time.append(
|
||||
sum(
|
||||
[val["runtime"][0] for val in list(original_results.values())]
|
||||
if original_results is not None
|
||||
else []
|
||||
)
|
||||
)
|
||||
times_run += 1
|
||||
if times_run == 0:
|
||||
continue
|
||||
original_runtime = best_runtime = min(all_original_test_time)
|
||||
print(
|
||||
f"ORIGINAL CODE RUNTIME OVER {times_run} RUN{'S' if times_run > 1 else ''} = {original_runtime}"
|
||||
)
|
||||
print("OPTIMIZING CODE....")
|
||||
optimizations = optimize_python_code(code_to_optimize, n=N_CANDIDATES)
|
||||
for i, (optimized_code, explanation) in enumerate(optimizations):
|
||||
j = i + 1
|
||||
if optimized_code is None:
|
||||
continue
|
||||
print("optimized_candidate", optimized_code)
|
||||
|
|
@ -119,6 +152,37 @@ def main() -> None:
|
|||
break
|
||||
with open(path, "w") as f:
|
||||
f.write(new_code)
|
||||
test_env = os.environ.copy()
|
||||
test_env["CODEFLASH_TEST_ITERATION"] = str(j)
|
||||
for test_index in range(MAX_TEST_RUN_ITERATIONS):
|
||||
if os.path.exists(f"/tmp/test_return_values_{j}.bin"):
|
||||
os.remove(f"/tmp/test_return_values_{j}.bin")
|
||||
start_time = time.time()
|
||||
test_stdout, test_stderr = run_tests(
|
||||
test_path=generated_tests_path,
|
||||
test_framework=args.test_framework,
|
||||
cwd=args.root,
|
||||
test_env=test_env,
|
||||
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
|
||||
)
|
||||
generated_tests_elapsed_time += time.time() - start_time
|
||||
try:
|
||||
test_results = parse_test_return_values_bin(
|
||||
f"/tmp/test_return_values_{j}.bin"
|
||||
)
|
||||
except AttributeError as e:
|
||||
print(e)
|
||||
test_results = None
|
||||
if test_index == 0 and test_results is not None:
|
||||
if compare_results(original_results, test_results):
|
||||
equal_results = True
|
||||
print("RESULTS MATCHED!")
|
||||
else:
|
||||
print("RESULTS DID NOT MATCH")
|
||||
equal_results = False
|
||||
if not equal_results:
|
||||
break
|
||||
|
||||
with open(path, "w") as f:
|
||||
f.write(original_code)
|
||||
print("----------------")
|
||||
|
|
|
|||
72
codeflash/verification/equivalence.py
Normal file
72
codeflash/verification/equivalence.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
try:
|
||||
import sqlalchemy
|
||||
|
||||
HAS_SQLALCHEMY = True
|
||||
except ImportError:
|
||||
HAS_SQLALCHEMY = False
|
||||
|
||||
from typing import Dict
|
||||
import copy
|
||||
|
||||
|
||||
def compare_results(original_result: Dict[str, Dict], test_result: Dict[str, Dict]):
|
||||
_test_result = copy.deepcopy(test_result)
|
||||
for key in original_result.keys():
|
||||
if key not in _test_result:
|
||||
_test_result[key] = {"runtime": [], "results": []}
|
||||
_test_result[key]["runtime"].extend(original_result[key]["runtime"])
|
||||
_test_result[key]["results"].extend(original_result[key]["results"])
|
||||
pass
|
||||
for key, value in _test_result.items():
|
||||
if len(value["results"]) != 2:
|
||||
return False
|
||||
if not comparator(value["results"][0], value["results"][1]):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def comparator(orig, new):
|
||||
if HAS_SQLALCHEMY:
|
||||
try:
|
||||
insp = sqlalchemy.inspection.inspect(orig)
|
||||
insp = sqlalchemy.inspection.inspect(new)
|
||||
orig_keys = orig.__dict__
|
||||
new_keys = new.__dict__
|
||||
for key in list(orig_keys.keys()):
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
if key not in new_keys or not comparator(orig_keys[key], new_keys[key]):
|
||||
return False
|
||||
return True
|
||||
|
||||
except sqlalchemy.exc.NoInspectionAvailable:
|
||||
pass
|
||||
if type(orig) != type(new):
|
||||
return False
|
||||
if isinstance(orig, list) and isinstance(new, list):
|
||||
if len(orig) != len(new):
|
||||
return False
|
||||
for elem1, elem2 in zip(orig, new):
|
||||
if not comparator(elem1, elem2):
|
||||
return False
|
||||
|
||||
if isinstance(orig, str):
|
||||
return orig == new
|
||||
if isinstance(orig, int):
|
||||
return orig == new
|
||||
if isinstance(orig, bool):
|
||||
return orig == new
|
||||
if isinstance(orig, float):
|
||||
return orig == new
|
||||
if isinstance(orig, dict):
|
||||
if len(orig) != len(new):
|
||||
return False
|
||||
for key in orig:
|
||||
if key not in new:
|
||||
return False
|
||||
if not comparator(orig[key], new[key]):
|
||||
return False
|
||||
return True
|
||||
|
||||
# TODO : Add other types here
|
||||
return True
|
||||
|
|
@ -1,12 +1,33 @@
|
|||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"):
|
||||
assert test_type in ["unit", "inspired", "replay"]
|
||||
function_name = function_name.replace(".", "_")
|
||||
path = os.path.join(
|
||||
test_dir, f"test_{function_name}__{test_type}_test_{iteration}.py"
|
||||
)
|
||||
path = os.path.join(test_dir, f"test_{function_name}__{test_type}_test_{iteration}.py")
|
||||
if os.path.exists(path):
|
||||
return get_test_file_path(test_dir, function_name, iteration + 1)
|
||||
return path
|
||||
return path
|
||||
|
||||
|
||||
def parse_test_return_values_bin(file_location):
|
||||
test_results = {}
|
||||
if not os.path.exists(file_location):
|
||||
return None
|
||||
with open(file_location, "rb") as file:
|
||||
while file:
|
||||
len_next = file.read(4)
|
||||
if not len_next:
|
||||
return test_results
|
||||
len_next = int.from_bytes(len_next, byteorder="big")
|
||||
test_name = file.read(len_next).decode("ascii")
|
||||
len_next = file.read(8)
|
||||
duration = int.from_bytes(len_next, byteorder="big")
|
||||
len_next = file.read(4)
|
||||
if not len_next:
|
||||
return test_results
|
||||
len_next = int.from_bytes(len_next, byteorder="big")
|
||||
test_pickle = pickle.loads(file.read(len_next))
|
||||
test_results[test_name] = {"runtime": [duration], "results": [test_pickle]}
|
||||
return test_results
|
||||
|
|
|
|||
Loading…
Reference in a new issue