SHA-384 + base64url key hashing matching the JS client. FastAPI dependencies for require_auth, check_rate_limit, and track_usage with Annotated[Depends()] pattern. Per-user per-endpoint rate limiting with employee bypass. Atomic subscription usage tracking with enterprise org and employee exemptions. DB queries module with asyncpg raw SQL for auth tables. 27 new tests covering auth flow, rate limits, usage enforcement, and edge cases.
343 lines
9 KiB
Python
343 lines
9 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
from typing import Annotated
|
|
from unittest.mock import AsyncMock
|
|
|
|
import httpx
|
|
import pytest
|
|
from fastapi import Depends, FastAPI
|
|
|
|
from codeflash_api.auth._deps import (
|
|
ENDPOINT_TOKEN_COST,
|
|
check_rate_limit,
|
|
require_auth,
|
|
track_usage,
|
|
)
|
|
from codeflash_api.auth._keys import hash_api_key
|
|
from codeflash_api.auth.models import (
|
|
APIKey,
|
|
Organization,
|
|
Subscription,
|
|
)
|
|
|
|
|
|
def _make_api_key(
|
|
*,
|
|
user_id: str = "github|12345",
|
|
tier: str | None = None,
|
|
organization_id: str | None = None,
|
|
) -> APIKey:
|
|
return APIKey(
|
|
id=1,
|
|
key=hash_api_key("test-key"),
|
|
suffix="tkey",
|
|
name="Test Key",
|
|
user_id=user_id,
|
|
tier=tier,
|
|
organization_id=organization_id,
|
|
created_at=datetime.now(tz=timezone.utc),
|
|
last_used=None,
|
|
)
|
|
|
|
|
|
def _make_subscription(
|
|
*,
|
|
user_id: str = "github|12345",
|
|
plan_type: str = "FREE",
|
|
used: int = 0,
|
|
limit: int = 4000,
|
|
status: str = "active",
|
|
) -> Subscription:
|
|
return Subscription(
|
|
id="sub-1",
|
|
user_id=user_id,
|
|
plan_type=plan_type,
|
|
optimizations_used=used,
|
|
optimizations_limit=limit,
|
|
subscription_status=status,
|
|
)
|
|
|
|
|
|
def _build_app(
|
|
mock_queries: AsyncMock,
|
|
) -> FastAPI:
|
|
app = FastAPI()
|
|
app.state.queries = mock_queries
|
|
app.state.rate_limit_cache = {}
|
|
|
|
@app.get("/healthcheck")
|
|
async def healthcheck() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
@app.post("/ai/optimize")
|
|
async def optimize(
|
|
user: Annotated[object, Depends(require_auth)],
|
|
_rate: Annotated[None, Depends(check_rate_limit)],
|
|
_usage: Annotated[None, Depends(track_usage)],
|
|
) -> dict[str, str]:
|
|
return {"result": "ok"}
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture(name="mock_queries")
|
|
def _mock_queries() -> AsyncMock:
|
|
queries = AsyncMock()
|
|
queries.get_api_key_by_hash.return_value = _make_api_key()
|
|
queries.get_subscription.return_value = _make_subscription()
|
|
queries.get_organization.return_value = None
|
|
queries.increment_usage.return_value = None
|
|
return queries
|
|
|
|
|
|
@pytest.fixture(name="app")
|
|
def _app(mock_queries: AsyncMock) -> FastAPI:
|
|
return _build_app(mock_queries)
|
|
|
|
|
|
@pytest.fixture(name="client")
|
|
async def _client(app: FastAPI) -> httpx.AsyncClient:
|
|
transport = httpx.ASGITransport(app=app)
|
|
async with httpx.AsyncClient(
|
|
transport=transport, base_url="http://test"
|
|
) as c:
|
|
yield c
|
|
|
|
|
|
class TestRequireAuth:
|
|
"""Tests for the require_auth dependency."""
|
|
|
|
async def test_valid_key(self, client: httpx.AsyncClient) -> None:
|
|
"""
|
|
Valid Bearer token authenticates successfully.
|
|
"""
|
|
resp = await client.post(
|
|
"/ai/optimize",
|
|
headers={"Authorization": "Bearer test-key"},
|
|
)
|
|
|
|
assert 200 == resp.status_code
|
|
|
|
async def test_missing_header(self, client: httpx.AsyncClient) -> None:
|
|
"""
|
|
Missing Authorization header returns 403.
|
|
"""
|
|
resp = await client.post("/ai/optimize")
|
|
|
|
assert 403 == resp.status_code
|
|
assert "Invalid API key" in resp.json()["detail"]
|
|
|
|
async def test_invalid_key(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
mock_queries: AsyncMock,
|
|
) -> None:
|
|
"""
|
|
Unknown API key returns 403.
|
|
"""
|
|
mock_queries.get_api_key_by_hash.return_value = None
|
|
|
|
resp = await client.post(
|
|
"/ai/optimize",
|
|
headers={"Authorization": "Bearer bad-key"},
|
|
)
|
|
|
|
assert 403 == resp.status_code
|
|
|
|
async def test_raw_token_without_bearer(
|
|
self, client: httpx.AsyncClient
|
|
) -> None:
|
|
"""
|
|
Raw token without Bearer prefix still works.
|
|
"""
|
|
resp = await client.post(
|
|
"/ai/optimize",
|
|
headers={"Authorization": "test-key"},
|
|
)
|
|
|
|
assert 200 == resp.status_code
|
|
|
|
|
|
class TestRateLimit:
|
|
"""Tests for the check_rate_limit dependency."""
|
|
|
|
async def test_under_limit(self, client: httpx.AsyncClient) -> None:
|
|
"""
|
|
Requests under the limit succeed.
|
|
"""
|
|
resp = await client.post(
|
|
"/ai/optimize",
|
|
headers={"Authorization": "Bearer test-key"},
|
|
)
|
|
|
|
assert 200 == resp.status_code
|
|
|
|
async def test_exceeds_limit(
|
|
self,
|
|
app: FastAPI,
|
|
client: httpx.AsyncClient,
|
|
) -> None:
|
|
"""
|
|
Requests past the max are rejected with 429.
|
|
"""
|
|
cache_key = "ratelimit:user:github|12345:/ai/optimize"
|
|
app.state.rate_limit_cache[cache_key] = 40
|
|
|
|
resp = await client.post(
|
|
"/ai/optimize",
|
|
headers={"Authorization": "Bearer test-key"},
|
|
)
|
|
|
|
assert 429 == resp.status_code
|
|
|
|
async def test_employee_bypass(
|
|
self,
|
|
app: FastAPI,
|
|
client: httpx.AsyncClient,
|
|
mock_queries: AsyncMock,
|
|
) -> None:
|
|
"""
|
|
Codeflash employees bypass rate limiting.
|
|
"""
|
|
mock_queries.get_api_key_by_hash.return_value = _make_api_key(
|
|
user_id="github|1271289"
|
|
)
|
|
mock_queries.get_subscription.return_value = _make_subscription(
|
|
user_id="github|1271289"
|
|
)
|
|
cache_key = "ratelimit:user:github|1271289:/ai/optimize"
|
|
app.state.rate_limit_cache[cache_key] = 999
|
|
|
|
resp = await client.post(
|
|
"/ai/optimize",
|
|
headers={"Authorization": "Bearer test-key"},
|
|
)
|
|
|
|
assert 200 == resp.status_code
|
|
|
|
|
|
class TestTrackUsage:
|
|
"""Tests for the track_usage dependency."""
|
|
|
|
async def test_deducts_cost(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
mock_queries: AsyncMock,
|
|
) -> None:
|
|
"""
|
|
Successful request increments usage by endpoint cost.
|
|
"""
|
|
await client.post(
|
|
"/ai/optimize",
|
|
headers={"Authorization": "Bearer test-key"},
|
|
)
|
|
|
|
mock_queries.increment_usage.assert_awaited_once_with(
|
|
user_id="github|12345",
|
|
cost=ENDPOINT_TOKEN_COST["optimize"],
|
|
)
|
|
|
|
async def test_limit_exceeded(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
mock_queries: AsyncMock,
|
|
) -> None:
|
|
"""
|
|
Usage over the limit returns 403.
|
|
"""
|
|
mock_queries.get_subscription.return_value = _make_subscription(
|
|
used=3995, limit=4000
|
|
)
|
|
|
|
resp = await client.post(
|
|
"/ai/optimize",
|
|
headers={"Authorization": "Bearer test-key"},
|
|
)
|
|
|
|
assert 403 == resp.status_code
|
|
assert "Usage limit exceeded" in str(resp.json()["detail"])
|
|
|
|
async def test_inactive_subscription(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
mock_queries: AsyncMock,
|
|
) -> None:
|
|
"""
|
|
Inactive subscription returns 403.
|
|
"""
|
|
mock_queries.get_subscription.return_value = _make_subscription(
|
|
status="canceled"
|
|
)
|
|
|
|
resp = await client.post(
|
|
"/ai/optimize",
|
|
headers={"Authorization": "Bearer test-key"},
|
|
)
|
|
|
|
assert 403 == resp.status_code
|
|
assert "not active" in str(resp.json()["detail"])
|
|
|
|
async def test_enterprise_org_bypasses_usage(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
mock_queries: AsyncMock,
|
|
) -> None:
|
|
"""
|
|
Enterprise org users bypass usage tracking.
|
|
"""
|
|
mock_queries.get_api_key_by_hash.return_value = _make_api_key(
|
|
organization_id="org-1"
|
|
)
|
|
mock_queries.get_organization.return_value = Organization(
|
|
id="org-1",
|
|
name="big-corp",
|
|
subscription=True,
|
|
)
|
|
|
|
resp = await client.post(
|
|
"/ai/optimize",
|
|
headers={"Authorization": "Bearer test-key"},
|
|
)
|
|
|
|
assert 200 == resp.status_code
|
|
mock_queries.increment_usage.assert_not_awaited()
|
|
|
|
async def test_employee_bypasses_usage(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
mock_queries: AsyncMock,
|
|
) -> None:
|
|
"""
|
|
Codeflash employees bypass usage tracking.
|
|
"""
|
|
mock_queries.get_api_key_by_hash.return_value = _make_api_key(
|
|
user_id="github|1271289"
|
|
)
|
|
|
|
resp = await client.post(
|
|
"/ai/optimize",
|
|
headers={"Authorization": "Bearer test-key"},
|
|
)
|
|
|
|
assert 200 == resp.status_code
|
|
mock_queries.increment_usage.assert_not_awaited()
|
|
|
|
async def test_creates_backup_subscription(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
mock_queries: AsyncMock,
|
|
) -> None:
|
|
"""
|
|
Missing subscription creates a FREE backup.
|
|
"""
|
|
mock_queries.get_subscription.return_value = None
|
|
mock_queries.create_subscription.return_value = _make_subscription()
|
|
|
|
resp = await client.post(
|
|
"/ai/optimize",
|
|
headers={"Authorization": "Bearer test-key"},
|
|
)
|
|
|
|
assert 200 == resp.status_code
|
|
mock_queries.create_subscription.assert_awaited_once()
|