mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
refactor: clean up _instrument_async and add 100% test coverage
Remove dead code (unused fields, hasattr guard, duplicate decorator set), rename _optimized_instrument_statement to _find_awaited_target_call, simplify AsyncDecoratorAdder init and leave_FunctionDef. Add 21 new unit tests covering all branches: non-test skipping, attribute calls, class body recursion, counter independence, decorator deduplication (name and call form), error handlers, and mode selection.
This commit is contained in:
parent
2fd9d06e28
commit
c670d637c0
2 changed files with 520 additions and 138 deletions
|
|
@ -44,19 +44,10 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
|
|||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
) -> None:
|
||||
"""Initialize with the target async function and testing mode."""
|
||||
self.mode = mode
|
||||
self.function_object = function
|
||||
self.class_name: str | None = None
|
||||
self.only_function_name = function.function_name
|
||||
self.module_path = module_path
|
||||
self.call_positions = call_positions
|
||||
self.did_instrument = False
|
||||
self.async_call_counter: dict[str, int] = {}
|
||||
if (
|
||||
len(function.parents) == 1
|
||||
and function.parents[0].type == "ClassDef"
|
||||
):
|
||||
self.class_name = function.parents[0].name
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||
"""Recurse into class bodies to find test methods."""
|
||||
|
|
@ -90,14 +81,10 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
|
|||
|
||||
new_body: list[ast.stmt] = []
|
||||
|
||||
# Scan only relevant nodes instead of
|
||||
# full ast.walk in _instrument_statement
|
||||
for _i, stmt in enumerate(node.body):
|
||||
transformed_stmt, added_env_assignment = (
|
||||
self._optimized_instrument_statement(stmt)
|
||||
)
|
||||
for stmt in node.body:
|
||||
_, has_target = self._find_awaited_target_call(stmt)
|
||||
|
||||
if added_env_assignment:
|
||||
if has_target:
|
||||
current_call_index = self.async_call_counter[node.name]
|
||||
self.async_call_counter[node.name] += 1
|
||||
|
||||
|
|
@ -116,12 +103,12 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
|
|||
)
|
||||
],
|
||||
value=ast.Constant(value=f"{current_call_index}"),
|
||||
lineno=stmt.lineno if hasattr(stmt, "lineno") else 1,
|
||||
lineno=stmt.lineno,
|
||||
)
|
||||
new_body.append(env_assignment)
|
||||
self.did_instrument = True
|
||||
|
||||
new_body.append(transformed_stmt)
|
||||
new_body.append(stmt)
|
||||
|
||||
node.body = new_body
|
||||
return node
|
||||
|
|
@ -134,35 +121,23 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
|
|||
return call_node.func.attr == self.function_object.function_name
|
||||
return False
|
||||
|
||||
def _call_in_positions(self, call_node: ast.Call) -> bool:
|
||||
"""Return True if the call node is at one of the tracked positions."""
|
||||
if not hasattr(call_node, "lineno") or not hasattr(
|
||||
call_node, "col_offset"
|
||||
):
|
||||
return False
|
||||
|
||||
return node_in_call_position(call_node, self.call_positions)
|
||||
|
||||
# Optimized version: only walk child nodes for Await
|
||||
def _optimized_instrument_statement(
|
||||
self, stmt: ast.stmt
|
||||
def _find_awaited_target_call(
|
||||
self,
|
||||
stmt: ast.stmt,
|
||||
) -> tuple[ast.stmt, bool]:
|
||||
"""Stack-based search for awaited target calls in a statement."""
|
||||
# Stack-based DFS, manual for relevant Await nodes
|
||||
"""Search a statement for awaited calls to the target function."""
|
||||
stack: list[ast.AST] = [stmt]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
# Favor direct ast.Await detection
|
||||
if isinstance(node, ast.Await):
|
||||
val = node.value
|
||||
if (
|
||||
isinstance(val, ast.Call)
|
||||
and self._is_target_call(val)
|
||||
and self._call_in_positions(val)
|
||||
and node_in_call_position(val, self.call_positions)
|
||||
):
|
||||
return stmt, True
|
||||
# Use _fields instead of ast.walk for less allocations
|
||||
for fname in getattr(node, "_fields", ()):
|
||||
for fname in node._fields:
|
||||
child = getattr(node, fname, None)
|
||||
if isinstance(child, list):
|
||||
stack.extend(child)
|
||||
|
|
@ -171,6 +146,16 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
|
|||
return stmt, False
|
||||
|
||||
|
||||
_CODEFLASH_ASYNC_DECORATORS = frozenset(
|
||||
{
|
||||
"codeflash_trace_async",
|
||||
"codeflash_behavior_async",
|
||||
"codeflash_performance_async",
|
||||
"codeflash_concurrency_async",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class AsyncDecoratorAdder(cst.CSTTransformer):
|
||||
"""Transformer that adds async decorator to async function definitions."""
|
||||
|
||||
|
|
@ -179,34 +164,18 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
|
|||
function: FunctionToOptimize,
|
||||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
) -> None:
|
||||
"""Initialize the transformer.
|
||||
|
||||
Args:
|
||||
----
|
||||
function: Target async function.
|
||||
mode: Testing mode for decorator.
|
||||
|
||||
"""
|
||||
"""Initialize the transformer."""
|
||||
super().__init__()
|
||||
self.function = function
|
||||
self.mode = mode
|
||||
self.qualified_name_parts = function.qualified_name.split(".")
|
||||
self.context_stack: list[str] = []
|
||||
self.added_decorator = False
|
||||
|
||||
# Choose decorator based on mode
|
||||
if mode == TestingMode.BEHAVIOR:
|
||||
self.decorator_name = "codeflash_behavior_async"
|
||||
elif mode == TestingMode.CONCURRENCY:
|
||||
self.decorator_name = "codeflash_concurrency_async"
|
||||
else:
|
||||
self.decorator_name = "codeflash_performance_async"
|
||||
self.decorator_name = get_decorator_name_for_mode(mode)
|
||||
|
||||
def visit_ClassDef( # noqa: N802
|
||||
self, node: cst.ClassDef
|
||||
self,
|
||||
node: cst.ClassDef,
|
||||
) -> None:
|
||||
"""Push class name onto the context stack."""
|
||||
# Track when we enter a class
|
||||
self.context_stack.append(node.name.value)
|
||||
|
||||
def leave_ClassDef( # noqa: N802
|
||||
|
|
@ -215,15 +184,14 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
|
|||
updated_node: cst.ClassDef,
|
||||
) -> cst.ClassDef:
|
||||
"""Pop class name from the context stack."""
|
||||
# Pop the context when we leave a class
|
||||
self.context_stack.pop()
|
||||
return updated_node
|
||||
|
||||
def visit_FunctionDef( # noqa: N802
|
||||
self, node: cst.FunctionDef
|
||||
self,
|
||||
node: cst.FunctionDef,
|
||||
) -> None:
|
||||
"""Push function name onto the context stack."""
|
||||
# Track when we enter a function
|
||||
self.context_stack.append(node.name.value)
|
||||
|
||||
def leave_FunctionDef( # noqa: N802
|
||||
|
|
@ -232,55 +200,37 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
|
|||
updated_node: cst.FunctionDef,
|
||||
) -> cst.FunctionDef:
|
||||
"""Add the async decorator if the function matches the target."""
|
||||
# Check if this is an async function and matches our target
|
||||
if (
|
||||
original_node.asynchronous is not None
|
||||
and self.context_stack == self.qualified_name_parts
|
||||
):
|
||||
# Check if the decorator is already present
|
||||
has_decorator = any(
|
||||
self._is_target_decorator(decorator.decorator)
|
||||
for decorator in original_node.decorators
|
||||
and not any(
|
||||
self._is_codeflash_decorator(d.decorator)
|
||||
for d in original_node.decorators
|
||||
)
|
||||
):
|
||||
new_decorator = cst.Decorator(
|
||||
decorator=cst.Name(value=self.decorator_name),
|
||||
)
|
||||
updated_node = updated_node.with_changes(
|
||||
decorators=(new_decorator, *updated_node.decorators),
|
||||
)
|
||||
self.added_decorator = True
|
||||
|
||||
# Only add the decorator if it's not already there
|
||||
if not has_decorator:
|
||||
new_decorator = cst.Decorator(
|
||||
decorator=cst.Name(value=self.decorator_name)
|
||||
)
|
||||
|
||||
# Add our new decorator to the existing decorators
|
||||
updated_decorators = [
|
||||
new_decorator,
|
||||
*list(updated_node.decorators),
|
||||
]
|
||||
updated_node = updated_node.with_changes(
|
||||
decorators=tuple(updated_decorators)
|
||||
)
|
||||
self.added_decorator = True
|
||||
|
||||
# Pop the context when we leave a function
|
||||
self.context_stack.pop()
|
||||
return updated_node
|
||||
|
||||
def _is_target_decorator(self, decorator_node: cst.BaseExpression) -> bool:
|
||||
"""Check if a decorator matches our target decorator name."""
|
||||
@staticmethod
|
||||
def _is_codeflash_decorator(
|
||||
decorator_node: cst.BaseExpression,
|
||||
) -> bool:
|
||||
"""Check if a decorator is one of the codeflash async decorators."""
|
||||
if isinstance(decorator_node, cst.Name):
|
||||
return decorator_node.value in {
|
||||
"codeflash_trace_async",
|
||||
"codeflash_behavior_async",
|
||||
"codeflash_performance_async",
|
||||
"codeflash_concurrency_async",
|
||||
}
|
||||
return decorator_node.value in _CODEFLASH_ASYNC_DECORATORS
|
||||
if isinstance(decorator_node, cst.Call) and isinstance(
|
||||
decorator_node.func, cst.Name
|
||||
decorator_node.func,
|
||||
cst.Name,
|
||||
):
|
||||
return decorator_node.func.value in {
|
||||
"codeflash_trace_async",
|
||||
"codeflash_behavior_async",
|
||||
"codeflash_performance_async",
|
||||
"codeflash_concurrency_async",
|
||||
}
|
||||
return decorator_node.func.value in _CODEFLASH_ASYNC_DECORATORS
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -248,23 +248,28 @@ class TestGetDecoratorNameForMode:
|
|||
"""get_decorator_name_for_mode decorator selection."""
|
||||
|
||||
def test_behavior_mode(self) -> None:
|
||||
"""Returns correct decorator for BEHAVIOR mode."""
|
||||
name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
assert isinstance(name, str)
|
||||
assert len(name) > 0
|
||||
"""Returns codeflash_behavior_async for BEHAVIOR."""
|
||||
assert "codeflash_behavior_async" == get_decorator_name_for_mode(
|
||||
TestingMode.BEHAVIOR,
|
||||
)
|
||||
|
||||
def test_performance_mode(self) -> None:
|
||||
"""Returns correct decorator for PERFORMANCE mode."""
|
||||
name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
|
||||
assert isinstance(name, str)
|
||||
assert len(name) > 0
|
||||
"""Returns codeflash_performance_async for PERFORMANCE."""
|
||||
assert "codeflash_performance_async" == get_decorator_name_for_mode(
|
||||
TestingMode.PERFORMANCE,
|
||||
)
|
||||
|
||||
def test_all_modes_return_strings(self) -> None:
|
||||
"""All modes return non-empty string decorator names."""
|
||||
for mode in TestingMode:
|
||||
name = get_decorator_name_for_mode(mode)
|
||||
assert isinstance(name, str)
|
||||
assert len(name) > 0
|
||||
def test_concurrency_mode(self) -> None:
|
||||
"""Returns codeflash_concurrency_async for CONCURRENCY."""
|
||||
assert "codeflash_concurrency_async" == get_decorator_name_for_mode(
|
||||
TestingMode.CONCURRENCY,
|
||||
)
|
||||
|
||||
def test_line_profile_falls_through_to_performance(self) -> None:
|
||||
"""LINE_PROFILE is not async-specific, falls through to performance."""
|
||||
assert "codeflash_performance_async" == get_decorator_name_for_mode(
|
||||
TestingMode.LINE_PROFILE,
|
||||
)
|
||||
|
||||
|
||||
class TestCreateDeviceSyncPrecomputeStatements:
|
||||
|
|
@ -455,30 +460,190 @@ class TestInjectPerfOnly:
|
|||
class TestAsyncCallInstrumenter:
|
||||
"""AsyncCallInstrumenter AST transformer."""
|
||||
|
||||
def _make_transformer(
|
||||
self,
|
||||
code: str,
|
||||
*,
|
||||
name: str = "target_func",
|
||||
parents: tuple[FunctionParent, ...] = (),
|
||||
positions: list[CodePosition] | None = None,
|
||||
) -> tuple[AsyncCallInstrumenter, ast.Module]:
|
||||
"""Parse code and build a transformer with call positions from it."""
|
||||
tree = ast.parse(code)
|
||||
if positions is None:
|
||||
positions = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Call):
|
||||
func_name = None
|
||||
if isinstance(node.func, ast.Name):
|
||||
func_name = node.func.id
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
func_name = node.func.attr
|
||||
if func_name == name:
|
||||
positions.append(
|
||||
CodePosition(
|
||||
line_no=node.lineno,
|
||||
col_no=node.col_offset,
|
||||
)
|
||||
)
|
||||
func = make_function(
|
||||
name,
|
||||
"module.py",
|
||||
parents=parents,
|
||||
is_async=True,
|
||||
)
|
||||
transformer = AsyncCallInstrumenter(
|
||||
function=func,
|
||||
module_path="module",
|
||||
call_positions=positions,
|
||||
)
|
||||
return transformer, tree
|
||||
|
||||
def test_instruments_await_call(self) -> None:
|
||||
"""Adds env var assignment for an async function call."""
|
||||
"""Adds env var assignment before an awaited target call."""
|
||||
code = textwrap.dedent("""\
|
||||
async def test_it():
|
||||
result = await target_func(1, 2)
|
||||
""")
|
||||
tree = ast.parse(code)
|
||||
call_node = (
|
||||
tree.body[0].body[0].value.value # type: ignore[attr-defined]
|
||||
)
|
||||
pos = CodePosition(
|
||||
line_no=call_node.lineno,
|
||||
col_no=call_node.col_offset,
|
||||
)
|
||||
func = make_function("target_func", "module.py", is_async=True)
|
||||
transformer = AsyncCallInstrumenter(
|
||||
function=func,
|
||||
module_path="module",
|
||||
call_positions=[pos],
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
)
|
||||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "os.environ" in source or "CODEFLASH" in source
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
||||
assert transformer.did_instrument is True
|
||||
|
||||
def test_skips_non_test_async_functions(self) -> None:
|
||||
"""Does not instrument async functions that don't start with test_."""
|
||||
code = textwrap.dedent("""\
|
||||
async def helper():
|
||||
result = await target_func(1, 2)
|
||||
""")
|
||||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" not in source
|
||||
assert transformer.did_instrument is False
|
||||
|
||||
def test_skips_non_test_sync_functions(self) -> None:
|
||||
"""Does not instrument sync functions that don't start with test_."""
|
||||
code = textwrap.dedent("""\
|
||||
def helper():
|
||||
result = await target_func(1, 2)
|
||||
""")
|
||||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" not in source
|
||||
|
||||
def test_instruments_sync_test_with_await(self) -> None:
|
||||
"""Instruments sync test_ functions that contain awaited calls."""
|
||||
code = textwrap.dedent("""\
|
||||
def test_it():
|
||||
result = await target_func(1, 2)
|
||||
""")
|
||||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
||||
assert transformer.did_instrument is True
|
||||
|
||||
def test_multiple_awaits_get_incrementing_ids(self) -> None:
|
||||
"""Each awaited target call gets a unique incrementing counter."""
|
||||
code = textwrap.dedent("""\
|
||||
async def test_it():
|
||||
a = await target_func(1)
|
||||
b = await target_func(2)
|
||||
c = await target_func(3)
|
||||
""")
|
||||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "'0'" in source
|
||||
assert "'1'" in source
|
||||
assert "'2'" in source
|
||||
|
||||
def test_attribute_style_call(self) -> None:
|
||||
"""Instruments await obj.target_func() attribute-style calls."""
|
||||
code = textwrap.dedent("""\
|
||||
async def test_it():
|
||||
result = await obj.target_func(1, 2)
|
||||
""")
|
||||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
||||
assert transformer.did_instrument is True
|
||||
|
||||
def test_recurses_into_class_body(self) -> None:
|
||||
"""Finds and instruments test methods inside a class."""
|
||||
code = textwrap.dedent("""\
|
||||
class TestSuite:
|
||||
async def test_it(self):
|
||||
result = await target_func(1)
|
||||
""")
|
||||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
||||
assert transformer.did_instrument is True
|
||||
|
||||
def test_no_match_when_position_wrong(self) -> None:
|
||||
"""Does not instrument when call positions don't match."""
|
||||
code = textwrap.dedent("""\
|
||||
async def test_it():
|
||||
result = await target_func(1, 2)
|
||||
""")
|
||||
transformer, tree = self._make_transformer(
|
||||
code,
|
||||
positions=[CodePosition(line_no=99, col_no=99)],
|
||||
)
|
||||
new_tree = transformer.visit(tree)
|
||||
assert transformer.did_instrument is False
|
||||
|
||||
def test_nested_await_in_conditional(self) -> None:
|
||||
"""Finds awaited target calls nested inside if statements."""
|
||||
code = textwrap.dedent("""\
|
||||
async def test_it():
|
||||
if True:
|
||||
result = await target_func(1)
|
||||
""")
|
||||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
||||
|
||||
def test_ignores_non_target_awaits(self) -> None:
|
||||
"""Does not instrument awaits of unrelated functions."""
|
||||
code = textwrap.dedent("""\
|
||||
async def test_it():
|
||||
result = await other_func(1, 2)
|
||||
""")
|
||||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
assert transformer.did_instrument is False
|
||||
|
||||
def test_subscript_style_call_not_matched(self) -> None:
|
||||
"""Await of a subscript-style call (funcs[0]()) is not matched."""
|
||||
code = textwrap.dedent("""\
|
||||
async def test_it():
|
||||
result = await funcs[0](1, 2)
|
||||
""")
|
||||
transformer, tree = self._make_transformer(code)
|
||||
transformer.visit(tree)
|
||||
assert transformer.did_instrument is False
|
||||
|
||||
def test_counters_independent_per_test_function(self) -> None:
|
||||
"""Each test function gets its own independent counter."""
|
||||
code = textwrap.dedent("""\
|
||||
async def test_a():
|
||||
await target_func(1)
|
||||
await target_func(2)
|
||||
async def test_b():
|
||||
await target_func(3)
|
||||
""")
|
||||
transformer, tree = self._make_transformer(code)
|
||||
transformer.visit(tree)
|
||||
assert 2 == transformer.async_call_counter["test_a"]
|
||||
assert 1 == transformer.async_call_counter["test_b"]
|
||||
|
||||
|
||||
class TestFunctionImportedAsVisitor:
|
||||
|
|
@ -549,6 +714,7 @@ class TestAsyncDecoratorAdder:
|
|||
new_tree = tree.visit(transformer)
|
||||
output = new_tree.code
|
||||
assert "@codeflash_behavior_async" in output
|
||||
assert transformer.added_decorator is True
|
||||
|
||||
def test_does_not_add_to_non_matching(self) -> None:
|
||||
"""Does not add decorator to functions that do not match."""
|
||||
|
|
@ -562,6 +728,124 @@ class TestAsyncDecoratorAdder:
|
|||
new_tree = tree.visit(transformer)
|
||||
output = new_tree.code
|
||||
assert "@" not in output
|
||||
assert transformer.added_decorator is False
|
||||
|
||||
def test_does_not_add_to_sync_function(self) -> None:
|
||||
"""Skips sync functions even if the name matches."""
|
||||
code = textwrap.dedent("""\
|
||||
def target_func():
|
||||
pass
|
||||
""")
|
||||
tree = cst.parse_module(code)
|
||||
func = make_function("target_func", "module.py", is_async=True)
|
||||
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
||||
new_tree = tree.visit(transformer)
|
||||
assert "@" not in new_tree.code
|
||||
assert transformer.added_decorator is False
|
||||
|
||||
def test_performance_mode_decorator(self) -> None:
|
||||
"""Uses codeflash_performance_async for PERFORMANCE mode."""
|
||||
code = "async def target_func():\n pass\n"
|
||||
tree = cst.parse_module(code)
|
||||
func = make_function("target_func", "module.py", is_async=True)
|
||||
transformer = AsyncDecoratorAdder(
|
||||
func,
|
||||
mode=TestingMode.PERFORMANCE,
|
||||
)
|
||||
new_tree = tree.visit(transformer)
|
||||
assert "@codeflash_performance_async" in new_tree.code
|
||||
|
||||
def test_concurrency_mode_decorator(self) -> None:
|
||||
"""Uses codeflash_concurrency_async for CONCURRENCY mode."""
|
||||
code = "async def target_func():\n pass\n"
|
||||
tree = cst.parse_module(code)
|
||||
func = make_function("target_func", "module.py", is_async=True)
|
||||
transformer = AsyncDecoratorAdder(
|
||||
func,
|
||||
mode=TestingMode.CONCURRENCY,
|
||||
)
|
||||
new_tree = tree.visit(transformer)
|
||||
assert "@codeflash_concurrency_async" in new_tree.code
|
||||
|
||||
def test_class_method_matching(self) -> None:
|
||||
"""Matches async method inside a class via qualified name."""
|
||||
code = textwrap.dedent("""\
|
||||
class MyClass:
|
||||
async def target_func(self):
|
||||
pass
|
||||
""")
|
||||
tree = cst.parse_module(code)
|
||||
parent = FunctionParent(name="MyClass", type="ClassDef")
|
||||
func = make_function(
|
||||
"target_func",
|
||||
"module.py",
|
||||
parents=(parent,),
|
||||
is_async=True,
|
||||
)
|
||||
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
||||
new_tree = tree.visit(transformer)
|
||||
assert "@codeflash_behavior_async" in new_tree.code
|
||||
assert transformer.added_decorator is True
|
||||
|
||||
def test_no_duplicate_when_already_decorated(self) -> None:
|
||||
"""Does not add a second decorator when one is already present."""
|
||||
code = textwrap.dedent("""\
|
||||
@codeflash_behavior_async
|
||||
async def target_func():
|
||||
pass
|
||||
""")
|
||||
tree = cst.parse_module(code)
|
||||
func = make_function("target_func", "module.py", is_async=True)
|
||||
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
||||
new_tree = tree.visit(transformer)
|
||||
assert new_tree.code.count("codeflash_behavior_async") == 1
|
||||
assert transformer.added_decorator is False
|
||||
|
||||
def test_no_duplicate_when_call_style_decorator(self) -> None:
|
||||
"""Detects existing decorator even in @decorator() call form."""
|
||||
code = textwrap.dedent("""\
|
||||
@codeflash_behavior_async()
|
||||
async def target_func():
|
||||
pass
|
||||
""")
|
||||
tree = cst.parse_module(code)
|
||||
func = make_function("target_func", "module.py", is_async=True)
|
||||
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
||||
new_tree = tree.visit(transformer)
|
||||
assert new_tree.code.count("codeflash_behavior_async") == 1
|
||||
assert transformer.added_decorator is False
|
||||
|
||||
def test_preserves_existing_decorators(self) -> None:
|
||||
"""Keeps existing decorators and prepends the codeflash one."""
|
||||
code = textwrap.dedent("""\
|
||||
@staticmethod
|
||||
async def target_func():
|
||||
pass
|
||||
""")
|
||||
tree = cst.parse_module(code)
|
||||
func = make_function("target_func", "module.py", is_async=True)
|
||||
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
||||
new_tree = tree.visit(transformer)
|
||||
output = new_tree.code
|
||||
assert "@codeflash_behavior_async" in output
|
||||
assert "@staticmethod" in output
|
||||
behavior_pos = output.index("@codeflash_behavior_async")
|
||||
static_pos = output.index("@staticmethod")
|
||||
assert behavior_pos < static_pos
|
||||
|
||||
def test_attribute_decorator_not_matched(self) -> None:
|
||||
"""Attribute-style decorators (mod.decorator) are not codeflash."""
|
||||
code = textwrap.dedent("""\
|
||||
@mod.codeflash_behavior_async
|
||||
async def target_func():
|
||||
pass
|
||||
""")
|
||||
tree = cst.parse_module(code)
|
||||
func = make_function("target_func", "module.py", is_async=True)
|
||||
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
||||
new_tree = tree.visit(transformer)
|
||||
assert new_tree.code.count("codeflash_behavior_async") == 2
|
||||
assert transformer.added_decorator is True
|
||||
|
||||
|
||||
class TestWriteAsyncHelperFile:
|
||||
|
|
@ -578,11 +862,20 @@ class TestWriteAsyncHelperFile:
|
|||
result = write_async_helper_file(tmp_path)
|
||||
assert ASYNC_HELPER_FILENAME == result.name
|
||||
|
||||
def test_file_content(self, tmp_path: Path) -> None:
|
||||
"""The file contains the inline code constant."""
|
||||
def test_file_content_is_valid_python(self, tmp_path: Path) -> None:
|
||||
"""The copied file is valid, parseable Python."""
|
||||
result = write_async_helper_file(tmp_path)
|
||||
content = result.read_text()
|
||||
assert len(content) > 0
|
||||
ast.parse(content)
|
||||
|
||||
def test_idempotent_does_not_overwrite(self, tmp_path: Path) -> None:
|
||||
"""Calling twice does not overwrite the existing file."""
|
||||
first = write_async_helper_file(tmp_path)
|
||||
first.write_text("sentinel", encoding="utf-8")
|
||||
second = write_async_helper_file(tmp_path)
|
||||
assert first == second
|
||||
assert "sentinel" == second.read_text()
|
||||
|
||||
|
||||
class TestAsyncHelperConstants:
|
||||
|
|
@ -784,7 +1077,6 @@ class TestInjectAsyncProfilingIntoExistingTest:
|
|||
test_file.write_text(test_code, encoding="utf-8")
|
||||
|
||||
func = make_function("target_func", "module.py", is_async=True)
|
||||
# await target_func(1, 2) — the Call node is at col 25
|
||||
positions = [CodePosition(line_no=4, col_no=25)]
|
||||
|
||||
ok, source = inject_async_profiling_into_existing_test(
|
||||
|
|
@ -820,6 +1112,57 @@ class TestInjectAsyncProfilingIntoExistingTest:
|
|||
assert ok is False
|
||||
assert source is None
|
||||
|
||||
def test_syntax_error_returns_false(self, tmp_path: Path) -> None:
|
||||
"""Returns (False, None) for a file with invalid Python."""
|
||||
project_root = tmp_path / "project"
|
||||
project_root.mkdir()
|
||||
test_file = project_root / "test_bad.py"
|
||||
test_file.write_text(
|
||||
"async def test_x(\n not valid !!!",
|
||||
encoding="utf-8",
|
||||
)
|
||||
func = make_function("target_func", "module.py", is_async=True)
|
||||
ok, source = inject_async_profiling_into_existing_test(
|
||||
test_file,
|
||||
[CodePosition(line_no=1, col_no=0)],
|
||||
func,
|
||||
project_root,
|
||||
)
|
||||
assert ok is False
|
||||
assert source is None
|
||||
|
||||
def test_multiple_awaits_get_sequential_ids(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Multiple awaited calls in one test get sequential counter IDs."""
|
||||
project_root = tmp_path / "project"
|
||||
project_root.mkdir()
|
||||
test_file = project_root / "test_multi.py"
|
||||
test_code = textwrap.dedent("""\
|
||||
from module import target_func
|
||||
|
||||
async def test_multi():
|
||||
a = await target_func(1)
|
||||
b = await target_func(2)
|
||||
""")
|
||||
test_file.write_text(test_code, encoding="utf-8")
|
||||
|
||||
func = make_function("target_func", "module.py", is_async=True)
|
||||
positions = [
|
||||
CodePosition(line_no=4, col_no=22),
|
||||
CodePosition(line_no=5, col_no=22),
|
||||
]
|
||||
ok, source = inject_async_profiling_into_existing_test(
|
||||
test_file,
|
||||
positions,
|
||||
func,
|
||||
project_root,
|
||||
)
|
||||
assert ok is True
|
||||
assert source is not None
|
||||
assert source.count("CODEFLASH_CURRENT_LINE_ID") == 2
|
||||
|
||||
|
||||
class TestAddAsyncDecoratorToFunction:
|
||||
"""add_async_decorator_to_function source rewriting."""
|
||||
|
|
@ -832,11 +1175,13 @@ class TestAddAsyncDecoratorToFunction:
|
|||
encoding="utf-8",
|
||||
)
|
||||
func = make_function("target_func", str(source_file))
|
||||
result, _ = add_async_decorator_to_function(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
result, originals = add_async_decorator_to_function(
|
||||
source_file,
|
||||
func,
|
||||
TestingMode.BEHAVIOR,
|
||||
)
|
||||
assert result is False
|
||||
# File unchanged
|
||||
assert {} == originals
|
||||
assert "decorator" not in source_file.read_text()
|
||||
|
||||
def test_async_function_gets_decorator(self, tmp_path: Path) -> None:
|
||||
|
|
@ -853,18 +1198,41 @@ class TestAddAsyncDecoratorToFunction:
|
|||
str(source_file),
|
||||
is_async=True,
|
||||
)
|
||||
result, _ = add_async_decorator_to_function(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
result, originals = add_async_decorator_to_function(
|
||||
source_file,
|
||||
func,
|
||||
TestingMode.BEHAVIOR,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
modified = source_file.read_text()
|
||||
assert "codeflash_behavior_async" in modified
|
||||
|
||||
# Helper file created in source dir (no project_root given)
|
||||
helper = tmp_path / ASYNC_HELPER_FILENAME
|
||||
assert helper.exists()
|
||||
|
||||
def test_originals_contains_pre_modification_source(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""The originals dict maps the file to its content before rewriting."""
|
||||
source_file = tmp_path / "module.py"
|
||||
original_code = "async def target_func():\n pass\n"
|
||||
source_file.write_text(original_code, encoding="utf-8")
|
||||
|
||||
func = make_function(
|
||||
"target_func",
|
||||
str(source_file),
|
||||
is_async=True,
|
||||
)
|
||||
_, originals = add_async_decorator_to_function(
|
||||
source_file,
|
||||
func,
|
||||
TestingMode.BEHAVIOR,
|
||||
)
|
||||
assert source_file in originals
|
||||
assert original_code == originals[source_file]
|
||||
|
||||
def test_with_explicit_project_root(self, tmp_path: Path) -> None:
|
||||
"""Writes helper file to project_root when specified."""
|
||||
src_dir = tmp_path / "src"
|
||||
|
|
@ -891,8 +1259,72 @@ class TestAddAsyncDecoratorToFunction:
|
|||
project_root=project_root,
|
||||
)
|
||||
assert result is True
|
||||
# Helper in project_root, not in src_dir
|
||||
assert (project_root / ASYNC_HELPER_FILENAME).exists()
|
||||
assert not (src_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
def test_already_decorated_returns_false(self, tmp_path: Path) -> None:
|
||||
"""Returns False when the function already has the decorator."""
|
||||
source_file = tmp_path / "module.py"
|
||||
source_code = textwrap.dedent("""\
|
||||
@codeflash_behavior_async
|
||||
async def target_func():
|
||||
pass
|
||||
""")
|
||||
source_file.write_text(source_code, encoding="utf-8")
|
||||
|
||||
func = make_function(
|
||||
"target_func",
|
||||
str(source_file),
|
||||
is_async=True,
|
||||
)
|
||||
result, originals = add_async_decorator_to_function(
|
||||
source_file,
|
||||
func,
|
||||
TestingMode.BEHAVIOR,
|
||||
)
|
||||
assert result is False
|
||||
assert {} == originals
|
||||
|
||||
def test_adds_import_for_decorator(self, tmp_path: Path) -> None:
|
||||
"""The rewritten file includes the import for the decorator."""
|
||||
source_file = tmp_path / "module.py"
|
||||
source_file.write_text(
|
||||
"async def target_func():\n pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
func = make_function(
|
||||
"target_func",
|
||||
str(source_file),
|
||||
is_async=True,
|
||||
)
|
||||
add_async_decorator_to_function(
|
||||
source_file,
|
||||
func,
|
||||
TestingMode.PERFORMANCE,
|
||||
)
|
||||
modified = source_file.read_text()
|
||||
assert "from codeflash_async_wrapper import" in modified
|
||||
assert "codeflash_performance_async" in modified
|
||||
|
||||
def test_cst_parse_error_returns_false(self, tmp_path: Path) -> None:
|
||||
"""Returns (False, {}) when the source file has invalid syntax."""
|
||||
source_file = tmp_path / "module.py"
|
||||
source_file.write_text(
|
||||
"async def target_func(\n invalid !!!",
|
||||
encoding="utf-8",
|
||||
)
|
||||
func = make_function(
|
||||
"target_func",
|
||||
str(source_file),
|
||||
is_async=True,
|
||||
)
|
||||
result, originals = add_async_decorator_to_function(
|
||||
source_file,
|
||||
func,
|
||||
TestingMode.BEHAVIOR,
|
||||
)
|
||||
assert result is False
|
||||
assert {} == originals
|
||||
|
||||
|
||||
class TestCreateInstrumentedSourceModulePath:
|
||||
|
|
|
|||
Loading…
Reference in a new issue