mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
fix: eliminate redundant DB queries in middleware and unblock LLM responses
Auth now attaches fetched organization/subscription to the request so TrackUsageMiddleware reuses them instead of re-querying. RateLimitMiddleware caches restricted_paths at init and uses async cache methods. LLM call recording is fire-and-forget via asyncio.create_task to avoid blocking responses on DB writes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
05aecd6fbd
commit
7f824ce101
4 changed files with 35 additions and 35 deletions
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
|
@ -233,8 +234,8 @@ async def call_llm(
|
|||
|
||||
finally:
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
try:
|
||||
await record_llm_call(
|
||||
task = asyncio.create_task(
|
||||
record_llm_call(
|
||||
trace_id=trace_id,
|
||||
call_type=call_type,
|
||||
model_name=llm.name,
|
||||
|
|
@ -247,8 +248,13 @@ async def call_llm(
|
|||
llm_cost=calculate_llm_cost(result.raw_response, llm) if result else None,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Tracing: Failed to record LLM call: {e}")
|
||||
)
|
||||
|
||||
def _log_record_failure(t: asyncio.Task) -> None:
|
||||
if exc := t.exception():
|
||||
logger.warning(f"Tracing: Failed to record LLM call: {exc}")
|
||||
|
||||
task.add_done_callback(_log_record_failure)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -31,15 +31,14 @@ class RateLimitMiddleware:
|
|||
if iscoroutinefunction(self.get_response):
|
||||
print("RateLimitMiddleware is async.")
|
||||
markcoroutinefunction(self) # type: ignore[arg-type]
|
||||
|
||||
async def __call__(self, request):
|
||||
restricted_paths = [
|
||||
self.restricted_paths = [
|
||||
"/" + str(pattern.pattern)
|
||||
for pattern in get_resolver().url_patterns
|
||||
if str(pattern.pattern).startswith("ai/")
|
||||
]
|
||||
if any(request.path.startswith(path) for path in restricted_paths):
|
||||
# Use request.user as identifier
|
||||
|
||||
async def __call__(self, request):
|
||||
if any(request.path.startswith(path) for path in self.restricted_paths):
|
||||
user_id = getattr(request, "user", None)
|
||||
if not user_id:
|
||||
return JsonResponse({"error": "Authentication required for rate-limited endpoint"}, status=401)
|
||||
|
|
@ -49,7 +48,7 @@ class RateLimitMiddleware:
|
|||
path = request.path
|
||||
user_key = f"ratelimit:user:{user_id}:{path}"
|
||||
|
||||
request_count_user = cache.get(user_key, 0)
|
||||
request_count_user = await cache.aget(user_key, 0)
|
||||
if request_count_user >= RATE_LIMIT_MAX:
|
||||
sentry_sdk.capture_message(
|
||||
"Rate limit exceeded",
|
||||
|
|
@ -58,6 +57,6 @@ class RateLimitMiddleware:
|
|||
)
|
||||
return JsonResponse({"error": "Rate limit exceeded"}, status=429)
|
||||
|
||||
cache.set(user_key, request_count_user + 1, timeout=RATE_LIMIT_WINDOW_MS // 1000)
|
||||
await cache.aset(user_key, request_count_user + 1, timeout=RATE_LIMIT_WINDOW_MS // 1000)
|
||||
|
||||
return await self.get_response(request)
|
||||
|
|
|
|||
|
|
@ -44,11 +44,12 @@ def get_next_subscription_period(period_end: datetime.date):
|
|||
return {"start": start, "end": end}
|
||||
|
||||
|
||||
async def check_and_reset_subscription_period(user_id: int):
|
||||
async def check_and_reset_subscription_period(user_id: int, subscription: Subscriptions | None = None):
|
||||
"""Check if the subscription period has ended.
|
||||
If yes, reset usage and move to the next billing cycle.
|
||||
"""
|
||||
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
|
||||
if subscription is None:
|
||||
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
|
||||
if not subscription:
|
||||
return None
|
||||
|
||||
|
|
@ -94,7 +95,9 @@ class TrackUsageMiddleware:
|
|||
# Check if API key is linked to an organization with subscription
|
||||
organization_id = getattr(request, "organization_id", None)
|
||||
if organization_id:
|
||||
org = await Organizations.objects.filter(id=organization_id).afirst()
|
||||
org = getattr(request, "organization", None)
|
||||
if org is None:
|
||||
org = await Organizations.objects.filter(id=organization_id).afirst()
|
||||
if org and org.subscription:
|
||||
request.subscription_info = {
|
||||
"userId": user_id,
|
||||
|
|
@ -106,8 +109,10 @@ class TrackUsageMiddleware:
|
|||
return await self.get_response(request)
|
||||
|
||||
try:
|
||||
# Get or create subscription
|
||||
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
|
||||
# Get or create subscription (reuse from auth if available)
|
||||
subscription = getattr(request, "subscription", None)
|
||||
if subscription is None:
|
||||
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
|
||||
|
||||
if not subscription:
|
||||
# Subscription is now created during login in cf-webapp
|
||||
|
|
@ -141,7 +146,7 @@ class TrackUsageMiddleware:
|
|||
)
|
||||
|
||||
# Lazy reset monthly usage
|
||||
subscription = await check_and_reset_subscription_period(user_id)
|
||||
subscription = await check_and_reset_subscription_period(user_id, subscription=subscription)
|
||||
current_used = subscription.optimizations_used or 0
|
||||
|
||||
if current_used + cost > subscription.optimizations_limit:
|
||||
|
|
@ -163,9 +168,8 @@ class TrackUsageMiddleware:
|
|||
total_lifetime_optimizations=F("total_lifetime_optimizations") + cost,
|
||||
)
|
||||
|
||||
# Re-read to get the actual updated value for the response
|
||||
updated_subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
|
||||
new_used = updated_subscription.optimizations_used if updated_subscription else current_used + cost
|
||||
# Compute new usage from the known current value instead of re-reading from DB
|
||||
new_used = current_used + cost
|
||||
|
||||
logging.debug(
|
||||
f"track_usage_middleware.py|__call__|Atomic update completed: "
|
||||
|
|
|
|||
|
|
@ -21,40 +21,31 @@ class AuthenticatedRequest(Protocol):
|
|||
should_log_features: bool # whether to log optimization features
|
||||
|
||||
|
||||
async def check_subscription_status(user_id, tier, organization_id=None) -> bool:
|
||||
async def check_subscription_status(request, user_id, tier, organization_id=None) -> bool:
|
||||
"""Check if a user has a premium subscription that doesn't require feature logging.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to check
|
||||
tier: The user's tier if already available
|
||||
organization_id: The ID of the user's organization if available
|
||||
|
||||
Returns:
|
||||
bool: False if features should not be logged (premium user or paid org), True otherwise
|
||||
|
||||
Attaches fetched organization and subscription objects to the request
|
||||
so downstream middleware can reuse them without re-querying.
|
||||
"""
|
||||
# If tier is already set, no need to check subscription
|
||||
if tier is not None:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Check if user belongs to a paid organization
|
||||
if organization_id:
|
||||
org = await Organizations.objects.filter(id=organization_id).afirst()
|
||||
request.organization = org
|
||||
if org and org.name == "codeflash-ai":
|
||||
return True
|
||||
if org and org.subscription:
|
||||
# Paid organization - don't log features
|
||||
return False
|
||||
|
||||
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
|
||||
request.subscription = subscription
|
||||
if subscription and subscription.plan_type.lower() in ["pro", "enterprise"]:
|
||||
# Premium users for CF- don't log features
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Error checking subscription: {e!s}")
|
||||
sentry_sdk.capture_exception(e)
|
||||
# Default to not logging
|
||||
return False
|
||||
return True
|
||||
|
||||
|
|
@ -76,7 +67,7 @@ class AuthBearer(HttpBearer):
|
|||
request.api_key_id = api_key_instance.id
|
||||
request.organization_id = api_key_instance.organization_id
|
||||
request.should_log_features = await check_subscription_status(
|
||||
user_id=request.user, tier=request.tier, organization_id=request.organization_id
|
||||
request, user_id=request.user, tier=request.tier, organization_id=request.organization_id
|
||||
)
|
||||
return token
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue