From 3048ece4da9ef9585684645a995ec92c8f36539b Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Thu, 1 Jan 2026 02:25:10 -0500 Subject: [PATCH] fix: track class __init__ as helper when class is instantiated Ensures LLM sees constructor signatures for proper test generation. --- codeflash/context/code_context_extractor.py | 57 ++++--- tests/test_code_context_extractor.py | 170 +++++++++++++++++++- 2 files changed, 206 insertions(+), 21 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 14d549633..a411bafac 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -446,31 +446,45 @@ def get_function_sources_from_jedi( definition_path = definition.module_path # The definition is part of this project and not defined within the original function - if ( + is_valid_definition = ( str(definition_path).startswith(str(project_root_path) + os.sep) and not path_belongs_to_site_packages(definition_path) and definition.full_name - and definition.type == "function" and not belongs_to_function_qualified(definition, qualified_function_name) and definition.full_name.startswith(definition.module_name) + ) + if is_valid_definition and definition.type == "function": + qualified_name = get_qualified_name(definition.module_name, definition.full_name) # Avoid nested functions or classes. Only class.function is allowed - and len( - (qualified_name := get_qualified_name(definition.module_name, definition.full_name)).split( - "." + if len(qualified_name.split(".")) <= 2: + function_source = FunctionSource( + file_path=definition_path, + qualified_name=qualified_name, + fully_qualified_name=definition.full_name, + only_function_name=definition.name, + source_code=definition.get_line_code(), + jedi_definition=definition, ) + file_path_to_function_source[definition_path].add(function_source) + function_source_list.append(function_source) + # When a class is instantiated (e.g., MyClass()), track its __init__ as a helper + # This ensures the class definition with constructor is included in testgen context + elif is_valid_definition and definition.type == "class": + init_qualified_name = get_qualified_name( + definition.module_name, f"{definition.full_name}.__init__" ) - <= 2 - ): - function_source = FunctionSource( - file_path=definition_path, - qualified_name=qualified_name, - fully_qualified_name=definition.full_name, - only_function_name=definition.name, - source_code=definition.get_line_code(), - jedi_definition=definition, - ) - file_path_to_function_source[definition_path].add(function_source) - function_source_list.append(function_source) + # Only include if it's a top-level class (not nested) + if len(init_qualified_name.split(".")) <= 2: + function_source = FunctionSource( + file_path=definition_path, + qualified_name=init_qualified_name, + fully_qualified_name=f"{definition.full_name}.__init__", + only_function_name="__init__", + source_code=definition.get_line_code(), + jedi_definition=definition, + ) + file_path_to_function_source[definition_path].add(function_source) + function_source_list.append(function_source) return file_path_to_function_source, function_source_list @@ -647,7 +661,10 @@ def prune_cst_for_code_hashing( # noqa: PLR0911 if isinstance(node, cst.FunctionDef): qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value - if qualified_name in target_functions: + # For hashing, exclude __init__ methods even if in target_functions + # because they don't affect the semantic behavior being hashed + # But include other dunder methods like __call__ which do affect behavior + if qualified_name in target_functions and node.name.value != "__init__": new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body return node.with_changes(body=new_body), True return None, False @@ -666,7 +683,9 @@ def prune_cst_for_code_hashing( # noqa: PLR0911 for stmt in node.body.body: if isinstance(stmt, cst.FunctionDef): qualified_name = f"{class_prefix}.{stmt.name.value}" - if qualified_name in target_functions: + # For hashing, exclude __init__ methods even if in target_functions + # but include other methods like __call__ which affect behavior + if qualified_name in target_functions and stmt.name.value != "__init__": stmt_with_changes = stmt.with_changes( body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body)) ) diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index aa4e2880f..b7cce0869 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -84,7 +84,8 @@ def test_code_replacement10() -> None: code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent) qualified_names = {func.qualified_name for func in code_ctx.helper_functions} - assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here + # HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class + assert qualified_names == {"HelperClass.helper_method", "HelperClass.__init__"} # Nested method should not be in here read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context @@ -570,6 +571,8 @@ _STORE_T = TypeVar("_STORE_T") class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): """Interface for cache backends used by the persistent cache decorator.""" + def __init__(self) -> None: ... + def hash_key( self, *, @@ -1296,6 +1299,8 @@ class DataProcessor: ``` ```python:{path_to_transform_utils.relative_to(project_root)} class DataTransformer: + def __init__(self): + self.data = None def transform(self, data): self.data = data @@ -1599,7 +1604,11 @@ class DataProcessor: \"\"\"Return a string representation of the DataProcessor.\"\"\" return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` - +```python:{path_to_transform_utils.relative_to(project_root)} +class DataTransformer: + def __init__(self): + self.data = None +``` """ expected_hashing_context = f""" ```python:utils.py @@ -1705,6 +1714,7 @@ def test_direct_module_import() -> None: expected_read_only_context = """ ```python:utils.py +import math from transform_utils import DataTransformer class DataProcessor: @@ -1712,6 +1722,11 @@ class DataProcessor: number = 1 + def __init__(self, default_prefix: str = "PREFIX_"): + \"\"\"Initialize the DataProcessor with a default prefix.\"\"\" + self.default_prefix = default_prefix + self.number += math.log(self.number) + def __repr__(self) -> str: \"\"\"Return a string representation of the DataProcessor.\"\"\" return f"DataProcessor(default_prefix={self.default_prefix!r})" @@ -2727,3 +2742,154 @@ FINAL_ASSIGNMENT = {"data": "value"} # Verify correct order expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"] assert collector.assignment_order == expected_order + + +def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None: + """Test that when a class is instantiated, its __init__ method is tracked as a helper. + + This test verifies the fix for the bug where class constructors were not + included in the context when only the class instantiation was called + (not any other methods). This caused LLMs to not know the constructor + signatures when generating tests. + """ + code = ''' +class DataDumper: + """A class that dumps data.""" + + def __init__(self, data): + """Initialize with data.""" + self.data = data + + def dump(self): + """Dump the data.""" + return self.data + + +def target_function(): + # Only instantiates DataDumper, doesn't call any other methods + dumper = DataDumper({"key": "value"}) + return dumper +''' + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_function", + file_path=file_path, + parents=[], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + + # The __init__ method should be tracked as a helper since DataDumper() instantiates the class + qualified_names = {func.qualified_name for func in code_ctx.helper_functions} + assert "DataDumper.__init__" in qualified_names, ( + "DataDumper.__init__ should be tracked as a helper when the class is instantiated" + ) + + # The testgen context should contain the class with __init__ (critical for LLM to know constructor) + testgen_context = code_ctx.testgen_context.markdown + assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context" + assert "def __init__(self, data):" in testgen_context, ( + "__init__ method should be included in testgen context" + ) + + # The hashing context should NOT contain __init__ (excluded for stability) + hashing_context = code_ctx.hashing_code_context + assert "__init__" not in hashing_context, ( + "__init__ should NOT be in hashing context (excluded for hash stability)" + ) + + +def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None: + """Test that instantiated classes are fully preserved in testgen context. + + This is specifically for the unstructured LayoutDumper bug where helper classes + that were instantiated but had no other methods called were being excluded + from the testgen context. + """ + code = ''' +class LayoutDumper: + """Base class for layout dumpers.""" + layout_source: str = "unknown" + + def __init__(self, layout): + self._layout = layout + + def dump(self) -> dict: + raise NotImplementedError() + + +class ObjectDetectionLayoutDumper(LayoutDumper): + """Specific dumper for object detection layouts.""" + + def __init__(self, layout): + super().__init__(layout) + + def dump(self) -> dict: + return {"type": "object_detection", "layout": self._layout} + + +def dump_layout(layout_type, layout): + """Dump a layout based on its type.""" + if layout_type == "object_detection": + dumper = ObjectDetectionLayoutDumper(layout) + else: + dumper = LayoutDumper(layout) + return dumper.dump() +''' + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="dump_layout", + file_path=file_path, + parents=[], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + qualified_names = {func.qualified_name for func in code_ctx.helper_functions} + + # Both class __init__ methods should be tracked as helpers + assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, ( + "ObjectDetectionLayoutDumper.__init__ should be tracked" + ) + assert "LayoutDumper.__init__" in qualified_names, ( + "LayoutDumper.__init__ should be tracked" + ) + + # The testgen context should include both classes with their __init__ methods + testgen_context = code_ctx.testgen_context.markdown + assert "class LayoutDumper:" in testgen_context, "LayoutDumper should be in testgen context" + assert "class ObjectDetectionLayoutDumper" in testgen_context, ( + "ObjectDetectionLayoutDumper should be in testgen context" + ) + + # Both __init__ methods should be in the testgen context (so LLM knows constructor signatures) + assert testgen_context.count("def __init__") >= 2, ( + "Both __init__ methods should be in testgen context" + )