test: add tests for base class extraction in imported dataclasses

- Update test_get_imported_class_definitions_includes_dataclass_decorators
  to expect both base class and derived class to be extracted
- Add test_get_imported_class_definitions_extracts_multilevel_inheritance
  to verify multi-level inheritance chains are fully extracted
This commit is contained in:
Kevin Turcios 2026-01-23 07:26:03 -05:00
parent 4e310324b4
commit ebf77033ba

View file

@ -3135,21 +3135,26 @@ class ConfigRegistry:
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
# Should extract LLMConfig
assert len(result.code_strings) == 1, "Should extract exactly one class (LLMConfig)"
extracted_code = result.code_strings[0].code
# Should extract both LLMConfigBase (base class) and LLMConfig
assert len(result.code_strings) == 2, "Should extract both LLMConfig and its base class LLMConfigBase"
# Verify the extracted code includes the @dataclass decorator
assert "@dataclass(frozen=True)" in extracted_code, (
"Should include @dataclass decorator - this is critical for LLM to understand constructor"
# Combine extracted code to check for all required elements
all_extracted_code = "\n".join(cs.code for cs in result.code_strings)
# Verify the base class is extracted first (for proper inheritance understanding)
base_class_idx = all_extracted_code.find("class LLMConfigBase")
derived_class_idx = all_extracted_code.find("class LLMConfig(")
assert base_class_idx < derived_class_idx, "Base class should appear before derived class"
# Verify both classes include @dataclass decorators
assert all_extracted_code.count("@dataclass(frozen=True)") == 2, (
"Should include @dataclass decorator for both classes"
)
assert "class LLMConfig" in extracted_code, "Should contain LLMConfig class definition"
assert "class LLMConfig" in all_extracted_code, "Should contain LLMConfig class definition"
assert "class LLMConfigBase" in all_extracted_code, "Should contain LLMConfigBase class definition"
# Verify imports are included for dataclass-related items
assert "from dataclasses import" in extracted_code, "Should include dataclasses import"
assert "Optional" in extracted_code or "from typing import" in extracted_code, (
"Should include type annotation imports"
)
assert "from dataclasses import" in all_extracted_code, "Should include dataclasses import"
def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(tmp_path: Path) -> None:
@ -3399,3 +3404,99 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
assert "@total_ordering" in extracted_code, "Should include @total_ordering decorator"
assert "@dataclass" in extracted_code, "Should include @dataclass decorator"
assert "class OrderedConfig" in extracted_code
def test_get_imported_class_definitions_extracts_multilevel_inheritance(tmp_path: Path) -> None:
"""Test that base classes are recursively extracted for multi-level inheritance.
This is critical for understanding dataclass constructor signatures, as fields
from parent classes become required positional arguments in child classes.
"""
# Create a package structure
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with multi-level inheritance like skyvern's LLM models:
# GrandParent -> Parent -> Child
models_code = '''from dataclasses import dataclass, field
from typing import Optional, Literal
@dataclass(frozen=True)
class GrandParentConfig:
"""Base config with common fields."""
model_name: str
required_env_vars: list[str]
@dataclass(frozen=True)
class ParentConfig(GrandParentConfig):
"""Intermediate config adding vision support."""
supports_vision: bool
add_assistant_prefix: bool
@dataclass(frozen=True)
class ChildConfig(ParentConfig):
"""Full config with optional parameters."""
litellm_params: Optional[dict] = field(default=None)
max_tokens: int | None = None
temperature: float | None = 0.7
@dataclass(frozen=True)
class RouterConfig(ParentConfig):
"""Router config branching from ParentConfig."""
model_list: list
main_model_group: str
routing_strategy: Literal["simple", "least-busy"] = "simple"
'''
models_path = package_dir / "models.py"
models_path.write_text(models_code, encoding="utf-8")
# Create code that imports only the child classes (not the base classes)
code = '''from mypackage.models import ChildConfig, RouterConfig
class ConfigRegistry:
def get_child_config(self) -> ChildConfig:
pass
def get_router_config(self) -> RouterConfig:
pass
'''
code_path = package_dir / "registry.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
# Should extract 4 classes: GrandParentConfig, ParentConfig, ChildConfig, RouterConfig
# (all classes needed to understand the full inheritance hierarchy)
assert len(result.code_strings) == 4, (
f"Should extract 4 classes (GrandParent, Parent, Child, Router), got {len(result.code_strings)}"
)
# Combine extracted code
all_extracted_code = "\n".join(cs.code for cs in result.code_strings)
# Verify all classes are extracted
assert "class GrandParentConfig" in all_extracted_code, "Should extract GrandParentConfig base class"
assert "class ParentConfig(GrandParentConfig)" in all_extracted_code, "Should extract ParentConfig"
assert "class ChildConfig(ParentConfig)" in all_extracted_code, "Should extract ChildConfig"
assert "class RouterConfig(ParentConfig)" in all_extracted_code, "Should extract RouterConfig"
# Verify classes are ordered correctly (base classes before derived)
grandparent_idx = all_extracted_code.find("class GrandParentConfig")
parent_idx = all_extracted_code.find("class ParentConfig(")
child_idx = all_extracted_code.find("class ChildConfig(")
router_idx = all_extracted_code.find("class RouterConfig(")
assert grandparent_idx < parent_idx, "GrandParentConfig should appear before ParentConfig"
assert parent_idx < child_idx, "ParentConfig should appear before ChildConfig"
assert parent_idx < router_idx, "ParentConfig should appear before RouterConfig"
# Verify the critical fields are visible for constructor understanding
assert "model_name: str" in all_extracted_code, "Should include model_name field from GrandParent"
assert "required_env_vars: list[str]" in all_extracted_code, "Should include required_env_vars field"
assert "supports_vision: bool" in all_extracted_code, "Should include supports_vision field from Parent"
assert "litellm_params:" in all_extracted_code, "Should include litellm_params field from Child"
assert "model_list: list" in all_extracted_code, "Should include model_list field from Router"