diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 32de8bc4d..2d547e27e 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -797,7 +797,8 @@ class MainClass: def test_code_replacement10() -> None: - get_code_output = """from __future__ import annotations + get_code_output = """# file: test_code_replacement.py +from __future__ import annotations class HelperClass: def __init__(self, name): @@ -827,7 +828,7 @@ class MainClass: ) func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config) code_context = func_optimizer.get_code_optimization_context().unwrap() - assert code_context.testgen_context.rstrip() == get_code_output.rstrip() + assert code_context.testgen_context.flat.rstrip() == get_code_output.rstrip() def test_code_replacement11() -> None: diff --git a/tests/test_function_dependencies.py b/tests/test_function_dependencies.py index 49f4fc30e..4a886ba8d 100644 --- a/tests/test_function_dependencies.py +++ b/tests/test_function_dependencies.py @@ -160,8 +160,9 @@ def test_class_method_dependencies() -> None: ) assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil" assert ( - code_context.testgen_context - == """from collections import defaultdict + code_context.testgen_context.flat + == """# file: test_function_dependencies.py +from collections import defaultdict class Graph: def __init__(self, vertices): @@ -220,8 +221,9 @@ def test_recursive_function_context() -> None: assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3" assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive" assert ( - code_context.testgen_context - == """class C: + code_context.testgen_context.flat + == """# file: test_function_dependencies.py +class C: def calculate_something_3(self, num): return num + 1 diff --git a/tests/test_get_helper_code.py b/tests/test_get_helper_code.py index c3382e513..5cf2c963e 100644 --- a/tests/test_get_helper_code.py +++ b/tests/test_get_helper_code.py @@ -241,8 +241,9 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]): code_context = ctx_result.unwrap() assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call" assert ( - code_context.testgen_context - == f'''_P = ParamSpec("_P") + code_context.testgen_context.flat + == f'''# file: {file_path.relative_to(project_root_path)} +_P = ParamSpec("_P") _KEY_T = TypeVar("_KEY_T") _STORE_T = TypeVar("_STORE_T") class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): @@ -394,10 +395,11 @@ def test_bubble_sort_deps() -> None: function_to_optimize = FunctionToOptimize( function_name="sorter_deps", file_path=file_path, parents=[], starting_line=None, ending_line=None ) + project_root = file_path.parent.parent.resolve() test_config = TestConfig( tests_root=str(file_path.parent / "tests"), tests_project_rootdir=file_path.parent.resolve(), - project_root_path=file_path.parent.parent.resolve(), + project_root_path=project_root, test_framework="pytest", pytest_cmd="pytest", ) @@ -409,19 +411,20 @@ def test_bubble_sort_deps() -> None: pytest.fail() code_context = ctx_result.unwrap() assert ( - code_context.testgen_context - == """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer -from code_to_optimize.bubble_sort_dep2_swap import dep2_swap - + code_context.testgen_context.flat + == f"""# file: code_to_optimize/bubble_sort_dep1_helper.py def dep1_comparer(arr, j: int) -> bool: return arr[j] > arr[j + 1] +# file: code_to_optimize/bubble_sort_dep2_swap.py def dep2_swap(arr, j): temp = arr[j] arr[j] = arr[j + 1] arr[j + 1] = temp - +# file: code_to_optimize/bubble_sort_deps.py +from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer +from code_to_optimize.bubble_sort_dep2_swap import dep2_swap def sorter_deps(arr): for i in range(len(arr)):