fix unit test (hopefully)
This commit is contained in:
parent
2cc39336aa
commit
ca4191268a
1 changed files with 35 additions and 94 deletions
|
|
@ -10,8 +10,7 @@ from typing import Any
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.tracer import FakeCode, FakeFrame, Tracer
|
||||
from codeflash.tracing.tracing_new_process import FakeCode, FakeFrame, Tracer
|
||||
|
||||
|
||||
class TestFakeCode:
|
||||
|
|
@ -54,7 +53,7 @@ class TestTracer:
|
|||
temp_dir = Path(tempfile.mkdtemp())
|
||||
tests_dir = temp_dir / "tests"
|
||||
tests_dir.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
# Use the current working directory as module root so test files are included
|
||||
current_dir = Path.cwd()
|
||||
|
||||
|
|
@ -69,6 +68,7 @@ ignore-paths = []
|
|||
config_path = Path(f.name)
|
||||
yield config_path
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -94,23 +94,18 @@ ignore-paths = []
|
|||
def test_tracer_disabled_by_environment(self, temp_config_file: Path, temp_trace_file: Path) -> None:
|
||||
"""Test that tracer is disabled when CODEFLASH_TRACER_DISABLE is set."""
|
||||
with patch.dict("os.environ", {"CODEFLASH_TRACER_DISABLE": "1"}):
|
||||
tracer = Tracer(
|
||||
output=str(temp_trace_file),
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
|
||||
assert tracer.disable is True
|
||||
|
||||
def test_tracer_disabled_with_existing_profiler(self, temp_config_file: Path, temp_trace_file: Path) -> None:
|
||||
"""Test that tracer is disabled when another profiler is running."""
|
||||
|
||||
def dummy_profiler(_frame: object, _event: str, _arg: object) -> object:
|
||||
return dummy_profiler
|
||||
|
||||
sys.setprofile(dummy_profiler)
|
||||
try:
|
||||
tracer = Tracer(
|
||||
output=str(temp_trace_file),
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
|
||||
assert tracer.disable is True
|
||||
finally:
|
||||
sys.setprofile(None)
|
||||
|
|
@ -122,7 +117,7 @@ ignore-paths = []
|
|||
functions=["test_func"],
|
||||
max_function_count=128,
|
||||
timeout=10,
|
||||
config_file_path=temp_config_file
|
||||
config_file_path=temp_config_file,
|
||||
)
|
||||
|
||||
assert tracer.disable is False
|
||||
|
|
@ -131,37 +126,23 @@ ignore-paths = []
|
|||
assert tracer.max_function_count == 128
|
||||
assert tracer.timeout == 10
|
||||
assert hasattr(tracer, "_db_lock")
|
||||
assert getattr(tracer, "_db_lock") is not None
|
||||
assert tracer._db_lock is not None
|
||||
|
||||
def test_tracer_timeout_validation(self, temp_config_file: Path, temp_trace_file: Path) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
Tracer(
|
||||
output=str(temp_trace_file),
|
||||
timeout=0,
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
Tracer(output=str(temp_trace_file), timeout=0, config_file_path=temp_config_file)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
Tracer(
|
||||
output=str(temp_trace_file),
|
||||
timeout=-5,
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
Tracer(output=str(temp_trace_file), timeout=-5, config_file_path=temp_config_file)
|
||||
|
||||
def test_tracer_context_manager_disabled(self, temp_config_file: Path, temp_trace_file: Path) -> None:
|
||||
tracer = Tracer(
|
||||
output=str(temp_trace_file),
|
||||
disable=True,
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
tracer = Tracer(output=str(temp_trace_file), disable=True, config_file_path=temp_config_file)
|
||||
|
||||
with tracer:
|
||||
pass
|
||||
|
||||
assert not temp_trace_file.exists()
|
||||
|
||||
|
||||
|
||||
def test_tracer_function_filtering(self, temp_config_file: Path, temp_trace_file: Path) -> None:
|
||||
"""Test that tracer respects function filtering."""
|
||||
if hasattr(Tracer, "used_once"):
|
||||
|
|
@ -173,11 +154,7 @@ ignore-paths = []
|
|||
def other_function() -> int:
|
||||
return 24
|
||||
|
||||
tracer = Tracer(
|
||||
output=str(temp_trace_file),
|
||||
functions=["test_function"],
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
tracer = Tracer(output=str(temp_trace_file), functions=["test_function"], config_file_path=temp_config_file)
|
||||
|
||||
with tracer:
|
||||
test_function()
|
||||
|
|
@ -197,21 +174,16 @@ ignore-paths = []
|
|||
|
||||
con.close()
|
||||
|
||||
|
||||
def test_tracer_max_function_count(self, temp_config_file: Path, temp_trace_file: Path) -> None:
|
||||
def counting_function(n: int) -> int:
|
||||
return n * 2
|
||||
|
||||
tracer = Tracer(
|
||||
output=str(temp_trace_file),
|
||||
max_function_count=3,
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
tracer = Tracer(output=str(temp_trace_file), max_function_count=3, config_file_path=temp_config_file)
|
||||
|
||||
with tracer:
|
||||
for i in range(5):
|
||||
counting_function(i)
|
||||
|
||||
|
||||
assert tracer.trace_count <= 3, "Tracer should limit the number of traced functions to max_function_count"
|
||||
|
||||
def test_tracer_timeout_functionality(self, temp_config_file: Path, temp_trace_file: Path) -> None:
|
||||
|
|
@ -222,7 +194,7 @@ ignore-paths = []
|
|||
tracer = Tracer(
|
||||
output=str(temp_trace_file),
|
||||
timeout=1, # 1 second timeout
|
||||
config_file_path=temp_config_file
|
||||
config_file_path=temp_config_file,
|
||||
)
|
||||
|
||||
with tracer:
|
||||
|
|
@ -235,10 +207,7 @@ ignore-paths = []
|
|||
def thread_function(n: int) -> None:
|
||||
results.append(n * 2)
|
||||
|
||||
tracer = Tracer(
|
||||
output=str(temp_trace_file),
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
|
||||
|
||||
with tracer:
|
||||
threads = []
|
||||
|
|
@ -254,29 +223,20 @@ ignore-paths = []
|
|||
|
||||
def test_simulate_call(self, temp_config_file: Path, temp_trace_file: Path) -> None:
|
||||
"""Test the simulate_call method."""
|
||||
tracer = Tracer(
|
||||
output=str(temp_trace_file),
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
|
||||
|
||||
tracer.simulate_call("test_simulation")
|
||||
|
||||
def test_simulate_cmd_complete(self, temp_config_file: Path, temp_trace_file: Path) -> None:
|
||||
"""Test the simulate_cmd_complete method."""
|
||||
tracer = Tracer(
|
||||
output=str(temp_trace_file),
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
|
||||
|
||||
tracer.simulate_call("test")
|
||||
tracer.simulate_cmd_complete()
|
||||
|
||||
def test_runctx_method(self, temp_config_file: Path, temp_trace_file: Path) -> None:
|
||||
"""Test the runctx method for executing code with tracing."""
|
||||
tracer = Tracer(
|
||||
output=str(temp_trace_file),
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
|
||||
|
||||
global_vars = {"x": 10}
|
||||
local_vars = {}
|
||||
|
|
@ -291,7 +251,7 @@ ignore-paths = []
|
|||
# Ensure tracer hasn't been used yet in this test
|
||||
if hasattr(Tracer, "used_once"):
|
||||
delattr(Tracer, "used_once")
|
||||
|
||||
|
||||
class TestClass:
|
||||
def instance_method(self) -> str:
|
||||
return "instance"
|
||||
|
|
@ -304,32 +264,27 @@ ignore-paths = []
|
|||
def static_method() -> str:
|
||||
return "static"
|
||||
|
||||
tracer = Tracer(
|
||||
output=str(temp_trace_file),
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
|
||||
|
||||
with tracer:
|
||||
obj = TestClass()
|
||||
instance_result = obj.instance_method()
|
||||
class_result = TestClass.class_method()
|
||||
static_result = TestClass.static_method()
|
||||
|
||||
|
||||
|
||||
if temp_trace_file.exists():
|
||||
con = sqlite3.connect(temp_trace_file)
|
||||
cursor = con.cursor()
|
||||
|
||||
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'")
|
||||
if cursor.fetchone():
|
||||
# Query for all function calls
|
||||
cursor.execute("SELECT function, classname FROM function_calls")
|
||||
calls = cursor.fetchall()
|
||||
|
||||
|
||||
function_names = [call[0] for call in calls]
|
||||
class_names = [call[1] for call in calls if call[1] is not None]
|
||||
|
||||
|
||||
assert "instance_method" in function_names
|
||||
assert "class_method" in function_names
|
||||
assert "static_method" in function_names
|
||||
|
|
@ -338,46 +293,31 @@ ignore-paths = []
|
|||
pytest.fail("No function_calls table found in trace file")
|
||||
con.close()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def test_tracer_handles_exceptions_gracefully(self, temp_config_file: Path, temp_trace_file: Path) -> None:
|
||||
"""Test that tracer handles exceptions in traced code gracefully."""
|
||||
|
||||
def failing_function() -> None:
|
||||
raise ValueError("Test exception")
|
||||
|
||||
tracer = Tracer(
|
||||
output=str(temp_trace_file),
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
|
||||
|
||||
with tracer, contextlib.suppress(ValueError):
|
||||
failing_function()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def test_tracer_with_complex_arguments(self, temp_config_file: Path, temp_trace_file: Path) -> None:
|
||||
def complex_function(data_dict: dict[str, Any], nested_list: list[list[int]], func_arg: object = lambda x: x) -> int:
|
||||
def complex_function(
|
||||
data_dict: dict[str, Any], nested_list: list[list[int]], func_arg: object = lambda x: x
|
||||
) -> int:
|
||||
return len(data_dict) + len(nested_list)
|
||||
|
||||
tracer = Tracer(
|
||||
output=str(temp_trace_file),
|
||||
config_file_path=temp_config_file
|
||||
)
|
||||
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
|
||||
|
||||
expected_dict = {"key": "value", "nested": {"inner": "data"}}
|
||||
expected_list = [[1, 2], [3, 4], [5, 6]]
|
||||
expected_func = lambda x: x * 2
|
||||
|
||||
with tracer:
|
||||
complex_function(
|
||||
expected_dict,
|
||||
expected_list,
|
||||
func_arg=expected_func
|
||||
)
|
||||
complex_function(expected_dict, expected_list, func_arg=expected_func)
|
||||
|
||||
if temp_trace_file.exists():
|
||||
con = sqlite3.connect(temp_trace_file)
|
||||
|
|
@ -388,15 +328,16 @@ ignore-paths = []
|
|||
cursor.execute("SELECT args FROM function_calls WHERE function = 'complex_function'")
|
||||
result = cursor.fetchone()
|
||||
assert result is not None, "Function complex_function should be traced"
|
||||
|
||||
|
||||
# Deserialize the arguments
|
||||
import pickle
|
||||
|
||||
traced_args = pickle.loads(result[0])
|
||||
|
||||
|
||||
assert "data_dict" in traced_args
|
||||
assert "nested_list" in traced_args
|
||||
assert "func_arg" in traced_args
|
||||
|
||||
|
||||
assert traced_args["data_dict"] == expected_dict
|
||||
assert traced_args["nested_list"] == expected_list
|
||||
assert callable(traced_args["func_arg"])
|
||||
|
|
|
|||
Loading…
Reference in a new issue