Add unittest discovery and instrumentation

This commit is contained in:
afik.cohen 2023-10-22 13:45:41 -07:00
parent 80dfb25d5d
commit fd7002235b
5 changed files with 250 additions and 74 deletions

View file

@ -1,31 +1,106 @@
import subprocess
import jedi
import re
import os
from collections import defaultdict
import re
import subprocess
import unittest
from collections import defaultdict, namedtuple
import jedi
TestsInFile = namedtuple("TestsInFile", ["test_file", "test_function", "test_suite"])
def discover_unit_tests(test_directory, project_root_path):
def discover_unit_tests(test_directory, project_root_path, test_framework="pytest"):
if test_framework == "pytest":
return discover_tests_pytest(test_directory, project_root_path)
elif test_framework == "unittest":
return discover_tests_unittest(test_directory, project_root_path)
def discover_tests_pytest(test_directory, project_root_path):
pytest_result = subprocess.run(
["pytest", f"{test_directory}", "--co", "-q"], stdout=subprocess.PIPE, cwd=project_root_path
["pytest", f"{test_directory}", "--co", "-q"],
stdout=subprocess.PIPE,
cwd=test_directory,
)
tests = parse_tests(pytest_result.stdout.decode("utf-8"))
file_to_test_map = defaultdict(list)
function_to_test_map = defaultdict(list)
for test in tests:
test_file, function = test.split("::")
test_file = os.path.join(project_root_path, test_file)
file_to_test_map[test_file].append(function)
test_file_path = os.path.join(test_directory, test_file)
if not os.path.exists(test_file_path):
# Seeing that in some circumstances pytest also returns the cwd as part of the test file path
one_level_up_test_directory = os.path.abspath(os.path.join(test_directory, ".."))
test_file_path = os.path.join(one_level_up_test_directory, test_file)
if not os.path.exists(test_file_path):
raise ValueError(
f"Pytest test discovery failed. Test file {test_file_path} does not exist."
)
file_to_test_map[test_file_path].append({"test_function": function})
# Within these test files, find the project functions they are referring to and return their names/locations
return process_test_files(file_to_test_map, project_root_path, test_framework="pytest")
def discover_tests_unittest(test_directory, project_root_path):
loader = unittest.TestLoader()
tests = loader.discover(str(test_directory))
file_to_test_map = defaultdict(list)
for _test_suite in tests._tests:
for test_suite_2 in _test_suite._tests:
if not hasattr(test_suite_2, "_tests"):
print("Didn't find tests for ", test_suite_2)
continue
for test in test_suite_2._tests:
m = re.search("(.*)\s\((.*)\.(.*)\)", str(test))
if not m:
continue
test_function, test_module, test_suite_name = (
m.group(1),
m.group(2),
m.group(3),
)
test_module_path = test_module.replace(".", os.sep)
test_module_path = os.path.join(str(test_directory), test_module_path) + ".py"
if not os.path.exists(test_module_path):
continue
file_to_test_map[test_module_path].append(
{"test_function": test_function, "test_suite_name": test_suite_name}
)
return process_test_files(file_to_test_map, project_root_path, test_framework="unittest")
def process_test_files(file_to_test_map, project_root_path, test_framework="pytest"):
function_to_test_map = defaultdict(list)
jedi_project = jedi.Project(path=project_root_path)
TestFunction = namedtuple("TestFunction", ["function_name", "test_suite_name"])
for test_file, functions in file_to_test_map.items():
script = jedi.Script(path=test_file, project=jedi_project)
test_functions = set()
top_level_names = script.get_names()
all_names = script.get_names(all_scopes=True, references=True)
all_defs = script.get_names(all_scopes=True, definitions=True)
for name in top_level_names:
if name.name in functions and name.type == "function":
test_functions.add(name.name)
if test_framework == "pytest":
functions_to_search = [elem["test_function"] for elem in functions]
if name.name in functions_to_search and name.type == "function":
test_functions.add(TestFunction(name.name, None))
if test_framework == "unittest":
functions_to_search = [elem["test_function"] for elem in functions]
test_suites = [elem["test_suite_name"] for elem in functions]
if name.name in test_suites and name.type == "class":
for def_name in all_defs:
if (
def_name.name in functions_to_search
and def_name.type == "function"
and def_name.full_name is not None
and f".{name.name}." in def_name.full_name
):
test_functions.add(TestFunction(def_name.name, name.name))
test_functions_list = list(test_functions)
test_functions_raw = [elem.function_name for elem in test_functions_list]
for name in all_names:
if name.full_name is None:
continue
@ -33,21 +108,30 @@ def discover_unit_tests(test_directory, project_root_path):
if not m:
continue
scope = m.group(1)
if scope in test_functions:
definition = script.goto(
line=name.line,
column=name.column,
follow_imports=True,
follow_builtin_imports=False,
)
index = test_functions_raw.index(scope) if scope in test_functions_raw else -1
if index >= 0:
scope_test_function = test_functions_list[index].function_name
scope_test_suite = test_functions_list[index].test_suite_name
try:
definition = script.goto(
line=name.line,
column=name.column,
follow_imports=True,
follow_builtin_imports=False,
)
except Exception as e:
print(str(e))
continue
if definition and definition[0].type == "function":
definition_path = str(definition[0].module_path)
# The definition is part of this project and not defined within the original function
if (
definition_path.startswith(project_root_path + os.sep)
definition_path.startswith(str(project_root_path) + os.sep)
and definition[0].module_name != name.module_name
):
function_to_test_map[definition[0].full_name].append((test_file, scope))
function_to_test_map[definition[0].full_name].append(
TestsInFile(test_file, scope_test_function, scope_test_suite)
)
deduped_function_to_test_map = {}
for function, tests in function_to_test_map.items():
deduped_function_to_test_map[function] = list(set(tests))

View file

@ -2,14 +2,24 @@ import ast
class InjectPerfAndLogging(ast.NodeTransformer):
def __init__(self, function_name, auxiliary_function_names):
def __init__(
self,
function_name: str,
auxillary_function_names,
test_framework="pytest",
test_timeout: int = 15,
):
self.function_name = function_name
self.class_name = None
self.only_function_name = self.function_name
self.test_framework = test_framework
self.individual_test_timeout = test_timeout
if len(function_name.split(".")) > 1:
self.class_name = function_name.split(".")[0]
self.only_function_name = function_name.split(".")[-1]
self.auxiliary_function_names = auxiliary_function_names # Other functional dependencies that were injected
self.auxillary_function_names = (
auxillary_function_names # Other functional dependencies that were injected
)
def visit_ImportFrom(self, node: ast.ImportFrom):
if any([name.name == self.function_name for name in node.names]):
@ -17,10 +27,12 @@ class InjectPerfAndLogging(ast.NodeTransformer):
def visit_ClassDef(self, node: ast.ClassDef):
# If the original class exists during testing, then remove it all together
if not self.class_name:
return node
if node.name == self.class_name:
return None # Remove the re-definition of the class and its dependencies from the test generation code
for inner_node in ast.walk(node):
if isinstance(inner_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
inner_node = self.visit_FunctionDef(inner_node)
return node
def visit_Assert(self, node: ast.Assert):
# TODO : This does not work yet
@ -31,42 +43,90 @@ class InjectPerfAndLogging(ast.NodeTransformer):
return None
def is_target_function_node(self, node):
return isinstance(node, ast.Call) and ((hasattr(node.func,
'id') and node.func.id == self.only_function_name) or
(hasattr(node.func,
'attr') and node.func.attr == self.only_function_name))
return isinstance(node, ast.Call) and (
(hasattr(node.func, "id") and node.func.id == self.only_function_name)
or (hasattr(node.func, "attr") and node.func.attr == self.only_function_name)
)
def update_line_node(self, test_node, node_name, index: str):
return [
ast.Expr(value=ast.Call(
func=ast.Attribute(value=ast.Name(id='gc', ctx=ast.Load()), attr='disable', ctx=ast.Load()), args=[],
keywords=[]), lineno=test_node.lineno, col_offset=test_node.col_offset),
ast.Assign(targets=[ast.Name(id='counter', ctx=ast.Store())],
value=ast.Call(func=ast.Attribute(
value=ast.Name(id='time', ctx=ast.Load()),
attr='perf_counter_ns', ctx=ast.Load()), args=[], keywords=[]), lineno=test_node.lineno + 1,
col_offset=test_node.col_offset),
ast.Assign(targets=[ast.Name(id='return_value', ctx=ast.Store())],
value=test_node, lineno=test_node.lineno + 2, col_offset=test_node.col_offset),
ast.Assign(targets=[ast.Name(id='duration', ctx=ast.Store())],
value=ast.BinOp(
left=ast.Call(func=ast.Attribute(value=ast.Name(id='time', ctx=ast.Load()),
attr='perf_counter_ns', ctx=ast.Load()),
args=[], keywords=[]),
op=ast.Sub(), right=ast.Name(id='counter', ctx=ast.Load())),
lineno=test_node.lineno + 3, col_offset=test_node.col_offset),
ast.Expr(value=ast.Call(
func=ast.Attribute(value=ast.Name(id='gc', ctx=ast.Load()), attr='enable', ctx=ast.Load()), args=[],
keywords=[]), lineno=test_node.lineno + 4, col_offset=test_node.col_offset),
ast.Expr(ast.Call(func=ast.Name(id="___log_test_values", ctx=ast.Load()),
args=[ast.Name(id='return_value', ctx=ast.Load()),
ast.Name(id='duration', ctx=ast.Load()),
ast.Constant(value=self.only_function_name + "_" + node_name + "_" + index)],
keywords=[]), lineno=test_node.lineno + 5, col_offset=test_node.col_offset)]
return [
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="gc", ctx=ast.Load()), attr="disable", ctx=ast.Load()
),
args=[],
keywords=[],
),
lineno=test_node.lineno,
col_offset=test_node.col_offset,
),
ast.Assign(
targets=[ast.Name(id="counter", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="time", ctx=ast.Load()),
attr="perf_counter_ns",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
lineno=test_node.lineno + 1,
col_offset=test_node.col_offset,
),
ast.Assign(
targets=[ast.Name(id="return_value", ctx=ast.Store())],
value=test_node,
lineno=test_node.lineno + 2,
col_offset=test_node.col_offset,
),
ast.Assign(
targets=[ast.Name(id="duration", ctx=ast.Store())],
value=ast.BinOp(
left=ast.Call(
func=ast.Attribute(
value=ast.Name(id="time", ctx=ast.Load()),
attr="perf_counter_ns",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
op=ast.Sub(),
right=ast.Name(id="counter", ctx=ast.Load()),
),
lineno=test_node.lineno + 3,
col_offset=test_node.col_offset,
),
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="gc", ctx=ast.Load()), attr="enable", ctx=ast.Load()
),
args=[],
keywords=[],
),
lineno=test_node.lineno + 4,
col_offset=test_node.col_offset,
),
ast.Expr(
ast.Call(
func=ast.Name(id="_log__test__values", ctx=ast.Load()),
args=[
ast.Name(id="return_value", ctx=ast.Load()),
ast.Name(id="duration", ctx=ast.Load()),
ast.Constant(value=self.only_function_name + "_" + node_name + "_" + index),
],
keywords=[],
),
lineno=test_node.lineno + 5,
col_offset=test_node.col_offset,
),
]
def visit_FunctionDef(self, node: ast.FunctionDef):
only_function_name = self.function_name.split(".")[-1]
if node.name == self.function_name or node.name in self.auxiliary_function_names:
if node.name == self.function_name or node.name in self.auxillary_function_names:
return None # Remove the re-definition of the function and its dependencies from the test generation code
elif node.name.startswith("test_"):
i = len(node.body) - 1
@ -85,24 +145,39 @@ class InjectPerfAndLogging(ast.NodeTransformer):
while j >= 0:
with_line_node = line_node.body[j]
for with_node in ast.walk(with_line_node):
if self.is_target_function_node(with_node):
line_node.body[j:j+1] = self.update_line_node(with_node, node.name, str(i) + "_" + str(j))
line_node.body[j : j + 1] = self.update_line_node(
with_node, node.name, str(i) + "_" + str(j)
)
break
j -= 1
else:
for test_node in ast.walk(line_node):
if self.is_target_function_node(test_node):
node.body[i:i + 1] = self.update_line_node(test_node, node.name, str(i))
node.body[i : i + 1] = self.update_line_node(
test_node, node.name, str(i)
)
break
i -= 1
if self.test_framework == "unittest":
node.decorator_list.append(
ast.Call(
func=ast.Attribute(
value=ast.Name(id="timeout_decorator", ctx=ast.Load()),
attr="timeout",
ctx=ast.Load(),
),
args=[ast.Constant(value=self.individual_test_timeout)],
keywords=[],
)
)
return node
def inject_logging_code(test_code, iteration=0):
logging_code = f"""
import pickle
def ___log_test_values(values, duration, test_name):
def _log__test__values(values, duration, test_name):
with open(f'/tmp/test_return_values_{iteration}.bin', 'ab') as f:
return_bytes = pickle.dumps(values)
_test_name = f"{{test_name}}".encode("ascii")

View file

@ -22,6 +22,7 @@ class args:
file: Optional[str] = None
function: Optional[str] = None
tests_root: Optional[str] = "tests"
test_framework: Optional[str] = "pytest"
MAX_TEST_RUN_ITERATIONS = 5
@ -71,7 +72,7 @@ def main() -> None:
source_code_being_tested=code_to_optimize,
function_name=function_name,
module_path=module_path,
function_dependencies=function_dependencies
function_dependencies=function_dependencies,
)
print(new_tests)
generated_tests_path = get_test_file_path(args.tests_root, function_name, 0)

View file

@ -3,17 +3,33 @@ import subprocess
def run_tests(
test_path,
test_framework: str,
cwd: str = None,
test_env=None,
pytest_timeout: int = None,
verbose: bool = False,
):
pytest_results = subprocess.run(
["pytest", test_path, "-q", f"--timeout={pytest_timeout}"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=cwd,
env=test_env,
)
stdout = pytest_results.stdout.decode("utf-8")
stderr = pytest_results.stderr.decode("utf-8")
assert test_framework in ["pytest", "unittest"]
if test_framework == "pytest":
pytest_results = subprocess.run(
["pytest", test_path, "-q", f"--timeout={pytest_timeout}"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=cwd,
env=test_env,
)
stdout = pytest_results.stdout.decode("utf-8")
stderr = pytest_results.stderr.decode("utf-8")
elif test_framework == "unittest":
unittest_results = subprocess.run(
["python", "-m", "unittest"] + (["-v"] if verbose else []) + [test_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=cwd,
env=test_env,
)
stdout = unittest_results.stdout.decode("utf-8")
stderr = unittest_results.stderr.decode("utf-8")
else:
raise ValueError("Invalid test framework, we only support Pytest and Unittest currently.")
return stdout, stderr

View file

@ -18,10 +18,10 @@ def generate_tests(source_code_being_tested, function_name, module_path, functio
module_node = ast.parse(generated_test_source)
print("REGRESSION TESTS SOURCE GENERATED", generated_test_source)
auxiliary_function_names = [
definition[0] for definition in function_dependencies
]
new_module_node = InjectPerfAndLogging(function_name, auxiliary_function_names=auxiliary_function_names).visit(module_node)
auxiliary_function_names = [definition[0] for definition in function_dependencies]
new_module_node = InjectPerfAndLogging(
function_name, auxiliary_function_names=auxiliary_function_names
).visit(module_node)
new_module_node.body = [
ast.Import(names=[ast.alias(name="time")]),
ast.ImportFrom(