codeflash-internal/experiments/rl_env/commit_resolver.py
2026-04-16 16:31:25 -07:00

636 lines
24 KiB
Python

"""Resolve pre-optimization commits and extract original/optimized code pairs.
For each catalog entry, determines:
1. The exact commit where the function exists in its pre-optimization state
2. The original function code at that commit
3. The optimized function code (from DB or git)
4. A git patch that transforms original → optimized
Usage:
python rl_env/commit_resolver.py \
--catalog optimization_catalog.json \
--inference-repo /path/to/inference \
--output resolved_catalog.json
"""
from __future__ import annotations
import argparse
import ast
import json
import logging
import subprocess
import textwrap
from dataclasses import dataclass
from pathlib import Path
import psycopg2
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
log = logging.getLogger(__name__)
DB_HOST = "codeflash-pgsql-db-prod.postgres.database.azure.com"
DB_NAME = "postgres"
DB_USER = "readonly_user"
DB_PASSWORD = "3U@L5d2B^z1+,ohC[]dP"
@dataclass
class ResolvedTask:
trace_id: str
function_name: str
file_path: str
speedup_x: float | None
review_quality: str | None
pre_optimization_commit: str # repo state before optimization
original_function_code: str # the function before optimization
optimized_function_code: str # the function after optimization (single file, fences stripped)
optimized_raw_markdown: str # raw DB format with ```python:filepath fences (for multi-file)
optimized_function_only: str | None # just the function extracted from optimized code
original_file_content: str # full file at pre-optimization commit
patch: str | None # git diff for solve.sh
source: str
resolution_method: str # "git_parent", "db_code_match", "db_code_only"
def get_db_connection() -> psycopg2.extensions.connection:
return psycopg2.connect(host=DB_HOST, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD, sslmode="require")
def git_show_file(repo: Path, commit: str, file_path: str) -> str | None:
"""Get file content at a specific commit."""
result = subprocess.run(
["git", "show", f"{commit}:{file_path}"], cwd=repo, capture_output=True, text=True, encoding="utf-8"
)
if result.returncode != 0:
return None
return result.stdout
def git_diff(repo: Path, commit_a: str, commit_b: str, file_path: str) -> str | None:
"""Get diff between two commits for a specific file."""
result = subprocess.run(
["git", "diff", commit_a, commit_b, "--", file_path], cwd=repo, capture_output=True, text=True, encoding="utf-8"
)
if result.returncode != 0:
return None
return result.stdout
def extract_function_from_source(source: str, function_name: str) -> str | None:
"""Extract a function/method definition from source code using AST.
Handles both top-level functions and class methods.
"""
short_name = function_name.rsplit(".", maxsplit=1)[-1]
class_name = function_name.split(".")[-2] if "." in function_name else None
try:
tree = ast.parse(source)
except SyntaxError:
return None
lines = source.splitlines(keepends=True)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
if node.name == short_name:
# If class_name specified, verify parent is the right class
if class_name:
parent_class = find_parent_class(tree, node)
if parent_class and parent_class.name != class_name:
continue
start = node.lineno - 1
end = node.end_lineno if node.end_lineno else start + 1
func_lines = lines[start:end]
return "".join(func_lines)
return None
def find_parent_class(tree: ast.Module, target_node: ast.AST) -> ast.ClassDef | None:
"""Find the parent class of a function node."""
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
for child in ast.walk(node):
if child is target_node:
return node
return None
def get_optimized_code_raw_from_db(conn: psycopg2.extensions.connection, trace_id: str, best_opt_id: str | None) -> str | None:
"""Get the raw optimized code (with markdown file fences) for the best candidate."""
with conn.cursor() as cur:
cur.execute(
"SELECT f.optimizations_post, f.is_correct, f.speedup_ratio, "
"f.metadata->>'best_optimization_id' as best_opt_id "
"FROM optimization_features f WHERE f.trace_id = %s",
(trace_id,),
)
row = cur.fetchone()
if not row:
return None
optimizations_post, is_correct, speedup_ratio, db_best_id = row
if not optimizations_post:
return None
target_id = best_opt_id or db_best_id
if target_id and target_id in optimizations_post:
return optimizations_post[target_id]
if not is_correct or not speedup_ratio:
return None
best_id = None
best_speedup = -1.0
for opt_id, correct in is_correct.items():
if not correct or opt_id not in optimizations_post:
continue
speedup = speedup_ratio.get(opt_id)
if speedup is not None and speedup > best_speedup:
best_speedup = speedup
best_id = opt_id
if best_id:
return optimizations_post[best_id]
return None
def get_optimized_code_from_db(
conn: psycopg2.extensions.connection,
trace_id: str,
best_opt_id: str | None,
function_name: str = "",
file_path: str = "",
) -> str | None:
"""Get the optimized code for the best candidate from optimization_features."""
with conn.cursor() as cur:
cur.execute(
"""
SELECT f.optimizations_post, f.is_correct, f.speedup_ratio,
f.metadata->>'best_optimization_id' as best_opt_id
FROM optimization_features f
WHERE f.trace_id = %s
""",
(trace_id,),
)
row = cur.fetchone()
if not row:
return None
optimizations_post, is_correct, speedup_ratio, db_best_id = row
if not optimizations_post:
return None
# Use the provided best_opt_id, or the DB one, or find the best ourselves
target_id = best_opt_id or db_best_id
if target_id and target_id in optimizations_post:
raw = optimizations_post[target_id]
if function_name:
block = find_block_containing_function(raw, function_name, file_path)
if block:
return block
return strip_markdown_fencing(raw)
# Find best correct candidate with highest speedup that exists in optimizations_post
if not is_correct or not speedup_ratio:
return None
best_id = None
best_speedup = -1.0
for opt_id, correct in is_correct.items():
if not correct:
continue
if opt_id not in optimizations_post:
continue
speedup = speedup_ratio.get(opt_id)
if speedup is not None and speedup > best_speedup:
best_speedup = speedup
best_id = opt_id
if best_id:
raw = optimizations_post[best_id]
if function_name:
block = find_block_containing_function(raw, function_name, file_path)
if block:
return block
return strip_markdown_fencing(raw)
return None
def get_original_code_from_db(
conn: psycopg2.extensions.connection, trace_id: str, function_name: str = "", file_path: str = ""
) -> str | None:
"""Get the original code from optimization_features.
Handles multi-block markdown by finding the block containing the target function.
"""
with conn.cursor() as cur:
cur.execute("SELECT f.original_code FROM optimization_features f WHERE f.trace_id = %s", (trace_id,))
row = cur.fetchone()
if not row or not row[0]:
return None
raw = row[0]
# Try to find the specific block containing our function
if function_name:
block = find_block_containing_function(raw, function_name, file_path)
if block:
return block
# Fallback: return first block
return strip_markdown_fencing(raw)
def find_commit_with_matching_function(
repo: Path, file_path: str, function_name: str, original_code: str, max_commits: int = 50
) -> str | None:
"""Search git history for a commit where the function matches the original code.
Only compares the target function, not the entire file.
"""
result = subprocess.run(
["git", "log", "--all", "--format=%H", f"-{max_commits}", "--", file_path],
cwd=repo,
capture_output=True,
text=True,
encoding="utf-8",
)
if result.returncode != 0:
return None
# Extract just the function from original_code for comparison
original_func = extract_function_from_source(original_code, function_name)
if not original_func:
# If we can't parse, try normalizing the whole thing
original_func = normalize_whitespace(original_code)
for commit in result.stdout.strip().split("\n"):
if not commit:
continue
file_content = git_show_file(repo, commit, file_path)
if not file_content:
continue
func_at_commit = extract_function_from_source(file_content, function_name)
if not func_at_commit:
continue
if normalize_whitespace(func_at_commit) == normalize_whitespace(original_func):
return commit
return None
def parse_markdown_code_blocks(raw: str) -> list[tuple[str | None, str]]:
"""Parse multiple markdown code blocks from DB code field.
Returns list of (filepath_hint, code) tuples.
The filepath_hint comes from ```python:filepath headers.
"""
blocks: list[tuple[str | None, str]] = []
current_lines: list[str] = []
current_hint: str | None = None
in_block = False
for line in raw.splitlines():
stripped = line.strip()
if stripped.startswith("```") and not in_block:
in_block = True
current_lines = []
# Extract filepath hint from ```python:filepath
if ":" in stripped:
current_hint = stripped.split(":", 1)[1].strip()
else:
current_hint = None
elif stripped == "```" and in_block:
in_block = False
blocks.append((current_hint, "\n".join(current_lines)))
current_lines = []
current_hint = None
elif in_block:
current_lines.append(line)
else:
# Lines outside any fence — treat as a single block
current_lines.append(line)
# If there were no fences at all, return the whole thing as one block
if not blocks and current_lines:
blocks.append((None, "\n".join(current_lines)))
return blocks
def strip_markdown_fencing(code: str) -> str:
"""Strip markdown fencing — returns the FIRST code block's content.
For multi-block content, use parse_markdown_code_blocks instead.
"""
blocks = parse_markdown_code_blocks(code)
if blocks:
return blocks[0][1]
return code
def find_block_containing_function(raw: str, function_name: str, target_file_path: str | None = None) -> str | None:
"""From multi-block markdown, find the block containing the target function.
Tries file path hint matching first, then falls back to AST search.
"""
blocks = parse_markdown_code_blocks(raw)
short_name = function_name.rsplit(".", maxsplit=1)[-1]
# First pass: try to match by file path hint
if target_file_path:
for hint, code in blocks:
if hint and target_file_path.endswith(hint.lstrip("/")):
return code
if hint and hint.endswith(target_file_path):
return code
# Second pass: find the block where the function is defined
for _hint, code in blocks:
try:
tree = ast.parse(code)
except SyntaxError:
continue
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
if node.name == short_name:
return code
return None
def normalize_whitespace(code: str) -> str:
"""Normalize whitespace for comparison."""
return textwrap.dedent(code).strip()
def resolve_file_path(repo: Path, file_path: str) -> str:
"""If file_path doesn't exist in git, search for a similar file by basename."""
# Quick check: does it exist at HEAD?
result = subprocess.run(["git", "cat-file", "-e", f"HEAD:{file_path}"], cwd=repo, capture_output=True)
if result.returncode == 0:
return file_path
# Search for files with the same basename at HEAD
basename = Path(file_path).name
parent_dir = Path(file_path).parent.name # e.g., "byte_track" from ".../byte_track/v1.py"
result = subprocess.run(
["git", "ls-tree", "-r", "--name-only", "HEAD"], cwd=repo, capture_output=True, text=True, encoding="utf-8"
)
if result.returncode != 0:
return file_path
all_files = result.stdout.strip().split("\n")
# First try: match parent_dir/basename pattern (e.g., "byte_track*/v1.py")
# This handles renames like byte_track -> byte_tracker, area_measurement -> mask_area_measurement
candidates = [f for f in all_files if f.endswith(f"/{basename}") and parent_dir.rstrip("/") in Path(f).parent.name]
if len(candidates) == 1:
log.info(f" Resolved file path: {file_path} -> {candidates[0]}")
return candidates[0]
if len(candidates) > 1:
original_parts = set(Path(file_path).parts)
best = max(candidates, key=lambda c: len(set(Path(c).parts) & original_parts))
log.info(f" Resolved file path: {file_path} -> {best} (from {len(candidates)} candidates)")
return best
# Second try: broader basename match but only if few candidates
candidates = [f for f in all_files if f.endswith(f"/{basename}") or f == basename]
if len(candidates) == 1:
log.info(f" Resolved file path: {file_path} -> {candidates[0]}")
return candidates[0]
# Too many candidates with a generic name like v1.py — don't guess
if len(candidates) > 5:
log.warning(f" Cannot resolve {file_path}: {len(candidates)} candidates for basename '{basename}'")
return file_path
if len(candidates) > 1:
original_parts = set(Path(file_path).parts)
best = max(candidates, key=lambda c: len(set(Path(c).parts) & original_parts))
log.info(f" Resolved file path: {file_path} -> {best} (from {len(candidates)} candidates)")
return best
return file_path
def resolve_entry(entry: dict, repo: Path, conn: psycopg2.extensions.connection) -> ResolvedTask | None:
"""Resolve a catalog entry into a full task with code pairs."""
trace_id = entry["trace_id"]
function_name = entry["function_name"]
file_path = entry["file_path"]
# Strategy 1: Git-based resolution (have pre/post commit)
if entry.get("git_commit") and entry.get("pre_optimization_commit"):
pre_commit = entry["pre_optimization_commit"]
post_commit = entry["git_commit"]
original_file = git_show_file(repo, pre_commit, file_path)
optimized_file = git_show_file(repo, post_commit, file_path)
if original_file and optimized_file:
original_func = extract_function_from_source(original_file, function_name)
optimized_func = extract_function_from_source(optimized_file, function_name)
if original_func and optimized_func:
patch = git_diff(repo, pre_commit, post_commit, file_path)
return ResolvedTask(
trace_id=trace_id,
function_name=function_name,
file_path=file_path,
speedup_x=entry.get("speedup_x"),
review_quality=entry.get("review_quality"),
pre_optimization_commit=pre_commit,
original_function_code=original_func,
optimized_function_code=optimized_file, # full file for git-based
optimized_raw_markdown="",
optimized_function_only=optimized_func,
original_file_content=original_file,
patch=patch,
source=entry.get("source", ""),
resolution_method="git_parent",
)
# Try to fix file_path if it doesn't exist in the repo
file_path = resolve_file_path(repo, file_path)
# Strategy 2: DB code + git history search
original_code = get_original_code_from_db(conn, trace_id, function_name, file_path)
optimized_code_full = get_optimized_code_from_db(
conn, trace_id, entry.get("best_optimization_id"), function_name, file_path
)
# Also get the raw markdown (with file path fences) for multi-file replacement
optimized_raw = get_optimized_code_raw_from_db(conn, trace_id, entry.get("best_optimization_id")) or ""
if not original_code or not optimized_code_full:
log.debug(f" Missing code in DB for {trace_id} ({function_name})")
return None
optimized_func_only = extract_function_from_source(optimized_code_full, function_name)
optimized_code = optimized_code_full
# Find commit where original code matches
matching_commit = find_commit_with_matching_function(repo, file_path, function_name, original_code)
if matching_commit:
original_file = git_show_file(repo, matching_commit, file_path)
original_func = extract_function_from_source(original_file, function_name) if original_file else None
if original_file and original_func:
return ResolvedTask(
trace_id=trace_id,
function_name=function_name,
file_path=file_path,
speedup_x=entry.get("speedup_x"),
review_quality=entry.get("review_quality"),
pre_optimization_commit=matching_commit,
original_function_code=original_func,
optimized_function_code=optimized_code,
optimized_raw_markdown=optimized_raw,
optimized_function_only=optimized_func_only,
original_file_content=original_file,
patch=None,
source=entry.get("source", ""),
resolution_method="db_code_match",
)
# Strategy 3: DB-only (use original_code directly, no commit pinning)
# This means we have the code but can't pin it to a specific repo state.
# We'll use the original_code as-is and the file from HEAD.
head_file = git_show_file(repo, "HEAD", file_path)
if head_file:
original_func = extract_function_from_source(original_code, function_name)
if original_func:
return ResolvedTask(
trace_id=trace_id,
function_name=function_name,
file_path=file_path,
speedup_x=entry.get("speedup_x"),
review_quality=entry.get("review_quality"),
pre_optimization_commit="HEAD",
original_function_code=original_func,
optimized_function_code=optimized_code,
optimized_raw_markdown=optimized_raw,
optimized_function_only=optimized_func_only,
original_file_content=original_code,
patch=None,
source=entry.get("source", ""),
resolution_method="db_code_only",
)
return None
def main() -> None:
parser = argparse.ArgumentParser(description="Resolve catalog entries to full task data")
parser.add_argument(
"--catalog", type=Path, default=Path("optimization_catalog.json"), help="Input catalog from catalog_builder.py"
)
parser.add_argument(
"--inference-repo",
type=Path,
default=Path("/Users/saurabh/Dropbox/hacks/inference"),
help="Path to local inference repo clone",
)
parser.add_argument(
"--output", type=Path, default=Path("resolved_catalog.json"), help="Output resolved catalog JSON"
)
parser.add_argument("--limit", type=int, default=0, help="Limit number of entries to process (0 = all)")
args = parser.parse_args()
catalog = json.loads(args.catalog.read_text(encoding="utf-8"))
log.info(f"Loaded {len(catalog)} catalog entries")
if args.limit > 0:
catalog = catalog[: args.limit]
log.info(f" Processing first {args.limit} entries")
conn = get_db_connection()
resolved = []
failed = []
by_method = {"git_parent": 0, "db_code_match": 0, "db_code_only": 0}
for i, entry in enumerate(catalog):
if (i + 1) % 20 == 0:
log.info(f" Processing {i + 1}/{len(catalog)}...")
task = resolve_entry(entry, args.inference_repo, conn)
if task:
resolved.append(task)
by_method[task.resolution_method] += 1
else:
failed.append(entry)
conn.close()
log.info("=" * 60)
log.info(f"RESOLUTION RESULTS: {len(resolved)}/{len(catalog)} resolved")
log.info("=" * 60)
for method, count in by_method.items():
log.info(f" {method}: {count}")
log.info(f" Failed: {len(failed)}")
if failed:
log.info("\n Failed entries:")
for e in failed[:10]:
log.info(f" {e['function_name']} ({e['file_path']}) source={e['source']}")
# Serialize (exclude large code fields from main output, save separately)
output_entries = []
for task in resolved:
output_entries.append(
{
"trace_id": task.trace_id,
"function_name": task.function_name,
"file_path": task.file_path,
"speedup_x": task.speedup_x,
"review_quality": task.review_quality,
"pre_optimization_commit": task.pre_optimization_commit,
"resolution_method": task.resolution_method,
"source": task.source,
"original_function_code_len": len(task.original_function_code),
"optimized_function_code_len": len(task.optimized_function_code),
"has_patch": task.patch is not None,
"original_function_preview": task.original_function_code[:200] + "..."
if len(task.original_function_code) > 200
else task.original_function_code,
"optimized_function_only_preview": (task.optimized_function_only or "")[:200] + "..."
if task.optimized_function_only and len(task.optimized_function_only) > 200
else (task.optimized_function_only or "N/A"),
}
)
args.output.write_text(json.dumps(output_entries, indent=2, default=str), encoding="utf-8")
log.info(f"\nResolved catalog written to {args.output}")
# Also save full data (with code) for task generation
full_output = args.output.with_stem(args.output.stem + "_full")
full_entries = []
for task in resolved:
full_entries.append(
{
"trace_id": task.trace_id,
"function_name": task.function_name,
"file_path": task.file_path,
"speedup_x": task.speedup_x,
"review_quality": task.review_quality,
"pre_optimization_commit": task.pre_optimization_commit,
"resolution_method": task.resolution_method,
"source": task.source,
"original_function_code": task.original_function_code,
"optimized_function_code": task.optimized_function_code,
"optimized_raw_markdown": task.optimized_raw_markdown,
"optimized_function_only": task.optimized_function_only,
"original_file_content": task.original_file_content,
"patch": task.patch,
}
)
full_output.write_text(json.dumps(full_entries, indent=2, default=str), encoding="utf-8")
log.info(f"Full resolved catalog (with code) written to {full_output}")
if __name__ == "__main__":
main()