perf: unify agent loop and pre-build lookup maps for O(1) tool calls

Eliminate redundant API call by extracting text from the loop's final
response directly instead of making a separate streaming call. Pre-build
candidatesBySource, candidatesById, and testModelMap in indexTraceData()
to replace repeated O(n) linear searches in tool calls and prompt
building. Combine cost/token aggregation into a single pass.
This commit is contained in:
Kevin Turcios 2026-02-15 01:36:14 -05:00
parent b09262ccbc
commit fbcc283e97
2 changed files with 96 additions and 81 deletions

View file

@ -180,16 +180,25 @@ export async function POST(request: NextRequest): Promise<Response> {
const keepalive = setInterval(() => enqueue(": keepalive\n\n"), KEEPALIVE_INTERVAL_MS)
try {
// Tool resolution loop — uses stream() + finalMessage() to avoid the SDK's
// non-streaming timeout limit (max_tokens 32k estimates >10min, which blocks
// client.messages.create()). No text handler is attached, so no text reaches
// the frontend — avoiding "stutter" from intermediate fragments.
// Unified agent loop — each iteration either processes tool calls or
// extracts the final text. Uses stream()+finalMessage() to avoid the
// SDK's non-streaming timeout (max_tokens 32k estimates >10min).
// No separate "final streaming call" — when the loop gets a non-tool
// response, it extracts text directly, saving an extra API round-trip.
let toolRounds = 0
while (toolRounds < MAX_TOOL_ROUNDS) {
const toolStream = client.messages.stream(baseParams(systemPrompt, conversationMessages))
const response = await toolStream.finalMessage()
while (toolRounds <= MAX_TOOL_ROUNDS) {
const messageStream = client.messages.stream(baseParams(systemPrompt, conversationMessages))
const response = await messageStream.finalMessage()
if (response.stop_reason !== "tool_use") break
if (response.stop_reason !== "tool_use") {
// Final response — extract text blocks and send to client
for (const block of response.content) {
if (block.type === "text") {
enqueue(`data: ${JSON.stringify({ type: "text", text: block.text })}\n\n`)
}
}
break
}
conversationMessages.push({ role: "assistant", content: response.content })
const toolResults = await processToolCalls(response.content, indexed, enqueue)
@ -197,29 +206,6 @@ export async function POST(request: NextRequest): Promise<Response> {
toolRounds++
}
// Stream the final response — this is the only text the user sees
const messageStream = client.messages.stream(baseParams(systemPrompt, conversationMessages))
messageStream.on("text", (textDelta) => {
enqueue(`data: ${JSON.stringify({ type: "text", text: textDelta })}\n\n`)
})
const finalMessage = await messageStream.finalMessage()
// Edge case: final streaming response ended with tool_use. Process tools
// and make one more streaming call without tools to get a text response.
if (finalMessage.stop_reason === "tool_use") {
conversationMessages.push({ role: "assistant", content: finalMessage.content })
const toolResults = await processToolCalls(finalMessage.content, indexed, enqueue)
conversationMessages.push({ role: "user", content: toolResults })
const followUp = client.messages.stream(baseParams(systemPrompt, conversationMessages))
followUp.on("text", (textDelta) => {
enqueue(`data: ${JSON.stringify({ type: "text", text: textDelta })}\n\n`)
})
await followUp.finalMessage()
}
enqueue("data: [DONE]\n\n")
} catch (err) {
const message = err instanceof Anthropic.APIError

View file

@ -4,6 +4,28 @@ import { execFile } from "node:child_process"
import { prisma } from "@/lib/prisma"
import type { TraceData } from "./get-trace-data"
interface Candidate {
id: string
code: string
source: string
model?: string
explanation?: string
rank?: number
isBest: boolean
parentId?: string | null
}
interface LlmCallSummary {
id: string
call_type: string | null
model_name: string | null
status: string
latency_ms: number | null
llm_cost: number | null
total_tokens: number | null
context: Record<string, unknown> | null
}
export interface IndexedTraceData {
functionName: string | null
filePath: string | null
@ -12,16 +34,12 @@ export interface IndexedTraceData {
generatedTests: string[]
instrumentedTests: string[]
instrumentedPerfTests: string[]
candidates: Array<{
id: string
code: string
source: string
model?: string
explanation?: string
rank?: number
isBest: boolean
parentId?: string | null
}>
candidates: Candidate[]
candidatesBySource: Map<string, Candidate[]>
candidatesById: Map<string, Candidate>
testModelMap: Map<number, string | null>
totalCost: number
totalTokens: number
ranking: { ranking: string[]; explanation: string } | null
usedForPr: boolean
errors: Array<{
@ -30,16 +48,7 @@ export interface IndexedTraceData {
error_message: string | null
context: Record<string, unknown> | null
}>
llmCalls: Array<{
id: string
call_type: string | null
model_name: string | null
status: string
latency_ms: number | null
llm_cost: number | null
total_tokens: number | null
context: Record<string, unknown> | null
}>
llmCalls: LlmCallSummary[]
}
export function indexTraceData(traceData: TraceData): IndexedTraceData {
@ -92,6 +101,42 @@ export function indexTraceData(traceData: TraceData): IndexedTraceData {
Object.keys(pullRequestRaw as Record<string, unknown>).length > 0,
)
// Pre-build lookup maps for O(1) access in tool calls
const candidatesBySource = new Map<string, typeof candidates>()
const candidatesById = new Map<string, (typeof candidates)[0]>()
for (const c of candidates) {
candidatesById.set(c.id, c)
const group = candidatesBySource.get(c.source)
if (group) group.push(c)
else candidatesBySource.set(c.source, [c])
}
const llmCalls = rawLlmCalls.map((c) => ({
id: c.id,
call_type: c.call_type,
model_name: c.model_name,
status: c.status,
latency_ms: c.latency_ms,
llm_cost: c.llm_cost,
total_tokens: c.total_tokens,
context: c.context as Record<string, unknown> | null,
}))
// Pre-compute test group → model mapping and cost/token totals in single passes
const testModelMap = new Map<number, string | null>()
let totalCost = 0
let totalTokens = 0
for (const c of llmCalls) {
totalCost += c.llm_cost ?? 0
totalTokens += c.total_tokens ?? 0
if (c.call_type === "test_generation" || c.call_type === "testgen") {
const ctx = c.context
if (ctx?.test_index != null) {
testModelMap.set(ctx.test_index as number, c.model_name)
}
}
}
return {
functionName,
filePath,
@ -101,6 +146,11 @@ export function indexTraceData(traceData: TraceData): IndexedTraceData {
instrumentedTests: optimizationFeatures?.instrumented_generated_test ?? [],
instrumentedPerfTests: optimizationFeatures?.instrumented_perf_test ?? [],
candidates,
candidatesBySource,
candidatesById,
testModelMap,
totalCost,
totalTokens,
ranking: rankingData?.ranking
? { ranking: rankingData.ranking, explanation: rankingData.explanation ?? "" }
: null,
@ -111,30 +161,10 @@ export function indexTraceData(traceData: TraceData): IndexedTraceData {
error_message: e.error_message,
context: e.context as Record<string, unknown> | null,
})),
llmCalls: rawLlmCalls.map((c) => ({
id: c.id,
call_type: c.call_type,
model_name: c.model_name,
status: c.status,
latency_ms: c.latency_ms,
llm_cost: c.llm_cost,
total_tokens: c.total_tokens,
context: c.context as Record<string, unknown> | null,
})),
llmCalls,
}
}
function findModelForTestGroup(testIndex: number, data: IndexedTraceData): string | null {
const match = data.llmCalls.find((c) => {
const callType = c.call_type
if (callType !== "test_generation" && callType !== "testgen") return false
const ctx = c.context
if (!ctx || ctx.test_index == null) return false
return ctx.test_index === testIndex
})
return match?.model_name ?? null
}
// --- Codebase browsing helpers ---
function getRepoRoot(repo: string): string | null {
@ -241,15 +271,13 @@ export function buildSummaryPrompt(data: IndexedTraceData): string {
if (data.functionName) lines.push(`Function: ${data.functionName}`)
if (data.filePath) lines.push(`File: ${data.filePath}`)
const totalCost = data.llmCalls.reduce((s, c) => s + (c.llm_cost ?? 0), 0)
const totalTokens = data.llmCalls.reduce((s, c) => s + (c.total_tokens ?? 0), 0)
lines.push(`LLM calls: ${data.llmCalls.length} (total cost: $${totalCost.toFixed(4)}, tokens: ${totalTokens})`)
lines.push(`LLM calls: ${data.llmCalls.length} (total cost: $${data.totalCost.toFixed(4)}, tokens: ${data.totalTokens})`)
if (data.testFramework) lines.push(`Test framework: ${data.testFramework}`)
if (data.generatedTests.length > 0) {
lines.push(`Generated test groups: ${data.generatedTests.length}`)
for (let i = 0; i < data.generatedTests.length; i++) {
const model = findModelForTestGroup(i, data)
const model = data.testModelMap.get(i)
lines.push(` - Test group ${i + 1}: model=${model ?? "unknown"}`)
}
} else {
@ -263,7 +291,7 @@ export function buildSummaryPrompt(data: IndexedTraceData): string {
]
for (const { source, label, prefix } of candidateGroups) {
const candidates = data.candidates.filter((c) => c.source === source)
const candidates = data.candidatesBySource.get(source) ?? []
if (candidates.length > 0) {
lines.push(`${label}: ${candidates.length}`)
for (let i = 0; i < candidates.length; i++) {
@ -276,7 +304,8 @@ export function buildSummaryPrompt(data: IndexedTraceData): string {
}
if (data.ranking) {
lines.push(`Ranking: ${data.ranking.ranking.length} candidates ranked. Best: ${data.candidates.find((c) => c.isBest)?.source ?? "unknown"}`)
const bestCandidate = data.candidatesById.get(data.ranking.ranking[0] ?? "")
lines.push(`Ranking: ${data.ranking.ranking.length} candidates ranked. Best: ${bestCandidate?.source ?? "unknown"}`)
lines.push(`Used for PR: ${data.usedForPr ? "Yes" : "No"}`)
}
@ -511,7 +540,7 @@ export async function resolveToolCall(
case "get_candidate_code": {
const sourceType = args.source_type as string
const index = args.index as number
const filtered = data.candidates.filter((c) => c.source === sourceType)
const filtered = data.candidatesBySource.get(sourceType) ?? []
const candidate = filtered[index - 1]
if (!candidate) {
return `No ${sourceType} candidate at index ${index}. Available: ${filtered.length}`
@ -534,7 +563,7 @@ export async function resolveToolCall(
if (!gen && !instr && !instrPerf) {
return `No test at index ${testIndex + 1}. Available: ${data.generatedTests.length}`
}
const model = findModelForTestGroup(testIndex, data)
const model = data.testModelMap.get(testIndex)
parts.push(`Model: ${model ?? "unknown"}`)
if (gen) parts.push(`Generated test:\n\`\`\`python\n${gen}\n\`\`\``)
if (instr) parts.push(`Instrumented test:\n\`\`\`python\n${instr}\n\`\`\``)
@ -547,7 +576,7 @@ export async function resolveToolCall(
const parts = [`Explanation: ${data.ranking.explanation}`, "", "Ordered ranking:"]
for (let i = 0; i < data.ranking.ranking.length; i++) {
const id = data.ranking.ranking[i]
const cand = data.candidates.find((c) => c.id === id)
const cand = data.candidatesById.get(id)
const label = cand ? `${cand.source} (model: ${cand.model ?? "unknown"})` : id
parts.push(` #${i + 1}: ${label}${i === 0 ? " [BEST]" : ""}`)
}