Add support for parameterized tests

This commit is contained in:
ihitamandal 2024-01-29 15:01:27 -08:00
parent 647d66992b
commit 30dfc55d69
2 changed files with 38 additions and 5 deletions

View file

@ -0,0 +1,15 @@
from code_to_optimize.bubble_sort import sorter
import pytest
@pytest.mark.parametrize(
"input, expected_output",
[
([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]),
([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]),
(list(reversed(range(5000))), list(range(5000))),
],
)
def test_sort_parametrized(input, expected_output):
output = sorter(input)
assert output == expected_output

View file

@ -44,6 +44,7 @@ class TestsInFile:
class TestFunction:
function_name: str
test_suite_name: Optional[str]
parameters: Optional[str]
def discover_unit_tests(cfg: TestConfig) -> Dict[str, List[TestsInFile]]:
@ -78,7 +79,13 @@ def discover_tests_pytest(cfg: TestConfig) -> Dict[str, List[TestsInFile]]:
file_to_test_map = defaultdict(list)
for test in tests:
file_to_test_map[test.test_file].append({"test_function": test.test_function})
test_function = test.test_function
parameters = None
if "[" in test_function:
parameters = re.findall("\[(.*?)\]", test_function)[0]
file_to_test_map[test.test_file].append(
{"test_function": test.test_function, "parameters": parameters}
)
# Within these test files, find the project functions they are referring to and return their names/locations
return process_test_files(file_to_test_map, cfg)
@ -129,8 +136,16 @@ def process_test_files(
for name in top_level_names:
if test_framework == "pytest":
functions_to_search = [elem["test_function"] for elem in functions]
if name.name in functions_to_search and name.type == "function":
test_functions.add(TestFunction(name.name, None))
for function in functions_to_search:
if "[" in function:
function_name = re.split(r"\[|\]", function)[0]
parameters = re.split(r"\[|\]", function)[1]
if name.name == function_name and name.type == "function":
test_functions.add(TestFunction(name.name, None, parameters))
else:
if name.name == function and name.type == "function":
test_functions.add(TestFunction(name.name, None, None))
break
if test_framework == "unittest":
functions_to_search = [elem["test_function"] for elem in functions]
test_suites = [elem["test_suite_name"] for elem in functions]
@ -153,10 +168,11 @@ def process_test_files(
if not m:
continue
scope = m.group(1)
index = test_functions_raw.index(scope) if scope in test_functions_raw else -1
if index >= 0:
indices = [i for i, x in enumerate(test_functions_raw) if x == scope]
for index in indices:
scope_test_function = test_functions_list[index].function_name
scope_test_suite = test_functions_list[index].test_suite_name
scope_parameters = test_functions_list[index].parameters
try:
definition = script.goto(
line=name.line,
@ -174,6 +190,8 @@ def process_test_files(
definition_path.startswith(str(project_root_path) + os.sep)
and definition[0].module_name != name.module_name
):
if scope_parameters is not None:
scope_test_function += "[" + scope_parameters + "]"
function_to_test_map[definition[0].full_name].append(
TestsInFile(test_file, None, scope_test_function, scope_test_suite)
)