Refactor PR creation process

This commit is contained in:
afik.cohen 2024-04-12 17:38:45 -07:00
parent 52311a81f5
commit 2636646379
5 changed files with 172 additions and 95 deletions

View file

@ -0,0 +1,6 @@
from code_to_optimize.bubble_sort import sorter
def sort_from_another_file(arr):
sorted_arr = sorter(arr)
return sorted_arr

View file

@ -1,4 +1,5 @@
"""Math utils."""
from typing import List, Optional, Tuple, Union
import numpy as np
@ -15,7 +16,7 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
if X.shape[1] != Y.shape[1]:
raise ValueError(
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
f"and Y has shape {Y.shape}."
f"and Y has shape {Y.shape}.",
)
X_norm = np.linalg.norm(X, axis=1)
@ -34,14 +35,17 @@ def cosine_similarity_top_k(
"""Row-wise cosine similarity with optional top-k and score threshold filtering.
Args:
----
X: Matrix.
Y: Matrix, same width as X.
top_k: Max number of results to return.
score_threshold: Minimum cosine similarity of results.
Returns:
-------
Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx),
second contains corresponding cosine similarities.
"""
if len(X) == 0 or len(Y) == 0:
return [], []

View file

@ -1,6 +1,6 @@
from typing import Optional, List, Tuple
from typing import List, Optional, Tuple
from code_to_optimize.math_utils import cosine_similarity_top_k, Matrix
from code_to_optimize.math_utils import Matrix, cosine_similarity_top_k
def use_cosine_similarity(

View file

@ -42,7 +42,7 @@ from codeflash.optimization.function_context import (
Source,
get_constrained_function_context_and_dependent_functions,
)
from codeflash.result.create_pr import check_create_pr
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
from codeflash.result.explanation import Explanation
from codeflash.telemetry import posthog
from codeflash.telemetry.posthog import ph
@ -149,18 +149,14 @@ class Optimizer:
f: IO[str]
with open(path, encoding="utf8") as f:
original_code: str = f.read()
should_sort_imports = True
if sort_imports(self.args.imports_sort_cmd, should_sort_imports, path) != original_code:
should_sort_imports = False
function_to_optimize: FunctionToOptimize
for function_to_optimize in file_to_funcs_to_optimize[path]:
function_trace_id: str = str(uuid.uuid4())
ph("cli-optimize-function-start", {"function_trace_id": function_trace_id})
qualified_function_name: str = function_to_optimize.qualified_name
function_iterator_count += 1
logging.info(
f"Optimizing function {function_iterator_count} of {num_optimizable_functions} - {qualified_function_name}",
f"Optimizing function {function_iterator_count} of {num_optimizable_functions} - {function_to_optimize.qualified_name}",
)
winning_test_results = None
self.cleanup_leftover_test_return_values()
@ -168,7 +164,7 @@ class Optimizer:
if not is_successful(ctx_result):
logging.error(ctx_result.failure())
continue
code_context = ctx_result.unwrap()
code_context: CodeOptimizationContext = ctx_result.unwrap()
dependent_functions_by_module_abspath = defaultdict(set)
for _, module_abspath, qualified_name in code_context.dependent_functions:
dependent_functions_by_module_abspath[module_abspath].add(qualified_name)
@ -181,7 +177,7 @@ class Optimizer:
module_path = module_name_from_file_path(path, self.args.project_root)
instrumented_unittests_created_for_function = self.instrument_existing_tests(
function_name=qualified_function_name,
function_name=function_to_optimize.qualified_name,
module_path=module_path,
function_to_tests=function_to_tests,
)
@ -211,7 +207,7 @@ class Optimizer:
test_files_created.add(generated_tests_path)
baseline_result = self.establish_original_code_baseline(
qualified_function_name,
function_to_optimize.qualified_name,
instrumented_unittests_created_for_function,
generated_tests_path,
)
@ -243,7 +239,7 @@ class Optimizer:
logging.info(optimization.source_code)
try:
replace_function_definitions_in_module(
[qualified_function_name],
[function_to_optimize.qualified_name],
optimization.source_code,
path,
code_context.preexisting_functions,
@ -267,11 +263,12 @@ class Optimizer:
AttributeError,
) as e:
logging.exception(e)
with open(path, "w", encoding="utf8") as f:
f.write(original_code)
for module_abspath in dependent_functions_by_module_abspath.keys():
with open(module_abspath, "w", encoding="utf8") as f:
f.write(original_dependent_code[module_abspath])
self.write_code_and_dependents(
original_code,
original_dependent_code,
path,
dependent_functions_by_module_abspath,
)
continue
run_results = self.run_optimized_candidate(
@ -320,11 +317,12 @@ class Optimizer:
)
best_runtime = best_test_runtime
winning_test_results = candidate_result.best_test_results
with open(path, "w", encoding="utf8") as f:
f.write(original_code)
for module_abspath in dependent_functions_by_module_abspath.keys():
with open(module_abspath, "w", encoding="utf8") as f:
f.write(original_dependent_code[module_abspath])
self.write_code_and_dependents(
original_code,
original_dependent_code,
path,
dependent_functions_by_module_abspath,
)
logging.info("----------------")
log_results(
function_trace_id=function_trace_id,
@ -342,86 +340,69 @@ class Optimizer:
)
optimized_code = best_optimization.source_code
replace_function_definitions_in_module(
[qualified_function_name],
optimized_code,
path,
code_context.preexisting_functions,
code_context.contextual_dunder_methods,
)
for (
module_abspath,
qualified_names,
) in dependent_functions_by_module_abspath.items():
replace_function_definitions_in_module(
list(qualified_names),
optimized_code,
module_abspath,
[],
code_context.contextual_dunder_methods,
)
explanation_final = Explanation(
explanation = Explanation(
raw_explanation_message=best_optimization.explanation,
winning_test_results=winning_test_results,
original_runtime_ns=original_code_baseline.runtime,
best_runtime_ns=best_runtime,
function_name=qualified_function_name,
function_name=function_to_optimize.qualified_name,
path=path,
)
logging.info(f"Explanation: \n{explanation_final.to_console_string()}")
new_code = format_code(
self.args.formatter_cmd,
self.args.imports_sort_cmd,
should_sort_imports,
path,
logging.info(
f"⚡️ Optimization successful! 📄 {function_to_optimize.qualified_name} in {explanation.path}",
)
new_dependent_code: dict[str, str] = {
module_abspath: format_code(
self.args.formatter_cmd,
self.args.imports_sort_cmd,
should_sort_imports,
module_abspath,
)
for module_abspath in dependent_functions_by_module_abspath.keys()
}
logging.info(f"📈 {explanation.perf_improvement_line}")
logging.info(f"Explanation: \n{explanation.to_console_string()}")
logging.info(
f"Optimization was validated for correctness by running the following tests - "
f"\n{tests_and_optimizations.generated_original_test_source}",
)
logging.info(f"⚡️ Optimization successful! 📄 {qualified_function_name} in {path}")
logging.info(f"📈 {explanation_final.perf_improvement_line}")
ph(
"cli-optimize-success",
{
"function_trace_id": function_trace_id,
"speedup_x": explanation_final.speedup_x,
"speedup_pct": explanation_final.speedup_pct,
"best_runtime": explanation_final.best_runtime_ns,
"original_runtime": explanation_final.original_runtime_ns,
"speedup_x": explanation.speedup_x,
"speedup_pct": explanation.speedup_pct,
"best_runtime": explanation.best_runtime_ns,
"original_runtime": explanation.original_runtime_ns,
"winning_test_results": {
tt.to_name(): v
for tt, v in explanation_final.winning_test_results.get_test_pass_fail_report_by_type().items()
for tt, v in explanation.winning_test_results.get_test_pass_fail_report_by_type().items()
},
},
)
test_files = function_to_tests.get(module_path + "." + qualified_function_name)
existing_tests = ""
if test_files:
for test_file in test_files:
with open(test_file.test_file, encoding="utf8") as f:
new_test = "".join(f.readlines())
if new_test not in existing_tests:
existing_tests += new_test
existing_tests = existing_tests_source_for(
function_to_optimize.qualified_name,
module_path,
function_to_tests,
)
self.replace_function_and_dependents_with_optimized_code(
code_context,
dependent_functions_by_module_abspath,
explanation,
optimized_code,
function_to_optimize.qualified_name,
)
new_code, new_dependent_code = self.reformat_code_and_dependents(
dependent_functions_by_module_abspath,
explanation.path,
original_code,
)
original_code_combined = original_dependent_code.copy()
original_code_combined[explanation.path] = original_code
new_code_combined = new_dependent_code.copy()
new_code_combined[explanation.path] = new_code
check_create_pr(
optimize_all=self.args.all,
path=path,
original_code=original_dependent_code | {path: original_code},
new_code=new_dependent_code | {path: new_code},
explanation=explanation_final,
original_code=original_code_combined,
new_code=new_code_combined,
explanation=explanation,
existing_tests_source=existing_tests,
generated_original_test_source=tests_and_optimizations.generated_original_test_source,
)
@ -430,11 +411,12 @@ class Optimizer:
# a) Error propagation, where error in one function can cause the next optimization to fail
# b) Performance estimates become unstable, as the runtime of an optimization might be
# dependent on the runtime of the previous optimization
with open(path, "w", encoding="utf8") as f:
f.write(original_code)
for module_abspath in dependent_functions_by_module_abspath.keys():
with open(module_abspath, "w", encoding="utf8") as f:
f.write(original_dependent_code[module_abspath])
self.write_code_and_dependents(
original_code,
original_dependent_code,
path,
dependent_functions_by_module_abspath,
)
# Delete all the generated tests to not cause any clutter.
pathlib.Path(generated_tests_path).unlink(missing_ok=True)
for test_paths in instrumented_unittests_created_for_function:
@ -453,6 +435,73 @@ class Optimizer:
if hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir.cleanup()
def write_code_and_dependents(
self,
original_code: str,
original_dependent_code: Dict[str, str],
path: str,
dependent_functions_by_module_abspath: Dict[str, set[str]],
) -> None:
with open(path, "w", encoding="utf8") as f:
f.write(original_code)
for module_abspath in dependent_functions_by_module_abspath:
with open(module_abspath, "w", encoding="utf8") as f:
f.write(original_dependent_code[module_abspath])
def reformat_code_and_dependents(
self,
dependent_functions_by_module_abspath: Dict[str, set[str]],
path: str,
original_code: str,
) -> Tuple[str, Dict[str, str]]:
should_sort_imports = True
if sort_imports(self.args.imports_sort_cmd, should_sort_imports, path) != original_code:
should_sort_imports = False
new_code = format_code(
self.args.formatter_cmd,
self.args.imports_sort_cmd,
should_sort_imports,
path,
)
new_dependent_code: dict[str, str] = {
module_abspath: format_code(
self.args.formatter_cmd,
self.args.imports_sort_cmd,
should_sort_imports,
module_abspath,
)
for module_abspath in dependent_functions_by_module_abspath
}
return new_code, new_dependent_code
def replace_function_and_dependents_with_optimized_code(
self,
code_context: CodeOptimizationContext,
dependent_functions_by_module_abspath: Dict[str, set[str]],
explanation: Explanation,
optimized_code: str,
qualified_function_name: str,
) -> None:
replace_function_definitions_in_module(
[qualified_function_name],
optimized_code,
explanation.path,
code_context.preexisting_functions,
code_context.contextual_dunder_methods,
)
for (
module_abspath,
qualified_names,
) in dependent_functions_by_module_abspath.items():
replace_function_definitions_in_module(
list(qualified_names),
optimized_code,
module_abspath,
[],
code_context.contextual_dunder_methods,
)
def get_code_optimization_context(
self,
function_to_optimize: FunctionToOptimize,

View file

@ -1,22 +1,38 @@
import logging
import os.path
import pathlib
from typing import Optional
from typing import Dict, List, Optional
from codeflash.api import cfapi
from codeflash.code_utils import env_utils
from codeflash.code_utils.git_utils import (
get_current_branch,
get_repo_owner_and_name,
git_root_dir,
get_current_branch,
)
from codeflash.discovery.discover_unit_tests import TestsInFile
from codeflash.github.PrComment import FileDiffContent, PrComment
from codeflash.result.explanation import Explanation
def existing_tests_source_for(
qualified_function_name: str,
module_path: str,
function_to_tests: Dict[str, List[TestsInFile]],
) -> str:
test_files = function_to_tests.get(module_path + "." + qualified_function_name)
existing_tests = ""
if test_files:
for test_file in test_files:
with open(test_file.test_file, encoding="utf8") as f:
new_test = "".join(f.readlines())
if new_test not in existing_tests:
existing_tests += new_test
return existing_tests
def check_create_pr(
optimize_all: bool,
path: str,
original_code: dict[str, str],
new_code: dict[str, str],
explanation: Explanation,
@ -28,16 +44,17 @@ def check_create_pr(
if pr_number is not None:
logging.info(f"Suggesting changes to PR #{pr_number} ...")
owner, repo = get_repo_owner_and_name()
relative_path = str(pathlib.Path(os.path.relpath(path, git_root_dir())).as_posix())
relative_path = str(pathlib.Path(os.path.relpath(explanation.path, git_root_dir())).as_posix())
response = cfapi.suggest_changes(
owner=owner,
repo=repo,
pr_number=pr_number,
file_changes={
str(pathlib.Path(os.path.relpath(p, git_root_dir())).as_posix()): FileDiffContent(
oldContent=original_code[p], newContent=new_code[p]
oldContent=original_code[p],
newContent=new_code[p],
)
for p in original_code.keys()
for p in original_code
},
pr_comment=PrComment(
optimization_explanation=explanation.explanation_message(),
@ -53,18 +70,18 @@ def check_create_pr(
generated_tests=generated_original_test_source,
)
if response.ok:
logging.info("Suggestions were successfully made to PR #" + str(pr_number))
logging.info(f"Suggestions were successfully made to PR #{pr_number}")
else:
logging.error(
f"Optimization was successful, but I failed to suggest changes to PR #{pr_number}."
f" Response from server was: {response.text}"
f" Response from server was: {response.text}",
)
elif optimize_all:
logging.info("Creating a new PR with the optimized code...")
owner, repo = get_repo_owner_and_name()
relative_path = str(pathlib.Path(os.path.relpath(path, git_root_dir())).as_posix())
relative_path = str(pathlib.Path(os.path.relpath(explanation.path, git_root_dir())).as_posix())
base_branch = get_current_branch()
response = cfapi.create_pr(
owner=owner,
@ -72,9 +89,10 @@ def check_create_pr(
base_branch=base_branch,
file_changes={
str(pathlib.Path(os.path.relpath(p, git_root_dir())).as_posix()): FileDiffContent(
oldContent=original_code[p], newContent=new_code[p]
oldContent=original_code[p],
newContent=new_code[p],
)
for p in original_code.keys()
for p in original_code
},
pr_comment=PrComment(
optimization_explanation=explanation.explanation_message(),
@ -94,5 +112,5 @@ def check_create_pr(
else:
logging.error(
f"Optimization was successful, but I failed to create a PR with the optimized code."
f" Response from server was: {response.text}"
f" Response from server was: {response.text}",
)