refactored to make a new category for line profiler tests

This commit is contained in:
Aseem Saxena 2025-03-25 17:46:05 -07:00
parent 81a6b78973
commit 89e72b7346
8 changed files with 181 additions and 100 deletions

View file

@ -0,0 +1,6 @@
from code_to_optimize.bubble_sort_in_class import BubbleSortClass
def sort_classmethod(x):
y = BubbleSortClass()
return y.sorter(x)

View file

@ -3,22 +3,60 @@ import libcst as cst
from pathlib import Path
from codeflash.code_utils.code_utils import get_run_tmp_file
def add_decorator_cst(module_node, function_name, decorator_name):
"""Adds a decorator to a function definition in a LibCST module node."""
def add_decorator_cst(module_node, function_path, decorator_name):
"""
Adds a decorator to a function or method definition in a LibCST module node.
Args:
module_node: LibCST module node
function_path: String path to the function (e.g., 'function_name' or 'ClassName.method_name')
decorator_name: Name of the decorator to add
"""
path_parts = function_path.split('.')
class AddDecoratorTransformer(cst.CSTTransformer):
def leave_FunctionDef(self, original_node, updated_node):
if original_node.name.value == function_name:
new_decorator = cst.Decorator(
decorator=cst.Name(value=decorator_name)
)
def __init__(self):
super().__init__()
self.current_class = None
updated_decorators = list(updated_node.decorators)
updated_decorators.insert(0, new_decorator)
def visit_ClassDef(self, node):
# Track when we enter a class that matches our path
if len(path_parts) > 1 and node.name.value == path_parts[0]:
self.current_class = node.name.value
return True
return updated_node.with_changes(decorators=updated_decorators)
def leave_ClassDef(self, original_node, updated_node):
# Reset class tracking when leaving a class node
if self.current_class == original_node.name.value:
self.current_class = None
return updated_node
def leave_FunctionDef(self, original_node, updated_node):
# Handle standalone functions
if len(path_parts) == 1 and original_node.name.value == path_parts[0] and self.current_class is None:
return self._add_decorator(updated_node)
# Handle class methods
elif len(path_parts) == 2 and self.current_class == path_parts[0] and original_node.name.value == path_parts[1]:
return self._add_decorator(updated_node)
return updated_node
def _add_decorator(self, node):
# Create and add the decorator
new_decorator = cst.Decorator(
decorator=cst.Name(value=decorator_name)
)
# Check if this decorator already exists
for decorator in node.decorators:
if (isinstance(decorator.decorator, cst.Name) and
decorator.decorator.value == decorator_name):
return node # Decorator already exists
updated_decorators = list(node.decorators)
updated_decorators.insert(0, new_decorator)
return node.with_changes(decorators=updated_decorators)
transformer = AddDecoratorTransformer()
updated_module = module_node.visit(transformer)
return updated_module
@ -83,13 +121,23 @@ class ImportAdder(cst.CSTTransformer):
self.has_import = True
def add_decorator_imports(file_paths, fn_list, db_file):
def add_decorator_imports(function_to_optimize, code_context):
#self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
#todo change function signature to get filepaths of fn, helpers and db
# modify libcst parser to visit with qualified name
file_paths = list()
fn_list = list()
db_file = get_run_tmp_file(Path("baseline"))
file_paths.append(function_to_optimize.file_path)
fn_list.append(function_to_optimize.qualified_name)
for elem in code_context.helper_functions:
file_paths.append(elem.file_path)
fn_list.append(elem.qualified_name)
"""Adds a decorator to a function in a Python file."""
for file_path, fn_name in zip(file_paths, fn_list):
#open file
with open(file_path, "r", encoding="utf-8") as file:
file_contents = file.read()
# parse to cst
module_node = cst.parse_module(file_contents)
# add decorator
@ -97,7 +145,6 @@ def add_decorator_imports(file_paths, fn_list, db_file):
# add imports
# Create a transformer to add the import
transformer = ImportAdder("from line_profiler import profile")
# Apply the transformer to add the import
module_node = module_node.visit(transformer)
modified_code = isort.code(module_node.code, float_to_top=True)
@ -110,6 +157,7 @@ def add_decorator_imports(file_paths, fn_list, db_file):
modified_code = add_profile_enable(file_contents,db_file)
with open(file_paths[0],'w') as f:
f.write(modified_code)
return db_file
def prepare_lprofiler_files(prefix: str = "") -> tuple[Path]:

View file

@ -216,9 +216,9 @@ class FunctionParent:
class OriginalCodeBaseline(BaseModel):
behavioral_test_results: TestResults
benchmarking_test_results: TestResults
lprofiler_test_results: str
runtime: int
coverage_results: Optional[CoverageData]
lprof_results: str
class CoverageStatus(Enum):
@ -512,3 +512,4 @@ class FunctionCoverage:
class TestingMode(enum.Enum):
BEHAVIOR = "behavior"
PERFORMANCE = "performance"
LPROF = "lprof"

View file

@ -37,10 +37,10 @@ from codeflash.code_utils.config_consts import (
)
from codeflash.code_utils.formatter import format_code, sort_imports
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.code_utils.lprof_utils import add_decorator_imports
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.code_utils.lprof_utils import add_decorator_imports, prepare_lprofiler_files
from codeflash.context import code_context_extractor
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import Failure, Success, is_successful
@ -65,10 +65,10 @@ from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.concolic_testing import generate_concolic_tests
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.verification.parse_test_output import parse_test_results
from codeflash.verification.parse_lprof_test_output import parse_lprof_results
from codeflash.verification.parse_test_output import parse_test_results
from codeflash.verification.test_results import TestResults, TestType
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_lprof_tests
from codeflash.verification.verification_utils import get_test_file_path
from codeflash.verification.verifier import generate_tests
@ -78,7 +78,7 @@ if TYPE_CHECKING:
from codeflash.either import Result
from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate
from codeflash.verification.verification_utils import TestConfig
from collections import deque
class FunctionOptimizer:
def __init__(
@ -209,6 +209,7 @@ class FunctionOptimizer:
and "." in function_source.qualified_name
):
file_path_to_helper_classes[function_source.file_path].add(function_source.qualified_name.split(".")[0])
baseline_result = self.establish_original_code_baseline( # this needs better typing
code_context=code_context,
original_helper_code=original_helper_code,
@ -232,27 +233,7 @@ class FunctionOptimizer:
return Failure("The threshold for test coverage was not met.")
best_optimization = None
lprof_generated_results = []
logger.info(f"Adding more candidates based on lineprof info, calling ai service")
with concurrent.futures.ThreadPoolExecutor(max_workers= N_TESTS_TO_GENERATE + 2) as executor:
future_optimization_candidates_lp = executor.submit(self.aiservice_client.optimize_python_code_line_profiler,
source_code=code_context.read_writable_code,
dependency_code=code_context.read_only_context_code,
trace_id=self.function_trace_id,
line_profiler_results=original_code_baseline.lprof_results,
num_candidates = 10,
experiment_metadata = None)
future = [future_optimization_candidates_lp]
concurrent.futures.wait(future)
lprof_generated_results = future[0].result()
if len(lprof_generated_results)==0:
logger.info(f"Generated tests with line profiler failed.")
else:
logger.info(f"Generated tests with line profiler succeeded. Appending to optimization candidates.")
logger.info(f"initial optimization candidates: {len(optimizations_set.control)}")
optimizations_set.control.extend(lprof_generated_results)
logger.info(f"After adding optimization candidates: {len(optimizations_set.control)}")
#append to optimization candidates
for _u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
if candidates is None:
continue
@ -782,7 +763,7 @@ class FunctionOptimizer:
with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"):
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
success = True
lprof_results = ''
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_env["CODEFLASH_TRACER_DISABLE"] = "1"
@ -793,7 +774,6 @@ class FunctionOptimizer:
test_env["PYTHONPATH"] += os.pathsep + str(self.args.project_root)
coverage_results = None
lprofiler_results = None
# Instrument codeflash capture
try:
instrument_codeflash_capture(
@ -806,7 +786,6 @@ class FunctionOptimizer:
optimization_iteration=0,
testing_time=TOTAL_LOOPING_TIME,
enable_coverage=test_framework == "pytest",
enable_lprofiler=False,
code_context=code_context,
)
finally:
@ -822,42 +801,30 @@ class FunctionOptimizer:
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
if not coverage_critic(coverage_results, self.args.test_framework):
return Failure("The threshold for test coverage was not met.")
#Running lprof now
try:
#add decorator here and import too
lprofiler_database_file = prepare_lprofiler_files("baseline")
#add decorator config to file, need to delete afterwards
files_to_instrument = [self.function_to_optimize.file_path]
fns_to_instrument = [self.function_to_optimize.function_name]
for helper_obj in code_context.helper_functions:
files_to_instrument.append(helper_obj.file_path)
fns_to_instrument.append(helper_obj.qualified_name)
add_decorator_imports(files_to_instrument,fns_to_instrument, lprofiler_database_file)
#output doesn't matter, just need to run it
lprof_cmd_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=TOTAL_LOOPING_TIME,
enable_coverage=False,
enable_lprofiler=test_framework == "pytest",
code_context=code_context,
lprofiler_database_file=lprofiler_database_file,
)
#real magic happens here
lprof_results = parse_lprof_results(lprofiler_database_file)
except Exception as e:
logger.warning(f"Failed to run lprof for {self.function_to_optimize.function_name}. SKIPPING OPTIMIZING THIS FUNCTION.")
console.rule()
console.print(f"Failed to run lprof for {self.function_to_optimize.function_name}")
console.rule()
finally:
# Remove decorators and lineprof import
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
if test_framework == "pytest":
try:
lprofiler_database_file = add_decorator_imports(
self.function_to_optimize, code_context)
lprof_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.LPROF,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=TOTAL_LOOPING_TIME,
enable_coverage=False,
code_context=code_context,
lprofiler_database_file=lprofiler_database_file,
)
finally:
# Remove codeflash capture
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
if not lprof_results:
logger.warning(
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
)
console.rule()
benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
@ -867,6 +834,7 @@ class FunctionOptimizer:
enable_coverage=False,
code_context=code_context,
)
else:
benchmarking_results = TestResults()
start_time: float = time.time()
@ -928,7 +896,7 @@ class FunctionOptimizer:
benchmarking_test_results=benchmarking_results,
runtime=total_timing,
coverage_results=coverage_results,
lprof_results=lprof_results,
lprofiler_test_results=lprof_results,
),
functions_to_remove,
)
@ -1060,12 +1028,11 @@ class FunctionOptimizer:
testing_time: float = TOTAL_LOOPING_TIME,
*,
enable_coverage: bool = False,
enable_lprofiler: bool = False,
pytest_min_loops: int = 5,
pytest_max_loops: int = 100_000,
code_context: CodeOptimizationContext | None = None,
unittest_loop_index: int | None = None,
lprofiler_database_file: str | None = None,
lprofiler_database_file: Path | None = None,
) -> tuple[TestResults, CoverageData | None]:
coverage_database_file = None
coverage_config_file = None
@ -1079,7 +1046,19 @@ class FunctionOptimizer:
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
verbose=True,
enable_coverage=enable_coverage,
enable_lprofiler=enable_lprofiler,
)
elif testing_type == TestingMode.LPROF:
result_file_path, run_result = run_lprof_tests(
test_files,
cwd=self.project_root,
test_env=test_env,
pytest_cmd=self.test_cfg.pytest_cmd,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
pytest_target_runtime_seconds=testing_time,
pytest_min_loops=pytest_min_loops,
pytest_max_loops=pytest_max_loops,
test_framework=self.test_cfg.test_framework,
lprofiler_database_file=lprofiler_database_file
)
elif testing_type == TestingMode.PERFORMANCE:
result_file_path, run_result = run_benchmarking_tests(
@ -1108,7 +1087,7 @@ class FunctionOptimizer:
f"stdout: {run_result.stdout}\n"
f"stderr: {run_result.stderr}\n"
)
if not enable_lprofiler:
if testing_type in [TestingMode.BEHAVIOR, TestingMode.PERFORMANCE]:
results, coverage_results = parse_test_results(
test_xml_path=result_file_path,
test_files=test_files,
@ -1122,11 +1101,9 @@ class FunctionOptimizer:
coverage_database_file=coverage_database_file,
coverage_config_file=coverage_config_file,
)
return results, coverage_results
else:
#maintaining the function signature for the lprofiler
return TestResults(), None
results, coverage_results = parse_lprof_results(lprofiler_database_file=lprofiler_database_file)
return results, coverage_results
def generate_and_instrument_tests(
self,

View file

@ -107,4 +107,4 @@ def parse_lprof_results(lprofiler_database_file: Path | None) -> str:
else:
with open(lprofiler_database_file,'rb') as f:
stats = pickle.load(f)
return show_text(stats)
return show_text(stats), None

View file

@ -21,6 +21,7 @@ from codeflash.code_utils.code_utils import (
)
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
from codeflash.models.models import CoverageData, TestFiles
from codeflash.verification.parse_lprof_test_output import parse_lprof_results
from codeflash.verification.test_results import (
FunctionTestInvocation,
InvocationId,

View file

@ -12,7 +12,6 @@ import time
import warnings
from typing import TYPE_CHECKING, Any, Callable
from unittest import TestCase
import line_profiler
# PyTest Imports
import pytest

View file

@ -38,7 +38,6 @@ def run_behavioral_tests(
verbose: bool = False,
pytest_target_runtime_seconds: int = TOTAL_LOOPING_TIME,
enable_coverage: bool = False,
enable_lprofiler: bool = False,
) -> tuple[Path, subprocess.CompletedProcess, Path | None, Path | None]:
if test_framework == "pytest":
test_files: list[str] = []
@ -98,17 +97,6 @@ def run_behavioral_tests(
f"Result return code: {results.returncode}, "
f"{'Result stderr:' + str(results.stderr) if results.stderr else ''}"
)
elif enable_lprofiler:
pytest_test_env["LINE_PROFILE"]="1"
cmd = [SAFE_SYS_EXECUTABLE,"-m","pytest"]
results = execute_test_subprocess(
cmd+test_files, cwd=cwd, env=pytest_test_env, timeout=600
)
logger.debug(
f"Result return code: {results.returncode}, "
f"{'Result stderr:' + str(results.stderr) if results.stderr else ''}"
)
else:
blocklist_args = [f"-p no:{plugin}" for plugin in BEHAVIORAL_BLOCKLISTED_PLUGINS]
results = execute_test_subprocess(
@ -138,6 +126,67 @@ def run_behavioral_tests(
return result_file_path, results, coverage_database_file if enable_coverage else None, coverage_config_file if enable_coverage else None
def run_lprof_tests(
test_paths: TestFiles,
pytest_cmd: str,
test_env: dict[str, str],
cwd: Path,
test_framework: str,
*,
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME,
verbose: bool = False,
pytest_timeout: int | None = None,
pytest_min_loops: int = 5,
pytest_max_loops: int = 100_000,
lprofiler_database_file: Path | None = None,
) -> tuple[Path, subprocess.CompletedProcess]:
if test_framework == "pytest":
pytest_cmd_list = (
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX)
if pytest_cmd == "pytest"
else shlex.split(pytest_cmd)
)
test_files: list[str] = []
for file in test_paths.test_files:
if file.test_type in [TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST] and file.tests_in_file:
test_files.extend(
[
str(file.benchmarking_file_path)
+ "::"
+ (test.test_class + "::" if test.test_class else "")
+ (test.test_function.split("[", 1)[0] if "[" in test.test_function else test.test_function)
for test in file.tests_in_file
]
)
else:
test_files.append(str(file.benchmarking_file_path))
test_files = list(set(test_files)) # remove multiple calls in the same test function
# pytest_args = [
# "--capture=tee-sys",
# f"--timeout={pytest_timeout}",
# "-q",
# "--codeflash_loops_scope=session",
# f"--codeflash_min_loops={pytest_min_loops}",
# f"--codeflash_max_loops={pytest_max_loops}",
# f"--codeflash_seconds={pytest_target_runtime_seconds}",
# ]
# result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
# result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
pytest_test_env = test_env.copy()
# pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
# blocklist_args = [f"-p no:{plugin}" for plugin in BENCHMARKING_BLOCKLISTED_PLUGINS]
pytest_test_env["LINE_PROFILE"]="1"
results = execute_test_subprocess(
pytest_cmd_list + test_files,
cwd=cwd,
env=pytest_test_env,
timeout=600, # TODO: Make this dynamic
)
else:
msg = f"Unsupported test framework: {test_framework}"
raise ValueError(msg)
return lprofiler_database_file, results
def run_benchmarking_tests(
test_paths: TestFiles,