diff --git a/.idea/codeflash.iml b/.idea/codeflash.iml index 1cbadf231..0b638381e 100644 --- a/.idea/codeflash.iml +++ b/.idea/codeflash.iml @@ -7,8 +7,7 @@ - - + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml index 03d9549ea..a58fa0a6f 100644 --- a/.idea/inspectionProfiles/Project_Default.xml +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -2,5 +2,6 @@ \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml index ad1389744..5e2daa8c5 100644 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -1,9 +1,10 @@ + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index 8b442f7e5..9e5b17a3a 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,8 +3,7 @@ - + \ No newline at end of file diff --git a/.idea/pydantic.xml b/.idea/pydantic.xml new file mode 100644 index 000000000..e6e3ec67b --- /dev/null +++ b/.idea/pydantic.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/.idea/sqldialects.xml b/.idea/sqldialects.xml index 2c3db7023..c0e01cabc 100644 --- a/.idea/sqldialects.xml +++ b/.idea/sqldialects.xml @@ -1,7 +1,6 @@ - \ No newline at end of file diff --git a/codeflash/main.py b/codeflash/main.py index 57ef8a468..2038dea38 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -1,4 +1,5 @@ import logging +from typing import List from codeflash.code_utils import env_utils from codeflash.verification import EXPLAIN_MODEL @@ -7,7 +8,7 @@ logging.basicConfig(level=logging.INFO) import os import subprocess import time -from argparse import ArgumentParser, SUPPRESS +from argparse import ArgumentParser, SUPPRESS, Namespace import libcst as cst @@ -19,12 +20,16 @@ from codeflash.code_utils.code_utils import ( get_run_tmp_file, ) from codeflash.code_utils.config_parser import parse_config_file -from codeflash.discovery.discover_unit_tests import discover_unit_tests -from codeflash.discovery.functions_to_optimize import get_functions_to_optimize_by_file +from codeflash.discovery.discover_unit_tests import discover_unit_tests, TestsInFile +from codeflash.discovery.functions_to_optimize import ( + get_functions_to_optimize_by_file, + FunctionToOptimize, +) from codeflash.instrumentation.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.models import TestConfig from codeflash.optimization.function_context import ( get_constrained_function_context_and_dependent_functions, + Source, ) from codeflash.optimization.optimizer import optimize_python_code from codeflash.verification.equivalence import compare_results @@ -40,7 +45,7 @@ from codeflash.verification.verification_utils import ( from codeflash.verification.verifier import generate_tests -def parse_args(): +def parse_args() -> Namespace: parser = ArgumentParser() parser.add_argument("--file", help="Try to optimize only this file") parser.add_argument( @@ -76,7 +81,7 @@ def parse_args(): help="Use cached tests from a specified file for debugging.", ) parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose logs") - args = parser.parse_args() + args: Namespace = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.DEBUG) @@ -89,9 +94,7 @@ def parse_args(): if key in pyproject_config and getattr(args, key.replace("-", "_")) is None: setattr(args, key.replace("-", "_"), pyproject_config[key]) assert os.path.isdir(args.root), "--root must be a valid directory" - assert os.path.isdir( - args.test_root - ), f"--test-root must be a valid directory; {args.test_root} is not a directory" + assert os.path.isdir(args.test_root), "--test_root must be a valid directory" args.root = os.path.realpath(args.root) args.test_root = os.path.realpath(args.test_root) if not hasattr(args, "all"): @@ -111,322 +114,147 @@ N_CANDIDATES = 10 MIN_IMPROVEMENT_THRESHOLD = 0.05 -def main(): - logging.info("RUNNING THE OPTIMIZER") - args = parse_args() - env_utils.ensure_codeflash_api_key() - test_cfg = TestConfig( - test_root=args.test_root, - project_root_path=args.root, - test_framework=args.test_framework, - ) +class Optimizer: + def __init__(self, args: Namespace): + self.args = args + self.test_cfg = TestConfig( + test_root=args.test_root, + project_root_path=args.root, + test_framework=args.test_framework, + ) - modified_functions, num_modified_functions = get_functions_to_optimize_by_file( - optimize_all=args.all, - file=args.file, - function=args.function, - test_cfg=test_cfg, - ) + def run(self): + logging.info("RUNNING THE OPTIMIZER") + env_utils.ensure_codeflash_api_key() - test_files_created = set() - test_files_to_preserve = set() - instrumented_unittests_created = set() - found_atleast_one_optimization = False + file_to_funcs_to_optimize, num_modified_functions = get_functions_to_optimize_by_file( + optimize_all=self.args.all, + file=self.args.file, + function=self.args.function, + test_cfg=self.test_cfg, + ) - if os.path.exists("/tmp/pr_comment_temp.txt"): - os.remove("/tmp/pr_comment_temp.txt") - function_iterator_count = 0 - try: - if num_modified_functions == 0: - logging.info("No functions found to optimize. Exiting...") - return - functions_to_tests_map = discover_unit_tests(test_cfg) - for path in modified_functions: - logging.info(f"Examining file {path} ...") - # TODO: Sequence the functions one goes through intelligently. If we are optimizing f(g(x)), then we might want to first - # optimize f rather than g because optimizing f would already optimize g as it is a dependency - with open(path, "r") as f: - original_code = f.read() - for function_to_optimize in modified_functions[path]: - function_name = function_to_optimize.function_name - function_iterator_count += 1 - logging.info( - f"Optimizing function {function_iterator_count} of {num_modified_functions} - {function_name}" - ) - explanation_final = "" - winning_test_results = None - overall_original_test_results = None - if os.path.exists(get_run_tmp_file("test_return_values_0.bin")): - # remove left overs from previous run - os.remove(get_run_tmp_file("test_return_values_0.bin")) - if os.path.exists(get_run_tmp_file("test_return_values_0.sqlite")): - os.remove(get_run_tmp_file("test_return_values_0.sqlite")) - code_to_optimize = get_code(function_to_optimize) - if code_to_optimize is None: - logging.error("Could not find function to optimize") - continue + test_files_created = set() + test_files_to_preserve = set() + instrumented_unittests_created = set() + self.found_atleast_one_optimization = False - preexisting_functions = get_all_function_names(code_to_optimize) - - ( - code_to_optimize_with_dependents, - function_dependencies, - ) = get_constrained_function_context_and_dependent_functions( - function_to_optimize, - args.root, - code_to_optimize, - max_tokens=EXPLAIN_MODEL.max_tokens, - ) - logging.info("CODE TO OPTIMIZE %s", code_to_optimize_with_dependents) - module_path = module_name_from_file_path(path, args.root) - unique_original_test_files = set() - - if not module_path + "." + function_name in functions_to_tests_map: - logging.warning( - "Could not find any pre-existing tests for '%s', will only use generated tests.", - module_path + "." + function_name, + if os.path.exists("/tmp/pr_comment_temp.txt"): + os.remove("/tmp/pr_comment_temp.txt") + function_iterator_count = 0 + try: + if num_modified_functions == 0: + logging.info("No functions found to optimize. Exiting...") + return + function_to_tests: dict[str, list[TestsInFile]] = discover_unit_tests(self.test_cfg) + for path in file_to_funcs_to_optimize: + logging.info(f"Examining file {path} ...") + # TODO: Sequence the functions one goes through intelligently. If we are optimizing f(g(x)), then we might want to first + # optimize f rather than g because optimizing f would already optimize g as it is a dependency + with open(path, "r") as f: + original_code = f.read() + for function_to_optimize in file_to_funcs_to_optimize[path]: + function_name = function_to_optimize.function_name + function_iterator_count += 1 + logging.info( + f"Optimizing function {function_iterator_count} of {num_modified_functions} - {function_name}" ) - else: - for i, tests_in_file in enumerate( - functions_to_tests_map.get(module_path + "." + function_name) - ): - if tests_in_file.test_file in unique_original_test_files: - continue - new_test_path = ( - os.path.splitext(tests_in_file.test_file)[0] - + "__perfinstrumented" - + os.path.splitext(tests_in_file.test_file)[1] - ) - injected_test = inject_profiling_into_existing_test( - tests_in_file.test_file, - function_name, - args.root, - ) - with open(new_test_path, "w") as f: - f.write(injected_test) - instrumented_unittests_created.add(new_test_path) - unique_original_test_files.add(tests_in_file.test_file) - generated_tests_path = get_test_file_path(args.test_root, function_name, 0) - test_module_path = module_name_from_file_path(generated_tests_path, args.root) - new_tests = generate_tests( - source_code_being_tested=code_to_optimize_with_dependents, - function=function_to_optimize, - module_path=module_path, - test_module_path=test_module_path, - function_dependencies=function_dependencies, - test_framework=args.test_framework, - test_timeout=INDIVIDUAL_TEST_TIMEOUT, - use_cached_tests=args.use_cached_tests, - ) - if new_tests is None: - logging.error("/!\\ NO TESTS GENERATED for %s", function_name) - continue - - test_files_created.add(generated_tests_path) - with open(generated_tests_path, "w") as file: - file.write(new_tests) - original_runtime = None - times_run = 0 - # TODO : Dynamically determine the number of times to run the tests based on the runtime of the tests. - # Keep the runtime in some acceptable range - generated_tests_elapsed_time = 0.0 - - # For the original function - run the tests and get the runtime - # TODO: Compare the function return values over the multiple runs and check if they are any different, - # if they are different, then we can't optimize this function because it is a non-deterministic function - test_env = os.environ.copy() - test_env["CODEFLASH_TEST_ITERATION"] = str(0) - for i in range(MAX_TEST_RUN_ITERATIONS): - if generated_tests_elapsed_time > MAX_FUNCTION_TEST_SECONDS: - break - instrumented_test_timing = [] - original_test_results_iter = TestResults() - for test_file in instrumented_unittests_created: - result_file_path, run_result = run_tests( - test_file, - test_framework=args.test_framework, - cwd=args.root, - pytest_timeout=INDIVIDUAL_TEST_TIMEOUT, - verbose=True, - test_env=test_env, - ) - unittest_results = parse_test_results( - test_xml_path=result_file_path, - test_py_path=test_file, - test_config=test_cfg, - test_type=TestType.EXISTING_UNIT_TEST, - run_result=run_result, - optimization_iteration=0, - ) - - for result in unittest_results: - if result.did_pass and result.runtime is None: - logging.debug( - f"Ignoring test case that passed but had no runtime -> {result.id}" - ) - - timing = sum( - [ - result.runtime - for result in unittest_results - if (result.did_pass and result.runtime is not None) - ] - ) - original_test_results_iter.merge(unittest_results) - instrumented_test_timing.append(timing) - if i == 0: - logging.info( - f"original code, existing unit test results -> {original_test_results_iter.get_test_pass_fail_report()}" - ) - start_time = time.time() - result_file_path, run_result = run_tests( - test_path=generated_tests_path, - cwd=args.root, - test_framework=args.test_framework, - test_env=test_env, - pytest_timeout=INDIVIDUAL_TEST_TIMEOUT, - ) - generated_tests_elapsed_time += time.time() - start_time - # TODO: Implement the logic to disregard the timing info of the tests that ERRORed out. That is remove test cases that failed to run. - original_gen_results = parse_test_results( - result_file_path, - generated_tests_path, - test_cfg, - test_type=TestType.GENERATED_REGRESSION, - run_result=run_result, - optimization_iteration=0, - ) - - if not original_gen_results and len(instrumented_test_timing) == 0: - logging.warning( - f"Couldn't run any tests for original function {function_name}. SKIPPING OPTIMIZING THIS FUNCTION." - ) - - break - # TODO: Doing a simple sum of test runtime, Improve it by looking at test by test runtime, or a better scheme - # TODO: If the runtime is None, that happens in the case where an exception is expected and is successfully - # caught by the test framework. This makes the test pass, but we can't find runtime because the exception caused - # the execution to not reach the runtime measurement part. We are currently ignoring such tests, because the performance - # for such a execution that raises an exception should not matter. - for result in original_gen_results: - if result.did_pass and result.runtime is None: - logging.debug( - f"Ignoring test case that passed but had no runtime -> {result.id}" - ) - if i == 0: - logging.info( - f"original generated tests results -> {original_gen_results.get_test_pass_fail_report()}" - ) - - original_total_runtime_iter = sum( - ( - [ - result.runtime - for result in original_gen_results - if (result.did_pass and result.runtime is not None) - ] - if original_gen_results is not None - else [] - ) - + instrumented_test_timing - ) - if original_total_runtime_iter == 0: - logging.warning( - f"The overall test runtime of the original function is 0, trying again..." - ) - logging.warning(original_gen_results.test_results) - continue - original_test_results_iter.merge(original_gen_results) - if i == 0: - logging.info( - f"Original overall test results = {original_test_results_iter.get_test_pass_fail_report_by_type()}" - ) - if original_runtime is None or original_total_runtime_iter < original_runtime: - original_runtime = best_runtime = original_total_runtime_iter - overall_original_test_results = original_test_results_iter - - times_run += 1 - - if times_run == 0: - logging.warning( - "Failed to run the tests for the original function, skipping optimization" - ) - continue - logging.info( - f"ORIGINAL CODE RUNTIME OVER {times_run} RUN{'S' if times_run > 1 else ''} = {original_runtime}ns" - ) - logging.info("OPTIMIZING CODE....") - # TODO: Postprocess the optimized function to include the original docstring and such - optimizations = optimize_python_code( - code_to_optimize_with_dependents, n=N_CANDIDATES - ) - best_optimization = [] - for i, (optimized_code, explanation) in enumerate(optimizations): - j = i + 1 - if optimized_code is None: - continue - if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")): + explanation_final = "" + winning_test_results = None + overall_original_test_results = None + if os.path.exists(get_run_tmp_file("test_return_values_0.bin")): # remove left overs from previous run - os.remove(get_run_tmp_file(f"test_return_values_{j}.bin")) - if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")): - os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite")) - logging.info(f"optimized_candidate {optimized_code}") - try: - new_code = replace_function_in_file( - path, - function_name, - optimized_code, - preexisting_functions, - # test_cfg.project_root_path, - # function_dependencies, - ) - except ( - ValueError, - SyntaxError, - cst.ParserSyntaxError, - AttributeError, - ) as e: - logging.error(e) + os.remove(get_run_tmp_file("test_return_values_0.bin")) + if os.path.exists(get_run_tmp_file("test_return_values_0.sqlite")): + os.remove(get_run_tmp_file("test_return_values_0.sqlite")) + code_to_optimize = get_code(function_to_optimize) + if code_to_optimize is None: + logging.error("Could not find function to optimize") continue - with open(path, "w") as f: - f.write(new_code) - all_test_times = [] - equal_results = True + + preexisting_functions = get_all_function_names(code_to_optimize) + + ( + code_to_optimize_with_dependents, + function_dependencies, + ) = get_constrained_function_context_and_dependent_functions( + function_to_optimize, + self.args.root, + code_to_optimize, + max_tokens=EXPLAIN_MODEL.max_tokens, + ) + logging.info("CODE TO OPTIMIZE %s", code_to_optimize_with_dependents) + module_path = module_name_from_file_path(path, self.args.root) + unique_original_test_files = set() + + full_module_function_path = module_path + "." + function_name + if full_module_function_path not in function_to_tests: + logging.warning( + "Could not find any pre-existing tests for '%s', will only use generated tests.", + full_module_function_path, + ) + else: + for tests_in_file in function_to_tests.get(full_module_function_path): + if tests_in_file.test_file in unique_original_test_files: + continue + injected_test = inject_profiling_into_existing_test( + tests_in_file.test_file, + function_name, + self.args.root, + ) + new_test_path = ( + os.path.splitext(tests_in_file.test_file)[0] + + "__perfinstrumented" + + os.path.splitext(tests_in_file.test_file)[1] + ) + with open(new_test_path, "w") as f: + f.write(injected_test) + instrumented_unittests_created.add(new_test_path) + unique_original_test_files.add(tests_in_file.test_file) + + generated_tests_path = self.generate_test_files( + function_to_optimize, + function_dependencies, + code_to_optimize_with_dependents, + module_path, + ) + + test_files_created.add(generated_tests_path) + original_runtime = None + times_run = 0 + # TODO : Dynamically determine the number of times to run the tests based on the runtime of the tests. + # Keep the runtime in some acceptable range generated_tests_elapsed_time = 0.0 - times_run = 0 + # For the original function - run the tests and get the runtime + # TODO: Compare the function return values over the multiple runs and check if they are any different, + # if they are different, then we can't optimize this function because it is a non-deterministic function test_env = os.environ.copy() - test_env["CODEFLASH_TEST_ITERATION"] = str(j) - for test_index in range(MAX_TEST_RUN_ITERATIONS): - if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")): - os.remove(get_run_tmp_file(f"test_return_values_{j}.bin")) - if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")): - os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite")) + test_env["CODEFLASH_TEST_ITERATION"] = str(0) + for i in range(MAX_TEST_RUN_ITERATIONS): if generated_tests_elapsed_time > MAX_FUNCTION_TEST_SECONDS: break - - optimized_test_results_iter = TestResults() instrumented_test_timing = [] - for instrumented_test_file in instrumented_unittests_created: + original_test_results_iter = TestResults() + for test_file in instrumented_unittests_created: result_file_path, run_result = run_tests( - instrumented_test_file, - test_framework=args.test_framework, - cwd=args.root, + test_file, + test_framework=self.args.test_framework, + cwd=self.args.root, pytest_timeout=INDIVIDUAL_TEST_TIMEOUT, verbose=True, test_env=test_env, ) - - unittest_results_optimized = parse_test_results( + unittest_results = parse_test_results( test_xml_path=result_file_path, - test_py_path=instrumented_test_file, - test_config=test_cfg, + test_py_path=test_file, + test_config=self.test_cfg, test_type=TestType.EXISTING_UNIT_TEST, run_result=run_result, - optimization_iteration=j, + optimization_iteration=0, ) - for result in unittest_results_optimized: + for result in unittest_results: if result.did_pass and result.runtime is None: logging.debug( f"Ignoring test case that passed but had no runtime -> {result.id}" @@ -435,182 +263,380 @@ def main(): timing = sum( [ result.runtime - for result in unittest_results_optimized + for result in unittest_results if (result.did_pass and result.runtime is not None) ] ) - optimized_test_results_iter.merge(unittest_results_optimized) + original_test_results_iter.merge(unittest_results) instrumented_test_timing.append(timing) - if test_index == 0: - equal_results = True + if i == 0: logging.info( - f"optimized existing unit tests result -> {optimized_test_results_iter.get_test_pass_fail_report()}" + f"original code, existing unit test results -> {original_test_results_iter.get_test_pass_fail_report()}" ) - for test_invocation in optimized_test_results_iter: - if ( - overall_original_test_results.get_by_id(test_invocation.id) - is None - or test_invocation.did_pass - != overall_original_test_results.get_by_id( - test_invocation.id - ).did_pass - ): - logging.info("RESULTS DID NOT MATCH") - logging.info( - f"Test {test_invocation.id} failed on the optimized code. Skipping this optimization" - ) - equal_results = False - break - if not equal_results: - break - start_time = time.time() result_file_path, run_result = run_tests( test_path=generated_tests_path, - test_framework=args.test_framework, - cwd=args.root, + cwd=self.args.root, + test_framework=self.args.test_framework, test_env=test_env, pytest_timeout=INDIVIDUAL_TEST_TIMEOUT, ) generated_tests_elapsed_time += time.time() - start_time - test_results = parse_test_results( - test_xml_path=result_file_path, - test_py_path=generated_tests_path, - optimization_iteration=j, + # TODO: Implement the logic to disregard the timing info of the tests that ERRORed out. That is remove test cases that failed to run. + original_gen_results = parse_test_results( + result_file_path, + generated_tests_path, + self.test_cfg, test_type=TestType.GENERATED_REGRESSION, - test_config=test_cfg, run_result=run_result, + optimization_iteration=0, ) - if test_index == 0: - logging.info( - f"generated test_results optimized -> {test_results.get_test_pass_fail_report()}" + + if not original_gen_results and len(instrumented_test_timing) == 0: + logging.warning( + f"Couldn't run any tests for original function {function_name}. SKIPPING OPTIMIZING THIS FUNCTION." ) - if test_results: - if compare_results(original_gen_results, test_results): - equal_results = True - logging.info("RESULTS MATCHED!") - else: - logging.info("RESULTS DID NOT MATCH") - equal_results = False - if not equal_results: + break - for result in test_results: + # TODO: Doing a simple sum of test runtime, Improve it by looking at test by test runtime, or a better scheme + # TODO: If the runtime is None, that happens in the case where an exception is expected and is successfully + # caught by the test framework. This makes the test pass, but we can't find runtime because the exception caused + # the execution to not reach the runtime measurement part. We are currently ignoring such tests, because the performance + # for such a execution that raises an exception should not matter. + for result in original_gen_results: if result.did_pass and result.runtime is None: logging.debug( f"Ignoring test case that passed but had no runtime -> {result.id}" ) - test_runtime = sum( + if i == 0: + logging.info( + f"original generated tests results -> {original_gen_results.get_test_pass_fail_report()}" + ) + + original_total_runtime_iter = sum( ( [ result.runtime - for result in test_results + for result in original_gen_results if (result.did_pass and result.runtime is not None) ] - if test_results is not None + if original_gen_results is not None else [] ) + instrumented_test_timing ) - if test_runtime == 0: + if original_total_runtime_iter == 0: logging.warning( - f"The overall test runtime of the optimized function is 0, trying again..." + f"The overall test runtime of the original function is 0, trying again..." ) + logging.warning(original_gen_results.test_results) continue - all_test_times.append(test_runtime) - optimized_test_results_iter.merge(test_results) - times_run += 1 - if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")): - os.remove(get_run_tmp_file(f"test_return_values_{j}.bin")) - if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")): - os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite")) - if equal_results and times_run > 0: - # TODO: Make the runtime more human readable by using humanize - new_test_time = min(all_test_times) - logging.info( - f"NEW CODE RUNTIME OVER {times_run} RUN{'S' if times_run > 1 else ''} = {new_test_time}ns, SPEEDUP RATIO = {((original_runtime - new_test_time) / new_test_time):.3f}" - ) - if ( - ((original_runtime - new_test_time) / new_test_time) - > MIN_IMPROVEMENT_THRESHOLD - ) and new_test_time < best_runtime: - logging.info("THIS IS BETTER!") + original_test_results_iter.merge(original_gen_results) + if i == 0: logging.info( - f"original_test_time={original_runtime} new_test_time={new_test_time}, FASTER RATIO = {((original_runtime - new_test_time) / new_test_time)}" + f"Original overall test results = {original_test_results_iter.get_test_pass_fail_report_by_type()}" ) - best_optimization = [optimized_code, explanation] - best_runtime = new_test_time - winning_test_results = optimized_test_results_iter - with open(path, "w") as f: - f.write(original_code) - logging.info("----------------") - logging.info(f"BEST OPTIMIZATION {best_optimization}") - if best_optimization: - found_atleast_one_optimization = True - logging.info(f"BEST OPTIMIZED CODE {best_optimization[0]}") - if not args.all: - new_code = replace_function_in_file( - path, - function_name, - best_optimization[0], - preexisting_functions, - # test_cfg.project_root_path, - # function_dependencies, + if ( + original_runtime is None + or original_total_runtime_iter < original_runtime + ): + original_runtime = best_runtime = original_total_runtime_iter + overall_original_test_results = original_test_results_iter + + times_run += 1 + + if times_run == 0: + logging.warning( + "Failed to run the tests for the original function, skipping optimization" ) + continue + logging.info( + f"ORIGINAL CODE RUNTIME OVER {times_run} RUN{'S' if times_run > 1 else ''} = {original_runtime}ns" + ) + logging.info("OPTIMIZING CODE....") + # TODO: Postprocess the optimized function to include the original docstring and such + optimizations = optimize_python_code( + code_to_optimize_with_dependents, n=N_CANDIDATES + ) + best_optimization = [] + for i, (optimized_code, explanation) in enumerate(optimizations): + j = i + 1 + if optimized_code is None: + continue + if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")): + # remove left overs from previous run + os.remove(get_run_tmp_file(f"test_return_values_{j}.bin")) + if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")): + os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite")) + logging.info(f"optimized_candidate {optimized_code}") + try: + new_code = replace_function_in_file( + path, + function_name, + optimized_code, + preexisting_functions, + # test_cfg.project_root_path, + # function_dependencies, + ) + except ( + ValueError, + SyntaxError, + cst.ParserSyntaxError, + AttributeError, + ) as e: + logging.error(e) + continue with open(path, "w") as f: f.write(new_code) - # TODO: After doing the best optimization, remove the test cases that errored on the new code, because they might be failing because of syntax errors and such. - speedup = (original_runtime / best_runtime) - 1 - # TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts - # TODO: Use python package humanize to make the runtime more human readable - explanation_final += ( - f"Function {function_name} in file {path}:\n" - f"Performance went up by {speedup:.2f}x ({speedup * 100:.2f}%). Runtime went down from {(original_runtime / 1000):.2f}μs to {(best_runtime / 1000):.2f}μs \n\n" - + "Optimization explanation:\n" - + best_optimization[1] - + " \n\n" - + "The code has been tested for correctness\n" - + f"Test Result for the best optimized code:- {winning_test_results.get_test_pass_fail_report_by_type()}\n" - ) - with open("/tmp/pr_comment_temp.txt", "a") as f: - f.write(explanation_final) - logging.info(f"EXPLANATION_FINAL {explanation_final}") - if args.all: - with open("optimizations_all.txt", "a") as f: - f.write(best_optimization[0]) - f.write("\n\n") + all_test_times = [] + equal_results = True + generated_tests_elapsed_time = 0.0 + + times_run = 0 + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = str(j) + for test_index in range(MAX_TEST_RUN_ITERATIONS): + if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")): + os.remove(get_run_tmp_file(f"test_return_values_{j}.bin")) + if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")): + os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite")) + if generated_tests_elapsed_time > MAX_FUNCTION_TEST_SECONDS: + break + + optimized_test_results_iter = TestResults() + instrumented_test_timing = [] + for instrumented_test_file in instrumented_unittests_created: + result_file_path, run_result = run_tests( + instrumented_test_file, + test_framework=self.args.test_framework, + cwd=self.args.root, + pytest_timeout=INDIVIDUAL_TEST_TIMEOUT, + verbose=True, + test_env=test_env, + ) + + unittest_results_optimized = parse_test_results( + test_xml_path=result_file_path, + test_py_path=instrumented_test_file, + test_config=self.test_cfg, + test_type=TestType.EXISTING_UNIT_TEST, + run_result=run_result, + optimization_iteration=j, + ) + + for result in unittest_results_optimized: + if result.did_pass and result.runtime is None: + logging.debug( + f"Ignoring test case that passed but had no runtime -> {result.id}" + ) + + timing = sum( + [ + result.runtime + for result in unittest_results_optimized + if (result.did_pass and result.runtime is not None) + ] + ) + optimized_test_results_iter.merge(unittest_results_optimized) + instrumented_test_timing.append(timing) + if test_index == 0: + equal_results = True + logging.info( + f"optimized existing unit tests result -> {optimized_test_results_iter.get_test_pass_fail_report()}" + ) + for test_invocation in optimized_test_results_iter: + if ( + overall_original_test_results.get_by_id(test_invocation.id) + is None + or test_invocation.did_pass + != overall_original_test_results.get_by_id( + test_invocation.id + ).did_pass + ): + logging.info("RESULTS DID NOT MATCH") + logging.info( + f"Test {test_invocation.id} failed on the optimized code. Skipping this optimization" + ) + equal_results = False + break + if not equal_results: + break + + start_time = time.time() + result_file_path, run_result = run_tests( + test_path=generated_tests_path, + test_framework=self.args.test_framework, + cwd=self.args.root, + test_env=test_env, + pytest_timeout=INDIVIDUAL_TEST_TIMEOUT, + ) + generated_tests_elapsed_time += time.time() - start_time + test_results = parse_test_results( + test_xml_path=result_file_path, + test_py_path=generated_tests_path, + optimization_iteration=j, + test_type=TestType.GENERATED_REGRESSION, + test_config=self.test_cfg, + run_result=run_result, + ) + if test_index == 0: + logging.info( + f"generated test_results optimized -> {test_results.get_test_pass_fail_report()}" + ) + if test_results: + if compare_results(original_gen_results, test_results): + equal_results = True + logging.info("RESULTS MATCHED!") + else: + logging.info("RESULTS DID NOT MATCH") + equal_results = False + if not equal_results: + break + for result in test_results: + if result.did_pass and result.runtime is None: + logging.debug( + f"Ignoring test case that passed but had no runtime -> {result.id}" + ) + test_runtime = sum( + ( + [ + result.runtime + for result in test_results + if (result.did_pass and result.runtime is not None) + ] + if test_results is not None + else [] + ) + + instrumented_test_timing + ) + if test_runtime == 0: + logging.warning( + f"The overall test runtime of the optimized function is 0, trying again..." + ) + continue + all_test_times.append(test_runtime) + optimized_test_results_iter.merge(test_results) + times_run += 1 + if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")): + os.remove(get_run_tmp_file(f"test_return_values_{j}.bin")) + if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")): + os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite")) + if equal_results and times_run > 0: + # TODO: Make the runtime more human readable by using humanize + new_test_time = min(all_test_times) + logging.info( + f"NEW CODE RUNTIME OVER {times_run} RUN{'S' if times_run > 1 else ''} = {new_test_time}ns, SPEEDUP RATIO = {((original_runtime - new_test_time) / new_test_time):.3f}" + ) + if ( + ((original_runtime - new_test_time) / new_test_time) + > MIN_IMPROVEMENT_THRESHOLD + ) and new_test_time < best_runtime: + logging.info("THIS IS BETTER!") + logging.info( + f"original_test_time={original_runtime} new_test_time={new_test_time}, FASTER RATIO = {((original_runtime - new_test_time) / new_test_time)}" + ) + best_optimization = [optimized_code, explanation] + best_runtime = new_test_time + winning_test_results = optimized_test_results_iter + with open(path, "w") as f: + f.write(original_code) + logging.info("----------------") + logging.info(f"BEST OPTIMIZATION {best_optimization}") + if best_optimization: + self.found_atleast_one_optimization = True + logging.info(f"BEST OPTIMIZED CODE {best_optimization[0]}") + if not self.args.all: + new_code = replace_function_in_file( + path, + function_name, + best_optimization[0], + preexisting_functions, + # test_cfg.project_root_path, + # function_dependencies, + ) + with open(path, "w") as f: + f.write(new_code) + # TODO: After doing the best optimization, remove the test cases that errored on the new code, because they might be failing because of syntax errors and such. + speedup = (original_runtime / best_runtime) - 1 + # TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts + # TODO: Use python package humanize to make the runtime more human readable + explanation_final += ( + f"Function {function_name} in file {path}:\n" + f"Performance went up by {speedup:.2f}x ({speedup * 100:.2f}%). Runtime went down from {(original_runtime / 1000):.2f}μs to {(best_runtime / 1000):.2f}μs \n\n" + + "Optimization explanation:\n" + + best_optimization[1] + + " \n\n" + + "The code has been tested for correctness\n" + + f"Test Result for the best optimized code:- {winning_test_results.get_test_pass_fail_report_by_type()}\n" + ) + with open("/tmp/pr_comment_temp.txt", "a") as f: f.write(explanation_final) - f.write("\n---------\n") + logging.info(f"EXPLANATION_FINAL {explanation_final}") + if self.args.all: + with open("optimizations_all.txt", "a") as f: + f.write(best_optimization[0]) + f.write("\n\n") + f.write(explanation_final) + f.write("\n---------\n") - subprocess.run(["black", path], stdout=subprocess.PIPE) - test_files_to_preserve.add(generated_tests_path) - try: - with open(os.environ["GITHUB_OUTPUT"], "w") as fh: - print("optimization_success=truee", file=fh) - except KeyError: - os.environ["GITHUB_OUTPUT"] = "optimization_success=truee" - else: - # Delete it here to not cause a lot of clutter if we are optimizing with --all option - if os.path.exists(generated_tests_path): - os.remove(generated_tests_path) - if not found_atleast_one_optimization: - try: - with open(os.environ["GITHUB_OUTPUT"], "w") as fh: - print("optimization_success=falsee", file=fh) - except KeyError: - os.environ["GITHUB_OUTPUT"] = "optimization_success=falsee" + subprocess.run(["black", path], stdout=subprocess.PIPE) + test_files_to_preserve.add(generated_tests_path) + try: + with open(os.environ["GITHUB_OUTPUT"], "w") as fh: + print("optimization_success=truee", file=fh) + except KeyError: + os.environ["GITHUB_OUTPUT"] = "optimization_success=truee" + else: + # Delete it here to not cause a lot of clutter if we are optimizing with --all option + if os.path.exists(generated_tests_path): + os.remove(generated_tests_path) + if not self.found_atleast_one_optimization: + try: + with open(os.environ["GITHUB_OUTPUT"], "w") as fh: + print("optimization_success=falsee", file=fh) + except KeyError: + os.environ["GITHUB_OUTPUT"] = "optimization_success=falsee" - finally: - # TODO: Also revert the file/function being optimized if the process did not succeed - for test_file in instrumented_unittests_created: - if os.path.exists(test_file): - os.remove(test_file) - for test_file in test_files_created: - if test_file not in test_files_to_preserve: + finally: + # TODO: Also revert the file/function being optimized if the process did not succeed + for test_file in instrumented_unittests_created: if os.path.exists(test_file): os.remove(test_file) - if hasattr(get_run_tmp_file, "tmpdir"): - get_run_tmp_file.tmpdir.cleanup() + for test_file in test_files_created: + if test_file not in test_files_to_preserve: + if os.path.exists(test_file): + os.remove(test_file) + if hasattr(get_run_tmp_file, "tmpdir"): + get_run_tmp_file.tmpdir.cleanup() + + def generate_test_files( + self, + function_to_optimize: FunctionToOptimize, + function_dependencies: List[Source], + code_to_optimize_with_dependents: str, + module_path: str, + ) -> str | None: + generated_tests_path = get_test_file_path( + self.args.test_root, function_to_optimize.function_name, 0 + ) + test_module_path = module_name_from_file_path(generated_tests_path, self.args.root) + new_tests = generate_tests( + source_code_being_tested=code_to_optimize_with_dependents, + function=function_to_optimize, + module_path=module_path, + test_module_path=test_module_path, + function_dependencies=function_dependencies, + test_framework=self.args.test_framework, + test_timeout=INDIVIDUAL_TEST_TIMEOUT, + use_cached_tests=self.args.use_cached_tests, + ) + if new_tests is None: + logging.error("/!\\ NO TESTS GENERATED for %s", function_to_optimize.function_name) + return None + with open(generated_tests_path, "w") as file: + file.write(new_tests) + return generated_tests_path if __name__ == "__main__": - main() + Optimizer(parse_args()).run() diff --git a/codeflash/verification/verification_utils.py b/codeflash/verification/verification_utils.py index e69b748e9..02299a37d 100644 --- a/codeflash/verification/verification_utils.py +++ b/codeflash/verification/verification_utils.py @@ -1,8 +1,8 @@ -import os import ast +import os -def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"): +def get_test_file_path(test_dir: str, function_name: str, iteration: int = 0, test_type: str = "unit") -> str: assert test_type in ["unit", "inspired", "replay"] function_name = function_name.replace(".", "_") path = os.path.join(test_dir, f"test_{function_name}__{test_type}_test_{iteration}.py") @@ -11,7 +11,7 @@ def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"): return path -def delete_multiple_if_name_main(test_ast): +def delete_multiple_if_name_main(test_ast: ast.Module) -> ast.Module: if_indexes = [] for index, node in enumerate(test_ast.body): if isinstance(node, ast.If):