Add unittest discovery and instrumentation
This commit is contained in:
parent
80dfb25d5d
commit
fd7002235b
5 changed files with 250 additions and 74 deletions
|
|
@ -1,31 +1,106 @@
|
|||
import subprocess
|
||||
import jedi
|
||||
import re
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import re
|
||||
import subprocess
|
||||
import unittest
|
||||
from collections import defaultdict, namedtuple
|
||||
|
||||
import jedi
|
||||
|
||||
TestsInFile = namedtuple("TestsInFile", ["test_file", "test_function", "test_suite"])
|
||||
|
||||
|
||||
def discover_unit_tests(test_directory, project_root_path):
|
||||
def discover_unit_tests(test_directory, project_root_path, test_framework="pytest"):
|
||||
if test_framework == "pytest":
|
||||
return discover_tests_pytest(test_directory, project_root_path)
|
||||
elif test_framework == "unittest":
|
||||
return discover_tests_unittest(test_directory, project_root_path)
|
||||
|
||||
|
||||
def discover_tests_pytest(test_directory, project_root_path):
|
||||
pytest_result = subprocess.run(
|
||||
["pytest", f"{test_directory}", "--co", "-q"], stdout=subprocess.PIPE, cwd=project_root_path
|
||||
["pytest", f"{test_directory}", "--co", "-q"],
|
||||
stdout=subprocess.PIPE,
|
||||
cwd=test_directory,
|
||||
)
|
||||
tests = parse_tests(pytest_result.stdout.decode("utf-8"))
|
||||
file_to_test_map = defaultdict(list)
|
||||
function_to_test_map = defaultdict(list)
|
||||
|
||||
for test in tests:
|
||||
test_file, function = test.split("::")
|
||||
test_file = os.path.join(project_root_path, test_file)
|
||||
file_to_test_map[test_file].append(function)
|
||||
test_file_path = os.path.join(test_directory, test_file)
|
||||
if not os.path.exists(test_file_path):
|
||||
# Seeing that in some circumstances pytest also returns the cwd as part of the test file path
|
||||
one_level_up_test_directory = os.path.abspath(os.path.join(test_directory, ".."))
|
||||
test_file_path = os.path.join(one_level_up_test_directory, test_file)
|
||||
if not os.path.exists(test_file_path):
|
||||
raise ValueError(
|
||||
f"Pytest test discovery failed. Test file {test_file_path} does not exist."
|
||||
)
|
||||
file_to_test_map[test_file_path].append({"test_function": function})
|
||||
# Within these test files, find the project functions they are referring to and return their names/locations
|
||||
return process_test_files(file_to_test_map, project_root_path, test_framework="pytest")
|
||||
|
||||
|
||||
def discover_tests_unittest(test_directory, project_root_path):
|
||||
loader = unittest.TestLoader()
|
||||
tests = loader.discover(str(test_directory))
|
||||
file_to_test_map = defaultdict(list)
|
||||
for _test_suite in tests._tests:
|
||||
for test_suite_2 in _test_suite._tests:
|
||||
if not hasattr(test_suite_2, "_tests"):
|
||||
print("Didn't find tests for ", test_suite_2)
|
||||
continue
|
||||
for test in test_suite_2._tests:
|
||||
m = re.search("(.*)\s\((.*)\.(.*)\)", str(test))
|
||||
if not m:
|
||||
continue
|
||||
test_function, test_module, test_suite_name = (
|
||||
m.group(1),
|
||||
m.group(2),
|
||||
m.group(3),
|
||||
)
|
||||
test_module_path = test_module.replace(".", os.sep)
|
||||
test_module_path = os.path.join(str(test_directory), test_module_path) + ".py"
|
||||
if not os.path.exists(test_module_path):
|
||||
continue
|
||||
file_to_test_map[test_module_path].append(
|
||||
{"test_function": test_function, "test_suite_name": test_suite_name}
|
||||
)
|
||||
return process_test_files(file_to_test_map, project_root_path, test_framework="unittest")
|
||||
|
||||
|
||||
def process_test_files(file_to_test_map, project_root_path, test_framework="pytest"):
|
||||
function_to_test_map = defaultdict(list)
|
||||
jedi_project = jedi.Project(path=project_root_path)
|
||||
TestFunction = namedtuple("TestFunction", ["function_name", "test_suite_name"])
|
||||
|
||||
for test_file, functions in file_to_test_map.items():
|
||||
script = jedi.Script(path=test_file, project=jedi_project)
|
||||
test_functions = set()
|
||||
top_level_names = script.get_names()
|
||||
all_names = script.get_names(all_scopes=True, references=True)
|
||||
all_defs = script.get_names(all_scopes=True, definitions=True)
|
||||
|
||||
for name in top_level_names:
|
||||
if name.name in functions and name.type == "function":
|
||||
test_functions.add(name.name)
|
||||
if test_framework == "pytest":
|
||||
functions_to_search = [elem["test_function"] for elem in functions]
|
||||
if name.name in functions_to_search and name.type == "function":
|
||||
test_functions.add(TestFunction(name.name, None))
|
||||
if test_framework == "unittest":
|
||||
functions_to_search = [elem["test_function"] for elem in functions]
|
||||
test_suites = [elem["test_suite_name"] for elem in functions]
|
||||
if name.name in test_suites and name.type == "class":
|
||||
for def_name in all_defs:
|
||||
if (
|
||||
def_name.name in functions_to_search
|
||||
and def_name.type == "function"
|
||||
and def_name.full_name is not None
|
||||
and f".{name.name}." in def_name.full_name
|
||||
):
|
||||
test_functions.add(TestFunction(def_name.name, name.name))
|
||||
test_functions_list = list(test_functions)
|
||||
test_functions_raw = [elem.function_name for elem in test_functions_list]
|
||||
|
||||
for name in all_names:
|
||||
if name.full_name is None:
|
||||
continue
|
||||
|
|
@ -33,21 +108,30 @@ def discover_unit_tests(test_directory, project_root_path):
|
|||
if not m:
|
||||
continue
|
||||
scope = m.group(1)
|
||||
if scope in test_functions:
|
||||
definition = script.goto(
|
||||
line=name.line,
|
||||
column=name.column,
|
||||
follow_imports=True,
|
||||
follow_builtin_imports=False,
|
||||
)
|
||||
index = test_functions_raw.index(scope) if scope in test_functions_raw else -1
|
||||
if index >= 0:
|
||||
scope_test_function = test_functions_list[index].function_name
|
||||
scope_test_suite = test_functions_list[index].test_suite_name
|
||||
try:
|
||||
definition = script.goto(
|
||||
line=name.line,
|
||||
column=name.column,
|
||||
follow_imports=True,
|
||||
follow_builtin_imports=False,
|
||||
)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
continue
|
||||
if definition and definition[0].type == "function":
|
||||
definition_path = str(definition[0].module_path)
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
definition_path.startswith(project_root_path + os.sep)
|
||||
definition_path.startswith(str(project_root_path) + os.sep)
|
||||
and definition[0].module_name != name.module_name
|
||||
):
|
||||
function_to_test_map[definition[0].full_name].append((test_file, scope))
|
||||
function_to_test_map[definition[0].full_name].append(
|
||||
TestsInFile(test_file, scope_test_function, scope_test_suite)
|
||||
)
|
||||
deduped_function_to_test_map = {}
|
||||
for function, tests in function_to_test_map.items():
|
||||
deduped_function_to_test_map[function] = list(set(tests))
|
||||
|
|
|
|||
|
|
@ -2,14 +2,24 @@ import ast
|
|||
|
||||
|
||||
class InjectPerfAndLogging(ast.NodeTransformer):
|
||||
def __init__(self, function_name, auxiliary_function_names):
|
||||
def __init__(
|
||||
self,
|
||||
function_name: str,
|
||||
auxillary_function_names,
|
||||
test_framework="pytest",
|
||||
test_timeout: int = 15,
|
||||
):
|
||||
self.function_name = function_name
|
||||
self.class_name = None
|
||||
self.only_function_name = self.function_name
|
||||
self.test_framework = test_framework
|
||||
self.individual_test_timeout = test_timeout
|
||||
if len(function_name.split(".")) > 1:
|
||||
self.class_name = function_name.split(".")[0]
|
||||
self.only_function_name = function_name.split(".")[-1]
|
||||
self.auxiliary_function_names = auxiliary_function_names # Other functional dependencies that were injected
|
||||
self.auxillary_function_names = (
|
||||
auxillary_function_names # Other functional dependencies that were injected
|
||||
)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom):
|
||||
if any([name.name == self.function_name for name in node.names]):
|
||||
|
|
@ -17,10 +27,12 @@ class InjectPerfAndLogging(ast.NodeTransformer):
|
|||
|
||||
def visit_ClassDef(self, node: ast.ClassDef):
|
||||
# If the original class exists during testing, then remove it all together
|
||||
if not self.class_name:
|
||||
return node
|
||||
if node.name == self.class_name:
|
||||
return None # Remove the re-definition of the class and its dependencies from the test generation code
|
||||
for inner_node in ast.walk(node):
|
||||
if isinstance(inner_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
inner_node = self.visit_FunctionDef(inner_node)
|
||||
return node
|
||||
|
||||
def visit_Assert(self, node: ast.Assert):
|
||||
# TODO : This does not work yet
|
||||
|
|
@ -31,42 +43,90 @@ class InjectPerfAndLogging(ast.NodeTransformer):
|
|||
return None
|
||||
|
||||
def is_target_function_node(self, node):
|
||||
return isinstance(node, ast.Call) and ((hasattr(node.func,
|
||||
'id') and node.func.id == self.only_function_name) or
|
||||
(hasattr(node.func,
|
||||
'attr') and node.func.attr == self.only_function_name))
|
||||
return isinstance(node, ast.Call) and (
|
||||
(hasattr(node.func, "id") and node.func.id == self.only_function_name)
|
||||
or (hasattr(node.func, "attr") and node.func.attr == self.only_function_name)
|
||||
)
|
||||
|
||||
def update_line_node(self, test_node, node_name, index: str):
|
||||
return [
|
||||
ast.Expr(value=ast.Call(
|
||||
func=ast.Attribute(value=ast.Name(id='gc', ctx=ast.Load()), attr='disable', ctx=ast.Load()), args=[],
|
||||
keywords=[]), lineno=test_node.lineno, col_offset=test_node.col_offset),
|
||||
ast.Assign(targets=[ast.Name(id='counter', ctx=ast.Store())],
|
||||
value=ast.Call(func=ast.Attribute(
|
||||
value=ast.Name(id='time', ctx=ast.Load()),
|
||||
attr='perf_counter_ns', ctx=ast.Load()), args=[], keywords=[]), lineno=test_node.lineno + 1,
|
||||
col_offset=test_node.col_offset),
|
||||
ast.Assign(targets=[ast.Name(id='return_value', ctx=ast.Store())],
|
||||
value=test_node, lineno=test_node.lineno + 2, col_offset=test_node.col_offset),
|
||||
ast.Assign(targets=[ast.Name(id='duration', ctx=ast.Store())],
|
||||
value=ast.BinOp(
|
||||
left=ast.Call(func=ast.Attribute(value=ast.Name(id='time', ctx=ast.Load()),
|
||||
attr='perf_counter_ns', ctx=ast.Load()),
|
||||
args=[], keywords=[]),
|
||||
op=ast.Sub(), right=ast.Name(id='counter', ctx=ast.Load())),
|
||||
lineno=test_node.lineno + 3, col_offset=test_node.col_offset),
|
||||
ast.Expr(value=ast.Call(
|
||||
func=ast.Attribute(value=ast.Name(id='gc', ctx=ast.Load()), attr='enable', ctx=ast.Load()), args=[],
|
||||
keywords=[]), lineno=test_node.lineno + 4, col_offset=test_node.col_offset),
|
||||
ast.Expr(ast.Call(func=ast.Name(id="___log_test_values", ctx=ast.Load()),
|
||||
args=[ast.Name(id='return_value', ctx=ast.Load()),
|
||||
ast.Name(id='duration', ctx=ast.Load()),
|
||||
ast.Constant(value=self.only_function_name + "_" + node_name + "_" + index)],
|
||||
keywords=[]), lineno=test_node.lineno + 5, col_offset=test_node.col_offset)]
|
||||
return [
|
||||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="gc", ctx=ast.Load()), attr="disable", ctx=ast.Load()
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=test_node.lineno,
|
||||
col_offset=test_node.col_offset,
|
||||
),
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="counter", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="time", ctx=ast.Load()),
|
||||
attr="perf_counter_ns",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=test_node.lineno + 1,
|
||||
col_offset=test_node.col_offset,
|
||||
),
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="return_value", ctx=ast.Store())],
|
||||
value=test_node,
|
||||
lineno=test_node.lineno + 2,
|
||||
col_offset=test_node.col_offset,
|
||||
),
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="duration", ctx=ast.Store())],
|
||||
value=ast.BinOp(
|
||||
left=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="time", ctx=ast.Load()),
|
||||
attr="perf_counter_ns",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
),
|
||||
op=ast.Sub(),
|
||||
right=ast.Name(id="counter", ctx=ast.Load()),
|
||||
),
|
||||
lineno=test_node.lineno + 3,
|
||||
col_offset=test_node.col_offset,
|
||||
),
|
||||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="gc", ctx=ast.Load()), attr="enable", ctx=ast.Load()
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=test_node.lineno + 4,
|
||||
col_offset=test_node.col_offset,
|
||||
),
|
||||
ast.Expr(
|
||||
ast.Call(
|
||||
func=ast.Name(id="_log__test__values", ctx=ast.Load()),
|
||||
args=[
|
||||
ast.Name(id="return_value", ctx=ast.Load()),
|
||||
ast.Name(id="duration", ctx=ast.Load()),
|
||||
ast.Constant(value=self.only_function_name + "_" + node_name + "_" + index),
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=test_node.lineno + 5,
|
||||
col_offset=test_node.col_offset,
|
||||
),
|
||||
]
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
only_function_name = self.function_name.split(".")[-1]
|
||||
if node.name == self.function_name or node.name in self.auxiliary_function_names:
|
||||
if node.name == self.function_name or node.name in self.auxillary_function_names:
|
||||
return None # Remove the re-definition of the function and its dependencies from the test generation code
|
||||
elif node.name.startswith("test_"):
|
||||
i = len(node.body) - 1
|
||||
|
|
@ -85,24 +145,39 @@ class InjectPerfAndLogging(ast.NodeTransformer):
|
|||
while j >= 0:
|
||||
with_line_node = line_node.body[j]
|
||||
for with_node in ast.walk(with_line_node):
|
||||
|
||||
if self.is_target_function_node(with_node):
|
||||
line_node.body[j:j+1] = self.update_line_node(with_node, node.name, str(i) + "_" + str(j))
|
||||
line_node.body[j : j + 1] = self.update_line_node(
|
||||
with_node, node.name, str(i) + "_" + str(j)
|
||||
)
|
||||
break
|
||||
j -= 1
|
||||
else:
|
||||
for test_node in ast.walk(line_node):
|
||||
if self.is_target_function_node(test_node):
|
||||
node.body[i:i + 1] = self.update_line_node(test_node, node.name, str(i))
|
||||
node.body[i : i + 1] = self.update_line_node(
|
||||
test_node, node.name, str(i)
|
||||
)
|
||||
break
|
||||
i -= 1
|
||||
if self.test_framework == "unittest":
|
||||
node.decorator_list.append(
|
||||
ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="timeout_decorator", ctx=ast.Load()),
|
||||
attr="timeout",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=[ast.Constant(value=self.individual_test_timeout)],
|
||||
keywords=[],
|
||||
)
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
def inject_logging_code(test_code, iteration=0):
|
||||
logging_code = f"""
|
||||
import pickle
|
||||
def ___log_test_values(values, duration, test_name):
|
||||
def _log__test__values(values, duration, test_name):
|
||||
with open(f'/tmp/test_return_values_{iteration}.bin', 'ab') as f:
|
||||
return_bytes = pickle.dumps(values)
|
||||
_test_name = f"{{test_name}}".encode("ascii")
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ class args:
|
|||
file: Optional[str] = None
|
||||
function: Optional[str] = None
|
||||
tests_root: Optional[str] = "tests"
|
||||
test_framework: Optional[str] = "pytest"
|
||||
|
||||
|
||||
MAX_TEST_RUN_ITERATIONS = 5
|
||||
|
|
@ -71,7 +72,7 @@ def main() -> None:
|
|||
source_code_being_tested=code_to_optimize,
|
||||
function_name=function_name,
|
||||
module_path=module_path,
|
||||
function_dependencies=function_dependencies
|
||||
function_dependencies=function_dependencies,
|
||||
)
|
||||
print(new_tests)
|
||||
generated_tests_path = get_test_file_path(args.tests_root, function_name, 0)
|
||||
|
|
|
|||
|
|
@ -3,17 +3,33 @@ import subprocess
|
|||
|
||||
def run_tests(
|
||||
test_path,
|
||||
test_framework: str,
|
||||
cwd: str = None,
|
||||
test_env=None,
|
||||
pytest_timeout: int = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
pytest_results = subprocess.run(
|
||||
["pytest", test_path, "-q", f"--timeout={pytest_timeout}"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=test_env,
|
||||
)
|
||||
stdout = pytest_results.stdout.decode("utf-8")
|
||||
stderr = pytest_results.stderr.decode("utf-8")
|
||||
assert test_framework in ["pytest", "unittest"]
|
||||
if test_framework == "pytest":
|
||||
pytest_results = subprocess.run(
|
||||
["pytest", test_path, "-q", f"--timeout={pytest_timeout}"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=test_env,
|
||||
)
|
||||
stdout = pytest_results.stdout.decode("utf-8")
|
||||
stderr = pytest_results.stderr.decode("utf-8")
|
||||
elif test_framework == "unittest":
|
||||
unittest_results = subprocess.run(
|
||||
["python", "-m", "unittest"] + (["-v"] if verbose else []) + [test_path],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=test_env,
|
||||
)
|
||||
stdout = unittest_results.stdout.decode("utf-8")
|
||||
stderr = unittest_results.stderr.decode("utf-8")
|
||||
else:
|
||||
raise ValueError("Invalid test framework, we only support Pytest and Unittest currently.")
|
||||
return stdout, stderr
|
||||
|
|
|
|||
|
|
@ -18,10 +18,10 @@ def generate_tests(source_code_being_tested, function_name, module_path, functio
|
|||
|
||||
module_node = ast.parse(generated_test_source)
|
||||
print("REGRESSION TESTS SOURCE GENERATED", generated_test_source)
|
||||
auxiliary_function_names = [
|
||||
definition[0] for definition in function_dependencies
|
||||
]
|
||||
new_module_node = InjectPerfAndLogging(function_name, auxiliary_function_names=auxiliary_function_names).visit(module_node)
|
||||
auxiliary_function_names = [definition[0] for definition in function_dependencies]
|
||||
new_module_node = InjectPerfAndLogging(
|
||||
function_name, auxiliary_function_names=auxiliary_function_names
|
||||
).visit(module_node)
|
||||
new_module_node.body = [
|
||||
ast.Import(names=[ast.alias(name="time")]),
|
||||
ast.ImportFrom(
|
||||
|
|
|
|||
Loading…
Reference in a new issue