more tests with different positions of global variables and multiple reassignments

This commit is contained in:
aseembits93 2025-05-01 15:25:50 -07:00
parent 28596b7b55
commit 0cbd20451f

View file

@ -1695,6 +1695,159 @@ print("Hello world")
a=2
print("Hello world")
def some_fn():
a=np.zeros(10)
print("did something")
class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
"""
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
code_path.write_text(original_code, encoding="utf-8")
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
project_root_path = (Path(__file__).parent / "..").resolve()
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
test_config = TestConfig(
tests_root=tests_root,
tests_project_rootdir=project_root_path,
project_root_path=project_root_path,
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
assert new_code.rstrip() == expected_code.rstrip()
original_code = """print("Hello world")
def some_fn():
print("did noting")
class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
a=1
"""
optimized_code = """a=2
import numpy as np
def some_fn():
a=np.zeros(10)
print("did something")
class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
print("Hello world")
"""
expected_code = """import numpy as np
print("Hello world")
print("Hello world")
def some_fn():
a=np.zeros(10)
print("did something")
class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
a=2
"""
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
code_path.write_text(original_code, encoding="utf-8")
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
project_root_path = (Path(__file__).parent / "..").resolve()
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
test_config = TestConfig(
tests_root=tests_root,
tests_project_rootdir=project_root_path,
project_root_path=project_root_path,
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
assert new_code.rstrip() == expected_code.rstrip()
original_code = """a=1
print("Hello world")
def some_fn():
print("did noting")
class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
"""
optimized_code = """import numpy as np
a=2
def some_fn():
a=np.zeros(10)
print("did something")
class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
a=3
print("Hello world")
"""
expected_code = """import numpy as np
print("Hello world")
a=3
print("Hello world")
def some_fn():
a=np.zeros(10)
print("did something")