mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
245 lines
9.1 KiB
Python
245 lines
9.1 KiB
Python
from __future__ import annotations
|
|
|
|
import ast
|
|
from functools import lru_cache
|
|
from typing import TYPE_CHECKING, TypeVar
|
|
|
|
import libcst as cst
|
|
|
|
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
|
from codeflash.discovery.functions_to_optimize import FunctionParent
|
|
|
|
if TYPE_CHECKING:
|
|
from pathlib import Path
|
|
|
|
from libcst import FunctionDef
|
|
|
|
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
|
|
|
|
|
|
def normalize_node(node: ASTNodeT) -> ASTNodeT:
|
|
if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and ast.get_docstring(node):
|
|
node.body = node.body[1:]
|
|
if hasattr(node, "body"):
|
|
node.body = [normalize_node(node) for node in node.body if not isinstance(node, (ast.Import, ast.ImportFrom))]
|
|
return node
|
|
|
|
|
|
@lru_cache(maxsize=3)
|
|
def normalize_code(code: str) -> str:
|
|
return ast.unparse(normalize_node(ast.parse(code)))
|
|
|
|
|
|
class OptimFunctionCollector(cst.CSTVisitor):
|
|
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
|
|
|
|
def __init__(
|
|
self,
|
|
function_name: str,
|
|
class_name: str | None,
|
|
contextual_functions: set[tuple[str, str]],
|
|
preexisting_objects: list[tuple[str, list[FunctionParent]]] | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
if preexisting_objects is None:
|
|
preexisting_objects = []
|
|
self.function_name = function_name
|
|
self.class_name = class_name
|
|
self.optim_body: FunctionDef | None = None
|
|
self.optim_new_class_functions: list[cst.FunctionDef] = []
|
|
self.optim_new_functions: list[cst.FunctionDef] = []
|
|
self.preexisting_objects = preexisting_objects
|
|
self.contextual_functions = contextual_functions.union({(self.class_name, self.function_name)})
|
|
|
|
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
|
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
|
parent2 = None
|
|
try:
|
|
if parent is not None and isinstance(parent, cst.Module):
|
|
parent2 = self.get_metadata(cst.metadata.ParentNodeProvider, parent)
|
|
except:
|
|
pass
|
|
if node.name.value == self.function_name:
|
|
self.optim_body = node
|
|
elif (
|
|
self.preexisting_objects
|
|
and (node.name.value, []) not in self.preexisting_objects
|
|
and (isinstance(parent, cst.Module) or (parent2 is not None and not isinstance(parent2, cst.ClassDef)))
|
|
):
|
|
self.optim_new_functions.append(node)
|
|
|
|
def visit_ClassDef_body(self, node: cst.ClassDef) -> None:
|
|
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
|
|
for child_node in node.body.body:
|
|
if (
|
|
self.preexisting_objects
|
|
and isinstance(child_node, cst.FunctionDef)
|
|
and (node.name.value, child_node.name.value) not in self.contextual_functions
|
|
and (child_node.name.value, parents) not in self.preexisting_objects
|
|
):
|
|
self.optim_new_class_functions.append(child_node)
|
|
|
|
|
|
class OptimFunctionReplacer(cst.CSTTransformer):
|
|
def __init__(
|
|
self,
|
|
function_name: str,
|
|
optim_body: cst.FunctionDef,
|
|
optim_new_class_functions: list[cst.FunctionDef],
|
|
optim_new_functions: list[cst.FunctionDef],
|
|
class_name: str | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.function_name = function_name
|
|
self.optim_body = optim_body
|
|
self.optim_new_class_functions = optim_new_class_functions
|
|
self.optim_new_functions = optim_new_functions
|
|
self.class_name = class_name
|
|
self.depth: int = 0
|
|
self.in_class: bool = False
|
|
|
|
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
|
return False
|
|
|
|
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
|
if original_node.name.value == self.function_name and (self.depth == 0 or (self.depth == 1 and self.in_class)):
|
|
return updated_node.with_changes(body=self.optim_body.body, decorators=self.optim_body.decorators)
|
|
return updated_node
|
|
|
|
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
|
self.depth += 1
|
|
if self.in_class:
|
|
return False
|
|
self.in_class = (self.depth == 1) and (node.name.value == self.class_name)
|
|
return self.in_class
|
|
|
|
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
|
self.depth -= 1
|
|
if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name):
|
|
self.in_class = False
|
|
return updated_node.with_changes(
|
|
body=updated_node.body.with_changes(
|
|
body=(list(updated_node.body.body) + self.optim_new_class_functions)
|
|
)
|
|
)
|
|
return updated_node
|
|
|
|
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
|
node = updated_node
|
|
max_function_index = None
|
|
class_index = None
|
|
for index, _node in enumerate(node.body):
|
|
if isinstance(_node, cst.FunctionDef):
|
|
max_function_index = index
|
|
if isinstance(_node, cst.ClassDef):
|
|
class_index = index
|
|
if max_function_index is not None:
|
|
node = node.with_changes(
|
|
body=(
|
|
*node.body[: max_function_index + 1],
|
|
*self.optim_new_functions,
|
|
*node.body[max_function_index + 1 :],
|
|
)
|
|
)
|
|
elif class_index is not None:
|
|
node = node.with_changes(
|
|
body=(*node.body[: class_index + 1], *self.optim_new_functions, *node.body[class_index + 1 :])
|
|
)
|
|
else:
|
|
node = node.with_changes(body=(*self.optim_new_functions, *node.body))
|
|
return node
|
|
|
|
|
|
def replace_functions_in_file(
|
|
source_code: str,
|
|
original_function_names: list[str],
|
|
optimized_code: str,
|
|
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
|
contextual_functions: set[tuple[str, str]],
|
|
) -> str:
|
|
parsed_function_names = []
|
|
for original_function_name in original_function_names:
|
|
if original_function_name.count(".") == 0:
|
|
class_name, function_name = None, original_function_name
|
|
elif original_function_name.count(".") == 1:
|
|
class_name, function_name = original_function_name.split(".")
|
|
else:
|
|
msg = f"Don't know how to find {original_function_name} yet!"
|
|
raise ValueError(msg)
|
|
parsed_function_names.append((function_name, class_name))
|
|
|
|
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
|
|
|
|
for function_name, class_name in parsed_function_names:
|
|
visitor = OptimFunctionCollector(function_name, class_name, contextual_functions, preexisting_objects)
|
|
module.visit(visitor)
|
|
|
|
if visitor.optim_body is None and not preexisting_objects:
|
|
continue
|
|
if visitor.optim_body is None:
|
|
msg = f"Did not find the function {function_name} in the optimized code"
|
|
raise ValueError(msg)
|
|
|
|
transformer = OptimFunctionReplacer(
|
|
visitor.function_name,
|
|
visitor.optim_body,
|
|
visitor.optim_new_class_functions,
|
|
visitor.optim_new_functions,
|
|
class_name=class_name,
|
|
)
|
|
original_module = cst.parse_module(source_code)
|
|
modified_tree = original_module.visit(transformer)
|
|
source_code = modified_tree.code
|
|
|
|
return source_code
|
|
|
|
|
|
def replace_functions_and_add_imports(
|
|
source_code: str,
|
|
function_names: list[str],
|
|
optimized_code: str,
|
|
file_path_of_module_with_function_to_optimize: Path,
|
|
module_abspath: Path,
|
|
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
|
contextual_functions: set[tuple[str, str]],
|
|
project_root_path: Path,
|
|
) -> str:
|
|
return add_needed_imports_from_module(
|
|
optimized_code,
|
|
replace_functions_in_file(
|
|
source_code, function_names, optimized_code, preexisting_objects, contextual_functions
|
|
),
|
|
file_path_of_module_with_function_to_optimize,
|
|
module_abspath,
|
|
project_root_path,
|
|
)
|
|
|
|
|
|
def replace_function_definitions_in_module(
|
|
function_names: list[str],
|
|
optimized_code: str,
|
|
file_path_of_module_with_function_to_optimize: Path,
|
|
module_abspath: Path,
|
|
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
|
contextual_functions: set[tuple[str, str]],
|
|
project_root_path: Path,
|
|
) -> bool:
|
|
source_code: str = module_abspath.read_text(encoding="utf8")
|
|
new_code: str = replace_functions_and_add_imports(
|
|
source_code,
|
|
function_names,
|
|
optimized_code,
|
|
file_path_of_module_with_function_to_optimize,
|
|
module_abspath,
|
|
preexisting_objects,
|
|
contextual_functions,
|
|
project_root_path,
|
|
)
|
|
if is_zero_diff(source_code, new_code):
|
|
return False
|
|
module_abspath.write_text(new_code, encoding="utf8")
|
|
return True
|
|
|
|
|
|
def is_zero_diff(original_code: str, new_code: str) -> bool:
|
|
return normalize_code(original_code) == normalize_code(new_code)
|