From f993500610cd4644ac0d1ddfec4e9f8120100b93 Mon Sep 17 00:00:00 2001 From: RD <92499101+iusedmyimagination@users.noreply.github.com> Date: Fri, 26 Jul 2024 04:46:55 -0700 Subject: [PATCH] Copy and pickle, fix __future__ aliased imports below docstrings. --- codeflash/code_utils/code_extractor.py | 47 +++++++++++--------------- codeflash/tracer.py | 11 ++++-- tests/test_code_replacement.py | 32 +++++++++++++----- 3 files changed, 50 insertions(+), 40 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 32a5629df..5816f5f50 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -5,7 +5,6 @@ import logging from typing import TYPE_CHECKING import libcst as cst -import libcst.matchers as m from libcst.codemod import CodemodContext from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor from libcst.helpers import calculate_module_and_package, get_full_name_for_node @@ -19,33 +18,25 @@ if TYPE_CHECKING: from codeflash.models.models import FunctionSource -def get_first_import_from_node(module: cst.Module) -> cst.ImportFrom | None: - # - if ( - (module_body := module.body) - and (first_statement_body := module_body[0].body) - and not isinstance(first_statement_body, cst.BaseSuite) - and isinstance(first_base_small_statement := first_statement_body[0], cst.ImportFrom) - ): - return first_base_small_statement - return None - - -def remove_first_imported_aliased_objects( - module_code: str, - imported_module_name: str, -) -> tuple[str, cst.ImportFrom | None]: - tree: cst.Module = cst.parse_module(module_code) - first_import_from_node: cst.ImportFrom | None = get_first_import_from_node(tree) - return ( - ((tree.with_changes(body=tree.body[1:]) if tree.body else tree).code, first_import_from_node) - if first_import_from_node - and (first_import_from_node_module := first_import_from_node.module) - and get_full_name_for_node(first_import_from_node_module) == imported_module_name - and not isinstance(first_import_from_node_names := first_import_from_node.names, cst.ImportStar) - and any(m.matches(alias, m.ImportAlias(asname=m.AsName())) for alias in first_import_from_node_names) - else (module_code, None) +def delete_aliased___future___import(module_code: str) -> str: + module: cst.Module = cst.parse_module(module_code) + statement_number = next( + ( + i + for i, top_statement in enumerate(module.body) + if isinstance(top_statement, cst.SimpleStatementLine) + for base_small_statement in top_statement.body + if isinstance(base_small_statement, cst.ImportFrom) + and get_full_name_for_node(base_small_statement.module) == "__future__" + and not isinstance(base_small_statement_names := base_small_statement.names, cst.ImportStar) + for name in base_small_statement_names + if name.evaluated_alias + ), + None, ) + if statement_number is None: + return module_code + return module.with_changes(body=[s for i, s in enumerate(module.body) if i != statement_number]).code def add_needed_imports_from_module( @@ -57,7 +48,7 @@ def add_needed_imports_from_module( helper_functions: list[FunctionSource] | None = None, ) -> str: """Add all needed and used source module code imports to the destination module code, and return it.""" - src_module_code, _ = remove_first_imported_aliased_objects(src_module_code, "__future__") + src_module_code = delete_aliased___future___import(src_module_code) if helper_functions is None: helper_functions = [] helper_functions_fqn = {f.fully_qualified_name for f in helper_functions} diff --git a/codeflash/tracer.py b/codeflash/tracer.py index f03abe981..ae810ac03 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -280,9 +280,14 @@ class Tracer: try: # pickling can be a recursive operator, so we need to increase the recursion limit sys.setrecursionlimit(10000) - # We do not pickle self to avoid recursion errors, and will instead - if class_name and code.co_name == "__init__": - arguments = {k: v for k, v in arguments.items() if k != "self"} + # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class + # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory + # leaks, bad references or side-effects when unpickling. + arguments = ( + {k: v for k, v in arguments.items() if k != "self"} + if (class_name and code.co_name == "__init__") + else arguments.copy() + ) local_vars = pickle.dumps( arguments, protocol=pickle.HIGHEST_PROTOCOL, diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index d32cf825e..c7bbbc0be 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -7,7 +7,7 @@ from collections import defaultdict from pathlib import Path import libcst as cst -from codeflash.code_utils.code_extractor import remove_first_imported_aliased_objects +from codeflash.code_utils.code_extractor import delete_aliased___future___import from codeflash.code_utils.code_replacer import ( is_zero_diff, replace_functions_and_add_imports, @@ -1515,13 +1515,13 @@ print("Hello monde") expected_code1 = """print("Hello monde") """ - assert remove_first_imported_aliased_objects(module_code1, "__future__")[0] == expected_code1 + assert delete_aliased___future___import(module_code1) == expected_code1 module_code2 = """from __future__ import annotations print("Hello monde") """ - assert remove_first_imported_aliased_objects(module_code2, "__future__")[0] == module_code2 + assert delete_aliased___future___import(module_code2) == module_code2 module_code3 = """from __future__ import annotations as _annotations from __future__ import annotations @@ -1534,7 +1534,7 @@ from past import autopasta as dood print("Hello monde") """ - assert remove_first_imported_aliased_objects(module_code3, "__future__")[0] == expected_code3 + assert delete_aliased___future___import(module_code3) == expected_code3 module_code4 = """from __future__ import annotations from __future__ import annotations as _annotations @@ -1542,15 +1542,29 @@ from past import autopasta as dood print("Hello monde") """ - assert remove_first_imported_aliased_objects(module_code4, "__future__")[0] == module_code4 - - module_code5 = """from future import annotations as _annotations -from __future__ import annotations as _annotations + expected_module_code4 = """from __future__ import annotations from past import autopasta as dood print("Hello monde") """ - assert remove_first_imported_aliased_objects(module_code5, "__future__")[0] == module_code5 + assert delete_aliased___future___import(module_code4) == expected_module_code4 + + module_code5 = """from future import annotations as _annotations +from past import autopasta as dood +print("Hello monde") +""" + + assert delete_aliased___future___import(module_code5) == module_code5 + + module_code6 = '''"""Private logic for creating models.""" + +from __future__ import annotations as _annotations +''' + expected_code6 = '''"""Private logic for creating models.""" +''' + + assert delete_aliased___future___import(module_code6) == expected_code6 + def test_0_diff_code_replacement(): original_code = """from __future__ import annotations