initial implementation
This commit is contained in:
parent
9e0aa9c3fb
commit
9b4ede56a3
5 changed files with 923 additions and 27 deletions
|
|
@ -0,0 +1,5 @@
|
|||
|
||||
import code_to_optimize.code_directories.retriever.main
|
||||
|
||||
def function_to_optimize():
|
||||
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
|
||||
|
|
@ -14,6 +14,7 @@ from libcst import CSTNode
|
|||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
|
||||
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
|
||||
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import (
|
||||
CodeContextType,
|
||||
|
|
@ -189,7 +190,7 @@ def extract_code_string_context_from_files(
|
|||
helpers_of_helpers_qualified_names,
|
||||
remove_docstrings,
|
||||
)
|
||||
|
||||
code_context = remove_unused_definitions_by_function_names(code_context, qualified_function_names | helpers_of_helpers_qualified_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
|
@ -217,6 +218,7 @@ def extract_code_string_context_from_files(
|
|||
code_context = parse_code_and_prune_cst(
|
||||
original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings
|
||||
)
|
||||
code_context = remove_unused_definitions_by_function_names(code_context, qualified_helper_function_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
|
@ -290,6 +292,9 @@ def extract_code_markdown_context_from_files(
|
|||
helpers_of_helpers_qualified_names,
|
||||
remove_docstrings,
|
||||
)
|
||||
code_context = remove_unused_definitions_by_function_names(
|
||||
code_context, qualified_function_names | helpers_of_helpers_qualified_names
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
|
|
@ -321,6 +326,9 @@ def extract_code_markdown_context_from_files(
|
|||
code_context = parse_code_and_prune_cst(
|
||||
original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings
|
||||
)
|
||||
code_context = remove_unused_definitions_by_function_names(
|
||||
code_context, qualified_helper_function_names
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
|
|
|||
476
codeflash/context/unused_definition_remover.py
Normal file
476
codeflash/context/unused_definition_remover.py
Normal file
|
|
@ -0,0 +1,476 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import libcst as cst
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageInfo:
|
||||
"""Information about a name and its usage."""
|
||||
|
||||
name: str
|
||||
used_by_qualified_function: bool = False
|
||||
dependencies: set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
def extract_names_from_targets(target: cst.CSTNode) -> list[str]:
|
||||
"""Extract all variable names from a target node, including from tuple unpacking."""
|
||||
names = []
|
||||
|
||||
# Handle a simple name
|
||||
if isinstance(target, cst.Name):
|
||||
names.append(target.value)
|
||||
|
||||
# Handle any node with a value attribute (StarredElement, etc.)
|
||||
elif hasattr(target, "value"):
|
||||
names.extend(extract_names_from_targets(target.value))
|
||||
|
||||
# Handle any node with elements attribute (tuples, lists, etc.)
|
||||
elif hasattr(target, "elements"):
|
||||
for element in target.elements:
|
||||
# Recursive call for each element
|
||||
names.extend(extract_names_from_targets(element))
|
||||
|
||||
return names
|
||||
|
||||
|
||||
def collect_top_level_definitions(node: cst.CSTNode, definitions: dict[str, UsageInfo] = None) -> dict[str, UsageInfo]:
|
||||
"""Recursively collect all top-level variable, function, and class definitions."""
|
||||
if definitions is None:
|
||||
definitions = {}
|
||||
|
||||
# Handle top-level function definitions
|
||||
if isinstance(node, cst.FunctionDef):
|
||||
name = node.name.value
|
||||
definitions[name] = UsageInfo(
|
||||
name=name,
|
||||
used_by_qualified_function=False, # Will be marked later if in qualified functions
|
||||
)
|
||||
return definitions
|
||||
|
||||
# Handle top-level class definitions
|
||||
if isinstance(node, cst.ClassDef):
|
||||
name = node.name.value
|
||||
definitions[name] = UsageInfo(name=name)
|
||||
|
||||
# Also collect method definitions within the class
|
||||
if hasattr(node, "body") and isinstance(node.body, cst.IndentedBlock):
|
||||
for statement in node.body.body:
|
||||
if isinstance(statement, cst.FunctionDef):
|
||||
method_name = f"{name}.{statement.name.value}"
|
||||
definitions[method_name] = UsageInfo(name=method_name)
|
||||
|
||||
return definitions
|
||||
|
||||
# Handle top-level variable assignments
|
||||
if isinstance(node, cst.Assign):
|
||||
for target in node.targets:
|
||||
names = extract_names_from_targets(target.target)
|
||||
for name in names:
|
||||
definitions[name] = UsageInfo(name=name)
|
||||
return definitions
|
||||
|
||||
if isinstance(node, cst.AnnAssign | cst.AugAssign):
|
||||
if isinstance(node.target, cst.Name):
|
||||
name = node.target.value
|
||||
definitions[name] = UsageInfo(name=name)
|
||||
else:
|
||||
names = extract_names_from_targets(node.target)
|
||||
for name in names:
|
||||
definitions[name] = UsageInfo(name=name)
|
||||
return definitions
|
||||
|
||||
# Recursively process children. Takes care of top level assignments in if/else/while/for blocks
|
||||
section_names = get_section_names(node)
|
||||
|
||||
if section_names:
|
||||
for section in section_names:
|
||||
original_content = getattr(node, section, None)
|
||||
# If section contains a list of nodes
|
||||
if isinstance(original_content, list | tuple):
|
||||
for child in original_content:
|
||||
collect_top_level_definitions(child, definitions)
|
||||
# If section contains a single node
|
||||
elif original_content is not None:
|
||||
collect_top_level_definitions(original_content, definitions)
|
||||
|
||||
return definitions
|
||||
|
||||
|
||||
def get_section_names(node: cst.CSTNode) -> list[str]:
|
||||
"""Return the section attribute names (e.g., body, orelse) for a given node if they exist."""
|
||||
possible_sections = ["body", "orelse", "finalbody", "handlers"]
|
||||
return [sec for sec in possible_sections if hasattr(node, sec)]
|
||||
|
||||
|
||||
class DependencyCollector(cst.CSTVisitor):
|
||||
"""Collects dependencies between definitions using the visitor pattern with depth tracking."""
|
||||
|
||||
def __init__(self, definitions: dict[str, UsageInfo]) -> None:
|
||||
super().__init__()
|
||||
self.definitions = definitions
|
||||
# Track function and class depths
|
||||
self.function_depth = 0
|
||||
self.class_depth = 0
|
||||
# Track top-level qualified names
|
||||
self.current_top_level_name = ""
|
||||
self.current_class = ""
|
||||
# Track if we're processing a top-level variable
|
||||
self.processing_variable = False
|
||||
self.current_variable_names = set()
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
function_name = node.name.value
|
||||
|
||||
if self.function_depth == 0:
|
||||
# This is a top-level function
|
||||
if self.class_depth > 0:
|
||||
# If inside a class, we're now tracking dependencies at the class level
|
||||
self.current_top_level_name = f"{self.current_class}.{function_name}"
|
||||
else:
|
||||
# Regular top-level function
|
||||
self.current_top_level_name = function_name
|
||||
|
||||
# Check parameter type annotations for dependencies
|
||||
if hasattr(node, "params") and node.params:
|
||||
for param in node.params.params:
|
||||
if param.annotation:
|
||||
# Visit the annotation to extract dependencies
|
||||
self._collect_annotation_dependencies(param.annotation)
|
||||
|
||||
self.function_depth += 1
|
||||
|
||||
def _collect_annotation_dependencies(self, annotation: cst.Annotation) -> None:
|
||||
"""Extract dependencies from type annotations"""
|
||||
if hasattr(annotation, "annotation"):
|
||||
# Extract names from annotation (could be Name, Attribute, Subscript, etc.)
|
||||
self._extract_names_from_annotation(annotation.annotation)
|
||||
|
||||
def _extract_names_from_annotation(self, node: cst.CSTNode) -> None:
|
||||
"""Extract names from a type annotation node"""
|
||||
# Simple name reference like 'int', 'str', or custom type
|
||||
if isinstance(node, cst.Name):
|
||||
name = node.value
|
||||
if name in self.definitions and name != self.current_top_level_name and self.current_top_level_name:
|
||||
self.definitions[self.current_top_level_name].dependencies.add(name)
|
||||
|
||||
# Handle compound annotations like List[int], Dict[str, CustomType], etc.
|
||||
elif isinstance(node, cst.Subscript):
|
||||
if hasattr(node, "value"):
|
||||
self._extract_names_from_annotation(node.value)
|
||||
if hasattr(node, "slice"):
|
||||
for slice_item in node.slice:
|
||||
if hasattr(slice_item, "slice"):
|
||||
self._extract_names_from_annotation(slice_item.slice)
|
||||
|
||||
# Handle attribute access like module.Type
|
||||
elif isinstance(node, cst.Attribute):
|
||||
if hasattr(node, "value"):
|
||||
self._extract_names_from_annotation(node.value)
|
||||
# No need to check the attribute name itself as it's likely not a top-level definition
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
|
||||
self.function_depth -= 1
|
||||
|
||||
if self.function_depth == 0 and self.class_depth == 0:
|
||||
# Exiting top-level function that's not in a class
|
||||
self.current_top_level_name = ""
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
class_name = node.name.value
|
||||
|
||||
if self.class_depth == 0:
|
||||
# This is a top-level class
|
||||
self.current_class = class_name
|
||||
self.current_top_level_name = class_name
|
||||
|
||||
self.class_depth += 1
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
self.class_depth -= 1
|
||||
|
||||
if self.class_depth == 0:
|
||||
# Exiting top-level class
|
||||
self.current_class = ""
|
||||
self.current_top_level_name = ""
|
||||
|
||||
def visit_Assign(self, node: cst.Assign) -> None:
|
||||
# Only handle top-level assignments
|
||||
if self.function_depth == 0 and self.class_depth == 0:
|
||||
for target in node.targets:
|
||||
# Extract all variable names from the target
|
||||
names = extract_names_from_targets(target.target)
|
||||
|
||||
# Check if any of these names are top-level definitions we're tracking
|
||||
tracked_names = [name for name in names if name in self.definitions]
|
||||
if tracked_names:
|
||||
self.processing_variable = True
|
||||
self.current_variable_names.update(tracked_names)
|
||||
# Use the first tracked name as the current top-level name (for dependency tracking)
|
||||
self.current_top_level_name = tracked_names[0]
|
||||
|
||||
def leave_Assign(self, original_node: cst.Assign) -> None:
|
||||
if self.processing_variable:
|
||||
self.processing_variable = False
|
||||
self.current_variable_names.clear()
|
||||
self.current_top_level_name = ""
|
||||
|
||||
def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
|
||||
# Extract names from the variable annotations
|
||||
if hasattr(node, "annotation") and node.annotation:
|
||||
# First mark we're processing a variable to avoid recording it as a dependency of itself
|
||||
self.processing_variable = True
|
||||
if isinstance(node.target, cst.Name):
|
||||
self.current_variable_names.add(node.target.value)
|
||||
else:
|
||||
self.current_variable_names.update(extract_names_from_targets(node.target))
|
||||
|
||||
# Process the annotation
|
||||
self._collect_annotation_dependencies(node.annotation)
|
||||
|
||||
# Reset processing state
|
||||
self.processing_variable = False
|
||||
self.current_variable_names.clear()
|
||||
|
||||
def visit_Name(self, node: cst.Name) -> None:
|
||||
name = node.value
|
||||
|
||||
# Skip if we're not inside a tracked definition
|
||||
if not self.current_top_level_name or self.current_top_level_name not in self.definitions:
|
||||
return
|
||||
|
||||
# Skip if we're looking at the variable name itself in an assignment
|
||||
if self.processing_variable and name in self.current_variable_names:
|
||||
return
|
||||
|
||||
# Check if name is a top-level definition we're tracking
|
||||
if name in self.definitions and name != self.current_top_level_name:
|
||||
self.definitions[self.current_top_level_name].dependencies.add(name)
|
||||
|
||||
|
||||
class QualifiedFunctionUsageMarker:
|
||||
"""Marks definitions that are used by specific qualified functions."""
|
||||
|
||||
def __init__(self, definitions: dict[str, UsageInfo], qualified_function_names: set[str]) -> None:
|
||||
self.definitions = definitions
|
||||
self.qualified_function_names = qualified_function_names
|
||||
self.expanded_qualified_functions = self._expand_qualified_functions()
|
||||
|
||||
def _expand_qualified_functions(self) -> set[str]:
|
||||
"""Expand the qualified function names to include related methods."""
|
||||
expanded = set(self.qualified_function_names)
|
||||
|
||||
# Find class methods and add their containing classes and dunder methods
|
||||
for qualified_name in list(self.qualified_function_names):
|
||||
if "." in qualified_name:
|
||||
class_name, method_name = qualified_name.split(".", 1)
|
||||
|
||||
# Add the class itself
|
||||
expanded.add(class_name)
|
||||
|
||||
# Add all dunder methods of the class
|
||||
for name in self.definitions:
|
||||
if name.startswith(f"{class_name}.__") and name.endswith("__"):
|
||||
expanded.add(name)
|
||||
|
||||
return expanded
|
||||
|
||||
def mark_used_definitions(self) -> None:
|
||||
"""Find all qualified functions and mark them and their dependencies as used."""
|
||||
# First identify all specified functions (including expanded ones)
|
||||
functions_to_mark = [name for name in self.expanded_qualified_functions if name in self.definitions]
|
||||
|
||||
# For each specified function, mark it and all its dependencies as used
|
||||
for func_name in functions_to_mark:
|
||||
self.definitions[func_name].used_by_qualified_function = True
|
||||
for dep in self.definitions[func_name].dependencies:
|
||||
self.mark_as_used_recursively(dep)
|
||||
|
||||
def mark_as_used_recursively(self, name: str) -> None:
|
||||
"""Mark a name and all its dependencies as used recursively."""
|
||||
if name not in self.definitions:
|
||||
return
|
||||
|
||||
if self.definitions[name].used_by_qualified_function:
|
||||
return # Already marked
|
||||
|
||||
self.definitions[name].used_by_qualified_function = True
|
||||
|
||||
# Mark all dependencies as used
|
||||
for dep in self.definitions[name].dependencies:
|
||||
self.mark_as_used_recursively(dep)
|
||||
|
||||
|
||||
def remove_unused_definitions_recursively(
|
||||
node: cst.CSTNode, definitions: dict[str, UsageInfo]
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node to remove unused definitions.
|
||||
|
||||
Args:
|
||||
node: The CST node to process
|
||||
definitions: Dictionary of definition info
|
||||
|
||||
Returns:
|
||||
(filtered_node, used_by_function):
|
||||
filtered_node: The modified CST node or None if it should be removed
|
||||
used_by_function: True if this node or any child is used by qualified functions
|
||||
|
||||
"""
|
||||
# Skip import statements
|
||||
if isinstance(node, cst.Import | cst.ImportFrom):
|
||||
return node, True
|
||||
|
||||
# Never remove function definitions
|
||||
if isinstance(node, cst.FunctionDef):
|
||||
return node, True
|
||||
|
||||
# Never remove class definitions
|
||||
if isinstance(node, cst.ClassDef):
|
||||
class_name = node.name.value
|
||||
|
||||
# Check if any methods or variables in this class are used
|
||||
method_or_var_used = False
|
||||
class_has_dependencies = False
|
||||
|
||||
# Check if class itself is marked as used
|
||||
if class_name in definitions and definitions[class_name].used_by_qualified_function:
|
||||
class_has_dependencies = True
|
||||
|
||||
if hasattr(node, "body") and isinstance(node.body, cst.IndentedBlock):
|
||||
updates = {}
|
||||
new_statements = []
|
||||
|
||||
for statement in node.body.body:
|
||||
# Keep all function definitions
|
||||
if isinstance(statement, cst.FunctionDef):
|
||||
method_name = f"{class_name}.{statement.name.value}"
|
||||
if method_name in definitions and definitions[method_name].used_by_qualified_function:
|
||||
method_or_var_used = True
|
||||
new_statements.append(statement)
|
||||
# Only process variable assignments
|
||||
elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)):
|
||||
var_used = False
|
||||
|
||||
# Check if any variable in this assignment is used
|
||||
if isinstance(statement, cst.Assign):
|
||||
for target in statement.targets:
|
||||
names = extract_names_from_targets(target.target)
|
||||
for name in names:
|
||||
class_var_name = f"{class_name}.{name}"
|
||||
if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function:
|
||||
var_used = True
|
||||
method_or_var_used = True
|
||||
break
|
||||
elif isinstance(statement, (cst.AnnAssign, cst.AugAssign)):
|
||||
names = extract_names_from_targets(statement.target)
|
||||
for name in names:
|
||||
class_var_name = f"{class_name}.{name}"
|
||||
if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function:
|
||||
var_used = True
|
||||
method_or_var_used = True
|
||||
break
|
||||
|
||||
if var_used or class_has_dependencies:
|
||||
new_statements.append(statement)
|
||||
else:
|
||||
# Keep all other statements in the class
|
||||
new_statements.append(statement)
|
||||
|
||||
# Update the class body
|
||||
new_body = node.body.with_changes(body=new_statements)
|
||||
updates["body"] = new_body
|
||||
|
||||
return node.with_changes(**updates), True
|
||||
|
||||
return node, method_or_var_used or class_has_dependencies
|
||||
|
||||
# Handle assignments (Assign and AnnAssign)
|
||||
if isinstance(node, cst.Assign):
|
||||
for target in node.targets:
|
||||
names = extract_names_from_targets(target.target)
|
||||
for name in names:
|
||||
if name in definitions and definitions[name].used_by_qualified_function:
|
||||
return node, True
|
||||
return None, False
|
||||
|
||||
if isinstance(node, cst.AnnAssign | cst.AugAssign):
|
||||
names = extract_names_from_targets(node.target)
|
||||
for name in names:
|
||||
if name in definitions and definitions[name].used_by_qualified_function:
|
||||
return node, True
|
||||
return None, False
|
||||
|
||||
# For other nodes, recursively process children
|
||||
section_names = get_section_names(node)
|
||||
if not section_names:
|
||||
return node, False
|
||||
|
||||
updates = {}
|
||||
found_used = False
|
||||
|
||||
for section in section_names:
|
||||
original_content = getattr(node, section, None)
|
||||
if isinstance(original_content, list | tuple):
|
||||
new_children = []
|
||||
section_found_used = False
|
||||
|
||||
for child in original_content:
|
||||
filtered, used = remove_unused_definitions_recursively(child, definitions)
|
||||
if filtered:
|
||||
new_children.append(filtered)
|
||||
section_found_used |= used
|
||||
|
||||
if new_children or section_found_used:
|
||||
found_used |= section_found_used
|
||||
updates[section] = new_children
|
||||
elif original_content is not None:
|
||||
filtered, used = remove_unused_definitions_recursively(original_content, definitions)
|
||||
found_used |= used
|
||||
if filtered:
|
||||
updates[section] = filtered
|
||||
if not found_used:
|
||||
return None, False
|
||||
if updates:
|
||||
return node.with_changes(**updates), found_used
|
||||
|
||||
return node, False
|
||||
|
||||
|
||||
def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str:
|
||||
"""Analyze a file and remove top level definitions not used by specified functions.
|
||||
|
||||
Top level definitions, in this context, are only classes, variables or functions.
|
||||
If a class is referenced by a qualified function, we keep the entire class.
|
||||
|
||||
Args:
|
||||
code: The code to process
|
||||
qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname'
|
||||
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
# Collect all definitions (top level classes, variables or function)
|
||||
definitions = collect_top_level_definitions(module)
|
||||
|
||||
# Collect dependencies between definitions using the visitor pattern
|
||||
dependency_collector = DependencyCollector(definitions)
|
||||
module.visit(dependency_collector)
|
||||
|
||||
# Mark definitions used by specified functions, and their dependencies recursively
|
||||
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
|
||||
usage_marker.mark_used_definitions()
|
||||
|
||||
# Apply the recursive removal transformation
|
||||
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
|
||||
|
||||
return modified_module.code if modified_module else ""
|
||||
|
||||
|
||||
def print_definitions(definitions: dict[str, UsageInfo]) -> None:
|
||||
"""Print information about each definition without the complex node object, used for debugging."""
|
||||
print(f"Found {len(definitions)} definitions:")
|
||||
for name, info in sorted(definitions.items()):
|
||||
print(f" - Name: {name}")
|
||||
print(f" Used by qualified function: {info.used_by_qualified_function}")
|
||||
print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}")
|
||||
print()
|
||||
|
|
@ -929,9 +929,6 @@ def fetch_and_process_data():
|
|||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
\"\"\"A class for processing data.\"\"\"
|
||||
|
||||
|
|
@ -941,11 +938,6 @@ class DataProcessor:
|
|||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
```python:{path_to_file.relative_to(project_root)}
|
||||
if __name__ == "__main__":
|
||||
result = fetch_and_process_data()
|
||||
print("Processed data:", result)
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
|
|
@ -1006,9 +998,6 @@ def fetch_and_transform_data():
|
|||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
\"\"\"A class for processing data.\"\"\"
|
||||
|
||||
|
|
@ -1018,11 +1007,6 @@ class DataProcessor:
|
|||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
```python:{path_to_file.relative_to(project_root)}
|
||||
if __name__ == "__main__":
|
||||
result = fetch_and_process_data()
|
||||
print("Processed data:", result)
|
||||
```
|
||||
```python:{path_to_transform_utils.relative_to(project_root)}
|
||||
class DataTransformer:
|
||||
|
||||
|
|
@ -1084,9 +1068,6 @@ class DataTransformer:
|
|||
return self.data
|
||||
```
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
\"\"\"A class for processing data.\"\"\"
|
||||
|
||||
|
|
@ -1147,9 +1128,6 @@ def update_data(data):
|
|||
return data + " updated"
|
||||
```
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
\"\"\"A class for processing data.\"\"\"
|
||||
|
||||
|
|
@ -1252,9 +1230,6 @@ class DataTransformer:
|
|||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
\"\"\"A class for processing data.\"\"\"
|
||||
|
||||
|
|
@ -1322,4 +1297,20 @@ def outside_method():
|
|||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
|
||||
def test_direct_module_import() -> None:
|
||||
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
||||
path_to_main = project_root / "main.py"
|
||||
path_to_fto = project_root / "import_test.py"
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="function_to_optimize",
|
||||
file_path=str(path_to_fto),
|
||||
parents=[],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
print(read_only_context.strip())
|
||||
416
tests/test_remove_unused_definitions.py
Normal file
416
tests/test_remove_unused_definitions.py
Normal file
|
|
@ -0,0 +1,416 @@
|
|||
import libcst as cst
|
||||
|
||||
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
|
||||
|
||||
|
||||
def test_variable_removal_only() -> None:
|
||||
"""Test that only variables not used by specified functions are removed, not functions."""
|
||||
code = """
|
||||
def main_function():
|
||||
return USED_CONSTANT + 10
|
||||
|
||||
def helper_function():
|
||||
return 42
|
||||
|
||||
USED_CONSTANT = 42
|
||||
UNUSED_CONSTANT = 123
|
||||
|
||||
def another_function():
|
||||
return UNUSED_CONSTANT
|
||||
"""
|
||||
|
||||
expected = """
|
||||
def main_function():
|
||||
return USED_CONSTANT + 10
|
||||
|
||||
def helper_function():
|
||||
return 42
|
||||
|
||||
USED_CONSTANT = 42
|
||||
|
||||
def another_function():
|
||||
return UNUSED_CONSTANT
|
||||
"""
|
||||
|
||||
qualified_functions = {"main_function"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
# Normalize whitespace for comparison
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_class_variable_removal() -> None:
|
||||
"""Test that only class variables not used by specified functions are removed, not methods."""
|
||||
code = """
|
||||
class MyClass:
|
||||
CLASS_USED = "used value"
|
||||
CLASS_UNUSED = "unused value"
|
||||
|
||||
def __init__(self):
|
||||
self.value = self.CLASS_USED
|
||||
self.other = self.CLASS_UNUSED
|
||||
|
||||
def used_method(self):
|
||||
return self.value
|
||||
|
||||
def unused_method(self):
|
||||
return "Not used but not removed"
|
||||
|
||||
GLOBAL_USED = "global used"
|
||||
GLOBAL_UNUSED = "global unused"
|
||||
|
||||
def helper_function():
|
||||
return MyClass().used_method() + GLOBAL_USED
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class MyClass:
|
||||
CLASS_USED = "used value"
|
||||
CLASS_UNUSED = "unused value"
|
||||
|
||||
def __init__(self):
|
||||
self.value = self.CLASS_USED
|
||||
self.other = self.CLASS_UNUSED
|
||||
|
||||
def used_method(self):
|
||||
return self.value
|
||||
|
||||
def unused_method(self):
|
||||
return "Not used but not removed"
|
||||
|
||||
GLOBAL_USED = "global used"
|
||||
|
||||
def helper_function():
|
||||
return MyClass().used_method() + GLOBAL_USED
|
||||
"""
|
||||
|
||||
qualified_functions = {"helper_function"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
# Normalize whitespace for comparison
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_complex_variable_dependencies() -> None:
|
||||
"""Test that only variables with complex dependencies are properly handled."""
|
||||
code = """
|
||||
def main_function():
|
||||
return DIRECT_DEPENDENCY
|
||||
|
||||
def unused_function():
|
||||
return "Not used but not removed"
|
||||
|
||||
DIRECT_DEPENDENCY = INDIRECT_DEPENDENCY + "_suffix"
|
||||
INDIRECT_DEPENDENCY = "base value"
|
||||
UNUSED_VARIABLE = "This should be removed"
|
||||
|
||||
TUPLE_USED, TUPLE_UNUSED = ("used", "unused")
|
||||
|
||||
def tuple_user():
|
||||
return TUPLE_USED
|
||||
"""
|
||||
|
||||
expected = """
|
||||
def main_function():
|
||||
return DIRECT_DEPENDENCY
|
||||
|
||||
def unused_function():
|
||||
return "Not used but not removed"
|
||||
|
||||
DIRECT_DEPENDENCY = INDIRECT_DEPENDENCY + "_suffix"
|
||||
INDIRECT_DEPENDENCY = "base value"
|
||||
|
||||
def tuple_user():
|
||||
return TUPLE_USED
|
||||
"""
|
||||
|
||||
qualified_functions = {"main_function"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_type_annotation_usage() -> None:
|
||||
"""Test that variables used in type annotations are considered used."""
|
||||
code = """
|
||||
# Type definition
|
||||
CustomType = int
|
||||
UnusedType = str
|
||||
|
||||
def main_function(param: CustomType) -> CustomType:
|
||||
return param + 10
|
||||
|
||||
def unused_function(param: UnusedType) -> UnusedType:
|
||||
return param + " suffix"
|
||||
|
||||
UNUSED_CONSTANT = 123
|
||||
"""
|
||||
|
||||
expected = """
|
||||
# Type definition
|
||||
CustomType = int
|
||||
|
||||
def main_function(param: CustomType) -> CustomType:
|
||||
return param + 10
|
||||
|
||||
def unused_function(param: UnusedType) -> UnusedType:
|
||||
return param + " suffix"
|
||||
|
||||
"""
|
||||
|
||||
qualified_functions = {"main_function"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
# Normalize whitespace for comparison
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_class_method_with_dunder_methods() -> None:
|
||||
"""Test that when a class method is used, dunder methods of that class are preserved."""
|
||||
code = """
|
||||
class MyClass:
|
||||
CLASS_VAR = "class variable"
|
||||
UNUSED_VAR = GLOBAL_VAR_2
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = GLOBAL_VAR
|
||||
|
||||
def __str__(self):
|
||||
return f"MyClass({self.value})"
|
||||
|
||||
def target_method(self):
|
||||
return self.value * 2
|
||||
|
||||
def unused_method(self):
|
||||
return "Not used"
|
||||
|
||||
GLOBAL_VAR = "global"
|
||||
GLOBAL_VAR_2 = "global"
|
||||
UNUSED_GLOBAL = "unused global"
|
||||
|
||||
def helper_function():
|
||||
obj = MyClass(5)
|
||||
return obj.target_method()
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class MyClass:
|
||||
CLASS_VAR = "class variable"
|
||||
UNUSED_VAR = GLOBAL_VAR_2
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = GLOBAL_VAR
|
||||
|
||||
def __str__(self):
|
||||
return f"MyClass({self.value})"
|
||||
|
||||
def target_method(self):
|
||||
return self.value * 2
|
||||
|
||||
def unused_method(self):
|
||||
return "Not used"
|
||||
|
||||
GLOBAL_VAR = "global"
|
||||
GLOBAL_VAR_2 = "global"
|
||||
|
||||
def helper_function():
|
||||
obj = MyClass(5)
|
||||
return obj.target_method()
|
||||
"""
|
||||
|
||||
qualified_functions = {"MyClass.target_method"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
# Normalize whitespace for comparison
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_complex_type_annotations() -> None:
|
||||
"""Test complex type annotations with nested types."""
|
||||
code = """
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
# Type aliases
|
||||
ItemType = Dict[str, int]
|
||||
ResultType = List[ItemType]
|
||||
UnusedType = Optional[str]
|
||||
|
||||
def process_data(items: ResultType) -> int:
|
||||
total = 0
|
||||
for item in items:
|
||||
for key, value in item.items():
|
||||
total += value
|
||||
return total
|
||||
|
||||
def unused_function(param: UnusedType) -> None:
|
||||
pass
|
||||
|
||||
# Variables
|
||||
SAMPLE_DATA: ResultType = [{"a": 1, "b": 2}]
|
||||
UNUSED_DATA: UnusedType = None
|
||||
"""
|
||||
|
||||
expected = """
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
# Type aliases
|
||||
ItemType = Dict[str, int]
|
||||
ResultType = List[ItemType]
|
||||
|
||||
def process_data(items: ResultType) -> int:
|
||||
total = 0
|
||||
for item in items:
|
||||
for key, value in item.items():
|
||||
total += value
|
||||
return total
|
||||
|
||||
def unused_function(param: UnusedType) -> None:
|
||||
pass
|
||||
"""
|
||||
|
||||
qualified_functions = {"process_data"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_try_except_finally_variables() -> None:
|
||||
"""Test handling of variables defined in try-except-finally blocks."""
|
||||
code = """
|
||||
import math
|
||||
import os
|
||||
|
||||
# Top-level try-except that defines variables
|
||||
try:
|
||||
MATH_CONSTANT = math.pi
|
||||
USED_ERROR_MSG = "An error occurred"
|
||||
UNUSED_CONST = 42
|
||||
except ImportError:
|
||||
MATH_CONSTANT = 3.14
|
||||
USED_ERROR_MSG = "Math module not available"
|
||||
UNUSED_CONST = 0
|
||||
finally:
|
||||
CLEANUP_FLAG = True
|
||||
UNUSED_CLEANUP = "Not used"
|
||||
|
||||
def use_constants():
|
||||
return f"Pi is approximately {MATH_CONSTANT}, message: {USED_ERROR_MSG}"
|
||||
|
||||
def use_cleanup():
|
||||
if CLEANUP_FLAG:
|
||||
return "Cleanup performed"
|
||||
return "No cleanup"
|
||||
|
||||
def unused_function():
|
||||
return UNUSED_CONST
|
||||
"""
|
||||
|
||||
expected = """
|
||||
import math
|
||||
import os
|
||||
|
||||
# Top-level try-except that defines variables
|
||||
try:
|
||||
MATH_CONSTANT = math.pi
|
||||
USED_ERROR_MSG = "An error occurred"
|
||||
except ImportError:
|
||||
MATH_CONSTANT = 3.14
|
||||
USED_ERROR_MSG = "Math module not available"
|
||||
finally:
|
||||
CLEANUP_FLAG = True
|
||||
|
||||
def use_constants():
|
||||
return f"Pi is approximately {MATH_CONSTANT}, message: {USED_ERROR_MSG}"
|
||||
|
||||
def use_cleanup():
|
||||
if CLEANUP_FLAG:
|
||||
return "Cleanup performed"
|
||||
return "No cleanup"
|
||||
|
||||
def unused_function():
|
||||
return UNUSED_CONST
|
||||
"""
|
||||
|
||||
qualified_functions = {"use_constants", "use_cleanup"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
def test_conditional_and_loop_variables() -> None:
|
||||
"""Test handling of variables defined in if-else and while loops."""
|
||||
code = """
|
||||
import sys
|
||||
import platform
|
||||
|
||||
# Top-level if-else block defining variables
|
||||
if sys.platform.startswith('win'):
|
||||
OS_TYPE = "Windows"
|
||||
OS_SEP = ""
|
||||
UNUSED_WIN_VAR = "Unused Windows variable"
|
||||
elif sys.platform.startswith('linux'):
|
||||
OS_TYPE = "Linux"
|
||||
OS_SEP = "/"
|
||||
UNUSED_LINUX_VAR = "Unused Linux variable"
|
||||
else:
|
||||
OS_TYPE = "Other"
|
||||
OS_SEP = "/"
|
||||
UNUSED_OTHER_VAR = "Unused other variable"
|
||||
|
||||
# While loop with variable definitions
|
||||
counter = 0
|
||||
while counter < 5:
|
||||
LOOP_RESULT = "Iteration " + str(counter)
|
||||
UNUSED_LOOP_VAR = "Unused loop " + str(counter)
|
||||
counter += 1
|
||||
|
||||
def get_platform_info():
|
||||
return "OS: " + OS_TYPE + ", Separator: " + OS_SEP
|
||||
|
||||
def get_loop_result():
|
||||
return LOOP_RESULT
|
||||
|
||||
def unused_function():
|
||||
result = ""
|
||||
if sys.platform.startswith('win'):
|
||||
result = UNUSED_WIN_VAR
|
||||
elif sys.platform.startswith('linux'):
|
||||
result = UNUSED_LINUX_VAR
|
||||
else:
|
||||
result = UNUSED_OTHER_VAR
|
||||
return result
|
||||
"""
|
||||
|
||||
expected = """
|
||||
import sys
|
||||
import platform
|
||||
|
||||
# Top-level if-else block defining variables
|
||||
if sys.platform.startswith('win'):
|
||||
OS_TYPE = "Windows"
|
||||
OS_SEP = ""
|
||||
elif sys.platform.startswith('linux'):
|
||||
OS_TYPE = "Linux"
|
||||
OS_SEP = "/"
|
||||
else:
|
||||
OS_TYPE = "Other"
|
||||
OS_SEP = "/"
|
||||
|
||||
# While loop with variable definitions
|
||||
counter = 0
|
||||
while counter < 5:
|
||||
LOOP_RESULT = "Iteration " + str(counter)
|
||||
counter += 1
|
||||
|
||||
def get_platform_info():
|
||||
return "OS: " + OS_TYPE + ", Separator: " + OS_SEP
|
||||
|
||||
def get_loop_result():
|
||||
return LOOP_RESULT
|
||||
|
||||
def unused_function():
|
||||
result = ""
|
||||
if sys.platform.startswith('win'):
|
||||
result = UNUSED_WIN_VAR
|
||||
elif sys.platform.startswith('linux'):
|
||||
result = UNUSED_LINUX_VAR
|
||||
else:
|
||||
result = UNUSED_OTHER_VAR
|
||||
return result
|
||||
"""
|
||||
|
||||
qualified_functions = {"get_platform_info", "get_loop_result"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
assert result.strip() == expected.strip()
|
||||
Loading…
Reference in a new issue