mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Change name of unmodifiable context, add test for dunder method to optimize.
This commit is contained in:
parent
06e0fe322e
commit
5557c0b13e
2 changed files with 79 additions and 21 deletions
|
|
@ -13,7 +13,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
self,
|
||||
function_name: str,
|
||||
class_name: str | None,
|
||||
immutable_methods: set[tuple[str, str]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
preexisting_functions: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -26,7 +26,9 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
self.optim_new_functions: list[cst.FunctionDef] = []
|
||||
self.optim_imports: list[cst.SimpleStatementLine] = []
|
||||
self.preexisting_functions = preexisting_functions
|
||||
self.immutable_methods = immutable_methods.union({(self.class_name, self.function_name)})
|
||||
self.contextual_functions = contextual_functions.union(
|
||||
{(self.class_name, self.function_name)},
|
||||
)
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
|
|
@ -56,7 +58,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
node.name.value,
|
||||
child_node.name.value,
|
||||
)
|
||||
not in self.immutable_methods
|
||||
not in self.contextual_functions
|
||||
):
|
||||
self.optim_new_class_functions.append(child_node)
|
||||
|
||||
|
|
@ -181,7 +183,7 @@ def replace_functions_in_file(
|
|||
original_function_names: list[str],
|
||||
optimized_code: str,
|
||||
preexisting_functions: list[str],
|
||||
immutable_methods: set[tuple[str, str]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
) -> str:
|
||||
parsed_function_names = []
|
||||
for original_function_name in original_function_names:
|
||||
|
|
@ -199,7 +201,7 @@ def replace_functions_in_file(
|
|||
visitor = OptimFunctionCollector(
|
||||
function_name,
|
||||
class_name,
|
||||
immutable_methods,
|
||||
contextual_functions,
|
||||
preexisting_functions,
|
||||
)
|
||||
module.visit(visitor)
|
||||
|
|
@ -230,7 +232,7 @@ def replace_function_definitions_in_module(
|
|||
optimized_code: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[str],
|
||||
immutable_methods: set[tuple[str, str]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
) -> None:
|
||||
file: IO[str]
|
||||
with open(module_abspath, encoding="utf8") as file:
|
||||
|
|
@ -240,7 +242,7 @@ def replace_function_definitions_in_module(
|
|||
function_names,
|
||||
optimized_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
contextual_functions,
|
||||
)
|
||||
with open(module_abspath, "w", encoding="utf8") as file:
|
||||
file.write(new_code)
|
||||
|
|
|
|||
|
|
@ -49,13 +49,13 @@ print("Hello world")
|
|||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[str] = ["new_function"]
|
||||
immutable_methods: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
[function_name],
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
contextual_functions,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -109,13 +109,13 @@ print("Hello world")
|
|||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[str] = ["new_function", "other_function"]
|
||||
immutable_methods: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
[function_name],
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
contextual_functions,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -170,13 +170,13 @@ print("Salut monde")
|
|||
|
||||
function_names: list[str] = ["module.other_function"]
|
||||
preexisting_functions: list[str] = []
|
||||
immutable_methods: set[tuple[str, str]] = set()
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
contextual_functions,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -234,13 +234,13 @@ print("Salut monde")
|
|||
|
||||
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
|
||||
preexisting_functions: list[str] = []
|
||||
immutable_methods: set[tuple[str, str]] = set()
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
contextual_functions,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -284,13 +284,13 @@ def supersort(doink):
|
|||
|
||||
function_names: list[str] = ["sorter_deps"]
|
||||
preexisting_functions: list[str] = ["sorter_deps"]
|
||||
immutable_methods: set[tuple[str, str]] = set()
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
contextual_functions,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -567,7 +567,7 @@ class CacheConfig(BaseConfig):
|
|||
"__init__",
|
||||
"from_config",
|
||||
]
|
||||
immutable_methods: set[tuple[str, str]] = {
|
||||
contextual_functions: set[tuple[str, str]] = {
|
||||
("CacheSimilarityEvalConfig", "__init__"),
|
||||
("CacheConfig", "__init__"),
|
||||
("CacheInitConfig", "__init__"),
|
||||
|
|
@ -577,7 +577,7 @@ class CacheConfig(BaseConfig):
|
|||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
contextual_functions,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -642,12 +642,68 @@ def test_test_libcst_code_replacement8() -> None:
|
|||
preexisting_functions: list[str] = [
|
||||
"_hamming_distance",
|
||||
]
|
||||
immutable_methods: set[tuple[str, str]] = set()
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
contextual_functions,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement9() -> None:
|
||||
optim_code = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
||||
def totally_new_function(value):
|
||||
return value
|
||||
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = str(name)
|
||||
def __call__(self, value):
|
||||
return self.name
|
||||
def new_function2(value):
|
||||
return value
|
||||
"""
|
||||
|
||||
original_code = """class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def __call__(self, value):
|
||||
return "I am still old"
|
||||
|
||||
print("Hello world")
|
||||
"""
|
||||
expected = """import libcst as cst
|
||||
from typing import Optional
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = str(name)
|
||||
def __call__(self, value):
|
||||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return value
|
||||
|
||||
def totally_new_function(value):
|
||||
return value
|
||||
|
||||
print("Hello world")
|
||||
"""
|
||||
|
||||
function_name: str = "NewClass.__init__"
|
||||
preexisting_functions: list[str] = ["__init__", "__call__"]
|
||||
contextual_functions: set[tuple[str, str]] = {
|
||||
("NewClass", "__init__"),
|
||||
("NewClass", "__call__"),
|
||||
}
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
[function_name],
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
contextual_functions,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
|
|||
Loading…
Reference in a new issue