fix: track base class dependencies in unused definition removal
Added base class dependency tracking as a secondary safeguard. When a class inherits from another, the parent is now marked as a dependency, preventing incorrect removal.
This commit is contained in:
parent
b1819ffcae
commit
7943f8d782
2 changed files with 80 additions and 0 deletions
|
|
@ -172,6 +172,18 @@ class DependencyCollector(cst.CSTVisitor):
|
|||
self.current_class = class_name
|
||||
self.current_top_level_name = class_name
|
||||
|
||||
# Track base classes as dependencies
|
||||
for base in node.bases:
|
||||
if isinstance(base.value, cst.Name):
|
||||
base_name = base.value.value
|
||||
if base_name in self.definitions and class_name in self.definitions:
|
||||
self.definitions[class_name].dependencies.add(base_name)
|
||||
elif isinstance(base.value, cst.Attribute):
|
||||
# Handle cases like module.ClassName
|
||||
attr_name = base.value.attr.value
|
||||
if attr_name in self.definitions and class_name in self.definitions:
|
||||
self.definitions[class_name].dependencies.add(attr_name)
|
||||
|
||||
self.class_depth += 1
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
|
|
|
|||
|
|
@ -1250,3 +1250,71 @@ class TestSorter(unittest.TestCase):
|
|||
|
||||
result = remove_unused_definitions_from_pytest_file(cst.parse_module(code)).code
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_abstract_base_class_inheritance() -> None:
|
||||
"""Test that abstract base classes used only for inheritance are preserved.
|
||||
|
||||
This mimics the real-world case where LLM generates mock classes that inherit
|
||||
from a base class, and the base class should not be removed even though it's
|
||||
only referenced in the inheritance declaration.
|
||||
"""
|
||||
code = """
|
||||
class LayoutDumper:
|
||||
layout_source: str = "unknown"
|
||||
def dump(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
class ObjectDetectionLayoutDumper(LayoutDumper):
|
||||
def __init__(self, layout):
|
||||
self._layout = layout
|
||||
def dump(self) -> dict:
|
||||
return self._layout
|
||||
|
||||
class ExtractedLayoutDumper(LayoutDumper):
|
||||
def __init__(self, layout):
|
||||
self._layout = layout
|
||||
def dump(self) -> dict:
|
||||
return self._layout
|
||||
|
||||
class UnusedClass:
|
||||
pass
|
||||
|
||||
def test_object_detection():
|
||||
dumper = ObjectDetectionLayoutDumper({})
|
||||
assert dumper.dump() == {}
|
||||
|
||||
def test_extracted():
|
||||
dumper = ExtractedLayoutDumper({"text": "hello"})
|
||||
assert dumper.dump() == {"text": "hello"}
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class LayoutDumper:
|
||||
layout_source: str = "unknown"
|
||||
def dump(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
class ObjectDetectionLayoutDumper(LayoutDumper):
|
||||
def __init__(self, layout):
|
||||
self._layout = layout
|
||||
def dump(self) -> dict:
|
||||
return self._layout
|
||||
|
||||
class ExtractedLayoutDumper(LayoutDumper):
|
||||
def __init__(self, layout):
|
||||
self._layout = layout
|
||||
def dump(self) -> dict:
|
||||
return self._layout
|
||||
|
||||
def test_object_detection():
|
||||
dumper = ObjectDetectionLayoutDumper({})
|
||||
assert dumper.dump() == {}
|
||||
|
||||
def test_extracted():
|
||||
dumper = ExtractedLayoutDumper({"text": "hello"})
|
||||
assert dumper.dump() == {"text": "hello"}
|
||||
"""
|
||||
|
||||
result = remove_unused_definitions_from_pytest_file(cst.parse_module(code)).code
|
||||
assert result.strip() == expected.strip()
|
||||
|
|
|
|||
Loading…
Reference in a new issue