Async subscription reset (#1703)

ToDO: Test update and sql migrations.

Upstream dependent PR:
https://github.com/codeflash-ai/codeflash-internal/pull/1700

### How to Test?
1. Install Stripe: `curl -fsSL https://stripe.com/install.sh | sudo
bash`
2. Create Stripe credentials and save them in your environment
variables.
3. Run cf-api and start Stripe listening with: ` stripe listen
--forward-to localhost:3001/webhook`
4.	Save the webhook secret in your environment variables.
5.	Use the CLI to optimize and check the quota.
6.	Call the AI endpoint in Insomnia and verify the quota in cf-webAPP.

---------

Co-authored-by: HeshamHM28 <HeshamMohamedFathy@outlook.com>
This commit is contained in:
Sarthak Agarwal 2025-09-15 21:23:38 +05:30 committed by GitHub
parent d1bade4dce
commit d6167aa9ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 445 additions and 65 deletions

2
.gitignore vendored
View file

@ -251,4 +251,6 @@ fabric.properties
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
*/node_modules/*

View file

@ -0,0 +1,168 @@
from authapp.models import Subscriptions
from django.http import JsonResponse
from django.utils.decorators import async_only_middleware
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
import logging
import sentry_sdk
SUBSCRIPTION_PLANS = {
"FREE": {
"name": "Free",
"optimizations": 4000,
"price": 0,
},
"PRO": {
"name": "Pro",
"optimizations": 100000,
"monthlyPrice": 3000, # $30.00
"yearlyPrice": 30000, # $300.00
},
"ENTERPRISE": {
"name": "Enterprise",
"optimizations": "Custom",
"price": "Custom",
},
}
PRICE_ID_KEYS = {
"PRO_MONTHLY": "STRIPE_PRO_PRICE_MONTHLY_ID",
"PRO_YEARLY": "STRIPE_PRO_PRICE_YEARLY_ID",
}
# subscription_utils.py
import datetime
from django.utils.timezone import now
def add_months_safe(date: datetime.date, months: int) -> datetime.date:
"""
Safely add months to a date, preserving the day if possible.
"""
year = date.year + (date.month + months - 1) // 12
month = (date.month + months - 1) % 12 + 1
day = date.day
# Get the last valid day of the new month
last_day = (datetime.date(year, month % 12 + 1, 1) - datetime.timedelta(days=1)).day
return datetime.date(year, month, min(day, last_day))
def get_next_subscription_period(period_end: datetime.date):
start = add_months_safe(period_end, 1)
end = add_months_safe(start, 1)
return {"start": start, "end": end}
async def check_and_reset_subscription_period(user_id: int):
"""
Check if the subscription period has ended.
If yes, reset usage and move to the next billing cycle.
"""
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
if not subscription:
return None
current_time = now()
period_end = subscription.current_period_end
# Pro User Stripe handles billing cycle, so we only need to manage FREE plan here
if subscription.plan_type == "FREE":
if period_end and current_time > period_end:
# Reset usage and roll period
next_period = get_next_subscription_period(period_end)
subscription.optimizations_used = 0
subscription.current_period_start = next_period["start"]
subscription.current_period_end = next_period["end"]
await subscription.asave(
update_fields=["optimizations_used", "current_period_start", "current_period_end"]
)
return subscription
return subscription
ENDPOINT_TOKEN_COST = {
"optimize": 10,
"optimize-line-profiler": 10,
"testgen": 20,
"refinement": 20,
"explain": 10,
}
@async_only_middleware
class TrackUsageMiddleware:
def __init__(self, get_response) -> None:
self.get_response = get_response
if iscoroutinefunction(self.get_response):
print("TrackUsageMiddleware is async.")
markcoroutinefunction(self)
async def __call__(self, request):
endpoint = request.path.replace("/ai/", "").split("?")[0]
cost = ENDPOINT_TOKEN_COST.get(endpoint, 0)
user_id = getattr(request, "user", None)
if not user_id:
return JsonResponse({"error": "Authentication required for rate-limited endpoint"}, status=401)
try:
# Get or create subscription
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
if not subscription:
subscription = await Subscriptions.objects.acreate(
user_id=user_id,
plan_type="free",
optimizations_limit=SUBSCRIPTION_PLANS["FREE"]["optimizations"],
subscription_status="active",
optimizations_used=cost,
)
request.subscription_info = {
"userId": user_id,
"tier": subscription.plan_type,
"used": cost,
"limit": subscription.optimizations_limit,
}
return await self.get_response(request)
if subscription.subscription_status != "active":
return JsonResponse(
{
"error": "Subscription is not active",
"status": subscription.subscription_status,
},
status=403,
)
# Lazy reset monthly usage
subscription = await check_and_reset_subscription_period(user_id)
current_used = subscription.optimizations_used or 0
if current_used + cost > subscription.optimizations_limit:
return JsonResponse(
{
"error": "Usage limit exceeded",
"used": current_used,
"limit": subscription.optimizations_limit,
"tier": subscription.plan_type,
},
status=403,
)
# Increment usage
subscription.optimizations_used = current_used + cost
subscription.total_lifetime_optimizations += cost
await subscription.asave(update_fields=["optimizations_used", "total_lifetime_optimizations"])
# Attach subscription info to request
request.subscription_info = {
"userId": user_id,
"tier": subscription.plan_type,
"used": current_used + cost,
"limit": subscription.optimizations_limit,
}
return await self.get_response(request)
except Exception as e:
sentry_sdk.capture_exception(e)
logging.exception("Error tracking usage")
return JsonResponse({"error": "Internal server error"}, status=500)

View file

@ -65,6 +65,7 @@ MIDDLEWARE: list[str] = [
"aiservice.middleware.healthcheck.HealthCheckMiddleware",
"aiservice.middleware.auth_middleware.AuthMiddleware",
"aiservice.middleware.rate_limit.RateLimitMiddleware",
"aiservice.middleware.track_usage_middleware.TrackUsageMiddleware"
]
ROOT_URLCONF: str = "aiservice.urls"

View file

@ -0,0 +1,160 @@
import pytest
from django.http import JsonResponse
from aiservice.middleware.track_usage_middleware import TrackUsageMiddleware
def parse_json(response: JsonResponse):
import json
return json.loads(response.content.decode())
class FakeSubscription:
def __init__(
self,
status="active",
used=0,
limit=100,
lifetime=0,
plan_type="free",
current_period_end=None,
):
from datetime import datetime, timedelta, timezone
self.subscription_status = status
self.optimizations_used = used
self.optimizations_limit = limit
self.total_lifetime_optimizations = lifetime
self.plan_type = plan_type
self.current_period_end = current_period_end or (
datetime.now(timezone.utc) + timedelta(days=30)
)
async def asave(self, *args, **kwargs):
return None
@pytest.fixture
def middleware():
async def get_response(request):
return JsonResponse({"ok": True})
return TrackUsageMiddleware(get_response)
@pytest.mark.asyncio
async def test_requires_authentication(middleware, rf):
request = rf.get("/ai/optimize")
request.user = None
response = await middleware(request)
data = parse_json(response)
assert response.status_code == 401
assert "Authentication required" in data["error"]
@pytest.mark.asyncio
async def test_creates_subscription_for_new_user(middleware, rf, monkeypatch):
request = rf.get("/ai/optimize")
request.user = type("User", (), {"id": 1})()
class FakeFilter:
async def afirst(self): return None
monkeypatch.setattr(
"aiservice.middleware.track_usage_middleware.Subscriptions.objects.filter",
lambda **kwargs: FakeFilter(),
)
fake_sub = FakeSubscription()
async def fake_acreate(**kwargs): return fake_sub
monkeypatch.setattr(
"aiservice.middleware.track_usage_middleware.Subscriptions.objects.acreate",
fake_acreate,
)
response = await middleware(request)
data = parse_json(response)
assert response.status_code == 200
assert data["ok"] is True
@pytest.mark.asyncio
async def test_blocks_inactive_subscription(middleware, rf, monkeypatch):
request = rf.get("/ai/optimize")
request.user = type("User", (), {"id": 1})()
fake_sub = FakeSubscription(status="inactive")
class FakeFilter:
async def afirst(self): return fake_sub
monkeypatch.setattr(
"aiservice.middleware.track_usage_middleware.Subscriptions.objects.filter",
lambda **kwargs: FakeFilter(),
)
response = await middleware(request)
data = parse_json(response)
assert response.status_code == 403
assert data["status"] == "inactive"
@pytest.mark.asyncio
async def test_usage_limit_exceeded(middleware, rf, monkeypatch):
request = rf.get("/ai/optimize")
request.user = type("User", (), {"id": 1})()
fake_sub = FakeSubscription(used=101, limit=100)
class FakeFilter:
async def afirst(self): return fake_sub
monkeypatch.setattr(
"aiservice.middleware.track_usage_middleware.Subscriptions.objects.filter",
lambda **kwargs: FakeFilter(),
)
response = await middleware(request)
data = parse_json(response)
assert response.status_code == 403
assert "Usage limit exceeded" in data["error"]
@pytest.mark.asyncio
async def test_allows_usage_within_limit(middleware, rf, monkeypatch):
request = rf.get("/ai/optimize")
request.user = type("User", (), {"id": 1})()
fake_sub = FakeSubscription(used=50, limit=100, lifetime=10)
class FakeFilter:
async def afirst(self): return fake_sub
monkeypatch.setattr(
"aiservice.middleware.track_usage_middleware.Subscriptions.objects.filter",
lambda **kwargs: FakeFilter(),
)
response = await middleware(request)
data = parse_json(response)
assert response.status_code == 200
assert data["ok"] is True
@pytest.mark.asyncio
async def test_internal_error_handled(middleware, rf, monkeypatch):
request = rf.get("/ai/optimize")
request.user = type("User", (), {"id": 1})()
class FakeFilter:
async def afirst(self): raise Exception("DB down")
monkeypatch.setattr(
"aiservice.middleware.track_usage_middleware.Subscriptions.objects.filter",
lambda **kwargs: FakeFilter(),
)
response = await middleware(request)
data = parse_json(response)
assert response.status_code == 500
assert "Internal server error" in data["error"]

View file

@ -40,6 +40,7 @@ class Subscriptions(models.Model):
stripe_subscription_id = models.TextField(null=True, blank=True)
plan_type = models.TextField()
optimizations_used = models.IntegerField(default=0)
total_lifetime_optimizations = models.IntegerField(default=0)
optimizations_limit = models.IntegerField()
subscription_status = models.TextField()
current_period_start = models.DateTimeField(null=True, blank=True)

View file

@ -1,5 +1,5 @@
import { Request, Response } from "express"
import { stripe } from "@codeflash-ai/common"
import { addMonthsSafe, stripe, SUBSCRIPTION_PLANS } from "@codeflash-ai/common"
import { prisma } from "@codeflash-ai/common"
import * as Sentry from "@sentry/node"
@ -163,17 +163,31 @@ export async function handleSubscriptionUpdate(subscription: any) {
try {
const priceId = subscription.items.data[0].price.id
const price = await dependencies.stripe.prices.retrieve(priceId)
console.log(`Updating subscription for user ${userId}`)
const currentPeriodStart = new Date()
let currentPeriodEnd
// Adjust currentPeriodEnd based on interval and interval_count if needed
if (price.recurring) {
const interval = price.recurring.interval
const interval_count = price.recurring.interval_count || 1
if (interval === "month") {
currentPeriodEnd = addMonthsSafe(currentPeriodStart, interval_count)
} else if (interval === "year") {
currentPeriodEnd = addMonthsSafe(currentPeriodStart, interval_count * 12)
}
}
// Prepare update data
const updateData: any = {
stripe_subscription_id: subscription.id,
plan_type: price.metadata?.tier || "pro",
optimizations_limit: parseInt(price.metadata?.optimizations || "500"),
optimizations_limit: SUBSCRIPTION_PLANS.PRO.optimizations,
subscription_status: subscription.status,
current_period_start: new Date(subscription.current_period_start * 1000),
current_period_end: new Date(subscription.current_period_end * 1000),
current_period_start: currentPeriodStart,
current_period_end: currentPeriodEnd,
optimizations_used: 0,
updated_at: new Date(),
}
@ -201,8 +215,8 @@ export async function handleSubscriptionUpdate(subscription: any) {
optimizations_limit: parseInt(price.metadata?.optimizations || "500"),
optimizations_used: 0,
subscription_status: subscription.status,
current_period_start: new Date(subscription.current_period_start * 1000),
current_period_end: new Date(subscription.current_period_end * 1000),
current_period_start: currentPeriodStart,
current_period_end: currentPeriodEnd,
created_at: new Date(),
updated_at: new Date(),
cancel_at_period_end: subscription.cancel_at_period_end || false,
@ -228,7 +242,7 @@ export async function handleSubscriptionCancellation(subscription: any) {
data: {
subscription_status: "canceled",
plan_type: "free",
optimizations_limit: 100,
optimizations_limit: SUBSCRIPTION_PLANS.FREE.optimizations,
stripe_subscription_id: null,
cancel_at_period_end: false,
cancellation_request_date: null,

View file

@ -2,6 +2,7 @@ import { Request, Response, NextFunction } from "express"
import {
prisma,
createCheckoutSession,
getSubscription as fetchSubscription,
cancelSubscription as cancelStripeSubscription,
} from "@codeflash-ai/common"
import * as Sentry from "@sentry/node"
@ -51,10 +52,8 @@ export async function getSubscription(req: Request, res: Response, next: NextFun
}
try {
// Get subscription with usage data
const subscription = await dependencies.prisma.subscriptions.findUnique({
where: { user_id: userId },
})
// Get subscription with usage data (includes lazy reset)
const subscription = await fetchSubscription(userId)
if (!subscription) {
return res.status(404).json({ error: "Subscription not found" })
@ -66,6 +65,7 @@ export async function getSubscription(req: Request, res: Response, next: NextFun
usageCount: subscription.optimizations_used,
usageLimit: subscription.optimizations_limit,
renewalDate: subscription.current_period_end,
totalLifetimeOptimizations: subscription.total_lifetime_optimizations,
})
} catch (error) {
console.error("Error getting subscription:", error)

View file

@ -1,16 +1,14 @@
import { Request, Response, NextFunction } from "express"
import { prisma } from "@codeflash-ai/common"
import { Response, NextFunction } from "express"
import { prisma, checkAndResetSubscriptionPeriod, SUBSCRIPTION_PLANS } from "@codeflash-ai/common"
import * as Sentry from "@sentry/node"
import { AuthorizedUserReq } from "../types.js"
export async function trackUsage(req: AuthorizedUserReq, res: Response, next: NextFunction) {
//ToDO: Sarthak: resolve errors for tslint
const userId = req.userId // Get userId from the request object (set by checkForValidAPIKey)
const userId = req.userId
if (!userId) {
return res.status(401).json({ error: "User ID is missing" })
}
try {
// Get subscription info for the user
const subscription = await prisma.subscriptions.findUnique({
@ -23,9 +21,9 @@ export async function trackUsage(req: AuthorizedUserReq, res: Response, next: Ne
data: {
user_id: userId,
plan_type: "free",
optimizations_limit: 100,
optimizations_limit: SUBSCRIPTION_PLANS.FREE.optimizations,
subscription_status: "active",
optimizations_used: 1,
optimizations_used: 0,
},
})
@ -48,26 +46,15 @@ export async function trackUsage(req: AuthorizedUserReq, res: Response, next: Ne
})
}
if (subscription.optimizations_used >= subscription.optimizations_limit) {
return res.status(403).json({
error: "Usage limit exceeded",
used: subscription.optimizations_used,
limit: subscription.optimizations_limit,
tier: subscription.plan_type,
})
}
// Increment usage
await prisma.subscriptions.update({
where: { user_id: userId },
data: { optimizations_used: { increment: 1 } },
})
// Check if we need to reset monthly usage (lazy reset)
const currentSubscription = await checkAndResetSubscriptionPeriod(userId)
const currentOptimizationsUsed = currentSubscription?.optimizations_used || 0
// Add subscription info to request for later use
req["subscriptionInfo"] = {
userId: userId,
tier: subscription.plan_type,
used: subscription.optimizations_used,
used: currentOptimizationsUsed,
limit: subscription.optimizations_limit,
}
next()

View file

@ -2,7 +2,6 @@
import { useEffect, useState } from "react"
import { Button } from "@/components/ui/button"
import { Card, CardHeader, CardTitle, CardContent, CardFooter } from "@/components/ui/card"
import { Progress } from "@/components/ui/progress"
import {
upgradeSubscription,
cancelSubscription,
@ -92,9 +91,6 @@ export function BillingView({ userId, subscription: initialSubscription, plans }
}
}
// Calculate usage percentage
const usagePercent = (subscription.optimizations_used / subscription.optimizations_limit) * 100
// Calculate time remaining on cancelled subscription
const daysRemaining = subscription.current_period_end
? Math.ceil(
@ -151,7 +147,6 @@ export function BillingView({ userId, subscription: initialSubscription, plans }
)}
</CardTitle>
</CardHeader>
<CardContent className="space-y-4">
{/* Display cancellation notice if applicable */}
{subscription.cancel_at_period_end && (
@ -180,13 +175,6 @@ export function BillingView({ userId, subscription: initialSubscription, plans }
</Button>
</div>
)}
<div>
<p className="text-sm text-gray-500 mb-1">Optimizations Usage</p>
<Progress value={usagePercent} className="h-2" />
<p className="text-sm mt-1">
{subscription.optimizations_used} / {subscription.optimizations_limit} optimizations
</p>
</div>
{subscription.current_period_end && (
<div>

View file

@ -1,11 +1,8 @@
"use server"
import { getSession } from "@auth0/nextjs-auth0"
import { PrismaClient } from "@prisma/client"
import { BillingView } from "./billing-view"
import PostHogClient from "@/lib/posthog"
import { SUBSCRIPTION_PLANS } from "@codeflash-ai/common"
const prisma = new PrismaClient()
import { SUBSCRIPTION_PLANS, checkAndResetSubscriptionPeriod } from "@codeflash-ai/common"
export default async function BillingPage() {
const session = await getSession()
@ -24,13 +21,11 @@ export default async function BillingPage() {
})
await posthog.shutdown()
// Get subscription info from database
const subscription = (await prisma.subscriptions.findUnique({
where: { user_id: userId },
})) || {
// Get subscription info from database with lazy reset
const subscription = (await checkAndResetSubscriptionPeriod(userId)) || {
plan_type: "free",
optimizations_used: 0,
optimizations_limit: 100,
optimizations_limit: SUBSCRIPTION_PLANS.FREE.optimizations,
subscription_status: "active",
}
@ -42,7 +37,7 @@ export default async function BillingPage() {
const fallbackSubscription = {
plan_type: "free",
optimizations_used: 0,
optimizations_limit: 100,
optimizations_limit: SUBSCRIPTION_PLANS.FREE.optimizations,
subscription_status: "active",
}

View file

@ -0,0 +1,5 @@
-- AlterTable
ALTER TABLE "subscriptions" ADD COLUMN "total_lifetime_optimizations" INTEGER NOT NULL DEFAULT 0;
-- Update existing rows to have total_lifetime_optimizations equal to their current optimizations_used
UPDATE "subscriptions" SET "total_lifetime_optimizations" = "optimizations_used" WHERE "total_lifetime_optimizations" = 0;

View file

@ -83,6 +83,7 @@ model subscriptions {
plan_type String
optimizations_used Int @default(0)
optimizations_limit Int
total_lifetime_optimizations Int @default(0)
subscription_status String
current_period_start DateTime?
current_period_end DateTime?

View file

@ -16,6 +16,7 @@ export { prisma } from "./prisma-client"
// Use prismaClient instead of new PrismaClient()
export * from "./subscription-functions"
export * from "./subscription-config"
export * from "./subscription-utils"
export * from "./stripe-client"
export * from "./optimization-event"
export * from "./cf-app-installations-functions"

View file

@ -1,13 +1,14 @@
// common/src/subscription-config.ts
// The optimizations: number of optimization tokens (units for the endpoints) included per plan
export const SUBSCRIPTION_PLANS = {
FREE: {
name: "Free",
optimizations: 100,
optimizations: 4000,
price: 0,
},
PRO: {
name: "Pro",
optimizations: 500,
optimizations: 100000,
monthlyPrice: 3000, // $30.00
yearlyPrice: 30000, // $300.00
},

View file

@ -1,6 +1,7 @@
import { stripe } from "./stripe-client"
import { prisma } from "./prisma-client"
import { SUBSCRIPTION_PLANS } from "./subscription-config"
import { checkAndResetSubscriptionPeriod } from "./subscription-utils"
/**
* Create a checkout session for a subscription
@ -110,9 +111,7 @@ export async function checkSubscriptionLimits(userId: string): Promise<{
tier: string
}> {
try {
let subscription = await prisma.subscriptions.findUnique({
where: { user_id: userId },
})
let subscription = await checkAndResetSubscriptionPeriod(userId)
// Create a free tier subscription if none exists
if (!subscription) {
@ -169,19 +168,24 @@ export async function resetMonthlyUsageCounts(): Promise<number> {
export async function getSubscription(userId: string) {
try {
// Directly fetch subscription from DB
const subscription = await prisma.subscriptions.findUnique({
where: { user_id: userId },
})
return (
subscription || {
if (!subscription) {
return {
plan_type: "free",
optimizations_used: 0,
optimizations_limit: 100,
optimizations_limit: SUBSCRIPTION_PLANS.FREE.optimizations,
subscription_status: "active",
cancel_at_period_end: false,
current_period_end: null,
total_lifetime_optimizations: 0,
}
)
}
return subscription
} catch (error) {
console.error("Error fetching subscription:", error)
return null

View file

@ -0,0 +1,52 @@
import { prisma } from "./prisma-client"
export async function checkAndResetSubscriptionPeriod(userId: string) {
const subscription = await prisma.subscriptions.findUnique({
where: { user_id: userId },
})
if (!subscription) {
return null
}
if (subscription.plan_type === "free") {
const now = new Date()
const periodEnd = subscription.current_period_end
if (periodEnd && now > periodEnd) {
// Period has ended, reset usage and update period
const { start, end } = getNextSubscriptionPeriod(subscription.current_period_end)
const updatedSubscription = await prisma.subscriptions.update({
where: { user_id: userId },
data: {
optimizations_used: 0,
current_period_start: start,
current_period_end: end,
},
})
return updatedSubscription
}
}
return subscription
}
// period-utils.ts
export function addMonthsSafe(date: Date, months: number) {
const result = new Date(date)
const day = result.getDate()
result.setDate(1)
result.setMonth(result.getMonth() + months)
const lastDay = new Date(result.getFullYear(), result.getMonth() + 1, 0).getDate()
result.setDate(Math.min(day, lastDay))
return result
}
export function getNextSubscriptionPeriod(periodEnd: Date) {
const start = addMonthsSafe(periodEnd, 1)
const end = addMonthsSafe(start, 1)
return { start, end }
}