Multi-file PR schema changes.

This commit is contained in:
renaud 2024-02-08 15:52:49 -08:00
parent 26a1ee87f6
commit 22c6cb2017
4 changed files with 54 additions and 12 deletions

View file

@ -1,10 +1,9 @@
from typing import Union
from codeflash.verification.test_results import TestResults
from pydantic import BaseModel
from pydantic.dataclasses import dataclass
from codeflash.verification.test_results import TestResults
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
class PrComment:
@ -12,7 +11,7 @@ class PrComment:
best_runtime: int
original_runtime: int
function_name: str
relative_file_path: list[str]
relative_file_path: str
speedup_x: str
speedup_pct: str
winning_test_results: TestResults

View file

@ -9,9 +9,8 @@ from argparse import ArgumentParser, SUPPRESS, Namespace
from collections import defaultdict
from typing import Tuple, Union
import libcst as cst
import codeflash.cli_cmds.logging_config # intializes logging, has to be the first non-system import # noqa
import libcst as cst
from codeflash.api.aiservice import optimize_python_code
from codeflash.cli_cmds.cli import process_cmd_args
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO
@ -401,6 +400,7 @@ class Optimizer:
check_create_pr(
optimize_all=self.args.all,
path=path,
original_code=original_dependent_code
| {path: original_code},
new_code=new_dependent_code | {path: new_code},

View file

@ -15,6 +15,7 @@ from codeflash.result.explanation import Explanation
def check_create_pr(
optimize_all: bool,
path: str,
original_code: dict[str, str],
new_code: dict[str, str],
explanation: Explanation,
@ -25,13 +26,12 @@ def check_create_pr(
if pr_number is not None:
logging.info(f"Suggesting changes to PR #{pr_number} ...")
owner, repo = get_repo_owner_and_name()
relative_file_path = [
os.path.relpath(p, git_root_dir()) for p in original_code.keys()
]
relative_path = os.path.relpath(path, git_root_dir())
response = cfapi.suggest_changes(
owner=owner,
repo=repo,
pr_number=pr_number,
relative_path=relative_path,
file_changes={
os.path.relpath(p, git_root_dir()): FileDiffContent(
oldContent=original_code[p], newContent=new_code[p]
@ -43,7 +43,7 @@ def check_create_pr(
best_runtime=explanation.best_runtime_ns,
original_runtime=explanation.original_runtime_ns,
function_name=explanation.function_name,
relative_file_path=relative_file_path,
relative_file_path=relative_path,
speedup_x=explanation.speedup_x,
speedup_pct=explanation.speedup_pct,
winning_test_results=explanation.winning_test_results,
@ -62,9 +62,7 @@ def check_create_pr(
logging.info("Creating a new PR with the optimized code...")
owner, repo = get_repo_owner_and_name()
relative_path = [
os.path.relpath(p, git_root_dir()) for p in original_code.keys()
]
relative_path = os.path.relpath(path, git_root_dir())
base_branch = get_current_branch()
response = cfapi.create_pr(
owner=owner,

View file

@ -220,3 +220,48 @@ print("Salut monde")
original_code, function_names, optim_code, preexisting_functions, "module"
)
assert new_code == expected
def test_test_libcst_code_replacement5():
optim_code = """def sorter_deps(arr):
supersort(badsort(arr))
return arr
def badsort(ploc):
donothing(ploc)
def supersort(doink):
for i in range(len(doink)):
fix(doink, i)
"""
original_code = """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
def sorter_deps(arr):
for i in range(len(arr)):
for j in range(len(arr) - 1):
if dep1_comparer(arr, j):
dep2_swap(arr, j)
return arr
"""
expected = """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
def sorter_deps(arr):
supersort(badsort(arr))
return arr
def badsort(ploc):
donothing(ploc)
def supersort(doink):
for i in range(len(doink)):
fix(doink, i)
"""
function_names: list[str] = ["sorter_deps"]
preexisting_functions: list[str] = ["sorter_deps"]
new_code: str = replace_functions_in_file(
original_code, function_names, optim_code, preexisting_functions, "module"
)
assert new_code == expected