Merge pull request #1369 from codeflash-ai/helper_fn_discovery_list_compre

fixed function discovery in list comprehension
This commit is contained in:
Alvin Ryanputra 2024-12-26 15:03:15 -08:00 committed by GitHub
commit 5525a67598
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 80 additions and 34 deletions

View file

@ -20,36 +20,27 @@ if TYPE_CHECKING:
from pathlib import Path
def belongs_to_class(name: Name, class_name: str) -> bool:
"""Check if the given name belongs to the specified class."""
return bool(name.full_name and name.full_name.startswith(f"{name.module_name}.{class_name}."))
def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool:
"""Check if the given name belongs to the specified method."""
return belongs_to_function(name, method_name) and belongs_to_class(name, class_name)
def belongs_to_function(name: Name, function_name: str) -> bool:
"""Check if the given name belongs to the specified function"""
if name.full_name and name.full_name.startswith(name.module_name):
subname: str = name.full_name.replace(name.module_name, "", 1)
# The name is defined inside the function or is the function itself
if f".{function_name}." in subname or f".{function_name}" == subname:
return True
return bool(name_in_listcomp_in_function(name, function_name))
"""Check if the given jedi Name is a direct child of the specified function"""
if name.name == function_name: # Handles function definition and recursive function calls
return False
if name := name.parent():
if name.type == "function":
return name.name == function_name
return False
def name_in_listcomp_in_function(name: Name, function_name: str) -> bool:
"""Check if the given name is in a list comprehension in the specified function
Special case because jedi has a bug https://github.com/davidhalter/jedi/issues/1944
"""
try:
parent_node = name._name.parent_context.tree_node.parent
if hasattr(parent_node, "type") and parent_node.type == "testlist_comp":
while parent_node := parent_node.parent:
if parent_node.type == "funcdef":
return parent_node.name.value == function_name
return False
except Exception:
# don't want to handle conformance with 3rd party library private attribute access exception types
return False
def belongs_to_class(name: Name, class_name: str) -> bool:
"""Check if given jedi Name is a direct child of the specified class"""
while name := name.parent():
if name.type == "class":
return name.name == class_name
return False
def get_type_annotation_context(
@ -176,9 +167,7 @@ def get_function_variables_definitions(
if ref.full_name:
if function_to_optimize.parents:
# Check if the reference belongs to the specified class when FunctionParent is provided
if belongs_to_class(ref, function_to_optimize.parents[-1].name) and belongs_to_function(
ref, function_name
):
if belongs_to_method(ref, function_to_optimize.parents[-1].name, function_name):
names.append(ref)
elif belongs_to_function(ref, function_name):
names.append(ref)

View file

@ -49,6 +49,19 @@ class A:
c = global_dependency_1(b)
return c
def function_in_list_comprehension(self):
return [global_dependency_3(1) for x in range(10)]
def add_two(self, num):
return num + 2
def method_in_list_comprehension(self):
return [self.add_two(1) for x in range(10)]
def nested_function(self):
def nested():
return global_dependency_3(1)
return nested() + self.add_two(3)
class B:
def calculate_something_2(self, num):
@ -79,13 +92,12 @@ class C:
def test_multiple_classes_dependencies() -> None:
# TODO: Check if C.run only gets calculate_something_3 as dependency and likewise for other classes
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
FunctionToOptimize("run", str(file_path), [FunctionParent("C", "ClassDef")]), str(file_path.parent.resolve())
)
# assert len(helper_functions) == 2
assert len(helper_functions) == 2
assert list(map(lambda x: x.fully_qualified_name, helper_functions[0])) == [
"test_function_dependencies.global_dependency_3",
"test_function_dependencies.C.calculate_something_3",
@ -291,11 +303,8 @@ def test_recursive_function_context() -> None:
if not is_successful(ctx_result):
pytest.fail()
code_context = ctx_result.unwrap()
# The code_context above should have the topologicalSortUtil function in it
assert len(code_context.helper_functions) == 2
assert set(
[code_context.helper_functions[1].fully_qualified_name, code_context.helper_functions[0].fully_qualified_name]
) == set(["test_function_dependencies.C.calculate_something_3", "test_function_dependencies.C.recursive"])
assert len(code_context.helper_functions) == 1
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
assert (
code_context.code_to_optimize_with_helpers
== """class C:
@ -318,3 +327,51 @@ def test_list_comprehension_dependency() -> None:
assert len(helper_functions) == 2
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData"
assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something"
def test_function_in_method_list_comprehension() -> None:
file_path = pathlib.Path(__file__).resolve()
function_to_optimize = FunctionToOptimize(
function_name="function_in_list_comprehension",
file_path=str(file_path),
parents=[FunctionParent(name="A", type="ClassDef")],
starting_line=None,
ending_line=None,
)
helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0]
assert len(helper_functions) == 1
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.global_dependency_3"
def test_method_in_method_list_comprehension() -> None:
file_path = pathlib.Path(__file__).resolve()
function_to_optimize = FunctionToOptimize(
function_name="method_in_list_comprehension",
file_path=str(file_path),
parents=[FunctionParent(name="A", type="ClassDef")],
starting_line=None,
ending_line=None,
)
helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0]
assert len(helper_functions) == 1
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two"
def test_nested_method() -> None:
file_path = pathlib.Path(__file__).resolve()
function_to_optimize = FunctionToOptimize(
function_name="nested_function",
file_path=str(file_path),
parents=[FunctionParent(name="A", type="ClassDef")],
starting_line=None,
ending_line=None,
)
helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0]
# The nested function should be included in the helper functions
assert len(helper_functions) == 1
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two"