takes care of docstring of new init function

This commit is contained in:
Alvin Ryanputra 2025-01-14 10:54:43 -08:00
parent 24ad3915bd
commit 17c354a466
3 changed files with 10 additions and 27 deletions

View file

@ -236,7 +236,16 @@ def merge_init_functions(original_init: cst.FunctionDef, new_init: cst.FunctionD
original_stmts = {get_only_code_content(cst_to_code(stmt)) for stmt in original_init.body.body}
# Filter new init body statements
filtered_body = []
for stmt in new_init.body.body:
# Filter out docstring of new init
if (
isinstance(stmt, cst.SimpleStatementLine)
and len(stmt.body) == 1
and isinstance(stmt.body[0], cst.Expr)
and isinstance(stmt.body[0].value, cst.SimpleString)
):
continue
# Filter out duplicate statements
if get_only_code_content(cst_to_code(stmt)) in original_stmts:
continue

View file

@ -73,11 +73,7 @@ def test_code_replacement10() -> None:
)
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
read_write_context, read_only_context, testgen_context = (
code_ctx.read_writable_code,
code_ctx.read_only_context_code,
code_ctx.testgen_context_code,
)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
from __future__ import annotations
@ -99,27 +95,9 @@ def test_code_replacement10() -> None:
"""
expected_read_only_context = """
"""
expected_testgen_context = """
from __future__ import annotations
class HelperClass:
def __init__(self, name):
self.name = name
def helper_method(self):
return self.name
class MainClass:
def __init__(self, name):
self.name = name
def main_method(self):
return HelperClass(self.name).helper_method()
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
assert testgen_context.strip() == dedent(expected_testgen_context).strip()
def test_class_method_dependencies() -> None:

View file

@ -124,7 +124,6 @@ def test_docstrings_and_comments() -> None:
original = """
class MyClass:
def __init__(self):
\"\"\"Original docstring.\"\"\"
# Setup configuration
self.config = {} # Empty config
"""
@ -138,13 +137,10 @@ def test_docstrings_and_comments() -> None:
result = merge_init_functions(
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
)
# TODO: handle docstrings differently
expected = """
def __init__(self):
\"\"\"Original docstring.\"\"\"
# Setup configuration
self.config = {} # Empty config
\"\"\"New docstring.\"\"\"
# Initialize database
self.db = None # Database connection
"""