Fix all import zombies.
This commit is contained in:
parent
d7a42cf78b
commit
7dbb3fda05
3 changed files with 32 additions and 35 deletions
|
|
@ -5,9 +5,10 @@ import logging
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
import libcst as cst
|
||||
import libcst.matchers as m
|
||||
from libcst.codemod import CodemodContext
|
||||
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
|
||||
from libcst.helpers import calculate_module_and_package, get_full_name_for_node
|
||||
from libcst.helpers import calculate_module_and_package
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent
|
||||
|
||||
|
|
@ -18,26 +19,25 @@ if TYPE_CHECKING:
|
|||
from codeflash.models.models import FunctionSource
|
||||
|
||||
|
||||
def delete_aliased___future___import(module_code: str) -> str:
|
||||
module: cst.Module = cst.parse_module(module_code)
|
||||
statement_number = next(
|
||||
(
|
||||
i
|
||||
for i, top_statement in enumerate(module.body)
|
||||
if isinstance(top_statement, cst.SimpleStatementLine)
|
||||
for base_small_statement in top_statement.body
|
||||
if isinstance(base_small_statement, cst.ImportFrom)
|
||||
and get_full_name_for_node(base_small_statement.module) == "__future__"
|
||||
and not isinstance(base_small_statement_names := base_small_statement.names, cst.ImportStar)
|
||||
for name in base_small_statement_names
|
||||
if name.evaluated_alias
|
||||
),
|
||||
None,
|
||||
)
|
||||
if statement_number is None:
|
||||
return module_code
|
||||
new_body = [statement for i, statement in enumerate(module.body) if i != statement_number]
|
||||
return (module.with_changes(body=new_body) if module.body else module).code
|
||||
class FutureAliasedImportTransformer(cst.CSTTransformer):
|
||||
def leave_ImportFrom(
|
||||
self,
|
||||
original_node: cst.ImportFrom,
|
||||
updated_node: cst.ImportFrom,
|
||||
) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel:
|
||||
if (
|
||||
(updated_node_module := updated_node.module)
|
||||
and updated_node_module.value == "__future__"
|
||||
and all(m.matches(name, m.ImportAlias()) for name in updated_node.names)
|
||||
):
|
||||
if names := [name for name in updated_node.names if name.asname is None]:
|
||||
return updated_node.with_changes(names=names)
|
||||
return cst.RemoveFromParent()
|
||||
return updated_node
|
||||
|
||||
|
||||
def delete___future___aliased_imports(module_code: str) -> str:
|
||||
return cst.parse_module(module_code).visit(FutureAliasedImportTransformer()).code
|
||||
|
||||
|
||||
def add_needed_imports_from_module(
|
||||
|
|
@ -49,7 +49,7 @@ def add_needed_imports_from_module(
|
|||
helper_functions: list[FunctionSource] | None = None,
|
||||
) -> str:
|
||||
"""Add all needed and used source module code imports to the destination module code, and return it."""
|
||||
src_module_code = delete_aliased___future___import(src_module_code)
|
||||
src_module_code = delete___future___aliased_imports(src_module_code)
|
||||
if helper_functions is None:
|
||||
helper_functions = []
|
||||
helper_functions_fqn = {f.fully_qualified_name for f in helper_functions}
|
||||
|
|
|
|||
|
|
@ -283,11 +283,9 @@ class Tracer:
|
|||
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
|
||||
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
|
||||
# leaks, bad references or side-effects when unpickling.
|
||||
arguments = (
|
||||
{k: v for k, v in arguments.items() if k != "self"}
|
||||
if (class_name and code.co_name == "__init__")
|
||||
else arguments.copy()
|
||||
)
|
||||
arguments = {k: v for k, v in arguments.items()}
|
||||
if class_name and code.co_name == "__init__":
|
||||
del arguments["self"]
|
||||
local_vars = pickle.dumps(
|
||||
arguments,
|
||||
protocol=pickle.HIGHEST_PROTOCOL,
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ from argparse import Namespace
|
|||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import libcst as cst
|
||||
from codeflash.code_utils.code_extractor import delete_aliased___future___import
|
||||
from codeflash.code_utils.code_extractor import delete___future___aliased_imports
|
||||
from codeflash.code_utils.code_replacer import (
|
||||
is_zero_diff,
|
||||
replace_functions_and_add_imports,
|
||||
|
|
@ -1515,13 +1514,13 @@ print("Hello monde")
|
|||
expected_code1 = """print("Hello monde")
|
||||
"""
|
||||
|
||||
assert delete_aliased___future___import(module_code1) == expected_code1
|
||||
assert delete___future___aliased_imports(module_code1) == expected_code1
|
||||
|
||||
module_code2 = """from __future__ import annotations
|
||||
print("Hello monde")
|
||||
"""
|
||||
|
||||
assert delete_aliased___future___import(module_code2) == module_code2
|
||||
assert delete___future___aliased_imports(module_code2) == module_code2
|
||||
|
||||
module_code3 = """from __future__ import annotations as _annotations
|
||||
from __future__ import annotations
|
||||
|
|
@ -1534,7 +1533,7 @@ from past import autopasta as dood
|
|||
print("Hello monde")
|
||||
"""
|
||||
|
||||
assert delete_aliased___future___import(module_code3) == expected_code3
|
||||
assert delete___future___aliased_imports(module_code3) == expected_code3
|
||||
|
||||
module_code4 = """from __future__ import annotations
|
||||
from __future__ import annotations as _annotations
|
||||
|
|
@ -1547,14 +1546,14 @@ from past import autopasta as dood
|
|||
print("Hello monde")
|
||||
"""
|
||||
|
||||
assert delete_aliased___future___import(module_code4) == expected_module_code4
|
||||
assert delete___future___aliased_imports(module_code4) == expected_module_code4
|
||||
|
||||
module_code5 = """from future import annotations as _annotations
|
||||
from past import autopasta as dood
|
||||
print("Hello monde")
|
||||
"""
|
||||
|
||||
assert delete_aliased___future___import(module_code5) == module_code5
|
||||
assert delete___future___aliased_imports(module_code5) == module_code5
|
||||
|
||||
module_code6 = '''"""Private logic for creating models."""
|
||||
|
||||
|
|
@ -1563,7 +1562,7 @@ from __future__ import annotations as _annotations
|
|||
expected_code6 = '''"""Private logic for creating models."""
|
||||
'''
|
||||
|
||||
assert delete_aliased___future___import(module_code6) == expected_code6
|
||||
assert delete___future___aliased_imports(module_code6) == expected_code6
|
||||
|
||||
|
||||
def test_0_diff_code_replacement():
|
||||
|
|
|
|||
Loading…
Reference in a new issue