Copy and pickle, fix __future__ aliased imports below docstrings.

This commit is contained in:
RD 2024-07-26 04:46:55 -07:00
parent d5fc6b0d12
commit f993500610
3 changed files with 50 additions and 40 deletions

View file

@ -5,7 +5,6 @@ import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import libcst as cst import libcst as cst
import libcst.matchers as m
from libcst.codemod import CodemodContext from libcst.codemod import CodemodContext
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
from libcst.helpers import calculate_module_and_package, get_full_name_for_node from libcst.helpers import calculate_module_and_package, get_full_name_for_node
@ -19,33 +18,25 @@ if TYPE_CHECKING:
from codeflash.models.models import FunctionSource from codeflash.models.models import FunctionSource
def get_first_import_from_node(module: cst.Module) -> cst.ImportFrom | None: def delete_aliased___future___import(module_code: str) -> str:
# module: cst.Module = cst.parse_module(module_code)
if ( statement_number = next(
(module_body := module.body) (
and (first_statement_body := module_body[0].body) i
and not isinstance(first_statement_body, cst.BaseSuite) for i, top_statement in enumerate(module.body)
and isinstance(first_base_small_statement := first_statement_body[0], cst.ImportFrom) if isinstance(top_statement, cst.SimpleStatementLine)
): for base_small_statement in top_statement.body
return first_base_small_statement if isinstance(base_small_statement, cst.ImportFrom)
return None and get_full_name_for_node(base_small_statement.module) == "__future__"
and not isinstance(base_small_statement_names := base_small_statement.names, cst.ImportStar)
for name in base_small_statement_names
def remove_first_imported_aliased_objects( if name.evaluated_alias
module_code: str, ),
imported_module_name: str, None,
) -> tuple[str, cst.ImportFrom | None]:
tree: cst.Module = cst.parse_module(module_code)
first_import_from_node: cst.ImportFrom | None = get_first_import_from_node(tree)
return (
((tree.with_changes(body=tree.body[1:]) if tree.body else tree).code, first_import_from_node)
if first_import_from_node
and (first_import_from_node_module := first_import_from_node.module)
and get_full_name_for_node(first_import_from_node_module) == imported_module_name
and not isinstance(first_import_from_node_names := first_import_from_node.names, cst.ImportStar)
and any(m.matches(alias, m.ImportAlias(asname=m.AsName())) for alias in first_import_from_node_names)
else (module_code, None)
) )
if statement_number is None:
return module_code
return module.with_changes(body=[s for i, s in enumerate(module.body) if i != statement_number]).code
def add_needed_imports_from_module( def add_needed_imports_from_module(
@ -57,7 +48,7 @@ def add_needed_imports_from_module(
helper_functions: list[FunctionSource] | None = None, helper_functions: list[FunctionSource] | None = None,
) -> str: ) -> str:
"""Add all needed and used source module code imports to the destination module code, and return it.""" """Add all needed and used source module code imports to the destination module code, and return it."""
src_module_code, _ = remove_first_imported_aliased_objects(src_module_code, "__future__") src_module_code = delete_aliased___future___import(src_module_code)
if helper_functions is None: if helper_functions is None:
helper_functions = [] helper_functions = []
helper_functions_fqn = {f.fully_qualified_name for f in helper_functions} helper_functions_fqn = {f.fully_qualified_name for f in helper_functions}

View file

@ -280,9 +280,14 @@ class Tracer:
try: try:
# pickling can be a recursive operator, so we need to increase the recursion limit # pickling can be a recursive operator, so we need to increase the recursion limit
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
# We do not pickle self to avoid recursion errors, and will instead # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
if class_name and code.co_name == "__init__": # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
arguments = {k: v for k, v in arguments.items() if k != "self"} # leaks, bad references or side-effects when unpickling.
arguments = (
{k: v for k, v in arguments.items() if k != "self"}
if (class_name and code.co_name == "__init__")
else arguments.copy()
)
local_vars = pickle.dumps( local_vars = pickle.dumps(
arguments, arguments,
protocol=pickle.HIGHEST_PROTOCOL, protocol=pickle.HIGHEST_PROTOCOL,

View file

@ -7,7 +7,7 @@ from collections import defaultdict
from pathlib import Path from pathlib import Path
import libcst as cst import libcst as cst
from codeflash.code_utils.code_extractor import remove_first_imported_aliased_objects from codeflash.code_utils.code_extractor import delete_aliased___future___import
from codeflash.code_utils.code_replacer import ( from codeflash.code_utils.code_replacer import (
is_zero_diff, is_zero_diff,
replace_functions_and_add_imports, replace_functions_and_add_imports,
@ -1515,13 +1515,13 @@ print("Hello monde")
expected_code1 = """print("Hello monde") expected_code1 = """print("Hello monde")
""" """
assert remove_first_imported_aliased_objects(module_code1, "__future__")[0] == expected_code1 assert delete_aliased___future___import(module_code1) == expected_code1
module_code2 = """from __future__ import annotations module_code2 = """from __future__ import annotations
print("Hello monde") print("Hello monde")
""" """
assert remove_first_imported_aliased_objects(module_code2, "__future__")[0] == module_code2 assert delete_aliased___future___import(module_code2) == module_code2
module_code3 = """from __future__ import annotations as _annotations module_code3 = """from __future__ import annotations as _annotations
from __future__ import annotations from __future__ import annotations
@ -1534,7 +1534,7 @@ from past import autopasta as dood
print("Hello monde") print("Hello monde")
""" """
assert remove_first_imported_aliased_objects(module_code3, "__future__")[0] == expected_code3 assert delete_aliased___future___import(module_code3) == expected_code3
module_code4 = """from __future__ import annotations module_code4 = """from __future__ import annotations
from __future__ import annotations as _annotations from __future__ import annotations as _annotations
@ -1542,15 +1542,29 @@ from past import autopasta as dood
print("Hello monde") print("Hello monde")
""" """
assert remove_first_imported_aliased_objects(module_code4, "__future__")[0] == module_code4 expected_module_code4 = """from __future__ import annotations
module_code5 = """from future import annotations as _annotations
from __future__ import annotations as _annotations
from past import autopasta as dood from past import autopasta as dood
print("Hello monde") print("Hello monde")
""" """
assert remove_first_imported_aliased_objects(module_code5, "__future__")[0] == module_code5 assert delete_aliased___future___import(module_code4) == expected_module_code4
module_code5 = """from future import annotations as _annotations
from past import autopasta as dood
print("Hello monde")
"""
assert delete_aliased___future___import(module_code5) == module_code5
module_code6 = '''"""Private logic for creating models."""
from __future__ import annotations as _annotations
'''
expected_code6 = '''"""Private logic for creating models."""
'''
assert delete_aliased___future___import(module_code6) == expected_code6
def test_0_diff_code_replacement(): def test_0_diff_code_replacement():
original_code = """from __future__ import annotations original_code = """from __future__ import annotations