codeflash-internal/experiments/rl_env/discover_existing_tests.py
2026-04-16 16:31:25 -07:00

288 lines
9.7 KiB
Python

"""Precompute a map of which optimization tasks have existing unit tests in the repo.
Uses codeflash's Jedi-based unit test discovery. Must run in an environment where the
target project's dependencies are installed (i.e., inside Docker).
Two modes:
1. Docker mode (default): Runs discovery inside Docker container for each commit
2. Local mode (--local): Runs directly, requires project deps installed
Output: existing_tests_map.json
Usage:
# Docker mode (recommended):
python rl_env/discover_existing_tests.py \
--resolved-catalog rl_env/resolved_catalog_full.json
# Local mode (if deps are installed):
python rl_env/discover_existing_tests.py --local \
--resolved-catalog rl_env/resolved_catalog_full.json
"""
from __future__ import annotations
import argparse
import json
import logging
import subprocess
import tempfile
from dataclasses import asdict, dataclass
from pathlib import Path
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
log = logging.getLogger(__name__)
INFERENCE_REPO = Path("/Users/saurabh/Dropbox/hacks/inference")
DOCKER_IMAGE = "codeflash-inference-base:latest"
@dataclass
class ExistingTest:
test_file: str # relative to repo root
test_class: str | None
test_function: str
@dataclass
class TaskTestMap:
trace_id: str
function_name: str
file_path: str
pre_optimization_commit: str
existing_tests: list[ExistingTest]
# This script runs INSIDE Docker to discover tests for a single function
DOCKER_DISCOVERY_SCRIPT = """
import json
import sys
from pathlib import Path
function_name = sys.argv[1]
file_path = sys.argv[2]
project_root = Path(sys.argv[3])
tests_root = Path(sys.argv[4])
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.models.function_types import FunctionParent, FunctionToOptimize
from codeflash.verification.verification_utils import TestConfig
abs_file = project_root / file_path
short_name = function_name.rsplit(".", maxsplit=1)[-1]
parents = []
if "." in function_name:
class_name = function_name.rsplit(".", maxsplit=1)[0]
parents = [FunctionParent(name=class_name, type="ClassDef")]
fto = FunctionToOptimize(
function_name=short_name, file_path=abs_file, parents=parents, language="python",
)
cfg = TestConfig(
tests_root=tests_root, project_root_path=project_root,
tests_project_rootdir=project_root, use_cache=False,
)
try:
function_to_tests, num_tests, _ = discover_unit_tests(cfg, file_to_funcs_to_optimize={abs_file: [fto]})
except Exception as e:
print(json.dumps({"error": str(e), "tests": []}))
sys.exit(0)
results = []
for qualified_name, test_set in function_to_tests.items():
if short_name in qualified_name:
for fct in test_set:
tif = fct.tests_in_file
rel = str(tif.test_file.relative_to(project_root)) if tif.test_file.is_absolute() else str(tif.test_file)
results.append({
"test_file": rel,
"test_class": tif.test_class,
"test_function": tif.test_function,
})
print(json.dumps({"tests": results}))
"""
def discover_via_docker(commit: str, entries: list[dict], tests_root_rel: str) -> dict[str, list[ExistingTest]]:
"""Run codeflash test discovery inside Docker for a batch of functions at one commit."""
# Build a batch script that discovers tests for all functions at this commit
batch_entries = json.dumps(
[{"function_name": e["function_name"], "file_path": e["file_path"], "trace_id": e["trace_id"]} for e in entries]
)
batch_script = f"""
import json, sys
from pathlib import Path
entries = json.loads('{batch_entries.replace("'", "\\'")}')
project_root = Path("/workspace/inference")
tests_root = project_root / "{tests_root_rel}"
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.models.function_types import FunctionParent, FunctionToOptimize
from codeflash.verification.verification_utils import TestConfig
results = {{}}
for entry in entries:
function_name = entry["function_name"]
file_path = entry["file_path"]
trace_id = entry["trace_id"]
abs_file = project_root / file_path
if not abs_file.exists():
results[trace_id] = []
continue
short_name = function_name.rsplit(".", maxsplit=1)[-1]
parents = []
if "." in function_name:
class_name = function_name.rsplit(".", maxsplit=1)[0]
parents = [FunctionParent(name=class_name, type="ClassDef")]
fto = FunctionToOptimize(
function_name=short_name, file_path=abs_file, parents=parents, language="python",
)
cfg = TestConfig(
tests_root=tests_root, project_root_path=project_root,
tests_project_rootdir=project_root, use_cache=False,
)
try:
function_to_tests, _, _ = discover_unit_tests(cfg, file_to_funcs_to_optimize={{abs_file: [fto]}})
except Exception:
results[trace_id] = []
continue
tests = []
for qualified_name, test_set in function_to_tests.items():
if short_name in qualified_name:
for fct in test_set:
tif = fct.tests_in_file
rel = str(tif.test_file.relative_to(project_root)) if tif.test_file.is_absolute() else str(tif.test_file)
tests.append({{"test_file": rel, "test_class": tif.test_class, "test_function": tif.test_function}})
results[trace_id] = tests
print(json.dumps(results))
"""
# Write script to temp file and mount into Docker
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
f.write(batch_script)
script_path = Path(f.name)
checkout_cmd = f"git checkout {commit} 2>/dev/null && " if commit != "HEAD" else ""
result = subprocess.run(
[
"docker",
"run",
"--rm",
"-v",
f"{script_path}:/tmp/discover.py:ro",
DOCKER_IMAGE,
"bash",
"-c",
f"cd /workspace/inference && {checkout_cmd}python /tmp/discover.py",
],
capture_output=True,
text=True,
encoding="utf-8",
timeout=600,
)
script_path.unlink(missing_ok=True)
if result.returncode != 0:
log.warning(f" Docker discovery failed for commit {commit[:12]}: {result.stderr[-300:]}")
return {}
# Parse output — find the JSON line
for line in result.stdout.strip().splitlines():
line = line.strip()
if line.startswith("{"):
try:
raw = json.loads(line)
return {tid: [ExistingTest(**t) for t in tests] for tid, tests in raw.items()}
except json.JSONDecodeError:
continue
log.warning(f" No JSON output from Docker for commit {commit[:12]}")
return {}
def main() -> None:
parser = argparse.ArgumentParser(description="Discover existing unit tests for optimization tasks")
parser.add_argument("--resolved-catalog", type=Path, default=Path("rl_env/resolved_catalog_full.json"))
parser.add_argument("--inference-repo", type=Path, default=INFERENCE_REPO)
parser.add_argument("--output", type=Path, default=Path("rl_env/existing_tests_map.json"))
parser.add_argument("--limit", type=int, default=0)
args = parser.parse_args()
catalog = json.loads(args.resolved_catalog.read_text(encoding="utf-8"))
log.info(f"Loaded {len(catalog)} resolved entries")
if args.limit > 0:
catalog = catalog[: args.limit]
tests_root_rel = "tests/inference/unit_tests"
# Discovery at the pre-optimization commit for each task.
# Group by commit to batch Docker runs (one container per commit).
by_commit: dict[str, list[dict]] = {}
for entry in catalog:
commit = entry["pre_optimization_commit"]
by_commit.setdefault(commit, []).append(entry)
log.info(f"Processing {len(catalog)} entries across {len(by_commit)} unique commits (via Docker)")
all_results: list[TaskTestMap] = []
total_with_tests = 0
total_tests_found = 0
for commit_idx, (commit, entries) in enumerate(by_commit.items()):
log.info(f" [{commit_idx + 1}/{len(by_commit)}] Commit {commit[:12]} ({len(entries)} functions)...")
try:
discovered = discover_via_docker(commit, entries, tests_root_rel)
except subprocess.TimeoutExpired:
log.warning(f" Timeout for commit {commit[:12]}, skipping {len(entries)} entries")
discovered = {}
for entry in entries:
trace_id = entry["trace_id"]
tests = discovered.get(trace_id, [])
all_results.append(
TaskTestMap(
trace_id=trace_id,
function_name=entry["function_name"],
file_path=entry["file_path"],
pre_optimization_commit=commit,
existing_tests=tests,
)
)
if tests:
total_with_tests += 1
total_tests_found += len(tests)
log.info("=" * 60)
log.info("DISCOVERY RESULTS:")
log.info(f" Total tasks: {len(all_results)}")
log.info(f" Tasks with existing tests: {total_with_tests}")
log.info(f" Total existing tests found: {total_tests_found}")
with_tests = [r for r in all_results if r.existing_tests]
for r in with_tests[:10]:
log.info(f" {r.function_name}: {len(r.existing_tests)} tests")
for t in r.existing_tests[:3]:
log.info(f" {t.test_file}::{t.test_class or ''}.{t.test_function}")
output_data = [asdict(r) for r in all_results]
args.output.write_text(json.dumps(output_data, indent=2), encoding="utf-8")
log.info(f"\nWrote {args.output}")
if __name__ == "__main__":
main()