mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
Fix call-site IDs to use source line numbers instead of sequential counter
Restore the old InjectPerfOnly behavior where call-site identifiers are the source line number of the instrumented statement. Also fix the sync integration test to properly apply the decorator and write the helper file, and remove dead imports from test_instrumentation.
This commit is contained in:
parent
5b20981cd4
commit
2c9f2ad8de
7 changed files with 74 additions and 223 deletions
|
|
@ -44,7 +44,6 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
|
||||||
self.function_object = function
|
self.function_object = function
|
||||||
self.call_positions = call_positions
|
self.call_positions = call_positions
|
||||||
self.did_instrument = False
|
self.did_instrument = False
|
||||||
self.async_call_counter: dict[str, int] = {}
|
|
||||||
|
|
||||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||||
"""Recurse into class bodies to find test methods."""
|
"""Recurse into class bodies to find test methods."""
|
||||||
|
|
@ -72,18 +71,12 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
|
||||||
node: ast.AsyncFunctionDef | ast.FunctionDef,
|
node: ast.AsyncFunctionDef | ast.FunctionDef,
|
||||||
) -> ast.AsyncFunctionDef | ast.FunctionDef:
|
) -> ast.AsyncFunctionDef | ast.FunctionDef:
|
||||||
"""Add _codeflash_call_site.set() calls before target await calls."""
|
"""Add _codeflash_call_site.set() calls before target await calls."""
|
||||||
if node.name not in self.async_call_counter:
|
|
||||||
self.async_call_counter[node.name] = 0
|
|
||||||
|
|
||||||
new_body: list[ast.stmt] = []
|
new_body: list[ast.stmt] = []
|
||||||
|
|
||||||
for stmt in node.body:
|
for stmt in node.body:
|
||||||
_, has_target = self._find_awaited_target_call(stmt)
|
_, has_target = self._find_awaited_target_call(stmt)
|
||||||
|
|
||||||
if has_target:
|
if has_target:
|
||||||
current_call_index = self.async_call_counter[node.name]
|
|
||||||
self.async_call_counter[node.name] += 1
|
|
||||||
|
|
||||||
call_site_set = ast.Expr(
|
call_site_set = ast.Expr(
|
||||||
value=ast.Call(
|
value=ast.Call(
|
||||||
func=ast.Attribute(
|
func=ast.Attribute(
|
||||||
|
|
@ -96,7 +89,7 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
|
||||||
),
|
),
|
||||||
args=[
|
args=[
|
||||||
ast.Constant(
|
ast.Constant(
|
||||||
value=f"{current_call_index}",
|
value=f"{stmt.lineno}",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
keywords=[],
|
keywords=[],
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,6 @@ class SyncCallInstrumenter(ast.NodeTransformer):
|
||||||
self.function_object = function
|
self.function_object = function
|
||||||
self.call_positions = call_positions
|
self.call_positions = call_positions
|
||||||
self.did_instrument = False
|
self.did_instrument = False
|
||||||
self.call_counter: dict[str, int] = {}
|
|
||||||
|
|
||||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||||
"""Recurse into class bodies to find test methods."""
|
"""Recurse into class bodies to find test methods."""
|
||||||
|
|
@ -78,18 +77,12 @@ class SyncCallInstrumenter(ast.NodeTransformer):
|
||||||
node: ast.FunctionDef | ast.AsyncFunctionDef,
|
node: ast.FunctionDef | ast.AsyncFunctionDef,
|
||||||
) -> ast.FunctionDef | ast.AsyncFunctionDef:
|
) -> ast.FunctionDef | ast.AsyncFunctionDef:
|
||||||
"""Add _codeflash_call_site.set() calls before target function calls."""
|
"""Add _codeflash_call_site.set() calls before target function calls."""
|
||||||
if node.name not in self.call_counter:
|
|
||||||
self.call_counter[node.name] = 0
|
|
||||||
|
|
||||||
new_body: list[ast.stmt] = []
|
new_body: list[ast.stmt] = []
|
||||||
|
|
||||||
for stmt in node.body:
|
for stmt in node.body:
|
||||||
_, has_target = self._find_target_call(stmt)
|
call_node, has_target = self._find_target_call(stmt)
|
||||||
|
|
||||||
if has_target:
|
if has_target:
|
||||||
current_call_index = self.call_counter[node.name]
|
|
||||||
self.call_counter[node.name] += 1
|
|
||||||
|
|
||||||
call_site_set = ast.Expr(
|
call_site_set = ast.Expr(
|
||||||
value=ast.Call(
|
value=ast.Call(
|
||||||
func=ast.Attribute(
|
func=ast.Attribute(
|
||||||
|
|
@ -102,7 +95,7 @@ class SyncCallInstrumenter(ast.NodeTransformer):
|
||||||
),
|
),
|
||||||
args=[
|
args=[
|
||||||
ast.Constant(
|
ast.Constant(
|
||||||
value=f"{current_call_index}",
|
value=f"{stmt.lineno}",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
keywords=[],
|
keywords=[],
|
||||||
|
|
|
||||||
|
|
@ -1,41 +1,48 @@
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from codeflash_async_wrapper import codeflash_behavior_sync
|
||||||
|
|
||||||
|
from codeflash_python.runtime._codeflash_capture import codeflash_capture
|
||||||
|
|
||||||
|
|
||||||
class BubbleSorter:
|
class BubbleSorter:
|
||||||
|
|
||||||
|
@codeflash_capture(function_name='BubbleSorter.__init__', tmp_dir_path='/var/folders/mg/k_c0twcj37q_gph3cfy3zlt80000gn/T/codeflash_l3k89hc3/codeflash_results', tests_root='/Users/krrt7/Desktop/work/cf_org/codeflash-agent/.claude/worktrees/jaunty-sauteeing-dolphin/packages/codeflash-python/tests/code_to_optimize/tests/pytest', is_fto=True)
|
||||||
def __init__(self, x=0):
|
def __init__(self, x=0):
|
||||||
self.x = x
|
self.x = x
|
||||||
|
|
||||||
|
@codeflash_behavior_sync
|
||||||
def sorter(self, arr):
|
def sorter(self, arr):
|
||||||
print("codeflash stdout : BubbleSorter.sorter() called")
|
print('codeflash stdout : BubbleSorter.sorter() called')
|
||||||
for i in range(len(arr)):
|
for i in range(len(arr)):
|
||||||
for j in range(len(arr) - 1):
|
for j in range(len(arr) - 1):
|
||||||
if arr[j] > arr[j + 1]:
|
if arr[j] > arr[j + 1]:
|
||||||
temp = arr[j]
|
temp = arr[j]
|
||||||
arr[j] = arr[j + 1]
|
arr[j] = arr[j + 1]
|
||||||
arr[j + 1] = temp
|
arr[j + 1] = temp
|
||||||
print("stderr test", file=sys.stderr)
|
print('stderr test', file=sys.stderr)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sorter_classmethod(cls, arr):
|
def sorter_classmethod(cls, arr):
|
||||||
print("codeflash stdout : BubbleSorter.sorter_classmethod() called")
|
print('codeflash stdout : BubbleSorter.sorter_classmethod() called')
|
||||||
for i in range(len(arr)):
|
for i in range(len(arr)):
|
||||||
for j in range(len(arr) - 1):
|
for j in range(len(arr) - 1):
|
||||||
if arr[j] > arr[j + 1]:
|
if arr[j] > arr[j + 1]:
|
||||||
temp = arr[j]
|
temp = arr[j]
|
||||||
arr[j] = arr[j + 1]
|
arr[j] = arr[j + 1]
|
||||||
arr[j + 1] = temp
|
arr[j + 1] = temp
|
||||||
print("stderr test classmethod", file=sys.stderr)
|
print('stderr test classmethod', file=sys.stderr)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sorter_staticmethod(arr):
|
def sorter_staticmethod(arr):
|
||||||
print("codeflash stdout : BubbleSorter.sorter_staticmethod() called")
|
print('codeflash stdout : BubbleSorter.sorter_staticmethod() called')
|
||||||
for i in range(len(arr)):
|
for i in range(len(arr)):
|
||||||
for j in range(len(arr) - 1):
|
for j in range(len(arr) - 1):
|
||||||
if arr[j] > arr[j + 1]:
|
if arr[j] > arr[j + 1]:
|
||||||
temp = arr[j]
|
temp = arr[j]
|
||||||
arr[j] = arr[j + 1]
|
arr[j] = arr[j + 1]
|
||||||
arr[j + 1] = temp
|
arr[j + 1] = temp
|
||||||
print("stderr test staticmethod", file=sys.stderr)
|
print('stderr test staticmethod', file=sys.stderr)
|
||||||
return arr
|
return arr
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,9 @@ from codeflash_python.testing._instrument_async import (
|
||||||
from codeflash_python.testing._instrument_capture import (
|
from codeflash_python.testing._instrument_capture import (
|
||||||
instrument_codeflash_capture,
|
instrument_codeflash_capture,
|
||||||
)
|
)
|
||||||
|
from codeflash_python.testing._instrument_sync import (
|
||||||
|
add_sync_decorator_to_function,
|
||||||
|
)
|
||||||
from codeflash_python.testing._instrumentation import (
|
from codeflash_python.testing._instrumentation import (
|
||||||
inject_profiling_into_existing_test,
|
inject_profiling_into_existing_test,
|
||||||
)
|
)
|
||||||
|
|
@ -968,6 +971,14 @@ def test_sync_sort():
|
||||||
with test_path.open("w") as f:
|
with test_path.open("w") as f:
|
||||||
f.write(instrumented_test)
|
f.write(instrumented_test)
|
||||||
|
|
||||||
|
added, sync_originals = add_sync_decorator_to_function(
|
||||||
|
sync_fto_path,
|
||||||
|
func,
|
||||||
|
mode=TestingMode.BEHAVIOR,
|
||||||
|
project_root=project_root,
|
||||||
|
)
|
||||||
|
assert added
|
||||||
|
|
||||||
instrument_codeflash_capture(func, {}, tests_root)
|
instrument_codeflash_capture(func, {}, tests_root)
|
||||||
|
|
||||||
test_env = os.environ.copy()
|
test_env = os.environ.copy()
|
||||||
|
|
@ -1018,7 +1029,7 @@ def test_sync_sort():
|
||||||
|
|
||||||
results_list = test_results.test_results
|
results_list = test_results.test_results
|
||||||
assert results_list[0].id.function_getting_tested == "sync_sorter"
|
assert results_list[0].id.function_getting_tested == "sync_sorter"
|
||||||
assert results_list[0].id.iteration_id == "1_0"
|
assert results_list[0].id.iteration_id == "6_0"
|
||||||
assert results_list[0].id.test_class_name is None
|
assert results_list[0].id.test_class_name is None
|
||||||
assert results_list[0].id.test_function_name == "test_sync_sort"
|
assert results_list[0].id.test_function_name == "test_sync_sort"
|
||||||
assert results_list[0].did_pass
|
assert results_list[0].did_pass
|
||||||
|
|
@ -1031,7 +1042,7 @@ def test_sync_sort():
|
||||||
|
|
||||||
if len(results_list) > 1:
|
if len(results_list) > 1:
|
||||||
assert results_list[1].id.function_getting_tested == "sync_sorter"
|
assert results_list[1].id.function_getting_tested == "sync_sorter"
|
||||||
assert results_list[1].id.iteration_id == "4_0"
|
assert results_list[1].id.iteration_id == "10_0"
|
||||||
assert results_list[1].id.test_function_name == "test_sync_sort"
|
assert results_list[1].id.test_function_name == "test_sync_sort"
|
||||||
assert results_list[1].did_pass
|
assert results_list[1].did_pass
|
||||||
|
|
||||||
|
|
@ -1045,6 +1056,9 @@ def test_sync_sort():
|
||||||
test_path.unlink()
|
test_path.unlink()
|
||||||
if test_path_perf.exists():
|
if test_path_perf.exists():
|
||||||
test_path_perf.unlink()
|
test_path_perf.unlink()
|
||||||
|
helper_path = project_root / ASYNC_HELPER_FILENAME
|
||||||
|
if helper_path.exists():
|
||||||
|
helper_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
|
|
|
||||||
|
|
@ -750,29 +750,22 @@ async def test_multiple_calls():
|
||||||
assert success
|
assert success
|
||||||
assert instrumented_test_code is not None
|
assert instrumented_test_code is not None
|
||||||
|
|
||||||
assert (
|
assert "_codeflash_call_site.set('8')" in instrumented_test_code
|
||||||
"_codeflash_call_site.set('0')"
|
assert "_codeflash_call_site.set('13')" in instrumented_test_code
|
||||||
in instrumented_test_code
|
assert "_codeflash_call_site.set('14')" in instrumented_test_code
|
||||||
)
|
assert "_codeflash_call_site.set('15')" in instrumented_test_code
|
||||||
|
|
||||||
line_id_0_count = instrumented_test_code.count(
|
assert 1 == instrumented_test_code.count(
|
||||||
"_codeflash_call_site.set('0')"
|
"_codeflash_call_site.set('8')"
|
||||||
)
|
)
|
||||||
line_id_1_count = instrumented_test_code.count(
|
assert 1 == instrumented_test_code.count(
|
||||||
"_codeflash_call_site.set('1')"
|
"_codeflash_call_site.set('13')"
|
||||||
)
|
)
|
||||||
line_id_2_count = instrumented_test_code.count(
|
assert 1 == instrumented_test_code.count(
|
||||||
"_codeflash_call_site.set('2')"
|
"_codeflash_call_site.set('14')"
|
||||||
)
|
)
|
||||||
|
assert 1 == instrumented_test_code.count(
|
||||||
assert 2 == line_id_0_count, (
|
"_codeflash_call_site.set('15')"
|
||||||
f"Expected 2 occurrences of line_id '0', got {line_id_0_count}"
|
|
||||||
)
|
|
||||||
assert 1 == line_id_1_count, (
|
|
||||||
f"Expected 1 occurrence of line_id '1', got {line_id_1_count}"
|
|
||||||
)
|
|
||||||
assert 1 == line_id_2_count, (
|
|
||||||
f"Expected 1 occurrence of line_id '2', got {line_id_2_count}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -148,7 +148,7 @@ class TestSyncCallInstrumenter:
|
||||||
|
|
||||||
assert instrumenter.did_instrument
|
assert instrumenter.did_instrument
|
||||||
source = ast.unparse(tree)
|
source = ast.unparse(tree)
|
||||||
assert "_codeflash_call_site.set('0')" in source
|
assert "_codeflash_call_site.set('2')" in source
|
||||||
|
|
||||||
def test_instruments_method_call(self) -> None:
|
def test_instruments_method_call(self) -> None:
|
||||||
"""Injects call-site set before obj.method() style calls."""
|
"""Injects call-site set before obj.method() style calls."""
|
||||||
|
|
@ -171,10 +171,10 @@ class TestSyncCallInstrumenter:
|
||||||
|
|
||||||
assert instrumenter.did_instrument
|
assert instrumenter.did_instrument
|
||||||
source = ast.unparse(tree)
|
source = ast.unparse(tree)
|
||||||
assert "_codeflash_call_site.set('0')" in source
|
assert "_codeflash_call_site.set('2')" in source
|
||||||
|
|
||||||
def test_multiple_calls_get_incremented_indices(self) -> None:
|
def test_multiple_calls_get_line_number_indices(self) -> None:
|
||||||
"""Multiple calls in the same test get sequential indices."""
|
"""Multiple calls use their source line numbers as call-site IDs."""
|
||||||
test_code = (
|
test_code = (
|
||||||
"def test_example():\n"
|
"def test_example():\n"
|
||||||
" a = my_func(1)\n"
|
" a = my_func(1)\n"
|
||||||
|
|
@ -200,9 +200,9 @@ class TestSyncCallInstrumenter:
|
||||||
|
|
||||||
assert instrumenter.did_instrument
|
assert instrumenter.did_instrument
|
||||||
source = ast.unparse(tree)
|
source = ast.unparse(tree)
|
||||||
assert "_codeflash_call_site.set('0')" in source
|
|
||||||
assert "_codeflash_call_site.set('1')" in source
|
|
||||||
assert "_codeflash_call_site.set('2')" in source
|
assert "_codeflash_call_site.set('2')" in source
|
||||||
|
assert "_codeflash_call_site.set('3')" in source
|
||||||
|
assert "_codeflash_call_site.set('4')" in source
|
||||||
|
|
||||||
def test_skips_non_test_functions(self) -> None:
|
def test_skips_non_test_functions(self) -> None:
|
||||||
"""Does not instrument functions that don't start with test_."""
|
"""Does not instrument functions that don't start with test_."""
|
||||||
|
|
@ -245,7 +245,7 @@ class TestSyncCallInstrumenter:
|
||||||
|
|
||||||
assert instrumenter.did_instrument
|
assert instrumenter.did_instrument
|
||||||
source = ast.unparse(tree)
|
source = ast.unparse(tree)
|
||||||
assert "_codeflash_call_site.set('0')" in source
|
assert "_codeflash_call_site.set('3')" in source
|
||||||
|
|
||||||
def test_no_match_when_position_wrong(self) -> None:
|
def test_no_match_when_position_wrong(self) -> None:
|
||||||
"""Does not instrument if code position doesn't match."""
|
"""Does not instrument if code position doesn't match."""
|
||||||
|
|
@ -366,7 +366,7 @@ class TestAddSyncDecoratorToFunction:
|
||||||
assert not (subdir / ASYNC_HELPER_FILENAME).exists()
|
assert not (subdir / ASYNC_HELPER_FILENAME).exists()
|
||||||
|
|
||||||
def test_preserves_existing_decorators(self, temp_dir) -> None:
|
def test_preserves_existing_decorators(self, temp_dir) -> None:
|
||||||
"""Adds codeflash decorator above existing decorators."""
|
"""Adds codeflash decorator below @staticmethod/@classmethod."""
|
||||||
source_code = (
|
source_code = (
|
||||||
"@staticmethod\n"
|
"@staticmethod\n"
|
||||||
"def my_func(x: int) -> int:\n"
|
"def my_func(x: int) -> int:\n"
|
||||||
|
|
@ -390,7 +390,7 @@ class TestAddSyncDecoratorToFunction:
|
||||||
modified = source_file.read_text()
|
modified = source_file.read_text()
|
||||||
cf_pos = modified.find("@codeflash_behavior_sync")
|
cf_pos = modified.find("@codeflash_behavior_sync")
|
||||||
sm_pos = modified.find("@staticmethod")
|
sm_pos = modified.find("@staticmethod")
|
||||||
assert cf_pos < sm_pos
|
assert sm_pos < cf_pos
|
||||||
|
|
||||||
def test_no_duplicate_decorator(self, temp_dir) -> None:
|
def test_no_duplicate_decorator(self, temp_dir) -> None:
|
||||||
"""Does not add decorator if already present."""
|
"""Does not add decorator if already present."""
|
||||||
|
|
@ -449,11 +449,11 @@ class TestInjectSyncProfilingIntoExistingTest:
|
||||||
|
|
||||||
assert success
|
assert success
|
||||||
assert instrumented is not None
|
assert instrumented is not None
|
||||||
assert "_codeflash_call_site.set('0')" in instrumented
|
assert "_codeflash_call_site.set('4')" in instrumented
|
||||||
assert "from codeflash_async_wrapper import _codeflash_call_site" in instrumented
|
assert "from codeflash_async_wrapper import _codeflash_call_site" in instrumented
|
||||||
|
|
||||||
def test_multiple_calls_numbered_sequentially(self, temp_dir) -> None:
|
def test_multiple_calls_use_line_numbers(self, temp_dir) -> None:
|
||||||
"""Multiple calls get sequential call-site indices."""
|
"""Multiple calls use their source line numbers as call-site IDs."""
|
||||||
source_code = "def my_func(x: int) -> int:\n return x + 1\n"
|
source_code = "def my_func(x: int) -> int:\n return x + 1\n"
|
||||||
source_file = temp_dir / "my_module.py"
|
source_file = temp_dir / "my_module.py"
|
||||||
source_file.write_text(source_code)
|
source_file.write_text(source_code)
|
||||||
|
|
@ -487,9 +487,9 @@ class TestInjectSyncProfilingIntoExistingTest:
|
||||||
|
|
||||||
assert success
|
assert success
|
||||||
assert instrumented is not None
|
assert instrumented is not None
|
||||||
assert "_codeflash_call_site.set('0')" in instrumented
|
assert "_codeflash_call_site.set('4')" in instrumented
|
||||||
assert "_codeflash_call_site.set('1')" in instrumented
|
assert "_codeflash_call_site.set('5')" in instrumented
|
||||||
assert "_codeflash_call_site.set('2')" in instrumented
|
assert "_codeflash_call_site.set('6')" in instrumented
|
||||||
|
|
||||||
def test_returns_false_for_syntax_error(self, temp_dir) -> None:
|
def test_returns_false_for_syntax_error(self, temp_dir) -> None:
|
||||||
"""Returns (False, None) when the test file has a syntax error."""
|
"""Returns (False, None) when the test file has a syntax error."""
|
||||||
|
|
|
||||||
|
|
@ -29,13 +29,8 @@ from codeflash_python.testing._instrument_capture import (
|
||||||
create_instrumented_source_module_path,
|
create_instrumented_source_module_path,
|
||||||
)
|
)
|
||||||
from codeflash_python.testing._instrument_core import (
|
from codeflash_python.testing._instrument_core import (
|
||||||
FunctionCallNodeArguments,
|
|
||||||
FunctionImportedAsVisitor,
|
FunctionImportedAsVisitor,
|
||||||
create_device_sync_precompute_statements,
|
|
||||||
create_device_sync_statements,
|
|
||||||
detect_frameworks_from_code,
|
detect_frameworks_from_code,
|
||||||
get_call_arguments,
|
|
||||||
is_argument_name,
|
|
||||||
node_in_call_position,
|
node_in_call_position,
|
||||||
)
|
)
|
||||||
from codeflash_python.testing._instrumentation import (
|
from codeflash_python.testing._instrumentation import (
|
||||||
|
|
@ -94,35 +89,6 @@ class TestVerificationType:
|
||||||
assert 3 == len(VerificationType)
|
assert 3 == len(VerificationType)
|
||||||
|
|
||||||
|
|
||||||
class TestGetCallArguments:
|
|
||||||
"""get_call_arguments Call node extraction."""
|
|
||||||
|
|
||||||
def test_simple_call(self) -> None:
|
|
||||||
"""Extracts positional args and keywords from a Call node."""
|
|
||||||
tree = ast.parse("func(1, 2, key='val')")
|
|
||||||
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
||||||
result = get_call_arguments(call_node)
|
|
||||||
assert isinstance(result, FunctionCallNodeArguments)
|
|
||||||
assert 2 == len(result.args)
|
|
||||||
assert 1 == len(result.keywords)
|
|
||||||
|
|
||||||
def test_no_args(self) -> None:
|
|
||||||
"""Returns empty lists for a call with no arguments."""
|
|
||||||
tree = ast.parse("func()")
|
|
||||||
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
||||||
result = get_call_arguments(call_node)
|
|
||||||
assert [] == result.args
|
|
||||||
assert [] == result.keywords
|
|
||||||
|
|
||||||
def test_only_keywords(self) -> None:
|
|
||||||
"""Returns keywords when only keyword args are present."""
|
|
||||||
tree = ast.parse("func(a=1, b=2)")
|
|
||||||
call_node = tree.body[0].value # type: ignore[attr-defined]
|
|
||||||
result = get_call_arguments(call_node)
|
|
||||||
assert [] == result.args
|
|
||||||
assert 2 == len(result.keywords)
|
|
||||||
|
|
||||||
|
|
||||||
class TestNodeInCallPosition:
|
class TestNodeInCallPosition:
|
||||||
"""node_in_call_position position matching."""
|
"""node_in_call_position position matching."""
|
||||||
|
|
||||||
|
|
@ -161,45 +127,6 @@ class TestNodeInCallPosition:
|
||||||
assert node_in_call_position(call_node, positions) is True
|
assert node_in_call_position(call_node, positions) is True
|
||||||
|
|
||||||
|
|
||||||
class TestIsArgumentName:
|
|
||||||
"""is_argument_name argument detection."""
|
|
||||||
|
|
||||||
def test_regular_arg(self) -> None:
|
|
||||||
"""Returns True for a regular positional argument name."""
|
|
||||||
code = "def f(x, y): pass"
|
|
||||||
tree = ast.parse(code)
|
|
||||||
func_def = tree.body[0]
|
|
||||||
assert is_argument_name("x", func_def.args) is True # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
def test_kwonly_arg(self) -> None:
|
|
||||||
"""Returns True for a keyword-only argument name."""
|
|
||||||
code = "def f(*, key): pass"
|
|
||||||
tree = ast.parse(code)
|
|
||||||
func_def = tree.body[0]
|
|
||||||
assert is_argument_name("key", func_def.args) is True # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
def test_no_match(self) -> None:
|
|
||||||
"""Returns False when name is not an argument."""
|
|
||||||
code = "def f(x, y): pass"
|
|
||||||
tree = ast.parse(code)
|
|
||||||
func_def = tree.body[0]
|
|
||||||
assert is_argument_name("z", func_def.args) is False # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
def test_vararg_not_matched(self) -> None:
|
|
||||||
"""Returns False for *args (vararg is not a list attribute)."""
|
|
||||||
code = "def f(*args): pass"
|
|
||||||
tree = ast.parse(code)
|
|
||||||
func_def = tree.body[0]
|
|
||||||
assert is_argument_name("args", func_def.args) is False # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
def test_kwarg_not_matched(self) -> None:
|
|
||||||
"""Returns False for **kwargs (kwarg is not a list attribute)."""
|
|
||||||
code = "def f(**kwargs): pass"
|
|
||||||
tree = ast.parse(code)
|
|
||||||
func_def = tree.body[0]
|
|
||||||
assert is_argument_name("kwargs", func_def.args) is False # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
|
|
||||||
class TestDetectFrameworksFromCode:
|
class TestDetectFrameworksFromCode:
|
||||||
"""detect_frameworks_from_code import detection."""
|
"""detect_frameworks_from_code import detection."""
|
||||||
|
|
||||||
|
|
@ -276,84 +203,6 @@ class TestGetDecoratorNameForMode:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestCreateDeviceSyncPrecomputeStatements:
|
|
||||||
"""create_device_sync_precompute_statements AST generation."""
|
|
||||||
|
|
||||||
def test_none_frameworks(self) -> None:
|
|
||||||
"""Returns empty list when frameworks is None."""
|
|
||||||
result = create_device_sync_precompute_statements(None)
|
|
||||||
assert [] == result
|
|
||||||
|
|
||||||
def test_empty_dict(self) -> None:
|
|
||||||
"""Returns empty list when frameworks dict is empty."""
|
|
||||||
result = create_device_sync_precompute_statements({})
|
|
||||||
assert [] == result
|
|
||||||
|
|
||||||
def test_torch_produces_statements(self) -> None:
|
|
||||||
"""Produces AST statements for torch."""
|
|
||||||
result = create_device_sync_precompute_statements(
|
|
||||||
{"torch": "torch"},
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert all(isinstance(s, ast.stmt) for s in result)
|
|
||||||
|
|
||||||
def test_jax_produces_statements(self) -> None:
|
|
||||||
"""Produces AST statements for jax."""
|
|
||||||
result = create_device_sync_precompute_statements({"jax": "jax"})
|
|
||||||
assert len(result) > 0
|
|
||||||
assert all(isinstance(s, ast.stmt) for s in result)
|
|
||||||
|
|
||||||
def test_tensorflow_produces_statements(self) -> None:
|
|
||||||
"""Produces AST statements for tensorflow."""
|
|
||||||
result = create_device_sync_precompute_statements(
|
|
||||||
{"tensorflow": "tf"},
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert all(isinstance(s, ast.stmt) for s in result)
|
|
||||||
|
|
||||||
def test_combined_frameworks(self) -> None:
|
|
||||||
"""Produces statements for multiple frameworks."""
|
|
||||||
result = create_device_sync_precompute_statements(
|
|
||||||
{"torch": "torch", "jax": "jax"},
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
|
|
||||||
|
|
||||||
class TestCreateDeviceSyncStatements:
|
|
||||||
"""create_device_sync_statements AST generation."""
|
|
||||||
|
|
||||||
def test_none_frameworks(self) -> None:
|
|
||||||
"""Returns empty list when frameworks is None."""
|
|
||||||
result = create_device_sync_statements(None)
|
|
||||||
assert [] == result
|
|
||||||
|
|
||||||
def test_empty_dict(self) -> None:
|
|
||||||
"""Returns empty list when frameworks dict is empty."""
|
|
||||||
result = create_device_sync_statements({})
|
|
||||||
assert [] == result
|
|
||||||
|
|
||||||
def test_torch_sync(self) -> None:
|
|
||||||
"""Produces sync statements for torch."""
|
|
||||||
result = create_device_sync_statements({"torch": "torch"})
|
|
||||||
assert len(result) > 0
|
|
||||||
assert all(isinstance(s, ast.stmt) for s in result)
|
|
||||||
|
|
||||||
def test_for_return_value_flag(self) -> None:
|
|
||||||
"""Produces statements with for_return_value=True."""
|
|
||||||
result = create_device_sync_statements(
|
|
||||||
{"jax": "jax"},
|
|
||||||
for_return_value=True,
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert all(isinstance(s, ast.stmt) for s in result)
|
|
||||||
|
|
||||||
def test_tensorflow_sync(self) -> None:
|
|
||||||
"""Produces sync statements for tensorflow."""
|
|
||||||
result = create_device_sync_statements({"tensorflow": "tf"})
|
|
||||||
assert len(result) > 0
|
|
||||||
assert all(isinstance(s, ast.stmt) for s in result)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAsyncCallInstrumenter:
|
class TestAsyncCallInstrumenter:
|
||||||
"""AsyncCallInstrumenter AST transformer."""
|
"""AsyncCallInstrumenter AST transformer."""
|
||||||
|
|
||||||
|
|
@ -439,11 +288,11 @@ class TestAsyncCallInstrumenter:
|
||||||
transformer, tree = self._make_transformer(code)
|
transformer, tree = self._make_transformer(code)
|
||||||
new_tree = transformer.visit(tree)
|
new_tree = transformer.visit(tree)
|
||||||
source = ast.unparse(new_tree)
|
source = ast.unparse(new_tree)
|
||||||
assert "_codeflash_call_site.set(" in source
|
assert "_codeflash_call_site.set('2')" in source
|
||||||
assert transformer.did_instrument is True
|
assert transformer.did_instrument is True
|
||||||
|
|
||||||
def test_multiple_awaits_get_incrementing_ids(self) -> None:
|
def test_multiple_awaits_get_line_number_ids(self) -> None:
|
||||||
"""Each awaited target call gets a unique incrementing counter."""
|
"""Each awaited target call uses its source line number as call-site ID."""
|
||||||
code = textwrap.dedent("""\
|
code = textwrap.dedent("""\
|
||||||
async def test_it():
|
async def test_it():
|
||||||
a = await target_func(1)
|
a = await target_func(1)
|
||||||
|
|
@ -453,9 +302,9 @@ class TestAsyncCallInstrumenter:
|
||||||
transformer, tree = self._make_transformer(code)
|
transformer, tree = self._make_transformer(code)
|
||||||
new_tree = transformer.visit(tree)
|
new_tree = transformer.visit(tree)
|
||||||
source = ast.unparse(new_tree)
|
source = ast.unparse(new_tree)
|
||||||
assert "'0'" in source
|
assert "_codeflash_call_site.set('2')" in source
|
||||||
assert "'1'" in source
|
assert "_codeflash_call_site.set('3')" in source
|
||||||
assert "'2'" in source
|
assert "_codeflash_call_site.set('4')" in source
|
||||||
|
|
||||||
def test_attribute_style_call(self) -> None:
|
def test_attribute_style_call(self) -> None:
|
||||||
"""Instruments await obj.target_func() attribute-style calls."""
|
"""Instruments await obj.target_func() attribute-style calls."""
|
||||||
|
|
@ -466,7 +315,7 @@ class TestAsyncCallInstrumenter:
|
||||||
transformer, tree = self._make_transformer(code)
|
transformer, tree = self._make_transformer(code)
|
||||||
new_tree = transformer.visit(tree)
|
new_tree = transformer.visit(tree)
|
||||||
source = ast.unparse(new_tree)
|
source = ast.unparse(new_tree)
|
||||||
assert "_codeflash_call_site.set(" in source
|
assert "_codeflash_call_site.set('2')" in source
|
||||||
assert transformer.did_instrument is True
|
assert transformer.did_instrument is True
|
||||||
|
|
||||||
def test_recurses_into_class_body(self) -> None:
|
def test_recurses_into_class_body(self) -> None:
|
||||||
|
|
@ -479,7 +328,7 @@ class TestAsyncCallInstrumenter:
|
||||||
transformer, tree = self._make_transformer(code)
|
transformer, tree = self._make_transformer(code)
|
||||||
new_tree = transformer.visit(tree)
|
new_tree = transformer.visit(tree)
|
||||||
source = ast.unparse(new_tree)
|
source = ast.unparse(new_tree)
|
||||||
assert "_codeflash_call_site.set(" in source
|
assert "_codeflash_call_site.set('3')" in source
|
||||||
assert transformer.did_instrument is True
|
assert transformer.did_instrument is True
|
||||||
|
|
||||||
def test_no_match_when_position_wrong(self) -> None:
|
def test_no_match_when_position_wrong(self) -> None:
|
||||||
|
|
@ -527,8 +376,8 @@ class TestAsyncCallInstrumenter:
|
||||||
transformer.visit(tree)
|
transformer.visit(tree)
|
||||||
assert transformer.did_instrument is False
|
assert transformer.did_instrument is False
|
||||||
|
|
||||||
def test_counters_independent_per_test_function(self) -> None:
|
def test_line_numbers_independent_per_test_function(self) -> None:
|
||||||
"""Each test function gets its own independent counter."""
|
"""Each call uses its own line number regardless of test function."""
|
||||||
code = textwrap.dedent("""\
|
code = textwrap.dedent("""\
|
||||||
async def test_a():
|
async def test_a():
|
||||||
await target_func(1)
|
await target_func(1)
|
||||||
|
|
@ -537,9 +386,11 @@ class TestAsyncCallInstrumenter:
|
||||||
await target_func(3)
|
await target_func(3)
|
||||||
""")
|
""")
|
||||||
transformer, tree = self._make_transformer(code)
|
transformer, tree = self._make_transformer(code)
|
||||||
transformer.visit(tree)
|
new_tree = transformer.visit(tree)
|
||||||
assert 2 == transformer.async_call_counter["test_a"]
|
source = ast.unparse(new_tree)
|
||||||
assert 1 == transformer.async_call_counter["test_b"]
|
assert "_codeflash_call_site.set('2')" in source
|
||||||
|
assert "_codeflash_call_site.set('3')" in source
|
||||||
|
assert "_codeflash_call_site.set('5')" in source
|
||||||
|
|
||||||
|
|
||||||
class TestFunctionImportedAsVisitor:
|
class TestFunctionImportedAsVisitor:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue