fix tests
test_class_method_instrumentation test_time_correction_instrumentation test_perfinjector_only_replay_test test_perfinjector_bubble_sort ruff refactor
This commit is contained in:
parent
3d3943908d
commit
14f94b9eb0
2 changed files with 159 additions and 122 deletions
|
|
@ -2,23 +2,34 @@ from __future__ import annotations
|
|||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
import isort
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodePosition, FunctionParent
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import CodePosition
|
||||
|
||||
|
||||
def node_in_call_position(node: ast.stmt, call_positions: list[CodePosition]) -> bool:
|
||||
def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool:
|
||||
if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"):
|
||||
for pos in call_positions:
|
||||
if node.lineno <= pos.line_no <= node.end_lineno:
|
||||
if (
|
||||
pos.line_no is not None
|
||||
and node.end_lineno is not None
|
||||
and node.lineno <= pos.line_no <= node.end_lineno
|
||||
):
|
||||
if pos.line_no == node.lineno and node.col_offset <= pos.col_no:
|
||||
return True
|
||||
if pos.line_no == node.end_lineno and node.end_col_offset >= pos.col_no:
|
||||
if (
|
||||
pos.line_no == node.end_lineno
|
||||
and node.end_col_offset is not None
|
||||
and node.end_col_offset >= pos.col_no
|
||||
):
|
||||
return True
|
||||
if node.lineno < pos.line_no < node.end_lineno:
|
||||
return True
|
||||
|
|
@ -53,9 +64,9 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
) -> Iterable[ast.stmt] | None:
|
||||
call_node = None
|
||||
for node in ast.walk(test_node):
|
||||
if node_in_call_position(node, self.call_positions):
|
||||
if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions):
|
||||
call_node = node
|
||||
if hasattr(node.func, "id"):
|
||||
if isinstance(node.func, ast.Name):
|
||||
function_name = node.func.id
|
||||
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
|
||||
node.args = [
|
||||
|
|
@ -69,11 +80,10 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
ast.Name(id="codeflash_cur", ctx=ast.Load()),
|
||||
ast.Name(id="codeflash_con", ctx=ast.Load()),
|
||||
*call_node.args,
|
||||
*call_node.keywords,
|
||||
]
|
||||
node.keywords = []
|
||||
node.keywords = call_node.keywords
|
||||
break
|
||||
if hasattr(node.func, "attr"):
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
function_to_test = node.func.attr
|
||||
if function_to_test == self.function_object.function_name:
|
||||
function_name = ast.unparse(node.func)
|
||||
|
|
@ -89,16 +99,15 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
ast.Name(id="codeflash_cur", ctx=ast.Load()),
|
||||
ast.Name(id="codeflash_con", ctx=ast.Load()),
|
||||
*call_node.args,
|
||||
*call_node.keywords,
|
||||
]
|
||||
node.keywords = []
|
||||
node.keywords = call_node.keywords
|
||||
break
|
||||
|
||||
if call_node is None:
|
||||
return None
|
||||
return [test_node]
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: # noqa: N802
|
||||
# TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes.
|
||||
for inner_node in ast.walk(node):
|
||||
if isinstance(inner_node, ast.FunctionDef):
|
||||
|
|
@ -106,7 +115,7 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef: # noqa: N802
|
||||
if node.name.startswith("test_"):
|
||||
did_update = False
|
||||
if self.test_framework == "unittest":
|
||||
|
|
@ -144,103 +153,99 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
did_update = True
|
||||
i -= 1
|
||||
if did_update:
|
||||
node.body = (
|
||||
[
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="codeflash_iteration", ctx=ast.Store())],
|
||||
value=ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
|
||||
),
|
||||
slice=ast.Constant(value="CODEFLASH_TEST_ITERATION"),
|
||||
ctx=ast.Load(),
|
||||
node.body = [
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="codeflash_iteration", ctx=ast.Store())],
|
||||
value=ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
|
||||
),
|
||||
lineno=node.lineno + 1,
|
||||
col_offset=node.col_offset,
|
||||
slice=ast.Constant(value="CODEFLASH_TEST_ITERATION"),
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="codeflash_loop_index", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="int", ctx=ast.Load()),
|
||||
args=[
|
||||
ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
|
||||
lineno=node.lineno + 1,
|
||||
col_offset=node.col_offset,
|
||||
),
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="codeflash_loop_index", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="int", ctx=ast.Load()),
|
||||
args=[
|
||||
ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
|
||||
),
|
||||
slice=ast.Constant(value="CODEFLASH_LOOP_INDEX"),
|
||||
ctx=ast.Load(),
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=node.lineno + 2,
|
||||
col_offset=node.col_offset,
|
||||
),
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="codeflash_con", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="sqlite3", ctx=ast.Load()), attr="connect", ctx=ast.Load()
|
||||
),
|
||||
args=[
|
||||
ast.JoinedStr(
|
||||
values=[
|
||||
ast.Constant(value=f"{get_run_tmp_file(Path('test_return_values_'))}"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="codeflash_iteration", ctx=ast.Load()), conversion=-1
|
||||
),
|
||||
slice=ast.Constant(value="CODEFLASH_LOOP_INDEX"),
|
||||
ctx=ast.Load(),
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=node.lineno + 2,
|
||||
col_offset=node.col_offset,
|
||||
ast.Constant(value=".sqlite"),
|
||||
]
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="codeflash_con", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="sqlite3", ctx=ast.Load()), attr="connect", ctx=ast.Load()
|
||||
),
|
||||
args=[
|
||||
ast.JoinedStr(
|
||||
values=[
|
||||
ast.Constant(value=f"{get_run_tmp_file('test_return_values_')}"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="codeflash_iteration", ctx=ast.Load()), conversion=-1
|
||||
),
|
||||
ast.Constant(value=".sqlite"),
|
||||
]
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
lineno=node.lineno + 3,
|
||||
col_offset=node.col_offset,
|
||||
),
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="codeflash_cur", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="cursor", ctx=ast.Load()
|
||||
),
|
||||
lineno=node.lineno + 3,
|
||||
col_offset=node.col_offset,
|
||||
args=[],
|
||||
keywords=[],
|
||||
),
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="codeflash_cur", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="cursor", ctx=ast.Load()
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
lineno=node.lineno + 4,
|
||||
col_offset=node.col_offset,
|
||||
),
|
||||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load()
|
||||
),
|
||||
lineno=node.lineno + 4,
|
||||
col_offset=node.col_offset,
|
||||
args=[
|
||||
ast.Constant(
|
||||
value="CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT,"
|
||||
" test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT,"
|
||||
" loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)"
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load()
|
||||
),
|
||||
args=[
|
||||
ast.Constant(
|
||||
value="CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT,"
|
||||
" test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT,"
|
||||
" loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)"
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
lineno=node.lineno + 5,
|
||||
col_offset=node.col_offset,
|
||||
),
|
||||
*node.body,
|
||||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="close", ctx=ast.Load()
|
||||
),
|
||||
lineno=node.lineno + 5,
|
||||
col_offset=node.col_offset,
|
||||
),
|
||||
]
|
||||
+ node.body
|
||||
+ [
|
||||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="close", ctx=ast.Load()
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
)
|
||||
args=[],
|
||||
keywords=[],
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
]
|
||||
return node
|
||||
|
||||
|
||||
|
|
@ -261,7 +266,7 @@ class FunctionImportedAsVisitor(ast.NodeVisitor):
|
|||
self.to_match = function.function_name
|
||||
|
||||
# TODO: Validate if the function imported is actually from the right module
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # noqa: N802
|
||||
for alias in node.names:
|
||||
if alias.name == self.to_match and hasattr(alias, "asname") and alias.asname is not None:
|
||||
if self.function.parents:
|
||||
|
|
|
|||
|
|
@ -30,16 +30,24 @@ codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_cl
|
|||
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
|
||||
print(f"!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!")
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = wrapped(*args, **kwargs)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
exception = None
|
||||
try:
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = wrapped(*args, **kwargs)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
except Exception as e:
|
||||
exception = e
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
finally:
|
||||
gc.enable()
|
||||
if loop_index == 1:
|
||||
pickled_return_value = pickle.dumps(return_value)
|
||||
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
|
||||
else:
|
||||
pickled_return_value = None
|
||||
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value))
|
||||
codeflash_con.commit()
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
"""
|
||||
|
||||
|
|
@ -92,16 +100,24 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
|
|||
expected += """print(f'!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!')"""
|
||||
expected += """
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = wrapped(*args, **kwargs)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
exception = None
|
||||
try:
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = wrapped(*args, **kwargs)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
except Exception as e:
|
||||
exception = e
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
finally:
|
||||
gc.enable()
|
||||
if loop_index == 1:
|
||||
pickled_return_value = pickle.dumps(return_value)
|
||||
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
|
||||
else:
|
||||
pickled_return_value = None
|
||||
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value))
|
||||
codeflash_con.commit()
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
|
|
@ -189,16 +205,24 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
|
|||
expected += """print(f'!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!')"""
|
||||
expected += """
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = wrapped(*args, **kwargs)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
exception = None
|
||||
try:
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = wrapped(*args, **kwargs)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
except Exception as e:
|
||||
exception = e
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
finally:
|
||||
gc.enable()
|
||||
if loop_index == 1:
|
||||
pickled_return_value = pickle.dumps(return_value)
|
||||
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
|
||||
else:
|
||||
pickled_return_value = None
|
||||
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value))
|
||||
codeflash_con.commit()
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
|
||||
def test_prepare_image_for_yolo():
|
||||
|
|
@ -1809,16 +1833,24 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
|
|||
expected += """ print(f'!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!')"""
|
||||
expected += """
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = wrapped(*args, **kwargs)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
exception = None
|
||||
try:
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = wrapped(*args, **kwargs)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
except Exception as e:
|
||||
exception = e
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
finally:
|
||||
gc.enable()
|
||||
if loop_index == 1:
|
||||
pickled_return_value = pickle.dumps(return_value)
|
||||
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
|
||||
else:
|
||||
pickled_return_value = None
|
||||
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value))
|
||||
codeflash_con.commit()
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
|
||||
def test_code_replacement10() -> None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue