First deployment ready version

This commit is contained in:
Saurabh Misra 2023-12-20 17:07:24 -08:00
parent bcff8d6508
commit 121c02b286
6 changed files with 73 additions and 24 deletions

View file

@ -14,6 +14,7 @@ import os
from pathlib import Path
import dj_database_url
import dotenv
# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent
@ -23,17 +24,25 @@ BASE_DIR = Path(__file__).resolve().parent.parent
# See https://docs.djangoproject.com/en/5.0/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = "django-insecure-g+z1*fom2&5^$7q20a^s@677+h*hc((xrx9g1_*935t&o+hven"
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True
if os.environ.get("ENVIRONMENT") == "PRODUCTION":
SECRET_KEY = os.environ.get("SECRET_KEY")
DEBUG = False
else:
env_vars = dotenv.dotenv_values()
SECRET_KEY = env_vars["SECRET_KEY"]
DATABASE_URL = env_vars["DATABASE_URL"]
DEBUG = True
assert "DATABASE_URL" in os.environ, "DATABASE_URL environment variable not set"
assert "SECRET_KEY" in os.environ, "SECRET_KEY environment variable not set"
ALLOWED_HOSTS = []
# Application definition
INSTALLED_APPS = [
"testgen.apps.TestgenConfig",
"optimizer.apps.OptimizerConfig",
"authapp.apps.AuthAppConfig",
"django.contrib.admin",
@ -78,7 +87,7 @@ WSGI_APPLICATION = "aiservice.wsgi.application"
# Database
# https://docs.djangoproject.com/en/5.0/ref/settings/#databases
# Requires DATABASE_URL environment variable to be set
assert "DATABASE_URL" in os.environ, "DATABASE_URL environment variable not set"
DATABASES = {
"default": dj_database_url.config(
conn_max_age=600,

View file

@ -16,9 +16,12 @@ Including another URLconf
"""
from django.contrib import admin
from django.urls import path
from optimizer.optimizer import optimize_api
from testgen.testgen import testgen_api
urlpatterns = [
path("admin/", admin.site.urls),
path("ai/", optimize_api.urls),
path("ai/optimize", optimize_api.urls),
path("ai/testgen", testgen_api.urls),
]

View file

@ -1,10 +1,11 @@
import base64
import hashlib
from authapp.models import CFAPIKeys
from django.shortcuts import aget_object_or_404
from django.db.models.functions import Now
from ninja.security import HttpBearer
from authapp.models import CFAPIKeys
class AuthBearer(HttpBearer):
async def authenticate(self, request, token):
@ -13,11 +14,17 @@ class AuthBearer(HttpBearer):
hashlib.sha384(token.encode("utf-8")).digest()
).decode("utf-8")
try:
# TODO: Convert this to update the last_updated field when accessed
user = await aget_object_or_404(
CFAPIKeys.objects.only("user_id"), key=hashed_token
num_users = await CFAPIKeys.objects.filter(key=hashed_token).aupdate(
last_used=Now()
)
if num_users == 0:
return
elif num_users == 1:
return token
else:
print(
"THIS SHOULD NOT HAPPEN! More than one users found in the db with the same api key!"
)
return token
except CFAPIKeys.DoesNotExist:
raise Exception("Invalid API Key")
if user.user_id is not None:
return token
return

View file

@ -25,7 +25,10 @@ class CFAPIKeys(models.Model):
suffix = models.CharField(max_length=4)
name = models.CharField(max_length=255)
created_at = models.DateTimeField(auto_now_add=True)
last_used = models.DateTimeField(null=True, blank=True)
last_used = models.DateTimeField(
null=True,
blank=True,
)
user_id = models.TextField(null=True, blank=True)
org_id = models.TextField(null=True, blank=True)

View file

@ -2,12 +2,13 @@ import os
import re
from typing import List, Tuple
from authapp.auth import AuthBearer
from dotenv import load_dotenv
from ninja import NinjaAPI, Schema
from openai import AsyncOpenAI, APIError
optimize_api = NinjaAPI(auth=AuthBearer())
from authapp.auth import AuthBearer
optimize_api = NinjaAPI(auth=AuthBearer(), urls_namespace="optimize")
if os.environ.get("ENVIRONMENT") != "PRODUCTION":
load_dotenv()
@ -73,7 +74,7 @@ class OptimizeResponseSchema(Schema):
explanation: str
@optimize_api.post("/optimize")
@optimize_api.post("/")
async def optimize(request, data: OptimizeSchema):
optimizations = await optimize_python_code(data.source_code, n=10)
if len(optimizations) == 0 or optimizations[0][0] == "":

View file

@ -11,7 +11,7 @@ from openai import AsyncOpenAI
from authapp.auth import AuthBearer
from testgen.aimodels import EXPLAIN_MODEL, PLAN_MODEL, EXECUTE_MODEL, LLM
testgen_api = NinjaAPI(auth=AuthBearer())
testgen_api = NinjaAPI(auth=AuthBearer(), urls_namespace="testgen")
if os.environ.get("ENVIRONMENT") != "PRODUCTION":
load_dotenv()
@ -92,7 +92,7 @@ async def regression_tests_from_function(
print_messages(explain_messages)
try:
explanation_response = await openai_client.with_options(
max_retries=3
max_retries=2
).chat.completions.create(
model=explain_model.name, messages=explain_messages, temperature=temperature
)
@ -126,7 +126,7 @@ To help unit test the function above, list diverse scenarios that the function s
print_messages([plan_user_message])
try:
plan_response = await openai_client.with_options(
max_retries=3
max_retries=2
).chat.completions.create(
model=plan_model.name, messages=plan_messages, temperature=temperature
)
@ -157,7 +157,7 @@ To help unit test the function above, list diverse scenarios that the function s
print_messages([elaboration_user_message])
try:
elaboration_response = await openai_client.with_options(
max_retries=3
max_retries=2
).chat.completions.create(
model=plan_model.name,
messages=elaboration_messages,
@ -214,7 +214,7 @@ import {unit_test_package} # used for our unit tests
while tries > 0:
try:
execute_response = await openai_client.with_options(
max_retries=3
max_retries=2
).chat.completions.create(
model=execute_model.name,
messages=execute_messages,
@ -233,6 +233,7 @@ import {unit_test_package} # used for our unit tests
# If the test generator is generating ellipsis, it is punting on generating
# the concrete test cases and we should re-generate
raise SyntaxError("Ellipsis in generated test code, regenerating...")
break
except SyntaxError as e:
tries -= 1
logging.warning(f"Syntax error in generated code: {e}")
@ -256,8 +257,33 @@ class TestGenResponseSchema(Schema):
code: str
@testgen_api.post("/testgen")
async def testgen(request, data: TestGenSchema):
class TestGenErrorResponseSchema(Schema):
detail: str
@testgen_api.post(
"/", response={200: TestGenResponseSchema, 400: TestGenErrorResponseSchema}
)
async def testgen(
request,
data: TestGenSchema,
):
if data.test_framework not in ["unittest", "pytest"]:
return 400, {
"detail": "Invalid test framework. We only support unittest and pytest."
}
if data.function_name == "":
# TODO: Add a validation check here to see if the function_name is actually present in
# the source_code_being_tested. Parse ast
return 400, {"detail": "Invalid function name."}
if data.source_code_being_tested == "":
return 400, {"detail": "Invalid source code. It is empty."}
try:
ast.parse(data.source_code_being_tested)
except SyntaxError as e:
return 400, {
"detail": "Invalid source code. It is not valid Python code. Please check syntax of your code."
}
try:
generated_test_source = await regression_tests_from_function(
function_code=data.source_code_being_tested,