Fixing latest helper method dupe issues

This commit is contained in:
RD 2024-06-09 05:30:06 -07:00
parent 1f84426a9e
commit a55eeb417a
3 changed files with 112 additions and 40 deletions

View file

@ -2,7 +2,8 @@ import ast
import logging
import os
import re
from typing import Tuple, Union
from collections import defaultdict
from typing import Union
import jedi
import tiktoken
@ -39,9 +40,9 @@ class Source:
def get_type_annotation_context(
function: FunctionToOptimize,
jedi_script: jedi.Script,
project_root_path: str,
function: FunctionToOptimize,
jedi_script: jedi.Script,
project_root_path: str,
) -> list[tuple[Source, str, str]]:
function_name: str = function.function_name
file_path: str = function.file_path
@ -57,14 +58,14 @@ def get_type_annotation_context(
contextual_dunder_methods = set()
def get_annotation_source(
jedi_script: jedi.Script,
name: str,
node_parents,
line_no: int,
col_no: str,
) -> str:
j_script: jedi.Script,
name: str,
node_parents: list[FunctionParent],
line_no: int,
col_no: str,
) -> None:
try:
definition: list[Name] = jedi_script.goto(
definition: list[Name] = j_script.goto(
line=line_no,
column=col_no,
follow_imports=True,
@ -82,10 +83,10 @@ def get_type_annotation_context(
definition_path = str(definition[0].module_path)
# The definition is part of this project and not defined within the original function
if (
definition_path.startswith(project_root_path + os.sep)
and definition[0].full_name
and not path_belongs_to_site_packages(definition_path)
and not belongs_to_function(definition[0], function_name)
definition_path.startswith(project_root_path + os.sep)
and definition[0].full_name
and not path_belongs_to_site_packages(definition_path)
and not belongs_to_function(definition[0], function_name)
):
source_code = get_code(
[
@ -113,16 +114,16 @@ def get_type_annotation_context(
contextual_dunder_methods.update(source_code[1])
def visit_children(
node: Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module],
node_parents: list[FunctionParent],
node: Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module],
node_parents: list[FunctionParent],
) -> None:
child: Union[ast.AST, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module]
for child in ast.iter_child_nodes(node):
visit(child, node_parents)
def visit_all_annotation_children(
node: Union[ast.Subscript, ast.Name, ast.BinOp],
node_parents: list[FunctionParent],
node: Union[ast.Subscript, ast.Name, ast.BinOp],
node_parents: list[FunctionParent],
) -> None:
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
visit_all_annotation_children(node.left, node_parents)
@ -146,8 +147,8 @@ def get_type_annotation_context(
visit_all_annotation_children(node.value, node_parents)
def visit(
node: Union[ast.AST, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module],
node_parents: list[FunctionParent],
node: Union[ast.AST, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module],
node_parents: list[FunctionParent],
) -> None:
if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
@ -171,9 +172,9 @@ def get_type_annotation_context(
def get_function_variables_definitions(
function_to_optimize: FunctionToOptimize,
project_root_path: str,
) -> Tuple[list[tuple[Source, str, str]], set[tuple[str, str]]]:
function_to_optimize: FunctionToOptimize,
project_root_path: str,
) -> tuple[list[tuple[Source, str, str]], set[tuple[str, str]]]:
function_name = function_to_optimize.function_name
file_path = function_to_optimize.file_path
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
@ -188,8 +189,8 @@ def get_function_variables_definitions(
if function_to_optimize.parents:
# Check if the reference belongs to the specified class when FunctionParent is provided
if belongs_to_class(
ref,
function_to_optimize.parents[-1].name,
ref,
function_to_optimize.parents[-1].name,
) and belongs_to_function(ref, function_name):
names.append(ref)
elif belongs_to_function(ref, function_name):
@ -214,10 +215,10 @@ def get_function_variables_definitions(
definition_path = str(definition.module_path)
# The definition is part of this project and not defined within the original function
if (
definition_path.startswith(project_root_path + os.sep)
and not path_belongs_to_site_packages(definition_path)
and definition.full_name
and not belongs_to_function(definition, function_name)
definition_path.startswith(project_root_path + os.sep)
and not path_belongs_to_site_packages(definition_path)
and definition.full_name
and not belongs_to_function(definition, function_name)
):
module_name = module_name_from_file_path(definition_path, project_root_path)
m = re.match(rf"{module_name}\.(.*)\.{definitions[0].name}", definitions[0].full_name)
@ -250,23 +251,29 @@ def get_function_variables_definitions(
)
sources[:0] = annotation_sources # prepend the annotation sources
contextual_dunder_methods.update(annotation_dunder_methods)
deduped_sources = []
existing_full_names = set()
no_parent_sources: dict[str, dict[str, set[tuple[Source, str, str]]]] = defaultdict(lambda: defaultdict(set))
parent_sources = set()
for source in sources:
if source[0].full_name not in existing_full_names:
deduped_sources.append(source)
existing_full_names.add(source[0].full_name)
return deduped_sources, contextual_dunder_methods
if (full_name := source[0].full_name) not in existing_full_names:
if not source[2].count("."):
no_parent_sources[source[1]][source[2]].add(source)
else:
parent_sources.add(source)
existing_full_names.add(full_name)
deduped_parent_sources = [source for source in parent_sources if source[1] not in no_parent_sources or source[2].rpartition('.')[0] not in no_parent_sources[source[1]]]
deduped_no_parent_sources = [source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2]]
return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods
MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k
def get_constrained_function_context_and_helper_functions(
function_to_optimize: FunctionToOptimize,
project_root_path: str,
code_to_optimize: str,
max_tokens: int = MAX_PROMPT_TOKENS,
function_to_optimize: FunctionToOptimize,
project_root_path: str,
code_to_optimize: str,
max_tokens: int = MAX_PROMPT_TOKENS,
) -> tuple[str, list[tuple[Source, str, str]], set[tuple[str, str]]]:
# TODO: Not just do static analysis, but also find the datatypes of function arguments by running the existing
# unittests and inspecting the arguments to resolve the real definitions and dependencies.

View file

@ -1,8 +1,14 @@
from __future__ import annotations
import os
from argparse import Namespace
from pathlib import Path
from returns.pipeline import is_successful
from codeflash.code_utils.code_replacer import replace_functions_in_file
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionParent
from codeflash.optimization.optimizer import Optimizer
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
@ -707,3 +713,62 @@ print("Hello world")
contextual_functions,
)
assert new_code == expected
class HelperClass:
def __init__(self, name):
self.name = name
def innocent_bystander(self):
pass
def helper_method(self):
return self.name
class MainClass:
def __init__(self, name):
self.name = name
def main_method(self):
return HelperClass(self.name).helper_method()
def test_code_replacement10() -> None:
get_code_output = """from __future__ import annotations
class HelperClass:
def __init__(self, name):
self.name = name
def innocent_bystander(self):
pass
def helper_method(self):
return self.name
class MainClass:
def __init__(self, name):
self.name = name
def main_method(self):
return HelperClass(self.name).helper_method()
"""
file_path = Path(__file__).resolve()
opt = Optimizer(
Namespace(
project_root=str(file_path.parent.resolve()),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
),
)
func_top_optimize = FunctionToOptimize(function_name="main_method", file_path=str(file_path),
parents=[FunctionParent("MainClass", "ClassDef")])
with open(file_path) as f:
original_code = f.read()
code_context = opt.get_code_optimization_context(function_to_optimize=func_top_optimize,
project_root=str(file_path.parent),
original_source_code=original_code).unwrap()
assert code_context.code_to_optimize_with_helpers == get_code_output

View file

@ -82,8 +82,8 @@ def test_multiple_classes_dependencies():
# assert len(helper_functions) == 2
assert list(map(lambda x: x[0].full_name, helper_functions[0])) == [
"test_function_dependencies.C.calculate_something_3",
"test_function_dependencies.global_dependency_3",
"test_function_dependencies.C.calculate_something_3",
]