Add unit tests for _benchmark_worker subprocess script

5 tests covering module-level argv parsing, project_root derivation,
benchmark plugin and trace decorator imports, and __main__ guard.
This commit is contained in:
Kevin Turcios 2026-04-23 02:31:38 -05:00
parent e2135e39b2
commit dd7d2db451

View file

@ -0,0 +1,60 @@
"""Tests for the benchmark execution subprocess worker."""
from __future__ import annotations
import importlib
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture(autouse=True)
def _patch_argv(monkeypatch, tmp_path):
"""Ensure sys.argv has enough entries for module-level reads."""
monkeypatch.setattr(
sys,
"argv",
[
"_benchmark_worker.py",
str(tmp_path / "benchmarks"),
str(tmp_path / "tests"),
str(tmp_path / "trace.db"),
],
)
@pytest.fixture
def _worker_mod():
"""Import the benchmark worker module (needs patched argv)."""
mod_name = "codeflash_python.benchmarking._benchmark_worker"
if mod_name in sys.modules:
return importlib.reload(sys.modules[mod_name])
return importlib.import_module(mod_name)
class TestBenchmarkWorkerModule:
"""Tests for the benchmark worker module structure."""
def test_module_level_args_parsed(self, _worker_mod, tmp_path) -> None:
"""Module-level argv reads produce correct paths."""
assert str(tmp_path / "benchmarks") == _worker_mod.benchmarks_root
assert str(tmp_path / "tests") == _worker_mod.tests_root
assert str(tmp_path / "trace.db") == _worker_mod.trace_file
def test_project_root_is_cwd(self, _worker_mod) -> None:
"""project_root defaults to Path.cwd()."""
assert Path.cwd() == _worker_mod.project_root
def test_imports_benchmark_plugin(self, _worker_mod) -> None:
"""Module imports the codeflash benchmark plugin."""
assert hasattr(_worker_mod, "codeflash_benchmark_plugin")
def test_imports_codeflash_trace(self, _worker_mod) -> None:
"""Module imports the codeflash trace decorator."""
assert hasattr(_worker_mod, "codeflash_trace")
def test_not_main_skips_pytest(self, _worker_mod) -> None:
"""When not __main__, pytest.main is not called."""
assert _worker_mod.__name__ != "__main__"