libcst hates aliased __future__ imports

This commit is contained in:
RD 2024-07-11 05:28:56 -07:00
parent 0ddef46323
commit c2e4e03b20
2 changed files with 82 additions and 3 deletions

View file

@ -5,9 +5,10 @@ import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import libcst as cst import libcst as cst
import libcst.matchers as m
from libcst.codemod import CodemodContext from libcst.codemod import CodemodContext
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
from libcst.helpers import calculate_module_and_package from libcst.helpers import calculate_module_and_package, get_full_name_for_node
from codeflash.discovery.functions_to_optimize import FunctionParent from codeflash.discovery.functions_to_optimize import FunctionParent
@ -18,6 +19,35 @@ if TYPE_CHECKING:
from codeflash.models.models import FunctionSource from codeflash.models.models import FunctionSource
def get_first_import_from_node(module: cst.Module) -> cst.ImportFrom | None:
#
if (
(module_body := module.body)
and (first_statement_body := module_body[0].body)
and not isinstance(first_statement_body, cst.BaseSuite)
and isinstance(first_base_small_statement := first_statement_body[0], cst.ImportFrom)
):
return first_base_small_statement
return None
def remove_first_imported_aliased_objects(
module_code: str,
imported_module_name: str,
) -> tuple[str, cst.ImportFrom | None]:
tree: cst.Module = cst.parse_module(module_code)
first_import_from_node: cst.ImportFrom | None = get_first_import_from_node(tree)
return (
((tree.with_changes(body=tree.body[1:]) if tree.body else tree).code, first_import_from_node)
if first_import_from_node
and (first_import_from_node_module := first_import_from_node.module)
and get_full_name_for_node(first_import_from_node_module) == imported_module_name
and not isinstance(first_import_from_node_names := first_import_from_node.names, cst.ImportStar)
and any(m.matches(alias, m.ImportAlias(asname=m.AsName())) for alias in first_import_from_node_names)
else (module_code, None)
)
def add_needed_imports_from_module( def add_needed_imports_from_module(
src_module_code: str, src_module_code: str,
dst_module_code: str, dst_module_code: str,
@ -27,6 +57,7 @@ def add_needed_imports_from_module(
helper_functions: list[FunctionSource] | None = None, helper_functions: list[FunctionSource] | None = None,
) -> str: ) -> str:
"""Add all needed and used source module code imports to the destination module code, and return it.""" """Add all needed and used source module code imports to the destination module code, and return it."""
src_module_code, _ = remove_first_imported_aliased_objects(src_module_code, "__future__")
if helper_functions is None: if helper_functions is None:
helper_functions = [] helper_functions = []
helper_functions_fqn = {f.fully_qualified_name for f in helper_functions} helper_functions_fqn = {f.fully_qualified_name for f in helper_functions}
@ -241,7 +272,7 @@ def extract_code(
return edited_code, contextual_dunder_methods return edited_code, contextual_dunder_methods
def find_preexisting_objects(source_code: str): def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionParent]]]:
"""Find all preexisting functions, classes or class methods in the source code""" """Find all preexisting functions, classes or class methods in the source code"""
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [] preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
try: try:

View file

@ -6,6 +6,8 @@ from argparse import Namespace
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import libcst as cst
from codeflash.code_utils.code_extractor import remove_first_imported_aliased_objects
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
from codeflash.optimization.optimizer import Optimizer from codeflash.optimization.optimizer import Optimizer
@ -1253,7 +1255,7 @@ class TestResults(BaseModel):
) )
def test_code_replacement_type_annotation(): def test_code_replacement_type_annotation() -> None:
original_code = '''import numpy as np original_code = '''import numpy as np
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@ -1499,3 +1501,49 @@ def cosine_similarity_top_k(
return ret_idxs, scores return ret_idxs, scores
''' '''
) )
def test_future_aliased_imports_removal() -> None:
module_code1 = """from __future__ import annotations as _annotations
print("Hello monde")
"""
expected_code1 = """print("Hello monde")
"""
assert remove_first_imported_aliased_objects(module_code1, "__future__")[0] == expected_code1
module_code2 = """from __future__ import annotations
print("Hello monde")
"""
assert remove_first_imported_aliased_objects(module_code2, "__future__")[0] == module_code2
module_code3 = """from __future__ import annotations as _annotations
from __future__ import annotations
from past import autopasta as dood
print("Hello monde")
"""
expected_code3 = """from __future__ import annotations
from past import autopasta as dood
print("Hello monde")
"""
assert remove_first_imported_aliased_objects(module_code3, "__future__")[0] == expected_code3
module_code4 = """from __future__ import annotations
from __future__ import annotations as _annotations
from past import autopasta as dood
print("Hello monde")
"""
assert remove_first_imported_aliased_objects(module_code4, "__future__")[0] == module_code4
module_code5 = """from future import annotations as _annotations
from __future__ import annotations as _annotations
from past import autopasta as dood
print("Hello monde")
"""
assert remove_first_imported_aliased_objects(module_code5, "__future__")[0] == module_code5