libcst hates aliased __future__ imports
This commit is contained in:
parent
0ddef46323
commit
c2e4e03b20
2 changed files with 82 additions and 3 deletions
|
|
@ -5,9 +5,10 @@ import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
|
import libcst.matchers as m
|
||||||
from libcst.codemod import CodemodContext
|
from libcst.codemod import CodemodContext
|
||||||
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
|
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
|
||||||
from libcst.helpers import calculate_module_and_package
|
from libcst.helpers import calculate_module_and_package, get_full_name_for_node
|
||||||
|
|
||||||
from codeflash.discovery.functions_to_optimize import FunctionParent
|
from codeflash.discovery.functions_to_optimize import FunctionParent
|
||||||
|
|
||||||
|
|
@ -18,6 +19,35 @@ if TYPE_CHECKING:
|
||||||
from codeflash.models.models import FunctionSource
|
from codeflash.models.models import FunctionSource
|
||||||
|
|
||||||
|
|
||||||
|
def get_first_import_from_node(module: cst.Module) -> cst.ImportFrom | None:
|
||||||
|
#
|
||||||
|
if (
|
||||||
|
(module_body := module.body)
|
||||||
|
and (first_statement_body := module_body[0].body)
|
||||||
|
and not isinstance(first_statement_body, cst.BaseSuite)
|
||||||
|
and isinstance(first_base_small_statement := first_statement_body[0], cst.ImportFrom)
|
||||||
|
):
|
||||||
|
return first_base_small_statement
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def remove_first_imported_aliased_objects(
|
||||||
|
module_code: str,
|
||||||
|
imported_module_name: str,
|
||||||
|
) -> tuple[str, cst.ImportFrom | None]:
|
||||||
|
tree: cst.Module = cst.parse_module(module_code)
|
||||||
|
first_import_from_node: cst.ImportFrom | None = get_first_import_from_node(tree)
|
||||||
|
return (
|
||||||
|
((tree.with_changes(body=tree.body[1:]) if tree.body else tree).code, first_import_from_node)
|
||||||
|
if first_import_from_node
|
||||||
|
and (first_import_from_node_module := first_import_from_node.module)
|
||||||
|
and get_full_name_for_node(first_import_from_node_module) == imported_module_name
|
||||||
|
and not isinstance(first_import_from_node_names := first_import_from_node.names, cst.ImportStar)
|
||||||
|
and any(m.matches(alias, m.ImportAlias(asname=m.AsName())) for alias in first_import_from_node_names)
|
||||||
|
else (module_code, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_needed_imports_from_module(
|
def add_needed_imports_from_module(
|
||||||
src_module_code: str,
|
src_module_code: str,
|
||||||
dst_module_code: str,
|
dst_module_code: str,
|
||||||
|
|
@ -27,6 +57,7 @@ def add_needed_imports_from_module(
|
||||||
helper_functions: list[FunctionSource] | None = None,
|
helper_functions: list[FunctionSource] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Add all needed and used source module code imports to the destination module code, and return it."""
|
"""Add all needed and used source module code imports to the destination module code, and return it."""
|
||||||
|
src_module_code, _ = remove_first_imported_aliased_objects(src_module_code, "__future__")
|
||||||
if helper_functions is None:
|
if helper_functions is None:
|
||||||
helper_functions = []
|
helper_functions = []
|
||||||
helper_functions_fqn = {f.fully_qualified_name for f in helper_functions}
|
helper_functions_fqn = {f.fully_qualified_name for f in helper_functions}
|
||||||
|
|
@ -241,7 +272,7 @@ def extract_code(
|
||||||
return edited_code, contextual_dunder_methods
|
return edited_code, contextual_dunder_methods
|
||||||
|
|
||||||
|
|
||||||
def find_preexisting_objects(source_code: str):
|
def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionParent]]]:
|
||||||
"""Find all preexisting functions, classes or class methods in the source code"""
|
"""Find all preexisting functions, classes or class methods in the source code"""
|
||||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@ from argparse import Namespace
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import libcst as cst
|
||||||
|
from codeflash.code_utils.code_extractor import remove_first_imported_aliased_objects
|
||||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file
|
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file
|
||||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||||
from codeflash.optimization.optimizer import Optimizer
|
from codeflash.optimization.optimizer import Optimizer
|
||||||
|
|
@ -1253,7 +1255,7 @@ class TestResults(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_code_replacement_type_annotation():
|
def test_code_replacement_type_annotation() -> None:
|
||||||
original_code = '''import numpy as np
|
original_code = '''import numpy as np
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
@ -1499,3 +1501,49 @@ def cosine_similarity_top_k(
|
||||||
return ret_idxs, scores
|
return ret_idxs, scores
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_future_aliased_imports_removal() -> None:
|
||||||
|
module_code1 = """from __future__ import annotations as _annotations
|
||||||
|
print("Hello monde")
|
||||||
|
"""
|
||||||
|
|
||||||
|
expected_code1 = """print("Hello monde")
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert remove_first_imported_aliased_objects(module_code1, "__future__")[0] == expected_code1
|
||||||
|
|
||||||
|
module_code2 = """from __future__ import annotations
|
||||||
|
print("Hello monde")
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert remove_first_imported_aliased_objects(module_code2, "__future__")[0] == module_code2
|
||||||
|
|
||||||
|
module_code3 = """from __future__ import annotations as _annotations
|
||||||
|
from __future__ import annotations
|
||||||
|
from past import autopasta as dood
|
||||||
|
print("Hello monde")
|
||||||
|
"""
|
||||||
|
|
||||||
|
expected_code3 = """from __future__ import annotations
|
||||||
|
from past import autopasta as dood
|
||||||
|
print("Hello monde")
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert remove_first_imported_aliased_objects(module_code3, "__future__")[0] == expected_code3
|
||||||
|
|
||||||
|
module_code4 = """from __future__ import annotations
|
||||||
|
from __future__ import annotations as _annotations
|
||||||
|
from past import autopasta as dood
|
||||||
|
print("Hello monde")
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert remove_first_imported_aliased_objects(module_code4, "__future__")[0] == module_code4
|
||||||
|
|
||||||
|
module_code5 = """from future import annotations as _annotations
|
||||||
|
from __future__ import annotations as _annotations
|
||||||
|
from past import autopasta as dood
|
||||||
|
print("Hello monde")
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert remove_first_imported_aliased_objects(module_code5, "__future__")[0] == module_code5
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue