mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Revert helper functions definitions when they are not used anymore in the optimized FTO
This commit is contained in:
parent
19dcbfb312
commit
14409f7a23
4 changed files with 1657 additions and 11 deletions
|
|
@ -1,10 +1,17 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
||||
from codeflash.models.models import CodeOptimizationContext, FunctionSource
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageInfo:
|
||||
|
|
@ -480,3 +487,210 @@ def print_definitions(definitions: dict[str, UsageInfo]) -> None:
|
|||
print(f" Used by qualified function: {info.used_by_qualified_function}")
|
||||
print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}")
|
||||
print()
|
||||
|
||||
|
||||
def revert_unused_helper_functions(
|
||||
project_root, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
|
||||
) -> None:
|
||||
"""Revert unused helper functions back to their original definitions.
|
||||
|
||||
Args:
|
||||
unused_helpers: List of unused helper functions to revert
|
||||
original_helper_code: Dictionary mapping file paths to their original code
|
||||
|
||||
"""
|
||||
if not unused_helpers:
|
||||
return
|
||||
|
||||
logger.info(f"Reverting {len(unused_helpers)} unused helper function(s) to original definitions")
|
||||
|
||||
# Group unused helpers by file path
|
||||
unused_helpers_by_file = defaultdict(list)
|
||||
for helper in unused_helpers:
|
||||
unused_helpers_by_file[helper.file_path].append(helper)
|
||||
|
||||
# For each file, revert the unused helper functions to their original definitions
|
||||
for file_path, helpers_in_file in unused_helpers_by_file.items():
|
||||
if file_path in original_helper_code:
|
||||
try:
|
||||
# Read current file content
|
||||
current_code = file_path.read_text(encoding="utf8")
|
||||
|
||||
# Get original code for this file
|
||||
original_code = original_helper_code[file_path]
|
||||
|
||||
# Use the code replacer to selectively revert only the unused helper functions
|
||||
helper_names = [helper.qualified_name for helper in helpers_in_file]
|
||||
reverted_code = replace_function_definitions_in_module(
|
||||
function_names=helper_names,
|
||||
optimized_code=original_code, # Use original code as the "optimized" code to revert
|
||||
module_abspath=file_path,
|
||||
preexisting_objects=set(), # Empty set since we're reverting
|
||||
project_root_path=project_root,
|
||||
)
|
||||
|
||||
if reverted_code:
|
||||
logger.debug(f"Reverted unused helpers in {file_path}: {', '.join(helper_names)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reverting unused helpers in {file_path}: {e}")
|
||||
|
||||
|
||||
def _analyze_imports_in_optimized_code(
|
||||
optimized_ast: ast.AST, code_context: CodeOptimizationContext
|
||||
) -> dict[str, set[str]]:
|
||||
"""Analyze import statements in optimized code to map imported names to qualified helper names.
|
||||
|
||||
Args:
|
||||
optimized_ast: The AST of the optimized code
|
||||
code_context: The code optimization context containing helper functions
|
||||
|
||||
Returns:
|
||||
Dictionary mapping imported names to sets of possible qualified helper names
|
||||
|
||||
"""
|
||||
imported_names_map = defaultdict(set)
|
||||
|
||||
# Create a lookup of helper functions by their simple names and file paths
|
||||
helpers_by_name = defaultdict(list)
|
||||
helpers_by_file = defaultdict(list)
|
||||
|
||||
for helper in code_context.helper_functions:
|
||||
if helper.jedi_definition.type != "class":
|
||||
helpers_by_name[helper.only_function_name].append(helper)
|
||||
module_name = helper.file_path.stem
|
||||
helpers_by_file[module_name].append(helper)
|
||||
|
||||
# Analyze import statements in the optimized code
|
||||
for node in ast.walk(optimized_ast):
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
# Handle "from module import function" statements
|
||||
if node.module:
|
||||
module_name = node.module
|
||||
for alias in node.names:
|
||||
imported_name = alias.asname if alias.asname else alias.name
|
||||
original_name = alias.name
|
||||
|
||||
# Find helpers that match this import
|
||||
for helper in helpers_by_file.get(module_name, []):
|
||||
if helper.only_function_name == original_name:
|
||||
imported_names_map[imported_name].add(helper.qualified_name)
|
||||
imported_names_map[imported_name].add(helper.fully_qualified_name)
|
||||
|
||||
elif isinstance(node, ast.Import):
|
||||
# Handle "import module" statements
|
||||
for alias in node.names:
|
||||
imported_name = alias.asname if alias.asname else alias.name
|
||||
module_name = alias.name
|
||||
|
||||
# For "import module" statements, functions would be called as module.function
|
||||
for helper in helpers_by_file.get(module_name, []):
|
||||
full_call = f"{imported_name}.{helper.only_function_name}"
|
||||
imported_names_map[full_call].add(helper.qualified_name)
|
||||
imported_names_map[full_call].add(helper.fully_qualified_name)
|
||||
|
||||
return dict(imported_names_map)
|
||||
|
||||
|
||||
def detect_unused_helper_functions(
|
||||
function_to_optimize, code_context: CodeOptimizationContext, optimized_code: str
|
||||
) -> list[FunctionSource]:
|
||||
"""Detect helper functions that are no longer called by the optimized entrypoint function.
|
||||
|
||||
Args:
|
||||
code_context: The code optimization context containing helper functions
|
||||
optimized_code: The optimized code to analyze
|
||||
|
||||
Returns:
|
||||
List of FunctionSource objects representing unused helper functions
|
||||
|
||||
"""
|
||||
try:
|
||||
# Parse the optimized code to analyze function calls and imports
|
||||
optimized_ast = ast.parse(optimized_code)
|
||||
|
||||
# Find the optimized entrypoint function
|
||||
entrypoint_function_ast = None
|
||||
for node in ast.walk(optimized_ast):
|
||||
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
|
||||
entrypoint_function_ast = node
|
||||
break
|
||||
|
||||
if not entrypoint_function_ast:
|
||||
logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code")
|
||||
return []
|
||||
|
||||
# First, analyze imports to build a mapping of imported names to their original qualified names
|
||||
imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context)
|
||||
|
||||
# Extract all function calls in the entrypoint function
|
||||
called_function_names = set()
|
||||
for node in ast.walk(entrypoint_function_ast):
|
||||
if isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
# Regular function call: function_name()
|
||||
called_name = node.func.id
|
||||
called_function_names.add(called_name)
|
||||
# Also add the qualified name if this is an imported function
|
||||
if called_name in imported_names_map:
|
||||
called_function_names.update(imported_names_map[called_name])
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
# Method call: obj.method() or self.method() or module.function()
|
||||
if isinstance(node.func.value, ast.Name):
|
||||
if node.func.value.id == "self":
|
||||
# self.method_name() -> add both method_name and ClassName.method_name
|
||||
called_function_names.add(node.func.attr)
|
||||
# For class methods, also add the qualified name
|
||||
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
|
||||
class_name = function_to_optimize.parents[0].name
|
||||
called_function_names.add(f"{class_name}.{node.func.attr}")
|
||||
else:
|
||||
# obj.method() or module.function()
|
||||
attr_name = node.func.attr
|
||||
called_function_names.add(attr_name)
|
||||
called_function_names.add(f"{node.func.value.id}.{attr_name}")
|
||||
# Check if this is a module.function call that maps to a helper
|
||||
full_call = f"{node.func.value.id}.{attr_name}"
|
||||
if full_call in imported_names_map:
|
||||
called_function_names.update(imported_names_map[full_call])
|
||||
# Handle nested attribute access like obj.attr.method()
|
||||
else:
|
||||
called_function_names.add(node.func.attr)
|
||||
|
||||
logger.debug(f"Functions called in optimized entrypoint: {called_function_names}")
|
||||
logger.debug(f"Imported names mapping: {imported_names_map}")
|
||||
|
||||
# Find helper functions that are no longer called
|
||||
unused_helpers = []
|
||||
for helper_function in code_context.helper_functions:
|
||||
if helper_function.jedi_definition.type != "class":
|
||||
# Check if the helper function is called using multiple name variants
|
||||
helper_qualified_name = helper_function.qualified_name
|
||||
helper_simple_name = helper_function.only_function_name
|
||||
helper_fully_qualified_name = helper_function.fully_qualified_name
|
||||
|
||||
# Create a set of all possible names this helper might be called by
|
||||
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}
|
||||
|
||||
# For cross-file helpers, also consider module-based calls
|
||||
if helper_function.file_path != function_to_optimize.file_path:
|
||||
# Add potential module.function combinations
|
||||
module_name = helper_function.file_path.stem
|
||||
possible_call_names.add(f"{module_name}.{helper_simple_name}")
|
||||
|
||||
# Check if any of the possible names are in the called functions
|
||||
is_called = bool(possible_call_names.intersection(called_function_names))
|
||||
|
||||
if not is_called:
|
||||
unused_helpers.append(helper_function)
|
||||
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
|
||||
logger.debug(f" Checked names: {possible_call_names}")
|
||||
else:
|
||||
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
|
||||
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")
|
||||
|
||||
return unused_helpers
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error detecting unused helper functions: {e}")
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ from codeflash.code_utils.remove_generated_tests import remove_functions_from_ge
|
|||
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.context import code_context_extractor
|
||||
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
|
||||
from codeflash.either import Failure, Success, is_successful
|
||||
from codeflash.models.ExperimentMetadata import ExperimentMetadata
|
||||
from codeflash.models.models import (
|
||||
|
|
@ -298,7 +299,9 @@ class FunctionOptimizer:
|
|||
self.log_successful_optimization(explanation, generated_tests, exp_type)
|
||||
|
||||
self.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=best_optimization.candidate.source_code
|
||||
code_context=code_context,
|
||||
optimized_code=best_optimization.candidate.source_code,
|
||||
original_helper_code=original_helper_code,
|
||||
)
|
||||
|
||||
new_code, new_helper_code = self.reformat_code_and_helpers(
|
||||
|
|
@ -612,7 +615,7 @@ class FunctionOptimizer:
|
|||
return new_code, new_helper_code
|
||||
|
||||
def replace_function_and_helpers_with_optimized_code(
|
||||
self, code_context: CodeOptimizationContext, optimized_code: str
|
||||
self, code_context: CodeOptimizationContext, optimized_code: str, original_helper_code: str
|
||||
) -> bool:
|
||||
did_update = False
|
||||
read_writable_functions_by_file_path = defaultdict(set)
|
||||
|
|
@ -630,6 +633,12 @@ class FunctionOptimizer:
|
|||
preexisting_objects=code_context.preexisting_objects,
|
||||
project_root_path=self.project_root,
|
||||
)
|
||||
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)
|
||||
|
||||
# Revert unused helper functions to their original definitions
|
||||
if unused_helpers:
|
||||
revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code)
|
||||
|
||||
return did_update
|
||||
|
||||
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ def sorter(arr):
|
|||
original_helper_code[helper_function_path] = helper_code
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=optimized_code
|
||||
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
|
||||
)
|
||||
final_output = code_path.read_text(encoding="utf-8")
|
||||
assert "inconsequential_var = '123'" in final_output
|
||||
|
|
@ -804,7 +804,8 @@ class MainClass:
|
|||
self.name = name
|
||||
|
||||
def main_method(self):
|
||||
return HelperClass(self.name).helper_method()"""
|
||||
return HelperClass(self.name).helper_method()
|
||||
"""
|
||||
file_path = Path(__file__).resolve()
|
||||
func_top_optimize = FunctionToOptimize(
|
||||
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
|
||||
|
|
@ -1662,6 +1663,7 @@ print("Hello world")
|
|||
)
|
||||
assert new_code == original_code
|
||||
|
||||
|
||||
def test_global_reassignment() -> None:
|
||||
original_code = """a=1
|
||||
print("Hello world")
|
||||
|
|
@ -1733,7 +1735,7 @@ class NewClass:
|
|||
original_helper_code[helper_function_path] = helper_code
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=optimized_code
|
||||
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
|
||||
)
|
||||
new_code = code_path.read_text(encoding="utf-8")
|
||||
code_path.unlink(missing_ok=True)
|
||||
|
|
@ -1809,7 +1811,7 @@ a=2
|
|||
original_helper_code[helper_function_path] = helper_code
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=optimized_code
|
||||
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
|
||||
)
|
||||
new_code = code_path.read_text(encoding="utf-8")
|
||||
code_path.unlink(missing_ok=True)
|
||||
|
|
@ -1886,7 +1888,7 @@ class NewClass:
|
|||
original_helper_code[helper_function_path] = helper_code
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=optimized_code
|
||||
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
|
||||
)
|
||||
new_code = code_path.read_text(encoding="utf-8")
|
||||
code_path.unlink(missing_ok=True)
|
||||
|
|
@ -1962,7 +1964,7 @@ class NewClass:
|
|||
original_helper_code[helper_function_path] = helper_code
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=optimized_code
|
||||
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
|
||||
)
|
||||
new_code = code_path.read_text(encoding="utf-8")
|
||||
code_path.unlink(missing_ok=True)
|
||||
|
|
@ -2039,7 +2041,7 @@ class NewClass:
|
|||
original_helper_code[helper_function_path] = helper_code
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=optimized_code
|
||||
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
|
||||
)
|
||||
new_code = code_path.read_text(encoding="utf-8")
|
||||
code_path.unlink(missing_ok=True)
|
||||
|
|
@ -2127,8 +2129,8 @@ a = 6
|
|||
original_helper_code[helper_function_path] = helper_code
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=optimized_code
|
||||
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
|
||||
)
|
||||
new_code = code_path.read_text(encoding="utf-8")
|
||||
code_path.unlink(missing_ok=True)
|
||||
assert new_code.rstrip() == expected_code.rstrip()
|
||||
assert new_code.rstrip() == expected_code.rstrip()
|
||||
|
|
|
|||
1421
tests/test_unused_helper_revert.py
Normal file
1421
tests/test_unused_helper_revert.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue