mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix: preserve comment position by passing CST module directly to import adder
parse_code_and_prune_cst now returns cst.Module instead of str. add_needed_imports_from_module accepts cst.Module | str, skipping re-parse when a Module is passed. This eliminates the string round-trip that caused comments to migrate from statement leading_lines to Module.header, resulting in comments appearing above imports instead of at their original position.
This commit is contained in:
parent
1689a7bbb5
commit
6c4378db51
6 changed files with 90 additions and 70 deletions
|
|
@ -368,7 +368,7 @@ def process_file_context(
|
|||
try:
|
||||
all_names = primary_qualified_names | secondary_qualified_names
|
||||
code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, all_names)
|
||||
code_context = parse_code_and_prune_cst(
|
||||
pruned_module = parse_code_and_prune_cst(
|
||||
code_without_unused_defs,
|
||||
code_context_type,
|
||||
primary_qualified_names,
|
||||
|
|
@ -379,11 +379,13 @@ def process_file_context(
|
|||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
return None
|
||||
|
||||
if code_context.strip():
|
||||
if code_context_type != CodeContextType.HASHING:
|
||||
if pruned_module.code.strip():
|
||||
if code_context_type == CodeContextType.HASHING:
|
||||
code_context = ast.unparse(ast.parse(pruned_module.code))
|
||||
else:
|
||||
code_context = add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=code_context,
|
||||
dst_module_code=pruned_module,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
|
|
@ -1280,8 +1282,8 @@ def parse_code_and_prune_cst(
|
|||
target_functions: set[str],
|
||||
helpers_of_helper_functions: set[str] = set(), # noqa: B006
|
||||
remove_docstrings: bool = False,
|
||||
) -> str:
|
||||
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
|
||||
) -> cst.Module:
|
||||
"""Parse and filter the code CST, returning the pruned Module."""
|
||||
module = cst.parse_module(code)
|
||||
defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions)
|
||||
|
||||
|
|
@ -1317,11 +1319,8 @@ def parse_code_and_prune_cst(
|
|||
if not found_target:
|
||||
raise ValueError("No target functions found in the provided code")
|
||||
if filtered_node and isinstance(filtered_node, cst.Module):
|
||||
code = str(filtered_node.code)
|
||||
if code_context_type == CodeContextType.HASHING:
|
||||
code = ast.unparse(ast.parse(code)) # Makes it standard
|
||||
return code
|
||||
return ""
|
||||
return filtered_node
|
||||
raise ValueError("Pruning produced no module")
|
||||
|
||||
|
||||
def prune_cst(
|
||||
|
|
|
|||
|
|
@ -684,7 +684,7 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
|
|||
|
||||
def add_needed_imports_from_module(
|
||||
src_module_code: str,
|
||||
dst_module_code: str,
|
||||
dst_module_code: str | cst.Module,
|
||||
src_path: Path,
|
||||
dst_path: Path,
|
||||
project_root: Path,
|
||||
|
|
@ -696,6 +696,8 @@ def add_needed_imports_from_module(
|
|||
if not helper_functions_fqn:
|
||||
helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])}
|
||||
|
||||
dst_code_fallback = dst_module_code if isinstance(dst_module_code, str) else dst_module_code.code
|
||||
|
||||
src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path)
|
||||
dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path)
|
||||
|
||||
|
|
@ -715,15 +717,19 @@ def add_needed_imports_from_module(
|
|||
cst.parse_module(src_module_code).visit(gatherer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing source module code: {e}")
|
||||
return dst_module_code
|
||||
return dst_code_fallback
|
||||
|
||||
dotted_import_collector = DottedImportCollector()
|
||||
try:
|
||||
parsed_dst_module = cst.parse_module(dst_module_code)
|
||||
if isinstance(dst_module_code, cst.Module):
|
||||
parsed_dst_module = dst_module_code
|
||||
parsed_dst_module.visit(dotted_import_collector)
|
||||
except cst.ParserSyntaxError as e:
|
||||
logger.exception(f"Syntax error in destination module code: {e}")
|
||||
return dst_module_code # Return the original code if there's a syntax error
|
||||
else:
|
||||
try:
|
||||
parsed_dst_module = cst.parse_module(dst_module_code)
|
||||
parsed_dst_module.visit(dotted_import_collector)
|
||||
except cst.ParserSyntaxError as e:
|
||||
logger.exception(f"Syntax error in destination module code: {e}")
|
||||
return dst_code_fallback
|
||||
|
||||
try:
|
||||
for mod in gatherer.module_imports:
|
||||
|
|
@ -768,7 +774,7 @@ def add_needed_imports_from_module(
|
|||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding imports to destination module code: {e}")
|
||||
return dst_module_code
|
||||
return dst_code_fallback
|
||||
|
||||
for mod, asname in gatherer.module_aliases.items():
|
||||
if not asname:
|
||||
|
|
@ -796,7 +802,7 @@ def add_needed_imports_from_module(
|
|||
return transformed_module.code.lstrip("\n")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding imports to destination module code: {e}")
|
||||
return dst_module_code
|
||||
return dst_code_fallback
|
||||
|
||||
|
||||
def get_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | None, set[tuple[str, str]]]:
|
||||
|
|
|
|||
|
|
@ -105,6 +105,7 @@ def test_code_replacement10() -> None:
|
|||
```python:{file_path.relative_to(file_path.parent)}
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
|
@ -165,6 +166,7 @@ def test_class_method_dependencies() -> None:
|
|||
from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class Graph:
|
||||
def __init__(self, vertices):
|
||||
self.graph = defaultdict(list)
|
||||
|
|
@ -244,6 +246,7 @@ def test_bubble_sort_helper() -> None:
|
|||
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
|
||||
import math
|
||||
|
||||
|
||||
def sorter(arr):
|
||||
arr.sort()
|
||||
x = math.sqrt(2)
|
||||
|
|
@ -253,6 +256,7 @@ def sorter(arr):
|
|||
```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py
|
||||
from bubble_sort_with_math import sorter
|
||||
|
||||
|
||||
def sort_from_another_file(arr):
|
||||
sorted_arr = sorter(arr)
|
||||
return sorted_arr
|
||||
|
|
@ -1181,6 +1185,7 @@ API_URL = "https://api.example.com/data"
|
|||
```python:{path_to_utils.relative_to(project_root)}
|
||||
import math
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
|
|
@ -1201,6 +1206,7 @@ import requests
|
|||
from globals import API_URL
|
||||
from utils import DataProcessor
|
||||
|
||||
|
||||
def fetch_and_process_data():
|
||||
# Use the global variable for the request
|
||||
response = requests.get(API_URL)
|
||||
|
|
@ -1280,6 +1286,7 @@ API_URL = "https://api.example.com/data"
|
|||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
|
|
@ -1300,6 +1307,7 @@ import requests
|
|||
from globals import API_URL
|
||||
from utils import DataProcessor
|
||||
|
||||
|
||||
def fetch_and_transform_data():
|
||||
# Use the global variable for the request
|
||||
response = requests.get(API_URL)
|
||||
|
|
@ -1388,6 +1396,7 @@ class DataTransformer:
|
|||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
|
|
@ -1468,6 +1477,7 @@ class DataTransformer:
|
|||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
|
|
@ -1599,6 +1609,7 @@ def test_repo_helper_circular_dependency() -> None:
|
|||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
|
|
@ -1613,6 +1624,7 @@ class DataProcessor:
|
|||
```python:{path_to_transform_utils.relative_to(project_root)}
|
||||
from code_to_optimize.code_directories.retriever.utils import DataProcessor
|
||||
|
||||
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
|
@ -1745,6 +1757,7 @@ def test_direct_module_import() -> None:
|
|||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
\"\"\"A class for processing data.\"\"\"
|
||||
|
||||
|
|
@ -1788,6 +1801,7 @@ import requests
|
|||
from globals import API_URL
|
||||
from utils import DataProcessor
|
||||
|
||||
|
||||
def fetch_and_transform_data():
|
||||
# Use the global variable for the request
|
||||
response = requests.get(API_URL)
|
||||
|
|
@ -3890,6 +3904,7 @@ import dataclasses
|
|||
import enum
|
||||
import typing as t
|
||||
|
||||
|
||||
class MessageKind(enum.StrEnum):
|
||||
ASK_FOR_CLIPBOARD_RESPONSE = "ask-for-clipboard-response"
|
||||
BEGIN_EXFILTRATION = "begin-exfiltration"
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def test_basic_class() -> None:
|
|||
class_var = "value"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -47,7 +47,7 @@ def test_dunder_methods() -> None:
|
|||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -75,7 +75,7 @@ def test_dunder_methods_remove_docstring() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -102,7 +102,7 @@ def test_class_remove_docstring() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -131,7 +131,7 @@ def test_mixed_remove_docstring() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -171,7 +171,7 @@ def test_docstrings() -> None:
|
|||
\"\"\"Class docstring.\"\"\"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -190,7 +190,7 @@ def test_method_signatures() -> None:
|
|||
|
||||
expected = """"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -212,7 +212,7 @@ def test_multiple_top_level_targets() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target1", "TestClass.target2"}, set()
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -232,7 +232,7 @@ def test_class_annotations() -> None:
|
|||
var2: str
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -254,7 +254,7 @@ def test_class_annotations_if() -> None:
|
|||
var2: str
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -280,7 +280,7 @@ def test_class_annotations_try() -> None:
|
|||
continue
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -316,7 +316,7 @@ def test_class_annotations_else() -> None:
|
|||
var2: str
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -331,7 +331,7 @@ def test_top_level_functions() -> None:
|
|||
|
||||
expected = """"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -350,7 +350,7 @@ def test_module_var() -> None:
|
|||
x = 5
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -377,7 +377,7 @@ def test_module_var_if() -> None:
|
|||
z = 10
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -412,7 +412,7 @@ def test_conditional_class_definitions() -> None:
|
|||
platform = "other"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -471,7 +471,7 @@ def test_multiple_except_clauses() -> None:
|
|||
error_type = "cleanup"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -524,7 +524,7 @@ def test_with_statement_and_loops() -> None:
|
|||
context = "cleanup"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -573,7 +573,7 @@ def test_async_with_try_except() -> None:
|
|||
status = "cancelled"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -675,7 +675,7 @@ def test_simplified_complete_implementation() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -768,5 +768,5 @@ def test_simplified_complete_implementation_no_docstring() -> None:
|
|||
{"DataProcessor.target_method", "ResultHandler.target_method"},
|
||||
set(),
|
||||
remove_docstrings=True,
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ def test_simple_function() -> None:
|
|||
y = 2
|
||||
return x + y
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
|
||||
|
||||
expected = dedent("""
|
||||
def target_function():
|
||||
|
|
@ -32,7 +32,7 @@ def test_class_method() -> None:
|
|||
y = 2
|
||||
return x + y
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"}).code
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -56,7 +56,7 @@ def test_class_with_attributes() -> None:
|
|||
def other_method(self):
|
||||
print("this should be excluded")
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -80,7 +80,7 @@ def test_basic_class_structure() -> None:
|
|||
def not_findable(self):
|
||||
return 42
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"Outer.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"Outer.target_method"}).code
|
||||
|
||||
expected = dedent("""
|
||||
class Outer:
|
||||
|
|
@ -100,7 +100,7 @@ def test_top_level_targets() -> None:
|
|||
def target_function():
|
||||
return 42
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
|
||||
|
||||
expected = dedent("""
|
||||
def target_function():
|
||||
|
|
@ -123,7 +123,7 @@ def test_multiple_top_level_classes() -> None:
|
|||
def process(self):
|
||||
return "C"
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"}).code
|
||||
|
||||
expected = dedent("""
|
||||
class ClassA:
|
||||
|
|
@ -148,7 +148,7 @@ def test_try_except_structure() -> None:
|
|||
def handle_error(self):
|
||||
print("error")
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"TargetClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"TargetClass.target_method"}).code
|
||||
|
||||
expected = dedent("""
|
||||
try:
|
||||
|
|
@ -175,7 +175,7 @@ def test_init_method() -> None:
|
|||
def target_method(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -200,7 +200,7 @@ def test_dunder_method() -> None:
|
|||
def target_method(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -221,7 +221,7 @@ def test_no_targets_found() -> None:
|
|||
def target(self):
|
||||
pass
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}).code
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
def method(self):
|
||||
|
|
@ -266,5 +266,5 @@ def test_module_var() -> None:
|
|||
var2 = "test"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ def test_simple_function() -> None:
|
|||
y = 2
|
||||
return x + y
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()).code
|
||||
|
||||
expected = """
|
||||
def target_function():
|
||||
|
|
@ -44,7 +44,7 @@ def test_basic_class() -> None:
|
|||
print("This should be included")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ def test_dunder_methods() -> None:
|
|||
print("include me")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -107,7 +107,7 @@ def test_dunder_methods_remove_docstring() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -139,7 +139,7 @@ def test_class_remove_docstring() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -181,7 +181,7 @@ def test_method_signatures() -> None:
|
|||
return "value"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -215,7 +215,7 @@ def test_multiple_top_level_targets() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"TestClass.target1", "TestClass.target2"}, set()
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -238,7 +238,7 @@ def test_class_annotations() -> None:
|
|||
self.var2 = "test"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -263,7 +263,7 @@ def test_class_annotations_if() -> None:
|
|||
self.var2 = "test"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -304,7 +304,7 @@ def test_conditional_class_definitions() -> None:
|
|||
print("other")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -333,7 +333,7 @@ def test_try_except_structure() -> None:
|
|||
print("error")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TargetClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TargetClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -355,7 +355,7 @@ def test_module_var() -> None:
|
|||
x = 5
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -385,7 +385,7 @@ def test_module_var_if() -> None:
|
|||
z = 10
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -416,7 +416,7 @@ def test_multiple_classes() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"ClassA.process", "ClassC.process"}, set()
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -477,7 +477,7 @@ def test_with_statement_and_loops() -> None:
|
|||
print("cleanup")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -532,7 +532,7 @@ def test_async_with_try_except() -> None:
|
|||
await self.cleanup()
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -659,7 +659,7 @@ def test_simplified_complete_implementation() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -765,5 +765,5 @@ def test_simplified_complete_implementation_no_docstring() -> None:
|
|||
{"DataProcessor.target_method", "ResultHandler.target_method"},
|
||||
set(),
|
||||
remove_docstrings=True,
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
|
|
|||
Loading…
Reference in a new issue