fix: add defense-in-depth SQL interpolation guards to dashboard queries

All raw SQL in dashboard/action.ts interpolated user-controlled strings
(userId, username, repoIds, eventTypeFilter, repositoryId, year,
pageSize) with only single-quote escaping. While these values come from
authenticated sessions and database lookups, there was no format
validation before interpolation.

Add typed validation functions that throw early if a value doesn't
match its expected format:
- sqlUuid(): validates UUID format (hex + hyphens only)
- sqlUserId(): validates Auth0 ID format (provider|numeric_id)
- sqlUsername(): validates GitHub username format (alphanumeric + hyphens)
- sqlEventType(): validates against allowlist of known event types
- Math.trunc() for all numeric interpolations (year, pageSize, offset)

Applied consistently across buildBaseEventsCte(), statistics(), and
getOptimizationPRs() — every string entering SQL interpolation now
passes through its format-specific validator.
This commit is contained in:
Kevin Turcios 2026-04-11 04:01:31 -05:00
parent 26307af804
commit 817e588425

View file

@ -9,6 +9,49 @@ import {
} from "@codeflash-ai/common"
import { dedup } from "@/lib/request-dedup"
// ── SQL interpolation guards ──────────────────────────────────────────
// These throw early if a value doesn't match the expected format,
// preventing malformed SQL from reaching the database even if an
// upstream caller passes unexpected data.
const UUID_RE = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i
/** Validate a UUID and return it single-quote-escaped for SQL interpolation. */
function sqlUuid(value: string): string {
if (!UUID_RE.test(value)) {
throw new Error(`Invalid UUID for SQL interpolation: ${value}`)
}
return value // UUIDs contain only hex digits and hyphens — no escaping needed
}
/** Validate an Auth0 user ID (e.g. "github|12345") and return it escaped. */
function sqlUserId(value: string): string {
// Auth0 IDs: provider|numeric_id (e.g. "github|12345678")
if (!/^[a-zA-Z0-9_-]+\|[a-zA-Z0-9_-]+$/.test(value)) {
throw new Error(`Invalid user ID for SQL interpolation: ${value}`)
}
return value.replace(/'/g, "''")
}
/** Validate a GitHub username and return it escaped. */
function sqlUsername(value: string): string {
// GitHub usernames: alphanumeric + hyphens, 1-39 chars, no consecutive hyphens
if (!/^[a-zA-Z0-9](?:[a-zA-Z0-9-]*[a-zA-Z0-9])?$/.test(value) || value.length > 39) {
throw new Error(`Invalid username for SQL interpolation: ${value}`)
}
return value.replace(/'/g, "''")
}
const VALID_EVENT_TYPES = new Set(["pr_created", "pr_merged", "pr_closed", "no-pr", "all"])
/** Validate an event_type filter value and return it escaped. */
function sqlEventType(value: string): string {
if (!VALID_EVENT_TYPES.has(value)) {
throw new Error(`Invalid event type for SQL interpolation: ${value}`)
}
return value
}
export interface RepositoryWithUsage {
id: string
github_repo_id: string
@ -84,8 +127,9 @@ export async function getAllRepositories(
* each hits its own composite index independently.
*/
function buildBaseEventsCte(payload: AccountPayload, repoIds: string[], year?: number): string {
const repoIdsString = repoIds.map(id => `'${id}'`).join(",")
const yearCondition = year ? `AND EXTRACT(YEAR FROM created_at) = ${year}` : ""
const repoIdsString = repoIds.map(id => `'${sqlUuid(id)}'`).join(",")
const safeYear = year != null ? Math.trunc(year) : undefined
const yearCondition = safeYear ? `AND EXTRACT(YEAR FROM created_at) = ${safeYear}` : ""
const selectCols = `
created_at,
@ -104,8 +148,8 @@ function buildBaseEventsCte(payload: AccountPayload, repoIds: string[], year?: n
// Personal account: UNION three index-backed scans, then deduplicate.
// Each branch can seek on its leading index column.
const userId = payload.userId.replace(/'/g, "''")
const username = payload.username.replace(/'/g, "''")
const userId = sqlUserId(payload.userId)
const username = sqlUsername(payload.username)
return `base_events AS (
SELECT ${selectCols}
@ -124,6 +168,7 @@ function buildBaseEventsCte(payload: AccountPayload, repoIds: string[], year?: n
export async function statistics(payload: AccountPayload, year: number) {
try {
const safeYear = Math.trunc(year) // ensure integer for SQL interpolation
const { repoIds } = await getRepositoriesForAccountCached(payload)
if (repoIds.length === 0) {
@ -136,7 +181,7 @@ export async function statistics(payload: AccountPayload, year: number) {
}
const since = new Date(Date.now() - 30 * 24 * 60 * 60 * 1000)
const baseEventsCte = buildBaseEventsCte(payload, repoIds, year)
const baseEventsCte = buildBaseEventsCte(payload, repoIds, safeYear)
const sinceFormatted = since.toISOString()
const result = await prisma.$queryRawUnsafe<
@ -164,7 +209,7 @@ export async function statistics(payload: AccountPayload, year: number) {
event_type,
DATE(created_at) as event_date,
created_at >= '${sinceFormatted}'::timestamp as is_recent,
EXTRACT(YEAR FROM created_at)::int = ${year} as is_target_year,
EXTRACT(YEAR FROM created_at)::int = ${safeYear} as is_target_year,
EXTRACT(MONTH FROM created_at)::int as event_month
FROM base_events
),
@ -428,15 +473,16 @@ export async function getOptimizationPRs(
}
}
// Build WHERE conditions with parameterized queries
const repoIdsString = repoIds.map(id => `'${id.replace(/'/g, "''")}'`).join(",")
// Build WHERE conditions — validate all interpolated values
const repoIdsString = repoIds.map(id => `'${sqlUuid(id)}'`).join(",")
const safeRepoId = repositoryId ? sqlUuid(repositoryId) : undefined
let accountCondition: string
if ("orgId" in payload) {
accountCondition = `oe.repository_id IN (${repoIdsString})`
} else {
const userId = payload.userId.replace(/'/g, "''")
const username = payload.username.replace(/'/g, "''")
const userId = sqlUserId(payload.userId)
const username = sqlUsername(payload.username)
accountCondition = `(
oe.repository_id IN (${repoIdsString})
OR oe.user_id = '${userId}'
@ -446,12 +492,10 @@ export async function getOptimizationPRs(
const eventTypeCondition =
eventTypeFilter && eventTypeFilter !== "all"
? `AND oe.event_type = '${String(eventTypeFilter).replace(/'/g, "''")}'`
? `AND oe.event_type = '${sqlEventType(eventTypeFilter)}'`
: `AND oe.event_type IN ('pr_created','pr_merged','pr_closed')`
const repositoryCondition = repositoryId
? `AND oe.repository_id = '${String(repositoryId).replace(/'/g, "''")}'`
: ""
const repositoryCondition = safeRepoId ? `AND oe.repository_id = '${safeRepoId}'` : ""
// Separate WHERE clauses: the count query uses EXISTS to avoid joining the
// large optimization_features table when oe.pr_url already satisfies the
@ -487,7 +531,8 @@ export async function getOptimizationPRs(
)
`
const offset = (page - 1) * pageSize
const safePageSize = Math.trunc(pageSize)
const offset = Math.trunc((page - 1) * safePageSize)
// Build count query — for personal accounts, rewrite 3-way OR as UNION
// so each branch uses its optimal composite index independently instead
@ -501,15 +546,13 @@ export async function getOptimizationPRs(
WHERE ${countWhereClause}
`
} else {
const uid = payload.userId.replace(/'/g, "''")
const uname = payload.username.replace(/'/g, "''")
const uid = sqlUserId(payload.userId)
const uname = sqlUsername(payload.username)
const eventFilter =
eventTypeFilter && eventTypeFilter !== "all"
? `event_type = '${String(eventTypeFilter).replace(/'/g, "''")}'`
? `event_type = '${sqlEventType(eventTypeFilter)}'`
: `event_type IN ('pr_created','pr_merged','pr_closed')`
const repoFilter = repositoryId
? `AND repository_id = '${String(repositoryId).replace(/'/g, "''")}'`
: ""
const repoFilter = safeRepoId ? `AND repository_id = '${safeRepoId}'` : ""
const branchFilters = `AND ${eventFilter} AND is_optimization_found = true ${repoFilter}`
countSql = `
@ -601,20 +644,18 @@ export async function getOptimizationPRs(
LEFT JOIN repositories r ON oe.repository_id = r.id
WHERE ${dataWhereClause}
ORDER BY oe.created_at DESC
LIMIT ${pageSize} OFFSET ${offset}
LIMIT ${safePageSize} OFFSET ${offset}
`
} else {
// Personal: CTE with UNION to identify candidate event IDs via index
// scans, then JOIN for the data fields (only for the LIMIT'd set).
const uid = payload.userId.replace(/'/g, "''")
const uname = payload.username.replace(/'/g, "''")
const uid = sqlUserId(payload.userId)
const uname = sqlUsername(payload.username)
const eventFilter =
eventTypeFilter && eventTypeFilter !== "all"
? `event_type = '${String(eventTypeFilter).replace(/'/g, "''")}'`
? `event_type = '${sqlEventType(eventTypeFilter)}'`
: `event_type IN ('pr_created','pr_merged','pr_closed')`
const repoFilter = repositoryId
? `AND repository_id = '${String(repositoryId).replace(/'/g, "''")}'`
: ""
const repoFilter = safeRepoId ? `AND repository_id = '${safeRepoId}'` : ""
const branchFilters = `AND ${eventFilter} AND is_optimization_found = true ${repoFilter}`
dataSql = `
@ -635,7 +676,7 @@ export async function getOptimizationPRs(
LEFT JOIN repositories r ON oe.repository_id = r.id
WHERE (oe.pr_url IS NOT NULL OR of.pull_request IS NOT NULL)
ORDER BY oe.created_at DESC
LIMIT ${pageSize} OFFSET ${offset}
LIMIT ${safePageSize} OFFSET ${offset}
`
}
@ -661,7 +702,7 @@ export async function getOptimizationPRs(
])
const totalCount = Number(countRows?.[0]?.count ?? 0)
const totalPages = Math.ceil(totalCount / pageSize)
const totalPages = Math.ceil(totalCount / safePageSize)
return {
events: events.map(e => ({