codeflash/tests/scripts/end_to_end_test_utilities.py
2026-03-26 06:38:33 +02:00

312 lines
12 KiB
Python

import contextlib
import logging
import os
import pathlib
import re
import shutil
import subprocess
import time
from dataclasses import dataclass, field
from typing import Optional
try:
import tomllib
except ImportError:
import tomli as tomllib
@dataclass
class CoverageExpectation:
function_name: str
expected_coverage: float = 100.0
expected_lines: list[int] = field(default_factory=list) # Field with default list
@dataclass
class TestConfig:
# Make file_path optional when trace_mode is True
file_path: Optional[pathlib.Path] = None
function_name: Optional[str] = None
# Global count: "Discovered X existing unit tests and Y replay tests in Z.Zs at /path"
expected_unit_tests_count: Optional[int] = None
# Per-function count: "Discovered X existing unit test files, Y replay test files, and Z concolic..."
expected_unit_test_files: Optional[int] = None
min_improvement_x: float = 0.1
trace_mode: bool = False
coverage_expectations: list[CoverageExpectation] = field(default_factory=list)
benchmarks_root: Optional[pathlib.Path] = None
use_worktree: bool = False
no_gen_tests: bool = False
expected_acceptance_reason: Optional[str] = None # "runtime", "throughput", "concurrency"
def clear_directory(directory_path: str | pathlib.Path) -> None:
"""Empties all the files and subdirectories in the given directory to avoid errors in count of functions to be tested during retry."""
dir_path = pathlib.Path(directory_path)
if not dir_path.exists():
print(f"The directory {directory_path} does not exist.")
return
for item in dir_path.iterdir():
try:
if item.is_file() or item.is_symlink():
item.unlink() # Remove the file or symbolic link
elif item.is_dir():
shutil.rmtree(item) # Remove the subdirectory
except Exception as e:
print(f"Failed to delete {item}. Reason: {e}")
def validate_coverage(stdout: str, expectations: list[CoverageExpectation]) -> bool:
if not expectations:
return True
assert "CoverageData(" in stdout, "Failed to find CoverageData in stdout"
for expect in expectations:
pattern = rf"""
(?:main|dependent)_func_coverage=FunctionCoverage\(
\s+name='{expect.function_name}',
\s+coverage=([\d.]+),
\s+executed_lines=\[(.+?)\],
"""
coverage_match = re.search(pattern, stdout, re.VERBOSE)
assert coverage_match, f"Failed to find coverage data for {expect.function_name}"
coverage = float(coverage_match.group(1))
assert coverage == expect.expected_coverage, (
f"Coverage was {coverage} instead of {expect.expected_coverage} for function: {expect.function_name}"
)
executed_lines = list(map(int, coverage_match.group(2).split(", ")))
assert executed_lines == expect.expected_lines, (
f"Executed lines were {executed_lines} instead of {expect.expected_lines} for function: {expect.function_name}"
)
return True
def validate_no_gen_tests(stdout: str) -> bool:
if "Generated '0' tests for" not in stdout:
logging.error("Tests generated even when flag was on")
return False
return True
def run_codeflash_command(
cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int, expected_in_stdout: list[str] = None
) -> bool:
logging.basicConfig(level=logging.INFO)
if config.trace_mode:
return run_trace_test(cwd, config, expected_improvement_pct)
path_to_file = cwd / config.file_path
file_contents = path_to_file.read_text("utf-8")
pytest_dir = cwd / "tests" / "pytest"
test_root = pytest_dir if pytest_dir.is_dir() else cwd / "tests"
command = build_command(cwd, config, test_root, config.benchmarks_root if config.benchmarks_root else None)
env = os.environ.copy()
env["PYTHONIOENCODING"] = "utf-8"
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding="utf-8"
)
output = []
for line in process.stdout:
logging.info(line.strip())
output.append(line)
return_code = process.wait()
stdout = "".join(output)
validated = validate_output(stdout, return_code, expected_improvement_pct, config)
if not validated:
# Write original file contents back to file
path_to_file.write_text(file_contents, "utf-8")
logging.info("Codeflash run did not meet expected requirements for testing, reverting file changes.")
return False
if expected_in_stdout:
stdout_validated = validate_stdout_in_candidate(stdout, expected_in_stdout)
if not stdout_validated:
logging.error("Failed to find expected output in candidate output")
validated = False
logging.info("Success: Expected output found in candidate output")
return validated
def build_command(
cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path, benchmarks_root: pathlib.Path | None = None
) -> list[str]:
repo_root = pathlib.Path(__file__).parent.parent.parent
python_path = os.path.relpath(repo_root / "codeflash" / "main.py", cwd)
base_command = ["uv", "run", "--no-project", python_path, "--file", config.file_path, "--no-pr"]
if config.function_name:
base_command.extend(["--function", config.function_name])
# Check if config exists (pyproject.toml, pom.xml, build.gradle) - if so, don't override it
has_codeflash_config = (
(cwd / "pom.xml").exists() or (cwd / "build.gradle").exists() or (cwd / "build.gradle.kts").exists()
)
if not has_codeflash_config:
pyproject_path = cwd / "pyproject.toml"
if pyproject_path.exists():
with contextlib.suppress(Exception), open(pyproject_path, "rb") as f:
pyproject_data = tomllib.load(f)
has_codeflash_config = "tool" in pyproject_data and "codeflash" in pyproject_data["tool"]
# Only pass --tests-root and --module-root if they're not configured in config files
if not has_codeflash_config:
base_command.extend(["--tests-root", str(test_root), "--module-root", str(cwd)])
if benchmarks_root:
base_command.extend(["--benchmark", "--benchmarks-root", str(benchmarks_root)])
if config.use_worktree:
base_command.append("--worktree")
if config.no_gen_tests:
base_command.append("--no-gen-tests")
return base_command
def validate_output(stdout: str, return_code: int, expected_improvement_pct: int, config: TestConfig) -> bool:
if return_code != 0:
logging.error(f"Command returned exit code {return_code} instead of 0")
return False
if "⚡️ Optimization successful! 📄 " not in stdout:
logging.error("Failed to find performance improvement message")
return False
improvement_match = re.search(r"📈 ([\d,]+)% (?:(\w+) )?improvement", stdout)
if not improvement_match:
logging.error("Could not find improvement percentage in output")
return False
improvement_pct = int(improvement_match.group(1).replace(",", ""))
improvement_x = float(improvement_pct) / 100
print("Performance improvement:", improvement_pct, "; Performance improvement rate:", improvement_x)
if improvement_pct <= expected_improvement_pct:
logging.error(f"Performance improvement {improvement_pct}% not above {expected_improvement_pct}%")
return False
if improvement_x <= config.min_improvement_x:
logging.error(f"Performance improvement rate {improvement_x}x not above {config.min_improvement_x}x")
return False
if config.expected_acceptance_reason is not None:
actual_reason = improvement_match.group(2)
if not actual_reason:
logging.error("Could not find acceptance reason type in output")
return False
if actual_reason != config.expected_acceptance_reason:
logging.error(f"Expected acceptance reason '{config.expected_acceptance_reason}', got '{actual_reason}'")
return False
if config.expected_unit_tests_count is not None:
# Match the global test discovery message from optimizer.py which counts test invocations
# Format: "Discovered X existing unit tests and Y replay tests in Z.Zs at /path/to/tests"
unit_test_match = re.search(
r"Discovered (\d+) existing unit tests? and \d+ replay tests? in [\d.]+s at", stdout
)
if not unit_test_match:
logging.error("Could not find global unit test count")
return False
num_tests = int(unit_test_match.group(1))
if num_tests != config.expected_unit_tests_count:
logging.error(f"Expected {config.expected_unit_tests_count} global unit tests, found {num_tests}")
return False
if config.expected_unit_test_files is not None:
# Match the per-function discovery message from function_optimizer.py
# Format: "Discovered X existing unit test files, Y replay test files, and Z concolic..."
unit_test_files_match = re.search(r"Discovered (\d+) existing unit test files?", stdout)
if not unit_test_files_match:
logging.error("Could not find per-function unit test file count")
return False
num_test_files = int(unit_test_files_match.group(1))
if num_test_files != config.expected_unit_test_files:
logging.error(f"Expected {config.expected_unit_test_files} unit test files, found {num_test_files}")
return False
if config.coverage_expectations:
validate_coverage(stdout, config.coverage_expectations)
if config.no_gen_tests:
validate_no_gen_tests(stdout)
logging.info(f"Success: Performance improvement is {improvement_pct}%")
return True
def validate_stdout_in_candidate(stdout: str, expected_in_stdout: list[str]) -> bool:
candidate_output = stdout[stdout.find("INFO Best candidate") : stdout.find("Best Candidate Explanation")]
return all(expected in candidate_output for expected in expected_in_stdout)
def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int) -> bool:
pytest_dir = cwd / "tests" / "pytest"
test_root = pytest_dir if pytest_dir.is_dir() else cwd / "tests"
clear_directory(test_root)
command = ["uv", "run", "--no-project", "-m", "codeflash.main", "optimize", "workload.py"]
env = os.environ.copy()
env["PYTHONIOENCODING"] = "utf-8"
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding="utf-8"
)
output = []
for line in process.stdout:
logging.info(line.strip())
output.append(line)
return_code = process.wait()
stdout = "".join(output)
if return_code != 0:
logging.error(f"Tracer with optimization command returned exit code {return_code}")
return False
functions_traced = re.search(r"Traced (\d+) function calls successfully", stdout)
logging.info(functions_traced.groups() if functions_traced else "No functions traced")
if not functions_traced:
logging.error("Failed to find traced functions in output")
return False
if int(functions_traced.group(1)) != 8:
logging.error(functions_traced.groups())
logging.error("Expected 8 traced functions")
return False
# Validate optimization results (from optimization phase)
return validate_output(stdout, return_code, expected_improvement_pct, config)
def run_with_retries(test_func, *args, **kwargs) -> bool:
max_retries = int(os.getenv("MAX_RETRIES", 3))
retry_delay = int(os.getenv("RETRY_DELAY", 5))
log = logging.getLogger()
log.setLevel(logging.DEBUG)
for attempt in range(1, max_retries + 1):
logging.info(f"\n=== Attempt {attempt} of {max_retries} ===")
if test_func(*args, **kwargs):
logging.info(f"Test passed on attempt {attempt}")
return 0
logging.error(f"Test failed on attempt {attempt}")
if attempt < max_retries:
logging.info(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
logging.error("Test failed after all retries")
return 1
return 1