mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix: track class __init__ as helper when class is instantiated
Ensures LLM sees constructor signatures for proper test generation.
This commit is contained in:
parent
0d84ab62dd
commit
3048ece4da
2 changed files with 206 additions and 21 deletions
|
|
@ -446,31 +446,45 @@ def get_function_sources_from_jedi(
|
|||
definition_path = definition.module_path
|
||||
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
is_valid_definition = (
|
||||
str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and definition.full_name
|
||||
and definition.type == "function"
|
||||
and not belongs_to_function_qualified(definition, qualified_function_name)
|
||||
and definition.full_name.startswith(definition.module_name)
|
||||
)
|
||||
if is_valid_definition and definition.type == "function":
|
||||
qualified_name = get_qualified_name(definition.module_name, definition.full_name)
|
||||
# Avoid nested functions or classes. Only class.function is allowed
|
||||
and len(
|
||||
(qualified_name := get_qualified_name(definition.module_name, definition.full_name)).split(
|
||||
"."
|
||||
if len(qualified_name.split(".")) <= 2:
|
||||
function_source = FunctionSource(
|
||||
file_path=definition_path,
|
||||
qualified_name=qualified_name,
|
||||
fully_qualified_name=definition.full_name,
|
||||
only_function_name=definition.name,
|
||||
source_code=definition.get_line_code(),
|
||||
jedi_definition=definition,
|
||||
)
|
||||
file_path_to_function_source[definition_path].add(function_source)
|
||||
function_source_list.append(function_source)
|
||||
# When a class is instantiated (e.g., MyClass()), track its __init__ as a helper
|
||||
# This ensures the class definition with constructor is included in testgen context
|
||||
elif is_valid_definition and definition.type == "class":
|
||||
init_qualified_name = get_qualified_name(
|
||||
definition.module_name, f"{definition.full_name}.__init__"
|
||||
)
|
||||
<= 2
|
||||
):
|
||||
function_source = FunctionSource(
|
||||
file_path=definition_path,
|
||||
qualified_name=qualified_name,
|
||||
fully_qualified_name=definition.full_name,
|
||||
only_function_name=definition.name,
|
||||
source_code=definition.get_line_code(),
|
||||
jedi_definition=definition,
|
||||
)
|
||||
file_path_to_function_source[definition_path].add(function_source)
|
||||
function_source_list.append(function_source)
|
||||
# Only include if it's a top-level class (not nested)
|
||||
if len(init_qualified_name.split(".")) <= 2:
|
||||
function_source = FunctionSource(
|
||||
file_path=definition_path,
|
||||
qualified_name=init_qualified_name,
|
||||
fully_qualified_name=f"{definition.full_name}.__init__",
|
||||
only_function_name="__init__",
|
||||
source_code=definition.get_line_code(),
|
||||
jedi_definition=definition,
|
||||
)
|
||||
file_path_to_function_source[definition_path].add(function_source)
|
||||
function_source_list.append(function_source)
|
||||
|
||||
return file_path_to_function_source, function_source_list
|
||||
|
||||
|
|
@ -647,7 +661,10 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
|
|||
|
||||
if isinstance(node, cst.FunctionDef):
|
||||
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
if qualified_name in target_functions:
|
||||
# For hashing, exclude __init__ methods even if in target_functions
|
||||
# because they don't affect the semantic behavior being hashed
|
||||
# But include other dunder methods like __call__ which do affect behavior
|
||||
if qualified_name in target_functions and node.name.value != "__init__":
|
||||
new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body
|
||||
return node.with_changes(body=new_body), True
|
||||
return None, False
|
||||
|
|
@ -666,7 +683,9 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
|
|||
for stmt in node.body.body:
|
||||
if isinstance(stmt, cst.FunctionDef):
|
||||
qualified_name = f"{class_prefix}.{stmt.name.value}"
|
||||
if qualified_name in target_functions:
|
||||
# For hashing, exclude __init__ methods even if in target_functions
|
||||
# but include other methods like __call__ which affect behavior
|
||||
if qualified_name in target_functions and stmt.name.value != "__init__":
|
||||
stmt_with_changes = stmt.with_changes(
|
||||
body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body))
|
||||
)
|
||||
|
|
|
|||
|
|
@ -84,7 +84,8 @@ def test_code_replacement10() -> None:
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
|
||||
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
||||
assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here
|
||||
# HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class
|
||||
assert qualified_names == {"HelperClass.helper_method", "HelperClass.__init__"} # Nested method should not be in here
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
|
||||
|
|
@ -570,6 +571,8 @@ _STORE_T = TypeVar("_STORE_T")
|
|||
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
||||
"""Interface for cache backends used by the persistent cache decorator."""
|
||||
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
def hash_key(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -1296,6 +1299,8 @@ class DataProcessor:
|
|||
```
|
||||
```python:{path_to_transform_utils.relative_to(project_root)}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
def transform(self, data):
|
||||
self.data = data
|
||||
|
|
@ -1599,7 +1604,11 @@ class DataProcessor:
|
|||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
|
||||
```python:{path_to_transform_utils.relative_to(project_root)}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
```
|
||||
"""
|
||||
expected_hashing_context = f"""
|
||||
```python:utils.py
|
||||
|
|
@ -1705,6 +1714,7 @@ def test_direct_module_import() -> None:
|
|||
|
||||
expected_read_only_context = """
|
||||
```python:utils.py
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
class DataProcessor:
|
||||
|
|
@ -1712,6 +1722,11 @@ class DataProcessor:
|
|||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={self.default_prefix!r})"
|
||||
|
|
@ -2727,3 +2742,154 @@ FINAL_ASSIGNMENT = {"data": "value"}
|
|||
# Verify correct order
|
||||
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
|
||||
assert collector.assignment_order == expected_order
|
||||
|
||||
|
||||
def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
|
||||
"""Test that when a class is instantiated, its __init__ method is tracked as a helper.
|
||||
|
||||
This test verifies the fix for the bug where class constructors were not
|
||||
included in the context when only the class instantiation was called
|
||||
(not any other methods). This caused LLMs to not know the constructor
|
||||
signatures when generating tests.
|
||||
"""
|
||||
code = '''
|
||||
class DataDumper:
|
||||
"""A class that dumps data."""
|
||||
|
||||
def __init__(self, data):
|
||||
"""Initialize with data."""
|
||||
self.data = data
|
||||
|
||||
def dump(self):
|
||||
"""Dump the data."""
|
||||
return self.data
|
||||
|
||||
|
||||
def target_function():
|
||||
# Only instantiates DataDumper, doesn't call any other methods
|
||||
dumper = DataDumper({"key": "value"})
|
||||
return dumper
|
||||
'''
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.parent.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="target_function",
|
||||
file_path=file_path,
|
||||
parents=[],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
|
||||
# The __init__ method should be tracked as a helper since DataDumper() instantiates the class
|
||||
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
||||
assert "DataDumper.__init__" in qualified_names, (
|
||||
"DataDumper.__init__ should be tracked as a helper when the class is instantiated"
|
||||
)
|
||||
|
||||
# The testgen context should contain the class with __init__ (critical for LLM to know constructor)
|
||||
testgen_context = code_ctx.testgen_context.markdown
|
||||
assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context"
|
||||
assert "def __init__(self, data):" in testgen_context, (
|
||||
"__init__ method should be included in testgen context"
|
||||
)
|
||||
|
||||
# The hashing context should NOT contain __init__ (excluded for stability)
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
assert "__init__" not in hashing_context, (
|
||||
"__init__ should NOT be in hashing context (excluded for hash stability)"
|
||||
)
|
||||
|
||||
|
||||
def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None:
|
||||
"""Test that instantiated classes are fully preserved in testgen context.
|
||||
|
||||
This is specifically for the unstructured LayoutDumper bug where helper classes
|
||||
that were instantiated but had no other methods called were being excluded
|
||||
from the testgen context.
|
||||
"""
|
||||
code = '''
|
||||
class LayoutDumper:
|
||||
"""Base class for layout dumpers."""
|
||||
layout_source: str = "unknown"
|
||||
|
||||
def __init__(self, layout):
|
||||
self._layout = layout
|
||||
|
||||
def dump(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ObjectDetectionLayoutDumper(LayoutDumper):
|
||||
"""Specific dumper for object detection layouts."""
|
||||
|
||||
def __init__(self, layout):
|
||||
super().__init__(layout)
|
||||
|
||||
def dump(self) -> dict:
|
||||
return {"type": "object_detection", "layout": self._layout}
|
||||
|
||||
|
||||
def dump_layout(layout_type, layout):
|
||||
"""Dump a layout based on its type."""
|
||||
if layout_type == "object_detection":
|
||||
dumper = ObjectDetectionLayoutDumper(layout)
|
||||
else:
|
||||
dumper = LayoutDumper(layout)
|
||||
return dumper.dump()
|
||||
'''
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.parent.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="dump_layout",
|
||||
file_path=file_path,
|
||||
parents=[],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
||||
|
||||
# Both class __init__ methods should be tracked as helpers
|
||||
assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, (
|
||||
"ObjectDetectionLayoutDumper.__init__ should be tracked"
|
||||
)
|
||||
assert "LayoutDumper.__init__" in qualified_names, (
|
||||
"LayoutDumper.__init__ should be tracked"
|
||||
)
|
||||
|
||||
# The testgen context should include both classes with their __init__ methods
|
||||
testgen_context = code_ctx.testgen_context.markdown
|
||||
assert "class LayoutDumper:" in testgen_context, "LayoutDumper should be in testgen context"
|
||||
assert "class ObjectDetectionLayoutDumper" in testgen_context, (
|
||||
"ObjectDetectionLayoutDumper should be in testgen context"
|
||||
)
|
||||
|
||||
# Both __init__ methods should be in the testgen context (so LLM knows constructor signatures)
|
||||
assert testgen_context.count("def __init__") >= 2, (
|
||||
"Both __init__ methods should be in testgen context"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue