mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
281 lines
13 KiB
Python
281 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import ast
|
|
import os
|
|
import re
|
|
from collections import defaultdict
|
|
from typing import TYPE_CHECKING
|
|
|
|
import jedi
|
|
import tiktoken
|
|
from jedi.api.classes import Name
|
|
|
|
from codeflash.cli_cmds.console import logger
|
|
from codeflash.code_utils.code_extractor import get_code
|
|
from codeflash.code_utils.code_utils import module_name_from_file_path, path_belongs_to_site_packages
|
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
|
from codeflash.models.models import FunctionParent, FunctionSource
|
|
|
|
if TYPE_CHECKING:
|
|
from pathlib import Path
|
|
|
|
|
|
def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool:
|
|
"""Check if the given name belongs to the specified method."""
|
|
return belongs_to_function(name, method_name) and belongs_to_class(name, class_name)
|
|
|
|
|
|
def belongs_to_function(name: Name, function_name: str) -> bool:
|
|
"""Check if the given jedi Name is a direct child of the specified function"""
|
|
if name.name == function_name: # Handles function definition and recursive function calls
|
|
return False
|
|
if name := name.parent():
|
|
if name.type == "function":
|
|
return name.name == function_name
|
|
return False
|
|
|
|
|
|
def belongs_to_class(name: Name, class_name: str) -> bool:
|
|
"""Check if given jedi Name is a direct child of the specified class"""
|
|
while name := name.parent():
|
|
if name.type == "class":
|
|
return name.name == class_name
|
|
return False
|
|
|
|
|
|
def get_type_annotation_context(
|
|
function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path
|
|
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
|
|
function_name: str = function.function_name
|
|
file_path: Path = function.file_path
|
|
file_contents: str = file_path.read_text(encoding="utf8")
|
|
try:
|
|
module: ast.Module = ast.parse(file_contents)
|
|
except SyntaxError as e:
|
|
logger.exception(f"get_type_annotation_context - Syntax error in code: {e}")
|
|
return [], set()
|
|
sources: list[FunctionSource] = []
|
|
ast_parents: list[FunctionParent] = []
|
|
contextual_dunder_methods = set()
|
|
|
|
def get_annotation_source(
|
|
j_script: jedi.Script, name: str, node_parents: list[FunctionParent], line_no: int, col_no: str
|
|
) -> None:
|
|
try:
|
|
definition: list[Name] = j_script.goto(
|
|
line=line_no, column=col_no, follow_imports=True, follow_builtin_imports=False
|
|
)
|
|
except Exception as ex:
|
|
if hasattr(name, "full_name"):
|
|
logger.exception(f"Error while getting definition for {name.full_name}: {ex}")
|
|
else:
|
|
logger.exception(f"Error while getting definition: {ex}")
|
|
definition = []
|
|
if definition: # TODO can be multiple definitions
|
|
definition_path = definition[0].module_path
|
|
|
|
# The definition is part of this project and not defined within the original function
|
|
if (
|
|
str(definition_path).startswith(str(project_root_path) + os.sep)
|
|
and definition[0].full_name
|
|
and not path_belongs_to_site_packages(definition_path)
|
|
and not belongs_to_function(definition[0], function_name)
|
|
):
|
|
source_code = get_code([FunctionToOptimize(definition[0].name, definition_path, node_parents[:-1])])
|
|
if source_code[0]:
|
|
sources.append(
|
|
FunctionSource(
|
|
fully_qualified_name=definition[0].full_name,
|
|
jedi_definition=definition[0],
|
|
source_code=source_code[0],
|
|
file_path=definition_path,
|
|
qualified_name=definition[0].full_name.removeprefix(definition[0].module_name + "."),
|
|
only_function_name=definition[0].name,
|
|
)
|
|
)
|
|
contextual_dunder_methods.update(source_code[1])
|
|
|
|
def visit_children(
|
|
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, node_parents: list[FunctionParent]
|
|
) -> None:
|
|
child: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module
|
|
for child in ast.iter_child_nodes(node):
|
|
visit(child, node_parents)
|
|
|
|
def visit_all_annotation_children(
|
|
node: ast.Subscript | ast.Name | ast.BinOp, node_parents: list[FunctionParent]
|
|
) -> None:
|
|
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
|
visit_all_annotation_children(node.left, node_parents)
|
|
visit_all_annotation_children(node.right, node_parents)
|
|
if isinstance(node, ast.Name) and hasattr(node, "id"):
|
|
name: str = node.id
|
|
line_no: int = node.lineno
|
|
col_no: int = node.col_offset
|
|
get_annotation_source(jedi_script, name, node_parents, line_no, col_no)
|
|
if isinstance(node, ast.Subscript):
|
|
if hasattr(node, "slice"):
|
|
if isinstance(node.slice, ast.Subscript):
|
|
visit_all_annotation_children(node.slice, node_parents)
|
|
elif isinstance(node.slice, ast.Tuple):
|
|
for elt in node.slice.elts:
|
|
if isinstance(elt, (ast.Name, ast.Subscript)):
|
|
visit_all_annotation_children(elt, node_parents)
|
|
elif isinstance(node.slice, ast.Name):
|
|
visit_all_annotation_children(node.slice, node_parents)
|
|
if hasattr(node, "value"):
|
|
visit_all_annotation_children(node.value, node_parents)
|
|
|
|
def visit(
|
|
node: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module,
|
|
node_parents: list[FunctionParent],
|
|
) -> None:
|
|
if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
if node.name == function_name and node_parents == function.parents:
|
|
arg: ast.arg
|
|
for arg in node.args.args:
|
|
if arg.annotation:
|
|
visit_all_annotation_children(arg.annotation, node_parents)
|
|
if node.returns:
|
|
visit_all_annotation_children(node.returns, node_parents)
|
|
|
|
if not isinstance(node, ast.Module):
|
|
node_parents.append(FunctionParent(node.name, type(node).__name__))
|
|
visit_children(node, node_parents)
|
|
if not isinstance(node, ast.Module):
|
|
node_parents.pop()
|
|
|
|
visit(module, ast_parents)
|
|
|
|
return sources, contextual_dunder_methods
|
|
|
|
|
|
def get_function_variables_definitions(
|
|
function_to_optimize: FunctionToOptimize, project_root_path: Path
|
|
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
|
|
function_name = function_to_optimize.function_name
|
|
file_path = function_to_optimize.file_path
|
|
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
|
|
sources: list[FunctionSource] = []
|
|
contextual_dunder_methods = set()
|
|
# TODO: The function name condition can be stricter so that it does not clash with other class names etc.
|
|
# TODO: The function could have been imported as some other name,
|
|
# we should be checking for the translation as well. Also check for the original function name.
|
|
names = []
|
|
for ref in script.get_names(all_scopes=True, definitions=False, references=True):
|
|
if ref.full_name:
|
|
if function_to_optimize.parents:
|
|
# Check if the reference belongs to the specified class when FunctionParent is provided
|
|
if belongs_to_method(ref, function_to_optimize.parents[-1].name, function_name):
|
|
names.append(ref)
|
|
elif belongs_to_function(ref, function_name):
|
|
names.append(ref)
|
|
|
|
for name in names:
|
|
try:
|
|
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
|
|
except Exception as e:
|
|
try:
|
|
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
|
|
except Exception as e:
|
|
# name.full_name can also throw exceptions sometimes
|
|
logger.exception(f"Error while getting definition: {e}")
|
|
definitions = []
|
|
if definitions:
|
|
# TODO: there can be multiple definitions, see how to handle such cases
|
|
definition = definitions[0]
|
|
definition_path = definition.module_path
|
|
|
|
# The definition is part of this project and not defined within the original function
|
|
if (
|
|
str(definition_path).startswith(str(project_root_path) + os.sep)
|
|
and not path_belongs_to_site_packages(definition_path)
|
|
and definition.full_name
|
|
and not belongs_to_function(definition, function_name)
|
|
):
|
|
module_name = module_name_from_file_path(definition_path, project_root_path)
|
|
m = re.match(rf"{module_name}\.(.*)\.{definitions[0].name}", definitions[0].full_name)
|
|
parents = []
|
|
if m:
|
|
parents = [FunctionParent(m.group(1), "ClassDef")]
|
|
|
|
source_code = get_code(
|
|
[FunctionToOptimize(function_name=definitions[0].name, file_path=definition_path, parents=parents)]
|
|
)
|
|
if source_code[0]:
|
|
sources.append(
|
|
FunctionSource(
|
|
fully_qualified_name=definition.full_name,
|
|
jedi_definition=definition,
|
|
source_code=source_code[0],
|
|
file_path=definition_path,
|
|
qualified_name=definition.full_name.removeprefix(definition.module_name + "."),
|
|
only_function_name=definition.name,
|
|
)
|
|
)
|
|
contextual_dunder_methods.update(source_code[1])
|
|
annotation_sources, annotation_dunder_methods = get_type_annotation_context(
|
|
function_to_optimize, script, project_root_path
|
|
)
|
|
sources[:0] = annotation_sources # prepend the annotation sources
|
|
contextual_dunder_methods.update(annotation_dunder_methods)
|
|
existing_fully_qualified_names = set()
|
|
no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(lambda: defaultdict(set))
|
|
parent_sources = set()
|
|
for source in sources:
|
|
if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names:
|
|
if not source.qualified_name.count("."):
|
|
no_parent_sources[source.file_path][source.qualified_name].add(source)
|
|
else:
|
|
parent_sources.add(source)
|
|
existing_fully_qualified_names.add(fully_qualified_name)
|
|
deduped_parent_sources = [
|
|
source
|
|
for source in parent_sources
|
|
if source.file_path not in no_parent_sources
|
|
or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path]
|
|
]
|
|
deduped_no_parent_sources = [
|
|
source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2]
|
|
]
|
|
return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods
|
|
|
|
|
|
MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k
|
|
|
|
|
|
def get_constrained_function_context_and_helper_functions(
|
|
function_to_optimize: FunctionToOptimize,
|
|
project_root_path: Path,
|
|
code_to_optimize: str,
|
|
max_tokens: int = MAX_PROMPT_TOKENS,
|
|
) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]:
|
|
helper_functions, dunder_methods = get_function_variables_definitions(function_to_optimize, project_root_path)
|
|
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
|
code_to_optimize_tokens = tokenizer.encode(code_to_optimize)
|
|
|
|
if not function_to_optimize.parents:
|
|
helper_functions_sources = [function.source_code for function in helper_functions]
|
|
else:
|
|
helper_functions_sources = [
|
|
function.source_code
|
|
for function in helper_functions
|
|
if not function.qualified_name.count(".")
|
|
or function.qualified_name.split(".")[0] != function_to_optimize.parents[0].name
|
|
]
|
|
helper_functions_tokens = [len(tokenizer.encode(function)) for function in helper_functions_sources]
|
|
|
|
context_list = []
|
|
context_len = len(code_to_optimize_tokens)
|
|
logger.debug(f"ORIGINAL CODE TOKENS LENGTH: {context_len}")
|
|
logger.debug(f"ALL DEPENDENCIES TOKENS LENGTH: {sum(helper_functions_tokens)}")
|
|
for function_source, source_len in zip(helper_functions_sources, helper_functions_tokens):
|
|
if context_len + source_len <= max_tokens:
|
|
context_list.append(function_source)
|
|
context_len += source_len
|
|
else:
|
|
break
|
|
logger.debug(f"FINAL OPTIMIZATION CONTEXT TOKENS LENGTH: {context_len}")
|
|
helper_code: str = "\n".join(context_list)
|
|
return helper_code, helper_functions, dunder_methods
|