fix: track class __init__ as helper when class is instantiated

Ensures LLM sees constructor signatures for proper test generation.
This commit is contained in:
Kevin Turcios 2026-01-01 02:25:10 -05:00
parent 0d84ab62dd
commit 3048ece4da
2 changed files with 206 additions and 21 deletions

View file

@ -446,31 +446,45 @@ def get_function_sources_from_jedi(
definition_path = definition.module_path definition_path = definition.module_path
# The definition is part of this project and not defined within the original function # The definition is part of this project and not defined within the original function
if ( is_valid_definition = (
str(definition_path).startswith(str(project_root_path) + os.sep) str(definition_path).startswith(str(project_root_path) + os.sep)
and not path_belongs_to_site_packages(definition_path) and not path_belongs_to_site_packages(definition_path)
and definition.full_name and definition.full_name
and definition.type == "function"
and not belongs_to_function_qualified(definition, qualified_function_name) and not belongs_to_function_qualified(definition, qualified_function_name)
and definition.full_name.startswith(definition.module_name) and definition.full_name.startswith(definition.module_name)
)
if is_valid_definition and definition.type == "function":
qualified_name = get_qualified_name(definition.module_name, definition.full_name)
# Avoid nested functions or classes. Only class.function is allowed # Avoid nested functions or classes. Only class.function is allowed
and len( if len(qualified_name.split(".")) <= 2:
(qualified_name := get_qualified_name(definition.module_name, definition.full_name)).split( function_source = FunctionSource(
"." file_path=definition_path,
qualified_name=qualified_name,
fully_qualified_name=definition.full_name,
only_function_name=definition.name,
source_code=definition.get_line_code(),
jedi_definition=definition,
) )
file_path_to_function_source[definition_path].add(function_source)
function_source_list.append(function_source)
# When a class is instantiated (e.g., MyClass()), track its __init__ as a helper
# This ensures the class definition with constructor is included in testgen context
elif is_valid_definition and definition.type == "class":
init_qualified_name = get_qualified_name(
definition.module_name, f"{definition.full_name}.__init__"
) )
<= 2 # Only include if it's a top-level class (not nested)
): if len(init_qualified_name.split(".")) <= 2:
function_source = FunctionSource( function_source = FunctionSource(
file_path=definition_path, file_path=definition_path,
qualified_name=qualified_name, qualified_name=init_qualified_name,
fully_qualified_name=definition.full_name, fully_qualified_name=f"{definition.full_name}.__init__",
only_function_name=definition.name, only_function_name="__init__",
source_code=definition.get_line_code(), source_code=definition.get_line_code(),
jedi_definition=definition, jedi_definition=definition,
) )
file_path_to_function_source[definition_path].add(function_source) file_path_to_function_source[definition_path].add(function_source)
function_source_list.append(function_source) function_source_list.append(function_source)
return file_path_to_function_source, function_source_list return file_path_to_function_source, function_source_list
@ -647,7 +661,10 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
if isinstance(node, cst.FunctionDef): if isinstance(node, cst.FunctionDef):
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
if qualified_name in target_functions: # For hashing, exclude __init__ methods even if in target_functions
# because they don't affect the semantic behavior being hashed
# But include other dunder methods like __call__ which do affect behavior
if qualified_name in target_functions and node.name.value != "__init__":
new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body
return node.with_changes(body=new_body), True return node.with_changes(body=new_body), True
return None, False return None, False
@ -666,7 +683,9 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
for stmt in node.body.body: for stmt in node.body.body:
if isinstance(stmt, cst.FunctionDef): if isinstance(stmt, cst.FunctionDef):
qualified_name = f"{class_prefix}.{stmt.name.value}" qualified_name = f"{class_prefix}.{stmt.name.value}"
if qualified_name in target_functions: # For hashing, exclude __init__ methods even if in target_functions
# but include other methods like __call__ which affect behavior
if qualified_name in target_functions and stmt.name.value != "__init__":
stmt_with_changes = stmt.with_changes( stmt_with_changes = stmt.with_changes(
body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body)) body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body))
) )

View file

@ -84,7 +84,8 @@ def test_code_replacement10() -> None:
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent) code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
qualified_names = {func.qualified_name for func in code_ctx.helper_functions} qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here # HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class
assert qualified_names == {"HelperClass.helper_method", "HelperClass.__init__"} # Nested method should not be in here
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context hashing_context = code_ctx.hashing_code_context
@ -570,6 +571,8 @@ _STORE_T = TypeVar("_STORE_T")
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
"""Interface for cache backends used by the persistent cache decorator.""" """Interface for cache backends used by the persistent cache decorator."""
def __init__(self) -> None: ...
def hash_key( def hash_key(
self, self,
*, *,
@ -1296,6 +1299,8 @@ class DataProcessor:
``` ```
```python:{path_to_transform_utils.relative_to(project_root)} ```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer: class DataTransformer:
def __init__(self):
self.data = None
def transform(self, data): def transform(self, data):
self.data = data self.data = data
@ -1599,7 +1604,11 @@ class DataProcessor:
\"\"\"Return a string representation of the DataProcessor.\"\"\" \"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})" return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
``` ```
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
```
""" """
expected_hashing_context = f""" expected_hashing_context = f"""
```python:utils.py ```python:utils.py
@ -1705,6 +1714,7 @@ def test_direct_module_import() -> None:
expected_read_only_context = """ expected_read_only_context = """
```python:utils.py ```python:utils.py
import math
from transform_utils import DataTransformer from transform_utils import DataTransformer
class DataProcessor: class DataProcessor:
@ -1712,6 +1722,11 @@ class DataProcessor:
number = 1 number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str: def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\" \"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={self.default_prefix!r})" return f"DataProcessor(default_prefix={self.default_prefix!r})"
@ -2727,3 +2742,154 @@ FINAL_ASSIGNMENT = {"data": "value"}
# Verify correct order # Verify correct order
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"] expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
assert collector.assignment_order == expected_order assert collector.assignment_order == expected_order
def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
"""Test that when a class is instantiated, its __init__ method is tracked as a helper.
This test verifies the fix for the bug where class constructors were not
included in the context when only the class instantiation was called
(not any other methods). This caused LLMs to not know the constructor
signatures when generating tests.
"""
code = '''
class DataDumper:
"""A class that dumps data."""
def __init__(self, data):
"""Initialize with data."""
self.data = data
def dump(self):
"""Dump the data."""
return self.data
def target_function():
# Only instantiates DataDumper, doesn't call any other methods
dumper = DataDumper({"key": "value"})
return dumper
'''
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_function",
file_path=file_path,
parents=[],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
# The __init__ method should be tracked as a helper since DataDumper() instantiates the class
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
assert "DataDumper.__init__" in qualified_names, (
"DataDumper.__init__ should be tracked as a helper when the class is instantiated"
)
# The testgen context should contain the class with __init__ (critical for LLM to know constructor)
testgen_context = code_ctx.testgen_context.markdown
assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context"
assert "def __init__(self, data):" in testgen_context, (
"__init__ method should be included in testgen context"
)
# The hashing context should NOT contain __init__ (excluded for stability)
hashing_context = code_ctx.hashing_code_context
assert "__init__" not in hashing_context, (
"__init__ should NOT be in hashing context (excluded for hash stability)"
)
def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None:
"""Test that instantiated classes are fully preserved in testgen context.
This is specifically for the unstructured LayoutDumper bug where helper classes
that were instantiated but had no other methods called were being excluded
from the testgen context.
"""
code = '''
class LayoutDumper:
"""Base class for layout dumpers."""
layout_source: str = "unknown"
def __init__(self, layout):
self._layout = layout
def dump(self) -> dict:
raise NotImplementedError()
class ObjectDetectionLayoutDumper(LayoutDumper):
"""Specific dumper for object detection layouts."""
def __init__(self, layout):
super().__init__(layout)
def dump(self) -> dict:
return {"type": "object_detection", "layout": self._layout}
def dump_layout(layout_type, layout):
"""Dump a layout based on its type."""
if layout_type == "object_detection":
dumper = ObjectDetectionLayoutDumper(layout)
else:
dumper = LayoutDumper(layout)
return dumper.dump()
'''
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="dump_layout",
file_path=file_path,
parents=[],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
# Both class __init__ methods should be tracked as helpers
assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, (
"ObjectDetectionLayoutDumper.__init__ should be tracked"
)
assert "LayoutDumper.__init__" in qualified_names, (
"LayoutDumper.__init__ should be tracked"
)
# The testgen context should include both classes with their __init__ methods
testgen_context = code_ctx.testgen_context.markdown
assert "class LayoutDumper:" in testgen_context, "LayoutDumper should be in testgen context"
assert "class ObjectDetectionLayoutDumper" in testgen_context, (
"ObjectDetectionLayoutDumper should be in testgen context"
)
# Both __init__ methods should be in the testgen context (so LLM knows constructor signatures)
assert testgen_context.count("def __init__") >= 2, (
"Both __init__ methods should be in testgen context"
)