mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
feat: add codeflash_behavior_sync decorator
Same pattern as the async behavior decorator: decorates the function-under-test directly, captures return values, timing (wall + CPU), and stdout into the shared async_results SQLite table. This is the first step toward replacing the AST-injected codeflash_wrap approach for sync functions.
This commit is contained in:
parent
c9f65aba6b
commit
8c218038e9
3 changed files with 218 additions and 1 deletions
|
|
@ -129,6 +129,98 @@ def _close_all_connections() -> None:
|
|||
atexit.register(_close_all_connections)
|
||||
|
||||
|
||||
def codeflash_behavior_sync(func: F) -> F:
|
||||
"""
|
||||
Capture sync return values, timing, and stdout for behavioral tests.
|
||||
|
||||
Results are written to the async_results SQLite table.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
function_name = func.__name__
|
||||
call_site = _codeflash_call_site.get()
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
) = extract_test_context_from_env()
|
||||
|
||||
test_id = (
|
||||
f"{test_module_name}:{test_class_name}"
|
||||
f":{test_name}:{call_site}:{loop_index}"
|
||||
)
|
||||
|
||||
if not hasattr(wrapper, "index"):
|
||||
wrapper.index = {} # type: ignore[attr-defined]
|
||||
if test_id in wrapper.index: # type: ignore[attr-defined]
|
||||
wrapper.index[test_id] += 1 # type: ignore[attr-defined]
|
||||
else:
|
||||
wrapper.index[test_id] = 0 # type: ignore[attr-defined]
|
||||
|
||||
call_index = wrapper.index[test_id] # type: ignore[attr-defined]
|
||||
invocation_id = f"{call_site}_{call_index}"
|
||||
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
|
||||
conn, cur = _get_async_db(db_path)
|
||||
|
||||
exception = None
|
||||
captured_stdout = io.StringIO()
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = captured_stdout
|
||||
gc.disable()
|
||||
try:
|
||||
counter = time.perf_counter_ns()
|
||||
cpu_counter = time.thread_time_ns()
|
||||
return_value = func(*args, **kwargs)
|
||||
wall_time = time.perf_counter_ns() - counter
|
||||
cpu_time = time.thread_time_ns() - cpu_counter
|
||||
except Exception as e:
|
||||
wall_time = time.perf_counter_ns() - counter
|
||||
cpu_time = time.thread_time_ns() - cpu_counter
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
sys.stdout = old_stdout
|
||||
|
||||
stdout_text = captured_stdout.getvalue()
|
||||
|
||||
pickled = (
|
||||
pickle.dumps(exception)
|
||||
if exception
|
||||
else pickle.dumps((args, kwargs, return_value))
|
||||
)
|
||||
cur.execute(
|
||||
"INSERT INTO async_results VALUES "
|
||||
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
function_name,
|
||||
loop_index,
|
||||
invocation_id,
|
||||
"behavior",
|
||||
wall_time,
|
||||
pickled,
|
||||
VerificationType.FUNCTION_CALL.value,
|
||||
cpu_time,
|
||||
None,
|
||||
None,
|
||||
stdout_text,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
|
||||
def codeflash_behavior_async(func: F) -> F:
|
||||
"""
|
||||
Capture async return values and timing for behavioral tests.
|
||||
|
|
@ -374,6 +466,7 @@ __all__ = [
|
|||
"VerificationType",
|
||||
"_codeflash_call_site",
|
||||
"codeflash_behavior_async",
|
||||
"codeflash_behavior_sync",
|
||||
"codeflash_concurrency_async",
|
||||
"codeflash_performance_async",
|
||||
"extract_test_context_from_env",
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def parse_sqlite_test_results(
|
|||
" function_getting_tested, loop_index,"
|
||||
" iteration_id, runtime,"
|
||||
" return_value, verification_type,"
|
||||
" cpu_runtime"
|
||||
" cpu_runtime, stdout"
|
||||
" FROM test_results"
|
||||
).fetchall()
|
||||
except sqlite3.Error:
|
||||
|
|
@ -101,6 +101,7 @@ def _process_sqlite_row_inner(
|
|||
runtime = val[6]
|
||||
verification_type = val[8]
|
||||
cpu_runtime = val[9]
|
||||
stdout_text = val[10] if len(val) > 10 else None
|
||||
|
||||
test_file_path = file_path_from_module_name(
|
||||
test_module_path, # type: ignore[arg-type]
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from codeflash_python.runtime._codeflash_async_decorators import (
|
|||
_connections,
|
||||
_get_async_db,
|
||||
codeflash_behavior_async,
|
||||
codeflash_behavior_sync,
|
||||
codeflash_concurrency_async,
|
||||
codeflash_performance_async,
|
||||
extract_test_context_from_env,
|
||||
|
|
@ -261,6 +262,128 @@ class TestBehaviorAsync:
|
|||
con.close()
|
||||
|
||||
|
||||
class TestBehaviorSync:
|
||||
"""codeflash_behavior_sync decorator."""
|
||||
|
||||
def test_returns_correct_value(self, env_setup, async_db_path) -> None:
|
||||
"""Decorated function returns the original return value."""
|
||||
|
||||
@codeflash_behavior_sync
|
||||
def add(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
result = add(3, 4)
|
||||
assert 7 == result
|
||||
|
||||
def test_writes_to_sqlite(self, env_setup, async_db_path) -> None:
|
||||
"""Writes behavior result to async_results table."""
|
||||
|
||||
@codeflash_behavior_sync
|
||||
def multiply(a: int, b: int) -> int:
|
||||
return a * b
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
multiply(5, 6)
|
||||
_close_all_connections()
|
||||
|
||||
assert async_db_path.exists()
|
||||
con = sqlite3.connect(async_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT * FROM async_results")
|
||||
rows = cur.fetchall()
|
||||
assert 1 == len(rows)
|
||||
row = rows[0]
|
||||
assert "behavior" == row[6]
|
||||
assert 0 < row[7]
|
||||
|
||||
data = pickle.loads(row[8])
|
||||
args, kwargs, ret = data
|
||||
assert (5, 6) == args
|
||||
assert {} == kwargs
|
||||
assert 30 == ret
|
||||
assert VerificationType.FUNCTION_CALL.value == row[9]
|
||||
con.close()
|
||||
|
||||
def test_exception_handling(self, env_setup, async_db_path) -> None:
|
||||
"""Re-raises exceptions and stores them pickled."""
|
||||
|
||||
@codeflash_behavior_sync
|
||||
def fail() -> None:
|
||||
raise ValueError("boom")
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
fail()
|
||||
|
||||
_close_all_connections()
|
||||
con = sqlite3.connect(async_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT return_value FROM async_results")
|
||||
row = cur.fetchone()
|
||||
exc = pickle.loads(row[0])
|
||||
assert isinstance(exc, ValueError)
|
||||
assert "boom" == str(exc)
|
||||
con.close()
|
||||
|
||||
def test_captures_stdout_in_sqlite(
|
||||
self, env_setup, async_db_path
|
||||
) -> None:
|
||||
"""Captures print output into the stdout column."""
|
||||
|
||||
@codeflash_behavior_sync
|
||||
def greeter(name: str) -> str:
|
||||
print(f"hello {name}")
|
||||
return f"hi {name}"
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
greeter("world")
|
||||
_close_all_connections()
|
||||
|
||||
con = sqlite3.connect(async_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT stdout FROM async_results")
|
||||
row = cur.fetchone()
|
||||
assert "hello world\n" == row[0]
|
||||
con.close()
|
||||
|
||||
def test_no_stdout_leak(
|
||||
self, env_setup, async_db_path, capsys
|
||||
) -> None:
|
||||
"""Sync decorator does not leak stdout to outer scope."""
|
||||
|
||||
@codeflash_behavior_sync
|
||||
def quiet() -> int:
|
||||
print("captured")
|
||||
return 1
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
quiet()
|
||||
captured = capsys.readouterr()
|
||||
assert "" == captured.out
|
||||
|
||||
def test_records_cpu_time(self, env_setup, async_db_path) -> None:
|
||||
"""Records cpu_time in the sequential_time_ns column."""
|
||||
|
||||
@codeflash_behavior_sync
|
||||
def work() -> int:
|
||||
return sum(range(1000))
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
work()
|
||||
_close_all_connections()
|
||||
|
||||
con = sqlite3.connect(async_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute(
|
||||
"SELECT sequential_time_ns FROM async_results"
|
||||
)
|
||||
row = cur.fetchone()
|
||||
assert row[0] is not None
|
||||
assert 0 <= row[0]
|
||||
con.close()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32",
|
||||
reason="pending support for asyncio on windows",
|
||||
|
|
|
|||
Loading…
Reference in a new issue