fix: close authorization bypass and data-integrity bugs across dashboard

Security (critical):
- Scope member lookups to parent resource (repository_id / organization_id)
  in updateRepositoryMemberRole, removeRepositoryMember,
  updateOrganizationMemberRole, and removeOrganizationMember to prevent
  cross-tenant escalation via crafted memberId
- Replace unvalidated currentOrganizationId cookie reads with
  getAccountContext() (validates org membership) in review page and
  repo detail data loaders

Bugs:
- Add missing string-UUID branch in repository_id filter (raw SQL paths)
- Pass actual username to RepoDetailClient instead of empty string
- Remove misleading React.cache() on getAllOptimizationEventsImpl (object
  arg means reference equality never hits)
- Use create() result directly in addOrganizationMember to avoid NPE
  from unnecessary re-fetch
- Separate null-session redirect from null-event 404 in profiler page

Tests:
- Rewrite action.test.ts: org payload for Prisma findMany path, proper
  $queryRaw tagged-template mock for raw SQL path, verify repository_id
  filter is actually applied
This commit is contained in:
Kevin Turcios 2026-04-13 14:56:12 -05:00
parent 71127055f3
commit 8202ea512c
9 changed files with 375 additions and 354 deletions

View file

@ -134,16 +134,15 @@ export async function addOrganizationMember(
// Check if user exists in our database
let user = await getUserById(invitedUserId)
// If user doesn't exist, create them and re-fetch for consistent types
// If user doesn't exist, create them
if (!user) {
await prisma.users.create({
user = await prisma.users.create({
data: {
user_id: invitedUserId,
github_username: invitedUser.username,
onboarding_completed: false,
},
})
user = await getUserById(invitedUserId)
}
// Add user to organization members
@ -200,8 +199,8 @@ export async function updateOrganizationMemberRole(
},
select: { role: true },
}),
prisma.organization_members.findUnique({
where: { id: memberId },
prisma.organization_members.findFirst({
where: { id: memberId, organization_id: organizationId },
select: { id: true, role: true, user_id: true },
}),
])
@ -210,17 +209,21 @@ export async function updateOrganizationMemberRole(
return createErrorResponse("Organization not found")
}
if (!targetMember) {
return createErrorResponse("Member not found in this organization")
}
// Only admins and owners can change roles
if (currentUserMember.role !== "admin" && currentUserMember.role !== "owner") {
return createErrorResponse("Only admins can change member roles")
}
// Don't allow changing owner role
if (targetMember?.role === "owner") {
if (targetMember.role === "owner") {
return createErrorResponse("Cannot change owner role")
}
if (targetMember?.user_id === currentUserId) {
if (targetMember.user_id === currentUserId) {
return createErrorResponse("Cannot change your own role as the only admin")
}
@ -253,14 +256,14 @@ export async function removeOrganizationMember(
},
select: { role: true },
}),
prisma.organization_members.findUnique({
where: { id: memberId },
prisma.organization_members.findFirst({
where: { id: memberId, organization_id: organizationId },
select: { id: true, role: true, user_id: true },
}),
])
if (!targetMember) {
return createErrorResponse("Member not found")
return createErrorResponse("Member not found in this organization")
}
// Cannot remove owner

View file

@ -384,8 +384,8 @@ export async function updateRepositoryMemberRole(
where: { repository_id_user_id: { repository_id: repoId, user_id: currentUserId } },
select: { role: true },
}),
prisma.repository_members.findUnique({
where: { id: memberId },
prisma.repository_members.findFirst({
where: { id: memberId, repository_id: repoId },
select: { id: true, role: true, user_id: true },
}),
])
@ -394,17 +394,21 @@ export async function updateRepositoryMemberRole(
return createErrorResponse("Repository not found")
}
if (!targetMember) {
return createErrorResponse("Member not found in this repository")
}
// Only admins and owners can change roles
if (currentUserMember.role !== "admin" && currentUserMember.role !== "owner") {
return createErrorResponse("Only admins can change member roles")
}
// Don't allow changing owner role
if (targetMember?.role === "owner") {
if (targetMember.role === "owner") {
return createErrorResponse("Cannot change owner role")
}
if (targetMember?.user_id === currentUserId) {
if (targetMember.user_id === currentUserId) {
return createErrorResponse("Cannot change your own role")
}
@ -436,14 +440,14 @@ export async function removeRepositoryMember(
where: { repository_id_user_id: { repository_id: repoId, user_id: currentUserId } },
select: { role: true },
}),
prisma.repository_members.findUnique({
where: { id: memberId },
prisma.repository_members.findFirst({
where: { id: memberId, repository_id: repoId },
select: { id: true, role: true, user_id: true },
}),
])
if (!targetMember) {
return createErrorResponse("Member not found")
return createErrorResponse("Member not found in this repository")
}
// Cannot remove owner

View file

@ -1,8 +1,7 @@
"use server"
import { auth0 } from "@/lib/auth0"
import { cookies } from "next/headers"
import type { AccountPayload } from "@codeflash-ai/common"
import { getAccountContext } from "@/lib/server/get-account-context"
import {
getRepositoryById,
getOptimizationCountsByRepo,
@ -24,10 +23,7 @@ export async function getRepoDetailInitData(repositoryId: string) {
const userId = session.user.sub
const username = session.user.nickname
const cookieStore = await cookies()
const orgId = cookieStore.get("currentOrganizationId")?.value
const payload: AccountPayload = orgId ? { orgId } : { userId, username }
const payload = await getAccountContext()
const repository = await getRepositoryById(payload, repositoryId)
if (!repository) {
@ -72,7 +68,8 @@ export async function getRepoDetailInitData(repositoryId: string) {
return {
userId,
orgId: orgId ?? null,
username,
orgId: "orgId" in payload ? payload.orgId : null,
repository,
stats: {
totalAttempts: totalAttempts ?? 0,

View file

@ -39,6 +39,7 @@ export default async function RepositoryDetailPage({
<RepoDetailClient
repositoryId={repositoryId}
initialUserId={initData.userId}
initialUsername={initData.username}
initialOrgId={initData.orgId ?? null}
initialRepository={initData.repository as any}
initialStats={initData.stats}

View file

@ -485,6 +485,7 @@ export interface RepoDetailStats {
export interface RepoDetailClientProps {
repositoryId: string
initialUserId: string
initialUsername: string
initialOrgId: string | null
initialRepository: RepositoryWithUsage
initialStats: RepoDetailStats
@ -493,6 +494,7 @@ export interface RepoDetailClientProps {
export function RepoDetailClient({
repositoryId,
initialUserId,
initialUsername,
initialOrgId,
initialRepository,
initialStats,
@ -574,7 +576,7 @@ export function RepoDetailClient({
const payload: AccountPayload = currentOrg
? { orgId: currentOrg.id }
: { userId: currentUserId, username: "" }
: { userId: currentUserId, username: initialUsername }
const currentRepo = await getRepositoryById(payload, repositoryId)

View file

@ -7,7 +7,7 @@ import { auth0 } from "@/lib/auth0"
import { AccountPayload, buildOptimizationOrCondition, prisma } from "@codeflash-ai/common"
import * as Sentry from "@sentry/nextjs"
import { trackOptimizationReviewed } from "@/lib/analytics/tracking"
import { cookies } from "next/headers"
import { getAccountContext } from "@/lib/server/get-account-context"
export interface DiffContent {
oldContent: string
@ -482,11 +482,8 @@ export async function getReviewPageInitData(traceId: string) {
const userId = session.user.sub
const username = session.user.nickname
// Read org cookie to determine payload
const cookieStore = await cookies()
const orgId = cookieStore.get("currentOrganizationId")?.value
const payload: AccountPayload = orgId ? { orgId } : { userId, username }
// Use validated account context (verifies org membership if cookie is set)
const payload = await getAccountContext()
// Fetch the optimization event
const event = await getOptimizationEventById({ payload, trace_id: traceId })

View file

@ -1,4 +1,4 @@
import { notFound } from "next/navigation"
import { notFound, redirect } from "next/navigation"
import { getReviewPageInitData } from "../action"
import { ProfilerClient } from "./profiler-client"
@ -11,7 +11,11 @@ export default async function LineProfilerPage({ params }: ProfilerPageProps) {
const initData = await getReviewPageInitData(traceId)
if (!initData || !initData.event) {
if (!initData) {
redirect("/login")
}
if (!initData.event) {
notFound()
}

View file

@ -11,7 +11,8 @@ vi.mock("@/lib/services/repository-utils", () => ({
}))
// Use realistic test fixtures: valid UUIDs and Auth0-style user IDs
const mockPayload = { userId: "github|12345", username: "testuser" }
const mockOrgPayload = { orgId: "org-a1b2c3d4-e5f6-7890" }
const mockPersonalPayload = { userId: "github|12345", username: "testuser" }
const mockRepoIds = ["a1b2c3d4-e5f6-7890-abcd-ef1234567890", "b2c3d4e5-f678-9012-bcde-f12345678901"]
const mockEvents = [
@ -47,6 +48,13 @@ const mockFeatures = [
},
]
/** Helper: extract SQL pattern from a $queryRaw tagged template mock call */
function getTaggedSql(mockFn: any, callIndex: number): string {
const args = mockFn.mock.calls[callIndex]
const strings = args[0] as string[]
return strings.join("$?")
}
describe("getAllOptimizationEvents", () => {
let getAllOptimizationEvents: typeof import("../action").getAllOptimizationEvents
@ -56,29 +64,32 @@ describe("getAllOptimizationEvents", () => {
repos: [],
} as any)
// $queryRaw is used as a tagged template literal — auto-mock doesn't create it
;(prisma as any).$queryRaw = vi.fn()
const mod = await import("../action")
getAllOptimizationEvents = mod.getAllOptimizationEvents
})
describe("Path B: standard Prisma query", () => {
describe("Path B: standard Prisma query (org account)", () => {
// Org accounts use Prisma findMany/count (not raw SQL) when not sorting by review_quality
it("calls findMany and count in parallel", async () => {
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce(mockEvents)
.mockResolvedValueOnce([{ count: BigInt(2) }])
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue(mockEvents as any)
vi.mocked(prisma.optimization_events.count).mockResolvedValue(2)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue([])
await getAllOptimizationEvents({ payload: mockPayload as any })
await getAllOptimizationEvents({ payload: mockOrgPayload as any })
expect(prisma.$queryRawUnsafe).toHaveBeenCalledTimes(2)
expect(prisma.optimization_events.findMany).toHaveBeenCalledTimes(1)
expect(prisma.optimization_events.count).toHaveBeenCalledTimes(1)
})
it("batch-fetches optimization_features by trace_id array (not N+1)", async () => {
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce(mockEvents)
.mockResolvedValueOnce([{ count: BigInt(2) }])
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue(mockEvents as any)
vi.mocked(prisma.optimization_events.count).mockResolvedValue(2)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue(mockFeatures as any)
await getAllOptimizationEvents({ payload: mockPayload as any })
await getAllOptimizationEvents({ payload: mockOrgPayload as any })
// Single batch query with all trace IDs — NOT one per event
expect(prisma.optimization_features.findMany).toHaveBeenCalledTimes(1)
@ -93,12 +104,11 @@ describe("getAllOptimizationEvents", () => {
})
it("merges review_quality into events", async () => {
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce(mockEvents)
.mockResolvedValueOnce([{ count: BigInt(2) }])
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue(mockEvents as any)
vi.mocked(prisma.optimization_events.count).mockResolvedValue(2)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue(mockFeatures as any)
const result = await getAllOptimizationEvents({ payload: mockPayload as any })
const result = await getAllOptimizationEvents({ payload: mockOrgPayload as any })
expect((result.events[0] as any).review_quality).toBe("high")
expect((result.events[0] as any).review_explanation).toBe("Great optimization")
@ -106,120 +116,120 @@ describe("getAllOptimizationEvents", () => {
})
it("returns totalCount from count query", async () => {
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ count: BigInt(42) }])
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue([])
vi.mocked(prisma.optimization_events.count).mockResolvedValue(42)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue([])
const result = await getAllOptimizationEvents({ payload: mockPayload as any })
const result = await getAllOptimizationEvents({ payload: mockOrgPayload as any })
expect(result.totalCount).toBe(42)
})
it("applies pagination with skip and take", async () => {
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ count: BigInt(0) }])
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue([])
vi.mocked(prisma.optimization_events.count).mockResolvedValue(0)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue([])
await getAllOptimizationEvents({
payload: mockPayload as any,
payload: mockOrgPayload as any,
page: 3,
pageSize: 25,
})
// Check that OFFSET is calculated correctly in the SQL
const sql = vi.mocked(prisma.$queryRawUnsafe).mock.calls[0][0] as string
expect(sql).toContain("OFFSET 50") // (3 - 1) * 25
expect(sql).toContain("LIMIT 25")
expect(prisma.optimization_events.findMany).toHaveBeenCalledWith(
expect.objectContaining({
skip: 50, // (3 - 1) * 25
take: 25,
}),
)
})
it("uses default sort (created_at desc) when no sort provided", async () => {
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ count: BigInt(0) }])
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue([])
vi.mocked(prisma.optimization_events.count).mockResolvedValue(0)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue([])
await getAllOptimizationEvents({ payload: mockPayload as any })
await getAllOptimizationEvents({ payload: mockOrgPayload as any })
const sql = vi.mocked(prisma.$queryRawUnsafe).mock.calls[0][0] as string
expect(sql).toContain("ORDER BY oe.created_at DESC")
expect(prisma.optimization_events.findMany).toHaveBeenCalledWith(
expect.objectContaining({
orderBy: { created_at: "desc" },
}),
)
})
it("applies search filter", async () => {
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ count: BigInt(0) }])
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue([])
vi.mocked(prisma.optimization_events.count).mockResolvedValue(0)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue([])
await getAllOptimizationEvents({
payload: mockPayload as any,
payload: mockOrgPayload as any,
search: "calc",
})
// Check that search is included in the SQL
const sql = vi.mocked(prisma.$queryRawUnsafe).mock.calls[0][0] as string
expect(sql).toContain("oe.function_name ILIKE $1")
expect(sql).toContain("oe.file_path ILIKE $1")
expect(sql).toContain("r.full_name ILIKE $1")
// Check params include the search term
const params = vi.mocked(prisma.$queryRawUnsafe).mock.calls[0].slice(1)
expect(params[0]).toBe("%calc%")
// Check findMany was called with a search-containing where clause
const call = vi.mocked(prisma.optimization_events.findMany).mock.calls[0][0] as any
// Should have AND with OR containing the search fields
expect(call.where.AND).toBeDefined()
const orClause = call.where.AND.find((c: any) => c.OR)
expect(orClause).toBeDefined()
expect(orClause.OR).toHaveLength(3) // function_name, file_path, repository.full_name
})
it("applies repository_id filter", async () => {
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ count: BigInt(0) }])
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue([])
vi.mocked(prisma.optimization_events.count).mockResolvedValue(0)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue([])
await getAllOptimizationEvents({
payload: mockPayload as any,
payload: mockOrgPayload as any,
filter: { repository_id: mockRepoIds[0] },
})
// In the new UNION-based implementation, additional filters are NOT supported
// because they would require complex WHERE clause merging across UNION branches.
// This test now verifies the query runs without errors (which is a valid regression test).
expect(prisma.$queryRawUnsafe).toHaveBeenCalledTimes(2)
const call = vi.mocked(prisma.optimization_events.findMany).mock.calls[0][0] as any
// The repository_id filter should be in the AND clause
const repoFilter = call.where.AND.find((c: any) => c.repository_id !== undefined)
expect(repoFilter).toBeDefined()
expect(repoFilter.repository_id).toBe(mockRepoIds[0])
})
})
describe("Path A: raw SQL query (review_quality sort/filter)", () => {
it("triggers when sort includes review_quality", async () => {
vi.mocked(prisma.$queryRawUnsafe)
;(prisma as any).$queryRaw
.mockResolvedValueOnce([]) // events
.mockResolvedValueOnce([{ count: BigInt(0) }]) // count
await getAllOptimizationEvents({
payload: mockPayload as any,
payload: mockOrgPayload as any,
sort: { review_quality: "desc" },
})
expect(prisma.$queryRawUnsafe).toHaveBeenCalledTimes(2)
expect((prisma as any).$queryRaw).toHaveBeenCalledTimes(2)
// Should NOT use standard Prisma findMany
expect(prisma.optimization_events.findMany).not.toHaveBeenCalled()
})
it("triggers when filter includes review_quality", async () => {
vi.mocked(prisma.$queryRawUnsafe)
;(prisma as any).$queryRaw
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ count: BigInt(0) }])
await getAllOptimizationEvents({
payload: mockPayload as any,
payload: mockOrgPayload as any,
filter: { review_quality: "high" },
})
expect(prisma.$queryRawUnsafe).toHaveBeenCalledTimes(2)
expect((prisma as any).$queryRaw).toHaveBeenCalledTimes(2)
})
it("returns correct totalCount from BigInt conversion", async () => {
vi.mocked(prisma.$queryRawUnsafe)
;(prisma as any).$queryRaw
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ count: BigInt(99) }])
const result = await getAllOptimizationEvents({
payload: mockPayload as any,
payload: mockOrgPayload as any,
sort: { review_quality: "asc" },
})
@ -238,12 +248,12 @@ describe("getAllOptimizationEvents", () => {
repo_id: mockRepoIds[0],
},
]
vi.mocked(prisma.$queryRawUnsafe)
;(prisma as any).$queryRaw
.mockResolvedValueOnce(rawEvents)
.mockResolvedValueOnce([{ count: BigInt(1) }])
const result = await getAllOptimizationEvents({
payload: mockPayload as any,
payload: mockOrgPayload as any,
sort: { review_quality: "desc" },
})
@ -266,12 +276,12 @@ describe("getAllOptimizationEvents", () => {
repo_id: null,
},
]
vi.mocked(prisma.$queryRawUnsafe)
;(prisma as any).$queryRaw
.mockResolvedValueOnce(rawEvents)
.mockResolvedValueOnce([{ count: BigInt(1) }])
const result = await getAllOptimizationEvents({
payload: mockPayload as any,
payload: mockOrgPayload as any,
sort: { review_quality: "desc" },
})
@ -279,16 +289,17 @@ describe("getAllOptimizationEvents", () => {
})
it("includes LEFT JOIN in raw SQL queries", async () => {
vi.mocked(prisma.$queryRawUnsafe)
;(prisma as any).$queryRaw
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ count: BigInt(0) }])
await getAllOptimizationEvents({
payload: mockPayload as any,
payload: mockOrgPayload as any,
sort: { review_quality: "desc" },
})
const sql = vi.mocked(prisma.$queryRawUnsafe).mock.calls[0][0] as string
// $queryRaw is a tagged template — first arg is TemplateStringsArray
const sql = getTaggedSql((prisma as any).$queryRaw, 0)
expect(sql).toContain("LEFT JOIN optimization_features")
expect(sql).toContain("LEFT JOIN repositories")
})
@ -301,7 +312,7 @@ describe("getAllOptimizationEvents", () => {
repos: [],
} as any)
const result = await getAllOptimizationEvents({ payload: mockPayload as any })
const result = await getAllOptimizationEvents({ payload: mockPersonalPayload as any })
expect(result.events).toEqual([])
expect(result.totalCount).toBe(0)
})

View file

@ -79,24 +79,23 @@ export const getRepositoriesWithStagingEvents = withTiming(
getRepositoriesWithStagingEventsImpl,
)
// Cached implementation for getAllOptimizationEvents
// React cache() deduplicates calls with identical arguments within a single request
const getAllOptimizationEventsImpl = cache(
async ({
// Note: React cache() is NOT used here because this function takes an object argument
// (reference equality means cache never hits). Deduplication happens at a higher level.
const getAllOptimizationEventsImpl = async ({
payload,
search,
filter,
sort,
page = 1,
pageSize = 10,
}: {
}: {
payload: AccountPayload
search?: string
filter?: Record<string, any>
sort?: { [key: string]: "asc" | "desc" }
page?: number
pageSize?: number
}) => {
}) => {
const repoIds = (await getRepositoriesForAccountCached(payload)).repoIds
if (repoIds.length === 0) {
@ -149,6 +148,8 @@ const getAllOptimizationEventsImpl = cache(
whereFragments.push(Prisma.sql`oe.repository_id IS NULL`)
} else if (filter.repository_id.not !== undefined && filter.repository_id.not === null) {
whereFragments.push(Prisma.sql`oe.repository_id IS NOT NULL`)
} else if (typeof filter.repository_id === "string") {
whereFragments.push(Prisma.sql`oe.repository_id = ${filter.repository_id}`)
}
}
}
@ -306,6 +307,8 @@ const getAllOptimizationEventsImpl = cache(
filterFragments.push(Prisma.sql`AND oe.repository_id IS NULL`)
} else if (value?.not === null) {
filterFragments.push(Prisma.sql`AND oe.repository_id IS NOT NULL`)
} else if (typeof value === "string") {
filterFragments.push(Prisma.sql`AND oe.repository_id = ${value}`)
}
}
})
@ -413,8 +416,7 @@ const getAllOptimizationEventsImpl = cache(
return { events: eventsWithReviewData, totalCount }
}
},
)
}
export const getAllOptimizationEvents = withTiming(
"getAllOptimizationEvents",