codeflash-internal/codeflash/main.py
2023-10-22 13:45:41 -07:00

127 lines
5.3 KiB
Python

import os
import time
from dataclasses import dataclass
from typing import Optional
from codeflash.code_utils.code_extractor import get_code
from codeflash.code_utils.code_replacer import replace_file_with_new_function
from codeflash.code_utils.code_utils import get_module_name_from_file
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
from codeflash.optimization.function_context import get_function_context_len_constrained
from codeflash.optimization.optimizer import optimize_python_code
from codeflash.verification.test_runner import run_tests
from codeflash.verification.verification_utils import get_test_file_path
from codeflash.verification.verifier import generate_tests
@dataclass
class args:
root: Optional[str] = "codeflash"
all: bool = False
file: Optional[str] = None
function: Optional[str] = None
tests_root: Optional[str] = "tests"
test_framework: Optional[str] = "pytest"
MAX_TEST_RUN_ITERATIONS = 5
INDIVIDUAL_TEST_TIMEOUT = 15
MAX_FUNCTION_TEST_SECONDS = 60
N_CANDIDATES = 10
def main() -> None:
print("RUNNING THE OPTIMIZER")
modified_functions = get_functions_to_optimize(
args.root, optimize_all=args.all, file=args.file, function=args.function
)
test_files_created = set()
existing_unit_tests = set()
try:
functions_to_tests_map = discover_unit_tests(args.tests_root, args.root)
print(functions_to_tests_map)
for path in modified_functions:
# TODO: Sequence the functions one goes through intelligently. If we are optimizing f(g(x)), then we might want to first
# optimize f rather than g because optimizing f would already optimizie g as it is a dependency
for function_name in modified_functions[path]:
if path.startswith(args.tests_root + os.sep):
print("SKIPPING OPTIMIZING TEST FILE")
continue
with open(path, "r") as f:
original_code = f.read()
print(f"OPTIMIZING {function_name} IN {path}")
code_to_optimize = get_code(path, function_name)
if code_to_optimize is None:
print("Could not find function to optimize")
continue
print(code_to_optimize)
(
code_to_optimize,
function_dependencies,
) = get_function_context_len_constrained(
function_name, path, args.root, code_to_optimize, max_tokens=2000
)
print("CODE TO OPTIMIZE", code_to_optimize)
module_path = get_module_name_from_file(path, args.root)
for i, tests_in_file in enumerate(
functions_to_tests_map[module_path + "." + function_name]
):
existing_unit_tests.add(tests_in_file.test_file)
new_tests = generate_tests(
source_code_being_tested=code_to_optimize,
function_name=function_name,
module_path=module_path,
function_dependencies=function_dependencies,
)
print(new_tests)
generated_tests_path = get_test_file_path(args.tests_root, function_name, 0)
test_files_created.add(generated_tests_path)
with open(generated_tests_path, "w") as file:
file.write(new_tests)
generated_tests_elapsed_time = 0.0
for i in range(MAX_TEST_RUN_ITERATIONS):
if generated_tests_elapsed_time > MAX_FUNCTION_TEST_SECONDS:
break
for test_file in existing_unit_tests:
stdout, stderr = run_tests(
test_file,
cwd=args.root,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
)
print(stdout, stderr)
start_time = time.time()
test_stdout, test_stderr = run_tests(
test_path=generated_tests_path,
cwd=args.root,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
)
generated_tests_elapsed_time += time.time() - start_time
print(test_stdout, test_stderr)
optimizations = optimize_python_code(code_to_optimize, n=N_CANDIDATES)
for i, (optimized_code, explanation) in enumerate(optimizations):
if optimized_code is None:
continue
print("optimized_candidate", optimized_code)
try:
new_code = replace_file_with_new_function(
path, function_name, optimized_code
)
except Exception as e:
print(e)
break
with open(path, "w") as f:
f.write(new_code)
with open(path, "w") as f:
f.write(original_code)
print("----------------")
except Exception as e:
print(e)
if __name__ == "__main__":
print(main())