Change name of unmodifiable context, add test for dunder method to optimize.

This commit is contained in:
renaud 2024-04-09 18:54:55 -07:00
parent 06e0fe322e
commit 5557c0b13e
2 changed files with 79 additions and 21 deletions

View file

@ -13,7 +13,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
self,
function_name: str,
class_name: str | None,
immutable_methods: set[tuple[str, str]],
contextual_functions: set[tuple[str, str]],
preexisting_functions: list[str] | None = None,
) -> None:
super().__init__()
@ -26,7 +26,9 @@ class OptimFunctionCollector(cst.CSTVisitor):
self.optim_new_functions: list[cst.FunctionDef] = []
self.optim_imports: list[cst.SimpleStatementLine] = []
self.preexisting_functions = preexisting_functions
self.immutable_methods = immutable_methods.union({(self.class_name, self.function_name)})
self.contextual_functions = contextual_functions.union(
{(self.class_name, self.function_name)},
)
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
@ -56,7 +58,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
node.name.value,
child_node.name.value,
)
not in self.immutable_methods
not in self.contextual_functions
):
self.optim_new_class_functions.append(child_node)
@ -181,7 +183,7 @@ def replace_functions_in_file(
original_function_names: list[str],
optimized_code: str,
preexisting_functions: list[str],
immutable_methods: set[tuple[str, str]],
contextual_functions: set[tuple[str, str]],
) -> str:
parsed_function_names = []
for original_function_name in original_function_names:
@ -199,7 +201,7 @@ def replace_functions_in_file(
visitor = OptimFunctionCollector(
function_name,
class_name,
immutable_methods,
contextual_functions,
preexisting_functions,
)
module.visit(visitor)
@ -230,7 +232,7 @@ def replace_function_definitions_in_module(
optimized_code: str,
module_abspath: str,
preexisting_functions: list[str],
immutable_methods: set[tuple[str, str]],
contextual_functions: set[tuple[str, str]],
) -> None:
file: IO[str]
with open(module_abspath, encoding="utf8") as file:
@ -240,7 +242,7 @@ def replace_function_definitions_in_module(
function_names,
optimized_code,
preexisting_functions,
immutable_methods,
contextual_functions,
)
with open(module_abspath, "w", encoding="utf8") as file:
file.write(new_code)

View file

@ -49,13 +49,13 @@ print("Hello world")
function_name: str = "NewClass.new_function"
preexisting_functions: list[str] = ["new_function"]
immutable_methods: set[tuple[str, str]] = {("NewClass", "__init__")}
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_in_file(
original_code,
[function_name],
optim_code,
preexisting_functions,
immutable_methods,
contextual_functions,
)
assert new_code == expected
@ -109,13 +109,13 @@ print("Hello world")
function_name: str = "NewClass.new_function"
preexisting_functions: list[str] = ["new_function", "other_function"]
immutable_methods: set[tuple[str, str]] = {("NewClass", "__init__")}
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_in_file(
original_code,
[function_name],
optim_code,
preexisting_functions,
immutable_methods,
contextual_functions,
)
assert new_code == expected
@ -170,13 +170,13 @@ print("Salut monde")
function_names: list[str] = ["module.other_function"]
preexisting_functions: list[str] = []
immutable_methods: set[tuple[str, str]] = set()
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
original_code,
function_names,
optim_code,
preexisting_functions,
immutable_methods,
contextual_functions,
)
assert new_code == expected
@ -234,13 +234,13 @@ print("Salut monde")
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
preexisting_functions: list[str] = []
immutable_methods: set[tuple[str, str]] = set()
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
original_code,
function_names,
optim_code,
preexisting_functions,
immutable_methods,
contextual_functions,
)
assert new_code == expected
@ -284,13 +284,13 @@ def supersort(doink):
function_names: list[str] = ["sorter_deps"]
preexisting_functions: list[str] = ["sorter_deps"]
immutable_methods: set[tuple[str, str]] = set()
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
original_code,
function_names,
optim_code,
preexisting_functions,
immutable_methods,
contextual_functions,
)
assert new_code == expected
@ -567,7 +567,7 @@ class CacheConfig(BaseConfig):
"__init__",
"from_config",
]
immutable_methods: set[tuple[str, str]] = {
contextual_functions: set[tuple[str, str]] = {
("CacheSimilarityEvalConfig", "__init__"),
("CacheConfig", "__init__"),
("CacheInitConfig", "__init__"),
@ -577,7 +577,7 @@ class CacheConfig(BaseConfig):
function_names,
optim_code,
preexisting_functions,
immutable_methods,
contextual_functions,
)
assert new_code == expected
@ -642,12 +642,68 @@ def test_test_libcst_code_replacement8() -> None:
preexisting_functions: list[str] = [
"_hamming_distance",
]
immutable_methods: set[tuple[str, str]] = set()
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
original_code,
function_names,
optim_code,
preexisting_functions,
immutable_methods,
contextual_functions,
)
assert new_code == expected
def test_test_libcst_code_replacement9() -> None:
optim_code = """import libcst as cst
from typing import Optional
def totally_new_function(value):
return value
class NewClass:
def __init__(self, name):
self.name = str(name)
def __call__(self, value):
return self.name
def new_function2(value):
return value
"""
original_code = """class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
print("Hello world")
"""
expected = """import libcst as cst
from typing import Optional
class NewClass:
def __init__(self, name):
self.name = str(name)
def __call__(self, value):
return "I am still old"
def new_function2(value):
return value
def totally_new_function(value):
return value
print("Hello world")
"""
function_name: str = "NewClass.__init__"
preexisting_functions: list[str] = ["__init__", "__call__"]
contextual_functions: set[tuple[str, str]] = {
("NewClass", "__init__"),
("NewClass", "__call__"),
}
new_code: str = replace_functions_in_file(
original_code,
[function_name],
optim_code,
preexisting_functions,
contextual_functions,
)
assert new_code == expected