codeflash-internal/codeflash/main.py
2023-11-16 15:18:05 -08:00

642 lines
33 KiB
Python

import logging
from typing import List
from codeflash.code_utils import env_utils
from codeflash.verification import EXPLAIN_MODEL
logging.basicConfig(level=logging.INFO)
import os
import subprocess
import time
from argparse import ArgumentParser, SUPPRESS, Namespace
import libcst as cst
from codeflash.code_utils.code_extractor import get_code
from codeflash.code_utils.code_replacer import replace_function_in_file
from codeflash.code_utils.code_utils import (
module_name_from_file_path,
get_all_function_names,
get_run_tmp_file,
)
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.discovery.discover_unit_tests import discover_unit_tests, TestsInFile
from codeflash.discovery.functions_to_optimize import (
get_functions_to_optimize_by_file,
FunctionToOptimize,
)
from codeflash.instrumentation.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.models import TestConfig
from codeflash.optimization.function_context import (
get_constrained_function_context_and_dependent_functions,
Source,
)
from codeflash.optimization.optimizer import optimize_python_code
from codeflash.verification.equivalence import compare_results
from codeflash.verification.parse_test_output import (
TestType,
parse_test_results,
)
from codeflash.verification.test_results import TestResults
from codeflash.verification.test_runner import run_tests
from codeflash.verification.verification_utils import (
get_test_file_path,
)
from codeflash.verification.verifier import generate_tests
def parse_args() -> Namespace:
parser = ArgumentParser()
parser.add_argument("--file", help="Try to optimize only this file")
parser.add_argument(
"--function",
help="Try to optimize only this function within the given file path",
)
parser.add_argument(
"--all",
help="Try to optimize all functions. Can take a really long time. Can pass an optional starting directory to"
" optimize code from. If no args specified (just --all), will optimize all code in the project.",
nargs="?",
const="",
default=SUPPRESS,
)
parser.add_argument(
"--config-file",
type=str,
help="Path to the pyproject.toml with codeflash configs.",
)
parser.add_argument(
"--root",
type=str,
help="Path to the root of the project, from where your python modules are imported",
)
parser.add_argument(
"--test-root",
type=str,
)
parser.add_argument("--test-framework", choices=["pytest", "unittest"])
parser.add_argument(
"--use-cached-tests",
action="store_true",
help="Use cached tests from a specified file for debugging.",
)
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose logs")
args: Namespace = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
if args.function and not args.file:
raise ValueError("If you specify a --function, you must specify the --file it is in")
pyproject_config = parse_config_file(args.config_file)
supported_keys = ["root", "test_root", "test_framework"]
for key in supported_keys:
if key in pyproject_config and getattr(args, key.replace("-", "_")) is None:
setattr(args, key.replace("-", "_"), pyproject_config[key])
assert os.path.isdir(args.root), "--root must be a valid directory"
assert os.path.isdir(args.test_root), "--test_root must be a valid directory"
args.root = os.path.realpath(args.root)
args.test_root = os.path.realpath(args.test_root)
if not hasattr(args, "all"):
setattr(args, "all", None)
elif args.all == "":
# The default behavior of --all is to optimize everything in args.root
args.all = args.root
else:
args.all = os.path.realpath(args.all)
return args
MAX_TEST_RUN_ITERATIONS = 5
INDIVIDUAL_TEST_TIMEOUT = 15
MAX_FUNCTION_TEST_SECONDS = 60
N_CANDIDATES = 10
MIN_IMPROVEMENT_THRESHOLD = 0.05
class Optimizer:
def __init__(self, args: Namespace):
self.args = args
self.test_cfg = TestConfig(
test_root=args.test_root,
project_root_path=args.root,
test_framework=args.test_framework,
)
def run(self):
logging.info("RUNNING THE OPTIMIZER")
env_utils.ensure_codeflash_api_key()
file_to_funcs_to_optimize, num_modified_functions = get_functions_to_optimize_by_file(
optimize_all=self.args.all,
file=self.args.file,
function=self.args.function,
test_cfg=self.test_cfg,
)
test_files_created = set()
test_files_to_preserve = set()
instrumented_unittests_created = set()
self.found_atleast_one_optimization = False
if os.path.exists("/tmp/pr_comment_temp.txt"):
os.remove("/tmp/pr_comment_temp.txt")
function_iterator_count = 0
try:
if num_modified_functions == 0:
logging.info("No functions found to optimize. Exiting...")
return
function_to_tests: dict[str, list[TestsInFile]] = discover_unit_tests(self.test_cfg)
for path in file_to_funcs_to_optimize:
logging.info(f"Examining file {path} ...")
# TODO: Sequence the functions one goes through intelligently. If we are optimizing f(g(x)), then we might want to first
# optimize f rather than g because optimizing f would already optimize g as it is a dependency
with open(path, "r") as f:
original_code = f.read()
for function_to_optimize in file_to_funcs_to_optimize[path]:
function_name = function_to_optimize.function_name
function_iterator_count += 1
logging.info(
f"Optimizing function {function_iterator_count} of {num_modified_functions} - {function_name}"
)
explanation_final = ""
winning_test_results = None
overall_original_test_results = None
if os.path.exists(get_run_tmp_file("test_return_values_0.bin")):
# remove left overs from previous run
os.remove(get_run_tmp_file("test_return_values_0.bin"))
if os.path.exists(get_run_tmp_file("test_return_values_0.sqlite")):
os.remove(get_run_tmp_file("test_return_values_0.sqlite"))
code_to_optimize = get_code(function_to_optimize)
if code_to_optimize is None:
logging.error("Could not find function to optimize")
continue
preexisting_functions = get_all_function_names(code_to_optimize)
(
code_to_optimize_with_dependents,
function_dependencies,
) = get_constrained_function_context_and_dependent_functions(
function_to_optimize,
self.args.root,
code_to_optimize,
max_tokens=EXPLAIN_MODEL.max_tokens,
)
logging.info("CODE TO OPTIMIZE %s", code_to_optimize_with_dependents)
module_path = module_name_from_file_path(path, self.args.root)
unique_original_test_files = set()
full_module_function_path = module_path + "." + function_name
if full_module_function_path not in function_to_tests:
logging.warning(
"Could not find any pre-existing tests for '%s', will only use generated tests.",
full_module_function_path,
)
else:
for tests_in_file in function_to_tests.get(full_module_function_path):
if tests_in_file.test_file in unique_original_test_files:
continue
injected_test = inject_profiling_into_existing_test(
tests_in_file.test_file,
function_name,
self.args.root,
)
new_test_path = (
os.path.splitext(tests_in_file.test_file)[0]
+ "__perfinstrumented"
+ os.path.splitext(tests_in_file.test_file)[1]
)
with open(new_test_path, "w") as f:
f.write(injected_test)
instrumented_unittests_created.add(new_test_path)
unique_original_test_files.add(tests_in_file.test_file)
generated_tests_path = self.generate_test_files(
function_to_optimize,
function_dependencies,
code_to_optimize_with_dependents,
module_path,
)
test_files_created.add(generated_tests_path)
original_runtime = None
times_run = 0
# TODO : Dynamically determine the number of times to run the tests based on the runtime of the tests.
# Keep the runtime in some acceptable range
generated_tests_elapsed_time = 0.0
# For the original function - run the tests and get the runtime
# TODO: Compare the function return values over the multiple runs and check if they are any different,
# if they are different, then we can't optimize this function because it is a non-deterministic function
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = str(0)
for i in range(MAX_TEST_RUN_ITERATIONS):
if generated_tests_elapsed_time > MAX_FUNCTION_TEST_SECONDS:
break
instrumented_test_timing = []
original_test_results_iter = TestResults()
for test_file in instrumented_unittests_created:
result_file_path, run_result = run_tests(
test_file,
test_framework=self.args.test_framework,
cwd=self.args.root,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
verbose=True,
test_env=test_env,
)
unittest_results = parse_test_results(
test_xml_path=result_file_path,
test_py_path=test_file,
test_config=self.test_cfg,
test_type=TestType.EXISTING_UNIT_TEST,
run_result=run_result,
optimization_iteration=0,
)
for result in unittest_results:
if result.did_pass and result.runtime is None:
logging.debug(
f"Ignoring test case that passed but had no runtime -> {result.id}"
)
timing = sum(
[
result.runtime
for result in unittest_results
if (result.did_pass and result.runtime is not None)
]
)
original_test_results_iter.merge(unittest_results)
instrumented_test_timing.append(timing)
if i == 0:
logging.info(
f"original code, existing unit test results -> {original_test_results_iter.get_test_pass_fail_report()}"
)
start_time = time.time()
result_file_path, run_result = run_tests(
test_path=generated_tests_path,
cwd=self.args.root,
test_framework=self.args.test_framework,
test_env=test_env,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
)
generated_tests_elapsed_time += time.time() - start_time
# TODO: Implement the logic to disregard the timing info of the tests that ERRORed out. That is remove test cases that failed to run.
original_gen_results = parse_test_results(
result_file_path,
generated_tests_path,
self.test_cfg,
test_type=TestType.GENERATED_REGRESSION,
run_result=run_result,
optimization_iteration=0,
)
if not original_gen_results and len(instrumented_test_timing) == 0:
logging.warning(
f"Couldn't run any tests for original function {function_name}. SKIPPING OPTIMIZING THIS FUNCTION."
)
break
# TODO: Doing a simple sum of test runtime, Improve it by looking at test by test runtime, or a better scheme
# TODO: If the runtime is None, that happens in the case where an exception is expected and is successfully
# caught by the test framework. This makes the test pass, but we can't find runtime because the exception caused
# the execution to not reach the runtime measurement part. We are currently ignoring such tests, because the performance
# for such a execution that raises an exception should not matter.
for result in original_gen_results:
if result.did_pass and result.runtime is None:
logging.debug(
f"Ignoring test case that passed but had no runtime -> {result.id}"
)
if i == 0:
logging.info(
f"original generated tests results -> {original_gen_results.get_test_pass_fail_report()}"
)
original_total_runtime_iter = sum(
(
[
result.runtime
for result in original_gen_results
if (result.did_pass and result.runtime is not None)
]
if original_gen_results is not None
else []
)
+ instrumented_test_timing
)
if original_total_runtime_iter == 0:
logging.warning(
f"The overall test runtime of the original function is 0, trying again..."
)
logging.warning(original_gen_results.test_results)
continue
original_test_results_iter.merge(original_gen_results)
if i == 0:
logging.info(
f"Original overall test results = {original_test_results_iter.get_test_pass_fail_report_by_type()}"
)
if (
original_runtime is None
or original_total_runtime_iter < original_runtime
):
original_runtime = best_runtime = original_total_runtime_iter
overall_original_test_results = original_test_results_iter
times_run += 1
if times_run == 0:
logging.warning(
"Failed to run the tests for the original function, skipping optimization"
)
continue
logging.info(
f"ORIGINAL CODE RUNTIME OVER {times_run} RUN{'S' if times_run > 1 else ''} = {original_runtime}ns"
)
logging.info("OPTIMIZING CODE....")
# TODO: Postprocess the optimized function to include the original docstring and such
optimizations = optimize_python_code(
code_to_optimize_with_dependents, n=N_CANDIDATES
)
best_optimization = []
for i, (optimized_code, explanation) in enumerate(optimizations):
j = i + 1
if optimized_code is None:
continue
if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")):
# remove left overs from previous run
os.remove(get_run_tmp_file(f"test_return_values_{j}.bin"))
if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")):
os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite"))
logging.info(f"optimized_candidate {optimized_code}")
try:
new_code = replace_function_in_file(
path,
function_name,
optimized_code,
preexisting_functions,
# test_cfg.project_root_path,
# function_dependencies,
)
except (
ValueError,
SyntaxError,
cst.ParserSyntaxError,
AttributeError,
) as e:
logging.error(e)
continue
with open(path, "w") as f:
f.write(new_code)
all_test_times = []
equal_results = True
generated_tests_elapsed_time = 0.0
times_run = 0
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = str(j)
for test_index in range(MAX_TEST_RUN_ITERATIONS):
if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")):
os.remove(get_run_tmp_file(f"test_return_values_{j}.bin"))
if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")):
os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite"))
if generated_tests_elapsed_time > MAX_FUNCTION_TEST_SECONDS:
break
optimized_test_results_iter = TestResults()
instrumented_test_timing = []
for instrumented_test_file in instrumented_unittests_created:
result_file_path, run_result = run_tests(
instrumented_test_file,
test_framework=self.args.test_framework,
cwd=self.args.root,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
verbose=True,
test_env=test_env,
)
unittest_results_optimized = parse_test_results(
test_xml_path=result_file_path,
test_py_path=instrumented_test_file,
test_config=self.test_cfg,
test_type=TestType.EXISTING_UNIT_TEST,
run_result=run_result,
optimization_iteration=j,
)
for result in unittest_results_optimized:
if result.did_pass and result.runtime is None:
logging.debug(
f"Ignoring test case that passed but had no runtime -> {result.id}"
)
timing = sum(
[
result.runtime
for result in unittest_results_optimized
if (result.did_pass and result.runtime is not None)
]
)
optimized_test_results_iter.merge(unittest_results_optimized)
instrumented_test_timing.append(timing)
if test_index == 0:
equal_results = True
logging.info(
f"optimized existing unit tests result -> {optimized_test_results_iter.get_test_pass_fail_report()}"
)
for test_invocation in optimized_test_results_iter:
if (
overall_original_test_results.get_by_id(test_invocation.id)
is None
or test_invocation.did_pass
!= overall_original_test_results.get_by_id(
test_invocation.id
).did_pass
):
logging.info("RESULTS DID NOT MATCH")
logging.info(
f"Test {test_invocation.id} failed on the optimized code. Skipping this optimization"
)
equal_results = False
break
if not equal_results:
break
start_time = time.time()
result_file_path, run_result = run_tests(
test_path=generated_tests_path,
test_framework=self.args.test_framework,
cwd=self.args.root,
test_env=test_env,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
)
generated_tests_elapsed_time += time.time() - start_time
test_results = parse_test_results(
test_xml_path=result_file_path,
test_py_path=generated_tests_path,
optimization_iteration=j,
test_type=TestType.GENERATED_REGRESSION,
test_config=self.test_cfg,
run_result=run_result,
)
if test_index == 0:
logging.info(
f"generated test_results optimized -> {test_results.get_test_pass_fail_report()}"
)
if test_results:
if compare_results(original_gen_results, test_results):
equal_results = True
logging.info("RESULTS MATCHED!")
else:
logging.info("RESULTS DID NOT MATCH")
equal_results = False
if not equal_results:
break
for result in test_results:
if result.did_pass and result.runtime is None:
logging.debug(
f"Ignoring test case that passed but had no runtime -> {result.id}"
)
test_runtime = sum(
(
[
result.runtime
for result in test_results
if (result.did_pass and result.runtime is not None)
]
if test_results is not None
else []
)
+ instrumented_test_timing
)
if test_runtime == 0:
logging.warning(
f"The overall test runtime of the optimized function is 0, trying again..."
)
continue
all_test_times.append(test_runtime)
optimized_test_results_iter.merge(test_results)
times_run += 1
if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")):
os.remove(get_run_tmp_file(f"test_return_values_{j}.bin"))
if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")):
os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite"))
if equal_results and times_run > 0:
# TODO: Make the runtime more human readable by using humanize
new_test_time = min(all_test_times)
logging.info(
f"NEW CODE RUNTIME OVER {times_run} RUN{'S' if times_run > 1 else ''} = {new_test_time}ns, SPEEDUP RATIO = {((original_runtime - new_test_time) / new_test_time):.3f}"
)
if (
((original_runtime - new_test_time) / new_test_time)
> MIN_IMPROVEMENT_THRESHOLD
) and new_test_time < best_runtime:
logging.info("THIS IS BETTER!")
logging.info(
f"original_test_time={original_runtime} new_test_time={new_test_time}, FASTER RATIO = {((original_runtime - new_test_time) / new_test_time)}"
)
best_optimization = [optimized_code, explanation]
best_runtime = new_test_time
winning_test_results = optimized_test_results_iter
with open(path, "w") as f:
f.write(original_code)
logging.info("----------------")
logging.info(f"BEST OPTIMIZATION {best_optimization}")
if best_optimization:
self.found_atleast_one_optimization = True
logging.info(f"BEST OPTIMIZED CODE {best_optimization[0]}")
if not self.args.all:
new_code = replace_function_in_file(
path,
function_name,
best_optimization[0],
preexisting_functions,
# test_cfg.project_root_path,
# function_dependencies,
)
with open(path, "w") as f:
f.write(new_code)
# TODO: After doing the best optimization, remove the test cases that errored on the new code, because they might be failing because of syntax errors and such.
speedup = (original_runtime / best_runtime) - 1
# TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts
# TODO: Use python package humanize to make the runtime more human readable
explanation_final += (
f"Function {function_name} in file {path}:\n"
f"Performance went up by {speedup:.2f}x ({speedup * 100:.2f}%). Runtime went down from {(original_runtime / 1000):.2f}μs to {(best_runtime / 1000):.2f}μs \n\n"
+ "Optimization explanation:\n"
+ best_optimization[1]
+ " \n\n"
+ "The code has been tested for correctness\n"
+ f"Test Result for the best optimized code:- {winning_test_results.get_test_pass_fail_report_by_type()}\n"
)
with open("/tmp/pr_comment_temp.txt", "a") as f:
f.write(explanation_final)
logging.info(f"EXPLANATION_FINAL {explanation_final}")
if self.args.all:
with open("optimizations_all.txt", "a") as f:
f.write(best_optimization[0])
f.write("\n\n")
f.write(explanation_final)
f.write("\n---------\n")
subprocess.run(["black", path], stdout=subprocess.PIPE)
test_files_to_preserve.add(generated_tests_path)
try:
with open(os.environ["GITHUB_OUTPUT"], "w") as fh:
print("optimization_success=truee", file=fh)
except KeyError:
os.environ["GITHUB_OUTPUT"] = "optimization_success=truee"
else:
# Delete it here to not cause a lot of clutter if we are optimizing with --all option
if os.path.exists(generated_tests_path):
os.remove(generated_tests_path)
if not self.found_atleast_one_optimization:
try:
with open(os.environ["GITHUB_OUTPUT"], "w") as fh:
print("optimization_success=falsee", file=fh)
except KeyError:
os.environ["GITHUB_OUTPUT"] = "optimization_success=falsee"
finally:
# TODO: Also revert the file/function being optimized if the process did not succeed
for test_file in instrumented_unittests_created:
if os.path.exists(test_file):
os.remove(test_file)
for test_file in test_files_created:
if test_file not in test_files_to_preserve:
if os.path.exists(test_file):
os.remove(test_file)
if hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir.cleanup()
def generate_test_files(
self,
function_to_optimize: FunctionToOptimize,
function_dependencies: List[Source],
code_to_optimize_with_dependents: str,
module_path: str,
) -> str | None:
generated_tests_path = get_test_file_path(
self.args.test_root, function_to_optimize.function_name, 0
)
test_module_path = module_name_from_file_path(generated_tests_path, self.args.root)
new_tests = generate_tests(
source_code_being_tested=code_to_optimize_with_dependents,
function=function_to_optimize,
module_path=module_path,
test_module_path=test_module_path,
function_dependencies=function_dependencies,
test_framework=self.args.test_framework,
test_timeout=INDIVIDUAL_TEST_TIMEOUT,
use_cached_tests=self.args.use_cached_tests,
)
if new_tests is None:
logging.error("/!\\ NO TESTS GENERATED for %s", function_to_optimize.function_name)
return None
with open(generated_tests_path, "w") as file:
file.write(new_tests)
return generated_tests_path
if __name__ == "__main__":
Optimizer(parse_args()).run()