Formatting clean-up.
This commit is contained in:
parent
f5101ffb6c
commit
eb5168a3f8
3 changed files with 80 additions and 63 deletions
|
|
@ -13,11 +13,11 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function_name: str,
|
||||
class_name: str | None,
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] | None = None,
|
||||
self,
|
||||
function_name: str,
|
||||
class_name: str | None,
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if preexisting_functions is None:
|
||||
|
|
@ -43,32 +43,34 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
if node.name.value == self.function_name:
|
||||
self.optim_body = node
|
||||
elif (
|
||||
self.preexisting_functions
|
||||
and (node.name.value, []) not in self.preexisting_functions
|
||||
and (
|
||||
isinstance(parent, cst.Module)
|
||||
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
|
||||
)
|
||||
self.preexisting_functions
|
||||
and (node.name.value, []) not in self.preexisting_functions
|
||||
and (
|
||||
isinstance(parent, cst.Module)
|
||||
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
|
||||
)
|
||||
):
|
||||
self.optim_new_functions.append(node)
|
||||
|
||||
def visit_ClassDef_body(self, node: cst.ClassDef) -> None:
|
||||
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
|
||||
for child_node in node.body.body:
|
||||
if isinstance(child_node, cst.FunctionDef) and (
|
||||
node.name.value, child_node.name.value) not in self.contextual_functions and (
|
||||
child_node.name.value, parents) not in self.preexisting_functions:
|
||||
if (
|
||||
isinstance(child_node, cst.FunctionDef)
|
||||
and (node.name.value, child_node.name.value) not in self.contextual_functions
|
||||
and (child_node.name.value, parents) not in self.preexisting_functions
|
||||
):
|
||||
self.optim_new_class_functions.append(child_node)
|
||||
|
||||
|
||||
class OptimFunctionReplacer(cst.CSTTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
function_name: str,
|
||||
optim_body: cst.FunctionDef,
|
||||
optim_new_class_functions: list[cst.FunctionDef],
|
||||
optim_new_functions: list[cst.FunctionDef],
|
||||
class_name: str | None = None,
|
||||
self,
|
||||
function_name: str,
|
||||
optim_body: cst.FunctionDef,
|
||||
optim_new_class_functions: list[cst.FunctionDef],
|
||||
optim_new_functions: list[cst.FunctionDef],
|
||||
class_name: str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.function_name = function_name
|
||||
|
|
@ -83,12 +85,12 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
return False
|
||||
|
||||
def leave_FunctionDef(
|
||||
self,
|
||||
original_node: cst.FunctionDef,
|
||||
updated_node: cst.FunctionDef,
|
||||
self,
|
||||
original_node: cst.FunctionDef,
|
||||
updated_node: cst.FunctionDef,
|
||||
) -> cst.FunctionDef:
|
||||
if original_node.name.value == self.function_name and (
|
||||
self.depth == 0 or (self.depth == 1 and self.in_class)
|
||||
self.depth == 0 or (self.depth == 1 and self.in_class)
|
||||
):
|
||||
return self.optim_body
|
||||
return updated_node
|
||||
|
|
@ -101,9 +103,9 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
return self.in_class
|
||||
|
||||
def leave_ClassDef(
|
||||
self,
|
||||
original_node: cst.ClassDef,
|
||||
updated_node: cst.ClassDef,
|
||||
self,
|
||||
original_node: cst.ClassDef,
|
||||
updated_node: cst.ClassDef,
|
||||
) -> cst.ClassDef:
|
||||
self.depth -= 1
|
||||
if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name):
|
||||
|
|
@ -129,7 +131,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
body=(
|
||||
*node.body[: max_function_index + 1],
|
||||
*self.optim_new_functions,
|
||||
*node.body[max_function_index + 1:],
|
||||
*node.body[max_function_index + 1 :],
|
||||
),
|
||||
)
|
||||
elif class_index is not None:
|
||||
|
|
@ -137,7 +139,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
body=(
|
||||
*node.body[: class_index + 1],
|
||||
*self.optim_new_functions,
|
||||
*node.body[class_index + 1:],
|
||||
*node.body[class_index + 1 :],
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -146,11 +148,11 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
|
||||
|
||||
def replace_functions_in_file(
|
||||
source_code: str,
|
||||
original_function_names: list[str],
|
||||
optimized_code: str,
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
source_code: str,
|
||||
original_function_names: list[str],
|
||||
optimized_code: str,
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
) -> str:
|
||||
parsed_function_names = []
|
||||
for original_function_name in original_function_names:
|
||||
|
|
@ -193,14 +195,14 @@ def replace_functions_in_file(
|
|||
|
||||
|
||||
def replace_functions_and_add_imports(
|
||||
source_code: str,
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
file_path_of_module_with_function_to_optimize: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: str,
|
||||
source_code: str,
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
file_path_of_module_with_function_to_optimize: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: str,
|
||||
) -> str:
|
||||
return add_needed_imports_from_module(
|
||||
optimized_code,
|
||||
|
|
@ -218,13 +220,13 @@ def replace_functions_and_add_imports(
|
|||
|
||||
|
||||
def replace_function_definitions_in_module(
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
file_path_of_module_with_function_to_optimize: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: str,
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
file_path_of_module_with_function_to_optimize: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: str,
|
||||
) -> None:
|
||||
file: IO[str]
|
||||
with open(module_abspath, encoding="utf8") as file:
|
||||
|
|
|
|||
|
|
@ -420,7 +420,7 @@ class Optimizer:
|
|||
)
|
||||
if not did_update:
|
||||
logging.warning(
|
||||
"No functions were replaced in the optimized code. Skipping optimization candidate."
|
||||
"No functions were replaced in the optimized code. Skipping optimization candidate.",
|
||||
)
|
||||
continue
|
||||
except (
|
||||
|
|
@ -630,11 +630,10 @@ class Optimizer:
|
|||
if code_to_optimize is None:
|
||||
return Failure("Could not find function to optimize.")
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
(name, [
|
||||
FunctionParent(
|
||||
name=class_name, type="ClassDef")]) for class_name, name in contextual_dunder_methods]
|
||||
preexisting_functions.append(
|
||||
(function_to_optimize.function_name, function_to_optimize.parents))
|
||||
(name, [FunctionParent(name=class_name, type="ClassDef")])
|
||||
for class_name, name in contextual_dunder_methods
|
||||
]
|
||||
preexisting_functions.append((function_to_optimize.function_name, function_to_optimize.parents))
|
||||
(
|
||||
helper_code,
|
||||
helper_functions,
|
||||
|
|
@ -676,9 +675,14 @@ class Optimizer:
|
|||
function_to_optimize.file_path,
|
||||
project_root,
|
||||
)
|
||||
preexisting_functions.extend([(qualified_name_list[-1], ([FunctionParent(name=qualified_name_list[-2], type="ClassDef")])) if len(
|
||||
qualified_name_list := fn[0].full_name.split(".")) > 1 else (
|
||||
qualified_name_list[-1], []) for fn in helper_functions])
|
||||
preexisting_functions.extend(
|
||||
[
|
||||
(qualified_name_list[-1], ([FunctionParent(name=qualified_name_list[-2], type="ClassDef")]))
|
||||
if len(qualified_name_list := fn[0].full_name.split(".")) > 1
|
||||
else (qualified_name_list[-1], [])
|
||||
for fn in helper_functions
|
||||
],
|
||||
)
|
||||
contextual_dunder_methods.update(helper_dunder_methods)
|
||||
return Success(
|
||||
CodeOptimizationContext(
|
||||
|
|
|
|||
|
|
@ -51,7 +51,8 @@ print("Hello world")
|
|||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
("new_function", [FunctionParent(name="NewClass", type="ClassDef")])]
|
||||
("new_function", [FunctionParent(name="NewClass", type="ClassDef")]),
|
||||
]
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -112,7 +113,10 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("new_function", []), ("other_function", [])]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
("new_function", []),
|
||||
("other_function", []),
|
||||
]
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -575,7 +579,10 @@ class CacheConfig(BaseConfig):
|
|||
"""
|
||||
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
|
||||
parents = [FunctionParent(name="CacheConfig", type="ClassDef")]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("from_config", parents)]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
("__init__", parents),
|
||||
("from_config", parents),
|
||||
]
|
||||
|
||||
contextual_functions: set[tuple[str, str]] = {
|
||||
("CacheSimilarityEvalConfig", "__init__"),
|
||||
|
|
@ -653,7 +660,8 @@ def test_test_libcst_code_replacement8() -> None:
|
|||
'''
|
||||
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
("_hamming_distance", [FunctionParent("_EmbeddingDistanceChainMixin", "ClassDef")])]
|
||||
("_hamming_distance", [FunctionParent("_EmbeddingDistanceChainMixin", "ClassDef")]),
|
||||
]
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -710,7 +718,10 @@ print("Hello world")
|
|||
"""
|
||||
parents = [FunctionParent(name="NewClass", type="ClassDef")]
|
||||
function_name: str = "NewClass.__init__"
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("__call__", parents)]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
("__init__", parents),
|
||||
("__call__", parents),
|
||||
]
|
||||
contextual_functions: set[tuple[str, str]] = {
|
||||
("NewClass", "__init__"),
|
||||
("NewClass", "__call__"),
|
||||
|
|
|
|||
Loading…
Reference in a new issue