fix: eliminate redundant DB queries in middleware and unblock LLM responses

Auth now attaches fetched organization/subscription to the request so
TrackUsageMiddleware reuses them instead of re-querying. RateLimitMiddleware
caches restricted_paths at init and uses async cache methods. LLM call
recording is fire-and-forget via asyncio.create_task to avoid blocking
responses on DB writes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
aseembits93 2026-02-23 20:43:18 +05:30
parent 05aecd6fbd
commit 7f824ce101
4 changed files with 35 additions and 35 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import logging
import os
import time
@ -233,8 +234,8 @@ async def call_llm(
finally:
latency_ms = int((time.time() - start_time) * 1000)
try:
await record_llm_call(
task = asyncio.create_task(
record_llm_call(
trace_id=trace_id,
call_type=call_type,
model_name=llm.name,
@ -247,8 +248,13 @@ async def call_llm(
llm_cost=calculate_llm_cost(result.raw_response, llm) if result else None,
latency_ms=latency_ms,
)
except Exception as e:
logger.warning(f"Tracing: Failed to record LLM call: {e}")
)
def _log_record_failure(t: asyncio.Task) -> None:
if exc := t.exception():
logger.warning(f"Tracing: Failed to record LLM call: {exc}")
task.add_done_callback(_log_record_failure)
# =============================================================================

View file

@ -31,15 +31,14 @@ class RateLimitMiddleware:
if iscoroutinefunction(self.get_response):
print("RateLimitMiddleware is async.")
markcoroutinefunction(self) # type: ignore[arg-type]
async def __call__(self, request):
restricted_paths = [
self.restricted_paths = [
"/" + str(pattern.pattern)
for pattern in get_resolver().url_patterns
if str(pattern.pattern).startswith("ai/")
]
if any(request.path.startswith(path) for path in restricted_paths):
# Use request.user as identifier
async def __call__(self, request):
if any(request.path.startswith(path) for path in self.restricted_paths):
user_id = getattr(request, "user", None)
if not user_id:
return JsonResponse({"error": "Authentication required for rate-limited endpoint"}, status=401)
@ -49,7 +48,7 @@ class RateLimitMiddleware:
path = request.path
user_key = f"ratelimit:user:{user_id}:{path}"
request_count_user = cache.get(user_key, 0)
request_count_user = await cache.aget(user_key, 0)
if request_count_user >= RATE_LIMIT_MAX:
sentry_sdk.capture_message(
"Rate limit exceeded",
@ -58,6 +57,6 @@ class RateLimitMiddleware:
)
return JsonResponse({"error": "Rate limit exceeded"}, status=429)
cache.set(user_key, request_count_user + 1, timeout=RATE_LIMIT_WINDOW_MS // 1000)
await cache.aset(user_key, request_count_user + 1, timeout=RATE_LIMIT_WINDOW_MS // 1000)
return await self.get_response(request)

View file

@ -44,11 +44,12 @@ def get_next_subscription_period(period_end: datetime.date):
return {"start": start, "end": end}
async def check_and_reset_subscription_period(user_id: int):
async def check_and_reset_subscription_period(user_id: int, subscription: Subscriptions | None = None):
"""Check if the subscription period has ended.
If yes, reset usage and move to the next billing cycle.
"""
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
if subscription is None:
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
if not subscription:
return None
@ -94,7 +95,9 @@ class TrackUsageMiddleware:
# Check if API key is linked to an organization with subscription
organization_id = getattr(request, "organization_id", None)
if organization_id:
org = await Organizations.objects.filter(id=organization_id).afirst()
org = getattr(request, "organization", None)
if org is None:
org = await Organizations.objects.filter(id=organization_id).afirst()
if org and org.subscription:
request.subscription_info = {
"userId": user_id,
@ -106,8 +109,10 @@ class TrackUsageMiddleware:
return await self.get_response(request)
try:
# Get or create subscription
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
# Get or create subscription (reuse from auth if available)
subscription = getattr(request, "subscription", None)
if subscription is None:
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
if not subscription:
# Subscription is now created during login in cf-webapp
@ -141,7 +146,7 @@ class TrackUsageMiddleware:
)
# Lazy reset monthly usage
subscription = await check_and_reset_subscription_period(user_id)
subscription = await check_and_reset_subscription_period(user_id, subscription=subscription)
current_used = subscription.optimizations_used or 0
if current_used + cost > subscription.optimizations_limit:
@ -163,9 +168,8 @@ class TrackUsageMiddleware:
total_lifetime_optimizations=F("total_lifetime_optimizations") + cost,
)
# Re-read to get the actual updated value for the response
updated_subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
new_used = updated_subscription.optimizations_used if updated_subscription else current_used + cost
# Compute new usage from the known current value instead of re-reading from DB
new_used = current_used + cost
logging.debug(
f"track_usage_middleware.py|__call__|Atomic update completed: "

View file

@ -21,40 +21,31 @@ class AuthenticatedRequest(Protocol):
should_log_features: bool # whether to log optimization features
async def check_subscription_status(user_id, tier, organization_id=None) -> bool:
async def check_subscription_status(request, user_id, tier, organization_id=None) -> bool:
"""Check if a user has a premium subscription that doesn't require feature logging.
Args:
user_id: The ID of the user to check
tier: The user's tier if already available
organization_id: The ID of the user's organization if available
Returns:
bool: False if features should not be logged (premium user or paid org), True otherwise
Attaches fetched organization and subscription objects to the request
so downstream middleware can reuse them without re-querying.
"""
# If tier is already set, no need to check subscription
if tier is not None:
return False
try:
# Check if user belongs to a paid organization
if organization_id:
org = await Organizations.objects.filter(id=organization_id).afirst()
request.organization = org
if org and org.name == "codeflash-ai":
return True
if org and org.subscription:
# Paid organization - don't log features
return False
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
request.subscription = subscription
if subscription and subscription.plan_type.lower() in ["pro", "enterprise"]:
# Premium users for CF- don't log features
return False
except Exception as e:
print(f"Error checking subscription: {e!s}")
sentry_sdk.capture_exception(e)
# Default to not logging
return False
return True
@ -76,7 +67,7 @@ class AuthBearer(HttpBearer):
request.api_key_id = api_key_instance.id
request.organization_id = api_key_instance.organization_id
request.should_log_features = await check_subscription_status(
user_id=request.user, tier=request.tier, organization_id=request.organization_id
request, user_id=request.user, tier=request.tier, organization_id=request.organization_id
)
return token