fix unit test (hopefully)

This commit is contained in:
Saurabh Misra 2025-07-10 22:16:15 -07:00
parent 2cc39336aa
commit ca4191268a

View file

@ -10,8 +10,7 @@ from typing import Any
from unittest.mock import patch
import pytest
from codeflash.tracer import FakeCode, FakeFrame, Tracer
from codeflash.tracing.tracing_new_process import FakeCode, FakeFrame, Tracer
class TestFakeCode:
@ -54,7 +53,7 @@ class TestTracer:
temp_dir = Path(tempfile.mkdtemp())
tests_dir = temp_dir / "tests"
tests_dir.mkdir(exist_ok=True)
# Use the current working directory as module root so test files are included
current_dir = Path.cwd()
@ -69,6 +68,7 @@ ignore-paths = []
config_path = Path(f.name)
yield config_path
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
@ -94,23 +94,18 @@ ignore-paths = []
def test_tracer_disabled_by_environment(self, temp_config_file: Path, temp_trace_file: Path) -> None:
"""Test that tracer is disabled when CODEFLASH_TRACER_DISABLE is set."""
with patch.dict("os.environ", {"CODEFLASH_TRACER_DISABLE": "1"}):
tracer = Tracer(
output=str(temp_trace_file),
config_file_path=temp_config_file
)
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
assert tracer.disable is True
def test_tracer_disabled_with_existing_profiler(self, temp_config_file: Path, temp_trace_file: Path) -> None:
"""Test that tracer is disabled when another profiler is running."""
def dummy_profiler(_frame: object, _event: str, _arg: object) -> object:
return dummy_profiler
sys.setprofile(dummy_profiler)
try:
tracer = Tracer(
output=str(temp_trace_file),
config_file_path=temp_config_file
)
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
assert tracer.disable is True
finally:
sys.setprofile(None)
@ -122,7 +117,7 @@ ignore-paths = []
functions=["test_func"],
max_function_count=128,
timeout=10,
config_file_path=temp_config_file
config_file_path=temp_config_file,
)
assert tracer.disable is False
@ -131,37 +126,23 @@ ignore-paths = []
assert tracer.max_function_count == 128
assert tracer.timeout == 10
assert hasattr(tracer, "_db_lock")
assert getattr(tracer, "_db_lock") is not None
assert tracer._db_lock is not None
def test_tracer_timeout_validation(self, temp_config_file: Path, temp_trace_file: Path) -> None:
with pytest.raises(AssertionError):
Tracer(
output=str(temp_trace_file),
timeout=0,
config_file_path=temp_config_file
)
Tracer(output=str(temp_trace_file), timeout=0, config_file_path=temp_config_file)
with pytest.raises(AssertionError):
Tracer(
output=str(temp_trace_file),
timeout=-5,
config_file_path=temp_config_file
)
Tracer(output=str(temp_trace_file), timeout=-5, config_file_path=temp_config_file)
def test_tracer_context_manager_disabled(self, temp_config_file: Path, temp_trace_file: Path) -> None:
tracer = Tracer(
output=str(temp_trace_file),
disable=True,
config_file_path=temp_config_file
)
tracer = Tracer(output=str(temp_trace_file), disable=True, config_file_path=temp_config_file)
with tracer:
pass
assert not temp_trace_file.exists()
def test_tracer_function_filtering(self, temp_config_file: Path, temp_trace_file: Path) -> None:
"""Test that tracer respects function filtering."""
if hasattr(Tracer, "used_once"):
@ -173,11 +154,7 @@ ignore-paths = []
def other_function() -> int:
return 24
tracer = Tracer(
output=str(temp_trace_file),
functions=["test_function"],
config_file_path=temp_config_file
)
tracer = Tracer(output=str(temp_trace_file), functions=["test_function"], config_file_path=temp_config_file)
with tracer:
test_function()
@ -197,21 +174,16 @@ ignore-paths = []
con.close()
def test_tracer_max_function_count(self, temp_config_file: Path, temp_trace_file: Path) -> None:
def counting_function(n: int) -> int:
return n * 2
tracer = Tracer(
output=str(temp_trace_file),
max_function_count=3,
config_file_path=temp_config_file
)
tracer = Tracer(output=str(temp_trace_file), max_function_count=3, config_file_path=temp_config_file)
with tracer:
for i in range(5):
counting_function(i)
assert tracer.trace_count <= 3, "Tracer should limit the number of traced functions to max_function_count"
def test_tracer_timeout_functionality(self, temp_config_file: Path, temp_trace_file: Path) -> None:
@ -222,7 +194,7 @@ ignore-paths = []
tracer = Tracer(
output=str(temp_trace_file),
timeout=1, # 1 second timeout
config_file_path=temp_config_file
config_file_path=temp_config_file,
)
with tracer:
@ -235,10 +207,7 @@ ignore-paths = []
def thread_function(n: int) -> None:
results.append(n * 2)
tracer = Tracer(
output=str(temp_trace_file),
config_file_path=temp_config_file
)
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
with tracer:
threads = []
@ -254,29 +223,20 @@ ignore-paths = []
def test_simulate_call(self, temp_config_file: Path, temp_trace_file: Path) -> None:
"""Test the simulate_call method."""
tracer = Tracer(
output=str(temp_trace_file),
config_file_path=temp_config_file
)
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
tracer.simulate_call("test_simulation")
def test_simulate_cmd_complete(self, temp_config_file: Path, temp_trace_file: Path) -> None:
"""Test the simulate_cmd_complete method."""
tracer = Tracer(
output=str(temp_trace_file),
config_file_path=temp_config_file
)
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
tracer.simulate_call("test")
tracer.simulate_cmd_complete()
def test_runctx_method(self, temp_config_file: Path, temp_trace_file: Path) -> None:
"""Test the runctx method for executing code with tracing."""
tracer = Tracer(
output=str(temp_trace_file),
config_file_path=temp_config_file
)
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
global_vars = {"x": 10}
local_vars = {}
@ -291,7 +251,7 @@ ignore-paths = []
# Ensure tracer hasn't been used yet in this test
if hasattr(Tracer, "used_once"):
delattr(Tracer, "used_once")
class TestClass:
def instance_method(self) -> str:
return "instance"
@ -304,32 +264,27 @@ ignore-paths = []
def static_method() -> str:
return "static"
tracer = Tracer(
output=str(temp_trace_file),
config_file_path=temp_config_file
)
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
with tracer:
obj = TestClass()
instance_result = obj.instance_method()
class_result = TestClass.class_method()
static_result = TestClass.static_method()
if temp_trace_file.exists():
con = sqlite3.connect(temp_trace_file)
cursor = con.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'")
if cursor.fetchone():
# Query for all function calls
cursor.execute("SELECT function, classname FROM function_calls")
calls = cursor.fetchall()
function_names = [call[0] for call in calls]
class_names = [call[1] for call in calls if call[1] is not None]
assert "instance_method" in function_names
assert "class_method" in function_names
assert "static_method" in function_names
@ -338,46 +293,31 @@ ignore-paths = []
pytest.fail("No function_calls table found in trace file")
con.close()
def test_tracer_handles_exceptions_gracefully(self, temp_config_file: Path, temp_trace_file: Path) -> None:
"""Test that tracer handles exceptions in traced code gracefully."""
def failing_function() -> None:
raise ValueError("Test exception")
tracer = Tracer(
output=str(temp_trace_file),
config_file_path=temp_config_file
)
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
with tracer, contextlib.suppress(ValueError):
failing_function()
def test_tracer_with_complex_arguments(self, temp_config_file: Path, temp_trace_file: Path) -> None:
def complex_function(data_dict: dict[str, Any], nested_list: list[list[int]], func_arg: object = lambda x: x) -> int:
def complex_function(
data_dict: dict[str, Any], nested_list: list[list[int]], func_arg: object = lambda x: x
) -> int:
return len(data_dict) + len(nested_list)
tracer = Tracer(
output=str(temp_trace_file),
config_file_path=temp_config_file
)
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
expected_dict = {"key": "value", "nested": {"inner": "data"}}
expected_list = [[1, 2], [3, 4], [5, 6]]
expected_func = lambda x: x * 2
with tracer:
complex_function(
expected_dict,
expected_list,
func_arg=expected_func
)
complex_function(expected_dict, expected_list, func_arg=expected_func)
if temp_trace_file.exists():
con = sqlite3.connect(temp_trace_file)
@ -388,15 +328,16 @@ ignore-paths = []
cursor.execute("SELECT args FROM function_calls WHERE function = 'complex_function'")
result = cursor.fetchone()
assert result is not None, "Function complex_function should be traced"
# Deserialize the arguments
import pickle
traced_args = pickle.loads(result[0])
assert "data_dict" in traced_args
assert "nested_list" in traced_args
assert "func_arg" in traced_args
assert traced_args["data_dict"] == expected_dict
assert traced_args["nested_list"] == expected_list
assert callable(traced_args["func_arg"])