mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
Replace the fragile stdout tag protocol with a unified SQLite table (async_results) for all 3 async test modes. The new runtime decorators write behavior, performance, and concurrency results directly to the DB with zero stdout output. Test-file instrumentation now injects _codeflash_call_site.set() (contextvar) instead of os.environ assignments, which is correct for async task isolation. New modules: - runtime/_codeflash_async_decorators.py: self-contained decorators - testing/_async_data_parser.py: SQLite reader replacing stdout parsing Both at 100% test coverage (42 new tests).
1336 lines
48 KiB
Python
1336 lines
48 KiB
Python
"""Tests for _instrumentation — test instrumentation and AST transforms."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import textwrap
|
|
from pathlib import Path
|
|
|
|
import libcst as cst
|
|
|
|
from codeflash_python._model import (
|
|
FunctionParent,
|
|
FunctionToOptimize,
|
|
TestingMode,
|
|
VerificationType,
|
|
)
|
|
from codeflash_python.test_discovery.models import CodePosition
|
|
from codeflash_python.testing._instrumentation import (
|
|
ASYNC_HELPER_FILENAME,
|
|
AsyncCallInstrumenter,
|
|
AsyncDecoratorAdder,
|
|
FunctionCallNodeArguments,
|
|
FunctionImportedAsVisitor,
|
|
InjectPerfOnly,
|
|
add_async_decorator_to_function,
|
|
create_device_sync_precompute_statements,
|
|
create_device_sync_statements,
|
|
create_instrumented_source_module_path,
|
|
create_wrapper_function,
|
|
detect_frameworks_from_code,
|
|
get_call_arguments,
|
|
get_decorator_name_for_mode,
|
|
inject_async_profiling_into_existing_test,
|
|
inject_profiling_into_existing_test,
|
|
is_argument_name,
|
|
node_in_call_position,
|
|
sort_imports,
|
|
write_async_helper_file,
|
|
)
|
|
|
|
|
|
def make_function(
|
|
name: str = "target_func",
|
|
file_path: str = "module.py",
|
|
parents: tuple[FunctionParent, ...] = (),
|
|
*,
|
|
is_async: bool = False,
|
|
) -> FunctionToOptimize:
|
|
"""Create a FunctionToOptimize for testing."""
|
|
return FunctionToOptimize(
|
|
function_name=name,
|
|
file_path=Path(file_path),
|
|
parents=parents,
|
|
is_async=is_async,
|
|
)
|
|
|
|
|
|
class TestTestingMode:
|
|
"""TestingMode enum values."""
|
|
|
|
def test_enum_values(self) -> None:
|
|
"""Each mode has the expected string value."""
|
|
assert "behavior" == TestingMode.BEHAVIOR.value
|
|
assert "performance" == TestingMode.PERFORMANCE.value
|
|
assert "line_profile" == TestingMode.LINE_PROFILE.value
|
|
assert "concurrency" == TestingMode.CONCURRENCY.value
|
|
|
|
def test_membership(self) -> None:
|
|
"""All four members are present."""
|
|
assert 4 == len(TestingMode)
|
|
|
|
|
|
class TestVerificationType:
|
|
"""VerificationType str enum values."""
|
|
|
|
def test_is_str_enum(self) -> None:
|
|
"""VerificationType members are strings."""
|
|
assert isinstance(VerificationType.FUNCTION_CALL, str)
|
|
assert isinstance(VerificationType.INIT_STATE_FTO, str)
|
|
assert isinstance(VerificationType.INIT_STATE_HELPER, str)
|
|
|
|
def test_enum_values(self) -> None:
|
|
"""Each type has the expected string value."""
|
|
assert "function_call" == VerificationType.FUNCTION_CALL
|
|
assert "init_state_fto" == VerificationType.INIT_STATE_FTO
|
|
assert "init_state_helper" == VerificationType.INIT_STATE_HELPER
|
|
|
|
def test_membership(self) -> None:
|
|
"""All three members are present."""
|
|
assert 3 == len(VerificationType)
|
|
|
|
|
|
class TestGetCallArguments:
|
|
"""get_call_arguments Call node extraction."""
|
|
|
|
def test_simple_call(self) -> None:
|
|
"""Extracts positional args and keywords from a Call node."""
|
|
tree = ast.parse("func(1, 2, key='val')")
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
result = get_call_arguments(call_node)
|
|
assert isinstance(result, FunctionCallNodeArguments)
|
|
assert 2 == len(result.args)
|
|
assert 1 == len(result.keywords)
|
|
|
|
def test_no_args(self) -> None:
|
|
"""Returns empty lists for a call with no arguments."""
|
|
tree = ast.parse("func()")
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
result = get_call_arguments(call_node)
|
|
assert [] == result.args
|
|
assert [] == result.keywords
|
|
|
|
def test_only_keywords(self) -> None:
|
|
"""Returns keywords when only keyword args are present."""
|
|
tree = ast.parse("func(a=1, b=2)")
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
result = get_call_arguments(call_node)
|
|
assert [] == result.args
|
|
assert 2 == len(result.keywords)
|
|
|
|
|
|
class TestNodeInCallPosition:
|
|
"""node_in_call_position position matching."""
|
|
|
|
def test_matching_position(self) -> None:
|
|
"""Returns True when Call node matches a listed position."""
|
|
code = "target_func()\n"
|
|
tree = ast.parse(code)
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
positions = [CodePosition(line_no=1, col_no=0)]
|
|
assert node_in_call_position(call_node, positions) is True
|
|
|
|
def test_no_matching_position(self) -> None:
|
|
"""Returns False when Call node does not match any position."""
|
|
code = "target_func()\n"
|
|
tree = ast.parse(code)
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
positions = [CodePosition(line_no=99, col_no=99)]
|
|
assert node_in_call_position(call_node, positions) is False
|
|
|
|
def test_empty_positions_matches_all(self) -> None:
|
|
"""Returns True for any Call when positions list is empty."""
|
|
code = "target_func()\n"
|
|
tree = ast.parse(code)
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
assert node_in_call_position(call_node, []) is True
|
|
|
|
def test_multiple_positions_one_match(self) -> None:
|
|
"""Returns True when one of several positions matches."""
|
|
code = "target_func()\n"
|
|
tree = ast.parse(code)
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
positions = [
|
|
CodePosition(line_no=50, col_no=0),
|
|
CodePosition(line_no=1, col_no=0),
|
|
]
|
|
assert node_in_call_position(call_node, positions) is True
|
|
|
|
|
|
class TestIsArgumentName:
|
|
"""is_argument_name argument detection."""
|
|
|
|
def test_regular_arg(self) -> None:
|
|
"""Returns True for a regular positional argument name."""
|
|
code = "def f(x, y): pass"
|
|
tree = ast.parse(code)
|
|
func_def = tree.body[0]
|
|
assert is_argument_name("x", func_def.args) is True # type: ignore[attr-defined]
|
|
|
|
def test_kwonly_arg(self) -> None:
|
|
"""Returns True for a keyword-only argument name."""
|
|
code = "def f(*, key): pass"
|
|
tree = ast.parse(code)
|
|
func_def = tree.body[0]
|
|
assert is_argument_name("key", func_def.args) is True # type: ignore[attr-defined]
|
|
|
|
def test_no_match(self) -> None:
|
|
"""Returns False when name is not an argument."""
|
|
code = "def f(x, y): pass"
|
|
tree = ast.parse(code)
|
|
func_def = tree.body[0]
|
|
assert is_argument_name("z", func_def.args) is False # type: ignore[attr-defined]
|
|
|
|
def test_vararg_not_matched(self) -> None:
|
|
"""Returns False for *args (vararg is not a list attribute)."""
|
|
code = "def f(*args): pass"
|
|
tree = ast.parse(code)
|
|
func_def = tree.body[0]
|
|
assert is_argument_name("args", func_def.args) is False # type: ignore[attr-defined]
|
|
|
|
def test_kwarg_not_matched(self) -> None:
|
|
"""Returns False for **kwargs (kwarg is not a list attribute)."""
|
|
code = "def f(**kwargs): pass"
|
|
tree = ast.parse(code)
|
|
func_def = tree.body[0]
|
|
assert is_argument_name("kwargs", func_def.args) is False # type: ignore[attr-defined]
|
|
|
|
|
|
class TestDetectFrameworksFromCode:
|
|
"""detect_frameworks_from_code import detection."""
|
|
|
|
def test_torch_import(self) -> None:
|
|
"""Detects torch from 'import torch'."""
|
|
code = "import torch\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert "torch" in result
|
|
|
|
def test_tensorflow_import(self) -> None:
|
|
"""Detects tensorflow from 'import tensorflow'."""
|
|
code = "import tensorflow\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert "tensorflow" in result
|
|
|
|
def test_jax_import(self) -> None:
|
|
"""Detects jax from 'import jax'."""
|
|
code = "import jax\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert "jax" in result
|
|
|
|
def test_aliased_import(self) -> None:
|
|
"""Detects framework with alias from 'import torch as th'."""
|
|
code = "import torch as th\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert "torch" in result
|
|
assert "th" == result["torch"]
|
|
|
|
def test_no_frameworks(self) -> None:
|
|
"""Returns empty dict when no GPU frameworks are imported."""
|
|
code = "import os\nimport sys\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert {} == result
|
|
|
|
def test_from_import_submodule(self) -> None:
|
|
"""Detects framework from 'from torch import nn'."""
|
|
code = "from torch import nn\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert "torch" in result
|
|
|
|
def test_multiple_frameworks(self) -> None:
|
|
"""Detects multiple frameworks in the same code."""
|
|
code = "import torch\nimport jax\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert "torch" in result
|
|
assert "jax" in result
|
|
|
|
|
|
class TestGetDecoratorNameForMode:
|
|
"""get_decorator_name_for_mode decorator selection."""
|
|
|
|
def test_behavior_mode(self) -> None:
|
|
"""Returns codeflash_behavior_async for BEHAVIOR."""
|
|
assert "codeflash_behavior_async" == get_decorator_name_for_mode(
|
|
TestingMode.BEHAVIOR,
|
|
)
|
|
|
|
def test_performance_mode(self) -> None:
|
|
"""Returns codeflash_performance_async for PERFORMANCE."""
|
|
assert "codeflash_performance_async" == get_decorator_name_for_mode(
|
|
TestingMode.PERFORMANCE,
|
|
)
|
|
|
|
def test_concurrency_mode(self) -> None:
|
|
"""Returns codeflash_concurrency_async for CONCURRENCY."""
|
|
assert "codeflash_concurrency_async" == get_decorator_name_for_mode(
|
|
TestingMode.CONCURRENCY,
|
|
)
|
|
|
|
def test_line_profile_falls_through_to_performance(self) -> None:
|
|
"""LINE_PROFILE is not async-specific, falls through to performance."""
|
|
assert "codeflash_performance_async" == get_decorator_name_for_mode(
|
|
TestingMode.LINE_PROFILE,
|
|
)
|
|
|
|
|
|
class TestCreateDeviceSyncPrecomputeStatements:
|
|
"""create_device_sync_precompute_statements AST generation."""
|
|
|
|
def test_none_frameworks(self) -> None:
|
|
"""Returns empty list when frameworks is None."""
|
|
result = create_device_sync_precompute_statements(None)
|
|
assert [] == result
|
|
|
|
def test_empty_dict(self) -> None:
|
|
"""Returns empty list when frameworks dict is empty."""
|
|
result = create_device_sync_precompute_statements({})
|
|
assert [] == result
|
|
|
|
def test_torch_produces_statements(self) -> None:
|
|
"""Produces AST statements for torch."""
|
|
result = create_device_sync_precompute_statements(
|
|
{"torch": "torch"},
|
|
)
|
|
assert len(result) > 0
|
|
assert all(isinstance(s, ast.stmt) for s in result)
|
|
|
|
def test_jax_produces_statements(self) -> None:
|
|
"""Produces AST statements for jax."""
|
|
result = create_device_sync_precompute_statements({"jax": "jax"})
|
|
assert len(result) > 0
|
|
assert all(isinstance(s, ast.stmt) for s in result)
|
|
|
|
def test_tensorflow_produces_statements(self) -> None:
|
|
"""Produces AST statements for tensorflow."""
|
|
result = create_device_sync_precompute_statements(
|
|
{"tensorflow": "tf"},
|
|
)
|
|
assert len(result) > 0
|
|
assert all(isinstance(s, ast.stmt) for s in result)
|
|
|
|
def test_combined_frameworks(self) -> None:
|
|
"""Produces statements for multiple frameworks."""
|
|
result = create_device_sync_precompute_statements(
|
|
{"torch": "torch", "jax": "jax"},
|
|
)
|
|
assert len(result) > 0
|
|
|
|
|
|
class TestCreateDeviceSyncStatements:
|
|
"""create_device_sync_statements AST generation."""
|
|
|
|
def test_none_frameworks(self) -> None:
|
|
"""Returns empty list when frameworks is None."""
|
|
result = create_device_sync_statements(None)
|
|
assert [] == result
|
|
|
|
def test_empty_dict(self) -> None:
|
|
"""Returns empty list when frameworks dict is empty."""
|
|
result = create_device_sync_statements({})
|
|
assert [] == result
|
|
|
|
def test_torch_sync(self) -> None:
|
|
"""Produces sync statements for torch."""
|
|
result = create_device_sync_statements({"torch": "torch"})
|
|
assert len(result) > 0
|
|
assert all(isinstance(s, ast.stmt) for s in result)
|
|
|
|
def test_for_return_value_flag(self) -> None:
|
|
"""Produces statements with for_return_value=True."""
|
|
result = create_device_sync_statements(
|
|
{"jax": "jax"},
|
|
for_return_value=True,
|
|
)
|
|
assert len(result) > 0
|
|
assert all(isinstance(s, ast.stmt) for s in result)
|
|
|
|
def test_tensorflow_sync(self) -> None:
|
|
"""Produces sync statements for tensorflow."""
|
|
result = create_device_sync_statements({"tensorflow": "tf"})
|
|
assert len(result) > 0
|
|
assert all(isinstance(s, ast.stmt) for s in result)
|
|
|
|
|
|
class TestCreateWrapperFunction:
|
|
"""create_wrapper_function AST generation."""
|
|
|
|
def test_returns_function_def(self) -> None:
|
|
"""Returns an ast.FunctionDef node."""
|
|
result = create_wrapper_function(TestingMode.BEHAVIOR)
|
|
assert isinstance(result, ast.FunctionDef)
|
|
|
|
def test_function_name(self) -> None:
|
|
"""The generated function is named codeflash_wrap."""
|
|
result = create_wrapper_function(TestingMode.BEHAVIOR)
|
|
assert "codeflash_wrap" == result.name
|
|
|
|
def test_behavior_mode_params(self) -> None:
|
|
"""BEHAVIOR mode wrapper has expected parameters."""
|
|
result = create_wrapper_function(TestingMode.BEHAVIOR)
|
|
arg_names = [a.arg for a in result.args.args]
|
|
assert len(arg_names) > 0
|
|
|
|
def test_performance_mode_params(self) -> None:
|
|
"""PERFORMANCE mode wrapper has expected parameters."""
|
|
result = create_wrapper_function(TestingMode.PERFORMANCE)
|
|
arg_names = [a.arg for a in result.args.args]
|
|
assert len(arg_names) > 0
|
|
|
|
def test_body_is_nonempty(self) -> None:
|
|
"""The function body contains statements."""
|
|
result = create_wrapper_function(TestingMode.BEHAVIOR)
|
|
assert len(result.body) > 0
|
|
|
|
def test_with_frameworks(self) -> None:
|
|
"""Accepts used_frameworks parameter without error."""
|
|
result = create_wrapper_function(
|
|
TestingMode.PERFORMANCE,
|
|
used_frameworks={"torch": "torch"},
|
|
)
|
|
assert isinstance(result, ast.FunctionDef)
|
|
|
|
|
|
class TestInjectPerfOnly:
|
|
"""InjectPerfOnly AST transformer."""
|
|
|
|
def test_wraps_name_call(self) -> None:
|
|
"""Wraps a direct Name call with codeflash_wrap."""
|
|
code = textwrap.dedent("""\
|
|
def test_it():
|
|
result = target_func(1, 2)
|
|
""")
|
|
tree = ast.parse(code)
|
|
call_node = tree.body[0].body[0].value # type: ignore[attr-defined]
|
|
pos = CodePosition(
|
|
line_no=call_node.lineno,
|
|
col_no=call_node.col_offset,
|
|
)
|
|
func = make_function("target_func", "module.py")
|
|
transformer = InjectPerfOnly(
|
|
function=func,
|
|
module_path="module",
|
|
call_positions=[pos],
|
|
mode=TestingMode.BEHAVIOR,
|
|
)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "codeflash_wrap" in source
|
|
|
|
def test_wraps_attribute_call(self) -> None:
|
|
"""Wraps a module.func() attribute call with codeflash_wrap."""
|
|
code = textwrap.dedent("""\
|
|
def test_it():
|
|
result = module.target_func(1, 2)
|
|
""")
|
|
tree = ast.parse(code)
|
|
call_node = tree.body[0].body[0].value # type: ignore[attr-defined]
|
|
pos = CodePosition(
|
|
line_no=call_node.lineno,
|
|
col_no=call_node.col_offset,
|
|
)
|
|
func = make_function("target_func", "module.py")
|
|
transformer = InjectPerfOnly(
|
|
function=func,
|
|
module_path="module",
|
|
call_positions=[pos],
|
|
mode=TestingMode.BEHAVIOR,
|
|
)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "codeflash_wrap" in source
|
|
|
|
def test_no_wrap_without_matching_position(self) -> None:
|
|
"""Does not wrap calls that are not in call_positions."""
|
|
code = textwrap.dedent("""\
|
|
def test_it():
|
|
result = target_func(1, 2)
|
|
""")
|
|
tree = ast.parse(code)
|
|
func = make_function("target_func", "module.py")
|
|
transformer = InjectPerfOnly(
|
|
function=func,
|
|
module_path="module",
|
|
call_positions=[CodePosition(line_no=99, col_no=99)],
|
|
mode=TestingMode.BEHAVIOR,
|
|
)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "codeflash_wrap" not in source
|
|
|
|
|
|
class TestAsyncCallInstrumenter:
|
|
"""AsyncCallInstrumenter AST transformer."""
|
|
|
|
def _make_transformer(
|
|
self,
|
|
code: str,
|
|
*,
|
|
name: str = "target_func",
|
|
parents: tuple[FunctionParent, ...] = (),
|
|
positions: list[CodePosition] | None = None,
|
|
) -> tuple[AsyncCallInstrumenter, ast.Module]:
|
|
"""Parse code and build a transformer with call positions from it."""
|
|
tree = ast.parse(code)
|
|
if positions is None:
|
|
positions = []
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.Call):
|
|
func_name = None
|
|
if isinstance(node.func, ast.Name):
|
|
func_name = node.func.id
|
|
elif isinstance(node.func, ast.Attribute):
|
|
func_name = node.func.attr
|
|
if func_name == name:
|
|
positions.append(
|
|
CodePosition(
|
|
line_no=node.lineno,
|
|
col_no=node.col_offset,
|
|
)
|
|
)
|
|
func = make_function(
|
|
name,
|
|
"module.py",
|
|
parents=parents,
|
|
is_async=True,
|
|
)
|
|
transformer = AsyncCallInstrumenter(
|
|
function=func,
|
|
call_positions=positions,
|
|
)
|
|
return transformer, tree
|
|
|
|
def test_instruments_await_call(self) -> None:
|
|
"""Adds call-site contextvar set before an awaited target call."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "_codeflash_call_site.set(" in source
|
|
assert transformer.did_instrument is True
|
|
|
|
def test_skips_non_test_async_functions(self) -> None:
|
|
"""Does not instrument async functions that don't start with test_."""
|
|
code = textwrap.dedent("""\
|
|
async def helper():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "_codeflash_call_site" not in source
|
|
assert transformer.did_instrument is False
|
|
|
|
def test_skips_non_test_sync_functions(self) -> None:
|
|
"""Does not instrument sync functions that don't start with test_."""
|
|
code = textwrap.dedent("""\
|
|
def helper():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "_codeflash_call_site" not in source
|
|
|
|
def test_instruments_sync_test_with_await(self) -> None:
|
|
"""Instruments sync test_ functions that contain awaited calls."""
|
|
code = textwrap.dedent("""\
|
|
def test_it():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "_codeflash_call_site.set(" in source
|
|
assert transformer.did_instrument is True
|
|
|
|
def test_multiple_awaits_get_incrementing_ids(self) -> None:
|
|
"""Each awaited target call gets a unique incrementing counter."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
a = await target_func(1)
|
|
b = await target_func(2)
|
|
c = await target_func(3)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "'0'" in source
|
|
assert "'1'" in source
|
|
assert "'2'" in source
|
|
|
|
def test_attribute_style_call(self) -> None:
|
|
"""Instruments await obj.target_func() attribute-style calls."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
result = await obj.target_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "_codeflash_call_site.set(" in source
|
|
assert transformer.did_instrument is True
|
|
|
|
def test_recurses_into_class_body(self) -> None:
|
|
"""Finds and instruments test methods inside a class."""
|
|
code = textwrap.dedent("""\
|
|
class TestSuite:
|
|
async def test_it(self):
|
|
result = await target_func(1)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "_codeflash_call_site.set(" in source
|
|
assert transformer.did_instrument is True
|
|
|
|
def test_no_match_when_position_wrong(self) -> None:
|
|
"""Does not instrument when call positions don't match."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(
|
|
code,
|
|
positions=[CodePosition(line_no=99, col_no=99)],
|
|
)
|
|
new_tree = transformer.visit(tree)
|
|
assert transformer.did_instrument is False
|
|
|
|
def test_nested_await_in_conditional(self) -> None:
|
|
"""Finds awaited target calls nested inside if statements."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
if True:
|
|
result = await target_func(1)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "_codeflash_call_site.set(" in source
|
|
|
|
def test_ignores_non_target_awaits(self) -> None:
|
|
"""Does not instrument awaits of unrelated functions."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
result = await other_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
assert transformer.did_instrument is False
|
|
|
|
def test_subscript_style_call_not_matched(self) -> None:
|
|
"""Await of a subscript-style call (funcs[0]()) is not matched."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
result = await funcs[0](1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
transformer.visit(tree)
|
|
assert transformer.did_instrument is False
|
|
|
|
def test_counters_independent_per_test_function(self) -> None:
|
|
"""Each test function gets its own independent counter."""
|
|
code = textwrap.dedent("""\
|
|
async def test_a():
|
|
await target_func(1)
|
|
await target_func(2)
|
|
async def test_b():
|
|
await target_func(3)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
transformer.visit(tree)
|
|
assert 2 == transformer.async_call_counter["test_a"]
|
|
assert 1 == transformer.async_call_counter["test_b"]
|
|
|
|
|
|
class TestFunctionImportedAsVisitor:
|
|
"""FunctionImportedAsVisitor alias detection."""
|
|
|
|
def test_aliased_import(self) -> None:
|
|
"""Updates imported_as with the aliased FunctionToOptimize."""
|
|
code = textwrap.dedent("""\
|
|
from module import target_func as tf
|
|
""")
|
|
tree = ast.parse(code)
|
|
func = make_function("target_func", "module.py")
|
|
visitor = FunctionImportedAsVisitor(func)
|
|
visitor.visit(tree)
|
|
assert "tf" == visitor.imported_as.function_name
|
|
|
|
def test_non_aliased_import(self) -> None:
|
|
"""Keeps original function when imported without alias."""
|
|
code = textwrap.dedent("""\
|
|
from module import target_func
|
|
""")
|
|
tree = ast.parse(code)
|
|
func = make_function("target_func", "module.py")
|
|
visitor = FunctionImportedAsVisitor(func)
|
|
visitor.visit(tree)
|
|
assert visitor.imported_as is func
|
|
|
|
def test_class_method_aliased_import(self) -> None:
|
|
"""Updates parent name when class is imported with alias."""
|
|
code = textwrap.dedent("""\
|
|
from module import MyClass as MC
|
|
""")
|
|
tree = ast.parse(code)
|
|
parent = FunctionParent(name="MyClass", type="ClassDef")
|
|
func = make_function(
|
|
"method",
|
|
"module.py",
|
|
parents=(parent,),
|
|
)
|
|
visitor = FunctionImportedAsVisitor(func)
|
|
visitor.visit(tree)
|
|
assert "MC" == visitor.imported_as.parents[0].name
|
|
|
|
def test_no_import(self) -> None:
|
|
"""Keeps original function when not imported."""
|
|
code = textwrap.dedent("""\
|
|
import os
|
|
""")
|
|
tree = ast.parse(code)
|
|
func = make_function("target_func", "module.py")
|
|
visitor = FunctionImportedAsVisitor(func)
|
|
visitor.visit(tree)
|
|
assert visitor.imported_as is func
|
|
|
|
|
|
class TestAsyncDecoratorAdder:
|
|
"""AsyncDecoratorAdder CST transformer."""
|
|
|
|
def test_adds_decorator_to_async_function(self) -> None:
|
|
"""Adds the async decorator to a matching async function."""
|
|
code = textwrap.dedent("""\
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
output = new_tree.code
|
|
assert "@codeflash_behavior_async" in output
|
|
assert transformer.added_decorator is True
|
|
|
|
def test_does_not_add_to_non_matching(self) -> None:
|
|
"""Does not add decorator to functions that do not match."""
|
|
code = textwrap.dedent("""\
|
|
async def other_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
output = new_tree.code
|
|
assert "@" not in output
|
|
assert transformer.added_decorator is False
|
|
|
|
def test_does_not_add_to_sync_function(self) -> None:
|
|
"""Skips sync functions even if the name matches."""
|
|
code = textwrap.dedent("""\
|
|
def target_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
assert "@" not in new_tree.code
|
|
assert transformer.added_decorator is False
|
|
|
|
def test_performance_mode_decorator(self) -> None:
|
|
"""Uses codeflash_performance_async for PERFORMANCE mode."""
|
|
code = "async def target_func():\n pass\n"
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(
|
|
func,
|
|
mode=TestingMode.PERFORMANCE,
|
|
)
|
|
new_tree = tree.visit(transformer)
|
|
assert "@codeflash_performance_async" in new_tree.code
|
|
|
|
def test_concurrency_mode_decorator(self) -> None:
|
|
"""Uses codeflash_concurrency_async for CONCURRENCY mode."""
|
|
code = "async def target_func():\n pass\n"
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(
|
|
func,
|
|
mode=TestingMode.CONCURRENCY,
|
|
)
|
|
new_tree = tree.visit(transformer)
|
|
assert "@codeflash_concurrency_async" in new_tree.code
|
|
|
|
def test_class_method_matching(self) -> None:
|
|
"""Matches async method inside a class via qualified name."""
|
|
code = textwrap.dedent("""\
|
|
class MyClass:
|
|
async def target_func(self):
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
parent = FunctionParent(name="MyClass", type="ClassDef")
|
|
func = make_function(
|
|
"target_func",
|
|
"module.py",
|
|
parents=(parent,),
|
|
is_async=True,
|
|
)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
assert "@codeflash_behavior_async" in new_tree.code
|
|
assert transformer.added_decorator is True
|
|
|
|
def test_no_duplicate_when_already_decorated(self) -> None:
|
|
"""Does not add a second decorator when one is already present."""
|
|
code = textwrap.dedent("""\
|
|
@codeflash_behavior_async
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
assert new_tree.code.count("codeflash_behavior_async") == 1
|
|
assert transformer.added_decorator is False
|
|
|
|
def test_no_duplicate_when_call_style_decorator(self) -> None:
|
|
"""Detects existing decorator even in @decorator() call form."""
|
|
code = textwrap.dedent("""\
|
|
@codeflash_behavior_async()
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
assert new_tree.code.count("codeflash_behavior_async") == 1
|
|
assert transformer.added_decorator is False
|
|
|
|
def test_preserves_existing_decorators(self) -> None:
|
|
"""Keeps existing decorators and prepends the codeflash one."""
|
|
code = textwrap.dedent("""\
|
|
@staticmethod
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
output = new_tree.code
|
|
assert "@codeflash_behavior_async" in output
|
|
assert "@staticmethod" in output
|
|
behavior_pos = output.index("@codeflash_behavior_async")
|
|
static_pos = output.index("@staticmethod")
|
|
assert behavior_pos < static_pos
|
|
|
|
def test_attribute_decorator_not_matched(self) -> None:
|
|
"""Attribute-style decorators (mod.decorator) are not codeflash."""
|
|
code = textwrap.dedent("""\
|
|
@mod.codeflash_behavior_async
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
assert new_tree.code.count("codeflash_behavior_async") == 2
|
|
assert transformer.added_decorator is True
|
|
|
|
|
|
class TestWriteAsyncHelperFile:
|
|
"""write_async_helper_file file creation."""
|
|
|
|
def test_creates_file(self, tmp_path: Path) -> None:
|
|
"""Creates the async helper file in the target directory."""
|
|
result = write_async_helper_file(tmp_path)
|
|
assert result.exists()
|
|
assert result.is_file()
|
|
|
|
def test_file_name(self, tmp_path: Path) -> None:
|
|
"""The created file has the expected name."""
|
|
result = write_async_helper_file(tmp_path)
|
|
assert ASYNC_HELPER_FILENAME == result.name
|
|
|
|
def test_file_content_is_valid_python(self, tmp_path: Path) -> None:
|
|
"""The copied file is valid, parseable Python."""
|
|
result = write_async_helper_file(tmp_path)
|
|
content = result.read_text()
|
|
assert len(content) > 0
|
|
ast.parse(content)
|
|
|
|
def test_idempotent_does_not_overwrite(self, tmp_path: Path) -> None:
|
|
"""Calling twice does not overwrite the existing file."""
|
|
first = write_async_helper_file(tmp_path)
|
|
first.write_text("sentinel", encoding="utf-8")
|
|
second = write_async_helper_file(tmp_path)
|
|
assert first == second
|
|
assert "sentinel" == second.read_text()
|
|
|
|
|
|
class TestAsyncHelperConstants:
|
|
"""ASYNC_HELPER_FILENAME and runtime decorator source."""
|
|
|
|
def test_filename_value(self) -> None:
|
|
"""ASYNC_HELPER_FILENAME has the expected value."""
|
|
assert "codeflash_async_wrapper.py" == ASYNC_HELPER_FILENAME
|
|
|
|
def test_runtime_decorator_is_self_contained(self) -> None:
|
|
"""Runtime decorator file has no internal codeflash imports."""
|
|
from codeflash_python.testing._instrument_async import (
|
|
_RUNTIME_DECORATOR_PATH,
|
|
)
|
|
|
|
source = _RUNTIME_DECORATOR_PATH.read_text("utf-8")
|
|
assert _RUNTIME_DECORATOR_PATH.exists()
|
|
for line in source.splitlines():
|
|
stripped = line.strip()
|
|
if stripped.startswith(("import ", "from ")):
|
|
assert "codeflash_python" not in stripped
|
|
|
|
|
|
class TestSortImports:
|
|
"""sort_imports import sorting."""
|
|
|
|
def test_sorts_unsorted_imports(self) -> None:
|
|
"""Sorts and deduplicates unsorted imports."""
|
|
code = textwrap.dedent("""\
|
|
import os
|
|
import ast
|
|
import os
|
|
""")
|
|
result = sort_imports(code)
|
|
lines = result.strip().splitlines()
|
|
assert "import ast" in lines[0]
|
|
assert "import os" in lines[1]
|
|
# Duplicate removed
|
|
assert result.count("import os") == 1
|
|
|
|
def test_syntax_error_returns_original(self) -> None:
|
|
"""Returns original code unchanged when isort encounters issues."""
|
|
code = "import os\nimport ast\n"
|
|
# isort handles most inputs gracefully; verify normal code works
|
|
result = sort_imports(code)
|
|
assert isinstance(result, str)
|
|
assert "import" in result
|
|
|
|
def test_kwargs_forwarded(self) -> None:
|
|
"""Forwards kwargs to isort.code (e.g. float_to_top)."""
|
|
code = textwrap.dedent("""\
|
|
from os.path import join
|
|
|
|
x = 1
|
|
|
|
import ast
|
|
""")
|
|
result = sort_imports(code, float_to_top=True)
|
|
# With float_to_top, imports should be grouped at the top
|
|
lines = result.strip().splitlines()
|
|
# Both imports should appear before 'x = 1'
|
|
import_lines = [i for i, line in enumerate(lines) if "import" in line]
|
|
code_lines = [
|
|
i for i, line in enumerate(lines) if line.strip() == "x = 1"
|
|
]
|
|
if import_lines and code_lines:
|
|
assert max(import_lines) < min(code_lines)
|
|
|
|
|
|
class TestInjectProfilingIntoExistingTest:
|
|
"""inject_profiling_into_existing_test orchestration."""
|
|
|
|
def test_sync_function_instrumentation(self, tmp_path: Path) -> None:
|
|
"""Instruments a sync test file with codeflash_wrap and imports."""
|
|
project_root = tmp_path / "project"
|
|
project_root.mkdir()
|
|
test_file = project_root / "test_example.py"
|
|
test_code = textwrap.dedent("""\
|
|
from module import target_func
|
|
|
|
def test_something():
|
|
result = target_func(1, 2)
|
|
assert result == 3
|
|
""")
|
|
test_file.write_text(test_code, encoding="utf-8")
|
|
|
|
func = make_function("target_func", "module.py")
|
|
# target_func(1, 2) is on line 4, col 13
|
|
positions = [CodePosition(line_no=4, col_no=13)]
|
|
|
|
ok, source = inject_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
project_root,
|
|
mode=TestingMode.PERFORMANCE,
|
|
)
|
|
assert ok is True
|
|
assert source is not None
|
|
assert "codeflash_wrap" in source
|
|
assert "import time" in source
|
|
assert "import gc" in source
|
|
assert "import os" in source
|
|
|
|
def test_async_delegation(self, tmp_path: Path) -> None:
|
|
"""Delegates to async handler for async functions without error."""
|
|
project_root = tmp_path / "project"
|
|
project_root.mkdir()
|
|
test_file = project_root / "test_async.py"
|
|
test_code = textwrap.dedent("""\
|
|
from module import target_func
|
|
|
|
async def test_something():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
test_file.write_text(test_code, encoding="utf-8")
|
|
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
positions = [CodePosition(line_no=4, col_no=25)]
|
|
|
|
ok, source = inject_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
project_root,
|
|
)
|
|
# Should delegate to async path and return a result
|
|
assert isinstance(ok, bool)
|
|
if ok:
|
|
assert source is not None
|
|
|
|
def test_syntax_error_returns_false(self, tmp_path: Path) -> None:
|
|
"""Returns (False, None) for a file with invalid Python."""
|
|
project_root = tmp_path / "project"
|
|
project_root.mkdir()
|
|
test_file = project_root / "test_bad.py"
|
|
test_file.write_text(
|
|
"def test_x(\n not valid python !!!",
|
|
encoding="utf-8",
|
|
)
|
|
|
|
func = make_function("target_func", "module.py")
|
|
positions = [CodePosition(line_no=1, col_no=0)]
|
|
|
|
ok, source = inject_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
project_root,
|
|
)
|
|
assert ok is False
|
|
assert source is None
|
|
|
|
def test_behavior_mode_extra_imports(self, tmp_path: Path) -> None:
|
|
"""BEHAVIOR mode adds inspect, sqlite3, and dill imports."""
|
|
project_root = tmp_path / "project"
|
|
project_root.mkdir()
|
|
test_file = project_root / "test_behav.py"
|
|
test_code = textwrap.dedent("""\
|
|
from module import target_func
|
|
|
|
def test_something():
|
|
result = target_func(1, 2)
|
|
assert result == 3
|
|
""")
|
|
test_file.write_text(test_code, encoding="utf-8")
|
|
|
|
func = make_function("target_func", "module.py")
|
|
positions = [CodePosition(line_no=4, col_no=13)]
|
|
|
|
ok, source = inject_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
project_root,
|
|
mode=TestingMode.BEHAVIOR,
|
|
)
|
|
assert ok is True
|
|
assert source is not None
|
|
assert "inspect" in source
|
|
assert "sqlite3" in source
|
|
assert "dill" in source
|
|
|
|
|
|
class TestInjectAsyncProfilingIntoExistingTest:
|
|
"""inject_async_profiling_into_existing_test orchestration."""
|
|
|
|
def test_async_instrumentation(self, tmp_path: Path) -> None:
|
|
"""Instruments an async test file and adds import os."""
|
|
test_file = tmp_path / "test_async_ex.py"
|
|
test_code = textwrap.dedent("""\
|
|
from module import target_func
|
|
|
|
async def test_something():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
test_file.write_text(test_code, encoding="utf-8")
|
|
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
positions = [CodePosition(line_no=4, col_no=25)]
|
|
|
|
ok, source = inject_async_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
)
|
|
assert ok is True
|
|
assert source is not None
|
|
assert (
|
|
"from codeflash_async_wrapper import _codeflash_call_site"
|
|
in source
|
|
)
|
|
|
|
def test_no_instrumentation(self, tmp_path: Path) -> None:
|
|
"""Returns (False, None) when test does not call target."""
|
|
test_file = tmp_path / "test_no_call.py"
|
|
test_code = textwrap.dedent("""\
|
|
def test_something():
|
|
assert 1 == 1
|
|
""")
|
|
test_file.write_text(test_code, encoding="utf-8")
|
|
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
positions = [CodePosition(line_no=2, col_no=0)]
|
|
|
|
ok, source = inject_async_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
)
|
|
assert ok is False
|
|
assert source is None
|
|
|
|
def test_syntax_error_returns_false(self, tmp_path: Path) -> None:
|
|
"""Returns (False, None) for a file with invalid Python."""
|
|
test_file = tmp_path / "test_bad.py"
|
|
test_file.write_text(
|
|
"async def test_x(\n not valid !!!",
|
|
encoding="utf-8",
|
|
)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
ok, source = inject_async_profiling_into_existing_test(
|
|
test_file,
|
|
[CodePosition(line_no=1, col_no=0)],
|
|
func,
|
|
)
|
|
assert ok is False
|
|
assert source is None
|
|
|
|
def test_multiple_awaits_get_sequential_ids(
|
|
self,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Multiple awaited calls in one test get sequential counter IDs."""
|
|
test_file = tmp_path / "test_multi.py"
|
|
test_code = textwrap.dedent("""\
|
|
from module import target_func
|
|
|
|
async def test_multi():
|
|
a = await target_func(1)
|
|
b = await target_func(2)
|
|
""")
|
|
test_file.write_text(test_code, encoding="utf-8")
|
|
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
positions = [
|
|
CodePosition(line_no=4, col_no=22),
|
|
CodePosition(line_no=5, col_no=22),
|
|
]
|
|
ok, source = inject_async_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
)
|
|
assert ok is True
|
|
assert source is not None
|
|
assert source.count("_codeflash_call_site.set(") == 2
|
|
|
|
|
|
class TestAddAsyncDecoratorToFunction:
|
|
"""add_async_decorator_to_function source rewriting."""
|
|
|
|
def test_non_async_returns_false(self, tmp_path: Path) -> None:
|
|
"""Returns False for a non-async function without modifying."""
|
|
source_file = tmp_path / "module.py"
|
|
source_file.write_text(
|
|
"def target_func():\n pass\n",
|
|
encoding="utf-8",
|
|
)
|
|
func = make_function("target_func", str(source_file))
|
|
result, originals = add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.BEHAVIOR,
|
|
)
|
|
assert result is False
|
|
assert {} == originals
|
|
assert "decorator" not in source_file.read_text()
|
|
|
|
def test_async_function_gets_decorator(self, tmp_path: Path) -> None:
|
|
"""Adds decorator, rewrites file, creates helper file."""
|
|
source_file = tmp_path / "module.py"
|
|
source_code = textwrap.dedent("""\
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
source_file.write_text(source_code, encoding="utf-8")
|
|
|
|
func = make_function(
|
|
"target_func",
|
|
str(source_file),
|
|
is_async=True,
|
|
)
|
|
result, originals = add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.BEHAVIOR,
|
|
)
|
|
assert result is True
|
|
|
|
modified = source_file.read_text()
|
|
assert "codeflash_behavior_async" in modified
|
|
|
|
helper = tmp_path / ASYNC_HELPER_FILENAME
|
|
assert helper.exists()
|
|
|
|
def test_originals_contains_pre_modification_source(
|
|
self,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""The originals dict maps the file to its content before rewriting."""
|
|
source_file = tmp_path / "module.py"
|
|
original_code = "async def target_func():\n pass\n"
|
|
source_file.write_text(original_code, encoding="utf-8")
|
|
|
|
func = make_function(
|
|
"target_func",
|
|
str(source_file),
|
|
is_async=True,
|
|
)
|
|
_, originals = add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.BEHAVIOR,
|
|
)
|
|
assert source_file in originals
|
|
assert original_code == originals[source_file]
|
|
|
|
def test_with_explicit_project_root(self, tmp_path: Path) -> None:
|
|
"""Writes helper file to project_root when specified."""
|
|
src_dir = tmp_path / "src"
|
|
src_dir.mkdir()
|
|
source_file = src_dir / "module.py"
|
|
source_code = textwrap.dedent("""\
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
source_file.write_text(source_code, encoding="utf-8")
|
|
|
|
project_root = tmp_path / "root"
|
|
project_root.mkdir()
|
|
|
|
func = make_function(
|
|
"target_func",
|
|
str(source_file),
|
|
is_async=True,
|
|
)
|
|
result, _ = add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.BEHAVIOR,
|
|
project_root=project_root,
|
|
)
|
|
assert result is True
|
|
assert (project_root / ASYNC_HELPER_FILENAME).exists()
|
|
assert not (src_dir / ASYNC_HELPER_FILENAME).exists()
|
|
|
|
def test_already_decorated_returns_false(self, tmp_path: Path) -> None:
|
|
"""Returns False when the function already has the decorator."""
|
|
source_file = tmp_path / "module.py"
|
|
source_code = textwrap.dedent("""\
|
|
@codeflash_behavior_async
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
source_file.write_text(source_code, encoding="utf-8")
|
|
|
|
func = make_function(
|
|
"target_func",
|
|
str(source_file),
|
|
is_async=True,
|
|
)
|
|
result, originals = add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.BEHAVIOR,
|
|
)
|
|
assert result is False
|
|
assert {} == originals
|
|
|
|
def test_adds_import_for_decorator(self, tmp_path: Path) -> None:
|
|
"""The rewritten file includes the import for the decorator."""
|
|
source_file = tmp_path / "module.py"
|
|
source_file.write_text(
|
|
"async def target_func():\n pass\n",
|
|
encoding="utf-8",
|
|
)
|
|
func = make_function(
|
|
"target_func",
|
|
str(source_file),
|
|
is_async=True,
|
|
)
|
|
add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.PERFORMANCE,
|
|
)
|
|
modified = source_file.read_text()
|
|
assert "from codeflash_async_wrapper import" in modified
|
|
assert "codeflash_performance_async" in modified
|
|
|
|
def test_cst_parse_error_returns_false(self, tmp_path: Path) -> None:
|
|
"""Returns (False, {}) when the source file has invalid syntax."""
|
|
source_file = tmp_path / "module.py"
|
|
source_file.write_text(
|
|
"async def target_func(\n invalid !!!",
|
|
encoding="utf-8",
|
|
)
|
|
func = make_function(
|
|
"target_func",
|
|
str(source_file),
|
|
is_async=True,
|
|
)
|
|
result, originals = add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.BEHAVIOR,
|
|
)
|
|
assert result is False
|
|
assert {} == originals
|
|
|
|
|
|
class TestCreateInstrumentedSourceModulePath:
|
|
"""create_instrumented_source_module_path path construction."""
|
|
|
|
def test_basic_path(self, tmp_path: Path) -> None:
|
|
"""Constructs instrumented path from source path and temp dir."""
|
|
result = create_instrumented_source_module_path(
|
|
Path("test_foo.py"), tmp_path
|
|
)
|
|
assert tmp_path / "instrumented_test_foo.py" == result
|
|
|
|
def test_preserves_extension(self, tmp_path: Path) -> None:
|
|
"""Preserves the .py extension in the instrumented filename."""
|
|
result = create_instrumented_source_module_path(
|
|
Path("my_module.py"), tmp_path
|
|
)
|
|
assert "instrumented_my_module.py" == result.name
|
|
assert tmp_path == result.parent
|