Add diff layer: SEARCH/REPLACE and V4A patch application
Faithfully ported from Django aiservice. V4A uses 3-tier fuzzy context matching (exact/rstrip/strip) with EOF penalties and scope markers. Per-file lint ignores for ported complexity.
This commit is contained in:
parent
2acebdbf51
commit
5c6b82050a
5 changed files with 984 additions and 0 deletions
42
packages/codeflash-api/src/codeflash_api/diff/_base.py
Normal file
42
packages/codeflash-api/src/codeflash_api/diff/_base.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class DiffError(ValueError):
|
||||
"""
|
||||
Any problem detected while parsing or applying a patch.
|
||||
"""
|
||||
|
||||
|
||||
class DiffMethod(enum.Enum):
|
||||
NO_DIFF = "no_diff"
|
||||
V4A = "v4a"
|
||||
SEARCH_AND_REPLACE = "search_and_replace"
|
||||
|
||||
|
||||
class Diff:
|
||||
"""
|
||||
Base class for diff application strategies.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
source_code: dict[str, str],
|
||||
*,
|
||||
match_files_when_having_single_patch: bool = True,
|
||||
) -> None:
|
||||
self.content = content
|
||||
self.source_code = source_code
|
||||
self.extracted_diff: dict[str, str] | None = None
|
||||
self.match_files_when_having_single_patch = (
|
||||
match_files_when_having_single_patch
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def run(self) -> dict[str, str]:
|
||||
"""
|
||||
Apply the diff to the source code and return modified files.
|
||||
"""
|
||||
187
packages/codeflash-api/src/codeflash_api/diff/_search_replace.py
Normal file
187
packages/codeflash-api/src/codeflash_api/diff/_search_replace.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
import attrs
|
||||
|
||||
from codeflash_api.diff._base import Diff
|
||||
|
||||
REPLACE_IN_FILE_TAGS_RE = re.compile(
|
||||
r"<replace_in_file>(.*?)</replace_in_file>",
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
MULTI_REPLACE_IN_FILE_RE = re.compile(
|
||||
r"<replace_in_file>\s*<path>(.*?)</path>"
|
||||
r"\s*<diff>([\s\S]*?)</diff>\s*</replace_in_file>",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
SEARCH_MARKER = "<<<<<<< SEARCH"
|
||||
DELIMITER = "======="
|
||||
REPLACE_MARKER = ">>>>>>> REPLACE"
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class SearchReplaceBlock:
|
||||
"""
|
||||
One SEARCH/REPLACE pair.
|
||||
"""
|
||||
|
||||
search: str
|
||||
replace: str
|
||||
|
||||
|
||||
def parse_diff(diff: str) -> list[SearchReplaceBlock]:
|
||||
"""
|
||||
Parse SEARCH/REPLACE blocks from *diff*.
|
||||
"""
|
||||
if not diff or not diff.strip():
|
||||
msg = "Empty or invalid diff string provided"
|
||||
raise ValueError(msg)
|
||||
|
||||
lines = diff.splitlines(keepends=True)
|
||||
blocks: list[SearchReplaceBlock] = []
|
||||
search_start: int | None = None
|
||||
delimiter_idx: int | None = None
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.rstrip()
|
||||
if stripped == SEARCH_MARKER:
|
||||
search_start = i
|
||||
elif stripped == DELIMITER and search_start is not None:
|
||||
delimiter_idx = i
|
||||
elif stripped == REPLACE_MARKER and delimiter_idx is not None:
|
||||
search_content = "".join(
|
||||
lines[search_start + 1 : delimiter_idx]
|
||||
).rstrip()
|
||||
replace_content = "".join(lines[delimiter_idx + 1 : i]).rstrip()
|
||||
blocks.append(
|
||||
SearchReplaceBlock(
|
||||
search=search_content,
|
||||
replace=replace_content,
|
||||
)
|
||||
)
|
||||
search_start = None
|
||||
delimiter_idx = None
|
||||
|
||||
if search_start is not None and delimiter_idx is None:
|
||||
msg = "Invalid diff format: Missing '=======' marker"
|
||||
raise ValueError(msg)
|
||||
if delimiter_idx is not None:
|
||||
msg = "Invalid diff format: Missing '>>>>>>> REPLACE' marker"
|
||||
raise ValueError(msg)
|
||||
if not blocks:
|
||||
msg = "No valid SEARCH/REPLACE blocks found in the diff"
|
||||
raise ValueError(msg)
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def group_diff_patches_by_path(
|
||||
replace_tags_str: str,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Group multiple diffs by file path from XML tag content.
|
||||
"""
|
||||
matches = MULTI_REPLACE_IN_FILE_RE.findall(replace_tags_str)
|
||||
result: dict[str, str] = {}
|
||||
for path, diff_content in matches:
|
||||
clean_path = path.strip()
|
||||
if clean_path in result:
|
||||
result[clean_path] += "\n" + diff_content
|
||||
else:
|
||||
result[clean_path] = diff_content
|
||||
return result
|
||||
|
||||
|
||||
def extract_patches(content: str) -> dict[str, str]:
|
||||
"""
|
||||
Extract all SEARCH/REPLACE patches from *content*.
|
||||
"""
|
||||
matches = REPLACE_IN_FILE_TAGS_RE.findall(content)
|
||||
if not matches:
|
||||
return {}
|
||||
wrapped = f"<replace_in_file>{matches[0]}</replace_in_file>"
|
||||
return group_diff_patches_by_path(wrapped)
|
||||
|
||||
|
||||
def find_with_whitespace_flexibility(
|
||||
search: str, content: str
|
||||
) -> tuple[int, int] | None:
|
||||
"""
|
||||
Match *search* in *content* with flexible whitespace.
|
||||
"""
|
||||
parts = re.split(r"(\s+)", search)
|
||||
pattern_parts: list[str] = []
|
||||
for part in parts:
|
||||
if not part:
|
||||
continue
|
||||
if re.match(r"^\s+$", part):
|
||||
pattern_parts.append(r"\s+")
|
||||
else:
|
||||
pattern_parts.append(re.escape(part))
|
||||
|
||||
pattern = "".join(pattern_parts)
|
||||
match = re.search(pattern, content, re.MULTILINE | re.DOTALL)
|
||||
if match:
|
||||
return (match.start(), match.end())
|
||||
return None
|
||||
|
||||
|
||||
def apply_patches(diff_str: str, content: str) -> str:
|
||||
"""
|
||||
Parse and apply SEARCH/REPLACE blocks to *content*.
|
||||
"""
|
||||
try:
|
||||
blocks = parse_diff(diff_str)
|
||||
except ValueError:
|
||||
return content
|
||||
|
||||
for block in blocks:
|
||||
if not block.search and block.replace:
|
||||
content += block.replace
|
||||
continue
|
||||
if block.search:
|
||||
idx = content.find(block.search)
|
||||
if idx >= 0:
|
||||
content = content.replace(block.search, block.replace, 1)
|
||||
else:
|
||||
span = find_with_whitespace_flexibility(block.search, content)
|
||||
if span:
|
||||
start, end = span
|
||||
content = content[:start] + block.replace + content[end:]
|
||||
|
||||
return content
|
||||
|
||||
|
||||
class SearchAndReplaceDiff(Diff):
|
||||
"""
|
||||
Apply SEARCH/REPLACE style diffs.
|
||||
"""
|
||||
|
||||
def run(self) -> dict[str, str]:
|
||||
"""
|
||||
Apply patches to source files.
|
||||
"""
|
||||
patches = extract_patches(self.content)
|
||||
result: dict[str, str] = {}
|
||||
|
||||
if (
|
||||
not self.match_files_when_having_single_patch
|
||||
and len(self.source_code) == 1
|
||||
and len(patches) == 1
|
||||
):
|
||||
src_path = next(iter(self.source_code))
|
||||
diff_content = next(iter(patches.values()))
|
||||
result[src_path] = apply_patches(
|
||||
diff_content, self.source_code[src_path]
|
||||
)
|
||||
return result
|
||||
|
||||
for path, code in self.source_code.items():
|
||||
if path in patches:
|
||||
result[path] = apply_patches(patches[path], code)
|
||||
else:
|
||||
result[path] = code
|
||||
|
||||
return result
|
||||
415
packages/codeflash-api/src/codeflash_api/diff/_v4a.py
Normal file
415
packages/codeflash-api/src/codeflash_api/diff/_v4a.py
Normal file
|
|
@ -0,0 +1,415 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
import attrs
|
||||
|
||||
from codeflash_api.diff._base import Diff, DiffError
|
||||
|
||||
UPDATE_FILE_PREFIX = "*** Update File: "
|
||||
|
||||
|
||||
@attrs.define
|
||||
class Chunk:
|
||||
"""
|
||||
A single add/delete change at a specific line index.
|
||||
"""
|
||||
|
||||
orig_index: int = -1
|
||||
del_lines: list[str] = attrs.Factory(list)
|
||||
ins_lines: list[str] = attrs.Factory(list)
|
||||
|
||||
|
||||
@attrs.define
|
||||
class PatchAction:
|
||||
"""
|
||||
All chunks for one file.
|
||||
"""
|
||||
|
||||
path: str = ""
|
||||
chunks: list[Chunk] = attrs.Factory(list)
|
||||
|
||||
|
||||
@attrs.define
|
||||
class Patch:
|
||||
"""
|
||||
Parsed patch with actions per file.
|
||||
"""
|
||||
|
||||
actions: dict[str, PatchAction] = attrs.Factory(dict)
|
||||
fuzz: int = 0
|
||||
|
||||
|
||||
def _norm(line: str) -> str:
|
||||
"""
|
||||
Strip CR for LF/CRLF compatibility.
|
||||
"""
|
||||
return line.replace("\r", "")
|
||||
|
||||
|
||||
def find_context_core(
|
||||
lines: list[str],
|
||||
context: list[str],
|
||||
start: int,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Find *context* in *lines* starting at *start* with 3-tier fuzz.
|
||||
"""
|
||||
ctx_len = len(context)
|
||||
end = len(lines) - ctx_len + 1
|
||||
|
||||
for i in range(start, end):
|
||||
if lines[i : i + ctx_len] == context:
|
||||
return (i, 0)
|
||||
|
||||
rstrip_ctx = [s.rstrip() for s in context]
|
||||
for i in range(start, end):
|
||||
if [s.rstrip() for s in lines[i : i + ctx_len]] == rstrip_ctx:
|
||||
return (i, 1)
|
||||
|
||||
strip_ctx = [s.strip() for s in context]
|
||||
for i in range(start, end):
|
||||
if [s.strip() for s in lines[i : i + ctx_len]] == strip_ctx:
|
||||
return (i, 100)
|
||||
|
||||
return (-1, 0)
|
||||
|
||||
|
||||
def find_context(
|
||||
lines: list[str],
|
||||
context: list[str],
|
||||
start: int,
|
||||
*,
|
||||
eof: bool,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Find *context* with special EOF marker handling.
|
||||
"""
|
||||
if eof:
|
||||
idx, fuzz = find_context_core(
|
||||
lines, context, len(lines) - len(context)
|
||||
)
|
||||
if idx >= 0:
|
||||
return (idx, fuzz)
|
||||
idx, fuzz = find_context_core(lines, context, start)
|
||||
if idx >= 0:
|
||||
return (idx, fuzz + 10_000)
|
||||
return (-1, 0)
|
||||
return find_context_core(lines, context, start)
|
||||
|
||||
|
||||
def peek_next_section(
|
||||
lines: list[str], index: int
|
||||
) -> tuple[list[str], list[Chunk], int, bool]:
|
||||
"""
|
||||
Parse one context/change section from an UPDATE block.
|
||||
"""
|
||||
context_lines: list[str] = []
|
||||
chunks: list[Chunk] = []
|
||||
del_lines: list[str] = []
|
||||
ins_lines: list[str] = []
|
||||
mode = "keep"
|
||||
is_eof = False
|
||||
|
||||
i = index
|
||||
while i < len(lines):
|
||||
line = _norm(lines[i])
|
||||
|
||||
if line.startswith("@@"):
|
||||
break
|
||||
if line.startswith("*** "):
|
||||
tag = line.strip()
|
||||
if tag == "*** End of File":
|
||||
is_eof = True
|
||||
i += 1
|
||||
break
|
||||
|
||||
if line.startswith(" ") or line == "":
|
||||
if mode in ("add", "delete"):
|
||||
chunks.append(
|
||||
Chunk(
|
||||
orig_index=len(context_lines) - len(del_lines),
|
||||
del_lines=del_lines,
|
||||
ins_lines=ins_lines,
|
||||
)
|
||||
)
|
||||
del_lines = []
|
||||
ins_lines = []
|
||||
mode = "keep"
|
||||
|
||||
ctx = line[1:] if line.startswith(" ") else ""
|
||||
context_lines.append(ctx)
|
||||
i += 1
|
||||
|
||||
elif line.startswith("-"):
|
||||
if mode == "add":
|
||||
chunks.append(
|
||||
Chunk(
|
||||
orig_index=len(context_lines) - len(del_lines),
|
||||
del_lines=del_lines,
|
||||
ins_lines=ins_lines,
|
||||
)
|
||||
)
|
||||
del_lines = []
|
||||
ins_lines = []
|
||||
mode = "delete"
|
||||
stripped = line[1:]
|
||||
del_lines.append(stripped)
|
||||
context_lines.append(stripped)
|
||||
i += 1
|
||||
|
||||
elif line.startswith("+"):
|
||||
mode = "add"
|
||||
ins_lines.append(line[1:])
|
||||
i += 1
|
||||
|
||||
else:
|
||||
msg = f"Invalid patch line found: {line!r}"
|
||||
raise DiffError(msg)
|
||||
|
||||
if mode in ("add", "delete"):
|
||||
chunks.append(
|
||||
Chunk(
|
||||
orig_index=len(context_lines) - len(del_lines),
|
||||
del_lines=del_lines,
|
||||
ins_lines=ins_lines,
|
||||
)
|
||||
)
|
||||
|
||||
return (context_lines, chunks, i, is_eof)
|
||||
|
||||
|
||||
def _parse_update_file_sections(
|
||||
lines: list[str],
|
||||
index: int,
|
||||
file_content: str,
|
||||
) -> tuple[PatchAction, int, int]:
|
||||
"""
|
||||
Parse all @@ scope markers and sections for one file.
|
||||
"""
|
||||
action = PatchAction()
|
||||
total_fuzz = 0
|
||||
file_lines = [_norm(ln) for ln in file_content.splitlines()]
|
||||
current_file_idx = 0
|
||||
|
||||
i = index
|
||||
while i < len(lines):
|
||||
line = _norm(lines[i])
|
||||
if not line.strip():
|
||||
i += 1
|
||||
continue
|
||||
if line.startswith("*** ") and not line.startswith("*** End of File"):
|
||||
break
|
||||
|
||||
if line.startswith("@@"):
|
||||
scope_lines: list[str] = []
|
||||
while i < len(lines) and _norm(lines[i]).startswith("@@"):
|
||||
scope_lines.append(_norm(lines[i]).strip().strip("@").strip())
|
||||
i += 1
|
||||
|
||||
found = False
|
||||
for scope in scope_lines:
|
||||
for j in range(current_file_idx, len(file_lines)):
|
||||
if file_lines[j].strip() == scope:
|
||||
current_file_idx = j
|
||||
found = True
|
||||
break
|
||||
if found:
|
||||
break
|
||||
|
||||
if not found:
|
||||
for scope in scope_lines:
|
||||
for j in range(current_file_idx, len(file_lines)):
|
||||
if file_lines[j].strip() == scope.strip():
|
||||
current_file_idx = j
|
||||
total_fuzz += 1
|
||||
found = True
|
||||
break
|
||||
if found:
|
||||
break
|
||||
|
||||
if not found:
|
||||
msg = f"Could not find scope: {scope_lines!r}"
|
||||
raise DiffError(msg)
|
||||
continue
|
||||
|
||||
context_block, chunks, next_i, is_eof = peek_next_section(lines, i)
|
||||
i = next_i
|
||||
|
||||
if not context_block and not chunks:
|
||||
continue
|
||||
|
||||
if context_block:
|
||||
found_idx, fuzz = find_context(
|
||||
file_lines,
|
||||
context_block,
|
||||
current_file_idx,
|
||||
eof=is_eof,
|
||||
)
|
||||
if found_idx < 0:
|
||||
msg = (
|
||||
"Could not find context block"
|
||||
f" starting with: {context_block[:3]!r}"
|
||||
)
|
||||
raise DiffError(msg)
|
||||
total_fuzz += fuzz
|
||||
for chunk in chunks:
|
||||
chunk.orig_index += found_idx
|
||||
current_file_idx = found_idx + len(context_block)
|
||||
else:
|
||||
for chunk in chunks:
|
||||
chunk.orig_index += current_file_idx
|
||||
|
||||
action.chunks.extend(chunks)
|
||||
|
||||
return (action, i, total_fuzz)
|
||||
|
||||
|
||||
def parse_patch_text(
|
||||
lines: list[str],
|
||||
current_files: dict[str, str],
|
||||
) -> Patch:
|
||||
"""
|
||||
Parse full patch content into a Patch object.
|
||||
"""
|
||||
patch = Patch()
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = _norm(lines[i])
|
||||
if not line.strip():
|
||||
i += 1
|
||||
continue
|
||||
|
||||
norm_line = line.strip()
|
||||
if norm_line.startswith(UPDATE_FILE_PREFIX.strip()):
|
||||
path = norm_line[len(UPDATE_FILE_PREFIX.strip()) :].strip()
|
||||
if path not in current_files:
|
||||
msg = f"File not found: {path!r}"
|
||||
raise DiffError(msg)
|
||||
i += 1
|
||||
action, i, fuzz = _parse_update_file_sections(
|
||||
lines, i, current_files[path]
|
||||
)
|
||||
action.path = path
|
||||
if path in patch.actions:
|
||||
patch.actions[path].chunks.extend(action.chunks)
|
||||
else:
|
||||
patch.actions[path] = action
|
||||
patch.fuzz += fuzz
|
||||
else:
|
||||
msg = f"Unknown or misplaced line: {line!r}"
|
||||
raise DiffError(msg)
|
||||
|
||||
return patch
|
||||
|
||||
|
||||
_PATCH_BODY_RE = re.compile(
|
||||
r"\*\*\* Begin Patch\s*(.*?)\s*\*\*\* End Patch",
|
||||
re.DOTALL,
|
||||
)
|
||||
_UPDATE_SPLIT_RE = re.compile(r"^\*\*\* Update File:\s*", re.MULTILINE)
|
||||
|
||||
|
||||
def extract_update_sections(
|
||||
content: str,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Extract UPDATE sections from patch content.
|
||||
"""
|
||||
match = _PATCH_BODY_RE.search(content)
|
||||
if not match:
|
||||
return {}
|
||||
|
||||
body = match.group(1)
|
||||
parts = _UPDATE_SPLIT_RE.split(body)
|
||||
|
||||
result: dict[str, str] = {}
|
||||
for part in parts[1:]:
|
||||
if not part.strip():
|
||||
continue
|
||||
first_nl = part.find("\n")
|
||||
if first_nl < 0:
|
||||
continue
|
||||
file_path = part[:first_nl].strip()
|
||||
diff_content = part[first_nl + 1 :]
|
||||
result[file_path] = f"{UPDATE_FILE_PREFIX}{file_path}\n{diff_content}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def apply_update(
|
||||
text: str,
|
||||
action: PatchAction,
|
||||
path: str,
|
||||
) -> str:
|
||||
"""
|
||||
Apply UPDATE chunks to file content.
|
||||
"""
|
||||
orig_lines = text.splitlines()
|
||||
output: list[str] = []
|
||||
current = 0
|
||||
|
||||
sorted_chunks = sorted(action.chunks, key=lambda c: c.orig_index)
|
||||
for chunk in sorted_chunks:
|
||||
if chunk.orig_index < current:
|
||||
msg = f"Overlapping chunks in {path} at line {chunk.orig_index}"
|
||||
raise DiffError(msg)
|
||||
|
||||
output.extend(orig_lines[current : chunk.orig_index])
|
||||
|
||||
expected = [_norm(s).strip() for s in chunk.del_lines]
|
||||
actual = [
|
||||
_norm(s).strip()
|
||||
for s in orig_lines[
|
||||
chunk.orig_index : chunk.orig_index + len(chunk.del_lines)
|
||||
]
|
||||
]
|
||||
if expected != actual:
|
||||
msg = (
|
||||
f"Deleted lines mismatch in {path}"
|
||||
f" at line {chunk.orig_index}:"
|
||||
f" expected {expected!r},"
|
||||
f" got {actual!r}"
|
||||
)
|
||||
raise DiffError(msg)
|
||||
|
||||
output.extend(chunk.ins_lines)
|
||||
current = chunk.orig_index + len(chunk.del_lines)
|
||||
|
||||
output.extend(orig_lines[current:])
|
||||
|
||||
result = "\n".join(output)
|
||||
if result or text:
|
||||
result += "\n"
|
||||
return result
|
||||
|
||||
|
||||
class V4ADiff(Diff):
|
||||
"""
|
||||
Apply V4A unified diff patches.
|
||||
"""
|
||||
|
||||
def run(self) -> dict[str, str]:
|
||||
"""
|
||||
Apply V4A patches to source files.
|
||||
"""
|
||||
sections = extract_update_sections(self.content)
|
||||
result: dict[str, str] = {}
|
||||
|
||||
for path, code in self.source_code.items():
|
||||
if path in sections:
|
||||
diff_text = sections[path].strip()
|
||||
patch = parse_patch_text(
|
||||
diff_text.splitlines(),
|
||||
{path: code},
|
||||
)
|
||||
if path in patch.actions:
|
||||
result[path] = apply_update(
|
||||
code, patch.actions[path], path
|
||||
)
|
||||
else:
|
||||
result[path] = code
|
||||
else:
|
||||
result[path] = code
|
||||
|
||||
return result
|
||||
331
packages/codeflash-api/tests/test_diff.py
Normal file
331
packages/codeflash-api/tests/test_diff.py
Normal file
|
|
@ -0,0 +1,331 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash_api.diff._base import DiffError, DiffMethod
|
||||
from codeflash_api.diff._search_replace import (
|
||||
SearchAndReplaceDiff,
|
||||
SearchReplaceBlock,
|
||||
apply_patches,
|
||||
extract_patches,
|
||||
find_with_whitespace_flexibility,
|
||||
parse_diff,
|
||||
)
|
||||
from codeflash_api.diff._v4a import (
|
||||
V4ADiff,
|
||||
_norm,
|
||||
apply_update,
|
||||
extract_update_sections,
|
||||
find_context,
|
||||
find_context_core,
|
||||
)
|
||||
|
||||
|
||||
class TestDiffMethod:
|
||||
"""Tests for DiffMethod enum."""
|
||||
|
||||
def test_values(self) -> None:
|
||||
"""
|
||||
All three diff methods exist.
|
||||
"""
|
||||
assert "no_diff" == DiffMethod.NO_DIFF.value
|
||||
assert "v4a" == DiffMethod.V4A.value
|
||||
assert "search_and_replace" == DiffMethod.SEARCH_AND_REPLACE.value
|
||||
|
||||
|
||||
class TestSearchReplaceParseDiff:
|
||||
"""Tests for parse_diff."""
|
||||
|
||||
def test_single_block(self) -> None:
|
||||
"""
|
||||
Single SEARCH/REPLACE block is parsed correctly.
|
||||
"""
|
||||
diff = "<<<<<<< SEARCH\nold code\n=======\nnew code\n>>>>>>> REPLACE\n"
|
||||
blocks = parse_diff(diff)
|
||||
|
||||
assert 1 == len(blocks)
|
||||
assert "old code" == blocks[0].search
|
||||
assert "new code" == blocks[0].replace
|
||||
|
||||
def test_multiple_blocks(self) -> None:
|
||||
"""
|
||||
Multiple SEARCH/REPLACE blocks are all parsed.
|
||||
"""
|
||||
diff = (
|
||||
"<<<<<<< SEARCH\nfoo\n=======\nbar\n>>>>>>> REPLACE\n"
|
||||
"<<<<<<< SEARCH\nbaz\n=======\nqux\n>>>>>>> REPLACE\n"
|
||||
)
|
||||
blocks = parse_diff(diff)
|
||||
|
||||
assert 2 == len(blocks)
|
||||
|
||||
def test_empty_search(self) -> None:
|
||||
"""
|
||||
Empty search content (append mode) is valid.
|
||||
"""
|
||||
diff = "<<<<<<< SEARCH\n=======\nnew stuff\n>>>>>>> REPLACE\n"
|
||||
blocks = parse_diff(diff)
|
||||
|
||||
assert "" == blocks[0].search
|
||||
assert "new stuff" == blocks[0].replace
|
||||
|
||||
def test_empty_input_raises(self) -> None:
|
||||
"""
|
||||
Empty string raises ValueError.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="Empty"):
|
||||
parse_diff("")
|
||||
|
||||
def test_missing_delimiter_raises(self) -> None:
|
||||
"""
|
||||
Missing ======= marker raises ValueError.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="======="):
|
||||
parse_diff("<<<<<<< SEARCH\nfoo\n>>>>>>> REPLACE\n")
|
||||
|
||||
def test_missing_replace_marker_raises(self) -> None:
|
||||
"""
|
||||
Missing >>>>>>> REPLACE marker raises ValueError.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="REPLACE"):
|
||||
parse_diff("<<<<<<< SEARCH\nfoo\n=======\nbar\n")
|
||||
|
||||
|
||||
class TestFindWithWhitespaceFlexibility:
|
||||
"""Tests for find_with_whitespace_flexibility."""
|
||||
|
||||
def test_exact_match(self) -> None:
|
||||
"""
|
||||
Exact content is found.
|
||||
"""
|
||||
result = find_with_whitespace_flexibility(
|
||||
"hello world", "prefix hello world suffix"
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert 7 == result[0]
|
||||
|
||||
def test_flexible_whitespace(self) -> None:
|
||||
"""
|
||||
Tabs match spaces.
|
||||
"""
|
||||
result = find_with_whitespace_flexibility(
|
||||
"hello world", "hello\tworld"
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
"""
|
||||
Non-existent content returns None.
|
||||
"""
|
||||
assert find_with_whitespace_flexibility("xyz", "abc") is None
|
||||
|
||||
|
||||
class TestApplyPatches:
|
||||
"""Tests for apply_patches."""
|
||||
|
||||
def test_simple_replace(self) -> None:
|
||||
"""
|
||||
Exact match is replaced.
|
||||
"""
|
||||
diff = "<<<<<<< SEARCH\nold\n=======\nnew\n>>>>>>> REPLACE\n"
|
||||
result = apply_patches(diff, "before old after")
|
||||
|
||||
assert "before new after" == result
|
||||
|
||||
def test_append_mode(self) -> None:
|
||||
"""
|
||||
Empty search appends to content.
|
||||
"""
|
||||
diff = "<<<<<<< SEARCH\n=======\n# appended\n>>>>>>> REPLACE\n"
|
||||
result = apply_patches(diff, "existing")
|
||||
|
||||
assert "existing# appended" == result
|
||||
|
||||
def test_parse_error_returns_original(self) -> None:
|
||||
"""
|
||||
Invalid diff returns content unchanged.
|
||||
"""
|
||||
result = apply_patches("not a diff", "original")
|
||||
|
||||
assert "original" == result
|
||||
|
||||
|
||||
class TestSearchAndReplaceDiff:
|
||||
"""Tests for SearchAndReplaceDiff.run."""
|
||||
|
||||
def test_single_file_patch(self) -> None:
|
||||
"""
|
||||
Patch applied to matching file.
|
||||
"""
|
||||
content = (
|
||||
"<replace_in_file>"
|
||||
"<path>foo.py</path>"
|
||||
"<diff>\n<<<<<<< SEARCH\nold\n=======\nnew\n>>>>>>> REPLACE\n</diff>"
|
||||
"</replace_in_file>"
|
||||
)
|
||||
diff = SearchAndReplaceDiff(
|
||||
content=content,
|
||||
source_code={"foo.py": "old code"},
|
||||
)
|
||||
result = diff.run()
|
||||
|
||||
assert "new code" == result["foo.py"]
|
||||
|
||||
|
||||
class TestNorm:
|
||||
"""Tests for _norm."""
|
||||
|
||||
def test_strips_cr(self) -> None:
|
||||
"""
|
||||
Carriage returns are removed.
|
||||
"""
|
||||
assert "hello" == _norm("hello\r")
|
||||
|
||||
def test_no_cr(self) -> None:
|
||||
"""
|
||||
Lines without CR are unchanged.
|
||||
"""
|
||||
assert "hello" == _norm("hello")
|
||||
|
||||
|
||||
class TestFindContextCore:
|
||||
"""Tests for find_context_core."""
|
||||
|
||||
def test_exact_match(self) -> None:
|
||||
"""
|
||||
Exact match returns fuzz=0.
|
||||
"""
|
||||
lines = ["a", "b", "c", "d"]
|
||||
idx, fuzz = find_context_core(lines, ["b", "c"], 0)
|
||||
|
||||
assert 1 == idx
|
||||
assert 0 == fuzz
|
||||
|
||||
def test_rstrip_match(self) -> None:
|
||||
"""
|
||||
Trailing whitespace match returns fuzz=1.
|
||||
"""
|
||||
lines = ["a ", "b "]
|
||||
idx, fuzz = find_context_core(lines, ["a", "b"], 0)
|
||||
|
||||
assert 0 == idx
|
||||
assert 1 == fuzz
|
||||
|
||||
def test_strip_match(self) -> None:
|
||||
"""
|
||||
Leading+trailing whitespace match returns fuzz=100.
|
||||
"""
|
||||
lines = [" a ", " b "]
|
||||
idx, fuzz = find_context_core(lines, ["a", "b"], 0)
|
||||
|
||||
assert 0 == idx
|
||||
assert 100 == fuzz
|
||||
|
||||
def test_not_found(self) -> None:
|
||||
"""
|
||||
Missing context returns -1.
|
||||
"""
|
||||
idx, fuzz = find_context_core(["a", "b"], ["x", "y"], 0)
|
||||
|
||||
assert -1 == idx
|
||||
|
||||
|
||||
class TestFindContext:
|
||||
"""Tests for find_context with EOF handling."""
|
||||
|
||||
def test_eof_penalty(self) -> None:
|
||||
"""
|
||||
EOF match not at end gets +10000 fuzz penalty.
|
||||
"""
|
||||
lines = ["a", "b", "c", "d"]
|
||||
idx, fuzz = find_context(lines, ["a", "b"], 0, eof=True)
|
||||
|
||||
assert 0 == idx
|
||||
assert fuzz >= 10_000
|
||||
|
||||
|
||||
class TestExtractUpdateSections:
|
||||
"""Tests for extract_update_sections."""
|
||||
|
||||
def test_single_file(self) -> None:
|
||||
"""
|
||||
Single UPDATE section is extracted.
|
||||
"""
|
||||
content = (
|
||||
"*** Begin Patch\n"
|
||||
"*** Update File: foo.py\n"
|
||||
" context\n"
|
||||
"-old\n"
|
||||
"+new\n"
|
||||
"*** End Patch\n"
|
||||
)
|
||||
sections = extract_update_sections(content)
|
||||
|
||||
assert "foo.py" in sections
|
||||
|
||||
def test_no_patch_markers(self) -> None:
|
||||
"""
|
||||
Missing markers return empty dict.
|
||||
"""
|
||||
assert {} == extract_update_sections("no patch here")
|
||||
|
||||
|
||||
class TestApplyUpdate:
|
||||
"""Tests for apply_update."""
|
||||
|
||||
def test_simple_replacement(self) -> None:
|
||||
"""
|
||||
Single chunk replaces lines correctly.
|
||||
"""
|
||||
from codeflash_api.diff._v4a import Chunk, PatchAction
|
||||
|
||||
action = PatchAction(
|
||||
path="test.py",
|
||||
chunks=[
|
||||
Chunk(
|
||||
orig_index=1,
|
||||
del_lines=["old_line"],
|
||||
ins_lines=["new_line"],
|
||||
)
|
||||
],
|
||||
)
|
||||
result = apply_update("first\nold_line\nlast", action, "test.py")
|
||||
|
||||
assert "first\nnew_line\nlast\n" == result
|
||||
|
||||
def test_mismatch_raises(self) -> None:
|
||||
"""
|
||||
Deleted line mismatch raises DiffError.
|
||||
"""
|
||||
from codeflash_api.diff._v4a import Chunk, PatchAction
|
||||
|
||||
action = PatchAction(
|
||||
path="test.py",
|
||||
chunks=[
|
||||
Chunk(
|
||||
orig_index=0,
|
||||
del_lines=["expected"],
|
||||
ins_lines=["new"],
|
||||
)
|
||||
],
|
||||
)
|
||||
with pytest.raises(DiffError, match="mismatch"):
|
||||
apply_update("actual", action, "test.py")
|
||||
|
||||
|
||||
class TestV4ADiff:
|
||||
"""Tests for V4ADiff.run."""
|
||||
|
||||
def test_no_changes(self) -> None:
|
||||
"""
|
||||
No matching sections returns original code.
|
||||
"""
|
||||
diff = V4ADiff(
|
||||
content="*** Begin Patch\n*** End Patch\n",
|
||||
source_code={"foo.py": "original"},
|
||||
)
|
||||
result = diff.run()
|
||||
|
||||
assert "original" == result["foo.py"]
|
||||
|
|
@ -81,6 +81,15 @@ ignore = [
|
|||
"S104", # binding to 0.0.0.0 is the dev default, overridden in production
|
||||
"S105", # dev placeholder secret_key, overridden via env var in production
|
||||
]
|
||||
"packages/codeflash-api/src/codeflash_api/diff/_v4a.py" = [
|
||||
"C901", # peek_next_section and _parse_update_file_sections faithfully ported from Django
|
||||
"PLR0912", # too many branches in faithfully ported _parse_update_file_sections
|
||||
"PLR0915", # too many statements in faithfully ported _parse_update_file_sections
|
||||
]
|
||||
"packages/codeflash-api/src/codeflash_api/llm/_client.py" = [
|
||||
"PLC0415", # conditional imports for event loop safety (clients recreated on loop change)
|
||||
"TRY301", # raise inside try is the intended pattern for cost-tracking on unsupported model type
|
||||
]
|
||||
"packages/codeflash-core/src/codeflash_core/_model.py" = [
|
||||
"C901", # humanize_runtime is complex but faithfully ported
|
||||
"PLR2004", # magic values in humanize_runtime thresholds
|
||||
|
|
|
|||
Loading…
Reference in a new issue