diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 52ba0def6..0348ae79b 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -5,9 +5,10 @@ import logging from typing import TYPE_CHECKING import libcst as cst +import libcst.matchers as m from libcst.codemod import CodemodContext from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor -from libcst.helpers import calculate_module_and_package, get_full_name_for_node +from libcst.helpers import calculate_module_and_package from codeflash.discovery.functions_to_optimize import FunctionParent @@ -18,26 +19,25 @@ if TYPE_CHECKING: from codeflash.models.models import FunctionSource -def delete_aliased___future___import(module_code: str) -> str: - module: cst.Module = cst.parse_module(module_code) - statement_number = next( - ( - i - for i, top_statement in enumerate(module.body) - if isinstance(top_statement, cst.SimpleStatementLine) - for base_small_statement in top_statement.body - if isinstance(base_small_statement, cst.ImportFrom) - and get_full_name_for_node(base_small_statement.module) == "__future__" - and not isinstance(base_small_statement_names := base_small_statement.names, cst.ImportStar) - for name in base_small_statement_names - if name.evaluated_alias - ), - None, - ) - if statement_number is None: - return module_code - new_body = [statement for i, statement in enumerate(module.body) if i != statement_number] - return (module.with_changes(body=new_body) if module.body else module).code +class FutureAliasedImportTransformer(cst.CSTTransformer): + def leave_ImportFrom( + self, + original_node: cst.ImportFrom, + updated_node: cst.ImportFrom, + ) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel: + if ( + (updated_node_module := updated_node.module) + and updated_node_module.value == "__future__" + and all(m.matches(name, m.ImportAlias()) for name in updated_node.names) + ): + if names := [name for name in updated_node.names if name.asname is None]: + return updated_node.with_changes(names=names) + return cst.RemoveFromParent() + return updated_node + + +def delete___future___aliased_imports(module_code: str) -> str: + return cst.parse_module(module_code).visit(FutureAliasedImportTransformer()).code def add_needed_imports_from_module( @@ -49,7 +49,7 @@ def add_needed_imports_from_module( helper_functions: list[FunctionSource] | None = None, ) -> str: """Add all needed and used source module code imports to the destination module code, and return it.""" - src_module_code = delete_aliased___future___import(src_module_code) + src_module_code = delete___future___aliased_imports(src_module_code) if helper_functions is None: helper_functions = [] helper_functions_fqn = {f.fully_qualified_name for f in helper_functions} diff --git a/codeflash/tracer.py b/codeflash/tracer.py index ae810ac03..c2ec1f40c 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -283,11 +283,9 @@ class Tracer: # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory # leaks, bad references or side-effects when unpickling. - arguments = ( - {k: v for k, v in arguments.items() if k != "self"} - if (class_name and code.co_name == "__init__") - else arguments.copy() - ) + arguments = {k: v for k, v in arguments.items()} + if class_name and code.co_name == "__init__": + del arguments["self"] local_vars = pickle.dumps( arguments, protocol=pickle.HIGHEST_PROTOCOL, diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index c7bbbc0be..aca3a99e9 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -6,8 +6,7 @@ from argparse import Namespace from collections import defaultdict from pathlib import Path -import libcst as cst -from codeflash.code_utils.code_extractor import delete_aliased___future___import +from codeflash.code_utils.code_extractor import delete___future___aliased_imports from codeflash.code_utils.code_replacer import ( is_zero_diff, replace_functions_and_add_imports, @@ -1515,13 +1514,13 @@ print("Hello monde") expected_code1 = """print("Hello monde") """ - assert delete_aliased___future___import(module_code1) == expected_code1 + assert delete___future___aliased_imports(module_code1) == expected_code1 module_code2 = """from __future__ import annotations print("Hello monde") """ - assert delete_aliased___future___import(module_code2) == module_code2 + assert delete___future___aliased_imports(module_code2) == module_code2 module_code3 = """from __future__ import annotations as _annotations from __future__ import annotations @@ -1534,7 +1533,7 @@ from past import autopasta as dood print("Hello monde") """ - assert delete_aliased___future___import(module_code3) == expected_code3 + assert delete___future___aliased_imports(module_code3) == expected_code3 module_code4 = """from __future__ import annotations from __future__ import annotations as _annotations @@ -1547,14 +1546,14 @@ from past import autopasta as dood print("Hello monde") """ - assert delete_aliased___future___import(module_code4) == expected_module_code4 + assert delete___future___aliased_imports(module_code4) == expected_module_code4 module_code5 = """from future import annotations as _annotations from past import autopasta as dood print("Hello monde") """ - assert delete_aliased___future___import(module_code5) == module_code5 + assert delete___future___aliased_imports(module_code5) == module_code5 module_code6 = '''"""Private logic for creating models.""" @@ -1563,7 +1562,7 @@ from __future__ import annotations as _annotations expected_code6 = '''"""Private logic for creating models.""" ''' - assert delete_aliased___future___import(module_code6) == expected_code6 + assert delete___future___aliased_imports(module_code6) == expected_code6 def test_0_diff_code_replacement():