codeflash/tests/test_tracer.py
2026-01-29 01:39:48 -08:00

443 lines
16 KiB
Python

import contextlib
import dataclasses
import pickle
import sqlite3
import sys
import threading
import time
from collections.abc import Generator
from pathlib import Path
from typing import Any
from unittest.mock import patch
import pytest
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.tracing.tracing_new_process import FakeCode, FakeFrame, Tracer
class TestFakeCode:
def test_fake_code_initialization(self) -> None:
fake_code = FakeCode("test.py", 10, "test_function")
assert fake_code.co_filename == "test.py"
assert fake_code.co_line == 10
assert fake_code.co_name == "test_function"
assert fake_code.co_firstlineno == 0
def test_fake_code_repr(self) -> None:
fake_code = FakeCode("test.py", 10, "test_function")
expected_repr = repr(("test.py", 10, "test_function", None))
assert repr(fake_code) == expected_repr
class TestFakeFrame:
def test_fake_frame_initialization(self) -> None:
fake_code = FakeCode("test.py", 10, "test_function")
fake_frame = FakeFrame(fake_code, None)
assert fake_frame.f_code == fake_code
assert fake_frame.f_back is None
assert fake_frame.f_locals == {}
def test_fake_frame_with_prior(self) -> None:
fake_code1 = FakeCode("test1.py", 5, "func1")
fake_code2 = FakeCode("test2.py", 10, "func2")
fake_frame1 = FakeFrame(fake_code1, None)
fake_frame2 = FakeFrame(fake_code2, fake_frame1)
assert fake_frame2.f_code == fake_code2
assert fake_frame2.f_back == fake_frame1
@dataclasses.dataclass
class TraceConfig:
trace_file: Path
trace_config: dict[str, Any]
result_pickle_file_path: Path
project_root: Path
command: str
class TestTracer:
@pytest.fixture
def trace_config(self, tmp_path: Path) -> Generator[TraceConfig, None, None]:
"""Create a temporary pyproject.toml config file."""
# Create a temporary directory structure
tests_dir = tmp_path / "tests"
tests_dir.mkdir(exist_ok=True)
# Use the current working directory as module root so test files are included
current_dir = Path.cwd()
config_path = tmp_path / "pyproject.toml"
config_path.write_text(
f"""
[tool.codeflash]
module-root = "{current_dir.as_posix()}"
tests-root = "{tests_dir.as_posix()}"
test-framework = "pytest"
ignore-paths = []
""",
encoding="utf-8",
)
trace_path = tmp_path / "trace_file.trace"
replay_test_pkl_path = tmp_path / "replay_test.pkl"
config, found_config_path = parse_config_file(config_path)
trace_config = TraceConfig(
trace_file=trace_path,
trace_config=config,
result_pickle_file_path=replay_test_pkl_path,
project_root=current_dir,
command="pytest random",
)
return trace_config
@pytest.fixture(autouse=True)
def reset_tracer_state(self) -> Generator[None, None, None]:
"""Reset the tracer used_once state before each test."""
# Reset the class variable if it exists
if hasattr(Tracer, "used_once"):
delattr(Tracer, "used_once")
yield
# Reset after test as well
if hasattr(Tracer, "used_once"):
delattr(Tracer, "used_once")
def test_tracer_disabled_by_environment(self, trace_config: TraceConfig) -> None:
"""Test that tracer is disabled when CODEFLASH_TRACER_DISABLE is set."""
with patch.dict("os.environ", {"CODEFLASH_TRACER_DISABLE": "1"}):
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
)
assert tracer.disable is True
def test_tracer_disabled_with_existing_profiler(self, trace_config: TraceConfig) -> None:
"""Test that tracer is disabled when another profiler is running."""
def dummy_profiler(_frame: object, _event: str, _arg: object) -> object:
return dummy_profiler
sys.setprofile(dummy_profiler)
try:
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
)
assert tracer.disable is True
finally:
sys.setprofile(None)
def test_tracer_initialization_normal(self, trace_config: TraceConfig) -> None:
"""Test normal tracer initialization."""
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
functions=["test_func"],
max_function_count=128,
timeout=10,
)
assert tracer.disable is False
assert tracer.functions == ["test_func"]
assert tracer.max_function_count == 128
assert tracer.timeout == 10
assert hasattr(tracer, "_db_lock")
assert tracer._db_lock is not None
def test_tracer_timeout_validation(self, trace_config: TraceConfig) -> None:
with pytest.raises(AssertionError):
Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
timeout=0,
)
with pytest.raises(AssertionError):
Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
timeout=-5,
)
def test_tracer_context_manager_disabled(self, trace_config: TraceConfig) -> None:
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
disable=True,
)
with tracer:
pass
# When disabled, the tracer doesn't create a trace file
# Note: output_file attribute won't exist when disabled, so we check if disable is True
assert tracer.disable is True
def test_tracer_function_filtering(self, trace_config: TraceConfig) -> None:
"""Test that tracer respects function filtering."""
if hasattr(Tracer, "used_once"):
delattr(Tracer, "used_once")
def test_function() -> int:
return 42
def other_function() -> int:
return 24
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
functions=["test_function"],
)
with tracer:
test_function()
other_function()
if tracer.output_file.exists():
con = sqlite3.connect(tracer.output_file)
cursor = con.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'")
if cursor.fetchone():
cursor.execute("SELECT function FROM function_calls WHERE function = 'test_function'")
cursor.fetchall()
cursor.execute("SELECT function FROM function_calls WHERE function = 'other_function'")
cursor.fetchall()
con.close()
def test_tracer_max_function_count(self, trace_config: TraceConfig) -> None:
def counting_function(n: int) -> int:
return n * 2
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
max_function_count=3,
)
with tracer:
for i in range(5):
counting_function(i)
assert tracer.trace_count <= 3, "Tracer should limit the number of traced functions to max_function_count"
def test_tracer_timeout_functionality(self, trace_config: TraceConfig) -> None:
def slow_function() -> str:
time.sleep(0.1)
return "done"
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
timeout=1, # 1 second timeout
)
with tracer:
slow_function()
def test_tracer_threading_safety(self, trace_config: TraceConfig) -> None:
"""Test that tracer works correctly with threading."""
results = []
def thread_function(n: int) -> None:
results.append(n * 2)
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
)
with tracer:
threads = []
for i in range(3):
thread = threading.Thread(target=thread_function, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
assert len(results) == 3
def test_simulate_call(self, trace_config: TraceConfig) -> None:
"""Test the simulate_call method."""
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
)
tracer.simulate_call("test_simulation")
def test_simulate_cmd_complete(self, trace_config: TraceConfig) -> None:
"""Test the simulate_cmd_complete method."""
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
)
tracer.simulate_call("test")
tracer.simulate_cmd_complete()
def test_runctx_method(self, trace_config: TraceConfig) -> None:
"""Test the runctx method for executing code with tracing."""
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
)
global_vars = {"x": 10}
local_vars = {}
result = tracer.runctx("y = x * 2", global_vars, local_vars)
assert result == tracer
assert local_vars["y"] == 20
def test_tracer_handles_class_methods(self, trace_config: TraceConfig) -> None:
"""Test that tracer correctly handles class methods."""
# Ensure tracer hasn't been used yet in this test
if hasattr(Tracer, "used_once"):
delattr(Tracer, "used_once")
class TestClass:
def instance_method(self) -> str:
return "instance"
@classmethod
def class_method(cls) -> str:
return "class"
@staticmethod
def static_method() -> str:
return "static"
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
)
with tracer:
obj = TestClass()
instance_result = obj.instance_method()
class_result = TestClass.class_method()
static_result = TestClass.static_method()
if tracer.output_file.exists():
con = sqlite3.connect(tracer.output_file)
cursor = con.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'")
if cursor.fetchone():
# Query for all function calls
cursor.execute("SELECT function, classname FROM function_calls")
calls = cursor.fetchall()
function_names = [call[0] for call in calls]
class_names = [call[1] for call in calls if call[1] is not None]
assert "instance_method" in function_names
assert "class_method" in function_names
assert "static_method" in function_names
assert "TestClass" in class_names
else:
pytest.fail("No function_calls table found in trace file")
con.close()
def test_tracer_handles_exceptions_gracefully(self, trace_config: TraceConfig) -> None:
"""Test that tracer handles exceptions in traced code gracefully."""
def failing_function() -> None:
raise ValueError("Test exception")
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
)
with tracer, contextlib.suppress(ValueError):
failing_function()
def test_tracer_with_complex_arguments(self, trace_config: TraceConfig) -> None:
def complex_function(
data_dict: dict[str, Any], nested_list: list[list[int]], func_arg: object = lambda x: x
) -> int:
return len(data_dict) + len(nested_list)
tracer = Tracer(
config=trace_config.trace_config,
project_root=trace_config.project_root,
result_pickle_file_path=trace_config.result_pickle_file_path,
command=trace_config.command,
)
expected_dict = {"key": "value", "nested": {"inner": "data"}}
expected_list = [[1, 2], [3, 4], [5, 6]]
expected_func = lambda x: x * 2
with tracer:
complex_function(expected_dict, expected_list, func_arg=expected_func)
assert trace_config.result_pickle_file_path.exists()
pickled = pickle.load(trace_config.result_pickle_file_path.open("rb"))
assert pickled["replay_test_file_path"].exists()
if tracer.output_file.exists():
con = sqlite3.connect(tracer.output_file)
cursor = con.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'")
if cursor.fetchone():
cursor.execute("SELECT args FROM function_calls WHERE function = 'complex_function'")
result = cursor.fetchone()
assert result is not None, "Function complex_function should be traced"
# Deserialize the arguments
traced_args = pickle.loads(result[0])
assert "data_dict" in traced_args
assert "nested_list" in traced_args
assert "func_arg" in traced_args
assert traced_args["data_dict"] == expected_dict
assert traced_args["nested_list"] == expected_list
assert callable(traced_args["func_arg"])
assert traced_args["func_arg"](2) == 4
assert len(traced_args["nested_list"]) == 3
else:
pytest.fail("No function_calls table found in trace file")
con.close()