mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Refactor PR creation process
This commit is contained in:
parent
52311a81f5
commit
2636646379
5 changed files with 172 additions and 95 deletions
6
code_to_optimize/bubble_sort_from_another_file.py
Normal file
6
code_to_optimize/bubble_sort_from_another_file.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
def sort_from_another_file(arr):
|
||||
sorted_arr = sorter(arr)
|
||||
return sorted_arr
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
"""Math utils."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -15,7 +16,7 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|||
if X.shape[1] != Y.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
||||
f"and Y has shape {Y.shape}."
|
||||
f"and Y has shape {Y.shape}.",
|
||||
)
|
||||
|
||||
X_norm = np.linalg.norm(X, axis=1)
|
||||
|
|
@ -34,14 +35,17 @@ def cosine_similarity_top_k(
|
|||
"""Row-wise cosine similarity with optional top-k and score threshold filtering.
|
||||
|
||||
Args:
|
||||
----
|
||||
X: Matrix.
|
||||
Y: Matrix, same width as X.
|
||||
top_k: Max number of results to return.
|
||||
score_threshold: Minimum cosine similarity of results.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx),
|
||||
second contains corresponding cosine similarities.
|
||||
|
||||
"""
|
||||
if len(X) == 0 or len(Y) == 0:
|
||||
return [], []
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional, List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from code_to_optimize.math_utils import cosine_similarity_top_k, Matrix
|
||||
from code_to_optimize.math_utils import Matrix, cosine_similarity_top_k
|
||||
|
||||
|
||||
def use_cosine_similarity(
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ from codeflash.optimization.function_context import (
|
|||
Source,
|
||||
get_constrained_function_context_and_dependent_functions,
|
||||
)
|
||||
from codeflash.result.create_pr import check_create_pr
|
||||
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
|
||||
from codeflash.result.explanation import Explanation
|
||||
from codeflash.telemetry import posthog
|
||||
from codeflash.telemetry.posthog import ph
|
||||
|
|
@ -149,18 +149,14 @@ class Optimizer:
|
|||
f: IO[str]
|
||||
with open(path, encoding="utf8") as f:
|
||||
original_code: str = f.read()
|
||||
should_sort_imports = True
|
||||
if sort_imports(self.args.imports_sort_cmd, should_sort_imports, path) != original_code:
|
||||
should_sort_imports = False
|
||||
|
||||
function_to_optimize: FunctionToOptimize
|
||||
for function_to_optimize in file_to_funcs_to_optimize[path]:
|
||||
function_trace_id: str = str(uuid.uuid4())
|
||||
ph("cli-optimize-function-start", {"function_trace_id": function_trace_id})
|
||||
qualified_function_name: str = function_to_optimize.qualified_name
|
||||
function_iterator_count += 1
|
||||
logging.info(
|
||||
f"Optimizing function {function_iterator_count} of {num_optimizable_functions} - {qualified_function_name}",
|
||||
f"Optimizing function {function_iterator_count} of {num_optimizable_functions} - {function_to_optimize.qualified_name}",
|
||||
)
|
||||
winning_test_results = None
|
||||
self.cleanup_leftover_test_return_values()
|
||||
|
|
@ -168,7 +164,7 @@ class Optimizer:
|
|||
if not is_successful(ctx_result):
|
||||
logging.error(ctx_result.failure())
|
||||
continue
|
||||
code_context = ctx_result.unwrap()
|
||||
code_context: CodeOptimizationContext = ctx_result.unwrap()
|
||||
dependent_functions_by_module_abspath = defaultdict(set)
|
||||
for _, module_abspath, qualified_name in code_context.dependent_functions:
|
||||
dependent_functions_by_module_abspath[module_abspath].add(qualified_name)
|
||||
|
|
@ -181,7 +177,7 @@ class Optimizer:
|
|||
module_path = module_name_from_file_path(path, self.args.project_root)
|
||||
|
||||
instrumented_unittests_created_for_function = self.instrument_existing_tests(
|
||||
function_name=qualified_function_name,
|
||||
function_name=function_to_optimize.qualified_name,
|
||||
module_path=module_path,
|
||||
function_to_tests=function_to_tests,
|
||||
)
|
||||
|
|
@ -211,7 +207,7 @@ class Optimizer:
|
|||
|
||||
test_files_created.add(generated_tests_path)
|
||||
baseline_result = self.establish_original_code_baseline(
|
||||
qualified_function_name,
|
||||
function_to_optimize.qualified_name,
|
||||
instrumented_unittests_created_for_function,
|
||||
generated_tests_path,
|
||||
)
|
||||
|
|
@ -243,7 +239,7 @@ class Optimizer:
|
|||
logging.info(optimization.source_code)
|
||||
try:
|
||||
replace_function_definitions_in_module(
|
||||
[qualified_function_name],
|
||||
[function_to_optimize.qualified_name],
|
||||
optimization.source_code,
|
||||
path,
|
||||
code_context.preexisting_functions,
|
||||
|
|
@ -267,11 +263,12 @@ class Optimizer:
|
|||
AttributeError,
|
||||
) as e:
|
||||
logging.exception(e)
|
||||
with open(path, "w", encoding="utf8") as f:
|
||||
f.write(original_code)
|
||||
for module_abspath in dependent_functions_by_module_abspath.keys():
|
||||
with open(module_abspath, "w", encoding="utf8") as f:
|
||||
f.write(original_dependent_code[module_abspath])
|
||||
self.write_code_and_dependents(
|
||||
original_code,
|
||||
original_dependent_code,
|
||||
path,
|
||||
dependent_functions_by_module_abspath,
|
||||
)
|
||||
continue
|
||||
|
||||
run_results = self.run_optimized_candidate(
|
||||
|
|
@ -320,11 +317,12 @@ class Optimizer:
|
|||
)
|
||||
best_runtime = best_test_runtime
|
||||
winning_test_results = candidate_result.best_test_results
|
||||
with open(path, "w", encoding="utf8") as f:
|
||||
f.write(original_code)
|
||||
for module_abspath in dependent_functions_by_module_abspath.keys():
|
||||
with open(module_abspath, "w", encoding="utf8") as f:
|
||||
f.write(original_dependent_code[module_abspath])
|
||||
self.write_code_and_dependents(
|
||||
original_code,
|
||||
original_dependent_code,
|
||||
path,
|
||||
dependent_functions_by_module_abspath,
|
||||
)
|
||||
logging.info("----------------")
|
||||
log_results(
|
||||
function_trace_id=function_trace_id,
|
||||
|
|
@ -342,86 +340,69 @@ class Optimizer:
|
|||
)
|
||||
|
||||
optimized_code = best_optimization.source_code
|
||||
replace_function_definitions_in_module(
|
||||
[qualified_function_name],
|
||||
optimized_code,
|
||||
path,
|
||||
code_context.preexisting_functions,
|
||||
code_context.contextual_dunder_methods,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
qualified_names,
|
||||
) in dependent_functions_by_module_abspath.items():
|
||||
replace_function_definitions_in_module(
|
||||
list(qualified_names),
|
||||
optimized_code,
|
||||
module_abspath,
|
||||
[],
|
||||
code_context.contextual_dunder_methods,
|
||||
)
|
||||
explanation_final = Explanation(
|
||||
explanation = Explanation(
|
||||
raw_explanation_message=best_optimization.explanation,
|
||||
winning_test_results=winning_test_results,
|
||||
original_runtime_ns=original_code_baseline.runtime,
|
||||
best_runtime_ns=best_runtime,
|
||||
function_name=qualified_function_name,
|
||||
function_name=function_to_optimize.qualified_name,
|
||||
path=path,
|
||||
)
|
||||
logging.info(f"Explanation: \n{explanation_final.to_console_string()}")
|
||||
|
||||
new_code = format_code(
|
||||
self.args.formatter_cmd,
|
||||
self.args.imports_sort_cmd,
|
||||
should_sort_imports,
|
||||
path,
|
||||
logging.info(
|
||||
f"⚡️ Optimization successful! 📄 {function_to_optimize.qualified_name} in {explanation.path}",
|
||||
)
|
||||
new_dependent_code: dict[str, str] = {
|
||||
module_abspath: format_code(
|
||||
self.args.formatter_cmd,
|
||||
self.args.imports_sort_cmd,
|
||||
should_sort_imports,
|
||||
module_abspath,
|
||||
)
|
||||
for module_abspath in dependent_functions_by_module_abspath.keys()
|
||||
}
|
||||
logging.info(f"📈 {explanation.perf_improvement_line}")
|
||||
logging.info(f"Explanation: \n{explanation.to_console_string()}")
|
||||
logging.info(
|
||||
f"Optimization was validated for correctness by running the following tests - "
|
||||
f"\n{tests_and_optimizations.generated_original_test_source}",
|
||||
)
|
||||
|
||||
logging.info(f"⚡️ Optimization successful! 📄 {qualified_function_name} in {path}")
|
||||
logging.info(f"📈 {explanation_final.perf_improvement_line}")
|
||||
|
||||
ph(
|
||||
"cli-optimize-success",
|
||||
{
|
||||
"function_trace_id": function_trace_id,
|
||||
"speedup_x": explanation_final.speedup_x,
|
||||
"speedup_pct": explanation_final.speedup_pct,
|
||||
"best_runtime": explanation_final.best_runtime_ns,
|
||||
"original_runtime": explanation_final.original_runtime_ns,
|
||||
"speedup_x": explanation.speedup_x,
|
||||
"speedup_pct": explanation.speedup_pct,
|
||||
"best_runtime": explanation.best_runtime_ns,
|
||||
"original_runtime": explanation.original_runtime_ns,
|
||||
"winning_test_results": {
|
||||
tt.to_name(): v
|
||||
for tt, v in explanation_final.winning_test_results.get_test_pass_fail_report_by_type().items()
|
||||
for tt, v in explanation.winning_test_results.get_test_pass_fail_report_by_type().items()
|
||||
},
|
||||
},
|
||||
)
|
||||
test_files = function_to_tests.get(module_path + "." + qualified_function_name)
|
||||
existing_tests = ""
|
||||
if test_files:
|
||||
for test_file in test_files:
|
||||
with open(test_file.test_file, encoding="utf8") as f:
|
||||
new_test = "".join(f.readlines())
|
||||
if new_test not in existing_tests:
|
||||
existing_tests += new_test
|
||||
|
||||
existing_tests = existing_tests_source_for(
|
||||
function_to_optimize.qualified_name,
|
||||
module_path,
|
||||
function_to_tests,
|
||||
)
|
||||
|
||||
self.replace_function_and_dependents_with_optimized_code(
|
||||
code_context,
|
||||
dependent_functions_by_module_abspath,
|
||||
explanation,
|
||||
optimized_code,
|
||||
function_to_optimize.qualified_name,
|
||||
)
|
||||
|
||||
new_code, new_dependent_code = self.reformat_code_and_dependents(
|
||||
dependent_functions_by_module_abspath,
|
||||
explanation.path,
|
||||
original_code,
|
||||
)
|
||||
|
||||
original_code_combined = original_dependent_code.copy()
|
||||
original_code_combined[explanation.path] = original_code
|
||||
new_code_combined = new_dependent_code.copy()
|
||||
new_code_combined[explanation.path] = new_code
|
||||
check_create_pr(
|
||||
optimize_all=self.args.all,
|
||||
path=path,
|
||||
original_code=original_dependent_code | {path: original_code},
|
||||
new_code=new_dependent_code | {path: new_code},
|
||||
explanation=explanation_final,
|
||||
original_code=original_code_combined,
|
||||
new_code=new_code_combined,
|
||||
explanation=explanation,
|
||||
existing_tests_source=existing_tests,
|
||||
generated_original_test_source=tests_and_optimizations.generated_original_test_source,
|
||||
)
|
||||
|
|
@ -430,11 +411,12 @@ class Optimizer:
|
|||
# a) Error propagation, where error in one function can cause the next optimization to fail
|
||||
# b) Performance estimates become unstable, as the runtime of an optimization might be
|
||||
# dependent on the runtime of the previous optimization
|
||||
with open(path, "w", encoding="utf8") as f:
|
||||
f.write(original_code)
|
||||
for module_abspath in dependent_functions_by_module_abspath.keys():
|
||||
with open(module_abspath, "w", encoding="utf8") as f:
|
||||
f.write(original_dependent_code[module_abspath])
|
||||
self.write_code_and_dependents(
|
||||
original_code,
|
||||
original_dependent_code,
|
||||
path,
|
||||
dependent_functions_by_module_abspath,
|
||||
)
|
||||
# Delete all the generated tests to not cause any clutter.
|
||||
pathlib.Path(generated_tests_path).unlink(missing_ok=True)
|
||||
for test_paths in instrumented_unittests_created_for_function:
|
||||
|
|
@ -453,6 +435,73 @@ class Optimizer:
|
|||
if hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir.cleanup()
|
||||
|
||||
def write_code_and_dependents(
|
||||
self,
|
||||
original_code: str,
|
||||
original_dependent_code: Dict[str, str],
|
||||
path: str,
|
||||
dependent_functions_by_module_abspath: Dict[str, set[str]],
|
||||
) -> None:
|
||||
with open(path, "w", encoding="utf8") as f:
|
||||
f.write(original_code)
|
||||
for module_abspath in dependent_functions_by_module_abspath:
|
||||
with open(module_abspath, "w", encoding="utf8") as f:
|
||||
f.write(original_dependent_code[module_abspath])
|
||||
|
||||
def reformat_code_and_dependents(
|
||||
self,
|
||||
dependent_functions_by_module_abspath: Dict[str, set[str]],
|
||||
path: str,
|
||||
original_code: str,
|
||||
) -> Tuple[str, Dict[str, str]]:
|
||||
should_sort_imports = True
|
||||
if sort_imports(self.args.imports_sort_cmd, should_sort_imports, path) != original_code:
|
||||
should_sort_imports = False
|
||||
|
||||
new_code = format_code(
|
||||
self.args.formatter_cmd,
|
||||
self.args.imports_sort_cmd,
|
||||
should_sort_imports,
|
||||
path,
|
||||
)
|
||||
new_dependent_code: dict[str, str] = {
|
||||
module_abspath: format_code(
|
||||
self.args.formatter_cmd,
|
||||
self.args.imports_sort_cmd,
|
||||
should_sort_imports,
|
||||
module_abspath,
|
||||
)
|
||||
for module_abspath in dependent_functions_by_module_abspath
|
||||
}
|
||||
return new_code, new_dependent_code
|
||||
|
||||
def replace_function_and_dependents_with_optimized_code(
|
||||
self,
|
||||
code_context: CodeOptimizationContext,
|
||||
dependent_functions_by_module_abspath: Dict[str, set[str]],
|
||||
explanation: Explanation,
|
||||
optimized_code: str,
|
||||
qualified_function_name: str,
|
||||
) -> None:
|
||||
replace_function_definitions_in_module(
|
||||
[qualified_function_name],
|
||||
optimized_code,
|
||||
explanation.path,
|
||||
code_context.preexisting_functions,
|
||||
code_context.contextual_dunder_methods,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
qualified_names,
|
||||
) in dependent_functions_by_module_abspath.items():
|
||||
replace_function_definitions_in_module(
|
||||
list(qualified_names),
|
||||
optimized_code,
|
||||
module_abspath,
|
||||
[],
|
||||
code_context.contextual_dunder_methods,
|
||||
)
|
||||
|
||||
def get_code_optimization_context(
|
||||
self,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
|
|
|
|||
|
|
@ -1,22 +1,38 @@
|
|||
import logging
|
||||
import os.path
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from codeflash.api import cfapi
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.git_utils import (
|
||||
get_current_branch,
|
||||
get_repo_owner_and_name,
|
||||
git_root_dir,
|
||||
get_current_branch,
|
||||
)
|
||||
from codeflash.discovery.discover_unit_tests import TestsInFile
|
||||
from codeflash.github.PrComment import FileDiffContent, PrComment
|
||||
from codeflash.result.explanation import Explanation
|
||||
|
||||
|
||||
def existing_tests_source_for(
|
||||
qualified_function_name: str,
|
||||
module_path: str,
|
||||
function_to_tests: Dict[str, List[TestsInFile]],
|
||||
) -> str:
|
||||
test_files = function_to_tests.get(module_path + "." + qualified_function_name)
|
||||
existing_tests = ""
|
||||
if test_files:
|
||||
for test_file in test_files:
|
||||
with open(test_file.test_file, encoding="utf8") as f:
|
||||
new_test = "".join(f.readlines())
|
||||
if new_test not in existing_tests:
|
||||
existing_tests += new_test
|
||||
return existing_tests
|
||||
|
||||
|
||||
def check_create_pr(
|
||||
optimize_all: bool,
|
||||
path: str,
|
||||
original_code: dict[str, str],
|
||||
new_code: dict[str, str],
|
||||
explanation: Explanation,
|
||||
|
|
@ -28,16 +44,17 @@ def check_create_pr(
|
|||
if pr_number is not None:
|
||||
logging.info(f"Suggesting changes to PR #{pr_number} ...")
|
||||
owner, repo = get_repo_owner_and_name()
|
||||
relative_path = str(pathlib.Path(os.path.relpath(path, git_root_dir())).as_posix())
|
||||
relative_path = str(pathlib.Path(os.path.relpath(explanation.path, git_root_dir())).as_posix())
|
||||
response = cfapi.suggest_changes(
|
||||
owner=owner,
|
||||
repo=repo,
|
||||
pr_number=pr_number,
|
||||
file_changes={
|
||||
str(pathlib.Path(os.path.relpath(p, git_root_dir())).as_posix()): FileDiffContent(
|
||||
oldContent=original_code[p], newContent=new_code[p]
|
||||
oldContent=original_code[p],
|
||||
newContent=new_code[p],
|
||||
)
|
||||
for p in original_code.keys()
|
||||
for p in original_code
|
||||
},
|
||||
pr_comment=PrComment(
|
||||
optimization_explanation=explanation.explanation_message(),
|
||||
|
|
@ -53,18 +70,18 @@ def check_create_pr(
|
|||
generated_tests=generated_original_test_source,
|
||||
)
|
||||
if response.ok:
|
||||
logging.info("Suggestions were successfully made to PR #" + str(pr_number))
|
||||
logging.info(f"Suggestions were successfully made to PR #{pr_number}")
|
||||
else:
|
||||
logging.error(
|
||||
f"Optimization was successful, but I failed to suggest changes to PR #{pr_number}."
|
||||
f" Response from server was: {response.text}"
|
||||
f" Response from server was: {response.text}",
|
||||
)
|
||||
|
||||
elif optimize_all:
|
||||
logging.info("Creating a new PR with the optimized code...")
|
||||
owner, repo = get_repo_owner_and_name()
|
||||
|
||||
relative_path = str(pathlib.Path(os.path.relpath(path, git_root_dir())).as_posix())
|
||||
relative_path = str(pathlib.Path(os.path.relpath(explanation.path, git_root_dir())).as_posix())
|
||||
base_branch = get_current_branch()
|
||||
response = cfapi.create_pr(
|
||||
owner=owner,
|
||||
|
|
@ -72,9 +89,10 @@ def check_create_pr(
|
|||
base_branch=base_branch,
|
||||
file_changes={
|
||||
str(pathlib.Path(os.path.relpath(p, git_root_dir())).as_posix()): FileDiffContent(
|
||||
oldContent=original_code[p], newContent=new_code[p]
|
||||
oldContent=original_code[p],
|
||||
newContent=new_code[p],
|
||||
)
|
||||
for p in original_code.keys()
|
||||
for p in original_code
|
||||
},
|
||||
pr_comment=PrComment(
|
||||
optimization_explanation=explanation.explanation_message(),
|
||||
|
|
@ -94,5 +112,5 @@ def check_create_pr(
|
|||
else:
|
||||
logging.error(
|
||||
f"Optimization was successful, but I failed to create a PR with the optimized code."
|
||||
f" Response from server was: {response.text}"
|
||||
f" Response from server was: {response.text}",
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue