codeflash-agent/packages/codeflash-api/tests/test_auth_deps.py
Kevin Turcios d20b82762a Add auth layer: key hashing, rate limiting, usage tracking
SHA-384 + base64url key hashing matching the JS client. FastAPI
dependencies for require_auth, check_rate_limit, and track_usage
with Annotated[Depends()] pattern. Per-user per-endpoint rate
limiting with employee bypass. Atomic subscription usage tracking
with enterprise org and employee exemptions. DB queries module
with asyncpg raw SQL for auth tables. 27 new tests covering auth
flow, rate limits, usage enforcement, and edge cases.
2026-04-21 21:33:02 -05:00

343 lines
9 KiB
Python

from __future__ import annotations
from datetime import datetime, timezone
from typing import Annotated
from unittest.mock import AsyncMock
import httpx
import pytest
from fastapi import Depends, FastAPI
from codeflash_api.auth._deps import (
ENDPOINT_TOKEN_COST,
check_rate_limit,
require_auth,
track_usage,
)
from codeflash_api.auth._keys import hash_api_key
from codeflash_api.auth.models import (
APIKey,
Organization,
Subscription,
)
def _make_api_key(
*,
user_id: str = "github|12345",
tier: str | None = None,
organization_id: str | None = None,
) -> APIKey:
return APIKey(
id=1,
key=hash_api_key("test-key"),
suffix="tkey",
name="Test Key",
user_id=user_id,
tier=tier,
organization_id=organization_id,
created_at=datetime.now(tz=timezone.utc),
last_used=None,
)
def _make_subscription(
*,
user_id: str = "github|12345",
plan_type: str = "FREE",
used: int = 0,
limit: int = 4000,
status: str = "active",
) -> Subscription:
return Subscription(
id="sub-1",
user_id=user_id,
plan_type=plan_type,
optimizations_used=used,
optimizations_limit=limit,
subscription_status=status,
)
def _build_app(
mock_queries: AsyncMock,
) -> FastAPI:
app = FastAPI()
app.state.queries = mock_queries
app.state.rate_limit_cache = {}
@app.get("/healthcheck")
async def healthcheck() -> dict[str, str]:
return {"status": "ok"}
@app.post("/ai/optimize")
async def optimize(
user: Annotated[object, Depends(require_auth)],
_rate: Annotated[None, Depends(check_rate_limit)],
_usage: Annotated[None, Depends(track_usage)],
) -> dict[str, str]:
return {"result": "ok"}
return app
@pytest.fixture(name="mock_queries")
def _mock_queries() -> AsyncMock:
queries = AsyncMock()
queries.get_api_key_by_hash.return_value = _make_api_key()
queries.get_subscription.return_value = _make_subscription()
queries.get_organization.return_value = None
queries.increment_usage.return_value = None
return queries
@pytest.fixture(name="app")
def _app(mock_queries: AsyncMock) -> FastAPI:
return _build_app(mock_queries)
@pytest.fixture(name="client")
async def _client(app: FastAPI) -> httpx.AsyncClient:
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(
transport=transport, base_url="http://test"
) as c:
yield c
class TestRequireAuth:
"""Tests for the require_auth dependency."""
async def test_valid_key(self, client: httpx.AsyncClient) -> None:
"""
Valid Bearer token authenticates successfully.
"""
resp = await client.post(
"/ai/optimize",
headers={"Authorization": "Bearer test-key"},
)
assert 200 == resp.status_code
async def test_missing_header(self, client: httpx.AsyncClient) -> None:
"""
Missing Authorization header returns 403.
"""
resp = await client.post("/ai/optimize")
assert 403 == resp.status_code
assert "Invalid API key" in resp.json()["detail"]
async def test_invalid_key(
self,
client: httpx.AsyncClient,
mock_queries: AsyncMock,
) -> None:
"""
Unknown API key returns 403.
"""
mock_queries.get_api_key_by_hash.return_value = None
resp = await client.post(
"/ai/optimize",
headers={"Authorization": "Bearer bad-key"},
)
assert 403 == resp.status_code
async def test_raw_token_without_bearer(
self, client: httpx.AsyncClient
) -> None:
"""
Raw token without Bearer prefix still works.
"""
resp = await client.post(
"/ai/optimize",
headers={"Authorization": "test-key"},
)
assert 200 == resp.status_code
class TestRateLimit:
"""Tests for the check_rate_limit dependency."""
async def test_under_limit(self, client: httpx.AsyncClient) -> None:
"""
Requests under the limit succeed.
"""
resp = await client.post(
"/ai/optimize",
headers={"Authorization": "Bearer test-key"},
)
assert 200 == resp.status_code
async def test_exceeds_limit(
self,
app: FastAPI,
client: httpx.AsyncClient,
) -> None:
"""
Requests past the max are rejected with 429.
"""
cache_key = "ratelimit:user:github|12345:/ai/optimize"
app.state.rate_limit_cache[cache_key] = 40
resp = await client.post(
"/ai/optimize",
headers={"Authorization": "Bearer test-key"},
)
assert 429 == resp.status_code
async def test_employee_bypass(
self,
app: FastAPI,
client: httpx.AsyncClient,
mock_queries: AsyncMock,
) -> None:
"""
Codeflash employees bypass rate limiting.
"""
mock_queries.get_api_key_by_hash.return_value = _make_api_key(
user_id="github|1271289"
)
mock_queries.get_subscription.return_value = _make_subscription(
user_id="github|1271289"
)
cache_key = "ratelimit:user:github|1271289:/ai/optimize"
app.state.rate_limit_cache[cache_key] = 999
resp = await client.post(
"/ai/optimize",
headers={"Authorization": "Bearer test-key"},
)
assert 200 == resp.status_code
class TestTrackUsage:
"""Tests for the track_usage dependency."""
async def test_deducts_cost(
self,
client: httpx.AsyncClient,
mock_queries: AsyncMock,
) -> None:
"""
Successful request increments usage by endpoint cost.
"""
await client.post(
"/ai/optimize",
headers={"Authorization": "Bearer test-key"},
)
mock_queries.increment_usage.assert_awaited_once_with(
user_id="github|12345",
cost=ENDPOINT_TOKEN_COST["optimize"],
)
async def test_limit_exceeded(
self,
client: httpx.AsyncClient,
mock_queries: AsyncMock,
) -> None:
"""
Usage over the limit returns 403.
"""
mock_queries.get_subscription.return_value = _make_subscription(
used=3995, limit=4000
)
resp = await client.post(
"/ai/optimize",
headers={"Authorization": "Bearer test-key"},
)
assert 403 == resp.status_code
assert "Usage limit exceeded" in str(resp.json()["detail"])
async def test_inactive_subscription(
self,
client: httpx.AsyncClient,
mock_queries: AsyncMock,
) -> None:
"""
Inactive subscription returns 403.
"""
mock_queries.get_subscription.return_value = _make_subscription(
status="canceled"
)
resp = await client.post(
"/ai/optimize",
headers={"Authorization": "Bearer test-key"},
)
assert 403 == resp.status_code
assert "not active" in str(resp.json()["detail"])
async def test_enterprise_org_bypasses_usage(
self,
client: httpx.AsyncClient,
mock_queries: AsyncMock,
) -> None:
"""
Enterprise org users bypass usage tracking.
"""
mock_queries.get_api_key_by_hash.return_value = _make_api_key(
organization_id="org-1"
)
mock_queries.get_organization.return_value = Organization(
id="org-1",
name="big-corp",
subscription=True,
)
resp = await client.post(
"/ai/optimize",
headers={"Authorization": "Bearer test-key"},
)
assert 200 == resp.status_code
mock_queries.increment_usage.assert_not_awaited()
async def test_employee_bypasses_usage(
self,
client: httpx.AsyncClient,
mock_queries: AsyncMock,
) -> None:
"""
Codeflash employees bypass usage tracking.
"""
mock_queries.get_api_key_by_hash.return_value = _make_api_key(
user_id="github|1271289"
)
resp = await client.post(
"/ai/optimize",
headers={"Authorization": "Bearer test-key"},
)
assert 200 == resp.status_code
mock_queries.increment_usage.assert_not_awaited()
async def test_creates_backup_subscription(
self,
client: httpx.AsyncClient,
mock_queries: AsyncMock,
) -> None:
"""
Missing subscription creates a FREE backup.
"""
mock_queries.get_subscription.return_value = None
mock_queries.create_subscription.return_value = _make_subscription()
resp = await client.post(
"/ai/optimize",
headers={"Authorization": "Bearer test-key"},
)
assert 200 == resp.status_code
mock_queries.create_subscription.assert_awaited_once()