Add diff layer: SEARCH/REPLACE and V4A patch application

Faithfully ported from Django aiservice. V4A uses 3-tier fuzzy
context matching (exact/rstrip/strip) with EOF penalties and scope
markers. Per-file lint ignores for ported complexity.
This commit is contained in:
Kevin Turcios 2026-04-21 21:55:28 -05:00
parent 2acebdbf51
commit 5c6b82050a
5 changed files with 984 additions and 0 deletions

View file

@ -0,0 +1,42 @@
from __future__ import annotations
import enum
from abc import abstractmethod
class DiffError(ValueError):
"""
Any problem detected while parsing or applying a patch.
"""
class DiffMethod(enum.Enum):
NO_DIFF = "no_diff"
V4A = "v4a"
SEARCH_AND_REPLACE = "search_and_replace"
class Diff:
"""
Base class for diff application strategies.
"""
def __init__(
self,
content: str,
source_code: dict[str, str],
*,
match_files_when_having_single_patch: bool = True,
) -> None:
self.content = content
self.source_code = source_code
self.extracted_diff: dict[str, str] | None = None
self.match_files_when_having_single_patch = (
match_files_when_having_single_patch
)
@abstractmethod
def run(self) -> dict[str, str]:
"""
Apply the diff to the source code and return modified files.
"""

View file

@ -0,0 +1,187 @@
from __future__ import annotations
import re
import attrs
from codeflash_api.diff._base import Diff
REPLACE_IN_FILE_TAGS_RE = re.compile(
r"<replace_in_file>(.*?)</replace_in_file>",
re.DOTALL | re.MULTILINE,
)
MULTI_REPLACE_IN_FILE_RE = re.compile(
r"<replace_in_file>\s*<path>(.*?)</path>"
r"\s*<diff>([\s\S]*?)</diff>\s*</replace_in_file>",
re.MULTILINE,
)
SEARCH_MARKER = "<<<<<<< SEARCH"
DELIMITER = "======="
REPLACE_MARKER = ">>>>>>> REPLACE"
@attrs.frozen
class SearchReplaceBlock:
"""
One SEARCH/REPLACE pair.
"""
search: str
replace: str
def parse_diff(diff: str) -> list[SearchReplaceBlock]:
"""
Parse SEARCH/REPLACE blocks from *diff*.
"""
if not diff or not diff.strip():
msg = "Empty or invalid diff string provided"
raise ValueError(msg)
lines = diff.splitlines(keepends=True)
blocks: list[SearchReplaceBlock] = []
search_start: int | None = None
delimiter_idx: int | None = None
for i, line in enumerate(lines):
stripped = line.rstrip()
if stripped == SEARCH_MARKER:
search_start = i
elif stripped == DELIMITER and search_start is not None:
delimiter_idx = i
elif stripped == REPLACE_MARKER and delimiter_idx is not None:
search_content = "".join(
lines[search_start + 1 : delimiter_idx]
).rstrip()
replace_content = "".join(lines[delimiter_idx + 1 : i]).rstrip()
blocks.append(
SearchReplaceBlock(
search=search_content,
replace=replace_content,
)
)
search_start = None
delimiter_idx = None
if search_start is not None and delimiter_idx is None:
msg = "Invalid diff format: Missing '=======' marker"
raise ValueError(msg)
if delimiter_idx is not None:
msg = "Invalid diff format: Missing '>>>>>>> REPLACE' marker"
raise ValueError(msg)
if not blocks:
msg = "No valid SEARCH/REPLACE blocks found in the diff"
raise ValueError(msg)
return blocks
def group_diff_patches_by_path(
replace_tags_str: str,
) -> dict[str, str]:
"""
Group multiple diffs by file path from XML tag content.
"""
matches = MULTI_REPLACE_IN_FILE_RE.findall(replace_tags_str)
result: dict[str, str] = {}
for path, diff_content in matches:
clean_path = path.strip()
if clean_path in result:
result[clean_path] += "\n" + diff_content
else:
result[clean_path] = diff_content
return result
def extract_patches(content: str) -> dict[str, str]:
"""
Extract all SEARCH/REPLACE patches from *content*.
"""
matches = REPLACE_IN_FILE_TAGS_RE.findall(content)
if not matches:
return {}
wrapped = f"<replace_in_file>{matches[0]}</replace_in_file>"
return group_diff_patches_by_path(wrapped)
def find_with_whitespace_flexibility(
search: str, content: str
) -> tuple[int, int] | None:
"""
Match *search* in *content* with flexible whitespace.
"""
parts = re.split(r"(\s+)", search)
pattern_parts: list[str] = []
for part in parts:
if not part:
continue
if re.match(r"^\s+$", part):
pattern_parts.append(r"\s+")
else:
pattern_parts.append(re.escape(part))
pattern = "".join(pattern_parts)
match = re.search(pattern, content, re.MULTILINE | re.DOTALL)
if match:
return (match.start(), match.end())
return None
def apply_patches(diff_str: str, content: str) -> str:
"""
Parse and apply SEARCH/REPLACE blocks to *content*.
"""
try:
blocks = parse_diff(diff_str)
except ValueError:
return content
for block in blocks:
if not block.search and block.replace:
content += block.replace
continue
if block.search:
idx = content.find(block.search)
if idx >= 0:
content = content.replace(block.search, block.replace, 1)
else:
span = find_with_whitespace_flexibility(block.search, content)
if span:
start, end = span
content = content[:start] + block.replace + content[end:]
return content
class SearchAndReplaceDiff(Diff):
"""
Apply SEARCH/REPLACE style diffs.
"""
def run(self) -> dict[str, str]:
"""
Apply patches to source files.
"""
patches = extract_patches(self.content)
result: dict[str, str] = {}
if (
not self.match_files_when_having_single_patch
and len(self.source_code) == 1
and len(patches) == 1
):
src_path = next(iter(self.source_code))
diff_content = next(iter(patches.values()))
result[src_path] = apply_patches(
diff_content, self.source_code[src_path]
)
return result
for path, code in self.source_code.items():
if path in patches:
result[path] = apply_patches(patches[path], code)
else:
result[path] = code
return result

View file

@ -0,0 +1,415 @@
from __future__ import annotations
import re
import attrs
from codeflash_api.diff._base import Diff, DiffError
UPDATE_FILE_PREFIX = "*** Update File: "
@attrs.define
class Chunk:
"""
A single add/delete change at a specific line index.
"""
orig_index: int = -1
del_lines: list[str] = attrs.Factory(list)
ins_lines: list[str] = attrs.Factory(list)
@attrs.define
class PatchAction:
"""
All chunks for one file.
"""
path: str = ""
chunks: list[Chunk] = attrs.Factory(list)
@attrs.define
class Patch:
"""
Parsed patch with actions per file.
"""
actions: dict[str, PatchAction] = attrs.Factory(dict)
fuzz: int = 0
def _norm(line: str) -> str:
"""
Strip CR for LF/CRLF compatibility.
"""
return line.replace("\r", "")
def find_context_core(
lines: list[str],
context: list[str],
start: int,
) -> tuple[int, int]:
"""
Find *context* in *lines* starting at *start* with 3-tier fuzz.
"""
ctx_len = len(context)
end = len(lines) - ctx_len + 1
for i in range(start, end):
if lines[i : i + ctx_len] == context:
return (i, 0)
rstrip_ctx = [s.rstrip() for s in context]
for i in range(start, end):
if [s.rstrip() for s in lines[i : i + ctx_len]] == rstrip_ctx:
return (i, 1)
strip_ctx = [s.strip() for s in context]
for i in range(start, end):
if [s.strip() for s in lines[i : i + ctx_len]] == strip_ctx:
return (i, 100)
return (-1, 0)
def find_context(
lines: list[str],
context: list[str],
start: int,
*,
eof: bool,
) -> tuple[int, int]:
"""
Find *context* with special EOF marker handling.
"""
if eof:
idx, fuzz = find_context_core(
lines, context, len(lines) - len(context)
)
if idx >= 0:
return (idx, fuzz)
idx, fuzz = find_context_core(lines, context, start)
if idx >= 0:
return (idx, fuzz + 10_000)
return (-1, 0)
return find_context_core(lines, context, start)
def peek_next_section(
lines: list[str], index: int
) -> tuple[list[str], list[Chunk], int, bool]:
"""
Parse one context/change section from an UPDATE block.
"""
context_lines: list[str] = []
chunks: list[Chunk] = []
del_lines: list[str] = []
ins_lines: list[str] = []
mode = "keep"
is_eof = False
i = index
while i < len(lines):
line = _norm(lines[i])
if line.startswith("@@"):
break
if line.startswith("*** "):
tag = line.strip()
if tag == "*** End of File":
is_eof = True
i += 1
break
if line.startswith(" ") or line == "":
if mode in ("add", "delete"):
chunks.append(
Chunk(
orig_index=len(context_lines) - len(del_lines),
del_lines=del_lines,
ins_lines=ins_lines,
)
)
del_lines = []
ins_lines = []
mode = "keep"
ctx = line[1:] if line.startswith(" ") else ""
context_lines.append(ctx)
i += 1
elif line.startswith("-"):
if mode == "add":
chunks.append(
Chunk(
orig_index=len(context_lines) - len(del_lines),
del_lines=del_lines,
ins_lines=ins_lines,
)
)
del_lines = []
ins_lines = []
mode = "delete"
stripped = line[1:]
del_lines.append(stripped)
context_lines.append(stripped)
i += 1
elif line.startswith("+"):
mode = "add"
ins_lines.append(line[1:])
i += 1
else:
msg = f"Invalid patch line found: {line!r}"
raise DiffError(msg)
if mode in ("add", "delete"):
chunks.append(
Chunk(
orig_index=len(context_lines) - len(del_lines),
del_lines=del_lines,
ins_lines=ins_lines,
)
)
return (context_lines, chunks, i, is_eof)
def _parse_update_file_sections(
lines: list[str],
index: int,
file_content: str,
) -> tuple[PatchAction, int, int]:
"""
Parse all @@ scope markers and sections for one file.
"""
action = PatchAction()
total_fuzz = 0
file_lines = [_norm(ln) for ln in file_content.splitlines()]
current_file_idx = 0
i = index
while i < len(lines):
line = _norm(lines[i])
if not line.strip():
i += 1
continue
if line.startswith("*** ") and not line.startswith("*** End of File"):
break
if line.startswith("@@"):
scope_lines: list[str] = []
while i < len(lines) and _norm(lines[i]).startswith("@@"):
scope_lines.append(_norm(lines[i]).strip().strip("@").strip())
i += 1
found = False
for scope in scope_lines:
for j in range(current_file_idx, len(file_lines)):
if file_lines[j].strip() == scope:
current_file_idx = j
found = True
break
if found:
break
if not found:
for scope in scope_lines:
for j in range(current_file_idx, len(file_lines)):
if file_lines[j].strip() == scope.strip():
current_file_idx = j
total_fuzz += 1
found = True
break
if found:
break
if not found:
msg = f"Could not find scope: {scope_lines!r}"
raise DiffError(msg)
continue
context_block, chunks, next_i, is_eof = peek_next_section(lines, i)
i = next_i
if not context_block and not chunks:
continue
if context_block:
found_idx, fuzz = find_context(
file_lines,
context_block,
current_file_idx,
eof=is_eof,
)
if found_idx < 0:
msg = (
"Could not find context block"
f" starting with: {context_block[:3]!r}"
)
raise DiffError(msg)
total_fuzz += fuzz
for chunk in chunks:
chunk.orig_index += found_idx
current_file_idx = found_idx + len(context_block)
else:
for chunk in chunks:
chunk.orig_index += current_file_idx
action.chunks.extend(chunks)
return (action, i, total_fuzz)
def parse_patch_text(
lines: list[str],
current_files: dict[str, str],
) -> Patch:
"""
Parse full patch content into a Patch object.
"""
patch = Patch()
i = 0
while i < len(lines):
line = _norm(lines[i])
if not line.strip():
i += 1
continue
norm_line = line.strip()
if norm_line.startswith(UPDATE_FILE_PREFIX.strip()):
path = norm_line[len(UPDATE_FILE_PREFIX.strip()) :].strip()
if path not in current_files:
msg = f"File not found: {path!r}"
raise DiffError(msg)
i += 1
action, i, fuzz = _parse_update_file_sections(
lines, i, current_files[path]
)
action.path = path
if path in patch.actions:
patch.actions[path].chunks.extend(action.chunks)
else:
patch.actions[path] = action
patch.fuzz += fuzz
else:
msg = f"Unknown or misplaced line: {line!r}"
raise DiffError(msg)
return patch
_PATCH_BODY_RE = re.compile(
r"\*\*\* Begin Patch\s*(.*?)\s*\*\*\* End Patch",
re.DOTALL,
)
_UPDATE_SPLIT_RE = re.compile(r"^\*\*\* Update File:\s*", re.MULTILINE)
def extract_update_sections(
content: str,
) -> dict[str, str]:
"""
Extract UPDATE sections from patch content.
"""
match = _PATCH_BODY_RE.search(content)
if not match:
return {}
body = match.group(1)
parts = _UPDATE_SPLIT_RE.split(body)
result: dict[str, str] = {}
for part in parts[1:]:
if not part.strip():
continue
first_nl = part.find("\n")
if first_nl < 0:
continue
file_path = part[:first_nl].strip()
diff_content = part[first_nl + 1 :]
result[file_path] = f"{UPDATE_FILE_PREFIX}{file_path}\n{diff_content}"
return result
def apply_update(
text: str,
action: PatchAction,
path: str,
) -> str:
"""
Apply UPDATE chunks to file content.
"""
orig_lines = text.splitlines()
output: list[str] = []
current = 0
sorted_chunks = sorted(action.chunks, key=lambda c: c.orig_index)
for chunk in sorted_chunks:
if chunk.orig_index < current:
msg = f"Overlapping chunks in {path} at line {chunk.orig_index}"
raise DiffError(msg)
output.extend(orig_lines[current : chunk.orig_index])
expected = [_norm(s).strip() for s in chunk.del_lines]
actual = [
_norm(s).strip()
for s in orig_lines[
chunk.orig_index : chunk.orig_index + len(chunk.del_lines)
]
]
if expected != actual:
msg = (
f"Deleted lines mismatch in {path}"
f" at line {chunk.orig_index}:"
f" expected {expected!r},"
f" got {actual!r}"
)
raise DiffError(msg)
output.extend(chunk.ins_lines)
current = chunk.orig_index + len(chunk.del_lines)
output.extend(orig_lines[current:])
result = "\n".join(output)
if result or text:
result += "\n"
return result
class V4ADiff(Diff):
"""
Apply V4A unified diff patches.
"""
def run(self) -> dict[str, str]:
"""
Apply V4A patches to source files.
"""
sections = extract_update_sections(self.content)
result: dict[str, str] = {}
for path, code in self.source_code.items():
if path in sections:
diff_text = sections[path].strip()
patch = parse_patch_text(
diff_text.splitlines(),
{path: code},
)
if path in patch.actions:
result[path] = apply_update(
code, patch.actions[path], path
)
else:
result[path] = code
else:
result[path] = code
return result

View file

@ -0,0 +1,331 @@
from __future__ import annotations
import pytest
from codeflash_api.diff._base import DiffError, DiffMethod
from codeflash_api.diff._search_replace import (
SearchAndReplaceDiff,
SearchReplaceBlock,
apply_patches,
extract_patches,
find_with_whitespace_flexibility,
parse_diff,
)
from codeflash_api.diff._v4a import (
V4ADiff,
_norm,
apply_update,
extract_update_sections,
find_context,
find_context_core,
)
class TestDiffMethod:
"""Tests for DiffMethod enum."""
def test_values(self) -> None:
"""
All three diff methods exist.
"""
assert "no_diff" == DiffMethod.NO_DIFF.value
assert "v4a" == DiffMethod.V4A.value
assert "search_and_replace" == DiffMethod.SEARCH_AND_REPLACE.value
class TestSearchReplaceParseDiff:
"""Tests for parse_diff."""
def test_single_block(self) -> None:
"""
Single SEARCH/REPLACE block is parsed correctly.
"""
diff = "<<<<<<< SEARCH\nold code\n=======\nnew code\n>>>>>>> REPLACE\n"
blocks = parse_diff(diff)
assert 1 == len(blocks)
assert "old code" == blocks[0].search
assert "new code" == blocks[0].replace
def test_multiple_blocks(self) -> None:
"""
Multiple SEARCH/REPLACE blocks are all parsed.
"""
diff = (
"<<<<<<< SEARCH\nfoo\n=======\nbar\n>>>>>>> REPLACE\n"
"<<<<<<< SEARCH\nbaz\n=======\nqux\n>>>>>>> REPLACE\n"
)
blocks = parse_diff(diff)
assert 2 == len(blocks)
def test_empty_search(self) -> None:
"""
Empty search content (append mode) is valid.
"""
diff = "<<<<<<< SEARCH\n=======\nnew stuff\n>>>>>>> REPLACE\n"
blocks = parse_diff(diff)
assert "" == blocks[0].search
assert "new stuff" == blocks[0].replace
def test_empty_input_raises(self) -> None:
"""
Empty string raises ValueError.
"""
with pytest.raises(ValueError, match="Empty"):
parse_diff("")
def test_missing_delimiter_raises(self) -> None:
"""
Missing ======= marker raises ValueError.
"""
with pytest.raises(ValueError, match="======="):
parse_diff("<<<<<<< SEARCH\nfoo\n>>>>>>> REPLACE\n")
def test_missing_replace_marker_raises(self) -> None:
"""
Missing >>>>>>> REPLACE marker raises ValueError.
"""
with pytest.raises(ValueError, match="REPLACE"):
parse_diff("<<<<<<< SEARCH\nfoo\n=======\nbar\n")
class TestFindWithWhitespaceFlexibility:
"""Tests for find_with_whitespace_flexibility."""
def test_exact_match(self) -> None:
"""
Exact content is found.
"""
result = find_with_whitespace_flexibility(
"hello world", "prefix hello world suffix"
)
assert result is not None
assert 7 == result[0]
def test_flexible_whitespace(self) -> None:
"""
Tabs match spaces.
"""
result = find_with_whitespace_flexibility(
"hello world", "hello\tworld"
)
assert result is not None
def test_no_match(self) -> None:
"""
Non-existent content returns None.
"""
assert find_with_whitespace_flexibility("xyz", "abc") is None
class TestApplyPatches:
"""Tests for apply_patches."""
def test_simple_replace(self) -> None:
"""
Exact match is replaced.
"""
diff = "<<<<<<< SEARCH\nold\n=======\nnew\n>>>>>>> REPLACE\n"
result = apply_patches(diff, "before old after")
assert "before new after" == result
def test_append_mode(self) -> None:
"""
Empty search appends to content.
"""
diff = "<<<<<<< SEARCH\n=======\n# appended\n>>>>>>> REPLACE\n"
result = apply_patches(diff, "existing")
assert "existing# appended" == result
def test_parse_error_returns_original(self) -> None:
"""
Invalid diff returns content unchanged.
"""
result = apply_patches("not a diff", "original")
assert "original" == result
class TestSearchAndReplaceDiff:
"""Tests for SearchAndReplaceDiff.run."""
def test_single_file_patch(self) -> None:
"""
Patch applied to matching file.
"""
content = (
"<replace_in_file>"
"<path>foo.py</path>"
"<diff>\n<<<<<<< SEARCH\nold\n=======\nnew\n>>>>>>> REPLACE\n</diff>"
"</replace_in_file>"
)
diff = SearchAndReplaceDiff(
content=content,
source_code={"foo.py": "old code"},
)
result = diff.run()
assert "new code" == result["foo.py"]
class TestNorm:
"""Tests for _norm."""
def test_strips_cr(self) -> None:
"""
Carriage returns are removed.
"""
assert "hello" == _norm("hello\r")
def test_no_cr(self) -> None:
"""
Lines without CR are unchanged.
"""
assert "hello" == _norm("hello")
class TestFindContextCore:
"""Tests for find_context_core."""
def test_exact_match(self) -> None:
"""
Exact match returns fuzz=0.
"""
lines = ["a", "b", "c", "d"]
idx, fuzz = find_context_core(lines, ["b", "c"], 0)
assert 1 == idx
assert 0 == fuzz
def test_rstrip_match(self) -> None:
"""
Trailing whitespace match returns fuzz=1.
"""
lines = ["a ", "b "]
idx, fuzz = find_context_core(lines, ["a", "b"], 0)
assert 0 == idx
assert 1 == fuzz
def test_strip_match(self) -> None:
"""
Leading+trailing whitespace match returns fuzz=100.
"""
lines = [" a ", " b "]
idx, fuzz = find_context_core(lines, ["a", "b"], 0)
assert 0 == idx
assert 100 == fuzz
def test_not_found(self) -> None:
"""
Missing context returns -1.
"""
idx, fuzz = find_context_core(["a", "b"], ["x", "y"], 0)
assert -1 == idx
class TestFindContext:
"""Tests for find_context with EOF handling."""
def test_eof_penalty(self) -> None:
"""
EOF match not at end gets +10000 fuzz penalty.
"""
lines = ["a", "b", "c", "d"]
idx, fuzz = find_context(lines, ["a", "b"], 0, eof=True)
assert 0 == idx
assert fuzz >= 10_000
class TestExtractUpdateSections:
"""Tests for extract_update_sections."""
def test_single_file(self) -> None:
"""
Single UPDATE section is extracted.
"""
content = (
"*** Begin Patch\n"
"*** Update File: foo.py\n"
" context\n"
"-old\n"
"+new\n"
"*** End Patch\n"
)
sections = extract_update_sections(content)
assert "foo.py" in sections
def test_no_patch_markers(self) -> None:
"""
Missing markers return empty dict.
"""
assert {} == extract_update_sections("no patch here")
class TestApplyUpdate:
"""Tests for apply_update."""
def test_simple_replacement(self) -> None:
"""
Single chunk replaces lines correctly.
"""
from codeflash_api.diff._v4a import Chunk, PatchAction
action = PatchAction(
path="test.py",
chunks=[
Chunk(
orig_index=1,
del_lines=["old_line"],
ins_lines=["new_line"],
)
],
)
result = apply_update("first\nold_line\nlast", action, "test.py")
assert "first\nnew_line\nlast\n" == result
def test_mismatch_raises(self) -> None:
"""
Deleted line mismatch raises DiffError.
"""
from codeflash_api.diff._v4a import Chunk, PatchAction
action = PatchAction(
path="test.py",
chunks=[
Chunk(
orig_index=0,
del_lines=["expected"],
ins_lines=["new"],
)
],
)
with pytest.raises(DiffError, match="mismatch"):
apply_update("actual", action, "test.py")
class TestV4ADiff:
"""Tests for V4ADiff.run."""
def test_no_changes(self) -> None:
"""
No matching sections returns original code.
"""
diff = V4ADiff(
content="*** Begin Patch\n*** End Patch\n",
source_code={"foo.py": "original"},
)
result = diff.run()
assert "original" == result["foo.py"]

View file

@ -81,6 +81,15 @@ ignore = [
"S104", # binding to 0.0.0.0 is the dev default, overridden in production
"S105", # dev placeholder secret_key, overridden via env var in production
]
"packages/codeflash-api/src/codeflash_api/diff/_v4a.py" = [
"C901", # peek_next_section and _parse_update_file_sections faithfully ported from Django
"PLR0912", # too many branches in faithfully ported _parse_update_file_sections
"PLR0915", # too many statements in faithfully ported _parse_update_file_sections
]
"packages/codeflash-api/src/codeflash_api/llm/_client.py" = [
"PLC0415", # conditional imports for event loop safety (clients recreated on loop change)
"TRY301", # raise inside try is the intended pattern for cost-tracking on unsupported model type
]
"packages/codeflash-core/src/codeflash_core/_model.py" = [
"C901", # humanize_runtime is complex but faithfully ported
"PLR2004", # magic values in humanize_runtime thresholds