fix: add defense-in-depth SQL interpolation guards to dashboard queries
All raw SQL in dashboard/action.ts interpolated user-controlled strings (userId, username, repoIds, eventTypeFilter, repositoryId, year, pageSize) with only single-quote escaping. While these values come from authenticated sessions and database lookups, there was no format validation before interpolation. Add typed validation functions that throw early if a value doesn't match its expected format: - sqlUuid(): validates UUID format (hex + hyphens only) - sqlUserId(): validates Auth0 ID format (provider|numeric_id) - sqlUsername(): validates GitHub username format (alphanumeric + hyphens) - sqlEventType(): validates against allowlist of known event types - Math.trunc() for all numeric interpolations (year, pageSize, offset) Applied consistently across buildBaseEventsCte(), statistics(), and getOptimizationPRs() — every string entering SQL interpolation now passes through its format-specific validator.
This commit is contained in:
parent
26307af804
commit
817e588425
1 changed files with 71 additions and 30 deletions
|
|
@ -9,6 +9,49 @@ import {
|
|||
} from "@codeflash-ai/common"
|
||||
import { dedup } from "@/lib/request-dedup"
|
||||
|
||||
// ── SQL interpolation guards ──────────────────────────────────────────
|
||||
// These throw early if a value doesn't match the expected format,
|
||||
// preventing malformed SQL from reaching the database even if an
|
||||
// upstream caller passes unexpected data.
|
||||
|
||||
const UUID_RE = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i
|
||||
|
||||
/** Validate a UUID and return it single-quote-escaped for SQL interpolation. */
|
||||
function sqlUuid(value: string): string {
|
||||
if (!UUID_RE.test(value)) {
|
||||
throw new Error(`Invalid UUID for SQL interpolation: ${value}`)
|
||||
}
|
||||
return value // UUIDs contain only hex digits and hyphens — no escaping needed
|
||||
}
|
||||
|
||||
/** Validate an Auth0 user ID (e.g. "github|12345") and return it escaped. */
|
||||
function sqlUserId(value: string): string {
|
||||
// Auth0 IDs: provider|numeric_id (e.g. "github|12345678")
|
||||
if (!/^[a-zA-Z0-9_-]+\|[a-zA-Z0-9_-]+$/.test(value)) {
|
||||
throw new Error(`Invalid user ID for SQL interpolation: ${value}`)
|
||||
}
|
||||
return value.replace(/'/g, "''")
|
||||
}
|
||||
|
||||
/** Validate a GitHub username and return it escaped. */
|
||||
function sqlUsername(value: string): string {
|
||||
// GitHub usernames: alphanumeric + hyphens, 1-39 chars, no consecutive hyphens
|
||||
if (!/^[a-zA-Z0-9](?:[a-zA-Z0-9-]*[a-zA-Z0-9])?$/.test(value) || value.length > 39) {
|
||||
throw new Error(`Invalid username for SQL interpolation: ${value}`)
|
||||
}
|
||||
return value.replace(/'/g, "''")
|
||||
}
|
||||
|
||||
const VALID_EVENT_TYPES = new Set(["pr_created", "pr_merged", "pr_closed", "no-pr", "all"])
|
||||
|
||||
/** Validate an event_type filter value and return it escaped. */
|
||||
function sqlEventType(value: string): string {
|
||||
if (!VALID_EVENT_TYPES.has(value)) {
|
||||
throw new Error(`Invalid event type for SQL interpolation: ${value}`)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
export interface RepositoryWithUsage {
|
||||
id: string
|
||||
github_repo_id: string
|
||||
|
|
@ -84,8 +127,9 @@ export async function getAllRepositories(
|
|||
* each hits its own composite index independently.
|
||||
*/
|
||||
function buildBaseEventsCte(payload: AccountPayload, repoIds: string[], year?: number): string {
|
||||
const repoIdsString = repoIds.map(id => `'${id}'`).join(",")
|
||||
const yearCondition = year ? `AND EXTRACT(YEAR FROM created_at) = ${year}` : ""
|
||||
const repoIdsString = repoIds.map(id => `'${sqlUuid(id)}'`).join(",")
|
||||
const safeYear = year != null ? Math.trunc(year) : undefined
|
||||
const yearCondition = safeYear ? `AND EXTRACT(YEAR FROM created_at) = ${safeYear}` : ""
|
||||
|
||||
const selectCols = `
|
||||
created_at,
|
||||
|
|
@ -104,8 +148,8 @@ function buildBaseEventsCte(payload: AccountPayload, repoIds: string[], year?: n
|
|||
|
||||
// Personal account: UNION three index-backed scans, then deduplicate.
|
||||
// Each branch can seek on its leading index column.
|
||||
const userId = payload.userId.replace(/'/g, "''")
|
||||
const username = payload.username.replace(/'/g, "''")
|
||||
const userId = sqlUserId(payload.userId)
|
||||
const username = sqlUsername(payload.username)
|
||||
|
||||
return `base_events AS (
|
||||
SELECT ${selectCols}
|
||||
|
|
@ -124,6 +168,7 @@ function buildBaseEventsCte(payload: AccountPayload, repoIds: string[], year?: n
|
|||
|
||||
export async function statistics(payload: AccountPayload, year: number) {
|
||||
try {
|
||||
const safeYear = Math.trunc(year) // ensure integer for SQL interpolation
|
||||
const { repoIds } = await getRepositoriesForAccountCached(payload)
|
||||
|
||||
if (repoIds.length === 0) {
|
||||
|
|
@ -136,7 +181,7 @@ export async function statistics(payload: AccountPayload, year: number) {
|
|||
}
|
||||
|
||||
const since = new Date(Date.now() - 30 * 24 * 60 * 60 * 1000)
|
||||
const baseEventsCte = buildBaseEventsCte(payload, repoIds, year)
|
||||
const baseEventsCte = buildBaseEventsCte(payload, repoIds, safeYear)
|
||||
|
||||
const sinceFormatted = since.toISOString()
|
||||
const result = await prisma.$queryRawUnsafe<
|
||||
|
|
@ -164,7 +209,7 @@ export async function statistics(payload: AccountPayload, year: number) {
|
|||
event_type,
|
||||
DATE(created_at) as event_date,
|
||||
created_at >= '${sinceFormatted}'::timestamp as is_recent,
|
||||
EXTRACT(YEAR FROM created_at)::int = ${year} as is_target_year,
|
||||
EXTRACT(YEAR FROM created_at)::int = ${safeYear} as is_target_year,
|
||||
EXTRACT(MONTH FROM created_at)::int as event_month
|
||||
FROM base_events
|
||||
),
|
||||
|
|
@ -428,15 +473,16 @@ export async function getOptimizationPRs(
|
|||
}
|
||||
}
|
||||
|
||||
// Build WHERE conditions with parameterized queries
|
||||
const repoIdsString = repoIds.map(id => `'${id.replace(/'/g, "''")}'`).join(",")
|
||||
// Build WHERE conditions — validate all interpolated values
|
||||
const repoIdsString = repoIds.map(id => `'${sqlUuid(id)}'`).join(",")
|
||||
const safeRepoId = repositoryId ? sqlUuid(repositoryId) : undefined
|
||||
|
||||
let accountCondition: string
|
||||
if ("orgId" in payload) {
|
||||
accountCondition = `oe.repository_id IN (${repoIdsString})`
|
||||
} else {
|
||||
const userId = payload.userId.replace(/'/g, "''")
|
||||
const username = payload.username.replace(/'/g, "''")
|
||||
const userId = sqlUserId(payload.userId)
|
||||
const username = sqlUsername(payload.username)
|
||||
accountCondition = `(
|
||||
oe.repository_id IN (${repoIdsString})
|
||||
OR oe.user_id = '${userId}'
|
||||
|
|
@ -446,12 +492,10 @@ export async function getOptimizationPRs(
|
|||
|
||||
const eventTypeCondition =
|
||||
eventTypeFilter && eventTypeFilter !== "all"
|
||||
? `AND oe.event_type = '${String(eventTypeFilter).replace(/'/g, "''")}'`
|
||||
? `AND oe.event_type = '${sqlEventType(eventTypeFilter)}'`
|
||||
: `AND oe.event_type IN ('pr_created','pr_merged','pr_closed')`
|
||||
|
||||
const repositoryCondition = repositoryId
|
||||
? `AND oe.repository_id = '${String(repositoryId).replace(/'/g, "''")}'`
|
||||
: ""
|
||||
const repositoryCondition = safeRepoId ? `AND oe.repository_id = '${safeRepoId}'` : ""
|
||||
|
||||
// Separate WHERE clauses: the count query uses EXISTS to avoid joining the
|
||||
// large optimization_features table when oe.pr_url already satisfies the
|
||||
|
|
@ -487,7 +531,8 @@ export async function getOptimizationPRs(
|
|||
)
|
||||
`
|
||||
|
||||
const offset = (page - 1) * pageSize
|
||||
const safePageSize = Math.trunc(pageSize)
|
||||
const offset = Math.trunc((page - 1) * safePageSize)
|
||||
|
||||
// Build count query — for personal accounts, rewrite 3-way OR as UNION
|
||||
// so each branch uses its optimal composite index independently instead
|
||||
|
|
@ -501,15 +546,13 @@ export async function getOptimizationPRs(
|
|||
WHERE ${countWhereClause}
|
||||
`
|
||||
} else {
|
||||
const uid = payload.userId.replace(/'/g, "''")
|
||||
const uname = payload.username.replace(/'/g, "''")
|
||||
const uid = sqlUserId(payload.userId)
|
||||
const uname = sqlUsername(payload.username)
|
||||
const eventFilter =
|
||||
eventTypeFilter && eventTypeFilter !== "all"
|
||||
? `event_type = '${String(eventTypeFilter).replace(/'/g, "''")}'`
|
||||
? `event_type = '${sqlEventType(eventTypeFilter)}'`
|
||||
: `event_type IN ('pr_created','pr_merged','pr_closed')`
|
||||
const repoFilter = repositoryId
|
||||
? `AND repository_id = '${String(repositoryId).replace(/'/g, "''")}'`
|
||||
: ""
|
||||
const repoFilter = safeRepoId ? `AND repository_id = '${safeRepoId}'` : ""
|
||||
const branchFilters = `AND ${eventFilter} AND is_optimization_found = true ${repoFilter}`
|
||||
|
||||
countSql = `
|
||||
|
|
@ -601,20 +644,18 @@ export async function getOptimizationPRs(
|
|||
LEFT JOIN repositories r ON oe.repository_id = r.id
|
||||
WHERE ${dataWhereClause}
|
||||
ORDER BY oe.created_at DESC
|
||||
LIMIT ${pageSize} OFFSET ${offset}
|
||||
LIMIT ${safePageSize} OFFSET ${offset}
|
||||
`
|
||||
} else {
|
||||
// Personal: CTE with UNION to identify candidate event IDs via index
|
||||
// scans, then JOIN for the data fields (only for the LIMIT'd set).
|
||||
const uid = payload.userId.replace(/'/g, "''")
|
||||
const uname = payload.username.replace(/'/g, "''")
|
||||
const uid = sqlUserId(payload.userId)
|
||||
const uname = sqlUsername(payload.username)
|
||||
const eventFilter =
|
||||
eventTypeFilter && eventTypeFilter !== "all"
|
||||
? `event_type = '${String(eventTypeFilter).replace(/'/g, "''")}'`
|
||||
? `event_type = '${sqlEventType(eventTypeFilter)}'`
|
||||
: `event_type IN ('pr_created','pr_merged','pr_closed')`
|
||||
const repoFilter = repositoryId
|
||||
? `AND repository_id = '${String(repositoryId).replace(/'/g, "''")}'`
|
||||
: ""
|
||||
const repoFilter = safeRepoId ? `AND repository_id = '${safeRepoId}'` : ""
|
||||
const branchFilters = `AND ${eventFilter} AND is_optimization_found = true ${repoFilter}`
|
||||
|
||||
dataSql = `
|
||||
|
|
@ -635,7 +676,7 @@ export async function getOptimizationPRs(
|
|||
LEFT JOIN repositories r ON oe.repository_id = r.id
|
||||
WHERE (oe.pr_url IS NOT NULL OR of.pull_request IS NOT NULL)
|
||||
ORDER BY oe.created_at DESC
|
||||
LIMIT ${pageSize} OFFSET ${offset}
|
||||
LIMIT ${safePageSize} OFFSET ${offset}
|
||||
`
|
||||
}
|
||||
|
||||
|
|
@ -661,7 +702,7 @@ export async function getOptimizationPRs(
|
|||
])
|
||||
|
||||
const totalCount = Number(countRows?.[0]?.count ?? 0)
|
||||
const totalPages = Math.ceil(totalCount / pageSize)
|
||||
const totalPages = Math.ceil(totalCount / safePageSize)
|
||||
|
||||
return {
|
||||
events: events.map(e => ({
|
||||
|
|
|
|||
Loading…
Reference in a new issue