Refactor rate limit key (#1560)
Env vars for cf-api and Django **RATE_LIMIT_WINDOW_MS**: the time in ms **RATE_LIMIT_MAX**: the max limit
This commit is contained in:
parent
88d13a769a
commit
18de2c6889
11 changed files with 86 additions and 55 deletions
|
|
@ -2,3 +2,5 @@
|
|||
DATABASE_URL=
|
||||
SECRET_KEY=
|
||||
POSTHOG_API_KEY=
|
||||
RATE_LIMIT_WINDOW_MS=
|
||||
RATE_LIMIT_MAX=
|
||||
|
|
|
|||
36
django/aiservice/aiservice/middleware/auth_middleware.py
Normal file
36
django/aiservice/aiservice/middleware/auth_middleware.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
from django.http import JsonResponse
|
||||
from django.utils.decorators import async_only_middleware
|
||||
from authapp.auth import AuthBearer
|
||||
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
|
||||
|
||||
@async_only_middleware
|
||||
class AuthMiddleware:
|
||||
def __init__(self, get_response):
|
||||
self.get_response = get_response
|
||||
self.auth_bearer = AuthBearer()
|
||||
|
||||
# Check if the `get_response` function is a coroutine (async)
|
||||
if iscoroutinefunction(self.get_response):
|
||||
print("AuthMiddleware is async.")
|
||||
# Mark this middleware as async if get_response is async
|
||||
markcoroutinefunction(self)
|
||||
|
||||
|
||||
|
||||
async def __call__(self, request):
|
||||
token = request.headers.get("Authorization") # Extract the token from the Authorization header
|
||||
|
||||
if token:
|
||||
try:
|
||||
# Strip "Bearer " prefix if present
|
||||
token = token.replace("Bearer ", "")
|
||||
|
||||
# Authenticate the token using the AuthBearer class
|
||||
await self.auth_bearer.authenticate(request, token)
|
||||
return await self.get_response(request)
|
||||
except Exception as e:
|
||||
# Return JsonResponse asynchronously
|
||||
return JsonResponse({"error": "Invalid API key"}, status=403)
|
||||
|
||||
# If no token is found, return Unauthorized error
|
||||
return JsonResponse({"error": "Invalid API key"}, status=403)
|
||||
|
|
@ -3,6 +3,10 @@ from django.core.cache import cache
|
|||
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
|
||||
from django.utils.decorators import async_only_middleware
|
||||
import sentry_sdk
|
||||
import os
|
||||
|
||||
RATE_LIMIT_WINDOW_MS = int(os.getenv("RATE_LIMIT_WINDOW_MS", "60000"))
|
||||
RATE_LIMIT_MAX = int(os.getenv("RATE_LIMIT_MAX", "40"))
|
||||
|
||||
|
||||
# Note: This rate limiting solution works only with a single server.
|
||||
|
|
@ -20,36 +24,32 @@ class RateLimitMiddleware:
|
|||
restricted_paths = ["/ai/optimize", "/ai/testgen", "/ai/log_features", "/ai/optimize-line-profiler"]
|
||||
|
||||
if any(request.path.startswith(path) for path in restricted_paths):
|
||||
# Get the user's IP address
|
||||
ip = request.META.get("REMOTE_ADDR", "")
|
||||
authorization_header = request.META.get("HTTP_AUTHORIZATION", "")
|
||||
token = None
|
||||
if authorization_header.startswith("Bearer "):
|
||||
token = authorization_header[7:] # Extract the token
|
||||
|
||||
# Use request.user as identifier
|
||||
user_id = getattr(request, "user", None)
|
||||
if not user_id:
|
||||
return JsonResponse({'error': 'Authentication required for rate-limited endpoint'}, status=401)
|
||||
|
||||
path = request.path
|
||||
user_key = f"ratelimit:user:{user_id}:{path}"
|
||||
|
||||
ip_key = f"ratelimit:{ip}:{path}" # Create a unique cache key based on IP and path
|
||||
token_key = f"ratelimit:{token}:{path}" # User-specific rate limit based on token
|
||||
request_count_user = cache.get(user_key, 0)
|
||||
|
||||
request_count_ip = cache.get(ip_key, 0)
|
||||
request_count_token = cache.get(token_key, 0)
|
||||
|
||||
if request_count_ip >= 40 or (token and request_count_token >= 40):
|
||||
if request_count_user >= RATE_LIMIT_MAX:
|
||||
|
||||
sentry_sdk.capture_message(
|
||||
"Rate limit exceeded",
|
||||
level="warning",
|
||||
extras={
|
||||
"ip": ip,
|
||||
"user_id": user_id,
|
||||
"path": path,
|
||||
"method": request.method,
|
||||
"limiterType": "token" if (token_key.startswith("ratelimit:") and token) else "ip",
|
||||
},
|
||||
"limiterType": "user",
|
||||
}
|
||||
)
|
||||
return JsonResponse({"error": "Rate limit exceeded"}, status=429)
|
||||
|
||||
cache.set(ip_key, request_count_ip + 1, timeout=60)
|
||||
|
||||
if token:
|
||||
cache.set(token_key, request_count_token + 1, timeout=60)
|
||||
cache.set(user_key, request_count_user + 1, timeout=RATE_LIMIT_WINDOW_MS // 1000)
|
||||
|
||||
return await self.get_response(request)
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ MIDDLEWARE: list[str] = [
|
|||
"django.middleware.security.SecurityMiddleware",
|
||||
"django.middleware.common.CommonMiddleware",
|
||||
"aiservice.middleware.healthcheck.HealthCheckMiddleware",
|
||||
"aiservice.middleware.auth_middleware.AuthMiddleware",
|
||||
"aiservice.middleware.rate_limit.RateLimitMiddleware",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -4,13 +4,12 @@ import datetime as dt
|
|||
import logging
|
||||
from asyncio import Semaphore
|
||||
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from authapp.auth import AuthBearer
|
||||
from ninja import NinjaAPI, Schema
|
||||
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from log_features.models import OptimizationFeatures
|
||||
|
||||
features_api = NinjaAPI(auth=AuthBearer(), urls_namespace="log_features")
|
||||
features_api = NinjaAPI(urls_namespace="log_features")
|
||||
|
||||
|
||||
semaphores: dict[str, Semaphore] = {}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from aiservice.env_specific import (
|
|||
debug_log_sensitive_data_from_callable,
|
||||
)
|
||||
from aiservice.models.aimodels import OPTIMIZE_MODEL
|
||||
from authapp.auth import AuthBearer
|
||||
from log_features.log_features import log_features
|
||||
from ninja import NinjaAPI, Schema
|
||||
from openai import OpenAIError
|
||||
|
|
@ -36,7 +35,7 @@ if TYPE_CHECKING:
|
|||
ChatCompletionToolMessageParam,
|
||||
)
|
||||
|
||||
optimize_api = NinjaAPI(auth=AuthBearer(), urls_namespace="optimize")
|
||||
optimize_api = NinjaAPI(urls_namespace="optimize")
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = Path(__file__).parent
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from aiservice.env_specific import (
|
|||
debug_log_sensitive_data_from_callable,
|
||||
)
|
||||
from aiservice.models.aimodels import OPTIMIZE_MODEL
|
||||
from authapp.auth import AuthBearer
|
||||
from log_features.log_features import log_features
|
||||
from ninja import NinjaAPI, Schema
|
||||
from openai import OpenAIError
|
||||
|
|
@ -36,7 +35,10 @@ if TYPE_CHECKING:
|
|||
ChatCompletionToolMessageParam,
|
||||
)
|
||||
|
||||
optimize_line_profiler_api = NinjaAPI(auth=AuthBearer(), urls_namespace="optimize-line-profiler")
|
||||
from aiservice.models.aimodels import LLM
|
||||
|
||||
optimize_line_profiler_api = NinjaAPI(urls_namespace="optimize-line-profiler")
|
||||
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = Path(__file__).parent
|
||||
|
|
|
|||
|
|
@ -15,12 +15,11 @@ from aiservice.common_utils import parse_python_version
|
|||
from aiservice.env_specific import create_openai_client, debug_log_sensitive_data
|
||||
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL
|
||||
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
||||
from authapp.auth import AuthBearer
|
||||
from log_features.log_features import log_features
|
||||
from testgen.instrumentation.edit_generated_test import parse_module_to_cst, replace_definition_with_import
|
||||
from testgen.instrumentation.instrument_new_tests import instrument_test_source
|
||||
|
||||
testgen_api = NinjaAPI(auth=AuthBearer(), urls_namespace="testgen")
|
||||
testgen_api = NinjaAPI(urls_namespace="testgen")
|
||||
|
||||
openai_client = create_openai_client()
|
||||
|
||||
|
|
@ -375,7 +374,7 @@ from aiservice.analytics.posthog import ph
|
|||
"/", response={200: TestGenResponseSchema, 400: TestGenErrorResponseSchema, 500: TestGenErrorResponseSchema}
|
||||
)
|
||||
async def testgen(
|
||||
request: AuthBearer, data: TestGenSchema
|
||||
request, data: TestGenSchema
|
||||
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
|
||||
ph(request.user, "aiservice-testgen-called")
|
||||
if data.test_framework not in ["unittest", "pytest"]:
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from testgen.postprocessing.code_validator import validate_testgen_code
|
|||
from testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline
|
||||
from testgen.preprocessing.preprocess_pipeline import preprocessing_testgen_pipeline
|
||||
|
||||
testgen_api = NinjaAPI(auth=AuthBearer(), urls_namespace="testgen")
|
||||
testgen_api = NinjaAPI(urls_namespace="testgen")
|
||||
|
||||
|
||||
# Get the directory of the current file
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import * as Sentry from "@sentry/node"
|
|||
import { stripeWebhookHandler } from "./endpoints/stripe-webhook.js"
|
||||
import { trackUsage } from "./middlewares/track-usage.js"
|
||||
import { testSentry } from "./endpoints/sentry-test.js"
|
||||
import { tokenLimiter, ipLimiter } from "./middlewares/rate-limit.js"
|
||||
import { idLimiter } from "./middlewares/rate-limit.js"
|
||||
const port = process.env.PORT ?? 3001
|
||||
// Define a custom type for the wrapped Express app
|
||||
const app = express()
|
||||
|
|
@ -39,12 +39,6 @@ const appExpress = addAsync(express()) as any as AsyncExpressApp
|
|||
// Basic middleware
|
||||
appExpress.use(logRequestDetails)
|
||||
|
||||
// ip rate limiter globally (applies to all routes)
|
||||
appExpress.use(ipLimiter)
|
||||
|
||||
// token rate limiter
|
||||
appExpress.use(tokenLimiter)
|
||||
|
||||
// Mount the github webhook middleware onto the express application
|
||||
// MUST be mounted before express.json() middleware
|
||||
console.log(`Mounting GitHub webhook middleware at path: ${ghAppPathPrefix}`)
|
||||
|
|
@ -88,6 +82,9 @@ appExpress.postAsync("/cfapi/collect-email", collectEmail)
|
|||
// @ts-expect-error: TS2555 // Protected routes
|
||||
appExpress.useAsync(checkForValidAPIKey)
|
||||
|
||||
//rate limiter
|
||||
appExpress.use(idLimiter)
|
||||
|
||||
// Posthog tracks calls to all endpoints defined after this
|
||||
appExpress.use(trackEndpointCalls)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,15 @@
|
|||
import rateLimit from "express-rate-limit"
|
||||
import * as Sentry from "@sentry/node"
|
||||
import { AuthorizedUserReq } from "types.js"
|
||||
|
||||
// Load values from environment or use defaults
|
||||
const RATE_LIMIT_WINDOW_MS = parseInt(process.env.RATE_LIMIT_WINDOW_MS || "60000", 10)
|
||||
const RATE_LIMIT_MAX = parseInt(process.env.RATE_LIMIT_MAX || "30", 10)
|
||||
|
||||
// Common configuration for rate limiters
|
||||
const baseRateLimitConfig = {
|
||||
windowMs: 60 * 1000, // 1 minute
|
||||
max: 20,
|
||||
windowMs: RATE_LIMIT_WINDOW_MS, // in milliseconds
|
||||
max: RATE_LIMIT_MAX,
|
||||
standardHeaders: true,
|
||||
legacyHeaders: false,
|
||||
handler: (req, res, next, options) => {
|
||||
|
|
@ -15,7 +20,7 @@ const baseRateLimitConfig = {
|
|||
ip: req.ip,
|
||||
path: req.path,
|
||||
method: req.method,
|
||||
limiterType: options.keyGenerator(req).startsWith("token") ? "token" : "ip",
|
||||
limiterType: "id",
|
||||
},
|
||||
})
|
||||
|
||||
|
|
@ -24,23 +29,14 @@ const baseRateLimitConfig = {
|
|||
},
|
||||
}
|
||||
|
||||
// TODO: If a load balancer is introduced, update the keyGenerator logic to account for real client IP (e.g., use req.headers['x-forwarded-for'] or similar).
|
||||
// IP-based rate limiter
|
||||
export const ipLimiter = rateLimit({
|
||||
// ID-based rate limiter
|
||||
export const idLimiter = rateLimit({
|
||||
...baseRateLimitConfig,
|
||||
keyGenerator: req => req.ip,
|
||||
})
|
||||
|
||||
// Token-based rate limiter
|
||||
export const tokenLimiter = rateLimit({
|
||||
...baseRateLimitConfig,
|
||||
skip: req => {
|
||||
const authHeader = req.headers.authorization || ""
|
||||
return !authHeader.startsWith("Bearer ") || !authHeader.split(" ")[1]
|
||||
skip: (req: AuthorizedUserReq) => {
|
||||
// Skip if no userId is set — typically means checkForValidAPIKey hasn't run yet
|
||||
return !req.userId
|
||||
},
|
||||
keyGenerator: req => {
|
||||
const authHeader = req.headers.authorization || ""
|
||||
const token = authHeader.split(" ")[1]
|
||||
return `token:${token}`
|
||||
keyGenerator: (req: AuthorizedUserReq) => {
|
||||
return `ratelimit:user:${req.userId}:${req.path}`
|
||||
},
|
||||
})
|
||||
|
|
|
|||
Loading…
Reference in a new issue