use markdown context for the testgen

This commit is contained in:
ali 2025-09-25 03:31:05 +03:00
parent 32e85ee67c
commit 5981d75b3e
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
7 changed files with 23 additions and 23 deletions

View file

@ -12,9 +12,10 @@ if TYPE_CHECKING:
def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]:
"""Extract the single dependent function from the code context excluding the main function."""
ast_tree = ast.parse(code_context.testgen_context_code)
dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)}
dependent_functions = set()
for code_string in code_context.testgen_context.code_strings:
ast_tree = ast.parse(code_string.code)
dependent_functions.update({node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)})
if main_function in dependent_functions:
dependent_functions.discard(main_function)

View file

@ -114,32 +114,32 @@ def get_code_optimization_context(
read_only_context_code = ""
# Extract code context for testgen
testgen_code_markdown = extract_code_string_context_from_files(
testgen_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=False,
code_context_type=CodeContextType.TESTGEN,
)
testgen_context_code = testgen_code_markdown.code
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
if testgen_context_code_tokens > testgen_token_limit:
testgen_code_markdown = extract_code_string_context_from_files(
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
testgen_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
code_context_type=CodeContextType.TESTGEN,
)
testgen_context_code = testgen_code_markdown.code
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
if testgen_context_code_tokens > testgen_token_limit:
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
code_hash_context = hashing_code_context.markdown
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
return CodeOptimizationContext(
testgen_context_code=testgen_context_code,
testgen_context=testgen_context,
read_writable_code=final_read_writable_code,
read_only_context_code=read_only_context_code,
hashing_code_context=code_hash_context,

View file

@ -253,7 +253,7 @@ class CodeStringsMarkdown(BaseModel):
class CodeOptimizationContext(BaseModel):
testgen_context_code: str = ""
testgen_context: CodeStringsMarkdown
read_writable_code: CodeStringsMarkdown
read_only_context_code: str = ""
hashing_code_context: str = ""

View file

@ -309,7 +309,7 @@ class FunctionOptimizer:
revert_to_print=bool(get_pr_number()),
):
generated_results = self.generate_tests_and_optimizations(
testgen_context_code=code_context.testgen_context_code,
testgen_context=code_context.testgen_context,
read_writable_code=code_context.read_writable_code,
read_only_context_code=code_context.read_only_context_code,
helper_functions=code_context.helper_functions,
@ -345,7 +345,6 @@ class FunctionOptimizer:
logger.info(f"Generated test {i + 1}/{count_tests}:")
code_print(generated_test.generated_original_test_source, file_name=f"test_{i + 1}.py")
if concolic_test_str:
# no concolic tests in lsp mode
logger.info(f"Generated test {count_tests}/{count_tests}:")
code_print(concolic_test_str)
@ -946,7 +945,7 @@ class FunctionOptimizer:
return Success(
CodeOptimizationContext(
testgen_context_code=new_code_ctx.testgen_context_code,
testgen_context=new_code_ctx.testgen_context,
read_writable_code=new_code_ctx.read_writable_code,
read_only_context_code=new_code_ctx.read_only_context_code,
hashing_code_context=new_code_ctx.hashing_code_context,
@ -1053,7 +1052,7 @@ class FunctionOptimizer:
def generate_tests_and_optimizations(
self,
testgen_context_code: str,
testgen_context: CodeStringsMarkdown,
read_writable_code: CodeStringsMarkdown,
read_only_context_code: str,
helper_functions: list[FunctionSource],
@ -1067,7 +1066,7 @@ class FunctionOptimizer:
# Submit the test generation task as future
future_tests = self.submit_test_generation_tasks(
self.executor,
testgen_context_code,
testgen_context.markdown,
[definition.fully_qualified_name for definition in helper_functions],
generated_test_paths,
generated_perf_test_paths,

View file

@ -827,7 +827,7 @@ class MainClass:
)
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
code_context = func_optimizer.get_code_optimization_context().unwrap()
assert code_context.testgen_context_code.rstrip() == get_code_output.rstrip()
assert code_context.testgen_context.rstrip() == get_code_output.rstrip()
def test_code_replacement11() -> None:

View file

@ -160,7 +160,7 @@ def test_class_method_dependencies() -> None:
)
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
assert (
code_context.testgen_context_code
code_context.testgen_context
== """from collections import defaultdict
class Graph:
@ -220,7 +220,7 @@ def test_recursive_function_context() -> None:
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
assert (
code_context.testgen_context_code
code_context.testgen_context
== """class C:
def calculate_something_3(self, num):
return num + 1

View file

@ -241,7 +241,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
code_context = ctx_result.unwrap()
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
assert (
code_context.testgen_context_code
code_context.testgen_context
== f'''_P = ParamSpec("_P")
_KEY_T = TypeVar("_KEY_T")
_STORE_T = TypeVar("_STORE_T")
@ -409,7 +409,7 @@ def test_bubble_sort_deps() -> None:
pytest.fail()
code_context = ctx_result.unwrap()
assert (
code_context.testgen_context_code
code_context.testgen_context
== """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap