Refactor main function into Optimizer class

This commit is contained in:
afik.cohen 2023-11-16 15:18:05 -08:00
parent 205e360e3e
commit e0db936639
8 changed files with 476 additions and 445 deletions

View file

@ -7,8 +7,7 @@
<excludeFolder url="file://$MODULE_DIR$/.mypy_cache" />
<excludeFolder url="file://$MODULE_DIR$/.pytest_cache" />
</content>
<orderEntry type="jdk" jdkName="codeflash311" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Poetry (codeflash) (3)" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Poetry (codeflash) (4)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View file

@ -2,5 +2,6 @@
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="Mypy" enabled="true" level="SERVER PROBLEM" enabled_by_default="true" editorAttributes="GENERIC_SERVER_ERROR_OR_WARNING" />
</profile>
</component>

View file

@ -1,9 +1,10 @@
<component name="InspectionProjectProfileManager">
<settings>
<info color="dc00eb">
<option name="FOREGROUND" value="dc00eb" />
<option name="BACKGROUND" value="e8d9ff" />
<option name="EFFECT_COLOR" value="dc00eb" />
<option name="ERROR_STRIPE_COLOR" value="dc00eb" />
<option name="EFFECT_TYPE" value="2" />
<option name="myName" value="Type Hint" />
<option name="myVal" value="50" />
<option name="myExternalName" value="Type Hint" />
@ -23,8 +24,8 @@
<item index="8" class="java.lang.String" itemvalue="WEAK WARNING" />
<item index="9" class="java.lang.String" itemvalue="INFO" />
<item index="10" class="java.lang.String" itemvalue="WARNING" />
<item index="11" class="java.lang.String" itemvalue="Type Hint" />
<item index="12" class="java.lang.String" itemvalue="ERROR" />
<item index="11" class="java.lang.String" itemvalue="ERROR" />
<item index="12" class="java.lang.String" itemvalue="Type Hint" />
</list>
</settings>
</component>

View file

@ -3,8 +3,7 @@
<component name="Black">
<option name="enabledOnReformat" value="true" />
<option name="enabledOnSave" value="true" />
<option name="pathToExecutable" value="$USER_HOME$/mambaforge/bin/black" />
<option name="sdkName" value="Poetry (codeflash) (3)" />
<option name="sdkName" value="Poetry (codeflash) (4)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Poetry (codeflash) (3)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Poetry (codeflash) (4)" project-jdk-type="Python SDK" />
</project>

6
.idea/pydantic.xml Normal file
View file

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PydanticConfigService">
<option name="warnUntypedFields" value="true" />
</component>
</project>

View file

@ -1,7 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/code_to_optimize/tests/pytest/test_bubble_sort_deleteme.py" dialect="GenericSQL" />
<file url="PROJECT" dialect="SQLite" />
</component>
</project>

View file

@ -1,4 +1,5 @@
import logging
from typing import List
from codeflash.code_utils import env_utils
from codeflash.verification import EXPLAIN_MODEL
@ -7,7 +8,7 @@ logging.basicConfig(level=logging.INFO)
import os
import subprocess
import time
from argparse import ArgumentParser, SUPPRESS
from argparse import ArgumentParser, SUPPRESS, Namespace
import libcst as cst
@ -19,12 +20,16 @@ from codeflash.code_utils.code_utils import (
get_run_tmp_file,
)
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize_by_file
from codeflash.discovery.discover_unit_tests import discover_unit_tests, TestsInFile
from codeflash.discovery.functions_to_optimize import (
get_functions_to_optimize_by_file,
FunctionToOptimize,
)
from codeflash.instrumentation.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.models import TestConfig
from codeflash.optimization.function_context import (
get_constrained_function_context_and_dependent_functions,
Source,
)
from codeflash.optimization.optimizer import optimize_python_code
from codeflash.verification.equivalence import compare_results
@ -40,7 +45,7 @@ from codeflash.verification.verification_utils import (
from codeflash.verification.verifier import generate_tests
def parse_args():
def parse_args() -> Namespace:
parser = ArgumentParser()
parser.add_argument("--file", help="Try to optimize only this file")
parser.add_argument(
@ -76,7 +81,7 @@ def parse_args():
help="Use cached tests from a specified file for debugging.",
)
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose logs")
args = parser.parse_args()
args: Namespace = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
@ -89,9 +94,7 @@ def parse_args():
if key in pyproject_config and getattr(args, key.replace("-", "_")) is None:
setattr(args, key.replace("-", "_"), pyproject_config[key])
assert os.path.isdir(args.root), "--root must be a valid directory"
assert os.path.isdir(
args.test_root
), f"--test-root must be a valid directory; {args.test_root} is not a directory"
assert os.path.isdir(args.test_root), "--test_root must be a valid directory"
args.root = os.path.realpath(args.root)
args.test_root = os.path.realpath(args.test_root)
if not hasattr(args, "all"):
@ -111,27 +114,30 @@ N_CANDIDATES = 10
MIN_IMPROVEMENT_THRESHOLD = 0.05
def main():
logging.info("RUNNING THE OPTIMIZER")
args = parse_args()
env_utils.ensure_codeflash_api_key()
test_cfg = TestConfig(
class Optimizer:
def __init__(self, args: Namespace):
self.args = args
self.test_cfg = TestConfig(
test_root=args.test_root,
project_root_path=args.root,
test_framework=args.test_framework,
)
modified_functions, num_modified_functions = get_functions_to_optimize_by_file(
optimize_all=args.all,
file=args.file,
function=args.function,
test_cfg=test_cfg,
def run(self):
logging.info("RUNNING THE OPTIMIZER")
env_utils.ensure_codeflash_api_key()
file_to_funcs_to_optimize, num_modified_functions = get_functions_to_optimize_by_file(
optimize_all=self.args.all,
file=self.args.file,
function=self.args.function,
test_cfg=self.test_cfg,
)
test_files_created = set()
test_files_to_preserve = set()
instrumented_unittests_created = set()
found_atleast_one_optimization = False
self.found_atleast_one_optimization = False
if os.path.exists("/tmp/pr_comment_temp.txt"):
os.remove("/tmp/pr_comment_temp.txt")
@ -140,14 +146,14 @@ def main():
if num_modified_functions == 0:
logging.info("No functions found to optimize. Exiting...")
return
functions_to_tests_map = discover_unit_tests(test_cfg)
for path in modified_functions:
function_to_tests: dict[str, list[TestsInFile]] = discover_unit_tests(self.test_cfg)
for path in file_to_funcs_to_optimize:
logging.info(f"Examining file {path} ...")
# TODO: Sequence the functions one goes through intelligently. If we are optimizing f(g(x)), then we might want to first
# optimize f rather than g because optimizing f would already optimize g as it is a dependency
with open(path, "r") as f:
original_code = f.read()
for function_to_optimize in modified_functions[path]:
for function_to_optimize in file_to_funcs_to_optimize[path]:
function_name = function_to_optimize.function_name
function_iterator_count += 1
logging.info(
@ -173,58 +179,47 @@ def main():
function_dependencies,
) = get_constrained_function_context_and_dependent_functions(
function_to_optimize,
args.root,
self.args.root,
code_to_optimize,
max_tokens=EXPLAIN_MODEL.max_tokens,
)
logging.info("CODE TO OPTIMIZE %s", code_to_optimize_with_dependents)
module_path = module_name_from_file_path(path, args.root)
module_path = module_name_from_file_path(path, self.args.root)
unique_original_test_files = set()
if not module_path + "." + function_name in functions_to_tests_map:
full_module_function_path = module_path + "." + function_name
if full_module_function_path not in function_to_tests:
logging.warning(
"Could not find any pre-existing tests for '%s', will only use generated tests.",
module_path + "." + function_name,
full_module_function_path,
)
else:
for i, tests_in_file in enumerate(
functions_to_tests_map.get(module_path + "." + function_name)
):
for tests_in_file in function_to_tests.get(full_module_function_path):
if tests_in_file.test_file in unique_original_test_files:
continue
injected_test = inject_profiling_into_existing_test(
tests_in_file.test_file,
function_name,
self.args.root,
)
new_test_path = (
os.path.splitext(tests_in_file.test_file)[0]
+ "__perfinstrumented"
+ os.path.splitext(tests_in_file.test_file)[1]
)
injected_test = inject_profiling_into_existing_test(
tests_in_file.test_file,
function_name,
args.root,
)
with open(new_test_path, "w") as f:
f.write(injected_test)
instrumented_unittests_created.add(new_test_path)
unique_original_test_files.add(tests_in_file.test_file)
generated_tests_path = get_test_file_path(args.test_root, function_name, 0)
test_module_path = module_name_from_file_path(generated_tests_path, args.root)
new_tests = generate_tests(
source_code_being_tested=code_to_optimize_with_dependents,
function=function_to_optimize,
module_path=module_path,
test_module_path=test_module_path,
function_dependencies=function_dependencies,
test_framework=args.test_framework,
test_timeout=INDIVIDUAL_TEST_TIMEOUT,
use_cached_tests=args.use_cached_tests,
generated_tests_path = self.generate_test_files(
function_to_optimize,
function_dependencies,
code_to_optimize_with_dependents,
module_path,
)
if new_tests is None:
logging.error("/!\\ NO TESTS GENERATED for %s", function_name)
continue
test_files_created.add(generated_tests_path)
with open(generated_tests_path, "w") as file:
file.write(new_tests)
original_runtime = None
times_run = 0
# TODO : Dynamically determine the number of times to run the tests based on the runtime of the tests.
@ -244,8 +239,8 @@ def main():
for test_file in instrumented_unittests_created:
result_file_path, run_result = run_tests(
test_file,
test_framework=args.test_framework,
cwd=args.root,
test_framework=self.args.test_framework,
cwd=self.args.root,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
verbose=True,
test_env=test_env,
@ -253,7 +248,7 @@ def main():
unittest_results = parse_test_results(
test_xml_path=result_file_path,
test_py_path=test_file,
test_config=test_cfg,
test_config=self.test_cfg,
test_type=TestType.EXISTING_UNIT_TEST,
run_result=run_result,
optimization_iteration=0,
@ -281,8 +276,8 @@ def main():
start_time = time.time()
result_file_path, run_result = run_tests(
test_path=generated_tests_path,
cwd=args.root,
test_framework=args.test_framework,
cwd=self.args.root,
test_framework=self.args.test_framework,
test_env=test_env,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
)
@ -291,7 +286,7 @@ def main():
original_gen_results = parse_test_results(
result_file_path,
generated_tests_path,
test_cfg,
self.test_cfg,
test_type=TestType.GENERATED_REGRESSION,
run_result=run_result,
optimization_iteration=0,
@ -341,7 +336,10 @@ def main():
logging.info(
f"Original overall test results = {original_test_results_iter.get_test_pass_fail_report_by_type()}"
)
if original_runtime is None or original_total_runtime_iter < original_runtime:
if (
original_runtime is None
or original_total_runtime_iter < original_runtime
):
original_runtime = best_runtime = original_total_runtime_iter
overall_original_test_results = original_test_results_iter
@ -410,8 +408,8 @@ def main():
for instrumented_test_file in instrumented_unittests_created:
result_file_path, run_result = run_tests(
instrumented_test_file,
test_framework=args.test_framework,
cwd=args.root,
test_framework=self.args.test_framework,
cwd=self.args.root,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
verbose=True,
test_env=test_env,
@ -420,7 +418,7 @@ def main():
unittest_results_optimized = parse_test_results(
test_xml_path=result_file_path,
test_py_path=instrumented_test_file,
test_config=test_cfg,
test_config=self.test_cfg,
test_type=TestType.EXISTING_UNIT_TEST,
run_result=run_result,
optimization_iteration=j,
@ -467,8 +465,8 @@ def main():
start_time = time.time()
result_file_path, run_result = run_tests(
test_path=generated_tests_path,
test_framework=args.test_framework,
cwd=args.root,
test_framework=self.args.test_framework,
cwd=self.args.root,
test_env=test_env,
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
)
@ -478,7 +476,7 @@ def main():
test_py_path=generated_tests_path,
optimization_iteration=j,
test_type=TestType.GENERATED_REGRESSION,
test_config=test_cfg,
test_config=self.test_cfg,
run_result=run_result,
)
if test_index == 0:
@ -545,9 +543,9 @@ def main():
logging.info("----------------")
logging.info(f"BEST OPTIMIZATION {best_optimization}")
if best_optimization:
found_atleast_one_optimization = True
self.found_atleast_one_optimization = True
logging.info(f"BEST OPTIMIZED CODE {best_optimization[0]}")
if not args.all:
if not self.args.all:
new_code = replace_function_in_file(
path,
function_name,
@ -574,7 +572,7 @@ def main():
with open("/tmp/pr_comment_temp.txt", "a") as f:
f.write(explanation_final)
logging.info(f"EXPLANATION_FINAL {explanation_final}")
if args.all:
if self.args.all:
with open("optimizations_all.txt", "a") as f:
f.write(best_optimization[0])
f.write("\n\n")
@ -592,7 +590,7 @@ def main():
# Delete it here to not cause a lot of clutter if we are optimizing with --all option
if os.path.exists(generated_tests_path):
os.remove(generated_tests_path)
if not found_atleast_one_optimization:
if not self.found_atleast_one_optimization:
try:
with open(os.environ["GITHUB_OUTPUT"], "w") as fh:
print("optimization_success=falsee", file=fh)
@ -611,6 +609,34 @@ def main():
if hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir.cleanup()
def generate_test_files(
self,
function_to_optimize: FunctionToOptimize,
function_dependencies: List[Source],
code_to_optimize_with_dependents: str,
module_path: str,
) -> str | None:
generated_tests_path = get_test_file_path(
self.args.test_root, function_to_optimize.function_name, 0
)
test_module_path = module_name_from_file_path(generated_tests_path, self.args.root)
new_tests = generate_tests(
source_code_being_tested=code_to_optimize_with_dependents,
function=function_to_optimize,
module_path=module_path,
test_module_path=test_module_path,
function_dependencies=function_dependencies,
test_framework=self.args.test_framework,
test_timeout=INDIVIDUAL_TEST_TIMEOUT,
use_cached_tests=self.args.use_cached_tests,
)
if new_tests is None:
logging.error("/!\\ NO TESTS GENERATED for %s", function_to_optimize.function_name)
return None
with open(generated_tests_path, "w") as file:
file.write(new_tests)
return generated_tests_path
if __name__ == "__main__":
main()
Optimizer(parse_args()).run()

View file

@ -1,8 +1,8 @@
import os
import ast
import os
def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"):
def get_test_file_path(test_dir: str, function_name: str, iteration: int = 0, test_type: str = "unit") -> str:
assert test_type in ["unit", "inspired", "replay"]
function_name = function_name.replace(".", "_")
path = os.path.join(test_dir, f"test_{function_name}__{test_type}_test_{iteration}.py")
@ -11,7 +11,7 @@ def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"):
return path
def delete_multiple_if_name_main(test_ast):
def delete_multiple_if_name_main(test_ast: ast.Module) -> ast.Module:
if_indexes = []
for index, node in enumerate(test_ast.body):
if isinstance(node, ast.If):