fix tests

test_class_method_instrumentation
test_time_correction_instrumentation
test_perfinjector_only_replay_test
test_perfinjector_bubble_sort
ruff refactor
This commit is contained in:
Kevin Turcios 2024-12-10 10:00:07 -05:00
parent 3d3943908d
commit 14f94b9eb0
2 changed files with 159 additions and 122 deletions

View file

@ -2,23 +2,34 @@ from __future__ import annotations
import ast
from pathlib import Path
from typing import Iterable
from typing import TYPE_CHECKING, Iterable
import isort
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, FunctionParent
from codeflash.models.models import FunctionParent
if TYPE_CHECKING:
from codeflash.models.models import CodePosition
def node_in_call_position(node: ast.stmt, call_positions: list[CodePosition]) -> bool:
def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool:
if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"):
for pos in call_positions:
if node.lineno <= pos.line_no <= node.end_lineno:
if (
pos.line_no is not None
and node.end_lineno is not None
and node.lineno <= pos.line_no <= node.end_lineno
):
if pos.line_no == node.lineno and node.col_offset <= pos.col_no:
return True
if pos.line_no == node.end_lineno and node.end_col_offset >= pos.col_no:
if (
pos.line_no == node.end_lineno
and node.end_col_offset is not None
and node.end_col_offset >= pos.col_no
):
return True
if node.lineno < pos.line_no < node.end_lineno:
return True
@ -53,9 +64,9 @@ class InjectPerfOnly(ast.NodeTransformer):
) -> Iterable[ast.stmt] | None:
call_node = None
for node in ast.walk(test_node):
if node_in_call_position(node, self.call_positions):
if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions):
call_node = node
if hasattr(node.func, "id"):
if isinstance(node.func, ast.Name):
function_name = node.func.id
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [
@ -69,11 +80,10 @@ class InjectPerfOnly(ast.NodeTransformer):
ast.Name(id="codeflash_cur", ctx=ast.Load()),
ast.Name(id="codeflash_con", ctx=ast.Load()),
*call_node.args,
*call_node.keywords,
]
node.keywords = []
node.keywords = call_node.keywords
break
if hasattr(node.func, "attr"):
if isinstance(node.func, ast.Attribute):
function_to_test = node.func.attr
if function_to_test == self.function_object.function_name:
function_name = ast.unparse(node.func)
@ -89,16 +99,15 @@ class InjectPerfOnly(ast.NodeTransformer):
ast.Name(id="codeflash_cur", ctx=ast.Load()),
ast.Name(id="codeflash_con", ctx=ast.Load()),
*call_node.args,
*call_node.keywords,
]
node.keywords = []
node.keywords = call_node.keywords
break
if call_node is None:
return None
return [test_node]
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: # noqa: N802
# TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes.
for inner_node in ast.walk(node):
if isinstance(inner_node, ast.FunctionDef):
@ -106,7 +115,7 @@ class InjectPerfOnly(ast.NodeTransformer):
return node
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef: # noqa: N802
if node.name.startswith("test_"):
did_update = False
if self.test_framework == "unittest":
@ -144,103 +153,99 @@ class InjectPerfOnly(ast.NodeTransformer):
did_update = True
i -= 1
if did_update:
node.body = (
[
ast.Assign(
targets=[ast.Name(id="codeflash_iteration", ctx=ast.Store())],
value=ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
),
slice=ast.Constant(value="CODEFLASH_TEST_ITERATION"),
ctx=ast.Load(),
node.body = [
ast.Assign(
targets=[ast.Name(id="codeflash_iteration", ctx=ast.Store())],
value=ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
),
lineno=node.lineno + 1,
col_offset=node.col_offset,
slice=ast.Constant(value="CODEFLASH_TEST_ITERATION"),
ctx=ast.Load(),
),
ast.Assign(
targets=[ast.Name(id="codeflash_loop_index", ctx=ast.Store())],
value=ast.Call(
func=ast.Name(id="int", ctx=ast.Load()),
args=[
ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
lineno=node.lineno + 1,
col_offset=node.col_offset,
),
ast.Assign(
targets=[ast.Name(id="codeflash_loop_index", ctx=ast.Store())],
value=ast.Call(
func=ast.Name(id="int", ctx=ast.Load()),
args=[
ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
),
slice=ast.Constant(value="CODEFLASH_LOOP_INDEX"),
ctx=ast.Load(),
)
],
keywords=[],
),
lineno=node.lineno + 2,
col_offset=node.col_offset,
),
ast.Assign(
targets=[ast.Name(id="codeflash_con", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="sqlite3", ctx=ast.Load()), attr="connect", ctx=ast.Load()
),
args=[
ast.JoinedStr(
values=[
ast.Constant(value=f"{get_run_tmp_file(Path('test_return_values_'))}"),
ast.FormattedValue(
value=ast.Name(id="codeflash_iteration", ctx=ast.Load()), conversion=-1
),
slice=ast.Constant(value="CODEFLASH_LOOP_INDEX"),
ctx=ast.Load(),
)
],
keywords=[],
),
lineno=node.lineno + 2,
col_offset=node.col_offset,
ast.Constant(value=".sqlite"),
]
)
],
keywords=[],
),
ast.Assign(
targets=[ast.Name(id="codeflash_con", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="sqlite3", ctx=ast.Load()), attr="connect", ctx=ast.Load()
),
args=[
ast.JoinedStr(
values=[
ast.Constant(value=f"{get_run_tmp_file('test_return_values_')}"),
ast.FormattedValue(
value=ast.Name(id="codeflash_iteration", ctx=ast.Load()), conversion=-1
),
ast.Constant(value=".sqlite"),
]
)
],
keywords=[],
lineno=node.lineno + 3,
col_offset=node.col_offset,
),
ast.Assign(
targets=[ast.Name(id="codeflash_cur", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="cursor", ctx=ast.Load()
),
lineno=node.lineno + 3,
col_offset=node.col_offset,
args=[],
keywords=[],
),
ast.Assign(
targets=[ast.Name(id="codeflash_cur", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="cursor", ctx=ast.Load()
),
args=[],
keywords=[],
lineno=node.lineno + 4,
col_offset=node.col_offset,
),
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load()
),
lineno=node.lineno + 4,
col_offset=node.col_offset,
args=[
ast.Constant(
value="CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT,"
" test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT,"
" loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)"
)
],
keywords=[],
),
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load()
),
args=[
ast.Constant(
value="CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT,"
" test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT,"
" loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)"
)
],
keywords=[],
lineno=node.lineno + 5,
col_offset=node.col_offset,
),
*node.body,
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="close", ctx=ast.Load()
),
lineno=node.lineno + 5,
col_offset=node.col_offset,
),
]
+ node.body
+ [
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="close", ctx=ast.Load()
),
args=[],
keywords=[],
)
args=[],
keywords=[],
)
]
)
),
]
return node
@ -261,7 +266,7 @@ class FunctionImportedAsVisitor(ast.NodeVisitor):
self.to_match = function.function_name
# TODO: Validate if the function imported is actually from the right module
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # noqa: N802
for alias in node.names:
if alias.name == self.to_match and hasattr(alias, "asname") and alias.asname is not None:
if self.function.parents:

View file

@ -30,16 +30,24 @@ codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_cl
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
print(f"!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!")
gc.disable()
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
gc.enable()
exception = None
try:
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
exception = e
codeflash_duration = time.perf_counter_ns() - counter
finally:
gc.enable()
if loop_index == 1:
pickled_return_value = pickle.dumps(return_value)
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
else:
pickled_return_value = None
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value))
codeflash_con.commit()
if exception:
raise exception
return return_value
"""
@ -92,16 +100,24 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
expected += """print(f'!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!')"""
expected += """
gc.disable()
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
gc.enable()
exception = None
try:
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
exception = e
codeflash_duration = time.perf_counter_ns() - counter
finally:
gc.enable()
if loop_index == 1:
pickled_return_value = pickle.dumps(return_value)
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
else:
pickled_return_value = None
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value))
codeflash_con.commit()
if exception:
raise exception
return return_value
class TestPigLatin(unittest.TestCase):
@ -189,16 +205,24 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
expected += """print(f'!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!')"""
expected += """
gc.disable()
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
gc.enable()
exception = None
try:
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
exception = e
codeflash_duration = time.perf_counter_ns() - counter
finally:
gc.enable()
if loop_index == 1:
pickled_return_value = pickle.dumps(return_value)
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
else:
pickled_return_value = None
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value))
codeflash_con.commit()
if exception:
raise exception
return return_value
def test_prepare_image_for_yolo():
@ -1809,16 +1833,24 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
expected += """ print(f'!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!')"""
expected += """
gc.disable()
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
gc.enable()
exception = None
try:
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
exception = e
codeflash_duration = time.perf_counter_ns() - counter
finally:
gc.enable()
if loop_index == 1:
pickled_return_value = pickle.dumps(return_value)
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
else:
pickled_return_value = None
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value))
codeflash_con.commit()
if exception:
raise exception
return return_value
def test_code_replacement10() -> None: