use markdown context for the testgen
This commit is contained in:
parent
32e85ee67c
commit
5981d75b3e
7 changed files with 23 additions and 23 deletions
|
|
@ -12,9 +12,10 @@ if TYPE_CHECKING:
|
|||
|
||||
def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]:
|
||||
"""Extract the single dependent function from the code context excluding the main function."""
|
||||
ast_tree = ast.parse(code_context.testgen_context_code)
|
||||
|
||||
dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)}
|
||||
dependent_functions = set()
|
||||
for code_string in code_context.testgen_context.code_strings:
|
||||
ast_tree = ast.parse(code_string.code)
|
||||
dependent_functions.update({node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)})
|
||||
|
||||
if main_function in dependent_functions:
|
||||
dependent_functions.discard(main_function)
|
||||
|
|
|
|||
|
|
@ -114,32 +114,32 @@ def get_code_optimization_context(
|
|||
read_only_context_code = ""
|
||||
|
||||
# Extract code context for testgen
|
||||
testgen_code_markdown = extract_code_string_context_from_files(
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=False,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
)
|
||||
testgen_context_code = testgen_code_markdown.code
|
||||
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
|
||||
if testgen_context_code_tokens > testgen_token_limit:
|
||||
testgen_code_markdown = extract_code_string_context_from_files(
|
||||
testgen_markdown_code = testgen_context.markdown
|
||||
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
|
||||
if testgen_code_token_length > testgen_token_limit:
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
)
|
||||
testgen_context_code = testgen_code_markdown.code
|
||||
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
|
||||
if testgen_context_code_tokens > testgen_token_limit:
|
||||
testgen_markdown_code = testgen_context.markdown
|
||||
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
|
||||
if testgen_code_token_length > testgen_token_limit:
|
||||
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
|
||||
code_hash_context = hashing_code_context.markdown
|
||||
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
|
||||
|
||||
return CodeOptimizationContext(
|
||||
testgen_context_code=testgen_context_code,
|
||||
testgen_context=testgen_context,
|
||||
read_writable_code=final_read_writable_code,
|
||||
read_only_context_code=read_only_context_code,
|
||||
hashing_code_context=code_hash_context,
|
||||
|
|
|
|||
|
|
@ -253,7 +253,7 @@ class CodeStringsMarkdown(BaseModel):
|
|||
|
||||
|
||||
class CodeOptimizationContext(BaseModel):
|
||||
testgen_context_code: str = ""
|
||||
testgen_context: CodeStringsMarkdown
|
||||
read_writable_code: CodeStringsMarkdown
|
||||
read_only_context_code: str = ""
|
||||
hashing_code_context: str = ""
|
||||
|
|
|
|||
|
|
@ -309,7 +309,7 @@ class FunctionOptimizer:
|
|||
revert_to_print=bool(get_pr_number()),
|
||||
):
|
||||
generated_results = self.generate_tests_and_optimizations(
|
||||
testgen_context_code=code_context.testgen_context_code,
|
||||
testgen_context=code_context.testgen_context,
|
||||
read_writable_code=code_context.read_writable_code,
|
||||
read_only_context_code=code_context.read_only_context_code,
|
||||
helper_functions=code_context.helper_functions,
|
||||
|
|
@ -345,7 +345,6 @@ class FunctionOptimizer:
|
|||
logger.info(f"Generated test {i + 1}/{count_tests}:")
|
||||
code_print(generated_test.generated_original_test_source, file_name=f"test_{i + 1}.py")
|
||||
if concolic_test_str:
|
||||
# no concolic tests in lsp mode
|
||||
logger.info(f"Generated test {count_tests}/{count_tests}:")
|
||||
code_print(concolic_test_str)
|
||||
|
||||
|
|
@ -946,7 +945,7 @@ class FunctionOptimizer:
|
|||
|
||||
return Success(
|
||||
CodeOptimizationContext(
|
||||
testgen_context_code=new_code_ctx.testgen_context_code,
|
||||
testgen_context=new_code_ctx.testgen_context,
|
||||
read_writable_code=new_code_ctx.read_writable_code,
|
||||
read_only_context_code=new_code_ctx.read_only_context_code,
|
||||
hashing_code_context=new_code_ctx.hashing_code_context,
|
||||
|
|
@ -1053,7 +1052,7 @@ class FunctionOptimizer:
|
|||
|
||||
def generate_tests_and_optimizations(
|
||||
self,
|
||||
testgen_context_code: str,
|
||||
testgen_context: CodeStringsMarkdown,
|
||||
read_writable_code: CodeStringsMarkdown,
|
||||
read_only_context_code: str,
|
||||
helper_functions: list[FunctionSource],
|
||||
|
|
@ -1067,7 +1066,7 @@ class FunctionOptimizer:
|
|||
# Submit the test generation task as future
|
||||
future_tests = self.submit_test_generation_tasks(
|
||||
self.executor,
|
||||
testgen_context_code,
|
||||
testgen_context.markdown,
|
||||
[definition.fully_qualified_name for definition in helper_functions],
|
||||
generated_test_paths,
|
||||
generated_perf_test_paths,
|
||||
|
|
|
|||
|
|
@ -827,7 +827,7 @@ class MainClass:
|
|||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
|
||||
code_context = func_optimizer.get_code_optimization_context().unwrap()
|
||||
assert code_context.testgen_context_code.rstrip() == get_code_output.rstrip()
|
||||
assert code_context.testgen_context.rstrip() == get_code_output.rstrip()
|
||||
|
||||
|
||||
def test_code_replacement11() -> None:
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ def test_class_method_dependencies() -> None:
|
|||
)
|
||||
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
|
||||
assert (
|
||||
code_context.testgen_context_code
|
||||
code_context.testgen_context
|
||||
== """from collections import defaultdict
|
||||
|
||||
class Graph:
|
||||
|
|
@ -220,7 +220,7 @@ def test_recursive_function_context() -> None:
|
|||
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
|
||||
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
|
||||
assert (
|
||||
code_context.testgen_context_code
|
||||
code_context.testgen_context
|
||||
== """class C:
|
||||
def calculate_something_3(self, num):
|
||||
return num + 1
|
||||
|
|
|
|||
|
|
@ -241,7 +241,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
code_context = ctx_result.unwrap()
|
||||
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
|
||||
assert (
|
||||
code_context.testgen_context_code
|
||||
code_context.testgen_context
|
||||
== f'''_P = ParamSpec("_P")
|
||||
_KEY_T = TypeVar("_KEY_T")
|
||||
_STORE_T = TypeVar("_STORE_T")
|
||||
|
|
@ -409,7 +409,7 @@ def test_bubble_sort_deps() -> None:
|
|||
pytest.fail()
|
||||
code_context = ctx_result.unwrap()
|
||||
assert (
|
||||
code_context.testgen_context_code
|
||||
code_context.testgen_context
|
||||
== """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
|
||||
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue