mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Restore and add tests, small fix.
This commit is contained in:
parent
3052c6e294
commit
b0ba6b384b
2 changed files with 84 additions and 1 deletions
|
|
@ -57,7 +57,8 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
|
||||
for child_node in node.body.body:
|
||||
if (
|
||||
isinstance(child_node, cst.FunctionDef)
|
||||
self.preexisting_functions
|
||||
and isinstance(child_node, cst.FunctionDef)
|
||||
and (node.name.value, child_node.name.value) not in self.contextual_functions
|
||||
and (child_node.name.value, parents) not in self.preexisting_functions
|
||||
):
|
||||
|
|
|
|||
|
|
@ -844,3 +844,85 @@ def test_code_replacement11() -> None:
|
|||
contextual_functions=contextual_functions,
|
||||
)
|
||||
assert new_code == expected_code
|
||||
|
||||
|
||||
def test_code_replacement12() -> None:
|
||||
optim_code = '''class Fu():
|
||||
def foo(self) -> dict[str, str]:
|
||||
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar() + 1)}
|
||||
return payload
|
||||
|
||||
def real_bar(self) -> int:
|
||||
"""No abstract nonsense"""
|
||||
pass
|
||||
'''
|
||||
original_code = '''class Fu():
|
||||
def foo(self) -> dict[str, str]:
|
||||
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar())}
|
||||
return payload
|
||||
|
||||
def real_bar(self) -> int:
|
||||
"""No abstract nonsense"""
|
||||
return 0
|
||||
'''
|
||||
expected_code = '''class Fu():
|
||||
def foo(self) -> dict[str, str]:
|
||||
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar())}
|
||||
return payload
|
||||
|
||||
def real_bar(self) -> int:
|
||||
"""No abstract nonsense"""
|
||||
pass
|
||||
'''
|
||||
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
original_function_names=["Fu.real_bar"],
|
||||
optimized_code=optim_code,
|
||||
preexisting_functions=preexisting_functions,
|
||||
contextual_functions=contextual_functions,
|
||||
)
|
||||
assert new_code == expected_code
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement13() -> None:
|
||||
# Test if the dunder method is not modified
|
||||
optim_code = """class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.new_attribute = "Sorry i modified a dunder method"
|
||||
def new_function(self, value):
|
||||
return other_function(self.name)
|
||||
def new_function2(value):
|
||||
return value
|
||||
def __call__(self, value):
|
||||
return self.new_attribute
|
||||
"""
|
||||
|
||||
original_code = """class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def new_function(self, value):
|
||||
return other_function(self.name)
|
||||
def new_function2(value):
|
||||
return value
|
||||
def __call__(self, value):
|
||||
return self.name
|
||||
"""
|
||||
|
||||
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
assert new_code == original_code
|
||||
|
|
|
|||
Loading…
Reference in a new issue