fix: unit tests

This commit is contained in:
ali 2025-09-25 04:10:25 +03:00
parent 5981d75b3e
commit 4b41ab73bf
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
3 changed files with 20 additions and 14 deletions

View file

@ -797,7 +797,8 @@ class MainClass:
def test_code_replacement10() -> None: def test_code_replacement10() -> None:
get_code_output = """from __future__ import annotations get_code_output = """# file: test_code_replacement.py
from __future__ import annotations
class HelperClass: class HelperClass:
def __init__(self, name): def __init__(self, name):
@ -827,7 +828,7 @@ class MainClass:
) )
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config) func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
code_context = func_optimizer.get_code_optimization_context().unwrap() code_context = func_optimizer.get_code_optimization_context().unwrap()
assert code_context.testgen_context.rstrip() == get_code_output.rstrip() assert code_context.testgen_context.flat.rstrip() == get_code_output.rstrip()
def test_code_replacement11() -> None: def test_code_replacement11() -> None:

View file

@ -160,8 +160,9 @@ def test_class_method_dependencies() -> None:
) )
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil" assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
assert ( assert (
code_context.testgen_context code_context.testgen_context.flat
== """from collections import defaultdict == """# file: test_function_dependencies.py
from collections import defaultdict
class Graph: class Graph:
def __init__(self, vertices): def __init__(self, vertices):
@ -220,8 +221,9 @@ def test_recursive_function_context() -> None:
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3" assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive" assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
assert ( assert (
code_context.testgen_context code_context.testgen_context.flat
== """class C: == """# file: test_function_dependencies.py
class C:
def calculate_something_3(self, num): def calculate_something_3(self, num):
return num + 1 return num + 1

View file

@ -241,8 +241,9 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
code_context = ctx_result.unwrap() code_context = ctx_result.unwrap()
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call" assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
assert ( assert (
code_context.testgen_context code_context.testgen_context.flat
== f'''_P = ParamSpec("_P") == f'''# file: {file_path.relative_to(project_root_path)}
_P = ParamSpec("_P")
_KEY_T = TypeVar("_KEY_T") _KEY_T = TypeVar("_KEY_T")
_STORE_T = TypeVar("_STORE_T") _STORE_T = TypeVar("_STORE_T")
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
@ -394,10 +395,11 @@ def test_bubble_sort_deps() -> None:
function_to_optimize = FunctionToOptimize( function_to_optimize = FunctionToOptimize(
function_name="sorter_deps", file_path=file_path, parents=[], starting_line=None, ending_line=None function_name="sorter_deps", file_path=file_path, parents=[], starting_line=None, ending_line=None
) )
project_root = file_path.parent.parent.resolve()
test_config = TestConfig( test_config = TestConfig(
tests_root=str(file_path.parent / "tests"), tests_root=str(file_path.parent / "tests"),
tests_project_rootdir=file_path.parent.resolve(), tests_project_rootdir=file_path.parent.resolve(),
project_root_path=file_path.parent.parent.resolve(), project_root_path=project_root,
test_framework="pytest", test_framework="pytest",
pytest_cmd="pytest", pytest_cmd="pytest",
) )
@ -409,19 +411,20 @@ def test_bubble_sort_deps() -> None:
pytest.fail() pytest.fail()
code_context = ctx_result.unwrap() code_context = ctx_result.unwrap()
assert ( assert (
code_context.testgen_context code_context.testgen_context.flat
== """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer == f"""# file: code_to_optimize/bubble_sort_dep1_helper.py
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
def dep1_comparer(arr, j: int) -> bool: def dep1_comparer(arr, j: int) -> bool:
return arr[j] > arr[j + 1] return arr[j] > arr[j + 1]
# file: code_to_optimize/bubble_sort_dep2_swap.py
def dep2_swap(arr, j): def dep2_swap(arr, j):
temp = arr[j] temp = arr[j]
arr[j] = arr[j + 1] arr[j] = arr[j + 1]
arr[j + 1] = temp arr[j + 1] = temp
# file: code_to_optimize/bubble_sort_deps.py
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
def sorter_deps(arr): def sorter_deps(arr):
for i in range(len(arr)): for i in range(len(arr)):