mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
test: add tests for base class extraction in imported dataclasses
- Update test_get_imported_class_definitions_includes_dataclass_decorators to expect both base class and derived class to be extracted - Add test_get_imported_class_definitions_extracts_multilevel_inheritance to verify multi-level inheritance chains are fully extracted
This commit is contained in:
parent
4e310324b4
commit
ebf77033ba
1 changed files with 112 additions and 11 deletions
|
|
@ -3135,21 +3135,26 @@ class ConfigRegistry:
|
|||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
|
||||
# Should extract LLMConfig
|
||||
assert len(result.code_strings) == 1, "Should extract exactly one class (LLMConfig)"
|
||||
extracted_code = result.code_strings[0].code
|
||||
# Should extract both LLMConfigBase (base class) and LLMConfig
|
||||
assert len(result.code_strings) == 2, "Should extract both LLMConfig and its base class LLMConfigBase"
|
||||
|
||||
# Verify the extracted code includes the @dataclass decorator
|
||||
assert "@dataclass(frozen=True)" in extracted_code, (
|
||||
"Should include @dataclass decorator - this is critical for LLM to understand constructor"
|
||||
# Combine extracted code to check for all required elements
|
||||
all_extracted_code = "\n".join(cs.code for cs in result.code_strings)
|
||||
|
||||
# Verify the base class is extracted first (for proper inheritance understanding)
|
||||
base_class_idx = all_extracted_code.find("class LLMConfigBase")
|
||||
derived_class_idx = all_extracted_code.find("class LLMConfig(")
|
||||
assert base_class_idx < derived_class_idx, "Base class should appear before derived class"
|
||||
|
||||
# Verify both classes include @dataclass decorators
|
||||
assert all_extracted_code.count("@dataclass(frozen=True)") == 2, (
|
||||
"Should include @dataclass decorator for both classes"
|
||||
)
|
||||
assert "class LLMConfig" in extracted_code, "Should contain LLMConfig class definition"
|
||||
assert "class LLMConfig" in all_extracted_code, "Should contain LLMConfig class definition"
|
||||
assert "class LLMConfigBase" in all_extracted_code, "Should contain LLMConfigBase class definition"
|
||||
|
||||
# Verify imports are included for dataclass-related items
|
||||
assert "from dataclasses import" in extracted_code, "Should include dataclasses import"
|
||||
assert "Optional" in extracted_code or "from typing import" in extracted_code, (
|
||||
"Should include type annotation imports"
|
||||
)
|
||||
assert "from dataclasses import" in all_extracted_code, "Should include dataclasses import"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(tmp_path: Path) -> None:
|
||||
|
|
@ -3399,3 +3404,99 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
|
|||
assert "@total_ordering" in extracted_code, "Should include @total_ordering decorator"
|
||||
assert "@dataclass" in extracted_code, "Should include @dataclass decorator"
|
||||
assert "class OrderedConfig" in extracted_code
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_extracts_multilevel_inheritance(tmp_path: Path) -> None:
|
||||
"""Test that base classes are recursively extracted for multi-level inheritance.
|
||||
|
||||
This is critical for understanding dataclass constructor signatures, as fields
|
||||
from parent classes become required positional arguments in child classes.
|
||||
"""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# Create a module with multi-level inheritance like skyvern's LLM models:
|
||||
# GrandParent -> Parent -> Child
|
||||
models_code = '''from dataclasses import dataclass, field
|
||||
from typing import Optional, Literal
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GrandParentConfig:
|
||||
"""Base config with common fields."""
|
||||
model_name: str
|
||||
required_env_vars: list[str]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParentConfig(GrandParentConfig):
|
||||
"""Intermediate config adding vision support."""
|
||||
supports_vision: bool
|
||||
add_assistant_prefix: bool
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChildConfig(ParentConfig):
|
||||
"""Full config with optional parameters."""
|
||||
litellm_params: Optional[dict] = field(default=None)
|
||||
max_tokens: int | None = None
|
||||
temperature: float | None = 0.7
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouterConfig(ParentConfig):
|
||||
"""Router config branching from ParentConfig."""
|
||||
model_list: list
|
||||
main_model_group: str
|
||||
routing_strategy: Literal["simple", "least-busy"] = "simple"
|
||||
'''
|
||||
models_path = package_dir / "models.py"
|
||||
models_path.write_text(models_code, encoding="utf-8")
|
||||
|
||||
# Create code that imports only the child classes (not the base classes)
|
||||
code = '''from mypackage.models import ChildConfig, RouterConfig
|
||||
|
||||
class ConfigRegistry:
|
||||
def get_child_config(self) -> ChildConfig:
|
||||
pass
|
||||
|
||||
def get_router_config(self) -> RouterConfig:
|
||||
pass
|
||||
'''
|
||||
code_path = package_dir / "registry.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
|
||||
# Should extract 4 classes: GrandParentConfig, ParentConfig, ChildConfig, RouterConfig
|
||||
# (all classes needed to understand the full inheritance hierarchy)
|
||||
assert len(result.code_strings) == 4, (
|
||||
f"Should extract 4 classes (GrandParent, Parent, Child, Router), got {len(result.code_strings)}"
|
||||
)
|
||||
|
||||
# Combine extracted code
|
||||
all_extracted_code = "\n".join(cs.code for cs in result.code_strings)
|
||||
|
||||
# Verify all classes are extracted
|
||||
assert "class GrandParentConfig" in all_extracted_code, "Should extract GrandParentConfig base class"
|
||||
assert "class ParentConfig(GrandParentConfig)" in all_extracted_code, "Should extract ParentConfig"
|
||||
assert "class ChildConfig(ParentConfig)" in all_extracted_code, "Should extract ChildConfig"
|
||||
assert "class RouterConfig(ParentConfig)" in all_extracted_code, "Should extract RouterConfig"
|
||||
|
||||
# Verify classes are ordered correctly (base classes before derived)
|
||||
grandparent_idx = all_extracted_code.find("class GrandParentConfig")
|
||||
parent_idx = all_extracted_code.find("class ParentConfig(")
|
||||
child_idx = all_extracted_code.find("class ChildConfig(")
|
||||
router_idx = all_extracted_code.find("class RouterConfig(")
|
||||
|
||||
assert grandparent_idx < parent_idx, "GrandParentConfig should appear before ParentConfig"
|
||||
assert parent_idx < child_idx, "ParentConfig should appear before ChildConfig"
|
||||
assert parent_idx < router_idx, "ParentConfig should appear before RouterConfig"
|
||||
|
||||
# Verify the critical fields are visible for constructor understanding
|
||||
assert "model_name: str" in all_extracted_code, "Should include model_name field from GrandParent"
|
||||
assert "required_env_vars: list[str]" in all_extracted_code, "Should include required_env_vars field"
|
||||
assert "supports_vision: bool" in all_extracted_code, "Should include supports_vision field from Parent"
|
||||
assert "litellm_params:" in all_extracted_code, "Should include litellm_params field from Child"
|
||||
assert "model_list: list" in all_extracted_code, "Should include model_list field from Router"
|
||||
|
|
|
|||
Loading…
Reference in a new issue