mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
Inject perf/instrumentation for generated tests
This commit is contained in:
parent
6158c38c02
commit
c50c9c5356
4 changed files with 315 additions and 29 deletions
200
codeflash/instrumentation/instrument_existing_tests.py
Normal file
200
codeflash/instrumentation/instrument_existing_tests.py
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
import ast
|
||||
import os.path
|
||||
|
||||
|
||||
class ReplaceCallNodeWithName(ast.NodeTransformer):
|
||||
def __init__(self, only_function_name, new_variable_name="return_value"):
|
||||
self.only_function_name = only_function_name
|
||||
self.new_variable_name = new_variable_name
|
||||
|
||||
def visit_Call(self, node: ast.Call):
|
||||
if isinstance(node, ast.Call) and (
|
||||
(hasattr(node.func, "id") and node.func.id == self.only_function_name)
|
||||
or (hasattr(node.func, "attr") and node.func.attr == self.only_function_name)
|
||||
):
|
||||
return ast.Name(id=self.new_variable_name, ctx=ast.Load())
|
||||
self.generic_visit(node)
|
||||
return node
|
||||
|
||||
|
||||
class InjectPerfOnly(ast.NodeTransformer):
|
||||
def __init__(self, function_name):
|
||||
self.only_function_name = function_name
|
||||
|
||||
def update_line_node(self, test_node, node_name, index: str):
|
||||
call_node = None
|
||||
for node in ast.walk(test_node):
|
||||
if isinstance(node, ast.Call) and (
|
||||
(hasattr(node.func, "id") and node.func.id == self.only_function_name)
|
||||
or (hasattr(node.func, "attr") and node.func.attr == self.only_function_name)
|
||||
):
|
||||
call_node = node
|
||||
if call_node is None:
|
||||
return [test_node]
|
||||
updated_nodes = [
|
||||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="gc", ctx=ast.Load()),
|
||||
attr="disable",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=test_node.lineno,
|
||||
col_offset=test_node.col_offset,
|
||||
),
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="counter", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="time", ctx=ast.Load()),
|
||||
attr="perf_counter_ns",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=test_node.lineno + 1,
|
||||
col_offset=test_node.col_offset,
|
||||
),
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="return_value", ctx=ast.Store())],
|
||||
value=call_node,
|
||||
lineno=test_node.lineno + 2,
|
||||
col_offset=test_node.col_offset,
|
||||
),
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="duration", ctx=ast.Store())],
|
||||
value=ast.BinOp(
|
||||
left=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="time", ctx=ast.Load()),
|
||||
attr="perf_counter_ns",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
),
|
||||
op=ast.Sub(),
|
||||
right=ast.Name(id="counter", ctx=ast.Load()),
|
||||
),
|
||||
lineno=test_node.lineno + 3,
|
||||
col_offset=test_node.col_offset,
|
||||
),
|
||||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="gc", ctx=ast.Load()),
|
||||
attr="enable",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=test_node.lineno + 4,
|
||||
col_offset=test_node.col_offset,
|
||||
),
|
||||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="print", ctx=ast.Load()),
|
||||
args=[
|
||||
ast.JoinedStr(
|
||||
values=[
|
||||
ast.Constant(value="#####"),
|
||||
ast.Constant(
|
||||
value=self.only_function_name + "_" + node_name + "_" + index
|
||||
),
|
||||
ast.Constant(value="#####"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="duration", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.Constant(value="^^^^^"),
|
||||
]
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=test_node.lineno + 5,
|
||||
col_offset=test_node.col_offset,
|
||||
),
|
||||
]
|
||||
subbed_node = ReplaceCallNodeWithName(self.only_function_name).visit(test_node)
|
||||
|
||||
updated_nodes.append(subbed_node)
|
||||
return updated_nodes
|
||||
|
||||
def is_target_function_line(self, line_node):
|
||||
for node in ast.walk(line_node):
|
||||
if isinstance(node, ast.Call) and (
|
||||
(hasattr(node.func, "id") and node.func.id == self.only_function_name)
|
||||
or (hasattr(node.func, "attr") and node.func.attr == self.only_function_name)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
if node.name.startswith("test_"):
|
||||
i = len(node.body) - 1
|
||||
while i >= 0:
|
||||
line_node = node.body[i]
|
||||
# TODO: Validate if the functional call actually did not raise any exceptions
|
||||
# TODO : This does not work for 'with' context managers
|
||||
|
||||
if isinstance(line_node, ast.With):
|
||||
j = len(line_node.body) - 1
|
||||
while j >= 0:
|
||||
with_line_node = line_node.body[j]
|
||||
for with_node in ast.walk(with_line_node):
|
||||
if self.is_target_function_line(with_node):
|
||||
line_node.body[j : j + 1] = self.update_line_node(
|
||||
with_node, node.name, str(i) + "_" + str(j)
|
||||
)
|
||||
j -= 1
|
||||
else:
|
||||
if self.is_target_function_line(line_node):
|
||||
node.body[i : i + 1] = self.update_line_node(line_node, node.name, str(i))
|
||||
i -= 1
|
||||
return node
|
||||
|
||||
|
||||
class FunctionImportedAsVisitor(ast.NodeVisitor):
|
||||
"""This checks if a function has been imported as an alias. We only care about the alias then.
|
||||
from numpy import array as np_array
|
||||
np_array is what we want"""
|
||||
|
||||
def __init__(self, original_function_name):
|
||||
self.original_function_name = original_function_name
|
||||
self.imported_as_function_name = original_function_name
|
||||
|
||||
# TOD: Validate if the function imported is actually from the right module
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
if alias.name == self.original_function_name:
|
||||
if hasattr(alias, "asname") and not alias.asname is None:
|
||||
self.imported_as_function_name = alias.asname
|
||||
|
||||
|
||||
def inject_profiling_into_existing_test(test_path, function_name):
|
||||
with open(test_path, "r") as f:
|
||||
test_code = f.read()
|
||||
tree = ast.parse(test_code)
|
||||
import_visitor = FunctionImportedAsVisitor(function_name)
|
||||
import_visitor.visit(tree)
|
||||
function_name = import_visitor.imported_as_function_name
|
||||
|
||||
tree = InjectPerfOnly(function_name).visit(tree)
|
||||
new_imports = [
|
||||
ast.Import(names=[ast.alias(name="time")]),
|
||||
ast.Import(names=[ast.alias(name="gc")]),
|
||||
]
|
||||
new_path = (
|
||||
os.path.splitext(test_path)[0] + "__perfinstrumented" + os.path.splitext(test_path)[1]
|
||||
)
|
||||
tree.body = new_imports + tree.body
|
||||
|
||||
with open(new_path, "w") as f:
|
||||
f.write(ast.unparse(tree))
|
||||
return new_path
|
||||
|
|
@ -11,13 +11,19 @@ from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
|||
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
|
||||
from codeflash.optimization.function_context import get_function_context_len_constrained
|
||||
from codeflash.optimization.optimizer import optimize_python_code
|
||||
from codeflash.verification.parse_test_output import (
|
||||
parse_unittest_output,
|
||||
parse_test_timing,
|
||||
filter_out_failed_test_timing,
|
||||
parse_test_return_values_bin,
|
||||
)
|
||||
from codeflash.verification.test_runner import run_tests
|
||||
from codeflash.verification.verification_utils import (
|
||||
get_test_file_path,
|
||||
parse_test_return_values_bin,
|
||||
)
|
||||
from codeflash.verification.verifier import generate_tests
|
||||
from codeflash.verification.equivalence import compare_results
|
||||
from codeflash.instrumentation.instrument_existing_tests import inject_profiling_into_existing_test
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -43,7 +49,7 @@ def main() -> None:
|
|||
args.root, optimize_all=args.all, file=args.file, function=args.function
|
||||
)
|
||||
test_files_created = set()
|
||||
existing_unit_tests = set()
|
||||
instrumented_unittests_created = set()
|
||||
found_atleast_one_optimization = False
|
||||
try:
|
||||
functions_to_tests_map = discover_unit_tests(
|
||||
|
|
@ -77,10 +83,23 @@ def main() -> None:
|
|||
)
|
||||
print("CODE TO OPTIMIZE", code_to_optimize)
|
||||
module_path = get_module_name_from_file(path, args.root)
|
||||
unique_original_test_files = set()
|
||||
for i, tests_in_file in enumerate(
|
||||
functions_to_tests_map[module_path + "." + function_name]
|
||||
):
|
||||
existing_unit_tests.add(tests_in_file.test_file)
|
||||
new_test_path = (
|
||||
os.path.splitext(tests_in_file.test_file)[0]
|
||||
+ "__perfinstrumented"
|
||||
+ os.path.splitext(tests_in_file.test_file)[1]
|
||||
)
|
||||
injected_test = inject_profiling_into_existing_test(
|
||||
tests_in_file.test_file,
|
||||
function_name,
|
||||
)
|
||||
with open(new_test_path, "w") as f:
|
||||
f.write(injected_test)
|
||||
instrumented_unittests_created.add(new_test_path)
|
||||
unique_original_test_files.add(tests_in_file.test_file)
|
||||
new_tests = generate_tests(
|
||||
source_code_being_tested=code_to_optimize,
|
||||
function_name=function_name,
|
||||
|
|
@ -97,18 +116,36 @@ def main() -> None:
|
|||
all_original_test_time = []
|
||||
times_run = 0
|
||||
generated_tests_elapsed_time = 0.0
|
||||
existing_unittest_results_original = {}
|
||||
|
||||
for i in range(MAX_TEST_RUN_ITERATIONS):
|
||||
if generated_tests_elapsed_time > MAX_FUNCTION_TEST_SECONDS:
|
||||
break
|
||||
for test_file in existing_unit_tests:
|
||||
stdout, stderr = run_tests(
|
||||
instrumented_test_timing = []
|
||||
for test_file in instrumented_unittests_created:
|
||||
# TODO: If some test case times out then flag it and don't run it in subsequent tests, to save a lot of time. It doesn't add value anyway
|
||||
# TODO: Add Support for PyTest too
|
||||
std_output, stderr_output = run_tests(
|
||||
test_file,
|
||||
cwd=args.root,
|
||||
test_framework=args.test_framework,
|
||||
cwd=args.root,
|
||||
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
|
||||
verbose=True,
|
||||
)
|
||||
print(stdout, stderr)
|
||||
|
||||
# TODO: Only consider the tests that passed, discard the timing of the tests that failed
|
||||
if i == 0:
|
||||
existing_unittest_results_original = {
|
||||
**existing_unittest_results_original,
|
||||
**parse_unittest_output(stderr_output),
|
||||
}
|
||||
timing_result = parse_test_timing(std_output)
|
||||
timing_result = filter_out_failed_test_timing(
|
||||
existing_unittest_results_original, timing_result
|
||||
)
|
||||
timing = sum(list(timing_result.values()))
|
||||
|
||||
instrumented_test_timing.append(timing)
|
||||
|
||||
start_time = time.time()
|
||||
test_stdout, test_stderr = run_tests(
|
||||
|
|
|
|||
70
codeflash/verification/parse_test_output.py
Normal file
70
codeflash/verification/parse_test_output.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
import re
|
||||
|
||||
|
||||
def filter_out_failed_test_timing(test_result, timing_result):
|
||||
final_timing_result = {}
|
||||
processed_test_result = {}
|
||||
for test_name, test_passed in test_result.items():
|
||||
test_path, test_case = test_name.split(":")
|
||||
split_test_path = test_path.split(".")
|
||||
if not split_test_path[-2].endswith("__perfinstrumented"):
|
||||
raise ValueError(
|
||||
"Didn't find the __perfinstrumented suffix for the perf instrumented test. Please check what test is being tested"
|
||||
)
|
||||
split_test_path[-2] = split_test_path[-2][: -len("__perfinstrumented")]
|
||||
processed_test_result[".".join(split_test_path[:-1]) + ":" + test_case] = test_passed
|
||||
|
||||
for test_name_compound, time_taken in timing_result.items():
|
||||
test_path, function_tested, test_case, test_case_id = test_name_compound.split(":")
|
||||
search_key = test_path + ":" + test_case
|
||||
if search_key in processed_test_result and processed_test_result[search_key]:
|
||||
final_timing_result[test_name_compound] = time_taken
|
||||
return final_timing_result
|
||||
|
||||
|
||||
def parse_test_return_values_bin(file_location):
|
||||
test_results = {}
|
||||
if not os.path.exists(file_location):
|
||||
return None
|
||||
with open(file_location, "rb") as file:
|
||||
while file:
|
||||
len_next = file.read(4)
|
||||
if not len_next:
|
||||
return test_results
|
||||
len_next = int.from_bytes(len_next, byteorder="big")
|
||||
test_name = file.read(len_next).decode("ascii")
|
||||
len_next = file.read(8)
|
||||
duration = int.from_bytes(len_next, byteorder="big")
|
||||
len_next = file.read(4)
|
||||
if not len_next:
|
||||
return test_results
|
||||
len_next = int.from_bytes(len_next, byteorder="big")
|
||||
test_pickle = pickle.loads(file.read(len_next))
|
||||
test_results[test_name] = {"runtime": [duration], "results": [test_pickle]}
|
||||
return test_results
|
||||
|
||||
|
||||
def parse_unittest_output(output):
|
||||
re_pattern = r"^(test\w+)\s\((.*?)\)\s\.\.\.\s.*?(ok|FAIL|ERROR)$"
|
||||
matches = re.findall(re_pattern, output, re.MULTILINE | re.DOTALL)
|
||||
test_results = {}
|
||||
for match in matches:
|
||||
if not str.isidentifier(match[0]):
|
||||
print(f"Invalid test name {match[0]}. Test names must be valid python identifiers")
|
||||
continue
|
||||
if match[2] == "ok":
|
||||
test_results[match[1] + ":" + match[0]] = True
|
||||
elif match[2] in ["FAIL", "ERROR"]:
|
||||
test_results[match[1] + ":" + match[0]] = False
|
||||
else:
|
||||
raise ValueError("Invalid test result, couldn't parse the test output")
|
||||
return test_results
|
||||
|
||||
|
||||
def parse_test_timing(test_results):
|
||||
m = re.findall(r"#####([^#]*?)#####([\d\.]*?)\^\^\^\^\^", test_results)
|
||||
parsed_results = {}
|
||||
for test_name, time_taken in m:
|
||||
time_taken = int(time_taken)
|
||||
parsed_results[test_name] = time_taken
|
||||
return parsed_results
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import pickle
|
||||
import re
|
||||
|
||||
|
||||
def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"):
|
||||
|
|
@ -9,25 +10,3 @@ def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"):
|
|||
if os.path.exists(path):
|
||||
return get_test_file_path(test_dir, function_name, iteration + 1)
|
||||
return path
|
||||
|
||||
|
||||
def parse_test_return_values_bin(file_location):
|
||||
test_results = {}
|
||||
if not os.path.exists(file_location):
|
||||
return None
|
||||
with open(file_location, "rb") as file:
|
||||
while file:
|
||||
len_next = file.read(4)
|
||||
if not len_next:
|
||||
return test_results
|
||||
len_next = int.from_bytes(len_next, byteorder="big")
|
||||
test_name = file.read(len_next).decode("ascii")
|
||||
len_next = file.read(8)
|
||||
duration = int.from_bytes(len_next, byteorder="big")
|
||||
len_next = file.read(4)
|
||||
if not len_next:
|
||||
return test_results
|
||||
len_next = int.from_bytes(len_next, byteorder="big")
|
||||
test_pickle = pickle.loads(file.read(len_next))
|
||||
test_results[test_name] = {"runtime": [duration], "results": [test_pickle]}
|
||||
return test_results
|
||||
|
|
|
|||
Loading…
Reference in a new issue