mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
refactor: eliminate inline async decorator duplication and fix 10-column test gaps
Replace 218-line ASYNC_HELPER_INLINE_CODE string with shutil.copy2 of the runtime decorator file. Update remaining test files for 10-column SQLite schema (cpu_runtime). Add cpu_runtime assertions to async E2E tests.
This commit is contained in:
parent
eb6a0be717
commit
2fd9d06e28
6 changed files with 33 additions and 254 deletions
|
|
@ -2,14 +2,16 @@
|
|||
|
||||
Provides ``AsyncCallInstrumenter`` for injecting ``CODEFLASH_CURRENT_LINE_ID``
|
||||
assignments before ``await`` calls, ``AsyncDecoratorAdder`` for adding
|
||||
async performance/behavior decorators via libcst, the inline async helper
|
||||
code, and high-level functions for instrumenting async test and source files.
|
||||
async performance/behavior decorators via libcst, and high-level functions
|
||||
for instrumenting async test and source files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import libcst as cst
|
||||
|
|
@ -26,8 +28,6 @@ from ._instrument_core import (
|
|||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from ..test_discovery.models import CodePosition
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -126,25 +126,6 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
|
|||
node.body = new_body
|
||||
return node
|
||||
|
||||
def _instrument_statement(
|
||||
self, stmt: ast.stmt, _node_name: str
|
||||
) -> tuple[ast.stmt, bool]:
|
||||
"""Check whether a statement contains an awaited target call."""
|
||||
for node in ast.walk(stmt):
|
||||
if (
|
||||
isinstance(node, ast.Await)
|
||||
and isinstance(node.value, ast.Call)
|
||||
and self._is_target_call(node.value)
|
||||
and self._call_in_positions(node.value)
|
||||
):
|
||||
# Check if this call is in one of our target positions
|
||||
return (
|
||||
stmt,
|
||||
True,
|
||||
) # Return original statement but signal we added env var
|
||||
|
||||
return stmt, False
|
||||
|
||||
def _is_target_call(self, call_node: ast.Call) -> bool:
|
||||
"""Check if this call node is calling our target async function."""
|
||||
if isinstance(call_node.func, ast.Name):
|
||||
|
|
@ -303,228 +284,14 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
|
|||
return False
|
||||
|
||||
|
||||
ASYNC_HELPER_INLINE_CODE = """import asyncio
|
||||
import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import dill as pickle
|
||||
|
||||
|
||||
def get_run_tmp_file(file_path):
|
||||
if not hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
|
||||
return Path(get_run_tmp_file.tmpdir.name) / file_path
|
||||
|
||||
|
||||
def extract_test_context_from_env():
|
||||
test_module = os.environ["CODEFLASH_TEST_MODULE"]
|
||||
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
|
||||
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
|
||||
if test_module and test_function:
|
||||
return (test_module, test_class if test_class else None, test_function)
|
||||
raise RuntimeError(
|
||||
"Test context environment variables not set"
|
||||
" - ensure tests are run through"
|
||||
" codeflash test runner"
|
||||
)
|
||||
|
||||
|
||||
def codeflash_behavior_async(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
(test_module_name, test_class_name,
|
||||
test_name) = extract_test_context_from_env()
|
||||
test_id = (
|
||||
f"{test_module_name}:{test_class_name}"
|
||||
f":{test_name}:{line_id}:{loop_index}"
|
||||
)
|
||||
if not hasattr(async_wrapper, "index"):
|
||||
async_wrapper.index = {}
|
||||
if test_id in async_wrapper.index:
|
||||
async_wrapper.index[test_id] += 1
|
||||
else:
|
||||
async_wrapper.index[test_id] = 0
|
||||
codeflash_test_index = async_wrapper.index[test_id]
|
||||
invocation_id = f"{line_id}_{codeflash_test_index}"
|
||||
class_prefix = (
|
||||
(test_class_name + ".") if test_class_name else ""
|
||||
)
|
||||
test_stdout_tag = (
|
||||
f"{test_module_name}:{class_prefix}"
|
||||
f"{test_name}:{function_name}"
|
||||
f":{loop_index}:{invocation_id}"
|
||||
)
|
||||
print(f"!$######{test_stdout_tag}######$!")
|
||||
iteration = os.environ.get(
|
||||
"CODEFLASH_TEST_ITERATION", "0"
|
||||
)
|
||||
db_path = get_run_tmp_file(
|
||||
Path(f"test_return_values_{iteration}.sqlite")
|
||||
)
|
||||
codeflash_con = sqlite3.connect(db_path)
|
||||
codeflash_cur = codeflash_con.cursor()
|
||||
codeflash_cur.execute(
|
||||
"CREATE TABLE IF NOT EXISTS test_results"
|
||||
" (test_module_path TEXT,"
|
||||
" test_class_name TEXT,"
|
||||
" test_function_name TEXT,"
|
||||
" function_getting_tested TEXT,"
|
||||
" loop_index INTEGER,"
|
||||
" iteration_id TEXT,"
|
||||
" runtime INTEGER,"
|
||||
" return_value BLOB,"
|
||||
" verification_type TEXT,"
|
||||
" cpu_runtime INTEGER)"
|
||||
)
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(*args, **kwargs)
|
||||
counter = loop.time()
|
||||
return_value = await ret
|
||||
codeflash_duration = int(
|
||||
(loop.time() - counter) * 1_000_000_000
|
||||
)
|
||||
except Exception as e:
|
||||
codeflash_duration = int(
|
||||
(loop.time() - counter) * 1_000_000_000
|
||||
)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
print(f"!######{test_stdout_tag}######!")
|
||||
pickled_return_value = (
|
||||
pickle.dumps(exception) if exception
|
||||
else pickle.dumps(
|
||||
(args, kwargs, return_value)
|
||||
)
|
||||
)
|
||||
codeflash_cur.execute(
|
||||
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
function_name,
|
||||
loop_index,
|
||||
invocation_id,
|
||||
codeflash_duration,
|
||||
pickled_return_value,
|
||||
"function_call",
|
||||
0,
|
||||
),
|
||||
)
|
||||
codeflash_con.commit()
|
||||
codeflash_con.close()
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
return async_wrapper
|
||||
|
||||
|
||||
def codeflash_performance_async(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
(test_module_name, test_class_name,
|
||||
test_name) = extract_test_context_from_env()
|
||||
test_id = (
|
||||
f"{test_module_name}:{test_class_name}"
|
||||
f":{test_name}:{line_id}:{loop_index}"
|
||||
)
|
||||
if not hasattr(async_wrapper, "index"):
|
||||
async_wrapper.index = {}
|
||||
if test_id in async_wrapper.index:
|
||||
async_wrapper.index[test_id] += 1
|
||||
else:
|
||||
async_wrapper.index[test_id] = 0
|
||||
codeflash_test_index = async_wrapper.index[test_id]
|
||||
invocation_id = f"{line_id}_{codeflash_test_index}"
|
||||
class_prefix = (
|
||||
(test_class_name + ".") if test_class_name else ""
|
||||
)
|
||||
test_stdout_tag = (
|
||||
f"{test_module_name}:{class_prefix}"
|
||||
f"{test_name}:{function_name}"
|
||||
f":{loop_index}:{invocation_id}"
|
||||
)
|
||||
print(f"!$######{test_stdout_tag}######$!")
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(*args, **kwargs)
|
||||
counter = loop.time()
|
||||
return_value = await ret
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
except Exception as e:
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
return async_wrapper
|
||||
|
||||
|
||||
def codeflash_concurrency_async(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
function_name = func.__name__
|
||||
concurrency_factor = int(os.environ.get(
|
||||
"CODEFLASH_CONCURRENCY_FACTOR", "10"
|
||||
))
|
||||
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
|
||||
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
|
||||
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
|
||||
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
|
||||
gc.disable()
|
||||
try:
|
||||
seq_start = time.perf_counter_ns()
|
||||
for _ in range(concurrency_factor):
|
||||
result = await func(*args, **kwargs)
|
||||
sequential_time = time.perf_counter_ns() - seq_start
|
||||
finally:
|
||||
gc.enable()
|
||||
gc.disable()
|
||||
try:
|
||||
conc_start = time.perf_counter_ns()
|
||||
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
|
||||
await asyncio.gather(*tasks)
|
||||
concurrent_time = time.perf_counter_ns() - conc_start
|
||||
finally:
|
||||
gc.enable()
|
||||
tag = (
|
||||
f"{test_module_name}:{test_class_name}"
|
||||
f":{test_function}:{function_name}"
|
||||
f":{loop_index}"
|
||||
)
|
||||
print(
|
||||
f"!@######CONC:{tag}"
|
||||
f":{sequential_time}:{concurrent_time}"
|
||||
f":{concurrency_factor}######@!"
|
||||
)
|
||||
return result
|
||||
return async_wrapper
|
||||
"""
|
||||
|
||||
ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py"
|
||||
|
||||
_RUNTIME_DECORATOR_PATH = (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "runtime"
|
||||
/ "_codeflash_wrap_decorator.py"
|
||||
)
|
||||
|
||||
|
||||
def get_decorator_name_for_mode(
|
||||
mode: TestingMode,
|
||||
|
|
@ -540,10 +307,10 @@ def get_decorator_name_for_mode(
|
|||
def write_async_helper_file(
|
||||
target_dir: Path,
|
||||
) -> Path:
|
||||
"""Write the async decorator helper file to the target directory."""
|
||||
"""Copy the runtime async decorator module to the target directory."""
|
||||
helper_path = target_dir / ASYNC_HELPER_FILENAME
|
||||
if not helper_path.exists():
|
||||
helper_path.write_text(ASYNC_HELPER_INLINE_CODE, "utf-8")
|
||||
shutil.copy2(_RUNTIME_DECORATOR_PATH, helper_path)
|
||||
return helper_path
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -28,9 +28,6 @@ from ..runtime._codeflash_wrap_decorator import (
|
|||
from ._instrument_async import (
|
||||
ASYNC_HELPER_FILENAME as ASYNC_HELPER_FILENAME, # noqa: PLC0414
|
||||
)
|
||||
from ._instrument_async import (
|
||||
ASYNC_HELPER_INLINE_CODE as ASYNC_HELPER_INLINE_CODE, # noqa: PLC0414
|
||||
)
|
||||
from ._instrument_async import (
|
||||
AsyncCallInstrumenter as AsyncCallInstrumenter, # noqa: PLC0414
|
||||
)
|
||||
|
|
|
|||
|
|
@ -155,6 +155,7 @@ async def test_async_sort():
|
|||
assert results_list[0].id.test_function_name == "test_async_sort"
|
||||
assert results_list[0].did_pass
|
||||
assert results_list[0].runtime is None or results_list[0].runtime >= 0
|
||||
assert isinstance(results_list[0].cpu_runtime, int)
|
||||
|
||||
expected_stdout = "codeflash stdout: Async sorting list\nresult: [0, 1, 2, 3, 4, 5]\n"
|
||||
assert expected_stdout == results_list[0].stdout
|
||||
|
|
@ -162,6 +163,7 @@ async def test_async_sort():
|
|||
assert results_list[1].id.function_getting_tested == "async_sorter"
|
||||
assert results_list[1].id.test_function_name == "test_async_sort"
|
||||
assert results_list[1].did_pass
|
||||
assert isinstance(results_list[1].cpu_runtime, int)
|
||||
|
||||
expected_stdout2 = "codeflash stdout: Async sorting list\nresult: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]\n"
|
||||
assert expected_stdout2 == results_list[1].stdout
|
||||
|
|
@ -300,6 +302,7 @@ async def test_async_class_sort():
|
|||
assert sorter_result.id.test_function_name == "test_async_class_sort"
|
||||
assert sorter_result.did_pass
|
||||
assert sorter_result.runtime is None or sorter_result.runtime >= 0
|
||||
assert isinstance(sorter_result.cpu_runtime, int)
|
||||
|
||||
expected_stdout = (
|
||||
"codeflash stdout: AsyncBubbleSorter.sorter() called\n"
|
||||
|
|
@ -308,6 +311,7 @@ async def test_async_class_sort():
|
|||
|
||||
assert ".__init__" in init_result.id.function_getting_tested
|
||||
assert init_result.did_pass
|
||||
assert isinstance(init_result.cpu_runtime, int)
|
||||
|
||||
finally:
|
||||
fto_path.write_text(original_code, "utf-8")
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ class TestAsyncWrapperSQLiteValidation:
|
|||
runtime,
|
||||
return_value_blob,
|
||||
verification_type,
|
||||
cpu_runtime,
|
||||
) = row
|
||||
|
||||
assert test_module_path == __name__
|
||||
|
|
@ -111,6 +112,7 @@ class TestAsyncWrapperSQLiteValidation:
|
|||
) and iteration_id.endswith("_0")
|
||||
assert runtime > 0
|
||||
assert verification_type == VerificationType.FUNCTION_CALL.value
|
||||
assert isinstance(cpu_runtime, int)
|
||||
|
||||
unpickled_data = pickle.loads(return_value_blob)
|
||||
args, kwargs, return_val = unpickled_data
|
||||
|
|
@ -324,6 +326,7 @@ class TestAsyncWrapperSQLiteValidation:
|
|||
(6, "runtime", "INTEGER", 0, None, 0),
|
||||
(7, "return_value", "BLOB", 0, None, 0),
|
||||
(8, "verification_type", "TEXT", 0, None, 0),
|
||||
(9, "cpu_runtime", "INTEGER", 0, None, 0),
|
||||
]
|
||||
|
||||
assert columns == expected_columns
|
||||
|
|
|
|||
|
|
@ -845,6 +845,7 @@ def test_async_behavior_decorator_return_values_and_test_ids():
|
|||
runtime,
|
||||
return_value_blob,
|
||||
verification_type,
|
||||
cpu_runtime,
|
||||
) = row
|
||||
|
||||
assert test_module == "test_module", (
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from codeflash_python._model import (
|
|||
from codeflash_python.test_discovery.models import CodePosition
|
||||
from codeflash_python.testing._instrumentation import (
|
||||
ASYNC_HELPER_FILENAME,
|
||||
ASYNC_HELPER_INLINE_CODE,
|
||||
AsyncCallInstrumenter,
|
||||
AsyncDecoratorAdder,
|
||||
FunctionCallNodeArguments,
|
||||
|
|
@ -587,16 +586,24 @@ class TestWriteAsyncHelperFile:
|
|||
|
||||
|
||||
class TestAsyncHelperConstants:
|
||||
"""ASYNC_HELPER_FILENAME and ASYNC_HELPER_INLINE_CODE constants."""
|
||||
"""ASYNC_HELPER_FILENAME and runtime decorator source."""
|
||||
|
||||
def test_filename_value(self) -> None:
|
||||
"""ASYNC_HELPER_FILENAME has the expected value."""
|
||||
assert "codeflash_async_wrapper.py" == ASYNC_HELPER_FILENAME
|
||||
|
||||
def test_inline_code_nonempty(self) -> None:
|
||||
"""ASYNC_HELPER_INLINE_CODE is a non-empty string."""
|
||||
assert isinstance(ASYNC_HELPER_INLINE_CODE, str)
|
||||
assert len(ASYNC_HELPER_INLINE_CODE) > 0
|
||||
def test_runtime_decorator_is_self_contained(self) -> None:
|
||||
"""Runtime decorator file has no internal codeflash imports."""
|
||||
from codeflash_python.testing._instrument_async import (
|
||||
_RUNTIME_DECORATOR_PATH,
|
||||
)
|
||||
|
||||
source = _RUNTIME_DECORATOR_PATH.read_text("utf-8")
|
||||
assert _RUNTIME_DECORATOR_PATH.exists()
|
||||
for line in source.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith(("import ", "from ")):
|
||||
assert "codeflash_python" not in stripped
|
||||
|
||||
|
||||
class TestSortImports:
|
||||
|
|
|
|||
Loading…
Reference in a new issue