"""Tests for _instrumentation — test instrumentation and AST transforms.""" from __future__ import annotations import ast import textwrap from pathlib import Path import libcst as cst from codeflash_python._model import ( FunctionParent, FunctionToOptimize, TestingMode, VerificationType, ) from codeflash_python.test_discovery.models import CodePosition from codeflash_python.testing._instrumentation import ( ASYNC_HELPER_FILENAME, AsyncCallInstrumenter, AsyncDecoratorAdder, FunctionCallNodeArguments, FunctionImportedAsVisitor, InjectPerfOnly, add_async_decorator_to_function, create_device_sync_precompute_statements, create_device_sync_statements, create_instrumented_source_module_path, create_wrapper_function, detect_frameworks_from_code, get_call_arguments, get_decorator_name_for_mode, inject_async_profiling_into_existing_test, inject_profiling_into_existing_test, is_argument_name, node_in_call_position, sort_imports, write_async_helper_file, ) def make_function( name: str = "target_func", file_path: str = "module.py", parents: tuple[FunctionParent, ...] = (), *, is_async: bool = False, ) -> FunctionToOptimize: """Create a FunctionToOptimize for testing.""" return FunctionToOptimize( function_name=name, file_path=Path(file_path), parents=parents, is_async=is_async, ) class TestTestingMode: """TestingMode enum values.""" def test_enum_values(self) -> None: """Each mode has the expected string value.""" assert "behavior" == TestingMode.BEHAVIOR.value assert "performance" == TestingMode.PERFORMANCE.value assert "line_profile" == TestingMode.LINE_PROFILE.value assert "concurrency" == TestingMode.CONCURRENCY.value def test_membership(self) -> None: """All four members are present.""" assert 4 == len(TestingMode) class TestVerificationType: """VerificationType str enum values.""" def test_is_str_enum(self) -> None: """VerificationType members are strings.""" assert isinstance(VerificationType.FUNCTION_CALL, str) assert isinstance(VerificationType.INIT_STATE_FTO, str) assert isinstance(VerificationType.INIT_STATE_HELPER, str) def test_enum_values(self) -> None: """Each type has the expected string value.""" assert "function_call" == VerificationType.FUNCTION_CALL assert "init_state_fto" == VerificationType.INIT_STATE_FTO assert "init_state_helper" == VerificationType.INIT_STATE_HELPER def test_membership(self) -> None: """All three members are present.""" assert 3 == len(VerificationType) class TestGetCallArguments: """get_call_arguments Call node extraction.""" def test_simple_call(self) -> None: """Extracts positional args and keywords from a Call node.""" tree = ast.parse("func(1, 2, key='val')") call_node = tree.body[0].value # type: ignore[attr-defined] result = get_call_arguments(call_node) assert isinstance(result, FunctionCallNodeArguments) assert 2 == len(result.args) assert 1 == len(result.keywords) def test_no_args(self) -> None: """Returns empty lists for a call with no arguments.""" tree = ast.parse("func()") call_node = tree.body[0].value # type: ignore[attr-defined] result = get_call_arguments(call_node) assert [] == result.args assert [] == result.keywords def test_only_keywords(self) -> None: """Returns keywords when only keyword args are present.""" tree = ast.parse("func(a=1, b=2)") call_node = tree.body[0].value # type: ignore[attr-defined] result = get_call_arguments(call_node) assert [] == result.args assert 2 == len(result.keywords) class TestNodeInCallPosition: """node_in_call_position position matching.""" def test_matching_position(self) -> None: """Returns True when Call node matches a listed position.""" code = "target_func()\n" tree = ast.parse(code) call_node = tree.body[0].value # type: ignore[attr-defined] positions = [CodePosition(line_no=1, col_no=0)] assert node_in_call_position(call_node, positions) is True def test_no_matching_position(self) -> None: """Returns False when Call node does not match any position.""" code = "target_func()\n" tree = ast.parse(code) call_node = tree.body[0].value # type: ignore[attr-defined] positions = [CodePosition(line_no=99, col_no=99)] assert node_in_call_position(call_node, positions) is False def test_empty_positions_matches_all(self) -> None: """Returns True for any Call when positions list is empty.""" code = "target_func()\n" tree = ast.parse(code) call_node = tree.body[0].value # type: ignore[attr-defined] assert node_in_call_position(call_node, []) is True def test_multiple_positions_one_match(self) -> None: """Returns True when one of several positions matches.""" code = "target_func()\n" tree = ast.parse(code) call_node = tree.body[0].value # type: ignore[attr-defined] positions = [ CodePosition(line_no=50, col_no=0), CodePosition(line_no=1, col_no=0), ] assert node_in_call_position(call_node, positions) is True class TestIsArgumentName: """is_argument_name argument detection.""" def test_regular_arg(self) -> None: """Returns True for a regular positional argument name.""" code = "def f(x, y): pass" tree = ast.parse(code) func_def = tree.body[0] assert is_argument_name("x", func_def.args) is True # type: ignore[attr-defined] def test_kwonly_arg(self) -> None: """Returns True for a keyword-only argument name.""" code = "def f(*, key): pass" tree = ast.parse(code) func_def = tree.body[0] assert is_argument_name("key", func_def.args) is True # type: ignore[attr-defined] def test_no_match(self) -> None: """Returns False when name is not an argument.""" code = "def f(x, y): pass" tree = ast.parse(code) func_def = tree.body[0] assert is_argument_name("z", func_def.args) is False # type: ignore[attr-defined] def test_vararg_not_matched(self) -> None: """Returns False for *args (vararg is not a list attribute).""" code = "def f(*args): pass" tree = ast.parse(code) func_def = tree.body[0] assert is_argument_name("args", func_def.args) is False # type: ignore[attr-defined] def test_kwarg_not_matched(self) -> None: """Returns False for **kwargs (kwarg is not a list attribute).""" code = "def f(**kwargs): pass" tree = ast.parse(code) func_def = tree.body[0] assert is_argument_name("kwargs", func_def.args) is False # type: ignore[attr-defined] class TestDetectFrameworksFromCode: """detect_frameworks_from_code import detection.""" def test_torch_import(self) -> None: """Detects torch from 'import torch'.""" code = "import torch\n" result = detect_frameworks_from_code(code) assert "torch" in result def test_tensorflow_import(self) -> None: """Detects tensorflow from 'import tensorflow'.""" code = "import tensorflow\n" result = detect_frameworks_from_code(code) assert "tensorflow" in result def test_jax_import(self) -> None: """Detects jax from 'import jax'.""" code = "import jax\n" result = detect_frameworks_from_code(code) assert "jax" in result def test_aliased_import(self) -> None: """Detects framework with alias from 'import torch as th'.""" code = "import torch as th\n" result = detect_frameworks_from_code(code) assert "torch" in result assert "th" == result["torch"] def test_no_frameworks(self) -> None: """Returns empty dict when no GPU frameworks are imported.""" code = "import os\nimport sys\n" result = detect_frameworks_from_code(code) assert {} == result def test_from_import_submodule(self) -> None: """Detects framework from 'from torch import nn'.""" code = "from torch import nn\n" result = detect_frameworks_from_code(code) assert "torch" in result def test_multiple_frameworks(self) -> None: """Detects multiple frameworks in the same code.""" code = "import torch\nimport jax\n" result = detect_frameworks_from_code(code) assert "torch" in result assert "jax" in result class TestGetDecoratorNameForMode: """get_decorator_name_for_mode decorator selection.""" def test_behavior_mode(self) -> None: """Returns codeflash_behavior_async for BEHAVIOR.""" assert "codeflash_behavior_async" == get_decorator_name_for_mode( TestingMode.BEHAVIOR, ) def test_performance_mode(self) -> None: """Returns codeflash_performance_async for PERFORMANCE.""" assert "codeflash_performance_async" == get_decorator_name_for_mode( TestingMode.PERFORMANCE, ) def test_concurrency_mode(self) -> None: """Returns codeflash_concurrency_async for CONCURRENCY.""" assert "codeflash_concurrency_async" == get_decorator_name_for_mode( TestingMode.CONCURRENCY, ) def test_line_profile_falls_through_to_performance(self) -> None: """LINE_PROFILE is not async-specific, falls through to performance.""" assert "codeflash_performance_async" == get_decorator_name_for_mode( TestingMode.LINE_PROFILE, ) class TestCreateDeviceSyncPrecomputeStatements: """create_device_sync_precompute_statements AST generation.""" def test_none_frameworks(self) -> None: """Returns empty list when frameworks is None.""" result = create_device_sync_precompute_statements(None) assert [] == result def test_empty_dict(self) -> None: """Returns empty list when frameworks dict is empty.""" result = create_device_sync_precompute_statements({}) assert [] == result def test_torch_produces_statements(self) -> None: """Produces AST statements for torch.""" result = create_device_sync_precompute_statements( {"torch": "torch"}, ) assert len(result) > 0 assert all(isinstance(s, ast.stmt) for s in result) def test_jax_produces_statements(self) -> None: """Produces AST statements for jax.""" result = create_device_sync_precompute_statements({"jax": "jax"}) assert len(result) > 0 assert all(isinstance(s, ast.stmt) for s in result) def test_tensorflow_produces_statements(self) -> None: """Produces AST statements for tensorflow.""" result = create_device_sync_precompute_statements( {"tensorflow": "tf"}, ) assert len(result) > 0 assert all(isinstance(s, ast.stmt) for s in result) def test_combined_frameworks(self) -> None: """Produces statements for multiple frameworks.""" result = create_device_sync_precompute_statements( {"torch": "torch", "jax": "jax"}, ) assert len(result) > 0 class TestCreateDeviceSyncStatements: """create_device_sync_statements AST generation.""" def test_none_frameworks(self) -> None: """Returns empty list when frameworks is None.""" result = create_device_sync_statements(None) assert [] == result def test_empty_dict(self) -> None: """Returns empty list when frameworks dict is empty.""" result = create_device_sync_statements({}) assert [] == result def test_torch_sync(self) -> None: """Produces sync statements for torch.""" result = create_device_sync_statements({"torch": "torch"}) assert len(result) > 0 assert all(isinstance(s, ast.stmt) for s in result) def test_for_return_value_flag(self) -> None: """Produces statements with for_return_value=True.""" result = create_device_sync_statements( {"jax": "jax"}, for_return_value=True, ) assert len(result) > 0 assert all(isinstance(s, ast.stmt) for s in result) def test_tensorflow_sync(self) -> None: """Produces sync statements for tensorflow.""" result = create_device_sync_statements({"tensorflow": "tf"}) assert len(result) > 0 assert all(isinstance(s, ast.stmt) for s in result) class TestCreateWrapperFunction: """create_wrapper_function AST generation.""" def test_returns_function_def(self) -> None: """Returns an ast.FunctionDef node.""" result = create_wrapper_function(TestingMode.BEHAVIOR) assert isinstance(result, ast.FunctionDef) def test_function_name(self) -> None: """The generated function is named codeflash_wrap.""" result = create_wrapper_function(TestingMode.BEHAVIOR) assert "codeflash_wrap" == result.name def test_behavior_mode_params(self) -> None: """BEHAVIOR mode wrapper has expected parameters.""" result = create_wrapper_function(TestingMode.BEHAVIOR) arg_names = [a.arg for a in result.args.args] assert len(arg_names) > 0 def test_performance_mode_params(self) -> None: """PERFORMANCE mode wrapper has expected parameters.""" result = create_wrapper_function(TestingMode.PERFORMANCE) arg_names = [a.arg for a in result.args.args] assert len(arg_names) > 0 def test_body_is_nonempty(self) -> None: """The function body contains statements.""" result = create_wrapper_function(TestingMode.BEHAVIOR) assert len(result.body) > 0 def test_with_frameworks(self) -> None: """Accepts used_frameworks parameter without error.""" result = create_wrapper_function( TestingMode.PERFORMANCE, used_frameworks={"torch": "torch"}, ) assert isinstance(result, ast.FunctionDef) class TestInjectPerfOnly: """InjectPerfOnly AST transformer.""" def test_wraps_name_call(self) -> None: """Wraps a direct Name call with codeflash_wrap.""" code = textwrap.dedent("""\ def test_it(): result = target_func(1, 2) """) tree = ast.parse(code) call_node = tree.body[0].body[0].value # type: ignore[attr-defined] pos = CodePosition( line_no=call_node.lineno, col_no=call_node.col_offset, ) func = make_function("target_func", "module.py") transformer = InjectPerfOnly( function=func, module_path="module", call_positions=[pos], mode=TestingMode.BEHAVIOR, ) new_tree = transformer.visit(tree) source = ast.unparse(new_tree) assert "codeflash_wrap" in source def test_wraps_attribute_call(self) -> None: """Wraps a module.func() attribute call with codeflash_wrap.""" code = textwrap.dedent("""\ def test_it(): result = module.target_func(1, 2) """) tree = ast.parse(code) call_node = tree.body[0].body[0].value # type: ignore[attr-defined] pos = CodePosition( line_no=call_node.lineno, col_no=call_node.col_offset, ) func = make_function("target_func", "module.py") transformer = InjectPerfOnly( function=func, module_path="module", call_positions=[pos], mode=TestingMode.BEHAVIOR, ) new_tree = transformer.visit(tree) source = ast.unparse(new_tree) assert "codeflash_wrap" in source def test_no_wrap_without_matching_position(self) -> None: """Does not wrap calls that are not in call_positions.""" code = textwrap.dedent("""\ def test_it(): result = target_func(1, 2) """) tree = ast.parse(code) func = make_function("target_func", "module.py") transformer = InjectPerfOnly( function=func, module_path="module", call_positions=[CodePosition(line_no=99, col_no=99)], mode=TestingMode.BEHAVIOR, ) new_tree = transformer.visit(tree) source = ast.unparse(new_tree) assert "codeflash_wrap" not in source class TestAsyncCallInstrumenter: """AsyncCallInstrumenter AST transformer.""" def _make_transformer( self, code: str, *, name: str = "target_func", parents: tuple[FunctionParent, ...] = (), positions: list[CodePosition] | None = None, ) -> tuple[AsyncCallInstrumenter, ast.Module]: """Parse code and build a transformer with call positions from it.""" tree = ast.parse(code) if positions is None: positions = [] for node in ast.walk(tree): if isinstance(node, ast.Call): func_name = None if isinstance(node.func, ast.Name): func_name = node.func.id elif isinstance(node.func, ast.Attribute): func_name = node.func.attr if func_name == name: positions.append( CodePosition( line_no=node.lineno, col_no=node.col_offset, ) ) func = make_function( name, "module.py", parents=parents, is_async=True, ) transformer = AsyncCallInstrumenter( function=func, module_path="module", call_positions=positions, ) return transformer, tree def test_instruments_await_call(self) -> None: """Adds env var assignment before an awaited target call.""" code = textwrap.dedent("""\ async def test_it(): result = await target_func(1, 2) """) transformer, tree = self._make_transformer(code) new_tree = transformer.visit(tree) source = ast.unparse(new_tree) assert "CODEFLASH_CURRENT_LINE_ID" in source assert transformer.did_instrument is True def test_skips_non_test_async_functions(self) -> None: """Does not instrument async functions that don't start with test_.""" code = textwrap.dedent("""\ async def helper(): result = await target_func(1, 2) """) transformer, tree = self._make_transformer(code) new_tree = transformer.visit(tree) source = ast.unparse(new_tree) assert "CODEFLASH_CURRENT_LINE_ID" not in source assert transformer.did_instrument is False def test_skips_non_test_sync_functions(self) -> None: """Does not instrument sync functions that don't start with test_.""" code = textwrap.dedent("""\ def helper(): result = await target_func(1, 2) """) transformer, tree = self._make_transformer(code) new_tree = transformer.visit(tree) source = ast.unparse(new_tree) assert "CODEFLASH_CURRENT_LINE_ID" not in source def test_instruments_sync_test_with_await(self) -> None: """Instruments sync test_ functions that contain awaited calls.""" code = textwrap.dedent("""\ def test_it(): result = await target_func(1, 2) """) transformer, tree = self._make_transformer(code) new_tree = transformer.visit(tree) source = ast.unparse(new_tree) assert "CODEFLASH_CURRENT_LINE_ID" in source assert transformer.did_instrument is True def test_multiple_awaits_get_incrementing_ids(self) -> None: """Each awaited target call gets a unique incrementing counter.""" code = textwrap.dedent("""\ async def test_it(): a = await target_func(1) b = await target_func(2) c = await target_func(3) """) transformer, tree = self._make_transformer(code) new_tree = transformer.visit(tree) source = ast.unparse(new_tree) assert "'0'" in source assert "'1'" in source assert "'2'" in source def test_attribute_style_call(self) -> None: """Instruments await obj.target_func() attribute-style calls.""" code = textwrap.dedent("""\ async def test_it(): result = await obj.target_func(1, 2) """) transformer, tree = self._make_transformer(code) new_tree = transformer.visit(tree) source = ast.unparse(new_tree) assert "CODEFLASH_CURRENT_LINE_ID" in source assert transformer.did_instrument is True def test_recurses_into_class_body(self) -> None: """Finds and instruments test methods inside a class.""" code = textwrap.dedent("""\ class TestSuite: async def test_it(self): result = await target_func(1) """) transformer, tree = self._make_transformer(code) new_tree = transformer.visit(tree) source = ast.unparse(new_tree) assert "CODEFLASH_CURRENT_LINE_ID" in source assert transformer.did_instrument is True def test_no_match_when_position_wrong(self) -> None: """Does not instrument when call positions don't match.""" code = textwrap.dedent("""\ async def test_it(): result = await target_func(1, 2) """) transformer, tree = self._make_transformer( code, positions=[CodePosition(line_no=99, col_no=99)], ) new_tree = transformer.visit(tree) assert transformer.did_instrument is False def test_nested_await_in_conditional(self) -> None: """Finds awaited target calls nested inside if statements.""" code = textwrap.dedent("""\ async def test_it(): if True: result = await target_func(1) """) transformer, tree = self._make_transformer(code) new_tree = transformer.visit(tree) source = ast.unparse(new_tree) assert "CODEFLASH_CURRENT_LINE_ID" in source def test_ignores_non_target_awaits(self) -> None: """Does not instrument awaits of unrelated functions.""" code = textwrap.dedent("""\ async def test_it(): result = await other_func(1, 2) """) transformer, tree = self._make_transformer(code) new_tree = transformer.visit(tree) assert transformer.did_instrument is False def test_subscript_style_call_not_matched(self) -> None: """Await of a subscript-style call (funcs[0]()) is not matched.""" code = textwrap.dedent("""\ async def test_it(): result = await funcs[0](1, 2) """) transformer, tree = self._make_transformer(code) transformer.visit(tree) assert transformer.did_instrument is False def test_counters_independent_per_test_function(self) -> None: """Each test function gets its own independent counter.""" code = textwrap.dedent("""\ async def test_a(): await target_func(1) await target_func(2) async def test_b(): await target_func(3) """) transformer, tree = self._make_transformer(code) transformer.visit(tree) assert 2 == transformer.async_call_counter["test_a"] assert 1 == transformer.async_call_counter["test_b"] class TestFunctionImportedAsVisitor: """FunctionImportedAsVisitor alias detection.""" def test_aliased_import(self) -> None: """Updates imported_as with the aliased FunctionToOptimize.""" code = textwrap.dedent("""\ from module import target_func as tf """) tree = ast.parse(code) func = make_function("target_func", "module.py") visitor = FunctionImportedAsVisitor(func) visitor.visit(tree) assert "tf" == visitor.imported_as.function_name def test_non_aliased_import(self) -> None: """Keeps original function when imported without alias.""" code = textwrap.dedent("""\ from module import target_func """) tree = ast.parse(code) func = make_function("target_func", "module.py") visitor = FunctionImportedAsVisitor(func) visitor.visit(tree) assert visitor.imported_as is func def test_class_method_aliased_import(self) -> None: """Updates parent name when class is imported with alias.""" code = textwrap.dedent("""\ from module import MyClass as MC """) tree = ast.parse(code) parent = FunctionParent(name="MyClass", type="ClassDef") func = make_function( "method", "module.py", parents=(parent,), ) visitor = FunctionImportedAsVisitor(func) visitor.visit(tree) assert "MC" == visitor.imported_as.parents[0].name def test_no_import(self) -> None: """Keeps original function when not imported.""" code = textwrap.dedent("""\ import os """) tree = ast.parse(code) func = make_function("target_func", "module.py") visitor = FunctionImportedAsVisitor(func) visitor.visit(tree) assert visitor.imported_as is func class TestAsyncDecoratorAdder: """AsyncDecoratorAdder CST transformer.""" def test_adds_decorator_to_async_function(self) -> None: """Adds the async decorator to a matching async function.""" code = textwrap.dedent("""\ async def target_func(): pass """) tree = cst.parse_module(code) func = make_function("target_func", "module.py", is_async=True) transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR) new_tree = tree.visit(transformer) output = new_tree.code assert "@codeflash_behavior_async" in output assert transformer.added_decorator is True def test_does_not_add_to_non_matching(self) -> None: """Does not add decorator to functions that do not match.""" code = textwrap.dedent("""\ async def other_func(): pass """) tree = cst.parse_module(code) func = make_function("target_func", "module.py", is_async=True) transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR) new_tree = tree.visit(transformer) output = new_tree.code assert "@" not in output assert transformer.added_decorator is False def test_does_not_add_to_sync_function(self) -> None: """Skips sync functions even if the name matches.""" code = textwrap.dedent("""\ def target_func(): pass """) tree = cst.parse_module(code) func = make_function("target_func", "module.py", is_async=True) transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR) new_tree = tree.visit(transformer) assert "@" not in new_tree.code assert transformer.added_decorator is False def test_performance_mode_decorator(self) -> None: """Uses codeflash_performance_async for PERFORMANCE mode.""" code = "async def target_func():\n pass\n" tree = cst.parse_module(code) func = make_function("target_func", "module.py", is_async=True) transformer = AsyncDecoratorAdder( func, mode=TestingMode.PERFORMANCE, ) new_tree = tree.visit(transformer) assert "@codeflash_performance_async" in new_tree.code def test_concurrency_mode_decorator(self) -> None: """Uses codeflash_concurrency_async for CONCURRENCY mode.""" code = "async def target_func():\n pass\n" tree = cst.parse_module(code) func = make_function("target_func", "module.py", is_async=True) transformer = AsyncDecoratorAdder( func, mode=TestingMode.CONCURRENCY, ) new_tree = tree.visit(transformer) assert "@codeflash_concurrency_async" in new_tree.code def test_class_method_matching(self) -> None: """Matches async method inside a class via qualified name.""" code = textwrap.dedent("""\ class MyClass: async def target_func(self): pass """) tree = cst.parse_module(code) parent = FunctionParent(name="MyClass", type="ClassDef") func = make_function( "target_func", "module.py", parents=(parent,), is_async=True, ) transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR) new_tree = tree.visit(transformer) assert "@codeflash_behavior_async" in new_tree.code assert transformer.added_decorator is True def test_no_duplicate_when_already_decorated(self) -> None: """Does not add a second decorator when one is already present.""" code = textwrap.dedent("""\ @codeflash_behavior_async async def target_func(): pass """) tree = cst.parse_module(code) func = make_function("target_func", "module.py", is_async=True) transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR) new_tree = tree.visit(transformer) assert new_tree.code.count("codeflash_behavior_async") == 1 assert transformer.added_decorator is False def test_no_duplicate_when_call_style_decorator(self) -> None: """Detects existing decorator even in @decorator() call form.""" code = textwrap.dedent("""\ @codeflash_behavior_async() async def target_func(): pass """) tree = cst.parse_module(code) func = make_function("target_func", "module.py", is_async=True) transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR) new_tree = tree.visit(transformer) assert new_tree.code.count("codeflash_behavior_async") == 1 assert transformer.added_decorator is False def test_preserves_existing_decorators(self) -> None: """Keeps existing decorators and prepends the codeflash one.""" code = textwrap.dedent("""\ @staticmethod async def target_func(): pass """) tree = cst.parse_module(code) func = make_function("target_func", "module.py", is_async=True) transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR) new_tree = tree.visit(transformer) output = new_tree.code assert "@codeflash_behavior_async" in output assert "@staticmethod" in output behavior_pos = output.index("@codeflash_behavior_async") static_pos = output.index("@staticmethod") assert behavior_pos < static_pos def test_attribute_decorator_not_matched(self) -> None: """Attribute-style decorators (mod.decorator) are not codeflash.""" code = textwrap.dedent("""\ @mod.codeflash_behavior_async async def target_func(): pass """) tree = cst.parse_module(code) func = make_function("target_func", "module.py", is_async=True) transformer = AsyncDecoratorAdder(func, mode=TestingMode.BEHAVIOR) new_tree = tree.visit(transformer) assert new_tree.code.count("codeflash_behavior_async") == 2 assert transformer.added_decorator is True class TestWriteAsyncHelperFile: """write_async_helper_file file creation.""" def test_creates_file(self, tmp_path: Path) -> None: """Creates the async helper file in the target directory.""" result = write_async_helper_file(tmp_path) assert result.exists() assert result.is_file() def test_file_name(self, tmp_path: Path) -> None: """The created file has the expected name.""" result = write_async_helper_file(tmp_path) assert ASYNC_HELPER_FILENAME == result.name def test_file_content_is_valid_python(self, tmp_path: Path) -> None: """The copied file is valid, parseable Python.""" result = write_async_helper_file(tmp_path) content = result.read_text() assert len(content) > 0 ast.parse(content) def test_idempotent_does_not_overwrite(self, tmp_path: Path) -> None: """Calling twice does not overwrite the existing file.""" first = write_async_helper_file(tmp_path) first.write_text("sentinel", encoding="utf-8") second = write_async_helper_file(tmp_path) assert first == second assert "sentinel" == second.read_text() class TestAsyncHelperConstants: """ASYNC_HELPER_FILENAME and runtime decorator source.""" def test_filename_value(self) -> None: """ASYNC_HELPER_FILENAME has the expected value.""" assert "codeflash_async_wrapper.py" == ASYNC_HELPER_FILENAME def test_runtime_decorator_is_self_contained(self) -> None: """Runtime decorator file has no internal codeflash imports.""" from codeflash_python.testing._instrument_async import ( _RUNTIME_DECORATOR_PATH, ) source = _RUNTIME_DECORATOR_PATH.read_text("utf-8") assert _RUNTIME_DECORATOR_PATH.exists() for line in source.splitlines(): stripped = line.strip() if stripped.startswith(("import ", "from ")): assert "codeflash_python" not in stripped class TestSortImports: """sort_imports import sorting.""" def test_sorts_unsorted_imports(self) -> None: """Sorts and deduplicates unsorted imports.""" code = textwrap.dedent("""\ import os import ast import os """) result = sort_imports(code) lines = result.strip().splitlines() assert "import ast" in lines[0] assert "import os" in lines[1] # Duplicate removed assert result.count("import os") == 1 def test_syntax_error_returns_original(self) -> None: """Returns original code unchanged when isort encounters issues.""" code = "import os\nimport ast\n" # isort handles most inputs gracefully; verify normal code works result = sort_imports(code) assert isinstance(result, str) assert "import" in result def test_kwargs_forwarded(self) -> None: """Forwards kwargs to isort.code (e.g. float_to_top).""" code = textwrap.dedent("""\ from os.path import join x = 1 import ast """) result = sort_imports(code, float_to_top=True) # With float_to_top, imports should be grouped at the top lines = result.strip().splitlines() # Both imports should appear before 'x = 1' import_lines = [i for i, line in enumerate(lines) if "import" in line] code_lines = [ i for i, line in enumerate(lines) if line.strip() == "x = 1" ] if import_lines and code_lines: assert max(import_lines) < min(code_lines) class TestInjectProfilingIntoExistingTest: """inject_profiling_into_existing_test orchestration.""" def test_sync_function_instrumentation(self, tmp_path: Path) -> None: """Instruments a sync test file with codeflash_wrap and imports.""" project_root = tmp_path / "project" project_root.mkdir() test_file = project_root / "test_example.py" test_code = textwrap.dedent("""\ from module import target_func def test_something(): result = target_func(1, 2) assert result == 3 """) test_file.write_text(test_code, encoding="utf-8") func = make_function("target_func", "module.py") # target_func(1, 2) is on line 4, col 13 positions = [CodePosition(line_no=4, col_no=13)] ok, source = inject_profiling_into_existing_test( test_file, positions, func, project_root, mode=TestingMode.PERFORMANCE, ) assert ok is True assert source is not None assert "codeflash_wrap" in source assert "import time" in source assert "import gc" in source assert "import os" in source def test_async_delegation(self, tmp_path: Path) -> None: """Delegates to async handler for async functions without error.""" project_root = tmp_path / "project" project_root.mkdir() test_file = project_root / "test_async.py" test_code = textwrap.dedent("""\ from module import target_func async def test_something(): result = await target_func(1, 2) """) test_file.write_text(test_code, encoding="utf-8") func = make_function("target_func", "module.py", is_async=True) positions = [CodePosition(line_no=4, col_no=25)] ok, source = inject_profiling_into_existing_test( test_file, positions, func, project_root, ) # Should delegate to async path and return a result assert isinstance(ok, bool) if ok: assert source is not None def test_syntax_error_returns_false(self, tmp_path: Path) -> None: """Returns (False, None) for a file with invalid Python.""" project_root = tmp_path / "project" project_root.mkdir() test_file = project_root / "test_bad.py" test_file.write_text( "def test_x(\n not valid python !!!", encoding="utf-8", ) func = make_function("target_func", "module.py") positions = [CodePosition(line_no=1, col_no=0)] ok, source = inject_profiling_into_existing_test( test_file, positions, func, project_root, ) assert ok is False assert source is None def test_behavior_mode_extra_imports(self, tmp_path: Path) -> None: """BEHAVIOR mode adds inspect, sqlite3, and dill imports.""" project_root = tmp_path / "project" project_root.mkdir() test_file = project_root / "test_behav.py" test_code = textwrap.dedent("""\ from module import target_func def test_something(): result = target_func(1, 2) assert result == 3 """) test_file.write_text(test_code, encoding="utf-8") func = make_function("target_func", "module.py") positions = [CodePosition(line_no=4, col_no=13)] ok, source = inject_profiling_into_existing_test( test_file, positions, func, project_root, mode=TestingMode.BEHAVIOR, ) assert ok is True assert source is not None assert "inspect" in source assert "sqlite3" in source assert "dill" in source class TestInjectAsyncProfilingIntoExistingTest: """inject_async_profiling_into_existing_test orchestration.""" def test_async_instrumentation(self, tmp_path: Path) -> None: """Instruments an async test file and adds import os.""" project_root = tmp_path / "project" project_root.mkdir() test_file = project_root / "test_async_ex.py" test_code = textwrap.dedent("""\ from module import target_func async def test_something(): result = await target_func(1, 2) """) test_file.write_text(test_code, encoding="utf-8") func = make_function("target_func", "module.py", is_async=True) positions = [CodePosition(line_no=4, col_no=25)] ok, source = inject_async_profiling_into_existing_test( test_file, positions, func, project_root, ) assert ok is True assert source is not None assert "import os" in source def test_no_instrumentation(self, tmp_path: Path) -> None: """Returns (False, None) when test does not call target.""" project_root = tmp_path / "project" project_root.mkdir() test_file = project_root / "test_no_call.py" test_code = textwrap.dedent("""\ def test_something(): assert 1 == 1 """) test_file.write_text(test_code, encoding="utf-8") func = make_function("target_func", "module.py", is_async=True) positions = [CodePosition(line_no=2, col_no=0)] ok, source = inject_async_profiling_into_existing_test( test_file, positions, func, project_root, ) assert ok is False assert source is None def test_syntax_error_returns_false(self, tmp_path: Path) -> None: """Returns (False, None) for a file with invalid Python.""" project_root = tmp_path / "project" project_root.mkdir() test_file = project_root / "test_bad.py" test_file.write_text( "async def test_x(\n not valid !!!", encoding="utf-8", ) func = make_function("target_func", "module.py", is_async=True) ok, source = inject_async_profiling_into_existing_test( test_file, [CodePosition(line_no=1, col_no=0)], func, project_root, ) assert ok is False assert source is None def test_multiple_awaits_get_sequential_ids( self, tmp_path: Path, ) -> None: """Multiple awaited calls in one test get sequential counter IDs.""" project_root = tmp_path / "project" project_root.mkdir() test_file = project_root / "test_multi.py" test_code = textwrap.dedent("""\ from module import target_func async def test_multi(): a = await target_func(1) b = await target_func(2) """) test_file.write_text(test_code, encoding="utf-8") func = make_function("target_func", "module.py", is_async=True) positions = [ CodePosition(line_no=4, col_no=22), CodePosition(line_no=5, col_no=22), ] ok, source = inject_async_profiling_into_existing_test( test_file, positions, func, project_root, ) assert ok is True assert source is not None assert source.count("CODEFLASH_CURRENT_LINE_ID") == 2 class TestAddAsyncDecoratorToFunction: """add_async_decorator_to_function source rewriting.""" def test_non_async_returns_false(self, tmp_path: Path) -> None: """Returns False for a non-async function without modifying.""" source_file = tmp_path / "module.py" source_file.write_text( "def target_func():\n pass\n", encoding="utf-8", ) func = make_function("target_func", str(source_file)) result, originals = add_async_decorator_to_function( source_file, func, TestingMode.BEHAVIOR, ) assert result is False assert {} == originals assert "decorator" not in source_file.read_text() def test_async_function_gets_decorator(self, tmp_path: Path) -> None: """Adds decorator, rewrites file, creates helper file.""" source_file = tmp_path / "module.py" source_code = textwrap.dedent("""\ async def target_func(): pass """) source_file.write_text(source_code, encoding="utf-8") func = make_function( "target_func", str(source_file), is_async=True, ) result, originals = add_async_decorator_to_function( source_file, func, TestingMode.BEHAVIOR, ) assert result is True modified = source_file.read_text() assert "codeflash_behavior_async" in modified helper = tmp_path / ASYNC_HELPER_FILENAME assert helper.exists() def test_originals_contains_pre_modification_source( self, tmp_path: Path, ) -> None: """The originals dict maps the file to its content before rewriting.""" source_file = tmp_path / "module.py" original_code = "async def target_func():\n pass\n" source_file.write_text(original_code, encoding="utf-8") func = make_function( "target_func", str(source_file), is_async=True, ) _, originals = add_async_decorator_to_function( source_file, func, TestingMode.BEHAVIOR, ) assert source_file in originals assert original_code == originals[source_file] def test_with_explicit_project_root(self, tmp_path: Path) -> None: """Writes helper file to project_root when specified.""" src_dir = tmp_path / "src" src_dir.mkdir() source_file = src_dir / "module.py" source_code = textwrap.dedent("""\ async def target_func(): pass """) source_file.write_text(source_code, encoding="utf-8") project_root = tmp_path / "root" project_root.mkdir() func = make_function( "target_func", str(source_file), is_async=True, ) result, _ = add_async_decorator_to_function( source_file, func, TestingMode.BEHAVIOR, project_root=project_root, ) assert result is True assert (project_root / ASYNC_HELPER_FILENAME).exists() assert not (src_dir / ASYNC_HELPER_FILENAME).exists() def test_already_decorated_returns_false(self, tmp_path: Path) -> None: """Returns False when the function already has the decorator.""" source_file = tmp_path / "module.py" source_code = textwrap.dedent("""\ @codeflash_behavior_async async def target_func(): pass """) source_file.write_text(source_code, encoding="utf-8") func = make_function( "target_func", str(source_file), is_async=True, ) result, originals = add_async_decorator_to_function( source_file, func, TestingMode.BEHAVIOR, ) assert result is False assert {} == originals def test_adds_import_for_decorator(self, tmp_path: Path) -> None: """The rewritten file includes the import for the decorator.""" source_file = tmp_path / "module.py" source_file.write_text( "async def target_func():\n pass\n", encoding="utf-8", ) func = make_function( "target_func", str(source_file), is_async=True, ) add_async_decorator_to_function( source_file, func, TestingMode.PERFORMANCE, ) modified = source_file.read_text() assert "from codeflash_async_wrapper import" in modified assert "codeflash_performance_async" in modified def test_cst_parse_error_returns_false(self, tmp_path: Path) -> None: """Returns (False, {}) when the source file has invalid syntax.""" source_file = tmp_path / "module.py" source_file.write_text( "async def target_func(\n invalid !!!", encoding="utf-8", ) func = make_function( "target_func", str(source_file), is_async=True, ) result, originals = add_async_decorator_to_function( source_file, func, TestingMode.BEHAVIOR, ) assert result is False assert {} == originals class TestCreateInstrumentedSourceModulePath: """create_instrumented_source_module_path path construction.""" def test_basic_path(self, tmp_path: Path) -> None: """Constructs instrumented path from source path and temp dir.""" result = create_instrumented_source_module_path( Path("test_foo.py"), tmp_path ) assert tmp_path / "instrumented_test_foo.py" == result def test_preserves_extension(self, tmp_path: Path) -> None: """Preserves the .py extension in the instrumented filename.""" result = create_instrumented_source_module_path( Path("my_module.py"), tmp_path ) assert "instrumented_my_module.py" == result.name assert tmp_path == result.parent