codeflash/tests/test_remove_unused_definitions.py
Kevin Turcios 5671562da2 perf: eliminate redundant CST parsing in get_code_optimization_context
Parse each file once instead of up to 16 times by:
- Making remove_unused_definitions_by_function_names accept/return cst.Module
- Making parse_code_and_prune_cst and add_needed_imports_from_module accept cst.Module
- Threading the parsed Module through process_file_context
- Adding extract_all_contexts_from_files that processes all 4 context types
  (READ_WRITABLE, READ_ONLY, HASHING, TESTGEN) in a single per-file pass
2026-03-16 10:11:58 -06:00

558 lines
14 KiB
Python

from codeflash.languages.python.context.unused_definition_remover import remove_unused_definitions_by_function_names
def test_variable_removal_only() -> None:
"""Test that only variables not used by specified functions are removed, not functions."""
code = """
def main_function():
return USED_CONSTANT + 10
def helper_function():
return 42
USED_CONSTANT = 42
UNUSED_CONSTANT = 123
def another_function():
return UNUSED_CONSTANT
"""
expected = """
def main_function():
return USED_CONSTANT + 10
def helper_function():
return 42
USED_CONSTANT = 42
def another_function():
return UNUSED_CONSTANT
"""
qualified_functions = {"main_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Normalize whitespace for comparison
assert result.code.strip() == expected.strip()
def test_class_variable_removal() -> None:
"""Test that only class variables not used by specified functions are removed, not methods."""
code = """
class MyClass:
CLASS_USED = "used value"
CLASS_UNUSED = "unused value"
def __init__(self):
self.value = self.CLASS_USED
self.other = self.CLASS_UNUSED
def used_method(self):
return self.value
def unused_method(self):
return "Not used but not removed"
GLOBAL_USED = "global used"
GLOBAL_UNUSED = "global unused"
def helper_function():
return MyClass().used_method() + GLOBAL_USED
"""
expected = """
class MyClass:
CLASS_USED = "used value"
CLASS_UNUSED = "unused value"
def __init__(self):
self.value = self.CLASS_USED
self.other = self.CLASS_UNUSED
def used_method(self):
return self.value
def unused_method(self):
return "Not used but not removed"
GLOBAL_USED = "global used"
def helper_function():
return MyClass().used_method() + GLOBAL_USED
"""
qualified_functions = {"helper_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Normalize whitespace for comparison
assert result.code.strip() == expected.strip()
def test_complex_variable_dependencies() -> None:
"""Test that only variables with complex dependencies are properly handled."""
code = """
def main_function():
return DIRECT_DEPENDENCY
def unused_function():
return "Not used but not removed"
DIRECT_DEPENDENCY = INDIRECT_DEPENDENCY + "_suffix"
INDIRECT_DEPENDENCY = "base value"
UNUSED_VARIABLE = "This should be removed"
TUPLE_USED, TUPLE_UNUSED = ("used", "unused")
def tuple_user():
return TUPLE_USED
"""
expected = """
def main_function():
return DIRECT_DEPENDENCY
def unused_function():
return "Not used but not removed"
DIRECT_DEPENDENCY = INDIRECT_DEPENDENCY + "_suffix"
INDIRECT_DEPENDENCY = "base value"
def tuple_user():
return TUPLE_USED
"""
qualified_functions = {"main_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.code.strip() == expected.strip()
def test_type_annotation_usage() -> None:
"""Test that variables used in type annotations are considered used."""
code = """
# Type definition
CustomType = int
UnusedType = str
def main_function(param: CustomType) -> CustomType:
return param + 10
def unused_function(param: UnusedType) -> UnusedType:
return param + " suffix"
UNUSED_CONSTANT = 123
"""
expected = """
# Type definition
CustomType = int
def main_function(param: CustomType) -> CustomType:
return param + 10
def unused_function(param: UnusedType) -> UnusedType:
return param + " suffix"
"""
qualified_functions = {"main_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Normalize whitespace for comparison
assert result.code.strip() == expected.strip()
def test_class_method_with_dunder_methods() -> None:
"""Test that when a class method is used, dunder methods of that class are preserved."""
code = """
class MyClass:
CLASS_VAR = "class variable"
UNUSED_VAR = GLOBAL_VAR_2
def __init__(self, value):
self.value = GLOBAL_VAR
def __str__(self):
return f"MyClass({self.value})"
def target_method(self):
return self.value * 2
def unused_method(self):
return "Not used"
GLOBAL_VAR = "global"
GLOBAL_VAR_2 = "global"
UNUSED_GLOBAL = "unused global"
def helper_function():
obj = MyClass(5)
return obj.target_method()
"""
expected = """
class MyClass:
CLASS_VAR = "class variable"
UNUSED_VAR = GLOBAL_VAR_2
def __init__(self, value):
self.value = GLOBAL_VAR
def __str__(self):
return f"MyClass({self.value})"
def target_method(self):
return self.value * 2
def unused_method(self):
return "Not used"
GLOBAL_VAR = "global"
GLOBAL_VAR_2 = "global"
def helper_function():
obj = MyClass(5)
return obj.target_method()
"""
qualified_functions = {"MyClass.target_method"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Normalize whitespace for comparison
assert result.code.strip() == expected.strip()
def test_complex_type_annotations() -> None:
"""Test complex type annotations with nested types."""
code = """
from typing import List, Dict, Optional
# Type aliases
ItemType = Dict[str, int]
ResultType = List[ItemType]
UnusedType = Optional[str]
def process_data(items: ResultType) -> int:
total = 0
for item in items:
for key, value in item.items():
total += value
return total
def unused_function(param: UnusedType) -> None:
pass
# Variables
SAMPLE_DATA: ResultType = [{"a": 1, "b": 2}]
UNUSED_DATA: UnusedType = None
"""
expected = """
from typing import List, Dict, Optional
# Type aliases
ItemType = Dict[str, int]
ResultType = List[ItemType]
def process_data(items: ResultType) -> int:
total = 0
for item in items:
for key, value in item.items():
total += value
return total
def unused_function(param: UnusedType) -> None:
pass
"""
qualified_functions = {"process_data"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.code.strip() == expected.strip()
def test_try_except_finally_variables() -> None:
"""Test handling of variables defined in try-except-finally blocks."""
code = """
import math
import os
# Top-level try-except that defines variables
try:
MATH_CONSTANT = math.pi
USED_ERROR_MSG = "An error occurred"
UNUSED_CONST = 42
except ImportError:
MATH_CONSTANT = 3.14
USED_ERROR_MSG = "Math module not available"
UNUSED_CONST = 0
finally:
CLEANUP_FLAG = True
UNUSED_CLEANUP = "Not used"
def use_constants():
return f"Pi is approximately {MATH_CONSTANT}, message: {USED_ERROR_MSG}"
def use_cleanup():
if CLEANUP_FLAG:
return "Cleanup performed"
return "No cleanup"
def unused_function():
return UNUSED_CONST
"""
expected = """
import math
import os
# Top-level try-except that defines variables
try:
MATH_CONSTANT = math.pi
USED_ERROR_MSG = "An error occurred"
except ImportError:
MATH_CONSTANT = 3.14
USED_ERROR_MSG = "Math module not available"
finally:
CLEANUP_FLAG = True
def use_constants():
return f"Pi is approximately {MATH_CONSTANT}, message: {USED_ERROR_MSG}"
def use_cleanup():
if CLEANUP_FLAG:
return "Cleanup performed"
return "No cleanup"
def unused_function():
return UNUSED_CONST
"""
qualified_functions = {"use_constants", "use_cleanup"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.code.strip() == expected.strip()
def test_base_class_inheritance() -> None:
"""Test that base classes used only for inheritance are preserved."""
code = """
class LayoutDumper:
def dump(self):
raise NotImplementedError
class ObjectDetectionLayoutDumper(LayoutDumper):
def __init__(self, data):
self.data = data
def dump(self):
return self.data
class ExtractedLayoutDumper(LayoutDumper):
def __init__(self, data):
self.data = data
def dump(self):
return self.data
class UnusedClass:
pass
def test_function():
dumper = ObjectDetectionLayoutDumper({})
return dumper.dump()
"""
expected = """
class LayoutDumper:
def dump(self):
raise NotImplementedError
class ObjectDetectionLayoutDumper(LayoutDumper):
def __init__(self, data):
self.data = data
def dump(self):
return self.data
class ExtractedLayoutDumper(LayoutDumper):
def __init__(self, data):
self.data = data
def dump(self):
return self.data
class UnusedClass:
pass
def test_function():
dumper = ObjectDetectionLayoutDumper({})
return dumper.dump()
"""
qualified_functions = {"test_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# LayoutDumper should be preserved because ObjectDetectionLayoutDumper inherits from it
assert "class LayoutDumper" in result.code
assert "class ObjectDetectionLayoutDumper" in result.code
assert result.code.strip() == expected.strip()
def test_conditional_and_loop_variables() -> None:
"""Test handling of variables defined in if-else and while loops."""
code = """
import sys
import platform
# Top-level if-else block defining variables
if sys.platform.startswith('win'):
OS_TYPE = "Windows"
OS_SEP = ""
UNUSED_WIN_VAR = "Unused Windows variable"
elif sys.platform.startswith('linux'):
OS_TYPE = "Linux"
OS_SEP = "/"
UNUSED_LINUX_VAR = "Unused Linux variable"
else:
OS_TYPE = "Other"
OS_SEP = "/"
UNUSED_OTHER_VAR = "Unused other variable"
# While loop with variable definitions
counter = 0
while counter < 5:
LOOP_RESULT = "Iteration " + str(counter)
UNUSED_LOOP_VAR = "Unused loop " + str(counter)
counter += 1
def get_platform_info():
return "OS: " + OS_TYPE + ", Separator: " + OS_SEP
def get_loop_result():
return LOOP_RESULT
def unused_function():
result = ""
if sys.platform.startswith('win'):
result = UNUSED_WIN_VAR
elif sys.platform.startswith('linux'):
result = UNUSED_LINUX_VAR
else:
result = UNUSED_OTHER_VAR
return result
"""
expected = """
import sys
import platform
# Top-level if-else block defining variables
if sys.platform.startswith('win'):
OS_TYPE = "Windows"
OS_SEP = ""
elif sys.platform.startswith('linux'):
OS_TYPE = "Linux"
OS_SEP = "/"
else:
OS_TYPE = "Other"
OS_SEP = "/"
# While loop with variable definitions
counter = 0
while counter < 5:
LOOP_RESULT = "Iteration " + str(counter)
counter += 1
def get_platform_info():
return "OS: " + OS_TYPE + ", Separator: " + OS_SEP
def get_loop_result():
return LOOP_RESULT
def unused_function():
result = ""
if sys.platform.startswith('win'):
result = UNUSED_WIN_VAR
elif sys.platform.startswith('linux'):
result = UNUSED_LINUX_VAR
else:
result = UNUSED_OTHER_VAR
return result
"""
qualified_functions = {"get_platform_info", "get_loop_result"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.code.strip() == expected.strip()
def test_enum_attribute_access_dependency() -> None:
"""Test that enum/class attribute access like MessageKind.VALUE is tracked as a dependency."""
code = """
from enum import Enum
class MessageKind(Enum):
VALUE = "value"
OTHER = "other"
class UnusedEnum(Enum):
UNUSED = "unused"
UNUSED_VAR = 123
def process_message(kind):
match kind:
case MessageKind.VALUE:
return "got value"
case MessageKind.OTHER:
return "got other"
return "unknown"
"""
expected = """
from enum import Enum
class MessageKind(Enum):
VALUE = "value"
OTHER = "other"
class UnusedEnum(Enum):
UNUSED = "unused"
def process_message(kind):
match kind:
case MessageKind.VALUE:
return "got value"
case MessageKind.OTHER:
return "got other"
return "unknown"
"""
qualified_functions = {"process_message"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# MessageKind should be preserved because process_message uses MessageKind.VALUE
assert "class MessageKind" in result.code
# UNUSED_VAR should be removed
assert "UNUSED_VAR" not in result.code
assert result.code.strip() == expected.strip()
def test_attribute_access_does_not_track_attr_name() -> None:
"""Test that self.x attribute access doesn't track 'x' as a dependency on module-level x."""
code = """
x = "module_level_x"
UNUSED_VAR = "unused"
class MyClass:
def __init__(self):
self.x = 1 # This 'x' is an attribute, not a reference to module-level 'x'
def get_x(self):
return self.x # This 'x' is also an attribute access
"""
expected = """
class MyClass:
def __init__(self):
self.x = 1 # This 'x' is an attribute, not a reference to module-level 'x'
def get_x(self):
return self.x # This 'x' is also an attribute access
"""
qualified_functions = {"MyClass.get_x", "MyClass.__init__"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Module-level x should NOT be kept (self.x doesn't reference it)
assert 'x = "module_level_x"' not in result.code
# UNUSED_VAR should also be removed
assert "UNUSED_VAR" not in result.code
assert result.code.strip() == expected.strip()