codeflash/codeflash/tracing/replay_test.py
2025-02-28 19:21:22 -08:00

163 lines
7.3 KiB
Python

from __future__ import annotations
import sqlite3
import textwrap
from collections.abc import Generator
from typing import Any, List, Optional
from codeflash.discovery.functions_to_optimize import FunctionProperties, inspect_top_level_functions_or_methods
from codeflash.tracing.tracing_utils import FunctionModules
def get_next_arg_and_return(
trace_file: str, function_name: str, file_name: str, class_name: Optional[str] = None, num_to_get: int = 25
) -> Generator[Any]:
db = sqlite3.connect(trace_file)
cur = db.cursor()
limit = num_to_get
if class_name is not None:
cursor = cur.execute(
"SELECT * FROM function_calls WHERE function = ? AND filename = ? AND classname = ? ORDER BY time_ns ASC LIMIT ?",
(function_name, file_name, class_name, limit),
)
else:
cursor = cur.execute(
"SELECT * FROM function_calls WHERE function = ? AND filename = ? ORDER BY time_ns ASC LIMIT ?",
(function_name, file_name, limit),
)
while (val := cursor.fetchone()) is not None:
event_type = val[0]
if event_type == "call":
yield val[7]
else:
raise ValueError("Invalid Trace event type")
def get_function_alias(module: str, function_name: str) -> str:
return "_".join(module.split(".")) + "_" + function_name
def create_trace_replay_test(
trace_file: str, functions: List[FunctionModules], test_framework: str = "pytest", max_run_count=100
) -> str:
assert test_framework in ["pytest", "unittest"]
imports = f"""import dill as pickle
{"import unittest" if test_framework == "unittest" else ""}
from codeflash.tracing.replay_test import get_next_arg_and_return
"""
# TODO: Module can have "-" character if the module-root is ".". Need to handle that case
function_properties: list[FunctionProperties] = [
inspect_top_level_functions_or_methods(
file_name=function.file_name,
function_or_method_name=function.function_name,
class_name=function.class_name,
line_no=function.line_no,
)
for function in functions
]
function_imports = []
for function, function_property in zip(functions, function_properties):
if not function_property.is_top_level:
# can't be imported and run in the replay test
continue
if function_property.is_staticmethod:
function_imports.append(
f"from {function.module_name} import {function_property.staticmethod_class_name} as {get_function_alias(function.module_name, function_property.staticmethod_class_name)}"
)
elif function.class_name:
function_imports.append(
f"from {function.module_name} import {function.class_name} as {get_function_alias(function.module_name, function.class_name)}"
)
else:
function_imports.append(
f"from {function.module_name} import {function.function_name} as {get_function_alias(function.module_name, function.function_name)}"
)
imports += "\n".join(function_imports)
functions_to_optimize = [function.function_name for function in functions if function.function_name != "__init__"]
metadata = f"""functions = {functions_to_optimize}
trace_file_path = r"{trace_file}"
""" # trace_file_path path is parsed with regex later, format is important
test_function_body = textwrap.dedent(
"""\
for arg_val_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name=r"{file_name}", num_to_get={max_run_count}):
args = pickle.loads(arg_val_pkl)
ret = {function_name}({args})
"""
)
test_class_method_body = textwrap.dedent(
"""\
for arg_val_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name=r"{file_name}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(arg_val_pkl){filter_variables}
ret = {class_name_alias}{method_name}(**args)
"""
)
test_class_staticmethod_body = textwrap.dedent(
"""\
for arg_val_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name=r"{file_name}", num_to_get={max_run_count}):
args = pickle.loads(arg_val_pkl){filter_variables}
ret = {class_name_alias}{method_name}(**args)
"""
)
if test_framework == "unittest":
self = "self"
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
else:
test_template = ""
self = ""
for func, func_property in zip(functions, function_properties):
if not func_property.is_top_level:
# can't be imported and run in the replay test
continue
if func.class_name is None and not func_property.is_staticmethod:
alias = get_function_alias(func.module_name, func.function_name)
test_body = test_function_body.format(
function_name=alias,
file_name=func.file_name,
orig_function_name=func.function_name,
max_run_count=max_run_count,
args="**args" if func_property.has_args else "",
)
elif func_property.is_staticmethod:
class_name_alias = get_function_alias(func.module_name, func_property.staticmethod_class_name)
alias = get_function_alias(
func.module_name, func_property.staticmethod_class_name + "_" + func.function_name
)
method_name = "." + func.function_name if func.function_name != "__init__" else ""
test_body = test_class_staticmethod_body.format(
orig_function_name=func.function_name,
file_name=func.file_name,
class_name_alias=class_name_alias,
method_name=method_name,
max_run_count=max_run_count,
filter_variables="",
)
else:
class_name_alias = get_function_alias(func.module_name, func.class_name)
alias = get_function_alias(func.module_name, func.class_name + "_" + func.function_name)
if func_property.is_classmethod:
filter_variables = '\n args.pop("cls", None)'
elif func.function_name == "__init__":
filter_variables = '\n args.pop("__class__", None)'
else:
filter_variables = ""
method_name = "." + func.function_name if func.function_name != "__init__" else ""
test_body = test_class_method_body.format(
orig_function_name=func.function_name,
file_name=func.file_name,
class_name_alias=class_name_alias,
class_name=func.class_name,
method_name=method_name,
max_run_count=max_run_count,
filter_variables=filter_variables,
)
formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
test_template += " " if test_framework == "unittest" else ""
test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n"
return imports + "\n" + metadata + "\n" + test_template