Implemented testgen context retrieval. Context retrieved is the union of read-writable code and read-only code. Did some refactors to remove code_to_optimize_with_helpers, and updated tests.
This commit is contained in:
parent
680d0da5eb
commit
17a42a218c
18 changed files with 1571 additions and 947 deletions
|
|
@ -1,6 +1,3 @@
|
|||
def sorter(arr):
|
||||
arr.sort()
|
||||
return arr
|
||||
|
||||
|
||||
CACHED_TESTS = "import unittest\ndef sorter(arr):\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n return arr\nclass SorterTestCase(unittest.TestCase):\n def test_empty_list(self):\n self.assertEqual(sorter([]), [])\n def test_single_element_list(self):\n self.assertEqual(sorter([5]), [5])\n def test_ascending_order_list(self):\n self.assertEqual(sorter([1, 2, 3, 4, 5]), [1, 2, 3, 4, 5])\n def test_descending_order_list(self):\n self.assertEqual(sorter([5, 4, 3, 2, 1]), [1, 2, 3, 4, 5])\n def test_random_order_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 5]), [1, 2, 3, 4, 5])\n def test_duplicate_elements_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 2, 5, 1]), [1, 1, 2, 2, 3, 4, 5])\n def test_negative_numbers_list(self):\n self.assertEqual(sorter([-5, -2, -8, -1, -3]), [-8, -5, -3, -2, -1])\n def test_mixed_data_types_list(self):\n self.assertEqual(sorter(['apple', 2, 'banana', 1, 'cherry']), [1, 2, 'apple', 'banana', 'cherry'])\n def test_large_input_list(self):\n self.assertEqual(sorter(list(range(1000, 0, -1))), list(range(1, 1001)))\n def test_list_with_none_values(self):\n self.assertEqual(sorter([None, 2, None, 1, None]), [None, None, None, 1, 2])\n def test_list_with_nan_values(self):\n self.assertEqual(sorter([float('nan'), 2, float('nan'), 1, float('nan')]), [1, 2, float('nan'), float('nan'), float('nan')])\n def test_list_with_complex_numbers(self):\n self.assertEqual(sorter([3 + 2j, 1 + 1j, 4 + 3j, 2 + 1j, 5 + 4j]), [1 + 1j, 2 + 1j, 3 + 2j, 4 + 3j, 5 + 4j])\n def test_list_with_custom_class_objects(self):\n class Person:\n def __init__(self, name, age):\n self.name = name\n self.age = age\n def __repr__(self):\n return f\"Person('{self.name}', {self.age})\"\n input_list = [Person('Alice', 25), Person('Bob', 30), Person('Charlie', 20)]\n expected_output = [Person('Charlie', 20), Person('Alice', 25), Person('Bob', 30)]\n self.assertEqual(sorter(input_list), expected_output)\n def test_list_with_uncomparable_elements(self):\n with self.assertRaises(TypeError):\n sorter([5, 'apple', 3, [1, 2, 3], 2])\n def test_list_with_custom_comparison_function(self):\n input_list = [5, 4, 3, 2, 1]\n expected_output = [5, 4, 3, 2, 1]\n self.assertEqual(sorter(input_list, reverse=True), expected_output)\nif __name__ == '__main__':\n unittest.main()"
|
||||
return arr
|
||||
|
|
@ -9,142 +9,3 @@ def sorter_deps(arr):
|
|||
dep2_swap(arr, j)
|
||||
return arr
|
||||
|
||||
|
||||
CACHED_TESTS = """import dill as pickle
|
||||
import os
|
||||
def _log__test__values(values, duration, test_name):
|
||||
iteration = os.environ["CODEFLASH_TEST_ITERATION"]
|
||||
with open(os.path.join(
|
||||
'/var/folders/ms/1tz2l1q55w5b7pp4wpdkbjq80000gn/T/codeflash_jk4pzz3w/',
|
||||
f'test_return_values_{iteration}.bin'), 'ab') as f:
|
||||
return_bytes = pickle.dumps(values)
|
||||
_test_name = f"{test_name}".encode("ascii")
|
||||
f.write(len(_test_name).to_bytes(4, byteorder='big'))
|
||||
f.write(_test_name)
|
||||
f.write(duration.to_bytes(8, byteorder='big'))
|
||||
f.write(len(return_bytes).to_bytes(4, byteorder='big'))
|
||||
f.write(return_bytes)
|
||||
import time
|
||||
import gc
|
||||
from code_to_optimize.bubble_sort_deps import sorter_deps
|
||||
import timeout_decorator
|
||||
import unittest
|
||||
|
||||
def dep1_comparer(arr, j: int) -> bool:
|
||||
return arr[j] > arr[j + 1]
|
||||
|
||||
def dep2_swap(arr, j):
|
||||
temp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
|
||||
class TestSorterDeps(unittest.TestCase):
|
||||
|
||||
@timeout_decorator.timeout(15, use_signals=True)
|
||||
def test_integers(self):
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter_deps([5, 3, 2, 4, 1])
|
||||
duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
_log__test__values(
|
||||
return_value, duration,
|
||||
'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_integers:sorter_deps:0')
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter_deps([10, -3, 0, 2, 7])
|
||||
duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
_log__test__values(
|
||||
return_value, duration,
|
||||
('code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:'
|
||||
'TestSorterDeps.test_integers:sorter_deps:1'))
|
||||
|
||||
@timeout_decorator.timeout(15, use_signals=True)
|
||||
def test_floats(self):
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter_deps([3.2, 1.5, 2.7, 4.1, 1.0])
|
||||
duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration,
|
||||
'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_floats:sorter_deps:0')
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter_deps([-1.1, 0.0, 3.14, 2.71, -0.5])
|
||||
duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration,
|
||||
'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_floats:sorter_deps:1')
|
||||
|
||||
@timeout_decorator.timeout(15, use_signals=True)
|
||||
def test_identical_elements(self):
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter_deps([1, 1, 1, 1, 1])
|
||||
duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration,
|
||||
('code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:'
|
||||
'TestSorterDeps.test_identical_elements:sorter_deps:0'))
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter_deps([3.14, 3.14, 3.14])
|
||||
duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration,
|
||||
('code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:'
|
||||
'TestSorterDeps.test_identical_elements:sorter_deps:1'))
|
||||
|
||||
@timeout_decorator.timeout(15, use_signals=True)
|
||||
def test_single_element(self):
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter_deps([5])
|
||||
duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_single_element:sorter_deps:0')
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter_deps([-3.2])
|
||||
duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_single_element:sorter_deps:1')
|
||||
|
||||
@timeout_decorator.timeout(15, use_signals=True)
|
||||
def test_empty_array(self):
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter_deps([])
|
||||
duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_empty_array:sorter_deps:0')
|
||||
|
||||
@timeout_decorator.timeout(15, use_signals=True)
|
||||
def test_strings(self):
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter_deps(['apple', 'banana', 'cherry', 'date'])
|
||||
duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_strings:sorter_deps:0')
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter_deps(['dog', 'cat', 'elephant', 'ant'])
|
||||
duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_strings:sorter_deps:1')
|
||||
|
||||
@timeout_decorator.timeout(15, use_signals=True)
|
||||
def test_mixed_types(self):
|
||||
with self.assertRaises(TypeError):
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter_deps([1, 'two', 3.0, 'four'])
|
||||
duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_mixed_types:sorter_deps:0_0')
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -5,7 +5,4 @@ def sorter(arr):
|
|||
temp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
return arr
|
||||
|
||||
|
||||
CACHED_TESTS = "import unittest\ndef sorter(arr):\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n return arr\nclass SorterTestCase(unittest.TestCase):\n def test_empty_list(self):\n self.assertEqual(sorter([]), [])\n def test_single_element_list(self):\n self.assertEqual(sorter([5]), [5])\n def test_ascending_order_list(self):\n self.assertEqual(sorter([1, 2, 3, 4, 5]), [1, 2, 3, 4, 5])\n def test_descending_order_list(self):\n self.assertEqual(sorter([5, 4, 3, 2, 1]), [1, 2, 3, 4, 5])\n def test_random_order_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 5]), [1, 2, 3, 4, 5])\n def test_duplicate_elements_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 2, 5, 1]), [1, 1, 2, 2, 3, 4, 5])\n def test_negative_numbers_list(self):\n self.assertEqual(sorter([-5, -2, -8, -1, -3]), [-8, -5, -3, -2, -1])\n def test_mixed_data_types_list(self):\n self.assertEqual(sorter(['apple', 2, 'banana', 1, 'cherry']), [1, 2, 'apple', 'banana', 'cherry'])\n def test_large_input_list(self):\n self.assertEqual(sorter(list(range(1000, 0, -1))), list(range(1, 1001)))\n def test_list_with_none_values(self):\n self.assertEqual(sorter([None, 2, None, 1, None]), [None, None, None, 1, 2])\n def test_list_with_nan_values(self):\n self.assertEqual(sorter([float('nan'), 2, float('nan'), 1, float('nan')]), [1, 2, float('nan'), float('nan'), float('nan')])\n def test_list_with_complex_numbers(self):\n self.assertEqual(sorter([3 + 2j, 1 + 1j, 4 + 3j, 2 + 1j, 5 + 4j]), [1 + 1j, 2 + 1j, 3 + 2j, 4 + 3j, 5 + 4j])\n def test_list_with_custom_class_objects(self):\n class Person:\n def __init__(self, name, age):\n self.name = name\n self.age = age\n def __repr__(self):\n return f\"Person('{self.name}', {self.age})\"\n input_list = [Person('Alice', 25), Person('Bob', 30), Person('Charlie', 20)]\n expected_output = [Person('Charlie', 20), Person('Alice', 25), Person('Bob', 30)]\n self.assertEqual(sorter(input_list), expected_output)\n def test_list_with_uncomparable_elements(self):\n with self.assertRaises(TypeError):\n sorter([5, 'apple', 3, [1, 2, 3], 2])\n def test_list_with_custom_comparison_function(self):\n input_list = [5, 4, 3, 2, 1]\n expected_output = [5, 4, 3, 2, 1]\n self.assertEqual(sorter(input_list, reverse=True), expected_output)\nif __name__ == '__main__':\n unittest.main()"
|
||||
return arr
|
||||
|
|
@ -9,110 +9,4 @@ def use_cosine_similarity(
|
|||
top_k: Optional[int] = 5,
|
||||
score_threshold: Optional[float] = None,
|
||||
) -> Tuple[List[Tuple[int, int]], List[float]]:
|
||||
return cosine_similarity_top_k(X, Y, top_k, score_threshold)
|
||||
|
||||
|
||||
CACHED_TESTS = """import unittest
|
||||
import numpy as np
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from typing import List, Optional, Tuple, Union
|
||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||
def cosine_similarity_top_k(X: Matrix, Y: Matrix, top_k: Optional[int]=5, score_threshold: Optional[float]=None) -> Tuple[List[Tuple[int, int]], List[float]]:
|
||||
\"\"\"Row-wise cosine similarity with optional top-k and score threshold filtering.
|
||||
Args:
|
||||
X: Matrix.
|
||||
Y: Matrix, same width as X.
|
||||
top_k: Max number of results to return.
|
||||
score_threshold: Minimum cosine similarity of results.
|
||||
Returns:
|
||||
Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx),
|
||||
second contains corresponding cosine similarities.
|
||||
\"\"\"
|
||||
if len(X) == 0 or len(Y) == 0:
|
||||
return ([], [])
|
||||
score_array = cosine_similarity(X, Y)
|
||||
sorted_idxs = score_array.flatten().argsort()[::-1]
|
||||
top_k = top_k or len(sorted_idxs)
|
||||
top_idxs = sorted_idxs[:top_k]
|
||||
score_threshold = score_threshold or -1.0
|
||||
top_idxs = top_idxs[score_array.flatten()[top_idxs] > score_threshold]
|
||||
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in top_idxs]
|
||||
scores = score_array.flatten()[top_idxs].tolist()
|
||||
return (ret_idxs, scores)
|
||||
def use_cosine_similarity(X: Matrix, Y: Matrix, top_k: Optional[int]=5, score_threshold: Optional[float]=None) -> Tuple[List[Tuple[int, int]], List[float]]:
|
||||
return cosine_similarity_top_k(X, Y, top_k, score_threshold)
|
||||
class TestUseCosineSimilarity(unittest.TestCase):
|
||||
def test_normal_scenario(self):
|
||||
X = [[1, 2, 3], [4, 5, 6]]
|
||||
Y = [[7, 8, 9], [10, 11, 12]]
|
||||
result = use_cosine_similarity(X, Y, top_k=1, score_threshold=0.5)
|
||||
self.assertEqual(result, ([(0, 1)], [0.9746318461970762]))
|
||||
def test_edge_case_empty_matrices(self):
|
||||
X = []
|
||||
Y = []
|
||||
result = use_cosine_similarity(X, Y)
|
||||
self.assertEqual(result, ([], []))
|
||||
def test_edge_case_different_widths(self):
|
||||
X = [[1, 2, 3]]
|
||||
Y = [[4, 5]]
|
||||
with self.assertRaises(ValueError):
|
||||
use_cosine_similarity(X, Y)
|
||||
def test_edge_case_negative_top_k(self):
|
||||
X = [[1, 2, 3]]
|
||||
Y = [[4, 5, 6]]
|
||||
with self.assertRaises(IndexError):
|
||||
use_cosine_similarity(X, Y, top_k=-1)
|
||||
def test_edge_case_zero_top_k(self):
|
||||
X = [[1, 2, 3]]
|
||||
Y = [[4, 5, 6]]
|
||||
result = use_cosine_similarity(X, Y, top_k=0)
|
||||
self.assertEqual(result, ([], []))
|
||||
def test_edge_case_negative_score_threshold(self):
|
||||
X = [[1, 2, 3]]
|
||||
Y = [[4, 5, 6]]
|
||||
result = use_cosine_similarity(X, Y, score_threshold=-1.0)
|
||||
self.assertEqual(result, ([(0, 0)], [0.9746318461970762]))
|
||||
def test_edge_case_large_score_threshold(self):
|
||||
X = [[1, 2, 3]]
|
||||
Y = [[4, 5, 6]]
|
||||
result = use_cosine_similarity(X, Y, score_threshold=2.0)
|
||||
self.assertEqual(result, ([], []))
|
||||
def test_exceptional_case_non_matrix_X(self):
|
||||
X = [1, 2, 3]
|
||||
Y = [[4, 5, 6]]
|
||||
with self.assertRaises(ValueError):
|
||||
use_cosine_similarity(X, Y)
|
||||
def test_exceptional_case_non_integer_top_k(self):
|
||||
X = [[1, 2, 3]]
|
||||
Y = [[4, 5, 6]]
|
||||
with self.assertRaises(TypeError):
|
||||
use_cosine_similarity(X, Y, top_k='5')
|
||||
def test_exceptional_case_non_float_score_threshold(self):
|
||||
X = [[1, 2, 3]]
|
||||
Y = [[4, 5, 6]]
|
||||
with self.assertRaises(TypeError):
|
||||
use_cosine_similarity(X, Y, score_threshold='0.5')
|
||||
def test_special_values_nan_in_matrices(self):
|
||||
X = [[1, 2, np.nan]]
|
||||
Y = [[4, 5, 6]]
|
||||
with self.assertRaises(ValueError):
|
||||
use_cosine_similarity(X, Y)
|
||||
def test_special_values_none_top_k(self):
|
||||
X = [[1, 2, 3]]
|
||||
Y = [[4, 5, 6]]
|
||||
result = use_cosine_similarity(X, Y, top_k=None)
|
||||
self.assertEqual(result, ([(0, 0)], [0.9746318461970762]))
|
||||
def test_special_values_none_score_threshold(self):
|
||||
X = [[1, 2, 3]]
|
||||
Y = [[4, 5, 6]]
|
||||
result = use_cosine_similarity(X, Y, score_threshold=None)
|
||||
self.assertEqual(result, ([(0, 0)], [0.9746318461970762]))
|
||||
def test_large_inputs(self):
|
||||
X = np.random.rand(1000, 1000)
|
||||
Y = np.random.rand(1000, 1000)
|
||||
result = use_cosine_similarity(X, Y, top_k=10, score_threshold=0.5)
|
||||
self.assertEqual(len(result[0]), 10)
|
||||
self.assertEqual(len(result[1]), 10)
|
||||
self.assertTrue(all((score > 0.5 for score in result[1])))
|
||||
if __name__ == '__main__':
|
||||
unittest.main()"""
|
||||
return cosine_similarity_top_k(X, Y, top_k, score_threshold)
|
||||
|
|
@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|||
|
||||
def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]:
|
||||
"""Extract the single dependent function from the code context excluding the main function."""
|
||||
ast_tree = ast.parse(code_context.code_to_optimize_with_helpers)
|
||||
ast_tree = ast.parse(code_context.testgen_context_code)
|
||||
|
||||
dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,12 +15,13 @@ from codeflash.cli_cmds.console import logger
|
|||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
|
||||
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource, \
|
||||
CodeContextType
|
||||
from codeflash.optimization.function_context import belongs_to_function_qualified
|
||||
|
||||
|
||||
def get_code_optimization_context(
|
||||
function_to_optimize: FunctionToOptimize, project_root_path: Path, token_limit: int = 8000
|
||||
function_to_optimize: FunctionToOptimize, project_root_path: Path, optim_token_limit: int = 8000, testgen_token_limit: int = 8000
|
||||
) -> CodeOptimizationContext:
|
||||
# Get qualified names and fully qualified names(fqn) of helpers
|
||||
helpers_of_fto, helpers_of_fto_fqn, helpers_of_fto_obj_list = get_file_path_to_helper_functions_dict(
|
||||
|
|
@ -37,21 +38,22 @@ def get_code_optimization_context(
|
|||
function_to_optimize.qualified_name_with_modules_from_root(project_root_path)
|
||||
)
|
||||
|
||||
# Extract code
|
||||
final_read_writable_code = get_all_read_writable_code(helpers_of_fto, helpers_of_fto_fqn, project_root_path).code
|
||||
read_only_code_markdown = get_all_read_only_code_context(
|
||||
# Extract code context for optimization
|
||||
final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto, helpers_of_fto_fqn, project_root_path).code
|
||||
read_only_code_markdown = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto,
|
||||
helpers_of_fto_fqn,
|
||||
helpers_of_helpers,
|
||||
helpers_of_helpers_fqn,
|
||||
project_root_path,
|
||||
remove_docstrings=False,
|
||||
code_context_type=CodeContextType.READ_ONLY,
|
||||
)
|
||||
|
||||
# Handle token limits
|
||||
tokenizer = tiktoken.encoding_for_model("gpt-4o")
|
||||
final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
|
||||
if final_read_writable_tokens > token_limit:
|
||||
if final_read_writable_tokens > optim_token_limit:
|
||||
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
|
||||
|
||||
# Setup preexisting objects for code replacer TODO: should remove duplicates
|
||||
|
|
@ -61,53 +63,82 @@ def get_code_optimization_context(
|
|||
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),
|
||||
)
|
||||
)
|
||||
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown))
|
||||
read_only_context_code = read_only_code_markdown.markdown
|
||||
|
||||
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_context_code))
|
||||
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
|
||||
if total_tokens <= token_limit:
|
||||
return CodeOptimizationContext(
|
||||
code_to_optimize_with_helpers="",
|
||||
read_writable_code=CodeString(code=final_read_writable_code).code,
|
||||
read_only_context_code=read_only_code_markdown.markdown,
|
||||
helper_functions=helpers_of_fto_obj_list,
|
||||
preexisting_objects=preexisting_objects,
|
||||
|
||||
if total_tokens > optim_token_limit:
|
||||
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
|
||||
# Extract read only code without docstrings
|
||||
read_only_code_no_docstring_markdown = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto,
|
||||
helpers_of_fto_fqn,
|
||||
helpers_of_helpers,
|
||||
helpers_of_helpers_fqn,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
)
|
||||
|
||||
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
|
||||
|
||||
# Extract read only code without docstrings
|
||||
read_only_code_no_docstring_markdown = get_all_read_only_code_context(
|
||||
read_only_context_code = read_only_code_no_docstring_markdown.markdown
|
||||
read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_context_code))
|
||||
total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens
|
||||
if total_tokens > optim_token_limit:
|
||||
logger.debug("Code context has exceeded token limit, removing read-only code")
|
||||
read_only_context_code = ""
|
||||
# Extract code context for testgen
|
||||
testgen_code_markdown = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto,
|
||||
helpers_of_fto_fqn,
|
||||
helpers_of_helpers,
|
||||
helpers_of_helpers_fqn,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
remove_docstrings=False,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
)
|
||||
read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_code_no_docstring_markdown.markdown))
|
||||
total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens
|
||||
if total_tokens <= token_limit:
|
||||
return CodeOptimizationContext(
|
||||
code_to_optimize_with_helpers="",
|
||||
read_writable_code=CodeString(code=final_read_writable_code).code,
|
||||
read_only_context_code=read_only_code_no_docstring_markdown.markdown,
|
||||
helper_functions=helpers_of_fto_obj_list,
|
||||
preexisting_objects=preexisting_objects,
|
||||
testgen_context_code = testgen_code_markdown.markdown
|
||||
testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code))
|
||||
if testgen_context_code_tokens > testgen_token_limit:
|
||||
testgen_code_markdown = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto,
|
||||
helpers_of_fto_fqn,
|
||||
helpers_of_helpers,
|
||||
helpers_of_helpers_fqn,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
)
|
||||
testgen_context_code = testgen_code_markdown.markdown
|
||||
testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code))
|
||||
if testgen_context_code_tokens > testgen_token_limit:
|
||||
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
|
||||
|
||||
logger.debug("Code context has exceeded token limit, removing read-only code")
|
||||
return CodeOptimizationContext(
|
||||
code_to_optimize_with_helpers="",
|
||||
testgen_context_code = testgen_context_code,
|
||||
read_writable_code=CodeString(code=final_read_writable_code).code,
|
||||
read_only_context_code="",
|
||||
read_only_context_code=read_only_context_code,
|
||||
helper_functions=helpers_of_fto_obj_list,
|
||||
preexisting_objects=preexisting_objects,
|
||||
)
|
||||
|
||||
|
||||
def get_all_read_writable_code(
|
||||
def extract_code_string_context_from_files(
|
||||
helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path
|
||||
) -> CodeString:
|
||||
"""Extract read-writable code context from files containing target functions and their helpers.
|
||||
|
||||
This function iterates through each file path that contains functions to optimize (fto) or
|
||||
their first-degree helpers, reads the original code, extracts relevant parts using CST parsing,
|
||||
and adds necessary imports from the original modules.
|
||||
|
||||
Args:
|
||||
helpers_of_fto: Dictionary mapping file paths to sets of qualified function names
|
||||
helpers_of_fto_fqn: Dictionary mapping file paths to sets of fully qualified names of functions
|
||||
project_root_path: Root path of the project for resolving relative imports
|
||||
|
||||
Returns:
|
||||
CodeString object containing the consolidated read-writable code with all necessary
|
||||
imports for the target functions and their helpers
|
||||
|
||||
"""
|
||||
final_read_writable_code = ""
|
||||
# Extract code from file paths that contain fto and first degree helpers
|
||||
for file_path, qualified_function_names in helpers_of_fto.items():
|
||||
|
|
@ -117,7 +148,7 @@ def get_all_read_writable_code(
|
|||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
read_writable_code = get_read_writable_code(original_code, qualified_function_names)
|
||||
read_writable_code = parse_code_and_prune_cst(original_code, CodeContextType.READ_WRITABLE, qualified_function_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-writable code: {e}")
|
||||
continue
|
||||
|
|
@ -133,16 +164,38 @@ def get_all_read_writable_code(
|
|||
helper_functions_fqn=helpers_of_fto_fqn[file_path],
|
||||
)
|
||||
return CodeString(code=final_read_writable_code)
|
||||
|
||||
|
||||
def get_all_read_only_code_context(
|
||||
def extract_code_markdown_context_from_files(
|
||||
helpers_of_fto: dict[Path, set[str]],
|
||||
helpers_of_fto_fqn: dict[Path, set[str]],
|
||||
helpers_of_helpers: dict[Path, set[str]],
|
||||
helpers_of_helpers_fqn: dict[Path, set[str]],
|
||||
project_root_path: Path,
|
||||
remove_docstrings: bool = False,
|
||||
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
|
||||
) -> CodeStringsMarkdown:
|
||||
"""Extract code context from files containing target functions and their helpers, formatting them as markdown.
|
||||
|
||||
This function processes two sets of files:
|
||||
1. Files containing the function to optimize (fto) and their first-degree helpers
|
||||
2. Files containing only helpers of helpers (with no overlap with the first set)
|
||||
|
||||
For each file, it extracts relevant code based on the specified context type, adds necessary
|
||||
imports, and combines them into a structured markdown format.
|
||||
|
||||
Args:
|
||||
helpers_of_fto: Dictionary mapping file paths to sets of function names to be optimized
|
||||
helpers_of_fto_fqn: Dictionary mapping file paths to sets of fully qualified names of functions to be optimized
|
||||
helpers_of_helpers: Dictionary mapping file paths to sets of helper function names
|
||||
helpers_of_helpers_fqn: Dictionary mapping file paths to sets of fully qualified names of helper functions
|
||||
project_root_path: Root path of the project
|
||||
remove_docstrings: Whether to remove docstrings from the extracted code
|
||||
code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN)
|
||||
|
||||
Returns:
|
||||
CodeStringsMarkdown containing the extracted code context with necessary imports,
|
||||
formatted for inclusion in markdown
|
||||
|
||||
"""
|
||||
# Rearrange to remove overlaps, so we only access each file path once
|
||||
helpers_of_helpers_no_overlap = defaultdict(set)
|
||||
helpers_of_helpers_no_overlap_fqn = defaultdict(set)
|
||||
|
|
@ -155,7 +208,7 @@ def get_all_read_only_code_context(
|
|||
helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path]
|
||||
helpers_of_helpers_no_overlap_fqn[file_path] = helpers_of_helpers_fqn[file_path]
|
||||
|
||||
read_only_code_markdown = CodeStringsMarkdown()
|
||||
code_context_markdown = CodeStringsMarkdown()
|
||||
# Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files
|
||||
for file_path, qualified_function_names in helpers_of_fto.items():
|
||||
try:
|
||||
|
|
@ -164,17 +217,18 @@ def get_all_read_only_code_context(
|
|||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
read_only_code = get_read_only_code(
|
||||
original_code, qualified_function_names, helpers_of_helpers.get(file_path, set()), remove_docstrings
|
||||
code_context = parse_code_and_prune_cst(
|
||||
original_code, code_context_type, qualified_function_names, helpers_of_helpers.get(file_path, set()), remove_docstrings
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
if read_only_code.strip():
|
||||
read_only_code_with_imports = CodeString(
|
||||
if code_context.strip():
|
||||
code_context_with_imports = CodeString(
|
||||
code=add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=read_only_code,
|
||||
dst_module_code=code_context,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
|
|
@ -182,7 +236,7 @@ def get_all_read_only_code_context(
|
|||
),
|
||||
file_path=file_path.relative_to(project_root_path),
|
||||
)
|
||||
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
|
||||
code_context_markdown.code_strings.append(code_context_with_imports)
|
||||
|
||||
# Extract code from file paths containing helpers of helpers
|
||||
for file_path, qualified_helper_function_names in helpers_of_helpers_no_overlap.items():
|
||||
|
|
@ -192,18 +246,18 @@ def get_all_read_only_code_context(
|
|||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
read_only_code = get_read_only_code(
|
||||
original_code, set(), qualified_helper_function_names, remove_docstrings
|
||||
code_context = parse_code_and_prune_cst(
|
||||
original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
||||
if read_only_code.strip():
|
||||
read_only_code_with_imports = CodeString(
|
||||
if code_context.strip():
|
||||
code_context_with_imports = CodeString(
|
||||
code=add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=read_only_code,
|
||||
dst_module_code=code_context,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
|
|
@ -211,8 +265,8 @@ def get_all_read_only_code_context(
|
|||
),
|
||||
file_path=file_path.relative_to(project_root_path),
|
||||
)
|
||||
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
|
||||
return read_only_code_markdown
|
||||
code_context_markdown.code_strings.append(code_context_with_imports)
|
||||
return code_context_markdown
|
||||
|
||||
|
||||
def get_file_path_to_helper_functions_dict(
|
||||
|
|
@ -221,11 +275,11 @@ def get_file_path_to_helper_functions_dict(
|
|||
file_path_to_helper_function_qualified_names = defaultdict(set)
|
||||
file_path_to_helper_function_fqn = defaultdict(set)
|
||||
function_source_list: list[FunctionSource] = []
|
||||
for file_path in file_path_to_qualified_function_names:
|
||||
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
|
||||
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
|
||||
file_refs = script.get_names(all_scopes=True, definitions=False, references=True)
|
||||
|
||||
for qualified_function_name in file_path_to_qualified_function_names[file_path]:
|
||||
for qualified_function_name in qualified_function_names:
|
||||
names = [
|
||||
ref
|
||||
for ref in file_refs
|
||||
|
|
@ -291,6 +345,29 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode
|
|||
return indented_block.with_changes(body=indented_block.body[1:])
|
||||
return indented_block
|
||||
|
||||
def parse_code_and_prune_cst(
|
||||
code: str, code_context_type: CodeContextType, target_functions: set[str], helpers_of_helper_functions: set[str] = {}, remove_docstrings: bool = False
|
||||
) -> str:
|
||||
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables. """
|
||||
module = cst.parse_module(code)
|
||||
if code_context_type == CodeContextType.READ_WRITABLE:
|
||||
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
|
||||
elif code_context_type == CodeContextType.READ_ONLY:
|
||||
filtered_node, found_target = prune_cst_for_read_only_code(
|
||||
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
|
||||
)
|
||||
elif code_context_type == CodeContextType.TESTGEN:
|
||||
filtered_node, found_target = prune_cst_for_testgen_code(
|
||||
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown code_context_type: {code_context_type}")
|
||||
|
||||
if not found_target:
|
||||
raise ValueError("No target functions found in the provided code")
|
||||
if filtered_node and isinstance(filtered_node, cst.Module):
|
||||
return str(filtered_node.code)
|
||||
return ""
|
||||
|
||||
def prune_cst_for_read_writable_code(
|
||||
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
|
||||
|
|
@ -371,20 +448,6 @@ def prune_cst_for_read_writable_code(
|
|||
|
||||
return (node.with_changes(**updates) if updates else node), True
|
||||
|
||||
|
||||
def get_read_writable_code(code: str, target_functions: set[str]) -> str:
|
||||
"""Creates a read-writable code string by parsing and filtering the code to keep only
|
||||
target functions and the minimal surrounding structure.
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
|
||||
if not found_target:
|
||||
raise ValueError("No target functions found in the provided code")
|
||||
if filtered_node and isinstance(filtered_node, cst.Module):
|
||||
return str(filtered_node.code)
|
||||
return ""
|
||||
|
||||
|
||||
def prune_cst_for_read_only_code(
|
||||
node: cst.CSTNode,
|
||||
target_functions: set[str],
|
||||
|
|
@ -489,18 +552,107 @@ def prune_cst_for_read_only_code(
|
|||
return None, False
|
||||
|
||||
|
||||
def get_read_only_code(
|
||||
code: str, target_functions: set[str], helpers_of_helper_functions: set[str], remove_docstrings: bool = False
|
||||
) -> str:
|
||||
"""Creates a read-only version of the code by parsing and filtering the code to keep only
|
||||
class contextual information, and other module scoped variables.
|
||||
|
||||
def prune_cst_for_testgen_code(
|
||||
node: cst.CSTNode,
|
||||
target_functions: set[str],
|
||||
helpers_of_helper_functions: set[str],
|
||||
prefix: str = "",
|
||||
remove_docstrings: bool = False,
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node for testgen context:
|
||||
|
||||
Returns:
|
||||
(filtered_node, found_target):
|
||||
filtered_node: The modified CST node or None if it should be removed.
|
||||
found_target: True if a target function was found in this node's subtree.
|
||||
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
filtered_node, found_target = prune_cst_for_read_only_code(
|
||||
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
|
||||
)
|
||||
if not found_target:
|
||||
raise ValueError("No target functions found in the provided code")
|
||||
if filtered_node and isinstance(filtered_node, cst.Module):
|
||||
return str(filtered_node.code)
|
||||
return ""
|
||||
if isinstance(node, (cst.Import, cst.ImportFrom)):
|
||||
return None, False
|
||||
|
||||
if isinstance(node, cst.FunctionDef):
|
||||
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
# If it's a target function, remove it but mark found_target = True
|
||||
if qualified_name in helpers_of_helper_functions or qualified_name in target_functions:
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
new_body = remove_docstring_from_body(node.body)
|
||||
return node.with_changes(body=new_body), True
|
||||
return node, True
|
||||
# Keep all dunder methods
|
||||
if is_dunder_method(node.name.value):
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
new_body = remove_docstring_from_body(node.body)
|
||||
return node.with_changes(body=new_body), False
|
||||
return node, False
|
||||
return None, False
|
||||
|
||||
if isinstance(node, cst.ClassDef):
|
||||
# Do not recurse into nested classes
|
||||
if prefix:
|
||||
return None, False
|
||||
# Assuming always an IndentedBlock
|
||||
if not isinstance(node.body, cst.IndentedBlock):
|
||||
raise ValueError("ClassDef body is not an IndentedBlock")
|
||||
|
||||
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
|
||||
# First pass: detect if there is a target function in the class
|
||||
found_in_class = False
|
||||
new_class_body: list[CSTNode] = []
|
||||
for stmt in node.body.body:
|
||||
filtered, found_target = prune_cst_for_testgen_code(
|
||||
stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings
|
||||
)
|
||||
found_in_class |= found_target
|
||||
if filtered:
|
||||
new_class_body.append(filtered)
|
||||
|
||||
if not found_in_class:
|
||||
return None, False
|
||||
|
||||
if remove_docstrings:
|
||||
return node.with_changes(
|
||||
body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))
|
||||
) if new_class_body else None, True
|
||||
return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True
|
||||
|
||||
# For other nodes, keep the node and recursively filter children
|
||||
section_names = get_section_names(node)
|
||||
if not section_names:
|
||||
return node, False
|
||||
|
||||
updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
|
||||
found_any_target = False
|
||||
|
||||
for section in section_names:
|
||||
original_content = getattr(node, section, None)
|
||||
if isinstance(original_content, (list, tuple)):
|
||||
new_children = []
|
||||
section_found_target = False
|
||||
for child in original_content:
|
||||
filtered, found_target = prune_cst_for_testgen_code(
|
||||
child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings
|
||||
)
|
||||
if filtered:
|
||||
new_children.append(filtered)
|
||||
section_found_target |= found_target
|
||||
|
||||
if section_found_target or new_children:
|
||||
found_any_target |= section_found_target
|
||||
updates[section] = new_children
|
||||
elif original_content is not None:
|
||||
filtered, found_target = prune_cst_for_testgen_code(
|
||||
original_content,
|
||||
target_functions,
|
||||
helpers_of_helper_functions,
|
||||
prefix,
|
||||
remove_docstrings=remove_docstrings,
|
||||
)
|
||||
found_any_target |= found_target
|
||||
if filtered:
|
||||
updates[section] = filtered
|
||||
if updates:
|
||||
return (node.with_changes(**updates), found_any_target)
|
||||
|
||||
return None, False
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ class CodeStringsMarkdown(BaseModel):
|
|||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
"""Returns the markdown representation of the code, including the file path where possible."""
|
||||
return "\n".join(
|
||||
[
|
||||
f"```python{':' + str(code_string.file_path) if code_string.file_path else ''}\n{code_string.code.strip()}\n```"
|
||||
|
|
@ -81,12 +82,18 @@ class CodeStringsMarkdown(BaseModel):
|
|||
|
||||
|
||||
class CodeOptimizationContext(BaseModel):
|
||||
code_to_optimize_with_helpers: str
|
||||
# code_to_optimize_with_helpers: str
|
||||
testgen_context_code: str = ""
|
||||
read_writable_code: str = Field(min_length=1)
|
||||
read_only_context_code: str = ""
|
||||
helper_functions: list[FunctionSource]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]]
|
||||
|
||||
class CodeContextType(str, Enum):
|
||||
READ_WRITABLE = "READ_WRITABLE"
|
||||
READ_ONLY = "READ_ONLY"
|
||||
TESTGEN = "TESTGEN"
|
||||
|
||||
|
||||
class OptimizedCandidateResult(BaseModel):
|
||||
max_loop_count: int
|
||||
|
|
|
|||
|
|
@ -60,240 +60,240 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b
|
|||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def get_type_annotation_context(
|
||||
function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path
|
||||
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
|
||||
function_name: str = function.function_name
|
||||
file_path: Path = function.file_path
|
||||
file_contents: str = file_path.read_text(encoding="utf8")
|
||||
try:
|
||||
module: ast.Module = ast.parse(file_contents)
|
||||
except SyntaxError as e:
|
||||
logger.exception(f"get_type_annotation_context - Syntax error in code: {e}")
|
||||
return [], set()
|
||||
sources: list[FunctionSource] = []
|
||||
ast_parents: list[FunctionParent] = []
|
||||
contextual_dunder_methods = set()
|
||||
|
||||
def get_annotation_source(
|
||||
j_script: jedi.Script, name: str, node_parents: list[FunctionParent], line_no: int, col_no: str
|
||||
) -> None:
|
||||
try:
|
||||
definition: list[Name] = j_script.goto(
|
||||
line=line_no, column=col_no, follow_imports=True, follow_builtin_imports=False
|
||||
)
|
||||
except Exception as ex:
|
||||
if hasattr(name, "full_name"):
|
||||
logger.exception(f"Error while getting definition for {name.full_name}: {ex}")
|
||||
else:
|
||||
logger.exception(f"Error while getting definition: {ex}")
|
||||
definition = []
|
||||
if definition: # TODO can be multiple definitions
|
||||
definition_path = definition[0].module_path
|
||||
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||
and definition[0].full_name
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and not belongs_to_function(definition[0], function_name)
|
||||
):
|
||||
source_code = get_code([FunctionToOptimize(definition[0].name, definition_path, node_parents[:-1])])
|
||||
if source_code[0]:
|
||||
sources.append(
|
||||
FunctionSource(
|
||||
fully_qualified_name=definition[0].full_name,
|
||||
jedi_definition=definition[0],
|
||||
source_code=source_code[0],
|
||||
file_path=definition_path,
|
||||
qualified_name=definition[0].full_name.removeprefix(definition[0].module_name + "."),
|
||||
only_function_name=definition[0].name,
|
||||
)
|
||||
)
|
||||
contextual_dunder_methods.update(source_code[1])
|
||||
|
||||
def visit_children(
|
||||
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, node_parents: list[FunctionParent]
|
||||
) -> None:
|
||||
child: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module
|
||||
for child in ast.iter_child_nodes(node):
|
||||
visit(child, node_parents)
|
||||
|
||||
def visit_all_annotation_children(
|
||||
node: ast.Subscript | ast.Name | ast.BinOp, node_parents: list[FunctionParent]
|
||||
) -> None:
|
||||
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
visit_all_annotation_children(node.left, node_parents)
|
||||
visit_all_annotation_children(node.right, node_parents)
|
||||
if isinstance(node, ast.Name) and hasattr(node, "id"):
|
||||
name: str = node.id
|
||||
line_no: int = node.lineno
|
||||
col_no: int = node.col_offset
|
||||
get_annotation_source(jedi_script, name, node_parents, line_no, col_no)
|
||||
if isinstance(node, ast.Subscript):
|
||||
if hasattr(node, "slice"):
|
||||
if isinstance(node.slice, ast.Subscript):
|
||||
visit_all_annotation_children(node.slice, node_parents)
|
||||
elif isinstance(node.slice, ast.Tuple):
|
||||
for elt in node.slice.elts:
|
||||
if isinstance(elt, (ast.Name, ast.Subscript)):
|
||||
visit_all_annotation_children(elt, node_parents)
|
||||
elif isinstance(node.slice, ast.Name):
|
||||
visit_all_annotation_children(node.slice, node_parents)
|
||||
if hasattr(node, "value"):
|
||||
visit_all_annotation_children(node.value, node_parents)
|
||||
|
||||
def visit(
|
||||
node: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module,
|
||||
node_parents: list[FunctionParent],
|
||||
) -> None:
|
||||
if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
if node.name == function_name and node_parents == function.parents:
|
||||
arg: ast.arg
|
||||
for arg in node.args.args:
|
||||
if arg.annotation:
|
||||
visit_all_annotation_children(arg.annotation, node_parents)
|
||||
if node.returns:
|
||||
visit_all_annotation_children(node.returns, node_parents)
|
||||
|
||||
if not isinstance(node, ast.Module):
|
||||
node_parents.append(FunctionParent(node.name, type(node).__name__))
|
||||
visit_children(node, node_parents)
|
||||
if not isinstance(node, ast.Module):
|
||||
node_parents.pop()
|
||||
|
||||
visit(module, ast_parents)
|
||||
|
||||
return sources, contextual_dunder_methods
|
||||
#
|
||||
# def get_type_annotation_context(
|
||||
# function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path
|
||||
# ) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
|
||||
# function_name: str = function.function_name
|
||||
# file_path: Path = function.file_path
|
||||
# file_contents: str = file_path.read_text(encoding="utf8")
|
||||
# try:
|
||||
# module: ast.Module = ast.parse(file_contents)
|
||||
# except SyntaxError as e:
|
||||
# logger.exception(f"get_type_annotation_context - Syntax error in code: {e}")
|
||||
# return [], set()
|
||||
# sources: list[FunctionSource] = []
|
||||
# ast_parents: list[FunctionParent] = []
|
||||
# contextual_dunder_methods = set()
|
||||
#
|
||||
# def get_annotation_source(
|
||||
# j_script: jedi.Script, name: str, node_parents: list[FunctionParent], line_no: int, col_no: str
|
||||
# ) -> None:
|
||||
# try:
|
||||
# definition: list[Name] = j_script.goto(
|
||||
# line=line_no, column=col_no, follow_imports=True, follow_builtin_imports=False
|
||||
# )
|
||||
# except Exception as ex:
|
||||
# if hasattr(name, "full_name"):
|
||||
# logger.exception(f"Error while getting definition for {name.full_name}: {ex}")
|
||||
# else:
|
||||
# logger.exception(f"Error while getting definition: {ex}")
|
||||
# definition = []
|
||||
# if definition: # TODO can be multiple definitions
|
||||
# definition_path = definition[0].module_path
|
||||
#
|
||||
# # The definition is part of this project and not defined within the original function
|
||||
# if (
|
||||
# str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||
# and definition[0].full_name
|
||||
# and not path_belongs_to_site_packages(definition_path)
|
||||
# and not belongs_to_function(definition[0], function_name)
|
||||
# ):
|
||||
# source_code = get_code([FunctionToOptimize(definition[0].name, definition_path, node_parents[:-1])])
|
||||
# if source_code[0]:
|
||||
# sources.append(
|
||||
# FunctionSource(
|
||||
# fully_qualified_name=definition[0].full_name,
|
||||
# jedi_definition=definition[0],
|
||||
# source_code=source_code[0],
|
||||
# file_path=definition_path,
|
||||
# qualified_name=definition[0].full_name.removeprefix(definition[0].module_name + "."),
|
||||
# only_function_name=definition[0].name,
|
||||
# )
|
||||
# )
|
||||
# contextual_dunder_methods.update(source_code[1])
|
||||
#
|
||||
# def visit_children(
|
||||
# node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, node_parents: list[FunctionParent]
|
||||
# ) -> None:
|
||||
# child: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module
|
||||
# for child in ast.iter_child_nodes(node):
|
||||
# visit(child, node_parents)
|
||||
#
|
||||
# def visit_all_annotation_children(
|
||||
# node: ast.Subscript | ast.Name | ast.BinOp, node_parents: list[FunctionParent]
|
||||
# ) -> None:
|
||||
# if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
# visit_all_annotation_children(node.left, node_parents)
|
||||
# visit_all_annotation_children(node.right, node_parents)
|
||||
# if isinstance(node, ast.Name) and hasattr(node, "id"):
|
||||
# name: str = node.id
|
||||
# line_no: int = node.lineno
|
||||
# col_no: int = node.col_offset
|
||||
# get_annotation_source(jedi_script, name, node_parents, line_no, col_no)
|
||||
# if isinstance(node, ast.Subscript):
|
||||
# if hasattr(node, "slice"):
|
||||
# if isinstance(node.slice, ast.Subscript):
|
||||
# visit_all_annotation_children(node.slice, node_parents)
|
||||
# elif isinstance(node.slice, ast.Tuple):
|
||||
# for elt in node.slice.elts:
|
||||
# if isinstance(elt, (ast.Name, ast.Subscript)):
|
||||
# visit_all_annotation_children(elt, node_parents)
|
||||
# elif isinstance(node.slice, ast.Name):
|
||||
# visit_all_annotation_children(node.slice, node_parents)
|
||||
# if hasattr(node, "value"):
|
||||
# visit_all_annotation_children(node.value, node_parents)
|
||||
#
|
||||
# def visit(
|
||||
# node: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module,
|
||||
# node_parents: list[FunctionParent],
|
||||
# ) -> None:
|
||||
# if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
# if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
# if node.name == function_name and node_parents == function.parents:
|
||||
# arg: ast.arg
|
||||
# for arg in node.args.args:
|
||||
# if arg.annotation:
|
||||
# visit_all_annotation_children(arg.annotation, node_parents)
|
||||
# if node.returns:
|
||||
# visit_all_annotation_children(node.returns, node_parents)
|
||||
#
|
||||
# if not isinstance(node, ast.Module):
|
||||
# node_parents.append(FunctionParent(node.name, type(node).__name__))
|
||||
# visit_children(node, node_parents)
|
||||
# if not isinstance(node, ast.Module):
|
||||
# node_parents.pop()
|
||||
#
|
||||
# visit(module, ast_parents)
|
||||
#
|
||||
# return sources, contextual_dunder_methods
|
||||
|
||||
|
||||
def get_function_variables_definitions(
|
||||
function_to_optimize: FunctionToOptimize, project_root_path: Path
|
||||
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
|
||||
function_name = function_to_optimize.function_name
|
||||
file_path = function_to_optimize.file_path
|
||||
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
|
||||
sources: list[FunctionSource] = []
|
||||
contextual_dunder_methods = set()
|
||||
# TODO: The function name condition can be stricter so that it does not clash with other class names etc.
|
||||
# TODO: The function could have been imported as some other name,
|
||||
# we should be checking for the translation as well. Also check for the original function name.
|
||||
names = []
|
||||
for ref in script.get_names(all_scopes=True, definitions=False, references=True):
|
||||
if ref.full_name:
|
||||
if function_to_optimize.parents:
|
||||
# Check if the reference belongs to the specified class when FunctionParent is provided
|
||||
if belongs_to_method(ref, function_to_optimize.parents[-1].name, function_name):
|
||||
names.append(ref)
|
||||
elif belongs_to_function(ref, function_name):
|
||||
names.append(ref)
|
||||
|
||||
for name in names:
|
||||
try:
|
||||
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
|
||||
except Exception as e:
|
||||
try:
|
||||
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
|
||||
except Exception as e:
|
||||
# name.full_name can also throw exceptions sometimes
|
||||
logger.exception(f"Error while getting definition: {e}")
|
||||
definitions = []
|
||||
if definitions:
|
||||
# TODO: there can be multiple definitions, see how to handle such cases
|
||||
definition = definitions[0]
|
||||
definition_path = definition.module_path
|
||||
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and definition.full_name
|
||||
and not belongs_to_function(definition, function_name)
|
||||
):
|
||||
module_name = module_name_from_file_path(definition_path, project_root_path)
|
||||
m = re.match(rf"{module_name}\.(.*)\.{definitions[0].name}", definitions[0].full_name)
|
||||
parents = []
|
||||
if m:
|
||||
parents = [FunctionParent(m.group(1), "ClassDef")]
|
||||
|
||||
source_code = get_code(
|
||||
[FunctionToOptimize(function_name=definitions[0].name, file_path=definition_path, parents=parents)]
|
||||
)
|
||||
if source_code[0]:
|
||||
sources.append(
|
||||
FunctionSource(
|
||||
fully_qualified_name=definition.full_name,
|
||||
jedi_definition=definition,
|
||||
source_code=source_code[0],
|
||||
file_path=definition_path,
|
||||
qualified_name=definition.full_name.removeprefix(definition.module_name + "."),
|
||||
only_function_name=definition.name,
|
||||
)
|
||||
)
|
||||
contextual_dunder_methods.update(source_code[1])
|
||||
annotation_sources, annotation_dunder_methods = get_type_annotation_context(
|
||||
function_to_optimize, script, project_root_path
|
||||
)
|
||||
sources[:0] = annotation_sources # prepend the annotation sources
|
||||
contextual_dunder_methods.update(annotation_dunder_methods)
|
||||
existing_fully_qualified_names = set()
|
||||
no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(lambda: defaultdict(set))
|
||||
parent_sources = set()
|
||||
for source in sources:
|
||||
if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names:
|
||||
if not source.qualified_name.count("."):
|
||||
no_parent_sources[source.file_path][source.qualified_name].add(source)
|
||||
else:
|
||||
parent_sources.add(source)
|
||||
existing_fully_qualified_names.add(fully_qualified_name)
|
||||
deduped_parent_sources = [
|
||||
source
|
||||
for source in parent_sources
|
||||
if source.file_path not in no_parent_sources
|
||||
or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path]
|
||||
]
|
||||
deduped_no_parent_sources = [
|
||||
source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2]
|
||||
]
|
||||
return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods
|
||||
|
||||
|
||||
MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k
|
||||
|
||||
|
||||
def get_constrained_function_context_and_helper_functions(
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root_path: Path,
|
||||
code_to_optimize: str,
|
||||
max_tokens: int = MAX_PROMPT_TOKENS,
|
||||
) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]:
|
||||
helper_functions, dunder_methods = get_function_variables_definitions(function_to_optimize, project_root_path)
|
||||
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
code_to_optimize_tokens = tokenizer.encode(code_to_optimize)
|
||||
|
||||
if not function_to_optimize.parents:
|
||||
helper_functions_sources = [function.source_code for function in helper_functions]
|
||||
else:
|
||||
helper_functions_sources = [
|
||||
function.source_code
|
||||
for function in helper_functions
|
||||
if not function.qualified_name.count(".")
|
||||
or function.qualified_name.split(".")[0] != function_to_optimize.parents[0].name
|
||||
]
|
||||
helper_functions_tokens = [len(tokenizer.encode(function)) for function in helper_functions_sources]
|
||||
|
||||
context_list = []
|
||||
context_len = len(code_to_optimize_tokens)
|
||||
logger.debug(f"ORIGINAL CODE TOKENS LENGTH: {context_len}")
|
||||
logger.debug(f"ALL DEPENDENCIES TOKENS LENGTH: {sum(helper_functions_tokens)}")
|
||||
for function_source, source_len in zip(helper_functions_sources, helper_functions_tokens):
|
||||
if context_len + source_len <= max_tokens:
|
||||
context_list.append(function_source)
|
||||
context_len += source_len
|
||||
else:
|
||||
break
|
||||
logger.debug(f"FINAL OPTIMIZATION CONTEXT TOKENS LENGTH: {context_len}")
|
||||
helper_code: str = "\n".join(context_list)
|
||||
return helper_code, helper_functions, dunder_methods
|
||||
# def get_function_variables_definitions(
|
||||
# function_to_optimize: FunctionToOptimize, project_root_path: Path
|
||||
# ) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
|
||||
# function_name = function_to_optimize.function_name
|
||||
# file_path = function_to_optimize.file_path
|
||||
# script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
|
||||
# sources: list[FunctionSource] = []
|
||||
# contextual_dunder_methods = set()
|
||||
# # TODO: The function name condition can be stricter so that it does not clash with other class names etc.
|
||||
# # TODO: The function could have been imported as some other name,
|
||||
# # we should be checking for the translation as well. Also check for the original function name.
|
||||
# names = []
|
||||
# for ref in script.get_names(all_scopes=True, definitions=False, references=True):
|
||||
# if ref.full_name:
|
||||
# if function_to_optimize.parents:
|
||||
# # Check if the reference belongs to the specified class when FunctionParent is provided
|
||||
# if belongs_to_method(ref, function_to_optimize.parents[-1].name, function_name):
|
||||
# names.append(ref)
|
||||
# elif belongs_to_function(ref, function_name):
|
||||
# names.append(ref)
|
||||
#
|
||||
# for name in names:
|
||||
# try:
|
||||
# definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
|
||||
# except Exception as e:
|
||||
# try:
|
||||
# logger.exception(f"Error while getting definition for {name.full_name}: {e}")
|
||||
# except Exception as e:
|
||||
# # name.full_name can also throw exceptions sometimes
|
||||
# logger.exception(f"Error while getting definition: {e}")
|
||||
# definitions = []
|
||||
# if definitions:
|
||||
# # TODO: there can be multiple definitions, see how to handle such cases
|
||||
# definition = definitions[0]
|
||||
# definition_path = definition.module_path
|
||||
#
|
||||
# # The definition is part of this project and not defined within the original function
|
||||
# if (
|
||||
# str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||
# and not path_belongs_to_site_packages(definition_path)
|
||||
# and definition.full_name
|
||||
# and not belongs_to_function(definition, function_name)
|
||||
# ):
|
||||
# module_name = module_name_from_file_path(definition_path, project_root_path)
|
||||
# m = re.match(rf"{module_name}\.(.*)\.{definitions[0].name}", definitions[0].full_name)
|
||||
# parents = []
|
||||
# if m:
|
||||
# parents = [FunctionParent(m.group(1), "ClassDef")]
|
||||
#
|
||||
# source_code = get_code(
|
||||
# [FunctionToOptimize(function_name=definitions[0].name, file_path=definition_path, parents=parents)]
|
||||
# )
|
||||
# if source_code[0]:
|
||||
# sources.append(
|
||||
# FunctionSource(
|
||||
# fully_qualified_name=definition.full_name,
|
||||
# jedi_definition=definition,
|
||||
# source_code=source_code[0],
|
||||
# file_path=definition_path,
|
||||
# qualified_name=definition.full_name.removeprefix(definition.module_name + "."),
|
||||
# only_function_name=definition.name,
|
||||
# )
|
||||
# )
|
||||
# contextual_dunder_methods.update(source_code[1])
|
||||
# annotation_sources, annotation_dunder_methods = get_type_annotation_context(
|
||||
# function_to_optimize, script, project_root_path
|
||||
# )
|
||||
# sources[:0] = annotation_sources # prepend the annotation sources
|
||||
# contextual_dunder_methods.update(annotation_dunder_methods)
|
||||
# existing_fully_qualified_names = set()
|
||||
# no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(lambda: defaultdict(set))
|
||||
# parent_sources = set()
|
||||
# for source in sources:
|
||||
# if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names:
|
||||
# if not source.qualified_name.count("."):
|
||||
# no_parent_sources[source.file_path][source.qualified_name].add(source)
|
||||
# else:
|
||||
# parent_sources.add(source)
|
||||
# existing_fully_qualified_names.add(fully_qualified_name)
|
||||
# deduped_parent_sources = [
|
||||
# source
|
||||
# for source in parent_sources
|
||||
# if source.file_path not in no_parent_sources
|
||||
# or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path]
|
||||
# ]
|
||||
# deduped_no_parent_sources = [
|
||||
# source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2]
|
||||
# ]
|
||||
# return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods
|
||||
#
|
||||
#
|
||||
# MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k
|
||||
#
|
||||
#
|
||||
# def get_constrained_function_context_and_helper_functions(
|
||||
# function_to_optimize: FunctionToOptimize,
|
||||
# project_root_path: Path,
|
||||
# code_to_optimize: str,
|
||||
# max_tokens: int = MAX_PROMPT_TOKENS,
|
||||
# ) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]:
|
||||
# helper_functions, dunder_methods = get_function_variables_definitions(function_to_optimize, project_root_path)
|
||||
# tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
# code_to_optimize_tokens = tokenizer.encode(code_to_optimize)
|
||||
#
|
||||
# if not function_to_optimize.parents:
|
||||
# helper_functions_sources = [function.source_code for function in helper_functions]
|
||||
# else:
|
||||
# helper_functions_sources = [
|
||||
# function.source_code
|
||||
# for function in helper_functions
|
||||
# if not function.qualified_name.count(".")
|
||||
# or function.qualified_name.split(".")[0] != function_to_optimize.parents[0].name
|
||||
# ]
|
||||
# helper_functions_tokens = [len(tokenizer.encode(function)) for function in helper_functions_sources]
|
||||
#
|
||||
# context_list = []
|
||||
# context_len = len(code_to_optimize_tokens)
|
||||
# logger.debug(f"ORIGINAL CODE TOKENS LENGTH: {context_len}")
|
||||
# logger.debug(f"ALL DEPENDENCIES TOKENS LENGTH: {sum(helper_functions_tokens)}")
|
||||
# for function_source, source_len in zip(helper_functions_sources, helper_functions_tokens):
|
||||
# if context_len + source_len <= max_tokens:
|
||||
# context_list.append(function_source)
|
||||
# context_len += source_len
|
||||
# else:
|
||||
# break
|
||||
# logger.debug(f"FINAL OPTIMIZATION CONTEXT TOKENS LENGTH: {context_len}")
|
||||
# helper_code: str = "\n".join(context_list)
|
||||
# return helper_code, helper_functions, dunder_methods
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ from codeflash.models.models import (
|
|||
TestFiles,
|
||||
TestingMode,
|
||||
)
|
||||
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
|
||||
# from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
|
||||
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
|
||||
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
|
||||
from codeflash.result.explanation import Explanation
|
||||
|
|
@ -140,14 +140,14 @@ class FunctionOptimizer:
|
|||
logger.info("Code to be optimized:")
|
||||
code_print(code_context.read_writable_code)
|
||||
|
||||
for module_abspath, helper_code_source in original_helper_code.items():
|
||||
code_context.code_to_optimize_with_helpers = add_needed_imports_from_module(
|
||||
helper_code_source,
|
||||
code_context.code_to_optimize_with_helpers,
|
||||
module_abspath,
|
||||
self.function_to_optimize.file_path,
|
||||
self.args.project_root,
|
||||
)
|
||||
# for module_abspath, helper_code_source in original_helper_code.items():
|
||||
# code_context.code_to_optimize_with_helpers = add_needed_imports_from_module(
|
||||
# helper_code_source,
|
||||
# code_context.code_to_optimize_with_helpers,
|
||||
# module_abspath,
|
||||
# self.function_to_optimize.file_path,
|
||||
# self.args.project_root,
|
||||
# )
|
||||
|
||||
generated_test_paths = [
|
||||
get_test_file_path(
|
||||
|
|
@ -167,7 +167,7 @@ class FunctionOptimizer:
|
|||
transient=True,
|
||||
):
|
||||
generated_results = self.generate_tests_and_optimizations(
|
||||
code_to_optimize_with_helpers=code_context.code_to_optimize_with_helpers,
|
||||
testgen_context_code=code_context.testgen_context_code,
|
||||
read_writable_code=code_context.read_writable_code,
|
||||
read_only_context_code=code_context.read_only_context_code,
|
||||
helper_functions=code_context.helper_functions,
|
||||
|
|
@ -556,49 +556,49 @@ class FunctionOptimizer:
|
|||
return did_update
|
||||
|
||||
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
|
||||
code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize])
|
||||
if code_to_optimize is None:
|
||||
return Failure("Could not find function to optimize.")
|
||||
(helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions(
|
||||
self.function_to_optimize, self.project_root, code_to_optimize
|
||||
)
|
||||
if self.function_to_optimize.parents:
|
||||
function_class = self.function_to_optimize.parents[0].name
|
||||
same_class_helper_methods = [
|
||||
df
|
||||
for df in helper_functions
|
||||
if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class
|
||||
]
|
||||
optimizable_methods = [
|
||||
FunctionToOptimize(
|
||||
df.qualified_name.split(".")[-1],
|
||||
df.file_path,
|
||||
[FunctionParent(df.qualified_name.split(".")[0], "ClassDef")],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
for df in same_class_helper_methods
|
||||
] + [self.function_to_optimize]
|
||||
dedup_optimizable_methods = []
|
||||
added_methods = set()
|
||||
for method in reversed(optimizable_methods):
|
||||
if f"{method.file_path}.{method.qualified_name}" not in added_methods:
|
||||
dedup_optimizable_methods.append(method)
|
||||
added_methods.add(f"{method.file_path}.{method.qualified_name}")
|
||||
if len(dedup_optimizable_methods) > 1:
|
||||
code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods)))
|
||||
if code_to_optimize is None:
|
||||
return Failure("Could not find function to optimize.")
|
||||
code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize
|
||||
|
||||
code_to_optimize_with_helpers_and_imports = add_needed_imports_from_module(
|
||||
self.function_to_optimize_source_code,
|
||||
code_to_optimize_with_helpers,
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize.file_path,
|
||||
self.project_root,
|
||||
helper_functions,
|
||||
)
|
||||
# code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize])
|
||||
# if code_to_optimize is None:
|
||||
# return Failure("Could not find function to optimize.")
|
||||
# (helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions(
|
||||
# self.function_to_optimize, self.project_root, code_to_optimize
|
||||
# )
|
||||
# if self.function_to_optimize.parents:
|
||||
# function_class = self.function_to_optimize.parents[0].name
|
||||
# same_class_helper_methods = [
|
||||
# df
|
||||
# for df in helper_functions
|
||||
# if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class
|
||||
# ]
|
||||
# optimizable_methods = [
|
||||
# FunctionToOptimize(
|
||||
# df.qualified_name.split(".")[-1],
|
||||
# df.file_path,
|
||||
# [FunctionParent(df.qualified_name.split(".")[0], "ClassDef")],
|
||||
# None,
|
||||
# None,
|
||||
# )
|
||||
# for df in same_class_helper_methods
|
||||
# ] + [self.function_to_optimize]
|
||||
# dedup_optimizable_methods = []
|
||||
# added_methods = set()
|
||||
# for method in reversed(optimizable_methods):
|
||||
# if f"{method.file_path}.{method.qualified_name}" not in added_methods:
|
||||
# dedup_optimizable_methods.append(method)
|
||||
# added_methods.add(f"{method.file_path}.{method.qualified_name}")
|
||||
# if len(dedup_optimizable_methods) > 1:
|
||||
# code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods)))
|
||||
# if code_to_optimize is None:
|
||||
# return Failure("Could not find function to optimize.")
|
||||
# code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize
|
||||
#
|
||||
# code_to_optimize_with_helpers_and_imports = add_needed_imports_from_module(
|
||||
# self.function_to_optimize_source_code,
|
||||
# code_to_optimize_with_helpers,
|
||||
# self.function_to_optimize.file_path,
|
||||
# self.function_to_optimize.file_path,
|
||||
# self.project_root,
|
||||
# helper_functions,
|
||||
# )
|
||||
|
||||
try:
|
||||
new_code_ctx = code_context_extractor.get_code_optimization_context(
|
||||
|
|
@ -609,7 +609,8 @@ class FunctionOptimizer:
|
|||
|
||||
return Success(
|
||||
CodeOptimizationContext(
|
||||
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
|
||||
# code_to_optimize_with_helpers=new_code_ctx.testgen_context_code, # Outdated, fix this!
|
||||
testgen_context_code=new_code_ctx.testgen_context_code,
|
||||
read_writable_code=new_code_ctx.read_writable_code,
|
||||
read_only_context_code=new_code_ctx.read_only_context_code,
|
||||
helper_functions=new_code_ctx.helper_functions, # only functions that are read writable
|
||||
|
|
@ -711,7 +712,7 @@ class FunctionOptimizer:
|
|||
|
||||
def generate_tests_and_optimizations(
|
||||
self,
|
||||
code_to_optimize_with_helpers: str,
|
||||
testgen_context_code: str,
|
||||
read_writable_code: str,
|
||||
read_only_context_code: str,
|
||||
helper_functions: list[FunctionSource],
|
||||
|
|
@ -726,7 +727,7 @@ class FunctionOptimizer:
|
|||
# Submit the test generation task as future
|
||||
future_tests = self.generate_and_instrument_tests(
|
||||
executor,
|
||||
code_to_optimize_with_helpers,
|
||||
testgen_context_code,
|
||||
[definition.fully_qualified_name for definition in helper_functions],
|
||||
generated_test_paths,
|
||||
generated_perf_test_paths,
|
||||
|
|
|
|||
|
|
@ -740,7 +740,7 @@ class HelperClass:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
|
||||
expected_read_write_context = """
|
||||
|
|
@ -813,6 +813,57 @@ class HelperClass:
|
|||
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
|
||||
def test_example_class_token_limit_4() -> None:
|
||||
string_filler = " ".join(
|
||||
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
||||
)
|
||||
code = f"""
|
||||
class MyClass:
|
||||
\"\"\"A class with a helper method. \"\"\"
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
x = '{string_filler}'
|
||||
|
||||
class HelperClass:
|
||||
\"\"\"A helper class for MyClass.\"\"\"
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def __repr__(self):
|
||||
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
||||
return "HelperClass" + str(self.x)
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
file_path = Path(f.name).resolve()
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.parent.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
# In this scenario, the testgen code context is too long, so we abort.
|
||||
with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"):
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
|
||||
def test_repo_helper() -> None:
|
||||
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
||||
|
|
|
|||
|
|
@ -747,24 +747,28 @@ class MainClass:
|
|||
|
||||
|
||||
def test_code_replacement10() -> None:
|
||||
get_code_output = """from __future__ import annotations
|
||||
get_code_output = """```python:test_code_replacement.py
|
||||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
||||
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def innocent_bystander(self):
|
||||
pass
|
||||
|
||||
def helper_method(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class MainClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def main_method(self):
|
||||
return HelperClass(self.name).helper_method()
|
||||
"""
|
||||
```"""
|
||||
file_path = Path(__file__).resolve()
|
||||
func_top_optimize = FunctionToOptimize(
|
||||
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
|
||||
|
|
@ -778,7 +782,7 @@ class MainClass:
|
|||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
|
||||
code_context = func_optimizer.get_code_optimization_context().unwrap()
|
||||
assert code_context.code_to_optimize_with_helpers == get_code_output
|
||||
assert code_context.testgen_context_code == get_code_output
|
||||
|
||||
|
||||
def test_code_replacement11() -> None:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.function_context import get_function_variables_definitions
|
||||
# from codeflash.optimization.function_context import get_function_variables_definitions
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
|
@ -18,15 +18,6 @@ def simple_function_with_one_dep(data):
|
|||
return calculate_something(data)
|
||||
|
||||
|
||||
def test_simple_dependencies() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
FunctionToOptimize("simple_function_with_one_dep", str(file_path), []), str(file_path.parent.resolve())
|
||||
)[0]
|
||||
assert len(helper_functions) == 1
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.calculate_something"
|
||||
|
||||
|
||||
def global_dependency_1(num):
|
||||
return num + 1
|
||||
|
||||
|
|
@ -93,63 +84,12 @@ class C:
|
|||
return self.recursive(num) + num_1
|
||||
|
||||
|
||||
def test_multiple_classes_dependencies() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
FunctionToOptimize("run", str(file_path), [FunctionParent("C", "ClassDef")]), str(file_path.parent.resolve())
|
||||
)
|
||||
|
||||
assert len(helper_functions) == 2
|
||||
assert list(map(lambda x: x.fully_qualified_name, helper_functions[0])) == [
|
||||
"test_function_dependencies.global_dependency_3",
|
||||
"test_function_dependencies.C.calculate_something_3",
|
||||
]
|
||||
|
||||
|
||||
def recursive_dependency_1(num):
|
||||
if num == 0:
|
||||
return 0
|
||||
num_1 = calculate_something(num)
|
||||
return recursive_dependency_1(num) + num_1
|
||||
|
||||
|
||||
def test_recursive_dependency() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
FunctionToOptimize("recursive_dependency_1", str(file_path), []), str(file_path.parent.resolve())
|
||||
)[0]
|
||||
assert len(helper_functions) == 1
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.calculate_something"
|
||||
assert helper_functions[0].fully_qualified_name == "test_function_dependencies.calculate_something"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyData:
|
||||
MyInt: int
|
||||
|
||||
|
||||
def calculate_something_ann(data):
|
||||
return data + 1
|
||||
|
||||
|
||||
def simple_function_with_one_dep_ann(data: MyData):
|
||||
return calculate_something_ann(data)
|
||||
|
||||
|
||||
def list_comprehension_dependency(data: MyData):
|
||||
return [calculate_something(data) for x in range(10)]
|
||||
|
||||
|
||||
def test_simple_dependencies_ann() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
FunctionToOptimize("simple_function_with_one_dep_ann", str(file_path), []), str(file_path.parent.resolve())
|
||||
)[0]
|
||||
assert len(helper_functions) == 2
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData"
|
||||
assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something_ann"
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
|
|
@ -220,13 +160,15 @@ def test_class_method_dependencies() -> None:
|
|||
)
|
||||
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
|
||||
assert (
|
||||
code_context.code_to_optimize_with_helpers
|
||||
== """from collections import defaultdict
|
||||
code_context.testgen_context_code
|
||||
== """```python:test_function_dependencies.py
|
||||
from collections import defaultdict
|
||||
|
||||
class Graph:
|
||||
def __init__(self, vertices):
|
||||
self.graph = defaultdict(list)
|
||||
self.V = vertices # No. of vertices
|
||||
|
||||
def topologicalSortUtil(self, v, visited, stack):
|
||||
visited[v] = True
|
||||
|
||||
|
|
@ -235,6 +177,7 @@ class Graph:
|
|||
self.topologicalSortUtil(i, visited, stack)
|
||||
|
||||
stack.insert(0, v)
|
||||
|
||||
def topologicalSort(self):
|
||||
visited = [False] * self.V
|
||||
stack = []
|
||||
|
|
@ -245,39 +188,9 @@ class Graph:
|
|||
|
||||
# Print contents of stack
|
||||
return stack
|
||||
"""
|
||||
```"""
|
||||
)
|
||||
|
||||
|
||||
def calculate_something_else(data):
|
||||
return data + 1
|
||||
|
||||
|
||||
def imalittledecorator(func):
|
||||
def wrapper(data):
|
||||
return func(data)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@imalittledecorator
|
||||
def simple_function_with_decorator_dep(data):
|
||||
return calculate_something_else(data)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="no decorator dependency support")
|
||||
def test_decorator_dependencies() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
FunctionToOptimize("simple_function_with_decorator_dep", str(file_path), []), str(file_path.parent.resolve())
|
||||
)[0]
|
||||
assert len(helper_functions) == 2
|
||||
assert {helper_functions[0][0].definition.full_name, helper_functions[1][0].definition.full_name} == {
|
||||
"test_function_dependencies.calculate_something",
|
||||
"test_function_dependencies.imalittledecorator",
|
||||
}
|
||||
|
||||
|
||||
def test_recursive_function_context() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
|
||||
|
|
@ -309,73 +222,16 @@ def test_recursive_function_context() -> None:
|
|||
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
|
||||
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
|
||||
assert (
|
||||
code_context.code_to_optimize_with_helpers
|
||||
== """class C:
|
||||
code_context.testgen_context_code
|
||||
== """```python:test_function_dependencies.py
|
||||
class C:
|
||||
def calculate_something_3(self, num):
|
||||
return num + 1
|
||||
|
||||
def recursive(self, num):
|
||||
if num == 0:
|
||||
return 0
|
||||
num_1 = self.calculate_something_3(num)
|
||||
return self.recursive(num) + num_1
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_list_comprehension_dependency() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
FunctionToOptimize("list_comprehension_dependency", str(file_path), []), str(file_path.parent.resolve())
|
||||
)[0]
|
||||
assert len(helper_functions) == 2
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData"
|
||||
assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something"
|
||||
|
||||
|
||||
def test_function_in_method_list_comprehension() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="function_in_list_comprehension",
|
||||
file_path=str(file_path),
|
||||
parents=[FunctionParent(name="A", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0]
|
||||
|
||||
assert len(helper_functions) == 1
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.global_dependency_3"
|
||||
|
||||
|
||||
def test_method_in_method_list_comprehension() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="method_in_list_comprehension",
|
||||
file_path=str(file_path),
|
||||
parents=[FunctionParent(name="A", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0]
|
||||
|
||||
assert len(helper_functions) == 1
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two"
|
||||
|
||||
|
||||
def test_nested_method() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="nested_function",
|
||||
file_path=str(file_path),
|
||||
parents=[FunctionParent(name="A", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0]
|
||||
|
||||
# The nested function should be included in the helper functions
|
||||
assert len(helper_functions) == 1
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two"
|
||||
```"""
|
||||
)
|
||||
|
|
@ -239,13 +239,37 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
pytest.fail()
|
||||
code_context = ctx_result.unwrap()
|
||||
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
|
||||
|
||||
assert (
|
||||
code_context.code_to_optimize_with_helpers
|
||||
== '''_R = TypeVar("_R")
|
||||
|
||||
code_context.testgen_context_code
|
||||
== f'''```python:{file_path.name}
|
||||
_P = ParamSpec("_P")
|
||||
_KEY_T = TypeVar("_KEY_T")
|
||||
_STORE_T = TypeVar("_STORE_T")
|
||||
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
||||
"""Interface for cache backends used by the persistent cache decorator."""
|
||||
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
def hash_key(
|
||||
self,
|
||||
*,
|
||||
func: Callable[_P, Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> tuple[str, _KEY_T]: ...
|
||||
|
||||
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
|
||||
...
|
||||
|
||||
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
|
||||
...
|
||||
|
||||
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
|
||||
|
||||
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
|
||||
|
||||
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
|
||||
|
||||
def get_cache_or_call(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -300,7 +324,33 @@ class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
|||
# If encoding fails, we should still return the result.
|
||||
return result
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
|
||||
|
||||
|
||||
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
||||
"""
|
||||
A decorator class that provides persistent caching functionality for a function.
|
||||
|
||||
Args:
|
||||
----
|
||||
func (Callable[_P, _R]): The function to be decorated.
|
||||
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
|
||||
backend (_backend): The backend storage for the cached results.
|
||||
|
||||
Attributes:
|
||||
----------
|
||||
__wrapped__ (Callable[_P, _R]): The wrapped function.
|
||||
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
|
||||
__backend__ (_backend): The backend storage for the cached results.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
__wrapped__: Callable[_P, _R]
|
||||
__duration__: datetime.timedelta
|
||||
__backend__: _CacheBackendT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[_P, _R],
|
||||
|
|
@ -310,6 +360,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
self.__duration__ = duration
|
||||
self.__backend__ = AbstractCacheBackend()
|
||||
functools.update_wrapper(self, func)
|
||||
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
"""
|
||||
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
|
||||
|
|
@ -333,7 +384,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
kwargs=kwargs,
|
||||
lifespan=self.__duration__,
|
||||
)
|
||||
'''
|
||||
```'''
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -358,14 +409,20 @@ def test_bubble_sort_deps() -> None:
|
|||
pytest.fail()
|
||||
code_context = ctx_result.unwrap()
|
||||
assert (
|
||||
code_context.code_to_optimize_with_helpers
|
||||
== """def dep1_comparer(arr, j: int) -> bool:
|
||||
code_context.testgen_context_code
|
||||
== """```python:code_to_optimize/bubble_sort_dep1_helper.py
|
||||
def dep1_comparer(arr, j: int) -> bool:
|
||||
return arr[j] > arr[j + 1]
|
||||
|
||||
```
|
||||
```python:code_to_optimize/bubble_sort_dep2_swap.py
|
||||
def dep2_swap(arr, j):
|
||||
temp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
```
|
||||
```python:code_to_optimize/bubble_sort_deps.py
|
||||
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
|
||||
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
|
||||
|
||||
def sorter_deps(arr):
|
||||
for i in range(len(arr)):
|
||||
|
|
@ -373,7 +430,7 @@ def sorter_deps(arr):
|
|||
if dep1_comparer(arr, j):
|
||||
dep2_swap(arr, j)
|
||||
return arr
|
||||
"""
|
||||
```"""
|
||||
)
|
||||
assert len(code_context.helper_functions) == 2
|
||||
assert (
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ from textwrap import dedent
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.context.code_context_extractor import get_read_only_code
|
||||
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.models.models import CodeContextType
|
||||
|
||||
|
||||
def test_basic_class() -> None:
|
||||
|
|
@ -22,7 +23,7 @@ def test_basic_class() -> None:
|
|||
class_var = "value"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -46,7 +47,7 @@ def test_dunder_methods() -> None:
|
|||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -72,7 +73,7 @@ def test_dunder_methods_remove_docstring() -> None:
|
|||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -97,7 +98,7 @@ def test_class_remove_docstring() -> None:
|
|||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -124,7 +125,7 @@ def test_mixed_remove_docstring() -> None:
|
|||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -142,7 +143,7 @@ def test_target_in_nested_class() -> None:
|
|||
"""
|
||||
|
||||
with pytest.raises(ValueError, match="No target functions found in the provided code"):
|
||||
get_read_only_code(dedent(code), {"Outer.Inner.target_method"}, set())
|
||||
parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"Outer.Inner.target_method"}, set())
|
||||
|
||||
|
||||
def test_docstrings() -> None:
|
||||
|
|
@ -164,7 +165,7 @@ def test_docstrings() -> None:
|
|||
\"\"\"Class docstring.\"\"\"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -183,7 +184,7 @@ def test_method_signatures() -> None:
|
|||
|
||||
expected = """"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -203,7 +204,7 @@ def test_multiple_top_level_targets() -> None:
|
|||
expected = """
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target1", "TestClass.target2"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target1", "TestClass.target2"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -223,7 +224,7 @@ def test_class_annotations() -> None:
|
|||
var2: str
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -245,7 +246,7 @@ def test_class_annotations_if() -> None:
|
|||
var2: str
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -271,7 +272,7 @@ def test_class_annotations_try() -> None:
|
|||
continue
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -307,7 +308,7 @@ def test_class_annotations_else() -> None:
|
|||
var2: str
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -322,7 +323,7 @@ def test_top_level_functions() -> None:
|
|||
|
||||
expected = """"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"target_function"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -341,7 +342,7 @@ def test_module_var() -> None:
|
|||
x = 5
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"target_function"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -352,7 +353,7 @@ def test_module_var_if() -> None:
|
|||
|
||||
if y:
|
||||
x = 5
|
||||
else:
|
||||
else:
|
||||
z = 10
|
||||
def some_function():
|
||||
print("wow")
|
||||
|
|
@ -364,11 +365,11 @@ def test_module_var_if() -> None:
|
|||
expected = """
|
||||
if y:
|
||||
x = 5
|
||||
else:
|
||||
else:
|
||||
z = 10
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"target_function"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -403,7 +404,7 @@ def test_conditional_class_definitions() -> None:
|
|||
platform = "other"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"PlatformClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -462,7 +463,7 @@ def test_multiple_except_clauses() -> None:
|
|||
error_type = "cleanup"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -515,7 +516,7 @@ def test_with_statement_and_loops() -> None:
|
|||
context = "cleanup"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -564,7 +565,7 @@ def test_async_with_try_except() -> None:
|
|||
status = "cancelled"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -664,7 +665,7 @@ def test_simplified_complete_implementation() -> None:
|
|||
pass
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -751,7 +752,7 @@ def test_simplified_complete_implementation_no_docstring() -> None:
|
|||
pass
|
||||
"""
|
||||
|
||||
output = get_read_only_code(
|
||||
dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
from codeflash.context.code_context_extractor import get_read_writable_code
|
||||
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.models.models import CodeContextType
|
||||
|
||||
|
||||
def test_simple_function() -> None:
|
||||
|
|
@ -11,7 +12,7 @@ def test_simple_function() -> None:
|
|||
y = 2
|
||||
return x + y
|
||||
"""
|
||||
result = get_read_writable_code(dedent(code), {"target_function"})
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
|
||||
expected = dedent("""
|
||||
def target_function():
|
||||
|
|
@ -30,7 +31,7 @@ def test_class_method() -> None:
|
|||
y = 2
|
||||
return x + y
|
||||
"""
|
||||
result = get_read_writable_code(dedent(code), {"MyClass.target_function"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"})
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -54,7 +55,7 @@ def test_class_with_attributes() -> None:
|
|||
def other_method(self):
|
||||
print("this should be excluded")
|
||||
"""
|
||||
result = get_read_writable_code(dedent(code), {"MyClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -78,7 +79,7 @@ def test_basic_class_structure() -> None:
|
|||
def not_findable(self):
|
||||
return 42
|
||||
"""
|
||||
result = get_read_writable_code(dedent(code), {"Outer.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"Outer.target_method"})
|
||||
|
||||
expected = dedent("""
|
||||
class Outer:
|
||||
|
|
@ -98,7 +99,7 @@ def test_top_level_targets() -> None:
|
|||
def target_function():
|
||||
return 42
|
||||
"""
|
||||
result = get_read_writable_code(dedent(code), {"target_function"})
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
|
||||
expected = dedent("""
|
||||
def target_function():
|
||||
|
|
@ -121,7 +122,7 @@ def test_multiple_top_level_classes() -> None:
|
|||
def process(self):
|
||||
return "C"
|
||||
"""
|
||||
result = get_read_writable_code(dedent(code), {"ClassA.process", "ClassC.process"})
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"})
|
||||
|
||||
expected = dedent("""
|
||||
class ClassA:
|
||||
|
|
@ -146,7 +147,7 @@ def test_try_except_structure() -> None:
|
|||
def handle_error(self):
|
||||
print("error")
|
||||
"""
|
||||
result = get_read_writable_code(dedent(code), {"TargetClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"TargetClass.target_method"})
|
||||
|
||||
expected = dedent("""
|
||||
try:
|
||||
|
|
@ -173,7 +174,7 @@ def test_init_method() -> None:
|
|||
def target_method(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
result = get_read_writable_code(dedent(code), {"MyClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -197,7 +198,7 @@ def test_dunder_method() -> None:
|
|||
def target_method(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
result = get_read_writable_code(dedent(code), {"MyClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -218,7 +219,7 @@ def test_no_targets_found() -> None:
|
|||
pass
|
||||
"""
|
||||
with pytest.raises(ValueError, match="No target functions found in the provided code"):
|
||||
get_read_writable_code(dedent(code), {"MyClass.Inner.target"})
|
||||
parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
|
||||
|
||||
|
||||
def test_module_var() -> None:
|
||||
|
|
@ -242,7 +243,7 @@ def test_module_var() -> None:
|
|||
var2 = "test"
|
||||
"""
|
||||
|
||||
output = get_read_writable_code(dedent(code), {"target_function"})
|
||||
output = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
|
|||
745
tests/test_get_testgen_code.py
Normal file
745
tests/test_get_testgen_code.py
Normal file
|
|
@ -0,0 +1,745 @@
|
|||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.models.models import CodeContextType
|
||||
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
|
||||
|
||||
def test_simple_function() -> None:
|
||||
code = """
|
||||
def target_function():
|
||||
x = 1
|
||||
y = 2
|
||||
return x + y
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
|
||||
|
||||
expected = """
|
||||
def target_function():
|
||||
x = 1
|
||||
y = 2
|
||||
return x + y
|
||||
"""
|
||||
assert dedent(expected).strip() == result.strip()
|
||||
|
||||
def test_basic_class() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
class_var = "value"
|
||||
|
||||
def target_method(self):
|
||||
print("This should be included")
|
||||
|
||||
def other_method(self):
|
||||
print("This too")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
class_var = "value"
|
||||
|
||||
def target_method(self):
|
||||
print("This should be included")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
def test_dunder_methods() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
|
||||
def target_method(self):
|
||||
print("include me")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
|
||||
def target_method(self):
|
||||
print("include me")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_dunder_methods_remove_docstring() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
\"\"\"Constructor for TestClass.\"\"\"
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
\"\"\"String representation of TestClass.\"\"\"
|
||||
return f"Value: {self.x}"
|
||||
|
||||
def target_method(self):
|
||||
\"\"\"Target method docstring.\"\"\"
|
||||
print("include me")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
|
||||
def target_method(self):
|
||||
print("include me")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_class_remove_docstring() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
\"\"\"Class docstring.\"\"\"
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
|
||||
def target_method(self):
|
||||
print("include me")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
|
||||
def target_method(self):
|
||||
print("include me")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_target_in_nested_class() -> None:
|
||||
"""Test that attempting to find a target in a nested class raises an error."""
|
||||
code = """
|
||||
class Outer:
|
||||
outer_var = 1
|
||||
|
||||
class Inner:
|
||||
inner_var = 2
|
||||
|
||||
def target_method(self):
|
||||
print("include this")
|
||||
"""
|
||||
|
||||
with pytest.raises(ValueError, match="No target functions found in the provided code"):
|
||||
parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"Outer.Inner.target_method"}, set())
|
||||
|
||||
def test_method_signatures() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
@property
|
||||
def target_method(self) -> str:
|
||||
\"\"\"Property docstring.\"\"\"
|
||||
return "value"
|
||||
|
||||
@classmethod
|
||||
def class_method(cls, param: int = 42) -> None:
|
||||
print("class method")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
@property
|
||||
def target_method(self) -> str:
|
||||
\"\"\"Property docstring.\"\"\"
|
||||
return "value"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
def test_multiple_top_level_targets() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
def target1(self):
|
||||
print("include 1")
|
||||
|
||||
def target2(self):
|
||||
print("include 2")
|
||||
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def other_method(self):
|
||||
print("include other")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
def target1(self):
|
||||
print("include 1")
|
||||
|
||||
def target2(self):
|
||||
print("include 2")
|
||||
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target1", "TestClass.target2"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_class_annotations() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
|
||||
def target_method(self) -> None:
|
||||
self.var2 = "test"
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
|
||||
def target_method(self) -> None:
|
||||
self.var2 = "test"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
def test_class_annotations_if() -> None:
|
||||
code = """
|
||||
if True:
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
|
||||
def target_method(self) -> None:
|
||||
self.var2 = "test"
|
||||
"""
|
||||
|
||||
expected = """
|
||||
if True:
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
|
||||
def target_method(self) -> None:
|
||||
self.var2 = "test"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_conditional_class_definitions() -> None:
|
||||
code = """
|
||||
if PLATFORM == "linux":
|
||||
class PlatformClass:
|
||||
platform = "linux"
|
||||
def target_method(self):
|
||||
print("linux")
|
||||
elif PLATFORM == "windows":
|
||||
class PlatformClass:
|
||||
platform = "windows"
|
||||
def target_method(self):
|
||||
print("windows")
|
||||
else:
|
||||
class PlatformClass:
|
||||
platform = "other"
|
||||
def target_method(self):
|
||||
print("other")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
if PLATFORM == "linux":
|
||||
class PlatformClass:
|
||||
platform = "linux"
|
||||
def target_method(self):
|
||||
print("linux")
|
||||
elif PLATFORM == "windows":
|
||||
class PlatformClass:
|
||||
platform = "windows"
|
||||
def target_method(self):
|
||||
print("windows")
|
||||
else:
|
||||
class PlatformClass:
|
||||
platform = "other"
|
||||
def target_method(self):
|
||||
print("other")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_try_except_structure() -> None:
|
||||
code = """
|
||||
try:
|
||||
class TargetClass:
|
||||
attr = "value"
|
||||
def target_method(self):
|
||||
return 42
|
||||
except ValueError:
|
||||
class ErrorClass:
|
||||
def handle_error(self):
|
||||
print("error")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
try:
|
||||
class TargetClass:
|
||||
attr = "value"
|
||||
def target_method(self):
|
||||
return 42
|
||||
except ValueError:
|
||||
class ErrorClass:
|
||||
def handle_error(self):
|
||||
print("error")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TargetClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_module_var() -> None:
|
||||
code = """
|
||||
def target_function(self) -> None:
|
||||
self.var2 = "test"
|
||||
|
||||
x = 5
|
||||
|
||||
def some_function():
|
||||
print("wow")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
def target_function(self) -> None:
|
||||
self.var2 = "test"
|
||||
|
||||
x = 5
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
def test_module_var_if() -> None:
|
||||
code = """
|
||||
def target_function(self) -> None:
|
||||
var2 = "test"
|
||||
|
||||
if y:
|
||||
x = 5
|
||||
else:
|
||||
z = 10
|
||||
def some_function():
|
||||
print("wow")
|
||||
|
||||
def some_function():
|
||||
print("wow")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
def target_function(self) -> None:
|
||||
var2 = "test"
|
||||
|
||||
if y:
|
||||
x = 5
|
||||
else:
|
||||
z = 10
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
def test_multiple_classes() -> None:
|
||||
code = """
|
||||
class ClassA:
|
||||
def process(self):
|
||||
return "A"
|
||||
|
||||
class ClassB:
|
||||
def process(self):
|
||||
return "B"
|
||||
|
||||
class ClassC:
|
||||
def process(self):
|
||||
return "C"
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class ClassA:
|
||||
def process(self):
|
||||
return "A"
|
||||
|
||||
class ClassC:
|
||||
def process(self):
|
||||
return "C"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"ClassA.process", "ClassC.process"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_with_statement_and_loops() -> None:
|
||||
code = """
|
||||
with context_manager() as ctx:
|
||||
while attempt_count < max_attempts:
|
||||
try:
|
||||
for item in items:
|
||||
if item.ready:
|
||||
class TestClass:
|
||||
context = "ready"
|
||||
def target_method(self):
|
||||
print("ready")
|
||||
else:
|
||||
class TestClass:
|
||||
context = "not_ready"
|
||||
def target_method(self):
|
||||
print("not ready")
|
||||
except ConnectionError:
|
||||
class TestClass:
|
||||
context = "connection_error"
|
||||
def target_method(self):
|
||||
print("connection error")
|
||||
continue
|
||||
finally:
|
||||
class TestClass:
|
||||
context = "cleanup"
|
||||
def target_method(self):
|
||||
print("cleanup")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
with context_manager() as ctx:
|
||||
while attempt_count < max_attempts:
|
||||
try:
|
||||
for item in items:
|
||||
if item.ready:
|
||||
class TestClass:
|
||||
context = "ready"
|
||||
def target_method(self):
|
||||
print("ready")
|
||||
else:
|
||||
class TestClass:
|
||||
context = "not_ready"
|
||||
def target_method(self):
|
||||
print("not ready")
|
||||
except ConnectionError:
|
||||
class TestClass:
|
||||
context = "connection_error"
|
||||
def target_method(self):
|
||||
print("connection error")
|
||||
continue
|
||||
finally:
|
||||
class TestClass:
|
||||
context = "cleanup"
|
||||
def target_method(self):
|
||||
print("cleanup")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_async_with_try_except() -> None:
|
||||
code = """
|
||||
async with async_context() as ctx:
|
||||
try:
|
||||
async for item in items:
|
||||
if await item.is_valid():
|
||||
class TestClass:
|
||||
status = "valid"
|
||||
async def target_method(self):
|
||||
await self.process()
|
||||
elif await item.can_retry():
|
||||
continue
|
||||
else:
|
||||
break
|
||||
except AsyncIOError:
|
||||
class TestClass:
|
||||
status = "io_error"
|
||||
async def target_method(self):
|
||||
await self.handle_error()
|
||||
except CancelledError:
|
||||
class TestClass:
|
||||
status = "cancelled"
|
||||
async def target_method(self):
|
||||
await self.cleanup()
|
||||
"""
|
||||
|
||||
expected = """
|
||||
async with async_context() as ctx:
|
||||
try:
|
||||
async for item in items:
|
||||
if await item.is_valid():
|
||||
class TestClass:
|
||||
status = "valid"
|
||||
async def target_method(self):
|
||||
await self.process()
|
||||
elif await item.can_retry():
|
||||
continue
|
||||
else:
|
||||
break
|
||||
except AsyncIOError:
|
||||
class TestClass:
|
||||
status = "io_error"
|
||||
async def target_method(self):
|
||||
await self.handle_error()
|
||||
except CancelledError:
|
||||
class TestClass:
|
||||
status = "cancelled"
|
||||
async def target_method(self):
|
||||
await self.cleanup()
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
def test_simplified_complete_implementation() -> None:
|
||||
code = """
|
||||
class DataProcessor:
|
||||
\"\"\"A simple data processing class.\"\"\"
|
||||
|
||||
def __init__(self, data: Dict[str, Any]) -> None:
|
||||
self.data = data
|
||||
self._processed = False
|
||||
self.result = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DataProcessor(processed={self._processed})"
|
||||
|
||||
def target_method(self, key: str) -> Optional[Any]:
|
||||
\"\"\"Process and retrieve a specific key from the data.\"\"\"
|
||||
if not self._processed:
|
||||
self._process_data()
|
||||
return self.result.get(key) if self.result else None
|
||||
|
||||
def _process_data(self) -> None:
|
||||
\"\"\"Internal method to process the data.\"\"\"
|
||||
processed = {}
|
||||
for key, value in self.data.items():
|
||||
if isinstance(value, (int, float)):
|
||||
processed[key] = value * 2
|
||||
elif isinstance(value, str):
|
||||
processed[key] = value.upper()
|
||||
else:
|
||||
processed[key] = value
|
||||
self.result = processed
|
||||
self._processed = True
|
||||
|
||||
def to_json(self) -> str:
|
||||
\"\"\"Convert the processed data to JSON string.\"\"\"
|
||||
if not self._processed:
|
||||
self._process_data()
|
||||
return json.dumps(self.result)
|
||||
|
||||
try:
|
||||
sample_data = {"number": 42, "text": "hello"}
|
||||
processor = DataProcessor(sample_data)
|
||||
|
||||
class ResultHandler:
|
||||
def __init__(self, processor: DataProcessor):
|
||||
self.processor = processor
|
||||
self.cache = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ResultHandler(cache_size={len(self.cache)})"
|
||||
|
||||
def target_method(self, key: str) -> Optional[Any]:
|
||||
\"\"\"Retrieve and cache results for a key.\"\"\"
|
||||
if key not in self.cache:
|
||||
self.cache[key] = self.processor.target_method(key)
|
||||
return self.cache[key]
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
\"\"\"Clear the internal cache.\"\"\"
|
||||
self.cache.clear()
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
\"\"\"Get cache statistics.\"\"\"
|
||||
return {
|
||||
"cache_size": len(self.cache),
|
||||
"hits": sum(1 for v in self.cache.values() if v is not None)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
class ResultHandler:
|
||||
def __init__(self):
|
||||
self.error = str(e)
|
||||
|
||||
def target_method(self, key: str) -> None:
|
||||
raise RuntimeError(f"Failed to initialize: {self.error}")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class DataProcessor:
|
||||
\"\"\"A simple data processing class.\"\"\"
|
||||
|
||||
def __init__(self, data: Dict[str, Any]) -> None:
|
||||
self.data = data
|
||||
self._processed = False
|
||||
self.result = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DataProcessor(processed={self._processed})"
|
||||
|
||||
def target_method(self, key: str) -> Optional[Any]:
|
||||
\"\"\"Process and retrieve a specific key from the data.\"\"\"
|
||||
if not self._processed:
|
||||
self._process_data()
|
||||
return self.result.get(key) if self.result else None
|
||||
|
||||
try:
|
||||
sample_data = {"number": 42, "text": "hello"}
|
||||
processor = DataProcessor(sample_data)
|
||||
|
||||
class ResultHandler:
|
||||
def __init__(self, processor: DataProcessor):
|
||||
self.processor = processor
|
||||
self.cache = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ResultHandler(cache_size={len(self.cache)})"
|
||||
|
||||
def target_method(self, key: str) -> Optional[Any]:
|
||||
\"\"\"Retrieve and cache results for a key.\"\"\"
|
||||
if key not in self.cache:
|
||||
self.cache[key] = self.processor.target_method(key)
|
||||
return self.cache[key]
|
||||
|
||||
except Exception as e:
|
||||
class ResultHandler:
|
||||
def __init__(self):
|
||||
self.error = str(e)
|
||||
|
||||
def target_method(self, key: str) -> None:
|
||||
raise RuntimeError(f"Failed to initialize: {self.error}")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_simplified_complete_implementation_no_docstring() -> None:
|
||||
code = """
|
||||
class DataProcessor:
|
||||
\"\"\"A simple data processing class.\"\"\"
|
||||
def __repr__(self) -> str:
|
||||
return f"DataProcessor(processed={self._processed})"
|
||||
|
||||
def target_method(self, key: str) -> Optional[Any]:
|
||||
\"\"\"Process and retrieve a specific key from the data.\"\"\"
|
||||
if not self._processed:
|
||||
self._process_data()
|
||||
return self.result.get(key) if self.result else None
|
||||
|
||||
def _process_data(self) -> None:
|
||||
\"\"\"Internal method to process the data.\"\"\"
|
||||
processed = {}
|
||||
for key, value in self.data.items():
|
||||
if isinstance(value, (int, float)):
|
||||
processed[key] = value * 2
|
||||
elif isinstance(value, str):
|
||||
processed[key] = value.upper()
|
||||
else:
|
||||
processed[key] = value
|
||||
self.result = processed
|
||||
self._processed = True
|
||||
|
||||
def to_json(self) -> str:
|
||||
\"\"\"Convert the processed data to JSON string.\"\"\"
|
||||
if not self._processed:
|
||||
self._process_data()
|
||||
return json.dumps(self.result)
|
||||
|
||||
try:
|
||||
sample_data = {"number": 42, "text": "hello"}
|
||||
processor = DataProcessor(sample_data)
|
||||
|
||||
class ResultHandler:
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ResultHandler(cache_size={len(self.cache)})"
|
||||
|
||||
def target_method(self, key: str) -> Optional[Any]:
|
||||
\"\"\"Retrieve and cache results for a key.\"\"\"
|
||||
if key not in self.cache:
|
||||
self.cache[key] = self.processor.target_method(key)
|
||||
return self.cache[key]
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
\"\"\"Clear the internal cache.\"\"\"
|
||||
self.cache.clear()
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
\"\"\"Get cache statistics.\"\"\"
|
||||
return {
|
||||
"cache_size": len(self.cache),
|
||||
"hits": sum(1 for v in self.cache.values() if v is not None)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
class ResultHandler:
|
||||
|
||||
def target_method(self, key: str) -> None:
|
||||
raise RuntimeError(f"Failed to initialize: {self.error}")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class DataProcessor:
|
||||
def __repr__(self) -> str:
|
||||
return f"DataProcessor(processed={self._processed})"
|
||||
|
||||
def target_method(self, key: str) -> Optional[Any]:
|
||||
if not self._processed:
|
||||
self._process_data()
|
||||
return self.result.get(key) if self.result else None
|
||||
|
||||
try:
|
||||
sample_data = {"number": 42, "text": "hello"}
|
||||
processor = DataProcessor(sample_data)
|
||||
|
||||
class ResultHandler:
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ResultHandler(cache_size={len(self.cache)})"
|
||||
|
||||
def target_method(self, key: str) -> Optional[Any]:
|
||||
if key not in self.cache:
|
||||
self.cache[key] = self.processor.target_method(key)
|
||||
return self.cache[key]
|
||||
|
||||
except Exception as e:
|
||||
class ResultHandler:
|
||||
|
||||
def target_method(self, key: str) -> None:
|
||||
raise RuntimeError(f"Failed to initialize: {self.error}")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
|
@ -2674,13 +2674,13 @@ def test_code_replacement10() -> None:
|
|||
project_root=str(file_path.parent),
|
||||
original_source_code=original_code,
|
||||
).unwrap()
|
||||
assert code_context.code_to_optimize_with_helpers == get_code_output
|
||||
assert code_context.testgen_context_code == get_code_output
|
||||
code_context = opt.get_code_optimization_context(
|
||||
function_to_optimize=func_top_optimize,
|
||||
project_root=str(file_path.parent),
|
||||
original_source_code=original_code,
|
||||
)
|
||||
assert code_context.code_to_optimize_with_helpers == get_code_output
|
||||
assert code_context.testgen_context_code == get_code_output
|
||||
"""
|
||||
|
||||
expected = """import gc
|
||||
|
|
@ -2739,9 +2739,9 @@ def test_code_replacement10() -> None:
|
|||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_1', codeflash_loop_index, codeflash_cur, codeflash_con, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code).unwrap()
|
||||
assert code_context.code_to_optimize_with_helpers == get_code_output
|
||||
assert code_context.testgen_context_code == get_code_output
|
||||
code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_3', codeflash_loop_index, codeflash_cur, codeflash_con, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code)
|
||||
assert code_context.code_to_optimize_with_helpers == get_code_output
|
||||
assert code_context.testgen_context_code == get_code_output
|
||||
codeflash_con.close()
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,103 +1,103 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pathlib
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from codeflash.code_utils.code_extractor import get_code
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
|
||||
|
||||
|
||||
class CustomType:
|
||||
def __init__(self) -> None:
|
||||
self.name = None
|
||||
self.data: List[int] = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomDataClass:
|
||||
name: str = ""
|
||||
data: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
def function_to_optimize(data: CustomType) -> CustomType:
|
||||
name = data.name
|
||||
data.data.sort()
|
||||
return data
|
||||
|
||||
|
||||
def function_to_optimize2(data: CustomDataClass) -> CustomType:
|
||||
name = data.name
|
||||
data.data.sort()
|
||||
return data
|
||||
|
||||
|
||||
def function_to_optimize3(data: dict[CustomDataClass, list[CustomDataClass]]) -> list[CustomType] | None:
|
||||
name = data.name
|
||||
data.data.sort()
|
||||
return data
|
||||
|
||||
|
||||
def test_function_context_includes_type_annotation() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
|
||||
FunctionToOptimize("function_to_optimize", str(file_path), []),
|
||||
str(file_path.parent.resolve()),
|
||||
"""def function_to_optimize(data: CustomType):
|
||||
name = data.name
|
||||
data.data.sort()
|
||||
return data""",
|
||||
1000,
|
||||
)
|
||||
|
||||
assert len(helper_functions) == 1
|
||||
assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomType"
|
||||
|
||||
|
||||
def test_function_context_includes_type_annotation_dataclass() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
|
||||
FunctionToOptimize("function_to_optimize2", str(file_path), []),
|
||||
str(file_path.parent.resolve()),
|
||||
"""def function_to_optimize2(data: CustomDataClass) -> CustomType:
|
||||
name = data.name
|
||||
data.data.sort()
|
||||
return data""",
|
||||
1000,
|
||||
)
|
||||
|
||||
assert len(helper_functions) == 2
|
||||
assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass"
|
||||
assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType"
|
||||
|
||||
|
||||
def test_function_context_works_for_composite_types() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
|
||||
FunctionToOptimize("function_to_optimize3", str(file_path), []),
|
||||
str(file_path.parent.resolve()),
|
||||
"""def function_to_optimize3(data: set[CustomDataClass[CustomDataClass, int]]) -> list[CustomType]:
|
||||
name = data.name
|
||||
data.data.sort()
|
||||
return data""",
|
||||
1000,
|
||||
)
|
||||
|
||||
assert len(helper_functions) == 2
|
||||
assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass"
|
||||
assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType"
|
||||
|
||||
|
||||
def test_function_context_custom_datatype() -> None:
|
||||
project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize"
|
||||
file_path = project_path / "math_utils.py"
|
||||
code, contextual_dunder_methods = get_code([FunctionToOptimize("cosine_similarity", str(file_path), [])])
|
||||
assert code is not None
|
||||
assert contextual_dunder_methods == set()
|
||||
a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
|
||||
FunctionToOptimize("cosine_similarity", str(file_path), []), str(project_path), code, 1000
|
||||
)
|
||||
|
||||
assert len(helper_functions) == 1
|
||||
assert helper_functions[0].fully_qualified_name == "math_utils.Matrix"
|
||||
# from __future__ import annotations
|
||||
#
|
||||
# import pathlib
|
||||
# from dataclasses import dataclass, field
|
||||
# from typing import List
|
||||
#
|
||||
# from codeflash.code_utils.code_extractor import get_code
|
||||
# from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
# from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
|
||||
#
|
||||
#
|
||||
# class CustomType:
|
||||
# def __init__(self) -> None:
|
||||
# self.name = None
|
||||
# self.data: List[int] = []
|
||||
#
|
||||
#
|
||||
# @dataclass
|
||||
# class CustomDataClass:
|
||||
# name: str = ""
|
||||
# data: List[int] = field(default_factory=list)
|
||||
#
|
||||
#
|
||||
# def function_to_optimize(data: CustomType) -> CustomType:
|
||||
# name = data.name
|
||||
# data.data.sort()
|
||||
# return data
|
||||
#
|
||||
#
|
||||
# def function_to_optimize2(data: CustomDataClass) -> CustomType:
|
||||
# name = data.name
|
||||
# data.data.sort()
|
||||
# return data
|
||||
#
|
||||
#
|
||||
# def function_to_optimize3(data: dict[CustomDataClass, list[CustomDataClass]]) -> list[CustomType] | None:
|
||||
# name = data.name
|
||||
# data.data.sort()
|
||||
# return data
|
||||
#
|
||||
#
|
||||
# def test_function_context_includes_type_annotation() -> None:
|
||||
# file_path = pathlib.Path(__file__).resolve()
|
||||
# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
|
||||
# FunctionToOptimize("function_to_optimize", str(file_path), []),
|
||||
# str(file_path.parent.resolve()),
|
||||
# """def function_to_optimize(data: CustomType):
|
||||
# name = data.name
|
||||
# data.data.sort()
|
||||
# return data""",
|
||||
# 1000,
|
||||
# )
|
||||
#
|
||||
# assert len(helper_functions) == 1
|
||||
# assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomType"
|
||||
#
|
||||
#
|
||||
# def test_function_context_includes_type_annotation_dataclass() -> None:
|
||||
# file_path = pathlib.Path(__file__).resolve()
|
||||
# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
|
||||
# FunctionToOptimize("function_to_optimize2", str(file_path), []),
|
||||
# str(file_path.parent.resolve()),
|
||||
# """def function_to_optimize2(data: CustomDataClass) -> CustomType:
|
||||
# name = data.name
|
||||
# data.data.sort()
|
||||
# return data""",
|
||||
# 1000,
|
||||
# )
|
||||
#
|
||||
# assert len(helper_functions) == 2
|
||||
# assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass"
|
||||
# assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType"
|
||||
#
|
||||
#
|
||||
# def test_function_context_works_for_composite_types() -> None:
|
||||
# file_path = pathlib.Path(__file__).resolve()
|
||||
# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
|
||||
# FunctionToOptimize("function_to_optimize3", str(file_path), []),
|
||||
# str(file_path.parent.resolve()),
|
||||
# """def function_to_optimize3(data: set[CustomDataClass[CustomDataClass, int]]) -> list[CustomType]:
|
||||
# name = data.name
|
||||
# data.data.sort()
|
||||
# return data""",
|
||||
# 1000,
|
||||
# )
|
||||
#
|
||||
# assert len(helper_functions) == 2
|
||||
# assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass"
|
||||
# assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType"
|
||||
#
|
||||
#
|
||||
# def test_function_context_custom_datatype() -> None:
|
||||
# project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize"
|
||||
# file_path = project_path / "math_utils.py"
|
||||
# code, contextual_dunder_methods = get_code([FunctionToOptimize("cosine_similarity", str(file_path), [])])
|
||||
# assert code is not None
|
||||
# assert contextual_dunder_methods == set()
|
||||
# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
|
||||
# FunctionToOptimize("cosine_similarity", str(file_path), []), str(project_path), code, 1000
|
||||
# )
|
||||
#
|
||||
# assert len(helper_functions) == 1
|
||||
# assert helper_functions[0].fully_qualified_name == "math_utils.Matrix"
|
||||
|
|
|
|||
Loading…
Reference in a new issue