Fix call-site IDs to use source line numbers instead of sequential counter

Restore the old InjectPerfOnly behavior where call-site identifiers
are the source line number of the instrumented statement. Also fix
the sync integration test to properly apply the decorator and write
the helper file, and remove dead imports from test_instrumentation.
This commit is contained in:
Kevin Turcios 2026-04-24 07:12:45 -05:00
parent 5b20981cd4
commit 2c9f2ad8de
7 changed files with 74 additions and 223 deletions

View file

@ -44,7 +44,6 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
self.function_object = function self.function_object = function
self.call_positions = call_positions self.call_positions = call_positions
self.did_instrument = False self.did_instrument = False
self.async_call_counter: dict[str, int] = {}
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
"""Recurse into class bodies to find test methods.""" """Recurse into class bodies to find test methods."""
@ -72,18 +71,12 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
node: ast.AsyncFunctionDef | ast.FunctionDef, node: ast.AsyncFunctionDef | ast.FunctionDef,
) -> ast.AsyncFunctionDef | ast.FunctionDef: ) -> ast.AsyncFunctionDef | ast.FunctionDef:
"""Add _codeflash_call_site.set() calls before target await calls.""" """Add _codeflash_call_site.set() calls before target await calls."""
if node.name not in self.async_call_counter:
self.async_call_counter[node.name] = 0
new_body: list[ast.stmt] = [] new_body: list[ast.stmt] = []
for stmt in node.body: for stmt in node.body:
_, has_target = self._find_awaited_target_call(stmt) _, has_target = self._find_awaited_target_call(stmt)
if has_target: if has_target:
current_call_index = self.async_call_counter[node.name]
self.async_call_counter[node.name] += 1
call_site_set = ast.Expr( call_site_set = ast.Expr(
value=ast.Call( value=ast.Call(
func=ast.Attribute( func=ast.Attribute(
@ -96,7 +89,7 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
), ),
args=[ args=[
ast.Constant( ast.Constant(
value=f"{current_call_index}", value=f"{stmt.lineno}",
), ),
], ],
keywords=[], keywords=[],

View file

@ -53,7 +53,6 @@ class SyncCallInstrumenter(ast.NodeTransformer):
self.function_object = function self.function_object = function
self.call_positions = call_positions self.call_positions = call_positions
self.did_instrument = False self.did_instrument = False
self.call_counter: dict[str, int] = {}
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
"""Recurse into class bodies to find test methods.""" """Recurse into class bodies to find test methods."""
@ -78,18 +77,12 @@ class SyncCallInstrumenter(ast.NodeTransformer):
node: ast.FunctionDef | ast.AsyncFunctionDef, node: ast.FunctionDef | ast.AsyncFunctionDef,
) -> ast.FunctionDef | ast.AsyncFunctionDef: ) -> ast.FunctionDef | ast.AsyncFunctionDef:
"""Add _codeflash_call_site.set() calls before target function calls.""" """Add _codeflash_call_site.set() calls before target function calls."""
if node.name not in self.call_counter:
self.call_counter[node.name] = 0
new_body: list[ast.stmt] = [] new_body: list[ast.stmt] = []
for stmt in node.body: for stmt in node.body:
_, has_target = self._find_target_call(stmt) call_node, has_target = self._find_target_call(stmt)
if has_target: if has_target:
current_call_index = self.call_counter[node.name]
self.call_counter[node.name] += 1
call_site_set = ast.Expr( call_site_set = ast.Expr(
value=ast.Call( value=ast.Call(
func=ast.Attribute( func=ast.Attribute(
@ -102,7 +95,7 @@ class SyncCallInstrumenter(ast.NodeTransformer):
), ),
args=[ args=[
ast.Constant( ast.Constant(
value=f"{current_call_index}", value=f"{stmt.lineno}",
), ),
], ],
keywords=[], keywords=[],

View file

@ -1,41 +1,48 @@
import sys import sys
from codeflash_async_wrapper import codeflash_behavior_sync
from codeflash_python.runtime._codeflash_capture import codeflash_capture
class BubbleSorter: class BubbleSorter:
@codeflash_capture(function_name='BubbleSorter.__init__', tmp_dir_path='/var/folders/mg/k_c0twcj37q_gph3cfy3zlt80000gn/T/codeflash_l3k89hc3/codeflash_results', tests_root='/Users/krrt7/Desktop/work/cf_org/codeflash-agent/.claude/worktrees/jaunty-sauteeing-dolphin/packages/codeflash-python/tests/code_to_optimize/tests/pytest', is_fto=True)
def __init__(self, x=0): def __init__(self, x=0):
self.x = x self.x = x
@codeflash_behavior_sync
def sorter(self, arr): def sorter(self, arr):
print("codeflash stdout : BubbleSorter.sorter() called") print('codeflash stdout : BubbleSorter.sorter() called')
for i in range(len(arr)): for i in range(len(arr)):
for j in range(len(arr) - 1): for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]: if arr[j] > arr[j + 1]:
temp = arr[j] temp = arr[j]
arr[j] = arr[j + 1] arr[j] = arr[j + 1]
arr[j + 1] = temp arr[j + 1] = temp
print("stderr test", file=sys.stderr) print('stderr test', file=sys.stderr)
return arr return arr
@classmethod @classmethod
def sorter_classmethod(cls, arr): def sorter_classmethod(cls, arr):
print("codeflash stdout : BubbleSorter.sorter_classmethod() called") print('codeflash stdout : BubbleSorter.sorter_classmethod() called')
for i in range(len(arr)): for i in range(len(arr)):
for j in range(len(arr) - 1): for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]: if arr[j] > arr[j + 1]:
temp = arr[j] temp = arr[j]
arr[j] = arr[j + 1] arr[j] = arr[j + 1]
arr[j + 1] = temp arr[j + 1] = temp
print("stderr test classmethod", file=sys.stderr) print('stderr test classmethod', file=sys.stderr)
return arr return arr
@staticmethod @staticmethod
def sorter_staticmethod(arr): def sorter_staticmethod(arr):
print("codeflash stdout : BubbleSorter.sorter_staticmethod() called") print('codeflash stdout : BubbleSorter.sorter_staticmethod() called')
for i in range(len(arr)): for i in range(len(arr)):
for j in range(len(arr) - 1): for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]: if arr[j] > arr[j + 1]:
temp = arr[j] temp = arr[j]
arr[j] = arr[j + 1] arr[j] = arr[j + 1]
arr[j + 1] = temp arr[j + 1] = temp
print("stderr test staticmethod", file=sys.stderr) print('stderr test staticmethod', file=sys.stderr)
return arr return arr

View file

@ -21,6 +21,9 @@ from codeflash_python.testing._instrument_async import (
from codeflash_python.testing._instrument_capture import ( from codeflash_python.testing._instrument_capture import (
instrument_codeflash_capture, instrument_codeflash_capture,
) )
from codeflash_python.testing._instrument_sync import (
add_sync_decorator_to_function,
)
from codeflash_python.testing._instrumentation import ( from codeflash_python.testing._instrumentation import (
inject_profiling_into_existing_test, inject_profiling_into_existing_test,
) )
@ -968,6 +971,14 @@ def test_sync_sort():
with test_path.open("w") as f: with test_path.open("w") as f:
f.write(instrumented_test) f.write(instrumented_test)
added, sync_originals = add_sync_decorator_to_function(
sync_fto_path,
func,
mode=TestingMode.BEHAVIOR,
project_root=project_root,
)
assert added
instrument_codeflash_capture(func, {}, tests_root) instrument_codeflash_capture(func, {}, tests_root)
test_env = os.environ.copy() test_env = os.environ.copy()
@ -1018,7 +1029,7 @@ def test_sync_sort():
results_list = test_results.test_results results_list = test_results.test_results
assert results_list[0].id.function_getting_tested == "sync_sorter" assert results_list[0].id.function_getting_tested == "sync_sorter"
assert results_list[0].id.iteration_id == "1_0" assert results_list[0].id.iteration_id == "6_0"
assert results_list[0].id.test_class_name is None assert results_list[0].id.test_class_name is None
assert results_list[0].id.test_function_name == "test_sync_sort" assert results_list[0].id.test_function_name == "test_sync_sort"
assert results_list[0].did_pass assert results_list[0].did_pass
@ -1031,7 +1042,7 @@ def test_sync_sort():
if len(results_list) > 1: if len(results_list) > 1:
assert results_list[1].id.function_getting_tested == "sync_sorter" assert results_list[1].id.function_getting_tested == "sync_sorter"
assert results_list[1].id.iteration_id == "4_0" assert results_list[1].id.iteration_id == "10_0"
assert results_list[1].id.test_function_name == "test_sync_sort" assert results_list[1].id.test_function_name == "test_sync_sort"
assert results_list[1].did_pass assert results_list[1].did_pass
@ -1045,6 +1056,9 @@ def test_sync_sort():
test_path.unlink() test_path.unlink()
if test_path_perf.exists(): if test_path_perf.exists():
test_path_perf.unlink() test_path_perf.unlink()
helper_path = project_root / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()
@pytest.mark.skipif( @pytest.mark.skipif(

View file

@ -750,29 +750,22 @@ async def test_multiple_calls():
assert success assert success
assert instrumented_test_code is not None assert instrumented_test_code is not None
assert ( assert "_codeflash_call_site.set('8')" in instrumented_test_code
"_codeflash_call_site.set('0')" assert "_codeflash_call_site.set('13')" in instrumented_test_code
in instrumented_test_code assert "_codeflash_call_site.set('14')" in instrumented_test_code
) assert "_codeflash_call_site.set('15')" in instrumented_test_code
line_id_0_count = instrumented_test_code.count( assert 1 == instrumented_test_code.count(
"_codeflash_call_site.set('0')" "_codeflash_call_site.set('8')"
) )
line_id_1_count = instrumented_test_code.count( assert 1 == instrumented_test_code.count(
"_codeflash_call_site.set('1')" "_codeflash_call_site.set('13')"
) )
line_id_2_count = instrumented_test_code.count( assert 1 == instrumented_test_code.count(
"_codeflash_call_site.set('2')" "_codeflash_call_site.set('14')"
) )
assert 1 == instrumented_test_code.count(
assert 2 == line_id_0_count, ( "_codeflash_call_site.set('15')"
f"Expected 2 occurrences of line_id '0', got {line_id_0_count}"
)
assert 1 == line_id_1_count, (
f"Expected 1 occurrence of line_id '1', got {line_id_1_count}"
)
assert 1 == line_id_2_count, (
f"Expected 1 occurrence of line_id '2', got {line_id_2_count}"
) )

View file

@ -148,7 +148,7 @@ class TestSyncCallInstrumenter:
assert instrumenter.did_instrument assert instrumenter.did_instrument
source = ast.unparse(tree) source = ast.unparse(tree)
assert "_codeflash_call_site.set('0')" in source assert "_codeflash_call_site.set('2')" in source
def test_instruments_method_call(self) -> None: def test_instruments_method_call(self) -> None:
"""Injects call-site set before obj.method() style calls.""" """Injects call-site set before obj.method() style calls."""
@ -171,10 +171,10 @@ class TestSyncCallInstrumenter:
assert instrumenter.did_instrument assert instrumenter.did_instrument
source = ast.unparse(tree) source = ast.unparse(tree)
assert "_codeflash_call_site.set('0')" in source assert "_codeflash_call_site.set('2')" in source
def test_multiple_calls_get_incremented_indices(self) -> None: def test_multiple_calls_get_line_number_indices(self) -> None:
"""Multiple calls in the same test get sequential indices.""" """Multiple calls use their source line numbers as call-site IDs."""
test_code = ( test_code = (
"def test_example():\n" "def test_example():\n"
" a = my_func(1)\n" " a = my_func(1)\n"
@ -200,9 +200,9 @@ class TestSyncCallInstrumenter:
assert instrumenter.did_instrument assert instrumenter.did_instrument
source = ast.unparse(tree) source = ast.unparse(tree)
assert "_codeflash_call_site.set('0')" in source
assert "_codeflash_call_site.set('1')" in source
assert "_codeflash_call_site.set('2')" in source assert "_codeflash_call_site.set('2')" in source
assert "_codeflash_call_site.set('3')" in source
assert "_codeflash_call_site.set('4')" in source
def test_skips_non_test_functions(self) -> None: def test_skips_non_test_functions(self) -> None:
"""Does not instrument functions that don't start with test_.""" """Does not instrument functions that don't start with test_."""
@ -245,7 +245,7 @@ class TestSyncCallInstrumenter:
assert instrumenter.did_instrument assert instrumenter.did_instrument
source = ast.unparse(tree) source = ast.unparse(tree)
assert "_codeflash_call_site.set('0')" in source assert "_codeflash_call_site.set('3')" in source
def test_no_match_when_position_wrong(self) -> None: def test_no_match_when_position_wrong(self) -> None:
"""Does not instrument if code position doesn't match.""" """Does not instrument if code position doesn't match."""
@ -366,7 +366,7 @@ class TestAddSyncDecoratorToFunction:
assert not (subdir / ASYNC_HELPER_FILENAME).exists() assert not (subdir / ASYNC_HELPER_FILENAME).exists()
def test_preserves_existing_decorators(self, temp_dir) -> None: def test_preserves_existing_decorators(self, temp_dir) -> None:
"""Adds codeflash decorator above existing decorators.""" """Adds codeflash decorator below @staticmethod/@classmethod."""
source_code = ( source_code = (
"@staticmethod\n" "@staticmethod\n"
"def my_func(x: int) -> int:\n" "def my_func(x: int) -> int:\n"
@ -390,7 +390,7 @@ class TestAddSyncDecoratorToFunction:
modified = source_file.read_text() modified = source_file.read_text()
cf_pos = modified.find("@codeflash_behavior_sync") cf_pos = modified.find("@codeflash_behavior_sync")
sm_pos = modified.find("@staticmethod") sm_pos = modified.find("@staticmethod")
assert cf_pos < sm_pos assert sm_pos < cf_pos
def test_no_duplicate_decorator(self, temp_dir) -> None: def test_no_duplicate_decorator(self, temp_dir) -> None:
"""Does not add decorator if already present.""" """Does not add decorator if already present."""
@ -449,11 +449,11 @@ class TestInjectSyncProfilingIntoExistingTest:
assert success assert success
assert instrumented is not None assert instrumented is not None
assert "_codeflash_call_site.set('0')" in instrumented assert "_codeflash_call_site.set('4')" in instrumented
assert "from codeflash_async_wrapper import _codeflash_call_site" in instrumented assert "from codeflash_async_wrapper import _codeflash_call_site" in instrumented
def test_multiple_calls_numbered_sequentially(self, temp_dir) -> None: def test_multiple_calls_use_line_numbers(self, temp_dir) -> None:
"""Multiple calls get sequential call-site indices.""" """Multiple calls use their source line numbers as call-site IDs."""
source_code = "def my_func(x: int) -> int:\n return x + 1\n" source_code = "def my_func(x: int) -> int:\n return x + 1\n"
source_file = temp_dir / "my_module.py" source_file = temp_dir / "my_module.py"
source_file.write_text(source_code) source_file.write_text(source_code)
@ -487,9 +487,9 @@ class TestInjectSyncProfilingIntoExistingTest:
assert success assert success
assert instrumented is not None assert instrumented is not None
assert "_codeflash_call_site.set('0')" in instrumented assert "_codeflash_call_site.set('4')" in instrumented
assert "_codeflash_call_site.set('1')" in instrumented assert "_codeflash_call_site.set('5')" in instrumented
assert "_codeflash_call_site.set('2')" in instrumented assert "_codeflash_call_site.set('6')" in instrumented
def test_returns_false_for_syntax_error(self, temp_dir) -> None: def test_returns_false_for_syntax_error(self, temp_dir) -> None:
"""Returns (False, None) when the test file has a syntax error.""" """Returns (False, None) when the test file has a syntax error."""

View file

@ -29,13 +29,8 @@ from codeflash_python.testing._instrument_capture import (
create_instrumented_source_module_path, create_instrumented_source_module_path,
) )
from codeflash_python.testing._instrument_core import ( from codeflash_python.testing._instrument_core import (
FunctionCallNodeArguments,
FunctionImportedAsVisitor, FunctionImportedAsVisitor,
create_device_sync_precompute_statements,
create_device_sync_statements,
detect_frameworks_from_code, detect_frameworks_from_code,
get_call_arguments,
is_argument_name,
node_in_call_position, node_in_call_position,
) )
from codeflash_python.testing._instrumentation import ( from codeflash_python.testing._instrumentation import (
@ -94,35 +89,6 @@ class TestVerificationType:
assert 3 == len(VerificationType) assert 3 == len(VerificationType)
class TestGetCallArguments:
"""get_call_arguments Call node extraction."""
def test_simple_call(self) -> None:
"""Extracts positional args and keywords from a Call node."""
tree = ast.parse("func(1, 2, key='val')")
call_node = tree.body[0].value # type: ignore[attr-defined]
result = get_call_arguments(call_node)
assert isinstance(result, FunctionCallNodeArguments)
assert 2 == len(result.args)
assert 1 == len(result.keywords)
def test_no_args(self) -> None:
"""Returns empty lists for a call with no arguments."""
tree = ast.parse("func()")
call_node = tree.body[0].value # type: ignore[attr-defined]
result = get_call_arguments(call_node)
assert [] == result.args
assert [] == result.keywords
def test_only_keywords(self) -> None:
"""Returns keywords when only keyword args are present."""
tree = ast.parse("func(a=1, b=2)")
call_node = tree.body[0].value # type: ignore[attr-defined]
result = get_call_arguments(call_node)
assert [] == result.args
assert 2 == len(result.keywords)
class TestNodeInCallPosition: class TestNodeInCallPosition:
"""node_in_call_position position matching.""" """node_in_call_position position matching."""
@ -161,45 +127,6 @@ class TestNodeInCallPosition:
assert node_in_call_position(call_node, positions) is True assert node_in_call_position(call_node, positions) is True
class TestIsArgumentName:
"""is_argument_name argument detection."""
def test_regular_arg(self) -> None:
"""Returns True for a regular positional argument name."""
code = "def f(x, y): pass"
tree = ast.parse(code)
func_def = tree.body[0]
assert is_argument_name("x", func_def.args) is True # type: ignore[attr-defined]
def test_kwonly_arg(self) -> None:
"""Returns True for a keyword-only argument name."""
code = "def f(*, key): pass"
tree = ast.parse(code)
func_def = tree.body[0]
assert is_argument_name("key", func_def.args) is True # type: ignore[attr-defined]
def test_no_match(self) -> None:
"""Returns False when name is not an argument."""
code = "def f(x, y): pass"
tree = ast.parse(code)
func_def = tree.body[0]
assert is_argument_name("z", func_def.args) is False # type: ignore[attr-defined]
def test_vararg_not_matched(self) -> None:
"""Returns False for *args (vararg is not a list attribute)."""
code = "def f(*args): pass"
tree = ast.parse(code)
func_def = tree.body[0]
assert is_argument_name("args", func_def.args) is False # type: ignore[attr-defined]
def test_kwarg_not_matched(self) -> None:
"""Returns False for **kwargs (kwarg is not a list attribute)."""
code = "def f(**kwargs): pass"
tree = ast.parse(code)
func_def = tree.body[0]
assert is_argument_name("kwargs", func_def.args) is False # type: ignore[attr-defined]
class TestDetectFrameworksFromCode: class TestDetectFrameworksFromCode:
"""detect_frameworks_from_code import detection.""" """detect_frameworks_from_code import detection."""
@ -276,84 +203,6 @@ class TestGetDecoratorNameForMode:
) )
class TestCreateDeviceSyncPrecomputeStatements:
"""create_device_sync_precompute_statements AST generation."""
def test_none_frameworks(self) -> None:
"""Returns empty list when frameworks is None."""
result = create_device_sync_precompute_statements(None)
assert [] == result
def test_empty_dict(self) -> None:
"""Returns empty list when frameworks dict is empty."""
result = create_device_sync_precompute_statements({})
assert [] == result
def test_torch_produces_statements(self) -> None:
"""Produces AST statements for torch."""
result = create_device_sync_precompute_statements(
{"torch": "torch"},
)
assert len(result) > 0
assert all(isinstance(s, ast.stmt) for s in result)
def test_jax_produces_statements(self) -> None:
"""Produces AST statements for jax."""
result = create_device_sync_precompute_statements({"jax": "jax"})
assert len(result) > 0
assert all(isinstance(s, ast.stmt) for s in result)
def test_tensorflow_produces_statements(self) -> None:
"""Produces AST statements for tensorflow."""
result = create_device_sync_precompute_statements(
{"tensorflow": "tf"},
)
assert len(result) > 0
assert all(isinstance(s, ast.stmt) for s in result)
def test_combined_frameworks(self) -> None:
"""Produces statements for multiple frameworks."""
result = create_device_sync_precompute_statements(
{"torch": "torch", "jax": "jax"},
)
assert len(result) > 0
class TestCreateDeviceSyncStatements:
"""create_device_sync_statements AST generation."""
def test_none_frameworks(self) -> None:
"""Returns empty list when frameworks is None."""
result = create_device_sync_statements(None)
assert [] == result
def test_empty_dict(self) -> None:
"""Returns empty list when frameworks dict is empty."""
result = create_device_sync_statements({})
assert [] == result
def test_torch_sync(self) -> None:
"""Produces sync statements for torch."""
result = create_device_sync_statements({"torch": "torch"})
assert len(result) > 0
assert all(isinstance(s, ast.stmt) for s in result)
def test_for_return_value_flag(self) -> None:
"""Produces statements with for_return_value=True."""
result = create_device_sync_statements(
{"jax": "jax"},
for_return_value=True,
)
assert len(result) > 0
assert all(isinstance(s, ast.stmt) for s in result)
def test_tensorflow_sync(self) -> None:
"""Produces sync statements for tensorflow."""
result = create_device_sync_statements({"tensorflow": "tf"})
assert len(result) > 0
assert all(isinstance(s, ast.stmt) for s in result)
class TestAsyncCallInstrumenter: class TestAsyncCallInstrumenter:
"""AsyncCallInstrumenter AST transformer.""" """AsyncCallInstrumenter AST transformer."""
@ -439,11 +288,11 @@ class TestAsyncCallInstrumenter:
transformer, tree = self._make_transformer(code) transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree) new_tree = transformer.visit(tree)
source = ast.unparse(new_tree) source = ast.unparse(new_tree)
assert "_codeflash_call_site.set(" in source assert "_codeflash_call_site.set('2')" in source
assert transformer.did_instrument is True assert transformer.did_instrument is True
def test_multiple_awaits_get_incrementing_ids(self) -> None: def test_multiple_awaits_get_line_number_ids(self) -> None:
"""Each awaited target call gets a unique incrementing counter.""" """Each awaited target call uses its source line number as call-site ID."""
code = textwrap.dedent("""\ code = textwrap.dedent("""\
async def test_it(): async def test_it():
a = await target_func(1) a = await target_func(1)
@ -453,9 +302,9 @@ class TestAsyncCallInstrumenter:
transformer, tree = self._make_transformer(code) transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree) new_tree = transformer.visit(tree)
source = ast.unparse(new_tree) source = ast.unparse(new_tree)
assert "'0'" in source assert "_codeflash_call_site.set('2')" in source
assert "'1'" in source assert "_codeflash_call_site.set('3')" in source
assert "'2'" in source assert "_codeflash_call_site.set('4')" in source
def test_attribute_style_call(self) -> None: def test_attribute_style_call(self) -> None:
"""Instruments await obj.target_func() attribute-style calls.""" """Instruments await obj.target_func() attribute-style calls."""
@ -466,7 +315,7 @@ class TestAsyncCallInstrumenter:
transformer, tree = self._make_transformer(code) transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree) new_tree = transformer.visit(tree)
source = ast.unparse(new_tree) source = ast.unparse(new_tree)
assert "_codeflash_call_site.set(" in source assert "_codeflash_call_site.set('2')" in source
assert transformer.did_instrument is True assert transformer.did_instrument is True
def test_recurses_into_class_body(self) -> None: def test_recurses_into_class_body(self) -> None:
@ -479,7 +328,7 @@ class TestAsyncCallInstrumenter:
transformer, tree = self._make_transformer(code) transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree) new_tree = transformer.visit(tree)
source = ast.unparse(new_tree) source = ast.unparse(new_tree)
assert "_codeflash_call_site.set(" in source assert "_codeflash_call_site.set('3')" in source
assert transformer.did_instrument is True assert transformer.did_instrument is True
def test_no_match_when_position_wrong(self) -> None: def test_no_match_when_position_wrong(self) -> None:
@ -527,8 +376,8 @@ class TestAsyncCallInstrumenter:
transformer.visit(tree) transformer.visit(tree)
assert transformer.did_instrument is False assert transformer.did_instrument is False
def test_counters_independent_per_test_function(self) -> None: def test_line_numbers_independent_per_test_function(self) -> None:
"""Each test function gets its own independent counter.""" """Each call uses its own line number regardless of test function."""
code = textwrap.dedent("""\ code = textwrap.dedent("""\
async def test_a(): async def test_a():
await target_func(1) await target_func(1)
@ -537,9 +386,11 @@ class TestAsyncCallInstrumenter:
await target_func(3) await target_func(3)
""") """)
transformer, tree = self._make_transformer(code) transformer, tree = self._make_transformer(code)
transformer.visit(tree) new_tree = transformer.visit(tree)
assert 2 == transformer.async_call_counter["test_a"] source = ast.unparse(new_tree)
assert 1 == transformer.async_call_counter["test_b"] assert "_codeflash_call_site.set('2')" in source
assert "_codeflash_call_site.set('3')" in source
assert "_codeflash_call_site.set('5')" in source
class TestFunctionImportedAsVisitor: class TestFunctionImportedAsVisitor: