more tests

This commit is contained in:
Aseem Saxena 2025-03-31 18:36:32 -07:00
parent 4b2046586e
commit 00abefbe6f
2 changed files with 148 additions and 2 deletions

View file

@ -1,6 +1,6 @@
from code_to_optimize.bubble_sort_in_nested_class import WrapperClass
from code_to_optimize.bubble_sort_in_class import BubbleSortClass
def sort_classmethod(x):
y = WrapperClass.BubbleSortClass()
y = BubbleSortClass()
return y.sorter(x)

View file

@ -3139,3 +3139,149 @@ class TestPigLatin(unittest.TestCase):
finally:
test_path.unlink(missing_ok=True)
def test_add_decorator_imports_helper_in_class():
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_classmethod.py").resolve()
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
project_root_path = (Path(__file__).parent / "..").resolve()
original_cwd = Path.cwd()
run_cwd = Path(__file__).parent.parent.resolve()
test_config = TestConfig(
tests_root=tests_root,
tests_project_rootdir=project_root_path,
project_root_path=project_root_path,
test_framework="pytest",
pytest_cmd="pytest",
)
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
#func_optimizer = pass
try:
ctx_result = func_optimizer.get_code_optimization_context()
code_context: CodeOptimizationContext = ctx_result.unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
computed_fn_opt = True
line_profiler_output_file = add_decorator_imports(
func_optimizer.function_to_optimize, code_context)
expected_code_main = f"""from code_to_optimize.bubble_sort_in_class import BubbleSortClass
from line_profiler import profile as codeflash_line_profile
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
@codeflash_line_profile
def sort_classmethod(x):
y = BubbleSortClass()
return y.sorter(x)
"""
expected_code_helper = """from line_profiler import profile as codeflash_line_profile
def hi():
pass
class BubbleSortClass:
def __init__(self):
pass
@codeflash_line_profile
def sorter(self, arr):
n = len(arr)
for i in range(n):
for j in range(0, n - i - 1):
if arr[j] > arr[j + 1]:
arr[j], arr[j + 1] = arr[j + 1], arr[j]
return arr
def helper(self, arr, j):
return arr[j] > arr[j + 1]
"""
assert code_path.read_text("utf-8") == expected_code_main
assert code_context.helper_functions[0].file_path.read_text("utf-8") == expected_code_helper
finally:
#if computed_fn_opt:
func_optimizer.write_code_and_helpers(
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
)
def test_add_decorator_imports_helper_in_nested_class():
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_nested_classmethod.py").resolve()
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
project_root_path = (Path(__file__).parent / "..").resolve()
original_cwd = Path.cwd()
run_cwd = Path(__file__).parent.parent.resolve()
test_config = TestConfig(
tests_root=tests_root,
tests_project_rootdir=project_root_path,
project_root_path=project_root_path,
test_framework="pytest",
pytest_cmd="pytest",
)
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
#func_optimizer = pass
try:
ctx_result = func_optimizer.get_code_optimization_context()
code_context: CodeOptimizationContext = ctx_result.unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
computed_fn_opt = True
line_profiler_output_file = add_decorator_imports(
func_optimizer.function_to_optimize, code_context)
expected_code_main = f"""from code_to_optimize.bubble_sort_in_nested_class import WrapperClass
from line_profiler import profile as codeflash_line_profile
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
@codeflash_line_profile
def sort_classmethod(x):
y = WrapperClass.BubbleSortClass()
return y.sorter(x)
"""
expected_code_helper = """from line_profiler import profile as codeflash_line_profile
def hi():
pass
class WrapperClass:
def __init__(self):
pass
class BubbleSortClass:
def __init__(self):
pass
@codeflash_line_profile
def sorter(self, arr):
def inner_helper(arr, j):
return arr[j] > arr[j + 1]
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr
def helper(self, arr, j):
return arr[j] > arr[j + 1]
"""
assert code_path.read_text("utf-8") == expected_code_main
assert code_context.helper_functions[0].file_path.read_text("utf-8") == expected_code_helper
finally:
#if computed_fn_opt:
func_optimizer.write_code_and_helpers(
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
)