use pytest as the execution engine for all tests (#951)

* first pass

restore

restore this too

Revert "first pass"

This reverts commit b507770b2c79cc948b33222d8877fb784bfe108a.

* continue

* Update uv.lock

* refresh lockfile

* bugfix

* temp

* fix these

* pytest changes

* formatting

* set up test env properly here too

* ruff

* make ruff happy

* Update e2e-bubblesort-unittest.yaml

* with pytest

* bugfix

* oops
This commit is contained in:
Kevin Turcios 2025-12-06 23:40:25 -05:00 committed by GitHub
parent c9e1483cda
commit 33437d39e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 824 additions and 765 deletions

View file

@ -61,6 +61,7 @@ jobs:
- name: Install dependencies (CLI)
run: |
uv sync
uv add timeout_decorator
- name: Run Codeflash to optimize code
id: optimize_code

View file

@ -1,7 +1,6 @@
from __future__ import annotations
import ast
import platform
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
@ -329,17 +328,6 @@ class InjectPerfOnly(ast.NodeTransformer):
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
if node.name.startswith("test_"):
did_update = False
if self.test_framework == "unittest" and platform.system() != "Windows":
# Only add timeout decorator on non-Windows platforms
# Windows doesn't support SIGALRM signal required by timeout_decorator
node.decorator_list.append(
ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
)
i = len(node.body) - 1
while i >= 0:
line_node = node.body[i]
@ -505,25 +493,6 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
self.class_name = function.top_level_parent_name
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
# Add timeout decorator for unittest test classes if needed
if self.test_framework == "unittest":
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
for item in node.body:
if (
isinstance(item, ast.FunctionDef)
and item.name.startswith("test_")
and not any(
isinstance(d, ast.Call)
and isinstance(d.func, ast.Name)
and d.func.id == "timeout_decorator.timeout"
for d in item.decorator_list
)
):
item.decorator_list.append(timeout_decorator)
return self.generic_visit(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
@ -542,25 +511,6 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
def _process_test_function(
self, node: ast.AsyncFunctionDef | ast.FunctionDef
) -> ast.AsyncFunctionDef | ast.FunctionDef:
# Optimize the search for decorator presence
if self.test_framework == "unittest":
found_timeout = False
for d in node.decorator_list:
# Avoid isinstance(d.func, ast.Name) if d is not ast.Call
if isinstance(d, ast.Call):
f = d.func
# Avoid attribute lookup if f is not ast.Name
if isinstance(f, ast.Name) and f.id == "timeout_decorator.timeout":
found_timeout = True
break
if not found_timeout:
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
node.decorator_list.append(timeout_decorator)
# Initialize counter for this test function
if node.name not in self.async_call_counter:
self.async_call_counter[node.name] = 0
@ -715,8 +665,6 @@ def inject_async_profiling_into_existing_test(
# Add necessary imports
new_imports = [ast.Import(names=[ast.alias(name="os")])]
if test_framework == "unittest":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
tree.body = [*new_imports, *tree.body]
return True, sort_imports(ast.unparse(tree), float_to_top=True)
@ -762,8 +710,6 @@ def inject_profiling_into_existing_test(
ast.Import(names=[ast.alias(name="dill", asname="pickle")]),
]
)
if test_framework == "unittest" and platform.system() != "Windows":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
additional_functions = [create_wrapper_function(mode)]
tree.body = [*new_imports, *additional_functions, *tree.body]

View file

@ -6,7 +6,6 @@ import os
import queue
import random
import subprocess
import time
import uuid
from collections import defaultdict
from pathlib import Path
@ -1641,58 +1640,35 @@ class FunctionOptimizer:
f"Test coverage is {coverage_results.coverage}%, which is below the required threshold of {COVERAGE_THRESHOLD}%."
)
if test_framework == "pytest":
with progress_bar("Running line profiler to identify performance bottlenecks..."):
line_profile_results = self.line_profiler_step(
code_context=code_context, original_helper_code=original_helper_code, candidate_index=0
with progress_bar("Running line profiler to identify performance bottlenecks..."):
line_profile_results = self.line_profiler_step(
code_context=code_context, original_helper_code=original_helper_code, candidate_index=0
)
console.rule()
with progress_bar("Running performance benchmarks..."):
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
)
console.rule()
with progress_bar("Running performance benchmarks..."):
try:
benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=total_looping_time,
enable_coverage=False,
code_context=code_context,
)
finally:
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
try:
benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=total_looping_time,
enable_coverage=False,
code_context=code_context,
)
finally:
if self.function_to_optimize.is_async:
self.write_code_and_helpers(
self.function_to_optimize_source_code,
original_helper_code,
self.function_to_optimize.file_path,
)
else:
benchmarking_results = TestResults()
start_time: float = time.time()
for i in range(100):
if i >= 5 and time.time() - start_time >= total_looping_time * 1.5:
# * 1.5 to give unittest a bit more time to run
break
test_env["CODEFLASH_LOOP_INDEX"] = str(i + 1)
with progress_bar("Running performance benchmarks..."):
unittest_loop_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=total_looping_time,
enable_coverage=False,
code_context=code_context,
unittest_loop_index=i + 1,
)
benchmarking_results.merge(unittest_loop_results)
console.print(
TestResults.report_to_tree(
behavioral_results.get_test_pass_fail_report_by_type(), title="Overall test results for original code"
@ -1760,8 +1736,6 @@ class FunctionOptimizer:
original_helper_code: dict[Path, str],
file_path_to_helper_classes: dict[Path, set[str]],
) -> Result[OptimizedCandidateResult, str]:
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018
with progress_bar("Testing optimization candidate"):
test_env = self.get_test_env(
codeflash_loop_index=0,
@ -1818,59 +1792,34 @@ class FunctionOptimizer:
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
if test_framework == "pytest":
# For async functions, instrument at definition site for performance benchmarking
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
# For async functions, instrument at definition site for performance benchmarking
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
)
try:
candidate_benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=optimization_candidate_index,
testing_time=total_looping_time,
enable_coverage=False,
)
finally:
# Restore original source if we instrumented it
if self.function_to_optimize.is_async:
self.write_code_and_helpers(
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
)
loop_count = (
max(all_loop_indices)
if (
all_loop_indices := {
result.loop_index for result in candidate_benchmarking_results.test_results
}
)
else 0
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
)
else:
candidate_benchmarking_results = TestResults()
start_time: float = time.time()
loop_count = 0
for i in range(100):
if i >= 5 and time.time() - start_time >= TOTAL_LOOPING_TIME_EFFECTIVE * 1.5:
# * 1.5 to give unittest a bit more time to run
break
test_env["CODEFLASH_LOOP_INDEX"] = str(i + 1)
unittest_loop_results, _cov = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=optimization_candidate_index,
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
unittest_loop_index=i + 1,
try:
candidate_benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=optimization_candidate_index,
testing_time=total_looping_time,
enable_coverage=False,
)
finally:
# Restore original source if we instrumented it
if self.function_to_optimize.is_async:
self.write_code_and_helpers(
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
)
loop_count = i + 1
candidate_benchmarking_results.merge(unittest_loop_results)
loop_count = (
max(all_loop_indices)
if (all_loop_indices := {result.loop_index for result in candidate_benchmarking_results.test_results})
else 0
)
if (total_candidate_timing := candidate_benchmarking_results.total_passed_runtime()) == 0:
logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.")
@ -1920,7 +1869,6 @@ class FunctionOptimizer:
pytest_min_loops: int = 5,
pytest_max_loops: int = 250,
code_context: CodeOptimizationContext | None = None,
unittest_loop_index: int | None = None,
line_profiler_output_file: Path | None = None,
) -> tuple[TestResults | dict, CoverageData | None]:
coverage_database_file = None
@ -1933,7 +1881,6 @@ class FunctionOptimizer:
cwd=self.project_root,
test_env=test_env,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
verbose=True,
enable_coverage=enable_coverage,
)
elif testing_type == TestingMode.LINE_PROFILE:
@ -1947,7 +1894,6 @@ class FunctionOptimizer:
pytest_min_loops=1,
pytest_max_loops=1,
test_framework=self.test_cfg.test_framework,
line_profiler_output_file=line_profiler_output_file,
)
elif testing_type == TestingMode.PERFORMANCE:
result_file_path, run_result = run_benchmarking_tests(
@ -1996,7 +1942,6 @@ class FunctionOptimizer:
test_config=self.test_cfg,
optimization_iteration=optimization_iteration,
run_result=run_result,
unittest_loop_index=unittest_loop_index,
function_name=self.function_to_optimize.function_name,
source_file=self.function_to_optimize.file_path,
code_context=code_context,

View file

@ -64,6 +64,56 @@ def calculate_function_throughput_from_test_results(test_results: TestResults, f
return function_throughput
def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> Path | None:
"""Resolve test file path from pytest's test class path.
This function handles various cases where pytest's classname in JUnit XML
includes parent directories that may already be part of base_dir.
Args:
test_class_path: The full class path from pytest (e.g., "project.tests.test_file.TestClass")
base_dir: The base directory for tests (tests project root)
Returns:
Path to the test file if found, None otherwise
Examples:
>>> # base_dir = "/path/to/tests"
>>> # test_class_path = "code_to_optimize.tests.unittest.test_file.TestClass"
>>> # Should find: /path/to/tests/unittest/test_file.py
"""
# First try the full path
test_file_path = file_name_from_test_module_name(test_class_path, base_dir)
# If we couldn't find the file, try stripping the last component (likely a class name)
# This handles cases like "module.TestClass" where TestClass is a class, not a module
if test_file_path is None and "." in test_class_path:
module_without_class = ".".join(test_class_path.split(".")[:-1])
test_file_path = file_name_from_test_module_name(module_without_class, base_dir)
# If still not found, progressively strip prefix components
# This handles cases where pytest's classname includes parent directories that are
# already part of base_dir (e.g., "project.tests.unittest.test_file.TestClass"
# when base_dir is "/.../tests")
if test_file_path is None:
parts = test_class_path.split(".")
# Try stripping 1, 2, 3, ... prefix components
for num_to_strip in range(1, len(parts)):
remaining = ".".join(parts[num_to_strip:])
test_file_path = file_name_from_test_module_name(remaining, base_dir)
if test_file_path:
break
# Also try without the last component (class name)
if "." in remaining:
remaining_no_class = ".".join(remaining.split(".")[:-1])
test_file_path = file_name_from_test_module_name(remaining_no_class, base_dir)
if test_file_path:
break
return test_file_path
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
test_results = TestResults()
if not file_location.exists():
@ -198,7 +248,6 @@ def parse_test_xml(
test_files: TestFiles,
test_config: TestConfig,
run_result: subprocess.CompletedProcess | None = None,
unittest_loop_index: int | None = None,
) -> TestResults:
test_results = TestResults()
# Parse unittest output
@ -211,9 +260,8 @@ def parse_test_xml(
except Exception as e:
logger.warning(f"Failed to parse {test_xml_file_path} as JUnitXml. Exception: {e}")
return test_results
base_dir = (
test_config.tests_project_rootdir if test_config.test_framework == "pytest" else test_config.project_root_path
)
# Always use tests_project_rootdir since pytest is now the test runner for all frameworks
base_dir = test_config.tests_project_rootdir
for suite in xml:
for testcase in suite:
class_name = testcase.classname
@ -253,7 +301,8 @@ def parse_test_xml(
if test_file_name is None:
if test_class_path:
# TODO : This might not be true if the test is organized under a class
test_file_path = file_name_from_test_module_name(test_class_path, base_dir)
test_file_path = resolve_test_file_from_class_path(test_class_path, base_dir)
if test_file_path is None:
logger.warning(f"Could not find the test for file name - {test_class_path} ")
continue
@ -274,24 +323,15 @@ def parse_test_xml(
if class_name is not None and class_name.startswith(test_module_path):
test_class = class_name[len(test_module_path) + 1 :] # +1 for the dot, gets Unittest class name
loop_index = unittest_loop_index if unittest_loop_index is not None else 1
loop_index = int(testcase.name.split("[ ")[-1][:-2]) if testcase.name and "[" in testcase.name else 1
timed_out = False
if test_config.test_framework == "pytest":
loop_index = int(testcase.name.split("[ ")[-1][:-2]) if testcase.name and "[" in testcase.name else 1
if len(testcase.result) > 1:
logger.debug(f"!!!!!Multiple results for {testcase.name or '<None>'} in {test_xml_file_path}!!!")
if len(testcase.result) == 1:
message = testcase.result[0].message.lower()
if "failed: timeout >" in message:
timed_out = True
else:
if len(testcase.result) > 1:
logger.debug(f"!!!!!Multiple results for {testcase.name or '<None>'} in {test_xml_file_path}!!!")
if len(testcase.result) == 1:
message = testcase.result[0].message.lower()
if "timed out" in message:
timed_out = True
if len(testcase.result) > 1:
logger.debug(f"!!!!!Multiple results for {testcase.name or '<None>'} in {test_xml_file_path}!!!")
if len(testcase.result) == 1:
message = testcase.result[0].message.lower()
if "failed: timeout >" in message or "timed out" in message:
timed_out = True
sys_stdout = testcase.system_out or ""
begin_matches = list(matches_re_start.finditer(sys_stdout))
@ -523,14 +563,9 @@ def parse_test_results(
coverage_config_file: Path | None,
code_context: CodeOptimizationContext | None = None,
run_result: subprocess.CompletedProcess | None = None,
unittest_loop_index: int | None = None,
) -> tuple[TestResults, CoverageData | None]:
test_results_xml = parse_test_xml(
test_xml_path,
test_files=test_files,
test_config=test_config,
run_result=run_result,
unittest_loop_index=unittest_loop_index,
test_xml_path, test_files=test_files, test_config=test_config, run_result=run_result
)
try:
bin_results_file = get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.bin"))

View file

@ -36,15 +36,14 @@ def run_behavioral_tests(
*,
pytest_timeout: int | None = None,
pytest_cmd: str = "pytest",
verbose: bool = False,
pytest_target_runtime_seconds: int = TOTAL_LOOPING_TIME_EFFECTIVE,
enable_coverage: bool = False,
) -> tuple[Path, subprocess.CompletedProcess, Path | None, Path | None]:
if test_framework == "pytest":
if test_framework in {"pytest", "unittest"}:
test_files: list[str] = []
for file in test_paths.test_files:
if file.test_type == TestType.REPLAY_TEST:
# TODO: Does this work for unittest framework?
# Replay tests need specific test targeting because one file contains tests for multiple functions
test_files.extend(
[
str(file.instrumented_behavior_file_path) + "::" + test.test_function
@ -61,13 +60,14 @@ def run_behavioral_tests(
test_files = list(set(test_files)) # remove multiple calls in the same test function
common_pytest_args = [
"--capture=tee-sys",
f"--timeout={pytest_timeout}",
"-q",
"--codeflash_loops_scope=session",
"--codeflash_min_loops=1",
"--codeflash_max_loops=1",
f"--codeflash_seconds={pytest_target_runtime_seconds}",
]
if pytest_timeout is not None:
common_pytest_args.insert(1, f"--timeout={pytest_timeout}")
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
@ -120,18 +120,6 @@ def run_behavioral_tests(
logger.debug(
f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ""}"""
)
elif test_framework == "unittest":
if enable_coverage:
msg = "Coverage is not supported yet for unittest framework"
raise ValueError(msg)
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_files = [file.instrumented_behavior_file_path for file in test_paths.test_files]
result_file_path, results = run_unittest_tests(
verbose=verbose, test_file_paths=test_files, test_env=test_env, cwd=cwd
)
logger.debug(
f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ""}"""
)
else:
msg = f"Unsupported test framework: {test_framework}"
raise ValueError(msg)
@ -152,42 +140,30 @@ def run_line_profile_tests(
test_framework: str,
*,
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME_EFFECTIVE,
verbose: bool = False,
pytest_timeout: int | None = None,
pytest_min_loops: int = 5, # noqa: ARG001
pytest_max_loops: int = 100_000, # noqa: ARG001
line_profiler_output_file: Path | None = None,
) -> tuple[Path, subprocess.CompletedProcess]:
if test_framework == "pytest":
if test_framework in {"pytest", "unittest"}: # pytest runs both pytest and unittest tests
pytest_cmd_list = (
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX)
if pytest_cmd == "pytest"
else shlex.split(pytest_cmd)
)
test_files: list[str] = []
for file in test_paths.test_files:
if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file:
test_files.extend(
[
str(file.benchmarking_file_path)
+ "::"
+ (test.test_class + "::" if test.test_class else "")
+ (test.test_function.split("[", 1)[0] if "[" in test.test_function else test.test_function)
for test in file.tests_in_file
]
)
else:
test_files.append(str(file.benchmarking_file_path))
test_files = list(set(test_files)) # remove multiple calls in the same test function
# Always use file path - pytest discovers all tests including parametrized ones
test_files: list[str] = list(
{str(file.benchmarking_file_path) for file in test_paths.test_files}
) # remove multiple calls in the same test function
pytest_args = [
"--capture=tee-sys",
f"--timeout={pytest_timeout}",
"-q",
"--codeflash_loops_scope=session",
"--codeflash_min_loops=1",
"--codeflash_max_loops=1",
f"--codeflash_seconds={pytest_target_runtime_seconds}",
]
if pytest_timeout is not None:
pytest_args.insert(1, f"--timeout={pytest_timeout}")
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
pytest_test_env = test_env.copy()
@ -200,34 +176,10 @@ def run_line_profile_tests(
env=pytest_test_env,
timeout=600, # TODO: Make this dynamic
)
elif test_framework == "unittest":
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_env["LINE_PROFILE"] = "1"
test_files: list[str] = []
for file in test_paths.test_files:
if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file:
test_files.extend(
[
str(file.benchmarking_file_path)
+ "::"
+ (test.test_class + "::" if test.test_class else "")
+ (test.test_function.split("[", 1)[0] if "[" in test.test_function else test.test_function)
for test in file.tests_in_file
]
)
else:
test_files.append(str(file.benchmarking_file_path))
test_files = list(set(test_files)) # remove multiple calls in the same test function
line_profiler_output_file, results = run_unittest_tests(
verbose=verbose, test_file_paths=[Path(file) for file in test_files], test_env=test_env, cwd=cwd
)
logger.debug(
f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ""}"""
)
else:
msg = f"Unsupported test framework: {test_framework}"
raise ValueError(msg)
return line_profiler_output_file, results
return result_file_path, results
def run_benchmarking_tests(
@ -238,41 +190,30 @@ def run_benchmarking_tests(
test_framework: str,
*,
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME_EFFECTIVE,
verbose: bool = False,
pytest_timeout: int | None = None,
pytest_min_loops: int = 5,
pytest_max_loops: int = 100_000,
) -> tuple[Path, subprocess.CompletedProcess]:
if test_framework == "pytest":
if test_framework in {"pytest", "unittest"}: # pytest runs both pytest and unittest tests
pytest_cmd_list = (
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX)
if pytest_cmd == "pytest"
else shlex.split(pytest_cmd)
)
test_files: list[str] = []
for file in test_paths.test_files:
if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file:
test_files.extend(
[
str(file.benchmarking_file_path)
+ "::"
+ (test.test_class + "::" if test.test_class else "")
+ (test.test_function.split("[", 1)[0] if "[" in test.test_function else test.test_function)
for test in file.tests_in_file
]
)
else:
test_files.append(str(file.benchmarking_file_path))
test_files = list(set(test_files)) # remove multiple calls in the same test function
# Always use file path - pytest discovers all tests including parametrized ones
test_files: list[str] = list(
{str(file.benchmarking_file_path) for file in test_paths.test_files}
) # remove multiple calls in the same test function
pytest_args = [
"--capture=tee-sys",
f"--timeout={pytest_timeout}",
"-q",
"--codeflash_loops_scope=session",
f"--codeflash_min_loops={pytest_min_loops}",
f"--codeflash_max_loops={pytest_max_loops}",
f"--codeflash_seconds={pytest_target_runtime_seconds}",
]
if pytest_timeout is not None:
pytest_args.insert(1, f"--timeout={pytest_timeout}")
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
pytest_test_env = test_env.copy()
@ -284,26 +225,7 @@ def run_benchmarking_tests(
env=pytest_test_env,
timeout=600, # TODO: Make this dynamic
)
elif test_framework == "unittest":
test_files = [file.benchmarking_file_path for file in test_paths.test_files]
result_file_path, results = run_unittest_tests(
verbose=verbose, test_file_paths=test_files, test_env=test_env, cwd=cwd
)
else:
msg = f"Unsupported test framework: {test_framework}"
raise ValueError(msg)
return result_file_path, results
def run_unittest_tests(
*, verbose: bool, test_file_paths: list[Path], test_env: dict[str, str], cwd: Path
) -> tuple[Path, subprocess.CompletedProcess]:
result_file_path = get_run_tmp_file(Path("unittest_results.xml"))
unittest_cmd_list = [SAFE_SYS_EXECUTABLE, "-m", "xmlrunner"]
log_level = ["-v"] if verbose else []
files = [str(file) for file in test_file_paths]
output_file = ["--output-file", str(result_file_path)]
results = execute_test_subprocess(
unittest_cmd_list + log_level + files + output_file, cwd=cwd, env=test_env, timeout=600
)
return result_file_path, results

View file

@ -21,10 +21,8 @@ dependencies = [
"gitpython>=3.1.31",
"libcst>=1.0.1",
"jedi>=0.19.1",
"timeout-decorator>=0.5.0",
"pytest-timeout>=2.1.0",
"tomlkit>=0.11.7",
"unittest-xml-reporting>=3.2.0",
"junitparser>=3.1.0",
"pydantic>=1.10.1",
"humanize>=4.0.0",

View file

@ -22,6 +22,7 @@ from codeflash.code_utils.code_utils import (
from codeflash.code_utils.concolic_utils import clean_concolic_tests
from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
from codeflash.models.models import CodeStringsMarkdown
from codeflash.verification.parse_test_output import resolve_test_file_from_class_path
@pytest.fixture
@ -497,6 +498,134 @@ def test_partial_module_name2(base_dir: Path) -> None:
assert result == base_dir / "subdir" / "test_submodule.py"
def test_pytest_unittest_path_resolution_with_prefix(tmp_path: Path) -> None:
"""Test path resolution when pytest includes parent directory in classname.
This handles the case where pytest's base_dir is /path/to/tests but the
classname includes the parent directory like "project.tests.unittest.test_file.TestClass".
"""
# Setup directory structure: /tmp/code_to_optimize/tests/unittest/
project_root = tmp_path / "code_to_optimize"
tests_root = project_root / "tests"
unittest_dir = tests_root / "unittest"
unittest_dir.mkdir(parents=True, exist_ok=True)
# Create test files
test_file = unittest_dir / "test_bubble_sort.py"
test_file.touch()
generated_test = unittest_dir / "test_sorter__unit_test_0.py"
generated_test.touch()
# Case 1: pytest reports classname with full path including "code_to_optimize.tests"
# but base_dir is .../tests (not the project root)
result = resolve_test_file_from_class_path(
"code_to_optimize.tests.unittest.test_bubble_sort.TestPigLatin",
tests_root
)
assert result == test_file
# Case 2: Generated test file with class name
result = resolve_test_file_from_class_path(
"code_to_optimize.tests.unittest.test_sorter__unit_test_0.TestSorter",
tests_root
)
assert result == generated_test
# Case 3: Without the class name (just the module path)
result = resolve_test_file_from_class_path(
"code_to_optimize.tests.unittest.test_bubble_sort",
tests_root
)
assert result == test_file
def test_pytest_unittest_multiple_prefix_levels(tmp_path: Path) -> None:
"""Test path resolution with multiple levels of prefix stripping."""
# Setup: /tmp/org/project/src/tests/unit/
base = tmp_path / "org" / "project" / "src" / "tests"
unit_dir = base / "unit"
unit_dir.mkdir(parents=True, exist_ok=True)
test_file = unit_dir / "test_example.py"
test_file.touch()
# pytest might report: org.project.src.tests.unit.test_example.TestClass
# with base_dir being .../src/tests or .../tests
result = resolve_test_file_from_class_path(
"org.project.src.tests.unit.test_example.TestClass",
base
)
assert result == test_file
# Also test with base_dir at different level
result = resolve_test_file_from_class_path(
"project.src.tests.unit.test_example.TestClass",
base
)
assert result == test_file
def test_pytest_unittest_instrumented_files(tmp_path: Path) -> None:
"""Test path resolution for instrumented test files."""
tests_root = tmp_path / "tests" / "unittest"
tests_root.mkdir(parents=True, exist_ok=True)
# Create instrumented test file
instrumented_file = tests_root / "test_bubble_sort__perfinstrumented.py"
instrumented_file.touch()
# pytest classname includes parent directories
result = resolve_test_file_from_class_path(
"code_to_optimize.tests.unittest.test_bubble_sort__perfinstrumented.TestPigLatin",
tmp_path / "tests"
)
assert result == instrumented_file
def test_pytest_unittest_nested_classes(tmp_path: Path) -> None:
"""Test path resolution with nested class names."""
tests_root = tmp_path / "tests"
tests_root.mkdir(parents=True, exist_ok=True)
test_file = tests_root / "test_nested.py"
test_file.touch()
# Some unittest frameworks use nested classes
result = resolve_test_file_from_class_path(
"project.tests.test_nested.OuterClass.InnerClass",
tests_root
)
assert result == test_file
def test_pytest_unittest_no_match_returns_none(tmp_path: Path) -> None:
"""Test that non-existent files return None even with prefix stripping."""
tests_root = tmp_path / "tests"
tests_root.mkdir(parents=True, exist_ok=True)
# File doesn't exist
result = resolve_test_file_from_class_path(
"code_to_optimize.tests.unittest.nonexistent_test.TestClass",
tests_root
)
assert result is None
def test_pytest_unittest_single_component(tmp_path: Path) -> None:
"""Test that single-component paths still work."""
base_dir = tmp_path
test_file = base_dir / "test_simple.py"
test_file.touch()
result = file_name_from_test_module_name("test_simple", base_dir)
assert result == test_file
# With class name
result = file_name_from_test_module_name("test_simple.TestClass", base_dir)
assert result == test_file
def test_cleanup_paths(multiple_existing_and_non_existing_files: list[Path]) -> None:
cleanup_paths(multiple_existing_and_non_existing_files)
for file in multiple_existing_and_non_existing_files:

View file

@ -98,8 +98,7 @@ import time
import unittest
import dill as pickle"""
if platform.system() != "Windows":
imports += "\nimport timeout_decorator"
# timeout_decorator no longer used since pytest handles timeouts
if extra_imports:
imports += "\n" + extra_imports
return imports
@ -148,15 +147,14 @@ import time
import unittest
import dill as pickle"""
if platform.system() != "Windows":
imports += "\nimport timeout_decorator"
# timeout_decorator no longer used since pytest handles timeouts
imports += "\n\nfrom code_to_optimize.bubble_sort import sorter"
wrapper_func = codeflash_wrap_string
test_class_header = "class TestPigLatin(unittest.TestCase):"
test_decorator = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_decorator = "" # pytest-timeout handles timeouts now, not timeout_decorator
expected = imports + "\n\n\n" + wrapper_func + "\n" + test_class_header + "\n\n"
if test_decorator:
@ -1585,7 +1583,6 @@ import time
import unittest
import dill as pickle
import timeout_decorator
from code_to_optimize.bubble_sort import sorter
@ -1595,7 +1592,6 @@ from code_to_optimize.bubble_sort import sorter
+ """
class TestPigLatin(unittest.TestCase):
@timeout_decorator.timeout(15)
def test_sort(self):
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
@ -1626,8 +1622,6 @@ import os
import time
import unittest
import timeout_decorator
from code_to_optimize.bubble_sort import sorter
@ -1636,7 +1630,6 @@ from code_to_optimize.bubble_sort import sorter
+ """
class TestPigLatin(unittest.TestCase):
@timeout_decorator.timeout(15)
def test_sort(self):
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
input = [5, 4, 3, 2, 1, 0]
@ -1865,7 +1858,7 @@ class TestPigLatin(unittest.TestCase):
imports_behavior = build_expected_unittest_imports("from parameterized import parameterized")
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_behavior = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_decorator_behavior = "" # pytest-timeout handles timeouts now
test_class_behavior = """class TestPigLatin(unittest.TestCase):
@parameterized.expand([([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]), ([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), (list(reversed(range(50))), list(range(50)))])
@ -1892,11 +1885,10 @@ import os
import time
import unittest
"""
if platform.system() != "Windows":
imports_perf += "\nimport timeout_decorator"
# pytest-timeout handles timeouts now, no timeout_decorator needed
imports_perf += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_perf = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_decorator_perf = "" # pytest-timeout handles timeouts now
test_class_perf = """class TestPigLatin(unittest.TestCase):
@parameterized.expand([([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]), ([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), (list(reversed(range(50))), list(range(50)))])
@ -2116,7 +2108,7 @@ class TestPigLatin(unittest.TestCase):
imports_behavior = build_expected_unittest_imports()
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_behavior = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_decorator_behavior = "" # pytest-timeout handles timeouts now
test_class_behavior = """class TestPigLatin(unittest.TestCase):
"""
@ -2148,13 +2140,10 @@ import os
import time
import unittest
"""
if platform.system() != "Windows":
imports_perf += "\nimport timeout_decorator"
imports_perf += "\n\nfrom code_to_optimize.bubble_sort import sorter"
else:
imports_perf += "\nfrom code_to_optimize.bubble_sort import sorter"
# pytest-timeout handles timeouts now, no timeout_decorator needed
imports_perf += "\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_perf = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_decorator_perf = "" # pytest-timeout handles timeouts now
test_class_perf = """class TestPigLatin(unittest.TestCase):
"""
@ -2378,7 +2367,7 @@ class TestPigLatin(unittest.TestCase):
imports_behavior = build_expected_unittest_imports("from parameterized import parameterized")
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_behavior = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_decorator_behavior = "" # pytest-timeout handles timeouts now
test_class_behavior = """class TestPigLatin(unittest.TestCase):
@parameterized.expand([([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]), ([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), (list(reversed(range(50))), list(range(50)))])
@ -2406,11 +2395,10 @@ import os
import time
import unittest
"""
if platform.system() != "Windows":
imports_perf += "\nimport timeout_decorator"
# pytest-timeout handles timeouts now, no timeout_decorator needed
imports_perf += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_perf = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_decorator_perf = "" # pytest-timeout handles timeouts now
test_class_perf = """class TestPigLatin(unittest.TestCase):
@parameterized.expand([([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]), ([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), (list(reversed(range(50))), list(range(50)))])
@ -3225,11 +3213,10 @@ import os
import time
import unittest
"""
if platform.system() != "Windows":
imports += "\nimport timeout_decorator"
# pytest-timeout handles timeouts now, no timeout_decorator needed
imports += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.sleeptime import accurate_sleepfunc"
test_decorator = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_decorator = "" # pytest-timeout handles timeouts now
test_class = """class TestPigLatin(unittest.TestCase):
@parameterized.expand([(0.01, 0.01), (0.02, 0.02)])
@ -3307,6 +3294,8 @@ import unittest
test_env=test_env,
test_files=test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
)

View file

@ -34,6 +34,14 @@ class TestUnittestRunnerSorter(unittest.TestCase):
tests_project_rootdir=cur_dir_path.parent,
)
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_env["CODEFLASH_TRACER_DISABLE"] = "1"
if "PYTHONPATH" not in test_env:
test_env["PYTHONPATH"] = str(config.project_root_path)
else:
test_env["PYTHONPATH"] += os.pathsep + str(config.project_root_path)
with tempfile.TemporaryDirectory(dir=cur_dir_path) as temp_dir:
test_file_path = Path(temp_dir) / "test_xx.py"
test_files = TestFiles(
@ -44,7 +52,7 @@ class TestUnittestRunnerSorter(unittest.TestCase):
test_files,
test_framework=config.test_framework,
cwd=Path(config.project_root_path),
test_env=os.environ.copy(),
test_env=test_env,
)
results = parse_test_xml(result_file, test_files, config, process)
assert results[0].did_pass, "Test did not pass as expected"

990
uv.lock

File diff suppressed because it is too large Load diff