Inject perf/instrumentation for generated tests

This commit is contained in:
afik.cohen 2023-10-22 21:47:45 -07:00
parent 6158c38c02
commit c50c9c5356
4 changed files with 315 additions and 29 deletions

View file

@ -0,0 +1,200 @@
import ast
import os.path
class ReplaceCallNodeWithName(ast.NodeTransformer):
def __init__(self, only_function_name, new_variable_name="return_value"):
self.only_function_name = only_function_name
self.new_variable_name = new_variable_name
def visit_Call(self, node: ast.Call):
if isinstance(node, ast.Call) and (
(hasattr(node.func, "id") and node.func.id == self.only_function_name)
or (hasattr(node.func, "attr") and node.func.attr == self.only_function_name)
):
return ast.Name(id=self.new_variable_name, ctx=ast.Load())
self.generic_visit(node)
return node
class InjectPerfOnly(ast.NodeTransformer):
def __init__(self, function_name):
self.only_function_name = function_name
def update_line_node(self, test_node, node_name, index: str):
call_node = None
for node in ast.walk(test_node):
if isinstance(node, ast.Call) and (
(hasattr(node.func, "id") and node.func.id == self.only_function_name)
or (hasattr(node.func, "attr") and node.func.attr == self.only_function_name)
):
call_node = node
if call_node is None:
return [test_node]
updated_nodes = [
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="gc", ctx=ast.Load()),
attr="disable",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
lineno=test_node.lineno,
col_offset=test_node.col_offset,
),
ast.Assign(
targets=[ast.Name(id="counter", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="time", ctx=ast.Load()),
attr="perf_counter_ns",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
lineno=test_node.lineno + 1,
col_offset=test_node.col_offset,
),
ast.Assign(
targets=[ast.Name(id="return_value", ctx=ast.Store())],
value=call_node,
lineno=test_node.lineno + 2,
col_offset=test_node.col_offset,
),
ast.Assign(
targets=[ast.Name(id="duration", ctx=ast.Store())],
value=ast.BinOp(
left=ast.Call(
func=ast.Attribute(
value=ast.Name(id="time", ctx=ast.Load()),
attr="perf_counter_ns",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
op=ast.Sub(),
right=ast.Name(id="counter", ctx=ast.Load()),
),
lineno=test_node.lineno + 3,
col_offset=test_node.col_offset,
),
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="gc", ctx=ast.Load()),
attr="enable",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
lineno=test_node.lineno + 4,
col_offset=test_node.col_offset,
),
ast.Expr(
value=ast.Call(
func=ast.Name(id="print", ctx=ast.Load()),
args=[
ast.JoinedStr(
values=[
ast.Constant(value="#####"),
ast.Constant(
value=self.only_function_name + "_" + node_name + "_" + index
),
ast.Constant(value="#####"),
ast.FormattedValue(
value=ast.Name(id="duration", ctx=ast.Load()),
conversion=-1,
),
ast.Constant(value="^^^^^"),
]
)
],
keywords=[],
),
lineno=test_node.lineno + 5,
col_offset=test_node.col_offset,
),
]
subbed_node = ReplaceCallNodeWithName(self.only_function_name).visit(test_node)
updated_nodes.append(subbed_node)
return updated_nodes
def is_target_function_line(self, line_node):
for node in ast.walk(line_node):
if isinstance(node, ast.Call) and (
(hasattr(node.func, "id") and node.func.id == self.only_function_name)
or (hasattr(node.func, "attr") and node.func.attr == self.only_function_name)
):
return True
return False
def visit_FunctionDef(self, node: ast.FunctionDef):
if node.name.startswith("test_"):
i = len(node.body) - 1
while i >= 0:
line_node = node.body[i]
# TODO: Validate if the functional call actually did not raise any exceptions
# TODO : This does not work for 'with' context managers
if isinstance(line_node, ast.With):
j = len(line_node.body) - 1
while j >= 0:
with_line_node = line_node.body[j]
for with_node in ast.walk(with_line_node):
if self.is_target_function_line(with_node):
line_node.body[j : j + 1] = self.update_line_node(
with_node, node.name, str(i) + "_" + str(j)
)
j -= 1
else:
if self.is_target_function_line(line_node):
node.body[i : i + 1] = self.update_line_node(line_node, node.name, str(i))
i -= 1
return node
class FunctionImportedAsVisitor(ast.NodeVisitor):
"""This checks if a function has been imported as an alias. We only care about the alias then.
from numpy import array as np_array
np_array is what we want"""
def __init__(self, original_function_name):
self.original_function_name = original_function_name
self.imported_as_function_name = original_function_name
# TOD: Validate if the function imported is actually from the right module
def visit_ImportFrom(self, node: ast.ImportFrom):
for alias in node.names:
if alias.name == self.original_function_name:
if hasattr(alias, "asname") and not alias.asname is None:
self.imported_as_function_name = alias.asname
def inject_profiling_into_existing_test(test_path, function_name):
with open(test_path, "r") as f:
test_code = f.read()
tree = ast.parse(test_code)
import_visitor = FunctionImportedAsVisitor(function_name)
import_visitor.visit(tree)
function_name = import_visitor.imported_as_function_name
tree = InjectPerfOnly(function_name).visit(tree)
new_imports = [
ast.Import(names=[ast.alias(name="time")]),
ast.Import(names=[ast.alias(name="gc")]),
]
new_path = (
os.path.splitext(test_path)[0] + "__perfinstrumented" + os.path.splitext(test_path)[1]
)
tree.body = new_imports + tree.body
with open(new_path, "w") as f:
f.write(ast.unparse(tree))
return new_path

View file

@ -11,13 +11,19 @@ from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
from codeflash.optimization.function_context import get_function_context_len_constrained
from codeflash.optimization.optimizer import optimize_python_code
from codeflash.verification.parse_test_output import (
parse_unittest_output,
parse_test_timing,
filter_out_failed_test_timing,
parse_test_return_values_bin,
)
from codeflash.verification.test_runner import run_tests
from codeflash.verification.verification_utils import (
get_test_file_path,
parse_test_return_values_bin,
)
from codeflash.verification.verifier import generate_tests
from codeflash.verification.equivalence import compare_results
from codeflash.instrumentation.instrument_existing_tests import inject_profiling_into_existing_test
@dataclass
@ -43,7 +49,7 @@ def main() -> None:
args.root, optimize_all=args.all, file=args.file, function=args.function
)
test_files_created = set()
existing_unit_tests = set()
instrumented_unittests_created = set()
found_atleast_one_optimization = False
try:
functions_to_tests_map = discover_unit_tests(
@ -77,10 +83,23 @@ def main() -> None:
)
print("CODE TO OPTIMIZE", code_to_optimize)
module_path = get_module_name_from_file(path, args.root)
unique_original_test_files = set()
for i, tests_in_file in enumerate(
functions_to_tests_map[module_path + "." + function_name]
):
existing_unit_tests.add(tests_in_file.test_file)
new_test_path = (
os.path.splitext(tests_in_file.test_file)[0]
+ "__perfinstrumented"
+ os.path.splitext(tests_in_file.test_file)[1]
)
injected_test = inject_profiling_into_existing_test(
tests_in_file.test_file,
function_name,
)
with open(new_test_path, "w") as f:
f.write(injected_test)
instrumented_unittests_created.add(new_test_path)
unique_original_test_files.add(tests_in_file.test_file)
new_tests = generate_tests(
source_code_being_tested=code_to_optimize,
function_name=function_name,
@ -97,18 +116,36 @@ def main() -> None:
all_original_test_time = []
times_run = 0
generated_tests_elapsed_time = 0.0
existing_unittest_results_original = {}
for i in range(MAX_TEST_RUN_ITERATIONS):
if generated_tests_elapsed_time > MAX_FUNCTION_TEST_SECONDS:
break
for test_file in existing_unit_tests:
stdout, stderr = run_tests(
instrumented_test_timing = []
for test_file in instrumented_unittests_created:
# TODO: If some test case times out then flag it and don't run it in subsequent tests, to save a lot of time. It doesn't add value anyway
# TODO: Add Support for PyTest too
std_output, stderr_output = run_tests(
test_file,
cwd=args.root,
test_framework=args.test_framework,
cwd=args.root,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
verbose=True,
)
print(stdout, stderr)
# TODO: Only consider the tests that passed, discard the timing of the tests that failed
if i == 0:
existing_unittest_results_original = {
**existing_unittest_results_original,
**parse_unittest_output(stderr_output),
}
timing_result = parse_test_timing(std_output)
timing_result = filter_out_failed_test_timing(
existing_unittest_results_original, timing_result
)
timing = sum(list(timing_result.values()))
instrumented_test_timing.append(timing)
start_time = time.time()
test_stdout, test_stderr = run_tests(

View file

@ -0,0 +1,70 @@
import re
def filter_out_failed_test_timing(test_result, timing_result):
final_timing_result = {}
processed_test_result = {}
for test_name, test_passed in test_result.items():
test_path, test_case = test_name.split(":")
split_test_path = test_path.split(".")
if not split_test_path[-2].endswith("__perfinstrumented"):
raise ValueError(
"Didn't find the __perfinstrumented suffix for the perf instrumented test. Please check what test is being tested"
)
split_test_path[-2] = split_test_path[-2][: -len("__perfinstrumented")]
processed_test_result[".".join(split_test_path[:-1]) + ":" + test_case] = test_passed
for test_name_compound, time_taken in timing_result.items():
test_path, function_tested, test_case, test_case_id = test_name_compound.split(":")
search_key = test_path + ":" + test_case
if search_key in processed_test_result and processed_test_result[search_key]:
final_timing_result[test_name_compound] = time_taken
return final_timing_result
def parse_test_return_values_bin(file_location):
test_results = {}
if not os.path.exists(file_location):
return None
with open(file_location, "rb") as file:
while file:
len_next = file.read(4)
if not len_next:
return test_results
len_next = int.from_bytes(len_next, byteorder="big")
test_name = file.read(len_next).decode("ascii")
len_next = file.read(8)
duration = int.from_bytes(len_next, byteorder="big")
len_next = file.read(4)
if not len_next:
return test_results
len_next = int.from_bytes(len_next, byteorder="big")
test_pickle = pickle.loads(file.read(len_next))
test_results[test_name] = {"runtime": [duration], "results": [test_pickle]}
return test_results
def parse_unittest_output(output):
re_pattern = r"^(test\w+)\s\((.*?)\)\s\.\.\.\s.*?(ok|FAIL|ERROR)$"
matches = re.findall(re_pattern, output, re.MULTILINE | re.DOTALL)
test_results = {}
for match in matches:
if not str.isidentifier(match[0]):
print(f"Invalid test name {match[0]}. Test names must be valid python identifiers")
continue
if match[2] == "ok":
test_results[match[1] + ":" + match[0]] = True
elif match[2] in ["FAIL", "ERROR"]:
test_results[match[1] + ":" + match[0]] = False
else:
raise ValueError("Invalid test result, couldn't parse the test output")
return test_results
def parse_test_timing(test_results):
m = re.findall(r"#####([^#]*?)#####([\d\.]*?)\^\^\^\^\^", test_results)
parsed_results = {}
for test_name, time_taken in m:
time_taken = int(time_taken)
parsed_results[test_name] = time_taken
return parsed_results

View file

@ -1,5 +1,6 @@
import os
import pickle
import re
def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"):
@ -9,25 +10,3 @@ def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"):
if os.path.exists(path):
return get_test_file_path(test_dir, function_name, iteration + 1)
return path
def parse_test_return_values_bin(file_location):
test_results = {}
if not os.path.exists(file_location):
return None
with open(file_location, "rb") as file:
while file:
len_next = file.read(4)
if not len_next:
return test_results
len_next = int.from_bytes(len_next, byteorder="big")
test_name = file.read(len_next).decode("ascii")
len_next = file.read(8)
duration = int.from_bytes(len_next, byteorder="big")
len_next = file.read(4)
if not len_next:
return test_results
len_next = int.from_bytes(len_next, byteorder="big")
test_pickle = pickle.loads(file.read(len_next))
test_results[test_name] = {"runtime": [duration], "results": [test_pickle]}
return test_results