perf: unify agent loop and pre-build lookup maps for O(1) tool calls
Eliminate redundant API call by extracting text from the loop's final response directly instead of making a separate streaming call. Pre-build candidatesBySource, candidatesById, and testModelMap in indexTraceData() to replace repeated O(n) linear searches in tool calls and prompt building. Combine cost/token aggregation into a single pass.
This commit is contained in:
parent
b09262ccbc
commit
fbcc283e97
2 changed files with 96 additions and 81 deletions
|
|
@ -180,16 +180,25 @@ export async function POST(request: NextRequest): Promise<Response> {
|
|||
const keepalive = setInterval(() => enqueue(": keepalive\n\n"), KEEPALIVE_INTERVAL_MS)
|
||||
|
||||
try {
|
||||
// Tool resolution loop — uses stream() + finalMessage() to avoid the SDK's
|
||||
// non-streaming timeout limit (max_tokens 32k estimates >10min, which blocks
|
||||
// client.messages.create()). No text handler is attached, so no text reaches
|
||||
// the frontend — avoiding "stutter" from intermediate fragments.
|
||||
// Unified agent loop — each iteration either processes tool calls or
|
||||
// extracts the final text. Uses stream()+finalMessage() to avoid the
|
||||
// SDK's non-streaming timeout (max_tokens 32k estimates >10min).
|
||||
// No separate "final streaming call" — when the loop gets a non-tool
|
||||
// response, it extracts text directly, saving an extra API round-trip.
|
||||
let toolRounds = 0
|
||||
while (toolRounds < MAX_TOOL_ROUNDS) {
|
||||
const toolStream = client.messages.stream(baseParams(systemPrompt, conversationMessages))
|
||||
const response = await toolStream.finalMessage()
|
||||
while (toolRounds <= MAX_TOOL_ROUNDS) {
|
||||
const messageStream = client.messages.stream(baseParams(systemPrompt, conversationMessages))
|
||||
const response = await messageStream.finalMessage()
|
||||
|
||||
if (response.stop_reason !== "tool_use") break
|
||||
if (response.stop_reason !== "tool_use") {
|
||||
// Final response — extract text blocks and send to client
|
||||
for (const block of response.content) {
|
||||
if (block.type === "text") {
|
||||
enqueue(`data: ${JSON.stringify({ type: "text", text: block.text })}\n\n`)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
conversationMessages.push({ role: "assistant", content: response.content })
|
||||
const toolResults = await processToolCalls(response.content, indexed, enqueue)
|
||||
|
|
@ -197,29 +206,6 @@ export async function POST(request: NextRequest): Promise<Response> {
|
|||
toolRounds++
|
||||
}
|
||||
|
||||
// Stream the final response — this is the only text the user sees
|
||||
const messageStream = client.messages.stream(baseParams(systemPrompt, conversationMessages))
|
||||
|
||||
messageStream.on("text", (textDelta) => {
|
||||
enqueue(`data: ${JSON.stringify({ type: "text", text: textDelta })}\n\n`)
|
||||
})
|
||||
|
||||
const finalMessage = await messageStream.finalMessage()
|
||||
|
||||
// Edge case: final streaming response ended with tool_use. Process tools
|
||||
// and make one more streaming call without tools to get a text response.
|
||||
if (finalMessage.stop_reason === "tool_use") {
|
||||
conversationMessages.push({ role: "assistant", content: finalMessage.content })
|
||||
const toolResults = await processToolCalls(finalMessage.content, indexed, enqueue)
|
||||
conversationMessages.push({ role: "user", content: toolResults })
|
||||
|
||||
const followUp = client.messages.stream(baseParams(systemPrompt, conversationMessages))
|
||||
followUp.on("text", (textDelta) => {
|
||||
enqueue(`data: ${JSON.stringify({ type: "text", text: textDelta })}\n\n`)
|
||||
})
|
||||
await followUp.finalMessage()
|
||||
}
|
||||
|
||||
enqueue("data: [DONE]\n\n")
|
||||
} catch (err) {
|
||||
const message = err instanceof Anthropic.APIError
|
||||
|
|
|
|||
|
|
@ -4,6 +4,28 @@ import { execFile } from "node:child_process"
|
|||
import { prisma } from "@/lib/prisma"
|
||||
import type { TraceData } from "./get-trace-data"
|
||||
|
||||
interface Candidate {
|
||||
id: string
|
||||
code: string
|
||||
source: string
|
||||
model?: string
|
||||
explanation?: string
|
||||
rank?: number
|
||||
isBest: boolean
|
||||
parentId?: string | null
|
||||
}
|
||||
|
||||
interface LlmCallSummary {
|
||||
id: string
|
||||
call_type: string | null
|
||||
model_name: string | null
|
||||
status: string
|
||||
latency_ms: number | null
|
||||
llm_cost: number | null
|
||||
total_tokens: number | null
|
||||
context: Record<string, unknown> | null
|
||||
}
|
||||
|
||||
export interface IndexedTraceData {
|
||||
functionName: string | null
|
||||
filePath: string | null
|
||||
|
|
@ -12,16 +34,12 @@ export interface IndexedTraceData {
|
|||
generatedTests: string[]
|
||||
instrumentedTests: string[]
|
||||
instrumentedPerfTests: string[]
|
||||
candidates: Array<{
|
||||
id: string
|
||||
code: string
|
||||
source: string
|
||||
model?: string
|
||||
explanation?: string
|
||||
rank?: number
|
||||
isBest: boolean
|
||||
parentId?: string | null
|
||||
}>
|
||||
candidates: Candidate[]
|
||||
candidatesBySource: Map<string, Candidate[]>
|
||||
candidatesById: Map<string, Candidate>
|
||||
testModelMap: Map<number, string | null>
|
||||
totalCost: number
|
||||
totalTokens: number
|
||||
ranking: { ranking: string[]; explanation: string } | null
|
||||
usedForPr: boolean
|
||||
errors: Array<{
|
||||
|
|
@ -30,16 +48,7 @@ export interface IndexedTraceData {
|
|||
error_message: string | null
|
||||
context: Record<string, unknown> | null
|
||||
}>
|
||||
llmCalls: Array<{
|
||||
id: string
|
||||
call_type: string | null
|
||||
model_name: string | null
|
||||
status: string
|
||||
latency_ms: number | null
|
||||
llm_cost: number | null
|
||||
total_tokens: number | null
|
||||
context: Record<string, unknown> | null
|
||||
}>
|
||||
llmCalls: LlmCallSummary[]
|
||||
}
|
||||
|
||||
export function indexTraceData(traceData: TraceData): IndexedTraceData {
|
||||
|
|
@ -92,6 +101,42 @@ export function indexTraceData(traceData: TraceData): IndexedTraceData {
|
|||
Object.keys(pullRequestRaw as Record<string, unknown>).length > 0,
|
||||
)
|
||||
|
||||
// Pre-build lookup maps for O(1) access in tool calls
|
||||
const candidatesBySource = new Map<string, typeof candidates>()
|
||||
const candidatesById = new Map<string, (typeof candidates)[0]>()
|
||||
for (const c of candidates) {
|
||||
candidatesById.set(c.id, c)
|
||||
const group = candidatesBySource.get(c.source)
|
||||
if (group) group.push(c)
|
||||
else candidatesBySource.set(c.source, [c])
|
||||
}
|
||||
|
||||
const llmCalls = rawLlmCalls.map((c) => ({
|
||||
id: c.id,
|
||||
call_type: c.call_type,
|
||||
model_name: c.model_name,
|
||||
status: c.status,
|
||||
latency_ms: c.latency_ms,
|
||||
llm_cost: c.llm_cost,
|
||||
total_tokens: c.total_tokens,
|
||||
context: c.context as Record<string, unknown> | null,
|
||||
}))
|
||||
|
||||
// Pre-compute test group → model mapping and cost/token totals in single passes
|
||||
const testModelMap = new Map<number, string | null>()
|
||||
let totalCost = 0
|
||||
let totalTokens = 0
|
||||
for (const c of llmCalls) {
|
||||
totalCost += c.llm_cost ?? 0
|
||||
totalTokens += c.total_tokens ?? 0
|
||||
if (c.call_type === "test_generation" || c.call_type === "testgen") {
|
||||
const ctx = c.context
|
||||
if (ctx?.test_index != null) {
|
||||
testModelMap.set(ctx.test_index as number, c.model_name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
functionName,
|
||||
filePath,
|
||||
|
|
@ -101,6 +146,11 @@ export function indexTraceData(traceData: TraceData): IndexedTraceData {
|
|||
instrumentedTests: optimizationFeatures?.instrumented_generated_test ?? [],
|
||||
instrumentedPerfTests: optimizationFeatures?.instrumented_perf_test ?? [],
|
||||
candidates,
|
||||
candidatesBySource,
|
||||
candidatesById,
|
||||
testModelMap,
|
||||
totalCost,
|
||||
totalTokens,
|
||||
ranking: rankingData?.ranking
|
||||
? { ranking: rankingData.ranking, explanation: rankingData.explanation ?? "" }
|
||||
: null,
|
||||
|
|
@ -111,30 +161,10 @@ export function indexTraceData(traceData: TraceData): IndexedTraceData {
|
|||
error_message: e.error_message,
|
||||
context: e.context as Record<string, unknown> | null,
|
||||
})),
|
||||
llmCalls: rawLlmCalls.map((c) => ({
|
||||
id: c.id,
|
||||
call_type: c.call_type,
|
||||
model_name: c.model_name,
|
||||
status: c.status,
|
||||
latency_ms: c.latency_ms,
|
||||
llm_cost: c.llm_cost,
|
||||
total_tokens: c.total_tokens,
|
||||
context: c.context as Record<string, unknown> | null,
|
||||
})),
|
||||
llmCalls,
|
||||
}
|
||||
}
|
||||
|
||||
function findModelForTestGroup(testIndex: number, data: IndexedTraceData): string | null {
|
||||
const match = data.llmCalls.find((c) => {
|
||||
const callType = c.call_type
|
||||
if (callType !== "test_generation" && callType !== "testgen") return false
|
||||
const ctx = c.context
|
||||
if (!ctx || ctx.test_index == null) return false
|
||||
return ctx.test_index === testIndex
|
||||
})
|
||||
return match?.model_name ?? null
|
||||
}
|
||||
|
||||
// --- Codebase browsing helpers ---
|
||||
|
||||
function getRepoRoot(repo: string): string | null {
|
||||
|
|
@ -241,15 +271,13 @@ export function buildSummaryPrompt(data: IndexedTraceData): string {
|
|||
if (data.functionName) lines.push(`Function: ${data.functionName}`)
|
||||
if (data.filePath) lines.push(`File: ${data.filePath}`)
|
||||
|
||||
const totalCost = data.llmCalls.reduce((s, c) => s + (c.llm_cost ?? 0), 0)
|
||||
const totalTokens = data.llmCalls.reduce((s, c) => s + (c.total_tokens ?? 0), 0)
|
||||
lines.push(`LLM calls: ${data.llmCalls.length} (total cost: $${totalCost.toFixed(4)}, tokens: ${totalTokens})`)
|
||||
lines.push(`LLM calls: ${data.llmCalls.length} (total cost: $${data.totalCost.toFixed(4)}, tokens: ${data.totalTokens})`)
|
||||
|
||||
if (data.testFramework) lines.push(`Test framework: ${data.testFramework}`)
|
||||
if (data.generatedTests.length > 0) {
|
||||
lines.push(`Generated test groups: ${data.generatedTests.length}`)
|
||||
for (let i = 0; i < data.generatedTests.length; i++) {
|
||||
const model = findModelForTestGroup(i, data)
|
||||
const model = data.testModelMap.get(i)
|
||||
lines.push(` - Test group ${i + 1}: model=${model ?? "unknown"}`)
|
||||
}
|
||||
} else {
|
||||
|
|
@ -263,7 +291,7 @@ export function buildSummaryPrompt(data: IndexedTraceData): string {
|
|||
]
|
||||
|
||||
for (const { source, label, prefix } of candidateGroups) {
|
||||
const candidates = data.candidates.filter((c) => c.source === source)
|
||||
const candidates = data.candidatesBySource.get(source) ?? []
|
||||
if (candidates.length > 0) {
|
||||
lines.push(`${label}: ${candidates.length}`)
|
||||
for (let i = 0; i < candidates.length; i++) {
|
||||
|
|
@ -276,7 +304,8 @@ export function buildSummaryPrompt(data: IndexedTraceData): string {
|
|||
}
|
||||
|
||||
if (data.ranking) {
|
||||
lines.push(`Ranking: ${data.ranking.ranking.length} candidates ranked. Best: ${data.candidates.find((c) => c.isBest)?.source ?? "unknown"}`)
|
||||
const bestCandidate = data.candidatesById.get(data.ranking.ranking[0] ?? "")
|
||||
lines.push(`Ranking: ${data.ranking.ranking.length} candidates ranked. Best: ${bestCandidate?.source ?? "unknown"}`)
|
||||
lines.push(`Used for PR: ${data.usedForPr ? "Yes" : "No"}`)
|
||||
}
|
||||
|
||||
|
|
@ -511,7 +540,7 @@ export async function resolveToolCall(
|
|||
case "get_candidate_code": {
|
||||
const sourceType = args.source_type as string
|
||||
const index = args.index as number
|
||||
const filtered = data.candidates.filter((c) => c.source === sourceType)
|
||||
const filtered = data.candidatesBySource.get(sourceType) ?? []
|
||||
const candidate = filtered[index - 1]
|
||||
if (!candidate) {
|
||||
return `No ${sourceType} candidate at index ${index}. Available: ${filtered.length}`
|
||||
|
|
@ -534,7 +563,7 @@ export async function resolveToolCall(
|
|||
if (!gen && !instr && !instrPerf) {
|
||||
return `No test at index ${testIndex + 1}. Available: ${data.generatedTests.length}`
|
||||
}
|
||||
const model = findModelForTestGroup(testIndex, data)
|
||||
const model = data.testModelMap.get(testIndex)
|
||||
parts.push(`Model: ${model ?? "unknown"}`)
|
||||
if (gen) parts.push(`Generated test:\n\`\`\`python\n${gen}\n\`\`\``)
|
||||
if (instr) parts.push(`Instrumented test:\n\`\`\`python\n${instr}\n\`\`\``)
|
||||
|
|
@ -547,7 +576,7 @@ export async function resolveToolCall(
|
|||
const parts = [`Explanation: ${data.ranking.explanation}`, "", "Ordered ranking:"]
|
||||
for (let i = 0; i < data.ranking.ranking.length; i++) {
|
||||
const id = data.ranking.ranking[i]
|
||||
const cand = data.candidates.find((c) => c.id === id)
|
||||
const cand = data.candidatesById.get(id)
|
||||
const label = cand ? `${cand.source} (model: ${cand.model ?? "unknown"})` : id
|
||||
parts.push(` #${i + 1}: ${label}${i === 0 ? " [BEST]" : ""}`)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue