From ca71d0c8a08cf52c80b9112ffef866841cb158d6 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Fri, 20 Feb 2026 06:50:11 -0500 Subject: [PATCH] refactor: remove constructor notes preprocessing from testgen pipeline Full class source is now included in the client-side testgen context, making the server-side constructor signature extraction redundant. --- .../preprocessing/class_constructor_notes.py | 165 ------ .../dataclass_constructor_notes.py | 176 ------ .../preprocessing/preprocess_pipeline.py | 4 +- .../testgen/test_class_constructor_notes.py | 418 -------------- .../test_dataclass_constructor_notes.py | 529 ------------------ 5 files changed, 1 insertion(+), 1291 deletions(-) delete mode 100644 django/aiservice/core/languages/python/testgen/preprocessing/class_constructor_notes.py delete mode 100644 django/aiservice/core/languages/python/testgen/preprocessing/dataclass_constructor_notes.py delete mode 100644 django/aiservice/tests/testgen/test_class_constructor_notes.py delete mode 100644 django/aiservice/tests/testgen/test_dataclass_constructor_notes.py diff --git a/django/aiservice/core/languages/python/testgen/preprocessing/class_constructor_notes.py b/django/aiservice/core/languages/python/testgen/preprocessing/class_constructor_notes.py deleted file mode 100644 index 5da191dc7..000000000 --- a/django/aiservice/core/languages/python/testgen/preprocessing/class_constructor_notes.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Preprocessing step to extract constructor signatures from regular (non-dataclass) classes. - -When the test context contains class definitions with explicit __init__ methods, -this module extracts the constructor signatures to help the LLM understand how -to properly instantiate them. -""" - -from __future__ import annotations - -from typing import NamedTuple - -import libcst as cst -from libcst.metadata import MetadataWrapper - -from aiservice.common.markdown_utils import extract_all_code_from_markdown -from core.languages.python.cst_utils import get_node_source_text, has_decorator, parse_module_to_cst - - -class ParamInfo(NamedTuple): - name: str - annotation: str | None - has_default: bool - default_repr: str | None - - -def _find_init_method(class_node: cst.ClassDef) -> cst.FunctionDef | None: - for stmt in class_node.body.body: - if isinstance(stmt, cst.FunctionDef) and stmt.name.value == "__init__": - return stmt - return None - - -def extract_init_params( - init_node: cst.FunctionDef, source_lines: list[str], wrapper: MetadataWrapper -) -> list[ParamInfo]: - params: list[ParamInfo] = [] - parameters = init_node.params - - for param in parameters.params: - name = param.name.value - if name == "self": - continue - - annotation = None - if param.annotation is not None: - annotation = get_node_source_text(param.annotation, source_lines, wrapper) - - default_repr = None - if param.default is not None: - default_repr = get_node_source_text(param.default, source_lines, wrapper) - if len(default_repr) > 50: - default_repr = default_repr[:47] + "..." - - params.append(ParamInfo(name, annotation, has_default=param.default is not None, default_repr=default_repr)) - - if isinstance(parameters.star_arg, cst.Param): - name = f"*{parameters.star_arg.name.value}" - annotation = None - if parameters.star_arg.annotation is not None: - annotation = get_node_source_text(parameters.star_arg.annotation, source_lines, wrapper) - params.append(ParamInfo(name, annotation, has_default=False, default_repr=None)) - - for param in parameters.kwonly_params: - name = param.name.value - annotation = None - if param.annotation is not None: - annotation = get_node_source_text(param.annotation, source_lines, wrapper) - - default_repr = None - if param.default is not None: - default_repr = get_node_source_text(param.default, source_lines, wrapper) - if len(default_repr) > 50: - default_repr = default_repr[:47] + "..." - - params.append(ParamInfo(name, annotation, has_default=param.default is not None, default_repr=default_repr)) - - if parameters.star_kwarg is not None: - name = f"**{parameters.star_kwarg.name.value}" - annotation = None - if parameters.star_kwarg.annotation is not None: - annotation = get_node_source_text(parameters.star_kwarg.annotation, source_lines, wrapper) - params.append(ParamInfo(name, annotation, has_default=False, default_repr=None)) - - return params - - -class ClassInitCollector(cst.CSTVisitor): - def __init__(self) -> None: - self.classes: dict[str, cst.ClassDef] = {} - - def visit_ClassDef(self, node: cst.ClassDef) -> bool: - if not has_decorator(node, "dataclass") and _find_init_method(node) is not None: - self.classes[node.name.value] = node - return True - - -def _find_all_classes_with_init(source_code: str) -> tuple[dict[str, cst.ClassDef], list[str], MetadataWrapper | None]: - try: - tree = parse_module_to_cst(source_code) - except cst.ParserSyntaxError: - return {}, [], None - - wrapper = MetadataWrapper(tree) - collector = ClassInitCollector() - wrapper.visit(collector) - source_lines = source_code.split("\n") - return collector.classes, source_lines, wrapper - - -def format_init_signature(class_name: str, params: list[ParamInfo]) -> str: - if not params: - return f"{class_name}() - no parameters" - - required = [p for p in params if not p.has_default and not p.name.startswith("*")] - optional = [p for p in params if p.has_default] - variadic = [p for p in params if p.name.startswith("*")] - - lines = [f"Constructor signature for {class_name}:"] - - if required: - lines.append(" Required (positional) arguments:") - for p in required: - if p.annotation: - lines.append(f" - {p.name}: {p.annotation}") - else: - lines.append(f" - {p.name}") - - if optional: - lines.append(" Optional (keyword) arguments:") - for p in optional: - default_note = f" = {p.default_repr}" if p.default_repr else " = ..." - if p.annotation: - lines.append(f" - {p.name}: {p.annotation}{default_note}") - else: - lines.append(f" - {p.name}{default_note}") - - if variadic: - lines.append(" Variadic arguments:") - for p in variadic: - if p.annotation: - lines.append(f" - {p.name}: {p.annotation}") - else: - lines.append(f" - {p.name}") - - return "\n".join(lines) - - -def get_class_constructor_notes(test_context: str) -> list[str]: - code_to_analyze = extract_all_code_from_markdown(test_context) if "```python" in test_context else test_context - - classes, source_lines, wrapper = _find_all_classes_with_init(code_to_analyze) - - if not classes or wrapper is None: - return [] - - notes = [] - for class_name, class_node in classes.items(): - init_node = _find_init_method(class_node) - if init_node is None: - continue - params = extract_init_params(init_node, source_lines, wrapper) - signature = format_init_signature(class_name, params) - notes.append(signature) - - return notes diff --git a/django/aiservice/core/languages/python/testgen/preprocessing/dataclass_constructor_notes.py b/django/aiservice/core/languages/python/testgen/preprocessing/dataclass_constructor_notes.py deleted file mode 100644 index b1c38a54b..000000000 --- a/django/aiservice/core/languages/python/testgen/preprocessing/dataclass_constructor_notes.py +++ /dev/null @@ -1,176 +0,0 @@ -"""Preprocessing step to extract dataclass constructor signatures. - -When the test context contains dataclass definitions, this module extracts -the constructor signatures to help the LLM understand how to properly -instantiate them, including inherited fields from parent dataclasses. -""" - -from __future__ import annotations - -from typing import NamedTuple - -import libcst as cst -from libcst.metadata import MetadataWrapper - -from core.languages.python.cst_utils import ( - get_base_class_name, - get_node_source_text, - has_decorator, - parse_module_to_cst, -) -from aiservice.common.markdown_utils import extract_all_code_from_markdown - - -class FieldInfo(NamedTuple): - """Information about a dataclass field.""" - - name: str - annotation: str - has_default: bool - default_repr: str | None - source: str # "inherited" or "own" - - -def extract_dataclass_fields( - class_node: cst.ClassDef, source_lines: list[str], wrapper: MetadataWrapper -) -> list[FieldInfo]: - """Extract fields from a dataclass definition. - - Returns list of FieldInfo with (field_name, type_annotation, has_default, default_value). - """ - fields: list[FieldInfo] = [] - for stmt in class_node.body.body: - if not isinstance(stmt, cst.SimpleStatementLine): - continue - for node in stmt.body: - if not (isinstance(node, cst.AnnAssign) and isinstance(node.target, cst.Name)): - continue - - field_name = node.target.value - annotation = get_node_source_text(node.annotation, source_lines, wrapper) - - default_repr = None - if node.value is not None: - default_repr = get_node_source_text(node.value, source_lines, wrapper) - # Truncate long defaults - if len(default_repr) > 50: - default_repr = default_repr[:47] + "..." - - fields.append(FieldInfo(field_name, annotation, node.value is not None, default_repr, "own")) - return fields - - -class DataclassCollector(cst.CSTVisitor): - """Visitor to collect all dataclass definitions.""" - - def __init__(self) -> None: - self.dataclasses: dict[str, cst.ClassDef] = {} - - def visit_ClassDef(self, node: cst.ClassDef) -> bool: - if has_decorator(node, "dataclass"): - self.dataclasses[node.name.value] = node - return True # Continue visiting nested classes - - -def find_all_dataclasses(source_code: str) -> tuple[dict[str, cst.ClassDef], list[str], MetadataWrapper | None]: - """Find all dataclass definitions in source code. - - Returns tuple of (dict mapping class name to class_node, source_lines, metadata_wrapper). - """ - try: - tree = parse_module_to_cst(source_code) - except cst.ParserSyntaxError: - return {}, [], None - - wrapper = MetadataWrapper(tree) - collector = DataclassCollector() - wrapper.visit(collector) - source_lines = source_code.split("\n") - return collector.dataclasses, source_lines, wrapper - - -def get_all_fields_with_inheritance( - class_name: str, dataclasses: dict[str, cst.ClassDef], source_lines: list[str], wrapper: MetadataWrapper -) -> list[FieldInfo]: - """Get all fields for a dataclass, including inherited ones. - - Fields are returned in order: inherited first, then own fields. - """ - if class_name not in dataclasses: - return [] - - class_node = dataclasses[class_name] - - # First, collect inherited fields from parent dataclasses - inherited_fields: list[FieldInfo] = [] - if class_node.bases: - for base in class_node.bases: - base_name = get_base_class_name(base) - if base_name and base_name in dataclasses: - parent_fields = get_all_fields_with_inheritance(base_name, dataclasses, source_lines, wrapper) - # Mark these as inherited - inherited_fields.extend( - FieldInfo(field.name, field.annotation, field.has_default, field.default_repr, "inherited") - for field in parent_fields - ) - - # Then collect own fields - own_fields = extract_dataclass_fields(class_node, source_lines, wrapper) - - return inherited_fields + own_fields - - -def format_constructor_signature(class_name: str, fields: list[FieldInfo]) -> str: - """Format a constructor signature for documentation.""" - if not fields: - return f"{class_name}() - no fields" - - # Separate required and optional fields - required = [f for f in fields if not f.has_default] - optional = [f for f in fields if f.has_default] - - lines = [f"Constructor signature for {class_name}:"] - - if required: - lines.append(" Required (positional) arguments:") - for f in required: - source_note = " (from parent class)" if f.source == "inherited" else "" - lines.append(f" - {f.name}: {f.annotation}{source_note}") - - if optional: - lines.append(" Optional (keyword) arguments:") - for f in optional: - default_note = f" = {f.default_repr}" if f.default_repr else " = ..." - source_note = " (from parent class)" if f.source == "inherited" else "" - lines.append(f" - {f.name}: {f.annotation}{default_note}{source_note}") - - return "\n".join(lines) - - -def get_dataclass_constructor_notes(test_context: str) -> list[str]: - """Generate notes about dataclass constructor signatures. - - Analyzes the test context for dataclass definitions and generates - explicit notes about their constructor signatures, including inherited fields. - - Args: - test_context: The source code context to analyze for dataclass definitions. - - """ - # Extract actual code from markdown if present - code_to_analyze = extract_all_code_from_markdown(test_context) if "```python" in test_context else test_context - - # Find all dataclasses in the context - dataclasses, source_lines, wrapper = find_all_dataclasses(code_to_analyze) - - if not dataclasses or wrapper is None: - return [] - - notes = [] - for class_name in dataclasses: - fields = get_all_fields_with_inheritance(class_name, dataclasses, source_lines, wrapper) - if fields: - signature = format_constructor_signature(class_name, fields) - notes.append(signature) - - return notes diff --git a/django/aiservice/core/languages/python/testgen/preprocessing/preprocess_pipeline.py b/django/aiservice/core/languages/python/testgen/preprocessing/preprocess_pipeline.py index 1e31c8244..e830c5082 100644 --- a/django/aiservice/core/languages/python/testgen/preprocessing/preprocess_pipeline.py +++ b/django/aiservice/core/languages/python/testgen/preprocessing/preprocess_pipeline.py @@ -1,12 +1,10 @@ from itertools import chain -from core.languages.python.testgen.preprocessing.class_constructor_notes import get_class_constructor_notes -from core.languages.python.testgen.preprocessing.dataclass_constructor_notes import get_dataclass_constructor_notes from core.languages.python.testgen.preprocessing.torch_tensor_limit import get_tensor_size_note # Preprocessing functions that analyze code context and return notes # Each function takes test_context (str) and returns a list of notes (list[str]) -_PREPROCESSING_FUNCTIONS = [get_tensor_size_note, get_dataclass_constructor_notes, get_class_constructor_notes] +_PREPROCESSING_FUNCTIONS = [get_tensor_size_note] def preprocessing_testgen_pipeline(test_context: str) -> list[str]: diff --git a/django/aiservice/tests/testgen/test_class_constructor_notes.py b/django/aiservice/tests/testgen/test_class_constructor_notes.py deleted file mode 100644 index ba81cb545..000000000 --- a/django/aiservice/tests/testgen/test_class_constructor_notes.py +++ /dev/null @@ -1,418 +0,0 @@ -"""Tests for regular class constructor signature extraction.""" - -import libcst as cst - -from core.languages.python.testgen.preprocessing.class_constructor_notes import ( - ParamInfo, - _find_all_classes_with_init, - extract_init_params, - format_init_signature, - get_class_constructor_notes, -) - - -class TestClassInitCollector: - def test_collects_class_with_init(self) -> None: - code = """ -class Foo: - def __init__(self, x: int): - self.x = x -""" - classes, _, wrapper = _find_all_classes_with_init(code) - assert "Foo" in classes - assert wrapper is not None - - def test_skips_class_without_init(self) -> None: - code = """ -class Foo: - def bar(self): - pass -""" - classes, _, _ = _find_all_classes_with_init(code) - assert "Foo" not in classes - - def test_skips_dataclass(self) -> None: - code = """ -@dataclass -class Foo: - x: int - def __init__(self, x: int): - self.x = x -""" - classes, _, _ = _find_all_classes_with_init(code) - assert "Foo" not in classes - - def test_skips_dataclass_with_call(self) -> None: - code = """ -@dataclass(frozen=True) -class Foo: - x: int - def __init__(self, x: int): - self.x = x -""" - classes, _, _ = _find_all_classes_with_init(code) - assert "Foo" not in classes - - def test_collects_multiple_classes(self) -> None: - code = """ -class Foo: - def __init__(self, x: int): - self.x = x - -class Bar: - def __init__(self, y: str): - self.y = y -""" - classes, _, _ = _find_all_classes_with_init(code) - assert "Foo" in classes - assert "Bar" in classes - - def test_handles_syntax_error(self) -> None: - code = "this is not valid python code {{{{" - classes, _, wrapper = _find_all_classes_with_init(code) - assert classes == {} - assert wrapper is None - - -class TestExtractInitParams: - def test_typed_params(self) -> None: - code = """ -class Foo: - def __init__(self, x: int, y: str): - pass -""" - classes, source_lines, wrapper = _find_all_classes_with_init(code) - assert wrapper is not None - init_node = next(iter(classes["Foo"].body.body)) - assert isinstance(init_node, cst.FunctionDef) - params = extract_init_params(init_node, source_lines, wrapper) - - assert len(params) == 2 - assert params[0].name == "x" - assert params[0].annotation == "int" - assert not params[0].has_default - assert params[1].name == "y" - assert params[1].annotation == "str" - assert not params[1].has_default - - def test_params_with_defaults(self) -> None: - code = """ -class Foo: - def __init__(self, x: int = 0, y: str = "hello"): - pass -""" - classes, source_lines, wrapper = _find_all_classes_with_init(code) - assert wrapper is not None - init_node = next(iter(classes["Foo"].body.body)) - assert isinstance(init_node, cst.FunctionDef) - params = extract_init_params(init_node, source_lines, wrapper) - - assert len(params) == 2 - assert params[0].name == "x" - assert params[0].has_default - assert params[0].default_repr == "0" - assert params[1].name == "y" - assert params[1].has_default - assert params[1].default_repr == '"hello"' - - def test_args_and_kwargs(self) -> None: - code = """ -class Foo: - def __init__(self, x: int, *args, **kwargs): - pass -""" - classes, source_lines, wrapper = _find_all_classes_with_init(code) - assert wrapper is not None - init_node = next(iter(classes["Foo"].body.body)) - assert isinstance(init_node, cst.FunctionDef) - params = extract_init_params(init_node, source_lines, wrapper) - - assert len(params) == 3 - assert params[0].name == "x" - assert params[1].name == "*args" - assert params[2].name == "**kwargs" - - def test_skips_self(self) -> None: - code = """ -class Foo: - def __init__(self): - pass -""" - classes, source_lines, wrapper = _find_all_classes_with_init(code) - assert wrapper is not None - init_node = next(iter(classes["Foo"].body.body)) - assert isinstance(init_node, cst.FunctionDef) - params = extract_init_params(init_node, source_lines, wrapper) - - assert len(params) == 0 - - def test_untyped_params(self) -> None: - code = """ -class Foo: - def __init__(self, x, y=10): - pass -""" - classes, source_lines, wrapper = _find_all_classes_with_init(code) - assert wrapper is not None - init_node = next(iter(classes["Foo"].body.body)) - assert isinstance(init_node, cst.FunctionDef) - params = extract_init_params(init_node, source_lines, wrapper) - - assert len(params) == 2 - assert params[0].name == "x" - assert params[0].annotation is None - assert params[1].name == "y" - assert params[1].annotation is None - assert params[1].has_default - assert params[1].default_repr == "10" - - def test_kwonly_params(self) -> None: - code = """ -class Foo: - def __init__(self, x: int, *, key: str = "default"): - pass -""" - classes, source_lines, wrapper = _find_all_classes_with_init(code) - assert wrapper is not None - init_node = next(iter(classes["Foo"].body.body)) - assert isinstance(init_node, cst.FunctionDef) - params = extract_init_params(init_node, source_lines, wrapper) - - assert len(params) == 2 - assert params[0].name == "x" - assert params[1].name == "key" - assert params[1].annotation == "str" - assert params[1].has_default - assert params[1].default_repr == '"default"' - - def test_long_default_truncated(self) -> None: - code = """ -class Foo: - def __init__(self, x: list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]): - pass -""" - classes, source_lines, wrapper = _find_all_classes_with_init(code) - assert wrapper is not None - init_node = next(iter(classes["Foo"].body.body)) - assert isinstance(init_node, cst.FunctionDef) - params = extract_init_params(init_node, source_lines, wrapper) - - assert len(params) == 1 - assert params[0].has_default - assert params[0].default_repr is not None - assert len(params[0].default_repr) <= 50 - assert params[0].default_repr.endswith("...") - - -class TestFormatInitSignature: - def test_required_only(self) -> None: - params = [ - ParamInfo(name="x", annotation="int", has_default=False, default_repr=None), - ParamInfo(name="y", annotation="str", has_default=False, default_repr=None), - ] - result = format_init_signature("Foo", params) - assert "Constructor signature for Foo:" in result - assert "Required (positional) arguments:" in result - assert "- x: int" in result - assert "- y: str" in result - assert "Optional" not in result - - def test_optional_only(self) -> None: - params = [ - ParamInfo(name="x", annotation="int", has_default=True, default_repr="0"), - ParamInfo(name="y", annotation="str", has_default=True, default_repr='"hi"'), - ] - result = format_init_signature("Foo", params) - assert "Constructor signature for Foo:" in result - assert "Optional (keyword) arguments:" in result - assert "- x: int = 0" in result - assert '- y: str = "hi"' in result - assert "Required" not in result - - def test_mixed_params(self) -> None: - params = [ - ParamInfo(name="x", annotation="int", has_default=False, default_repr=None), - ParamInfo(name="y", annotation="str", has_default=True, default_repr='"default"'), - ] - result = format_init_signature("Foo", params) - assert "Required (positional) arguments:" in result - assert "Optional (keyword) arguments:" in result - - def test_no_params(self) -> None: - result = format_init_signature("Empty", []) - assert "Empty() - no parameters" in result - - def test_variadic_params(self) -> None: - params = [ - ParamInfo(name="x", annotation="int", has_default=False, default_repr=None), - ParamInfo(name="*args", annotation=None, has_default=False, default_repr=None), - ParamInfo(name="**kwargs", annotation=None, has_default=False, default_repr=None), - ] - result = format_init_signature("Foo", params) - assert "Required (positional) arguments:" in result - assert "Variadic arguments:" in result - assert "- *args" in result - assert "- **kwargs" in result - - def test_untyped_params(self) -> None: - params = [ - ParamInfo(name="x", annotation=None, has_default=False, default_repr=None), - ParamInfo(name="y", annotation=None, has_default=True, default_repr="10"), - ] - result = format_init_signature("Foo", params) - assert " - x\n" in result - assert " - y = 10" in result - - -class TestGetClassConstructorNotes: - def test_class_with_typed_init(self) -> None: - context = """ -class LayoutElements: - def __init__(self, width: int, height: int, elements: list[str]): - self.width = width - self.height = height - self.elements = elements -""" - notes = get_class_constructor_notes(context) - assert len(notes) == 1 - assert "Constructor signature for LayoutElements:" in notes[0] - assert "- width: int" in notes[0] - assert "- height: int" in notes[0] - assert "- elements: list[str]" in notes[0] - - def test_class_with_defaults(self) -> None: - context = """ -class TextRegions: - def __init__(self, text: str = "", max_len: int = 100): - self.text = text - self.max_len = max_len -""" - notes = get_class_constructor_notes(context) - assert len(notes) == 1 - assert "Optional (keyword) arguments:" in notes[0] - assert '- text: str = ""' in notes[0] - assert "- max_len: int = 100" in notes[0] - - def test_class_with_args_kwargs(self) -> None: - context = """ -class Flexible: - def __init__(self, name: str, *args, **kwargs): - self.name = name -""" - notes = get_class_constructor_notes(context) - assert len(notes) == 1 - assert "- name: str" in notes[0] - assert "- *args" in notes[0] - assert "- **kwargs" in notes[0] - - def test_class_without_init_skipped(self) -> None: - context = """ -class NoInit: - x = 10 - def method(self): - pass -""" - notes = get_class_constructor_notes(context) - assert notes == [] - - def test_dataclass_skipped(self) -> None: - context = """ -@dataclass -class Config: - name: str - value: int -""" - notes = get_class_constructor_notes(context) - assert notes == [] - - def test_multiple_classes(self) -> None: - context = """ -class Foo: - def __init__(self, x: int): - self.x = x - -class Bar: - def __init__(self, y: str): - self.y = y -""" - notes = get_class_constructor_notes(context) - assert len(notes) == 2 - class_names = " ".join(notes) - assert "Foo" in class_names - assert "Bar" in class_names - - def test_markdown_wrapped_code(self) -> None: - context = """ -Some description text. -```python:models.py -class LayoutElements: - def __init__(self, width: int, height: int): - self.width = width - self.height = height -``` -""" - notes = get_class_constructor_notes(context) - assert len(notes) == 1 - assert "LayoutElements" in notes[0] - assert "- width: int" in notes[0] - - def test_syntax_error_returns_empty(self) -> None: - context = "this is not valid python {{{{" - notes = get_class_constructor_notes(context) - assert notes == [] - - def test_init_with_only_self(self) -> None: - context = """ -class Empty: - def __init__(self): - pass -""" - notes = get_class_constructor_notes(context) - assert len(notes) == 1 - assert "Empty() - no parameters" in notes[0] - - def test_mixed_dataclass_and_regular(self) -> None: - context = """ -@dataclass -class Config: - name: str - -class Service: - def __init__(self, config: Config, debug: bool = False): - self.config = config - self.debug = debug -""" - notes = get_class_constructor_notes(context) - assert len(notes) == 1 - assert "Service" in notes[0] - assert "Config" not in notes[0].split("\n")[0] - - def test_plain_code_without_markdown(self) -> None: - context = """ -class MyClass: - def __init__(self, value: int): - self.value = value -""" - notes = get_class_constructor_notes(context) - assert len(notes) == 1 - assert "MyClass" in notes[0] - - def test_full_output_format(self) -> None: - context = """ -class Server: - def __init__(self, host: str, port: int, debug: bool = False, workers: int = 4): - self.host = host - self.port = port - self.debug = debug - self.workers = workers -""" - notes = get_class_constructor_notes(context) - assert len(notes) == 1 - expected = """Constructor signature for Server: - Required (positional) arguments: - - host: str - - port: int - Optional (keyword) arguments: - - debug: bool = False - - workers: int = 4""" - assert notes[0] == expected diff --git a/django/aiservice/tests/testgen/test_dataclass_constructor_notes.py b/django/aiservice/tests/testgen/test_dataclass_constructor_notes.py deleted file mode 100644 index d7ebfd788..000000000 --- a/django/aiservice/tests/testgen/test_dataclass_constructor_notes.py +++ /dev/null @@ -1,529 +0,0 @@ -"""Tests for dataclass constructor signature extraction.""" - -import libcst as cst - -from core.languages.python.cst_utils import get_base_class_name, has_decorator -from aiservice.common.markdown_utils import extract_all_code_from_markdown -from core.languages.python.testgen.preprocessing.dataclass_constructor_notes import ( - FieldInfo, - extract_dataclass_fields, - find_all_dataclasses, - format_constructor_signature, - get_all_fields_with_inheritance, - get_dataclass_constructor_notes, -) - - -class TestHasDecorator: - def test_simple_dataclass_decorator(self) -> None: - code = """ -@dataclass -class Foo: - x: int -""" - tree = cst.parse_module(code) - class_node = tree.body[0] - assert isinstance(class_node, cst.ClassDef) - assert has_decorator(class_node, "dataclass") - - def test_dataclass_with_call(self) -> None: - code = """ -@dataclass(frozen=True) -class Foo: - x: int -""" - tree = cst.parse_module(code) - class_node = tree.body[0] - assert isinstance(class_node, cst.ClassDef) - assert has_decorator(class_node, "dataclass") - - def test_not_a_dataclass(self) -> None: - code = """ -class Foo: - x: int -""" - tree = cst.parse_module(code) - class_node = tree.body[0] - assert isinstance(class_node, cst.ClassDef) - assert not has_decorator(class_node, "dataclass") - - def test_other_decorator(self) -> None: - code = """ -@other_decorator -class Foo: - x: int -""" - tree = cst.parse_module(code) - class_node = tree.body[0] - assert isinstance(class_node, cst.ClassDef) - assert not has_decorator(class_node, "dataclass") - - -class TestGetBaseClassName: - def test_simple_name(self) -> None: - code = "class Foo(Bar): pass" - tree = cst.parse_module(code) - class_node = tree.body[0] - assert isinstance(class_node, cst.ClassDef) - assert get_base_class_name(class_node.bases[0]) == "Bar" - - def test_attribute(self) -> None: - code = "class Foo(module.Bar): pass" - tree = cst.parse_module(code) - class_node = tree.body[0] - assert isinstance(class_node, cst.ClassDef) - assert get_base_class_name(class_node.bases[0]) == "Bar" - - -class TestExtractDataclassFields: - def test_simple_fields(self) -> None: - code = """ -@dataclass -class Foo: - x: int - y: str -""" - dataclasses, source_lines, wrapper = find_all_dataclasses(code) - assert wrapper is not None - class_node = dataclasses["Foo"] - fields = extract_dataclass_fields(class_node, source_lines, wrapper) - - assert len(fields) == 2 - assert fields[0].name == "x" - assert fields[0].annotation == "int" - assert not fields[0].has_default - assert fields[1].name == "y" - assert fields[1].annotation == "str" - assert not fields[1].has_default - - def test_fields_with_defaults(self) -> None: - code = """ -@dataclass -class Foo: - x: int - y: str = "default" -""" - dataclasses, source_lines, wrapper = find_all_dataclasses(code) - assert wrapper is not None - class_node = dataclasses["Foo"] - fields = extract_dataclass_fields(class_node, source_lines, wrapper) - - assert len(fields) == 2 - assert fields[0].name == "x" - assert not fields[0].has_default - assert fields[1].name == "y" - assert fields[1].has_default - assert fields[1].default_repr == '"default"' - - def test_complex_annotation(self) -> None: - code = """ -@dataclass -class Foo: - items: list[str] - mapping: dict[str, int] | None = None -""" - dataclasses, source_lines, wrapper = find_all_dataclasses(code) - assert wrapper is not None - class_node = dataclasses["Foo"] - fields = extract_dataclass_fields(class_node, source_lines, wrapper) - - assert len(fields) == 2 - assert fields[0].name == "items" - assert fields[0].annotation == "list[str]" - assert fields[1].name == "mapping" - assert "dict[str, int]" in fields[1].annotation - - -class TestFindAllDataclasses: - def test_finds_single_dataclass(self) -> None: - code = """ -from dataclasses import dataclass - -@dataclass -class Foo: - x: int -""" - result, source_lines, wrapper = find_all_dataclasses(code) - assert "Foo" in result - assert wrapper is not None - assert len(source_lines) > 0 - - def test_finds_multiple_dataclasses(self) -> None: - code = """ -from dataclasses import dataclass - -@dataclass -class Foo: - x: int - -@dataclass -class Bar: - y: str -""" - result, _source_lines, wrapper = find_all_dataclasses(code) - assert "Foo" in result - assert "Bar" in result - assert wrapper is not None - - def test_ignores_non_dataclasses(self) -> None: - code = """ -class NotADataclass: - x: int - -@dataclass -class IsADataclass: - y: str -""" - result, _source_lines, wrapper = find_all_dataclasses(code) - assert "NotADataclass" not in result - assert "IsADataclass" in result - assert wrapper is not None - - def test_handles_syntax_error(self) -> None: - code = "this is not valid python code {{{{" - result, _source_lines, wrapper = find_all_dataclasses(code) - assert result == {} - assert wrapper is None - - -class TestGetAllFieldsWithInheritance: - def test_no_inheritance(self) -> None: - code = """ -@dataclass -class Foo: - x: int - y: str -""" - dataclasses, source_lines, wrapper = find_all_dataclasses(code) - assert wrapper is not None - fields = get_all_fields_with_inheritance("Foo", dataclasses, source_lines, wrapper) - - assert len(fields) == 2 - assert all(f.source == "own" for f in fields) - - def test_simple_inheritance(self) -> None: - code = """ -@dataclass -class Base: - a: int - b: str - -@dataclass -class Child(Base): - c: float -""" - dataclasses, source_lines, wrapper = find_all_dataclasses(code) - assert wrapper is not None - fields = get_all_fields_with_inheritance("Child", dataclasses, source_lines, wrapper) - - assert len(fields) == 3 - assert fields[0].name == "a" - assert fields[0].source == "inherited" - assert fields[1].name == "b" - assert fields[1].source == "inherited" - assert fields[2].name == "c" - assert fields[2].source == "own" - - def test_multi_level_inheritance(self) -> None: - code = """ -@dataclass -class GrandParent: - a: int - -@dataclass -class Parent(GrandParent): - b: str - -@dataclass -class Child(Parent): - c: float -""" - dataclasses, source_lines, wrapper = find_all_dataclasses(code) - assert wrapper is not None - fields = get_all_fields_with_inheritance("Child", dataclasses, source_lines, wrapper) - - assert len(fields) == 3 - field_names = [f.name for f in fields] - assert "a" in field_names - assert "b" in field_names - assert "c" in field_names - - def test_unknown_class(self) -> None: - code = """ -@dataclass -class Foo: - x: int -""" - dataclasses, source_lines, wrapper = find_all_dataclasses(code) - assert wrapper is not None - fields = get_all_fields_with_inheritance("Unknown", dataclasses, source_lines, wrapper) - assert fields == [] - - -class TestFormatConstructorSignature: - def test_required_only(self) -> None: - fields = [ - FieldInfo(name="x", annotation="int", has_default=False, default_repr=None, source="own"), - FieldInfo(name="y", annotation="str", has_default=False, default_repr=None, source="own"), - ] - result = format_constructor_signature("Foo", fields) - assert "Constructor signature for Foo:" in result - assert "Required (positional) arguments:" in result - assert "- x: int" in result - assert "- y: str" in result - assert "Optional" not in result - - def test_optional_only(self) -> None: - fields = [ - FieldInfo(name="x", annotation="int", has_default=True, default_repr="0", source="own"), - FieldInfo(name="y", annotation="str", has_default=True, default_repr='"default"', source="own"), - ] - result = format_constructor_signature("Foo", fields) - assert "Constructor signature for Foo:" in result - assert "Optional (keyword) arguments:" in result - assert "- x: int = 0" in result - assert '- y: str = "default"' in result - assert "Required" not in result - - def test_mixed_fields(self) -> None: - fields = [ - FieldInfo(name="a", annotation="int", has_default=False, default_repr=None, source="inherited"), - FieldInfo(name="b", annotation="str", has_default=False, default_repr=None, source="own"), - FieldInfo(name="c", annotation="float", has_default=True, default_repr="0.0", source="own"), - ] - result = format_constructor_signature("Foo", fields) - assert "Required (positional) arguments:" in result - assert "Optional (keyword) arguments:" in result - assert "(from parent class)" in result - - def test_no_fields(self) -> None: - result = format_constructor_signature("Empty", []) - assert "no fields" in result - - -class TestExtractCodeFromMarkdown: - def test_simple_code_block(self) -> None: - markdown = """ -Some text -```python -def foo(): - pass -``` -More text -""" - result = extract_all_code_from_markdown(markdown) - assert "def foo():" in result - assert "pass" in result - assert "Some text" not in result - - def test_code_block_with_filepath(self) -> None: - markdown = """ -```python:path/to/file.py -class Foo: - x: int -``` -""" - result = extract_all_code_from_markdown(markdown) - assert "class Foo:" in result - assert "x: int" in result - - def test_multiple_code_blocks(self) -> None: - markdown = """ -```python:file1.py -@dataclass -class A: - x: int -``` - -```python:file2.py -@dataclass -class B: - y: str -``` -""" - result = extract_all_code_from_markdown(markdown) - assert "class A:" in result - assert "class B:" in result - - -class TestGetDataclassConstructorNotes: - def test_simple_dataclass(self) -> None: - context = """ -```python:models.py -from dataclasses import dataclass - -@dataclass -class Config: - name: str - value: int -``` -""" - notes = get_dataclass_constructor_notes(context) - assert len(notes) == 1 - assert "Constructor signature for Config:" in notes[0] - assert "- name: str" in notes[0] - assert "- value: int" in notes[0] - - def test_dataclass_with_inheritance(self) -> None: - context = """ -```python:models.py -from dataclasses import dataclass - -@dataclass -class BaseConfig: - model_name: str - required_env_vars: list[str] - -@dataclass -class ExtendedConfig(BaseConfig): - extra_param: int = 0 -``` -""" - notes = get_dataclass_constructor_notes(context) - assert len(notes) == 2 - - # Find the note for ExtendedConfig - extended_note = next(n for n in notes if "ExtendedConfig" in n) - assert "- model_name: str (from parent class)" in extended_note - assert "- required_env_vars: list[str] (from parent class)" in extended_note - assert "- extra_param: int = 0" in extended_note - - def test_no_dataclasses(self) -> None: - context = """ -def regular_function(): - pass -""" - notes = get_dataclass_constructor_notes(context) - assert notes == [] - - def test_plain_code_without_markdown(self) -> None: - context = """ -from dataclasses import dataclass - -@dataclass -class SimpleConfig: - name: str -""" - notes = get_dataclass_constructor_notes(context) - assert len(notes) == 1 - assert "SimpleConfig" in notes[0] - - def test_llm_config_like_structure(self) -> None: - """Test with a structure similar to the skyvern LLMConfig.""" - context = """ -```python:skyvern/forge/sdk/api/llm/models.py -from dataclasses import dataclass, field -from typing import Optional - -@dataclass(frozen=True) -class LLMConfigBase: - model_name: str - required_env_vars: list[str] - supports_vision: bool - add_assistant_prefix: bool - -@dataclass(frozen=True) -class LLMConfig(LLMConfigBase): - litellm_params: Optional[dict] = field(default=None) - max_tokens: int | None = 4096 -``` -""" - notes = get_dataclass_constructor_notes(context) - - # Should have notes for both classes - assert len(notes) == 2 - - # Find the LLMConfig note - llm_config_note = next(n for n in notes if "LLMConfig:" in n) - - # Should show inherited fields from LLMConfigBase - assert "- model_name: str (from parent class)" in llm_config_note - assert "- required_env_vars: list[str] (from parent class)" in llm_config_note - assert "- supports_vision: bool (from parent class)" in llm_config_note - assert "- add_assistant_prefix: bool (from parent class)" in llm_config_note - - # Should also show own fields - assert "- litellm_params:" in llm_config_note - assert "- max_tokens:" in llm_config_note - - def test_full_output_for_complex_inheritance(self) -> None: - """Test full generated notes output for complex dataclass inheritance.""" - context = """ -```python:skyvern/forge/sdk/api/llm/models.py -from dataclasses import dataclass, field -from typing import Optional - -@dataclass(frozen=True) -class LLMConfigBase: - model_name: str - required_env_vars: list[str] - supports_vision: bool - add_assistant_prefix: bool - -@dataclass(frozen=True) -class LLMConfig(LLMConfigBase): - litellm_params: Optional[dict] = field(default=None) - max_tokens: int | None = 4096 - max_completion_tokens: int | None = None - temperature: float | None = 0.7 - reasoning_effort: str | None = None - -@dataclass(frozen=True) -class LLMRouterConfig(LLMConfigBase): - model_list: list - main_model_group: str - redis_host: str | None = None - redis_port: int | None = None - fallback_model_group: str | None = None - routing_strategy: str = "usage-based-routing" - num_retries: int = 1 -``` -""" - notes = get_dataclass_constructor_notes(context) - - # Should have notes for all 3 classes - assert len(notes) == 3 - - # Verify LLMConfigBase note (no inheritance) - base_note = next(n for n in notes if "LLMConfigBase:" in n) - expected_base = """Constructor signature for LLMConfigBase: - Required (positional) arguments: - - model_name: str - - required_env_vars: list[str] - - supports_vision: bool - - add_assistant_prefix: bool""" - assert base_note == expected_base - - # Verify LLMConfig note (with inheritance) - config_note = next(n for n in notes if "LLMConfig:" in n) - expected_config = """Constructor signature for LLMConfig: - Required (positional) arguments: - - model_name: str (from parent class) - - required_env_vars: list[str] (from parent class) - - supports_vision: bool (from parent class) - - add_assistant_prefix: bool (from parent class) - Optional (keyword) arguments: - - litellm_params: Optional[dict] = field(default=None) - - max_tokens: int | None = 4096 - - max_completion_tokens: int | None = None - - temperature: float | None = 0.7 - - reasoning_effort: str | None = None""" - assert config_note == expected_config - - # Verify LLMRouterConfig note (with inheritance + own required fields) - router_note = next(n for n in notes if "LLMRouterConfig:" in n) - expected_router = """Constructor signature for LLMRouterConfig: - Required (positional) arguments: - - model_name: str (from parent class) - - required_env_vars: list[str] (from parent class) - - supports_vision: bool (from parent class) - - add_assistant_prefix: bool (from parent class) - - model_list: list - - main_model_group: str - Optional (keyword) arguments: - - redis_host: str | None = None - - redis_port: int | None = None - - fallback_model_group: str | None = None - - routing_strategy: str = "usage-based-routing" - - num_retries: int = 1""" - assert router_note == expected_router