mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
Remove dead code (unused fields, hasattr guard, duplicate decorator set), rename _optimized_instrument_statement to _find_awaited_target_call, simplify AsyncDecoratorAdder init and leave_FunctionDef. Add 21 new unit tests covering all branches: non-test skipping, attribute calls, class body recursion, counter independence, decorator deduplication (name and call form), error handlers, and mode selection.
1346 lines
48 KiB
Python
1346 lines
48 KiB
Python
"""Tests for _instrumentation — test instrumentation and AST transforms."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import textwrap
|
|
from pathlib import Path
|
|
|
|
import libcst as cst
|
|
|
|
from codeflash_python._model import (
|
|
FunctionParent,
|
|
FunctionToOptimize,
|
|
TestingMode,
|
|
VerificationType,
|
|
)
|
|
from codeflash_python.test_discovery.models import CodePosition
|
|
from codeflash_python.testing._instrumentation import (
|
|
ASYNC_HELPER_FILENAME,
|
|
AsyncCallInstrumenter,
|
|
AsyncDecoratorAdder,
|
|
FunctionCallNodeArguments,
|
|
FunctionImportedAsVisitor,
|
|
InjectPerfOnly,
|
|
add_async_decorator_to_function,
|
|
create_device_sync_precompute_statements,
|
|
create_device_sync_statements,
|
|
create_instrumented_source_module_path,
|
|
create_wrapper_function,
|
|
detect_frameworks_from_code,
|
|
get_call_arguments,
|
|
get_decorator_name_for_mode,
|
|
inject_async_profiling_into_existing_test,
|
|
inject_profiling_into_existing_test,
|
|
is_argument_name,
|
|
node_in_call_position,
|
|
sort_imports,
|
|
write_async_helper_file,
|
|
)
|
|
|
|
|
|
def make_function(
|
|
name: str = "target_func",
|
|
file_path: str = "module.py",
|
|
parents: tuple[FunctionParent, ...] = (),
|
|
*,
|
|
is_async: bool = False,
|
|
) -> FunctionToOptimize:
|
|
"""Create a FunctionToOptimize for testing."""
|
|
return FunctionToOptimize(
|
|
function_name=name,
|
|
file_path=Path(file_path),
|
|
parents=parents,
|
|
is_async=is_async,
|
|
)
|
|
|
|
|
|
class TestTestingMode:
|
|
"""TestingMode enum values."""
|
|
|
|
def test_enum_values(self) -> None:
|
|
"""Each mode has the expected string value."""
|
|
assert "behavior" == TestingMode.BEHAVIOR.value
|
|
assert "performance" == TestingMode.PERFORMANCE.value
|
|
assert "line_profile" == TestingMode.LINE_PROFILE.value
|
|
assert "concurrency" == TestingMode.CONCURRENCY.value
|
|
|
|
def test_membership(self) -> None:
|
|
"""All four members are present."""
|
|
assert 4 == len(TestingMode)
|
|
|
|
|
|
class TestVerificationType:
|
|
"""VerificationType str enum values."""
|
|
|
|
def test_is_str_enum(self) -> None:
|
|
"""VerificationType members are strings."""
|
|
assert isinstance(VerificationType.FUNCTION_CALL, str)
|
|
assert isinstance(VerificationType.INIT_STATE_FTO, str)
|
|
assert isinstance(VerificationType.INIT_STATE_HELPER, str)
|
|
|
|
def test_enum_values(self) -> None:
|
|
"""Each type has the expected string value."""
|
|
assert "function_call" == VerificationType.FUNCTION_CALL
|
|
assert "init_state_fto" == VerificationType.INIT_STATE_FTO
|
|
assert "init_state_helper" == VerificationType.INIT_STATE_HELPER
|
|
|
|
def test_membership(self) -> None:
|
|
"""All three members are present."""
|
|
assert 3 == len(VerificationType)
|
|
|
|
|
|
class TestGetCallArguments:
|
|
"""get_call_arguments Call node extraction."""
|
|
|
|
def test_simple_call(self) -> None:
|
|
"""Extracts positional args and keywords from a Call node."""
|
|
tree = ast.parse("func(1, 2, key='val')")
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
result = get_call_arguments(call_node)
|
|
assert isinstance(result, FunctionCallNodeArguments)
|
|
assert 2 == len(result.args)
|
|
assert 1 == len(result.keywords)
|
|
|
|
def test_no_args(self) -> None:
|
|
"""Returns empty lists for a call with no arguments."""
|
|
tree = ast.parse("func()")
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
result = get_call_arguments(call_node)
|
|
assert [] == result.args
|
|
assert [] == result.keywords
|
|
|
|
def test_only_keywords(self) -> None:
|
|
"""Returns keywords when only keyword args are present."""
|
|
tree = ast.parse("func(a=1, b=2)")
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
result = get_call_arguments(call_node)
|
|
assert [] == result.args
|
|
assert 2 == len(result.keywords)
|
|
|
|
|
|
class TestNodeInCallPosition:
|
|
"""node_in_call_position position matching."""
|
|
|
|
def test_matching_position(self) -> None:
|
|
"""Returns True when Call node matches a listed position."""
|
|
code = "target_func()\n"
|
|
tree = ast.parse(code)
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
positions = [CodePosition(line_no=1, col_no=0)]
|
|
assert node_in_call_position(call_node, positions) is True
|
|
|
|
def test_no_matching_position(self) -> None:
|
|
"""Returns False when Call node does not match any position."""
|
|
code = "target_func()\n"
|
|
tree = ast.parse(code)
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
positions = [CodePosition(line_no=99, col_no=99)]
|
|
assert node_in_call_position(call_node, positions) is False
|
|
|
|
def test_empty_positions_matches_all(self) -> None:
|
|
"""Returns True for any Call when positions list is empty."""
|
|
code = "target_func()\n"
|
|
tree = ast.parse(code)
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
assert node_in_call_position(call_node, []) is True
|
|
|
|
def test_multiple_positions_one_match(self) -> None:
|
|
"""Returns True when one of several positions matches."""
|
|
code = "target_func()\n"
|
|
tree = ast.parse(code)
|
|
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
positions = [
|
|
CodePosition(line_no=50, col_no=0),
|
|
CodePosition(line_no=1, col_no=0),
|
|
]
|
|
assert node_in_call_position(call_node, positions) is True
|
|
|
|
|
|
class TestIsArgumentName:
|
|
"""is_argument_name argument detection."""
|
|
|
|
def test_regular_arg(self) -> None:
|
|
"""Returns True for a regular positional argument name."""
|
|
code = "def f(x, y): pass"
|
|
tree = ast.parse(code)
|
|
func_def = tree.body[0]
|
|
assert is_argument_name("x", func_def.args) is True # type: ignore[attr-defined]
|
|
|
|
def test_kwonly_arg(self) -> None:
|
|
"""Returns True for a keyword-only argument name."""
|
|
code = "def f(*, key): pass"
|
|
tree = ast.parse(code)
|
|
func_def = tree.body[0]
|
|
assert is_argument_name("key", func_def.args) is True # type: ignore[attr-defined]
|
|
|
|
def test_no_match(self) -> None:
|
|
"""Returns False when name is not an argument."""
|
|
code = "def f(x, y): pass"
|
|
tree = ast.parse(code)
|
|
func_def = tree.body[0]
|
|
assert is_argument_name("z", func_def.args) is False # type: ignore[attr-defined]
|
|
|
|
def test_vararg_not_matched(self) -> None:
|
|
"""Returns False for *args (vararg is not a list attribute)."""
|
|
code = "def f(*args): pass"
|
|
tree = ast.parse(code)
|
|
func_def = tree.body[0]
|
|
assert is_argument_name("args", func_def.args) is False # type: ignore[attr-defined]
|
|
|
|
def test_kwarg_not_matched(self) -> None:
|
|
"""Returns False for **kwargs (kwarg is not a list attribute)."""
|
|
code = "def f(**kwargs): pass"
|
|
tree = ast.parse(code)
|
|
func_def = tree.body[0]
|
|
assert is_argument_name("kwargs", func_def.args) is False # type: ignore[attr-defined]
|
|
|
|
|
|
class TestDetectFrameworksFromCode:
|
|
"""detect_frameworks_from_code import detection."""
|
|
|
|
def test_torch_import(self) -> None:
|
|
"""Detects torch from 'import torch'."""
|
|
code = "import torch\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert "torch" in result
|
|
|
|
def test_tensorflow_import(self) -> None:
|
|
"""Detects tensorflow from 'import tensorflow'."""
|
|
code = "import tensorflow\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert "tensorflow" in result
|
|
|
|
def test_jax_import(self) -> None:
|
|
"""Detects jax from 'import jax'."""
|
|
code = "import jax\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert "jax" in result
|
|
|
|
def test_aliased_import(self) -> None:
|
|
"""Detects framework with alias from 'import torch as th'."""
|
|
code = "import torch as th\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert "torch" in result
|
|
assert "th" == result["torch"]
|
|
|
|
def test_no_frameworks(self) -> None:
|
|
"""Returns empty dict when no GPU frameworks are imported."""
|
|
code = "import os\nimport sys\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert {} == result
|
|
|
|
def test_from_import_submodule(self) -> None:
|
|
"""Detects framework from 'from torch import nn'."""
|
|
code = "from torch import nn\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert "torch" in result
|
|
|
|
def test_multiple_frameworks(self) -> None:
|
|
"""Detects multiple frameworks in the same code."""
|
|
code = "import torch\nimport jax\n"
|
|
result = detect_frameworks_from_code(code)
|
|
assert "torch" in result
|
|
assert "jax" in result
|
|
|
|
|
|
class TestGetDecoratorNameForMode:
|
|
"""get_decorator_name_for_mode decorator selection."""
|
|
|
|
def test_behavior_mode(self) -> None:
|
|
"""Returns codeflash_behavior_async for BEHAVIOR."""
|
|
assert "codeflash_behavior_async" == get_decorator_name_for_mode(
|
|
TestingMode.BEHAVIOR,
|
|
)
|
|
|
|
def test_performance_mode(self) -> None:
|
|
"""Returns codeflash_performance_async for PERFORMANCE."""
|
|
assert "codeflash_performance_async" == get_decorator_name_for_mode(
|
|
TestingMode.PERFORMANCE,
|
|
)
|
|
|
|
def test_concurrency_mode(self) -> None:
|
|
"""Returns codeflash_concurrency_async for CONCURRENCY."""
|
|
assert "codeflash_concurrency_async" == get_decorator_name_for_mode(
|
|
TestingMode.CONCURRENCY,
|
|
)
|
|
|
|
def test_line_profile_falls_through_to_performance(self) -> None:
|
|
"""LINE_PROFILE is not async-specific, falls through to performance."""
|
|
assert "codeflash_performance_async" == get_decorator_name_for_mode(
|
|
TestingMode.LINE_PROFILE,
|
|
)
|
|
|
|
|
|
class TestCreateDeviceSyncPrecomputeStatements:
|
|
"""create_device_sync_precompute_statements AST generation."""
|
|
|
|
def test_none_frameworks(self) -> None:
|
|
"""Returns empty list when frameworks is None."""
|
|
result = create_device_sync_precompute_statements(None)
|
|
assert [] == result
|
|
|
|
def test_empty_dict(self) -> None:
|
|
"""Returns empty list when frameworks dict is empty."""
|
|
result = create_device_sync_precompute_statements({})
|
|
assert [] == result
|
|
|
|
def test_torch_produces_statements(self) -> None:
|
|
"""Produces AST statements for torch."""
|
|
result = create_device_sync_precompute_statements(
|
|
{"torch": "torch"},
|
|
)
|
|
assert len(result) > 0
|
|
assert all(isinstance(s, ast.stmt) for s in result)
|
|
|
|
def test_jax_produces_statements(self) -> None:
|
|
"""Produces AST statements for jax."""
|
|
result = create_device_sync_precompute_statements({"jax": "jax"})
|
|
assert len(result) > 0
|
|
assert all(isinstance(s, ast.stmt) for s in result)
|
|
|
|
def test_tensorflow_produces_statements(self) -> None:
|
|
"""Produces AST statements for tensorflow."""
|
|
result = create_device_sync_precompute_statements(
|
|
{"tensorflow": "tf"},
|
|
)
|
|
assert len(result) > 0
|
|
assert all(isinstance(s, ast.stmt) for s in result)
|
|
|
|
def test_combined_frameworks(self) -> None:
|
|
"""Produces statements for multiple frameworks."""
|
|
result = create_device_sync_precompute_statements(
|
|
{"torch": "torch", "jax": "jax"},
|
|
)
|
|
assert len(result) > 0
|
|
|
|
|
|
class TestCreateDeviceSyncStatements:
|
|
"""create_device_sync_statements AST generation."""
|
|
|
|
def test_none_frameworks(self) -> None:
|
|
"""Returns empty list when frameworks is None."""
|
|
result = create_device_sync_statements(None)
|
|
assert [] == result
|
|
|
|
def test_empty_dict(self) -> None:
|
|
"""Returns empty list when frameworks dict is empty."""
|
|
result = create_device_sync_statements({})
|
|
assert [] == result
|
|
|
|
def test_torch_sync(self) -> None:
|
|
"""Produces sync statements for torch."""
|
|
result = create_device_sync_statements({"torch": "torch"})
|
|
assert len(result) > 0
|
|
assert all(isinstance(s, ast.stmt) for s in result)
|
|
|
|
def test_for_return_value_flag(self) -> None:
|
|
"""Produces statements with for_return_value=True."""
|
|
result = create_device_sync_statements(
|
|
{"jax": "jax"},
|
|
for_return_value=True,
|
|
)
|
|
assert len(result) > 0
|
|
assert all(isinstance(s, ast.stmt) for s in result)
|
|
|
|
def test_tensorflow_sync(self) -> None:
|
|
"""Produces sync statements for tensorflow."""
|
|
result = create_device_sync_statements({"tensorflow": "tf"})
|
|
assert len(result) > 0
|
|
assert all(isinstance(s, ast.stmt) for s in result)
|
|
|
|
|
|
class TestCreateWrapperFunction:
|
|
"""create_wrapper_function AST generation."""
|
|
|
|
def test_returns_function_def(self) -> None:
|
|
"""Returns an ast.FunctionDef node."""
|
|
result = create_wrapper_function(TestingMode.BEHAVIOR)
|
|
assert isinstance(result, ast.FunctionDef)
|
|
|
|
def test_function_name(self) -> None:
|
|
"""The generated function is named codeflash_wrap."""
|
|
result = create_wrapper_function(TestingMode.BEHAVIOR)
|
|
assert "codeflash_wrap" == result.name
|
|
|
|
def test_behavior_mode_params(self) -> None:
|
|
"""BEHAVIOR mode wrapper has expected parameters."""
|
|
result = create_wrapper_function(TestingMode.BEHAVIOR)
|
|
arg_names = [a.arg for a in result.args.args]
|
|
assert len(arg_names) > 0
|
|
|
|
def test_performance_mode_params(self) -> None:
|
|
"""PERFORMANCE mode wrapper has expected parameters."""
|
|
result = create_wrapper_function(TestingMode.PERFORMANCE)
|
|
arg_names = [a.arg for a in result.args.args]
|
|
assert len(arg_names) > 0
|
|
|
|
def test_body_is_nonempty(self) -> None:
|
|
"""The function body contains statements."""
|
|
result = create_wrapper_function(TestingMode.BEHAVIOR)
|
|
assert len(result.body) > 0
|
|
|
|
def test_with_frameworks(self) -> None:
|
|
"""Accepts used_frameworks parameter without error."""
|
|
result = create_wrapper_function(
|
|
TestingMode.PERFORMANCE,
|
|
used_frameworks={"torch": "torch"},
|
|
)
|
|
assert isinstance(result, ast.FunctionDef)
|
|
|
|
|
|
class TestInjectPerfOnly:
|
|
"""InjectPerfOnly AST transformer."""
|
|
|
|
def test_wraps_name_call(self) -> None:
|
|
"""Wraps a direct Name call with codeflash_wrap."""
|
|
code = textwrap.dedent("""\
|
|
def test_it():
|
|
result = target_func(1, 2)
|
|
""")
|
|
tree = ast.parse(code)
|
|
call_node = tree.body[0].body[0].value # type: ignore[attr-defined]
|
|
pos = CodePosition(
|
|
line_no=call_node.lineno,
|
|
col_no=call_node.col_offset,
|
|
)
|
|
func = make_function("target_func", "module.py")
|
|
transformer = InjectPerfOnly(
|
|
function=func,
|
|
module_path="module",
|
|
call_positions=[pos],
|
|
mode=TestingMode.BEHAVIOR,
|
|
)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "codeflash_wrap" in source
|
|
|
|
def test_wraps_attribute_call(self) -> None:
|
|
"""Wraps a module.func() attribute call with codeflash_wrap."""
|
|
code = textwrap.dedent("""\
|
|
def test_it():
|
|
result = module.target_func(1, 2)
|
|
""")
|
|
tree = ast.parse(code)
|
|
call_node = tree.body[0].body[0].value # type: ignore[attr-defined]
|
|
pos = CodePosition(
|
|
line_no=call_node.lineno,
|
|
col_no=call_node.col_offset,
|
|
)
|
|
func = make_function("target_func", "module.py")
|
|
transformer = InjectPerfOnly(
|
|
function=func,
|
|
module_path="module",
|
|
call_positions=[pos],
|
|
mode=TestingMode.BEHAVIOR,
|
|
)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "codeflash_wrap" in source
|
|
|
|
def test_no_wrap_without_matching_position(self) -> None:
|
|
"""Does not wrap calls that are not in call_positions."""
|
|
code = textwrap.dedent("""\
|
|
def test_it():
|
|
result = target_func(1, 2)
|
|
""")
|
|
tree = ast.parse(code)
|
|
func = make_function("target_func", "module.py")
|
|
transformer = InjectPerfOnly(
|
|
function=func,
|
|
module_path="module",
|
|
call_positions=[CodePosition(line_no=99, col_no=99)],
|
|
mode=TestingMode.BEHAVIOR,
|
|
)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "codeflash_wrap" not in source
|
|
|
|
|
|
class TestAsyncCallInstrumenter:
|
|
"""AsyncCallInstrumenter AST transformer."""
|
|
|
|
def _make_transformer(
|
|
self,
|
|
code: str,
|
|
*,
|
|
name: str = "target_func",
|
|
parents: tuple[FunctionParent, ...] = (),
|
|
positions: list[CodePosition] | None = None,
|
|
) -> tuple[AsyncCallInstrumenter, ast.Module]:
|
|
"""Parse code and build a transformer with call positions from it."""
|
|
tree = ast.parse(code)
|
|
if positions is None:
|
|
positions = []
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.Call):
|
|
func_name = None
|
|
if isinstance(node.func, ast.Name):
|
|
func_name = node.func.id
|
|
elif isinstance(node.func, ast.Attribute):
|
|
func_name = node.func.attr
|
|
if func_name == name:
|
|
positions.append(
|
|
CodePosition(
|
|
line_no=node.lineno,
|
|
col_no=node.col_offset,
|
|
)
|
|
)
|
|
func = make_function(
|
|
name,
|
|
"module.py",
|
|
parents=parents,
|
|
is_async=True,
|
|
)
|
|
transformer = AsyncCallInstrumenter(
|
|
function=func,
|
|
module_path="module",
|
|
call_positions=positions,
|
|
)
|
|
return transformer, tree
|
|
|
|
def test_instruments_await_call(self) -> None:
|
|
"""Adds env var assignment before an awaited target call."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
|
assert transformer.did_instrument is True
|
|
|
|
def test_skips_non_test_async_functions(self) -> None:
|
|
"""Does not instrument async functions that don't start with test_."""
|
|
code = textwrap.dedent("""\
|
|
async def helper():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "CODEFLASH_CURRENT_LINE_ID" not in source
|
|
assert transformer.did_instrument is False
|
|
|
|
def test_skips_non_test_sync_functions(self) -> None:
|
|
"""Does not instrument sync functions that don't start with test_."""
|
|
code = textwrap.dedent("""\
|
|
def helper():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "CODEFLASH_CURRENT_LINE_ID" not in source
|
|
|
|
def test_instruments_sync_test_with_await(self) -> None:
|
|
"""Instruments sync test_ functions that contain awaited calls."""
|
|
code = textwrap.dedent("""\
|
|
def test_it():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
|
assert transformer.did_instrument is True
|
|
|
|
def test_multiple_awaits_get_incrementing_ids(self) -> None:
|
|
"""Each awaited target call gets a unique incrementing counter."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
a = await target_func(1)
|
|
b = await target_func(2)
|
|
c = await target_func(3)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "'0'" in source
|
|
assert "'1'" in source
|
|
assert "'2'" in source
|
|
|
|
def test_attribute_style_call(self) -> None:
|
|
"""Instruments await obj.target_func() attribute-style calls."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
result = await obj.target_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
|
assert transformer.did_instrument is True
|
|
|
|
def test_recurses_into_class_body(self) -> None:
|
|
"""Finds and instruments test methods inside a class."""
|
|
code = textwrap.dedent("""\
|
|
class TestSuite:
|
|
async def test_it(self):
|
|
result = await target_func(1)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
|
assert transformer.did_instrument is True
|
|
|
|
def test_no_match_when_position_wrong(self) -> None:
|
|
"""Does not instrument when call positions don't match."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(
|
|
code,
|
|
positions=[CodePosition(line_no=99, col_no=99)],
|
|
)
|
|
new_tree = transformer.visit(tree)
|
|
assert transformer.did_instrument is False
|
|
|
|
def test_nested_await_in_conditional(self) -> None:
|
|
"""Finds awaited target calls nested inside if statements."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
if True:
|
|
result = await target_func(1)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
source = ast.unparse(new_tree)
|
|
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
|
|
|
def test_ignores_non_target_awaits(self) -> None:
|
|
"""Does not instrument awaits of unrelated functions."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
result = await other_func(1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
new_tree = transformer.visit(tree)
|
|
assert transformer.did_instrument is False
|
|
|
|
def test_subscript_style_call_not_matched(self) -> None:
|
|
"""Await of a subscript-style call (funcs[0]()) is not matched."""
|
|
code = textwrap.dedent("""\
|
|
async def test_it():
|
|
result = await funcs[0](1, 2)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
transformer.visit(tree)
|
|
assert transformer.did_instrument is False
|
|
|
|
def test_counters_independent_per_test_function(self) -> None:
|
|
"""Each test function gets its own independent counter."""
|
|
code = textwrap.dedent("""\
|
|
async def test_a():
|
|
await target_func(1)
|
|
await target_func(2)
|
|
async def test_b():
|
|
await target_func(3)
|
|
""")
|
|
transformer, tree = self._make_transformer(code)
|
|
transformer.visit(tree)
|
|
assert 2 == transformer.async_call_counter["test_a"]
|
|
assert 1 == transformer.async_call_counter["test_b"]
|
|
|
|
|
|
class TestFunctionImportedAsVisitor:
|
|
"""FunctionImportedAsVisitor alias detection."""
|
|
|
|
def test_aliased_import(self) -> None:
|
|
"""Updates imported_as with the aliased FunctionToOptimize."""
|
|
code = textwrap.dedent("""\
|
|
from module import target_func as tf
|
|
""")
|
|
tree = ast.parse(code)
|
|
func = make_function("target_func", "module.py")
|
|
visitor = FunctionImportedAsVisitor(func)
|
|
visitor.visit(tree)
|
|
assert "tf" == visitor.imported_as.function_name
|
|
|
|
def test_non_aliased_import(self) -> None:
|
|
"""Keeps original function when imported without alias."""
|
|
code = textwrap.dedent("""\
|
|
from module import target_func
|
|
""")
|
|
tree = ast.parse(code)
|
|
func = make_function("target_func", "module.py")
|
|
visitor = FunctionImportedAsVisitor(func)
|
|
visitor.visit(tree)
|
|
assert visitor.imported_as is func
|
|
|
|
def test_class_method_aliased_import(self) -> None:
|
|
"""Updates parent name when class is imported with alias."""
|
|
code = textwrap.dedent("""\
|
|
from module import MyClass as MC
|
|
""")
|
|
tree = ast.parse(code)
|
|
parent = FunctionParent(name="MyClass", type="ClassDef")
|
|
func = make_function(
|
|
"method",
|
|
"module.py",
|
|
parents=(parent,),
|
|
)
|
|
visitor = FunctionImportedAsVisitor(func)
|
|
visitor.visit(tree)
|
|
assert "MC" == visitor.imported_as.parents[0].name
|
|
|
|
def test_no_import(self) -> None:
|
|
"""Keeps original function when not imported."""
|
|
code = textwrap.dedent("""\
|
|
import os
|
|
""")
|
|
tree = ast.parse(code)
|
|
func = make_function("target_func", "module.py")
|
|
visitor = FunctionImportedAsVisitor(func)
|
|
visitor.visit(tree)
|
|
assert visitor.imported_as is func
|
|
|
|
|
|
class TestAsyncDecoratorAdder:
|
|
"""AsyncDecoratorAdder CST transformer."""
|
|
|
|
def test_adds_decorator_to_async_function(self) -> None:
|
|
"""Adds the async decorator to a matching async function."""
|
|
code = textwrap.dedent("""\
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
output = new_tree.code
|
|
assert "@codeflash_behavior_async" in output
|
|
assert transformer.added_decorator is True
|
|
|
|
def test_does_not_add_to_non_matching(self) -> None:
|
|
"""Does not add decorator to functions that do not match."""
|
|
code = textwrap.dedent("""\
|
|
async def other_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
output = new_tree.code
|
|
assert "@" not in output
|
|
assert transformer.added_decorator is False
|
|
|
|
def test_does_not_add_to_sync_function(self) -> None:
|
|
"""Skips sync functions even if the name matches."""
|
|
code = textwrap.dedent("""\
|
|
def target_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
assert "@" not in new_tree.code
|
|
assert transformer.added_decorator is False
|
|
|
|
def test_performance_mode_decorator(self) -> None:
|
|
"""Uses codeflash_performance_async for PERFORMANCE mode."""
|
|
code = "async def target_func():\n pass\n"
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(
|
|
func,
|
|
mode=TestingMode.PERFORMANCE,
|
|
)
|
|
new_tree = tree.visit(transformer)
|
|
assert "@codeflash_performance_async" in new_tree.code
|
|
|
|
def test_concurrency_mode_decorator(self) -> None:
|
|
"""Uses codeflash_concurrency_async for CONCURRENCY mode."""
|
|
code = "async def target_func():\n pass\n"
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(
|
|
func,
|
|
mode=TestingMode.CONCURRENCY,
|
|
)
|
|
new_tree = tree.visit(transformer)
|
|
assert "@codeflash_concurrency_async" in new_tree.code
|
|
|
|
def test_class_method_matching(self) -> None:
|
|
"""Matches async method inside a class via qualified name."""
|
|
code = textwrap.dedent("""\
|
|
class MyClass:
|
|
async def target_func(self):
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
parent = FunctionParent(name="MyClass", type="ClassDef")
|
|
func = make_function(
|
|
"target_func",
|
|
"module.py",
|
|
parents=(parent,),
|
|
is_async=True,
|
|
)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
assert "@codeflash_behavior_async" in new_tree.code
|
|
assert transformer.added_decorator is True
|
|
|
|
def test_no_duplicate_when_already_decorated(self) -> None:
|
|
"""Does not add a second decorator when one is already present."""
|
|
code = textwrap.dedent("""\
|
|
@codeflash_behavior_async
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
assert new_tree.code.count("codeflash_behavior_async") == 1
|
|
assert transformer.added_decorator is False
|
|
|
|
def test_no_duplicate_when_call_style_decorator(self) -> None:
|
|
"""Detects existing decorator even in @decorator() call form."""
|
|
code = textwrap.dedent("""\
|
|
@codeflash_behavior_async()
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
assert new_tree.code.count("codeflash_behavior_async") == 1
|
|
assert transformer.added_decorator is False
|
|
|
|
def test_preserves_existing_decorators(self) -> None:
|
|
"""Keeps existing decorators and prepends the codeflash one."""
|
|
code = textwrap.dedent("""\
|
|
@staticmethod
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
output = new_tree.code
|
|
assert "@codeflash_behavior_async" in output
|
|
assert "@staticmethod" in output
|
|
behavior_pos = output.index("@codeflash_behavior_async")
|
|
static_pos = output.index("@staticmethod")
|
|
assert behavior_pos < static_pos
|
|
|
|
def test_attribute_decorator_not_matched(self) -> None:
|
|
"""Attribute-style decorators (mod.decorator) are not codeflash."""
|
|
code = textwrap.dedent("""\
|
|
@mod.codeflash_behavior_async
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
tree = cst.parse_module(code)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR)
|
|
new_tree = tree.visit(transformer)
|
|
assert new_tree.code.count("codeflash_behavior_async") == 2
|
|
assert transformer.added_decorator is True
|
|
|
|
|
|
class TestWriteAsyncHelperFile:
|
|
"""write_async_helper_file file creation."""
|
|
|
|
def test_creates_file(self, tmp_path: Path) -> None:
|
|
"""Creates the async helper file in the target directory."""
|
|
result = write_async_helper_file(tmp_path)
|
|
assert result.exists()
|
|
assert result.is_file()
|
|
|
|
def test_file_name(self, tmp_path: Path) -> None:
|
|
"""The created file has the expected name."""
|
|
result = write_async_helper_file(tmp_path)
|
|
assert ASYNC_HELPER_FILENAME == result.name
|
|
|
|
def test_file_content_is_valid_python(self, tmp_path: Path) -> None:
|
|
"""The copied file is valid, parseable Python."""
|
|
result = write_async_helper_file(tmp_path)
|
|
content = result.read_text()
|
|
assert len(content) > 0
|
|
ast.parse(content)
|
|
|
|
def test_idempotent_does_not_overwrite(self, tmp_path: Path) -> None:
|
|
"""Calling twice does not overwrite the existing file."""
|
|
first = write_async_helper_file(tmp_path)
|
|
first.write_text("sentinel", encoding="utf-8")
|
|
second = write_async_helper_file(tmp_path)
|
|
assert first == second
|
|
assert "sentinel" == second.read_text()
|
|
|
|
|
|
class TestAsyncHelperConstants:
|
|
"""ASYNC_HELPER_FILENAME and runtime decorator source."""
|
|
|
|
def test_filename_value(self) -> None:
|
|
"""ASYNC_HELPER_FILENAME has the expected value."""
|
|
assert "codeflash_async_wrapper.py" == ASYNC_HELPER_FILENAME
|
|
|
|
def test_runtime_decorator_is_self_contained(self) -> None:
|
|
"""Runtime decorator file has no internal codeflash imports."""
|
|
from codeflash_python.testing._instrument_async import (
|
|
_RUNTIME_DECORATOR_PATH,
|
|
)
|
|
|
|
source = _RUNTIME_DECORATOR_PATH.read_text("utf-8")
|
|
assert _RUNTIME_DECORATOR_PATH.exists()
|
|
for line in source.splitlines():
|
|
stripped = line.strip()
|
|
if stripped.startswith(("import ", "from ")):
|
|
assert "codeflash_python" not in stripped
|
|
|
|
|
|
class TestSortImports:
|
|
"""sort_imports import sorting."""
|
|
|
|
def test_sorts_unsorted_imports(self) -> None:
|
|
"""Sorts and deduplicates unsorted imports."""
|
|
code = textwrap.dedent("""\
|
|
import os
|
|
import ast
|
|
import os
|
|
""")
|
|
result = sort_imports(code)
|
|
lines = result.strip().splitlines()
|
|
assert "import ast" in lines[0]
|
|
assert "import os" in lines[1]
|
|
# Duplicate removed
|
|
assert result.count("import os") == 1
|
|
|
|
def test_syntax_error_returns_original(self) -> None:
|
|
"""Returns original code unchanged when isort encounters issues."""
|
|
code = "import os\nimport ast\n"
|
|
# isort handles most inputs gracefully; verify normal code works
|
|
result = sort_imports(code)
|
|
assert isinstance(result, str)
|
|
assert "import" in result
|
|
|
|
def test_kwargs_forwarded(self) -> None:
|
|
"""Forwards kwargs to isort.code (e.g. float_to_top)."""
|
|
code = textwrap.dedent("""\
|
|
from os.path import join
|
|
|
|
x = 1
|
|
|
|
import ast
|
|
""")
|
|
result = sort_imports(code, float_to_top=True)
|
|
# With float_to_top, imports should be grouped at the top
|
|
lines = result.strip().splitlines()
|
|
# Both imports should appear before 'x = 1'
|
|
import_lines = [i for i, line in enumerate(lines) if "import" in line]
|
|
code_lines = [
|
|
i for i, line in enumerate(lines) if line.strip() == "x = 1"
|
|
]
|
|
if import_lines and code_lines:
|
|
assert max(import_lines) < min(code_lines)
|
|
|
|
|
|
class TestInjectProfilingIntoExistingTest:
|
|
"""inject_profiling_into_existing_test orchestration."""
|
|
|
|
def test_sync_function_instrumentation(self, tmp_path: Path) -> None:
|
|
"""Instruments a sync test file with codeflash_wrap and imports."""
|
|
project_root = tmp_path / "project"
|
|
project_root.mkdir()
|
|
test_file = project_root / "test_example.py"
|
|
test_code = textwrap.dedent("""\
|
|
from module import target_func
|
|
|
|
def test_something():
|
|
result = target_func(1, 2)
|
|
assert result == 3
|
|
""")
|
|
test_file.write_text(test_code, encoding="utf-8")
|
|
|
|
func = make_function("target_func", "module.py")
|
|
# target_func(1, 2) is on line 4, col 13
|
|
positions = [CodePosition(line_no=4, col_no=13)]
|
|
|
|
ok, source = inject_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
project_root,
|
|
mode=TestingMode.PERFORMANCE,
|
|
)
|
|
assert ok is True
|
|
assert source is not None
|
|
assert "codeflash_wrap" in source
|
|
assert "import time" in source
|
|
assert "import gc" in source
|
|
assert "import os" in source
|
|
|
|
def test_async_delegation(self, tmp_path: Path) -> None:
|
|
"""Delegates to async handler for async functions without error."""
|
|
project_root = tmp_path / "project"
|
|
project_root.mkdir()
|
|
test_file = project_root / "test_async.py"
|
|
test_code = textwrap.dedent("""\
|
|
from module import target_func
|
|
|
|
async def test_something():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
test_file.write_text(test_code, encoding="utf-8")
|
|
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
positions = [CodePosition(line_no=4, col_no=25)]
|
|
|
|
ok, source = inject_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
project_root,
|
|
)
|
|
# Should delegate to async path and return a result
|
|
assert isinstance(ok, bool)
|
|
if ok:
|
|
assert source is not None
|
|
|
|
def test_syntax_error_returns_false(self, tmp_path: Path) -> None:
|
|
"""Returns (False, None) for a file with invalid Python."""
|
|
project_root = tmp_path / "project"
|
|
project_root.mkdir()
|
|
test_file = project_root / "test_bad.py"
|
|
test_file.write_text(
|
|
"def test_x(\n not valid python !!!",
|
|
encoding="utf-8",
|
|
)
|
|
|
|
func = make_function("target_func", "module.py")
|
|
positions = [CodePosition(line_no=1, col_no=0)]
|
|
|
|
ok, source = inject_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
project_root,
|
|
)
|
|
assert ok is False
|
|
assert source is None
|
|
|
|
def test_behavior_mode_extra_imports(self, tmp_path: Path) -> None:
|
|
"""BEHAVIOR mode adds inspect, sqlite3, and dill imports."""
|
|
project_root = tmp_path / "project"
|
|
project_root.mkdir()
|
|
test_file = project_root / "test_behav.py"
|
|
test_code = textwrap.dedent("""\
|
|
from module import target_func
|
|
|
|
def test_something():
|
|
result = target_func(1, 2)
|
|
assert result == 3
|
|
""")
|
|
test_file.write_text(test_code, encoding="utf-8")
|
|
|
|
func = make_function("target_func", "module.py")
|
|
positions = [CodePosition(line_no=4, col_no=13)]
|
|
|
|
ok, source = inject_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
project_root,
|
|
mode=TestingMode.BEHAVIOR,
|
|
)
|
|
assert ok is True
|
|
assert source is not None
|
|
assert "inspect" in source
|
|
assert "sqlite3" in source
|
|
assert "dill" in source
|
|
|
|
|
|
class TestInjectAsyncProfilingIntoExistingTest:
|
|
"""inject_async_profiling_into_existing_test orchestration."""
|
|
|
|
def test_async_instrumentation(self, tmp_path: Path) -> None:
|
|
"""Instruments an async test file and adds import os."""
|
|
project_root = tmp_path / "project"
|
|
project_root.mkdir()
|
|
test_file = project_root / "test_async_ex.py"
|
|
test_code = textwrap.dedent("""\
|
|
from module import target_func
|
|
|
|
async def test_something():
|
|
result = await target_func(1, 2)
|
|
""")
|
|
test_file.write_text(test_code, encoding="utf-8")
|
|
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
positions = [CodePosition(line_no=4, col_no=25)]
|
|
|
|
ok, source = inject_async_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
project_root,
|
|
)
|
|
assert ok is True
|
|
assert source is not None
|
|
assert "import os" in source
|
|
|
|
def test_no_instrumentation(self, tmp_path: Path) -> None:
|
|
"""Returns (False, None) when test does not call target."""
|
|
project_root = tmp_path / "project"
|
|
project_root.mkdir()
|
|
test_file = project_root / "test_no_call.py"
|
|
test_code = textwrap.dedent("""\
|
|
def test_something():
|
|
assert 1 == 1
|
|
""")
|
|
test_file.write_text(test_code, encoding="utf-8")
|
|
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
positions = [CodePosition(line_no=2, col_no=0)]
|
|
|
|
ok, source = inject_async_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
project_root,
|
|
)
|
|
assert ok is False
|
|
assert source is None
|
|
|
|
def test_syntax_error_returns_false(self, tmp_path: Path) -> None:
|
|
"""Returns (False, None) for a file with invalid Python."""
|
|
project_root = tmp_path / "project"
|
|
project_root.mkdir()
|
|
test_file = project_root / "test_bad.py"
|
|
test_file.write_text(
|
|
"async def test_x(\n not valid !!!",
|
|
encoding="utf-8",
|
|
)
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
ok, source = inject_async_profiling_into_existing_test(
|
|
test_file,
|
|
[CodePosition(line_no=1, col_no=0)],
|
|
func,
|
|
project_root,
|
|
)
|
|
assert ok is False
|
|
assert source is None
|
|
|
|
def test_multiple_awaits_get_sequential_ids(
|
|
self,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Multiple awaited calls in one test get sequential counter IDs."""
|
|
project_root = tmp_path / "project"
|
|
project_root.mkdir()
|
|
test_file = project_root / "test_multi.py"
|
|
test_code = textwrap.dedent("""\
|
|
from module import target_func
|
|
|
|
async def test_multi():
|
|
a = await target_func(1)
|
|
b = await target_func(2)
|
|
""")
|
|
test_file.write_text(test_code, encoding="utf-8")
|
|
|
|
func = make_function("target_func", "module.py", is_async=True)
|
|
positions = [
|
|
CodePosition(line_no=4, col_no=22),
|
|
CodePosition(line_no=5, col_no=22),
|
|
]
|
|
ok, source = inject_async_profiling_into_existing_test(
|
|
test_file,
|
|
positions,
|
|
func,
|
|
project_root,
|
|
)
|
|
assert ok is True
|
|
assert source is not None
|
|
assert source.count("CODEFLASH_CURRENT_LINE_ID") == 2
|
|
|
|
|
|
class TestAddAsyncDecoratorToFunction:
|
|
"""add_async_decorator_to_function source rewriting."""
|
|
|
|
def test_non_async_returns_false(self, tmp_path: Path) -> None:
|
|
"""Returns False for a non-async function without modifying."""
|
|
source_file = tmp_path / "module.py"
|
|
source_file.write_text(
|
|
"def target_func():\n pass\n",
|
|
encoding="utf-8",
|
|
)
|
|
func = make_function("target_func", str(source_file))
|
|
result, originals = add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.BEHAVIOR,
|
|
)
|
|
assert result is False
|
|
assert {} == originals
|
|
assert "decorator" not in source_file.read_text()
|
|
|
|
def test_async_function_gets_decorator(self, tmp_path: Path) -> None:
|
|
"""Adds decorator, rewrites file, creates helper file."""
|
|
source_file = tmp_path / "module.py"
|
|
source_code = textwrap.dedent("""\
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
source_file.write_text(source_code, encoding="utf-8")
|
|
|
|
func = make_function(
|
|
"target_func",
|
|
str(source_file),
|
|
is_async=True,
|
|
)
|
|
result, originals = add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.BEHAVIOR,
|
|
)
|
|
assert result is True
|
|
|
|
modified = source_file.read_text()
|
|
assert "codeflash_behavior_async" in modified
|
|
|
|
helper = tmp_path / ASYNC_HELPER_FILENAME
|
|
assert helper.exists()
|
|
|
|
def test_originals_contains_pre_modification_source(
|
|
self,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""The originals dict maps the file to its content before rewriting."""
|
|
source_file = tmp_path / "module.py"
|
|
original_code = "async def target_func():\n pass\n"
|
|
source_file.write_text(original_code, encoding="utf-8")
|
|
|
|
func = make_function(
|
|
"target_func",
|
|
str(source_file),
|
|
is_async=True,
|
|
)
|
|
_, originals = add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.BEHAVIOR,
|
|
)
|
|
assert source_file in originals
|
|
assert original_code == originals[source_file]
|
|
|
|
def test_with_explicit_project_root(self, tmp_path: Path) -> None:
|
|
"""Writes helper file to project_root when specified."""
|
|
src_dir = tmp_path / "src"
|
|
src_dir.mkdir()
|
|
source_file = src_dir / "module.py"
|
|
source_code = textwrap.dedent("""\
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
source_file.write_text(source_code, encoding="utf-8")
|
|
|
|
project_root = tmp_path / "root"
|
|
project_root.mkdir()
|
|
|
|
func = make_function(
|
|
"target_func",
|
|
str(source_file),
|
|
is_async=True,
|
|
)
|
|
result, _ = add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.BEHAVIOR,
|
|
project_root=project_root,
|
|
)
|
|
assert result is True
|
|
assert (project_root / ASYNC_HELPER_FILENAME).exists()
|
|
assert not (src_dir / ASYNC_HELPER_FILENAME).exists()
|
|
|
|
def test_already_decorated_returns_false(self, tmp_path: Path) -> None:
|
|
"""Returns False when the function already has the decorator."""
|
|
source_file = tmp_path / "module.py"
|
|
source_code = textwrap.dedent("""\
|
|
@codeflash_behavior_async
|
|
async def target_func():
|
|
pass
|
|
""")
|
|
source_file.write_text(source_code, encoding="utf-8")
|
|
|
|
func = make_function(
|
|
"target_func",
|
|
str(source_file),
|
|
is_async=True,
|
|
)
|
|
result, originals = add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.BEHAVIOR,
|
|
)
|
|
assert result is False
|
|
assert {} == originals
|
|
|
|
def test_adds_import_for_decorator(self, tmp_path: Path) -> None:
|
|
"""The rewritten file includes the import for the decorator."""
|
|
source_file = tmp_path / "module.py"
|
|
source_file.write_text(
|
|
"async def target_func():\n pass\n",
|
|
encoding="utf-8",
|
|
)
|
|
func = make_function(
|
|
"target_func",
|
|
str(source_file),
|
|
is_async=True,
|
|
)
|
|
add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.PERFORMANCE,
|
|
)
|
|
modified = source_file.read_text()
|
|
assert "from codeflash_async_wrapper import" in modified
|
|
assert "codeflash_performance_async" in modified
|
|
|
|
def test_cst_parse_error_returns_false(self, tmp_path: Path) -> None:
|
|
"""Returns (False, {}) when the source file has invalid syntax."""
|
|
source_file = tmp_path / "module.py"
|
|
source_file.write_text(
|
|
"async def target_func(\n invalid !!!",
|
|
encoding="utf-8",
|
|
)
|
|
func = make_function(
|
|
"target_func",
|
|
str(source_file),
|
|
is_async=True,
|
|
)
|
|
result, originals = add_async_decorator_to_function(
|
|
source_file,
|
|
func,
|
|
TestingMode.BEHAVIOR,
|
|
)
|
|
assert result is False
|
|
assert {} == originals
|
|
|
|
|
|
class TestCreateInstrumentedSourceModulePath:
|
|
"""create_instrumented_source_module_path path construction."""
|
|
|
|
def test_basic_path(self, tmp_path: Path) -> None:
|
|
"""Constructs instrumented path from source path and temp dir."""
|
|
result = create_instrumented_source_module_path(
|
|
Path("test_foo.py"), tmp_path
|
|
)
|
|
assert tmp_path / "instrumented_test_foo.py" == result
|
|
|
|
def test_preserves_extension(self, tmp_path: Path) -> None:
|
|
"""Preserves the .py extension in the instrumented filename."""
|
|
result = create_instrumented_source_module_path(
|
|
Path("my_module.py"), tmp_path
|
|
)
|
|
assert "instrumented_my_module.py" == result.name
|
|
assert tmp_path == result.parent
|