initial implementation

This commit is contained in:
Alvin Ryanputra 2025-04-16 14:14:05 -04:00
parent 9e0aa9c3fb
commit 9b4ede56a3
5 changed files with 923 additions and 27 deletions

View file

@ -0,0 +1,5 @@
import code_to_optimize.code_directories.retriever.main
def function_to_optimize():
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()

View file

@ -14,6 +14,7 @@ from libcst import CSTNode
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import (
CodeContextType,
@ -189,7 +190,7 @@ def extract_code_string_context_from_files(
helpers_of_helpers_qualified_names,
remove_docstrings,
)
code_context = remove_unused_definitions_by_function_names(code_context, qualified_function_names | helpers_of_helpers_qualified_names)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
@ -217,6 +218,7 @@ def extract_code_string_context_from_files(
code_context = parse_code_and_prune_cst(
original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings
)
code_context = remove_unused_definitions_by_function_names(code_context, qualified_helper_function_names)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
@ -290,6 +292,9 @@ def extract_code_markdown_context_from_files(
helpers_of_helpers_qualified_names,
remove_docstrings,
)
code_context = remove_unused_definitions_by_function_names(
code_context, qualified_function_names | helpers_of_helpers_qualified_names
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
@ -321,6 +326,9 @@ def extract_code_markdown_context_from_files(
code_context = parse_code_and_prune_cst(
original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings
)
code_context = remove_unused_definitions_by_function_names(
code_context, qualified_helper_function_names
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue

View file

@ -0,0 +1,476 @@
from __future__ import annotations
from dataclasses import dataclass, field
import libcst as cst
@dataclass
class UsageInfo:
"""Information about a name and its usage."""
name: str
used_by_qualified_function: bool = False
dependencies: set[str] = field(default_factory=set)
def extract_names_from_targets(target: cst.CSTNode) -> list[str]:
"""Extract all variable names from a target node, including from tuple unpacking."""
names = []
# Handle a simple name
if isinstance(target, cst.Name):
names.append(target.value)
# Handle any node with a value attribute (StarredElement, etc.)
elif hasattr(target, "value"):
names.extend(extract_names_from_targets(target.value))
# Handle any node with elements attribute (tuples, lists, etc.)
elif hasattr(target, "elements"):
for element in target.elements:
# Recursive call for each element
names.extend(extract_names_from_targets(element))
return names
def collect_top_level_definitions(node: cst.CSTNode, definitions: dict[str, UsageInfo] = None) -> dict[str, UsageInfo]:
"""Recursively collect all top-level variable, function, and class definitions."""
if definitions is None:
definitions = {}
# Handle top-level function definitions
if isinstance(node, cst.FunctionDef):
name = node.name.value
definitions[name] = UsageInfo(
name=name,
used_by_qualified_function=False, # Will be marked later if in qualified functions
)
return definitions
# Handle top-level class definitions
if isinstance(node, cst.ClassDef):
name = node.name.value
definitions[name] = UsageInfo(name=name)
# Also collect method definitions within the class
if hasattr(node, "body") and isinstance(node.body, cst.IndentedBlock):
for statement in node.body.body:
if isinstance(statement, cst.FunctionDef):
method_name = f"{name}.{statement.name.value}"
definitions[method_name] = UsageInfo(name=method_name)
return definitions
# Handle top-level variable assignments
if isinstance(node, cst.Assign):
for target in node.targets:
names = extract_names_from_targets(target.target)
for name in names:
definitions[name] = UsageInfo(name=name)
return definitions
if isinstance(node, cst.AnnAssign | cst.AugAssign):
if isinstance(node.target, cst.Name):
name = node.target.value
definitions[name] = UsageInfo(name=name)
else:
names = extract_names_from_targets(node.target)
for name in names:
definitions[name] = UsageInfo(name=name)
return definitions
# Recursively process children. Takes care of top level assignments in if/else/while/for blocks
section_names = get_section_names(node)
if section_names:
for section in section_names:
original_content = getattr(node, section, None)
# If section contains a list of nodes
if isinstance(original_content, list | tuple):
for child in original_content:
collect_top_level_definitions(child, definitions)
# If section contains a single node
elif original_content is not None:
collect_top_level_definitions(original_content, definitions)
return definitions
def get_section_names(node: cst.CSTNode) -> list[str]:
"""Return the section attribute names (e.g., body, orelse) for a given node if they exist."""
possible_sections = ["body", "orelse", "finalbody", "handlers"]
return [sec for sec in possible_sections if hasattr(node, sec)]
class DependencyCollector(cst.CSTVisitor):
"""Collects dependencies between definitions using the visitor pattern with depth tracking."""
def __init__(self, definitions: dict[str, UsageInfo]) -> None:
super().__init__()
self.definitions = definitions
# Track function and class depths
self.function_depth = 0
self.class_depth = 0
# Track top-level qualified names
self.current_top_level_name = ""
self.current_class = ""
# Track if we're processing a top-level variable
self.processing_variable = False
self.current_variable_names = set()
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
function_name = node.name.value
if self.function_depth == 0:
# This is a top-level function
if self.class_depth > 0:
# If inside a class, we're now tracking dependencies at the class level
self.current_top_level_name = f"{self.current_class}.{function_name}"
else:
# Regular top-level function
self.current_top_level_name = function_name
# Check parameter type annotations for dependencies
if hasattr(node, "params") and node.params:
for param in node.params.params:
if param.annotation:
# Visit the annotation to extract dependencies
self._collect_annotation_dependencies(param.annotation)
self.function_depth += 1
def _collect_annotation_dependencies(self, annotation: cst.Annotation) -> None:
"""Extract dependencies from type annotations"""
if hasattr(annotation, "annotation"):
# Extract names from annotation (could be Name, Attribute, Subscript, etc.)
self._extract_names_from_annotation(annotation.annotation)
def _extract_names_from_annotation(self, node: cst.CSTNode) -> None:
"""Extract names from a type annotation node"""
# Simple name reference like 'int', 'str', or custom type
if isinstance(node, cst.Name):
name = node.value
if name in self.definitions and name != self.current_top_level_name and self.current_top_level_name:
self.definitions[self.current_top_level_name].dependencies.add(name)
# Handle compound annotations like List[int], Dict[str, CustomType], etc.
elif isinstance(node, cst.Subscript):
if hasattr(node, "value"):
self._extract_names_from_annotation(node.value)
if hasattr(node, "slice"):
for slice_item in node.slice:
if hasattr(slice_item, "slice"):
self._extract_names_from_annotation(slice_item.slice)
# Handle attribute access like module.Type
elif isinstance(node, cst.Attribute):
if hasattr(node, "value"):
self._extract_names_from_annotation(node.value)
# No need to check the attribute name itself as it's likely not a top-level definition
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.function_depth -= 1
if self.function_depth == 0 and self.class_depth == 0:
# Exiting top-level function that's not in a class
self.current_top_level_name = ""
def visit_ClassDef(self, node: cst.ClassDef) -> None:
class_name = node.name.value
if self.class_depth == 0:
# This is a top-level class
self.current_class = class_name
self.current_top_level_name = class_name
self.class_depth += 1
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.class_depth -= 1
if self.class_depth == 0:
# Exiting top-level class
self.current_class = ""
self.current_top_level_name = ""
def visit_Assign(self, node: cst.Assign) -> None:
# Only handle top-level assignments
if self.function_depth == 0 and self.class_depth == 0:
for target in node.targets:
# Extract all variable names from the target
names = extract_names_from_targets(target.target)
# Check if any of these names are top-level definitions we're tracking
tracked_names = [name for name in names if name in self.definitions]
if tracked_names:
self.processing_variable = True
self.current_variable_names.update(tracked_names)
# Use the first tracked name as the current top-level name (for dependency tracking)
self.current_top_level_name = tracked_names[0]
def leave_Assign(self, original_node: cst.Assign) -> None:
if self.processing_variable:
self.processing_variable = False
self.current_variable_names.clear()
self.current_top_level_name = ""
def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
# Extract names from the variable annotations
if hasattr(node, "annotation") and node.annotation:
# First mark we're processing a variable to avoid recording it as a dependency of itself
self.processing_variable = True
if isinstance(node.target, cst.Name):
self.current_variable_names.add(node.target.value)
else:
self.current_variable_names.update(extract_names_from_targets(node.target))
# Process the annotation
self._collect_annotation_dependencies(node.annotation)
# Reset processing state
self.processing_variable = False
self.current_variable_names.clear()
def visit_Name(self, node: cst.Name) -> None:
name = node.value
# Skip if we're not inside a tracked definition
if not self.current_top_level_name or self.current_top_level_name not in self.definitions:
return
# Skip if we're looking at the variable name itself in an assignment
if self.processing_variable and name in self.current_variable_names:
return
# Check if name is a top-level definition we're tracking
if name in self.definitions and name != self.current_top_level_name:
self.definitions[self.current_top_level_name].dependencies.add(name)
class QualifiedFunctionUsageMarker:
"""Marks definitions that are used by specific qualified functions."""
def __init__(self, definitions: dict[str, UsageInfo], qualified_function_names: set[str]) -> None:
self.definitions = definitions
self.qualified_function_names = qualified_function_names
self.expanded_qualified_functions = self._expand_qualified_functions()
def _expand_qualified_functions(self) -> set[str]:
"""Expand the qualified function names to include related methods."""
expanded = set(self.qualified_function_names)
# Find class methods and add their containing classes and dunder methods
for qualified_name in list(self.qualified_function_names):
if "." in qualified_name:
class_name, method_name = qualified_name.split(".", 1)
# Add the class itself
expanded.add(class_name)
# Add all dunder methods of the class
for name in self.definitions:
if name.startswith(f"{class_name}.__") and name.endswith("__"):
expanded.add(name)
return expanded
def mark_used_definitions(self) -> None:
"""Find all qualified functions and mark them and their dependencies as used."""
# First identify all specified functions (including expanded ones)
functions_to_mark = [name for name in self.expanded_qualified_functions if name in self.definitions]
# For each specified function, mark it and all its dependencies as used
for func_name in functions_to_mark:
self.definitions[func_name].used_by_qualified_function = True
for dep in self.definitions[func_name].dependencies:
self.mark_as_used_recursively(dep)
def mark_as_used_recursively(self, name: str) -> None:
"""Mark a name and all its dependencies as used recursively."""
if name not in self.definitions:
return
if self.definitions[name].used_by_qualified_function:
return # Already marked
self.definitions[name].used_by_qualified_function = True
# Mark all dependencies as used
for dep in self.definitions[name].dependencies:
self.mark_as_used_recursively(dep)
def remove_unused_definitions_recursively(
node: cst.CSTNode, definitions: dict[str, UsageInfo]
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node to remove unused definitions.
Args:
node: The CST node to process
definitions: Dictionary of definition info
Returns:
(filtered_node, used_by_function):
filtered_node: The modified CST node or None if it should be removed
used_by_function: True if this node or any child is used by qualified functions
"""
# Skip import statements
if isinstance(node, cst.Import | cst.ImportFrom):
return node, True
# Never remove function definitions
if isinstance(node, cst.FunctionDef):
return node, True
# Never remove class definitions
if isinstance(node, cst.ClassDef):
class_name = node.name.value
# Check if any methods or variables in this class are used
method_or_var_used = False
class_has_dependencies = False
# Check if class itself is marked as used
if class_name in definitions and definitions[class_name].used_by_qualified_function:
class_has_dependencies = True
if hasattr(node, "body") and isinstance(node.body, cst.IndentedBlock):
updates = {}
new_statements = []
for statement in node.body.body:
# Keep all function definitions
if isinstance(statement, cst.FunctionDef):
method_name = f"{class_name}.{statement.name.value}"
if method_name in definitions and definitions[method_name].used_by_qualified_function:
method_or_var_used = True
new_statements.append(statement)
# Only process variable assignments
elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)):
var_used = False
# Check if any variable in this assignment is used
if isinstance(statement, cst.Assign):
for target in statement.targets:
names = extract_names_from_targets(target.target)
for name in names:
class_var_name = f"{class_name}.{name}"
if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function:
var_used = True
method_or_var_used = True
break
elif isinstance(statement, (cst.AnnAssign, cst.AugAssign)):
names = extract_names_from_targets(statement.target)
for name in names:
class_var_name = f"{class_name}.{name}"
if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function:
var_used = True
method_or_var_used = True
break
if var_used or class_has_dependencies:
new_statements.append(statement)
else:
# Keep all other statements in the class
new_statements.append(statement)
# Update the class body
new_body = node.body.with_changes(body=new_statements)
updates["body"] = new_body
return node.with_changes(**updates), True
return node, method_or_var_used or class_has_dependencies
# Handle assignments (Assign and AnnAssign)
if isinstance(node, cst.Assign):
for target in node.targets:
names = extract_names_from_targets(target.target)
for name in names:
if name in definitions and definitions[name].used_by_qualified_function:
return node, True
return None, False
if isinstance(node, cst.AnnAssign | cst.AugAssign):
names = extract_names_from_targets(node.target)
for name in names:
if name in definitions and definitions[name].used_by_qualified_function:
return node, True
return None, False
# For other nodes, recursively process children
section_names = get_section_names(node)
if not section_names:
return node, False
updates = {}
found_used = False
for section in section_names:
original_content = getattr(node, section, None)
if isinstance(original_content, list | tuple):
new_children = []
section_found_used = False
for child in original_content:
filtered, used = remove_unused_definitions_recursively(child, definitions)
if filtered:
new_children.append(filtered)
section_found_used |= used
if new_children or section_found_used:
found_used |= section_found_used
updates[section] = new_children
elif original_content is not None:
filtered, used = remove_unused_definitions_recursively(original_content, definitions)
found_used |= used
if filtered:
updates[section] = filtered
if not found_used:
return None, False
if updates:
return node.with_changes(**updates), found_used
return node, False
def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str:
"""Analyze a file and remove top level definitions not used by specified functions.
Top level definitions, in this context, are only classes, variables or functions.
If a class is referenced by a qualified function, we keep the entire class.
Args:
code: The code to process
qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname'
"""
module = cst.parse_module(code)
# Collect all definitions (top level classes, variables or function)
definitions = collect_top_level_definitions(module)
# Collect dependencies between definitions using the visitor pattern
dependency_collector = DependencyCollector(definitions)
module.visit(dependency_collector)
# Mark definitions used by specified functions, and their dependencies recursively
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
usage_marker.mark_used_definitions()
# Apply the recursive removal transformation
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
return modified_module.code if modified_module else ""
def print_definitions(definitions: dict[str, UsageInfo]) -> None:
"""Print information about each definition without the complex node object, used for debugging."""
print(f"Found {len(definitions)} definitions:")
for name, info in sorted(definitions.items()):
print(f" - Name: {name}")
print(f" Used by qualified function: {info.used_by_qualified_function}")
print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}")
print()

View file

@ -929,9 +929,6 @@ def fetch_and_process_data():
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
GLOBAL_VAR = 10
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
@ -941,11 +938,6 @@ class DataProcessor:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
```python:{path_to_file.relative_to(project_root)}
if __name__ == "__main__":
result = fetch_and_process_data()
print("Processed data:", result)
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
@ -1006,9 +998,6 @@ def fetch_and_transform_data():
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
GLOBAL_VAR = 10
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
@ -1018,11 +1007,6 @@ class DataProcessor:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
```python:{path_to_file.relative_to(project_root)}
if __name__ == "__main__":
result = fetch_and_process_data()
print("Processed data:", result)
```
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
@ -1084,9 +1068,6 @@ class DataTransformer:
return self.data
```
```python:{path_to_utils.relative_to(project_root)}
GLOBAL_VAR = 10
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
@ -1147,9 +1128,6 @@ def update_data(data):
return data + " updated"
```
```python:{path_to_utils.relative_to(project_root)}
GLOBAL_VAR = 10
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
@ -1252,9 +1230,6 @@ class DataTransformer:
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
GLOBAL_VAR = 10
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
@ -1322,4 +1297,20 @@ def outside_method():
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
def test_direct_module_import() -> None:
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
path_to_main = project_root / "main.py"
path_to_fto = project_root / "import_test.py"
function_to_optimize = FunctionToOptimize(
function_name="function_to_optimize",
file_path=str(path_to_fto),
parents=[],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
print(read_only_context.strip())

View file

@ -0,0 +1,416 @@
import libcst as cst
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
def test_variable_removal_only() -> None:
"""Test that only variables not used by specified functions are removed, not functions."""
code = """
def main_function():
return USED_CONSTANT + 10
def helper_function():
return 42
USED_CONSTANT = 42
UNUSED_CONSTANT = 123
def another_function():
return UNUSED_CONSTANT
"""
expected = """
def main_function():
return USED_CONSTANT + 10
def helper_function():
return 42
USED_CONSTANT = 42
def another_function():
return UNUSED_CONSTANT
"""
qualified_functions = {"main_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Normalize whitespace for comparison
assert result.strip() == expected.strip()
def test_class_variable_removal() -> None:
"""Test that only class variables not used by specified functions are removed, not methods."""
code = """
class MyClass:
CLASS_USED = "used value"
CLASS_UNUSED = "unused value"
def __init__(self):
self.value = self.CLASS_USED
self.other = self.CLASS_UNUSED
def used_method(self):
return self.value
def unused_method(self):
return "Not used but not removed"
GLOBAL_USED = "global used"
GLOBAL_UNUSED = "global unused"
def helper_function():
return MyClass().used_method() + GLOBAL_USED
"""
expected = """
class MyClass:
CLASS_USED = "used value"
CLASS_UNUSED = "unused value"
def __init__(self):
self.value = self.CLASS_USED
self.other = self.CLASS_UNUSED
def used_method(self):
return self.value
def unused_method(self):
return "Not used but not removed"
GLOBAL_USED = "global used"
def helper_function():
return MyClass().used_method() + GLOBAL_USED
"""
qualified_functions = {"helper_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Normalize whitespace for comparison
assert result.strip() == expected.strip()
def test_complex_variable_dependencies() -> None:
"""Test that only variables with complex dependencies are properly handled."""
code = """
def main_function():
return DIRECT_DEPENDENCY
def unused_function():
return "Not used but not removed"
DIRECT_DEPENDENCY = INDIRECT_DEPENDENCY + "_suffix"
INDIRECT_DEPENDENCY = "base value"
UNUSED_VARIABLE = "This should be removed"
TUPLE_USED, TUPLE_UNUSED = ("used", "unused")
def tuple_user():
return TUPLE_USED
"""
expected = """
def main_function():
return DIRECT_DEPENDENCY
def unused_function():
return "Not used but not removed"
DIRECT_DEPENDENCY = INDIRECT_DEPENDENCY + "_suffix"
INDIRECT_DEPENDENCY = "base value"
def tuple_user():
return TUPLE_USED
"""
qualified_functions = {"main_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.strip() == expected.strip()
def test_type_annotation_usage() -> None:
"""Test that variables used in type annotations are considered used."""
code = """
# Type definition
CustomType = int
UnusedType = str
def main_function(param: CustomType) -> CustomType:
return param + 10
def unused_function(param: UnusedType) -> UnusedType:
return param + " suffix"
UNUSED_CONSTANT = 123
"""
expected = """
# Type definition
CustomType = int
def main_function(param: CustomType) -> CustomType:
return param + 10
def unused_function(param: UnusedType) -> UnusedType:
return param + " suffix"
"""
qualified_functions = {"main_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Normalize whitespace for comparison
assert result.strip() == expected.strip()
def test_class_method_with_dunder_methods() -> None:
"""Test that when a class method is used, dunder methods of that class are preserved."""
code = """
class MyClass:
CLASS_VAR = "class variable"
UNUSED_VAR = GLOBAL_VAR_2
def __init__(self, value):
self.value = GLOBAL_VAR
def __str__(self):
return f"MyClass({self.value})"
def target_method(self):
return self.value * 2
def unused_method(self):
return "Not used"
GLOBAL_VAR = "global"
GLOBAL_VAR_2 = "global"
UNUSED_GLOBAL = "unused global"
def helper_function():
obj = MyClass(5)
return obj.target_method()
"""
expected = """
class MyClass:
CLASS_VAR = "class variable"
UNUSED_VAR = GLOBAL_VAR_2
def __init__(self, value):
self.value = GLOBAL_VAR
def __str__(self):
return f"MyClass({self.value})"
def target_method(self):
return self.value * 2
def unused_method(self):
return "Not used"
GLOBAL_VAR = "global"
GLOBAL_VAR_2 = "global"
def helper_function():
obj = MyClass(5)
return obj.target_method()
"""
qualified_functions = {"MyClass.target_method"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Normalize whitespace for comparison
assert result.strip() == expected.strip()
def test_complex_type_annotations() -> None:
"""Test complex type annotations with nested types."""
code = """
from typing import List, Dict, Optional
# Type aliases
ItemType = Dict[str, int]
ResultType = List[ItemType]
UnusedType = Optional[str]
def process_data(items: ResultType) -> int:
total = 0
for item in items:
for key, value in item.items():
total += value
return total
def unused_function(param: UnusedType) -> None:
pass
# Variables
SAMPLE_DATA: ResultType = [{"a": 1, "b": 2}]
UNUSED_DATA: UnusedType = None
"""
expected = """
from typing import List, Dict, Optional
# Type aliases
ItemType = Dict[str, int]
ResultType = List[ItemType]
def process_data(items: ResultType) -> int:
total = 0
for item in items:
for key, value in item.items():
total += value
return total
def unused_function(param: UnusedType) -> None:
pass
"""
qualified_functions = {"process_data"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.strip() == expected.strip()
def test_try_except_finally_variables() -> None:
"""Test handling of variables defined in try-except-finally blocks."""
code = """
import math
import os
# Top-level try-except that defines variables
try:
MATH_CONSTANT = math.pi
USED_ERROR_MSG = "An error occurred"
UNUSED_CONST = 42
except ImportError:
MATH_CONSTANT = 3.14
USED_ERROR_MSG = "Math module not available"
UNUSED_CONST = 0
finally:
CLEANUP_FLAG = True
UNUSED_CLEANUP = "Not used"
def use_constants():
return f"Pi is approximately {MATH_CONSTANT}, message: {USED_ERROR_MSG}"
def use_cleanup():
if CLEANUP_FLAG:
return "Cleanup performed"
return "No cleanup"
def unused_function():
return UNUSED_CONST
"""
expected = """
import math
import os
# Top-level try-except that defines variables
try:
MATH_CONSTANT = math.pi
USED_ERROR_MSG = "An error occurred"
except ImportError:
MATH_CONSTANT = 3.14
USED_ERROR_MSG = "Math module not available"
finally:
CLEANUP_FLAG = True
def use_constants():
return f"Pi is approximately {MATH_CONSTANT}, message: {USED_ERROR_MSG}"
def use_cleanup():
if CLEANUP_FLAG:
return "Cleanup performed"
return "No cleanup"
def unused_function():
return UNUSED_CONST
"""
qualified_functions = {"use_constants", "use_cleanup"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.strip() == expected.strip()
def test_conditional_and_loop_variables() -> None:
"""Test handling of variables defined in if-else and while loops."""
code = """
import sys
import platform
# Top-level if-else block defining variables
if sys.platform.startswith('win'):
OS_TYPE = "Windows"
OS_SEP = ""
UNUSED_WIN_VAR = "Unused Windows variable"
elif sys.platform.startswith('linux'):
OS_TYPE = "Linux"
OS_SEP = "/"
UNUSED_LINUX_VAR = "Unused Linux variable"
else:
OS_TYPE = "Other"
OS_SEP = "/"
UNUSED_OTHER_VAR = "Unused other variable"
# While loop with variable definitions
counter = 0
while counter < 5:
LOOP_RESULT = "Iteration " + str(counter)
UNUSED_LOOP_VAR = "Unused loop " + str(counter)
counter += 1
def get_platform_info():
return "OS: " + OS_TYPE + ", Separator: " + OS_SEP
def get_loop_result():
return LOOP_RESULT
def unused_function():
result = ""
if sys.platform.startswith('win'):
result = UNUSED_WIN_VAR
elif sys.platform.startswith('linux'):
result = UNUSED_LINUX_VAR
else:
result = UNUSED_OTHER_VAR
return result
"""
expected = """
import sys
import platform
# Top-level if-else block defining variables
if sys.platform.startswith('win'):
OS_TYPE = "Windows"
OS_SEP = ""
elif sys.platform.startswith('linux'):
OS_TYPE = "Linux"
OS_SEP = "/"
else:
OS_TYPE = "Other"
OS_SEP = "/"
# While loop with variable definitions
counter = 0
while counter < 5:
LOOP_RESULT = "Iteration " + str(counter)
counter += 1
def get_platform_info():
return "OS: " + OS_TYPE + ", Separator: " + OS_SEP
def get_loop_result():
return LOOP_RESULT
def unused_function():
result = ""
if sys.platform.startswith('win'):
result = UNUSED_WIN_VAR
elif sys.platform.startswith('linux'):
result = UNUSED_LINUX_VAR
else:
result = UNUSED_OTHER_VAR
return result
"""
qualified_functions = {"get_platform_info", "get_loop_result"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.strip() == expected.strip()