diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 6b6efb597..6bb9022b3 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -236,7 +236,16 @@ def merge_init_functions(original_init: cst.FunctionDef, new_init: cst.FunctionD original_stmts = {get_only_code_content(cst_to_code(stmt)) for stmt in original_init.body.body} # Filter new init body statements filtered_body = [] + for stmt in new_init.body.body: + # Filter out docstring of new init + if ( + isinstance(stmt, cst.SimpleStatementLine) + and len(stmt.body) == 1 + and isinstance(stmt.body[0], cst.Expr) + and isinstance(stmt.body[0].value, cst.SimpleString) + ): + continue # Filter out duplicate statements if get_only_code_content(cst_to_code(stmt)) in original_stmts: continue diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 6eaab0157..7272ebbe5 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -73,11 +73,7 @@ def test_code_replacement10() -> None: ) code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent) - read_write_context, read_only_context, testgen_context = ( - code_ctx.read_writable_code, - code_ctx.read_only_context_code, - code_ctx.testgen_context_code, - ) + read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code expected_read_write_context = """ from __future__ import annotations @@ -99,27 +95,9 @@ def test_code_replacement10() -> None: """ expected_read_only_context = """ """ - expected_testgen_context = """ - from __future__ import annotations - - class HelperClass: - def __init__(self, name): - self.name = name - def helper_method(self): - return self.name - - - class MainClass: - def __init__(self, name): - self.name = name - - def main_method(self): - return HelperClass(self.name).helper_method() - """ assert read_write_context.strip() == dedent(expected_read_write_context).strip() assert read_only_context.strip() == dedent(expected_read_only_context).strip() - assert testgen_context.strip() == dedent(expected_testgen_context).strip() def test_class_method_dependencies() -> None: diff --git a/tests/test_init_optimization.py b/tests/test_init_optimization.py index 16792221f..003b84d1d 100644 --- a/tests/test_init_optimization.py +++ b/tests/test_init_optimization.py @@ -124,7 +124,6 @@ def test_docstrings_and_comments() -> None: original = """ class MyClass: def __init__(self): - \"\"\"Original docstring.\"\"\" # Setup configuration self.config = {} # Empty config """ @@ -138,13 +137,10 @@ def test_docstrings_and_comments() -> None: result = merge_init_functions( cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0] ) - # TODO: handle docstrings differently expected = """ def __init__(self): - \"\"\"Original docstring.\"\"\" # Setup configuration self.config = {} # Empty config - \"\"\"New docstring.\"\"\" # Initialize database self.db = None # Database connection """