Revert helper functions definitions when they are not used anymore in the optimized FTO

This commit is contained in:
Saurabh Misra 2025-06-05 22:40:09 -07:00
parent 19dcbfb312
commit 14409f7a23
4 changed files with 1657 additions and 11 deletions

View file

@ -1,10 +1,17 @@
from __future__ import annotations
import ast
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.models.models import CodeOptimizationContext, FunctionSource
@dataclass
class UsageInfo:
@ -480,3 +487,210 @@ def print_definitions(definitions: dict[str, UsageInfo]) -> None:
print(f" Used by qualified function: {info.used_by_qualified_function}")
print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}")
print()
def revert_unused_helper_functions(
project_root, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
) -> None:
"""Revert unused helper functions back to their original definitions.
Args:
unused_helpers: List of unused helper functions to revert
original_helper_code: Dictionary mapping file paths to their original code
"""
if not unused_helpers:
return
logger.info(f"Reverting {len(unused_helpers)} unused helper function(s) to original definitions")
# Group unused helpers by file path
unused_helpers_by_file = defaultdict(list)
for helper in unused_helpers:
unused_helpers_by_file[helper.file_path].append(helper)
# For each file, revert the unused helper functions to their original definitions
for file_path, helpers_in_file in unused_helpers_by_file.items():
if file_path in original_helper_code:
try:
# Read current file content
current_code = file_path.read_text(encoding="utf8")
# Get original code for this file
original_code = original_helper_code[file_path]
# Use the code replacer to selectively revert only the unused helper functions
helper_names = [helper.qualified_name for helper in helpers_in_file]
reverted_code = replace_function_definitions_in_module(
function_names=helper_names,
optimized_code=original_code, # Use original code as the "optimized" code to revert
module_abspath=file_path,
preexisting_objects=set(), # Empty set since we're reverting
project_root_path=project_root,
)
if reverted_code:
logger.debug(f"Reverted unused helpers in {file_path}: {', '.join(helper_names)}")
except Exception as e:
logger.error(f"Error reverting unused helpers in {file_path}: {e}")
def _analyze_imports_in_optimized_code(
optimized_ast: ast.AST, code_context: CodeOptimizationContext
) -> dict[str, set[str]]:
"""Analyze import statements in optimized code to map imported names to qualified helper names.
Args:
optimized_ast: The AST of the optimized code
code_context: The code optimization context containing helper functions
Returns:
Dictionary mapping imported names to sets of possible qualified helper names
"""
imported_names_map = defaultdict(set)
# Create a lookup of helper functions by their simple names and file paths
helpers_by_name = defaultdict(list)
helpers_by_file = defaultdict(list)
for helper in code_context.helper_functions:
if helper.jedi_definition.type != "class":
helpers_by_name[helper.only_function_name].append(helper)
module_name = helper.file_path.stem
helpers_by_file[module_name].append(helper)
# Analyze import statements in the optimized code
for node in ast.walk(optimized_ast):
if isinstance(node, ast.ImportFrom):
# Handle "from module import function" statements
if node.module:
module_name = node.module
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
original_name = alias.name
# Find helpers that match this import
for helper in helpers_by_file.get(module_name, []):
if helper.only_function_name == original_name:
imported_names_map[imported_name].add(helper.qualified_name)
imported_names_map[imported_name].add(helper.fully_qualified_name)
elif isinstance(node, ast.Import):
# Handle "import module" statements
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
module_name = alias.name
# For "import module" statements, functions would be called as module.function
for helper in helpers_by_file.get(module_name, []):
full_call = f"{imported_name}.{helper.only_function_name}"
imported_names_map[full_call].add(helper.qualified_name)
imported_names_map[full_call].add(helper.fully_qualified_name)
return dict(imported_names_map)
def detect_unused_helper_functions(
function_to_optimize, code_context: CodeOptimizationContext, optimized_code: str
) -> list[FunctionSource]:
"""Detect helper functions that are no longer called by the optimized entrypoint function.
Args:
code_context: The code optimization context containing helper functions
optimized_code: The optimized code to analyze
Returns:
List of FunctionSource objects representing unused helper functions
"""
try:
# Parse the optimized code to analyze function calls and imports
optimized_ast = ast.parse(optimized_code)
# Find the optimized entrypoint function
entrypoint_function_ast = None
for node in ast.walk(optimized_ast):
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
entrypoint_function_ast = node
break
if not entrypoint_function_ast:
logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code")
return []
# First, analyze imports to build a mapping of imported names to their original qualified names
imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context)
# Extract all function calls in the entrypoint function
called_function_names = set()
for node in ast.walk(entrypoint_function_ast):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
# Regular function call: function_name()
called_name = node.func.id
called_function_names.add(called_name)
# Also add the qualified name if this is an imported function
if called_name in imported_names_map:
called_function_names.update(imported_names_map[called_name])
elif isinstance(node.func, ast.Attribute):
# Method call: obj.method() or self.method() or module.function()
if isinstance(node.func.value, ast.Name):
if node.func.value.id == "self":
# self.method_name() -> add both method_name and ClassName.method_name
called_function_names.add(node.func.attr)
# For class methods, also add the qualified name
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
class_name = function_to_optimize.parents[0].name
called_function_names.add(f"{class_name}.{node.func.attr}")
else:
# obj.method() or module.function()
attr_name = node.func.attr
called_function_names.add(attr_name)
called_function_names.add(f"{node.func.value.id}.{attr_name}")
# Check if this is a module.function call that maps to a helper
full_call = f"{node.func.value.id}.{attr_name}"
if full_call in imported_names_map:
called_function_names.update(imported_names_map[full_call])
# Handle nested attribute access like obj.attr.method()
else:
called_function_names.add(node.func.attr)
logger.debug(f"Functions called in optimized entrypoint: {called_function_names}")
logger.debug(f"Imported names mapping: {imported_names_map}")
# Find helper functions that are no longer called
unused_helpers = []
for helper_function in code_context.helper_functions:
if helper_function.jedi_definition.type != "class":
# Check if the helper function is called using multiple name variants
helper_qualified_name = helper_function.qualified_name
helper_simple_name = helper_function.only_function_name
helper_fully_qualified_name = helper_function.fully_qualified_name
# Create a set of all possible names this helper might be called by
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}
# For cross-file helpers, also consider module-based calls
if helper_function.file_path != function_to_optimize.file_path:
# Add potential module.function combinations
module_name = helper_function.file_path.stem
possible_call_names.add(f"{module_name}.{helper_simple_name}")
# Check if any of the possible names are in the called functions
is_called = bool(possible_call_names.intersection(called_function_names))
if not is_called:
unused_helpers.append(helper_function)
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
logger.debug(f" Checked names: {possible_call_names}")
else:
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")
return unused_helpers
except Exception as e:
logger.debug(f"Error detecting unused helper functions: {e}")
return []

View file

@ -43,6 +43,7 @@ from codeflash.code_utils.remove_generated_tests import remove_functions_from_ge
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.context import code_context_extractor
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
from codeflash.either import Failure, Success, is_successful
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import (
@ -298,7 +299,9 @@ class FunctionOptimizer:
self.log_successful_optimization(explanation, generated_tests, exp_type)
self.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=best_optimization.candidate.source_code
code_context=code_context,
optimized_code=best_optimization.candidate.source_code,
original_helper_code=original_helper_code,
)
new_code, new_helper_code = self.reformat_code_and_helpers(
@ -612,7 +615,7 @@ class FunctionOptimizer:
return new_code, new_helper_code
def replace_function_and_helpers_with_optimized_code(
self, code_context: CodeOptimizationContext, optimized_code: str
self, code_context: CodeOptimizationContext, optimized_code: str, original_helper_code: str
) -> bool:
did_update = False
read_writable_functions_by_file_path = defaultdict(set)
@ -630,6 +633,12 @@ class FunctionOptimizer:
preexisting_objects=code_context.preexisting_objects,
project_root_path=self.project_root,
)
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)
# Revert unused helper functions to their original definitions
if unused_helpers:
revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code)
return did_update
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:

View file

@ -69,7 +69,7 @@ def sorter(arr):
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
final_output = code_path.read_text(encoding="utf-8")
assert "inconsequential_var = '123'" in final_output
@ -804,7 +804,8 @@ class MainClass:
self.name = name
def main_method(self):
return HelperClass(self.name).helper_method()"""
return HelperClass(self.name).helper_method()
"""
file_path = Path(__file__).resolve()
func_top_optimize = FunctionToOptimize(
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
@ -1662,6 +1663,7 @@ print("Hello world")
)
assert new_code == original_code
def test_global_reassignment() -> None:
original_code = """a=1
print("Hello world")
@ -1733,7 +1735,7 @@ class NewClass:
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
@ -1809,7 +1811,7 @@ a=2
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
@ -1886,7 +1888,7 @@ class NewClass:
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
@ -1962,7 +1964,7 @@ class NewClass:
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
@ -2039,7 +2041,7 @@ class NewClass:
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
@ -2127,8 +2129,8 @@ a = 6
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
assert new_code.rstrip() == expected_code.rstrip()
assert new_code.rstrip() == expected_code.rstrip()

File diff suppressed because it is too large Load diff