mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
Merge pull request #1369 from codeflash-ai/helper_fn_discovery_list_compre
fixed function discovery in list comprehension
This commit is contained in:
commit
5525a67598
2 changed files with 80 additions and 34 deletions
|
|
@ -20,36 +20,27 @@ if TYPE_CHECKING:
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
def belongs_to_class(name: Name, class_name: str) -> bool:
|
||||
"""Check if the given name belongs to the specified class."""
|
||||
return bool(name.full_name and name.full_name.startswith(f"{name.module_name}.{class_name}."))
|
||||
def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool:
|
||||
"""Check if the given name belongs to the specified method."""
|
||||
return belongs_to_function(name, method_name) and belongs_to_class(name, class_name)
|
||||
|
||||
|
||||
def belongs_to_function(name: Name, function_name: str) -> bool:
|
||||
"""Check if the given name belongs to the specified function"""
|
||||
if name.full_name and name.full_name.startswith(name.module_name):
|
||||
subname: str = name.full_name.replace(name.module_name, "", 1)
|
||||
# The name is defined inside the function or is the function itself
|
||||
if f".{function_name}." in subname or f".{function_name}" == subname:
|
||||
return True
|
||||
return bool(name_in_listcomp_in_function(name, function_name))
|
||||
"""Check if the given jedi Name is a direct child of the specified function"""
|
||||
if name.name == function_name: # Handles function definition and recursive function calls
|
||||
return False
|
||||
if name := name.parent():
|
||||
if name.type == "function":
|
||||
return name.name == function_name
|
||||
return False
|
||||
|
||||
|
||||
def name_in_listcomp_in_function(name: Name, function_name: str) -> bool:
|
||||
"""Check if the given name is in a list comprehension in the specified function
|
||||
Special case because jedi has a bug https://github.com/davidhalter/jedi/issues/1944
|
||||
"""
|
||||
try:
|
||||
parent_node = name._name.parent_context.tree_node.parent
|
||||
if hasattr(parent_node, "type") and parent_node.type == "testlist_comp":
|
||||
while parent_node := parent_node.parent:
|
||||
if parent_node.type == "funcdef":
|
||||
return parent_node.name.value == function_name
|
||||
return False
|
||||
except Exception:
|
||||
# don't want to handle conformance with 3rd party library private attribute access exception types
|
||||
return False
|
||||
def belongs_to_class(name: Name, class_name: str) -> bool:
|
||||
"""Check if given jedi Name is a direct child of the specified class"""
|
||||
while name := name.parent():
|
||||
if name.type == "class":
|
||||
return name.name == class_name
|
||||
return False
|
||||
|
||||
|
||||
def get_type_annotation_context(
|
||||
|
|
@ -176,9 +167,7 @@ def get_function_variables_definitions(
|
|||
if ref.full_name:
|
||||
if function_to_optimize.parents:
|
||||
# Check if the reference belongs to the specified class when FunctionParent is provided
|
||||
if belongs_to_class(ref, function_to_optimize.parents[-1].name) and belongs_to_function(
|
||||
ref, function_name
|
||||
):
|
||||
if belongs_to_method(ref, function_to_optimize.parents[-1].name, function_name):
|
||||
names.append(ref)
|
||||
elif belongs_to_function(ref, function_name):
|
||||
names.append(ref)
|
||||
|
|
|
|||
|
|
@ -49,6 +49,19 @@ class A:
|
|||
c = global_dependency_1(b)
|
||||
return c
|
||||
|
||||
def function_in_list_comprehension(self):
|
||||
return [global_dependency_3(1) for x in range(10)]
|
||||
|
||||
def add_two(self, num):
|
||||
return num + 2
|
||||
|
||||
def method_in_list_comprehension(self):
|
||||
return [self.add_two(1) for x in range(10)]
|
||||
|
||||
def nested_function(self):
|
||||
def nested():
|
||||
return global_dependency_3(1)
|
||||
return nested() + self.add_two(3)
|
||||
|
||||
class B:
|
||||
def calculate_something_2(self, num):
|
||||
|
|
@ -79,13 +92,12 @@ class C:
|
|||
|
||||
|
||||
def test_multiple_classes_dependencies() -> None:
|
||||
# TODO: Check if C.run only gets calculate_something_3 as dependency and likewise for other classes
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
FunctionToOptimize("run", str(file_path), [FunctionParent("C", "ClassDef")]), str(file_path.parent.resolve())
|
||||
)
|
||||
|
||||
# assert len(helper_functions) == 2
|
||||
assert len(helper_functions) == 2
|
||||
assert list(map(lambda x: x.fully_qualified_name, helper_functions[0])) == [
|
||||
"test_function_dependencies.global_dependency_3",
|
||||
"test_function_dependencies.C.calculate_something_3",
|
||||
|
|
@ -291,11 +303,8 @@ def test_recursive_function_context() -> None:
|
|||
if not is_successful(ctx_result):
|
||||
pytest.fail()
|
||||
code_context = ctx_result.unwrap()
|
||||
# The code_context above should have the topologicalSortUtil function in it
|
||||
assert len(code_context.helper_functions) == 2
|
||||
assert set(
|
||||
[code_context.helper_functions[1].fully_qualified_name, code_context.helper_functions[0].fully_qualified_name]
|
||||
) == set(["test_function_dependencies.C.calculate_something_3", "test_function_dependencies.C.recursive"])
|
||||
assert len(code_context.helper_functions) == 1
|
||||
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
|
||||
assert (
|
||||
code_context.code_to_optimize_with_helpers
|
||||
== """class C:
|
||||
|
|
@ -318,3 +327,51 @@ def test_list_comprehension_dependency() -> None:
|
|||
assert len(helper_functions) == 2
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData"
|
||||
assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something"
|
||||
|
||||
|
||||
def test_function_in_method_list_comprehension() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="function_in_list_comprehension",
|
||||
file_path=str(file_path),
|
||||
parents=[FunctionParent(name="A", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0]
|
||||
|
||||
assert len(helper_functions) == 1
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.global_dependency_3"
|
||||
|
||||
|
||||
def test_method_in_method_list_comprehension() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="method_in_list_comprehension",
|
||||
file_path=str(file_path),
|
||||
parents=[FunctionParent(name="A", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0]
|
||||
|
||||
assert len(helper_functions) == 1
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two"
|
||||
|
||||
def test_nested_method() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="nested_function",
|
||||
file_path=str(file_path),
|
||||
parents=[FunctionParent(name="A", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0]
|
||||
|
||||
# The nested function should be included in the helper functions
|
||||
assert len(helper_functions) == 1
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two"
|
||||
|
|
|
|||
Loading…
Reference in a new issue