refactor: clean up _instrument_async and add 100% test coverage

Remove dead code (unused fields, hasattr guard, duplicate decorator
set), rename _optimized_instrument_statement to _find_awaited_target_call,
simplify AsyncDecoratorAdder init and leave_FunctionDef. Add 21 new
unit tests covering all branches: non-test skipping, attribute calls,
class body recursion, counter independence, decorator deduplication
(name and call form), error handlers, and mode selection.
This commit is contained in:
Kevin Turcios 2026-04-24 02:45:07 -05:00
parent 2fd9d06e28
commit c670d637c0
2 changed files with 520 additions and 138 deletions

View file

@ -44,19 +44,10 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
mode: TestingMode = TestingMode.BEHAVIOR,
) -> None:
"""Initialize with the target async function and testing mode."""
self.mode = mode
self.function_object = function
self.class_name: str | None = None
self.only_function_name = function.function_name
self.module_path = module_path
self.call_positions = call_positions
self.did_instrument = False
self.async_call_counter: dict[str, int] = {}
if (
len(function.parents) == 1
and function.parents[0].type == "ClassDef"
):
self.class_name = function.parents[0].name
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
"""Recurse into class bodies to find test methods."""
@ -90,14 +81,10 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
new_body: list[ast.stmt] = []
# Scan only relevant nodes instead of
# full ast.walk in _instrument_statement
for _i, stmt in enumerate(node.body):
transformed_stmt, added_env_assignment = (
self._optimized_instrument_statement(stmt)
)
for stmt in node.body:
_, has_target = self._find_awaited_target_call(stmt)
if added_env_assignment:
if has_target:
current_call_index = self.async_call_counter[node.name]
self.async_call_counter[node.name] += 1
@ -116,12 +103,12 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
)
],
value=ast.Constant(value=f"{current_call_index}"),
lineno=stmt.lineno if hasattr(stmt, "lineno") else 1,
lineno=stmt.lineno,
)
new_body.append(env_assignment)
self.did_instrument = True
new_body.append(transformed_stmt)
new_body.append(stmt)
node.body = new_body
return node
@ -134,35 +121,23 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
return call_node.func.attr == self.function_object.function_name
return False
def _call_in_positions(self, call_node: ast.Call) -> bool:
"""Return True if the call node is at one of the tracked positions."""
if not hasattr(call_node, "lineno") or not hasattr(
call_node, "col_offset"
):
return False
return node_in_call_position(call_node, self.call_positions)
# Optimized version: only walk child nodes for Await
def _optimized_instrument_statement(
self, stmt: ast.stmt
def _find_awaited_target_call(
self,
stmt: ast.stmt,
) -> tuple[ast.stmt, bool]:
"""Stack-based search for awaited target calls in a statement."""
# Stack-based DFS, manual for relevant Await nodes
"""Search a statement for awaited calls to the target function."""
stack: list[ast.AST] = [stmt]
while stack:
node = stack.pop()
# Favor direct ast.Await detection
if isinstance(node, ast.Await):
val = node.value
if (
isinstance(val, ast.Call)
and self._is_target_call(val)
and self._call_in_positions(val)
and node_in_call_position(val, self.call_positions)
):
return stmt, True
# Use _fields instead of ast.walk for less allocations
for fname in getattr(node, "_fields", ()):
for fname in node._fields:
child = getattr(node, fname, None)
if isinstance(child, list):
stack.extend(child)
@ -171,6 +146,16 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
return stmt, False
_CODEFLASH_ASYNC_DECORATORS = frozenset(
{
"codeflash_trace_async",
"codeflash_behavior_async",
"codeflash_performance_async",
"codeflash_concurrency_async",
}
)
class AsyncDecoratorAdder(cst.CSTTransformer):
"""Transformer that adds async decorator to async function definitions."""
@ -179,34 +164,18 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
function: FunctionToOptimize,
mode: TestingMode = TestingMode.BEHAVIOR,
) -> None:
"""Initialize the transformer.
Args:
----
function: Target async function.
mode: Testing mode for decorator.
"""
"""Initialize the transformer."""
super().__init__()
self.function = function
self.mode = mode
self.qualified_name_parts = function.qualified_name.split(".")
self.context_stack: list[str] = []
self.added_decorator = False
# Choose decorator based on mode
if mode == TestingMode.BEHAVIOR:
self.decorator_name = "codeflash_behavior_async"
elif mode == TestingMode.CONCURRENCY:
self.decorator_name = "codeflash_concurrency_async"
else:
self.decorator_name = "codeflash_performance_async"
self.decorator_name = get_decorator_name_for_mode(mode)
def visit_ClassDef( # noqa: N802
self, node: cst.ClassDef
self,
node: cst.ClassDef,
) -> None:
"""Push class name onto the context stack."""
# Track when we enter a class
self.context_stack.append(node.name.value)
def leave_ClassDef( # noqa: N802
@ -215,15 +184,14 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
updated_node: cst.ClassDef,
) -> cst.ClassDef:
"""Pop class name from the context stack."""
# Pop the context when we leave a class
self.context_stack.pop()
return updated_node
def visit_FunctionDef( # noqa: N802
self, node: cst.FunctionDef
self,
node: cst.FunctionDef,
) -> None:
"""Push function name onto the context stack."""
# Track when we enter a function
self.context_stack.append(node.name.value)
def leave_FunctionDef( # noqa: N802
@ -232,55 +200,37 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
updated_node: cst.FunctionDef,
) -> cst.FunctionDef:
"""Add the async decorator if the function matches the target."""
# Check if this is an async function and matches our target
if (
original_node.asynchronous is not None
and self.context_stack == self.qualified_name_parts
):
# Check if the decorator is already present
has_decorator = any(
self._is_target_decorator(decorator.decorator)
for decorator in original_node.decorators
and not any(
self._is_codeflash_decorator(d.decorator)
for d in original_node.decorators
)
):
new_decorator = cst.Decorator(
decorator=cst.Name(value=self.decorator_name),
)
updated_node = updated_node.with_changes(
decorators=(new_decorator, *updated_node.decorators),
)
self.added_decorator = True
# Only add the decorator if it's not already there
if not has_decorator:
new_decorator = cst.Decorator(
decorator=cst.Name(value=self.decorator_name)
)
# Add our new decorator to the existing decorators
updated_decorators = [
new_decorator,
*list(updated_node.decorators),
]
updated_node = updated_node.with_changes(
decorators=tuple(updated_decorators)
)
self.added_decorator = True
# Pop the context when we leave a function
self.context_stack.pop()
return updated_node
def _is_target_decorator(self, decorator_node: cst.BaseExpression) -> bool:
"""Check if a decorator matches our target decorator name."""
@staticmethod
def _is_codeflash_decorator(
decorator_node: cst.BaseExpression,
) -> bool:
"""Check if a decorator is one of the codeflash async decorators."""
if isinstance(decorator_node, cst.Name):
return decorator_node.value in {
"codeflash_trace_async",
"codeflash_behavior_async",
"codeflash_performance_async",
"codeflash_concurrency_async",
}
return decorator_node.value in _CODEFLASH_ASYNC_DECORATORS
if isinstance(decorator_node, cst.Call) and isinstance(
decorator_node.func, cst.Name
decorator_node.func,
cst.Name,
):
return decorator_node.func.value in {
"codeflash_trace_async",
"codeflash_behavior_async",
"codeflash_performance_async",
"codeflash_concurrency_async",
}
return decorator_node.func.value in _CODEFLASH_ASYNC_DECORATORS
return False

View file

@ -248,23 +248,28 @@ class TestGetDecoratorNameForMode:
"""get_decorator_name_for_mode decorator selection."""
def test_behavior_mode(self) -> None:
"""Returns correct decorator for BEHAVIOR mode."""
name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
assert isinstance(name, str)
assert len(name) > 0
"""Returns codeflash_behavior_async for BEHAVIOR."""
assert "codeflash_behavior_async" == get_decorator_name_for_mode(
TestingMode.BEHAVIOR,
)
def test_performance_mode(self) -> None:
"""Returns correct decorator for PERFORMANCE mode."""
name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
assert isinstance(name, str)
assert len(name) > 0
"""Returns codeflash_performance_async for PERFORMANCE."""
assert "codeflash_performance_async" == get_decorator_name_for_mode(
TestingMode.PERFORMANCE,
)
def test_all_modes_return_strings(self) -> None:
"""All modes return non-empty string decorator names."""
for mode in TestingMode:
name = get_decorator_name_for_mode(mode)
assert isinstance(name, str)
assert len(name) > 0
def test_concurrency_mode(self) -> None:
"""Returns codeflash_concurrency_async for CONCURRENCY."""
assert "codeflash_concurrency_async" == get_decorator_name_for_mode(
TestingMode.CONCURRENCY,
)
def test_line_profile_falls_through_to_performance(self) -> None:
"""LINE_PROFILE is not async-specific, falls through to performance."""
assert "codeflash_performance_async" == get_decorator_name_for_mode(
TestingMode.LINE_PROFILE,
)
class TestCreateDeviceSyncPrecomputeStatements:
@ -455,30 +460,190 @@ class TestInjectPerfOnly:
class TestAsyncCallInstrumenter:
"""AsyncCallInstrumenter AST transformer."""
def _make_transformer(
self,
code: str,
*,
name: str = "target_func",
parents: tuple[FunctionParent, ...] = (),
positions: list[CodePosition] | None = None,
) -> tuple[AsyncCallInstrumenter, ast.Module]:
"""Parse code and build a transformer with call positions from it."""
tree = ast.parse(code)
if positions is None:
positions = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
func_name = None
if isinstance(node.func, ast.Name):
func_name = node.func.id
elif isinstance(node.func, ast.Attribute):
func_name = node.func.attr
if func_name == name:
positions.append(
CodePosition(
line_no=node.lineno,
col_no=node.col_offset,
)
)
func = make_function(
name,
"module.py",
parents=parents,
is_async=True,
)
transformer = AsyncCallInstrumenter(
function=func,
module_path="module",
call_positions=positions,
)
return transformer, tree
def test_instruments_await_call(self) -> None:
"""Adds env var assignment for an async function call."""
"""Adds env var assignment before an awaited target call."""
code = textwrap.dedent("""\
async def test_it():
result = await target_func(1, 2)
""")
tree = ast.parse(code)
call_node = (
tree.body[0].body[0].value.value # type: ignore[attr-defined]
)
pos = CodePosition(
line_no=call_node.lineno,
col_no=call_node.col_offset,
)
func = make_function("target_func", "module.py", is_async=True)
transformer = AsyncCallInstrumenter(
function=func,
module_path="module",
call_positions=[pos],
mode=TestingMode.BEHAVIOR,
)
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "os.environ" in source or "CODEFLASH" in source
assert "CODEFLASH_CURRENT_LINE_ID" in source
assert transformer.did_instrument is True
def test_skips_non_test_async_functions(self) -> None:
"""Does not instrument async functions that don't start with test_."""
code = textwrap.dedent("""\
async def helper():
result = await target_func(1, 2)
""")
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "CODEFLASH_CURRENT_LINE_ID" not in source
assert transformer.did_instrument is False
def test_skips_non_test_sync_functions(self) -> None:
"""Does not instrument sync functions that don't start with test_."""
code = textwrap.dedent("""\
def helper():
result = await target_func(1, 2)
""")
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "CODEFLASH_CURRENT_LINE_ID" not in source
def test_instruments_sync_test_with_await(self) -> None:
"""Instruments sync test_ functions that contain awaited calls."""
code = textwrap.dedent("""\
def test_it():
result = await target_func(1, 2)
""")
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "CODEFLASH_CURRENT_LINE_ID" in source
assert transformer.did_instrument is True
def test_multiple_awaits_get_incrementing_ids(self) -> None:
"""Each awaited target call gets a unique incrementing counter."""
code = textwrap.dedent("""\
async def test_it():
a = await target_func(1)
b = await target_func(2)
c = await target_func(3)
""")
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "'0'" in source
assert "'1'" in source
assert "'2'" in source
def test_attribute_style_call(self) -> None:
"""Instruments await obj.target_func() attribute-style calls."""
code = textwrap.dedent("""\
async def test_it():
result = await obj.target_func(1, 2)
""")
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "CODEFLASH_CURRENT_LINE_ID" in source
assert transformer.did_instrument is True
def test_recurses_into_class_body(self) -> None:
"""Finds and instruments test methods inside a class."""
code = textwrap.dedent("""\
class TestSuite:
async def test_it(self):
result = await target_func(1)
""")
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "CODEFLASH_CURRENT_LINE_ID" in source
assert transformer.did_instrument is True
def test_no_match_when_position_wrong(self) -> None:
"""Does not instrument when call positions don't match."""
code = textwrap.dedent("""\
async def test_it():
result = await target_func(1, 2)
""")
transformer, tree = self._make_transformer(
code,
positions=[CodePosition(line_no=99, col_no=99)],
)
new_tree = transformer.visit(tree)
assert transformer.did_instrument is False
def test_nested_await_in_conditional(self) -> None:
"""Finds awaited target calls nested inside if statements."""
code = textwrap.dedent("""\
async def test_it():
if True:
result = await target_func(1)
""")
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "CODEFLASH_CURRENT_LINE_ID" in source
def test_ignores_non_target_awaits(self) -> None:
"""Does not instrument awaits of unrelated functions."""
code = textwrap.dedent("""\
async def test_it():
result = await other_func(1, 2)
""")
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
assert transformer.did_instrument is False
def test_subscript_style_call_not_matched(self) -> None:
"""Await of a subscript-style call (funcs[0]()) is not matched."""
code = textwrap.dedent("""\
async def test_it():
result = await funcs[0](1, 2)
""")
transformer, tree = self._make_transformer(code)
transformer.visit(tree)
assert transformer.did_instrument is False
def test_counters_independent_per_test_function(self) -> None:
"""Each test function gets its own independent counter."""
code = textwrap.dedent("""\
async def test_a():
await target_func(1)
await target_func(2)
async def test_b():
await target_func(3)
""")
transformer, tree = self._make_transformer(code)
transformer.visit(tree)
assert 2 == transformer.async_call_counter["test_a"]
assert 1 == transformer.async_call_counter["test_b"]
class TestFunctionImportedAsVisitor:
@ -549,6 +714,7 @@ class TestAsyncDecoratorAdder:
new_tree = tree.visit(transformer)
output = new_tree.code
assert "@codeflash_behavior_async" in output
assert transformer.added_decorator is True
def test_does_not_add_to_non_matching(self) -> None:
"""Does not add decorator to functions that do not match."""
@ -562,6 +728,124 @@ class TestAsyncDecoratorAdder:
new_tree = tree.visit(transformer)
output = new_tree.code
assert "@" not in output
assert transformer.added_decorator is False
def test_does_not_add_to_sync_function(self) -> None:
"""Skips sync functions even if the name matches."""
code = textwrap.dedent("""\
def target_func():
pass
""")
tree = cst.parse_module(code)
func = make_function("target_func", "module.py", is_async=True)
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
new_tree = tree.visit(transformer)
assert "@" not in new_tree.code
assert transformer.added_decorator is False
def test_performance_mode_decorator(self) -> None:
"""Uses codeflash_performance_async for PERFORMANCE mode."""
code = "async def target_func():\n pass\n"
tree = cst.parse_module(code)
func = make_function("target_func", "module.py", is_async=True)
transformer = AsyncDecoratorAdder(
func,
mode=TestingMode.PERFORMANCE,
)
new_tree = tree.visit(transformer)
assert "@codeflash_performance_async" in new_tree.code
def test_concurrency_mode_decorator(self) -> None:
"""Uses codeflash_concurrency_async for CONCURRENCY mode."""
code = "async def target_func():\n pass\n"
tree = cst.parse_module(code)
func = make_function("target_func", "module.py", is_async=True)
transformer = AsyncDecoratorAdder(
func,
mode=TestingMode.CONCURRENCY,
)
new_tree = tree.visit(transformer)
assert "@codeflash_concurrency_async" in new_tree.code
def test_class_method_matching(self) -> None:
"""Matches async method inside a class via qualified name."""
code = textwrap.dedent("""\
class MyClass:
async def target_func(self):
pass
""")
tree = cst.parse_module(code)
parent = FunctionParent(name="MyClass", type="ClassDef")
func = make_function(
"target_func",
"module.py",
parents=(parent,),
is_async=True,
)
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
new_tree = tree.visit(transformer)
assert "@codeflash_behavior_async" in new_tree.code
assert transformer.added_decorator is True
def test_no_duplicate_when_already_decorated(self) -> None:
"""Does not add a second decorator when one is already present."""
code = textwrap.dedent("""\
@codeflash_behavior_async
async def target_func():
pass
""")
tree = cst.parse_module(code)
func = make_function("target_func", "module.py", is_async=True)
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
new_tree = tree.visit(transformer)
assert new_tree.code.count("codeflash_behavior_async") == 1
assert transformer.added_decorator is False
def test_no_duplicate_when_call_style_decorator(self) -> None:
"""Detects existing decorator even in @decorator() call form."""
code = textwrap.dedent("""\
@codeflash_behavior_async()
async def target_func():
pass
""")
tree = cst.parse_module(code)
func = make_function("target_func", "module.py", is_async=True)
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
new_tree = tree.visit(transformer)
assert new_tree.code.count("codeflash_behavior_async") == 1
assert transformer.added_decorator is False
def test_preserves_existing_decorators(self) -> None:
"""Keeps existing decorators and prepends the codeflash one."""
code = textwrap.dedent("""\
@staticmethod
async def target_func():
pass
""")
tree = cst.parse_module(code)
func = make_function("target_func", "module.py", is_async=True)
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
new_tree = tree.visit(transformer)
output = new_tree.code
assert "@codeflash_behavior_async" in output
assert "@staticmethod" in output
behavior_pos = output.index("@codeflash_behavior_async")
static_pos = output.index("@staticmethod")
assert behavior_pos < static_pos
def test_attribute_decorator_not_matched(self) -> None:
"""Attribute-style decorators (mod.decorator) are not codeflash."""
code = textwrap.dedent("""\
@mod.codeflash_behavior_async
async def target_func():
pass
""")
tree = cst.parse_module(code)
func = make_function("target_func", "module.py", is_async=True)
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
new_tree = tree.visit(transformer)
assert new_tree.code.count("codeflash_behavior_async") == 2
assert transformer.added_decorator is True
class TestWriteAsyncHelperFile:
@ -578,11 +862,20 @@ class TestWriteAsyncHelperFile:
result = write_async_helper_file(tmp_path)
assert ASYNC_HELPER_FILENAME == result.name
def test_file_content(self, tmp_path: Path) -> None:
"""The file contains the inline code constant."""
def test_file_content_is_valid_python(self, tmp_path: Path) -> None:
"""The copied file is valid, parseable Python."""
result = write_async_helper_file(tmp_path)
content = result.read_text()
assert len(content) > 0
ast.parse(content)
def test_idempotent_does_not_overwrite(self, tmp_path: Path) -> None:
"""Calling twice does not overwrite the existing file."""
first = write_async_helper_file(tmp_path)
first.write_text("sentinel", encoding="utf-8")
second = write_async_helper_file(tmp_path)
assert first == second
assert "sentinel" == second.read_text()
class TestAsyncHelperConstants:
@ -784,7 +1077,6 @@ class TestInjectAsyncProfilingIntoExistingTest:
test_file.write_text(test_code, encoding="utf-8")
func = make_function("target_func", "module.py", is_async=True)
# await target_func(1, 2) — the Call node is at col 25
positions = [CodePosition(line_no=4, col_no=25)]
ok, source = inject_async_profiling_into_existing_test(
@ -820,6 +1112,57 @@ class TestInjectAsyncProfilingIntoExistingTest:
assert ok is False
assert source is None
def test_syntax_error_returns_false(self, tmp_path: Path) -> None:
"""Returns (False, None) for a file with invalid Python."""
project_root = tmp_path / "project"
project_root.mkdir()
test_file = project_root / "test_bad.py"
test_file.write_text(
"async def test_x(\n not valid !!!",
encoding="utf-8",
)
func = make_function("target_func", "module.py", is_async=True)
ok, source = inject_async_profiling_into_existing_test(
test_file,
[CodePosition(line_no=1, col_no=0)],
func,
project_root,
)
assert ok is False
assert source is None
def test_multiple_awaits_get_sequential_ids(
self,
tmp_path: Path,
) -> None:
"""Multiple awaited calls in one test get sequential counter IDs."""
project_root = tmp_path / "project"
project_root.mkdir()
test_file = project_root / "test_multi.py"
test_code = textwrap.dedent("""\
from module import target_func
async def test_multi():
a = await target_func(1)
b = await target_func(2)
""")
test_file.write_text(test_code, encoding="utf-8")
func = make_function("target_func", "module.py", is_async=True)
positions = [
CodePosition(line_no=4, col_no=22),
CodePosition(line_no=5, col_no=22),
]
ok, source = inject_async_profiling_into_existing_test(
test_file,
positions,
func,
project_root,
)
assert ok is True
assert source is not None
assert source.count("CODEFLASH_CURRENT_LINE_ID") == 2
class TestAddAsyncDecoratorToFunction:
"""add_async_decorator_to_function source rewriting."""
@ -832,11 +1175,13 @@ class TestAddAsyncDecoratorToFunction:
encoding="utf-8",
)
func = make_function("target_func", str(source_file))
result, _ = add_async_decorator_to_function(
source_file, func, TestingMode.BEHAVIOR
result, originals = add_async_decorator_to_function(
source_file,
func,
TestingMode.BEHAVIOR,
)
assert result is False
# File unchanged
assert {} == originals
assert "decorator" not in source_file.read_text()
def test_async_function_gets_decorator(self, tmp_path: Path) -> None:
@ -853,18 +1198,41 @@ class TestAddAsyncDecoratorToFunction:
str(source_file),
is_async=True,
)
result, _ = add_async_decorator_to_function(
source_file, func, TestingMode.BEHAVIOR
result, originals = add_async_decorator_to_function(
source_file,
func,
TestingMode.BEHAVIOR,
)
assert result is True
modified = source_file.read_text()
assert "codeflash_behavior_async" in modified
# Helper file created in source dir (no project_root given)
helper = tmp_path / ASYNC_HELPER_FILENAME
assert helper.exists()
def test_originals_contains_pre_modification_source(
self,
tmp_path: Path,
) -> None:
"""The originals dict maps the file to its content before rewriting."""
source_file = tmp_path / "module.py"
original_code = "async def target_func():\n pass\n"
source_file.write_text(original_code, encoding="utf-8")
func = make_function(
"target_func",
str(source_file),
is_async=True,
)
_, originals = add_async_decorator_to_function(
source_file,
func,
TestingMode.BEHAVIOR,
)
assert source_file in originals
assert original_code == originals[source_file]
def test_with_explicit_project_root(self, tmp_path: Path) -> None:
"""Writes helper file to project_root when specified."""
src_dir = tmp_path / "src"
@ -891,8 +1259,72 @@ class TestAddAsyncDecoratorToFunction:
project_root=project_root,
)
assert result is True
# Helper in project_root, not in src_dir
assert (project_root / ASYNC_HELPER_FILENAME).exists()
assert not (src_dir / ASYNC_HELPER_FILENAME).exists()
def test_already_decorated_returns_false(self, tmp_path: Path) -> None:
"""Returns False when the function already has the decorator."""
source_file = tmp_path / "module.py"
source_code = textwrap.dedent("""\
@codeflash_behavior_async
async def target_func():
pass
""")
source_file.write_text(source_code, encoding="utf-8")
func = make_function(
"target_func",
str(source_file),
is_async=True,
)
result, originals = add_async_decorator_to_function(
source_file,
func,
TestingMode.BEHAVIOR,
)
assert result is False
assert {} == originals
def test_adds_import_for_decorator(self, tmp_path: Path) -> None:
"""The rewritten file includes the import for the decorator."""
source_file = tmp_path / "module.py"
source_file.write_text(
"async def target_func():\n pass\n",
encoding="utf-8",
)
func = make_function(
"target_func",
str(source_file),
is_async=True,
)
add_async_decorator_to_function(
source_file,
func,
TestingMode.PERFORMANCE,
)
modified = source_file.read_text()
assert "from codeflash_async_wrapper import" in modified
assert "codeflash_performance_async" in modified
def test_cst_parse_error_returns_false(self, tmp_path: Path) -> None:
"""Returns (False, {}) when the source file has invalid syntax."""
source_file = tmp_path / "module.py"
source_file.write_text(
"async def target_func(\n invalid !!!",
encoding="utf-8",
)
func = make_function(
"target_func",
str(source_file),
is_async=True,
)
result, originals = add_async_decorator_to_function(
source_file,
func,
TestingMode.BEHAVIOR,
)
assert result is False
assert {} == originals
class TestCreateInstrumentedSourceModulePath: