Copy and pickle, fix __future__ aliased imports below docstrings.
This commit is contained in:
parent
d5fc6b0d12
commit
f993500610
3 changed files with 50 additions and 40 deletions
|
|
@ -5,7 +5,6 @@ import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
import libcst.matchers as m
|
|
||||||
from libcst.codemod import CodemodContext
|
from libcst.codemod import CodemodContext
|
||||||
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
|
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
|
||||||
from libcst.helpers import calculate_module_and_package, get_full_name_for_node
|
from libcst.helpers import calculate_module_and_package, get_full_name_for_node
|
||||||
|
|
@ -19,33 +18,25 @@ if TYPE_CHECKING:
|
||||||
from codeflash.models.models import FunctionSource
|
from codeflash.models.models import FunctionSource
|
||||||
|
|
||||||
|
|
||||||
def get_first_import_from_node(module: cst.Module) -> cst.ImportFrom | None:
|
def delete_aliased___future___import(module_code: str) -> str:
|
||||||
#
|
module: cst.Module = cst.parse_module(module_code)
|
||||||
if (
|
statement_number = next(
|
||||||
(module_body := module.body)
|
(
|
||||||
and (first_statement_body := module_body[0].body)
|
i
|
||||||
and not isinstance(first_statement_body, cst.BaseSuite)
|
for i, top_statement in enumerate(module.body)
|
||||||
and isinstance(first_base_small_statement := first_statement_body[0], cst.ImportFrom)
|
if isinstance(top_statement, cst.SimpleStatementLine)
|
||||||
):
|
for base_small_statement in top_statement.body
|
||||||
return first_base_small_statement
|
if isinstance(base_small_statement, cst.ImportFrom)
|
||||||
return None
|
and get_full_name_for_node(base_small_statement.module) == "__future__"
|
||||||
|
and not isinstance(base_small_statement_names := base_small_statement.names, cst.ImportStar)
|
||||||
|
for name in base_small_statement_names
|
||||||
def remove_first_imported_aliased_objects(
|
if name.evaluated_alias
|
||||||
module_code: str,
|
),
|
||||||
imported_module_name: str,
|
None,
|
||||||
) -> tuple[str, cst.ImportFrom | None]:
|
|
||||||
tree: cst.Module = cst.parse_module(module_code)
|
|
||||||
first_import_from_node: cst.ImportFrom | None = get_first_import_from_node(tree)
|
|
||||||
return (
|
|
||||||
((tree.with_changes(body=tree.body[1:]) if tree.body else tree).code, first_import_from_node)
|
|
||||||
if first_import_from_node
|
|
||||||
and (first_import_from_node_module := first_import_from_node.module)
|
|
||||||
and get_full_name_for_node(first_import_from_node_module) == imported_module_name
|
|
||||||
and not isinstance(first_import_from_node_names := first_import_from_node.names, cst.ImportStar)
|
|
||||||
and any(m.matches(alias, m.ImportAlias(asname=m.AsName())) for alias in first_import_from_node_names)
|
|
||||||
else (module_code, None)
|
|
||||||
)
|
)
|
||||||
|
if statement_number is None:
|
||||||
|
return module_code
|
||||||
|
return module.with_changes(body=[s for i, s in enumerate(module.body) if i != statement_number]).code
|
||||||
|
|
||||||
|
|
||||||
def add_needed_imports_from_module(
|
def add_needed_imports_from_module(
|
||||||
|
|
@ -57,7 +48,7 @@ def add_needed_imports_from_module(
|
||||||
helper_functions: list[FunctionSource] | None = None,
|
helper_functions: list[FunctionSource] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Add all needed and used source module code imports to the destination module code, and return it."""
|
"""Add all needed and used source module code imports to the destination module code, and return it."""
|
||||||
src_module_code, _ = remove_first_imported_aliased_objects(src_module_code, "__future__")
|
src_module_code = delete_aliased___future___import(src_module_code)
|
||||||
if helper_functions is None:
|
if helper_functions is None:
|
||||||
helper_functions = []
|
helper_functions = []
|
||||||
helper_functions_fqn = {f.fully_qualified_name for f in helper_functions}
|
helper_functions_fqn = {f.fully_qualified_name for f in helper_functions}
|
||||||
|
|
|
||||||
|
|
@ -280,9 +280,14 @@ class Tracer:
|
||||||
try:
|
try:
|
||||||
# pickling can be a recursive operator, so we need to increase the recursion limit
|
# pickling can be a recursive operator, so we need to increase the recursion limit
|
||||||
sys.setrecursionlimit(10000)
|
sys.setrecursionlimit(10000)
|
||||||
# We do not pickle self to avoid recursion errors, and will instead
|
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
|
||||||
if class_name and code.co_name == "__init__":
|
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
|
||||||
arguments = {k: v for k, v in arguments.items() if k != "self"}
|
# leaks, bad references or side-effects when unpickling.
|
||||||
|
arguments = (
|
||||||
|
{k: v for k, v in arguments.items() if k != "self"}
|
||||||
|
if (class_name and code.co_name == "__init__")
|
||||||
|
else arguments.copy()
|
||||||
|
)
|
||||||
local_vars = pickle.dumps(
|
local_vars = pickle.dumps(
|
||||||
arguments,
|
arguments,
|
||||||
protocol=pickle.HIGHEST_PROTOCOL,
|
protocol=pickle.HIGHEST_PROTOCOL,
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
from codeflash.code_utils.code_extractor import remove_first_imported_aliased_objects
|
from codeflash.code_utils.code_extractor import delete_aliased___future___import
|
||||||
from codeflash.code_utils.code_replacer import (
|
from codeflash.code_utils.code_replacer import (
|
||||||
is_zero_diff,
|
is_zero_diff,
|
||||||
replace_functions_and_add_imports,
|
replace_functions_and_add_imports,
|
||||||
|
|
@ -1515,13 +1515,13 @@ print("Hello monde")
|
||||||
expected_code1 = """print("Hello monde")
|
expected_code1 = """print("Hello monde")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert remove_first_imported_aliased_objects(module_code1, "__future__")[0] == expected_code1
|
assert delete_aliased___future___import(module_code1) == expected_code1
|
||||||
|
|
||||||
module_code2 = """from __future__ import annotations
|
module_code2 = """from __future__ import annotations
|
||||||
print("Hello monde")
|
print("Hello monde")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert remove_first_imported_aliased_objects(module_code2, "__future__")[0] == module_code2
|
assert delete_aliased___future___import(module_code2) == module_code2
|
||||||
|
|
||||||
module_code3 = """from __future__ import annotations as _annotations
|
module_code3 = """from __future__ import annotations as _annotations
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -1534,7 +1534,7 @@ from past import autopasta as dood
|
||||||
print("Hello monde")
|
print("Hello monde")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert remove_first_imported_aliased_objects(module_code3, "__future__")[0] == expected_code3
|
assert delete_aliased___future___import(module_code3) == expected_code3
|
||||||
|
|
||||||
module_code4 = """from __future__ import annotations
|
module_code4 = """from __future__ import annotations
|
||||||
from __future__ import annotations as _annotations
|
from __future__ import annotations as _annotations
|
||||||
|
|
@ -1542,15 +1542,29 @@ from past import autopasta as dood
|
||||||
print("Hello monde")
|
print("Hello monde")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert remove_first_imported_aliased_objects(module_code4, "__future__")[0] == module_code4
|
expected_module_code4 = """from __future__ import annotations
|
||||||
|
|
||||||
module_code5 = """from future import annotations as _annotations
|
|
||||||
from __future__ import annotations as _annotations
|
|
||||||
from past import autopasta as dood
|
from past import autopasta as dood
|
||||||
print("Hello monde")
|
print("Hello monde")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert remove_first_imported_aliased_objects(module_code5, "__future__")[0] == module_code5
|
assert delete_aliased___future___import(module_code4) == expected_module_code4
|
||||||
|
|
||||||
|
module_code5 = """from future import annotations as _annotations
|
||||||
|
from past import autopasta as dood
|
||||||
|
print("Hello monde")
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert delete_aliased___future___import(module_code5) == module_code5
|
||||||
|
|
||||||
|
module_code6 = '''"""Private logic for creating models."""
|
||||||
|
|
||||||
|
from __future__ import annotations as _annotations
|
||||||
|
'''
|
||||||
|
expected_code6 = '''"""Private logic for creating models."""
|
||||||
|
'''
|
||||||
|
|
||||||
|
assert delete_aliased___future___import(module_code6) == expected_code6
|
||||||
|
|
||||||
|
|
||||||
def test_0_diff_code_replacement():
|
def test_0_diff_code_replacement():
|
||||||
original_code = """from __future__ import annotations
|
original_code = """from __future__ import annotations
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue