fix: preserve comment position by passing CST module directly to import adder

parse_code_and_prune_cst now returns cst.Module instead of str.
add_needed_imports_from_module accepts cst.Module | str, skipping re-parse
when a Module is passed. This eliminates the string round-trip that caused
comments to migrate from statement leading_lines to Module.header,
resulting in comments appearing above imports instead of at their
original position.
This commit is contained in:
Kevin Turcios 2026-02-23 01:08:39 -05:00
parent e82e4cf16e
commit b5fab57499
6 changed files with 90 additions and 70 deletions

View file

@ -355,7 +355,7 @@ def process_file_context(
try:
all_names = primary_qualified_names | secondary_qualified_names
code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, all_names)
code_context = parse_code_and_prune_cst(
pruned_module = parse_code_and_prune_cst(
code_without_unused_defs,
code_context_type,
primary_qualified_names,
@ -366,11 +366,13 @@ def process_file_context(
logger.debug(f"Error while getting read-only code: {e}")
return None
if code_context.strip():
if code_context_type != CodeContextType.HASHING:
if pruned_module.code.strip():
if code_context_type == CodeContextType.HASHING:
code_context = ast.unparse(ast.parse(pruned_module.code))
else:
code_context = add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=code_context,
dst_module_code=pruned_module,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
@ -1063,8 +1065,8 @@ def parse_code_and_prune_cst(
target_functions: set[str],
helpers_of_helper_functions: set[str] = set(), # noqa: B006
remove_docstrings: bool = False,
) -> str:
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
) -> cst.Module:
"""Parse and filter the code CST, returning the pruned Module."""
module = cst.parse_module(code)
defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions)
@ -1100,11 +1102,8 @@ def parse_code_and_prune_cst(
if not found_target:
raise ValueError("No target functions found in the provided code")
if filtered_node and isinstance(filtered_node, cst.Module):
code = str(filtered_node.code)
if code_context_type == CodeContextType.HASHING:
code = ast.unparse(ast.parse(code)) # Makes it standard
return code
return ""
return filtered_node
raise ValueError("Pruning produced no module")
def _qualified_name(prefix: str, name: str) -> str:

View file

@ -684,7 +684,7 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
def add_needed_imports_from_module(
src_module_code: str,
dst_module_code: str,
dst_module_code: str | cst.Module,
src_path: Path,
dst_path: Path,
project_root: Path,
@ -696,6 +696,8 @@ def add_needed_imports_from_module(
if not helper_functions_fqn:
helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])}
dst_code_fallback = dst_module_code if isinstance(dst_module_code, str) else dst_module_code.code
src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path)
dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path)
@ -715,15 +717,19 @@ def add_needed_imports_from_module(
cst.parse_module(src_module_code).visit(gatherer)
except Exception as e:
logger.error(f"Error parsing source module code: {e}")
return dst_module_code
return dst_code_fallback
dotted_import_collector = DottedImportCollector()
try:
parsed_dst_module = cst.parse_module(dst_module_code)
if isinstance(dst_module_code, cst.Module):
parsed_dst_module = dst_module_code
parsed_dst_module.visit(dotted_import_collector)
except cst.ParserSyntaxError as e:
logger.exception(f"Syntax error in destination module code: {e}")
return dst_module_code # Return the original code if there's a syntax error
else:
try:
parsed_dst_module = cst.parse_module(dst_module_code)
parsed_dst_module.visit(dotted_import_collector)
except cst.ParserSyntaxError as e:
logger.exception(f"Syntax error in destination module code: {e}")
return dst_code_fallback
try:
for mod in gatherer.module_imports:
@ -768,7 +774,7 @@ def add_needed_imports_from_module(
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
except Exception as e:
logger.exception(f"Error adding imports to destination module code: {e}")
return dst_module_code
return dst_code_fallback
for mod, asname in gatherer.module_aliases.items():
if not asname:
@ -796,7 +802,7 @@ def add_needed_imports_from_module(
return transformed_module.code.lstrip("\n")
except Exception as e:
logger.exception(f"Error adding imports to destination module code: {e}")
return dst_module_code
return dst_code_fallback
def get_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | None, set[tuple[str, str]]]:

View file

@ -104,6 +104,7 @@ def test_code_replacement10() -> None:
```python:{file_path.relative_to(file_path.parent)}
from __future__ import annotations
class HelperClass:
def __init__(self, name):
self.name = name
@ -164,6 +165,7 @@ def test_class_method_dependencies() -> None:
from __future__ import annotations
from collections import defaultdict
class Graph:
def __init__(self, vertices):
self.graph = defaultdict(list)
@ -243,6 +245,7 @@ def test_bubble_sort_helper() -> None:
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
import math
def sorter(arr):
arr.sort()
x = math.sqrt(2)
@ -252,6 +255,7 @@ def sorter(arr):
```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py
from bubble_sort_with_math import sorter
def sort_from_another_file(arr):
sorted_arr = sorter(arr)
return sorted_arr
@ -1180,6 +1184,7 @@ API_URL = "https://api.example.com/data"
```python:{path_to_utils.relative_to(project_root)}
import math
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
@ -1200,6 +1205,7 @@ import requests
from globals import API_URL
from utils import DataProcessor
def fetch_and_process_data():
# Use the global variable for the request
response = requests.get(API_URL)
@ -1279,6 +1285,7 @@ API_URL = "https://api.example.com/data"
import math
from transform_utils import DataTransformer
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
@ -1299,6 +1306,7 @@ import requests
from globals import API_URL
from utils import DataProcessor
def fetch_and_transform_data():
# Use the global variable for the request
response = requests.get(API_URL)
@ -1387,6 +1395,7 @@ class DataTransformer:
import math
from transform_utils import DataTransformer
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
@ -1467,6 +1476,7 @@ class DataTransformer:
import math
from transform_utils import DataTransformer
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
@ -1598,6 +1608,7 @@ def test_repo_helper_circular_dependency() -> None:
import math
from transform_utils import DataTransformer
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
@ -1612,6 +1623,7 @@ class DataProcessor:
```python:{path_to_transform_utils.relative_to(project_root)}
from code_to_optimize.code_directories.retriever.utils import DataProcessor
class DataTransformer:
def __init__(self):
self.data = None
@ -1744,6 +1756,7 @@ def test_direct_module_import() -> None:
import math
from transform_utils import DataTransformer
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
@ -1787,6 +1800,7 @@ import requests
from globals import API_URL
from utils import DataProcessor
def fetch_and_transform_data():
# Use the global variable for the request
response = requests.get(API_URL)
@ -4077,6 +4091,7 @@ import dataclasses
import enum
import typing as t
class MessageKind(enum.StrEnum):
ASK_FOR_CLIPBOARD_RESPONSE = "ask-for-clipboard-response"
BEGIN_EXFILTRATION = "begin-exfiltration"

View file

@ -23,7 +23,7 @@ def test_basic_class() -> None:
class_var = "value"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -47,7 +47,7 @@ def test_dunder_methods() -> None:
return f"Value: {self.x}"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -75,7 +75,7 @@ def test_dunder_methods_remove_docstring() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
)
).code
assert dedent(expected).strip() == output.strip()
@ -102,7 +102,7 @@ def test_class_remove_docstring() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
)
).code
assert dedent(expected).strip() == output.strip()
@ -131,7 +131,7 @@ def test_mixed_remove_docstring() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
)
).code
assert dedent(expected).strip() == output.strip()
@ -171,7 +171,7 @@ def test_docstrings() -> None:
\"\"\"Class docstring.\"\"\"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -190,7 +190,7 @@ def test_method_signatures() -> None:
expected = """"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -212,7 +212,7 @@ def test_multiple_top_level_targets() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target1", "TestClass.target2"}, set()
)
).code
assert dedent(expected).strip() == output.strip()
@ -232,7 +232,7 @@ def test_class_annotations() -> None:
var2: str
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -254,7 +254,7 @@ def test_class_annotations_if() -> None:
var2: str
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -280,7 +280,7 @@ def test_class_annotations_try() -> None:
continue
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -316,7 +316,7 @@ def test_class_annotations_else() -> None:
var2: str
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -331,7 +331,7 @@ def test_top_level_functions() -> None:
expected = """"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -350,7 +350,7 @@ def test_module_var() -> None:
x = 5
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -377,7 +377,7 @@ def test_module_var_if() -> None:
z = 10
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -412,7 +412,7 @@ def test_conditional_class_definitions() -> None:
platform = "other"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -471,7 +471,7 @@ def test_multiple_except_clauses() -> None:
error_type = "cleanup"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -524,7 +524,7 @@ def test_with_statement_and_loops() -> None:
context = "cleanup"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -573,7 +573,7 @@ def test_async_with_try_except() -> None:
status = "cancelled"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -675,7 +675,7 @@ def test_simplified_complete_implementation() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()
)
).code
assert dedent(expected).strip() == output.strip()
@ -768,5 +768,5 @@ def test_simplified_complete_implementation_no_docstring() -> None:
{"DataProcessor.target_method", "ResultHandler.target_method"},
set(),
remove_docstrings=True,
)
).code
assert dedent(expected).strip() == output.strip()

View file

@ -13,7 +13,7 @@ def test_simple_function() -> None:
y = 2
return x + y
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
expected = dedent("""
def target_function():
@ -32,7 +32,7 @@ def test_class_method() -> None:
y = 2
return x + y
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"}).code
expected = dedent("""
class MyClass:
@ -56,7 +56,7 @@ def test_class_with_attributes() -> None:
def other_method(self):
print("this should be excluded")
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
expected = dedent("""
class MyClass:
@ -80,7 +80,7 @@ def test_basic_class_structure() -> None:
def not_findable(self):
return 42
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"Outer.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"Outer.target_method"}).code
expected = dedent("""
class Outer:
@ -100,7 +100,7 @@ def test_top_level_targets() -> None:
def target_function():
return 42
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
expected = dedent("""
def target_function():
@ -123,7 +123,7 @@ def test_multiple_top_level_classes() -> None:
def process(self):
return "C"
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"}).code
expected = dedent("""
class ClassA:
@ -148,7 +148,7 @@ def test_try_except_structure() -> None:
def handle_error(self):
print("error")
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"TargetClass.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"TargetClass.target_method"}).code
expected = dedent("""
try:
@ -175,7 +175,7 @@ def test_init_method() -> None:
def target_method(self):
return f"Value: {self.x}"
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
expected = dedent("""
class MyClass:
@ -200,7 +200,7 @@ def test_dunder_method() -> None:
def target_method(self):
return f"Value: {self.x}"
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
expected = dedent("""
class MyClass:
@ -221,7 +221,7 @@ def test_no_targets_found() -> None:
def target(self):
pass
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}).code
expected = dedent("""
class MyClass:
def method(self):
@ -266,5 +266,5 @@ def test_module_var() -> None:
var2 = "test"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
assert dedent(expected).strip() == output.strip()

View file

@ -13,7 +13,7 @@ def test_simple_function() -> None:
y = 2
return x + y
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
result = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()).code
expected = """
def target_function():
@ -44,7 +44,7 @@ def test_basic_class() -> None:
print("This should be included")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -73,7 +73,7 @@ def test_dunder_methods() -> None:
print("include me")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -107,7 +107,7 @@ def test_dunder_methods_remove_docstring() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True
)
).code
assert dedent(expected).strip() == output.strip()
@ -139,7 +139,7 @@ def test_class_remove_docstring() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True
)
).code
assert dedent(expected).strip() == output.strip()
@ -181,7 +181,7 @@ def test_method_signatures() -> None:
return "value"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -215,7 +215,7 @@ def test_multiple_top_level_targets() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"TestClass.target1", "TestClass.target2"}, set()
)
).code
assert dedent(expected).strip() == output.strip()
@ -238,7 +238,7 @@ def test_class_annotations() -> None:
self.var2 = "test"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -263,7 +263,7 @@ def test_class_annotations_if() -> None:
self.var2 = "test"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -304,7 +304,7 @@ def test_conditional_class_definitions() -> None:
print("other")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -333,7 +333,7 @@ def test_try_except_structure() -> None:
print("error")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TargetClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TargetClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -355,7 +355,7 @@ def test_module_var() -> None:
x = 5
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -385,7 +385,7 @@ def test_module_var_if() -> None:
z = 10
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -416,7 +416,7 @@ def test_multiple_classes() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"ClassA.process", "ClassC.process"}, set()
)
).code
assert dedent(expected).strip() == output.strip()
@ -477,7 +477,7 @@ def test_with_statement_and_loops() -> None:
print("cleanup")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -532,7 +532,7 @@ def test_async_with_try_except() -> None:
await self.cleanup()
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -659,7 +659,7 @@ def test_simplified_complete_implementation() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()
)
).code
assert dedent(expected).strip() == output.strip()
@ -765,5 +765,5 @@ def test_simplified_complete_implementation_no_docstring() -> None:
{"DataProcessor.target_method", "ResultHandler.target_method"},
set(),
remove_docstrings=True,
)
).code
assert dedent(expected).strip() == output.strip()