mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
Async subscription reset (#1703)
ToDO: Test update and sql migrations. Upstream dependent PR: https://github.com/codeflash-ai/codeflash-internal/pull/1700 ### How to Test? 1. Install Stripe: `curl -fsSL https://stripe.com/install.sh | sudo bash` 2. Create Stripe credentials and save them in your environment variables. 3. Run cf-api and start Stripe listening with: ` stripe listen --forward-to localhost:3001/webhook` 4. Save the webhook secret in your environment variables. 5. Use the CLI to optimize and check the quota. 6. Call the AI endpoint in Insomnia and verify the quota in cf-webAPP. --------- Co-authored-by: HeshamHM28 <HeshamMohamedFathy@outlook.com>
This commit is contained in:
parent
d1bade4dce
commit
d6167aa9ac
16 changed files with 445 additions and 65 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -251,4 +251,6 @@ fabric.properties
|
|||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
.idea/caches/build_file_checksums.ser
|
||||
|
||||
*/node_modules/*
|
||||
|
||||
|
|
|
|||
168
django/aiservice/aiservice/middleware/track_usage_middleware.py
Normal file
168
django/aiservice/aiservice/middleware/track_usage_middleware.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
|
||||
from authapp.models import Subscriptions
|
||||
from django.http import JsonResponse
|
||||
from django.utils.decorators import async_only_middleware
|
||||
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
|
||||
import logging
|
||||
import sentry_sdk
|
||||
|
||||
SUBSCRIPTION_PLANS = {
|
||||
"FREE": {
|
||||
"name": "Free",
|
||||
"optimizations": 4000,
|
||||
"price": 0,
|
||||
},
|
||||
"PRO": {
|
||||
"name": "Pro",
|
||||
"optimizations": 100000,
|
||||
"monthlyPrice": 3000, # $30.00
|
||||
"yearlyPrice": 30000, # $300.00
|
||||
},
|
||||
"ENTERPRISE": {
|
||||
"name": "Enterprise",
|
||||
"optimizations": "Custom",
|
||||
"price": "Custom",
|
||||
},
|
||||
}
|
||||
|
||||
PRICE_ID_KEYS = {
|
||||
"PRO_MONTHLY": "STRIPE_PRO_PRICE_MONTHLY_ID",
|
||||
"PRO_YEARLY": "STRIPE_PRO_PRICE_YEARLY_ID",
|
||||
}
|
||||
# subscription_utils.py
|
||||
import datetime
|
||||
from django.utils.timezone import now
|
||||
|
||||
|
||||
def add_months_safe(date: datetime.date, months: int) -> datetime.date:
|
||||
"""
|
||||
Safely add months to a date, preserving the day if possible.
|
||||
"""
|
||||
year = date.year + (date.month + months - 1) // 12
|
||||
month = (date.month + months - 1) % 12 + 1
|
||||
day = date.day
|
||||
|
||||
# Get the last valid day of the new month
|
||||
last_day = (datetime.date(year, month % 12 + 1, 1) - datetime.timedelta(days=1)).day
|
||||
return datetime.date(year, month, min(day, last_day))
|
||||
|
||||
|
||||
def get_next_subscription_period(period_end: datetime.date):
|
||||
start = add_months_safe(period_end, 1)
|
||||
end = add_months_safe(start, 1)
|
||||
return {"start": start, "end": end}
|
||||
|
||||
|
||||
async def check_and_reset_subscription_period(user_id: int):
|
||||
"""
|
||||
Check if the subscription period has ended.
|
||||
If yes, reset usage and move to the next billing cycle.
|
||||
"""
|
||||
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
|
||||
if not subscription:
|
||||
return None
|
||||
|
||||
current_time = now()
|
||||
period_end = subscription.current_period_end
|
||||
# Pro User Stripe handles billing cycle, so we only need to manage FREE plan here
|
||||
if subscription.plan_type == "FREE":
|
||||
|
||||
if period_end and current_time > period_end:
|
||||
# Reset usage and roll period
|
||||
next_period = get_next_subscription_period(period_end)
|
||||
|
||||
subscription.optimizations_used = 0
|
||||
subscription.current_period_start = next_period["start"]
|
||||
subscription.current_period_end = next_period["end"]
|
||||
await subscription.asave(
|
||||
update_fields=["optimizations_used", "current_period_start", "current_period_end"]
|
||||
)
|
||||
return subscription
|
||||
|
||||
return subscription
|
||||
ENDPOINT_TOKEN_COST = {
|
||||
"optimize": 10,
|
||||
"optimize-line-profiler": 10,
|
||||
"testgen": 20,
|
||||
"refinement": 20,
|
||||
"explain": 10,
|
||||
}
|
||||
@async_only_middleware
|
||||
class TrackUsageMiddleware:
|
||||
def __init__(self, get_response) -> None:
|
||||
self.get_response = get_response
|
||||
if iscoroutinefunction(self.get_response):
|
||||
print("TrackUsageMiddleware is async.")
|
||||
markcoroutinefunction(self)
|
||||
|
||||
async def __call__(self, request):
|
||||
endpoint = request.path.replace("/ai/", "").split("?")[0]
|
||||
cost = ENDPOINT_TOKEN_COST.get(endpoint, 0)
|
||||
|
||||
user_id = getattr(request, "user", None)
|
||||
if not user_id:
|
||||
return JsonResponse({"error": "Authentication required for rate-limited endpoint"}, status=401)
|
||||
|
||||
try:
|
||||
# Get or create subscription
|
||||
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
|
||||
|
||||
if not subscription:
|
||||
subscription = await Subscriptions.objects.acreate(
|
||||
user_id=user_id,
|
||||
plan_type="free",
|
||||
optimizations_limit=SUBSCRIPTION_PLANS["FREE"]["optimizations"],
|
||||
subscription_status="active",
|
||||
optimizations_used=cost,
|
||||
)
|
||||
request.subscription_info = {
|
||||
"userId": user_id,
|
||||
"tier": subscription.plan_type,
|
||||
"used": cost,
|
||||
"limit": subscription.optimizations_limit,
|
||||
}
|
||||
return await self.get_response(request)
|
||||
|
||||
if subscription.subscription_status != "active":
|
||||
return JsonResponse(
|
||||
{
|
||||
"error": "Subscription is not active",
|
||||
"status": subscription.subscription_status,
|
||||
},
|
||||
status=403,
|
||||
)
|
||||
|
||||
# Lazy reset monthly usage
|
||||
subscription = await check_and_reset_subscription_period(user_id)
|
||||
current_used = subscription.optimizations_used or 0
|
||||
|
||||
if current_used + cost > subscription.optimizations_limit:
|
||||
return JsonResponse(
|
||||
{
|
||||
"error": "Usage limit exceeded",
|
||||
"used": current_used,
|
||||
"limit": subscription.optimizations_limit,
|
||||
"tier": subscription.plan_type,
|
||||
},
|
||||
status=403,
|
||||
)
|
||||
|
||||
# Increment usage
|
||||
subscription.optimizations_used = current_used + cost
|
||||
subscription.total_lifetime_optimizations += cost
|
||||
await subscription.asave(update_fields=["optimizations_used", "total_lifetime_optimizations"])
|
||||
|
||||
# Attach subscription info to request
|
||||
request.subscription_info = {
|
||||
"userId": user_id,
|
||||
"tier": subscription.plan_type,
|
||||
"used": current_used + cost,
|
||||
"limit": subscription.optimizations_limit,
|
||||
}
|
||||
|
||||
return await self.get_response(request)
|
||||
|
||||
except Exception as e:
|
||||
sentry_sdk.capture_exception(e)
|
||||
logging.exception("Error tracking usage")
|
||||
return JsonResponse({"error": "Internal server error"}, status=500)
|
||||
|
|
@ -65,6 +65,7 @@ MIDDLEWARE: list[str] = [
|
|||
"aiservice.middleware.healthcheck.HealthCheckMiddleware",
|
||||
"aiservice.middleware.auth_middleware.AuthMiddleware",
|
||||
"aiservice.middleware.rate_limit.RateLimitMiddleware",
|
||||
"aiservice.middleware.track_usage_middleware.TrackUsageMiddleware"
|
||||
]
|
||||
|
||||
ROOT_URLCONF: str = "aiservice.urls"
|
||||
|
|
|
|||
160
django/aiservice/aiservice/tests/test_track_usage_middleware.py
Normal file
160
django/aiservice/aiservice/tests/test_track_usage_middleware.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
import pytest
|
||||
from django.http import JsonResponse
|
||||
from aiservice.middleware.track_usage_middleware import TrackUsageMiddleware
|
||||
|
||||
|
||||
def parse_json(response: JsonResponse):
|
||||
import json
|
||||
return json.loads(response.content.decode())
|
||||
|
||||
|
||||
class FakeSubscription:
|
||||
def __init__(
|
||||
self,
|
||||
status="active",
|
||||
used=0,
|
||||
limit=100,
|
||||
lifetime=0,
|
||||
plan_type="free",
|
||||
current_period_end=None,
|
||||
):
|
||||
from datetime import datetime, timedelta, timezone
|
||||
self.subscription_status = status
|
||||
self.optimizations_used = used
|
||||
self.optimizations_limit = limit
|
||||
self.total_lifetime_optimizations = lifetime
|
||||
self.plan_type = plan_type
|
||||
self.current_period_end = current_period_end or (
|
||||
datetime.now(timezone.utc) + timedelta(days=30)
|
||||
)
|
||||
|
||||
async def asave(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def middleware():
|
||||
async def get_response(request):
|
||||
return JsonResponse({"ok": True})
|
||||
return TrackUsageMiddleware(get_response)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_authentication(middleware, rf):
|
||||
request = rf.get("/ai/optimize")
|
||||
request.user = None
|
||||
|
||||
response = await middleware(request)
|
||||
data = parse_json(response)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert "Authentication required" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_subscription_for_new_user(middleware, rf, monkeypatch):
|
||||
request = rf.get("/ai/optimize")
|
||||
request.user = type("User", (), {"id": 1})()
|
||||
|
||||
class FakeFilter:
|
||||
async def afirst(self): return None
|
||||
monkeypatch.setattr(
|
||||
"aiservice.middleware.track_usage_middleware.Subscriptions.objects.filter",
|
||||
lambda **kwargs: FakeFilter(),
|
||||
)
|
||||
|
||||
fake_sub = FakeSubscription()
|
||||
|
||||
async def fake_acreate(**kwargs): return fake_sub
|
||||
monkeypatch.setattr(
|
||||
"aiservice.middleware.track_usage_middleware.Subscriptions.objects.acreate",
|
||||
fake_acreate,
|
||||
)
|
||||
|
||||
response = await middleware(request)
|
||||
data = parse_json(response)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert data["ok"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocks_inactive_subscription(middleware, rf, monkeypatch):
|
||||
request = rf.get("/ai/optimize")
|
||||
request.user = type("User", (), {"id": 1})()
|
||||
|
||||
fake_sub = FakeSubscription(status="inactive")
|
||||
|
||||
class FakeFilter:
|
||||
async def afirst(self): return fake_sub
|
||||
monkeypatch.setattr(
|
||||
"aiservice.middleware.track_usage_middleware.Subscriptions.objects.filter",
|
||||
lambda **kwargs: FakeFilter(),
|
||||
)
|
||||
|
||||
response = await middleware(request)
|
||||
data = parse_json(response)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert data["status"] == "inactive"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_limit_exceeded(middleware, rf, monkeypatch):
|
||||
request = rf.get("/ai/optimize")
|
||||
request.user = type("User", (), {"id": 1})()
|
||||
|
||||
fake_sub = FakeSubscription(used=101, limit=100)
|
||||
|
||||
class FakeFilter:
|
||||
async def afirst(self): return fake_sub
|
||||
monkeypatch.setattr(
|
||||
"aiservice.middleware.track_usage_middleware.Subscriptions.objects.filter",
|
||||
lambda **kwargs: FakeFilter(),
|
||||
)
|
||||
|
||||
response = await middleware(request)
|
||||
data = parse_json(response)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert "Usage limit exceeded" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_usage_within_limit(middleware, rf, monkeypatch):
|
||||
request = rf.get("/ai/optimize")
|
||||
request.user = type("User", (), {"id": 1})()
|
||||
|
||||
fake_sub = FakeSubscription(used=50, limit=100, lifetime=10)
|
||||
|
||||
class FakeFilter:
|
||||
async def afirst(self): return fake_sub
|
||||
monkeypatch.setattr(
|
||||
"aiservice.middleware.track_usage_middleware.Subscriptions.objects.filter",
|
||||
lambda **kwargs: FakeFilter(),
|
||||
)
|
||||
|
||||
response = await middleware(request)
|
||||
data = parse_json(response)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert data["ok"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_internal_error_handled(middleware, rf, monkeypatch):
|
||||
request = rf.get("/ai/optimize")
|
||||
request.user = type("User", (), {"id": 1})()
|
||||
|
||||
class FakeFilter:
|
||||
async def afirst(self): raise Exception("DB down")
|
||||
monkeypatch.setattr(
|
||||
"aiservice.middleware.track_usage_middleware.Subscriptions.objects.filter",
|
||||
lambda **kwargs: FakeFilter(),
|
||||
)
|
||||
|
||||
response = await middleware(request)
|
||||
data = parse_json(response)
|
||||
|
||||
assert response.status_code == 500
|
||||
assert "Internal server error" in data["error"]
|
||||
|
|
@ -40,6 +40,7 @@ class Subscriptions(models.Model):
|
|||
stripe_subscription_id = models.TextField(null=True, blank=True)
|
||||
plan_type = models.TextField()
|
||||
optimizations_used = models.IntegerField(default=0)
|
||||
total_lifetime_optimizations = models.IntegerField(default=0)
|
||||
optimizations_limit = models.IntegerField()
|
||||
subscription_status = models.TextField()
|
||||
current_period_start = models.DateTimeField(null=True, blank=True)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import { Request, Response } from "express"
|
||||
import { stripe } from "@codeflash-ai/common"
|
||||
import { addMonthsSafe, stripe, SUBSCRIPTION_PLANS } from "@codeflash-ai/common"
|
||||
import { prisma } from "@codeflash-ai/common"
|
||||
import * as Sentry from "@sentry/node"
|
||||
|
||||
|
|
@ -163,17 +163,31 @@ export async function handleSubscriptionUpdate(subscription: any) {
|
|||
try {
|
||||
const priceId = subscription.items.data[0].price.id
|
||||
const price = await dependencies.stripe.prices.retrieve(priceId)
|
||||
|
||||
console.log(`Updating subscription for user ${userId}`)
|
||||
|
||||
const currentPeriodStart = new Date()
|
||||
let currentPeriodEnd
|
||||
// Adjust currentPeriodEnd based on interval and interval_count if needed
|
||||
if (price.recurring) {
|
||||
const interval = price.recurring.interval
|
||||
const interval_count = price.recurring.interval_count || 1
|
||||
|
||||
if (interval === "month") {
|
||||
currentPeriodEnd = addMonthsSafe(currentPeriodStart, interval_count)
|
||||
} else if (interval === "year") {
|
||||
currentPeriodEnd = addMonthsSafe(currentPeriodStart, interval_count * 12)
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare update data
|
||||
const updateData: any = {
|
||||
stripe_subscription_id: subscription.id,
|
||||
plan_type: price.metadata?.tier || "pro",
|
||||
optimizations_limit: parseInt(price.metadata?.optimizations || "500"),
|
||||
optimizations_limit: SUBSCRIPTION_PLANS.PRO.optimizations,
|
||||
subscription_status: subscription.status,
|
||||
current_period_start: new Date(subscription.current_period_start * 1000),
|
||||
current_period_end: new Date(subscription.current_period_end * 1000),
|
||||
current_period_start: currentPeriodStart,
|
||||
current_period_end: currentPeriodEnd,
|
||||
optimizations_used: 0,
|
||||
updated_at: new Date(),
|
||||
}
|
||||
|
||||
|
|
@ -201,8 +215,8 @@ export async function handleSubscriptionUpdate(subscription: any) {
|
|||
optimizations_limit: parseInt(price.metadata?.optimizations || "500"),
|
||||
optimizations_used: 0,
|
||||
subscription_status: subscription.status,
|
||||
current_period_start: new Date(subscription.current_period_start * 1000),
|
||||
current_period_end: new Date(subscription.current_period_end * 1000),
|
||||
current_period_start: currentPeriodStart,
|
||||
current_period_end: currentPeriodEnd,
|
||||
created_at: new Date(),
|
||||
updated_at: new Date(),
|
||||
cancel_at_period_end: subscription.cancel_at_period_end || false,
|
||||
|
|
@ -228,7 +242,7 @@ export async function handleSubscriptionCancellation(subscription: any) {
|
|||
data: {
|
||||
subscription_status: "canceled",
|
||||
plan_type: "free",
|
||||
optimizations_limit: 100,
|
||||
optimizations_limit: SUBSCRIPTION_PLANS.FREE.optimizations,
|
||||
stripe_subscription_id: null,
|
||||
cancel_at_period_end: false,
|
||||
cancellation_request_date: null,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import { Request, Response, NextFunction } from "express"
|
|||
import {
|
||||
prisma,
|
||||
createCheckoutSession,
|
||||
getSubscription as fetchSubscription,
|
||||
cancelSubscription as cancelStripeSubscription,
|
||||
} from "@codeflash-ai/common"
|
||||
import * as Sentry from "@sentry/node"
|
||||
|
|
@ -51,10 +52,8 @@ export async function getSubscription(req: Request, res: Response, next: NextFun
|
|||
}
|
||||
|
||||
try {
|
||||
// Get subscription with usage data
|
||||
const subscription = await dependencies.prisma.subscriptions.findUnique({
|
||||
where: { user_id: userId },
|
||||
})
|
||||
// Get subscription with usage data (includes lazy reset)
|
||||
const subscription = await fetchSubscription(userId)
|
||||
|
||||
if (!subscription) {
|
||||
return res.status(404).json({ error: "Subscription not found" })
|
||||
|
|
@ -66,6 +65,7 @@ export async function getSubscription(req: Request, res: Response, next: NextFun
|
|||
usageCount: subscription.optimizations_used,
|
||||
usageLimit: subscription.optimizations_limit,
|
||||
renewalDate: subscription.current_period_end,
|
||||
totalLifetimeOptimizations: subscription.total_lifetime_optimizations,
|
||||
})
|
||||
} catch (error) {
|
||||
console.error("Error getting subscription:", error)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,14 @@
|
|||
import { Request, Response, NextFunction } from "express"
|
||||
import { prisma } from "@codeflash-ai/common"
|
||||
import { Response, NextFunction } from "express"
|
||||
import { prisma, checkAndResetSubscriptionPeriod, SUBSCRIPTION_PLANS } from "@codeflash-ai/common"
|
||||
import * as Sentry from "@sentry/node"
|
||||
import { AuthorizedUserReq } from "../types.js"
|
||||
|
||||
export async function trackUsage(req: AuthorizedUserReq, res: Response, next: NextFunction) {
|
||||
//ToDO: Sarthak: resolve errors for tslint
|
||||
const userId = req.userId // Get userId from the request object (set by checkForValidAPIKey)
|
||||
const userId = req.userId
|
||||
|
||||
if (!userId) {
|
||||
return res.status(401).json({ error: "User ID is missing" })
|
||||
}
|
||||
|
||||
try {
|
||||
// Get subscription info for the user
|
||||
const subscription = await prisma.subscriptions.findUnique({
|
||||
|
|
@ -23,9 +21,9 @@ export async function trackUsage(req: AuthorizedUserReq, res: Response, next: Ne
|
|||
data: {
|
||||
user_id: userId,
|
||||
plan_type: "free",
|
||||
optimizations_limit: 100,
|
||||
optimizations_limit: SUBSCRIPTION_PLANS.FREE.optimizations,
|
||||
subscription_status: "active",
|
||||
optimizations_used: 1,
|
||||
optimizations_used: 0,
|
||||
},
|
||||
})
|
||||
|
||||
|
|
@ -48,26 +46,15 @@ export async function trackUsage(req: AuthorizedUserReq, res: Response, next: Ne
|
|||
})
|
||||
}
|
||||
|
||||
if (subscription.optimizations_used >= subscription.optimizations_limit) {
|
||||
return res.status(403).json({
|
||||
error: "Usage limit exceeded",
|
||||
used: subscription.optimizations_used,
|
||||
limit: subscription.optimizations_limit,
|
||||
tier: subscription.plan_type,
|
||||
})
|
||||
}
|
||||
|
||||
// Increment usage
|
||||
await prisma.subscriptions.update({
|
||||
where: { user_id: userId },
|
||||
data: { optimizations_used: { increment: 1 } },
|
||||
})
|
||||
// Check if we need to reset monthly usage (lazy reset)
|
||||
const currentSubscription = await checkAndResetSubscriptionPeriod(userId)
|
||||
const currentOptimizationsUsed = currentSubscription?.optimizations_used || 0
|
||||
|
||||
// Add subscription info to request for later use
|
||||
req["subscriptionInfo"] = {
|
||||
userId: userId,
|
||||
tier: subscription.plan_type,
|
||||
used: subscription.optimizations_used,
|
||||
used: currentOptimizationsUsed,
|
||||
limit: subscription.optimizations_limit,
|
||||
}
|
||||
next()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
import { useEffect, useState } from "react"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { Card, CardHeader, CardTitle, CardContent, CardFooter } from "@/components/ui/card"
|
||||
import { Progress } from "@/components/ui/progress"
|
||||
import {
|
||||
upgradeSubscription,
|
||||
cancelSubscription,
|
||||
|
|
@ -92,9 +91,6 @@ export function BillingView({ userId, subscription: initialSubscription, plans }
|
|||
}
|
||||
}
|
||||
|
||||
// Calculate usage percentage
|
||||
const usagePercent = (subscription.optimizations_used / subscription.optimizations_limit) * 100
|
||||
|
||||
// Calculate time remaining on cancelled subscription
|
||||
const daysRemaining = subscription.current_period_end
|
||||
? Math.ceil(
|
||||
|
|
@ -151,7 +147,6 @@ export function BillingView({ userId, subscription: initialSubscription, plans }
|
|||
)}
|
||||
</CardTitle>
|
||||
</CardHeader>
|
||||
|
||||
<CardContent className="space-y-4">
|
||||
{/* Display cancellation notice if applicable */}
|
||||
{subscription.cancel_at_period_end && (
|
||||
|
|
@ -180,13 +175,6 @@ export function BillingView({ userId, subscription: initialSubscription, plans }
|
|||
</Button>
|
||||
</div>
|
||||
)}
|
||||
<div>
|
||||
<p className="text-sm text-gray-500 mb-1">Optimizations Usage</p>
|
||||
<Progress value={usagePercent} className="h-2" />
|
||||
<p className="text-sm mt-1">
|
||||
{subscription.optimizations_used} / {subscription.optimizations_limit} optimizations
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{subscription.current_period_end && (
|
||||
<div>
|
||||
|
|
|
|||
|
|
@ -1,11 +1,8 @@
|
|||
"use server"
|
||||
import { getSession } from "@auth0/nextjs-auth0"
|
||||
import { PrismaClient } from "@prisma/client"
|
||||
import { BillingView } from "./billing-view"
|
||||
import PostHogClient from "@/lib/posthog"
|
||||
import { SUBSCRIPTION_PLANS } from "@codeflash-ai/common"
|
||||
|
||||
const prisma = new PrismaClient()
|
||||
import { SUBSCRIPTION_PLANS, checkAndResetSubscriptionPeriod } from "@codeflash-ai/common"
|
||||
|
||||
export default async function BillingPage() {
|
||||
const session = await getSession()
|
||||
|
|
@ -24,13 +21,11 @@ export default async function BillingPage() {
|
|||
})
|
||||
await posthog.shutdown()
|
||||
|
||||
// Get subscription info from database
|
||||
const subscription = (await prisma.subscriptions.findUnique({
|
||||
where: { user_id: userId },
|
||||
})) || {
|
||||
// Get subscription info from database with lazy reset
|
||||
const subscription = (await checkAndResetSubscriptionPeriod(userId)) || {
|
||||
plan_type: "free",
|
||||
optimizations_used: 0,
|
||||
optimizations_limit: 100,
|
||||
optimizations_limit: SUBSCRIPTION_PLANS.FREE.optimizations,
|
||||
subscription_status: "active",
|
||||
}
|
||||
|
||||
|
|
@ -42,7 +37,7 @@ export default async function BillingPage() {
|
|||
const fallbackSubscription = {
|
||||
plan_type: "free",
|
||||
optimizations_used: 0,
|
||||
optimizations_limit: 100,
|
||||
optimizations_limit: SUBSCRIPTION_PLANS.FREE.optimizations,
|
||||
subscription_status: "active",
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
-- AlterTable
|
||||
ALTER TABLE "subscriptions" ADD COLUMN "total_lifetime_optimizations" INTEGER NOT NULL DEFAULT 0;
|
||||
|
||||
-- Update existing rows to have total_lifetime_optimizations equal to their current optimizations_used
|
||||
UPDATE "subscriptions" SET "total_lifetime_optimizations" = "optimizations_used" WHERE "total_lifetime_optimizations" = 0;
|
||||
|
|
@ -83,6 +83,7 @@ model subscriptions {
|
|||
plan_type String
|
||||
optimizations_used Int @default(0)
|
||||
optimizations_limit Int
|
||||
total_lifetime_optimizations Int @default(0)
|
||||
subscription_status String
|
||||
current_period_start DateTime?
|
||||
current_period_end DateTime?
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ export { prisma } from "./prisma-client"
|
|||
// Use prismaClient instead of new PrismaClient()
|
||||
export * from "./subscription-functions"
|
||||
export * from "./subscription-config"
|
||||
export * from "./subscription-utils"
|
||||
export * from "./stripe-client"
|
||||
export * from "./optimization-event"
|
||||
export * from "./cf-app-installations-functions"
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
// common/src/subscription-config.ts
|
||||
// The optimizations: number of optimization tokens (units for the endpoints) included per plan
|
||||
export const SUBSCRIPTION_PLANS = {
|
||||
FREE: {
|
||||
name: "Free",
|
||||
optimizations: 100,
|
||||
optimizations: 4000,
|
||||
price: 0,
|
||||
},
|
||||
PRO: {
|
||||
name: "Pro",
|
||||
optimizations: 500,
|
||||
optimizations: 100000,
|
||||
monthlyPrice: 3000, // $30.00
|
||||
yearlyPrice: 30000, // $300.00
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import { stripe } from "./stripe-client"
|
||||
import { prisma } from "./prisma-client"
|
||||
import { SUBSCRIPTION_PLANS } from "./subscription-config"
|
||||
import { checkAndResetSubscriptionPeriod } from "./subscription-utils"
|
||||
|
||||
/**
|
||||
* Create a checkout session for a subscription
|
||||
|
|
@ -110,9 +111,7 @@ export async function checkSubscriptionLimits(userId: string): Promise<{
|
|||
tier: string
|
||||
}> {
|
||||
try {
|
||||
let subscription = await prisma.subscriptions.findUnique({
|
||||
where: { user_id: userId },
|
||||
})
|
||||
let subscription = await checkAndResetSubscriptionPeriod(userId)
|
||||
|
||||
// Create a free tier subscription if none exists
|
||||
if (!subscription) {
|
||||
|
|
@ -169,19 +168,24 @@ export async function resetMonthlyUsageCounts(): Promise<number> {
|
|||
|
||||
export async function getSubscription(userId: string) {
|
||||
try {
|
||||
// Directly fetch subscription from DB
|
||||
const subscription = await prisma.subscriptions.findUnique({
|
||||
where: { user_id: userId },
|
||||
})
|
||||
|
||||
return (
|
||||
subscription || {
|
||||
if (!subscription) {
|
||||
return {
|
||||
plan_type: "free",
|
||||
optimizations_used: 0,
|
||||
optimizations_limit: 100,
|
||||
optimizations_limit: SUBSCRIPTION_PLANS.FREE.optimizations,
|
||||
subscription_status: "active",
|
||||
cancel_at_period_end: false,
|
||||
current_period_end: null,
|
||||
total_lifetime_optimizations: 0,
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
return subscription
|
||||
} catch (error) {
|
||||
console.error("Error fetching subscription:", error)
|
||||
return null
|
||||
|
|
|
|||
52
js/common/src/subscription-utils.ts
Normal file
52
js/common/src/subscription-utils.ts
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import { prisma } from "./prisma-client"
|
||||
|
||||
export async function checkAndResetSubscriptionPeriod(userId: string) {
|
||||
const subscription = await prisma.subscriptions.findUnique({
|
||||
where: { user_id: userId },
|
||||
})
|
||||
|
||||
if (!subscription) {
|
||||
return null
|
||||
}
|
||||
if (subscription.plan_type === "free") {
|
||||
const now = new Date()
|
||||
const periodEnd = subscription.current_period_end
|
||||
|
||||
if (periodEnd && now > periodEnd) {
|
||||
// Period has ended, reset usage and update period
|
||||
const { start, end } = getNextSubscriptionPeriod(subscription.current_period_end)
|
||||
|
||||
const updatedSubscription = await prisma.subscriptions.update({
|
||||
where: { user_id: userId },
|
||||
data: {
|
||||
optimizations_used: 0,
|
||||
current_period_start: start,
|
||||
current_period_end: end,
|
||||
},
|
||||
})
|
||||
|
||||
return updatedSubscription
|
||||
}
|
||||
}
|
||||
return subscription
|
||||
}
|
||||
|
||||
// period-utils.ts
|
||||
export function addMonthsSafe(date: Date, months: number) {
|
||||
const result = new Date(date)
|
||||
const day = result.getDate()
|
||||
|
||||
result.setDate(1)
|
||||
result.setMonth(result.getMonth() + months)
|
||||
|
||||
const lastDay = new Date(result.getFullYear(), result.getMonth() + 1, 0).getDate()
|
||||
result.setDate(Math.min(day, lastDay))
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
export function getNextSubscriptionPeriod(periodEnd: Date) {
|
||||
const start = addMonthsSafe(periodEnd, 1)
|
||||
const end = addMonthsSafe(start, 1)
|
||||
return { start, end }
|
||||
}
|
||||
Loading…
Reference in a new issue