127 lines
5.3 KiB
Python
127 lines
5.3 KiB
Python
import os
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
from codeflash.code_utils.code_extractor import get_code
|
|
from codeflash.code_utils.code_replacer import replace_file_with_new_function
|
|
from codeflash.code_utils.code_utils import get_module_name_from_file
|
|
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
|
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
|
|
from codeflash.optimization.function_context import get_function_context_len_constrained
|
|
from codeflash.optimization.optimizer import optimize_python_code
|
|
from codeflash.verification.test_runner import run_tests
|
|
from codeflash.verification.verification_utils import get_test_file_path
|
|
from codeflash.verification.verifier import generate_tests
|
|
|
|
|
|
@dataclass
|
|
class args:
|
|
root: Optional[str] = "codeflash"
|
|
all: bool = False
|
|
file: Optional[str] = None
|
|
function: Optional[str] = None
|
|
tests_root: Optional[str] = "tests"
|
|
test_framework: Optional[str] = "pytest"
|
|
|
|
|
|
MAX_TEST_RUN_ITERATIONS = 5
|
|
INDIVIDUAL_TEST_TIMEOUT = 15
|
|
MAX_FUNCTION_TEST_SECONDS = 60
|
|
N_CANDIDATES = 10
|
|
|
|
|
|
def main() -> None:
|
|
print("RUNNING THE OPTIMIZER")
|
|
modified_functions = get_functions_to_optimize(
|
|
args.root, optimize_all=args.all, file=args.file, function=args.function
|
|
)
|
|
test_files_created = set()
|
|
existing_unit_tests = set()
|
|
try:
|
|
functions_to_tests_map = discover_unit_tests(args.tests_root, args.root)
|
|
print(functions_to_tests_map)
|
|
for path in modified_functions:
|
|
# TODO: Sequence the functions one goes through intelligently. If we are optimizing f(g(x)), then we might want to first
|
|
# optimize f rather than g because optimizing f would already optimizie g as it is a dependency
|
|
for function_name in modified_functions[path]:
|
|
if path.startswith(args.tests_root + os.sep):
|
|
print("SKIPPING OPTIMIZING TEST FILE")
|
|
continue
|
|
with open(path, "r") as f:
|
|
original_code = f.read()
|
|
print(f"OPTIMIZING {function_name} IN {path}")
|
|
code_to_optimize = get_code(path, function_name)
|
|
if code_to_optimize is None:
|
|
print("Could not find function to optimize")
|
|
continue
|
|
print(code_to_optimize)
|
|
(
|
|
code_to_optimize,
|
|
function_dependencies,
|
|
) = get_function_context_len_constrained(
|
|
function_name, path, args.root, code_to_optimize, max_tokens=2000
|
|
)
|
|
print("CODE TO OPTIMIZE", code_to_optimize)
|
|
module_path = get_module_name_from_file(path, args.root)
|
|
for i, tests_in_file in enumerate(
|
|
functions_to_tests_map[module_path + "." + function_name]
|
|
):
|
|
existing_unit_tests.add(tests_in_file.test_file)
|
|
new_tests = generate_tests(
|
|
source_code_being_tested=code_to_optimize,
|
|
function_name=function_name,
|
|
module_path=module_path,
|
|
function_dependencies=function_dependencies,
|
|
)
|
|
print(new_tests)
|
|
generated_tests_path = get_test_file_path(args.tests_root, function_name, 0)
|
|
test_files_created.add(generated_tests_path)
|
|
with open(generated_tests_path, "w") as file:
|
|
file.write(new_tests)
|
|
generated_tests_elapsed_time = 0.0
|
|
|
|
for i in range(MAX_TEST_RUN_ITERATIONS):
|
|
if generated_tests_elapsed_time > MAX_FUNCTION_TEST_SECONDS:
|
|
break
|
|
for test_file in existing_unit_tests:
|
|
stdout, stderr = run_tests(
|
|
test_file,
|
|
cwd=args.root,
|
|
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
|
|
)
|
|
print(stdout, stderr)
|
|
|
|
start_time = time.time()
|
|
test_stdout, test_stderr = run_tests(
|
|
test_path=generated_tests_path,
|
|
cwd=args.root,
|
|
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
|
|
)
|
|
generated_tests_elapsed_time += time.time() - start_time
|
|
print(test_stdout, test_stderr)
|
|
|
|
optimizations = optimize_python_code(code_to_optimize, n=N_CANDIDATES)
|
|
for i, (optimized_code, explanation) in enumerate(optimizations):
|
|
if optimized_code is None:
|
|
continue
|
|
print("optimized_candidate", optimized_code)
|
|
try:
|
|
new_code = replace_file_with_new_function(
|
|
path, function_name, optimized_code
|
|
)
|
|
except Exception as e:
|
|
print(e)
|
|
break
|
|
with open(path, "w") as f:
|
|
f.write(new_code)
|
|
with open(path, "w") as f:
|
|
f.write(original_code)
|
|
print("----------------")
|
|
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print(main())
|