Implemented testgen context retrieval. Context retrieved is the union of read-writable code and read-only code. Did some refactors to remove code_to_optimize_with_helpers, and updated tests.

This commit is contained in:
Alvin Ryanputra 2025-03-05 16:40:23 -08:00
parent 680d0da5eb
commit 17a42a218c
18 changed files with 1571 additions and 947 deletions

View file

@ -1,6 +1,3 @@
def sorter(arr):
arr.sort()
return arr
CACHED_TESTS = "import unittest\ndef sorter(arr):\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n return arr\nclass SorterTestCase(unittest.TestCase):\n def test_empty_list(self):\n self.assertEqual(sorter([]), [])\n def test_single_element_list(self):\n self.assertEqual(sorter([5]), [5])\n def test_ascending_order_list(self):\n self.assertEqual(sorter([1, 2, 3, 4, 5]), [1, 2, 3, 4, 5])\n def test_descending_order_list(self):\n self.assertEqual(sorter([5, 4, 3, 2, 1]), [1, 2, 3, 4, 5])\n def test_random_order_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 5]), [1, 2, 3, 4, 5])\n def test_duplicate_elements_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 2, 5, 1]), [1, 1, 2, 2, 3, 4, 5])\n def test_negative_numbers_list(self):\n self.assertEqual(sorter([-5, -2, -8, -1, -3]), [-8, -5, -3, -2, -1])\n def test_mixed_data_types_list(self):\n self.assertEqual(sorter(['apple', 2, 'banana', 1, 'cherry']), [1, 2, 'apple', 'banana', 'cherry'])\n def test_large_input_list(self):\n self.assertEqual(sorter(list(range(1000, 0, -1))), list(range(1, 1001)))\n def test_list_with_none_values(self):\n self.assertEqual(sorter([None, 2, None, 1, None]), [None, None, None, 1, 2])\n def test_list_with_nan_values(self):\n self.assertEqual(sorter([float('nan'), 2, float('nan'), 1, float('nan')]), [1, 2, float('nan'), float('nan'), float('nan')])\n def test_list_with_complex_numbers(self):\n self.assertEqual(sorter([3 + 2j, 1 + 1j, 4 + 3j, 2 + 1j, 5 + 4j]), [1 + 1j, 2 + 1j, 3 + 2j, 4 + 3j, 5 + 4j])\n def test_list_with_custom_class_objects(self):\n class Person:\n def __init__(self, name, age):\n self.name = name\n self.age = age\n def __repr__(self):\n return f\"Person('{self.name}', {self.age})\"\n input_list = [Person('Alice', 25), Person('Bob', 30), Person('Charlie', 20)]\n expected_output = [Person('Charlie', 20), Person('Alice', 25), Person('Bob', 30)]\n self.assertEqual(sorter(input_list), expected_output)\n def test_list_with_uncomparable_elements(self):\n with self.assertRaises(TypeError):\n sorter([5, 'apple', 3, [1, 2, 3], 2])\n def test_list_with_custom_comparison_function(self):\n input_list = [5, 4, 3, 2, 1]\n expected_output = [5, 4, 3, 2, 1]\n self.assertEqual(sorter(input_list, reverse=True), expected_output)\nif __name__ == '__main__':\n unittest.main()"
return arr

View file

@ -9,142 +9,3 @@ def sorter_deps(arr):
dep2_swap(arr, j)
return arr
CACHED_TESTS = """import dill as pickle
import os
def _log__test__values(values, duration, test_name):
iteration = os.environ["CODEFLASH_TEST_ITERATION"]
with open(os.path.join(
'/var/folders/ms/1tz2l1q55w5b7pp4wpdkbjq80000gn/T/codeflash_jk4pzz3w/',
f'test_return_values_{iteration}.bin'), 'ab') as f:
return_bytes = pickle.dumps(values)
_test_name = f"{test_name}".encode("ascii")
f.write(len(_test_name).to_bytes(4, byteorder='big'))
f.write(_test_name)
f.write(duration.to_bytes(8, byteorder='big'))
f.write(len(return_bytes).to_bytes(4, byteorder='big'))
f.write(return_bytes)
import time
import gc
from code_to_optimize.bubble_sort_deps import sorter_deps
import timeout_decorator
import unittest
def dep1_comparer(arr, j: int) -> bool:
return arr[j] > arr[j + 1]
def dep2_swap(arr, j):
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
class TestSorterDeps(unittest.TestCase):
@timeout_decorator.timeout(15, use_signals=True)
def test_integers(self):
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter_deps([5, 3, 2, 4, 1])
duration = time.perf_counter_ns() - counter
gc.enable()
_log__test__values(
return_value, duration,
'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_integers:sorter_deps:0')
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter_deps([10, -3, 0, 2, 7])
duration = time.perf_counter_ns() - counter
gc.enable()
_log__test__values(
return_value, duration,
('code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:'
'TestSorterDeps.test_integers:sorter_deps:1'))
@timeout_decorator.timeout(15, use_signals=True)
def test_floats(self):
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter_deps([3.2, 1.5, 2.7, 4.1, 1.0])
duration = time.perf_counter_ns() - counter
gc.enable()
_log__test__values(return_value, duration,
'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_floats:sorter_deps:0')
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter_deps([-1.1, 0.0, 3.14, 2.71, -0.5])
duration = time.perf_counter_ns() - counter
gc.enable()
_log__test__values(return_value, duration,
'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_floats:sorter_deps:1')
@timeout_decorator.timeout(15, use_signals=True)
def test_identical_elements(self):
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter_deps([1, 1, 1, 1, 1])
duration = time.perf_counter_ns() - counter
gc.enable()
_log__test__values(return_value, duration,
('code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:'
'TestSorterDeps.test_identical_elements:sorter_deps:0'))
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter_deps([3.14, 3.14, 3.14])
duration = time.perf_counter_ns() - counter
gc.enable()
_log__test__values(return_value, duration,
('code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:'
'TestSorterDeps.test_identical_elements:sorter_deps:1'))
@timeout_decorator.timeout(15, use_signals=True)
def test_single_element(self):
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter_deps([5])
duration = time.perf_counter_ns() - counter
gc.enable()
_log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_single_element:sorter_deps:0')
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter_deps([-3.2])
duration = time.perf_counter_ns() - counter
gc.enable()
_log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_single_element:sorter_deps:1')
@timeout_decorator.timeout(15, use_signals=True)
def test_empty_array(self):
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter_deps([])
duration = time.perf_counter_ns() - counter
gc.enable()
_log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_empty_array:sorter_deps:0')
@timeout_decorator.timeout(15, use_signals=True)
def test_strings(self):
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter_deps(['apple', 'banana', 'cherry', 'date'])
duration = time.perf_counter_ns() - counter
gc.enable()
_log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_strings:sorter_deps:0')
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter_deps(['dog', 'cat', 'elephant', 'ant'])
duration = time.perf_counter_ns() - counter
gc.enable()
_log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_strings:sorter_deps:1')
@timeout_decorator.timeout(15, use_signals=True)
def test_mixed_types(self):
with self.assertRaises(TypeError):
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter_deps([1, 'two', 3.0, 'four'])
duration = time.perf_counter_ns() - counter
gc.enable()
_log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_mixed_types:sorter_deps:0_0')
if __name__ == '__main__':
unittest.main()
"""

View file

@ -5,7 +5,4 @@ def sorter(arr):
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr
CACHED_TESTS = "import unittest\ndef sorter(arr):\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n return arr\nclass SorterTestCase(unittest.TestCase):\n def test_empty_list(self):\n self.assertEqual(sorter([]), [])\n def test_single_element_list(self):\n self.assertEqual(sorter([5]), [5])\n def test_ascending_order_list(self):\n self.assertEqual(sorter([1, 2, 3, 4, 5]), [1, 2, 3, 4, 5])\n def test_descending_order_list(self):\n self.assertEqual(sorter([5, 4, 3, 2, 1]), [1, 2, 3, 4, 5])\n def test_random_order_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 5]), [1, 2, 3, 4, 5])\n def test_duplicate_elements_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 2, 5, 1]), [1, 1, 2, 2, 3, 4, 5])\n def test_negative_numbers_list(self):\n self.assertEqual(sorter([-5, -2, -8, -1, -3]), [-8, -5, -3, -2, -1])\n def test_mixed_data_types_list(self):\n self.assertEqual(sorter(['apple', 2, 'banana', 1, 'cherry']), [1, 2, 'apple', 'banana', 'cherry'])\n def test_large_input_list(self):\n self.assertEqual(sorter(list(range(1000, 0, -1))), list(range(1, 1001)))\n def test_list_with_none_values(self):\n self.assertEqual(sorter([None, 2, None, 1, None]), [None, None, None, 1, 2])\n def test_list_with_nan_values(self):\n self.assertEqual(sorter([float('nan'), 2, float('nan'), 1, float('nan')]), [1, 2, float('nan'), float('nan'), float('nan')])\n def test_list_with_complex_numbers(self):\n self.assertEqual(sorter([3 + 2j, 1 + 1j, 4 + 3j, 2 + 1j, 5 + 4j]), [1 + 1j, 2 + 1j, 3 + 2j, 4 + 3j, 5 + 4j])\n def test_list_with_custom_class_objects(self):\n class Person:\n def __init__(self, name, age):\n self.name = name\n self.age = age\n def __repr__(self):\n return f\"Person('{self.name}', {self.age})\"\n input_list = [Person('Alice', 25), Person('Bob', 30), Person('Charlie', 20)]\n expected_output = [Person('Charlie', 20), Person('Alice', 25), Person('Bob', 30)]\n self.assertEqual(sorter(input_list), expected_output)\n def test_list_with_uncomparable_elements(self):\n with self.assertRaises(TypeError):\n sorter([5, 'apple', 3, [1, 2, 3], 2])\n def test_list_with_custom_comparison_function(self):\n input_list = [5, 4, 3, 2, 1]\n expected_output = [5, 4, 3, 2, 1]\n self.assertEqual(sorter(input_list, reverse=True), expected_output)\nif __name__ == '__main__':\n unittest.main()"
return arr

View file

@ -9,110 +9,4 @@ def use_cosine_similarity(
top_k: Optional[int] = 5,
score_threshold: Optional[float] = None,
) -> Tuple[List[Tuple[int, int]], List[float]]:
return cosine_similarity_top_k(X, Y, top_k, score_threshold)
CACHED_TESTS = """import unittest
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Optional, Tuple, Union
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
def cosine_similarity_top_k(X: Matrix, Y: Matrix, top_k: Optional[int]=5, score_threshold: Optional[float]=None) -> Tuple[List[Tuple[int, int]], List[float]]:
\"\"\"Row-wise cosine similarity with optional top-k and score threshold filtering.
Args:
X: Matrix.
Y: Matrix, same width as X.
top_k: Max number of results to return.
score_threshold: Minimum cosine similarity of results.
Returns:
Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx),
second contains corresponding cosine similarities.
\"\"\"
if len(X) == 0 or len(Y) == 0:
return ([], [])
score_array = cosine_similarity(X, Y)
sorted_idxs = score_array.flatten().argsort()[::-1]
top_k = top_k or len(sorted_idxs)
top_idxs = sorted_idxs[:top_k]
score_threshold = score_threshold or -1.0
top_idxs = top_idxs[score_array.flatten()[top_idxs] > score_threshold]
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in top_idxs]
scores = score_array.flatten()[top_idxs].tolist()
return (ret_idxs, scores)
def use_cosine_similarity(X: Matrix, Y: Matrix, top_k: Optional[int]=5, score_threshold: Optional[float]=None) -> Tuple[List[Tuple[int, int]], List[float]]:
return cosine_similarity_top_k(X, Y, top_k, score_threshold)
class TestUseCosineSimilarity(unittest.TestCase):
def test_normal_scenario(self):
X = [[1, 2, 3], [4, 5, 6]]
Y = [[7, 8, 9], [10, 11, 12]]
result = use_cosine_similarity(X, Y, top_k=1, score_threshold=0.5)
self.assertEqual(result, ([(0, 1)], [0.9746318461970762]))
def test_edge_case_empty_matrices(self):
X = []
Y = []
result = use_cosine_similarity(X, Y)
self.assertEqual(result, ([], []))
def test_edge_case_different_widths(self):
X = [[1, 2, 3]]
Y = [[4, 5]]
with self.assertRaises(ValueError):
use_cosine_similarity(X, Y)
def test_edge_case_negative_top_k(self):
X = [[1, 2, 3]]
Y = [[4, 5, 6]]
with self.assertRaises(IndexError):
use_cosine_similarity(X, Y, top_k=-1)
def test_edge_case_zero_top_k(self):
X = [[1, 2, 3]]
Y = [[4, 5, 6]]
result = use_cosine_similarity(X, Y, top_k=0)
self.assertEqual(result, ([], []))
def test_edge_case_negative_score_threshold(self):
X = [[1, 2, 3]]
Y = [[4, 5, 6]]
result = use_cosine_similarity(X, Y, score_threshold=-1.0)
self.assertEqual(result, ([(0, 0)], [0.9746318461970762]))
def test_edge_case_large_score_threshold(self):
X = [[1, 2, 3]]
Y = [[4, 5, 6]]
result = use_cosine_similarity(X, Y, score_threshold=2.0)
self.assertEqual(result, ([], []))
def test_exceptional_case_non_matrix_X(self):
X = [1, 2, 3]
Y = [[4, 5, 6]]
with self.assertRaises(ValueError):
use_cosine_similarity(X, Y)
def test_exceptional_case_non_integer_top_k(self):
X = [[1, 2, 3]]
Y = [[4, 5, 6]]
with self.assertRaises(TypeError):
use_cosine_similarity(X, Y, top_k='5')
def test_exceptional_case_non_float_score_threshold(self):
X = [[1, 2, 3]]
Y = [[4, 5, 6]]
with self.assertRaises(TypeError):
use_cosine_similarity(X, Y, score_threshold='0.5')
def test_special_values_nan_in_matrices(self):
X = [[1, 2, np.nan]]
Y = [[4, 5, 6]]
with self.assertRaises(ValueError):
use_cosine_similarity(X, Y)
def test_special_values_none_top_k(self):
X = [[1, 2, 3]]
Y = [[4, 5, 6]]
result = use_cosine_similarity(X, Y, top_k=None)
self.assertEqual(result, ([(0, 0)], [0.9746318461970762]))
def test_special_values_none_score_threshold(self):
X = [[1, 2, 3]]
Y = [[4, 5, 6]]
result = use_cosine_similarity(X, Y, score_threshold=None)
self.assertEqual(result, ([(0, 0)], [0.9746318461970762]))
def test_large_inputs(self):
X = np.random.rand(1000, 1000)
Y = np.random.rand(1000, 1000)
result = use_cosine_similarity(X, Y, top_k=10, score_threshold=0.5)
self.assertEqual(len(result[0]), 10)
self.assertEqual(len(result[1]), 10)
self.assertTrue(all((score > 0.5 for score in result[1])))
if __name__ == '__main__':
unittest.main()"""
return cosine_similarity_top_k(X, Y, top_k, score_threshold)

View file

@ -12,7 +12,7 @@ if TYPE_CHECKING:
def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]:
"""Extract the single dependent function from the code context excluding the main function."""
ast_tree = ast.parse(code_context.code_to_optimize_with_helpers)
ast_tree = ast.parse(code_context.testgen_context_code)
dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)}

View file

@ -15,12 +15,13 @@ from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource, \
CodeContextType
from codeflash.optimization.function_context import belongs_to_function_qualified
def get_code_optimization_context(
function_to_optimize: FunctionToOptimize, project_root_path: Path, token_limit: int = 8000
function_to_optimize: FunctionToOptimize, project_root_path: Path, optim_token_limit: int = 8000, testgen_token_limit: int = 8000
) -> CodeOptimizationContext:
# Get qualified names and fully qualified names(fqn) of helpers
helpers_of_fto, helpers_of_fto_fqn, helpers_of_fto_obj_list = get_file_path_to_helper_functions_dict(
@ -37,21 +38,22 @@ def get_code_optimization_context(
function_to_optimize.qualified_name_with_modules_from_root(project_root_path)
)
# Extract code
final_read_writable_code = get_all_read_writable_code(helpers_of_fto, helpers_of_fto_fqn, project_root_path).code
read_only_code_markdown = get_all_read_only_code_context(
# Extract code context for optimization
final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto, helpers_of_fto_fqn, project_root_path).code
read_only_code_markdown = extract_code_markdown_context_from_files(
helpers_of_fto,
helpers_of_fto_fqn,
helpers_of_helpers,
helpers_of_helpers_fqn,
project_root_path,
remove_docstrings=False,
code_context_type=CodeContextType.READ_ONLY,
)
# Handle token limits
tokenizer = tiktoken.encoding_for_model("gpt-4o")
final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
if final_read_writable_tokens > token_limit:
if final_read_writable_tokens > optim_token_limit:
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
# Setup preexisting objects for code replacer TODO: should remove duplicates
@ -61,53 +63,82 @@ def get_code_optimization_context(
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),
)
)
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown))
read_only_context_code = read_only_code_markdown.markdown
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_context_code))
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
if total_tokens <= token_limit:
return CodeOptimizationContext(
code_to_optimize_with_helpers="",
read_writable_code=CodeString(code=final_read_writable_code).code,
read_only_context_code=read_only_code_markdown.markdown,
helper_functions=helpers_of_fto_obj_list,
preexisting_objects=preexisting_objects,
if total_tokens > optim_token_limit:
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
# Extract read only code without docstrings
read_only_code_no_docstring_markdown = extract_code_markdown_context_from_files(
helpers_of_fto,
helpers_of_fto_fqn,
helpers_of_helpers,
helpers_of_helpers_fqn,
project_root_path,
remove_docstrings=True,
)
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
# Extract read only code without docstrings
read_only_code_no_docstring_markdown = get_all_read_only_code_context(
read_only_context_code = read_only_code_no_docstring_markdown.markdown
read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_context_code))
total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens
if total_tokens > optim_token_limit:
logger.debug("Code context has exceeded token limit, removing read-only code")
read_only_context_code = ""
# Extract code context for testgen
testgen_code_markdown = extract_code_markdown_context_from_files(
helpers_of_fto,
helpers_of_fto_fqn,
helpers_of_helpers,
helpers_of_helpers_fqn,
project_root_path,
remove_docstrings=True,
remove_docstrings=False,
code_context_type=CodeContextType.TESTGEN,
)
read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_code_no_docstring_markdown.markdown))
total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens
if total_tokens <= token_limit:
return CodeOptimizationContext(
code_to_optimize_with_helpers="",
read_writable_code=CodeString(code=final_read_writable_code).code,
read_only_context_code=read_only_code_no_docstring_markdown.markdown,
helper_functions=helpers_of_fto_obj_list,
preexisting_objects=preexisting_objects,
testgen_context_code = testgen_code_markdown.markdown
testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code))
if testgen_context_code_tokens > testgen_token_limit:
testgen_code_markdown = extract_code_markdown_context_from_files(
helpers_of_fto,
helpers_of_fto_fqn,
helpers_of_helpers,
helpers_of_helpers_fqn,
project_root_path,
remove_docstrings=True,
code_context_type=CodeContextType.TESTGEN,
)
testgen_context_code = testgen_code_markdown.markdown
testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code))
if testgen_context_code_tokens > testgen_token_limit:
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
logger.debug("Code context has exceeded token limit, removing read-only code")
return CodeOptimizationContext(
code_to_optimize_with_helpers="",
testgen_context_code = testgen_context_code,
read_writable_code=CodeString(code=final_read_writable_code).code,
read_only_context_code="",
read_only_context_code=read_only_context_code,
helper_functions=helpers_of_fto_obj_list,
preexisting_objects=preexisting_objects,
)
def get_all_read_writable_code(
def extract_code_string_context_from_files(
helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path
) -> CodeString:
"""Extract read-writable code context from files containing target functions and their helpers.
This function iterates through each file path that contains functions to optimize (fto) or
their first-degree helpers, reads the original code, extracts relevant parts using CST parsing,
and adds necessary imports from the original modules.
Args:
helpers_of_fto: Dictionary mapping file paths to sets of qualified function names
helpers_of_fto_fqn: Dictionary mapping file paths to sets of fully qualified names of functions
project_root_path: Root path of the project for resolving relative imports
Returns:
CodeString object containing the consolidated read-writable code with all necessary
imports for the target functions and their helpers
"""
final_read_writable_code = ""
# Extract code from file paths that contain fto and first degree helpers
for file_path, qualified_function_names in helpers_of_fto.items():
@ -117,7 +148,7 @@ def get_all_read_writable_code(
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
read_writable_code = get_read_writable_code(original_code, qualified_function_names)
read_writable_code = parse_code_and_prune_cst(original_code, CodeContextType.READ_WRITABLE, qualified_function_names)
except ValueError as e:
logger.debug(f"Error while getting read-writable code: {e}")
continue
@ -133,16 +164,38 @@ def get_all_read_writable_code(
helper_functions_fqn=helpers_of_fto_fqn[file_path],
)
return CodeString(code=final_read_writable_code)
def get_all_read_only_code_context(
def extract_code_markdown_context_from_files(
helpers_of_fto: dict[Path, set[str]],
helpers_of_fto_fqn: dict[Path, set[str]],
helpers_of_helpers: dict[Path, set[str]],
helpers_of_helpers_fqn: dict[Path, set[str]],
project_root_path: Path,
remove_docstrings: bool = False,
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
) -> CodeStringsMarkdown:
"""Extract code context from files containing target functions and their helpers, formatting them as markdown.
This function processes two sets of files:
1. Files containing the function to optimize (fto) and their first-degree helpers
2. Files containing only helpers of helpers (with no overlap with the first set)
For each file, it extracts relevant code based on the specified context type, adds necessary
imports, and combines them into a structured markdown format.
Args:
helpers_of_fto: Dictionary mapping file paths to sets of function names to be optimized
helpers_of_fto_fqn: Dictionary mapping file paths to sets of fully qualified names of functions to be optimized
helpers_of_helpers: Dictionary mapping file paths to sets of helper function names
helpers_of_helpers_fqn: Dictionary mapping file paths to sets of fully qualified names of helper functions
project_root_path: Root path of the project
remove_docstrings: Whether to remove docstrings from the extracted code
code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN)
Returns:
CodeStringsMarkdown containing the extracted code context with necessary imports,
formatted for inclusion in markdown
"""
# Rearrange to remove overlaps, so we only access each file path once
helpers_of_helpers_no_overlap = defaultdict(set)
helpers_of_helpers_no_overlap_fqn = defaultdict(set)
@ -155,7 +208,7 @@ def get_all_read_only_code_context(
helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path]
helpers_of_helpers_no_overlap_fqn[file_path] = helpers_of_helpers_fqn[file_path]
read_only_code_markdown = CodeStringsMarkdown()
code_context_markdown = CodeStringsMarkdown()
# Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files
for file_path, qualified_function_names in helpers_of_fto.items():
try:
@ -164,17 +217,18 @@ def get_all_read_only_code_context(
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
read_only_code = get_read_only_code(
original_code, qualified_function_names, helpers_of_helpers.get(file_path, set()), remove_docstrings
code_context = parse_code_and_prune_cst(
original_code, code_context_type, qualified_function_names, helpers_of_helpers.get(file_path, set()), remove_docstrings
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
if read_only_code.strip():
read_only_code_with_imports = CodeString(
if code_context.strip():
code_context_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=read_only_code,
dst_module_code=code_context,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
@ -182,7 +236,7 @@ def get_all_read_only_code_context(
),
file_path=file_path.relative_to(project_root_path),
)
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
code_context_markdown.code_strings.append(code_context_with_imports)
# Extract code from file paths containing helpers of helpers
for file_path, qualified_helper_function_names in helpers_of_helpers_no_overlap.items():
@ -192,18 +246,18 @@ def get_all_read_only_code_context(
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
read_only_code = get_read_only_code(
original_code, set(), qualified_helper_function_names, remove_docstrings
code_context = parse_code_and_prune_cst(
original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
if read_only_code.strip():
read_only_code_with_imports = CodeString(
if code_context.strip():
code_context_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=read_only_code,
dst_module_code=code_context,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
@ -211,8 +265,8 @@ def get_all_read_only_code_context(
),
file_path=file_path.relative_to(project_root_path),
)
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
return read_only_code_markdown
code_context_markdown.code_strings.append(code_context_with_imports)
return code_context_markdown
def get_file_path_to_helper_functions_dict(
@ -221,11 +275,11 @@ def get_file_path_to_helper_functions_dict(
file_path_to_helper_function_qualified_names = defaultdict(set)
file_path_to_helper_function_fqn = defaultdict(set)
function_source_list: list[FunctionSource] = []
for file_path in file_path_to_qualified_function_names:
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
file_refs = script.get_names(all_scopes=True, definitions=False, references=True)
for qualified_function_name in file_path_to_qualified_function_names[file_path]:
for qualified_function_name in qualified_function_names:
names = [
ref
for ref in file_refs
@ -291,6 +345,29 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode
return indented_block.with_changes(body=indented_block.body[1:])
return indented_block
def parse_code_and_prune_cst(
code: str, code_context_type: CodeContextType, target_functions: set[str], helpers_of_helper_functions: set[str] = {}, remove_docstrings: bool = False
) -> str:
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables. """
module = cst.parse_module(code)
if code_context_type == CodeContextType.READ_WRITABLE:
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
elif code_context_type == CodeContextType.READ_ONLY:
filtered_node, found_target = prune_cst_for_read_only_code(
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
)
elif code_context_type == CodeContextType.TESTGEN:
filtered_node, found_target = prune_cst_for_testgen_code(
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
)
else:
raise ValueError(f"Unknown code_context_type: {code_context_type}")
if not found_target:
raise ValueError("No target functions found in the provided code")
if filtered_node and isinstance(filtered_node, cst.Module):
return str(filtered_node.code)
return ""
def prune_cst_for_read_writable_code(
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
@ -371,20 +448,6 @@ def prune_cst_for_read_writable_code(
return (node.with_changes(**updates) if updates else node), True
def get_read_writable_code(code: str, target_functions: set[str]) -> str:
"""Creates a read-writable code string by parsing and filtering the code to keep only
target functions and the minimal surrounding structure.
"""
module = cst.parse_module(code)
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
if not found_target:
raise ValueError("No target functions found in the provided code")
if filtered_node and isinstance(filtered_node, cst.Module):
return str(filtered_node.code)
return ""
def prune_cst_for_read_only_code(
node: cst.CSTNode,
target_functions: set[str],
@ -489,18 +552,107 @@ def prune_cst_for_read_only_code(
return None, False
def get_read_only_code(
code: str, target_functions: set[str], helpers_of_helper_functions: set[str], remove_docstrings: bool = False
) -> str:
"""Creates a read-only version of the code by parsing and filtering the code to keep only
class contextual information, and other module scoped variables.
def prune_cst_for_testgen_code(
node: cst.CSTNode,
target_functions: set[str],
helpers_of_helper_functions: set[str],
prefix: str = "",
remove_docstrings: bool = False,
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node for testgen context:
Returns:
(filtered_node, found_target):
filtered_node: The modified CST node or None if it should be removed.
found_target: True if a target function was found in this node's subtree.
"""
module = cst.parse_module(code)
filtered_node, found_target = prune_cst_for_read_only_code(
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
)
if not found_target:
raise ValueError("No target functions found in the provided code")
if filtered_node and isinstance(filtered_node, cst.Module):
return str(filtered_node.code)
return ""
if isinstance(node, (cst.Import, cst.ImportFrom)):
return None, False
if isinstance(node, cst.FunctionDef):
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
# If it's a target function, remove it but mark found_target = True
if qualified_name in helpers_of_helper_functions or qualified_name in target_functions:
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
new_body = remove_docstring_from_body(node.body)
return node.with_changes(body=new_body), True
return node, True
# Keep all dunder methods
if is_dunder_method(node.name.value):
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
new_body = remove_docstring_from_body(node.body)
return node.with_changes(body=new_body), False
return node, False
return None, False
if isinstance(node, cst.ClassDef):
# Do not recurse into nested classes
if prefix:
return None, False
# Assuming always an IndentedBlock
if not isinstance(node.body, cst.IndentedBlock):
raise ValueError("ClassDef body is not an IndentedBlock")
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
# First pass: detect if there is a target function in the class
found_in_class = False
new_class_body: list[CSTNode] = []
for stmt in node.body.body:
filtered, found_target = prune_cst_for_testgen_code(
stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings
)
found_in_class |= found_target
if filtered:
new_class_body.append(filtered)
if not found_in_class:
return None, False
if remove_docstrings:
return node.with_changes(
body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))
) if new_class_body else None, True
return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True
# For other nodes, keep the node and recursively filter children
section_names = get_section_names(node)
if not section_names:
return node, False
updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
found_any_target = False
for section in section_names:
original_content = getattr(node, section, None)
if isinstance(original_content, (list, tuple)):
new_children = []
section_found_target = False
for child in original_content:
filtered, found_target = prune_cst_for_testgen_code(
child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings
)
if filtered:
new_children.append(filtered)
section_found_target |= found_target
if section_found_target or new_children:
found_any_target |= section_found_target
updates[section] = new_children
elif original_content is not None:
filtered, found_target = prune_cst_for_testgen_code(
original_content,
target_functions,
helpers_of_helper_functions,
prefix,
remove_docstrings=remove_docstrings,
)
found_any_target |= found_target
if filtered:
updates[section] = filtered
if updates:
return (node.with_changes(**updates), found_any_target)
return None, False

View file

@ -72,6 +72,7 @@ class CodeStringsMarkdown(BaseModel):
@property
def markdown(self) -> str:
"""Returns the markdown representation of the code, including the file path where possible."""
return "\n".join(
[
f"```python{':' + str(code_string.file_path) if code_string.file_path else ''}\n{code_string.code.strip()}\n```"
@ -81,12 +82,18 @@ class CodeStringsMarkdown(BaseModel):
class CodeOptimizationContext(BaseModel):
code_to_optimize_with_helpers: str
# code_to_optimize_with_helpers: str
testgen_context_code: str = ""
read_writable_code: str = Field(min_length=1)
read_only_context_code: str = ""
helper_functions: list[FunctionSource]
preexisting_objects: list[tuple[str, list[FunctionParent]]]
class CodeContextType(str, Enum):
READ_WRITABLE = "READ_WRITABLE"
READ_ONLY = "READ_ONLY"
TESTGEN = "TESTGEN"
class OptimizedCandidateResult(BaseModel):
max_loop_count: int

View file

@ -60,240 +60,240 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b
except ValueError:
return False
def get_type_annotation_context(
function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
function_name: str = function.function_name
file_path: Path = function.file_path
file_contents: str = file_path.read_text(encoding="utf8")
try:
module: ast.Module = ast.parse(file_contents)
except SyntaxError as e:
logger.exception(f"get_type_annotation_context - Syntax error in code: {e}")
return [], set()
sources: list[FunctionSource] = []
ast_parents: list[FunctionParent] = []
contextual_dunder_methods = set()
def get_annotation_source(
j_script: jedi.Script, name: str, node_parents: list[FunctionParent], line_no: int, col_no: str
) -> None:
try:
definition: list[Name] = j_script.goto(
line=line_no, column=col_no, follow_imports=True, follow_builtin_imports=False
)
except Exception as ex:
if hasattr(name, "full_name"):
logger.exception(f"Error while getting definition for {name.full_name}: {ex}")
else:
logger.exception(f"Error while getting definition: {ex}")
definition = []
if definition: # TODO can be multiple definitions
definition_path = definition[0].module_path
# The definition is part of this project and not defined within the original function
if (
str(definition_path).startswith(str(project_root_path) + os.sep)
and definition[0].full_name
and not path_belongs_to_site_packages(definition_path)
and not belongs_to_function(definition[0], function_name)
):
source_code = get_code([FunctionToOptimize(definition[0].name, definition_path, node_parents[:-1])])
if source_code[0]:
sources.append(
FunctionSource(
fully_qualified_name=definition[0].full_name,
jedi_definition=definition[0],
source_code=source_code[0],
file_path=definition_path,
qualified_name=definition[0].full_name.removeprefix(definition[0].module_name + "."),
only_function_name=definition[0].name,
)
)
contextual_dunder_methods.update(source_code[1])
def visit_children(
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, node_parents: list[FunctionParent]
) -> None:
child: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module
for child in ast.iter_child_nodes(node):
visit(child, node_parents)
def visit_all_annotation_children(
node: ast.Subscript | ast.Name | ast.BinOp, node_parents: list[FunctionParent]
) -> None:
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
visit_all_annotation_children(node.left, node_parents)
visit_all_annotation_children(node.right, node_parents)
if isinstance(node, ast.Name) and hasattr(node, "id"):
name: str = node.id
line_no: int = node.lineno
col_no: int = node.col_offset
get_annotation_source(jedi_script, name, node_parents, line_no, col_no)
if isinstance(node, ast.Subscript):
if hasattr(node, "slice"):
if isinstance(node.slice, ast.Subscript):
visit_all_annotation_children(node.slice, node_parents)
elif isinstance(node.slice, ast.Tuple):
for elt in node.slice.elts:
if isinstance(elt, (ast.Name, ast.Subscript)):
visit_all_annotation_children(elt, node_parents)
elif isinstance(node.slice, ast.Name):
visit_all_annotation_children(node.slice, node_parents)
if hasattr(node, "value"):
visit_all_annotation_children(node.value, node_parents)
def visit(
node: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module,
node_parents: list[FunctionParent],
) -> None:
if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if node.name == function_name and node_parents == function.parents:
arg: ast.arg
for arg in node.args.args:
if arg.annotation:
visit_all_annotation_children(arg.annotation, node_parents)
if node.returns:
visit_all_annotation_children(node.returns, node_parents)
if not isinstance(node, ast.Module):
node_parents.append(FunctionParent(node.name, type(node).__name__))
visit_children(node, node_parents)
if not isinstance(node, ast.Module):
node_parents.pop()
visit(module, ast_parents)
return sources, contextual_dunder_methods
#
# def get_type_annotation_context(
# function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path
# ) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
# function_name: str = function.function_name
# file_path: Path = function.file_path
# file_contents: str = file_path.read_text(encoding="utf8")
# try:
# module: ast.Module = ast.parse(file_contents)
# except SyntaxError as e:
# logger.exception(f"get_type_annotation_context - Syntax error in code: {e}")
# return [], set()
# sources: list[FunctionSource] = []
# ast_parents: list[FunctionParent] = []
# contextual_dunder_methods = set()
#
# def get_annotation_source(
# j_script: jedi.Script, name: str, node_parents: list[FunctionParent], line_no: int, col_no: str
# ) -> None:
# try:
# definition: list[Name] = j_script.goto(
# line=line_no, column=col_no, follow_imports=True, follow_builtin_imports=False
# )
# except Exception as ex:
# if hasattr(name, "full_name"):
# logger.exception(f"Error while getting definition for {name.full_name}: {ex}")
# else:
# logger.exception(f"Error while getting definition: {ex}")
# definition = []
# if definition: # TODO can be multiple definitions
# definition_path = definition[0].module_path
#
# # The definition is part of this project and not defined within the original function
# if (
# str(definition_path).startswith(str(project_root_path) + os.sep)
# and definition[0].full_name
# and not path_belongs_to_site_packages(definition_path)
# and not belongs_to_function(definition[0], function_name)
# ):
# source_code = get_code([FunctionToOptimize(definition[0].name, definition_path, node_parents[:-1])])
# if source_code[0]:
# sources.append(
# FunctionSource(
# fully_qualified_name=definition[0].full_name,
# jedi_definition=definition[0],
# source_code=source_code[0],
# file_path=definition_path,
# qualified_name=definition[0].full_name.removeprefix(definition[0].module_name + "."),
# only_function_name=definition[0].name,
# )
# )
# contextual_dunder_methods.update(source_code[1])
#
# def visit_children(
# node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, node_parents: list[FunctionParent]
# ) -> None:
# child: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module
# for child in ast.iter_child_nodes(node):
# visit(child, node_parents)
#
# def visit_all_annotation_children(
# node: ast.Subscript | ast.Name | ast.BinOp, node_parents: list[FunctionParent]
# ) -> None:
# if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
# visit_all_annotation_children(node.left, node_parents)
# visit_all_annotation_children(node.right, node_parents)
# if isinstance(node, ast.Name) and hasattr(node, "id"):
# name: str = node.id
# line_no: int = node.lineno
# col_no: int = node.col_offset
# get_annotation_source(jedi_script, name, node_parents, line_no, col_no)
# if isinstance(node, ast.Subscript):
# if hasattr(node, "slice"):
# if isinstance(node.slice, ast.Subscript):
# visit_all_annotation_children(node.slice, node_parents)
# elif isinstance(node.slice, ast.Tuple):
# for elt in node.slice.elts:
# if isinstance(elt, (ast.Name, ast.Subscript)):
# visit_all_annotation_children(elt, node_parents)
# elif isinstance(node.slice, ast.Name):
# visit_all_annotation_children(node.slice, node_parents)
# if hasattr(node, "value"):
# visit_all_annotation_children(node.value, node_parents)
#
# def visit(
# node: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module,
# node_parents: list[FunctionParent],
# ) -> None:
# if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
# if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
# if node.name == function_name and node_parents == function.parents:
# arg: ast.arg
# for arg in node.args.args:
# if arg.annotation:
# visit_all_annotation_children(arg.annotation, node_parents)
# if node.returns:
# visit_all_annotation_children(node.returns, node_parents)
#
# if not isinstance(node, ast.Module):
# node_parents.append(FunctionParent(node.name, type(node).__name__))
# visit_children(node, node_parents)
# if not isinstance(node, ast.Module):
# node_parents.pop()
#
# visit(module, ast_parents)
#
# return sources, contextual_dunder_methods
def get_function_variables_definitions(
function_to_optimize: FunctionToOptimize, project_root_path: Path
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
function_name = function_to_optimize.function_name
file_path = function_to_optimize.file_path
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
sources: list[FunctionSource] = []
contextual_dunder_methods = set()
# TODO: The function name condition can be stricter so that it does not clash with other class names etc.
# TODO: The function could have been imported as some other name,
# we should be checking for the translation as well. Also check for the original function name.
names = []
for ref in script.get_names(all_scopes=True, definitions=False, references=True):
if ref.full_name:
if function_to_optimize.parents:
# Check if the reference belongs to the specified class when FunctionParent is provided
if belongs_to_method(ref, function_to_optimize.parents[-1].name, function_name):
names.append(ref)
elif belongs_to_function(ref, function_name):
names.append(ref)
for name in names:
try:
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
except Exception as e:
try:
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
except Exception as e:
# name.full_name can also throw exceptions sometimes
logger.exception(f"Error while getting definition: {e}")
definitions = []
if definitions:
# TODO: there can be multiple definitions, see how to handle such cases
definition = definitions[0]
definition_path = definition.module_path
# The definition is part of this project and not defined within the original function
if (
str(definition_path).startswith(str(project_root_path) + os.sep)
and not path_belongs_to_site_packages(definition_path)
and definition.full_name
and not belongs_to_function(definition, function_name)
):
module_name = module_name_from_file_path(definition_path, project_root_path)
m = re.match(rf"{module_name}\.(.*)\.{definitions[0].name}", definitions[0].full_name)
parents = []
if m:
parents = [FunctionParent(m.group(1), "ClassDef")]
source_code = get_code(
[FunctionToOptimize(function_name=definitions[0].name, file_path=definition_path, parents=parents)]
)
if source_code[0]:
sources.append(
FunctionSource(
fully_qualified_name=definition.full_name,
jedi_definition=definition,
source_code=source_code[0],
file_path=definition_path,
qualified_name=definition.full_name.removeprefix(definition.module_name + "."),
only_function_name=definition.name,
)
)
contextual_dunder_methods.update(source_code[1])
annotation_sources, annotation_dunder_methods = get_type_annotation_context(
function_to_optimize, script, project_root_path
)
sources[:0] = annotation_sources # prepend the annotation sources
contextual_dunder_methods.update(annotation_dunder_methods)
existing_fully_qualified_names = set()
no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(lambda: defaultdict(set))
parent_sources = set()
for source in sources:
if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names:
if not source.qualified_name.count("."):
no_parent_sources[source.file_path][source.qualified_name].add(source)
else:
parent_sources.add(source)
existing_fully_qualified_names.add(fully_qualified_name)
deduped_parent_sources = [
source
for source in parent_sources
if source.file_path not in no_parent_sources
or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path]
]
deduped_no_parent_sources = [
source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2]
]
return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods
MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k
def get_constrained_function_context_and_helper_functions(
function_to_optimize: FunctionToOptimize,
project_root_path: Path,
code_to_optimize: str,
max_tokens: int = MAX_PROMPT_TOKENS,
) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]:
helper_functions, dunder_methods = get_function_variables_definitions(function_to_optimize, project_root_path)
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
code_to_optimize_tokens = tokenizer.encode(code_to_optimize)
if not function_to_optimize.parents:
helper_functions_sources = [function.source_code for function in helper_functions]
else:
helper_functions_sources = [
function.source_code
for function in helper_functions
if not function.qualified_name.count(".")
or function.qualified_name.split(".")[0] != function_to_optimize.parents[0].name
]
helper_functions_tokens = [len(tokenizer.encode(function)) for function in helper_functions_sources]
context_list = []
context_len = len(code_to_optimize_tokens)
logger.debug(f"ORIGINAL CODE TOKENS LENGTH: {context_len}")
logger.debug(f"ALL DEPENDENCIES TOKENS LENGTH: {sum(helper_functions_tokens)}")
for function_source, source_len in zip(helper_functions_sources, helper_functions_tokens):
if context_len + source_len <= max_tokens:
context_list.append(function_source)
context_len += source_len
else:
break
logger.debug(f"FINAL OPTIMIZATION CONTEXT TOKENS LENGTH: {context_len}")
helper_code: str = "\n".join(context_list)
return helper_code, helper_functions, dunder_methods
# def get_function_variables_definitions(
# function_to_optimize: FunctionToOptimize, project_root_path: Path
# ) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
# function_name = function_to_optimize.function_name
# file_path = function_to_optimize.file_path
# script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
# sources: list[FunctionSource] = []
# contextual_dunder_methods = set()
# # TODO: The function name condition can be stricter so that it does not clash with other class names etc.
# # TODO: The function could have been imported as some other name,
# # we should be checking for the translation as well. Also check for the original function name.
# names = []
# for ref in script.get_names(all_scopes=True, definitions=False, references=True):
# if ref.full_name:
# if function_to_optimize.parents:
# # Check if the reference belongs to the specified class when FunctionParent is provided
# if belongs_to_method(ref, function_to_optimize.parents[-1].name, function_name):
# names.append(ref)
# elif belongs_to_function(ref, function_name):
# names.append(ref)
#
# for name in names:
# try:
# definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
# except Exception as e:
# try:
# logger.exception(f"Error while getting definition for {name.full_name}: {e}")
# except Exception as e:
# # name.full_name can also throw exceptions sometimes
# logger.exception(f"Error while getting definition: {e}")
# definitions = []
# if definitions:
# # TODO: there can be multiple definitions, see how to handle such cases
# definition = definitions[0]
# definition_path = definition.module_path
#
# # The definition is part of this project and not defined within the original function
# if (
# str(definition_path).startswith(str(project_root_path) + os.sep)
# and not path_belongs_to_site_packages(definition_path)
# and definition.full_name
# and not belongs_to_function(definition, function_name)
# ):
# module_name = module_name_from_file_path(definition_path, project_root_path)
# m = re.match(rf"{module_name}\.(.*)\.{definitions[0].name}", definitions[0].full_name)
# parents = []
# if m:
# parents = [FunctionParent(m.group(1), "ClassDef")]
#
# source_code = get_code(
# [FunctionToOptimize(function_name=definitions[0].name, file_path=definition_path, parents=parents)]
# )
# if source_code[0]:
# sources.append(
# FunctionSource(
# fully_qualified_name=definition.full_name,
# jedi_definition=definition,
# source_code=source_code[0],
# file_path=definition_path,
# qualified_name=definition.full_name.removeprefix(definition.module_name + "."),
# only_function_name=definition.name,
# )
# )
# contextual_dunder_methods.update(source_code[1])
# annotation_sources, annotation_dunder_methods = get_type_annotation_context(
# function_to_optimize, script, project_root_path
# )
# sources[:0] = annotation_sources # prepend the annotation sources
# contextual_dunder_methods.update(annotation_dunder_methods)
# existing_fully_qualified_names = set()
# no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(lambda: defaultdict(set))
# parent_sources = set()
# for source in sources:
# if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names:
# if not source.qualified_name.count("."):
# no_parent_sources[source.file_path][source.qualified_name].add(source)
# else:
# parent_sources.add(source)
# existing_fully_qualified_names.add(fully_qualified_name)
# deduped_parent_sources = [
# source
# for source in parent_sources
# if source.file_path not in no_parent_sources
# or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path]
# ]
# deduped_no_parent_sources = [
# source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2]
# ]
# return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods
#
#
# MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k
#
#
# def get_constrained_function_context_and_helper_functions(
# function_to_optimize: FunctionToOptimize,
# project_root_path: Path,
# code_to_optimize: str,
# max_tokens: int = MAX_PROMPT_TOKENS,
# ) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]:
# helper_functions, dunder_methods = get_function_variables_definitions(function_to_optimize, project_root_path)
# tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
# code_to_optimize_tokens = tokenizer.encode(code_to_optimize)
#
# if not function_to_optimize.parents:
# helper_functions_sources = [function.source_code for function in helper_functions]
# else:
# helper_functions_sources = [
# function.source_code
# for function in helper_functions
# if not function.qualified_name.count(".")
# or function.qualified_name.split(".")[0] != function_to_optimize.parents[0].name
# ]
# helper_functions_tokens = [len(tokenizer.encode(function)) for function in helper_functions_sources]
#
# context_list = []
# context_len = len(code_to_optimize_tokens)
# logger.debug(f"ORIGINAL CODE TOKENS LENGTH: {context_len}")
# logger.debug(f"ALL DEPENDENCIES TOKENS LENGTH: {sum(helper_functions_tokens)}")
# for function_source, source_len in zip(helper_functions_sources, helper_functions_tokens):
# if context_len + source_len <= max_tokens:
# context_list.append(function_source)
# context_len += source_len
# else:
# break
# logger.debug(f"FINAL OPTIMIZATION CONTEXT TOKENS LENGTH: {context_len}")
# helper_code: str = "\n".join(context_list)
# return helper_code, helper_functions, dunder_methods

View file

@ -58,7 +58,7 @@ from codeflash.models.models import (
TestFiles,
TestingMode,
)
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
# from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
from codeflash.result.explanation import Explanation
@ -140,14 +140,14 @@ class FunctionOptimizer:
logger.info("Code to be optimized:")
code_print(code_context.read_writable_code)
for module_abspath, helper_code_source in original_helper_code.items():
code_context.code_to_optimize_with_helpers = add_needed_imports_from_module(
helper_code_source,
code_context.code_to_optimize_with_helpers,
module_abspath,
self.function_to_optimize.file_path,
self.args.project_root,
)
# for module_abspath, helper_code_source in original_helper_code.items():
# code_context.code_to_optimize_with_helpers = add_needed_imports_from_module(
# helper_code_source,
# code_context.code_to_optimize_with_helpers,
# module_abspath,
# self.function_to_optimize.file_path,
# self.args.project_root,
# )
generated_test_paths = [
get_test_file_path(
@ -167,7 +167,7 @@ class FunctionOptimizer:
transient=True,
):
generated_results = self.generate_tests_and_optimizations(
code_to_optimize_with_helpers=code_context.code_to_optimize_with_helpers,
testgen_context_code=code_context.testgen_context_code,
read_writable_code=code_context.read_writable_code,
read_only_context_code=code_context.read_only_context_code,
helper_functions=code_context.helper_functions,
@ -556,49 +556,49 @@ class FunctionOptimizer:
return did_update
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize])
if code_to_optimize is None:
return Failure("Could not find function to optimize.")
(helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions(
self.function_to_optimize, self.project_root, code_to_optimize
)
if self.function_to_optimize.parents:
function_class = self.function_to_optimize.parents[0].name
same_class_helper_methods = [
df
for df in helper_functions
if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class
]
optimizable_methods = [
FunctionToOptimize(
df.qualified_name.split(".")[-1],
df.file_path,
[FunctionParent(df.qualified_name.split(".")[0], "ClassDef")],
None,
None,
)
for df in same_class_helper_methods
] + [self.function_to_optimize]
dedup_optimizable_methods = []
added_methods = set()
for method in reversed(optimizable_methods):
if f"{method.file_path}.{method.qualified_name}" not in added_methods:
dedup_optimizable_methods.append(method)
added_methods.add(f"{method.file_path}.{method.qualified_name}")
if len(dedup_optimizable_methods) > 1:
code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods)))
if code_to_optimize is None:
return Failure("Could not find function to optimize.")
code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize
code_to_optimize_with_helpers_and_imports = add_needed_imports_from_module(
self.function_to_optimize_source_code,
code_to_optimize_with_helpers,
self.function_to_optimize.file_path,
self.function_to_optimize.file_path,
self.project_root,
helper_functions,
)
# code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize])
# if code_to_optimize is None:
# return Failure("Could not find function to optimize.")
# (helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions(
# self.function_to_optimize, self.project_root, code_to_optimize
# )
# if self.function_to_optimize.parents:
# function_class = self.function_to_optimize.parents[0].name
# same_class_helper_methods = [
# df
# for df in helper_functions
# if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class
# ]
# optimizable_methods = [
# FunctionToOptimize(
# df.qualified_name.split(".")[-1],
# df.file_path,
# [FunctionParent(df.qualified_name.split(".")[0], "ClassDef")],
# None,
# None,
# )
# for df in same_class_helper_methods
# ] + [self.function_to_optimize]
# dedup_optimizable_methods = []
# added_methods = set()
# for method in reversed(optimizable_methods):
# if f"{method.file_path}.{method.qualified_name}" not in added_methods:
# dedup_optimizable_methods.append(method)
# added_methods.add(f"{method.file_path}.{method.qualified_name}")
# if len(dedup_optimizable_methods) > 1:
# code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods)))
# if code_to_optimize is None:
# return Failure("Could not find function to optimize.")
# code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize
#
# code_to_optimize_with_helpers_and_imports = add_needed_imports_from_module(
# self.function_to_optimize_source_code,
# code_to_optimize_with_helpers,
# self.function_to_optimize.file_path,
# self.function_to_optimize.file_path,
# self.project_root,
# helper_functions,
# )
try:
new_code_ctx = code_context_extractor.get_code_optimization_context(
@ -609,7 +609,8 @@ class FunctionOptimizer:
return Success(
CodeOptimizationContext(
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
# code_to_optimize_with_helpers=new_code_ctx.testgen_context_code, # Outdated, fix this!
testgen_context_code=new_code_ctx.testgen_context_code,
read_writable_code=new_code_ctx.read_writable_code,
read_only_context_code=new_code_ctx.read_only_context_code,
helper_functions=new_code_ctx.helper_functions, # only functions that are read writable
@ -711,7 +712,7 @@ class FunctionOptimizer:
def generate_tests_and_optimizations(
self,
code_to_optimize_with_helpers: str,
testgen_context_code: str,
read_writable_code: str,
read_only_context_code: str,
helper_functions: list[FunctionSource],
@ -726,7 +727,7 @@ class FunctionOptimizer:
# Submit the test generation task as future
future_tests = self.generate_and_instrument_tests(
executor,
code_to_optimize_with_helpers,
testgen_context_code,
[definition.fully_qualified_name for definition in helper_functions],
generated_test_paths,
generated_perf_test_paths,

View file

@ -740,7 +740,7 @@ class HelperClass:
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
expected_read_write_context = """
@ -813,6 +813,57 @@ class HelperClass:
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
def test_example_class_token_limit_4() -> None:
string_filler = " ".join(
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method. \"\"\"
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
x = '{string_filler}'
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
# In this scenario, the testgen code context is too long, so we abort.
with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"):
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
def test_repo_helper() -> None:
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"

View file

@ -747,24 +747,28 @@ class MainClass:
def test_code_replacement10() -> None:
get_code_output = """from __future__ import annotations
get_code_output = """```python:test_code_replacement.py
from __future__ import annotations
import os
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
class HelperClass:
def __init__(self, name):
self.name = name
def innocent_bystander(self):
pass
def helper_method(self):
return self.name
class MainClass:
def __init__(self, name):
self.name = name
def main_method(self):
return HelperClass(self.name).helper_method()
"""
```"""
file_path = Path(__file__).resolve()
func_top_optimize = FunctionToOptimize(
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
@ -778,7 +782,7 @@ class MainClass:
)
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
code_context = func_optimizer.get_code_optimization_context().unwrap()
assert code_context.code_to_optimize_with_helpers == get_code_output
assert code_context.testgen_context_code == get_code_output
def test_code_replacement11() -> None:

View file

@ -5,7 +5,7 @@ import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import is_successful
from codeflash.models.models import FunctionParent
from codeflash.optimization.function_context import get_function_variables_definitions
# from codeflash.optimization.function_context import get_function_variables_definitions
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -18,15 +18,6 @@ def simple_function_with_one_dep(data):
return calculate_something(data)
def test_simple_dependencies() -> None:
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
FunctionToOptimize("simple_function_with_one_dep", str(file_path), []), str(file_path.parent.resolve())
)[0]
assert len(helper_functions) == 1
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.calculate_something"
def global_dependency_1(num):
return num + 1
@ -93,63 +84,12 @@ class C:
return self.recursive(num) + num_1
def test_multiple_classes_dependencies() -> None:
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
FunctionToOptimize("run", str(file_path), [FunctionParent("C", "ClassDef")]), str(file_path.parent.resolve())
)
assert len(helper_functions) == 2
assert list(map(lambda x: x.fully_qualified_name, helper_functions[0])) == [
"test_function_dependencies.global_dependency_3",
"test_function_dependencies.C.calculate_something_3",
]
def recursive_dependency_1(num):
if num == 0:
return 0
num_1 = calculate_something(num)
return recursive_dependency_1(num) + num_1
def test_recursive_dependency() -> None:
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
FunctionToOptimize("recursive_dependency_1", str(file_path), []), str(file_path.parent.resolve())
)[0]
assert len(helper_functions) == 1
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.calculate_something"
assert helper_functions[0].fully_qualified_name == "test_function_dependencies.calculate_something"
@dataclass
class MyData:
MyInt: int
def calculate_something_ann(data):
return data + 1
def simple_function_with_one_dep_ann(data: MyData):
return calculate_something_ann(data)
def list_comprehension_dependency(data: MyData):
return [calculate_something(data) for x in range(10)]
def test_simple_dependencies_ann() -> None:
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
FunctionToOptimize("simple_function_with_one_dep_ann", str(file_path), []), str(file_path.parent.resolve())
)[0]
assert len(helper_functions) == 2
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData"
assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something_ann"
from collections import defaultdict
@ -220,13 +160,15 @@ def test_class_method_dependencies() -> None:
)
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
assert (
code_context.code_to_optimize_with_helpers
== """from collections import defaultdict
code_context.testgen_context_code
== """```python:test_function_dependencies.py
from collections import defaultdict
class Graph:
def __init__(self, vertices):
self.graph = defaultdict(list)
self.V = vertices # No. of vertices
def topologicalSortUtil(self, v, visited, stack):
visited[v] = True
@ -235,6 +177,7 @@ class Graph:
self.topologicalSortUtil(i, visited, stack)
stack.insert(0, v)
def topologicalSort(self):
visited = [False] * self.V
stack = []
@ -245,39 +188,9 @@ class Graph:
# Print contents of stack
return stack
"""
```"""
)
def calculate_something_else(data):
return data + 1
def imalittledecorator(func):
def wrapper(data):
return func(data)
return wrapper
@imalittledecorator
def simple_function_with_decorator_dep(data):
return calculate_something_else(data)
@pytest.mark.skip(reason="no decorator dependency support")
def test_decorator_dependencies() -> None:
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
FunctionToOptimize("simple_function_with_decorator_dep", str(file_path), []), str(file_path.parent.resolve())
)[0]
assert len(helper_functions) == 2
assert {helper_functions[0][0].definition.full_name, helper_functions[1][0].definition.full_name} == {
"test_function_dependencies.calculate_something",
"test_function_dependencies.imalittledecorator",
}
def test_recursive_function_context() -> None:
file_path = pathlib.Path(__file__).resolve()
@ -309,73 +222,16 @@ def test_recursive_function_context() -> None:
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
assert (
code_context.code_to_optimize_with_helpers
== """class C:
code_context.testgen_context_code
== """```python:test_function_dependencies.py
class C:
def calculate_something_3(self, num):
return num + 1
def recursive(self, num):
if num == 0:
return 0
num_1 = self.calculate_something_3(num)
return self.recursive(num) + num_1
"""
)
def test_list_comprehension_dependency() -> None:
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
FunctionToOptimize("list_comprehension_dependency", str(file_path), []), str(file_path.parent.resolve())
)[0]
assert len(helper_functions) == 2
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData"
assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something"
def test_function_in_method_list_comprehension() -> None:
file_path = pathlib.Path(__file__).resolve()
function_to_optimize = FunctionToOptimize(
function_name="function_in_list_comprehension",
file_path=str(file_path),
parents=[FunctionParent(name="A", type="ClassDef")],
starting_line=None,
ending_line=None,
)
helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0]
assert len(helper_functions) == 1
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.global_dependency_3"
def test_method_in_method_list_comprehension() -> None:
file_path = pathlib.Path(__file__).resolve()
function_to_optimize = FunctionToOptimize(
function_name="method_in_list_comprehension",
file_path=str(file_path),
parents=[FunctionParent(name="A", type="ClassDef")],
starting_line=None,
ending_line=None,
)
helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0]
assert len(helper_functions) == 1
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two"
def test_nested_method() -> None:
file_path = pathlib.Path(__file__).resolve()
function_to_optimize = FunctionToOptimize(
function_name="nested_function",
file_path=str(file_path),
parents=[FunctionParent(name="A", type="ClassDef")],
starting_line=None,
ending_line=None,
)
helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0]
# The nested function should be included in the helper functions
assert len(helper_functions) == 1
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two"
```"""
)

View file

@ -239,13 +239,37 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
pytest.fail()
code_context = ctx_result.unwrap()
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
assert (
code_context.code_to_optimize_with_helpers
== '''_R = TypeVar("_R")
code_context.testgen_context_code
== f'''```python:{file_path.name}
_P = ParamSpec("_P")
_KEY_T = TypeVar("_KEY_T")
_STORE_T = TypeVar("_STORE_T")
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
"""Interface for cache backends used by the persistent cache decorator."""
def __init__(self) -> None: ...
def hash_key(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> tuple[str, _KEY_T]: ...
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
...
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
...
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
def get_cache_or_call(
self,
*,
@ -300,7 +324,33 @@ class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
# If encoding fails, we should still return the result.
return result
_P = ParamSpec("_P")
_R = TypeVar("_R")
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
"""
A decorator class that provides persistent caching functionality for a function.
Args:
----
func (Callable[_P, _R]): The function to be decorated.
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
backend (_backend): The backend storage for the cached results.
Attributes:
----------
__wrapped__ (Callable[_P, _R]): The wrapped function.
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
__backend__ (_backend): The backend storage for the cached results.
""" # noqa: E501
__wrapped__: Callable[_P, _R]
__duration__: datetime.timedelta
__backend__: _CacheBackendT
def __init__(
self,
func: Callable[_P, _R],
@ -310,6 +360,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
self.__duration__ = duration
self.__backend__ = AbstractCacheBackend()
functools.update_wrapper(self, func)
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
"""
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
@ -333,7 +384,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
kwargs=kwargs,
lifespan=self.__duration__,
)
'''
```'''
)
@ -358,14 +409,20 @@ def test_bubble_sort_deps() -> None:
pytest.fail()
code_context = ctx_result.unwrap()
assert (
code_context.code_to_optimize_with_helpers
== """def dep1_comparer(arr, j: int) -> bool:
code_context.testgen_context_code
== """```python:code_to_optimize/bubble_sort_dep1_helper.py
def dep1_comparer(arr, j: int) -> bool:
return arr[j] > arr[j + 1]
```
```python:code_to_optimize/bubble_sort_dep2_swap.py
def dep2_swap(arr, j):
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
```
```python:code_to_optimize/bubble_sort_deps.py
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
def sorter_deps(arr):
for i in range(len(arr)):
@ -373,7 +430,7 @@ def sorter_deps(arr):
if dep1_comparer(arr, j):
dep2_swap(arr, j)
return arr
"""
```"""
)
assert len(code_context.helper_functions) == 2
assert (

View file

@ -2,7 +2,8 @@ from textwrap import dedent
import pytest
from codeflash.context.code_context_extractor import get_read_only_code
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
from codeflash.models.models import CodeContextType
def test_basic_class() -> None:
@ -22,7 +23,7 @@ def test_basic_class() -> None:
class_var = "value"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -46,7 +47,7 @@ def test_dunder_methods() -> None:
return f"Value: {self.x}"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -72,7 +73,7 @@ def test_dunder_methods_remove_docstring() -> None:
return f"Value: {self.x}"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True)
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True)
assert dedent(expected).strip() == output.strip()
@ -97,7 +98,7 @@ def test_class_remove_docstring() -> None:
return f"Value: {self.x}"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True)
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True)
assert dedent(expected).strip() == output.strip()
@ -124,7 +125,7 @@ def test_mixed_remove_docstring() -> None:
return f"Value: {self.x}"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True)
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True)
assert dedent(expected).strip() == output.strip()
@ -142,7 +143,7 @@ def test_target_in_nested_class() -> None:
"""
with pytest.raises(ValueError, match="No target functions found in the provided code"):
get_read_only_code(dedent(code), {"Outer.Inner.target_method"}, set())
parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"Outer.Inner.target_method"}, set())
def test_docstrings() -> None:
@ -164,7 +165,7 @@ def test_docstrings() -> None:
\"\"\"Class docstring.\"\"\"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -183,7 +184,7 @@ def test_method_signatures() -> None:
expected = """"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -203,7 +204,7 @@ def test_multiple_top_level_targets() -> None:
expected = """
"""
output = get_read_only_code(dedent(code), {"TestClass.target1", "TestClass.target2"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target1", "TestClass.target2"}, set())
assert dedent(expected).strip() == output.strip()
@ -223,7 +224,7 @@ def test_class_annotations() -> None:
var2: str
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -245,7 +246,7 @@ def test_class_annotations_if() -> None:
var2: str
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -271,7 +272,7 @@ def test_class_annotations_try() -> None:
continue
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -307,7 +308,7 @@ def test_class_annotations_else() -> None:
var2: str
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -322,7 +323,7 @@ def test_top_level_functions() -> None:
expected = """"""
output = get_read_only_code(dedent(code), {"target_function"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
assert dedent(expected).strip() == output.strip()
@ -341,7 +342,7 @@ def test_module_var() -> None:
x = 5
"""
output = get_read_only_code(dedent(code), {"target_function"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
assert dedent(expected).strip() == output.strip()
@ -352,7 +353,7 @@ def test_module_var_if() -> None:
if y:
x = 5
else:
else:
z = 10
def some_function():
print("wow")
@ -364,11 +365,11 @@ def test_module_var_if() -> None:
expected = """
if y:
x = 5
else:
else:
z = 10
"""
output = get_read_only_code(dedent(code), {"target_function"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
assert dedent(expected).strip() == output.strip()
@ -403,7 +404,7 @@ def test_conditional_class_definitions() -> None:
platform = "other"
"""
output = get_read_only_code(dedent(code), {"PlatformClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -462,7 +463,7 @@ def test_multiple_except_clauses() -> None:
error_type = "cleanup"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -515,7 +516,7 @@ def test_with_statement_and_loops() -> None:
context = "cleanup"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -564,7 +565,7 @@ def test_async_with_try_except() -> None:
status = "cancelled"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -664,7 +665,7 @@ def test_simplified_complete_implementation() -> None:
pass
"""
output = get_read_only_code(dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -751,7 +752,7 @@ def test_simplified_complete_implementation_no_docstring() -> None:
pass
"""
output = get_read_only_code(
dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True
)
assert dedent(expected).strip() == output.strip()

View file

@ -1,7 +1,8 @@
from textwrap import dedent
import pytest
from codeflash.context.code_context_extractor import get_read_writable_code
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
from codeflash.models.models import CodeContextType
def test_simple_function() -> None:
@ -11,7 +12,7 @@ def test_simple_function() -> None:
y = 2
return x + y
"""
result = get_read_writable_code(dedent(code), {"target_function"})
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"})
expected = dedent("""
def target_function():
@ -30,7 +31,7 @@ def test_class_method() -> None:
y = 2
return x + y
"""
result = get_read_writable_code(dedent(code), {"MyClass.target_function"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"})
expected = dedent("""
class MyClass:
@ -54,7 +55,7 @@ def test_class_with_attributes() -> None:
def other_method(self):
print("this should be excluded")
"""
result = get_read_writable_code(dedent(code), {"MyClass.target_method"})
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
expected = dedent("""
class MyClass:
@ -78,7 +79,7 @@ def test_basic_class_structure() -> None:
def not_findable(self):
return 42
"""
result = get_read_writable_code(dedent(code), {"Outer.target_method"})
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"Outer.target_method"})
expected = dedent("""
class Outer:
@ -98,7 +99,7 @@ def test_top_level_targets() -> None:
def target_function():
return 42
"""
result = get_read_writable_code(dedent(code), {"target_function"})
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"})
expected = dedent("""
def target_function():
@ -121,7 +122,7 @@ def test_multiple_top_level_classes() -> None:
def process(self):
return "C"
"""
result = get_read_writable_code(dedent(code), {"ClassA.process", "ClassC.process"})
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"})
expected = dedent("""
class ClassA:
@ -146,7 +147,7 @@ def test_try_except_structure() -> None:
def handle_error(self):
print("error")
"""
result = get_read_writable_code(dedent(code), {"TargetClass.target_method"})
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"TargetClass.target_method"})
expected = dedent("""
try:
@ -173,7 +174,7 @@ def test_init_method() -> None:
def target_method(self):
return f"Value: {self.x}"
"""
result = get_read_writable_code(dedent(code), {"MyClass.target_method"})
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
expected = dedent("""
class MyClass:
@ -197,7 +198,7 @@ def test_dunder_method() -> None:
def target_method(self):
return f"Value: {self.x}"
"""
result = get_read_writable_code(dedent(code), {"MyClass.target_method"})
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
expected = dedent("""
class MyClass:
@ -218,7 +219,7 @@ def test_no_targets_found() -> None:
pass
"""
with pytest.raises(ValueError, match="No target functions found in the provided code"):
get_read_writable_code(dedent(code), {"MyClass.Inner.target"})
parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
def test_module_var() -> None:
@ -242,7 +243,7 @@ def test_module_var() -> None:
var2 = "test"
"""
output = get_read_writable_code(dedent(code), {"target_function"})
output = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"})
assert dedent(expected).strip() == output.strip()

View file

@ -0,0 +1,745 @@
from textwrap import dedent
import pytest
from codeflash.models.models import CodeContextType
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
def test_simple_function() -> None:
code = """
def target_function():
x = 1
y = 2
return x + y
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
expected = """
def target_function():
x = 1
y = 2
return x + y
"""
assert dedent(expected).strip() == result.strip()
def test_basic_class() -> None:
code = """
class TestClass:
class_var = "value"
def target_method(self):
print("This should be included")
def other_method(self):
print("This too")
"""
expected = """
class TestClass:
class_var = "value"
def target_method(self):
print("This should be included")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_dunder_methods() -> None:
code = """
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
def target_method(self):
print("include me")
"""
expected = """
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
def target_method(self):
print("include me")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_dunder_methods_remove_docstring() -> None:
code = """
class TestClass:
def __init__(self):
\"\"\"Constructor for TestClass.\"\"\"
self.x = 42
def __str__(self):
\"\"\"String representation of TestClass.\"\"\"
return f"Value: {self.x}"
def target_method(self):
\"\"\"Target method docstring.\"\"\"
print("include me")
"""
expected = """
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
def target_method(self):
print("include me")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True)
assert dedent(expected).strip() == output.strip()
def test_class_remove_docstring() -> None:
code = """
class TestClass:
\"\"\"Class docstring.\"\"\"
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
def target_method(self):
print("include me")
"""
expected = """
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
def target_method(self):
print("include me")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True)
assert dedent(expected).strip() == output.strip()
def test_target_in_nested_class() -> None:
"""Test that attempting to find a target in a nested class raises an error."""
code = """
class Outer:
outer_var = 1
class Inner:
inner_var = 2
def target_method(self):
print("include this")
"""
with pytest.raises(ValueError, match="No target functions found in the provided code"):
parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"Outer.Inner.target_method"}, set())
def test_method_signatures() -> None:
code = """
class TestClass:
@property
def target_method(self) -> str:
\"\"\"Property docstring.\"\"\"
return "value"
@classmethod
def class_method(cls, param: int = 42) -> None:
print("class method")
"""
expected = """
class TestClass:
@property
def target_method(self) -> str:
\"\"\"Property docstring.\"\"\"
return "value"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_multiple_top_level_targets() -> None:
code = """
class TestClass:
def target1(self):
print("include 1")
def target2(self):
print("include 2")
def __init__(self):
self.x = 42
def other_method(self):
print("include other")
"""
expected = """
class TestClass:
def target1(self):
print("include 1")
def target2(self):
print("include 2")
def __init__(self):
self.x = 42
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target1", "TestClass.target2"}, set())
assert dedent(expected).strip() == output.strip()
def test_class_annotations() -> None:
code = """
class TestClass:
var1: int = 42
var2: str
def target_method(self) -> None:
self.var2 = "test"
"""
expected = """
class TestClass:
var1: int = 42
var2: str
def target_method(self) -> None:
self.var2 = "test"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_class_annotations_if() -> None:
code = """
if True:
class TestClass:
var1: int = 42
var2: str
def target_method(self) -> None:
self.var2 = "test"
"""
expected = """
if True:
class TestClass:
var1: int = 42
var2: str
def target_method(self) -> None:
self.var2 = "test"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_conditional_class_definitions() -> None:
code = """
if PLATFORM == "linux":
class PlatformClass:
platform = "linux"
def target_method(self):
print("linux")
elif PLATFORM == "windows":
class PlatformClass:
platform = "windows"
def target_method(self):
print("windows")
else:
class PlatformClass:
platform = "other"
def target_method(self):
print("other")
"""
expected = """
if PLATFORM == "linux":
class PlatformClass:
platform = "linux"
def target_method(self):
print("linux")
elif PLATFORM == "windows":
class PlatformClass:
platform = "windows"
def target_method(self):
print("windows")
else:
class PlatformClass:
platform = "other"
def target_method(self):
print("other")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_try_except_structure() -> None:
code = """
try:
class TargetClass:
attr = "value"
def target_method(self):
return 42
except ValueError:
class ErrorClass:
def handle_error(self):
print("error")
"""
expected = """
try:
class TargetClass:
attr = "value"
def target_method(self):
return 42
except ValueError:
class ErrorClass:
def handle_error(self):
print("error")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TargetClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_module_var() -> None:
code = """
def target_function(self) -> None:
self.var2 = "test"
x = 5
def some_function():
print("wow")
"""
expected = """
def target_function(self) -> None:
self.var2 = "test"
x = 5
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
assert dedent(expected).strip() == output.strip()
def test_module_var_if() -> None:
code = """
def target_function(self) -> None:
var2 = "test"
if y:
x = 5
else:
z = 10
def some_function():
print("wow")
def some_function():
print("wow")
"""
expected = """
def target_function(self) -> None:
var2 = "test"
if y:
x = 5
else:
z = 10
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
assert dedent(expected).strip() == output.strip()
def test_multiple_classes() -> None:
code = """
class ClassA:
def process(self):
return "A"
class ClassB:
def process(self):
return "B"
class ClassC:
def process(self):
return "C"
"""
expected = """
class ClassA:
def process(self):
return "A"
class ClassC:
def process(self):
return "C"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"ClassA.process", "ClassC.process"}, set())
assert dedent(expected).strip() == output.strip()
def test_with_statement_and_loops() -> None:
code = """
with context_manager() as ctx:
while attempt_count < max_attempts:
try:
for item in items:
if item.ready:
class TestClass:
context = "ready"
def target_method(self):
print("ready")
else:
class TestClass:
context = "not_ready"
def target_method(self):
print("not ready")
except ConnectionError:
class TestClass:
context = "connection_error"
def target_method(self):
print("connection error")
continue
finally:
class TestClass:
context = "cleanup"
def target_method(self):
print("cleanup")
"""
expected = """
with context_manager() as ctx:
while attempt_count < max_attempts:
try:
for item in items:
if item.ready:
class TestClass:
context = "ready"
def target_method(self):
print("ready")
else:
class TestClass:
context = "not_ready"
def target_method(self):
print("not ready")
except ConnectionError:
class TestClass:
context = "connection_error"
def target_method(self):
print("connection error")
continue
finally:
class TestClass:
context = "cleanup"
def target_method(self):
print("cleanup")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_async_with_try_except() -> None:
code = """
async with async_context() as ctx:
try:
async for item in items:
if await item.is_valid():
class TestClass:
status = "valid"
async def target_method(self):
await self.process()
elif await item.can_retry():
continue
else:
break
except AsyncIOError:
class TestClass:
status = "io_error"
async def target_method(self):
await self.handle_error()
except CancelledError:
class TestClass:
status = "cancelled"
async def target_method(self):
await self.cleanup()
"""
expected = """
async with async_context() as ctx:
try:
async for item in items:
if await item.is_valid():
class TestClass:
status = "valid"
async def target_method(self):
await self.process()
elif await item.can_retry():
continue
else:
break
except AsyncIOError:
class TestClass:
status = "io_error"
async def target_method(self):
await self.handle_error()
except CancelledError:
class TestClass:
status = "cancelled"
async def target_method(self):
await self.cleanup()
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_simplified_complete_implementation() -> None:
code = """
class DataProcessor:
\"\"\"A simple data processing class.\"\"\"
def __init__(self, data: Dict[str, Any]) -> None:
self.data = data
self._processed = False
self.result = None
def __repr__(self) -> str:
return f"DataProcessor(processed={self._processed})"
def target_method(self, key: str) -> Optional[Any]:
\"\"\"Process and retrieve a specific key from the data.\"\"\"
if not self._processed:
self._process_data()
return self.result.get(key) if self.result else None
def _process_data(self) -> None:
\"\"\"Internal method to process the data.\"\"\"
processed = {}
for key, value in self.data.items():
if isinstance(value, (int, float)):
processed[key] = value * 2
elif isinstance(value, str):
processed[key] = value.upper()
else:
processed[key] = value
self.result = processed
self._processed = True
def to_json(self) -> str:
\"\"\"Convert the processed data to JSON string.\"\"\"
if not self._processed:
self._process_data()
return json.dumps(self.result)
try:
sample_data = {"number": 42, "text": "hello"}
processor = DataProcessor(sample_data)
class ResultHandler:
def __init__(self, processor: DataProcessor):
self.processor = processor
self.cache = {}
def __str__(self) -> str:
return f"ResultHandler(cache_size={len(self.cache)})"
def target_method(self, key: str) -> Optional[Any]:
\"\"\"Retrieve and cache results for a key.\"\"\"
if key not in self.cache:
self.cache[key] = self.processor.target_method(key)
return self.cache[key]
def clear_cache(self) -> None:
\"\"\"Clear the internal cache.\"\"\"
self.cache.clear()
def get_stats(self) -> Dict[str, int]:
\"\"\"Get cache statistics.\"\"\"
return {
"cache_size": len(self.cache),
"hits": sum(1 for v in self.cache.values() if v is not None)
}
except Exception as e:
class ResultHandler:
def __init__(self):
self.error = str(e)
def target_method(self, key: str) -> None:
raise RuntimeError(f"Failed to initialize: {self.error}")
"""
expected = """
class DataProcessor:
\"\"\"A simple data processing class.\"\"\"
def __init__(self, data: Dict[str, Any]) -> None:
self.data = data
self._processed = False
self.result = None
def __repr__(self) -> str:
return f"DataProcessor(processed={self._processed})"
def target_method(self, key: str) -> Optional[Any]:
\"\"\"Process and retrieve a specific key from the data.\"\"\"
if not self._processed:
self._process_data()
return self.result.get(key) if self.result else None
try:
sample_data = {"number": 42, "text": "hello"}
processor = DataProcessor(sample_data)
class ResultHandler:
def __init__(self, processor: DataProcessor):
self.processor = processor
self.cache = {}
def __str__(self) -> str:
return f"ResultHandler(cache_size={len(self.cache)})"
def target_method(self, key: str) -> Optional[Any]:
\"\"\"Retrieve and cache results for a key.\"\"\"
if key not in self.cache:
self.cache[key] = self.processor.target_method(key)
return self.cache[key]
except Exception as e:
class ResultHandler:
def __init__(self):
self.error = str(e)
def target_method(self, key: str) -> None:
raise RuntimeError(f"Failed to initialize: {self.error}")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
assert dedent(expected).strip() == output.strip()
def test_simplified_complete_implementation_no_docstring() -> None:
code = """
class DataProcessor:
\"\"\"A simple data processing class.\"\"\"
def __repr__(self) -> str:
return f"DataProcessor(processed={self._processed})"
def target_method(self, key: str) -> Optional[Any]:
\"\"\"Process and retrieve a specific key from the data.\"\"\"
if not self._processed:
self._process_data()
return self.result.get(key) if self.result else None
def _process_data(self) -> None:
\"\"\"Internal method to process the data.\"\"\"
processed = {}
for key, value in self.data.items():
if isinstance(value, (int, float)):
processed[key] = value * 2
elif isinstance(value, str):
processed[key] = value.upper()
else:
processed[key] = value
self.result = processed
self._processed = True
def to_json(self) -> str:
\"\"\"Convert the processed data to JSON string.\"\"\"
if not self._processed:
self._process_data()
return json.dumps(self.result)
try:
sample_data = {"number": 42, "text": "hello"}
processor = DataProcessor(sample_data)
class ResultHandler:
def __str__(self) -> str:
return f"ResultHandler(cache_size={len(self.cache)})"
def target_method(self, key: str) -> Optional[Any]:
\"\"\"Retrieve and cache results for a key.\"\"\"
if key not in self.cache:
self.cache[key] = self.processor.target_method(key)
return self.cache[key]
def clear_cache(self) -> None:
\"\"\"Clear the internal cache.\"\"\"
self.cache.clear()
def get_stats(self) -> Dict[str, int]:
\"\"\"Get cache statistics.\"\"\"
return {
"cache_size": len(self.cache),
"hits": sum(1 for v in self.cache.values() if v is not None)
}
except Exception as e:
class ResultHandler:
def target_method(self, key: str) -> None:
raise RuntimeError(f"Failed to initialize: {self.error}")
"""
expected = """
class DataProcessor:
def __repr__(self) -> str:
return f"DataProcessor(processed={self._processed})"
def target_method(self, key: str) -> Optional[Any]:
if not self._processed:
self._process_data()
return self.result.get(key) if self.result else None
try:
sample_data = {"number": 42, "text": "hello"}
processor = DataProcessor(sample_data)
class ResultHandler:
def __str__(self) -> str:
return f"ResultHandler(cache_size={len(self.cache)})"
def target_method(self, key: str) -> Optional[Any]:
if key not in self.cache:
self.cache[key] = self.processor.target_method(key)
return self.cache[key]
except Exception as e:
class ResultHandler:
def target_method(self, key: str) -> None:
raise RuntimeError(f"Failed to initialize: {self.error}")
"""
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True
)
assert dedent(expected).strip() == output.strip()

View file

@ -2674,13 +2674,13 @@ def test_code_replacement10() -> None:
project_root=str(file_path.parent),
original_source_code=original_code,
).unwrap()
assert code_context.code_to_optimize_with_helpers == get_code_output
assert code_context.testgen_context_code == get_code_output
code_context = opt.get_code_optimization_context(
function_to_optimize=func_top_optimize,
project_root=str(file_path.parent),
original_source_code=original_code,
)
assert code_context.code_to_optimize_with_helpers == get_code_output
assert code_context.testgen_context_code == get_code_output
"""
expected = """import gc
@ -2739,9 +2739,9 @@ def test_code_replacement10() -> None:
with open(file_path) as f:
original_code = f.read()
code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_1', codeflash_loop_index, codeflash_cur, codeflash_con, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code).unwrap()
assert code_context.code_to_optimize_with_helpers == get_code_output
assert code_context.testgen_context_code == get_code_output
code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_3', codeflash_loop_index, codeflash_cur, codeflash_con, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code)
assert code_context.code_to_optimize_with_helpers == get_code_output
assert code_context.testgen_context_code == get_code_output
codeflash_con.close()
"""

View file

@ -1,103 +1,103 @@
from __future__ import annotations
import pathlib
from dataclasses import dataclass, field
from typing import List
from codeflash.code_utils.code_extractor import get_code
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
class CustomType:
def __init__(self) -> None:
self.name = None
self.data: List[int] = []
@dataclass
class CustomDataClass:
name: str = ""
data: List[int] = field(default_factory=list)
def function_to_optimize(data: CustomType) -> CustomType:
name = data.name
data.data.sort()
return data
def function_to_optimize2(data: CustomDataClass) -> CustomType:
name = data.name
data.data.sort()
return data
def function_to_optimize3(data: dict[CustomDataClass, list[CustomDataClass]]) -> list[CustomType] | None:
name = data.name
data.data.sort()
return data
def test_function_context_includes_type_annotation() -> None:
file_path = pathlib.Path(__file__).resolve()
a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
FunctionToOptimize("function_to_optimize", str(file_path), []),
str(file_path.parent.resolve()),
"""def function_to_optimize(data: CustomType):
name = data.name
data.data.sort()
return data""",
1000,
)
assert len(helper_functions) == 1
assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomType"
def test_function_context_includes_type_annotation_dataclass() -> None:
file_path = pathlib.Path(__file__).resolve()
a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
FunctionToOptimize("function_to_optimize2", str(file_path), []),
str(file_path.parent.resolve()),
"""def function_to_optimize2(data: CustomDataClass) -> CustomType:
name = data.name
data.data.sort()
return data""",
1000,
)
assert len(helper_functions) == 2
assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass"
assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType"
def test_function_context_works_for_composite_types() -> None:
file_path = pathlib.Path(__file__).resolve()
a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
FunctionToOptimize("function_to_optimize3", str(file_path), []),
str(file_path.parent.resolve()),
"""def function_to_optimize3(data: set[CustomDataClass[CustomDataClass, int]]) -> list[CustomType]:
name = data.name
data.data.sort()
return data""",
1000,
)
assert len(helper_functions) == 2
assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass"
assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType"
def test_function_context_custom_datatype() -> None:
project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize"
file_path = project_path / "math_utils.py"
code, contextual_dunder_methods = get_code([FunctionToOptimize("cosine_similarity", str(file_path), [])])
assert code is not None
assert contextual_dunder_methods == set()
a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
FunctionToOptimize("cosine_similarity", str(file_path), []), str(project_path), code, 1000
)
assert len(helper_functions) == 1
assert helper_functions[0].fully_qualified_name == "math_utils.Matrix"
# from __future__ import annotations
#
# import pathlib
# from dataclasses import dataclass, field
# from typing import List
#
# from codeflash.code_utils.code_extractor import get_code
# from codeflash.discovery.functions_to_optimize import FunctionToOptimize
# from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
#
#
# class CustomType:
# def __init__(self) -> None:
# self.name = None
# self.data: List[int] = []
#
#
# @dataclass
# class CustomDataClass:
# name: str = ""
# data: List[int] = field(default_factory=list)
#
#
# def function_to_optimize(data: CustomType) -> CustomType:
# name = data.name
# data.data.sort()
# return data
#
#
# def function_to_optimize2(data: CustomDataClass) -> CustomType:
# name = data.name
# data.data.sort()
# return data
#
#
# def function_to_optimize3(data: dict[CustomDataClass, list[CustomDataClass]]) -> list[CustomType] | None:
# name = data.name
# data.data.sort()
# return data
#
#
# def test_function_context_includes_type_annotation() -> None:
# file_path = pathlib.Path(__file__).resolve()
# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
# FunctionToOptimize("function_to_optimize", str(file_path), []),
# str(file_path.parent.resolve()),
# """def function_to_optimize(data: CustomType):
# name = data.name
# data.data.sort()
# return data""",
# 1000,
# )
#
# assert len(helper_functions) == 1
# assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomType"
#
#
# def test_function_context_includes_type_annotation_dataclass() -> None:
# file_path = pathlib.Path(__file__).resolve()
# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
# FunctionToOptimize("function_to_optimize2", str(file_path), []),
# str(file_path.parent.resolve()),
# """def function_to_optimize2(data: CustomDataClass) -> CustomType:
# name = data.name
# data.data.sort()
# return data""",
# 1000,
# )
#
# assert len(helper_functions) == 2
# assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass"
# assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType"
#
#
# def test_function_context_works_for_composite_types() -> None:
# file_path = pathlib.Path(__file__).resolve()
# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
# FunctionToOptimize("function_to_optimize3", str(file_path), []),
# str(file_path.parent.resolve()),
# """def function_to_optimize3(data: set[CustomDataClass[CustomDataClass, int]]) -> list[CustomType]:
# name = data.name
# data.data.sort()
# return data""",
# 1000,
# )
#
# assert len(helper_functions) == 2
# assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass"
# assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType"
#
#
# def test_function_context_custom_datatype() -> None:
# project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize"
# file_path = project_path / "math_utils.py"
# code, contextual_dunder_methods = get_code([FunctionToOptimize("cosine_similarity", str(file_path), [])])
# assert code is not None
# assert contextual_dunder_methods == set()
# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
# FunctionToOptimize("cosine_similarity", str(file_path), []), str(project_path), code, 1000
# )
#
# assert len(helper_functions) == 1
# assert helper_functions[0].fully_qualified_name == "math_utils.Matrix"