diff --git a/.idea/codeflash.iml b/.idea/codeflash.iml
index 1cbadf231..0b638381e 100644
--- a/.idea/codeflash.iml
+++ b/.idea/codeflash.iml
@@ -7,8 +7,7 @@
-
-
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
index 03d9549ea..a58fa0a6f 100644
--- a/.idea/inspectionProfiles/Project_Default.xml
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -2,5 +2,6 @@
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
index ad1389744..5e2daa8c5 100644
--- a/.idea/inspectionProfiles/profiles_settings.xml
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -1,9 +1,10 @@
+
+
-
@@ -23,8 +24,8 @@
-
-
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
index 8b442f7e5..9e5b17a3a 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -3,8 +3,7 @@
-
-
+
-
+
\ No newline at end of file
diff --git a/.idea/pydantic.xml b/.idea/pydantic.xml
new file mode 100644
index 000000000..e6e3ec67b
--- /dev/null
+++ b/.idea/pydantic.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/sqldialects.xml b/.idea/sqldialects.xml
index 2c3db7023..c0e01cabc 100644
--- a/.idea/sqldialects.xml
+++ b/.idea/sqldialects.xml
@@ -1,7 +1,6 @@
-
\ No newline at end of file
diff --git a/codeflash/main.py b/codeflash/main.py
index 57ef8a468..2038dea38 100644
--- a/codeflash/main.py
+++ b/codeflash/main.py
@@ -1,4 +1,5 @@
import logging
+from typing import List
from codeflash.code_utils import env_utils
from codeflash.verification import EXPLAIN_MODEL
@@ -7,7 +8,7 @@ logging.basicConfig(level=logging.INFO)
import os
import subprocess
import time
-from argparse import ArgumentParser, SUPPRESS
+from argparse import ArgumentParser, SUPPRESS, Namespace
import libcst as cst
@@ -19,12 +20,16 @@ from codeflash.code_utils.code_utils import (
get_run_tmp_file,
)
from codeflash.code_utils.config_parser import parse_config_file
-from codeflash.discovery.discover_unit_tests import discover_unit_tests
-from codeflash.discovery.functions_to_optimize import get_functions_to_optimize_by_file
+from codeflash.discovery.discover_unit_tests import discover_unit_tests, TestsInFile
+from codeflash.discovery.functions_to_optimize import (
+ get_functions_to_optimize_by_file,
+ FunctionToOptimize,
+)
from codeflash.instrumentation.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.models import TestConfig
from codeflash.optimization.function_context import (
get_constrained_function_context_and_dependent_functions,
+ Source,
)
from codeflash.optimization.optimizer import optimize_python_code
from codeflash.verification.equivalence import compare_results
@@ -40,7 +45,7 @@ from codeflash.verification.verification_utils import (
from codeflash.verification.verifier import generate_tests
-def parse_args():
+def parse_args() -> Namespace:
parser = ArgumentParser()
parser.add_argument("--file", help="Try to optimize only this file")
parser.add_argument(
@@ -76,7 +81,7 @@ def parse_args():
help="Use cached tests from a specified file for debugging.",
)
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose logs")
- args = parser.parse_args()
+ args: Namespace = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
@@ -89,9 +94,7 @@ def parse_args():
if key in pyproject_config and getattr(args, key.replace("-", "_")) is None:
setattr(args, key.replace("-", "_"), pyproject_config[key])
assert os.path.isdir(args.root), "--root must be a valid directory"
- assert os.path.isdir(
- args.test_root
- ), f"--test-root must be a valid directory; {args.test_root} is not a directory"
+ assert os.path.isdir(args.test_root), "--test_root must be a valid directory"
args.root = os.path.realpath(args.root)
args.test_root = os.path.realpath(args.test_root)
if not hasattr(args, "all"):
@@ -111,322 +114,147 @@ N_CANDIDATES = 10
MIN_IMPROVEMENT_THRESHOLD = 0.05
-def main():
- logging.info("RUNNING THE OPTIMIZER")
- args = parse_args()
- env_utils.ensure_codeflash_api_key()
- test_cfg = TestConfig(
- test_root=args.test_root,
- project_root_path=args.root,
- test_framework=args.test_framework,
- )
+class Optimizer:
+ def __init__(self, args: Namespace):
+ self.args = args
+ self.test_cfg = TestConfig(
+ test_root=args.test_root,
+ project_root_path=args.root,
+ test_framework=args.test_framework,
+ )
- modified_functions, num_modified_functions = get_functions_to_optimize_by_file(
- optimize_all=args.all,
- file=args.file,
- function=args.function,
- test_cfg=test_cfg,
- )
+ def run(self):
+ logging.info("RUNNING THE OPTIMIZER")
+ env_utils.ensure_codeflash_api_key()
- test_files_created = set()
- test_files_to_preserve = set()
- instrumented_unittests_created = set()
- found_atleast_one_optimization = False
+ file_to_funcs_to_optimize, num_modified_functions = get_functions_to_optimize_by_file(
+ optimize_all=self.args.all,
+ file=self.args.file,
+ function=self.args.function,
+ test_cfg=self.test_cfg,
+ )
- if os.path.exists("/tmp/pr_comment_temp.txt"):
- os.remove("/tmp/pr_comment_temp.txt")
- function_iterator_count = 0
- try:
- if num_modified_functions == 0:
- logging.info("No functions found to optimize. Exiting...")
- return
- functions_to_tests_map = discover_unit_tests(test_cfg)
- for path in modified_functions:
- logging.info(f"Examining file {path} ...")
- # TODO: Sequence the functions one goes through intelligently. If we are optimizing f(g(x)), then we might want to first
- # optimize f rather than g because optimizing f would already optimize g as it is a dependency
- with open(path, "r") as f:
- original_code = f.read()
- for function_to_optimize in modified_functions[path]:
- function_name = function_to_optimize.function_name
- function_iterator_count += 1
- logging.info(
- f"Optimizing function {function_iterator_count} of {num_modified_functions} - {function_name}"
- )
- explanation_final = ""
- winning_test_results = None
- overall_original_test_results = None
- if os.path.exists(get_run_tmp_file("test_return_values_0.bin")):
- # remove left overs from previous run
- os.remove(get_run_tmp_file("test_return_values_0.bin"))
- if os.path.exists(get_run_tmp_file("test_return_values_0.sqlite")):
- os.remove(get_run_tmp_file("test_return_values_0.sqlite"))
- code_to_optimize = get_code(function_to_optimize)
- if code_to_optimize is None:
- logging.error("Could not find function to optimize")
- continue
+ test_files_created = set()
+ test_files_to_preserve = set()
+ instrumented_unittests_created = set()
+ self.found_atleast_one_optimization = False
- preexisting_functions = get_all_function_names(code_to_optimize)
-
- (
- code_to_optimize_with_dependents,
- function_dependencies,
- ) = get_constrained_function_context_and_dependent_functions(
- function_to_optimize,
- args.root,
- code_to_optimize,
- max_tokens=EXPLAIN_MODEL.max_tokens,
- )
- logging.info("CODE TO OPTIMIZE %s", code_to_optimize_with_dependents)
- module_path = module_name_from_file_path(path, args.root)
- unique_original_test_files = set()
-
- if not module_path + "." + function_name in functions_to_tests_map:
- logging.warning(
- "Could not find any pre-existing tests for '%s', will only use generated tests.",
- module_path + "." + function_name,
+ if os.path.exists("/tmp/pr_comment_temp.txt"):
+ os.remove("/tmp/pr_comment_temp.txt")
+ function_iterator_count = 0
+ try:
+ if num_modified_functions == 0:
+ logging.info("No functions found to optimize. Exiting...")
+ return
+ function_to_tests: dict[str, list[TestsInFile]] = discover_unit_tests(self.test_cfg)
+ for path in file_to_funcs_to_optimize:
+ logging.info(f"Examining file {path} ...")
+ # TODO: Sequence the functions one goes through intelligently. If we are optimizing f(g(x)), then we might want to first
+ # optimize f rather than g because optimizing f would already optimize g as it is a dependency
+ with open(path, "r") as f:
+ original_code = f.read()
+ for function_to_optimize in file_to_funcs_to_optimize[path]:
+ function_name = function_to_optimize.function_name
+ function_iterator_count += 1
+ logging.info(
+ f"Optimizing function {function_iterator_count} of {num_modified_functions} - {function_name}"
)
- else:
- for i, tests_in_file in enumerate(
- functions_to_tests_map.get(module_path + "." + function_name)
- ):
- if tests_in_file.test_file in unique_original_test_files:
- continue
- new_test_path = (
- os.path.splitext(tests_in_file.test_file)[0]
- + "__perfinstrumented"
- + os.path.splitext(tests_in_file.test_file)[1]
- )
- injected_test = inject_profiling_into_existing_test(
- tests_in_file.test_file,
- function_name,
- args.root,
- )
- with open(new_test_path, "w") as f:
- f.write(injected_test)
- instrumented_unittests_created.add(new_test_path)
- unique_original_test_files.add(tests_in_file.test_file)
- generated_tests_path = get_test_file_path(args.test_root, function_name, 0)
- test_module_path = module_name_from_file_path(generated_tests_path, args.root)
- new_tests = generate_tests(
- source_code_being_tested=code_to_optimize_with_dependents,
- function=function_to_optimize,
- module_path=module_path,
- test_module_path=test_module_path,
- function_dependencies=function_dependencies,
- test_framework=args.test_framework,
- test_timeout=INDIVIDUAL_TEST_TIMEOUT,
- use_cached_tests=args.use_cached_tests,
- )
- if new_tests is None:
- logging.error("/!\\ NO TESTS GENERATED for %s", function_name)
- continue
-
- test_files_created.add(generated_tests_path)
- with open(generated_tests_path, "w") as file:
- file.write(new_tests)
- original_runtime = None
- times_run = 0
- # TODO : Dynamically determine the number of times to run the tests based on the runtime of the tests.
- # Keep the runtime in some acceptable range
- generated_tests_elapsed_time = 0.0
-
- # For the original function - run the tests and get the runtime
- # TODO: Compare the function return values over the multiple runs and check if they are any different,
- # if they are different, then we can't optimize this function because it is a non-deterministic function
- test_env = os.environ.copy()
- test_env["CODEFLASH_TEST_ITERATION"] = str(0)
- for i in range(MAX_TEST_RUN_ITERATIONS):
- if generated_tests_elapsed_time > MAX_FUNCTION_TEST_SECONDS:
- break
- instrumented_test_timing = []
- original_test_results_iter = TestResults()
- for test_file in instrumented_unittests_created:
- result_file_path, run_result = run_tests(
- test_file,
- test_framework=args.test_framework,
- cwd=args.root,
- pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
- verbose=True,
- test_env=test_env,
- )
- unittest_results = parse_test_results(
- test_xml_path=result_file_path,
- test_py_path=test_file,
- test_config=test_cfg,
- test_type=TestType.EXISTING_UNIT_TEST,
- run_result=run_result,
- optimization_iteration=0,
- )
-
- for result in unittest_results:
- if result.did_pass and result.runtime is None:
- logging.debug(
- f"Ignoring test case that passed but had no runtime -> {result.id}"
- )
-
- timing = sum(
- [
- result.runtime
- for result in unittest_results
- if (result.did_pass and result.runtime is not None)
- ]
- )
- original_test_results_iter.merge(unittest_results)
- instrumented_test_timing.append(timing)
- if i == 0:
- logging.info(
- f"original code, existing unit test results -> {original_test_results_iter.get_test_pass_fail_report()}"
- )
- start_time = time.time()
- result_file_path, run_result = run_tests(
- test_path=generated_tests_path,
- cwd=args.root,
- test_framework=args.test_framework,
- test_env=test_env,
- pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
- )
- generated_tests_elapsed_time += time.time() - start_time
- # TODO: Implement the logic to disregard the timing info of the tests that ERRORed out. That is remove test cases that failed to run.
- original_gen_results = parse_test_results(
- result_file_path,
- generated_tests_path,
- test_cfg,
- test_type=TestType.GENERATED_REGRESSION,
- run_result=run_result,
- optimization_iteration=0,
- )
-
- if not original_gen_results and len(instrumented_test_timing) == 0:
- logging.warning(
- f"Couldn't run any tests for original function {function_name}. SKIPPING OPTIMIZING THIS FUNCTION."
- )
-
- break
- # TODO: Doing a simple sum of test runtime, Improve it by looking at test by test runtime, or a better scheme
- # TODO: If the runtime is None, that happens in the case where an exception is expected and is successfully
- # caught by the test framework. This makes the test pass, but we can't find runtime because the exception caused
- # the execution to not reach the runtime measurement part. We are currently ignoring such tests, because the performance
- # for such a execution that raises an exception should not matter.
- for result in original_gen_results:
- if result.did_pass and result.runtime is None:
- logging.debug(
- f"Ignoring test case that passed but had no runtime -> {result.id}"
- )
- if i == 0:
- logging.info(
- f"original generated tests results -> {original_gen_results.get_test_pass_fail_report()}"
- )
-
- original_total_runtime_iter = sum(
- (
- [
- result.runtime
- for result in original_gen_results
- if (result.did_pass and result.runtime is not None)
- ]
- if original_gen_results is not None
- else []
- )
- + instrumented_test_timing
- )
- if original_total_runtime_iter == 0:
- logging.warning(
- f"The overall test runtime of the original function is 0, trying again..."
- )
- logging.warning(original_gen_results.test_results)
- continue
- original_test_results_iter.merge(original_gen_results)
- if i == 0:
- logging.info(
- f"Original overall test results = {original_test_results_iter.get_test_pass_fail_report_by_type()}"
- )
- if original_runtime is None or original_total_runtime_iter < original_runtime:
- original_runtime = best_runtime = original_total_runtime_iter
- overall_original_test_results = original_test_results_iter
-
- times_run += 1
-
- if times_run == 0:
- logging.warning(
- "Failed to run the tests for the original function, skipping optimization"
- )
- continue
- logging.info(
- f"ORIGINAL CODE RUNTIME OVER {times_run} RUN{'S' if times_run > 1 else ''} = {original_runtime}ns"
- )
- logging.info("OPTIMIZING CODE....")
- # TODO: Postprocess the optimized function to include the original docstring and such
- optimizations = optimize_python_code(
- code_to_optimize_with_dependents, n=N_CANDIDATES
- )
- best_optimization = []
- for i, (optimized_code, explanation) in enumerate(optimizations):
- j = i + 1
- if optimized_code is None:
- continue
- if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")):
+ explanation_final = ""
+ winning_test_results = None
+ overall_original_test_results = None
+ if os.path.exists(get_run_tmp_file("test_return_values_0.bin")):
# remove left overs from previous run
- os.remove(get_run_tmp_file(f"test_return_values_{j}.bin"))
- if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")):
- os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite"))
- logging.info(f"optimized_candidate {optimized_code}")
- try:
- new_code = replace_function_in_file(
- path,
- function_name,
- optimized_code,
- preexisting_functions,
- # test_cfg.project_root_path,
- # function_dependencies,
- )
- except (
- ValueError,
- SyntaxError,
- cst.ParserSyntaxError,
- AttributeError,
- ) as e:
- logging.error(e)
+ os.remove(get_run_tmp_file("test_return_values_0.bin"))
+ if os.path.exists(get_run_tmp_file("test_return_values_0.sqlite")):
+ os.remove(get_run_tmp_file("test_return_values_0.sqlite"))
+ code_to_optimize = get_code(function_to_optimize)
+ if code_to_optimize is None:
+ logging.error("Could not find function to optimize")
continue
- with open(path, "w") as f:
- f.write(new_code)
- all_test_times = []
- equal_results = True
+
+ preexisting_functions = get_all_function_names(code_to_optimize)
+
+ (
+ code_to_optimize_with_dependents,
+ function_dependencies,
+ ) = get_constrained_function_context_and_dependent_functions(
+ function_to_optimize,
+ self.args.root,
+ code_to_optimize,
+ max_tokens=EXPLAIN_MODEL.max_tokens,
+ )
+ logging.info("CODE TO OPTIMIZE %s", code_to_optimize_with_dependents)
+ module_path = module_name_from_file_path(path, self.args.root)
+ unique_original_test_files = set()
+
+ full_module_function_path = module_path + "." + function_name
+ if full_module_function_path not in function_to_tests:
+ logging.warning(
+ "Could not find any pre-existing tests for '%s', will only use generated tests.",
+ full_module_function_path,
+ )
+ else:
+ for tests_in_file in function_to_tests.get(full_module_function_path):
+ if tests_in_file.test_file in unique_original_test_files:
+ continue
+ injected_test = inject_profiling_into_existing_test(
+ tests_in_file.test_file,
+ function_name,
+ self.args.root,
+ )
+ new_test_path = (
+ os.path.splitext(tests_in_file.test_file)[0]
+ + "__perfinstrumented"
+ + os.path.splitext(tests_in_file.test_file)[1]
+ )
+ with open(new_test_path, "w") as f:
+ f.write(injected_test)
+ instrumented_unittests_created.add(new_test_path)
+ unique_original_test_files.add(tests_in_file.test_file)
+
+ generated_tests_path = self.generate_test_files(
+ function_to_optimize,
+ function_dependencies,
+ code_to_optimize_with_dependents,
+ module_path,
+ )
+
+ test_files_created.add(generated_tests_path)
+ original_runtime = None
+ times_run = 0
+ # TODO : Dynamically determine the number of times to run the tests based on the runtime of the tests.
+ # Keep the runtime in some acceptable range
generated_tests_elapsed_time = 0.0
- times_run = 0
+ # For the original function - run the tests and get the runtime
+ # TODO: Compare the function return values over the multiple runs and check if they are any different,
+ # if they are different, then we can't optimize this function because it is a non-deterministic function
test_env = os.environ.copy()
- test_env["CODEFLASH_TEST_ITERATION"] = str(j)
- for test_index in range(MAX_TEST_RUN_ITERATIONS):
- if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")):
- os.remove(get_run_tmp_file(f"test_return_values_{j}.bin"))
- if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")):
- os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite"))
+ test_env["CODEFLASH_TEST_ITERATION"] = str(0)
+ for i in range(MAX_TEST_RUN_ITERATIONS):
if generated_tests_elapsed_time > MAX_FUNCTION_TEST_SECONDS:
break
-
- optimized_test_results_iter = TestResults()
instrumented_test_timing = []
- for instrumented_test_file in instrumented_unittests_created:
+ original_test_results_iter = TestResults()
+ for test_file in instrumented_unittests_created:
result_file_path, run_result = run_tests(
- instrumented_test_file,
- test_framework=args.test_framework,
- cwd=args.root,
+ test_file,
+ test_framework=self.args.test_framework,
+ cwd=self.args.root,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
verbose=True,
test_env=test_env,
)
-
- unittest_results_optimized = parse_test_results(
+ unittest_results = parse_test_results(
test_xml_path=result_file_path,
- test_py_path=instrumented_test_file,
- test_config=test_cfg,
+ test_py_path=test_file,
+ test_config=self.test_cfg,
test_type=TestType.EXISTING_UNIT_TEST,
run_result=run_result,
- optimization_iteration=j,
+ optimization_iteration=0,
)
- for result in unittest_results_optimized:
+ for result in unittest_results:
if result.did_pass and result.runtime is None:
logging.debug(
f"Ignoring test case that passed but had no runtime -> {result.id}"
@@ -435,182 +263,380 @@ def main():
timing = sum(
[
result.runtime
- for result in unittest_results_optimized
+ for result in unittest_results
if (result.did_pass and result.runtime is not None)
]
)
- optimized_test_results_iter.merge(unittest_results_optimized)
+ original_test_results_iter.merge(unittest_results)
instrumented_test_timing.append(timing)
- if test_index == 0:
- equal_results = True
+ if i == 0:
logging.info(
- f"optimized existing unit tests result -> {optimized_test_results_iter.get_test_pass_fail_report()}"
+ f"original code, existing unit test results -> {original_test_results_iter.get_test_pass_fail_report()}"
)
- for test_invocation in optimized_test_results_iter:
- if (
- overall_original_test_results.get_by_id(test_invocation.id)
- is None
- or test_invocation.did_pass
- != overall_original_test_results.get_by_id(
- test_invocation.id
- ).did_pass
- ):
- logging.info("RESULTS DID NOT MATCH")
- logging.info(
- f"Test {test_invocation.id} failed on the optimized code. Skipping this optimization"
- )
- equal_results = False
- break
- if not equal_results:
- break
-
start_time = time.time()
result_file_path, run_result = run_tests(
test_path=generated_tests_path,
- test_framework=args.test_framework,
- cwd=args.root,
+ cwd=self.args.root,
+ test_framework=self.args.test_framework,
test_env=test_env,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
)
generated_tests_elapsed_time += time.time() - start_time
- test_results = parse_test_results(
- test_xml_path=result_file_path,
- test_py_path=generated_tests_path,
- optimization_iteration=j,
+ # TODO: Implement the logic to disregard the timing info of the tests that ERRORed out. That is remove test cases that failed to run.
+ original_gen_results = parse_test_results(
+ result_file_path,
+ generated_tests_path,
+ self.test_cfg,
test_type=TestType.GENERATED_REGRESSION,
- test_config=test_cfg,
run_result=run_result,
+ optimization_iteration=0,
)
- if test_index == 0:
- logging.info(
- f"generated test_results optimized -> {test_results.get_test_pass_fail_report()}"
+
+ if not original_gen_results and len(instrumented_test_timing) == 0:
+ logging.warning(
+ f"Couldn't run any tests for original function {function_name}. SKIPPING OPTIMIZING THIS FUNCTION."
)
- if test_results:
- if compare_results(original_gen_results, test_results):
- equal_results = True
- logging.info("RESULTS MATCHED!")
- else:
- logging.info("RESULTS DID NOT MATCH")
- equal_results = False
- if not equal_results:
+
break
- for result in test_results:
+ # TODO: Doing a simple sum of test runtime, Improve it by looking at test by test runtime, or a better scheme
+ # TODO: If the runtime is None, that happens in the case where an exception is expected and is successfully
+ # caught by the test framework. This makes the test pass, but we can't find runtime because the exception caused
+ # the execution to not reach the runtime measurement part. We are currently ignoring such tests, because the performance
+ # for such a execution that raises an exception should not matter.
+ for result in original_gen_results:
if result.did_pass and result.runtime is None:
logging.debug(
f"Ignoring test case that passed but had no runtime -> {result.id}"
)
- test_runtime = sum(
+ if i == 0:
+ logging.info(
+ f"original generated tests results -> {original_gen_results.get_test_pass_fail_report()}"
+ )
+
+ original_total_runtime_iter = sum(
(
[
result.runtime
- for result in test_results
+ for result in original_gen_results
if (result.did_pass and result.runtime is not None)
]
- if test_results is not None
+ if original_gen_results is not None
else []
)
+ instrumented_test_timing
)
- if test_runtime == 0:
+ if original_total_runtime_iter == 0:
logging.warning(
- f"The overall test runtime of the optimized function is 0, trying again..."
+ f"The overall test runtime of the original function is 0, trying again..."
)
+ logging.warning(original_gen_results.test_results)
continue
- all_test_times.append(test_runtime)
- optimized_test_results_iter.merge(test_results)
- times_run += 1
- if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")):
- os.remove(get_run_tmp_file(f"test_return_values_{j}.bin"))
- if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")):
- os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite"))
- if equal_results and times_run > 0:
- # TODO: Make the runtime more human readable by using humanize
- new_test_time = min(all_test_times)
- logging.info(
- f"NEW CODE RUNTIME OVER {times_run} RUN{'S' if times_run > 1 else ''} = {new_test_time}ns, SPEEDUP RATIO = {((original_runtime - new_test_time) / new_test_time):.3f}"
- )
- if (
- ((original_runtime - new_test_time) / new_test_time)
- > MIN_IMPROVEMENT_THRESHOLD
- ) and new_test_time < best_runtime:
- logging.info("THIS IS BETTER!")
+ original_test_results_iter.merge(original_gen_results)
+ if i == 0:
logging.info(
- f"original_test_time={original_runtime} new_test_time={new_test_time}, FASTER RATIO = {((original_runtime - new_test_time) / new_test_time)}"
+ f"Original overall test results = {original_test_results_iter.get_test_pass_fail_report_by_type()}"
)
- best_optimization = [optimized_code, explanation]
- best_runtime = new_test_time
- winning_test_results = optimized_test_results_iter
- with open(path, "w") as f:
- f.write(original_code)
- logging.info("----------------")
- logging.info(f"BEST OPTIMIZATION {best_optimization}")
- if best_optimization:
- found_atleast_one_optimization = True
- logging.info(f"BEST OPTIMIZED CODE {best_optimization[0]}")
- if not args.all:
- new_code = replace_function_in_file(
- path,
- function_name,
- best_optimization[0],
- preexisting_functions,
- # test_cfg.project_root_path,
- # function_dependencies,
+ if (
+ original_runtime is None
+ or original_total_runtime_iter < original_runtime
+ ):
+ original_runtime = best_runtime = original_total_runtime_iter
+ overall_original_test_results = original_test_results_iter
+
+ times_run += 1
+
+ if times_run == 0:
+ logging.warning(
+ "Failed to run the tests for the original function, skipping optimization"
)
+ continue
+ logging.info(
+ f"ORIGINAL CODE RUNTIME OVER {times_run} RUN{'S' if times_run > 1 else ''} = {original_runtime}ns"
+ )
+ logging.info("OPTIMIZING CODE....")
+ # TODO: Postprocess the optimized function to include the original docstring and such
+ optimizations = optimize_python_code(
+ code_to_optimize_with_dependents, n=N_CANDIDATES
+ )
+ best_optimization = []
+ for i, (optimized_code, explanation) in enumerate(optimizations):
+ j = i + 1
+ if optimized_code is None:
+ continue
+ if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")):
+ # remove left overs from previous run
+ os.remove(get_run_tmp_file(f"test_return_values_{j}.bin"))
+ if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")):
+ os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite"))
+ logging.info(f"optimized_candidate {optimized_code}")
+ try:
+ new_code = replace_function_in_file(
+ path,
+ function_name,
+ optimized_code,
+ preexisting_functions,
+ # test_cfg.project_root_path,
+ # function_dependencies,
+ )
+ except (
+ ValueError,
+ SyntaxError,
+ cst.ParserSyntaxError,
+ AttributeError,
+ ) as e:
+ logging.error(e)
+ continue
with open(path, "w") as f:
f.write(new_code)
- # TODO: After doing the best optimization, remove the test cases that errored on the new code, because they might be failing because of syntax errors and such.
- speedup = (original_runtime / best_runtime) - 1
- # TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts
- # TODO: Use python package humanize to make the runtime more human readable
- explanation_final += (
- f"Function {function_name} in file {path}:\n"
- f"Performance went up by {speedup:.2f}x ({speedup * 100:.2f}%). Runtime went down from {(original_runtime / 1000):.2f}μs to {(best_runtime / 1000):.2f}μs \n\n"
- + "Optimization explanation:\n"
- + best_optimization[1]
- + " \n\n"
- + "The code has been tested for correctness\n"
- + f"Test Result for the best optimized code:- {winning_test_results.get_test_pass_fail_report_by_type()}\n"
- )
- with open("/tmp/pr_comment_temp.txt", "a") as f:
- f.write(explanation_final)
- logging.info(f"EXPLANATION_FINAL {explanation_final}")
- if args.all:
- with open("optimizations_all.txt", "a") as f:
- f.write(best_optimization[0])
- f.write("\n\n")
+ all_test_times = []
+ equal_results = True
+ generated_tests_elapsed_time = 0.0
+
+ times_run = 0
+ test_env = os.environ.copy()
+ test_env["CODEFLASH_TEST_ITERATION"] = str(j)
+ for test_index in range(MAX_TEST_RUN_ITERATIONS):
+ if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")):
+ os.remove(get_run_tmp_file(f"test_return_values_{j}.bin"))
+ if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")):
+ os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite"))
+ if generated_tests_elapsed_time > MAX_FUNCTION_TEST_SECONDS:
+ break
+
+ optimized_test_results_iter = TestResults()
+ instrumented_test_timing = []
+ for instrumented_test_file in instrumented_unittests_created:
+ result_file_path, run_result = run_tests(
+ instrumented_test_file,
+ test_framework=self.args.test_framework,
+ cwd=self.args.root,
+ pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
+ verbose=True,
+ test_env=test_env,
+ )
+
+ unittest_results_optimized = parse_test_results(
+ test_xml_path=result_file_path,
+ test_py_path=instrumented_test_file,
+ test_config=self.test_cfg,
+ test_type=TestType.EXISTING_UNIT_TEST,
+ run_result=run_result,
+ optimization_iteration=j,
+ )
+
+ for result in unittest_results_optimized:
+ if result.did_pass and result.runtime is None:
+ logging.debug(
+ f"Ignoring test case that passed but had no runtime -> {result.id}"
+ )
+
+ timing = sum(
+ [
+ result.runtime
+ for result in unittest_results_optimized
+ if (result.did_pass and result.runtime is not None)
+ ]
+ )
+ optimized_test_results_iter.merge(unittest_results_optimized)
+ instrumented_test_timing.append(timing)
+ if test_index == 0:
+ equal_results = True
+ logging.info(
+ f"optimized existing unit tests result -> {optimized_test_results_iter.get_test_pass_fail_report()}"
+ )
+ for test_invocation in optimized_test_results_iter:
+ if (
+ overall_original_test_results.get_by_id(test_invocation.id)
+ is None
+ or test_invocation.did_pass
+ != overall_original_test_results.get_by_id(
+ test_invocation.id
+ ).did_pass
+ ):
+ logging.info("RESULTS DID NOT MATCH")
+ logging.info(
+ f"Test {test_invocation.id} failed on the optimized code. Skipping this optimization"
+ )
+ equal_results = False
+ break
+ if not equal_results:
+ break
+
+ start_time = time.time()
+ result_file_path, run_result = run_tests(
+ test_path=generated_tests_path,
+ test_framework=self.args.test_framework,
+ cwd=self.args.root,
+ test_env=test_env,
+ pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
+ )
+ generated_tests_elapsed_time += time.time() - start_time
+ test_results = parse_test_results(
+ test_xml_path=result_file_path,
+ test_py_path=generated_tests_path,
+ optimization_iteration=j,
+ test_type=TestType.GENERATED_REGRESSION,
+ test_config=self.test_cfg,
+ run_result=run_result,
+ )
+ if test_index == 0:
+ logging.info(
+ f"generated test_results optimized -> {test_results.get_test_pass_fail_report()}"
+ )
+ if test_results:
+ if compare_results(original_gen_results, test_results):
+ equal_results = True
+ logging.info("RESULTS MATCHED!")
+ else:
+ logging.info("RESULTS DID NOT MATCH")
+ equal_results = False
+ if not equal_results:
+ break
+ for result in test_results:
+ if result.did_pass and result.runtime is None:
+ logging.debug(
+ f"Ignoring test case that passed but had no runtime -> {result.id}"
+ )
+ test_runtime = sum(
+ (
+ [
+ result.runtime
+ for result in test_results
+ if (result.did_pass and result.runtime is not None)
+ ]
+ if test_results is not None
+ else []
+ )
+ + instrumented_test_timing
+ )
+ if test_runtime == 0:
+ logging.warning(
+ f"The overall test runtime of the optimized function is 0, trying again..."
+ )
+ continue
+ all_test_times.append(test_runtime)
+ optimized_test_results_iter.merge(test_results)
+ times_run += 1
+ if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.bin")):
+ os.remove(get_run_tmp_file(f"test_return_values_{j}.bin"))
+ if os.path.exists(get_run_tmp_file(f"test_return_values_{j}.sqlite")):
+ os.remove(get_run_tmp_file(f"test_return_values_{j}.sqlite"))
+ if equal_results and times_run > 0:
+ # TODO: Make the runtime more human readable by using humanize
+ new_test_time = min(all_test_times)
+ logging.info(
+ f"NEW CODE RUNTIME OVER {times_run} RUN{'S' if times_run > 1 else ''} = {new_test_time}ns, SPEEDUP RATIO = {((original_runtime - new_test_time) / new_test_time):.3f}"
+ )
+ if (
+ ((original_runtime - new_test_time) / new_test_time)
+ > MIN_IMPROVEMENT_THRESHOLD
+ ) and new_test_time < best_runtime:
+ logging.info("THIS IS BETTER!")
+ logging.info(
+ f"original_test_time={original_runtime} new_test_time={new_test_time}, FASTER RATIO = {((original_runtime - new_test_time) / new_test_time)}"
+ )
+ best_optimization = [optimized_code, explanation]
+ best_runtime = new_test_time
+ winning_test_results = optimized_test_results_iter
+ with open(path, "w") as f:
+ f.write(original_code)
+ logging.info("----------------")
+ logging.info(f"BEST OPTIMIZATION {best_optimization}")
+ if best_optimization:
+ self.found_atleast_one_optimization = True
+ logging.info(f"BEST OPTIMIZED CODE {best_optimization[0]}")
+ if not self.args.all:
+ new_code = replace_function_in_file(
+ path,
+ function_name,
+ best_optimization[0],
+ preexisting_functions,
+ # test_cfg.project_root_path,
+ # function_dependencies,
+ )
+ with open(path, "w") as f:
+ f.write(new_code)
+ # TODO: After doing the best optimization, remove the test cases that errored on the new code, because they might be failing because of syntax errors and such.
+ speedup = (original_runtime / best_runtime) - 1
+ # TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts
+ # TODO: Use python package humanize to make the runtime more human readable
+ explanation_final += (
+ f"Function {function_name} in file {path}:\n"
+ f"Performance went up by {speedup:.2f}x ({speedup * 100:.2f}%). Runtime went down from {(original_runtime / 1000):.2f}μs to {(best_runtime / 1000):.2f}μs \n\n"
+ + "Optimization explanation:\n"
+ + best_optimization[1]
+ + " \n\n"
+ + "The code has been tested for correctness\n"
+ + f"Test Result for the best optimized code:- {winning_test_results.get_test_pass_fail_report_by_type()}\n"
+ )
+ with open("/tmp/pr_comment_temp.txt", "a") as f:
f.write(explanation_final)
- f.write("\n---------\n")
+ logging.info(f"EXPLANATION_FINAL {explanation_final}")
+ if self.args.all:
+ with open("optimizations_all.txt", "a") as f:
+ f.write(best_optimization[0])
+ f.write("\n\n")
+ f.write(explanation_final)
+ f.write("\n---------\n")
- subprocess.run(["black", path], stdout=subprocess.PIPE)
- test_files_to_preserve.add(generated_tests_path)
- try:
- with open(os.environ["GITHUB_OUTPUT"], "w") as fh:
- print("optimization_success=truee", file=fh)
- except KeyError:
- os.environ["GITHUB_OUTPUT"] = "optimization_success=truee"
- else:
- # Delete it here to not cause a lot of clutter if we are optimizing with --all option
- if os.path.exists(generated_tests_path):
- os.remove(generated_tests_path)
- if not found_atleast_one_optimization:
- try:
- with open(os.environ["GITHUB_OUTPUT"], "w") as fh:
- print("optimization_success=falsee", file=fh)
- except KeyError:
- os.environ["GITHUB_OUTPUT"] = "optimization_success=falsee"
+ subprocess.run(["black", path], stdout=subprocess.PIPE)
+ test_files_to_preserve.add(generated_tests_path)
+ try:
+ with open(os.environ["GITHUB_OUTPUT"], "w") as fh:
+ print("optimization_success=truee", file=fh)
+ except KeyError:
+ os.environ["GITHUB_OUTPUT"] = "optimization_success=truee"
+ else:
+ # Delete it here to not cause a lot of clutter if we are optimizing with --all option
+ if os.path.exists(generated_tests_path):
+ os.remove(generated_tests_path)
+ if not self.found_atleast_one_optimization:
+ try:
+ with open(os.environ["GITHUB_OUTPUT"], "w") as fh:
+ print("optimization_success=falsee", file=fh)
+ except KeyError:
+ os.environ["GITHUB_OUTPUT"] = "optimization_success=falsee"
- finally:
- # TODO: Also revert the file/function being optimized if the process did not succeed
- for test_file in instrumented_unittests_created:
- if os.path.exists(test_file):
- os.remove(test_file)
- for test_file in test_files_created:
- if test_file not in test_files_to_preserve:
+ finally:
+ # TODO: Also revert the file/function being optimized if the process did not succeed
+ for test_file in instrumented_unittests_created:
if os.path.exists(test_file):
os.remove(test_file)
- if hasattr(get_run_tmp_file, "tmpdir"):
- get_run_tmp_file.tmpdir.cleanup()
+ for test_file in test_files_created:
+ if test_file not in test_files_to_preserve:
+ if os.path.exists(test_file):
+ os.remove(test_file)
+ if hasattr(get_run_tmp_file, "tmpdir"):
+ get_run_tmp_file.tmpdir.cleanup()
+
+ def generate_test_files(
+ self,
+ function_to_optimize: FunctionToOptimize,
+ function_dependencies: List[Source],
+ code_to_optimize_with_dependents: str,
+ module_path: str,
+ ) -> str | None:
+ generated_tests_path = get_test_file_path(
+ self.args.test_root, function_to_optimize.function_name, 0
+ )
+ test_module_path = module_name_from_file_path(generated_tests_path, self.args.root)
+ new_tests = generate_tests(
+ source_code_being_tested=code_to_optimize_with_dependents,
+ function=function_to_optimize,
+ module_path=module_path,
+ test_module_path=test_module_path,
+ function_dependencies=function_dependencies,
+ test_framework=self.args.test_framework,
+ test_timeout=INDIVIDUAL_TEST_TIMEOUT,
+ use_cached_tests=self.args.use_cached_tests,
+ )
+ if new_tests is None:
+ logging.error("/!\\ NO TESTS GENERATED for %s", function_to_optimize.function_name)
+ return None
+ with open(generated_tests_path, "w") as file:
+ file.write(new_tests)
+ return generated_tests_path
if __name__ == "__main__":
- main()
+ Optimizer(parse_args()).run()
diff --git a/codeflash/verification/verification_utils.py b/codeflash/verification/verification_utils.py
index e69b748e9..02299a37d 100644
--- a/codeflash/verification/verification_utils.py
+++ b/codeflash/verification/verification_utils.py
@@ -1,8 +1,8 @@
-import os
import ast
+import os
-def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"):
+def get_test_file_path(test_dir: str, function_name: str, iteration: int = 0, test_type: str = "unit") -> str:
assert test_type in ["unit", "inspired", "replay"]
function_name = function_name.replace(".", "_")
path = os.path.join(test_dir, f"test_{function_name}__{test_type}_test_{iteration}.py")
@@ -11,7 +11,7 @@ def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"):
return path
-def delete_multiple_if_name_main(test_ast):
+def delete_multiple_if_name_main(test_ast: ast.Module) -> ast.Module:
if_indexes = []
for index, node in enumerate(test_ast.body):
if isinstance(node, ast.If):