"""Resolve pre-optimization commits and extract original/optimized code pairs. For each catalog entry, determines: 1. The exact commit where the function exists in its pre-optimization state 2. The original function code at that commit 3. The optimized function code (from DB or git) 4. A git patch that transforms original → optimized Usage: python rl_env/commit_resolver.py \ --catalog optimization_catalog.json \ --inference-repo /path/to/inference \ --output resolved_catalog.json """ from __future__ import annotations import argparse import ast import json import logging import subprocess import textwrap from dataclasses import dataclass from pathlib import Path import psycopg2 logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") log = logging.getLogger(__name__) DB_HOST = "codeflash-pgsql-db-prod.postgres.database.azure.com" DB_NAME = "postgres" DB_USER = "readonly_user" DB_PASSWORD = "3U@L5d2B^z1+,ohC[]dP" @dataclass class ResolvedTask: trace_id: str function_name: str file_path: str speedup_x: float | None review_quality: str | None pre_optimization_commit: str # repo state before optimization original_function_code: str # the function before optimization optimized_function_code: str # the function after optimization (single file, fences stripped) optimized_raw_markdown: str # raw DB format with ```python:filepath fences (for multi-file) optimized_function_only: str | None # just the function extracted from optimized code original_file_content: str # full file at pre-optimization commit patch: str | None # git diff for solve.sh source: str resolution_method: str # "git_parent", "db_code_match", "db_code_only" def get_db_connection() -> psycopg2.extensions.connection: return psycopg2.connect(host=DB_HOST, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD, sslmode="require") def git_show_file(repo: Path, commit: str, file_path: str) -> str | None: """Get file content at a specific commit.""" result = subprocess.run( ["git", "show", f"{commit}:{file_path}"], cwd=repo, capture_output=True, text=True, encoding="utf-8" ) if result.returncode != 0: return None return result.stdout def git_diff(repo: Path, commit_a: str, commit_b: str, file_path: str) -> str | None: """Get diff between two commits for a specific file.""" result = subprocess.run( ["git", "diff", commit_a, commit_b, "--", file_path], cwd=repo, capture_output=True, text=True, encoding="utf-8" ) if result.returncode != 0: return None return result.stdout def extract_function_from_source(source: str, function_name: str) -> str | None: """Extract a function/method definition from source code using AST. Handles both top-level functions and class methods. """ short_name = function_name.rsplit(".", maxsplit=1)[-1] class_name = function_name.split(".")[-2] if "." in function_name else None try: tree = ast.parse(source) except SyntaxError: return None lines = source.splitlines(keepends=True) for node in ast.walk(tree): if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef): if node.name == short_name: # If class_name specified, verify parent is the right class if class_name: parent_class = find_parent_class(tree, node) if parent_class and parent_class.name != class_name: continue start = node.lineno - 1 end = node.end_lineno if node.end_lineno else start + 1 func_lines = lines[start:end] return "".join(func_lines) return None def find_parent_class(tree: ast.Module, target_node: ast.AST) -> ast.ClassDef | None: """Find the parent class of a function node.""" for node in ast.walk(tree): if isinstance(node, ast.ClassDef): for child in ast.walk(node): if child is target_node: return node return None def get_optimized_code_raw_from_db(conn: psycopg2.extensions.connection, trace_id: str, best_opt_id: str | None) -> str | None: """Get the raw optimized code (with markdown file fences) for the best candidate.""" with conn.cursor() as cur: cur.execute( "SELECT f.optimizations_post, f.is_correct, f.speedup_ratio, " "f.metadata->>'best_optimization_id' as best_opt_id " "FROM optimization_features f WHERE f.trace_id = %s", (trace_id,), ) row = cur.fetchone() if not row: return None optimizations_post, is_correct, speedup_ratio, db_best_id = row if not optimizations_post: return None target_id = best_opt_id or db_best_id if target_id and target_id in optimizations_post: return optimizations_post[target_id] if not is_correct or not speedup_ratio: return None best_id = None best_speedup = -1.0 for opt_id, correct in is_correct.items(): if not correct or opt_id not in optimizations_post: continue speedup = speedup_ratio.get(opt_id) if speedup is not None and speedup > best_speedup: best_speedup = speedup best_id = opt_id if best_id: return optimizations_post[best_id] return None def get_optimized_code_from_db( conn: psycopg2.extensions.connection, trace_id: str, best_opt_id: str | None, function_name: str = "", file_path: str = "", ) -> str | None: """Get the optimized code for the best candidate from optimization_features.""" with conn.cursor() as cur: cur.execute( """ SELECT f.optimizations_post, f.is_correct, f.speedup_ratio, f.metadata->>'best_optimization_id' as best_opt_id FROM optimization_features f WHERE f.trace_id = %s """, (trace_id,), ) row = cur.fetchone() if not row: return None optimizations_post, is_correct, speedup_ratio, db_best_id = row if not optimizations_post: return None # Use the provided best_opt_id, or the DB one, or find the best ourselves target_id = best_opt_id or db_best_id if target_id and target_id in optimizations_post: raw = optimizations_post[target_id] if function_name: block = find_block_containing_function(raw, function_name, file_path) if block: return block return strip_markdown_fencing(raw) # Find best correct candidate with highest speedup that exists in optimizations_post if not is_correct or not speedup_ratio: return None best_id = None best_speedup = -1.0 for opt_id, correct in is_correct.items(): if not correct: continue if opt_id not in optimizations_post: continue speedup = speedup_ratio.get(opt_id) if speedup is not None and speedup > best_speedup: best_speedup = speedup best_id = opt_id if best_id: raw = optimizations_post[best_id] if function_name: block = find_block_containing_function(raw, function_name, file_path) if block: return block return strip_markdown_fencing(raw) return None def get_original_code_from_db( conn: psycopg2.extensions.connection, trace_id: str, function_name: str = "", file_path: str = "" ) -> str | None: """Get the original code from optimization_features. Handles multi-block markdown by finding the block containing the target function. """ with conn.cursor() as cur: cur.execute("SELECT f.original_code FROM optimization_features f WHERE f.trace_id = %s", (trace_id,)) row = cur.fetchone() if not row or not row[0]: return None raw = row[0] # Try to find the specific block containing our function if function_name: block = find_block_containing_function(raw, function_name, file_path) if block: return block # Fallback: return first block return strip_markdown_fencing(raw) def find_commit_with_matching_function( repo: Path, file_path: str, function_name: str, original_code: str, max_commits: int = 50 ) -> str | None: """Search git history for a commit where the function matches the original code. Only compares the target function, not the entire file. """ result = subprocess.run( ["git", "log", "--all", "--format=%H", f"-{max_commits}", "--", file_path], cwd=repo, capture_output=True, text=True, encoding="utf-8", ) if result.returncode != 0: return None # Extract just the function from original_code for comparison original_func = extract_function_from_source(original_code, function_name) if not original_func: # If we can't parse, try normalizing the whole thing original_func = normalize_whitespace(original_code) for commit in result.stdout.strip().split("\n"): if not commit: continue file_content = git_show_file(repo, commit, file_path) if not file_content: continue func_at_commit = extract_function_from_source(file_content, function_name) if not func_at_commit: continue if normalize_whitespace(func_at_commit) == normalize_whitespace(original_func): return commit return None def parse_markdown_code_blocks(raw: str) -> list[tuple[str | None, str]]: """Parse multiple markdown code blocks from DB code field. Returns list of (filepath_hint, code) tuples. The filepath_hint comes from ```python:filepath headers. """ blocks: list[tuple[str | None, str]] = [] current_lines: list[str] = [] current_hint: str | None = None in_block = False for line in raw.splitlines(): stripped = line.strip() if stripped.startswith("```") and not in_block: in_block = True current_lines = [] # Extract filepath hint from ```python:filepath if ":" in stripped: current_hint = stripped.split(":", 1)[1].strip() else: current_hint = None elif stripped == "```" and in_block: in_block = False blocks.append((current_hint, "\n".join(current_lines))) current_lines = [] current_hint = None elif in_block: current_lines.append(line) else: # Lines outside any fence — treat as a single block current_lines.append(line) # If there were no fences at all, return the whole thing as one block if not blocks and current_lines: blocks.append((None, "\n".join(current_lines))) return blocks def strip_markdown_fencing(code: str) -> str: """Strip markdown fencing — returns the FIRST code block's content. For multi-block content, use parse_markdown_code_blocks instead. """ blocks = parse_markdown_code_blocks(code) if blocks: return blocks[0][1] return code def find_block_containing_function(raw: str, function_name: str, target_file_path: str | None = None) -> str | None: """From multi-block markdown, find the block containing the target function. Tries file path hint matching first, then falls back to AST search. """ blocks = parse_markdown_code_blocks(raw) short_name = function_name.rsplit(".", maxsplit=1)[-1] # First pass: try to match by file path hint if target_file_path: for hint, code in blocks: if hint and target_file_path.endswith(hint.lstrip("/")): return code if hint and hint.endswith(target_file_path): return code # Second pass: find the block where the function is defined for _hint, code in blocks: try: tree = ast.parse(code) except SyntaxError: continue for node in ast.walk(tree): if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef): if node.name == short_name: return code return None def normalize_whitespace(code: str) -> str: """Normalize whitespace for comparison.""" return textwrap.dedent(code).strip() def resolve_file_path(repo: Path, file_path: str) -> str: """If file_path doesn't exist in git, search for a similar file by basename.""" # Quick check: does it exist at HEAD? result = subprocess.run(["git", "cat-file", "-e", f"HEAD:{file_path}"], cwd=repo, capture_output=True) if result.returncode == 0: return file_path # Search for files with the same basename at HEAD basename = Path(file_path).name parent_dir = Path(file_path).parent.name # e.g., "byte_track" from ".../byte_track/v1.py" result = subprocess.run( ["git", "ls-tree", "-r", "--name-only", "HEAD"], cwd=repo, capture_output=True, text=True, encoding="utf-8" ) if result.returncode != 0: return file_path all_files = result.stdout.strip().split("\n") # First try: match parent_dir/basename pattern (e.g., "byte_track*/v1.py") # This handles renames like byte_track -> byte_tracker, area_measurement -> mask_area_measurement candidates = [f for f in all_files if f.endswith(f"/{basename}") and parent_dir.rstrip("/") in Path(f).parent.name] if len(candidates) == 1: log.info(f" Resolved file path: {file_path} -> {candidates[0]}") return candidates[0] if len(candidates) > 1: original_parts = set(Path(file_path).parts) best = max(candidates, key=lambda c: len(set(Path(c).parts) & original_parts)) log.info(f" Resolved file path: {file_path} -> {best} (from {len(candidates)} candidates)") return best # Second try: broader basename match but only if few candidates candidates = [f for f in all_files if f.endswith(f"/{basename}") or f == basename] if len(candidates) == 1: log.info(f" Resolved file path: {file_path} -> {candidates[0]}") return candidates[0] # Too many candidates with a generic name like v1.py — don't guess if len(candidates) > 5: log.warning(f" Cannot resolve {file_path}: {len(candidates)} candidates for basename '{basename}'") return file_path if len(candidates) > 1: original_parts = set(Path(file_path).parts) best = max(candidates, key=lambda c: len(set(Path(c).parts) & original_parts)) log.info(f" Resolved file path: {file_path} -> {best} (from {len(candidates)} candidates)") return best return file_path def resolve_entry(entry: dict, repo: Path, conn: psycopg2.extensions.connection) -> ResolvedTask | None: """Resolve a catalog entry into a full task with code pairs.""" trace_id = entry["trace_id"] function_name = entry["function_name"] file_path = entry["file_path"] # Strategy 1: Git-based resolution (have pre/post commit) if entry.get("git_commit") and entry.get("pre_optimization_commit"): pre_commit = entry["pre_optimization_commit"] post_commit = entry["git_commit"] original_file = git_show_file(repo, pre_commit, file_path) optimized_file = git_show_file(repo, post_commit, file_path) if original_file and optimized_file: original_func = extract_function_from_source(original_file, function_name) optimized_func = extract_function_from_source(optimized_file, function_name) if original_func and optimized_func: patch = git_diff(repo, pre_commit, post_commit, file_path) return ResolvedTask( trace_id=trace_id, function_name=function_name, file_path=file_path, speedup_x=entry.get("speedup_x"), review_quality=entry.get("review_quality"), pre_optimization_commit=pre_commit, original_function_code=original_func, optimized_function_code=optimized_file, # full file for git-based optimized_raw_markdown="", optimized_function_only=optimized_func, original_file_content=original_file, patch=patch, source=entry.get("source", ""), resolution_method="git_parent", ) # Try to fix file_path if it doesn't exist in the repo file_path = resolve_file_path(repo, file_path) # Strategy 2: DB code + git history search original_code = get_original_code_from_db(conn, trace_id, function_name, file_path) optimized_code_full = get_optimized_code_from_db( conn, trace_id, entry.get("best_optimization_id"), function_name, file_path ) # Also get the raw markdown (with file path fences) for multi-file replacement optimized_raw = get_optimized_code_raw_from_db(conn, trace_id, entry.get("best_optimization_id")) or "" if not original_code or not optimized_code_full: log.debug(f" Missing code in DB for {trace_id} ({function_name})") return None optimized_func_only = extract_function_from_source(optimized_code_full, function_name) optimized_code = optimized_code_full # Find commit where original code matches matching_commit = find_commit_with_matching_function(repo, file_path, function_name, original_code) if matching_commit: original_file = git_show_file(repo, matching_commit, file_path) original_func = extract_function_from_source(original_file, function_name) if original_file else None if original_file and original_func: return ResolvedTask( trace_id=trace_id, function_name=function_name, file_path=file_path, speedup_x=entry.get("speedup_x"), review_quality=entry.get("review_quality"), pre_optimization_commit=matching_commit, original_function_code=original_func, optimized_function_code=optimized_code, optimized_raw_markdown=optimized_raw, optimized_function_only=optimized_func_only, original_file_content=original_file, patch=None, source=entry.get("source", ""), resolution_method="db_code_match", ) # Strategy 3: DB-only (use original_code directly, no commit pinning) # This means we have the code but can't pin it to a specific repo state. # We'll use the original_code as-is and the file from HEAD. head_file = git_show_file(repo, "HEAD", file_path) if head_file: original_func = extract_function_from_source(original_code, function_name) if original_func: return ResolvedTask( trace_id=trace_id, function_name=function_name, file_path=file_path, speedup_x=entry.get("speedup_x"), review_quality=entry.get("review_quality"), pre_optimization_commit="HEAD", original_function_code=original_func, optimized_function_code=optimized_code, optimized_raw_markdown=optimized_raw, optimized_function_only=optimized_func_only, original_file_content=original_code, patch=None, source=entry.get("source", ""), resolution_method="db_code_only", ) return None def main() -> None: parser = argparse.ArgumentParser(description="Resolve catalog entries to full task data") parser.add_argument( "--catalog", type=Path, default=Path("optimization_catalog.json"), help="Input catalog from catalog_builder.py" ) parser.add_argument( "--inference-repo", type=Path, default=Path("/Users/saurabh/Dropbox/hacks/inference"), help="Path to local inference repo clone", ) parser.add_argument( "--output", type=Path, default=Path("resolved_catalog.json"), help="Output resolved catalog JSON" ) parser.add_argument("--limit", type=int, default=0, help="Limit number of entries to process (0 = all)") args = parser.parse_args() catalog = json.loads(args.catalog.read_text(encoding="utf-8")) log.info(f"Loaded {len(catalog)} catalog entries") if args.limit > 0: catalog = catalog[: args.limit] log.info(f" Processing first {args.limit} entries") conn = get_db_connection() resolved = [] failed = [] by_method = {"git_parent": 0, "db_code_match": 0, "db_code_only": 0} for i, entry in enumerate(catalog): if (i + 1) % 20 == 0: log.info(f" Processing {i + 1}/{len(catalog)}...") task = resolve_entry(entry, args.inference_repo, conn) if task: resolved.append(task) by_method[task.resolution_method] += 1 else: failed.append(entry) conn.close() log.info("=" * 60) log.info(f"RESOLUTION RESULTS: {len(resolved)}/{len(catalog)} resolved") log.info("=" * 60) for method, count in by_method.items(): log.info(f" {method}: {count}") log.info(f" Failed: {len(failed)}") if failed: log.info("\n Failed entries:") for e in failed[:10]: log.info(f" {e['function_name']} ({e['file_path']}) source={e['source']}") # Serialize (exclude large code fields from main output, save separately) output_entries = [] for task in resolved: output_entries.append( { "trace_id": task.trace_id, "function_name": task.function_name, "file_path": task.file_path, "speedup_x": task.speedup_x, "review_quality": task.review_quality, "pre_optimization_commit": task.pre_optimization_commit, "resolution_method": task.resolution_method, "source": task.source, "original_function_code_len": len(task.original_function_code), "optimized_function_code_len": len(task.optimized_function_code), "has_patch": task.patch is not None, "original_function_preview": task.original_function_code[:200] + "..." if len(task.original_function_code) > 200 else task.original_function_code, "optimized_function_only_preview": (task.optimized_function_only or "")[:200] + "..." if task.optimized_function_only and len(task.optimized_function_only) > 200 else (task.optimized_function_only or "N/A"), } ) args.output.write_text(json.dumps(output_entries, indent=2, default=str), encoding="utf-8") log.info(f"\nResolved catalog written to {args.output}") # Also save full data (with code) for task generation full_output = args.output.with_stem(args.output.stem + "_full") full_entries = [] for task in resolved: full_entries.append( { "trace_id": task.trace_id, "function_name": task.function_name, "file_path": task.file_path, "speedup_x": task.speedup_x, "review_quality": task.review_quality, "pre_optimization_commit": task.pre_optimization_commit, "resolution_method": task.resolution_method, "source": task.source, "original_function_code": task.original_function_code, "optimized_function_code": task.optimized_function_code, "optimized_raw_markdown": task.optimized_raw_markdown, "optimized_function_only": task.optimized_function_only, "original_file_content": task.original_file_content, "patch": task.patch, } ) full_output.write_text(json.dumps(full_entries, indent=2, default=str), encoding="utf-8") log.info(f"Full resolved catalog (with code) written to {full_output}") if __name__ == "__main__": main()