Formatting clean-up.

This commit is contained in:
RD 2024-06-21 17:43:05 -07:00
parent f5101ffb6c
commit eb5168a3f8
3 changed files with 80 additions and 63 deletions

View file

@ -13,11 +13,11 @@ class OptimFunctionCollector(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
def __init__(
self,
function_name: str,
class_name: str | None,
contextual_functions: set[tuple[str, str]],
preexisting_functions: list[tuple[str, list[FunctionParent]]] | None = None,
self,
function_name: str,
class_name: str | None,
contextual_functions: set[tuple[str, str]],
preexisting_functions: list[tuple[str, list[FunctionParent]]] | None = None,
) -> None:
super().__init__()
if preexisting_functions is None:
@ -43,32 +43,34 @@ class OptimFunctionCollector(cst.CSTVisitor):
if node.name.value == self.function_name:
self.optim_body = node
elif (
self.preexisting_functions
and (node.name.value, []) not in self.preexisting_functions
and (
isinstance(parent, cst.Module)
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
)
self.preexisting_functions
and (node.name.value, []) not in self.preexisting_functions
and (
isinstance(parent, cst.Module)
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
)
):
self.optim_new_functions.append(node)
def visit_ClassDef_body(self, node: cst.ClassDef) -> None:
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
for child_node in node.body.body:
if isinstance(child_node, cst.FunctionDef) and (
node.name.value, child_node.name.value) not in self.contextual_functions and (
child_node.name.value, parents) not in self.preexisting_functions:
if (
isinstance(child_node, cst.FunctionDef)
and (node.name.value, child_node.name.value) not in self.contextual_functions
and (child_node.name.value, parents) not in self.preexisting_functions
):
self.optim_new_class_functions.append(child_node)
class OptimFunctionReplacer(cst.CSTTransformer):
def __init__(
self,
function_name: str,
optim_body: cst.FunctionDef,
optim_new_class_functions: list[cst.FunctionDef],
optim_new_functions: list[cst.FunctionDef],
class_name: str | None = None,
self,
function_name: str,
optim_body: cst.FunctionDef,
optim_new_class_functions: list[cst.FunctionDef],
optim_new_functions: list[cst.FunctionDef],
class_name: str | None = None,
) -> None:
super().__init__()
self.function_name = function_name
@ -83,12 +85,12 @@ class OptimFunctionReplacer(cst.CSTTransformer):
return False
def leave_FunctionDef(
self,
original_node: cst.FunctionDef,
updated_node: cst.FunctionDef,
self,
original_node: cst.FunctionDef,
updated_node: cst.FunctionDef,
) -> cst.FunctionDef:
if original_node.name.value == self.function_name and (
self.depth == 0 or (self.depth == 1 and self.in_class)
self.depth == 0 or (self.depth == 1 and self.in_class)
):
return self.optim_body
return updated_node
@ -101,9 +103,9 @@ class OptimFunctionReplacer(cst.CSTTransformer):
return self.in_class
def leave_ClassDef(
self,
original_node: cst.ClassDef,
updated_node: cst.ClassDef,
self,
original_node: cst.ClassDef,
updated_node: cst.ClassDef,
) -> cst.ClassDef:
self.depth -= 1
if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name):
@ -129,7 +131,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
body=(
*node.body[: max_function_index + 1],
*self.optim_new_functions,
*node.body[max_function_index + 1:],
*node.body[max_function_index + 1 :],
),
)
elif class_index is not None:
@ -137,7 +139,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
body=(
*node.body[: class_index + 1],
*self.optim_new_functions,
*node.body[class_index + 1:],
*node.body[class_index + 1 :],
),
)
else:
@ -146,11 +148,11 @@ class OptimFunctionReplacer(cst.CSTTransformer):
def replace_functions_in_file(
source_code: str,
original_function_names: list[str],
optimized_code: str,
preexisting_functions: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
source_code: str,
original_function_names: list[str],
optimized_code: str,
preexisting_functions: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
) -> str:
parsed_function_names = []
for original_function_name in original_function_names:
@ -193,14 +195,14 @@ def replace_functions_in_file(
def replace_functions_and_add_imports(
source_code: str,
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: str,
module_abspath: str,
preexisting_functions: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
project_root_path: str,
source_code: str,
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: str,
module_abspath: str,
preexisting_functions: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
project_root_path: str,
) -> str:
return add_needed_imports_from_module(
optimized_code,
@ -218,13 +220,13 @@ def replace_functions_and_add_imports(
def replace_function_definitions_in_module(
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: str,
module_abspath: str,
preexisting_functions: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
project_root_path: str,
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: str,
module_abspath: str,
preexisting_functions: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
project_root_path: str,
) -> None:
file: IO[str]
with open(module_abspath, encoding="utf8") as file:

View file

@ -420,7 +420,7 @@ class Optimizer:
)
if not did_update:
logging.warning(
"No functions were replaced in the optimized code. Skipping optimization candidate."
"No functions were replaced in the optimized code. Skipping optimization candidate.",
)
continue
except (
@ -630,11 +630,10 @@ class Optimizer:
if code_to_optimize is None:
return Failure("Could not find function to optimize.")
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
(name, [
FunctionParent(
name=class_name, type="ClassDef")]) for class_name, name in contextual_dunder_methods]
preexisting_functions.append(
(function_to_optimize.function_name, function_to_optimize.parents))
(name, [FunctionParent(name=class_name, type="ClassDef")])
for class_name, name in contextual_dunder_methods
]
preexisting_functions.append((function_to_optimize.function_name, function_to_optimize.parents))
(
helper_code,
helper_functions,
@ -676,9 +675,14 @@ class Optimizer:
function_to_optimize.file_path,
project_root,
)
preexisting_functions.extend([(qualified_name_list[-1], ([FunctionParent(name=qualified_name_list[-2], type="ClassDef")])) if len(
qualified_name_list := fn[0].full_name.split(".")) > 1 else (
qualified_name_list[-1], []) for fn in helper_functions])
preexisting_functions.extend(
[
(qualified_name_list[-1], ([FunctionParent(name=qualified_name_list[-2], type="ClassDef")]))
if len(qualified_name_list := fn[0].full_name.split(".")) > 1
else (qualified_name_list[-1], [])
for fn in helper_functions
],
)
contextual_dunder_methods.update(helper_dunder_methods)
return Success(
CodeOptimizationContext(

View file

@ -51,7 +51,8 @@ print("Hello world")
function_name: str = "NewClass.new_function"
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
("new_function", [FunctionParent(name="NewClass", type="ClassDef")])]
("new_function", [FunctionParent(name="NewClass", type="ClassDef")]),
]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
@ -112,7 +113,10 @@ print("Hello world")
"""
function_name: str = "NewClass.new_function"
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("new_function", []), ("other_function", [])]
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
("new_function", []),
("other_function", []),
]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
@ -575,7 +579,10 @@ class CacheConfig(BaseConfig):
"""
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
parents = [FunctionParent(name="CacheConfig", type="ClassDef")]
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("from_config", parents)]
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
("__init__", parents),
("from_config", parents),
]
contextual_functions: set[tuple[str, str]] = {
("CacheSimilarityEvalConfig", "__init__"),
@ -653,7 +660,8 @@ def test_test_libcst_code_replacement8() -> None:
'''
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
("_hamming_distance", [FunctionParent("_EmbeddingDistanceChainMixin", "ClassDef")])]
("_hamming_distance", [FunctionParent("_EmbeddingDistanceChainMixin", "ClassDef")]),
]
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
@ -710,7 +718,10 @@ print("Hello world")
"""
parents = [FunctionParent(name="NewClass", type="ClassDef")]
function_name: str = "NewClass.__init__"
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("__call__", parents)]
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
("__init__", parents),
("__call__", parents),
]
contextual_functions: set[tuple[str, str]] = {
("NewClass", "__init__"),
("NewClass", "__call__"),