Restore and add tests, small fix.

This commit is contained in:
RD 2024-06-22 19:39:15 -07:00
parent 3052c6e294
commit b0ba6b384b
2 changed files with 84 additions and 1 deletions

View file

@ -57,7 +57,8 @@ class OptimFunctionCollector(cst.CSTVisitor):
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
for child_node in node.body.body:
if (
isinstance(child_node, cst.FunctionDef)
self.preexisting_functions
and isinstance(child_node, cst.FunctionDef)
and (node.name.value, child_node.name.value) not in self.contextual_functions
and (child_node.name.value, parents) not in self.preexisting_functions
):

View file

@ -844,3 +844,85 @@ def test_code_replacement11() -> None:
contextual_functions=contextual_functions,
)
assert new_code == expected_code
def test_code_replacement12() -> None:
optim_code = '''class Fu():
def foo(self) -> dict[str, str]:
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar() + 1)}
return payload
def real_bar(self) -> int:
"""No abstract nonsense"""
pass
'''
original_code = '''class Fu():
def foo(self) -> dict[str, str]:
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar())}
return payload
def real_bar(self) -> int:
"""No abstract nonsense"""
return 0
'''
expected_code = '''class Fu():
def foo(self) -> dict[str, str]:
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar())}
return payload
def real_bar(self) -> int:
"""No abstract nonsense"""
pass
'''
preexisting_functions: list[tuple[str, list[FunctionParent]]] = []
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
source_code=original_code,
original_function_names=["Fu.real_bar"],
optimized_code=optim_code,
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
)
assert new_code == expected_code
def test_test_libcst_code_replacement13() -> None:
# Test if the dunder method is not modified
optim_code = """class NewClass:
def __init__(self, name):
self.name = name
self.new_attribute = "Sorry i modified a dunder method"
def new_function(self, value):
return other_function(self.name)
def new_function2(value):
return value
def __call__(self, value):
return self.new_attribute
"""
original_code = """class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value):
return other_function(self.name)
def new_function2(value):
return value
def __call__(self, value):
return self.name
"""
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
preexisting_functions: list[tuple[str, list[FunctionParent]]] = []
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)
assert new_code == original_code