"""Data models for test execution and results.""" from __future__ import annotations import logging from collections import Counter, defaultdict from pathlib import Path from typing import TYPE_CHECKING import attrs import libcst as cst from .._model import VerificationType from ..test_discovery.models import TestType if TYPE_CHECKING: from collections.abc import Iterator from ..benchmarking.models import BenchmarkKey from ..test_discovery.models import TestsInFile log = logging.getLogger(__name__) @attrs.frozen class InvocationId: """Identifies a specific test function invocation.""" test_module_path: str test_class_name: str | None test_function_name: str | None function_getting_tested: str iteration_id: str | None def id(self) -> str: """Return a unique string identifier for this invocation.""" class_prefix = ( f"{self.test_class_name}." if self.test_class_name else "" ) return ( f"{self.test_module_path}:{class_prefix}" f"{self.test_function_name}:" f"{self.function_getting_tested}:{self.iteration_id}" ) def test_fn_qualified_name(self) -> str: """Return *ClassName.test_function* or just *test_function*.""" if self.test_class_name: return f"{self.test_class_name}.{self.test_function_name}" return str(self.test_function_name) @staticmethod def find_func_in_class( class_node: cst.ClassDef, func_name: str, ) -> cst.FunctionDef | None: """Find a function definition inside a class node.""" for stmt in class_node.body.body: if ( isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name ): return stmt return None def get_src_code(self, test_path: Path) -> str | None: """Extract the source code of this test function from *test_path*.""" if not test_path.exists(): return None try: test_src = test_path.read_text(encoding="utf-8") module_node = cst.parse_module(test_src) except (cst.ParserSyntaxError, UnicodeDecodeError): return ( f"# Test: {self.test_function_name}\n" f"# File: {test_path.name}\n" f"# Testing function: {self.function_getting_tested}" ) if self.test_class_name: for stmt in module_node.body: if ( isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name ): func_node = self.find_func_in_class( stmt, self.test_function_name or "", ) if func_node: return module_node.code_for_node( func_node, ).strip() return None for stmt in module_node.body: if ( isinstance(stmt, cst.FunctionDef) and stmt.name.value == self.test_function_name ): return module_node.code_for_node(stmt).strip() return None @staticmethod def from_str_id( string_id: str, iteration_id: str | None = None, ) -> InvocationId: """Parse an invocation id from its string form.""" components = string_id.split(":") if len(components) != 4: # noqa: PLR2004 msg = ( f"Expected 4 colon-separated components, " f"got {len(components)}: {string_id!r}" ) raise ValueError(msg) second_components = components[1].split(".") if len(second_components) == 1: test_class_name = None test_function_name = second_components[0] else: test_class_name = second_components[0] test_function_name = second_components[1] return InvocationId( test_module_path=components[0], test_class_name=test_class_name, test_function_name=test_function_name, function_getting_tested=components[2], iteration_id=(iteration_id or components[3]), ) @attrs.frozen class FunctionTestInvocation: """A single function invocation result from a test run.""" loop_index: int id: InvocationId file_name: Path = attrs.field(converter=Path) did_pass: bool runtime: int | None test_framework: str test_type: TestType return_value: object | None cpu_runtime: int timed_out: bool | None verification_type: str | None = VerificationType.FUNCTION_CALL stdout: str | None = None @property def unique_invocation_loop_id(self) -> str: """Return a unique id incorporating the loop index.""" return f"{self.loop_index}:{self.id.id()}" @attrs.define class TestResults: """Collection of test invocation results.""" __test__ = False test_results: list[FunctionTestInvocation] = attrs.Factory(list) test_result_idx: dict[str, int] = attrs.Factory(dict) perf_stdout: str | None = None test_failures: dict[str, str] | None = None def add( self, function_test_invocation: FunctionTestInvocation, ) -> None: """Add an invocation, skipping duplicates.""" uid = function_test_invocation.unique_invocation_loop_id if uid in self.test_result_idx: log.debug("Test result with id %s already exists, skipping", uid) return self.test_result_idx[uid] = len(self.test_results) self.test_results.append(function_test_invocation) def merge(self, other: TestResults) -> None: """Merge another *TestResults* into this one.""" offset = len(self.test_results) self.test_results.extend(other.test_results) for key, idx in other.test_result_idx.items(): if key in self.test_result_idx: msg = f"Duplicate test result id: {key}" raise ValueError(msg) self.test_result_idx[key] = idx + offset def get_by_unique_invocation_loop_id( self, uid: str, ) -> FunctionTestInvocation | None: """Look up an invocation by its unique loop id.""" try: return self.test_results[self.test_result_idx[uid]] except (IndexError, KeyError): return None def number_of_loops(self) -> int: """Return the maximum loop index across all results.""" if not self.test_results: return 0 return max(r.loop_index for r in self.test_results) def usable_runtime_data_by_test_case( self, ) -> dict[InvocationId, list[int]]: """Return runtimes grouped by invocation id (passing only).""" by_id: dict[InvocationId, list[int]] = {} for result in self.test_results: if result.did_pass and result.runtime is not None: by_id.setdefault(result.id, []).append(result.runtime) return by_id def total_passed_runtime(self) -> int: """Sum of minimum runtimes across all passing test cases. Each test case's runtime is the minimum across all loop iterations. Returns nanoseconds. """ return sum( min(runtimes) for runtimes in self.usable_runtime_data_by_test_case().values() ) def file_to_no_of_tests( self, test_functions_to_remove: list[str], ) -> Counter[Path]: """Count generated regression results per file, excluding *test_functions_to_remove*.""" counts: Counter[Path] = Counter() for result in self.test_results: if ( result.test_type == TestType.GENERATED_REGRESSION and result.id.test_function_name not in test_functions_to_remove ): counts[result.file_name] += 1 return counts def __iter__(self) -> Iterator[FunctionTestInvocation]: """Iterate over test invocation results.""" return iter(self.test_results) def __len__(self) -> int: """Return the number of test invocation results.""" return len(self.test_results) def __getitem__(self, index: int) -> FunctionTestInvocation: """Return the test invocation result at the given index.""" return self.test_results[index] def __bool__(self) -> bool: """Return True if there are any test results.""" return bool(self.test_results) def __contains__( self, value: object, ) -> bool: """Check if a test invocation result is in this collection.""" return value in self.test_results def get_all_unique_invocation_loop_ids(self) -> set[str]: """Return the set of all unique invocation loop ids.""" return { result.unique_invocation_loop_id for result in self.test_results } def get_test_pass_fail_report_by_type( self, ) -> dict[TestType, dict[str, int]]: """Count passed/failed tests grouped by test type.""" report: dict[TestType, dict[str, int]] = { tt: {"passed": 0, "failed": 0} for tt in TestType } for result in self.test_results: if result.loop_index != 1: continue if result.did_pass: report[result.test_type]["passed"] += 1 else: report[result.test_type]["failed"] += 1 return report def group_by_benchmarks( self, benchmark_keys: list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path, ) -> dict[BenchmarkKey, TestResults]: """Group replay test results by benchmark key. Each benchmark key maps to the :class:`TestResults` whose replay test module path starts with the expected prefix derived from the benchmark's module path. """ from ..test_discovery.linking import ( # noqa: PLC0415 module_name_from_file_path, ) test_results_by_benchmark: dict[BenchmarkKey, TestResults] = ( defaultdict(TestResults) ) benchmark_module_path: dict[BenchmarkKey, str] = {} for benchmark_key in benchmark_keys: benchmark_module_path[benchmark_key] = module_name_from_file_path( benchmark_replay_test_dir.resolve() / ( "test_" + benchmark_key.module_path.replace(".", "_") + "__replay_test_" ), project_root, ) for test_result in self.test_results: if test_result.test_type == TestType.REPLAY_TEST: for bk, mod_path in benchmark_module_path.items(): if test_result.id.test_module_path.startswith( mod_path, ): test_results_by_benchmark[bk].add(test_result) return test_results_by_benchmark @attrs.frozen class TestFile: """A test file ready for execution.""" __test__ = False original_file_path: Path = attrs.field(converter=Path) instrumented_behavior_file_path: Path | None = None benchmarking_file_path: Path | None = None test_type: TestType = TestType.EXISTING_UNIT_TEST tests_in_file: tuple[TestsInFile, ...] = () @attrs.define class TestFiles: """Collection of test files for a test run.""" __test__ = False test_files: list[TestFile] = attrs.Factory(list) def get_test_type_by_instrumented_file_path( self, path: Path, ) -> TestType | None: """Find the test type for an instrumented file path.""" resolved = path.resolve() for tf in self.test_files: if ( tf.instrumented_behavior_file_path and tf.instrumented_behavior_file_path.resolve() == resolved ): return tf.test_type if ( tf.benchmarking_file_path and tf.benchmarking_file_path.resolve() == resolved ): return tf.test_type return None def get_test_type_by_original_file_path( self, path: Path, ) -> TestType | None: """Find the test type for an original file path.""" resolved = path.resolve() for tf in self.test_files: if tf.original_file_path.resolve() == resolved: return tf.test_type return None @attrs.frozen class TestConfig: """Configuration for test execution.""" __test__ = False tests_project_rootdir: Path = attrs.field(converter=Path) test_framework: str = "pytest" pytest_cmd: str = "pytest" tests_root: str | Path = "tests" project_root_path: str | Path = "." use_cache: bool = True module_root: Path | None = None