fix: track base class dependencies in unused definition removal

Added base class dependency tracking as a secondary safeguard.
When a class inherits from another, the parent is now marked as
a dependency, preventing incorrect removal.
This commit is contained in:
Kevin Turcios 2026-01-01 02:02:57 -05:00
parent b1819ffcae
commit 7943f8d782
2 changed files with 80 additions and 0 deletions

View file

@ -172,6 +172,18 @@ class DependencyCollector(cst.CSTVisitor):
self.current_class = class_name
self.current_top_level_name = class_name
# Track base classes as dependencies
for base in node.bases:
if isinstance(base.value, cst.Name):
base_name = base.value.value
if base_name in self.definitions and class_name in self.definitions:
self.definitions[class_name].dependencies.add(base_name)
elif isinstance(base.value, cst.Attribute):
# Handle cases like module.ClassName
attr_name = base.value.attr.value
if attr_name in self.definitions and class_name in self.definitions:
self.definitions[class_name].dependencies.add(attr_name)
self.class_depth += 1
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:

View file

@ -1250,3 +1250,71 @@ class TestSorter(unittest.TestCase):
result = remove_unused_definitions_from_pytest_file(cst.parse_module(code)).code
assert result.strip() == expected.strip()
def test_abstract_base_class_inheritance() -> None:
"""Test that abstract base classes used only for inheritance are preserved.
This mimics the real-world case where LLM generates mock classes that inherit
from a base class, and the base class should not be removed even though it's
only referenced in the inheritance declaration.
"""
code = """
class LayoutDumper:
layout_source: str = "unknown"
def dump(self) -> dict:
raise NotImplementedError()
class ObjectDetectionLayoutDumper(LayoutDumper):
def __init__(self, layout):
self._layout = layout
def dump(self) -> dict:
return self._layout
class ExtractedLayoutDumper(LayoutDumper):
def __init__(self, layout):
self._layout = layout
def dump(self) -> dict:
return self._layout
class UnusedClass:
pass
def test_object_detection():
dumper = ObjectDetectionLayoutDumper({})
assert dumper.dump() == {}
def test_extracted():
dumper = ExtractedLayoutDumper({"text": "hello"})
assert dumper.dump() == {"text": "hello"}
"""
expected = """
class LayoutDumper:
layout_source: str = "unknown"
def dump(self) -> dict:
raise NotImplementedError()
class ObjectDetectionLayoutDumper(LayoutDumper):
def __init__(self, layout):
self._layout = layout
def dump(self) -> dict:
return self._layout
class ExtractedLayoutDumper(LayoutDumper):
def __init__(self, layout):
self._layout = layout
def dump(self) -> dict:
return self._layout
def test_object_detection():
dumper = ObjectDetectionLayoutDumper({})
assert dumper.dump() == {}
def test_extracted():
dumper = ExtractedLayoutDumper({"text": "hello"})
assert dumper.dump() == {"text": "hello"}
"""
result = remove_unused_definitions_from_pytest_file(cst.parse_module(code)).code
assert result.strip() == expected.strip()