codeflash-internal/cli/codeflash/code_utils/code_replacer.py

245 lines
9.1 KiB
Python

from __future__ import annotations
import ast
from functools import lru_cache
from typing import TYPE_CHECKING, TypeVar
import libcst as cst
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.discovery.functions_to_optimize import FunctionParent
if TYPE_CHECKING:
from pathlib import Path
from libcst import FunctionDef
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
def normalize_node(node: ASTNodeT) -> ASTNodeT:
if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and ast.get_docstring(node):
node.body = node.body[1:]
if hasattr(node, "body"):
node.body = [normalize_node(node) for node in node.body if not isinstance(node, (ast.Import, ast.ImportFrom))]
return node
@lru_cache(maxsize=3)
def normalize_code(code: str) -> str:
return ast.unparse(normalize_node(ast.parse(code)))
class OptimFunctionCollector(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
def __init__(
self,
function_name: str,
class_name: str | None,
contextual_functions: set[tuple[str, str]],
preexisting_objects: list[tuple[str, list[FunctionParent]]] | None = None,
) -> None:
super().__init__()
if preexisting_objects is None:
preexisting_objects = []
self.function_name = function_name
self.class_name = class_name
self.optim_body: FunctionDef | None = None
self.optim_new_class_functions: list[cst.FunctionDef] = []
self.optim_new_functions: list[cst.FunctionDef] = []
self.preexisting_objects = preexisting_objects
self.contextual_functions = contextual_functions.union({(self.class_name, self.function_name)})
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
parent2 = None
try:
if parent is not None and isinstance(parent, cst.Module):
parent2 = self.get_metadata(cst.metadata.ParentNodeProvider, parent)
except:
pass
if node.name.value == self.function_name:
self.optim_body = node
elif (
self.preexisting_objects
and (node.name.value, []) not in self.preexisting_objects
and (isinstance(parent, cst.Module) or (parent2 is not None and not isinstance(parent2, cst.ClassDef)))
):
self.optim_new_functions.append(node)
def visit_ClassDef_body(self, node: cst.ClassDef) -> None:
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
for child_node in node.body.body:
if (
self.preexisting_objects
and isinstance(child_node, cst.FunctionDef)
and (node.name.value, child_node.name.value) not in self.contextual_functions
and (child_node.name.value, parents) not in self.preexisting_objects
):
self.optim_new_class_functions.append(child_node)
class OptimFunctionReplacer(cst.CSTTransformer):
def __init__(
self,
function_name: str,
optim_body: cst.FunctionDef,
optim_new_class_functions: list[cst.FunctionDef],
optim_new_functions: list[cst.FunctionDef],
class_name: str | None = None,
) -> None:
super().__init__()
self.function_name = function_name
self.optim_body = optim_body
self.optim_new_class_functions = optim_new_class_functions
self.optim_new_functions = optim_new_functions
self.class_name = class_name
self.depth: int = 0
self.in_class: bool = False
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
return False
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
if original_node.name.value == self.function_name and (self.depth == 0 or (self.depth == 1 and self.in_class)):
return updated_node.with_changes(body=self.optim_body.body, decorators=self.optim_body.decorators)
return updated_node
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self.depth += 1
if self.in_class:
return False
self.in_class = (self.depth == 1) and (node.name.value == self.class_name)
return self.in_class
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.depth -= 1
if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name):
self.in_class = False
return updated_node.with_changes(
body=updated_node.body.with_changes(
body=(list(updated_node.body.body) + self.optim_new_class_functions)
)
)
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
node = updated_node
max_function_index = None
class_index = None
for index, _node in enumerate(node.body):
if isinstance(_node, cst.FunctionDef):
max_function_index = index
if isinstance(_node, cst.ClassDef):
class_index = index
if max_function_index is not None:
node = node.with_changes(
body=(
*node.body[: max_function_index + 1],
*self.optim_new_functions,
*node.body[max_function_index + 1 :],
)
)
elif class_index is not None:
node = node.with_changes(
body=(*node.body[: class_index + 1], *self.optim_new_functions, *node.body[class_index + 1 :])
)
else:
node = node.with_changes(body=(*self.optim_new_functions, *node.body))
return node
def replace_functions_in_file(
source_code: str,
original_function_names: list[str],
optimized_code: str,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
) -> str:
parsed_function_names = []
for original_function_name in original_function_names:
if original_function_name.count(".") == 0:
class_name, function_name = None, original_function_name
elif original_function_name.count(".") == 1:
class_name, function_name = original_function_name.split(".")
else:
msg = f"Don't know how to find {original_function_name} yet!"
raise ValueError(msg)
parsed_function_names.append((function_name, class_name))
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
for function_name, class_name in parsed_function_names:
visitor = OptimFunctionCollector(function_name, class_name, contextual_functions, preexisting_objects)
module.visit(visitor)
if visitor.optim_body is None and not preexisting_objects:
continue
if visitor.optim_body is None:
msg = f"Did not find the function {function_name} in the optimized code"
raise ValueError(msg)
transformer = OptimFunctionReplacer(
visitor.function_name,
visitor.optim_body,
visitor.optim_new_class_functions,
visitor.optim_new_functions,
class_name=class_name,
)
original_module = cst.parse_module(source_code)
modified_tree = original_module.visit(transformer)
source_code = modified_tree.code
return source_code
def replace_functions_and_add_imports(
source_code: str,
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: Path,
module_abspath: Path,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
project_root_path: Path,
) -> str:
return add_needed_imports_from_module(
optimized_code,
replace_functions_in_file(
source_code, function_names, optimized_code, preexisting_objects, contextual_functions
),
file_path_of_module_with_function_to_optimize,
module_abspath,
project_root_path,
)
def replace_function_definitions_in_module(
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: Path,
module_abspath: Path,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
project_root_path: Path,
) -> bool:
source_code: str = module_abspath.read_text(encoding="utf8")
new_code: str = replace_functions_and_add_imports(
source_code,
function_names,
optimized_code,
file_path_of_module_with_function_to_optimize,
module_abspath,
preexisting_objects,
contextual_functions,
project_root_path,
)
if is_zero_diff(source_code, new_code):
return False
module_abspath.write_text(new_code, encoding="utf8")
return True
def is_zero_diff(original_code: str, new_code: str) -> bool:
return normalize_code(original_code) == normalize_code(new_code)