mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Multi-file PR schema changes.
This commit is contained in:
parent
26a1ee87f6
commit
22c6cb2017
4 changed files with 54 additions and 12 deletions
|
|
@ -1,10 +1,9 @@
|
|||
from typing import Union
|
||||
|
||||
from codeflash.verification.test_results import TestResults
|
||||
from pydantic import BaseModel
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.verification.test_results import TestResults
|
||||
|
||||
|
||||
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
||||
class PrComment:
|
||||
|
|
@ -12,7 +11,7 @@ class PrComment:
|
|||
best_runtime: int
|
||||
original_runtime: int
|
||||
function_name: str
|
||||
relative_file_path: list[str]
|
||||
relative_file_path: str
|
||||
speedup_x: str
|
||||
speedup_pct: str
|
||||
winning_test_results: TestResults
|
||||
|
|
|
|||
|
|
@ -9,9 +9,8 @@ from argparse import ArgumentParser, SUPPRESS, Namespace
|
|||
from collections import defaultdict
|
||||
from typing import Tuple, Union
|
||||
|
||||
import libcst as cst
|
||||
|
||||
import codeflash.cli_cmds.logging_config # intializes logging, has to be the first non-system import # noqa
|
||||
import libcst as cst
|
||||
from codeflash.api.aiservice import optimize_python_code
|
||||
from codeflash.cli_cmds.cli import process_cmd_args
|
||||
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO
|
||||
|
|
@ -401,6 +400,7 @@ class Optimizer:
|
|||
|
||||
check_create_pr(
|
||||
optimize_all=self.args.all,
|
||||
path=path,
|
||||
original_code=original_dependent_code
|
||||
| {path: original_code},
|
||||
new_code=new_dependent_code | {path: new_code},
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from codeflash.result.explanation import Explanation
|
|||
|
||||
def check_create_pr(
|
||||
optimize_all: bool,
|
||||
path: str,
|
||||
original_code: dict[str, str],
|
||||
new_code: dict[str, str],
|
||||
explanation: Explanation,
|
||||
|
|
@ -25,13 +26,12 @@ def check_create_pr(
|
|||
if pr_number is not None:
|
||||
logging.info(f"Suggesting changes to PR #{pr_number} ...")
|
||||
owner, repo = get_repo_owner_and_name()
|
||||
relative_file_path = [
|
||||
os.path.relpath(p, git_root_dir()) for p in original_code.keys()
|
||||
]
|
||||
relative_path = os.path.relpath(path, git_root_dir())
|
||||
response = cfapi.suggest_changes(
|
||||
owner=owner,
|
||||
repo=repo,
|
||||
pr_number=pr_number,
|
||||
relative_path=relative_path,
|
||||
file_changes={
|
||||
os.path.relpath(p, git_root_dir()): FileDiffContent(
|
||||
oldContent=original_code[p], newContent=new_code[p]
|
||||
|
|
@ -43,7 +43,7 @@ def check_create_pr(
|
|||
best_runtime=explanation.best_runtime_ns,
|
||||
original_runtime=explanation.original_runtime_ns,
|
||||
function_name=explanation.function_name,
|
||||
relative_file_path=relative_file_path,
|
||||
relative_file_path=relative_path,
|
||||
speedup_x=explanation.speedup_x,
|
||||
speedup_pct=explanation.speedup_pct,
|
||||
winning_test_results=explanation.winning_test_results,
|
||||
|
|
@ -62,9 +62,7 @@ def check_create_pr(
|
|||
logging.info("Creating a new PR with the optimized code...")
|
||||
owner, repo = get_repo_owner_and_name()
|
||||
|
||||
relative_path = [
|
||||
os.path.relpath(p, git_root_dir()) for p in original_code.keys()
|
||||
]
|
||||
relative_path = os.path.relpath(path, git_root_dir())
|
||||
base_branch = get_current_branch()
|
||||
response = cfapi.create_pr(
|
||||
owner=owner,
|
||||
|
|
|
|||
|
|
@ -220,3 +220,48 @@ print("Salut monde")
|
|||
original_code, function_names, optim_code, preexisting_functions, "module"
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement5():
|
||||
optim_code = """def sorter_deps(arr):
|
||||
supersort(badsort(arr))
|
||||
return arr
|
||||
|
||||
def badsort(ploc):
|
||||
donothing(ploc)
|
||||
|
||||
def supersort(doink):
|
||||
for i in range(len(doink)):
|
||||
fix(doink, i)
|
||||
"""
|
||||
|
||||
original_code = """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
|
||||
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
|
||||
|
||||
def sorter_deps(arr):
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
if dep1_comparer(arr, j):
|
||||
dep2_swap(arr, j)
|
||||
return arr
|
||||
"""
|
||||
expected = """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
|
||||
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
|
||||
def sorter_deps(arr):
|
||||
supersort(badsort(arr))
|
||||
return arr
|
||||
|
||||
def badsort(ploc):
|
||||
donothing(ploc)
|
||||
|
||||
def supersort(doink):
|
||||
for i in range(len(doink)):
|
||||
fix(doink, i)
|
||||
"""
|
||||
|
||||
function_names: list[str] = ["sorter_deps"]
|
||||
preexisting_functions: list[str] = ["sorter_deps"]
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, function_names, optim_code, preexisting_functions, "module"
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
|
|||
Loading…
Reference in a new issue