mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
refactor: remove constructor notes preprocessing from testgen pipeline
Full class source is now included in the client-side testgen context, making the server-side constructor signature extraction redundant.
This commit is contained in:
parent
bfd9f2cd04
commit
ca71d0c8a0
5 changed files with 1 additions and 1291 deletions
|
|
@ -1,165 +0,0 @@
|
|||
"""Preprocessing step to extract constructor signatures from regular (non-dataclass) classes.
|
||||
|
||||
When the test context contains class definitions with explicit __init__ methods,
|
||||
this module extracts the constructor signatures to help the LLM understand how
|
||||
to properly instantiate them.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
import libcst as cst
|
||||
from libcst.metadata import MetadataWrapper
|
||||
|
||||
from aiservice.common.markdown_utils import extract_all_code_from_markdown
|
||||
from core.languages.python.cst_utils import get_node_source_text, has_decorator, parse_module_to_cst
|
||||
|
||||
|
||||
class ParamInfo(NamedTuple):
|
||||
name: str
|
||||
annotation: str | None
|
||||
has_default: bool
|
||||
default_repr: str | None
|
||||
|
||||
|
||||
def _find_init_method(class_node: cst.ClassDef) -> cst.FunctionDef | None:
|
||||
for stmt in class_node.body.body:
|
||||
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == "__init__":
|
||||
return stmt
|
||||
return None
|
||||
|
||||
|
||||
def extract_init_params(
|
||||
init_node: cst.FunctionDef, source_lines: list[str], wrapper: MetadataWrapper
|
||||
) -> list[ParamInfo]:
|
||||
params: list[ParamInfo] = []
|
||||
parameters = init_node.params
|
||||
|
||||
for param in parameters.params:
|
||||
name = param.name.value
|
||||
if name == "self":
|
||||
continue
|
||||
|
||||
annotation = None
|
||||
if param.annotation is not None:
|
||||
annotation = get_node_source_text(param.annotation, source_lines, wrapper)
|
||||
|
||||
default_repr = None
|
||||
if param.default is not None:
|
||||
default_repr = get_node_source_text(param.default, source_lines, wrapper)
|
||||
if len(default_repr) > 50:
|
||||
default_repr = default_repr[:47] + "..."
|
||||
|
||||
params.append(ParamInfo(name, annotation, has_default=param.default is not None, default_repr=default_repr))
|
||||
|
||||
if isinstance(parameters.star_arg, cst.Param):
|
||||
name = f"*{parameters.star_arg.name.value}"
|
||||
annotation = None
|
||||
if parameters.star_arg.annotation is not None:
|
||||
annotation = get_node_source_text(parameters.star_arg.annotation, source_lines, wrapper)
|
||||
params.append(ParamInfo(name, annotation, has_default=False, default_repr=None))
|
||||
|
||||
for param in parameters.kwonly_params:
|
||||
name = param.name.value
|
||||
annotation = None
|
||||
if param.annotation is not None:
|
||||
annotation = get_node_source_text(param.annotation, source_lines, wrapper)
|
||||
|
||||
default_repr = None
|
||||
if param.default is not None:
|
||||
default_repr = get_node_source_text(param.default, source_lines, wrapper)
|
||||
if len(default_repr) > 50:
|
||||
default_repr = default_repr[:47] + "..."
|
||||
|
||||
params.append(ParamInfo(name, annotation, has_default=param.default is not None, default_repr=default_repr))
|
||||
|
||||
if parameters.star_kwarg is not None:
|
||||
name = f"**{parameters.star_kwarg.name.value}"
|
||||
annotation = None
|
||||
if parameters.star_kwarg.annotation is not None:
|
||||
annotation = get_node_source_text(parameters.star_kwarg.annotation, source_lines, wrapper)
|
||||
params.append(ParamInfo(name, annotation, has_default=False, default_repr=None))
|
||||
|
||||
return params
|
||||
|
||||
|
||||
class ClassInitCollector(cst.CSTVisitor):
|
||||
def __init__(self) -> None:
|
||||
self.classes: dict[str, cst.ClassDef] = {}
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
if not has_decorator(node, "dataclass") and _find_init_method(node) is not None:
|
||||
self.classes[node.name.value] = node
|
||||
return True
|
||||
|
||||
|
||||
def _find_all_classes_with_init(source_code: str) -> tuple[dict[str, cst.ClassDef], list[str], MetadataWrapper | None]:
|
||||
try:
|
||||
tree = parse_module_to_cst(source_code)
|
||||
except cst.ParserSyntaxError:
|
||||
return {}, [], None
|
||||
|
||||
wrapper = MetadataWrapper(tree)
|
||||
collector = ClassInitCollector()
|
||||
wrapper.visit(collector)
|
||||
source_lines = source_code.split("\n")
|
||||
return collector.classes, source_lines, wrapper
|
||||
|
||||
|
||||
def format_init_signature(class_name: str, params: list[ParamInfo]) -> str:
|
||||
if not params:
|
||||
return f"{class_name}() - no parameters"
|
||||
|
||||
required = [p for p in params if not p.has_default and not p.name.startswith("*")]
|
||||
optional = [p for p in params if p.has_default]
|
||||
variadic = [p for p in params if p.name.startswith("*")]
|
||||
|
||||
lines = [f"Constructor signature for {class_name}:"]
|
||||
|
||||
if required:
|
||||
lines.append(" Required (positional) arguments:")
|
||||
for p in required:
|
||||
if p.annotation:
|
||||
lines.append(f" - {p.name}: {p.annotation}")
|
||||
else:
|
||||
lines.append(f" - {p.name}")
|
||||
|
||||
if optional:
|
||||
lines.append(" Optional (keyword) arguments:")
|
||||
for p in optional:
|
||||
default_note = f" = {p.default_repr}" if p.default_repr else " = ..."
|
||||
if p.annotation:
|
||||
lines.append(f" - {p.name}: {p.annotation}{default_note}")
|
||||
else:
|
||||
lines.append(f" - {p.name}{default_note}")
|
||||
|
||||
if variadic:
|
||||
lines.append(" Variadic arguments:")
|
||||
for p in variadic:
|
||||
if p.annotation:
|
||||
lines.append(f" - {p.name}: {p.annotation}")
|
||||
else:
|
||||
lines.append(f" - {p.name}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def get_class_constructor_notes(test_context: str) -> list[str]:
|
||||
code_to_analyze = extract_all_code_from_markdown(test_context) if "```python" in test_context else test_context
|
||||
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code_to_analyze)
|
||||
|
||||
if not classes or wrapper is None:
|
||||
return []
|
||||
|
||||
notes = []
|
||||
for class_name, class_node in classes.items():
|
||||
init_node = _find_init_method(class_node)
|
||||
if init_node is None:
|
||||
continue
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
signature = format_init_signature(class_name, params)
|
||||
notes.append(signature)
|
||||
|
||||
return notes
|
||||
|
|
@ -1,176 +0,0 @@
|
|||
"""Preprocessing step to extract dataclass constructor signatures.
|
||||
|
||||
When the test context contains dataclass definitions, this module extracts
|
||||
the constructor signatures to help the LLM understand how to properly
|
||||
instantiate them, including inherited fields from parent dataclasses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
import libcst as cst
|
||||
from libcst.metadata import MetadataWrapper
|
||||
|
||||
from core.languages.python.cst_utils import (
|
||||
get_base_class_name,
|
||||
get_node_source_text,
|
||||
has_decorator,
|
||||
parse_module_to_cst,
|
||||
)
|
||||
from aiservice.common.markdown_utils import extract_all_code_from_markdown
|
||||
|
||||
|
||||
class FieldInfo(NamedTuple):
|
||||
"""Information about a dataclass field."""
|
||||
|
||||
name: str
|
||||
annotation: str
|
||||
has_default: bool
|
||||
default_repr: str | None
|
||||
source: str # "inherited" or "own"
|
||||
|
||||
|
||||
def extract_dataclass_fields(
|
||||
class_node: cst.ClassDef, source_lines: list[str], wrapper: MetadataWrapper
|
||||
) -> list[FieldInfo]:
|
||||
"""Extract fields from a dataclass definition.
|
||||
|
||||
Returns list of FieldInfo with (field_name, type_annotation, has_default, default_value).
|
||||
"""
|
||||
fields: list[FieldInfo] = []
|
||||
for stmt in class_node.body.body:
|
||||
if not isinstance(stmt, cst.SimpleStatementLine):
|
||||
continue
|
||||
for node in stmt.body:
|
||||
if not (isinstance(node, cst.AnnAssign) and isinstance(node.target, cst.Name)):
|
||||
continue
|
||||
|
||||
field_name = node.target.value
|
||||
annotation = get_node_source_text(node.annotation, source_lines, wrapper)
|
||||
|
||||
default_repr = None
|
||||
if node.value is not None:
|
||||
default_repr = get_node_source_text(node.value, source_lines, wrapper)
|
||||
# Truncate long defaults
|
||||
if len(default_repr) > 50:
|
||||
default_repr = default_repr[:47] + "..."
|
||||
|
||||
fields.append(FieldInfo(field_name, annotation, node.value is not None, default_repr, "own"))
|
||||
return fields
|
||||
|
||||
|
||||
class DataclassCollector(cst.CSTVisitor):
|
||||
"""Visitor to collect all dataclass definitions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.dataclasses: dict[str, cst.ClassDef] = {}
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
if has_decorator(node, "dataclass"):
|
||||
self.dataclasses[node.name.value] = node
|
||||
return True # Continue visiting nested classes
|
||||
|
||||
|
||||
def find_all_dataclasses(source_code: str) -> tuple[dict[str, cst.ClassDef], list[str], MetadataWrapper | None]:
|
||||
"""Find all dataclass definitions in source code.
|
||||
|
||||
Returns tuple of (dict mapping class name to class_node, source_lines, metadata_wrapper).
|
||||
"""
|
||||
try:
|
||||
tree = parse_module_to_cst(source_code)
|
||||
except cst.ParserSyntaxError:
|
||||
return {}, [], None
|
||||
|
||||
wrapper = MetadataWrapper(tree)
|
||||
collector = DataclassCollector()
|
||||
wrapper.visit(collector)
|
||||
source_lines = source_code.split("\n")
|
||||
return collector.dataclasses, source_lines, wrapper
|
||||
|
||||
|
||||
def get_all_fields_with_inheritance(
|
||||
class_name: str, dataclasses: dict[str, cst.ClassDef], source_lines: list[str], wrapper: MetadataWrapper
|
||||
) -> list[FieldInfo]:
|
||||
"""Get all fields for a dataclass, including inherited ones.
|
||||
|
||||
Fields are returned in order: inherited first, then own fields.
|
||||
"""
|
||||
if class_name not in dataclasses:
|
||||
return []
|
||||
|
||||
class_node = dataclasses[class_name]
|
||||
|
||||
# First, collect inherited fields from parent dataclasses
|
||||
inherited_fields: list[FieldInfo] = []
|
||||
if class_node.bases:
|
||||
for base in class_node.bases:
|
||||
base_name = get_base_class_name(base)
|
||||
if base_name and base_name in dataclasses:
|
||||
parent_fields = get_all_fields_with_inheritance(base_name, dataclasses, source_lines, wrapper)
|
||||
# Mark these as inherited
|
||||
inherited_fields.extend(
|
||||
FieldInfo(field.name, field.annotation, field.has_default, field.default_repr, "inherited")
|
||||
for field in parent_fields
|
||||
)
|
||||
|
||||
# Then collect own fields
|
||||
own_fields = extract_dataclass_fields(class_node, source_lines, wrapper)
|
||||
|
||||
return inherited_fields + own_fields
|
||||
|
||||
|
||||
def format_constructor_signature(class_name: str, fields: list[FieldInfo]) -> str:
|
||||
"""Format a constructor signature for documentation."""
|
||||
if not fields:
|
||||
return f"{class_name}() - no fields"
|
||||
|
||||
# Separate required and optional fields
|
||||
required = [f for f in fields if not f.has_default]
|
||||
optional = [f for f in fields if f.has_default]
|
||||
|
||||
lines = [f"Constructor signature for {class_name}:"]
|
||||
|
||||
if required:
|
||||
lines.append(" Required (positional) arguments:")
|
||||
for f in required:
|
||||
source_note = " (from parent class)" if f.source == "inherited" else ""
|
||||
lines.append(f" - {f.name}: {f.annotation}{source_note}")
|
||||
|
||||
if optional:
|
||||
lines.append(" Optional (keyword) arguments:")
|
||||
for f in optional:
|
||||
default_note = f" = {f.default_repr}" if f.default_repr else " = ..."
|
||||
source_note = " (from parent class)" if f.source == "inherited" else ""
|
||||
lines.append(f" - {f.name}: {f.annotation}{default_note}{source_note}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def get_dataclass_constructor_notes(test_context: str) -> list[str]:
|
||||
"""Generate notes about dataclass constructor signatures.
|
||||
|
||||
Analyzes the test context for dataclass definitions and generates
|
||||
explicit notes about their constructor signatures, including inherited fields.
|
||||
|
||||
Args:
|
||||
test_context: The source code context to analyze for dataclass definitions.
|
||||
|
||||
"""
|
||||
# Extract actual code from markdown if present
|
||||
code_to_analyze = extract_all_code_from_markdown(test_context) if "```python" in test_context else test_context
|
||||
|
||||
# Find all dataclasses in the context
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code_to_analyze)
|
||||
|
||||
if not dataclasses or wrapper is None:
|
||||
return []
|
||||
|
||||
notes = []
|
||||
for class_name in dataclasses:
|
||||
fields = get_all_fields_with_inheritance(class_name, dataclasses, source_lines, wrapper)
|
||||
if fields:
|
||||
signature = format_constructor_signature(class_name, fields)
|
||||
notes.append(signature)
|
||||
|
||||
return notes
|
||||
|
|
@ -1,12 +1,10 @@
|
|||
from itertools import chain
|
||||
|
||||
from core.languages.python.testgen.preprocessing.class_constructor_notes import get_class_constructor_notes
|
||||
from core.languages.python.testgen.preprocessing.dataclass_constructor_notes import get_dataclass_constructor_notes
|
||||
from core.languages.python.testgen.preprocessing.torch_tensor_limit import get_tensor_size_note
|
||||
|
||||
# Preprocessing functions that analyze code context and return notes
|
||||
# Each function takes test_context (str) and returns a list of notes (list[str])
|
||||
_PREPROCESSING_FUNCTIONS = [get_tensor_size_note, get_dataclass_constructor_notes, get_class_constructor_notes]
|
||||
_PREPROCESSING_FUNCTIONS = [get_tensor_size_note]
|
||||
|
||||
|
||||
def preprocessing_testgen_pipeline(test_context: str) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -1,418 +0,0 @@
|
|||
"""Tests for regular class constructor signature extraction."""
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from core.languages.python.testgen.preprocessing.class_constructor_notes import (
|
||||
ParamInfo,
|
||||
_find_all_classes_with_init,
|
||||
extract_init_params,
|
||||
format_init_signature,
|
||||
get_class_constructor_notes,
|
||||
)
|
||||
|
||||
|
||||
class TestClassInitCollector:
|
||||
def test_collects_class_with_init(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
"""
|
||||
classes, _, wrapper = _find_all_classes_with_init(code)
|
||||
assert "Foo" in classes
|
||||
assert wrapper is not None
|
||||
|
||||
def test_skips_class_without_init(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def bar(self):
|
||||
pass
|
||||
"""
|
||||
classes, _, _ = _find_all_classes_with_init(code)
|
||||
assert "Foo" not in classes
|
||||
|
||||
def test_skips_dataclass(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
"""
|
||||
classes, _, _ = _find_all_classes_with_init(code)
|
||||
assert "Foo" not in classes
|
||||
|
||||
def test_skips_dataclass_with_call(self) -> None:
|
||||
code = """
|
||||
@dataclass(frozen=True)
|
||||
class Foo:
|
||||
x: int
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
"""
|
||||
classes, _, _ = _find_all_classes_with_init(code)
|
||||
assert "Foo" not in classes
|
||||
|
||||
def test_collects_multiple_classes(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
|
||||
class Bar:
|
||||
def __init__(self, y: str):
|
||||
self.y = y
|
||||
"""
|
||||
classes, _, _ = _find_all_classes_with_init(code)
|
||||
assert "Foo" in classes
|
||||
assert "Bar" in classes
|
||||
|
||||
def test_handles_syntax_error(self) -> None:
|
||||
code = "this is not valid python code {{{{"
|
||||
classes, _, wrapper = _find_all_classes_with_init(code)
|
||||
assert classes == {}
|
||||
assert wrapper is None
|
||||
|
||||
|
||||
class TestExtractInitParams:
|
||||
def test_typed_params(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: int, y: str):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 2
|
||||
assert params[0].name == "x"
|
||||
assert params[0].annotation == "int"
|
||||
assert not params[0].has_default
|
||||
assert params[1].name == "y"
|
||||
assert params[1].annotation == "str"
|
||||
assert not params[1].has_default
|
||||
|
||||
def test_params_with_defaults(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: int = 0, y: str = "hello"):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 2
|
||||
assert params[0].name == "x"
|
||||
assert params[0].has_default
|
||||
assert params[0].default_repr == "0"
|
||||
assert params[1].name == "y"
|
||||
assert params[1].has_default
|
||||
assert params[1].default_repr == '"hello"'
|
||||
|
||||
def test_args_and_kwargs(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: int, *args, **kwargs):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 3
|
||||
assert params[0].name == "x"
|
||||
assert params[1].name == "*args"
|
||||
assert params[2].name == "**kwargs"
|
||||
|
||||
def test_skips_self(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 0
|
||||
|
||||
def test_untyped_params(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x, y=10):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 2
|
||||
assert params[0].name == "x"
|
||||
assert params[0].annotation is None
|
||||
assert params[1].name == "y"
|
||||
assert params[1].annotation is None
|
||||
assert params[1].has_default
|
||||
assert params[1].default_repr == "10"
|
||||
|
||||
def test_kwonly_params(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: int, *, key: str = "default"):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 2
|
||||
assert params[0].name == "x"
|
||||
assert params[1].name == "key"
|
||||
assert params[1].annotation == "str"
|
||||
assert params[1].has_default
|
||||
assert params[1].default_repr == '"default"'
|
||||
|
||||
def test_long_default_truncated(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 1
|
||||
assert params[0].has_default
|
||||
assert params[0].default_repr is not None
|
||||
assert len(params[0].default_repr) <= 50
|
||||
assert params[0].default_repr.endswith("...")
|
||||
|
||||
|
||||
class TestFormatInitSignature:
|
||||
def test_required_only(self) -> None:
|
||||
params = [
|
||||
ParamInfo(name="x", annotation="int", has_default=False, default_repr=None),
|
||||
ParamInfo(name="y", annotation="str", has_default=False, default_repr=None),
|
||||
]
|
||||
result = format_init_signature("Foo", params)
|
||||
assert "Constructor signature for Foo:" in result
|
||||
assert "Required (positional) arguments:" in result
|
||||
assert "- x: int" in result
|
||||
assert "- y: str" in result
|
||||
assert "Optional" not in result
|
||||
|
||||
def test_optional_only(self) -> None:
|
||||
params = [
|
||||
ParamInfo(name="x", annotation="int", has_default=True, default_repr="0"),
|
||||
ParamInfo(name="y", annotation="str", has_default=True, default_repr='"hi"'),
|
||||
]
|
||||
result = format_init_signature("Foo", params)
|
||||
assert "Constructor signature for Foo:" in result
|
||||
assert "Optional (keyword) arguments:" in result
|
||||
assert "- x: int = 0" in result
|
||||
assert '- y: str = "hi"' in result
|
||||
assert "Required" not in result
|
||||
|
||||
def test_mixed_params(self) -> None:
|
||||
params = [
|
||||
ParamInfo(name="x", annotation="int", has_default=False, default_repr=None),
|
||||
ParamInfo(name="y", annotation="str", has_default=True, default_repr='"default"'),
|
||||
]
|
||||
result = format_init_signature("Foo", params)
|
||||
assert "Required (positional) arguments:" in result
|
||||
assert "Optional (keyword) arguments:" in result
|
||||
|
||||
def test_no_params(self) -> None:
|
||||
result = format_init_signature("Empty", [])
|
||||
assert "Empty() - no parameters" in result
|
||||
|
||||
def test_variadic_params(self) -> None:
|
||||
params = [
|
||||
ParamInfo(name="x", annotation="int", has_default=False, default_repr=None),
|
||||
ParamInfo(name="*args", annotation=None, has_default=False, default_repr=None),
|
||||
ParamInfo(name="**kwargs", annotation=None, has_default=False, default_repr=None),
|
||||
]
|
||||
result = format_init_signature("Foo", params)
|
||||
assert "Required (positional) arguments:" in result
|
||||
assert "Variadic arguments:" in result
|
||||
assert "- *args" in result
|
||||
assert "- **kwargs" in result
|
||||
|
||||
def test_untyped_params(self) -> None:
|
||||
params = [
|
||||
ParamInfo(name="x", annotation=None, has_default=False, default_repr=None),
|
||||
ParamInfo(name="y", annotation=None, has_default=True, default_repr="10"),
|
||||
]
|
||||
result = format_init_signature("Foo", params)
|
||||
assert " - x\n" in result
|
||||
assert " - y = 10" in result
|
||||
|
||||
|
||||
class TestGetClassConstructorNotes:
|
||||
def test_class_with_typed_init(self) -> None:
|
||||
context = """
|
||||
class LayoutElements:
|
||||
def __init__(self, width: int, height: int, elements: list[str]):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.elements = elements
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "Constructor signature for LayoutElements:" in notes[0]
|
||||
assert "- width: int" in notes[0]
|
||||
assert "- height: int" in notes[0]
|
||||
assert "- elements: list[str]" in notes[0]
|
||||
|
||||
def test_class_with_defaults(self) -> None:
|
||||
context = """
|
||||
class TextRegions:
|
||||
def __init__(self, text: str = "", max_len: int = 100):
|
||||
self.text = text
|
||||
self.max_len = max_len
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "Optional (keyword) arguments:" in notes[0]
|
||||
assert '- text: str = ""' in notes[0]
|
||||
assert "- max_len: int = 100" in notes[0]
|
||||
|
||||
def test_class_with_args_kwargs(self) -> None:
|
||||
context = """
|
||||
class Flexible:
|
||||
def __init__(self, name: str, *args, **kwargs):
|
||||
self.name = name
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "- name: str" in notes[0]
|
||||
assert "- *args" in notes[0]
|
||||
assert "- **kwargs" in notes[0]
|
||||
|
||||
def test_class_without_init_skipped(self) -> None:
|
||||
context = """
|
||||
class NoInit:
|
||||
x = 10
|
||||
def method(self):
|
||||
pass
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert notes == []
|
||||
|
||||
def test_dataclass_skipped(self) -> None:
|
||||
context = """
|
||||
@dataclass
|
||||
class Config:
|
||||
name: str
|
||||
value: int
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert notes == []
|
||||
|
||||
def test_multiple_classes(self) -> None:
|
||||
context = """
|
||||
class Foo:
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
|
||||
class Bar:
|
||||
def __init__(self, y: str):
|
||||
self.y = y
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 2
|
||||
class_names = " ".join(notes)
|
||||
assert "Foo" in class_names
|
||||
assert "Bar" in class_names
|
||||
|
||||
def test_markdown_wrapped_code(self) -> None:
|
||||
context = """
|
||||
Some description text.
|
||||
```python:models.py
|
||||
class LayoutElements:
|
||||
def __init__(self, width: int, height: int):
|
||||
self.width = width
|
||||
self.height = height
|
||||
```
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "LayoutElements" in notes[0]
|
||||
assert "- width: int" in notes[0]
|
||||
|
||||
def test_syntax_error_returns_empty(self) -> None:
|
||||
context = "this is not valid python {{{{"
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert notes == []
|
||||
|
||||
def test_init_with_only_self(self) -> None:
|
||||
context = """
|
||||
class Empty:
|
||||
def __init__(self):
|
||||
pass
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "Empty() - no parameters" in notes[0]
|
||||
|
||||
def test_mixed_dataclass_and_regular(self) -> None:
|
||||
context = """
|
||||
@dataclass
|
||||
class Config:
|
||||
name: str
|
||||
|
||||
class Service:
|
||||
def __init__(self, config: Config, debug: bool = False):
|
||||
self.config = config
|
||||
self.debug = debug
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "Service" in notes[0]
|
||||
assert "Config" not in notes[0].split("\n")[0]
|
||||
|
||||
def test_plain_code_without_markdown(self) -> None:
|
||||
context = """
|
||||
class MyClass:
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "MyClass" in notes[0]
|
||||
|
||||
def test_full_output_format(self) -> None:
|
||||
context = """
|
||||
class Server:
|
||||
def __init__(self, host: str, port: int, debug: bool = False, workers: int = 4):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.debug = debug
|
||||
self.workers = workers
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
expected = """Constructor signature for Server:
|
||||
Required (positional) arguments:
|
||||
- host: str
|
||||
- port: int
|
||||
Optional (keyword) arguments:
|
||||
- debug: bool = False
|
||||
- workers: int = 4"""
|
||||
assert notes[0] == expected
|
||||
|
|
@ -1,529 +0,0 @@
|
|||
"""Tests for dataclass constructor signature extraction."""
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from core.languages.python.cst_utils import get_base_class_name, has_decorator
|
||||
from aiservice.common.markdown_utils import extract_all_code_from_markdown
|
||||
from core.languages.python.testgen.preprocessing.dataclass_constructor_notes import (
|
||||
FieldInfo,
|
||||
extract_dataclass_fields,
|
||||
find_all_dataclasses,
|
||||
format_constructor_signature,
|
||||
get_all_fields_with_inheritance,
|
||||
get_dataclass_constructor_notes,
|
||||
)
|
||||
|
||||
|
||||
class TestHasDecorator:
|
||||
def test_simple_dataclass_decorator(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
"""
|
||||
tree = cst.parse_module(code)
|
||||
class_node = tree.body[0]
|
||||
assert isinstance(class_node, cst.ClassDef)
|
||||
assert has_decorator(class_node, "dataclass")
|
||||
|
||||
def test_dataclass_with_call(self) -> None:
|
||||
code = """
|
||||
@dataclass(frozen=True)
|
||||
class Foo:
|
||||
x: int
|
||||
"""
|
||||
tree = cst.parse_module(code)
|
||||
class_node = tree.body[0]
|
||||
assert isinstance(class_node, cst.ClassDef)
|
||||
assert has_decorator(class_node, "dataclass")
|
||||
|
||||
def test_not_a_dataclass(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
x: int
|
||||
"""
|
||||
tree = cst.parse_module(code)
|
||||
class_node = tree.body[0]
|
||||
assert isinstance(class_node, cst.ClassDef)
|
||||
assert not has_decorator(class_node, "dataclass")
|
||||
|
||||
def test_other_decorator(self) -> None:
|
||||
code = """
|
||||
@other_decorator
|
||||
class Foo:
|
||||
x: int
|
||||
"""
|
||||
tree = cst.parse_module(code)
|
||||
class_node = tree.body[0]
|
||||
assert isinstance(class_node, cst.ClassDef)
|
||||
assert not has_decorator(class_node, "dataclass")
|
||||
|
||||
|
||||
class TestGetBaseClassName:
|
||||
def test_simple_name(self) -> None:
|
||||
code = "class Foo(Bar): pass"
|
||||
tree = cst.parse_module(code)
|
||||
class_node = tree.body[0]
|
||||
assert isinstance(class_node, cst.ClassDef)
|
||||
assert get_base_class_name(class_node.bases[0]) == "Bar"
|
||||
|
||||
def test_attribute(self) -> None:
|
||||
code = "class Foo(module.Bar): pass"
|
||||
tree = cst.parse_module(code)
|
||||
class_node = tree.body[0]
|
||||
assert isinstance(class_node, cst.ClassDef)
|
||||
assert get_base_class_name(class_node.bases[0]) == "Bar"
|
||||
|
||||
|
||||
class TestExtractDataclassFields:
|
||||
def test_simple_fields(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
y: str
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
class_node = dataclasses["Foo"]
|
||||
fields = extract_dataclass_fields(class_node, source_lines, wrapper)
|
||||
|
||||
assert len(fields) == 2
|
||||
assert fields[0].name == "x"
|
||||
assert fields[0].annotation == "int"
|
||||
assert not fields[0].has_default
|
||||
assert fields[1].name == "y"
|
||||
assert fields[1].annotation == "str"
|
||||
assert not fields[1].has_default
|
||||
|
||||
def test_fields_with_defaults(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
y: str = "default"
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
class_node = dataclasses["Foo"]
|
||||
fields = extract_dataclass_fields(class_node, source_lines, wrapper)
|
||||
|
||||
assert len(fields) == 2
|
||||
assert fields[0].name == "x"
|
||||
assert not fields[0].has_default
|
||||
assert fields[1].name == "y"
|
||||
assert fields[1].has_default
|
||||
assert fields[1].default_repr == '"default"'
|
||||
|
||||
def test_complex_annotation(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
items: list[str]
|
||||
mapping: dict[str, int] | None = None
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
class_node = dataclasses["Foo"]
|
||||
fields = extract_dataclass_fields(class_node, source_lines, wrapper)
|
||||
|
||||
assert len(fields) == 2
|
||||
assert fields[0].name == "items"
|
||||
assert fields[0].annotation == "list[str]"
|
||||
assert fields[1].name == "mapping"
|
||||
assert "dict[str, int]" in fields[1].annotation
|
||||
|
||||
|
||||
class TestFindAllDataclasses:
|
||||
def test_finds_single_dataclass(self) -> None:
|
||||
code = """
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
"""
|
||||
result, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert "Foo" in result
|
||||
assert wrapper is not None
|
||||
assert len(source_lines) > 0
|
||||
|
||||
def test_finds_multiple_dataclasses(self) -> None:
|
||||
code = """
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
|
||||
@dataclass
|
||||
class Bar:
|
||||
y: str
|
||||
"""
|
||||
result, _source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert "Foo" in result
|
||||
assert "Bar" in result
|
||||
assert wrapper is not None
|
||||
|
||||
def test_ignores_non_dataclasses(self) -> None:
|
||||
code = """
|
||||
class NotADataclass:
|
||||
x: int
|
||||
|
||||
@dataclass
|
||||
class IsADataclass:
|
||||
y: str
|
||||
"""
|
||||
result, _source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert "NotADataclass" not in result
|
||||
assert "IsADataclass" in result
|
||||
assert wrapper is not None
|
||||
|
||||
def test_handles_syntax_error(self) -> None:
|
||||
code = "this is not valid python code {{{{"
|
||||
result, _source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert result == {}
|
||||
assert wrapper is None
|
||||
|
||||
|
||||
class TestGetAllFieldsWithInheritance:
|
||||
def test_no_inheritance(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
y: str
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
fields = get_all_fields_with_inheritance("Foo", dataclasses, source_lines, wrapper)
|
||||
|
||||
assert len(fields) == 2
|
||||
assert all(f.source == "own" for f in fields)
|
||||
|
||||
def test_simple_inheritance(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Base:
|
||||
a: int
|
||||
b: str
|
||||
|
||||
@dataclass
|
||||
class Child(Base):
|
||||
c: float
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
fields = get_all_fields_with_inheritance("Child", dataclasses, source_lines, wrapper)
|
||||
|
||||
assert len(fields) == 3
|
||||
assert fields[0].name == "a"
|
||||
assert fields[0].source == "inherited"
|
||||
assert fields[1].name == "b"
|
||||
assert fields[1].source == "inherited"
|
||||
assert fields[2].name == "c"
|
||||
assert fields[2].source == "own"
|
||||
|
||||
def test_multi_level_inheritance(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class GrandParent:
|
||||
a: int
|
||||
|
||||
@dataclass
|
||||
class Parent(GrandParent):
|
||||
b: str
|
||||
|
||||
@dataclass
|
||||
class Child(Parent):
|
||||
c: float
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
fields = get_all_fields_with_inheritance("Child", dataclasses, source_lines, wrapper)
|
||||
|
||||
assert len(fields) == 3
|
||||
field_names = [f.name for f in fields]
|
||||
assert "a" in field_names
|
||||
assert "b" in field_names
|
||||
assert "c" in field_names
|
||||
|
||||
def test_unknown_class(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
fields = get_all_fields_with_inheritance("Unknown", dataclasses, source_lines, wrapper)
|
||||
assert fields == []
|
||||
|
||||
|
||||
class TestFormatConstructorSignature:
|
||||
def test_required_only(self) -> None:
|
||||
fields = [
|
||||
FieldInfo(name="x", annotation="int", has_default=False, default_repr=None, source="own"),
|
||||
FieldInfo(name="y", annotation="str", has_default=False, default_repr=None, source="own"),
|
||||
]
|
||||
result = format_constructor_signature("Foo", fields)
|
||||
assert "Constructor signature for Foo:" in result
|
||||
assert "Required (positional) arguments:" in result
|
||||
assert "- x: int" in result
|
||||
assert "- y: str" in result
|
||||
assert "Optional" not in result
|
||||
|
||||
def test_optional_only(self) -> None:
|
||||
fields = [
|
||||
FieldInfo(name="x", annotation="int", has_default=True, default_repr="0", source="own"),
|
||||
FieldInfo(name="y", annotation="str", has_default=True, default_repr='"default"', source="own"),
|
||||
]
|
||||
result = format_constructor_signature("Foo", fields)
|
||||
assert "Constructor signature for Foo:" in result
|
||||
assert "Optional (keyword) arguments:" in result
|
||||
assert "- x: int = 0" in result
|
||||
assert '- y: str = "default"' in result
|
||||
assert "Required" not in result
|
||||
|
||||
def test_mixed_fields(self) -> None:
|
||||
fields = [
|
||||
FieldInfo(name="a", annotation="int", has_default=False, default_repr=None, source="inherited"),
|
||||
FieldInfo(name="b", annotation="str", has_default=False, default_repr=None, source="own"),
|
||||
FieldInfo(name="c", annotation="float", has_default=True, default_repr="0.0", source="own"),
|
||||
]
|
||||
result = format_constructor_signature("Foo", fields)
|
||||
assert "Required (positional) arguments:" in result
|
||||
assert "Optional (keyword) arguments:" in result
|
||||
assert "(from parent class)" in result
|
||||
|
||||
def test_no_fields(self) -> None:
|
||||
result = format_constructor_signature("Empty", [])
|
||||
assert "no fields" in result
|
||||
|
||||
|
||||
class TestExtractCodeFromMarkdown:
|
||||
def test_simple_code_block(self) -> None:
|
||||
markdown = """
|
||||
Some text
|
||||
```python
|
||||
def foo():
|
||||
pass
|
||||
```
|
||||
More text
|
||||
"""
|
||||
result = extract_all_code_from_markdown(markdown)
|
||||
assert "def foo():" in result
|
||||
assert "pass" in result
|
||||
assert "Some text" not in result
|
||||
|
||||
def test_code_block_with_filepath(self) -> None:
|
||||
markdown = """
|
||||
```python:path/to/file.py
|
||||
class Foo:
|
||||
x: int
|
||||
```
|
||||
"""
|
||||
result = extract_all_code_from_markdown(markdown)
|
||||
assert "class Foo:" in result
|
||||
assert "x: int" in result
|
||||
|
||||
def test_multiple_code_blocks(self) -> None:
|
||||
markdown = """
|
||||
```python:file1.py
|
||||
@dataclass
|
||||
class A:
|
||||
x: int
|
||||
```
|
||||
|
||||
```python:file2.py
|
||||
@dataclass
|
||||
class B:
|
||||
y: str
|
||||
```
|
||||
"""
|
||||
result = extract_all_code_from_markdown(markdown)
|
||||
assert "class A:" in result
|
||||
assert "class B:" in result
|
||||
|
||||
|
||||
class TestGetDataclassConstructorNotes:
|
||||
def test_simple_dataclass(self) -> None:
|
||||
context = """
|
||||
```python:models.py
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
name: str
|
||||
value: int
|
||||
```
|
||||
"""
|
||||
notes = get_dataclass_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "Constructor signature for Config:" in notes[0]
|
||||
assert "- name: str" in notes[0]
|
||||
assert "- value: int" in notes[0]
|
||||
|
||||
def test_dataclass_with_inheritance(self) -> None:
|
||||
context = """
|
||||
```python:models.py
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
model_name: str
|
||||
required_env_vars: list[str]
|
||||
|
||||
@dataclass
|
||||
class ExtendedConfig(BaseConfig):
|
||||
extra_param: int = 0
|
||||
```
|
||||
"""
|
||||
notes = get_dataclass_constructor_notes(context)
|
||||
assert len(notes) == 2
|
||||
|
||||
# Find the note for ExtendedConfig
|
||||
extended_note = next(n for n in notes if "ExtendedConfig" in n)
|
||||
assert "- model_name: str (from parent class)" in extended_note
|
||||
assert "- required_env_vars: list[str] (from parent class)" in extended_note
|
||||
assert "- extra_param: int = 0" in extended_note
|
||||
|
||||
def test_no_dataclasses(self) -> None:
|
||||
context = """
|
||||
def regular_function():
|
||||
pass
|
||||
"""
|
||||
notes = get_dataclass_constructor_notes(context)
|
||||
assert notes == []
|
||||
|
||||
def test_plain_code_without_markdown(self) -> None:
|
||||
context = """
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class SimpleConfig:
|
||||
name: str
|
||||
"""
|
||||
notes = get_dataclass_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "SimpleConfig" in notes[0]
|
||||
|
||||
def test_llm_config_like_structure(self) -> None:
|
||||
"""Test with a structure similar to the skyvern LLMConfig."""
|
||||
context = """
|
||||
```python:skyvern/forge/sdk/api/llm/models.py
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfigBase:
|
||||
model_name: str
|
||||
required_env_vars: list[str]
|
||||
supports_vision: bool
|
||||
add_assistant_prefix: bool
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig(LLMConfigBase):
|
||||
litellm_params: Optional[dict] = field(default=None)
|
||||
max_tokens: int | None = 4096
|
||||
```
|
||||
"""
|
||||
notes = get_dataclass_constructor_notes(context)
|
||||
|
||||
# Should have notes for both classes
|
||||
assert len(notes) == 2
|
||||
|
||||
# Find the LLMConfig note
|
||||
llm_config_note = next(n for n in notes if "LLMConfig:" in n)
|
||||
|
||||
# Should show inherited fields from LLMConfigBase
|
||||
assert "- model_name: str (from parent class)" in llm_config_note
|
||||
assert "- required_env_vars: list[str] (from parent class)" in llm_config_note
|
||||
assert "- supports_vision: bool (from parent class)" in llm_config_note
|
||||
assert "- add_assistant_prefix: bool (from parent class)" in llm_config_note
|
||||
|
||||
# Should also show own fields
|
||||
assert "- litellm_params:" in llm_config_note
|
||||
assert "- max_tokens:" in llm_config_note
|
||||
|
||||
def test_full_output_for_complex_inheritance(self) -> None:
|
||||
"""Test full generated notes output for complex dataclass inheritance."""
|
||||
context = """
|
||||
```python:skyvern/forge/sdk/api/llm/models.py
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfigBase:
|
||||
model_name: str
|
||||
required_env_vars: list[str]
|
||||
supports_vision: bool
|
||||
add_assistant_prefix: bool
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig(LLMConfigBase):
|
||||
litellm_params: Optional[dict] = field(default=None)
|
||||
max_tokens: int | None = 4096
|
||||
max_completion_tokens: int | None = None
|
||||
temperature: float | None = 0.7
|
||||
reasoning_effort: str | None = None
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMRouterConfig(LLMConfigBase):
|
||||
model_list: list
|
||||
main_model_group: str
|
||||
redis_host: str | None = None
|
||||
redis_port: int | None = None
|
||||
fallback_model_group: str | None = None
|
||||
routing_strategy: str = "usage-based-routing"
|
||||
num_retries: int = 1
|
||||
```
|
||||
"""
|
||||
notes = get_dataclass_constructor_notes(context)
|
||||
|
||||
# Should have notes for all 3 classes
|
||||
assert len(notes) == 3
|
||||
|
||||
# Verify LLMConfigBase note (no inheritance)
|
||||
base_note = next(n for n in notes if "LLMConfigBase:" in n)
|
||||
expected_base = """Constructor signature for LLMConfigBase:
|
||||
Required (positional) arguments:
|
||||
- model_name: str
|
||||
- required_env_vars: list[str]
|
||||
- supports_vision: bool
|
||||
- add_assistant_prefix: bool"""
|
||||
assert base_note == expected_base
|
||||
|
||||
# Verify LLMConfig note (with inheritance)
|
||||
config_note = next(n for n in notes if "LLMConfig:" in n)
|
||||
expected_config = """Constructor signature for LLMConfig:
|
||||
Required (positional) arguments:
|
||||
- model_name: str (from parent class)
|
||||
- required_env_vars: list[str] (from parent class)
|
||||
- supports_vision: bool (from parent class)
|
||||
- add_assistant_prefix: bool (from parent class)
|
||||
Optional (keyword) arguments:
|
||||
- litellm_params: Optional[dict] = field(default=None)
|
||||
- max_tokens: int | None = 4096
|
||||
- max_completion_tokens: int | None = None
|
||||
- temperature: float | None = 0.7
|
||||
- reasoning_effort: str | None = None"""
|
||||
assert config_note == expected_config
|
||||
|
||||
# Verify LLMRouterConfig note (with inheritance + own required fields)
|
||||
router_note = next(n for n in notes if "LLMRouterConfig:" in n)
|
||||
expected_router = """Constructor signature for LLMRouterConfig:
|
||||
Required (positional) arguments:
|
||||
- model_name: str (from parent class)
|
||||
- required_env_vars: list[str] (from parent class)
|
||||
- supports_vision: bool (from parent class)
|
||||
- add_assistant_prefix: bool (from parent class)
|
||||
- model_list: list
|
||||
- main_model_group: str
|
||||
Optional (keyword) arguments:
|
||||
- redis_host: str | None = None
|
||||
- redis_port: int | None = None
|
||||
- fallback_model_group: str | None = None
|
||||
- routing_strategy: str = "usage-based-routing"
|
||||
- num_retries: int = 1"""
|
||||
assert router_note == expected_router
|
||||
Loading…
Reference in a new issue