fix: track class __init__ as helper when class is instantiated
Ensures LLM sees constructor signatures for proper test generation.
This commit is contained in:
parent
0d84ab62dd
commit
3048ece4da
2 changed files with 206 additions and 21 deletions
|
|
@ -446,31 +446,45 @@ def get_function_sources_from_jedi(
|
||||||
definition_path = definition.module_path
|
definition_path = definition.module_path
|
||||||
|
|
||||||
# The definition is part of this project and not defined within the original function
|
# The definition is part of this project and not defined within the original function
|
||||||
if (
|
is_valid_definition = (
|
||||||
str(definition_path).startswith(str(project_root_path) + os.sep)
|
str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||||
and not path_belongs_to_site_packages(definition_path)
|
and not path_belongs_to_site_packages(definition_path)
|
||||||
and definition.full_name
|
and definition.full_name
|
||||||
and definition.type == "function"
|
|
||||||
and not belongs_to_function_qualified(definition, qualified_function_name)
|
and not belongs_to_function_qualified(definition, qualified_function_name)
|
||||||
and definition.full_name.startswith(definition.module_name)
|
and definition.full_name.startswith(definition.module_name)
|
||||||
|
)
|
||||||
|
if is_valid_definition and definition.type == "function":
|
||||||
|
qualified_name = get_qualified_name(definition.module_name, definition.full_name)
|
||||||
# Avoid nested functions or classes. Only class.function is allowed
|
# Avoid nested functions or classes. Only class.function is allowed
|
||||||
and len(
|
if len(qualified_name.split(".")) <= 2:
|
||||||
(qualified_name := get_qualified_name(definition.module_name, definition.full_name)).split(
|
function_source = FunctionSource(
|
||||||
"."
|
file_path=definition_path,
|
||||||
|
qualified_name=qualified_name,
|
||||||
|
fully_qualified_name=definition.full_name,
|
||||||
|
only_function_name=definition.name,
|
||||||
|
source_code=definition.get_line_code(),
|
||||||
|
jedi_definition=definition,
|
||||||
)
|
)
|
||||||
|
file_path_to_function_source[definition_path].add(function_source)
|
||||||
|
function_source_list.append(function_source)
|
||||||
|
# When a class is instantiated (e.g., MyClass()), track its __init__ as a helper
|
||||||
|
# This ensures the class definition with constructor is included in testgen context
|
||||||
|
elif is_valid_definition and definition.type == "class":
|
||||||
|
init_qualified_name = get_qualified_name(
|
||||||
|
definition.module_name, f"{definition.full_name}.__init__"
|
||||||
)
|
)
|
||||||
<= 2
|
# Only include if it's a top-level class (not nested)
|
||||||
):
|
if len(init_qualified_name.split(".")) <= 2:
|
||||||
function_source = FunctionSource(
|
function_source = FunctionSource(
|
||||||
file_path=definition_path,
|
file_path=definition_path,
|
||||||
qualified_name=qualified_name,
|
qualified_name=init_qualified_name,
|
||||||
fully_qualified_name=definition.full_name,
|
fully_qualified_name=f"{definition.full_name}.__init__",
|
||||||
only_function_name=definition.name,
|
only_function_name="__init__",
|
||||||
source_code=definition.get_line_code(),
|
source_code=definition.get_line_code(),
|
||||||
jedi_definition=definition,
|
jedi_definition=definition,
|
||||||
)
|
)
|
||||||
file_path_to_function_source[definition_path].add(function_source)
|
file_path_to_function_source[definition_path].add(function_source)
|
||||||
function_source_list.append(function_source)
|
function_source_list.append(function_source)
|
||||||
|
|
||||||
return file_path_to_function_source, function_source_list
|
return file_path_to_function_source, function_source_list
|
||||||
|
|
||||||
|
|
@ -647,7 +661,10 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
|
||||||
|
|
||||||
if isinstance(node, cst.FunctionDef):
|
if isinstance(node, cst.FunctionDef):
|
||||||
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||||
if qualified_name in target_functions:
|
# For hashing, exclude __init__ methods even if in target_functions
|
||||||
|
# because they don't affect the semantic behavior being hashed
|
||||||
|
# But include other dunder methods like __call__ which do affect behavior
|
||||||
|
if qualified_name in target_functions and node.name.value != "__init__":
|
||||||
new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body
|
new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body
|
||||||
return node.with_changes(body=new_body), True
|
return node.with_changes(body=new_body), True
|
||||||
return None, False
|
return None, False
|
||||||
|
|
@ -666,7 +683,9 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
|
||||||
for stmt in node.body.body:
|
for stmt in node.body.body:
|
||||||
if isinstance(stmt, cst.FunctionDef):
|
if isinstance(stmt, cst.FunctionDef):
|
||||||
qualified_name = f"{class_prefix}.{stmt.name.value}"
|
qualified_name = f"{class_prefix}.{stmt.name.value}"
|
||||||
if qualified_name in target_functions:
|
# For hashing, exclude __init__ methods even if in target_functions
|
||||||
|
# but include other methods like __call__ which affect behavior
|
||||||
|
if qualified_name in target_functions and stmt.name.value != "__init__":
|
||||||
stmt_with_changes = stmt.with_changes(
|
stmt_with_changes = stmt.with_changes(
|
||||||
body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body))
|
body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body))
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,8 @@ def test_code_replacement10() -> None:
|
||||||
|
|
||||||
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
|
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
|
||||||
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
||||||
assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here
|
# HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class
|
||||||
|
assert qualified_names == {"HelperClass.helper_method", "HelperClass.__init__"} # Nested method should not be in here
|
||||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||||
hashing_context = code_ctx.hashing_code_context
|
hashing_context = code_ctx.hashing_code_context
|
||||||
|
|
||||||
|
|
@ -570,6 +571,8 @@ _STORE_T = TypeVar("_STORE_T")
|
||||||
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
||||||
"""Interface for cache backends used by the persistent cache decorator."""
|
"""Interface for cache backends used by the persistent cache decorator."""
|
||||||
|
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
|
||||||
def hash_key(
|
def hash_key(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|
@ -1296,6 +1299,8 @@ class DataProcessor:
|
||||||
```
|
```
|
||||||
```python:{path_to_transform_utils.relative_to(project_root)}
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
||||||
class DataTransformer:
|
class DataTransformer:
|
||||||
|
def __init__(self):
|
||||||
|
self.data = None
|
||||||
|
|
||||||
def transform(self, data):
|
def transform(self, data):
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
@ -1599,7 +1604,11 @@ class DataProcessor:
|
||||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||||
```
|
```
|
||||||
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
||||||
|
class DataTransformer:
|
||||||
|
def __init__(self):
|
||||||
|
self.data = None
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
expected_hashing_context = f"""
|
expected_hashing_context = f"""
|
||||||
```python:utils.py
|
```python:utils.py
|
||||||
|
|
@ -1705,6 +1714,7 @@ def test_direct_module_import() -> None:
|
||||||
|
|
||||||
expected_read_only_context = """
|
expected_read_only_context = """
|
||||||
```python:utils.py
|
```python:utils.py
|
||||||
|
import math
|
||||||
from transform_utils import DataTransformer
|
from transform_utils import DataTransformer
|
||||||
|
|
||||||
class DataProcessor:
|
class DataProcessor:
|
||||||
|
|
@ -1712,6 +1722,11 @@ class DataProcessor:
|
||||||
|
|
||||||
number = 1
|
number = 1
|
||||||
|
|
||||||
|
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||||
|
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||||
|
self.default_prefix = default_prefix
|
||||||
|
self.number += math.log(self.number)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||||
return f"DataProcessor(default_prefix={self.default_prefix!r})"
|
return f"DataProcessor(default_prefix={self.default_prefix!r})"
|
||||||
|
|
@ -2727,3 +2742,154 @@ FINAL_ASSIGNMENT = {"data": "value"}
|
||||||
# Verify correct order
|
# Verify correct order
|
||||||
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
|
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
|
||||||
assert collector.assignment_order == expected_order
|
assert collector.assignment_order == expected_order
|
||||||
|
|
||||||
|
|
||||||
|
def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
|
||||||
|
"""Test that when a class is instantiated, its __init__ method is tracked as a helper.
|
||||||
|
|
||||||
|
This test verifies the fix for the bug where class constructors were not
|
||||||
|
included in the context when only the class instantiation was called
|
||||||
|
(not any other methods). This caused LLMs to not know the constructor
|
||||||
|
signatures when generating tests.
|
||||||
|
"""
|
||||||
|
code = '''
|
||||||
|
class DataDumper:
|
||||||
|
"""A class that dumps data."""
|
||||||
|
|
||||||
|
def __init__(self, data):
|
||||||
|
"""Initialize with data."""
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def dump(self):
|
||||||
|
"""Dump the data."""
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
|
||||||
|
def target_function():
|
||||||
|
# Only instantiates DataDumper, doesn't call any other methods
|
||||||
|
dumper = DataDumper({"key": "value"})
|
||||||
|
return dumper
|
||||||
|
'''
|
||||||
|
file_path = tmp_path / "test_code.py"
|
||||||
|
file_path.write_text(code, encoding="utf-8")
|
||||||
|
opt = Optimizer(
|
||||||
|
Namespace(
|
||||||
|
project_root=file_path.parent.resolve(),
|
||||||
|
disable_telemetry=True,
|
||||||
|
tests_root="tests",
|
||||||
|
test_framework="pytest",
|
||||||
|
pytest_cmd="pytest",
|
||||||
|
experiment_id=None,
|
||||||
|
test_project_root=Path().resolve(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
function_to_optimize = FunctionToOptimize(
|
||||||
|
function_name="target_function",
|
||||||
|
file_path=file_path,
|
||||||
|
parents=[],
|
||||||
|
starting_line=None,
|
||||||
|
ending_line=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||||
|
|
||||||
|
# The __init__ method should be tracked as a helper since DataDumper() instantiates the class
|
||||||
|
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
||||||
|
assert "DataDumper.__init__" in qualified_names, (
|
||||||
|
"DataDumper.__init__ should be tracked as a helper when the class is instantiated"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The testgen context should contain the class with __init__ (critical for LLM to know constructor)
|
||||||
|
testgen_context = code_ctx.testgen_context.markdown
|
||||||
|
assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context"
|
||||||
|
assert "def __init__(self, data):" in testgen_context, (
|
||||||
|
"__init__ method should be included in testgen context"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The hashing context should NOT contain __init__ (excluded for stability)
|
||||||
|
hashing_context = code_ctx.hashing_code_context
|
||||||
|
assert "__init__" not in hashing_context, (
|
||||||
|
"__init__ should NOT be in hashing context (excluded for hash stability)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None:
|
||||||
|
"""Test that instantiated classes are fully preserved in testgen context.
|
||||||
|
|
||||||
|
This is specifically for the unstructured LayoutDumper bug where helper classes
|
||||||
|
that were instantiated but had no other methods called were being excluded
|
||||||
|
from the testgen context.
|
||||||
|
"""
|
||||||
|
code = '''
|
||||||
|
class LayoutDumper:
|
||||||
|
"""Base class for layout dumpers."""
|
||||||
|
layout_source: str = "unknown"
|
||||||
|
|
||||||
|
def __init__(self, layout):
|
||||||
|
self._layout = layout
|
||||||
|
|
||||||
|
def dump(self) -> dict:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectDetectionLayoutDumper(LayoutDumper):
|
||||||
|
"""Specific dumper for object detection layouts."""
|
||||||
|
|
||||||
|
def __init__(self, layout):
|
||||||
|
super().__init__(layout)
|
||||||
|
|
||||||
|
def dump(self) -> dict:
|
||||||
|
return {"type": "object_detection", "layout": self._layout}
|
||||||
|
|
||||||
|
|
||||||
|
def dump_layout(layout_type, layout):
|
||||||
|
"""Dump a layout based on its type."""
|
||||||
|
if layout_type == "object_detection":
|
||||||
|
dumper = ObjectDetectionLayoutDumper(layout)
|
||||||
|
else:
|
||||||
|
dumper = LayoutDumper(layout)
|
||||||
|
return dumper.dump()
|
||||||
|
'''
|
||||||
|
file_path = tmp_path / "test_code.py"
|
||||||
|
file_path.write_text(code, encoding="utf-8")
|
||||||
|
opt = Optimizer(
|
||||||
|
Namespace(
|
||||||
|
project_root=file_path.parent.resolve(),
|
||||||
|
disable_telemetry=True,
|
||||||
|
tests_root="tests",
|
||||||
|
test_framework="pytest",
|
||||||
|
pytest_cmd="pytest",
|
||||||
|
experiment_id=None,
|
||||||
|
test_project_root=Path().resolve(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
function_to_optimize = FunctionToOptimize(
|
||||||
|
function_name="dump_layout",
|
||||||
|
file_path=file_path,
|
||||||
|
parents=[],
|
||||||
|
starting_line=None,
|
||||||
|
ending_line=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||||
|
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
||||||
|
|
||||||
|
# Both class __init__ methods should be tracked as helpers
|
||||||
|
assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, (
|
||||||
|
"ObjectDetectionLayoutDumper.__init__ should be tracked"
|
||||||
|
)
|
||||||
|
assert "LayoutDumper.__init__" in qualified_names, (
|
||||||
|
"LayoutDumper.__init__ should be tracked"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The testgen context should include both classes with their __init__ methods
|
||||||
|
testgen_context = code_ctx.testgen_context.markdown
|
||||||
|
assert "class LayoutDumper:" in testgen_context, "LayoutDumper should be in testgen context"
|
||||||
|
assert "class ObjectDetectionLayoutDumper" in testgen_context, (
|
||||||
|
"ObjectDetectionLayoutDumper should be in testgen context"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Both __init__ methods should be in the testgen context (so LLM knows constructor signatures)
|
||||||
|
assert testgen_context.count("def __init__") >= 2, (
|
||||||
|
"Both __init__ methods should be in testgen context"
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue