diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 6949bdbe1..32a5629df 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -5,9 +5,10 @@ import logging from typing import TYPE_CHECKING import libcst as cst +import libcst.matchers as m from libcst.codemod import CodemodContext from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor -from libcst.helpers import calculate_module_and_package +from libcst.helpers import calculate_module_and_package, get_full_name_for_node from codeflash.discovery.functions_to_optimize import FunctionParent @@ -18,6 +19,35 @@ if TYPE_CHECKING: from codeflash.models.models import FunctionSource +def get_first_import_from_node(module: cst.Module) -> cst.ImportFrom | None: + # + if ( + (module_body := module.body) + and (first_statement_body := module_body[0].body) + and not isinstance(first_statement_body, cst.BaseSuite) + and isinstance(first_base_small_statement := first_statement_body[0], cst.ImportFrom) + ): + return first_base_small_statement + return None + + +def remove_first_imported_aliased_objects( + module_code: str, + imported_module_name: str, +) -> tuple[str, cst.ImportFrom | None]: + tree: cst.Module = cst.parse_module(module_code) + first_import_from_node: cst.ImportFrom | None = get_first_import_from_node(tree) + return ( + ((tree.with_changes(body=tree.body[1:]) if tree.body else tree).code, first_import_from_node) + if first_import_from_node + and (first_import_from_node_module := first_import_from_node.module) + and get_full_name_for_node(first_import_from_node_module) == imported_module_name + and not isinstance(first_import_from_node_names := first_import_from_node.names, cst.ImportStar) + and any(m.matches(alias, m.ImportAlias(asname=m.AsName())) for alias in first_import_from_node_names) + else (module_code, None) + ) + + def add_needed_imports_from_module( src_module_code: str, dst_module_code: str, @@ -27,6 +57,7 @@ def add_needed_imports_from_module( helper_functions: list[FunctionSource] | None = None, ) -> str: """Add all needed and used source module code imports to the destination module code, and return it.""" + src_module_code, _ = remove_first_imported_aliased_objects(src_module_code, "__future__") if helper_functions is None: helper_functions = [] helper_functions_fqn = {f.fully_qualified_name for f in helper_functions} @@ -241,7 +272,7 @@ def extract_code( return edited_code, contextual_dunder_methods -def find_preexisting_objects(source_code: str): +def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionParent]]]: """Find all preexisting functions, classes or class methods in the source code""" preexisting_objects: list[tuple[str, list[FunctionParent]]] = [] try: diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index fe741f411..433391784 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -6,6 +6,8 @@ from argparse import Namespace from collections import defaultdict from pathlib import Path +import libcst as cst +from codeflash.code_utils.code_extractor import remove_first_imported_aliased_objects from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize from codeflash.optimization.optimizer import Optimizer @@ -1253,7 +1255,7 @@ class TestResults(BaseModel): ) -def test_code_replacement_type_annotation(): +def test_code_replacement_type_annotation() -> None: original_code = '''import numpy as np from pydantic.dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -1499,3 +1501,49 @@ def cosine_similarity_top_k( return ret_idxs, scores ''' ) + + +def test_future_aliased_imports_removal() -> None: + module_code1 = """from __future__ import annotations as _annotations +print("Hello monde") +""" + + expected_code1 = """print("Hello monde") +""" + + assert remove_first_imported_aliased_objects(module_code1, "__future__")[0] == expected_code1 + + module_code2 = """from __future__ import annotations +print("Hello monde") +""" + + assert remove_first_imported_aliased_objects(module_code2, "__future__")[0] == module_code2 + + module_code3 = """from __future__ import annotations as _annotations +from __future__ import annotations +from past import autopasta as dood +print("Hello monde") +""" + + expected_code3 = """from __future__ import annotations +from past import autopasta as dood +print("Hello monde") +""" + + assert remove_first_imported_aliased_objects(module_code3, "__future__")[0] == expected_code3 + + module_code4 = """from __future__ import annotations +from __future__ import annotations as _annotations +from past import autopasta as dood +print("Hello monde") +""" + + assert remove_first_imported_aliased_objects(module_code4, "__future__")[0] == module_code4 + + module_code5 = """from future import annotations as _annotations +from __future__ import annotations as _annotations +from past import autopasta as dood +print("Hello monde") +""" + + assert remove_first_imported_aliased_objects(module_code5, "__future__")[0] == module_code5