Rename the term "dependent_function" to "helper_function"

This commit is contained in:
Saurabh Misra 2024-05-18 22:00:23 -04:00
parent 328035464f
commit 131333ebc5
7 changed files with 122 additions and 124 deletions

View file

@ -156,7 +156,7 @@ class AiServiceClient:
self,
source_code_being_tested: str,
function_to_optimize: FunctionToOptimize,
dependent_function_names: list[str],
helper_function_names: list[str],
module_path: str,
test_module_path: str,
test_framework: str,
@ -169,7 +169,7 @@ class AiServiceClient:
----------
- source_code_being_tested (str): The source code of the function being tested.
- function_to_optimize (FunctionToOptimize): The function to optimize.
- dependent_function_names (list[Source]): List of dependent function names.
- helper_function_names (list[Source]): List of helper function names.
- module_path (str): The module path where the function is located.
- test_module_path (str): The module path for the test code.
- test_framework (str): The test framework to use, e.g., "pytest".
@ -187,7 +187,7 @@ class AiServiceClient:
payload = {
"source_code_being_tested": source_code_being_tested,
"function_to_optimize": function_to_optimize,
"dependent_function_names": dependent_function_names,
"helper_function_names": helper_function_names,
"module_path": module_path,
"test_module_path": test_module_path,
"test_framework": test_framework,

View file

@ -242,7 +242,7 @@ def get_function_variables_definitions(
MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k
def get_constrained_function_context_and_dependent_functions(
def get_constrained_function_context_and_helper_functions(
function_to_optimize: FunctionToOptimize,
project_root_path: str,
code_to_optimize: str,
@ -250,7 +250,7 @@ def get_constrained_function_context_and_dependent_functions(
) -> tuple[str, list[tuple[Source, str, str]]]:
# TODO: Not just do static analysis, but also find the datatypes of function arguments by running the existing
# unittests and inspecting the arguments to resolve the real definitions and dependencies.
dependent_functions: list[tuple[Source, str, str]] = get_function_variables_definitions(
helper_functions: list[tuple[Source, str, str]] = get_function_variables_definitions(
function_to_optimize,
project_root_path,
)
@ -258,25 +258,25 @@ def get_constrained_function_context_and_dependent_functions(
code_to_optimize_tokens = tokenizer.encode(code_to_optimize)
if not function_to_optimize.parents:
dependent_functions_sources = [function[0].source_code for function in dependent_functions]
helper_functions_sources = [function[0].source_code for function in helper_functions]
else:
dependent_functions_sources = [
helper_functions_sources = [
function[0].source_code
for function in dependent_functions
for function in helper_functions
if not function[2].count(".") or function[2].split(".")[0] != function_to_optimize.parents[0].name
]
dependent_functions_tokens = [len(tokenizer.encode(function)) for function in dependent_functions_sources]
helper_functions_tokens = [len(tokenizer.encode(function)) for function in helper_functions_sources]
context_list = []
context_len = len(code_to_optimize_tokens)
logging.debug(f"ORIGINAL CODE TOKENS LENGTH: {context_len}")
logging.debug(f"ALL DEPENDENCIES TOKENS LENGTH: {sum(dependent_functions_tokens)}")
for function_source, source_len in zip(dependent_functions_sources, dependent_functions_tokens):
logging.debug(f"ALL DEPENDENCIES TOKENS LENGTH: {sum(helper_functions_tokens)}")
for function_source, source_len in zip(helper_functions_sources, helper_functions_tokens):
if context_len + source_len <= max_tokens:
context_list.append(function_source)
context_len += source_len
else:
break
logging.debug("FINAL OPTIMIZATION CONTEXT TOKENS LENGTH:", context_len)
dependent_code: str = "\n".join(context_list)
return dependent_code, dependent_functions
helper_code: str = "\n".join(context_list)
return helper_code, helper_functions

View file

@ -54,7 +54,7 @@ from codeflash.discovery.functions_to_optimize import (
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.optimization.function_context import (
Source,
get_constrained_function_context_and_dependent_functions,
get_constrained_function_context_and_helper_functions,
)
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
from codeflash.result.explanation import Explanation
@ -94,15 +94,15 @@ class GeneratedTests(BaseModel):
class BestOptimization(BaseModel):
candidate: OptimizedCandidate
dependent_functions: list[tuple[Source, str, str]]
helper_functions: list[tuple[Source, str, str]]
runtime: int
winning_test_results: TestResults
class CodeOptimizationContext(BaseModel):
code_to_optimize_with_dependents: str
code_to_optimize_with_helpers: str
contextual_dunder_methods: set[tuple[str, str]]
dependent_functions: list[tuple[Source, str, str]]
helper_functions: list[tuple[Source, str, str]]
preexisting_functions: list[str]
@ -232,21 +232,21 @@ class Optimizer:
if not is_successful(ctx_result):
return Failure(ctx_result.failure())
code_context: CodeOptimizationContext = ctx_result.unwrap()
dependent_functions_by_module_abspath = defaultdict(set)
for _, module_abspath, qualified_name in code_context.dependent_functions:
dependent_functions_by_module_abspath[module_abspath].add(qualified_name)
original_dependent_code = {}
for module_abspath in dependent_functions_by_module_abspath:
helper_functions_by_module_abspath = defaultdict(set)
for _, module_abspath, qualified_name in code_context.helper_functions:
helper_functions_by_module_abspath[module_abspath].add(qualified_name)
original_helper_code = {}
for module_abspath in helper_functions_by_module_abspath:
with pathlib.Path(module_abspath).open(encoding="utf8") as f:
dependent_code = f.read()
original_dependent_code[module_abspath] = dependent_code
logging.info(f"Code to be optimized:\n{code_context.code_to_optimize_with_dependents}")
helper_code = f.read()
original_helper_code[module_abspath] = helper_code
logging.info(f"Code to be optimized:\n{code_context.code_to_optimize_with_helpers}")
module_path = module_name_from_file_path(function_to_optimize.file_path, self.args.project_root)
for module_abspath in original_dependent_code:
code_context.code_to_optimize_with_dependents = add_needed_imports_from_module(
original_dependent_code[module_abspath],
code_context.code_to_optimize_with_dependents,
for module_abspath in original_helper_code:
code_context.code_to_optimize_with_helpers = add_needed_imports_from_module(
original_helper_code[module_abspath],
code_context.code_to_optimize_with_helpers,
module_abspath,
function_to_optimize.file_path,
self.args.project_root,
@ -259,9 +259,9 @@ class Optimizer:
self.instrumented_unittests_created.update(instrumented_unittests_created_for_function)
generated_results = self.generate_tests_and_optimizations(
code_context.code_to_optimize_with_dependents,
code_context.code_to_optimize_with_helpers,
function_to_optimize,
code_context.dependent_functions,
code_context.helper_functions,
module_path,
function_trace_id,
run_experiment=should_run_experiment,
@ -308,13 +308,13 @@ class Optimizer:
best_optimization = self.determine_best_candidate(
candidates,
code_context,
dependent_functions_by_module_abspath,
helper_functions_by_module_abspath,
function_to_optimize,
generated_tests_path,
instrumented_unittests_created_for_function,
original_code,
original_code_baseline,
original_dependent_code,
original_helper_code,
function_trace_id[:-4] + f"EXP{u}" if should_run_experiment else function_trace_id,
tests_in_file,
)
@ -341,16 +341,16 @@ class Optimizer:
generated_tests,
)
self.replace_function_and_dependents_with_optimized_code(
self.replace_function_and_helpers_with_optimized_code(
code_context,
dependent_functions_by_module_abspath,
helper_functions_by_module_abspath,
explanation,
best_optimization.candidate.source_code,
function_to_optimize.qualified_name,
)
new_code, new_dependent_code = self.reformat_code_and_dependents(
dependent_functions_by_module_abspath,
new_code, new_helper_code = self.reformat_code_and_helpers(
helper_functions_by_module_abspath,
explanation.file_path,
original_code,
)
@ -361,9 +361,9 @@ class Optimizer:
tests_root=self.test_cfg.tests_root,
)
original_code_combined = original_dependent_code.copy()
original_code_combined = original_helper_code.copy()
original_code_combined[explanation.file_path] = original_code
new_code_combined = new_dependent_code.copy()
new_code_combined = new_helper_code.copy()
new_code_combined[explanation.file_path] = new_code
if not self.args.no_pr:
check_create_pr(
@ -378,11 +378,11 @@ class Optimizer:
# a) Error propagation, where error in one function can cause the next optimization to fail
# b) Performance estimates become unstable, as the runtime of an optimization might be
# dependent on the runtime of the previous optimization
self.write_code_and_dependents(
self.write_code_and_helpers(
original_code,
original_dependent_code,
original_helper_code,
function_to_optimize.file_path,
dependent_functions_by_module_abspath,
helper_functions_by_module_abspath,
)
# Delete all the generated tests to not cause any clutter.
pathlib.Path(generated_tests_path).unlink(missing_ok=True)
@ -396,13 +396,13 @@ class Optimizer:
self,
candidates: list[OptimizedCandidate],
code_context: CodeOptimizationContext,
dependent_functions_by_module_abspath: dict[str, set[str]],
helper_functions_by_module_abspath: dict[str, set[str]],
function_to_optimize: FunctionToOptimize,
generated_tests_path: str,
instrumented_unittests_created_for_function: set[str],
original_code: str,
original_code_baseline: OriginalCodeBaseline,
original_dependent_code: dict[str, str],
original_helper_code: dict[str, str],
function_trace_id: str,
only_run_this_test_function: list[TestsInFile] | None = None,
) -> BestOptimization | None:
@ -442,7 +442,7 @@ class Optimizer:
for (
module_abspath,
qualified_names,
) in dependent_functions_by_module_abspath.items():
) in helper_functions_by_module_abspath.items():
replace_function_definitions_in_module(
function_names=list(qualified_names),
optimized_code=candidate.source_code,
@ -459,11 +459,11 @@ class Optimizer:
AttributeError,
) as e:
logging.error(e) # noqa: TRY400
self.write_code_and_dependents(
self.write_code_and_helpers(
original_code,
original_dependent_code,
original_helper_code,
function_to_optimize.file_path,
dependent_functions_by_module_abspath,
helper_functions_by_module_abspath,
)
continue
@ -519,16 +519,16 @@ class Optimizer:
)
best_optimization = BestOptimization(
candidate=candidate,
dependent_functions=code_context.dependent_functions,
helper_functions=code_context.helper_functions,
runtime=best_test_runtime,
winning_test_results=candidate_result.best_test_results,
)
best_runtime_until_now = best_test_runtime
self.write_code_and_dependents(
self.write_code_and_helpers(
original_code,
original_dependent_code,
original_helper_code,
function_to_optimize.file_path,
dependent_functions_by_module_abspath,
helper_functions_by_module_abspath,
)
logging.info("----------------")
self.aiservice_client.log_results(
@ -571,22 +571,22 @@ class Optimizer:
},
)
def write_code_and_dependents(
def write_code_and_helpers(
self,
original_code: str,
original_dependent_code: dict[str, str],
original_helper_code: dict[str, str],
path: str,
dependent_functions_by_module_abspath: dict[str, set[str]],
helper_functions_by_module_abspath: dict[str, set[str]],
) -> None:
with pathlib.Path(path).open("w", encoding="utf8") as f:
f.write(original_code)
for module_abspath in dependent_functions_by_module_abspath:
for module_abspath in helper_functions_by_module_abspath:
with pathlib.Path(module_abspath).open("w", encoding="utf8") as f:
f.write(original_dependent_code[module_abspath])
f.write(original_helper_code[module_abspath])
def reformat_code_and_dependents(
def reformat_code_and_helpers(
self,
dependent_functions_by_module_abspath: dict[str, set[str]],
helper_functions_by_module_abspath: dict[str, set[str]],
path: str,
original_code: str,
) -> tuple[str, dict[str, str]]:
@ -600,21 +600,21 @@ class Optimizer:
should_sort_imports,
path,
)
new_dependent_code: dict[str, str] = {
new_helper_code: dict[str, str] = {
module_abspath: format_code(
self.args.formatter_cmd,
self.args.imports_sort_cmd,
should_sort_imports,
module_abspath,
)
for module_abspath in dependent_functions_by_module_abspath
for module_abspath in helper_functions_by_module_abspath
}
return new_code, new_dependent_code
return new_code, new_helper_code
def replace_function_and_dependents_with_optimized_code(
def replace_function_and_helpers_with_optimized_code(
self,
code_context: CodeOptimizationContext,
dependent_functions_by_module_abspath: dict[str, set[str]],
helper_functions_by_module_abspath: dict[str, set[str]],
explanation: Explanation,
optimized_code: str,
qualified_function_name: str,
@ -631,7 +631,7 @@ class Optimizer:
for (
module_abspath,
qualified_names,
) in dependent_functions_by_module_abspath.items():
) in helper_functions_by_module_abspath.items():
replace_function_definitions_in_module(
function_names=list(qualified_names),
optimized_code=optimized_code,
@ -657,18 +657,18 @@ class Optimizer:
if not success:
return Failure("Error in parsing the code, skipping optimization.")
(
dependent_code,
dependent_functions,
) = get_constrained_function_context_and_dependent_functions(
helper_code,
helper_functions,
) = get_constrained_function_context_and_helper_functions(
function_to_optimize,
self.args.project_root,
code_to_optimize,
)
if function_to_optimize.parents:
function_class = function_to_optimize.parents[0].name
dependent_methods = [
helper_methods = [
df
for df in dependent_functions
for df in helper_functions
if df[2].count(".") > 0 and df[2].split(".")[0] == function_class
]
optimizable_methods = [function_to_optimize] + [
@ -679,7 +679,7 @@ class Optimizer:
None,
None,
)
for df in dependent_methods
for df in helper_methods
]
if len(optimizable_methods) > 1:
code_to_optimize, contextual_dunder_methods = extract_code(
@ -687,23 +687,23 @@ class Optimizer:
)
if code_to_optimize is None:
return Failure("Could not find function to optimize.")
code_to_optimize_with_dependents = dependent_code + "\n" + code_to_optimize
code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize
code_to_optimize_with_dependents_and_imports = add_needed_imports_from_module(
code_to_optimize_with_helpers_and_imports = add_needed_imports_from_module(
original_source_code,
code_to_optimize_with_dependents,
code_to_optimize_with_helpers,
function_to_optimize.file_path,
function_to_optimize.file_path,
project_root,
)
preexisting_functions.extend(
[fn[0].full_name.split(".")[-1] for fn in dependent_functions],
[fn[0].full_name.split(".")[-1] for fn in helper_functions],
)
return Success(
CodeOptimizationContext(
code_to_optimize_with_dependents=code_to_optimize_with_dependents_and_imports,
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
contextual_dunder_methods=contextual_dunder_methods,
dependent_functions=dependent_functions,
helper_functions=helper_functions,
preexisting_functions=preexisting_functions,
),
)
@ -758,9 +758,9 @@ class Optimizer:
def generate_tests_and_optimizations(
self,
code_to_optimize_with_dependents: str,
code_to_optimize_with_helpers: str,
function_to_optimize: FunctionToOptimize,
dependent_functions: list[tuple[Source, str, str]],
helper_functions: list[tuple[Source, str, str]],
module_path: str,
function_trace_id: str,
run_experiment: bool = False,
@ -771,15 +771,15 @@ class Optimizer:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_tests = executor.submit(
self.generate_and_instrument_tests,
code_to_optimize_with_dependents,
code_to_optimize_with_helpers,
function_to_optimize,
[definition[0].full_name for definition in dependent_functions],
[definition[0].full_name for definition in helper_functions],
module_path,
function_trace_id[:-4] + "EXP0" if run_experiment else function_trace_id,
)
future_optimization_candidates = executor.submit(
self.aiservice_client.optimize_python_code,
code_to_optimize_with_dependents,
code_to_optimize_with_helpers,
function_trace_id[:-4] + "EXP0" if run_experiment else function_trace_id,
N_CANDIDATES,
ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None,
@ -787,7 +787,7 @@ class Optimizer:
if run_experiment:
future_candidates_exp = executor.submit(
self.local_aiservice_client.optimize_python_code,
code_to_optimize_with_dependents,
code_to_optimize_with_helpers,
function_trace_id[:-4] + "EXP1",
N_CANDIDATES,
ExperimentMetadata(id=self.experiment_id, group="experiment"),
@ -1181,7 +1181,7 @@ class Optimizer:
self,
source_code_being_tested: str,
function_to_optimize: FunctionToOptimize,
dependent_function_names: list[str],
helper_function_names: list[str],
module_path: str,
function_trace_id: str,
) -> Union[Tuple[str, str], None]:
@ -1189,7 +1189,7 @@ class Optimizer:
self.aiservice_client,
source_code_being_tested=source_code_being_tested,
function_to_optimize=function_to_optimize,
dependent_function_names=dependent_function_names,
helper_function_names=helper_function_names,
module_path=module_path,
test_cfg=self.test_cfg,
test_timeout=INDIVIDUAL_TEST_TIMEOUT,

View file

@ -17,7 +17,7 @@ def generate_tests(
aiservice_client: AiServiceClient,
source_code_being_tested: str,
function_to_optimize: FunctionToOptimize,
dependent_function_names: list[str],
helper_function_names: list[str],
module_path: str,
test_cfg: TestConfig,
test_timeout: int,
@ -47,7 +47,7 @@ def generate_tests(
response = aiservice_client.generate_regression_tests(
source_code_being_tested=source_code_being_tested,
function_to_optimize=function_to_optimize,
dependent_function_names=dependent_function_names,
helper_function_names=helper_function_names,
module_path=module_path,
test_module_path=test_module_path,
test_framework=test_cfg.test_framework,

View file

@ -307,7 +307,7 @@ def blob(st):
"""
original_code_main = """import libcst as cst
from typing import Mandatory
from dependent import blob
from helper import blob
print("Au revoir")
@ -320,7 +320,7 @@ def other_function(st):
print("Salut monde")
"""
original_code_dependent = """import numpy as np
original_code_helper = """import numpy as np
print("Cool")
@ -336,7 +336,7 @@ print("Not cool")
from typing import Optional
import libcst as cst
from typing import Mandatory
from dependent import blob
from helper import blob
print("Au revoir")
@ -349,7 +349,7 @@ def other_function(st):
print("Salut monde")
"""
expected_dependent = """import libcst as cst
expected_helper = """import libcst as cst
from typing import Optional
import numpy as np
@ -372,14 +372,14 @@ print("Not cool")
)
assert new_main_code == expected_main
new_dependent_code: str = replace_functions_in_file(
original_code_dependent,
new_helper_code: str = replace_functions_in_file(
original_code_helper,
["blob"],
optim_code,
[],
set(),
)
assert new_dependent_code == expected_dependent
assert new_helper_code == expected_helper
def test_test_libcst_code_replacement7() -> None:

View file

@ -19,12 +19,12 @@ def simple_function_with_one_dep(data):
def test_simple_dependencies():
file_path = pathlib.Path(__file__).resolve()
dependent_functions = get_function_variables_definitions(
helper_functions = get_function_variables_definitions(
FunctionToOptimize("simple_function_with_one_dep", str(file_path), []),
str(file_path.parent.resolve()),
)
assert len(dependent_functions) == 1
assert dependent_functions[0][0].definition.full_name == "test_function_dependencies.calculate_something"
assert len(helper_functions) == 1
assert helper_functions[0][0].definition.full_name == "test_function_dependencies.calculate_something"
def global_dependency_1(num):
@ -75,13 +75,13 @@ class C:
def test_multiple_classes_dependencies():
# TODO: Check if C.run only gets calculate_something_3 as dependency and likewise for other classes
file_path = pathlib.Path(__file__).resolve()
dependent_functions = get_function_variables_definitions(
helper_functions = get_function_variables_definitions(
FunctionToOptimize("run", str(file_path), [FunctionParent("C", "ClassDef")]),
str(file_path.parent.resolve()),
)
# assert len(dependent_functions) == 2
assert list(map(lambda x: x[0].full_name, dependent_functions)) == [
# assert len(helper_functions) == 2
assert list(map(lambda x: x[0].full_name, helper_functions)) == [
"test_function_dependencies.C.calculate_something_3",
"test_function_dependencies.global_dependency_3",
]
@ -96,12 +96,12 @@ def recursive_dependency_1(num):
def test_recursive_dependency():
file_path = pathlib.Path(__file__).resolve()
dependent_functions = get_function_variables_definitions(
helper_functions = get_function_variables_definitions(
FunctionToOptimize("recursive_dependency_1", str(file_path), []),
str(file_path.parent.resolve()),
)
assert len(dependent_functions) == 1
assert dependent_functions[0][0].definition.full_name == "test_function_dependencies.calculate_something"
assert len(helper_functions) == 1
assert helper_functions[0][0].definition.full_name == "test_function_dependencies.calculate_something"
@dataclass
@ -119,15 +119,13 @@ def simple_function_with_one_dep_ann(data: MyData):
def test_simple_dependencies_ann():
file_path = pathlib.Path(__file__).resolve()
dependent_functions = get_function_variables_definitions(
helper_functions = get_function_variables_definitions(
FunctionToOptimize("simple_function_with_one_dep_ann", str(file_path), []),
str(file_path.parent.resolve()),
)
assert len(dependent_functions) == 2
assert dependent_functions[0][0].definition.full_name == "test_function_dependencies.MyData"
assert (
dependent_functions[1][0].definition.full_name == "test_function_dependencies.calculate_something_ann"
)
assert len(helper_functions) == 2
assert helper_functions[0][0].definition.full_name == "test_function_dependencies.MyData"
assert helper_functions[1][0].definition.full_name == "test_function_dependencies.calculate_something_ann"
from collections import defaultdict

View file

@ -7,7 +7,7 @@ from typing import List
from codeflash.code_utils.code_extractor import get_code
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.optimization.function_context import (
get_constrained_function_context_and_dependent_functions,
get_constrained_function_context_and_helper_functions,
)
@ -43,7 +43,7 @@ def function_to_optimize3(data: dict[CustomDataClass, list[CustomDataClass]]) ->
def test_function_context_includes_type_annotation() -> None:
file_path = pathlib.Path(__file__).resolve()
a, dependent_functions = get_constrained_function_context_and_dependent_functions(
a, helper_functions = get_constrained_function_context_and_helper_functions(
FunctionToOptimize("function_to_optimize", str(file_path), []),
str(file_path.parent.resolve()),
"""def function_to_optimize(data: CustomType):
@ -53,13 +53,13 @@ def test_function_context_includes_type_annotation() -> None:
1000,
)
assert len(dependent_functions) == 1
assert dependent_functions[0][0].full_name == "CustomType"
assert len(helper_functions) == 1
assert helper_functions[0][0].full_name == "CustomType"
def test_function_context_includes_type_annotation_dataclass() -> None:
file_path = pathlib.Path(__file__).resolve()
a, dependent_functions = get_constrained_function_context_and_dependent_functions(
a, helper_functions = get_constrained_function_context_and_helper_functions(
FunctionToOptimize("function_to_optimize2", str(file_path), []),
str(file_path.parent.resolve()),
"""def function_to_optimize2(data: CustomDataClass) -> CustomType:
@ -69,14 +69,14 @@ def test_function_context_includes_type_annotation_dataclass() -> None:
1000,
)
assert len(dependent_functions) == 2
assert dependent_functions[0][0].full_name == "CustomDataClass"
assert dependent_functions[1][0].full_name == "CustomType"
assert len(helper_functions) == 2
assert helper_functions[0][0].full_name == "CustomDataClass"
assert helper_functions[1][0].full_name == "CustomType"
def test_function_context_works_for_composite_types() -> None:
file_path = pathlib.Path(__file__).resolve()
a, dependent_functions = get_constrained_function_context_and_dependent_functions(
a, helper_functions = get_constrained_function_context_and_helper_functions(
FunctionToOptimize("function_to_optimize3", str(file_path), []),
str(file_path.parent.resolve()),
"""def function_to_optimize3(data: set[CustomDataClass[CustomDataClass, int]]) -> list[CustomType]:
@ -86,9 +86,9 @@ def test_function_context_works_for_composite_types() -> None:
1000,
)
assert len(dependent_functions) == 2
assert dependent_functions[0][0].full_name == "CustomDataClass"
assert dependent_functions[1][0].full_name == "CustomType"
assert len(helper_functions) == 2
assert helper_functions[0][0].full_name == "CustomDataClass"
assert helper_functions[1][0].full_name == "CustomType"
def test_function_context_custom_datatype() -> None:
@ -99,12 +99,12 @@ def test_function_context_custom_datatype() -> None:
)
assert code is not None
assert contextual_dunder_methods == set()
a, dependent_functions = get_constrained_function_context_and_dependent_functions(
a, helper_functions = get_constrained_function_context_and_helper_functions(
FunctionToOptimize("cosine_similarity", str(file_path), []),
str(project_path),
code,
1000,
)
assert len(dependent_functions) == 1
assert dependent_functions[0][0].full_name == "Matrix"
assert len(helper_functions) == 1
assert helper_functions[0][0].full_name == "Matrix"