mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
550 lines
22 KiB
Python
550 lines
22 KiB
Python
from __future__ import annotations
|
|
|
|
import ast
|
|
import os
|
|
import random
|
|
from _ast import AsyncFunctionDef, ClassDef, FunctionDef
|
|
from collections import defaultdict
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
import git
|
|
import libcst as cst
|
|
from pydantic.dataclasses import dataclass
|
|
|
|
from codeflash.api.cfapi import get_blocklisted_functions
|
|
from codeflash.cli_cmds.console import logger
|
|
from codeflash.code_utils.code_utils import (
|
|
is_class_defined_in_file,
|
|
module_name_from_file_path,
|
|
path_belongs_to_site_packages,
|
|
)
|
|
from codeflash.code_utils.git_utils import get_git_diff
|
|
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
|
from codeflash.telemetry.posthog_cf import ph
|
|
|
|
if TYPE_CHECKING:
|
|
from libcst import CSTNode
|
|
from libcst.metadata import CodeRange
|
|
|
|
from codeflash.verification.verification_utils import TestConfig
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class FunctionProperties:
|
|
is_top_level: bool
|
|
has_args: Optional[bool]
|
|
is_staticmethod: Optional[bool]
|
|
is_classmethod: Optional[bool]
|
|
staticmethod_class_name: Optional[str]
|
|
|
|
|
|
class ReturnStatementVisitor(cst.CSTVisitor):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.has_return_statement: bool = False
|
|
|
|
def visit_Return(self, node: cst.Return) -> None:
|
|
self.has_return_statement = True
|
|
|
|
|
|
class FunctionVisitor(cst.CSTVisitor):
|
|
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.ParentNodeProvider)
|
|
|
|
def __init__(self, file_path: str) -> None:
|
|
super().__init__()
|
|
self.file_path: str = file_path
|
|
self.functions: list[FunctionToOptimize] = []
|
|
|
|
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
|
return_visitor: ReturnStatementVisitor = ReturnStatementVisitor()
|
|
node.visit(return_visitor)
|
|
if return_visitor.has_return_statement:
|
|
pos: CodeRange = self.get_metadata(cst.metadata.PositionProvider, node)
|
|
parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
|
ast_parents: list[FunctionParent] = []
|
|
while parents is not None:
|
|
if isinstance(parents, (cst.FunctionDef, cst.ClassDef)):
|
|
ast_parents.append(FunctionParent(parents.name.value, parents.__class__.__name__))
|
|
parents = self.get_metadata(cst.metadata.ParentNodeProvider, parents, default=None)
|
|
self.functions.append(
|
|
FunctionToOptimize(
|
|
function_name=node.name.value,
|
|
file_path=self.file_path,
|
|
parents=list(reversed(ast_parents)),
|
|
starting_line=pos.start.line,
|
|
ending_line=pos.end.line,
|
|
),
|
|
)
|
|
|
|
|
|
class FunctionWithReturnStatement(ast.NodeVisitor):
|
|
def __init__(self, file_path: str) -> None:
|
|
self.functions: list[FunctionToOptimize] = []
|
|
self.ast_path: list[FunctionParent] = []
|
|
self.file_path: str = file_path
|
|
|
|
def visit_FunctionDef(self, node: FunctionDef) -> None:
|
|
# Check if the function has a return statement and add it to the list
|
|
if function_has_return_statement(node):
|
|
self.functions.append(
|
|
FunctionToOptimize(
|
|
function_name=node.name,
|
|
file_path=self.file_path,
|
|
parents=self.ast_path[:],
|
|
),
|
|
)
|
|
# Continue visiting the body of the function to find nested functions
|
|
self.generic_visit(node)
|
|
|
|
def generic_visit(self, node: ast.AST) -> None:
|
|
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
|
|
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
|
|
super().generic_visit(node)
|
|
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
|
|
self.ast_path.pop()
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class FunctionParent:
|
|
name: str
|
|
type: str
|
|
|
|
|
|
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
|
class FunctionToOptimize:
|
|
"""Represents a function that is a candidate for optimization.
|
|
|
|
Attributes
|
|
----------
|
|
function_name: The name of the function.
|
|
file_path: The absolute file path where the function is located.
|
|
parents: A list of parent scopes, which could be classes or functions.
|
|
starting_line: The starting line number of the function in the file.
|
|
ending_line: The ending line number of the function in the file.
|
|
|
|
The qualified_name property provides the full name of the function, including
|
|
any parent class or function names. The qualified_name_with_modules_from_root
|
|
method extends this with the module name from the project root.
|
|
|
|
"""
|
|
|
|
function_name: str
|
|
file_path: Path
|
|
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
|
|
starting_line: Optional[int] = None
|
|
ending_line: Optional[int] = None
|
|
|
|
@property
|
|
def top_level_parent_name(self) -> str:
|
|
return self.function_name if not self.parents else self.parents[0].name
|
|
|
|
def __str__(self) -> str:
|
|
return (
|
|
f"{self.file_path}:{'.'.join([p.name for p in self.parents])}"
|
|
f"{'.' if self.parents else ''}{self.function_name}"
|
|
)
|
|
|
|
@property
|
|
def qualified_name(self) -> str:
|
|
return self.function_name if self.parents == [] else f"{self.parents[0].name}.{self.function_name}"
|
|
|
|
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
|
|
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
|
|
|
|
|
|
def get_functions_to_optimize(
|
|
optimize_all: str | None,
|
|
replay_test: str | None,
|
|
file: Path | None,
|
|
only_get_this_function: str | None,
|
|
test_cfg: TestConfig,
|
|
ignore_paths: list[Path],
|
|
project_root: Path,
|
|
module_root: Path,
|
|
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
|
|
assert (
|
|
sum(
|
|
[ # Ensure only one of the options is provided
|
|
bool(optimize_all),
|
|
bool(replay_test),
|
|
bool(file),
|
|
],
|
|
)
|
|
<= 1
|
|
), "Only one of optimize_all, replay_test, or file should be provided"
|
|
functions: dict[str, list[FunctionToOptimize]]
|
|
if optimize_all:
|
|
logger.info("Finding all functions in the module '%s' ...", optimize_all)
|
|
functions = get_all_files_and_functions(Path(optimize_all))
|
|
elif replay_test is not None:
|
|
functions = get_all_replay_test_functions(
|
|
replay_test=replay_test,
|
|
test_cfg=test_cfg,
|
|
project_root_path=project_root,
|
|
)
|
|
|
|
elif file is not None:
|
|
logger.info("Finding all functions in the file '%s' ...", file)
|
|
functions = find_all_functions_in_file(file)
|
|
if only_get_this_function is not None:
|
|
split_function = only_get_this_function.split(".")
|
|
if len(split_function) > 2:
|
|
msg = "Function name should be in the format 'function_name' or 'class_name.function_name'"
|
|
raise ValueError(msg)
|
|
if len(split_function) == 2:
|
|
class_name, only_function_name = split_function
|
|
else:
|
|
class_name = None
|
|
only_function_name = split_function[0]
|
|
found_function = None
|
|
for fn in functions.get(str(file), []):
|
|
if only_function_name == fn.function_name and (
|
|
class_name is None or class_name == fn.top_level_parent_name
|
|
):
|
|
found_function = fn
|
|
if found_function is None:
|
|
msg = f"Function {only_function_name} not found in file {file} or the function does not have a 'return' statement."
|
|
raise ValueError(msg)
|
|
functions[str(file)] = [found_function]
|
|
else:
|
|
logger.info("Finding all functions modified in the current git diff ...")
|
|
ph("cli-optimizing-git-diff")
|
|
functions = get_functions_within_git_diff()
|
|
filtered_modified_functions, functions_count = filter_functions(
|
|
functions,
|
|
test_cfg.tests_root,
|
|
ignore_paths,
|
|
project_root,
|
|
module_root,
|
|
)
|
|
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
|
|
return filtered_modified_functions, functions_count
|
|
|
|
|
|
def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]:
|
|
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=False)
|
|
modified_functions: dict[str, list[FunctionToOptimize]] = {}
|
|
for path_str in modified_lines:
|
|
path = Path(path_str)
|
|
if not path.exists():
|
|
continue
|
|
with path.open(encoding="utf8") as f:
|
|
file_content = f.read()
|
|
try:
|
|
wrapper = cst.metadata.MetadataWrapper(cst.parse_module(file_content))
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
continue
|
|
function_lines = FunctionVisitor(file_path=str(path))
|
|
wrapper.visit(function_lines)
|
|
modified_functions[str(path)] = [
|
|
function_to_optimize
|
|
for function_to_optimize in function_lines.functions
|
|
if (start_line := function_to_optimize.starting_line) is not None
|
|
and (end_line := function_to_optimize.ending_line) is not None
|
|
and any(start_line <= line <= end_line for line in modified_lines[path_str])
|
|
]
|
|
return modified_functions
|
|
|
|
|
|
def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[FunctionToOptimize]]:
|
|
functions: dict[str, list[FunctionToOptimize]] = {}
|
|
module_root_path = Path(module_root_path)
|
|
for file_path in module_root_path.rglob("*.py"):
|
|
# Find all the functions in the file
|
|
functions.update(find_all_functions_in_file(file_path).items())
|
|
# Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time.
|
|
# Helpful if an optimize-all run is stuck and we restart it.
|
|
files_list = list(functions.items())
|
|
random.shuffle(files_list)
|
|
return dict(files_list)
|
|
|
|
|
|
def find_all_functions_in_file(file_path: Path) -> dict[str, list[FunctionToOptimize]]:
|
|
functions: dict[str, list[FunctionToOptimize]] = {}
|
|
with file_path.open(encoding="utf8") as f:
|
|
try:
|
|
ast_module = ast.parse(f.read())
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
return functions
|
|
function_name_visitor = FunctionWithReturnStatement(str(file_path))
|
|
function_name_visitor.visit(ast_module)
|
|
functions[str(file_path)] = function_name_visitor.functions
|
|
return functions
|
|
|
|
|
|
def get_all_replay_test_functions(
|
|
replay_test: str,
|
|
test_cfg: TestConfig,
|
|
project_root_path: Path,
|
|
) -> dict[str, list[FunctionToOptimize]]:
|
|
function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
|
|
# Get the absolute file paths for each function, excluding class name if present
|
|
filtered_valid_functions = defaultdict(list)
|
|
file_to_functions_map = defaultdict(list)
|
|
# below logic can be cleaned up with a better data structure to store the function paths
|
|
for function in function_tests:
|
|
parts = function.split(".")
|
|
module_path_parts = parts[:-1] # Exclude the function or method name
|
|
function_name = parts[-1]
|
|
# Check if the second-to-last part is a class name
|
|
class_name = (
|
|
module_path_parts[-1]
|
|
if module_path_parts
|
|
and is_class_defined_in_file(
|
|
module_path_parts[-1],
|
|
Path(project_root_path, *module_path_parts[:-1]).with_suffix(".py"),
|
|
)
|
|
else None
|
|
)
|
|
if class_name:
|
|
# If there is a class name, append it to the module path
|
|
function = class_name + "." + function_name
|
|
file_path_parts = module_path_parts[:-1] # Exclude the class name
|
|
else:
|
|
function = function_name
|
|
file_path_parts = module_path_parts
|
|
file_path = Path(project_root_path, *file_path_parts).with_suffix(".py")
|
|
file_to_functions_map[file_path].append((function, function_name, class_name))
|
|
for file_path, functions in file_to_functions_map.items():
|
|
all_valid_functions: dict[str, list[FunctionToOptimize]] = find_all_functions_in_file(
|
|
file_path=file_path,
|
|
)
|
|
filtered_list = []
|
|
for function in functions:
|
|
function_name, function_name_only, class_name = function
|
|
filtered_list.extend(
|
|
[
|
|
valid_function
|
|
for valid_function in all_valid_functions[file_path]
|
|
if valid_function.qualified_name == function_name
|
|
],
|
|
)
|
|
if len(filtered_list):
|
|
filtered_valid_functions[file_path] = filtered_list
|
|
|
|
return filtered_valid_functions
|
|
|
|
|
|
def is_git_repo(file_path: str) -> bool:
|
|
try:
|
|
git.Repo(file_path, search_parent_directories=True)
|
|
return True
|
|
except git.InvalidGitRepositoryError:
|
|
return False
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def ignored_submodule_paths(module_root: str) -> list[str]:
|
|
if is_git_repo(module_root):
|
|
git_repo = git.Repo(module_root, search_parent_directories=True)
|
|
return [
|
|
Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules
|
|
]
|
|
return []
|
|
|
|
|
|
class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|
def __init__(
|
|
self,
|
|
file_name: Path,
|
|
function_or_method_name: str,
|
|
class_name: str | None = None,
|
|
line_no: int | None = None,
|
|
) -> None:
|
|
self.file_name = file_name
|
|
self.class_name = class_name
|
|
self.function_name = function_or_method_name
|
|
self.is_top_level = False
|
|
self.function_has_args = None
|
|
self.line_no = line_no
|
|
self.is_staticmethod = False
|
|
self.is_classmethod = False
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
|
if self.class_name is None and node.name == self.function_name:
|
|
self.is_top_level = True
|
|
self.function_has_args = any(
|
|
(
|
|
bool(node.args.args),
|
|
bool(node.args.kwonlyargs),
|
|
bool(node.args.kwarg),
|
|
bool(node.args.posonlyargs),
|
|
bool(node.args.vararg),
|
|
),
|
|
)
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
|
# iterate over the class methods
|
|
if node.name == self.class_name:
|
|
for body_node in node.body:
|
|
if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name:
|
|
self.is_top_level = True
|
|
if any(
|
|
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
|
|
for decorator in body_node.decorator_list
|
|
):
|
|
self.is_classmethod = True
|
|
return
|
|
else:
|
|
# search if the class has a staticmethod with the same name and on the same line number
|
|
for body_node in node.body:
|
|
if (
|
|
isinstance(body_node, ast.FunctionDef)
|
|
and body_node.name == self.function_name
|
|
and body_node.lineno in {self.line_no, self.line_no + 1}
|
|
and any(
|
|
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
|
|
for decorator in body_node.decorator_list
|
|
)
|
|
):
|
|
self.is_staticmethod = True
|
|
self.is_top_level = True
|
|
self.class_name = node.name
|
|
return
|
|
|
|
return
|
|
|
|
|
|
def inspect_top_level_functions_or_methods(
|
|
file_name: Path,
|
|
function_or_method_name: str,
|
|
class_name: str | None = None,
|
|
line_no: int | None = None,
|
|
) -> FunctionProperties:
|
|
with open(file_name, encoding="utf8") as file:
|
|
try:
|
|
ast_module = ast.parse(file.read())
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
return False
|
|
visitor = TopLevelFunctionOrMethodVisitor(
|
|
file_name=file_name,
|
|
function_or_method_name=function_or_method_name,
|
|
class_name=class_name,
|
|
line_no=line_no,
|
|
)
|
|
visitor.visit(ast_module)
|
|
staticmethod_class_name = visitor.class_name if visitor.is_staticmethod else None
|
|
return FunctionProperties(
|
|
is_top_level=visitor.is_top_level,
|
|
has_args=visitor.function_has_args,
|
|
is_staticmethod=visitor.is_staticmethod,
|
|
is_classmethod=visitor.is_classmethod,
|
|
staticmethod_class_name=staticmethod_class_name,
|
|
)
|
|
|
|
|
|
def filter_functions(
|
|
modified_functions: dict[str, list[FunctionToOptimize]],
|
|
tests_root: Path,
|
|
ignore_paths: list[Path],
|
|
project_root: Path,
|
|
module_root: Path,
|
|
disable_logs: bool = False,
|
|
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
|
|
blocklist_funcs = get_blocklisted_functions()
|
|
# Remove any function that we don't want to optimize
|
|
|
|
# Ignore files with submodule path, cache the submodule paths
|
|
submodule_paths = ignored_submodule_paths(module_root)
|
|
|
|
filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {}
|
|
functions_count: int = 0
|
|
test_functions_removed_count: int = 0
|
|
non_modules_removed_count: int = 0
|
|
site_packages_removed_count: int = 0
|
|
ignore_paths_removed_count: int = 0
|
|
malformed_paths_count: int = 0
|
|
submodule_ignored_paths_count: int = 0
|
|
tests_root_str = str(tests_root)
|
|
module_root_str = str(module_root)
|
|
# We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
|
|
for file_path, functions in modified_functions.items():
|
|
if file_path.startswith(tests_root_str + os.sep):
|
|
test_functions_removed_count += len(functions)
|
|
continue
|
|
if file_path in ignore_paths or any(
|
|
# file_path.startswith(ignore_path + os.sep) for ignore_path in ignore_paths if ignore_path
|
|
file_path.startswith(str(ignore_path) + os.sep)
|
|
for ignore_path in ignore_paths
|
|
):
|
|
ignore_paths_removed_count += 1
|
|
continue
|
|
if file_path in submodule_paths or any(
|
|
file_path.startswith(str(submodule_path) + os.sep) for submodule_path in submodule_paths
|
|
):
|
|
submodule_ignored_paths_count += 1
|
|
continue
|
|
if path_belongs_to_site_packages(Path(file_path)):
|
|
site_packages_removed_count += len(functions)
|
|
continue
|
|
if not file_path.startswith(module_root_str + os.sep):
|
|
non_modules_removed_count += len(functions)
|
|
continue
|
|
try:
|
|
ast.parse(f"import {module_name_from_file_path(Path(file_path), project_root)}")
|
|
except SyntaxError:
|
|
malformed_paths_count += 1
|
|
continue
|
|
if blocklist_funcs:
|
|
for function in functions.copy():
|
|
path = Path(function.file_path).name
|
|
if path in blocklist_funcs and function.function_name in blocklist_funcs[path]:
|
|
functions.remove(function)
|
|
logger.debug(
|
|
f"Skipping {function.function_name} in {path} as it has already been optimized",
|
|
)
|
|
continue
|
|
|
|
filtered_modified_functions[file_path] = functions
|
|
functions_count += len(functions)
|
|
if not disable_logs:
|
|
log_info = {
|
|
f"{test_functions_removed_count} test function{'s' if test_functions_removed_count != 1 else ''}": test_functions_removed_count,
|
|
f"{site_packages_removed_count} site-package function{'s' if site_packages_removed_count != 1 else ''}": site_packages_removed_count,
|
|
f"{malformed_paths_count} non-importable file path{'s' if malformed_paths_count != 1 else ''}": malformed_paths_count,
|
|
f"{non_modules_removed_count} function{'s' if non_modules_removed_count != 1 else ''} outside module-root": non_modules_removed_count,
|
|
f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count,
|
|
f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count,
|
|
}
|
|
log_string: str
|
|
if log_string := "\n".join([k for k, v in log_info.items() if v > 0]):
|
|
logger.info(f"Ignoring:\n{log_string}")
|
|
return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count
|
|
|
|
|
|
def filter_files_optimized(
|
|
file_path: Path,
|
|
tests_root: Path,
|
|
ignore_paths: list[Path],
|
|
module_root: Path,
|
|
) -> bool:
|
|
"""Optimized version of the filter_functions function above.
|
|
|
|
Takes in file paths and returns the count of files that are to be optimized.
|
|
"""
|
|
submodule_paths = None
|
|
if file_path.is_relative_to(tests_root):
|
|
return False
|
|
if file_path in ignore_paths or any(
|
|
file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths
|
|
):
|
|
return False
|
|
if path_belongs_to_site_packages(file_path):
|
|
return False
|
|
if not file_path.is_relative_to(module_root):
|
|
return False
|
|
if submodule_paths is None:
|
|
submodule_paths = ignored_submodule_paths(module_root)
|
|
return not (
|
|
file_path in submodule_paths
|
|
or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths)
|
|
)
|
|
|
|
|
|
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
|
|
return any(isinstance(node, ast.Return) for node in ast.walk(function_node))
|