Refactor rate limit key (#1560)

Env vars for cf-api and Django
**RATE_LIMIT_WINDOW_MS**: the time in ms
**RATE_LIMIT_MAX**: the max limit
This commit is contained in:
HeshamHM28 2025-04-22 21:29:07 +02:00 committed by GitHub
parent 88d13a769a
commit 18de2c6889
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 86 additions and 55 deletions

View file

@ -2,3 +2,5 @@
DATABASE_URL=
SECRET_KEY=
POSTHOG_API_KEY=
RATE_LIMIT_WINDOW_MS=
RATE_LIMIT_MAX=

View file

@ -0,0 +1,36 @@
from django.http import JsonResponse
from django.utils.decorators import async_only_middleware
from authapp.auth import AuthBearer
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
@async_only_middleware
class AuthMiddleware:
def __init__(self, get_response):
self.get_response = get_response
self.auth_bearer = AuthBearer()
# Check if the `get_response` function is a coroutine (async)
if iscoroutinefunction(self.get_response):
print("AuthMiddleware is async.")
# Mark this middleware as async if get_response is async
markcoroutinefunction(self)
async def __call__(self, request):
token = request.headers.get("Authorization") # Extract the token from the Authorization header
if token:
try:
# Strip "Bearer " prefix if present
token = token.replace("Bearer ", "")
# Authenticate the token using the AuthBearer class
await self.auth_bearer.authenticate(request, token)
return await self.get_response(request)
except Exception as e:
# Return JsonResponse asynchronously
return JsonResponse({"error": "Invalid API key"}, status=403)
# If no token is found, return Unauthorized error
return JsonResponse({"error": "Invalid API key"}, status=403)

View file

@ -3,6 +3,10 @@ from django.core.cache import cache
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
from django.utils.decorators import async_only_middleware
import sentry_sdk
import os
RATE_LIMIT_WINDOW_MS = int(os.getenv("RATE_LIMIT_WINDOW_MS", "60000"))
RATE_LIMIT_MAX = int(os.getenv("RATE_LIMIT_MAX", "40"))
# Note: This rate limiting solution works only with a single server.
@ -20,36 +24,32 @@ class RateLimitMiddleware:
restricted_paths = ["/ai/optimize", "/ai/testgen", "/ai/log_features", "/ai/optimize-line-profiler"]
if any(request.path.startswith(path) for path in restricted_paths):
# Get the user's IP address
ip = request.META.get("REMOTE_ADDR", "")
authorization_header = request.META.get("HTTP_AUTHORIZATION", "")
token = None
if authorization_header.startswith("Bearer "):
token = authorization_header[7:] # Extract the token
# Use request.user as identifier
user_id = getattr(request, "user", None)
if not user_id:
return JsonResponse({'error': 'Authentication required for rate-limited endpoint'}, status=401)
path = request.path
user_key = f"ratelimit:user:{user_id}:{path}"
ip_key = f"ratelimit:{ip}:{path}" # Create a unique cache key based on IP and path
token_key = f"ratelimit:{token}:{path}" # User-specific rate limit based on token
request_count_user = cache.get(user_key, 0)
request_count_ip = cache.get(ip_key, 0)
request_count_token = cache.get(token_key, 0)
if request_count_ip >= 40 or (token and request_count_token >= 40):
if request_count_user >= RATE_LIMIT_MAX:
sentry_sdk.capture_message(
"Rate limit exceeded",
level="warning",
extras={
"ip": ip,
"user_id": user_id,
"path": path,
"method": request.method,
"limiterType": "token" if (token_key.startswith("ratelimit:") and token) else "ip",
},
"limiterType": "user",
}
)
return JsonResponse({"error": "Rate limit exceeded"}, status=429)
cache.set(ip_key, request_count_ip + 1, timeout=60)
if token:
cache.set(token_key, request_count_token + 1, timeout=60)
cache.set(user_key, request_count_user + 1, timeout=RATE_LIMIT_WINDOW_MS // 1000)
return await self.get_response(request)

View file

@ -62,6 +62,7 @@ MIDDLEWARE: list[str] = [
"django.middleware.security.SecurityMiddleware",
"django.middleware.common.CommonMiddleware",
"aiservice.middleware.healthcheck.HealthCheckMiddleware",
"aiservice.middleware.auth_middleware.AuthMiddleware",
"aiservice.middleware.rate_limit.RateLimitMiddleware",
]

View file

@ -4,13 +4,12 @@ import datetime as dt
import logging
from asyncio import Semaphore
from aiservice.common_utils import validate_trace_id
from authapp.auth import AuthBearer
from ninja import NinjaAPI, Schema
from aiservice.common_utils import validate_trace_id
from log_features.models import OptimizationFeatures
features_api = NinjaAPI(auth=AuthBearer(), urls_namespace="log_features")
features_api = NinjaAPI(urls_namespace="log_features")
semaphores: dict[str, Semaphore] = {}

View file

@ -16,7 +16,6 @@ from aiservice.env_specific import (
debug_log_sensitive_data_from_callable,
)
from aiservice.models.aimodels import OPTIMIZE_MODEL
from authapp.auth import AuthBearer
from log_features.log_features import log_features
from ninja import NinjaAPI, Schema
from openai import OpenAIError
@ -36,7 +35,7 @@ if TYPE_CHECKING:
ChatCompletionToolMessageParam,
)
optimize_api = NinjaAPI(auth=AuthBearer(), urls_namespace="optimize")
optimize_api = NinjaAPI(urls_namespace="optimize")
# Get the directory of the current file
current_dir = Path(__file__).parent

View file

@ -16,7 +16,6 @@ from aiservice.env_specific import (
debug_log_sensitive_data_from_callable,
)
from aiservice.models.aimodels import OPTIMIZE_MODEL
from authapp.auth import AuthBearer
from log_features.log_features import log_features
from ninja import NinjaAPI, Schema
from openai import OpenAIError
@ -36,7 +35,10 @@ if TYPE_CHECKING:
ChatCompletionToolMessageParam,
)
optimize_line_profiler_api = NinjaAPI(auth=AuthBearer(), urls_namespace="optimize-line-profiler")
from aiservice.models.aimodels import LLM
optimize_line_profiler_api = NinjaAPI(urls_namespace="optimize-line-profiler")
# Get the directory of the current file
current_dir = Path(__file__).parent

View file

@ -15,12 +15,11 @@ from aiservice.common_utils import parse_python_version
from aiservice.env_specific import create_openai_client, debug_log_sensitive_data
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL
from aiservice.models.functions_to_optimize import FunctionToOptimize
from authapp.auth import AuthBearer
from log_features.log_features import log_features
from testgen.instrumentation.edit_generated_test import parse_module_to_cst, replace_definition_with_import
from testgen.instrumentation.instrument_new_tests import instrument_test_source
testgen_api = NinjaAPI(auth=AuthBearer(), urls_namespace="testgen")
testgen_api = NinjaAPI(urls_namespace="testgen")
openai_client = create_openai_client()
@ -375,7 +374,7 @@ from aiservice.analytics.posthog import ph
"/", response={200: TestGenResponseSchema, 400: TestGenErrorResponseSchema, 500: TestGenErrorResponseSchema}
)
async def testgen(
request: AuthBearer, data: TestGenSchema
request, data: TestGenSchema
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
ph(request.user, "aiservice-testgen-called")
if data.test_framework not in ["unittest", "pytest"]:

View file

@ -26,7 +26,7 @@ from testgen.postprocessing.code_validator import validate_testgen_code
from testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline
from testgen.preprocessing.preprocess_pipeline import preprocessing_testgen_pipeline
testgen_api = NinjaAPI(auth=AuthBearer(), urls_namespace="testgen")
testgen_api = NinjaAPI(urls_namespace="testgen")
# Get the directory of the current file

View file

@ -23,7 +23,7 @@ import * as Sentry from "@sentry/node"
import { stripeWebhookHandler } from "./endpoints/stripe-webhook.js"
import { trackUsage } from "./middlewares/track-usage.js"
import { testSentry } from "./endpoints/sentry-test.js"
import { tokenLimiter, ipLimiter } from "./middlewares/rate-limit.js"
import { idLimiter } from "./middlewares/rate-limit.js"
const port = process.env.PORT ?? 3001
// Define a custom type for the wrapped Express app
const app = express()
@ -39,12 +39,6 @@ const appExpress = addAsync(express()) as any as AsyncExpressApp
// Basic middleware
appExpress.use(logRequestDetails)
// ip rate limiter globally (applies to all routes)
appExpress.use(ipLimiter)
// token rate limiter
appExpress.use(tokenLimiter)
// Mount the github webhook middleware onto the express application
// MUST be mounted before express.json() middleware
console.log(`Mounting GitHub webhook middleware at path: ${ghAppPathPrefix}`)
@ -88,6 +82,9 @@ appExpress.postAsync("/cfapi/collect-email", collectEmail)
// @ts-expect-error: TS2555 // Protected routes
appExpress.useAsync(checkForValidAPIKey)
//rate limiter
appExpress.use(idLimiter)
// Posthog tracks calls to all endpoints defined after this
appExpress.use(trackEndpointCalls)

View file

@ -1,10 +1,15 @@
import rateLimit from "express-rate-limit"
import * as Sentry from "@sentry/node"
import { AuthorizedUserReq } from "types.js"
// Load values from environment or use defaults
const RATE_LIMIT_WINDOW_MS = parseInt(process.env.RATE_LIMIT_WINDOW_MS || "60000", 10)
const RATE_LIMIT_MAX = parseInt(process.env.RATE_LIMIT_MAX || "30", 10)
// Common configuration for rate limiters
const baseRateLimitConfig = {
windowMs: 60 * 1000, // 1 minute
max: 20,
windowMs: RATE_LIMIT_WINDOW_MS, // in milliseconds
max: RATE_LIMIT_MAX,
standardHeaders: true,
legacyHeaders: false,
handler: (req, res, next, options) => {
@ -15,7 +20,7 @@ const baseRateLimitConfig = {
ip: req.ip,
path: req.path,
method: req.method,
limiterType: options.keyGenerator(req).startsWith("token") ? "token" : "ip",
limiterType: "id",
},
})
@ -24,23 +29,14 @@ const baseRateLimitConfig = {
},
}
// TODO: If a load balancer is introduced, update the keyGenerator logic to account for real client IP (e.g., use req.headers['x-forwarded-for'] or similar).
// IP-based rate limiter
export const ipLimiter = rateLimit({
// ID-based rate limiter
export const idLimiter = rateLimit({
...baseRateLimitConfig,
keyGenerator: req => req.ip,
})
// Token-based rate limiter
export const tokenLimiter = rateLimit({
...baseRateLimitConfig,
skip: req => {
const authHeader = req.headers.authorization || ""
return !authHeader.startsWith("Bearer ") || !authHeader.split(" ")[1]
skip: (req: AuthorizedUserReq) => {
// Skip if no userId is set — typically means checkForValidAPIKey hasn't run yet
return !req.userId
},
keyGenerator: req => {
const authHeader = req.headers.authorization || ""
const token = authHeader.split(" ")[1]
return `token:${token}`
keyGenerator: (req: AuthorizedUserReq) => {
return `ratelimit:user:${req.userId}:${req.path}`
},
})