mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
Refactor main function into Optimizer class
This commit is contained in:
parent
205e360e3e
commit
e0db936639
8 changed files with 476 additions and 445 deletions
|
|
@ -7,8 +7,7 @@
|
|||
<excludeFolder url="file://$MODULE_DIR$/.mypy_cache" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/.pytest_cache" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="codeflash311" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="Poetry (codeflash) (3)" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="Poetry (codeflash) (4)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
|
|
@ -2,5 +2,6 @@
|
|||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||
<inspection_tool class="Mypy" enabled="true" level="SERVER PROBLEM" enabled_by_default="true" editorAttributes="GENERIC_SERVER_ERROR_OR_WARNING" />
|
||||
</profile>
|
||||
</component>
|
||||
|
|
@ -1,9 +1,10 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<info color="dc00eb">
|
||||
<option name="FOREGROUND" value="dc00eb" />
|
||||
<option name="BACKGROUND" value="e8d9ff" />
|
||||
<option name="EFFECT_COLOR" value="dc00eb" />
|
||||
<option name="ERROR_STRIPE_COLOR" value="dc00eb" />
|
||||
<option name="EFFECT_TYPE" value="2" />
|
||||
<option name="myName" value="Type Hint" />
|
||||
<option name="myVal" value="50" />
|
||||
<option name="myExternalName" value="Type Hint" />
|
||||
|
|
@ -23,8 +24,8 @@
|
|||
<item index="8" class="java.lang.String" itemvalue="WEAK WARNING" />
|
||||
<item index="9" class="java.lang.String" itemvalue="INFO" />
|
||||
<item index="10" class="java.lang.String" itemvalue="WARNING" />
|
||||
<item index="11" class="java.lang.String" itemvalue="Type Hint" />
|
||||
<item index="12" class="java.lang.String" itemvalue="ERROR" />
|
||||
<item index="11" class="java.lang.String" itemvalue="ERROR" />
|
||||
<item index="12" class="java.lang.String" itemvalue="Type Hint" />
|
||||
</list>
|
||||
</settings>
|
||||
</component>
|
||||
|
|
@ -3,8 +3,7 @@
|
|||
<component name="Black">
|
||||
<option name="enabledOnReformat" value="true" />
|
||||
<option name="enabledOnSave" value="true" />
|
||||
<option name="pathToExecutable" value="$USER_HOME$/mambaforge/bin/black" />
|
||||
<option name="sdkName" value="Poetry (codeflash) (3)" />
|
||||
<option name="sdkName" value="Poetry (codeflash) (4)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Poetry (codeflash) (3)" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Poetry (codeflash) (4)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
6
.idea/pydantic.xml
Normal file
6
.idea/pydantic.xml
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PydanticConfigService">
|
||||
<option name="warnUntypedFields" value="true" />
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="SqlDialectMappings">
|
||||
<file url="file://$PROJECT_DIR$/code_to_optimize/tests/pytest/test_bubble_sort_deleteme.py" dialect="GenericSQL" />
|
||||
<file url="PROJECT" dialect="SQLite" />
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
from typing import List
|
||||
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.verification import EXPLAIN_MODEL
|
||||
|
|
@ -7,7 +8,7 @@ logging.basicConfig(level=logging.INFO)
|
|||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from argparse import ArgumentParser, SUPPRESS
|
||||
from argparse import ArgumentParser, SUPPRESS, Namespace
|
||||
|
||||
import libcst as cst
|
||||
|
||||
|
|
@ -19,12 +20,16 @@ from codeflash.code_utils.code_utils import (
|
|||
get_run_tmp_file,
|
||||
)
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize_by_file
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests, TestsInFile
|
||||
from codeflash.discovery.functions_to_optimize import (
|
||||
get_functions_to_optimize_by_file,
|
||||
FunctionToOptimize,
|
||||
)
|
||||
from codeflash.instrumentation.instrument_existing_tests import inject_profiling_into_existing_test
|
||||
from codeflash.models import TestConfig
|
||||
from codeflash.optimization.function_context import (
|
||||
get_constrained_function_context_and_dependent_functions,
|
||||
Source,
|
||||
)
|
||||
from codeflash.optimization.optimizer import optimize_python_code
|
||||
from codeflash.verification.equivalence import compare_results
|
||||
|
|
@ -40,7 +45,7 @@ from codeflash.verification.verification_utils import (
|
|||
from codeflash.verification.verifier import generate_tests
|
||||
|
||||
|
||||
def parse_args():
|
||||
def parse_args() -> Namespace:
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--file", help="Try to optimize only this file")
|
||||
parser.add_argument(
|
||||
|
|
@ -76,7 +81,7 @@ def parse_args():
|
|||
help="Use cached tests from a specified file for debugging.",
|
||||
)
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose logs")
|
||||
args = parser.parse_args()
|
||||
args: Namespace = parser.parse_args()
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
|
@ -89,9 +94,7 @@ def parse_args():
|
|||
if key in pyproject_config and getattr(args, key.replace("-", "_")) is None:
|
||||
setattr(args, key.replace("-", "_"), pyproject_config[key])
|
||||
assert os.path.isdir(args.root), "--root must be a valid directory"
|
||||
assert os.path.isdir(
|
||||
args.test_root
|
||||
), f"--test-root must be a valid directory; {args.test_root} is not a directory"
|
||||
assert os.path.isdir(args.test_root), "--test_root must be a valid directory"
|
||||
args.root = os.path.realpath(args.root)
|
||||
args.test_root = os.path.realpath(args.test_root)
|
||||
if not hasattr(args, "all"):
|
||||
|
|
@ -111,27 +114,30 @@ N_CANDIDATES = 10
|
|||
MIN_IMPROVEMENT_THRESHOLD = 0.05
|
||||
|
||||
|
||||
def main():
|
||||
logging.info("RUNNING THE OPTIMIZER")
|
||||
args = parse_args()
|
||||
env_utils.ensure_codeflash_api_key()
|
||||
test_cfg = TestConfig(
|
||||
class Optimizer:
|
||||
def __init__(self, args: Namespace):
|
||||
self.args = args
|
||||
self.test_cfg = TestConfig(
|
||||
test_root=args.test_root,
|
||||
project_root_path=args.root,
|
||||
test_framework=args.test_framework,
|
||||
)
|
||||
|
||||
modified_functions, num_modified_functions = get_functions_to_optimize_by_file(
|
||||
optimize_all=args.all,
|
||||
file=args.file,
|
||||
function=args.function,
|
||||
test_cfg=test_cfg,
|
||||
def run(self):
|
||||
logging.info("RUNNING THE OPTIMIZER")
|
||||
env_utils.ensure_codeflash_api_key()
|
||||
|
||||
file_to_funcs_to_optimize, num_modified_functions = get_functions_to_optimize_by_file(
|
||||
optimize_all=self.args.all,
|
||||
file=self.args.file,
|
||||
function=self.args.function,
|
||||
test_cfg=self.test_cfg,
|
||||
)
|
||||
|
||||
test_files_created = set()
|
||||
test_files_to_preserve = set()
|
||||
instrumented_unittests_created = set()
|
||||
found_atleast_one_optimization = False
|
||||
self.found_atleast_one_optimization = False
|
||||
|
||||
if os.path.exists("/tmp/pr_comment_temp.txt"):
|
||||
os.remove("/tmp/pr_comment_temp.txt")
|
||||
|
|
@ -140,14 +146,14 @@ def main():
|
|||
if num_modified_functions == 0:
|
||||
logging.info("No functions found to optimize. Exiting...")
|
||||
return
|
||||
functions_to_tests_map = discover_unit_tests(test_cfg)
|
||||
for path in modified_functions:
|
||||
function_to_tests: dict[str, list[TestsInFile]] = discover_unit_tests(self.test_cfg)
|
||||
for path in file_to_funcs_to_optimize:
|
||||
logging.info(f"Examining file {path} ...")
|
||||
# TODO: Sequence the functions one goes through intelligently. If we are optimizing f(g(x)), then we might want to first
|
||||
# optimize f rather than g because optimizing f would already optimize g as it is a dependency
|
||||
with open(path, "r") as f:
|
||||
original_code = f.read()
|
||||
for function_to_optimize in modified_functions[path]:
|
||||
for function_to_optimize in file_to_funcs_to_optimize[path]:
|
||||
function_name = function_to_optimize.function_name
|
||||
function_iterator_count += 1
|
||||
logging.info(
|
||||
|
|
@ -173,58 +179,47 @@ def main():
|
|||
function_dependencies,
|
||||
) = get_constrained_function_context_and_dependent_functions(
|
||||
function_to_optimize,
|
||||
args.root,
|
||||
self.args.root,
|
||||
code_to_optimize,
|
||||
max_tokens=EXPLAIN_MODEL.max_tokens,
|
||||
)
|
||||
logging.info("CODE TO OPTIMIZE %s", code_to_optimize_with_dependents)
|
||||
module_path = module_name_from_file_path(path, args.root)
|
||||
module_path = module_name_from_file_path(path, self.args.root)
|
||||
unique_original_test_files = set()
|
||||
|
||||
if not module_path + "." + function_name in functions_to_tests_map:
|
||||
full_module_function_path = module_path + "." + function_name
|
||||
if full_module_function_path not in function_to_tests:
|
||||
logging.warning(
|
||||
"Could not find any pre-existing tests for '%s', will only use generated tests.",
|
||||
module_path + "." + function_name,
|
||||
full_module_function_path,
|
||||
)
|
||||
else:
|
||||
for i, tests_in_file in enumerate(
|
||||
functions_to_tests_map.get(module_path + "." + function_name)
|
||||
):
|
||||
for tests_in_file in function_to_tests.get(full_module_function_path):
|
||||
if tests_in_file.test_file in unique_original_test_files:
|
||||
continue
|
||||
injected_test = inject_profiling_into_existing_test(
|
||||
tests_in_file.test_file,
|
||||
function_name,
|
||||
self.args.root,
|
||||
)
|
||||
new_test_path = (
|
||||
os.path.splitext(tests_in_file.test_file)[0]
|
||||
+ "__perfinstrumented"
|
||||
+ os.path.splitext(tests_in_file.test_file)[1]
|
||||
)
|
||||
injected_test = inject_profiling_into_existing_test(
|
||||
tests_in_file.test_file,
|
||||
function_name,
|
||||
args.root,
|
||||
)
|
||||
with open(new_test_path, "w") as f:
|
||||
f.write(injected_test)
|
||||
instrumented_unittests_created.add(new_test_path)
|
||||
unique_original_test_files.add(tests_in_file.test_file)
|
||||
generated_tests_path = get_test_file_path(args.test_root, function_name, 0)
|
||||
test_module_path = module_name_from_file_path(generated_tests_path, args.root)
|
||||
new_tests = generate_tests(
|
||||
source_code_being_tested=code_to_optimize_with_dependents,
|
||||
function=function_to_optimize,
|
||||
module_path=module_path,
|
||||
test_module_path=test_module_path,
|
||||
function_dependencies=function_dependencies,
|
||||
test_framework=args.test_framework,
|
||||
test_timeout=INDIVIDUAL_TEST_TIMEOUT,
|
||||
use_cached_tests=args.use_cached_tests,
|
||||
|
||||
generated_tests_path = self.generate_test_files(
|
||||
function_to_optimize,
|
||||
function_dependencies,
|
||||
code_to_optimize_with_dependents,
|
||||
module_path,
|
||||
)
|
||||
if new_tests is None:
|
||||
logging.error("/!\\ NO TESTS GENERATED for %s", function_name)
|
||||
continue
|
||||
|
||||
test_files_created.add(generated_tests_path)
|
||||
with open(generated_tests_path, "w") as file:
|
||||
file.write(new_tests)
|
||||
original_runtime = None
|
||||
times_run = 0
|
||||
# TODO : Dynamically determine the number of times to run the tests based on the runtime of the tests.
|
||||
|
|
@ -244,8 +239,8 @@ def main():
|
|||
for test_file in instrumented_unittests_created:
|
||||
result_file_path, run_result = run_tests(
|
||||
test_file,
|
||||
test_framework=args.test_framework,
|
||||
cwd=args.root,
|
||||
test_framework=self.args.test_framework,
|
||||
cwd=self.args.root,
|
||||
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
|
||||
verbose=True,
|
||||
test_env=test_env,
|
||||
|
|
@ -253,7 +248,7 @@ def main():
|
|||
unittest_results = parse_test_results(
|
||||
test_xml_path=result_file_path,
|
||||
test_py_path=test_file,
|
||||
test_config=test_cfg,
|
||||
test_config=self.test_cfg,
|
||||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
run_result=run_result,
|
||||
optimization_iteration=0,
|
||||
|
|
@ -281,8 +276,8 @@ def main():
|
|||
start_time = time.time()
|
||||
result_file_path, run_result = run_tests(
|
||||
test_path=generated_tests_path,
|
||||
cwd=args.root,
|
||||
test_framework=args.test_framework,
|
||||
cwd=self.args.root,
|
||||
test_framework=self.args.test_framework,
|
||||
test_env=test_env,
|
||||
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
|
||||
)
|
||||
|
|
@ -291,7 +286,7 @@ def main():
|
|||
original_gen_results = parse_test_results(
|
||||
result_file_path,
|
||||
generated_tests_path,
|
||||
test_cfg,
|
||||
self.test_cfg,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
run_result=run_result,
|
||||
optimization_iteration=0,
|
||||
|
|
@ -341,7 +336,10 @@ def main():
|
|||
logging.info(
|
||||
f"Original overall test results = {original_test_results_iter.get_test_pass_fail_report_by_type()}"
|
||||
)
|
||||
if original_runtime is None or original_total_runtime_iter < original_runtime:
|
||||
if (
|
||||
original_runtime is None
|
||||
or original_total_runtime_iter < original_runtime
|
||||
):
|
||||
original_runtime = best_runtime = original_total_runtime_iter
|
||||
overall_original_test_results = original_test_results_iter
|
||||
|
||||
|
|
@ -410,8 +408,8 @@ def main():
|
|||
for instrumented_test_file in instrumented_unittests_created:
|
||||
result_file_path, run_result = run_tests(
|
||||
instrumented_test_file,
|
||||
test_framework=args.test_framework,
|
||||
cwd=args.root,
|
||||
test_framework=self.args.test_framework,
|
||||
cwd=self.args.root,
|
||||
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
|
||||
verbose=True,
|
||||
test_env=test_env,
|
||||
|
|
@ -420,7 +418,7 @@ def main():
|
|||
unittest_results_optimized = parse_test_results(
|
||||
test_xml_path=result_file_path,
|
||||
test_py_path=instrumented_test_file,
|
||||
test_config=test_cfg,
|
||||
test_config=self.test_cfg,
|
||||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
run_result=run_result,
|
||||
optimization_iteration=j,
|
||||
|
|
@ -467,8 +465,8 @@ def main():
|
|||
start_time = time.time()
|
||||
result_file_path, run_result = run_tests(
|
||||
test_path=generated_tests_path,
|
||||
test_framework=args.test_framework,
|
||||
cwd=args.root,
|
||||
test_framework=self.args.test_framework,
|
||||
cwd=self.args.root,
|
||||
test_env=test_env,
|
||||
pytest_timeout=INDIVIDUAL_TEST_TIMEOUT,
|
||||
)
|
||||
|
|
@ -478,7 +476,7 @@ def main():
|
|||
test_py_path=generated_tests_path,
|
||||
optimization_iteration=j,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
test_config=test_cfg,
|
||||
test_config=self.test_cfg,
|
||||
run_result=run_result,
|
||||
)
|
||||
if test_index == 0:
|
||||
|
|
@ -545,9 +543,9 @@ def main():
|
|||
logging.info("----------------")
|
||||
logging.info(f"BEST OPTIMIZATION {best_optimization}")
|
||||
if best_optimization:
|
||||
found_atleast_one_optimization = True
|
||||
self.found_atleast_one_optimization = True
|
||||
logging.info(f"BEST OPTIMIZED CODE {best_optimization[0]}")
|
||||
if not args.all:
|
||||
if not self.args.all:
|
||||
new_code = replace_function_in_file(
|
||||
path,
|
||||
function_name,
|
||||
|
|
@ -574,7 +572,7 @@ def main():
|
|||
with open("/tmp/pr_comment_temp.txt", "a") as f:
|
||||
f.write(explanation_final)
|
||||
logging.info(f"EXPLANATION_FINAL {explanation_final}")
|
||||
if args.all:
|
||||
if self.args.all:
|
||||
with open("optimizations_all.txt", "a") as f:
|
||||
f.write(best_optimization[0])
|
||||
f.write("\n\n")
|
||||
|
|
@ -592,7 +590,7 @@ def main():
|
|||
# Delete it here to not cause a lot of clutter if we are optimizing with --all option
|
||||
if os.path.exists(generated_tests_path):
|
||||
os.remove(generated_tests_path)
|
||||
if not found_atleast_one_optimization:
|
||||
if not self.found_atleast_one_optimization:
|
||||
try:
|
||||
with open(os.environ["GITHUB_OUTPUT"], "w") as fh:
|
||||
print("optimization_success=falsee", file=fh)
|
||||
|
|
@ -611,6 +609,34 @@ def main():
|
|||
if hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir.cleanup()
|
||||
|
||||
def generate_test_files(
|
||||
self,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
function_dependencies: List[Source],
|
||||
code_to_optimize_with_dependents: str,
|
||||
module_path: str,
|
||||
) -> str | None:
|
||||
generated_tests_path = get_test_file_path(
|
||||
self.args.test_root, function_to_optimize.function_name, 0
|
||||
)
|
||||
test_module_path = module_name_from_file_path(generated_tests_path, self.args.root)
|
||||
new_tests = generate_tests(
|
||||
source_code_being_tested=code_to_optimize_with_dependents,
|
||||
function=function_to_optimize,
|
||||
module_path=module_path,
|
||||
test_module_path=test_module_path,
|
||||
function_dependencies=function_dependencies,
|
||||
test_framework=self.args.test_framework,
|
||||
test_timeout=INDIVIDUAL_TEST_TIMEOUT,
|
||||
use_cached_tests=self.args.use_cached_tests,
|
||||
)
|
||||
if new_tests is None:
|
||||
logging.error("/!\\ NO TESTS GENERATED for %s", function_to_optimize.function_name)
|
||||
return None
|
||||
with open(generated_tests_path, "w") as file:
|
||||
file.write(new_tests)
|
||||
return generated_tests_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Optimizer(parse_args()).run()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import os
|
||||
import ast
|
||||
import os
|
||||
|
||||
|
||||
def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"):
|
||||
def get_test_file_path(test_dir: str, function_name: str, iteration: int = 0, test_type: str = "unit") -> str:
|
||||
assert test_type in ["unit", "inspired", "replay"]
|
||||
function_name = function_name.replace(".", "_")
|
||||
path = os.path.join(test_dir, f"test_{function_name}__{test_type}_test_{iteration}.py")
|
||||
|
|
@ -11,7 +11,7 @@ def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"):
|
|||
return path
|
||||
|
||||
|
||||
def delete_multiple_if_name_main(test_ast):
|
||||
def delete_multiple_if_name_main(test_ast: ast.Module) -> ast.Module:
|
||||
if_indexes = []
|
||||
for index, node in enumerate(test_ast.body):
|
||||
if isinstance(node, ast.If):
|
||||
|
|
|
|||
Loading…
Reference in a new issue