feat: add constructor notes for non-dataclass classes with __init__

The LLM prompt preprocessing now highlights __init__ signatures for
regular classes, not just @dataclass ones, reducing brute-force
constructor guessing and pytest.skip() fallbacks in generated tests.
This commit is contained in:
Kevin Turcios 2026-02-15 07:29:05 -05:00
parent 38eda0c2d6
commit d6a3c6254f
3 changed files with 578 additions and 1 deletions

View file

@ -0,0 +1,167 @@
"""Preprocessing step to extract constructor signatures from regular (non-dataclass) classes.
When the test context contains class definitions with explicit __init__ methods,
this module extracts the constructor signatures to help the LLM understand how
to properly instantiate them.
"""
from __future__ import annotations
from typing import NamedTuple
import libcst as cst
from libcst.metadata import MetadataWrapper
from aiservice.common.markdown_utils import extract_all_code_from_markdown
from core.languages.python.cst_utils import get_node_source_text, has_decorator, parse_module_to_cst
class ParamInfo(NamedTuple):
name: str
annotation: str | None
has_default: bool
default_repr: str | None
def _find_init_method(class_node: cst.ClassDef) -> cst.FunctionDef | None:
for stmt in class_node.body.body:
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == "__init__":
return stmt
return None
def extract_init_params(
init_node: cst.FunctionDef, source_lines: list[str], wrapper: MetadataWrapper
) -> list[ParamInfo]:
params: list[ParamInfo] = []
parameters = init_node.params
for param in parameters.params:
name = param.name.value
if name == "self":
continue
annotation = None
if param.annotation is not None:
annotation = get_node_source_text(param.annotation, source_lines, wrapper)
has_default = param.default is not None
default_repr = None
if has_default:
default_repr = get_node_source_text(param.default, source_lines, wrapper)
if len(default_repr) > 50:
default_repr = default_repr[:47] + "..."
params.append(ParamInfo(name, annotation, has_default, default_repr))
if isinstance(parameters.star_arg, cst.Param):
name = f"*{parameters.star_arg.name.value}"
annotation = None
if parameters.star_arg.annotation is not None:
annotation = get_node_source_text(parameters.star_arg.annotation, source_lines, wrapper)
params.append(ParamInfo(name, annotation, has_default=False, default_repr=None))
for param in parameters.kwonly_params:
name = param.name.value
annotation = None
if param.annotation is not None:
annotation = get_node_source_text(param.annotation, source_lines, wrapper)
has_default = param.default is not None
default_repr = None
if has_default:
default_repr = get_node_source_text(param.default, source_lines, wrapper)
if len(default_repr) > 50:
default_repr = default_repr[:47] + "..."
params.append(ParamInfo(name, annotation, has_default, default_repr))
if parameters.star_kwarg is not None:
name = f"**{parameters.star_kwarg.name.value}"
annotation = None
if parameters.star_kwarg.annotation is not None:
annotation = get_node_source_text(parameters.star_kwarg.annotation, source_lines, wrapper)
params.append(ParamInfo(name, annotation, has_default=False, default_repr=None))
return params
class ClassInitCollector(cst.CSTVisitor):
def __init__(self) -> None:
self.classes: dict[str, cst.ClassDef] = {}
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
if not has_decorator(node, "dataclass") and _find_init_method(node) is not None:
self.classes[node.name.value] = node
return True
def _find_all_classes_with_init(source_code: str) -> tuple[dict[str, cst.ClassDef], list[str], MetadataWrapper | None]:
try:
tree = parse_module_to_cst(source_code)
except cst.ParserSyntaxError:
return {}, [], None
wrapper = MetadataWrapper(tree)
collector = ClassInitCollector()
wrapper.visit(collector)
source_lines = source_code.split("\n")
return collector.classes, source_lines, wrapper
def format_init_signature(class_name: str, params: list[ParamInfo]) -> str:
if not params:
return f"{class_name}() - no parameters"
required = [p for p in params if not p.has_default and not p.name.startswith("*")]
optional = [p for p in params if p.has_default]
variadic = [p for p in params if p.name.startswith("*")]
lines = [f"Constructor signature for {class_name}:"]
if required:
lines.append(" Required (positional) arguments:")
for p in required:
if p.annotation:
lines.append(f" - {p.name}: {p.annotation}")
else:
lines.append(f" - {p.name}")
if optional:
lines.append(" Optional (keyword) arguments:")
for p in optional:
default_note = f" = {p.default_repr}" if p.default_repr else " = ..."
if p.annotation:
lines.append(f" - {p.name}: {p.annotation}{default_note}")
else:
lines.append(f" - {p.name}{default_note}")
if variadic:
lines.append(" Variadic arguments:")
for p in variadic:
if p.annotation:
lines.append(f" - {p.name}: {p.annotation}")
else:
lines.append(f" - {p.name}")
return "\n".join(lines)
def get_class_constructor_notes(test_context: str) -> list[str]:
code_to_analyze = extract_all_code_from_markdown(test_context) if "```python" in test_context else test_context
classes, source_lines, wrapper = _find_all_classes_with_init(code_to_analyze)
if not classes or wrapper is None:
return []
notes = []
for class_name, class_node in classes.items():
init_node = _find_init_method(class_node)
if init_node is None:
continue
params = extract_init_params(init_node, source_lines, wrapper)
signature = format_init_signature(class_name, params)
notes.append(signature)
return notes

View file

@ -1,11 +1,12 @@
from itertools import chain
from core.languages.python.testgen.preprocessing.class_constructor_notes import get_class_constructor_notes
from core.languages.python.testgen.preprocessing.dataclass_constructor_notes import get_dataclass_constructor_notes
from core.languages.python.testgen.preprocessing.torch_tensor_limit import get_tensor_size_note
# Preprocessing functions that analyze code context and return notes
# Each function takes test_context (str) and returns a list of notes (list[str])
_PREPROCESSING_FUNCTIONS = [get_tensor_size_note, get_dataclass_constructor_notes]
_PREPROCESSING_FUNCTIONS = [get_tensor_size_note, get_dataclass_constructor_notes, get_class_constructor_notes]
def preprocessing_testgen_pipeline(test_context: str) -> list[str]:

View file

@ -0,0 +1,409 @@
"""Tests for regular class constructor signature extraction."""
from core.languages.python.testgen.preprocessing.class_constructor_notes import (
ParamInfo,
_find_all_classes_with_init,
extract_init_params,
format_init_signature,
get_class_constructor_notes,
)
class TestClassInitCollector:
def test_collects_class_with_init(self) -> None:
code = """
class Foo:
def __init__(self, x: int):
self.x = x
"""
classes, _, wrapper = _find_all_classes_with_init(code)
assert "Foo" in classes
assert wrapper is not None
def test_skips_class_without_init(self) -> None:
code = """
class Foo:
def bar(self):
pass
"""
classes, _, _ = _find_all_classes_with_init(code)
assert "Foo" not in classes
def test_skips_dataclass(self) -> None:
code = """
@dataclass
class Foo:
x: int
def __init__(self, x: int):
self.x = x
"""
classes, _, _ = _find_all_classes_with_init(code)
assert "Foo" not in classes
def test_skips_dataclass_with_call(self) -> None:
code = """
@dataclass(frozen=True)
class Foo:
x: int
def __init__(self, x: int):
self.x = x
"""
classes, _, _ = _find_all_classes_with_init(code)
assert "Foo" not in classes
def test_collects_multiple_classes(self) -> None:
code = """
class Foo:
def __init__(self, x: int):
self.x = x
class Bar:
def __init__(self, y: str):
self.y = y
"""
classes, _, _ = _find_all_classes_with_init(code)
assert "Foo" in classes
assert "Bar" in classes
def test_handles_syntax_error(self) -> None:
code = "this is not valid python code {{{{"
classes, _, wrapper = _find_all_classes_with_init(code)
assert classes == {}
assert wrapper is None
class TestExtractInitParams:
def test_typed_params(self) -> None:
code = """
class Foo:
def __init__(self, x: int, y: str):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 2
assert params[0].name == "x"
assert params[0].annotation == "int"
assert not params[0].has_default
assert params[1].name == "y"
assert params[1].annotation == "str"
assert not params[1].has_default
def test_params_with_defaults(self) -> None:
code = """
class Foo:
def __init__(self, x: int = 0, y: str = "hello"):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 2
assert params[0].name == "x"
assert params[0].has_default
assert params[0].default_repr == "0"
assert params[1].name == "y"
assert params[1].has_default
assert params[1].default_repr == '"hello"'
def test_args_and_kwargs(self) -> None:
code = """
class Foo:
def __init__(self, x: int, *args, **kwargs):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 3
assert params[0].name == "x"
assert params[1].name == "*args"
assert params[2].name == "**kwargs"
def test_skips_self(self) -> None:
code = """
class Foo:
def __init__(self):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 0
def test_untyped_params(self) -> None:
code = """
class Foo:
def __init__(self, x, y=10):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 2
assert params[0].name == "x"
assert params[0].annotation is None
assert params[1].name == "y"
assert params[1].annotation is None
assert params[1].has_default
assert params[1].default_repr == "10"
def test_kwonly_params(self) -> None:
code = """
class Foo:
def __init__(self, x: int, *, key: str = "default"):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 2
assert params[0].name == "x"
assert params[1].name == "key"
assert params[1].annotation == "str"
assert params[1].has_default
assert params[1].default_repr == '"default"'
def test_long_default_truncated(self) -> None:
code = """
class Foo:
def __init__(self, x: list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 1
assert params[0].has_default
assert params[0].default_repr is not None
assert len(params[0].default_repr) <= 50
assert params[0].default_repr.endswith("...")
class TestFormatInitSignature:
def test_required_only(self) -> None:
params = [
ParamInfo(name="x", annotation="int", has_default=False, default_repr=None),
ParamInfo(name="y", annotation="str", has_default=False, default_repr=None),
]
result = format_init_signature("Foo", params)
assert "Constructor signature for Foo:" in result
assert "Required (positional) arguments:" in result
assert "- x: int" in result
assert "- y: str" in result
assert "Optional" not in result
def test_optional_only(self) -> None:
params = [
ParamInfo(name="x", annotation="int", has_default=True, default_repr="0"),
ParamInfo(name="y", annotation="str", has_default=True, default_repr='"hi"'),
]
result = format_init_signature("Foo", params)
assert "Constructor signature for Foo:" in result
assert "Optional (keyword) arguments:" in result
assert "- x: int = 0" in result
assert '- y: str = "hi"' in result
assert "Required" not in result
def test_mixed_params(self) -> None:
params = [
ParamInfo(name="x", annotation="int", has_default=False, default_repr=None),
ParamInfo(name="y", annotation="str", has_default=True, default_repr='"default"'),
]
result = format_init_signature("Foo", params)
assert "Required (positional) arguments:" in result
assert "Optional (keyword) arguments:" in result
def test_no_params(self) -> None:
result = format_init_signature("Empty", [])
assert "Empty() - no parameters" in result
def test_variadic_params(self) -> None:
params = [
ParamInfo(name="x", annotation="int", has_default=False, default_repr=None),
ParamInfo(name="*args", annotation=None, has_default=False, default_repr=None),
ParamInfo(name="**kwargs", annotation=None, has_default=False, default_repr=None),
]
result = format_init_signature("Foo", params)
assert "Required (positional) arguments:" in result
assert "Variadic arguments:" in result
assert "- *args" in result
assert "- **kwargs" in result
def test_untyped_params(self) -> None:
params = [
ParamInfo(name="x", annotation=None, has_default=False, default_repr=None),
ParamInfo(name="y", annotation=None, has_default=True, default_repr="10"),
]
result = format_init_signature("Foo", params)
assert " - x\n" in result
assert " - y = 10" in result
class TestGetClassConstructorNotes:
def test_class_with_typed_init(self) -> None:
context = """
class LayoutElements:
def __init__(self, width: int, height: int, elements: list[str]):
self.width = width
self.height = height
self.elements = elements
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "Constructor signature for LayoutElements:" in notes[0]
assert "- width: int" in notes[0]
assert "- height: int" in notes[0]
assert "- elements: list[str]" in notes[0]
def test_class_with_defaults(self) -> None:
context = """
class TextRegions:
def __init__(self, text: str = "", max_len: int = 100):
self.text = text
self.max_len = max_len
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "Optional (keyword) arguments:" in notes[0]
assert '- text: str = ""' in notes[0]
assert "- max_len: int = 100" in notes[0]
def test_class_with_args_kwargs(self) -> None:
context = """
class Flexible:
def __init__(self, name: str, *args, **kwargs):
self.name = name
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "- name: str" in notes[0]
assert "- *args" in notes[0]
assert "- **kwargs" in notes[0]
def test_class_without_init_skipped(self) -> None:
context = """
class NoInit:
x = 10
def method(self):
pass
"""
notes = get_class_constructor_notes(context)
assert notes == []
def test_dataclass_skipped(self) -> None:
context = """
@dataclass
class Config:
name: str
value: int
"""
notes = get_class_constructor_notes(context)
assert notes == []
def test_multiple_classes(self) -> None:
context = """
class Foo:
def __init__(self, x: int):
self.x = x
class Bar:
def __init__(self, y: str):
self.y = y
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 2
class_names = " ".join(notes)
assert "Foo" in class_names
assert "Bar" in class_names
def test_markdown_wrapped_code(self) -> None:
context = """
Some description text.
```python:models.py
class LayoutElements:
def __init__(self, width: int, height: int):
self.width = width
self.height = height
```
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "LayoutElements" in notes[0]
assert "- width: int" in notes[0]
def test_syntax_error_returns_empty(self) -> None:
context = "this is not valid python {{{{"
notes = get_class_constructor_notes(context)
assert notes == []
def test_init_with_only_self(self) -> None:
context = """
class Empty:
def __init__(self):
pass
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "Empty() - no parameters" in notes[0]
def test_mixed_dataclass_and_regular(self) -> None:
context = """
@dataclass
class Config:
name: str
class Service:
def __init__(self, config: Config, debug: bool = False):
self.config = config
self.debug = debug
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "Service" in notes[0]
assert "Config" not in notes[0].split("\n")[0]
def test_plain_code_without_markdown(self) -> None:
context = """
class MyClass:
def __init__(self, value: int):
self.value = value
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "MyClass" in notes[0]
def test_full_output_format(self) -> None:
context = """
class Server:
def __init__(self, host: str, port: int, debug: bool = False, workers: int = 4):
self.host = host
self.port = port
self.debug = debug
self.workers = workers
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
expected = """Constructor signature for Server:
Required (positional) arguments:
- host: str
- port: int
Optional (keyword) arguments:
- debug: bool = False
- workers: int = 4"""
assert notes[0] == expected