mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Fixing latest helper method dupe issues
This commit is contained in:
parent
1f84426a9e
commit
a55eeb417a
3 changed files with 112 additions and 40 deletions
|
|
@ -2,7 +2,8 @@ import ast
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Tuple, Union
|
||||
from collections import defaultdict
|
||||
from typing import Union
|
||||
|
||||
import jedi
|
||||
import tiktoken
|
||||
|
|
@ -39,9 +40,9 @@ class Source:
|
|||
|
||||
|
||||
def get_type_annotation_context(
|
||||
function: FunctionToOptimize,
|
||||
jedi_script: jedi.Script,
|
||||
project_root_path: str,
|
||||
function: FunctionToOptimize,
|
||||
jedi_script: jedi.Script,
|
||||
project_root_path: str,
|
||||
) -> list[tuple[Source, str, str]]:
|
||||
function_name: str = function.function_name
|
||||
file_path: str = function.file_path
|
||||
|
|
@ -57,14 +58,14 @@ def get_type_annotation_context(
|
|||
contextual_dunder_methods = set()
|
||||
|
||||
def get_annotation_source(
|
||||
jedi_script: jedi.Script,
|
||||
name: str,
|
||||
node_parents,
|
||||
line_no: int,
|
||||
col_no: str,
|
||||
) -> str:
|
||||
j_script: jedi.Script,
|
||||
name: str,
|
||||
node_parents: list[FunctionParent],
|
||||
line_no: int,
|
||||
col_no: str,
|
||||
) -> None:
|
||||
try:
|
||||
definition: list[Name] = jedi_script.goto(
|
||||
definition: list[Name] = j_script.goto(
|
||||
line=line_no,
|
||||
column=col_no,
|
||||
follow_imports=True,
|
||||
|
|
@ -82,10 +83,10 @@ def get_type_annotation_context(
|
|||
definition_path = str(definition[0].module_path)
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
definition_path.startswith(project_root_path + os.sep)
|
||||
and definition[0].full_name
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and not belongs_to_function(definition[0], function_name)
|
||||
definition_path.startswith(project_root_path + os.sep)
|
||||
and definition[0].full_name
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and not belongs_to_function(definition[0], function_name)
|
||||
):
|
||||
source_code = get_code(
|
||||
[
|
||||
|
|
@ -113,16 +114,16 @@ def get_type_annotation_context(
|
|||
contextual_dunder_methods.update(source_code[1])
|
||||
|
||||
def visit_children(
|
||||
node: Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module],
|
||||
node_parents: list[FunctionParent],
|
||||
node: Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module],
|
||||
node_parents: list[FunctionParent],
|
||||
) -> None:
|
||||
child: Union[ast.AST, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module]
|
||||
for child in ast.iter_child_nodes(node):
|
||||
visit(child, node_parents)
|
||||
|
||||
def visit_all_annotation_children(
|
||||
node: Union[ast.Subscript, ast.Name, ast.BinOp],
|
||||
node_parents: list[FunctionParent],
|
||||
node: Union[ast.Subscript, ast.Name, ast.BinOp],
|
||||
node_parents: list[FunctionParent],
|
||||
) -> None:
|
||||
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
visit_all_annotation_children(node.left, node_parents)
|
||||
|
|
@ -146,8 +147,8 @@ def get_type_annotation_context(
|
|||
visit_all_annotation_children(node.value, node_parents)
|
||||
|
||||
def visit(
|
||||
node: Union[ast.AST, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module],
|
||||
node_parents: list[FunctionParent],
|
||||
node: Union[ast.AST, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module],
|
||||
node_parents: list[FunctionParent],
|
||||
) -> None:
|
||||
if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
|
|
@ -171,9 +172,9 @@ def get_type_annotation_context(
|
|||
|
||||
|
||||
def get_function_variables_definitions(
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root_path: str,
|
||||
) -> Tuple[list[tuple[Source, str, str]], set[tuple[str, str]]]:
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root_path: str,
|
||||
) -> tuple[list[tuple[Source, str, str]], set[tuple[str, str]]]:
|
||||
function_name = function_to_optimize.function_name
|
||||
file_path = function_to_optimize.file_path
|
||||
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
|
||||
|
|
@ -188,8 +189,8 @@ def get_function_variables_definitions(
|
|||
if function_to_optimize.parents:
|
||||
# Check if the reference belongs to the specified class when FunctionParent is provided
|
||||
if belongs_to_class(
|
||||
ref,
|
||||
function_to_optimize.parents[-1].name,
|
||||
ref,
|
||||
function_to_optimize.parents[-1].name,
|
||||
) and belongs_to_function(ref, function_name):
|
||||
names.append(ref)
|
||||
elif belongs_to_function(ref, function_name):
|
||||
|
|
@ -214,10 +215,10 @@ def get_function_variables_definitions(
|
|||
definition_path = str(definition.module_path)
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
definition_path.startswith(project_root_path + os.sep)
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and definition.full_name
|
||||
and not belongs_to_function(definition, function_name)
|
||||
definition_path.startswith(project_root_path + os.sep)
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and definition.full_name
|
||||
and not belongs_to_function(definition, function_name)
|
||||
):
|
||||
module_name = module_name_from_file_path(definition_path, project_root_path)
|
||||
m = re.match(rf"{module_name}\.(.*)\.{definitions[0].name}", definitions[0].full_name)
|
||||
|
|
@ -250,23 +251,29 @@ def get_function_variables_definitions(
|
|||
)
|
||||
sources[:0] = annotation_sources # prepend the annotation sources
|
||||
contextual_dunder_methods.update(annotation_dunder_methods)
|
||||
deduped_sources = []
|
||||
existing_full_names = set()
|
||||
no_parent_sources: dict[str, dict[str, set[tuple[Source, str, str]]]] = defaultdict(lambda: defaultdict(set))
|
||||
parent_sources = set()
|
||||
for source in sources:
|
||||
if source[0].full_name not in existing_full_names:
|
||||
deduped_sources.append(source)
|
||||
existing_full_names.add(source[0].full_name)
|
||||
return deduped_sources, contextual_dunder_methods
|
||||
if (full_name := source[0].full_name) not in existing_full_names:
|
||||
if not source[2].count("."):
|
||||
no_parent_sources[source[1]][source[2]].add(source)
|
||||
else:
|
||||
parent_sources.add(source)
|
||||
existing_full_names.add(full_name)
|
||||
deduped_parent_sources = [source for source in parent_sources if source[1] not in no_parent_sources or source[2].rpartition('.')[0] not in no_parent_sources[source[1]]]
|
||||
deduped_no_parent_sources = [source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2]]
|
||||
return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods
|
||||
|
||||
|
||||
MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k
|
||||
|
||||
|
||||
def get_constrained_function_context_and_helper_functions(
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root_path: str,
|
||||
code_to_optimize: str,
|
||||
max_tokens: int = MAX_PROMPT_TOKENS,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root_path: str,
|
||||
code_to_optimize: str,
|
||||
max_tokens: int = MAX_PROMPT_TOKENS,
|
||||
) -> tuple[str, list[tuple[Source, str, str]], set[tuple[str, str]]]:
|
||||
# TODO: Not just do static analysis, but also find the datatypes of function arguments by running the existing
|
||||
# unittests and inspecting the arguments to resolve the real definitions and dependencies.
|
||||
|
|
|
|||
|
|
@ -1,8 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from returns.pipeline import is_successful
|
||||
|
||||
from codeflash.code_utils.code_replacer import replace_functions_in_file
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
||||
|
||||
|
|
@ -707,3 +713,62 @@ print("Hello world")
|
|||
contextual_functions,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def innocent_bystander(self):
|
||||
pass
|
||||
|
||||
def helper_method(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class MainClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def main_method(self):
|
||||
return HelperClass(self.name).helper_method()
|
||||
|
||||
|
||||
def test_code_replacement10() -> None:
|
||||
get_code_output = """from __future__ import annotations
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def innocent_bystander(self):
|
||||
pass
|
||||
|
||||
def helper_method(self):
|
||||
return self.name
|
||||
|
||||
class MainClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def main_method(self):
|
||||
return HelperClass(self.name).helper_method()
|
||||
"""
|
||||
file_path = Path(__file__).resolve()
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=str(file_path.parent.resolve()),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
),
|
||||
)
|
||||
func_top_optimize = FunctionToOptimize(function_name="main_method", file_path=str(file_path),
|
||||
parents=[FunctionParent("MainClass", "ClassDef")])
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
code_context = opt.get_code_optimization_context(function_to_optimize=func_top_optimize,
|
||||
project_root=str(file_path.parent),
|
||||
original_source_code=original_code).unwrap()
|
||||
assert code_context.code_to_optimize_with_helpers == get_code_output
|
||||
|
|
|
|||
|
|
@ -82,8 +82,8 @@ def test_multiple_classes_dependencies():
|
|||
|
||||
# assert len(helper_functions) == 2
|
||||
assert list(map(lambda x: x[0].full_name, helper_functions[0])) == [
|
||||
"test_function_dependencies.C.calculate_something_3",
|
||||
"test_function_dependencies.global_dependency_3",
|
||||
"test_function_dependencies.C.calculate_something_3",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue