498 lines
22 KiB
Python
498 lines
22 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import re
|
|
import sqlite3
|
|
import subprocess
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
import dill as pickle
|
|
from junitparser.xunit2 import JUnitXml
|
|
from lxml.etree import XMLParser, parse
|
|
|
|
from codeflash.cli_cmds.console import DEBUG_MODE, console, logger
|
|
from codeflash.code_utils.code_utils import (
|
|
file_name_from_test_module_name,
|
|
file_path_from_module_name,
|
|
get_run_tmp_file,
|
|
module_name_from_file_path,
|
|
)
|
|
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
|
|
from codeflash.models.models import CoverageData, TestFiles
|
|
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, VerificationType
|
|
|
|
if TYPE_CHECKING:
|
|
import subprocess
|
|
|
|
from codeflash.models.models import CodeOptimizationContext
|
|
from codeflash.verification.verification_utils import TestConfig
|
|
|
|
|
|
def parse_func(file_path: Path) -> XMLParser:
|
|
"""Parse the XML file with lxml.etree.XMLParser as the backend."""
|
|
xml_parser = XMLParser(huge_tree=True)
|
|
return parse(file_path, xml_parser)
|
|
|
|
|
|
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
|
|
test_results = TestResults()
|
|
if not file_location.exists():
|
|
logger.warning(f"No test results for {file_location} found.")
|
|
console.rule()
|
|
return test_results
|
|
|
|
with file_location.open("rb") as file:
|
|
while file:
|
|
len_next_bytes = file.read(4)
|
|
if not len_next_bytes:
|
|
return test_results
|
|
len_next = int.from_bytes(len_next_bytes, byteorder="big")
|
|
encoded_test_name = file.read(len_next).decode("ascii")
|
|
duration_bytes = file.read(8)
|
|
duration = int.from_bytes(duration_bytes, byteorder="big")
|
|
len_next_bytes = file.read(4)
|
|
if not len_next_bytes:
|
|
return test_results
|
|
len_next = int.from_bytes(len_next_bytes, byteorder="big")
|
|
try:
|
|
test_pickle_bin = file.read(len_next)
|
|
except Exception as e:
|
|
if DEBUG_MODE:
|
|
logger.exception(f"Failed to load pickle file. Exception: {e}")
|
|
return test_results
|
|
loop_index_bytes = file.read(8)
|
|
loop_index = int.from_bytes(loop_index_bytes, byteorder="big")
|
|
len_next_bytes = file.read(4)
|
|
len_next = int.from_bytes(len_next_bytes, byteorder="big")
|
|
invocation_id = file.read(len_next).decode("ascii")
|
|
|
|
invocation_id_object = InvocationId.from_str_id(encoded_test_name, invocation_id)
|
|
test_file_path = file_path_from_module_name(
|
|
invocation_id_object.test_module_path, test_config.tests_project_rootdir
|
|
)
|
|
|
|
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
|
|
try:
|
|
test_pickle = pickle.loads(test_pickle_bin) if loop_index == 1 else None
|
|
except Exception as e:
|
|
if DEBUG_MODE:
|
|
logger.exception(f"Failed to load pickle file for {encoded_test_name} Exception: {e}")
|
|
continue
|
|
assert test_type is not None, f"Test type not found for {test_file_path}"
|
|
test_results.add(
|
|
function_test_invocation=FunctionTestInvocation(
|
|
loop_index=loop_index,
|
|
id=invocation_id_object,
|
|
file_name=test_file_path,
|
|
did_pass=True,
|
|
runtime=duration,
|
|
test_framework=test_config.test_framework,
|
|
test_type=test_type,
|
|
return_value=test_pickle,
|
|
timed_out=False,
|
|
verification_type=VerificationType.FUNCTION_TO_OPTIMIZE,
|
|
)
|
|
)
|
|
return test_results
|
|
|
|
|
|
def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
|
|
test_results = TestResults()
|
|
if not sqlite_file_path.exists():
|
|
logger.warning(f"No test results for {sqlite_file_path} found.")
|
|
console.rule()
|
|
return test_results
|
|
try:
|
|
db = sqlite3.connect(sqlite_file_path)
|
|
cur = db.cursor()
|
|
data = cur.execute(
|
|
"SELECT test_module_path, test_class_name, test_function_name, "
|
|
"function_getting_tested, loop_index, iteration_id, runtime, return_value,verification_type FROM test_results"
|
|
).fetchall()
|
|
finally:
|
|
db.close()
|
|
for val in data:
|
|
try:
|
|
test_module_path = val[0]
|
|
test_file_path = file_path_from_module_name(test_module_path, test_config.tests_project_rootdir)
|
|
# TODO : this is because sqlite writes original file module path. Should make it consistent
|
|
test_type = test_files.get_test_type_by_original_file_path(test_file_path)
|
|
loop_index = val[4]
|
|
verification_type = val[8]
|
|
try:
|
|
ret_val = (pickle.loads(val[7]) if loop_index == 1 else None,)
|
|
except Exception:
|
|
continue
|
|
test_results.add(
|
|
function_test_invocation=FunctionTestInvocation(
|
|
loop_index=loop_index,
|
|
id=InvocationId(
|
|
test_module_path=val[0],
|
|
test_class_name=val[1],
|
|
test_function_name=val[2],
|
|
function_getting_tested=val[3],
|
|
iteration_id=val[5],
|
|
),
|
|
file_name=test_file_path,
|
|
did_pass=True,
|
|
runtime=val[6],
|
|
test_framework=test_config.test_framework,
|
|
test_type=test_type,
|
|
return_value=ret_val,
|
|
timed_out=False,
|
|
verification_type=VerificationType(verification_type) if verification_type else None,
|
|
)
|
|
)
|
|
except Exception:
|
|
logger.exception(f"Failed to parse sqlite test results for {sqlite_file_path}")
|
|
# Hardcoding the test result to True because the test did execute and we are only interested in the return values,
|
|
# the did_pass comes from the xml results file
|
|
return test_results
|
|
|
|
|
|
def parse_test_xml(
|
|
test_xml_file_path: Path,
|
|
test_files: TestFiles,
|
|
test_config: TestConfig,
|
|
run_result: subprocess.CompletedProcess | None = None,
|
|
unittest_loop_index: int | None = None,
|
|
) -> TestResults:
|
|
test_results = TestResults()
|
|
# Parse unittest output
|
|
if not test_xml_file_path.exists():
|
|
logger.warning(f"No test results for {test_xml_file_path} found.")
|
|
console.rule()
|
|
return test_results
|
|
try:
|
|
xml = JUnitXml.fromfile(str(test_xml_file_path), parse_func=parse_func)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to parse {test_xml_file_path} as JUnitXml. Exception: {e}")
|
|
return test_results
|
|
base_dir = (
|
|
test_config.tests_project_rootdir if test_config.test_framework == "pytest" else test_config.project_root_path
|
|
)
|
|
for suite in xml:
|
|
for testcase in suite:
|
|
class_name = testcase.classname
|
|
test_file_name = suite._elem.attrib.get("file")
|
|
if (
|
|
test_file_name == f"unittest{os.sep}loader.py"
|
|
and class_name == "unittest.loader._FailedTest"
|
|
and suite.errors == 1
|
|
and suite.tests == 1
|
|
):
|
|
# This means that the test failed to load, so we don't want to crash on it
|
|
logger.info("Test failed to load, skipping it.")
|
|
if run_result is not None:
|
|
if isinstance(run_result.stdout, str) and isinstance(run_result.stderr, str):
|
|
logger.info(f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}")
|
|
else:
|
|
logger.info(
|
|
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}"
|
|
)
|
|
return test_results
|
|
|
|
test_class_path = testcase.classname
|
|
try:
|
|
test_function = testcase.name.split("[", 1)[0] if "[" in testcase.name else testcase.name
|
|
except (AttributeError, TypeError) as e:
|
|
msg = (
|
|
f"Accessing testcase.name in parse_test_xml for testcase {testcase!r} in file"
|
|
f" {test_xml_file_path} has exception: {e}"
|
|
)
|
|
logger.exception(msg)
|
|
continue
|
|
if test_file_name is None:
|
|
if test_class_path:
|
|
# TODO : This might not be true if the test is organized under a class
|
|
test_file_path = file_name_from_test_module_name(test_class_path, base_dir)
|
|
if test_file_path is None:
|
|
logger.warning(f"Could not find the test for file name - {test_class_path} ")
|
|
continue
|
|
else:
|
|
test_file_path = file_path_from_module_name(test_function, base_dir)
|
|
else:
|
|
test_file_path = base_dir / test_file_name
|
|
assert test_file_path, f"Test file path not found for {test_file_name}"
|
|
|
|
if not test_file_path.exists():
|
|
logger.warning(f"Could not find the test for file name - {test_file_path} ")
|
|
continue
|
|
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
|
|
assert test_type is not None, f"Test type not found for {test_file_path}"
|
|
test_module_path = module_name_from_file_path(test_file_path, test_config.tests_project_rootdir)
|
|
result = testcase.is_passed # TODO: See for the cases of ERROR and SKIPPED
|
|
test_class = None
|
|
if class_name is not None and class_name.startswith(test_module_path):
|
|
test_class = class_name[len(test_module_path) + 1 :] # +1 for the dot, gets Unittest class name
|
|
|
|
loop_index = unittest_loop_index if unittest_loop_index is not None else 1
|
|
|
|
timed_out = False
|
|
if test_config.test_framework == "pytest":
|
|
loop_index = int(testcase.name.split("[ ")[-1][:-2]) if "[" in testcase.name else 1
|
|
if len(testcase.result) > 1:
|
|
logger.warning(f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!")
|
|
if len(testcase.result) == 1:
|
|
message = testcase.result[0].message.lower()
|
|
if "failed: timeout >" in message:
|
|
timed_out = True
|
|
else:
|
|
if len(testcase.result) > 1:
|
|
logger.warning(f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!")
|
|
if len(testcase.result) == 1:
|
|
message = testcase.result[0].message.lower()
|
|
if "timed out" in message:
|
|
timed_out = True
|
|
matches = re.findall(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!", testcase.system_out or "")
|
|
if not matches or not len(matches):
|
|
test_results.add(
|
|
FunctionTestInvocation(
|
|
loop_index=loop_index,
|
|
id=InvocationId(
|
|
test_module_path=test_module_path,
|
|
test_class_name=test_class,
|
|
test_function_name=test_function,
|
|
function_getting_tested="", # FIXME
|
|
iteration_id="",
|
|
),
|
|
file_name=test_file_path,
|
|
runtime=None,
|
|
test_framework=test_config.test_framework,
|
|
did_pass=result,
|
|
test_type=test_type,
|
|
return_value=None,
|
|
timed_out=timed_out,
|
|
)
|
|
)
|
|
|
|
else:
|
|
for match in matches:
|
|
split_val = match[5].split(":")
|
|
if len(split_val) > 1:
|
|
iteration_id = split_val[0]
|
|
runtime = int(split_val[1])
|
|
else:
|
|
iteration_id, runtime = split_val[0], None
|
|
test_results.add(
|
|
FunctionTestInvocation(
|
|
loop_index=int(match[4]),
|
|
id=InvocationId(
|
|
test_module_path=match[0],
|
|
test_class_name=None if match[1] == "" else match[1][:-1],
|
|
test_function_name=match[2],
|
|
function_getting_tested=match[3],
|
|
iteration_id=iteration_id,
|
|
),
|
|
file_name=test_file_path,
|
|
runtime=runtime,
|
|
test_framework=test_config.test_framework,
|
|
did_pass=result,
|
|
test_type=test_type,
|
|
return_value=None,
|
|
timed_out=timed_out,
|
|
)
|
|
)
|
|
|
|
if not test_results:
|
|
logger.info(
|
|
f"Tests '{[test_file.original_file_path for test_file in test_files.test_files]}' failed to run, skipping"
|
|
)
|
|
if run_result is not None:
|
|
stdout, stderr = "", ""
|
|
try:
|
|
stdout = run_result.stdout.decode()
|
|
stderr = run_result.stderr.decode()
|
|
except AttributeError:
|
|
stdout = run_result.stderr
|
|
logger.debug(f"Test log - STDOUT : {stdout} \n STDERR : {stderr}")
|
|
return test_results
|
|
|
|
|
|
def merge_test_results(
|
|
xml_test_results: TestResults, bin_test_results: TestResults, test_framework: str
|
|
) -> TestResults:
|
|
merged_test_results = TestResults()
|
|
|
|
grouped_xml_results: defaultdict[str, TestResults] = defaultdict(TestResults)
|
|
grouped_bin_results: defaultdict[str, TestResults] = defaultdict(TestResults)
|
|
|
|
# This is done to match the right iteration_id which might not be available in the xml
|
|
for result in xml_test_results:
|
|
if test_framework == "pytest":
|
|
if result.id.test_function_name.endswith("]") and "[" in result.id.test_function_name: # parameterized test
|
|
test_function_name = result.id.test_function_name[: result.id.test_function_name.index("[")]
|
|
else:
|
|
test_function_name = result.id.test_function_name
|
|
|
|
if test_framework == "unittest":
|
|
test_function_name = result.id.test_function_name
|
|
is_parameterized, new_test_function_name, _ = discover_parameters_unittest(test_function_name)
|
|
if is_parameterized: # handle parameterized test
|
|
test_function_name = new_test_function_name
|
|
|
|
grouped_xml_results[
|
|
result.id.test_module_path
|
|
+ ":"
|
|
+ (result.id.test_class_name or "")
|
|
+ ":"
|
|
+ test_function_name
|
|
+ ":"
|
|
+ str(result.loop_index)
|
|
].add(result)
|
|
|
|
for result in bin_test_results:
|
|
grouped_bin_results[
|
|
result.id.test_module_path
|
|
+ ":"
|
|
+ (result.id.test_class_name or "")
|
|
+ ":"
|
|
+ result.id.test_function_name
|
|
+ ":"
|
|
+ str(result.loop_index)
|
|
].add(result)
|
|
|
|
for result_id in grouped_xml_results:
|
|
xml_results = grouped_xml_results[result_id]
|
|
bin_results = grouped_bin_results.get(result_id)
|
|
if not bin_results:
|
|
merged_test_results.merge(xml_results)
|
|
continue
|
|
|
|
if len(xml_results) == 1:
|
|
xml_result = xml_results[0]
|
|
# This means that we only have one FunctionTestInvocation for this test xml. Match them to the bin results
|
|
# Either a whole test function fails or passes.
|
|
for result_bin in bin_results:
|
|
merged_test_results.add(
|
|
FunctionTestInvocation(
|
|
loop_index=xml_result.loop_index,
|
|
id=result_bin.id,
|
|
file_name=xml_result.file_name,
|
|
runtime=result_bin.runtime,
|
|
test_framework=xml_result.test_framework,
|
|
did_pass=xml_result.did_pass,
|
|
test_type=xml_result.test_type,
|
|
return_value=result_bin.return_value,
|
|
timed_out=xml_result.timed_out,
|
|
verification_type=VerificationType(result_bin.verification_type),
|
|
)
|
|
)
|
|
elif xml_results.test_results[0].id.iteration_id is not None:
|
|
# This means that we have multiple iterations of the same test function
|
|
# We need to match the iteration_id to the bin results
|
|
for xml_result in xml_results.test_results:
|
|
try:
|
|
bin_result = bin_results.get_by_unique_invocation_loop_id(xml_result.unique_invocation_loop_id)
|
|
except AttributeError:
|
|
bin_result = None
|
|
if bin_result is None:
|
|
merged_test_results.add(xml_result)
|
|
continue
|
|
merged_test_results.add(
|
|
FunctionTestInvocation(
|
|
loop_index=xml_result.loop_index,
|
|
id=xml_result.id,
|
|
file_name=xml_result.file_name,
|
|
runtime=bin_result.runtime,
|
|
test_framework=xml_result.test_framework,
|
|
did_pass=bin_result.did_pass,
|
|
test_type=xml_result.test_type,
|
|
return_value=bin_result.return_value,
|
|
timed_out=xml_result.timed_out
|
|
if bin_result.runtime is None
|
|
else False, # If runtime was measured in the bin file, then the testcase did not time out
|
|
verification_type=VerificationType(bin_result.verification_type),
|
|
)
|
|
)
|
|
else:
|
|
# Should happen only if the xml did not have any test invocation id info
|
|
for i, bin_result in enumerate(bin_results.test_results):
|
|
try:
|
|
xml_result = xml_results.test_results[i]
|
|
except IndexError:
|
|
xml_result = None
|
|
if xml_result is None:
|
|
merged_test_results.add(bin_result)
|
|
continue
|
|
merged_test_results.add(
|
|
FunctionTestInvocation(
|
|
loop_index=bin_result.loop_index,
|
|
id=bin_result.id,
|
|
file_name=bin_result.file_name,
|
|
runtime=bin_result.runtime,
|
|
test_framework=bin_result.test_framework,
|
|
did_pass=bin_result.did_pass,
|
|
test_type=bin_result.test_type,
|
|
return_value=bin_result.return_value,
|
|
timed_out=xml_result.timed_out, # only the xml gets the timed_out flag
|
|
verification_type=VerificationType(bin_result.verification_type),
|
|
)
|
|
)
|
|
|
|
return merged_test_results
|
|
|
|
|
|
def parse_test_results(
|
|
test_xml_path: Path,
|
|
test_files: TestFiles,
|
|
test_config: TestConfig,
|
|
optimization_iteration: int,
|
|
function_name: str | None,
|
|
source_file: Path | None,
|
|
coverage_file: Path | None,
|
|
code_context: CodeOptimizationContext | None = None,
|
|
run_result: subprocess.CompletedProcess | None = None,
|
|
unittest_loop_index: int | None = None,
|
|
) -> tuple[TestResults, CoverageData | None]:
|
|
test_results_xml = parse_test_xml(
|
|
test_xml_path,
|
|
test_files=test_files,
|
|
test_config=test_config,
|
|
run_result=run_result,
|
|
unittest_loop_index=unittest_loop_index,
|
|
)
|
|
try:
|
|
bin_results_file = get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.bin"))
|
|
test_results_bin_file = (
|
|
parse_test_return_values_bin(bin_results_file, test_files=test_files, test_config=test_config)
|
|
if bin_results_file.exists()
|
|
else TestResults()
|
|
)
|
|
except AttributeError as e:
|
|
logger.exception(e)
|
|
test_results_bin_file = TestResults()
|
|
get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.bin")).unlink(missing_ok=True)
|
|
|
|
try:
|
|
sql_results_file = get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.sqlite"))
|
|
if sql_results_file.exists():
|
|
test_results_sqlite_file = parse_sqlite_test_results(
|
|
sqlite_file_path=sql_results_file, test_files=test_files, test_config=test_config
|
|
)
|
|
test_results_bin_file.merge(test_results_sqlite_file)
|
|
except AttributeError as e:
|
|
logger.exception(e)
|
|
|
|
get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.bin")).unlink(missing_ok=True)
|
|
|
|
get_run_tmp_file(Path("pytest_results.xml")).unlink(missing_ok=True)
|
|
get_run_tmp_file(Path("unittest_results.xml")).unlink(missing_ok=True)
|
|
get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.sqlite")).unlink(missing_ok=True)
|
|
results = merge_test_results(test_results_xml, test_results_bin_file, test_config.test_framework)
|
|
|
|
all_args = False
|
|
if coverage_file and coverage_file.exists() and source_file and code_context and function_name:
|
|
all_args = True
|
|
coverage = CoverageData.load_from_coverage_file(
|
|
coverage_file_path=coverage_file,
|
|
source_code_path=source_file,
|
|
code_context=code_context,
|
|
function_name=function_name,
|
|
)
|
|
coverage_file.unlink(missing_ok=True)
|
|
Path(".coverage").unlink(missing_ok=True)
|
|
coverage.log_coverage()
|
|
return results, coverage if all_args else None
|