Merge branch 'main' into cf-1083-java-jit-awareness-prompt

This commit is contained in:
mashraf-222 2026-04-08 00:05:34 +02:00 committed by GitHub
commit c91b81da63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
146 changed files with 10715 additions and 5290 deletions

View file

@ -0,0 +1,145 @@
name: cf-webapp Quality Gates
on:
pull_request:
paths:
- "js/cf-webapp/**"
permissions:
contents: read
packages: read
pull-requests: write
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
check-changes:
runs-on: ubuntu-latest
outputs:
should-run: ${{ steps.filter.outputs.webapp }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: dorny/paths-filter@v3
id: filter
with:
filters: |
webapp:
- 'js/cf-webapp/**'
skip:
needs: check-changes
if: needs.check-changes.outputs.should-run != 'true'
runs-on: ubuntu-latest
steps:
- run: echo "No cf-webapp changes, skipping."
benchmark:
needs: check-changes
if: needs.check-changes.outputs.should-run == 'true'
runs-on: ubuntu-latest
env:
NODE_AUTH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: "20"
cache: npm
cache-dependency-path: js/cf-webapp/package-lock.json
registry-url: https://npm.pkg.github.com
scope: "@codeflash-ai"
- name: Install dependencies
working-directory: js/cf-webapp
run: npm ci --ignore-scripts
- name: Generate Prisma client
working-directory: js/cf-webapp
run: npx prisma generate
- name: Type-check
id: typecheck
working-directory: js/cf-webapp
run: npx tsc --noEmit
continue-on-error: true
- name: Tests
id: tests
working-directory: js/cf-webapp
run: npx vitest run --reporter=verbose 2>&1 | tee test-output.txt
continue-on-error: true
- name: Build
id: build
working-directory: js/cf-webapp
run: npx next build 2>&1 | tee build-output.txt
continue-on-error: true
- name: Extract results
id: results
working-directory: js/cf-webapp
run: |
# Type-check status
if [ "${{ steps.typecheck.outcome }}" = "success" ]; then
echo "typecheck_status=✅ Pass" >> "$GITHUB_OUTPUT"
else
echo "typecheck_status=❌ Fail" >> "$GITHUB_OUTPUT"
fi
# Test summary
if [ "${{ steps.tests.outcome }}" = "success" ]; then
TESTS_SUMMARY=$(grep -E "Tests\s+[0-9]+" test-output.txt | tail -1 || echo "passed")
echo "tests_status=✅ ${TESTS_SUMMARY}" >> "$GITHUB_OUTPUT"
else
echo "tests_status=❌ Tests failed" >> "$GITHUB_OUTPUT"
fi
# Build status
if [ "${{ steps.build.outcome }}" = "success" ]; then
echo "build_status=✅ Success" >> "$GITHUB_OUTPUT"
else
echo "build_status=❌ Fail" >> "$GITHUB_OUTPUT"
fi
# Extract route sizes from build output
ROUTES=$(sed -n '/Route.*Size.*First Load/,/^$/p' build-output.txt | head -30 || echo "No route data")
{
echo "routes<<ROUTES_EOF"
echo "$ROUTES"
echo "ROUTES_EOF"
} >> "$GITHUB_OUTPUT"
- name: Post PR comment
if: github.event_name == 'pull_request'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
gh pr comment ${{ github.event.pull_request.number }} \
--repo ${{ github.repository }} \
--body "$(cat <<'COMMENT_EOF'
## cf-webapp Quality Report
| Check | Result |
|-------|--------|
| Type-check | ${{ steps.results.outputs.typecheck_status }} |
| Tests | ${{ steps.results.outputs.tests_status }} |
| Build | ${{ steps.results.outputs.build_status }} |
<details>
<summary>Route Sizes</summary>
```
${{ steps.results.outputs.routes }}
```
</details>
COMMENT_EOF
)"
- name: Fail if any check failed
if: steps.typecheck.outcome == 'failure' || steps.tests.outcome == 'failure' || steps.build.outcome == 'failure'
run: exit 1

View file

@ -0,0 +1,34 @@
from __future__ import annotations
import asyncio
import logging
from typing import Any, Coroutine
logger = logging.getLogger(__name__)
_background_tasks: set[asyncio.Task[Any]] = set()
def fire_and_forget(coro: Coroutine[Any, Any, Any]) -> asyncio.Task[Any]:
"""Schedule a coroutine as a background task without blocking the caller.
Holds a strong reference so the task is not garbage-collected before completion.
"""
task = asyncio.create_task(coro)
_background_tasks.add(task)
def _on_done(t: asyncio.Task[Any]) -> None:
_background_tasks.discard(t)
if t.cancelled():
return
if exc := t.exception():
logger.warning("Background task failed: %s", exc)
task.add_done_callback(_on_done)
return task
async def drain() -> None:
"""Await all pending background tasks. Call during shutdown or test teardown."""
while _background_tasks:
await asyncio.gather(*_background_tasks, return_exceptions=True)

View file

@ -13,13 +13,13 @@ from aiservice.common.llm_output_utils import truncate_pathological_output
# Matches both ```python and ```python:filepath blocks, captures content only
MARKDOWN_CODE_BLOCK_PATTERN = re.compile(r"```python(?::[^\n]*)?\n(.*?)```", re.DOTALL)
# Matches first ```python block (no filepath), captures content.
# Matches first ```python block (with optional :filepath), captures content.
# Uses greedy (.*) to handle LLM outputs with nested code fences (e.g. ```python:filepath
# blocks inside the main block). Requires closing ``` to be alone on its line.
FIRST_CODE_BLOCK_PATTERN = re.compile(r"^```python\s*\n(.*)\n```[ \t]*$", re.MULTILINE | re.DOTALL)
FIRST_CODE_BLOCK_PATTERN = re.compile(r"^```python(?::[^\n]*)?\s*\n(.*)\n```[ \t]*$", re.MULTILINE | re.DOTALL)
# Fallback for incomplete code blocks (missing closing ```)
FIRST_CODE_BLOCK_FALLBACK_PATTERN = re.compile(r"^```python\s*\n(.*)", re.MULTILINE | re.DOTALL)
FIRST_CODE_BLOCK_FALLBACK_PATTERN = re.compile(r"^```python(?::[^\n]*)?\s*\n(.*)", re.MULTILINE | re.DOTALL)
def extract_all_code_from_markdown(markdown: str) -> str:

View file

@ -9,22 +9,18 @@ import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import stamina
import sentry_sdk
from anthropic import (
APIConnectionError as AnthropicConnectionError,
APITimeoutError as AnthropicTimeoutError,
AsyncAnthropicBedrock,
InternalServerError as AnthropicServerError,
RateLimitError as AnthropicRateLimitError,
)
from openai import (
APIConnectionError as OpenAIConnectionError,
APITimeoutError as OpenAITimeoutError,
AsyncAzureOpenAI,
InternalServerError as OpenAIServerError,
RateLimitError as OpenAIRateLimitError,
)
import stamina
from anthropic import APIConnectionError as AnthropicConnectionError
from anthropic import APITimeoutError as AnthropicTimeoutError
from anthropic import AsyncAnthropicBedrock
from anthropic import InternalServerError as AnthropicServerError
from anthropic import RateLimitError as AnthropicRateLimitError
from openai import APIConnectionError as OpenAIConnectionError
from openai import APITimeoutError as OpenAITimeoutError
from openai import AsyncAzureOpenAI
from openai import InternalServerError as OpenAIServerError
from openai import RateLimitError as OpenAIRateLimitError
from aiservice.llm_models import has_anthropic, has_openai
@ -36,6 +32,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
_ANTHROPIC_MAX_INPUT_TOKENS = 195_000
_CHARS_PER_TOKEN_ESTIMATE = 4
_TRANSIENT_LLM_ERRORS = (
AnthropicConnectionError,
AnthropicTimeoutError,
@ -48,6 +47,14 @@ _TRANSIENT_LLM_ERRORS = (
)
class LLMOutputUnparseable(Exception):
"""Raised when the LLM responds but its output cannot be parsed into the expected format."""
def __init__(self, message: str = "LLM output could not be parsed", *, cost: float = 0.0) -> None:
super().__init__(message)
self.cost = cost
@dataclass
class LLMUsage:
input_tokens: int
@ -67,7 +74,7 @@ class LLMClient:
self.openai_client: AsyncAzureOpenAI | None = None
self.anthropic_client: AsyncAnthropicBedrock | None = None
self.client_loop: asyncio.AbstractEventLoop | None = None
self.background_tasks: set[asyncio.Task] = set()
self.background_tasks: set[asyncio.Task[Any]] = set()
async def call(
self,
@ -85,6 +92,23 @@ class LLMClient:
# Recreate provider clients when the event loop changes (stale connections)
loop = asyncio.get_running_loop()
if loop is not self.client_loop:
# Close old clients to prevent connection leaks and event loop closure errors
# Ignore errors if the client is already closed or the transport is in a bad state
if self.openai_client is not None:
try:
await self.openai_client.close()
except Exception as e:
logger.debug(
"Failed to close OpenAI client (already closed or transport error): %s", type(e).__name__
)
if self.anthropic_client is not None:
try:
await self.anthropic_client.close()
except Exception as e:
logger.debug(
"Failed to close Anthropic client (already closed or transport error): %s", type(e).__name__
)
self.client_loop = loop
self.background_tasks = set()
self.openai_client = AsyncAzureOpenAI() if has_openai else None
@ -159,6 +183,11 @@ class LLMClient:
async def call_anthropic(
self, llm: LLM, messages: list[ChatCompletionMessageParam], max_tokens: int
) -> LLMResponse:
estimated_tokens = sum(len(str(m["content"])) for m in messages) // _CHARS_PER_TOKEN_ESTIMATE
if estimated_tokens > _ANTHROPIC_MAX_INPUT_TOKENS:
msg = f"Prompt too large (~{estimated_tokens} tokens estimated, limit {_ANTHROPIC_MAX_INPUT_TOKENS})"
raise ValueError(msg)
system_prompt = next((m["content"] for m in messages if m["role"] == "system"), None)
non_system = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]

View file

@ -141,6 +141,12 @@ class TrackUsageMiddleware:
return JsonResponse({"error": "Failed to initialize user subscription"}, status=500)
if subscription.subscription_status != "active":
logging.warning(
"403 subscription inactive: user_id=%s, status=%s, endpoint=%s",
user_id,
subscription.subscription_status,
endpoint,
)
return JsonResponse(
{"error": "Subscription is not active", "status": subscription.subscription_status}, status=403
)
@ -150,6 +156,14 @@ class TrackUsageMiddleware:
current_used = subscription.optimizations_used or 0
if current_used + cost > subscription.optimizations_limit:
logging.warning(
"403 usage limit exceeded: user_id=%s, used=%s, limit=%s, tier=%s, endpoint=%s",
user_id,
current_used,
subscription.optimizations_limit,
subscription.plan_type,
endpoint,
)
return JsonResponse(
{
"error": "Usage limit exceeded",

View file

@ -82,7 +82,7 @@ assert "DATABASE_URL" in os.environ, "DATABASE_URL environment variable not set"
DATABASES = {"default": dj_database_url.config(conn_max_age=0)}
# psycopg3 native connection pooling — replaces persistent connections (conn_max_age)
# which don't work correctly under ASGI. The pool is managed by psycopg_pool.ConnectionPool.
DATABASES["default"]["OPTIONS"] = {"pool": {"min_size": 2, "max_size": 10}}
DATABASES["default"]["OPTIONS"] = {"pool": {"min_size": 2, "max_size": 100}}
# Password validation
@ -118,11 +118,26 @@ STATIC_URL: str = "static/"
DEFAULT_AUTO_FIELD: str = "django.db.models.BigAutoField"
# Logging — explicit config prevents unbounded record buffering.
# Django's default adds AdminEmailHandler which buffers; Sentry's LoggingIntegration
# adds handlers that capture every record. StreamHandler writes + flushes immediately.
# django.request at ERROR skips 4xx WARNING logs whose args pin the full ASGIRequest
# (headers, body, payload) in memory for the lifetime of the LogRecord.
LOGGING: dict[str, object] = {
"version": 1,
"disable_existing_loggers": False,
"handlers": {"console": {"class": "logging.StreamHandler"}},
"root": {"handlers": ["console"], "level": "WARNING"},
"loggers": {
"django": {"handlers": ["console"], "level": "INFO", "propagate": False},
"django.request": {"level": "ERROR"},
},
}
# Sentry
if os.environ.get("ENVIRONMENT", default="") == "PRODUCTION":
sentry_sdk.init(
dsn="https://8a857cbf974ca889a46c1b39173db44b@o4506833230561280.ingest.sentry.io/4506833234493440",
traces_sample_rate=0.1,
profiles_sample_rate=0.01,
enable_logs=True,
)

View file

@ -0,0 +1,121 @@
"""
Test for LLM client close() error handling.
This test verifies that the LLMClient handles exceptions gracefully when
closing clients during event loop changes, preventing 500 errors.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
class TestLLMClientClose:
"""Test LLMClient handles close() failures gracefully"""
@pytest.mark.asyncio
async def test_close_handles_transport_errors(self) -> None:
"""
Test that client close() errors are caught and don't propagate.
This reproduces the bug where httpx transport errors during close()
cause 500 errors on /ai/optimization_review.
"""
from aiservice.llm import LLMClient
from aiservice.llm_models import OpenAI_GPT_4_1
# Create a mock client that fails on close()
mock_openai = AsyncMock()
mock_openai.close = AsyncMock(side_effect=RuntimeError("Transport already closed"))
# Patch the client creation
with (
patch("aiservice.llm.AsyncAzureOpenAI", return_value=mock_openai),
patch("aiservice.llm.has_openai", True),
patch("aiservice.llm.has_anthropic", False),
):
client = LLMClient()
# First call - creates clients
with patch.object(client, "call_openai", AsyncMock()):
try:
await client.call(
llm=OpenAI_GPT_4_1(),
messages=[{"role": "user", "content": "test"}],
call_type="test",
trace_id="test-trace-1",
)
except Exception:
pass # Ignore call failures, we only care about close()
# Force event loop change detection
client.client_loop = None
# Second call - should try to close old client and handle errors gracefully
with patch.object(client, "call_openai", AsyncMock()):
# This should NOT raise RuntimeError from close()
try:
await client.call(
llm=OpenAI_GPT_4_1(),
messages=[{"role": "user", "content": "test"}],
call_type="test",
trace_id="test-trace-2",
)
# If we get here without exception, the bug is fixed
assert True
except RuntimeError as e:
if "Transport already closed" in str(e):
pytest.fail(f"close() error was not handled: {e}. This causes 500 errors in production.")
raise
@pytest.mark.asyncio
async def test_close_handles_event_loop_closed_error(self) -> None:
"""
Test that event loop closed errors during close() are handled.
This handles the case where the event loop is closed before we
try to close the client.
"""
from aiservice.llm import LLMClient
from aiservice.llm_models import OpenAI_GPT_4_1
mock_openai = AsyncMock()
mock_openai.close = AsyncMock(side_effect=RuntimeError("Event loop is closed"))
with (
patch("aiservice.llm.AsyncAzureOpenAI", return_value=mock_openai),
patch("aiservice.llm.has_openai", True),
patch("aiservice.llm.has_anthropic", False),
):
client = LLMClient()
# Create client
with patch.object(client, "call_openai", AsyncMock()):
try:
await client.call(
llm=OpenAI_GPT_4_1(),
messages=[{"role": "user", "content": "test"}],
call_type="test",
trace_id="test-trace-1",
)
except Exception:
pass
# Force event loop change
client.client_loop = None
# Second call should handle close error
with patch.object(client, "call_openai", AsyncMock()):
try:
await client.call(
llm=OpenAI_GPT_4_1(),
messages=[{"role": "user", "content": "test"}],
call_type="test",
trace_id="test-trace-2",
)
assert True
except RuntimeError as e:
if "Event loop is closed" in str(e):
pytest.fail(f"Event loop error was not handled: {e}")
raise

View file

@ -5,6 +5,7 @@ from django.db.models.functions import Now
from ninja.errors import HttpError
from ninja.security import HttpBearer
from aiservice.background import fire_and_forget
from authapp.auth_utils import hash_api_key
from authapp.models import CFAPIKeys, Organizations, Subscriptions
@ -58,7 +59,7 @@ class AuthBearer(HttpBearer):
api_key_instance = await CFAPIKeys.objects.filter(key=hashed_token).afirst()
if api_key_instance is None:
raise HttpError(403, "Invalid API key")
await CFAPIKeys.objects.filter(id=api_key_instance.id).aupdate(last_used=Now())
fire_and_forget(CFAPIKeys.objects.filter(id=api_key_instance.id).aupdate(last_used=Now()))
request.user = api_key_instance.user_id
request.tier = api_key_instance.tier
request.api_key_id = api_key_instance.id

View file

@ -16,14 +16,14 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
from aiservice.analytics.posthog import ph
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
from aiservice.llm import llm_client
from aiservice.llm_models import LLM, OPTIMIZE_MODEL
from authapp.auth import AuthenticatedRequest
from authapp.user import get_user_by_id
from core.languages.java.prompts.optimizer import get_system_prompt, get_user_prompt
from core.log_features.log_event import get_or_create_optimization_event
from core.log_features.log_features import log_features
from core.log_features.log_features import safe_log_features
from core.shared.context_helpers import extract_code_and_explanation, group_code, split_markdown_code
from core.shared.optimizer_config import MAX_OPTIMIZER_CALLS, get_model_distribution
from core.shared.optimizer_models import OptimizeSchema
@ -125,7 +125,9 @@ async def optimize_java_code_single(
llm_cost = output.cost
debug_log_sensitive_data(f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
debug_log_sensitive_data_from_callable(
lambda: f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}"
)
if output.raw_response.usage is not None:
ph(
@ -285,7 +287,7 @@ async def optimize_java(
)
if hasattr(request, "should_log_features") and request.should_log_features:
await log_features(
await safe_log_features(
trace_id=data.trace_id,
user_id=request.user,
original_code=data.source_code,

View file

@ -16,7 +16,7 @@ import sentry_sdk
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from aiservice.analytics.posthog import ph
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
from aiservice.llm import llm_client
from aiservice.llm_models import OPTIMIZE_MODEL
from aiservice.validators.java_validator import validate_java_syntax
@ -172,7 +172,9 @@ Here is the code to optimize:
llm_cost = output.cost
debug_log_sensitive_data(f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
debug_log_sensitive_data_from_callable(
lambda: f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}"
)
if output.raw_response.usage is not None:
ph(

View file

@ -342,7 +342,7 @@ async def generate_and_validate_java_test_code(
SyntaxError: If generated code has syntax errors
"""
obs_context: dict | None = (
obs_context: dict[str, object] | None = (
{"call_sequence": call_sequence, "test_index": test_index} if call_sequence is not None else None
)
output = await llm_client.call(
@ -623,11 +623,11 @@ async def testgen_java(
except TestGenerationFailedError as e:
logging.warning("Java test generation failed: %s", e)
sentry_sdk.capture_exception(e)
return 400, TestGenErrorResponseSchema(error=str(e))
return 422, TestGenErrorResponseSchema(error=str(e))
except (ValueError, SyntaxError) as e:
logging.warning("Java test generation error: %s", e)
sentry_sdk.capture_exception(e)
return 400, TestGenErrorResponseSchema(error=f"Failed to generate valid tests: {e}")
return 422, TestGenErrorResponseSchema(error=f"Failed to generate valid tests: {e}")
except Exception as e:
logging.exception("Unexpected error in Java test generation")
sentry_sdk.capture_exception(e)

View file

@ -27,7 +27,7 @@ from authapp.user import get_user_by_id
from core.languages.js_ts.context_helpers import is_multi_context_js, is_multi_context_ts
from core.languages.js_ts.prompts.optimizer import get_system_prompt, get_user_prompt
from core.log_features.log_event import get_or_create_optimization_event
from core.log_features.log_features import log_features
from core.log_features.log_features import safe_log_features
from core.shared.context_helpers import extract_code_and_explanation, group_code, normalize_c_style_code
from core.shared.optimizer_config import MAX_OPTIMIZER_CALLS, get_model_distribution
from core.shared.optimizer_models import OptimizedCandidateSource, OptimizeSchema
@ -161,7 +161,9 @@ You MUST output the target file. You may also output helper files if you optimiz
llm_cost = output.cost
debug_log_sensitive_data(f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
debug_log_sensitive_data_from_callable(
lambda: f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}"
)
if output.raw_response.usage is not None:
ph(
@ -464,7 +466,7 @@ async def optimize_javascript(
)
)
tg.create_task(
log_features(
safe_log_features(
trace_id=data.trace_id,
user_id=request.user,
original_code=data.source_code,
@ -491,11 +493,12 @@ async def optimize_javascript(
response = OptimizeResponseSchema(optimizations=optimization_response_items)
def log_response() -> None:
debug_log_sensitive_data(f"JavaScript Response:\n{response.model_dump_json()}")
def log_response() -> str:
parts = [f"JavaScript Response:\n{response.model_dump_json()}"]
for opt in response.optimizations:
debug_log_sensitive_data(f"Optimized JavaScript source:\n{opt.source_code}")
debug_log_sensitive_data(f"JavaScript optimization explanation:\n{opt.explanation}")
parts.append(f"Optimized JavaScript source:\n{opt.source_code}")
parts.append(f"JavaScript optimization explanation:\n{opt.explanation}")
return "\n".join(parts)
debug_log_sensitive_data_from_callable(log_response)
ph(request.user, "aiservice-optimize-successful", properties={"language": language})

View file

@ -17,7 +17,7 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
from aiservice.analytics.posthog import ph
from aiservice.common.markdown_utils import split_markdown_code
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
from aiservice.llm import llm_client
from aiservice.llm_models import OPTIMIZE_MODEL
from aiservice.validators.javascript_validator import validate_javascript_syntax, validate_typescript_syntax
@ -173,7 +173,9 @@ Here is the code to optimize:
llm_cost = output.cost
debug_log_sensitive_data(f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
debug_log_sensitive_data_from_callable(
lambda: f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}"
)
if output.raw_response.usage is not None:
ph(

View file

@ -23,7 +23,7 @@ from aiservice.llm_models import EXECUTE_MODEL, HAIKU_MODEL, OPENAI_MODEL
from aiservice.validators.javascript_validator import validate_javascript_syntax, validate_typescript_syntax
from authapp.auth import AuthenticatedRequest
from core.log_features.log_event import update_optimization_cost
from core.log_features.log_features import log_features
from core.log_features.log_features import safe_log_features
from core.shared.jinja_utils import create_prompt_env
from core.shared.testgen_models import (
TestGenerationFailedError,
@ -35,32 +35,8 @@ from core.shared.testgen_models import (
if TYPE_CHECKING:
from aiservice.llm_models import LLM
_TEST_FUNC_RE = re.compile(r"(?:test|it)\s*\(\s*['\"]")
# Get the directory of the current file
current_dir = Path(__file__).parent
JS_PROMPTS_DIR = current_dir / "prompts" / "testgen"
_jinja_env = create_prompt_env(JS_PROMPTS_DIR)
# Pattern to extract JavaScript code blocks
JS_PATTERN = re.compile(r"^```(?:javascript|js|typescript|ts)?\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
# JavaScript identifier pattern: starts with letter, underscore, or $,
# followed by letters, digits, underscores, or $
JS_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z_$][a-zA-Z0-9_$]*$")
def _is_valid_js_identifier(name: str) -> bool:
"""Check if a name is a valid JavaScript identifier (not a reserved word).
JavaScript identifiers can start with a letter, underscore, or $,
followed by letters, digits, underscores, or $. This differs from
Python identifiers (e.g., '$handler' is valid in JS but not Python).
"""
# Reserved words that cannot be used as variable names
reserved_words = {
_JS_RESERVED_WORDS = frozenset(
{
"module",
"exports",
"require",
@ -115,7 +91,45 @@ def _is_valid_js_identifier(name: str) -> bool:
"true",
"false",
}
return bool(JS_IDENTIFIER_PATTERN.match(name)) and name not in reserved_words
)
_DEFAULT_CLASS_PATTERN = re.compile(r"\bexport\s+default\s+class\s+([a-zA-Z_$][a-zA-Z0-9_$]*)\b")
_DEFAULT_FUNC_PATTERN = re.compile(r"\bexport\s+default\s+function\s+([a-zA-Z_$][a-zA-Z0-9_$]*)\b")
_NAMED_CLASS_PATTERN = re.compile(r"\bexport\s+class\s+([a-zA-Z_$][a-zA-Z0-9_$]*)\b")
_NAMED_FUNC_PATTERN = re.compile(r"\bexport\s+function\s+([a-zA-Z_$][a-zA-Z0-9_$]*)\b")
_NAMED_CONST_PATTERN = re.compile(r"\bexport\s+const\s+([a-zA-Z_$][a-zA-Z0-9_$]*)\b")
_NAMED_BRACES_PATTERN = re.compile(r"\bexport\s*\{[^}]*\b([a-zA-Z_$][a-zA-Z0-9_$]*)\b[^{]*\}")
_TEST_FUNC_RE = re.compile(r"(?:test|it)\s*\(\s*['\"]")
# Get the directory of the current file
current_dir = Path(__file__).parent
JS_PROMPTS_DIR = current_dir / "prompts" / "testgen"
_jinja_env = create_prompt_env(JS_PROMPTS_DIR)
# Pattern to extract JavaScript code blocks
JS_PATTERN = re.compile(r"^```(?:javascript|js|typescript|ts)?\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
# JavaScript identifier pattern: starts with letter, underscore, or $,
# followed by letters, digits, underscores, or $
JS_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z_$][a-zA-Z0-9_$]*$")
def _is_valid_js_identifier(name: str) -> bool:
"""Check if a name is a valid JavaScript identifier (not a reserved word).
JavaScript identifiers can start with a letter, underscore, or $,
followed by letters, digits, underscores, or $. This differs from
Python identifiers (e.g., '$handler' is valid in JS but not Python).
"""
return bool(JS_IDENTIFIER_PATTERN.match(name)) and name not in _JS_RESERVED_WORDS
# Patterns to strip file extensions from import paths
@ -127,12 +141,15 @@ _REQUIRE_EXTENSION_PATTERN = re.compile(
_JEST_MOCK_EXTENSION_PATTERN = re.compile(
r"""(jest\.(?:mock|doMock|unmock|requireActual|requireMock)\s*\(\s*['"])(\.{0,2}/[^'"]+?)(\.(?:js|ts|tsx|jsx|mjs|mts))(['"])"""
)
_VITEST_MOCK_EXTENSION_PATTERN = re.compile(
r"""(vi\.(?:mock|doMock|unmock|requireActual|requireMock|importActual|importMock)\s*\(\s*['"])(\.{0,2}/[^'"]+?)(\.(?:js|ts|tsx|jsx|mjs|mts))(['"])"""
)
def strip_js_extensions(source: str) -> str:
"""Strip .js/.ts/.tsx/.jsx extensions from relative import paths.
TypeScript and Jest's module resolution automatically resolve file extensions,
TypeScript and Jest/Vitest module resolution automatically resolve file extensions,
so adding them explicitly can cause "Cannot find module" errors when the LLM
adds incorrect extensions (e.g., .js to a .ts file).
@ -140,6 +157,7 @@ def strip_js_extensions(source: str) -> str:
- ES module imports: import { x } from '../path/file.js'
- CommonJS requires: require('../path/file.js')
- Jest mocks: jest.mock('../path/file.js')
- Vitest mocks: vi.mock('../path/file.js')
Args:
source: The test source code.
@ -150,19 +168,70 @@ def strip_js_extensions(source: str) -> str:
"""
source = _JS_EXTENSION_PATTERN.sub(r"\1\2\4", source)
source = _REQUIRE_EXTENSION_PATTERN.sub(r"\1\2\4", source)
return _JEST_MOCK_EXTENSION_PATTERN.sub(r"\1\2\4", source)
source = _JEST_MOCK_EXTENSION_PATTERN.sub(r"\1\2\4", source)
return _VITEST_MOCK_EXTENSION_PATTERN.sub(r"\1\2\4", source)
def _resolve_import(function_name: str, module_path: str) -> tuple[str, str, str]:
def _detect_export_style(source_code: str, identifier: str) -> str | None:
"""Detect if an identifier is exported and its export style.
Args:
source_code: The source code to analyze
identifier: The class/function name to look for
Returns:
"default" if exported as default, "named" if exported as named, None if not exported
"""
# Normalize whitespace for easier pattern matching
normalized = re.sub(r"\s+", " ", source_code)
# Check for default export: export default class X or export default function X
for match in _DEFAULT_CLASS_PATTERN.finditer(normalized):
if match.group(1) == identifier:
return "default"
for match in _DEFAULT_FUNC_PATTERN.finditer(normalized):
if match.group(1) == identifier:
return "default"
# Check for named export: export class X, export function X, export const X, export { X }
for match in _NAMED_CLASS_PATTERN.finditer(normalized):
if match.group(1) == identifier:
return "named"
for match in _NAMED_FUNC_PATTERN.finditer(normalized):
if match.group(1) == identifier:
return "named"
for match in _NAMED_CONST_PATTERN.finditer(normalized):
if match.group(1) == identifier:
return "named"
for match in _NAMED_BRACES_PATTERN.finditer(normalized):
if match.group(1) == identifier:
return "named"
return None
def _resolve_import(function_name: str, module_path: str, source_code: str = "") -> tuple[str, str, str]:
"""Determine import style and binding name for a JS/TS function.
Analyzes the function name to decide how to import it:
- 'Validator.validateRequest' default import of Validator
- 'execMongoEval' named/destructuring import
- 'Constructor.prototype.method' namespace import (fallback)
Analyzes the function name and source code to decide how to import it:
- 'ClassName.method' with 'export class ClassName' named import
- 'ClassName.method' with 'export default class ClassName' default import
- 'ClassName.method' with no export named import (will fail, surfacing the issue)
- 'funcName' named import
- Complex patterns namespace import (fallback)
The actual ESM vs CJS formatting is handled by the Jinja2 js_import macro.
Args:
function_name: Name of the function to test (e.g., "ClassName.method")
module_path: Import path for the module
source_code: Source code to analyze for export detection (optional)
Returns:
Tuple of (import_style, import_name, function_accessor) where
import_style is "default", "named", or "namespace".
@ -170,14 +239,26 @@ def _resolve_import(function_name: str, module_path: str) -> tuple[str, str, str
"""
parts = function_name.split(".")
# Handle ClassName.method pattern
if len(parts) == 2:
class_name, method_name = parts
if _is_valid_js_identifier(class_name) and method_name.isidentifier():
return "default", class_name, function_name
# Detect export style from source code if provided
if source_code:
export_style = _detect_export_style(source_code, class_name)
if export_style == "default":
return "default", class_name, function_name
# For both "named" and None (not exported), use named import
# If not exported, the test will fail, surfacing the issue that the class needs to be exported
return "named", class_name, function_name
# Fallback to named import if no source code provided (safer than assuming default)
return "named", class_name, function_name
# Handle standalone function
if len(parts) == 1 and parts[0].isidentifier():
return "named", function_name, function_name
# Fallback to namespace import for complex patterns
module_name = module_path.rstrip("/").split("/")[-1].replace("-", "_").replace(".", "_")
if not module_name.isidentifier():
module_name = "mod"
@ -213,7 +294,7 @@ def build_javascript_prompt(
system_template = f"{async_prefix}{framework}system.md.j2"
posthog_event_suffix = "async-" if is_async else ""
import_style, import_name, function_accessor = _resolve_import(function_name, module_path)
import_style, import_name, function_accessor = _resolve_import(function_name, module_path, function_code)
system_message: ChatCompletionMessageParam = {
"role": "system",
@ -524,12 +605,15 @@ async def testgen_javascript(
)
# Strip incorrect file extensions from import paths (LLMs sometimes add .js to .ts imports)
# Must strip from ALL three test outputs since CLI uses instrumented versions
generated_test_source = strip_js_extensions(generated_test_source)
instrumented_behavior_tests = strip_js_extensions(instrumented_behavior_tests)
instrumented_perf_tests = strip_js_extensions(instrumented_perf_tests)
ph(request.user, "aiservice-testgen-tests-generated", properties={"language": language})
if hasattr(request, "should_log_features") and request.should_log_features:
await log_features(
await safe_log_features(
trace_id=data.trace_id,
user_id=request.user,
generated_tests=[generated_test_source],
@ -550,6 +634,10 @@ async def testgen_javascript(
instrumented_perf_tests=instrumented_perf_tests,
)
except TestGenerationFailedError as e:
logging.warning("JavaScript test generation failed: %s trace_id=%s", e, data.trace_id)
sentry_sdk.capture_exception(e)
return 422, TestGenErrorResponseSchema(error=str(e))
except Exception as e:
logging.exception("JavaScript test generation failed. trace_id=%s", data.trace_id)
sentry_sdk.capture_exception(e)

View file

@ -1,28 +1,33 @@
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import libcst as cst
import sentry_sdk
import stamina
from ninja import NinjaAPI, Schema
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from pydantic import ValidationError
from aiservice.analytics.posthog import ph
from aiservice.background import fire_and_forget
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.llm import llm_client
from aiservice.llm import LLMOutputUnparseable, llm_client
from aiservice.llm_models import ADAPTIVE_OPTIMIZE_MODEL
from authapp.auth import AuthenticatedRequest
from core.log_features.log_event import update_optimization_cost
from core.log_features.log_features import log_features
from core.log_features.log_features import safe_log_features
from core.shared.optimizer_models import OptimizedCandidateSource
from core.shared.optimizer_schemas import OptimizeResponseItemSchema
from .adaptive_optimizer_context import AdaptiveOptContext, AdaptiveOptContextData, AdaptiveOptRequestSchema
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from openai.types.chat import ChatCompletionMessageParam
@ -49,9 +54,10 @@ SYSTEM_PROMPT = (current_dir / "ADAPTIVE_OPTIMIZER_SYSTEM_PROMPT.md").read_text(
USER_PROMPT = (current_dir / "ADAPTIVE_OPTIMIZER_USER_PROMPT.md").read_text()
@stamina.retry(on=LLMOutputUnparseable, attempts=2)
async def perform_adaptive_optimize(
user_id: str, ctx: AdaptiveOptContext, trace_id: str = "", optimize_model: LLM = ADAPTIVE_OPTIMIZE_MODEL
) -> tuple[OptimizeResponseItemSchema | None, float | None, AdaptiveOptErrorResponseSchema | None]:
) -> tuple[OptimizeResponseItemSchema, float]:
system_prompt = ctx.get_system_prompt()
user_prompt = ctx.get_user_prompt()
@ -65,8 +71,9 @@ async def perform_adaptive_optimize(
)
llm_cost = output.cost
except Exception as e:
logger.exception("adaptive_optimize LLM call failed: trace_id=%s, user_id=%s", trace_id, user_id)
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.data.original_source_code}")
return None, None, AdaptiveOptErrorResponseSchema(error=str(e))
raise LLMOutputUnparseable(str(e)) from e
debug_log_sensitive_data(f"ClaudeClient optimization response:\n{output.content}")
if output.usage is not None:
ph(
@ -83,20 +90,22 @@ async def perform_adaptive_optimize(
try:
ctx.extract_code_and_explanation_from_llm_res(llm_res)
new_opt = ctx.parse_and_generate_candidate_schema()
new_opt = await asyncio.to_thread(ctx.parse_and_generate_candidate_schema)
if not new_opt or not ctx.is_valid_code():
extracted_code = ctx.extracted_code_and_expl.code if ctx.extracted_code_and_expl else None
return (None, None, AdaptiveOptErrorResponseSchema(error="Invalid code generated " + str(extracted_code)))
logger.error("adaptive_optimize invalid code: trace_id=%s, user_id=%s", trace_id, user_id)
raise LLMOutputUnparseable("Invalid code generated " + str(extracted_code), cost=llm_cost)
# the parent is the last candidate in the previous optimizations
last_optimization_id = ctx.data.attempts[-1].optimization_id
new_opt.parent_id = last_optimization_id
return new_opt, llm_cost, None # noqa: TRY300
return new_opt, llm_cost # noqa: TRY300
except (ValueError, ValidationError, cst.ParserSyntaxError) as exc:
logger.exception("adaptive_optimize parsing failed: trace_id=%s, user_id=%s", trace_id, user_id)
sentry_sdk.capture_exception(exc)
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{ctx.data.original_source_code}")
debug_log_sensitive_data(f"Traceback: {exc}")
return None, None, AdaptiveOptErrorResponseSchema(error=str(exc))
raise LLMOutputUnparseable(str(exc), cost=llm_cost) from exc
@adaptive_optimize_api.post(
@ -104,6 +113,7 @@ async def perform_adaptive_optimize(
response={
200: OptimizeResponseItemSchema,
400: AdaptiveOptErrorResponseSchema,
422: AdaptiveOptErrorResponseSchema,
500: AdaptiveOptErrorResponseSchema,
},
)
@ -118,37 +128,35 @@ async def adaptive_optimize(
trace_id = data.trace_id
if not validate_trace_id(trace_id):
return 400, AdaptiveOptErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
try:
adaptive_optimization_candidate, llm_cost = await perform_adaptive_optimize(
user_id=request.user, ctx=ctx, trace_id=trace_id
)
if adaptive_optimization_candidate is None:
logger.error("adaptive_optimize endpoint returning 500: trace_id=%s, no candidate generated", trace_id)
return 500, AdaptiveOptErrorResponseSchema(error="Failed to generate optimization candidate")
adaptive_optimization_candidate, llm_cost, error = await perform_adaptive_optimize(
user_id=request.user, ctx=ctx, trace_id=trace_id
)
total_llm_cost = 0.0
if error:
return 500, error
if adaptive_optimization_candidate is None:
return 500, AdaptiveOptErrorResponseSchema(error="Failed to generate optimization candidate")
if llm_cost is not None:
total_llm_cost += llm_cost
except LLMOutputUnparseable as e:
return 422, AdaptiveOptErrorResponseSchema(error=str(e))
async with asyncio.TaskGroup() as tg:
tg.create_task(update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=request.user))
if hasattr(request, "should_log_features") and request.should_log_features:
tg.create_task(
log_features(
trace_id=data.trace_id,
user_id=request.user,
optimizations_raw={
adaptive_optimization_candidate.optimization_id: adaptive_optimization_candidate.source_code
},
explanations_raw={
adaptive_optimization_candidate.optimization_id: adaptive_optimization_candidate.explanation
},
optimizations_origin={
adaptive_optimization_candidate.optimization_id: {
"source": OptimizedCandidateSource.ADAPTIVE,
"parent": adaptive_optimization_candidate.parent_id,
}
},
)
fire_and_forget(update_optimization_cost(trace_id=trace_id, cost=llm_cost, user_id=request.user))
if hasattr(request, "should_log_features") and request.should_log_features:
fire_and_forget(
safe_log_features(
trace_id=data.trace_id,
user_id=request.user,
optimizations_raw={
adaptive_optimization_candidate.optimization_id: adaptive_optimization_candidate.source_code
},
explanations_raw={
adaptive_optimization_candidate.optimization_id: adaptive_optimization_candidate.explanation
},
optimizations_origin={
adaptive_optimization_candidate.optimization_id: {
"source": OptimizedCandidateSource.ADAPTIVE,
"parent": adaptive_optimization_candidate.parent_id,
}
},
)
)
return 200, adaptive_optimization_candidate

View file

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
import libcst as cst
import sentry_sdk
import stamina
from ninja import NinjaAPI, Schema
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from pydantic import ValidationError
@ -14,7 +15,7 @@ from pydantic import ValidationError
from aiservice.analytics.posthog import ph
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.llm import llm_client
from aiservice.llm import LLMOutputUnparseable, llm_client
from aiservice.llm_models import CODE_REPAIR_MODEL
from authapp.auth import AuthenticatedRequest
from core.languages.python.code_repair.code_repair_context import (
@ -23,7 +24,7 @@ from core.languages.python.code_repair.code_repair_context import (
CodeRepairRequestSchema,
)
from core.log_features.log_event import update_optimization_cost
from core.log_features.log_features import log_features
from core.log_features.log_features import safe_log_features
from core.shared.optimizer_models import OptimizedCandidateSource
from core.shared.optimizer_schemas import OptimizeResponseItemSchema
@ -41,8 +42,13 @@ SYSTEM_PROMPT = (current_dir / "CODE_REPAIR_SYSTEM_PROMPT.md").read_text()
USER_PROMPT = (current_dir / "CODE_REPAIR_USER_PROMPT.md").read_text()
@stamina.retry(on=LLMOutputUnparseable, attempts=2)
async def code_repair( # noqa: D417
user_id: str, optimization_id: str, ctx: CodeRepairContext, optimize_model: LLM = CODE_REPAIR_MODEL
user_id: str,
optimization_id: str,
ctx: CodeRepairContext,
trace_id: str = "",
optimize_model: LLM = CODE_REPAIR_MODEL,
) -> CodeRepairIntermediateResponseItemschema | CodeRepairErrorResponseSchema:
"""Repair the given candidate to match the behaviour of the original code.
@ -58,7 +64,7 @@ async def code_repair( # noqa: D417
Returns
-------
CodeRepairIntermediateResponseItemschema or CodeRepairErrorResponseSchema
CodeRepairIntermediateResponseItemschema
"""
system_prompt = ctx.get_system_prompt()
@ -70,11 +76,13 @@ async def code_repair( # noqa: D417
messages: list[ChatCompletionMessageParam] = [system_message, user_message]
debug_log_sensitive_data(f"This was the user prompt\n {user_prompt}\n")
try:
output = await llm_client.call(llm=optimize_model, messages=messages)
output = await llm_client.call(
llm=optimize_model, messages=messages, call_type="code_repair", trace_id=trace_id, user_id=user_id
)
llm_cost = output.cost
except Exception as e:
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.data.original_source_code}")
return CodeRepairErrorResponseSchema(error=str(e))
raise LLMOutputUnparseable(str(e)) from e
debug_log_sensitive_data(f"ClaudeClient optimization response:\n{output.content}")
if output.usage is not None:
ph(
@ -96,10 +104,10 @@ async def code_repair( # noqa: D417
sentry_sdk.capture_exception(exc)
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{ctx.data.modified_source_code}")
debug_log_sensitive_data(f"Traceback: {exc}")
repaired_optimization = ""
raise LLMOutputUnparseable(str(exc), cost=llm_cost) from exc
if not ctx.is_valid(repaired_optimization):
repaired_optimization = ""
raise LLMOutputUnparseable("Repaired code is not valid", cost=llm_cost)
return CodeRepairIntermediateResponseItemschema(
optimization_id=new_op_id,
@ -125,9 +133,14 @@ class CodeRepairIntermediateResponseItemschema(Schema):
@code_repair_api.post(
"/",
response={200: OptimizeResponseItemSchema, 400: CodeRepairErrorResponseSchema, 500: CodeRepairErrorResponseSchema},
response={
200: OptimizeResponseItemSchema,
400: CodeRepairErrorResponseSchema,
422: CodeRepairErrorResponseSchema,
500: CodeRepairErrorResponseSchema,
},
)
async def repair( # noqa: PLR0911
async def repair(
request: AuthenticatedRequest, data: CodeRepairRequestSchema
) -> tuple[int, OptimizeResponseItemSchema | CodeRepairErrorResponseSchema]:
ph(request.user, "aiservice-code_repair-called")
@ -151,11 +164,15 @@ async def repair( # noqa: PLR0911
if result is not None:
return 200, result
code_repair_data = await code_repair(user_id=request.user, optimization_id=data.optimization_id, ctx=ctx)
total_llm_cost = 0.0
if isinstance(code_repair_data, CodeRepairErrorResponseSchema):
return 500, code_repair_data
total_llm_cost += code_repair_data.llm_cost
try:
code_repair_data = await code_repair(
user_id=request.user, optimization_id=data.optimization_id, ctx=ctx, trace_id=trace_id
)
if isinstance(code_repair_data, CodeRepairErrorResponseSchema):
return 500, code_repair_data
except LLMOutputUnparseable as e:
return 422, CodeRepairErrorResponseSchema(error=str(e))
llm_cost = code_repair_data.llm_cost
try:
ctx.validate_module()
except cst.ParserSyntaxError as e:
@ -163,19 +180,19 @@ async def repair( # noqa: PLR0911
sentry_sdk.capture_exception(e)
debug_log_sensitive_data(f"ParserSyntaxError for source:\n{code_repair_data.source_code}")
debug_log_sensitive_data(f"Traceback: {e}")
return 500, CodeRepairErrorResponseSchema(error=str(e))
return 422, CodeRepairErrorResponseSchema(error=str(e))
except (ValueError, ValidationError) as exc:
# Another one bites the Pydantic validation dust
sentry_sdk.capture_exception(exc)
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{code_repair_data.source_code}")
debug_log_sensitive_data(f"Traceback: {exc}")
return 500, CodeRepairErrorResponseSchema(error=str(exc))
return 422, CodeRepairErrorResponseSchema(error=str(exc))
async with asyncio.TaskGroup() as tg:
tg.create_task(update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=request.user))
tg.create_task(update_optimization_cost(trace_id=trace_id, cost=llm_cost, user_id=request.user))
if hasattr(request, "should_log_features") and request.should_log_features:
tg.create_task(
log_features(
safe_log_features(
trace_id=data.trace_id,
user_id=request.user,
optimizations_raw={code_repair_data.optimization_id: code_repair_data.source_code},

View file

@ -20,7 +20,9 @@ from aiservice.common.markdown_utils import wrap_code_in_markdown
from aiservice.common.xml_utils import extract_xml_tag
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import IS_PRODUCTION, debug_log_sensitive_data
from aiservice.llm import llm_client
import stamina
from aiservice.llm import LLMOutputUnparseable, llm_client
from aiservice.llm_models import EXPLANATIONS_MODEL, LLM
from core.languages.python.explanations.models import (
ExplanationsErrorResponseSchema,
@ -42,9 +44,10 @@ SYSTEM_PROMPT_TEMPLATE = _jinja_env.get_template("system_prompt.md.j2")
USER_PROMPT_TEMPLATE = _jinja_env.get_template("user_prompt.md.j2")
@stamina.retry(on=LLMOutputUnparseable, attempts=2)
async def explain_optimizations(
user_id: str, data: ExplanationsSchema, explanations_model: LLM = EXPLANATIONS_MODEL
) -> tuple[ExplanationsResponseSchema, float] | ExplanationsErrorResponseSchema:
) -> tuple[ExplanationsResponseSchema, float]:
# Avoid building potentially very large debug strings when logging is disabled.
if not IS_PRODUCTION:
debug_log_sensitive_data(f"Generating an explanation for {user_id}:\n{data.optimized_code}")
@ -111,7 +114,7 @@ async def explain_optimizations(
context=obs_context,
)
except Exception as e:
return ExplanationsErrorResponseSchema(error=str(e))
raise LLMOutputUnparseable(str(e)) from e
if not IS_PRODUCTION:
debug_log_sensitive_data(f"AIClient optimization response:\n{output.content}")
if output.usage is not None:
@ -132,6 +135,7 @@ async def explain_optimizations(
response={
200: ExplanationsResponseSchema,
400: ExplanationsErrorResponseSchema,
422: ExplanationsErrorResponseSchema,
500: ExplanationsErrorResponseSchema,
},
)
@ -141,25 +145,25 @@ async def explain(
ph(request.user, "aiservice-explain-called")
if not validate_trace_id(data.trace_id):
return 400, ExplanationsErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
result = await explain_optimizations(request.user, data)
if isinstance(result, ExplanationsErrorResponseSchema):
try:
explanation_response, llm_cost = await explain_optimizations(request.user, data)
except LLMOutputUnparseable:
ph(request.user, "Explanation not generated, revert to old explanation")
debug_log_sensitive_data("No explanation was generated")
return 500, ExplanationsErrorResponseSchema(error="Error generating optimizations. Internal server error.")
explanation_response, llm_cost = result
return 422, ExplanationsErrorResponseSchema(error="Error generating optimizations. Internal server error.")
ph(request.user, "explanation generated", properties={"explanation": explanation_response})
# parse xml tag for explanation
explanation = extract_xml_tag(explanation_response.explanation, "explain")
if not explanation:
return 500, ExplanationsErrorResponseSchema(error="Failed to parse explanation from LLM response.")
return 422, ExplanationsErrorResponseSchema(error="Failed to parse explanation from LLM response.")
coros: list[Coroutine[Any, Any, Any]] = [
update_optimization_cost(trace_id=data.trace_id, cost=llm_cost, user_id=request.user)
]
if hasattr(request, "should_log_features") and request.should_log_features:
coros.append(log_features(trace_id=data.trace_id, user_id=request.user, final_explanation=explanation))
results = await asyncio.gather(*coros, return_exceptions=True)
for result in results:
if isinstance(result, BaseException):
sentry_sdk.capture_exception(result)
for coro_result in results:
if isinstance(coro_result, BaseException):
sentry_sdk.capture_exception(coro_result)
response = ExplanationsResponseSchema(explanation=explanation)
return 200, response

View file

@ -13,6 +13,7 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
from pydantic import ValidationError
from aiservice.analytics.posthog import ph
from aiservice.background import fire_and_forget
from aiservice.common_utils import parse_python_version, validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
from aiservice.llm import llm_client
@ -23,7 +24,7 @@ from core.languages.python.optimizer.context_utils.optimizer_context import Base
from core.languages.python.optimizer.diff_patches_utils.diff import DiffMethod
from core.languages.python.optimizer.models import JitRewriteOptimizeSchema
from core.log_features.log_event import get_or_create_optimization_event
from core.log_features.log_features import log_features
from core.log_features.log_features import safe_log_features
from core.shared.optimizer_config import MAX_OPTIMIZER_CALLS, get_model_distribution
from core.shared.optimizer_models import OptimizedCandidateSource
from core.shared.optimizer_schemas import (
@ -77,7 +78,9 @@ async def jit_rewrite_python_code_single(
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.source_code}")
return None, None, jit_rewrite_model.name
llm_cost = output.cost
debug_log_sensitive_data(f"OpenAIClient jit rewrite response:\n{output.raw_response.model_dump_json(indent=2)}")
debug_log_sensitive_data_from_callable(
lambda: f"OpenAIClient jit rewrite response:\n{output.raw_response.model_dump_json(indent=2)}"
)
if output.raw_response.usage is not None:
ph(
user_id,
@ -86,7 +89,7 @@ async def jit_rewrite_python_code_single(
)
ctx.extract_code_and_explanation_from_llm_res(output.content)
try:
res = ctx.parse_and_generate_candidate_schema()
res = await asyncio.to_thread(ctx.parse_and_generate_candidate_schema)
if res is not None and ctx.is_valid_code():
return res, llm_cost, jit_rewrite_model.name
except (ValueError, ValidationError, cst.ParserSyntaxError) as e:
@ -248,55 +251,53 @@ async def jit_rewrite(
"aiservice-jit-rewrite-optimizations-found",
properties={"num_optimizations": len(jit_rewrite_response_items)},
)
async with asyncio.TaskGroup() as tg:
event_task = tg.create_task(
get_or_create_optimization_event(
event_type="no-pr",
user_id=request.user,
current_username=data.current_username,
repo_owner=data.repo_owner,
repo_name=data.repo_name,
trace_id=data.trace_id,
api_key_id=request.api_key_id,
metadata={
"codeflash_version": data.codeflash_version,
"num_optimizations": len(jit_rewrite_response_items),
"experiment_metadata": data.experiment_metadata,
},
llm_cost=llm_cost,
)
fire_and_forget(
safe_log_features(
trace_id=data.trace_id,
user_id=request.user,
original_code=data.source_code,
dependency_code=data.dependency_code,
optimizations_raw={op_id: cei["code"] for op_id, cei in code_and_explanations.items()},
optimizations_post={cei.optimization_id: cei.source_code for cei in jit_rewrite_response_items},
explanations_raw={op_id: cei["explanation"] for op_id, cei in code_and_explanations.items()},
explanations_post={cei.optimization_id: cei.explanation for cei in jit_rewrite_response_items},
experiment_metadata=data.experiment_metadata or None,
optimizations_origin={
cei.optimization_id: {
"source": OptimizedCandidateSource.JIT_REWRITE,
"parent": None,
"model": jit_rewrite_models.get(cei.optimization_id, "unknown"),
}
for cei in jit_rewrite_response_items
},
)
tg.create_task(
log_features(
trace_id=data.trace_id,
user_id=request.user,
original_code=data.source_code,
dependency_code=data.dependency_code,
optimizations_raw={op_id: cei["code"] for op_id, cei in code_and_explanations.items()},
optimizations_post={cei.optimization_id: cei.source_code for cei in jit_rewrite_response_items},
explanations_raw={op_id: cei["explanation"] for op_id, cei in code_and_explanations.items()},
explanations_post={cei.optimization_id: cei.explanation for cei in jit_rewrite_response_items},
experiment_metadata=data.experiment_metadata or None,
optimizations_origin={
cei.optimization_id: {
"source": OptimizedCandidateSource.JIT_REWRITE,
"parent": None,
"model": jit_rewrite_models.get(cei.optimization_id, "unknown"),
}
for cei in jit_rewrite_response_items
},
)
)
event, _created = event_task.result()
)
event, _created = await get_or_create_optimization_event(
event_type="no-pr",
user_id=request.user,
current_username=data.current_username,
repo_owner=data.repo_owner,
repo_name=data.repo_name,
trace_id=data.trace_id,
api_key_id=request.api_key_id,
metadata={
"codeflash_version": data.codeflash_version,
"num_optimizations": len(jit_rewrite_response_items),
"experiment_metadata": data.experiment_metadata,
},
llm_cost=llm_cost,
)
for item in jit_rewrite_response_items:
item.optimization_event_id = str(event.id) if event else None
response = OptimizeResponseSchema(optimizations=jit_rewrite_response_items)
def log_response() -> None:
debug_log_sensitive_data(f"Response:\n{response.model_dump_json()}")
def log_response() -> str:
parts = [f"Response:\n{response.model_dump_json()}"]
for opt in response.optimizations:
debug_log_sensitive_data(f"Optimized source:\n{opt.source_code}")
debug_log_sensitive_data(f"Optimization explanation:\n{opt.explanation}")
parts.append(f"Optimized source:\n{opt.source_code}")
parts.append(f"Optimization explanation:\n{opt.explanation}")
return "\n".join(parts)
debug_log_sensitive_data_from_callable(log_response)
ph(request.user, "aiservice-jit-rewrite-successful")

View file

@ -15,7 +15,9 @@ from packaging import version
from aiservice.analytics.posthog import ph
from aiservice.common.markdown_utils import extract_code_block_with_context, wrap_code_in_markdown
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.llm import llm_client
import stamina
from aiservice.llm import LLMOutputUnparseable, llm_client
from aiservice.llm_models import OPTIMIZATION_REVIEW_MODEL
from authapp.auth import AuthenticatedRequest
from core.log_features.log_event import update_optimization_cost, update_optimization_features_review
@ -173,6 +175,7 @@ Please analyze this optimization following the structured assessment process and
return [system_message, user_message]
@stamina.retry(on=LLMOutputUnparseable, attempts=2)
async def get_optimization_review(
request: AuthenticatedRequest,
data: OptimizationReviewSchema,
@ -250,12 +253,14 @@ async def get_optimization_review(
logging.exception("Invalid optimization review response")
sentry_sdk.capture_exception(e)
debug_log_sensitive_data(f"Invalid response : {e}")
return 500, OptimizationReviewErrorSchema(error="Invalid response"), cost
raise LLMOutputUnparseable("Invalid response", cost=cost) from e
else:
ph(request.user, "aiservice-optimization-review-successful")
return 200, review, cost
else:
return 500, OptimizationReviewErrorSchema(error="Invalid response"), cost
raise LLMOutputUnparseable("Invalid response", cost=cost)
except LLMOutputUnparseable:
raise
except Exception as e:
logging.exception("Error in optimization_review")
sentry_sdk.capture_exception(e)
@ -267,13 +272,17 @@ async def get_optimization_review(
response={
200: OptimizationReviewResponseSchema,
400: OptimizationReviewErrorSchema,
422: OptimizationReviewErrorSchema,
500: OptimizationReviewErrorSchema,
},
)
async def optimization_review(
request: AuthenticatedRequest, data: OptimizationReviewSchema
) -> tuple[int, OptimizationReviewResponseSchema | OptimizationReviewErrorSchema]:
response_code, output, llm_cost = await get_optimization_review(request, data)
try:
response_code, output, llm_cost = await get_optimization_review(request, data)
except LLMOutputUnparseable as e:
return 422, OptimizationReviewErrorSchema(error="Invalid response")
if isinstance(output, OptimizationReviewResponseSchema):
review_event = output.review.value
review_explanation = output.review_explanation

View file

@ -1,8 +1,8 @@
from __future__ import annotations
from dataclasses import dataclass
import libcst
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from core.shared.optimizer_models import OptimizeSchema
@ -12,19 +12,6 @@ class CodeAndExplanation:
cst_module: libcst.Module | None
explanation: str
@field_validator("cst_module")
def validate_cst_module(cls, v):
if not isinstance(v, libcst.Module):
raise ValueError("cst_module must be an instance of libcst.Module")
try:
# Unparse the CST module to get the source code
source_code = v.code
# Compile the source code to check for syntax errors
compile(source_code, "<string>", "exec")
except Exception as e:
raise ValueError(f"Invalid cst_module, compilation error: {e}")
return v
@dataclass(frozen=True)
class CodeExplanationAndID:
@ -32,19 +19,6 @@ class CodeExplanationAndID:
explanation: str
id: str
@field_validator("cst_module")
def validate_cst_module(cls, v):
if not isinstance(v, libcst.Module):
raise ValueError("cst_module must be an instance of libcst.Module")
try:
# Unparse the CST module to get the source code
source_code = v.code
# Compile the source code to check for syntax errors
compile(source_code, "<string>", "exec")
except Exception as e:
raise ValueError(f"Invalid cst_module, compilation error: {e}")
return v
class JitRewriteOptimizeSchema(OptimizeSchema):
n_candidates: int = 1 # default value for backward compatibility

View file

@ -12,6 +12,7 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
from pydantic import ValidationError
from aiservice.analytics.posthog import ph
from aiservice.background import fire_and_forget
from aiservice.common_utils import parse_python_version, validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
from aiservice.llm import llm_client
@ -20,7 +21,7 @@ from authapp.user import get_user_by_id
from core.languages.python.optimizer.context_utils.optimizer_context import BaseOptimizerContext
from core.languages.python.optimizer.diff_patches_utils.diff import DiffMethod
from core.log_features.log_event import get_or_create_optimization_event
from core.log_features.log_features import log_features
from core.log_features.log_features import safe_log_features
from core.shared.optimizer_config import MAX_OPTIMIZER_CALLS, get_model_distribution
from core.shared.optimizer_models import OptimizedCandidateSource
from core.shared.optimizer_schemas import (
@ -87,7 +88,9 @@ async def generate_optimization_candidate(
llm_cost = output.cost
debug_log_sensitive_data(f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
debug_log_sensitive_data_from_callable(
lambda: f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}"
)
if output.raw_response.usage is not None:
ph(
@ -98,7 +101,7 @@ async def generate_optimization_candidate(
ctx.extract_code_and_explanation_from_llm_res(output.content)
try:
res = ctx.parse_and_generate_candidate_schema()
res = await asyncio.to_thread(ctx.parse_and_generate_candidate_schema)
if res is not None and ctx.is_valid_code():
return res, llm_cost, optimize_model.name
except (ValueError, ValidationError, cst.ParserSyntaxError) as e:
@ -290,65 +293,62 @@ async def optimize_python(
data.n_candidates,
len(data.source_code) if data.source_code else 0,
)
return 500, OptimizeErrorResponseSchema(error="Could not generate any optimizations. Please try again.")
return 422, OptimizeErrorResponseSchema(error="Could not generate any optimizations. Please try again.")
ph(
request.user,
"aiservice-optimize-optimizations-found",
properties={"num_optimizations": len(optimization_response_items)},
)
async with asyncio.TaskGroup() as tg:
event_task = tg.create_task(
get_or_create_optimization_event(
event_type="no-pr",
user_id=request.user,
current_username=data.current_username,
repo_owner=data.repo_owner,
repo_name=data.repo_name,
trace_id=data.trace_id,
api_key_id=request.api_key_id,
metadata={
"codeflash_version": data.codeflash_version,
"num_optimizations": len(optimization_response_items),
"experiment_metadata": data.experiment_metadata,
},
llm_cost=llm_cost,
)
)
tg.create_task(
log_features(
trace_id=data.trace_id,
user_id=request.user,
original_code=data.source_code,
dependency_code=data.dependency_code,
optimizations_raw={op_id: cei["code"] for op_id, cei in code_and_explanations.items()},
optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
explanations_raw={op_id: cei["explanation"] for op_id, cei in code_and_explanations.items()},
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
experiment_metadata=data.experiment_metadata or None,
optimizations_origin={
cei.optimization_id: {
"source": OptimizedCandidateSource.OPTIMIZE,
"parent": None,
"model": optimization_models.get(cei.optimization_id, "unknown"),
}
for cei in optimization_response_items
},
)
fire_and_forget(
safe_log_features(
trace_id=data.trace_id,
user_id=request.user,
original_code=data.source_code,
dependency_code=data.dependency_code,
optimizations_raw={op_id: cei["code"] for op_id, cei in code_and_explanations.items()},
optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
explanations_raw={op_id: cei["explanation"] for op_id, cei in code_and_explanations.items()},
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
experiment_metadata=data.experiment_metadata or None,
optimizations_origin={
cei.optimization_id: {
"source": OptimizedCandidateSource.OPTIMIZE,
"parent": None,
"model": optimization_models.get(cei.optimization_id, "unknown"),
}
for cei in optimization_response_items
},
)
)
event, _created = event_task.result()
event, _created = await get_or_create_optimization_event(
event_type="no-pr",
user_id=request.user,
current_username=data.current_username,
repo_owner=data.repo_owner,
repo_name=data.repo_name,
trace_id=data.trace_id,
api_key_id=request.api_key_id,
metadata={
"codeflash_version": data.codeflash_version,
"num_optimizations": len(optimization_response_items),
"experiment_metadata": data.experiment_metadata,
},
llm_cost=llm_cost,
)
for item in optimization_response_items:
item.optimization_event_id = str(event.id) if event else None
response = OptimizeResponseSchema(optimizations=optimization_response_items)
def log_response() -> None:
debug_log_sensitive_data(f"Response:\n{response.model_dump_json()}")
def log_response() -> str:
parts = [f"Response:\n{response.model_dump_json()}"]
for opt in response.optimizations:
debug_log_sensitive_data(f"Optimized source:\n{opt.source_code}")
debug_log_sensitive_data(f"Optimization explanation:\n{opt.explanation}")
parts.append(f"Optimized source:\n{opt.source_code}")
parts.append(f"Optimization explanation:\n{opt.explanation}")
return "\n".join(parts)
debug_log_sensitive_data_from_callable(log_response)
ph(request.user, "aiservice-optimize-successful")

View file

@ -3,12 +3,13 @@ from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from ninja import NinjaAPI
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from aiservice.analytics.posthog import ph
from aiservice.background import fire_and_forget
from aiservice.common.markdown_utils import split_markdown_code
from aiservice.common_utils import parse_python_version, validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
@ -21,7 +22,7 @@ from core.languages.js_ts.optimizer_lp import optimize_javascript_code_line_prof
from core.languages.python.optimizer.context_utils.optimizer_context import BaseOptimizerContext
from core.languages.python.optimizer.diff_patches_utils.diff import DiffMethod
from core.log_features.log_event import update_optimization_cost
from core.log_features.log_features import log_features
from core.log_features.log_features import safe_log_features
from core.shared.optimizer_config import MAX_OPTIMIZER_LP_CALLS, get_model_distribution
from core.shared.optimizer_models import OptimizedCandidateSource, OptimizeSchemaLP
from core.shared.optimizer_schemas import (
@ -34,6 +35,7 @@ if TYPE_CHECKING:
from openai.types.chat import ChatCompletionMessageParam
from aiservice.llm_models import LLM
from authapp.auth import AuthenticatedRequest
optimize_line_profiler_api = NinjaAPI(urls_namespace="optimize-line-profiler")
@ -74,7 +76,7 @@ async def optimize_python_code_line_profiler_single(
debug_log_sensitive_data(f"This was the user prompt\n {user_prompt}\n")
# TODO: Verify if the context window length is within the model capability
obs_context: dict = {}
obs_context: dict[str, Any] = {}
if call_sequence is not None:
obs_context["call_sequence"] = call_sequence
@ -94,7 +96,9 @@ async def optimize_python_code_line_profiler_single(
llm_cost = output.cost
debug_log_sensitive_data(f"OpenAIClient optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
debug_log_sensitive_data_from_callable(
lambda: f"OpenAIClient optimization response:\n{output.raw_response.model_dump_json(indent=2)}"
)
if output.raw_response.usage is not None:
ph(
@ -104,7 +108,7 @@ async def optimize_python_code_line_profiler_single(
)
ctx.extract_code_and_explanation_from_llm_res(output.content)
res = ctx.parse_and_generate_candidate_schema()
res = await asyncio.to_thread(ctx.parse_and_generate_candidate_schema)
if res is not None and ctx.is_valid_code():
return res, llm_cost, optimize_model.name
@ -120,7 +124,7 @@ async def optimize_python_code_line_profiler(
dependency_code: str | None = None,
n_candidates: int = 0,
python_version: tuple[int, int, int] = (3, 12, 9),
) -> tuple[list[OptimizeResponseItemSchema], float, dict[str, dict], dict[str, str]]:
) -> tuple[list[OptimizeResponseItemSchema], float, dict[str, dict[str, str]], dict[str, str]]:
"""Run parallel line profiler optimizations with multiple models.
Returns:
@ -167,7 +171,7 @@ async def optimize_python_code_line_profiler(
# Collect results
optimization_results: list[OptimizeResponseItemSchema] = []
total_cost = 0.0
code_and_explanations: dict[str, dict] = {}
code_and_explanations: dict[str, dict[str, str]] = {}
optimization_models: dict[str, str] = {}
for task, task_ctx in tasks:
@ -185,9 +189,17 @@ async def optimize_python_code_line_profiler(
@optimize_line_profiler_api.post(
"/", response={200: OptimizeResponseSchema, 400: OptimizeErrorResponseSchema, 500: OptimizeErrorResponseSchema}
"/",
response={
200: OptimizeResponseSchema,
400: OptimizeErrorResponseSchema,
422: OptimizeErrorResponseSchema,
500: OptimizeErrorResponseSchema,
},
)
async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]: # noqa: ANN001
async def optimize(
request: AuthenticatedRequest, data: OptimizeSchemaLP
) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]:
if data.rerun_trace_id:
from core.shared.replay import get_rerun_record, rerun_optimize # noqa: PLC0415
@ -278,7 +290,7 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
n_candidates=data.n_candidates,
)
# JavaScript path doesn't have code_and_explanations dict like Python
code_and_explanations: dict[str, dict] = {}
code_and_explanations: dict[str, dict[str, str]] = {}
elif language == "java":
# Java path
@ -329,11 +341,9 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
system_prompt = SYSTEM_PROMPT
if data.is_numerical_code:
system_prompt += f"\n{JIT_INSTRUCTIONS}\n"
ctx: BaseOptimizerContext = BaseOptimizerContext.get_dynamic_context(
system_prompt, USER_PROMPT, data.source_code, DiffMethod.NO_DIFF
)
ctx = BaseOptimizerContext.get_dynamic_context(system_prompt, USER_PROMPT, data.source_code, DiffMethod.NO_DIFF)
try:
python_version: tuple[int, int, int] = parse_python_version(data.python_version or "3.12.0")
python_version = parse_python_version(data.python_version or "3.12.0")
except: # noqa: E722
return 400, OptimizeErrorResponseSchema(
error="Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above."
@ -372,46 +382,46 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
len(data.source_code) if data.source_code else 0,
bool(data.line_profiler_results),
)
return 500, OptimizeErrorResponseSchema(error="Error generating optimizations. Internal server error.")
return 422, OptimizeErrorResponseSchema(error="Error generating optimizations. Internal server error.")
ph(
request.user,
"aiservice-optimize-optimizations-found",
properties={"num_optimizations": len(optimization_response_items), "language": language},
)
async with asyncio.TaskGroup() as tg:
tg.create_task(update_optimization_cost(trace_id=data.trace_id, cost=llm_cost, user_id=request.user))
if hasattr(request, "should_log_features") and request.should_log_features:
tg.create_task(
log_features(
trace_id=data.trace_id,
user_id=request.user,
original_code=data.source_code,
dependency_code=data.dependency_code,
line_profiler_results=data.line_profiler_results,
optimizations_raw={op_id: cei["code"] for op_id, cei in code_and_explanations.items()},
optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
explanations_raw={op_id: cei["explanation"] for op_id, cei in code_and_explanations.items()},
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
experiment_metadata=data.experiment_metadata or None,
optimizations_origin={
cei.optimization_id: {
"source": OptimizedCandidateSource.OPTIMIZE_LP,
"parent": None,
"model": optimization_models.get(cei.optimization_id, "unknown"),
}
for cei in optimization_response_items
},
)
fire_and_forget(update_optimization_cost(trace_id=data.trace_id, cost=llm_cost, user_id=request.user))
if hasattr(request, "should_log_features") and request.should_log_features:
fire_and_forget(
safe_log_features(
trace_id=data.trace_id,
user_id=request.user,
original_code=data.source_code,
dependency_code=data.dependency_code,
line_profiler_results=data.line_profiler_results,
optimizations_raw={op_id: cei["code"] for op_id, cei in code_and_explanations.items()},
optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
explanations_raw={op_id: cei["explanation"] for op_id, cei in code_and_explanations.items()},
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
experiment_metadata=data.experiment_metadata or None,
optimizations_origin={
cei.optimization_id: {
"source": OptimizedCandidateSource.OPTIMIZE_LP,
"parent": None,
"model": optimization_models.get(cei.optimization_id, "unknown"),
}
for cei in optimization_response_items
},
)
)
response = OptimizeResponseSchema(optimizations=optimization_response_items)
def log_response() -> None:
debug_log_sensitive_data(f"Response:\n{response.model_dump_json()}")
def log_response() -> str:
parts = [f"Response:\n{response.model_dump_json()}"]
for opt in response.optimizations:
debug_log_sensitive_data(f"Optimized source:\n{opt.source_code}")
debug_log_sensitive_data(f"Optimization explanation:\n{opt.explanation}")
parts.append(f"Optimized source:\n{opt.source_code}")
parts.append(f"Optimization explanation:\n{opt.explanation}")
return "\n".join(parts)
debug_log_sensitive_data_from_callable(log_response)
ph(request.user, "aiservice-optimize-successful")

View file

@ -12,8 +12,8 @@ import libcst as cst
import sentry_sdk
from libcst import BaseStatement, CSTTransformer, CSTVisitor, Expr, IndentedBlock, SimpleStatementLine, SimpleString
from core.languages.python.cst_utils import compare_unparsed_ast_to_source, parse_module_to_cst, unparse_parse_source
from aiservice.common_utils import safe_isort
from core.languages.python.cst_utils import compare_unparsed_ast_to_source, parse_module_to_cst, unparse_parse_source
from core.languages.python.optimizer.models import CodeExplanationAndID
from core.languages.python.testgen.postprocessing.add_missing_imports import add_future_annotations_import
@ -43,29 +43,27 @@ def deduplicate_optimizations(
List[CodeExplanationAndID]: A list of CodeExplanationAndID objects with duplicates removed.
"""
seen_asts = set()
seen_asts: set[str] = set()
unique_optimizations = []
for optimization in optimized_code_and_explanations:
code_ast = ast.parse(optimization.cst_module.code)
code_ast_tuple = ast.dump(code_ast, annotate_fields=False)
if code_ast_tuple not in seen_asts:
seen_asts.add(code_ast_tuple)
normalized = ast.unparse(ast.parse(optimization.cst_module.code))
if normalized not in seen_asts:
seen_asts.add(normalized)
unique_optimizations.append(optimization)
return unique_optimizations
def equality_check(
original_module: cst.Module, optimized_code_and_explanations: list[CodeExplanationAndID]
original_module: cst.Module,
optimized_code_and_explanations: list[CodeExplanationAndID],
*,
original_code: str | None = None,
) -> list[CodeExplanationAndID]:
original_source_code = original_module.code
original_source_code = original_code if original_code is not None else original_module.code
try:
original_source_ast = unparse_parse_source(original_source_code)
except Exception:
return [
CodeExplanationAndID(cst_module=ce.cst_module, explanation=ce.explanation, id=ce.id)
for ce in optimized_code_and_explanations
if ce.cst_module.code != original_source_code
]
return [ce for ce in optimized_code_and_explanations if ce.cst_module.code != original_source_code]
filtered_optimizations = []
for ce in optimized_code_and_explanations:
try:
@ -147,13 +145,13 @@ class DocstringTransformer(CSTTransformer):
if not updated_node.get_docstring(clean=False):
new_body: list[BaseStatement] = [
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
*cast(list[BaseStatement], list(updated_node.body.body)),
*cast("list[BaseStatement]", list(updated_node.body.body)),
]
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
else:
new_body = [
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
*cast(list[BaseStatement], list(updated_node.body.body[1:])),
*cast("list[BaseStatement]", list(updated_node.body.body[1:])),
]
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
self.class_name = None
@ -167,13 +165,13 @@ class DocstringTransformer(CSTTransformer):
if not updated_node.get_docstring(clean=False):
new_body: list[BaseStatement] = [
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
*cast(list[BaseStatement], list(updated_node.body.body)),
*cast("list[BaseStatement]", list(updated_node.body.body)),
]
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
else:
new_body = [
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
*cast(list[BaseStatement], list(updated_node.body.body[1:])),
*cast("list[BaseStatement]", list(updated_node.body.body[1:])),
]
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
return updated_node
@ -207,13 +205,18 @@ def dedup_and_sort_imports(
new_optimized_code_and_explanations = []
for ce in optimized_code_and_explanations:
try:
# Use isort to sort and deduplicate the imports
sorted_code = safe_isort(ce.cst_module.code, disregard_skip=True)
original_code = ce.cst_module.code
sorted_code = safe_isort(original_code, disregard_skip=True)
except Exception:
sorted_code = ce.cst_module.code
new_optimized_code_and_explanations.append(
CodeExplanationAndID(cst_module=parse_module_to_cst(sorted_code), explanation=ce.explanation, id=ce.id)
)
new_optimized_code_and_explanations.append(ce)
continue
# Skip re-parse if isort didn't change anything
if sorted_code == original_code:
new_optimized_code_and_explanations.append(ce)
else:
new_optimized_code_and_explanations.append(
CodeExplanationAndID(cst_module=parse_module_to_cst(sorted_code), explanation=ce.explanation, id=ce.id)
)
return new_optimized_code_and_explanations
@ -299,7 +302,13 @@ def _strip_comments_from_code(code: str) -> str:
return code
def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst.Module) -> cst.Module:
def clean_extraneous_comments(
original_module: cst.Module,
optimized_module: cst.Module,
*,
orig_code: str | None = None,
orig_code_stripped: str | None = None,
) -> cst.Module:
"""Clean extraneous comments from optimized code using difflib.
Uses diff-based approach on code (without comments) to identify which lines
@ -309,6 +318,8 @@ def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst
----
original_module: The original CST module.
optimized_module: The optimized CST module with potential extra comments.
orig_code: Pre-computed original module code (avoids redundant codegen).
orig_code_stripped: Pre-computed comment-stripped original code.
Returns:
-------
@ -316,14 +327,19 @@ def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst
"""
try:
# Get line-by-line representation
orig_lines = original_module.code.splitlines(keepends=True)
opt_lines = optimized_module.code.splitlines(keepends=True)
# Get line-by-line representation, reusing pre-computed strings when available
if orig_code is None:
orig_code = original_module.code
opt_code_full = optimized_module.code
orig_lines = orig_code.splitlines(keepends=True)
opt_lines = opt_code_full.splitlines(keepends=True)
# Strip comments from entire code to identify code changes
# This properly handles # symbols inside strings
orig_code_stripped = _strip_comments_from_code(original_module.code)
opt_code_stripped = _strip_comments_from_code(optimized_module.code)
if orig_code_stripped is None:
orig_code_stripped = _strip_comments_from_code(orig_code)
opt_code_stripped = _strip_comments_from_code(opt_code_full)
# Split stripped versions into lines
orig_code_only = orig_code_stripped.splitlines(keepends=True)
@ -523,18 +539,29 @@ def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst
def clean_extraneous_comments_pipeline(
original_module: cst.Module, optimized_code_and_explanations: list[CodeExplanationAndID]
original_module: cst.Module,
optimized_code_and_explanations: list[CodeExplanationAndID],
*,
orig_code: str | None = None,
orig_code_stripped: str | None = None,
) -> list[CodeExplanationAndID]:
"""Pipeline wrapper for comment cleaning.
Cleans extraneous comments from all optimized code variants.
Pre-computes original code and stripped version once for all candidates.
"""
try:
cleaned_results = []
if orig_code is None:
orig_code = original_module.code
if orig_code_stripped is None:
orig_code_stripped = _strip_comments_from_code(orig_code)
for ce in optimized_code_and_explanations:
try:
cleaned_module = clean_extraneous_comments(original_module, ce.cst_module)
cleaned_module = clean_extraneous_comments(
original_module, ce.cst_module, orig_code=orig_code, orig_code_stripped=orig_code_stripped
)
cleaned_results.append(
CodeExplanationAndID(cst_module=cleaned_module, explanation=ce.explanation, id=ce.id)
)
@ -581,17 +608,23 @@ def fix_forward_references(
def optimizations_postprocessing_pipeline(
original_module: cst.Module, optimized_code_and_explanations: list[CodeExplanationAndID]
) -> list[CodeExplanationAndID]:
pipeline = [
fix_missing_docstring, # We want to deduplicate with the fixed docstrings included
clean_extraneous_comments_pipeline, # Clean comments added to unchanged code
fix_forward_references, # Add future annotations for forward references
deduplicate_optimizations,
equality_check,
dedup_and_sort_imports,
cleanup_explanations,
filter_ellipsis_containing_code,
]
# Pre-compute original code string once — avoids redundant CST codegen across steps
original_code = original_module.code
original_code_stripped = _strip_comments_from_code(original_code)
for pipeline_fn in pipeline:
optimized_code_and_explanations = pipeline_fn(original_module, optimized_code_and_explanations)
optimized_code_and_explanations = fix_missing_docstring(original_module, optimized_code_and_explanations)
optimized_code_and_explanations = clean_extraneous_comments_pipeline(
original_module,
optimized_code_and_explanations,
orig_code=original_code,
orig_code_stripped=original_code_stripped,
)
optimized_code_and_explanations = fix_forward_references(original_module, optimized_code_and_explanations)
optimized_code_and_explanations = deduplicate_optimizations(original_module, optimized_code_and_explanations)
optimized_code_and_explanations = equality_check(
original_module, optimized_code_and_explanations, original_code=original_code
)
optimized_code_and_explanations = dedup_and_sort_imports(original_module, optimized_code_and_explanations)
optimized_code_and_explanations = cleanup_explanations(original_module, optimized_code_and_explanations)
optimized_code_and_explanations = filter_ellipsis_containing_code(original_module, optimized_code_and_explanations)
return optimized_code_and_explanations

View file

@ -13,7 +13,7 @@ from ninja.errors import HttpError
from aiservice.analytics.posthog import ph
from aiservice.common.markdown_utils import extract_code_block, split_markdown_code
from aiservice.common_utils import safe_isort
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
from aiservice.llm import llm_client
from aiservice.llm_models import EXECUTE_MODEL, HAIKU_MODEL, OPENAI_MODEL
from aiservice.models.functions_to_optimize import FunctionToOptimize
@ -30,7 +30,7 @@ from core.languages.python.testgen.postprocessing.postprocess_pipeline import po
from core.languages.python.testgen.preprocessing.preprocess_pipeline import generate_notes
from core.languages.python.testgen.validate import instrument_tests, validate_request_data
from core.log_features.log_event import update_optimization_cost
from core.log_features.log_features import log_features
from core.log_features.log_features import safe_log_features
from core.shared.jinja_utils import create_prompt_env
from core.shared.testgen_models import (
TestGenDebugInfo,
@ -197,7 +197,9 @@ async def generate_and_validate_test_code(
cost = response.cost
cost_tracker.add(cost)
debug_log_sensitive_data(f"LLM {error_context}execute response:\n{response.raw_response.model_dump_json(indent=2)}")
debug_log_sensitive_data_from_callable(
lambda: f"LLM {error_context}execute response:\n{response.raw_response.model_dump_json(indent=2)}"
)
if response.raw_response.usage:
ph(
@ -342,7 +344,7 @@ async def generate_regression_tests_from_function(
"validation_error": str(e),
},
) from e
except (SyntaxError, ValueError) as e:
except (SyntaxError, ValueError, cst.ParserSyntaxError) as e:
msg = f"Failed to generate valid {error_context}test code after {cost_tracker.calls} tries. trace_id={trace_id}"
logging.exception(msg)
raise TestGenerationFailedError(msg, debug_info={"stage": "unknown", "validation_error": str(e)}) from e
@ -402,7 +404,7 @@ async def testgen_python(
ph(request.user, "aiservice-testgen-tests-generated")
if hasattr(request, "should_log_features") and request.should_log_features:
await log_features(
await safe_log_features(
trace_id=data.trace_id,
user_id=request.user,
generated_tests=[raw_display_source],
@ -440,7 +442,7 @@ async def testgen_python(
lines_removed=int(lines_removed) if lines_removed is not None else None,
validation_error=str(v) if (v := e.debug_info.get("validation_error")) else None,
)
return 500, TestGenErrorResponseSchema(error=str(e), trace_id=data.trace_id, debug_info=debug_info)
return 422, TestGenErrorResponseSchema(error=str(e), trace_id=data.trace_id, debug_info=debug_info)
except Exception as e:
logging.exception("Test generation failed. trace_id=%s", data.trace_id)
sentry_sdk.capture_exception(e)

View file

@ -26,7 +26,7 @@ def _positional_list(items: list[str] | None, index: int | None) -> list[str | N
return result
@sync_to_async(thread_sensitive=True)
@sync_to_async(thread_sensitive=False)
@transaction.atomic
def log_features(
trace_id: str,
@ -56,32 +56,6 @@ def log_features(
ranking: dict[str, Any] | None = None,
optimizations_origin: dict[str, dict[str, str]] | None = None,
) -> None:
"""Log features of a code optimization run to the database.
:rtype: None
:param optimized_line_profiler_results: mapping of optimization candidate trace ids to line profiler results
:param trace_id: The client generated UUID of the optimization run. This is used to link the features together.
:param user_id: The user ID of the user who ran the optimization.
:param original_code: The original code that the LLM is allowed to modify.
:param dependency_code: The dependency code that the LLM is not allowed to modify.
:param original_code: The line profiling results for the original code.
:param optimizations_raw: The raw optimizations that were generated by the language model.
:param optimizations_post: The final optimizations that were generated by the optimization endpoint.
:param explanations_raw: Raw Explanations generated by the language model.
:param explanations_post: Final Explanations for the optimizations.
:param speedup_ratio: Speedups in fractions achieved by the optimizations.
:param original_runtime: The time taken in ns to run the original code.
:param optimized_runtime: The time taken in ns to run the optimized code.
:param is_correct: Behavioural correctness of the optimized code.
:param generated_tests: Generated tests for the optimized code, output of the aiservice test generation endpoint.
:param instrumented_generated_tests: Behavior instrumented tests for the optimized code.
:param instrumented_perf_tests: Performance instrumented tests for the optimized code.
:param test_framework: The test framework used to generate the tests.
:param datetime: The datetime of the optimization run. Should be calculated by the aiservice.
:param aiservice_commit: The commit hash of the AIService code used for the feature logging. hopefully should be the same for the entire run.
:param metadata: Additional metadata to log.
:param final_explanation: the final explanation in the PR
"""
f, created = OptimizationFeatures.objects.select_for_update().get_or_create(
trace_id=trace_id,
defaults={
@ -117,7 +91,6 @@ def log_features(
update_fields: list[str] = []
# Simple override fields - set if new value is provided
simple_updates = [
("user_id", user_id),
("original_code", original_code),
@ -138,7 +111,6 @@ def log_features(
setattr(f, field, value)
update_fields.append(field)
# Dict merge fields - merge new values into existing (only if new value provided)
dict_merge_updates = [
("optimizations_raw", f.optimizations_raw, optimizations_raw),
("optimizations_post", f.optimizations_post, optimizations_post),
@ -153,12 +125,10 @@ def log_features(
setattr(f, field, merge_dict(existing, new))
update_fields.append(field)
# Nested dict merge for optimizations_origin
if optimizations_origin is not None:
f.optimizations_origin = merge_dicts(f.optimizations_origin or {}, optimizations_origin)
update_fields.append("optimizations_origin")
# List fields — positional insertion when test_index is given, else append
list_updates = [
("generated_test", f.generated_test, generated_tests),
("instrumented_generated_test", f.instrumented_generated_test, instrumented_generated_tests),
@ -181,8 +151,14 @@ def log_features(
f.save(update_fields=update_fields)
async def safe_log_features(**kwargs: Any) -> None: # noqa: ANN401
try:
await log_features(**kwargs)
except Exception: # noqa: BLE001
sentry_sdk.capture_exception()
def merge_dict(existing: dict | None, new: dict | None) -> dict | None:
"""Merge new dict into existing, returning existing if new is None."""
if new is None:
return existing
return (existing or {}) | new
@ -198,7 +174,6 @@ def merge_dicts(a: dict[str, dict[str, str]], b: dict[str, dict[str, str]]) -> d
if key not in result:
result[key] = inner.copy()
else:
# b overrides a
result[key].update(inner)
return result

View file

@ -18,7 +18,13 @@ optimize_api = NinjaAPI(urls_namespace="optimize")
@optimize_api.post(
"/", response={200: OptimizeResponseSchema, 400: OptimizeErrorResponseSchema, 500: OptimizeErrorResponseSchema}
"/",
response={
200: OptimizeResponseSchema,
400: OptimizeErrorResponseSchema,
422: OptimizeErrorResponseSchema,
500: OptimizeErrorResponseSchema,
},
)
async def optimize(
request: AuthenticatedRequest, data: OptimizeSchema

View file

@ -15,7 +15,9 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
from aiservice.analytics.posthog import ph
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.llm import llm_client
import stamina
from aiservice.llm import LLMOutputUnparseable, llm_client
from aiservice.llm_models import LLM, RANKING_MODEL
from authapp.auth import AuthenticatedRequest
from core.log_features.log_event import update_optimization_cost
@ -319,9 +321,10 @@ def _scores_to_ranking(scores: CandidateScores) -> list[int]:
return sorted_candidates
@stamina.retry(on=LLMOutputUnparseable, attempts=2)
async def rank_optimizations( # noqa: D417
user_id: str, data: RankInputSchema, rank_model: LLM = RANKING_MODEL
) -> tuple[RankResponseSchema | RankErrorResponseSchema, float]:
) -> tuple[RankResponseSchema, float]:
"""Rank optimization candidates using multi-dimensional scoring.
Parameters
@ -363,9 +366,9 @@ async def rank_optimizations( # noqa: D417
"python_version": data.python_version,
},
)
except Exception:
except Exception as exc:
logging.exception("Ranking failed for trace_id=%s", data.trace_id)
return RankErrorResponseSchema(error="Failed to rank optimizations. Please try again."), 0.0
raise LLMOutputUnparseable("Failed to rank optimizations") from exc
debug_log_sensitive_data(f"AIClient optimization response:\n{output}")
if output.raw_response.usage is not None:
@ -432,7 +435,7 @@ async def rank_optimizations( # noqa: D417
logging.info("Derived ranking from scores")
else:
logging.warning("No valid ranking found")
return RankErrorResponseSchema(error="No ranking found"), output.cost
raise LLMOutputUnparseable("No valid ranking found", cost=output.cost)
return RankResponseSchema(ranking=ranking, explanation=explanation, scores=scores), output.cost
@ -456,18 +459,27 @@ class RankErrorResponseSchema(Schema):
error: str
@ranker_api.post("/", response={200: RankResponseSchema, 400: RankErrorResponseSchema, 500: RankErrorResponseSchema})
@ranker_api.post(
"/",
response={
200: RankResponseSchema,
400: RankErrorResponseSchema,
422: RankErrorResponseSchema,
500: RankErrorResponseSchema,
},
)
async def rank(
request: AuthenticatedRequest, data: RankInputSchema
) -> tuple[int, RankResponseSchema | RankErrorResponseSchema]:
ph(request.user, "aiservice-rank-called")
if not validate_trace_id(data.trace_id):
return 400, RankErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
ranking_response, llm_cost = await rank_optimizations(request.user, data)
if isinstance(ranking_response, RankErrorResponseSchema):
try:
ranking_response, llm_cost = await rank_optimizations(request.user, data)
except LLMOutputUnparseable:
ph(request.user, "Invalid Ranking, fallback to default")
debug_log_sensitive_data("No valid ranking was generated")
return 500, RankErrorResponseSchema(error="Error generating ranking. Internal server error.")
return 422, RankErrorResponseSchema(error="Error generating ranking. Internal server error.")
ph(request.user, "ranking generated", properties={"ranking": ranking_response})
ranking_0_idx = [x - 1 for x in ranking_response.ranking]
coros: list[Coroutine[Any, Any, Any]] = [

View file

@ -47,17 +47,28 @@ def rerun_optimize(record: OptimizationFeatures, source_filter: str) -> Optimize
def rerun_testgen(record: OptimizationFeatures, test_index: int) -> TestGenResponseSchema | None:
generated: list[str] = cast("list[str]", record.generated_test) or []
instrumented: list[str] = cast("list[str]", record.instrumented_generated_test) or []
perf: list[str] = cast("list[str]", record.instrumented_perf_test) or []
generated: list[str | None] = cast("list[str | None]", record.generated_test) or []
instrumented: list[str | None] = cast("list[str | None]", record.instrumented_generated_test) or []
perf: list[str | None] = cast("list[str | None]", record.instrumented_perf_test) or []
if test_index >= len(generated) or test_index >= len(instrumented):
return None
# Check if values at the index are None (can happen with NULL in database arrays)
generated_val = generated[test_index]
instrumented_val = instrumented[test_index]
if generated_val is None or instrumented_val is None:
return None
perf_val = perf[test_index] if test_index < len(perf) else None
# Default to empty string if perf value is None
perf_val = perf_val if perf_val is not None else ""
return TestGenResponseSchema(
generated_tests=generated[test_index],
instrumented_behavior_tests=instrumented[test_index],
instrumented_perf_tests=perf[test_index] if test_index < len(perf) else "",
generated_tests=generated_val,
instrumented_behavior_tests=instrumented_val,
instrumented_perf_tests=perf_val,
raw_generated_tests=None,
)

View file

@ -17,7 +17,13 @@ testgen_api = NinjaAPI(urls_namespace="testgen")
@testgen_api.post(
"/", response={200: TestGenResponseSchema, 400: TestGenErrorResponseSchema, 500: TestGenErrorResponseSchema}
"/",
response={
200: TestGenResponseSchema,
400: TestGenErrorResponseSchema,
422: TestGenErrorResponseSchema,
500: TestGenErrorResponseSchema,
},
)
async def testgen(
request: AuthenticatedRequest, data: TestGenSchema

View file

@ -0,0 +1,156 @@
"""Tests for LLM client event loop handling and connection cleanup."""
from __future__ import annotations
import asyncio
from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from aiservice.llm import LLMClient
class TestLLMClientEventLoopHandling:
"""Test that LLMClient properly handles event loop changes."""
@pytest.mark.asyncio
async def test_old_clients_are_closed_when_event_loop_changes(self) -> None:
"""Test that old HTTP clients are closed when switching event loops.
This test reproduces the bug where old AsyncAzureOpenAI and
AsyncAnthropicBedrock clients are not closed when the event loop
changes, causing connection leaks and 'Event loop is closed' errors.
"""
from aiservice.llm_models import LLM
client = LLMClient()
# Mock the API key check so we can test the cleanup logic
with (
patch("aiservice.llm.has_openai", True),
patch("aiservice.llm.has_anthropic", True),
patch("aiservice.llm.AsyncAzureOpenAI") as mock_openai_class,
patch("aiservice.llm.AsyncAnthropicBedrock") as mock_anthropic_class,
):
# Create mock instances with close methods
# Need to create new instances on each call to simulate proper client creation
def create_openai_mock(*args: object, **kwargs: object) -> MagicMock:
mock = MagicMock()
mock.close = AsyncMock()
mock.chat = MagicMock()
mock.chat.completions = MagicMock()
mock.chat.completions.create = AsyncMock(
return_value=MagicMock(
choices=[MagicMock(message=MagicMock(content="test"))],
usage=MagicMock(prompt_tokens=10, completion_tokens=20),
)
)
return mock
def create_anthropic_mock(*args: object, **kwargs: object) -> MagicMock:
mock = MagicMock()
mock.close = AsyncMock()
return mock
mock_openai_class.side_effect = create_openai_mock
mock_anthropic_class.side_effect = create_anthropic_mock
# Make first call in event loop 1
test_llm = LLM(name="gpt-4.1", model_type="openai", input_cost=2.0, output_cost=8.0, cached_input_cost=None)
await client.call(
llm=test_llm, messages=[{"role": "user", "content": "test"}], call_type="test", trace_id="test-trace-1"
)
# Save reference to first clients
first_openai_client = cast(MagicMock, client.openai_client)
first_anthropic_client = cast(MagicMock, client.anthropic_client)
first_loop = client.client_loop
assert first_openai_client is not None
assert first_anthropic_client is not None
assert first_loop == asyncio.get_running_loop()
# Simulate event loop change by creating a new loop and running in it
# In Django/ASGI, this happens when requests are handled by different workers
def make_call_in_new_loop() -> None:
async def inner() -> None:
# Create a fresh LLM object in the new loop
new_llm = LLM(
name="gpt-4.1", model_type="openai", input_cost=2.0, output_cost=8.0, cached_input_cost=None
)
await client.call(
llm=new_llm,
messages=[{"role": "user", "content": "test2"}],
call_type="test",
trace_id="test-trace-2",
)
new_loop = asyncio.new_event_loop()
try:
new_loop.run_until_complete(inner())
finally:
new_loop.close()
# Make call in new event loop
await asyncio.to_thread(make_call_in_new_loop)
# Check that old clients were closed
# THIS WILL FAIL with the current buggy code - old clients are NOT closed
first_openai_client.close.assert_called_once()
first_anthropic_client.close.assert_called_once()
# Verify new clients were created
assert client.openai_client is not first_openai_client
assert client.anthropic_client is not first_anthropic_client
assert client.client_loop != first_loop
@pytest.mark.asyncio
async def test_clients_are_not_recreated_in_same_event_loop(self) -> None:
"""Test that clients are reused when called in the same event loop."""
from aiservice.llm_models import LLM
client = LLMClient()
with (
patch("aiservice.llm.has_openai", True),
patch("aiservice.llm.has_anthropic", True),
patch("aiservice.llm.AsyncAzureOpenAI") as mock_openai_class,
patch("aiservice.llm.AsyncAnthropicBedrock") as mock_anthropic_class,
):
mock_openai_instance = MagicMock()
mock_openai_instance.chat = MagicMock()
mock_openai_instance.chat.completions = MagicMock()
mock_openai_instance.chat.completions.create = AsyncMock(
return_value=MagicMock(
choices=[MagicMock(message=MagicMock(content="test"))],
usage=MagicMock(prompt_tokens=10, completion_tokens=20),
)
)
mock_openai_class.return_value = mock_openai_instance
mock_anthropic_class.return_value = MagicMock()
test_llm = LLM(name="gpt-4.1", model_type="openai", input_cost=2.0, output_cost=8.0, cached_input_cost=None)
# Make first call
await client.call(
llm=test_llm, messages=[{"role": "user", "content": "test1"}], call_type="test", trace_id="test-trace-1"
)
first_openai_client = client.openai_client
first_anthropic_client = client.anthropic_client
# Make second call in same event loop
await client.call(
llm=test_llm, messages=[{"role": "user", "content": "test2"}], call_type="test", trace_id="test-trace-2"
)
# Clients should be the same instances (not recreated)
assert client.openai_client is first_openai_client
assert client.anthropic_client is first_anthropic_client
# Constructor should only be called once
assert mock_openai_class.call_count == 1
assert mock_anthropic_class.call_count == 1

View file

@ -181,6 +181,18 @@ x ="""
assert result == expected
def test_extract_code_block_with_filepath_annotation() -> None:
text = "```python:src/main.py\ndef foo(): pass\n```"
result = extract_code_block(text)
assert result == "def foo(): pass"
def test_extract_code_block_with_filepath_annotation_fallback() -> None:
text = "```python:src/main.py\ndef foo(): pass"
result = extract_code_block(text)
assert result == "def foo(): pass"
def test_extract_code_block_nested_code_fence_in_triple_quote() -> None:
# LLM embeds function definition in a triple-quoted string containing ```
text = '```python\nimport pytest\n_source = """```python:file.py\ndef foo(): pass\n```"""\ndef test_foo():\n assert True\n```'

View file

@ -1,4 +1,17 @@
import ast
from collections.abc import Generator
import pytest
from aiservice.background import _background_tasks
@pytest.fixture(autouse=True)
def _drain_background_tasks() -> Generator[None, None, None]:
yield
for task in list(_background_tasks):
task.cancel()
_background_tasks.clear()
def normalize_code(code: str | None) -> str | None:

View file

@ -1,86 +1,58 @@
"""Tests for database connection handling in log_features.
This module tests that log_features properly handles concurrent database operations
without exhausting the connection pool.
IMPORTANT: This test verifies the fix for the PostgreSQL connection pool exhaustion bug.
The bug occurs when @sync_to_async(thread_sensitive=False) is used, causing each call
to grab a separate connection. With thread_sensitive=True, connections are properly reused.
Bug trace IDs: a0d8dab6-6524-47dc-9c82-5fa92e6390fb, 62f5c35b-7161-4ab0-958a-4865231f5188
Error: psycopg_pool.PoolTimeout: couldn't get a connection after 30.00 sec
"""
"""Tests for database connection handling in log_features."""
import ast
import inspect
from pathlib import Path
from core.log_features.log_features import safe_log_features
def test_log_features_uses_thread_sensitive_true() -> None:
"""Test that log_features uses thread_sensitive=True to prevent pool exhaustion.
This test verifies that the @sync_to_async decorator on log_features
has thread_sensitive=True, which is required to properly handle database
connections across async/sync boundaries without exhausting the pool.
def test_log_features_uses_thread_sensitive_false() -> None:
"""Verify thread_sensitive=False so concurrent calls get their own threads.
With thread_sensitive=False (the bug), each call would grab a separate connection.
With thread_sensitive=True (the fix), Django reuses connections properly.
This test parses the source code directly to check the decorator parameters.
thread_sensitive=True serializes all calls through one thread, creating a bottleneck.
thread_sensitive=False allows parallel execution with one DB connection per thread.
"""
# Read the source file
log_features_path = Path(__file__).parent.parent.parent / "core" / "log_features" / "log_features.py"
source_code = log_features_path.read_text()
tree = ast.parse(log_features_path.read_text())
# Parse the source code
tree = ast.parse(source_code)
# Find the log_features function
log_features_func = None
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == "log_features":
log_features_func = node
break
else:
raise AssertionError("Could not find log_features function")
assert log_features_func is not None, "Could not find log_features function"
# Check that it has decorators
assert len(log_features_func.decorator_list) > 0, "log_features should have decorators"
# Find the @sync_to_async decorator
sync_to_async_decorator = None
for decorator in log_features_func.decorator_list:
for decorator in node.decorator_list:
if isinstance(decorator, ast.Call):
if isinstance(decorator.func, ast.Name) and decorator.func.id == "sync_to_async":
sync_to_async_decorator = decorator
break
elif isinstance(decorator.func, ast.Attribute) and decorator.func.attr == "sync_to_async":
func = decorator.func
if (isinstance(func, ast.Name) and func.id == "sync_to_async") or (
isinstance(func, ast.Attribute) and func.attr == "sync_to_async"
):
sync_to_async_decorator = decorator
break
assert sync_to_async_decorator is not None, (
"log_features should be decorated with @sync_to_async. "
"Decorators found: " + ", ".join(ast.unparse(d) for d in log_features_func.decorator_list)
"Decorators found: " + ", ".join(ast.unparse(d) for d in node.decorator_list)
)
# Check the keyword arguments
thread_sensitive_arg = None
for keyword in sync_to_async_decorator.keywords:
if keyword.arg == "thread_sensitive":
thread_sensitive_arg = keyword.value
assert isinstance(keyword.value, ast.Constant), (
f"thread_sensitive should be a boolean constant, got {ast.unparse(keyword.value)}"
)
assert keyword.value.value is False, (
f"thread_sensitive should be False, got {keyword.value.value!r}. "
"True serializes all calls through one thread, causing a bottleneck."
)
break
assert thread_sensitive_arg is not None, (
"@sync_to_async should have thread_sensitive parameter. "
"Current decorator: " + ast.unparse(sync_to_async_decorator)
)
# Verify thread_sensitive=True (not False)
if isinstance(thread_sensitive_arg, ast.Constant):
assert thread_sensitive_arg.value is True, (
f"thread_sensitive should be True, got {thread_sensitive_arg.value!r}. "
"This causes PostgreSQL connection pool exhaustion!"
)
else:
# If it's not a constant, fail with a clear message
assert False, f"thread_sensitive should be a boolean constant True, got {ast.unparse(thread_sensitive_arg)}"
raise AssertionError(
"@sync_to_async should have thread_sensitive=False. "
"Current decorator: " + ast.unparse(sync_to_async_decorator)
)
def test_safe_log_features_is_async() -> None:
assert inspect.iscoroutinefunction(safe_log_features), "safe_log_features should be an async def"

View file

@ -38,13 +38,20 @@ class TestResolveImport:
def test_class_method(self) -> None:
assert _resolve_import("Validator.validateRequest", "../middlewares/Validator") == (
"default",
"named",
"Validator",
"Validator.validateRequest",
)
def test_class_method_default_export(self) -> None:
assert _resolve_import(
"Validator.validateRequest",
"../middlewares/Validator",
"export default class Validator { validateRequest() {} }",
) == ("default", "Validator", "Validator.validateRequest")
def test_static_method(self) -> None:
assert _resolve_import("Utils.formatDate", "../utils/Utils") == ("default", "Utils", "Utils.formatDate")
assert _resolve_import("Utils.formatDate", "../utils/Utils") == ("named", "Utils", "Utils.formatDate")
def test_function_with_underscore(self) -> None:
assert _resolve_import("get_git_root", "../utils/git") == ("named", "get_git_root", "get_git_root")
@ -263,7 +270,7 @@ class TestBuildJavascriptPrompt:
def test_esm_default_import(self) -> None:
messages, _ = build_javascript_prompt(
function_name="Svc.run",
function_code="class Svc { run() {} }",
function_code="export default class Svc { run() {} }",
module_path="../services/Svc",
test_framework="vitest",
is_async=False,
@ -292,7 +299,7 @@ class TestBuildJavascriptPrompt:
def test_cjs_class_method(self) -> None:
messages, _ = build_javascript_prompt(
function_name="Validator.validateRequest",
function_code="class Validator { validateRequest() {} }",
function_code="export default class Validator { validateRequest() {} }",
module_path="../middlewares/Validator",
test_framework="jest",
is_async=False,
@ -306,7 +313,7 @@ class TestBuildJavascriptPrompt:
def test_cjs_class_method_async(self) -> None:
messages, suffix = build_javascript_prompt(
function_name="Controller.asyncMethod",
function_code="class Controller { async asyncMethod() {} }",
function_code="export default class Controller { async asyncMethod() {} }",
module_path="../controllers/Controller",
test_framework="jest",
is_async=True,

View file

@ -1,11 +1,12 @@
import unittest
import libcst as cst
from django.test import TestCase
from core.languages.python.cst_utils import any_ellipsis_in_cst, ellipsis_in_cst_not_types
from core.languages.python.testgen.generate import did_generate_ellipsis
class TestEllipsisInCst(TestCase):
class TestEllipsisInCst(unittest.TestCase):
def test_ellipsis_in_cst_with_ellipsis(self) -> None:
"""Test that ellipsis_in_cst detects an ellipsis in the function body."""
code = """
@ -91,7 +92,7 @@ def other_function():
assert ellipsis_in_cst_not_types(module) is True
class TestContextDidGenerateEllipsis(TestCase):
class TestContextDidGenerateEllipsis(unittest.TestCase):
def test_single_context_flags_generated_ellipsis(self) -> None:
"""Single context should flag ellipsis when original had none."""
source_code = "def foo():\n return 42"

View file

@ -0,0 +1,121 @@
"""Tests for JavaScript/TypeScript import resolution.
Tests the _resolve_import function to ensure it correctly detects export styles
from source code and generates appropriate import statements.
"""
import pytest
from core.languages.js_ts.testgen import _resolve_import
class TestResolveImport:
"""Tests for _resolve_import function."""
def test_standalone_function_uses_named_import(self) -> None:
"""Test that standalone functions use named import style."""
function_name = "execMongoEval"
module_path = "../utils/mongo"
source_code = """
export function execMongoEval(code: string) {
return eval(code);
}
"""
import_style, import_name, function_accessor = _resolve_import(function_name, module_path, source_code)
assert import_style == "named"
assert import_name == "execMongoEval"
assert function_accessor == "execMongoEval"
def test_class_method_with_named_export_uses_named_import(self) -> None:
"""Test that class methods with 'export class' use named import style."""
function_name = "ModulesContainer.getById"
module_path = "../../injector/modules-container"
source_code = """
export class ModulesContainer {
public getById(id: string): Module | undefined {
return Array.from(this.values()).find(m => m.id === id);
}
}
"""
import_style, import_name, function_accessor = _resolve_import(function_name, module_path, source_code)
# Should use NAMED import for named export
assert import_style == "named", f"Expected 'named' import for 'export class', got '{import_style}'"
assert import_name == "ModulesContainer"
assert function_accessor == "ModulesContainer.getById"
def test_class_method_with_default_export_uses_default_import(self) -> None:
"""Test that class methods with 'export default class' use default import style."""
function_name = "Validator.validateRequest"
module_path = "../validators/validator"
source_code = """
export default class Validator {
validateRequest(req: Request): boolean {
return req.method === 'POST';
}
}
"""
import_style, import_name, function_accessor = _resolve_import(function_name, module_path, source_code)
# Should use DEFAULT import for default export
assert import_style == "default"
assert import_name == "Validator"
assert function_accessor == "Validator.validateRequest"
def test_class_method_without_export_uses_named_import(self) -> None:
"""Test that non-exported class methods default to named import (will fail, surfacing the issue)."""
function_name = "PartialGraphHost.toJSON"
module_path = "../inspector/partial-graph"
source_code = """
class PartialGraphHost {
static toJSON() {
return this.partialGraph?.toJSON();
}
}
"""
import_style, import_name, function_accessor = _resolve_import(function_name, module_path, source_code)
# Should use NAMED import (even though not exported - this will fail at runtime,
# which surfaces the issue that the class needs to be exported)
assert import_style == "named", f"Expected 'named' import for non-exported class, got '{import_style}'"
assert import_name == "PartialGraphHost"
assert function_accessor == "PartialGraphHost.toJSON"
def test_static_method_with_named_export(self) -> None:
"""Test that static methods on named exports use named import."""
function_name = "ServerFactory.create"
module_path = "../factories/server-factory"
source_code = """
export class ServerFactory {
public static create<T>(server: T): ServerHost<T> {
return { server };
}
}
"""
import_style, import_name, function_accessor = _resolve_import(function_name, module_path, source_code)
assert import_style == "named"
assert import_name == "ServerFactory"
assert function_accessor == "ServerFactory.create"
def test_namespace_fallback_for_complex_patterns(self) -> None:
"""Test that complex patterns fall back to namespace import."""
function_name = "Constructor.prototype.method"
module_path = "../utils/complex"
source_code = """
function Constructor() {}
Constructor.prototype.method = function() {
return 42;
};
"""
import_style, import_name, function_accessor = _resolve_import(function_name, module_path, source_code)
assert import_style == "namespace"
assert function_accessor == "Constructor.prototype.method"

View file

@ -7,7 +7,7 @@ import re
import pytest
from core.languages.js_ts.testgen import build_javascript_prompt, parse_and_validate_js_output
from core.languages.js_ts.testgen import build_javascript_prompt, parse_and_validate_js_output, strip_js_extensions
def _has_test_functions(code: str) -> bool:
@ -343,3 +343,131 @@ class TestMochaPromptContent:
assert isinstance(user_content, str)
assert "vitest" in user_content.lower()
assert "expect" in user_content.lower()
class TestStripJsExtensions:
"""Tests for stripping .js/.ts extensions from import paths."""
def test_strips_extensions_from_imports(self) -> None:
"""Test that extensions are stripped from ES module imports."""
source = "import { x } from '../path/file.js';"
result = strip_js_extensions(source)
assert result == "import { x } from '../path/file';"
def test_strips_extensions_from_require(self) -> None:
"""Test that extensions are stripped from require() calls."""
source = "const x = require('../path/file.js');"
result = strip_js_extensions(source)
assert result == "const x = require('../path/file');"
def test_strips_extensions_from_jest_mock(self) -> None:
"""Test that extensions are stripped from jest.mock() calls."""
source = "jest.mock('../path/file.js', () => {});"
result = strip_js_extensions(source)
assert result == "jest.mock('../path/file', () => {});"
def test_strips_extensions_from_vi_mock(self) -> None:
"""Test that extensions are stripped from vi.mock() calls (Vitest).
This is a regression test for the bug where vi.mock() paths retained
.js extensions while imports had them stripped, causing mock/import
path mismatch in Vitest ESM mode.
Trace IDs affected:
- 0fe99c9f-b348-4f0a-b051-0ea9455231ba
- 127cdaec-a343-4918-a86a-b646dd4d79cf
- 2b6c896e-20d7-4505-8bf4-e4a2f20b37fc
"""
source = "vi.mock('../config/paths.js', () => {});"
result = strip_js_extensions(source)
# This test will FAIL until the bug is fixed
assert result == "vi.mock('../config/paths', () => {});"
def test_strips_extensions_from_complex_vi_mock(self) -> None:
"""Test extension stripping for complex vi.mock() with multiline callback."""
source = """vi.mock('../config/paths.js', () => {
return {
resolveCredentialsDir: vi.fn(() => '/mock/credentials'),
};
});"""
result = strip_js_extensions(source)
assert "vi.mock('../config/paths'" in result
assert "vi.mock('../config/paths.js'" not in result
def test_strips_all_vi_mock_variants(self) -> None:
"""Test that all vi.mock variants are handled."""
source = """
vi.mock('../a.js', () => {});
vi.doMock('../b.js', () => {});
vi.unmock('../c.js');
"""
result = strip_js_extensions(source)
assert "../a'" in result
assert "../b'" in result
assert "../c'" in result
assert ".js" not in result
def test_preserves_node_modules_paths(self) -> None:
"""Test that node_modules paths (without ./) are not modified."""
source = "import { x } from 'some-package';"
result = strip_js_extensions(source)
assert result == source
def test_handles_mixed_mocks_and_imports(self) -> None:
"""Test realistic scenario with both vi.mock() and imports."""
source = """vi.mock('../config/paths.js', () => {
return {
resolveCredentialsDir: vi.fn(() => '/mock/credentials'),
};
});
import { resolveChannelAllowFromPath } from './pairing/pairing-store.js';
import { resolveCredentialsDir } from '../config/paths.js';"""
result = strip_js_extensions(source)
# All .js extensions should be removed
assert "vi.mock('../config/paths'" in result
assert "from './pairing/pairing-store'" in result
assert "from '../config/paths'" in result
# No .js should remain
assert ".js" not in result
class TestInstrumentedTestsExtensionStripping:
"""Tests for ensuring .js extensions are stripped from ALL test outputs."""
def test_strip_extensions_on_all_outputs(self) -> None:
"""Test that .js extensions should be stripped from instrumented tests too.
This is a regression test for the bug where strip_js_extensions() was only
called on generated_test_source but not on instrumented_behavior_tests
and instrumented_perf_tests, causing "Cannot find module" errors in the CLI.
"""
# Simulated LLM output with .js extensions (what comes back from LLM)
llm_generated_test = """import { buildVerifyFn } from '../../google.js';
import { authenticate } from '../../sso.js';
test('should create verify function', () => {
const fn = buildVerifyFn(mockSave);
expect(fn).toBeDefined();
});"""
# All three test outputs should have extensions stripped
# (in practice, instrumented tests have capture() calls added, but for this test we're checking extension stripping)
expected_stripped = """import { buildVerifyFn } from '../../google';
import { authenticate } from '../../sso';
test('should create verify function', () => {
const fn = buildVerifyFn(mockSave);
expect(fn).toBeDefined();
});"""
# Verify that strip_js_extensions works
result = strip_js_extensions(llm_generated_test)
assert result == expected_stripped, "strip_js_extensions should remove .js extensions"
# Regression test: verifies strip_js_extensions() is applied correctly.
# For full end-to-end coverage, an integration test calling testgen_javascript()
# and asserting all three return values would be ideal.

View file

@ -1,12 +0,0 @@
node_modules/
dist/
build/
coverage/
*.config.js
.eslintrc.mjs
.eslintrc.json
postcss.config.js
tailwind.config.js
// Comment out the ESLint line temporarily to allow for the build to pass
**/*.ts
**/*.js

View file

@ -1,20 +0,0 @@
module.exports = {
root: true,
extends: ["next/core-web-vitals", "plugin:@typescript-eslint/recommended", "prettier"],
parser: "@typescript-eslint/parser",
parserOptions: {
project: "./tsconfig.json",
tsconfigRootDir: __dirname,
ecmaVersion: "latest",
sourceType: "module",
ecmaFeatures: {
jsx: true,
},
},
plugins: ["@typescript-eslint", "react"],
ignorePatterns: ["dist/**", "node_modules/**", "*.config.js", "*.config.mjs", ".eslintrc.js"],
rules: {
"react/react-in-jsx-scope": "off",
"@typescript-eslint/no-explicit-any": "warn",
},
}

View file

@ -0,0 +1,34 @@
- generic [ref=e2]:
- generic [ref=e4]:
- img [ref=e6]
- generic [ref=e12]:
- heading "Get started with Codeflash" [level=1] [ref=e13]
- paragraph [ref=e14]: Make all your code optimal
- button "Continue with GitHub" [ref=e15] [cursor=pointer]:
- img [ref=e16]
- generic [ref=e18]: Continue with GitHub
- generic [ref=e20]:
- link "Terms" [ref=e21] [cursor=pointer]:
- /url: https://www.codeflash.ai/terms-of-service
- link "Privacy" [ref=e22] [cursor=pointer]:
- /url: https://www.codeflash.ai/privacy-policy
- link "Documentation" [ref=e23] [cursor=pointer]:
- /url: https://docs.codeflash.ai
- generic [ref=e25]:
- heading "Always Ship Optimal Code" [level=2] [ref=e27]
- generic [ref=e28]:
- generic [ref=e29]:
- img [ref=e31]
- paragraph [ref=e34]: VS Code/Cursor Extension to optimize all code locally
- generic [ref=e35]:
- img [ref=e37]
- paragraph [ref=e40]: Set it as a GitHub action to automate optimization
- generic [ref=e41]:
- img [ref=e43]
- paragraph [ref=e46]: Codeflash finds 2-55x performance improvements automatically
- generic [ref=e47]:
- img [ref=e49]
- paragraph [ref=e52]: Confidently merge the tested and proven optimizations
- generic [ref=e53]:
- img [ref=e55]
- paragraph [ref=e58]: Start free. No credit card, no lock-in

View file

@ -0,0 +1,44 @@
import nextConfig from "eslint-config-next"
import prettier from "eslint-config-prettier"
// Find the config object that includes the @typescript-eslint plugin
// and add our custom rule there
const eslintConfig = [
...nextConfig.map(config => {
if (config.plugins?.["@typescript-eslint"]) {
return {
...config,
rules: {
...config.rules,
"@typescript-eslint/no-explicit-any": "warn",
},
}
}
return config
}),
prettier,
{
rules: {
"react/react-in-jsx-scope": "off",
// Downgrade React Compiler rules to warnings: pre-existing patterns, fix incrementally
"react-hooks/set-state-in-effect": "warn",
"react-hooks/error-boundaries": "warn",
"react-hooks/immutability": "warn",
"react-hooks/preserve-manual-memoization": "warn",
"react-hooks/purity": "warn",
"react-hooks/refs": "warn",
"react-hooks/static-components": "warn",
},
},
{
ignores: [
"dist/**",
"node_modules/**",
"*.config.js",
"*.config.mjs",
".next/**",
],
},
]
export default eslintConfig

View file

@ -1,3 +1,13 @@
import bundleAnalyzer from "@next/bundle-analyzer"
import { dirname } from "path"
import { fileURLToPath } from "url"
const withBundleAnalyzer = bundleAnalyzer({
enabled: process.env.ANALYZE === "true",
})
const __dirname = dirname(fileURLToPath(import.meta.url))
/** @type {import("next").NextConfig} */
const nextConfig = {
transpilePackages: ["@codeflash-ai/common"],
@ -25,12 +35,22 @@ const nextConfig = {
return config
},
turbopack: {
root: __dirname,
resolveAlias: {
// Stub Node.js built-ins that web-tree-sitter tries to import in the browser.
// Uses { browser: ... } so aliases only apply to client bundles, not SSR.
'fs': { browser: './src/lib/empty-shim.js' },
'fs/promises': { browser: './src/lib/empty-shim.js' },
'path': { browser: './src/lib/empty-shim.js' },
'module': { browser: './src/lib/empty-shim.js' },
},
},
experimental: {
serverActions: {
allowedOrigins: ["app.codeflash.ai", "localhost:3000"],
bodySizeLimit: '5mb', // Increased from default 1mb to handle large PR creation payloads
},
instrumentationHook: true,
},
typescript: {
ignoreBuildErrors: false,
@ -56,7 +76,7 @@ const nextConfig = {
import { withSentryConfig } from "@sentry/nextjs"
export default withSentryConfig(
export default withBundleAnalyzer(withSentryConfig(
nextConfig,
{
// For all available options, see:
@ -86,4 +106,4 @@ export default withSentryConfig(
// Disable automatic instrumentation that might cause issues
automaticVercelMonitors: false,
},
)
))

File diff suppressed because it is too large Load diff

View file

@ -7,10 +7,11 @@
"build": " npm install --loglevel verbose && npx prisma generate && npx next build",
"deploy": "az webapp up -n codeflash-webapp-2 --sku P1V2 --runtime NODE:20-lts",
"start": "node_modules/next/dist/bin/next start",
"lint": "next lint --fix",
"lint:check": "next lint",
"lint": "eslint --fix .",
"lint:check": "eslint .",
"test": "vitest",
"type-check": "tsc --noEmit",
"analyze": "ANALYZE=true next build",
"prisma:generate": "npx prisma generate",
"prisma:migrate": "npx prisma migrate dev",
"prepare": "simple-git-hooks",
@ -20,12 +21,14 @@
},
"dependencies": {
"@anthropic-ai/sdk": "^0.74.0",
"@auth0/nextjs-auth0": "^3.3.0",
"@azure/msal-node": "^3.7.3",
"@auth0/nextjs-auth0": "^4",
"@codeflash-ai/common": "^1.0.30",
"@hookform/resolvers": "^3.3.2",
"@monaco-editor/react": "^4.7.0",
"@opentelemetry/auto-instrumentations-node": "^0.72.0",
"@opentelemetry/sdk-node": "^0.214.0",
"@prisma/client": "^6.7.0",
"@prisma/instrumentation": "^7.6.0",
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-dropdown-menu": "^2.0.6",
"@radix-ui/react-label": "^2.0.2",
@ -37,37 +40,38 @@
"@radix-ui/react-tabs": "^1.1.3",
"@radix-ui/react-toast": "^1.1.5",
"@radix-ui/react-tooltip": "^1.1.4",
"@sentry/nextjs": "^9.34.0",
"@sentry/nextjs": "^10.38.0",
"@sentry/opentelemetry": "^10.47.0",
"@types/node": "^24.3.0",
"@types/pg": "^8.10.9",
"@types/react": "^18",
"@types/react-dom": "^18",
"@types/react": "19.2.13",
"@types/react-dom": "19.2.3",
"@types/react-syntax-highlighter": "^15.5.13",
"chart.js": "^4.4.9",
"chartjs-plugin-datalabels": "^2.2.0",
"class-variance-authority": "^0.7.0",
"clsx": "^2.0.0",
"date-fns": "^4.1.0",
"diff": "^8.0.2",
"framer-motion": "^12.12.1",
"github-markdown-css": "^5.4.0",
"jsonwebtoken": "^9.0.2",
"lucide-react": "^0.381.0",
"lucide-react": "^0.563.0",
"marked": "^16.1.1",
"next": "^14.2.32",
"next-themes": "^0.3.0",
"motion": "^12.38.0",
"next": "16.1.6",
"next-themes": "^0.4.6",
"node-ts-cache": "^4.4.0",
"node-ts-cache-storage-memory": "^4.4.0",
"papaparse": "^5.5.3",
"pg": "^8.11.3",
"postcss": "^8",
"posthog-js": "1.127.0",
"posthog-node": "^4.0.1",
"prism-react-renderer": "^2.4.1",
"react": "^18",
"react": "19.2.4",
"react-chartjs-2": "^5.3.0",
"react-dom": "^18",
"react-dom": "19.2.4",
"react-hook-form": "^7.48.2",
"react-markdown": "^9.0.1",
"react-papaparse": "^4.4.0",
"react-resizable-panels": "^4.6.4",
"react-syntax-highlighter": "^16.1.0",
"remark-gfm": "^4.0.0",
@ -80,20 +84,16 @@
"zod": "^3.22.4"
},
"devDependencies": {
"@next/bundle-analyzer": "^16.2.2",
"@testing-library/react": "^16.0.0",
"@types/jsonwebtoken": "^9.0.10",
"@typescript-eslint/eslint-plugin": "^8.50.1",
"@typescript-eslint/parser": "^8.50.1",
"@types/papaparse": "^5.5.2",
"@vitejs/plugin-react": "^4.3.1",
"autoprefixer": "^10.0.1",
"baseline-browser-mapping": "^2.9.11",
"eslint": "^8.57.0",
"eslint-config-next": "15.5.2",
"eslint": "^9",
"eslint-config-next": "16.1.6",
"eslint-config-prettier": "^10.1.8",
"eslint-plugin-import": "^2.29.0",
"eslint-plugin-n": "^16.6.2",
"eslint-plugin-promise": "^6.1.1",
"eslint-plugin-react": "^7.33.2",
"jsdom": "^24.1.0",
"lint-staged": "^15.4.3",
"prettier": "3.2.5",
@ -117,5 +117,9 @@
"**/*.{json,md}": [
"prettier --write"
]
},
"overrides": {
"@types/react": "19.2.13",
"@types/react-dom": "19.2.3"
}
}

BIN
js/cf-webapp/roadmap.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 511 KiB

View file

@ -11,8 +11,10 @@ Sentry.init({
? "https://0fa0f40b2d709e4f1eb9aac76ff9e6be@o4506833230561280.ingest.us.sentry.io/4506833279582208"
: undefined,
// Adjust this value in production, or use tracesSampler for greater control
tracesSampleRate: 1,
tracesSampleRate: isProduction ? 0.1 : 1,
// Let the custom OTel setup in src/instrumentation.ts manage OpenTelemetry
skipOpenTelemetrySetup: true,
// Setting this option to true will print useful information to the console while you're setting up Sentry.
debug: false,

View file

@ -57,8 +57,8 @@ export async function fetchUserInfo(): Promise<{
error?: string
}> {
try {
const { getSession } = await import("@auth0/nextjs-auth0")
const session = await getSession()
const { auth0 } = await import("@/lib/auth0")
const session = await auth0.getSession()
if (!session?.user) {
return { error: "Unauthorized" }

View file

@ -1,5 +1,5 @@
import { redirect } from "next/navigation"
import { getSession } from "@auth0/nextjs-auth0"
import { auth0 } from "@/lib/auth0"
import Link from "next/link"
import { type JSX } from "react"
import { APP_ROUTES } from "@/lib/types"
@ -12,12 +12,13 @@ function isValidReturnUrl(url: string): boolean {
return false
}
export default async function AuthenticationPage({
searchParams,
}: {
searchParams: { returnTo?: string; error?: string }
}): Promise<JSX.Element> {
const session = await getSession()
export default async function AuthenticationPage(
props: {
searchParams: Promise<{ returnTo?: string; error?: string }>
}
): Promise<JSX.Element> {
const searchParams = await props.searchParams;
const session = await auth0.getSession()
if (session) {
// User is already logged in
@ -35,7 +36,7 @@ export default async function AuthenticationPage({
<h2 className="text-2xl font-bold">Login Error</h2>
<p className="mt-2">There was an error during login. Please try again.</p>
<Link
href="/api/auth/login"
href="/auth/login"
className="mt-4 inline-block rounded bg-blue-500 px-4 py-2 text-white"
>
Try Again
@ -52,6 +53,6 @@ export default async function AuthenticationPage({
: APP_ROUTES.BASE
console.log(`[Login Page] Redirecting to Auth0 with returnTo: ${returnTo}`)
const loginUrl = `/api/auth/login?returnTo=${encodeURIComponent(returnTo)}`
const loginUrl = `/auth/login?returnTo=${encodeURIComponent(returnTo)}`
redirect(loginUrl)
}

View file

@ -1,6 +1,6 @@
"use server"
import { getSession } from "@auth0/nextjs-auth0"
import { auth0 } from "@/lib/auth0"
import { markUserCompletedOnboarding, submitOnboardingQuestions } from "@codeflash-ai/common"
import PostHogClient from "@/lib/posthog"
import { redirect } from "next/navigation"
@ -11,7 +11,7 @@ export async function SubmitFirstOnboardingPage(
selectedOptions: string[],
customOptionInput: string,
): Promise<void> {
const session = await getSession()
const session = await auth0.getSession()
if (session == null) {
console.log("No session, redirecting to login")
redirect("/login")
@ -39,15 +39,16 @@ export async function SubmitFirstOnboardingPage(
custom_pain_point: customOptionInput,
},
})
await posthog?.shutdown()
await posthog?.flush()
await submitOnboardingQuestions(user_id, email)
// Check for saved redirect URL after onboarding completion
const returnUrl = cookies().get("returnAfterOnboarding")?.value
const cookieStore = await cookies()
const returnUrl = cookieStore.get("returnAfterOnboarding")?.value
console.log("Checking for saved returnUrl:", returnUrl)
if (returnUrl) {
console.log("Found saved returnUrl, redirecting to:", returnUrl)
cookies().delete("returnAfterOnboarding")
cookieStore.delete("returnAfterOnboarding")
redirect(returnUrl)
} else {
console.log("No saved returnUrl, redirecting to /app/gettingstarted")
@ -56,7 +57,7 @@ export async function SubmitFirstOnboardingPage(
}
export async function SubmitSkipOnboardingPage(): Promise<void> {
const session = await getSession()
const session = await auth0.getSession()
if (session == null) {
console.log("No session, redirecting to login")
redirect("/login")
@ -80,15 +81,16 @@ export async function SubmitSkipOnboardingPage(): Promise<void> {
username: nickname,
},
})
await posthog?.shutdown()
await posthog?.flush()
await markUserCompletedOnboarding(user_id)
// Checking for saved redirect URL after onboarding completion
const returnUrl = cookies().get("returnAfterOnboarding")?.value
const cookieStore = await cookies()
const returnUrl = cookieStore.get("returnAfterOnboarding")?.value
console.log(`Checking for saved returnTo URL: ${returnUrl}`)
if (returnUrl) {
console.log("Found saved returnUrl, redirecting to:", returnUrl)
cookies().delete("returnAfterOnboarding")
cookieStore.delete("returnAfterOnboarding")
redirect(returnUrl)
} else {
console.log("No saved returnUrl, redirecting to /app/gettingstarted")

View file

@ -1,6 +1,6 @@
"use server"
import { getSession } from "@auth0/nextjs-auth0"
import { auth0 } from "@/lib/auth0"
import PostHogClient from "@/lib/posthog"
import { redirect } from "next/navigation"
@ -10,7 +10,7 @@ export async function SubmitSecondOnboardingPage(
pythonLibraries: string[] | null,
colleagueInviteEmail: string | null,
): Promise<void> {
const session = await getSession()
const session = await auth0.getSession()
if (session == null) {
console.log("No session, redirecting to login")
redirect("/login")
@ -31,5 +31,5 @@ export async function SubmitSecondOnboardingPage(
...(colleagueInviteEmail && { colleague_invite_email: colleagueInviteEmail }),
},
})
await posthog?.shutdown()
await posthog?.flush()
}

View file

@ -2,7 +2,7 @@
import { useMemo, useState, useEffect, type ReactNode } from "react"
import { useRouter } from "next/navigation"
import { AnimatePresence, motion } from "framer-motion"
import { AnimatePresence, motion } from "motion/react"
import {
ArrowRight,
ArrowRightCircle,

View file

@ -6,13 +6,13 @@ import {
getUserReferralData,
} from "@codeflash-ai/common"
import PostHogClient from "@/lib/posthog"
import { getSession } from "@auth0/nextjs-auth0"
import { auth0 } from "@/lib/auth0"
export async function upsertReferralSource(
referralSource: string,
additionalComments?: string,
): Promise<any> {
const session = await getSession()
const session = await auth0.getSession()
if (session != null) {
setUserReferralData(session.user.sub, referralSource, additionalComments)
const posthog = PostHogClient()

View file

@ -1,4 +1,5 @@
"use client"
import { type JSX } from "react"
import { Button } from "@/components/ui/button"
import { Trash2 } from "lucide-react"
import { type cf_api_keys } from "@prisma/client"

View file

@ -116,7 +116,6 @@ export function CreateApiKeyDialog(): React.JSX.Element {
</DialogTrigger>
<DialogContent className="sm:max-w-[425px]">
<Form {...form}>
{/* eslint-disable-next-line @typescript-eslint/no-misused-promises */}
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
<DialogHeader>
<DialogTitle>Create new API key</DialogTitle>

View file

@ -0,0 +1,20 @@
import { Skeleton } from "@/components/ui/skeleton"
export default function ApiKeysLoading() {
return (
<div className="py-6 sm:py-8 px-4 sm:px-6 max-w-[1400px] mx-auto">
<Skeleton className="h-8 w-32 mb-6" />
<div className="space-y-4">
{Array.from({ length: 3 }).map((_, i) => (
<div key={i} className="flex items-center gap-4 p-4 rounded-lg border">
<div className="flex-1 space-y-2">
<Skeleton className="h-4 w-40" />
<Skeleton className="h-3 w-56" />
</div>
<Skeleton className="h-8 w-16 rounded-md" />
</div>
))}
</div>
</div>
)
}

View file

@ -1,13 +1,13 @@
"use server"
import { getSession } from "@auth0/nextjs-auth0"
import { type JSX } from "react"
import { auth0 } from "@/lib/auth0"
import { CreateApiKeyDialog } from "./dialog-create-api-key"
import { Separator } from "@/components/ui/separator"
import { ApiKeyTable } from "./api-key-table"
import { type cf_api_keys, PrismaClient } from "@prisma/client"
import { type cf_api_keys } from "@prisma/client"
import PostHogClient from "@/lib/posthog"
import { VS_CODE_KEY_NAME } from "@codeflash-ai/common"
const prisma = new PrismaClient()
import { prisma } from "@/lib/prisma"
interface ApiKeyWithOrg extends cf_api_keys {
organization?: {
@ -23,7 +23,7 @@ interface ApiKeyWithOrg extends cf_api_keys {
}
export default async function APIKeyGenerator(): Promise<JSX.Element> {
const session = await getSession()
const session = await auth0.getSession()
// Auth handled by middleware + layout
if (!session?.user) {
throw new Error("Authentication required")
@ -40,10 +40,7 @@ export default async function APIKeyGenerator(): Promise<JSX.Element> {
// Fetch personal keys (no organization) and keys from user's organizations
const apiKeys: ApiKeyWithOrg[] = await prisma.cf_api_keys.findMany({
where: {
OR: [
{ user_id: userId, organization_id: null },
{ organization_id: { in: userOrgIds } },
],
OR: [{ user_id: userId, organization_id: null }, { organization_id: { in: userOrgIds } }],
},
include: {
organization: {
@ -68,7 +65,7 @@ export default async function APIKeyGenerator(): Promise<JSX.Element> {
event: "webapp-loaded-api-keys",
})
await posthog?.shutdown()
await posthog?.flush()
return (
<div>

View file

@ -1,5 +1,5 @@
"use server"
import { getSession } from "@auth0/nextjs-auth0"
import { auth0 } from "@/lib/auth0"
import { redirect } from "next/navigation"
import {
deleteAPIKeyById,
@ -8,15 +8,14 @@ import {
VS_CODE_KEY_NAME,
} from "@codeflash-ai/common"
import { TokenLimitExceededError } from "./token-error"
import { PrismaClient } from "@prisma/client"
const prisma = new PrismaClient()
import { prisma } from "@/lib/prisma"
import { trackApiKeyCreated } from "@/lib/analytics/tracking"
export async function generateToken(
keyName: string,
organizationId?: string,
): Promise<{ success: boolean; token: string | undefined; err: string | undefined }> {
const user = await getSession()
const user = await auth0.getSession()
if (user == null) {
redirect("/login")
}
@ -24,12 +23,16 @@ export async function generateToken(
try {
const token: string = await safeGenAndStoreAPITokenHash(keyName, userId, organizationId)
await trackApiKeyCreated(userId, { keyName, organizationId })
return { success: true, token, err: undefined }
} catch (error) {
if (error instanceof Error && error.message === "Token limit exceeded") {
return { success: false, err: new TokenLimitExceededError().message, token: undefined }
}
if (error instanceof Error && error.message === "User is not a member of the specified organization") {
if (
error instanceof Error &&
error.message === "User is not a member of the specified organization"
) {
return { success: false, err: error.message, token: undefined }
}
return {
@ -63,7 +66,7 @@ export async function generateTokenForVsCode(
}
}
export async function deleteAPIKey(id: number): Promise<void> {
const user = await getSession()
const user = await auth0.getSession()
if (user == null) {
redirect("/login")
return

View file

@ -0,0 +1,20 @@
import { Skeleton } from "@/components/ui/skeleton"
export default function BillingLoading() {
return (
<div className="py-6 sm:py-8 px-4 sm:px-6 max-w-[1400px] mx-auto">
<Skeleton className="h-8 w-32 mb-6" />
<div className="grid gap-6">
<div className="rounded-xl border p-6 space-y-4">
<Skeleton className="h-5 w-40" />
<Skeleton className="h-8 w-24" />
<Skeleton className="h-4 w-64" />
</div>
<div className="rounded-xl border p-6 space-y-4">
<Skeleton className="h-5 w-48" />
<Skeleton className="h-32 w-full rounded-md" />
</div>
</div>
</div>
)
}

View file

@ -1,22 +1,16 @@
"use server"
import { getSession } from "@auth0/nextjs-auth0"
import { auth0 } from "@/lib/auth0"
import { BillingView } from "./billing-view"
import PostHogClient from "@/lib/posthog"
import { trackBillingPageViewed } from "@/lib/analytics/tracking"
import { SUBSCRIPTION_PLANS, checkAndResetSubscriptionPeriod } from "@codeflash-ai/common"
export default async function BillingPage() {
const session = await getSession()
const session = await auth0.getSession()
if (!session?.user) return null
const userId = session.user.sub
try {
// Track page view
const posthog = PostHogClient()
posthog?.capture({
distinctId: userId,
properties: { username: session.user.nickname },
event: "webapp-loaded-billing-page",
})
await posthog?.shutdown()
await trackBillingPageViewed(userId, { username: session.user.nickname })
// Get subscription info from database with lazy reset
const subscription = (await checkAndResetSubscriptionPeriod(userId)) || {

View file

@ -1,9 +1,9 @@
import { getSession } from "@auth0/nextjs-auth0"
import { auth0 } from "@/lib/auth0"
import PostHogClient from "@/lib/posthog"
import GettingStartedClient from "./getting-started-client"
export default async function GettingStarted() {
const session = await getSession()
const session = await auth0.getSession()
if (!session) return null
const userId = session.user.sub
@ -14,7 +14,7 @@ export default async function GettingStarted() {
event: "webapp-loaded-getting-started",
})
await posthog?.shutdown()
await posthog?.flush()
return <GettingStartedClient />
}

View file

@ -1,10 +1,10 @@
import { getSession } from "@auth0/nextjs-auth0"
import { auth0 } from "@/lib/auth0"
import { redirect } from "next/navigation"
import { ReactNode } from "react"
import { hasCompletedOnboarding } from "@codeflash-ai/common"
export default async function DashboardLayout({ children }: { children: ReactNode }) {
const session = await getSession()
const session = await auth0.getSession()
if (!session) return null
const completedOnboarding = await hasCompletedOnboarding(session.user.sub)

View file

@ -0,0 +1,118 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import { prisma } from "@codeflash-ai/common"
vi.mock("@/lib/server-action-timing", () => ({
withTiming: vi.fn((_name: string, fn: Function) => fn),
}))
vi.mock("@/lib/analytics/tracking", () => ({
trackMemberInvited: vi.fn(),
}))
const mockOrg = {
id: "org-1",
organization_members: [
{
id: "member-1",
user_id: "user-1",
role: "admin",
added_at: new Date("2024-01-15"),
user: {
github_username: "alice",
name: "Alice Smith",
email: "alice@example.com",
},
},
{
id: "member-2",
user_id: "user-2",
role: "member",
added_at: new Date("2024-02-01"),
user: {
github_username: "bob",
name: "Bob Jones",
email: "bob@example.com",
},
},
],
}
describe("getOrganizationMembers", () => {
let getOrganizationMembers: typeof import("../action").getOrganizationMembers
beforeEach(async () => {
const mod = await import("../action")
getOrganizationMembers = mod.getOrganizationMembers
})
describe("successful retrieval", () => {
it("returns members when user has access", async () => {
vi.mocked(prisma.organizations.findFirst).mockResolvedValue(mockOrg as any)
const result = await getOrganizationMembers("user-1", "org-1")
expect(result.success).toBe(true)
expect(result.data).toHaveLength(2)
})
it("maps nested organization_members to flat Member structure", async () => {
vi.mocked(prisma.organizations.findFirst).mockResolvedValue(mockOrg as any)
const result = await getOrganizationMembers("user-1", "org-1")
const member = result.data![0]
expect(member).toEqual({
id: "member-1",
user_id: "user-1",
username: "alice",
name: "Alice Smith",
email: "alice@example.com",
role: "admin",
added_at: new Date("2024-01-15"),
avatarUrl: "https://github.com/alice.png",
})
})
})
describe("access control", () => {
it("returns error when organization not found", async () => {
vi.mocked(prisma.organizations.findFirst).mockResolvedValue(null)
const result = await getOrganizationMembers("user-1", "org-1")
expect(result.success).toBe(false)
expect(result.error).toBe("Organization not found")
})
it("returns error when user is not in organization members", async () => {
vi.mocked(prisma.organizations.findFirst).mockResolvedValue(mockOrg as any)
const result = await getOrganizationMembers("unknown-user", "org-1")
expect(result.success).toBe(false)
expect(result.error).toBe("You don't have access to this organization")
})
})
describe("error handling", () => {
it("returns error response when Prisma throws", async () => {
vi.mocked(prisma.organizations.findFirst).mockRejectedValue(
new Error("Connection failed"),
)
const result = await getOrganizationMembers("user-1", "org-1")
expect(result.success).toBe(false)
expect(result.error).toBe("Connection failed")
})
it("uses fallback message for non-Error exceptions", async () => {
vi.mocked(prisma.organizations.findFirst).mockRejectedValue("string error")
const result = await getOrganizationMembers("user-1", "org-1")
expect(result.success).toBe(false)
expect(result.error).toBe("Failed to get members")
})
})
})

View file

@ -8,14 +8,18 @@ import {
organizationMemberRepository,
prisma,
} from "@codeflash-ai/common"
import { withTiming } from "@/lib/server-action-timing"
import { trackMemberInvited } from "@/lib/analytics/tracking"
/**
* Get organization members
*/
export async function getOrganizationMembers(
currentUserId: string,
organizationId: string,
): Promise<ActionResponse<Member[]>> {
export const getOrganizationMembers = withTiming(
"getOrganizationMembers",
async (
currentUserId: string,
organizationId: string,
): Promise<ActionResponse<Member[]>> => {
try {
const org = await prisma.organizations.findFirst({
where: { id: organizationId },
@ -58,7 +62,8 @@ export async function getOrganizationMembers(
console.error("Failed to get organization members:", error)
return createErrorResponse(error instanceof Error ? error.message : "Failed to get members")
}
}
},
)
/**
* Add a member to organization
@ -121,6 +126,14 @@ export async function addOrganizationMember(
added_by: currentUserId,
},
})
trackMemberInvited(currentUserId, {
invitedUsername: invitedUser.username,
role,
scope: "organization",
targetId: organizationId,
})
return createSuccessResponse({
id: newMember.id,
user_id: newMember.user_id,

View file

@ -0,0 +1,21 @@
import { Skeleton } from "@/components/ui/skeleton"
export default function MembersLoading() {
return (
<div className="py-6 sm:py-8 px-4 sm:px-6 max-w-[1400px] mx-auto">
<Skeleton className="h-8 w-40 mb-6" />
<div className="space-y-4">
{Array.from({ length: 5 }).map((_, i) => (
<div key={i} className="flex items-center gap-4 p-4 rounded-lg border">
<Skeleton className="h-10 w-10 rounded-full" />
<div className="flex-1 space-y-2">
<Skeleton className="h-4 w-32" />
<Skeleton className="h-3 w-48" />
</div>
<Skeleton className="h-6 w-20 rounded-full" />
</div>
))}
</div>
</div>
)
}

View file

@ -59,13 +59,14 @@ function OrganizationMembers() {
setCurrentUserId(data.userId)
const roleResult = await getCurrentUserRole(data.userId, currentOrg?.id)
const [roleResult, result] = await Promise.all([
getCurrentUserRole(data.userId, currentOrg?.id),
getOrganizationMembers(data.userId, currentOrg?.id),
])
if (roleResult.success && roleResult.data) {
setCurrentUserRole(roleResult.data.role)
}
const result = await getOrganizationMembers(data.userId, currentOrg?.id)
if (result.success && result.data) {
setMembers(result.data)
} else {
@ -103,10 +104,7 @@ function OrganizationMembers() {
setSuccess("Member added successfully!")
}
const handleUserAdd = async (
user: GitHubUserSearchResult,
role: "admin" | "member",
) => {
const handleUserAdd = async (user: GitHubUserSearchResult, role: "admin" | "member") => {
if (!currentOrg?.id) {
return { success: false, error: "No organization selected" }
}

View file

@ -0,0 +1,163 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import { prisma } from "@codeflash-ai/common"
import { getRepositoriesForAccountCached } from "@/lib/services/repository-utils"
import { trackRepositoryConnected } from "@/lib/analytics/tracking"
vi.mock("@/lib/server-action-timing", () => ({
withTiming: vi.fn((_name: string, fn: Function) => fn),
}))
vi.mock("@/lib/services/repository-utils", () => ({
getRepositoriesForAccountCached: vi.fn(),
}))
vi.mock("@/lib/analytics/tracking", () => ({
trackMemberInvited: vi.fn(),
trackRepositoryConnected: vi.fn(),
}))
const mockRepo = {
id: "repo-1",
github_repo_id: "12345",
name: "my-repo",
full_name: "myorg/my-repo",
is_private: false,
has_github_action: true,
created_at: new Date("2024-01-01"),
last_optimized: new Date("2024-06-01"),
optimizations_limit: 100,
optimizations_used: 50,
repository_members: [{ id: "rm-1" }, { id: "rm-2" }],
}
const mockPayload = { userId: "user-1", username: "testuser" }
describe("getRepositoryById", () => {
let getRepositoryById: typeof import("../action").getRepositoryById
beforeEach(async () => {
const mod = await import("../action")
getRepositoryById = mod.getRepositoryById
})
describe("parallel fetch", () => {
it("fetches repo and authorized repoIds concurrently", async () => {
vi.mocked(prisma.repositories.findFirst).mockResolvedValue(mockRepo as any)
vi.mocked(getRepositoriesForAccountCached).mockResolvedValue({
repoIds: ["repo-1"],
repos: [],
} as any)
vi.mocked(prisma.optimization_events.count).mockResolvedValue(5)
await getRepositoryById(mockPayload as any, "repo-1")
expect(prisma.repositories.findFirst).toHaveBeenCalledTimes(1)
expect(getRepositoriesForAccountCached).toHaveBeenCalledWith(mockPayload)
})
it("returns null when repo is not found", async () => {
vi.mocked(prisma.repositories.findFirst).mockResolvedValue(null)
vi.mocked(getRepositoriesForAccountCached).mockResolvedValue({
repoIds: ["repo-1"],
repos: [],
} as any)
const result = await getRepositoryById(mockPayload as any, "repo-1")
expect(result).toBeNull()
})
it("returns null when repo is not in authorized list", async () => {
vi.mocked(prisma.repositories.findFirst).mockResolvedValue(mockRepo as any)
vi.mocked(getRepositoriesForAccountCached).mockResolvedValue({
repoIds: ["other-repo"],
repos: [],
} as any)
const result = await getRepositoryById(mockPayload as any, "repo-1")
expect(result).toBeNull()
})
})
describe("successful retrieval", () => {
beforeEach(() => {
vi.mocked(prisma.repositories.findFirst).mockResolvedValue(mockRepo as any)
vi.mocked(getRepositoriesForAccountCached).mockResolvedValue({
repoIds: ["repo-1"],
repos: [],
} as any)
})
it("returns RepositoryWithUsage with all required fields", async () => {
vi.mocked(prisma.optimization_events.count).mockResolvedValue(3)
const result = await getRepositoryById(mockPayload as any, "repo-1")
expect(result).toEqual({
id: "repo-1",
github_repo_id: "12345",
name: "my-repo",
full_name: "myorg/my-repo",
is_private: false,
is_active: true,
has_github_action: true,
created_at: new Date("2024-01-01"),
last_optimized: new Date("2024-06-01"),
optimizations_limit: 100,
optimizations_used: 50,
organization: "myorg",
avatarUrl: "https://github.com/myorg.png",
membersCount: 2,
})
})
it("sets is_active to false when no recent events", async () => {
vi.mocked(prisma.optimization_events.count).mockResolvedValue(0)
const result = await getRepositoryById(mockPayload as any, "repo-1")
expect(result!.is_active).toBe(false)
})
it("sets is_active to true when recent events exist", async () => {
vi.mocked(prisma.optimization_events.count).mockResolvedValue(10)
const result = await getRepositoryById(mockPayload as any, "repo-1")
expect(result!.is_active).toBe(true)
})
})
describe("analytics tracking", () => {
beforeEach(() => {
vi.mocked(prisma.repositories.findFirst).mockResolvedValue(mockRepo as any)
vi.mocked(getRepositoriesForAccountCached).mockResolvedValue({
repoIds: ["repo-1"],
repos: [],
} as any)
vi.mocked(prisma.optimization_events.count).mockResolvedValue(1)
})
it("calls trackRepositoryConnected for user payloads", async () => {
await getRepositoryById(mockPayload as any, "repo-1")
expect(trackRepositoryConnected).toHaveBeenCalledWith("user-1", {
repositoryId: "repo-1",
repositoryName: "myorg/my-repo",
})
})
})
describe("error handling", () => {
it("returns null and logs when Prisma throws", async () => {
vi.spyOn(console, "error").mockImplementation(() => {})
vi.mocked(prisma.repositories.findFirst).mockRejectedValue(
new Error("timeout"),
)
vi.mocked(getRepositoriesForAccountCached).mockResolvedValue({
repoIds: ["repo-1"],
repos: [],
} as any)
const result = await getRepositoryById(mockPayload as any, "repo-1")
expect(result).toBeNull()
})
})
})

View file

@ -1,12 +1,14 @@
"use server"
import * as Sentry from "@sentry/node"
import * as Sentry from "@sentry/nextjs"
import { AccountPayload, createOrUpdateUser, getUserById, prisma } from "@codeflash-ai/common"
import { eachDayOfInterval, startOfDay } from "date-fns"
import { GitHubUserSearchResult, Member, UserRole } from "@/lib/types"
import { ActionResponse, createErrorResponse, createSuccessResponse } from "@/lib/action-response"
import { RepositoryWithUsage } from "@/app/dashboard/action"
import { getRepositoriesForAccountCached } from "@/lib/services/repository-utils"
import { withTiming } from "@/lib/server-action-timing"
import { trackMemberInvited, trackRepositoryConnected } from "@/lib/analytics/tracking"
export async function getOptimizationsTimeSeriesData(repoId: string, onlySuccessful?: boolean) {
try {
@ -158,53 +160,61 @@ export async function getActiveUserLeaderboardLast30DaysForRepo(
}))
}
export async function getRepositoryById(
payload: AccountPayload,
repoId: string,
): Promise<RepositoryWithUsage | null> {
try {
const repo = await prisma.repositories.findFirst({
where: {
id: repoId,
},
include: {
repository_members: true,
},
})
const repoIds = await (await getRepositoriesForAccountCached(payload)).repoIds
export const getRepositoryById = withTiming(
"getRepositoryById",
async (payload: AccountPayload, repoId: string): Promise<RepositoryWithUsage | null> => {
try {
// Fetch repo and authorized repoIds in parallel
const [repo, { repoIds }] = await Promise.all([
prisma.repositories.findFirst({
where: { id: repoId },
include: { repository_members: true },
}),
getRepositoriesForAccountCached(payload),
])
if (!repo || !repoIds.includes(repo.id)) return null
if (!repo || !repoIds.includes(repo.id)) return null
const recentEventCount = await prisma.optimization_events.count({
where: {
repository_id: repo.id,
created_at: {
gte: new Date(Date.now() - 30 * 24 * 60 * 60 * 1000),
const recentEventCount = await prisma.optimization_events.count({
where: {
repository_id: repo.id,
created_at: {
gte: new Date(Date.now() - 30 * 24 * 60 * 60 * 1000),
},
},
},
})
})
return {
id: repo.id,
github_repo_id: repo.github_repo_id,
name: repo.name,
full_name: repo.full_name,
is_private: repo.is_private,
is_active: recentEventCount > 0,
has_github_action: repo.has_github_action,
created_at: repo.created_at,
last_optimized: repo.last_optimized,
optimizations_limit: repo.optimizations_limit,
optimizations_used: repo.optimizations_used,
organization: repo.full_name.split("/")[0],
avatarUrl: `https://github.com/${repo.full_name.split("/")[0]}.png`,
membersCount: repo.repository_members.length,
// Track repository view as a connection/engagement signal
const userId = "userId" in payload ? payload.userId : undefined
if (userId) {
trackRepositoryConnected(userId, {
repositoryId: repo.id,
repositoryName: repo.full_name,
})
}
return {
id: repo.id,
github_repo_id: repo.github_repo_id,
name: repo.name,
full_name: repo.full_name,
is_private: repo.is_private,
is_active: recentEventCount > 0,
has_github_action: repo.has_github_action,
created_at: repo.created_at,
last_optimized: repo.last_optimized,
optimizations_limit: repo.optimizations_limit,
optimizations_used: repo.optimizations_used,
organization: repo.full_name.split("/")[0],
avatarUrl: `https://github.com/${repo.full_name.split("/")[0]}.png`,
membersCount: repo.repository_members.length,
}
} catch (error) {
console.error("Failed to fetch repository by ID:", error)
return null
}
} catch (error) {
console.error("Failed to fetch repository by ID:", error)
return null
}
}
},
)
export async function addRepositoryMemberById(
currentUserId: string,
@ -265,6 +275,13 @@ export async function addRepositoryMemberById(
},
})
trackMemberInvited(currentUserId, {
invitedUsername: invitedUser.username,
role,
scope: "repository",
targetId: repoId,
})
return createSuccessResponse({
id: newMember.id,
user_id: newMember.user_id,

View file

@ -576,9 +576,22 @@ function RepositoryDetail() {
setRepository(currentRepo)
const totalAttempts = await getUserOptimizationCountByRepo(repositoryId)
const successfulAttempts = await getUserOptimizationSuccessfulCountByRepo(repositoryId)
const optimizationsOverTime = await getOptimizationsTimeSeriesData(repositoryId, false)
// Fetch all statistics in parallel - these are all independent queries
const [
totalAttempts,
successfulAttempts,
optimizationsOverTime,
successfulOptimizationsOverTime,
prData,
leaderboardData,
] = await Promise.all([
getUserOptimizationCountByRepo(repositoryId),
getUserOptimizationSuccessfulCountByRepo(repositoryId),
getOptimizationsTimeSeriesData(repositoryId, false),
getOptimizationsTimeSeriesData(repositoryId, true),
getPullRequestEventTimeSeriesData(selectedPrYear, repositoryId),
getActiveUserLeaderboardLast30DaysForRepo(repositoryId),
])
if (Array.isArray(optimizationsOverTime) && optimizationsOverTime.length > 0) {
const optimizationValues = optimizationsOverTime.map(item => item?.count || 0)
@ -590,11 +603,6 @@ function RepositoryDetail() {
setOptimizationsTrendDates([])
}
const successfulOptimizationsOverTime = await getOptimizationsTimeSeriesData(
repositoryId,
true,
)
if (
Array.isArray(successfulOptimizationsOverTime) &&
successfulOptimizationsOverTime.length > 0
@ -608,16 +616,12 @@ function RepositoryDetail() {
setSuccessfulOptimizationsTrendDates([])
}
const prData = await getPullRequestEventTimeSeriesData(selectedPrYear, repositoryId)
if (Array.isArray(prData)) {
setPrActivityData(prData)
} else {
setPrActivityData([])
}
const leaderboardData = await getActiveUserLeaderboardLast30DaysForRepo(repositoryId)
if (Array.isArray(leaderboardData)) {
setActiveUsersData(leaderboardData)
} else {

View file

@ -0,0 +1,438 @@
"use client"
import { useState, useMemo } from "react"
import {
Clock,
GitPullRequest,
Search,
ChevronDown,
X,
RefreshCw,
Filter,
ArrowUpDown,
BookOpen,
} from "lucide-react"
import Image from "next/image"
import { Card } from "@/components/ui/card"
import Link from "next/link"
import { useRouter } from "next/navigation"
import type { RepositoryWithUsage } from "@/app/dashboard/action"
/** Serialized version for server→client boundary (Dates become ISO strings) */
type SerializedRepository = Omit<RepositoryWithUsage, "created_at" | "last_optimized"> & {
created_at: string
last_optimized: string | null
}
import { useOutsideClick } from "@/components/hooks/useOutsideClick"
function SearchBar({
searchQuery,
setSearchQuery,
}: {
searchQuery: string
setSearchQuery: (value: string) => void
}) {
return (
<div className="relative flex-1 group">
<div className="relative">
<div className="absolute inset-y-0 left-0 flex items-center pl-3 pointer-events-none">
<Search
size={18}
className="text-muted-foreground/70 group-focus-within:text-primary transition-colors"
/>
</div>
<input
type="text"
placeholder="Search repositories..."
value={searchQuery}
onChange={e => setSearchQuery(e.target.value)}
className="block w-full rounded-xl border border-border bg-background/60 p-3 pl-10 text-foreground focus:border-primary focus:ring-1 focus:ring-primary transition-all duration-200"
/>
{searchQuery && (
<button
className="absolute inset-y-0 right-0 flex items-center pr-3 text-muted-foreground hover:text-foreground"
onClick={() => setSearchQuery("")}
>
<X size={18} className="opacity-70 hover:opacity-100" />
</button>
)}
</div>
</div>
)
}
function RepositoryCard({ repo }: { repo: SerializedRepository }) {
return (
<Link href={`/repositories/${repo.id}`}>
<Card
key={repo.id}
className="bg-card bg-muted/5 rounded-xl border border-border hover:border-primary/30 hover:shadow-md hover:shadow-primary/5 transition-all duration-300 overflow-hidden group"
>
<div className="p-5">
<div className="flex items-start">
<div className="mr-3 flex-shrink-0">
{repo.avatarUrl ? (
<div className="w-9 h-9 sm:w-11 sm:h-11 rounded-full overflow-hidden border-2 border-border/50 group-hover:border-primary/20 transition-colors">
<Image
src={repo.avatarUrl}
alt={`${repo.organization} avatar`}
width={44}
height={44}
className="object-cover w-full h-full"
/>
</div>
) : (
<div className="w-9 h-9 sm:w-11 sm:h-11 rounded-full bg-gradient-to-br from-primary/10 to-primary/30 flex items-center justify-center border-2 border-border group-hover:from-primary/20 group-hover:to-primary/40 transition-colors">
<span className="text-primary font-semibold">
{repo.name?.substring(0, 1).toUpperCase() || "?"}
</span>
</div>
)}
</div>
<div className="flex-1 min-w-0">
<div className="flex items-center flex-wrap gap-1">
<h3 className="text-sm sm:text-base font-semibold text-primary hover:underline truncate">
{repo.name || "Unknown Repository"}
</h3>
<span
className={`ml-1 px-1.5 sm:px-2 py-0.5 text-xs font-medium rounded-full ${repo.is_private ? "bg-amber-100 text-amber-700" : "bg-emerald-100 text-emerald-700"}`}
>
{repo.is_private ? "Private" : "Public"}
</span>
</div>
<p className="text-xs sm:text-sm text-muted-foreground mb-2">
{repo.full_name || repo.name}
</p>
<div className="flex items-center flex-wrap gap-1.5 sm:gap-2">
<span
className={`inline-flex items-center px-1.5 sm:px-2 py-0.5 sm:py-1 rounded-full text-xs ${repo.is_active ? "bg-green-100 text-green-700" : "bg-gray-100 text-gray-600"}`}
>
<span
className={`inline-block w-1.5 sm:w-2 h-1.5 sm:h-2 rounded-full ${repo.is_active ? "bg-green-500" : "bg-gray-400"} mr-1 sm:mr-1.5`}
></span>
<span>{repo.is_active ? "Active" : "Inactive"}</span>
</span>
{repo.has_github_action && (
<span className="inline-flex items-center px-1.5 sm:px-2 py-0.5 sm:py-1 rounded-full bg-blue-50 text-xs text-blue-700">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
width="10"
height="10"
className="mr-1 fill-current sm:w-3 sm:h-3"
>
<path d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0016 8c0-4.42-3.58-8-8-8z"></path>
</svg>
Action
</span>
)}
{repo.membersCount !== undefined && repo.membersCount > 0 && (
<span className="inline-flex items-center px-1.5 sm:px-2 py-0.5 sm:py-1 rounded-full bg-indigo-50 text-xs text-indigo-700">
<svg
xmlns="http://www.w3.org/2000/svg"
width="10"
height="10"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
className="mr-1 sm:w-3 sm:h-3"
>
<path d="M17 21v-2a4 4 0 0 0-4-4H5a4 4 0 0 0-4 4v2"></path>
<circle cx="9" cy="7" r="4"></circle>
<path d="M23 21v-2a4 4 0 0 0-3-3.87"></path>
<path d="M16 3.13a4 4 0 0 1 0 7.75"></path>
</svg>
{repo.membersCount}
</span>
)}
</div>
</div>
</div>
{repo.last_optimized && (
<div className="mt-3 sm:mt-4 text-xs text-muted-foreground flex items-center">
<Clock size={10} className="mr-1 sm:w-3 sm:h-3" />
Last optimized: {new Date(repo.last_optimized).toLocaleDateString()}
</div>
)}
</div>
</Card>
</Link>
)
}
export function RepositoryList({ repositories }: { repositories: SerializedRepository[] }) {
const router = useRouter()
const [searchQuery, setSearchQuery] = useState("")
const [filter, setFilter] = useState<"all" | "active" | "public" | "private">("all")
const [isFilterDropdownOpen, setIsFilterDropdownOpen] = useState(false)
const [sortBy, setSortBy] = useState<"name">("name")
const [isSortDropdownOpen, setIsSortDropdownOpen] = useState(false)
const [isRefreshing, setIsRefreshing] = useState(false)
const filterDropdownRef = useOutsideClick(() => setIsFilterDropdownOpen(false))
const sortDropdownRef = useOutsideClick(() => setIsSortDropdownOpen(false))
const getSortLabel = (sortType: string) => {
switch (sortType) {
case "name":
return "Name"
default:
return "Last Optimized"
}
}
const handleRefresh = () => {
if (isRefreshing) return
setIsRefreshing(true)
router.refresh()
// Reset after a short delay since router.refresh() doesn't provide a completion callback
setTimeout(() => setIsRefreshing(false), 2000)
}
const filteredRepositories = useMemo(() => {
if (!repositories || !Array.isArray(repositories)) {
return []
}
let repos = repositories.filter(repo => {
if (!repo) return false
const matchesSearch =
searchQuery === "" ||
(repo.name && repo.name.toLowerCase().includes(searchQuery.toLowerCase())) ||
(repo.full_name && repo.full_name.toLowerCase().includes(searchQuery.toLowerCase())) ||
(repo.organization && repo.organization.toLowerCase().includes(searchQuery.toLowerCase()))
if (!matchesSearch) return false
switch (filter) {
case "active":
return repo.is_active
case "public":
return !repo.is_private
case "private":
return repo.is_private
default:
return true
}
})
switch (sortBy) {
case "name":
repos = repos.sort((a, b) => {
const nameA = a?.name || ""
const nameB = b?.name || ""
return nameA.localeCompare(nameB)
})
break
}
return repos
}, [repositories, searchQuery, filter, sortBy])
return (
<>
<div className="flex justify-between items-center mb-4 sm:mb-6">
<h2 className="text-base sm:text-lg font-semibold flex items-center">
<BookOpen size={18} className="mr-2 text-primary" />
Repository List
</h2>
<button
onClick={handleRefresh}
disabled={isRefreshing}
className={`flex items-center gap-1 sm:gap-2 px-3 sm:px-4 py-1.5 sm:py-2 rounded-lg text-xs sm:text-sm border border-border bg-background hover:bg-muted/50 transition-colors ${
isRefreshing ? "opacity-50 cursor-not-allowed" : ""
}`}
>
<RefreshCw
size={12}
className={`text-muted-foreground sm:w-4 sm:h-4 ${isRefreshing ? "animate-spin" : ""}`}
/>
{isRefreshing ? "Refreshing..." : "Refresh"}
</button>
</div>
<div className="flex flex-col md:flex-row gap-3 sm:gap-4 mb-5 sm:mb-6">
<SearchBar searchQuery={searchQuery} setSearchQuery={setSearchQuery} />
<div className="flex gap-2">
<div className="relative" ref={filterDropdownRef}>
<button
onClick={() => setIsFilterDropdownOpen(!isFilterDropdownOpen)}
className="flex items-center gap-1 sm:gap-2 px-3 sm:px-4 py-2.5 sm:py-3 text-xs sm:text-sm bg-background border border-border rounded-xl hover:border-primary/50 transition-colors focus:outline-none focus:ring-1 focus:ring-primary"
>
<Filter size={14} className="text-muted-foreground sm:w-4 sm:h-4" />
<span>
{filter === "all"
? "All"
: filter.charAt(0).toUpperCase() + filter.slice(1).replace(/-/g, " ")}
</span>
<ChevronDown
size={14}
className={`transition-transform text-muted-foreground sm:w-4 sm:h-4 ${isFilterDropdownOpen ? "rotate-180" : ""}`}
/>
</button>
{isFilterDropdownOpen && (
<div className="absolute z-10 mt-2 w-48 sm:w-52 bg-card rounded-xl shadow-lg overflow-hidden border border-border animate-in fade-in-50 slide-in-from-top-5">
<div className="py-1">
<button
onClick={() => {
setFilter("all")
setIsFilterDropdownOpen(false)
}}
className={`w-full px-3 sm:px-4 py-2 sm:py-2.5 text-left hover:bg-muted flex items-center ${filter === "all" ? "bg-primary/10 text-primary font-medium" : ""}`}
>
<span className="w-4 sm:w-5 h-4 sm:h-5 mr-1.5 sm:mr-2 flex items-center justify-center">
{filter === "all" && (
<span className="w-1.5 h-1.5 rounded-full bg-primary"></span>
)}
</span>
All repositories
</button>
</div>
<div className="border-t border-border py-1">
<button
onClick={() => {
setFilter("active")
setIsFilterDropdownOpen(false)
}}
className={`w-full px-3 sm:px-4 py-2 sm:py-2.5 text-left hover:bg-muted flex items-center ${filter === "active" ? "bg-primary/10 text-primary font-medium" : ""}`}
>
<span className="w-4 sm:w-5 h-4 sm:h-5 mr-1.5 sm:mr-2 flex items-center justify-center">
{filter === "active" && (
<span className="w-1.5 h-1.5 rounded-full bg-primary"></span>
)}
</span>
Active
</button>
</div>
<div className="border-t border-border py-1">
<button
onClick={() => {
setFilter("public")
setIsFilterDropdownOpen(false)
}}
className={`w-full px-3 sm:px-4 py-2 sm:py-2.5 text-left hover:bg-muted flex items-center ${filter === "public" ? "bg-primary/10 text-primary font-medium" : ""}`}
>
<span className="w-4 sm:w-5 h-4 sm:h-5 mr-1.5 sm:mr-2 flex items-center justify-center">
{filter === "public" && (
<span className="w-1.5 h-1.5 rounded-full bg-primary"></span>
)}
</span>
Public
</button>
<button
onClick={() => {
setFilter("private")
setIsFilterDropdownOpen(false)
}}
className={`w-full px-3 sm:px-4 py-2 sm:py-2.5 text-left hover:bg-muted flex items-center ${filter === "private" ? "bg-primary/10 text-primary font-medium" : ""}`}
>
<span className="w-4 sm:w-5 h-4 sm:h-5 mr-1.5 sm:mr-2 flex items-center justify-center">
{filter === "private" && (
<span className="w-1.5 h-1.5 rounded-full bg-primary"></span>
)}
</span>
Private
</button>
</div>
</div>
)}
</div>
<div className="relative" ref={sortDropdownRef}>
<button
onClick={() => setIsSortDropdownOpen(!isSortDropdownOpen)}
className="flex items-center gap-1 sm:gap-2 px-3 sm:px-4 py-2.5 sm:py-3 text-xs sm:text-sm bg-background border border-border rounded-xl hover:border-primary/50 transition-colors focus:outline-none focus:ring-1 focus:ring-primary"
>
<ArrowUpDown size={14} className="text-muted-foreground sm:w-4 sm:h-4" />
<span>Sort: {getSortLabel(sortBy)}</span>
<ChevronDown
size={14}
className={`transition-transform text-muted-foreground sm:w-4 sm:h-4 ${isSortDropdownOpen ? "rotate-180" : ""}`}
/>
</button>
{isSortDropdownOpen && (
<div className="absolute right-0 z-10 mt-2 w-48 sm:w-52 bg-card rounded-xl shadow-lg overflow-hidden border border-border animate-in fade-in-50 slide-in-from-top-5">
<div className="py-1">
<button
onClick={() => {
setSortBy("name")
setIsSortDropdownOpen(false)
}}
className={`w-full px-3 sm:px-4 py-2 sm:py-2.5 text-left hover:bg-muted flex items-center ${sortBy === "name" ? "bg-primary/10 text-primary font-medium" : ""}`}
>
<span className="w-4 sm:w-5 h-4 sm:h-5 mr-1.5 sm:mr-2 flex items-center justify-center">
{sortBy === "name" && (
<span className="w-1.5 h-1.5 rounded-full bg-primary"></span>
)}
</span>
Name
</button>
</div>
</div>
)}
</div>
</div>
</div>
{searchQuery && (
<div className="flex items-center mb-4 sm:mb-5 ml-1">
<span className="text-xs sm:text-sm text-muted-foreground mr-1 sm:mr-2">
Searching for:
</span>
<div className="bg-primary/10 text-primary px-2 sm:px-3 py-1 sm:py-1.5 rounded-full text-xs sm:text-sm flex items-center gap-1 sm:gap-1.5">
<span>{searchQuery}</span>
<button
onClick={() => setSearchQuery("")}
className="text-primary hover:text-primary/80"
>
<X size={12} className="sm:w-4 sm:h-4" />
</button>
</div>
</div>
)}
{filteredRepositories.length === 0 ? (
<div className="text-center py-16 sm:py-20 bg-card/50 rounded-xl border border-dashed border-border">
<div className="inline-flex items-center justify-center w-16 h-16 sm:w-20 sm:h-20 rounded-full bg-muted/40 mb-3 sm:mb-4">
<Search size={24} className="text-muted-foreground sm:w-7 sm:h-7" />
</div>
<h3 className="text-base sm:text-lg font-medium mb-1 sm:mb-2">No repositories found</h3>
<p className="text-xs sm:text-sm text-muted-foreground max-w-md mx-auto">
{
"We couldn't find any repositories matching your search criteria. Try adjusting your filters or search term."
}
</p>
<button
onClick={() => {
setSearchQuery("")
setFilter("all")
}}
className="mt-4 sm:mt-5 px-4 sm:px-5 py-2 sm:py-2.5 bg-primary text-primary-foreground rounded-lg sm:rounded-xl text-xs sm:text-sm hover:bg-primary/90 transition-colors"
>
Clear filters
</button>
</div>
) : (
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-3 sm:gap-5">
{filteredRepositories.map(repo => (
<RepositoryCard key={repo.id} repo={repo} />
))}
</div>
)}
</>
)
}

View file

@ -0,0 +1,22 @@
"use client"
import { RefreshCw } from "lucide-react"
export default function RepositoriesError({ reset }: { error: Error; reset: () => void }) {
return (
<div className="flex justify-center items-center h-[70vh]">
<div className="bg-red-50 text-red-800 p-6 sm:p-8 rounded-xl max-w-md border border-red-200">
<h3 className="text-base sm:text-lg font-medium mb-2 sm:mb-3">Something went wrong</h3>
<p className="mb-3 sm:mb-4 text-sm sm:text-base">
There was an error loading the repositories page.
</p>
<button
onClick={reset}
className="flex items-center gap-1 sm:gap-2 w-full justify-center px-3 sm:px-4 py-2 sm:py-2.5 bg-red-100 hover:bg-red-200 text-red-800 rounded-lg text-xs sm:text-sm font-medium transition-colors"
>
<RefreshCw size={14} className="sm:w-4 sm:h-4" /> Try Again
</button>
</div>
</div>
)
}

View file

@ -0,0 +1,5 @@
import { RepositoriesSkeleton } from "@/components/repositories/RepositoriesSkeleton"
export default function RepositoriesLoading() {
return <RepositoriesSkeleton />
}

View file

@ -1,692 +1,33 @@
"use client"
import { GitPullRequest } from "lucide-react"
import { getAccountContext } from "@/lib/server/get-account-context"
import { getAllRepositories } from "@/app/dashboard/action"
import { RepositoryList } from "./_components/RepositoryList"
import React, { useState, useMemo, useEffect, useCallback, useRef } from "react"
import {
Clock,
GitPullRequest,
Search,
ChevronDown,
X,
RefreshCw,
Filter,
ArrowUpDown,
BookOpen,
} from "lucide-react"
import Image from "next/image"
import { Card } from "@/components/ui/card"
import { getUserIdAndUsername } from "@/app/utils/auth"
import Link from "next/link"
import { getAllRepositories, RepositoryWithUsage } from "@/app/dashboard/action"
import { useViewMode } from "@/app/app/ViewModeContext"
// Error Boundary Component
class RepositoryErrorBoundary extends React.Component<
{ children: React.ReactNode },
{ hasError: boolean; error?: Error }
> {
constructor(props: { children: React.ReactNode }) {
super(props)
this.state = { hasError: false }
}
static getDerivedStateFromError(error: Error) {
return { hasError: true, error }
}
componentDidCatch(error: Error, errorInfo: React.ErrorInfo) {
console.error("Repository page error:", error, errorInfo)
}
render() {
if (this.state.hasError) {
return (
<div className="flex justify-center items-center h-[70vh]">
<div className="bg-red-50 text-red-800 p-6 sm:p-8 rounded-xl max-w-md border border-red-200">
<h3 className="text-base sm:text-lg font-medium mb-2 sm:mb-3">Something went wrong</h3>
<p className="mb-3 sm:mb-4 text-sm sm:text-base">
There was an error loading the repositories page.
</p>
<button
onClick={() => window.location.reload()}
className="flex items-center gap-1 sm:gap-2 w-full justify-center px-3 sm:px-4 py-2 sm:py-2.5 bg-red-100 hover:bg-red-200 text-red-800 rounded-lg text-xs sm:text-sm font-medium transition-colors"
>
<RefreshCw size={14} className="sm:w-4 sm:h-4" /> Reload Page
</button>
</div>
</div>
)
}
return this.props.children
}
}
// Custom hook for debouncing
const useDebounce = (callback: () => void, delay: number) => {
const timeoutRef = useRef<NodeJS.Timeout>()
return useCallback(() => {
if (timeoutRef.current) {
clearTimeout(timeoutRef.current)
}
timeoutRef.current = setTimeout(callback, delay)
}, [callback, delay])
}
// Custom hook for detecting clicks outside of an element
const useOutsideClick = (callback: () => void) => {
const ref = React.useRef<HTMLDivElement>(null)
useEffect(() => {
const handleClickOutside = (event: MouseEvent) => {
if (ref.current && !ref.current.contains(event.target as Node)) {
callback()
}
}
document.addEventListener("mousedown", handleClickOutside)
return () => {
document.removeEventListener("mousedown", handleClickOutside)
}
}, [callback])
return ref
}
// Enhanced search component
const SearchBar = ({
searchQuery,
setSearchQuery,
}: {
searchQuery: string
setSearchQuery: (value: string) => void
}) => {
return (
<div className="relative flex-1 group">
<div className="relative">
<div className="absolute inset-y-0 left-0 flex items-center pl-3 pointer-events-none">
<Search
size={18}
className="text-muted-foreground/70 group-focus-within:text-primary transition-colors"
/>
</div>
<input
type="text"
placeholder="Search repositories..."
value={searchQuery}
onChange={e => setSearchQuery(e.target.value)}
className="block w-full rounded-xl border border-border bg-background/60 p-3 pl-10 text-foreground focus:border-primary focus:ring-1 focus:ring-primary transition-all duration-200"
/>
{searchQuery && (
<button
className="absolute inset-y-0 right-0 flex items-center pr-3 text-muted-foreground hover:text-foreground"
onClick={() => setSearchQuery("")}
>
<X size={18} className="opacity-70 hover:opacity-100" />
</button>
)}
</div>
</div>
)
}
// GitHub-style Repository Card Component
const RepositoryCard = ({ repo }: { repo: RepositoryWithUsage }) => (
<Link href={`/repositories/${repo.id}`}>
<Card
key={repo.id}
className="bg-card bg-muted/5 rounded-xl border border-border hover:border-primary/30 hover:shadow-md hover:shadow-primary/5 transition-all duration-300 overflow-hidden group"
>
<div className="p-5">
<div className="flex items-start">
{/* Circular avatar for organization */}
<div className="mr-3 flex-shrink-0">
{repo.avatarUrl ? (
<div className="w-9 h-9 sm:w-11 sm:h-11 rounded-full overflow-hidden border-2 border-border/50 group-hover:border-primary/20 transition-colors">
<Image
src={repo.avatarUrl}
alt={`${repo.organization} avatar`}
width={44}
height={44}
className="object-cover w-full h-full"
/>
</div>
) : (
<div className="w-9 h-9 sm:w-11 sm:h-11 rounded-full bg-gradient-to-br from-primary/10 to-primary/30 flex items-center justify-center border-2 border-border group-hover:from-primary/20 group-hover:to-primary/40 transition-colors">
<span className="text-primary font-semibold">
{repo.name?.substring(0, 1).toUpperCase() || "?"}
</span>
</div>
)}
</div>
<div className="flex-1 min-w-0">
{/* Repository name with visibility badge */}
<div className="flex items-center flex-wrap gap-1">
<h3 className="text-sm sm:text-base font-semibold text-primary hover:underline truncate">
{repo.name || "Unknown Repository"}
</h3>
<span
className={`ml-1 px-1.5 sm:px-2 py-0.5 text-xs font-medium rounded-full ${repo.is_private ? "bg-amber-100 text-amber-700" : "bg-emerald-100 text-emerald-700"}`}
>
{repo.is_private ? "Private" : "Public"}
</span>
</div>
{/* Organization/full name */}
<p className="text-xs sm:text-sm text-muted-foreground mb-2">
{repo.full_name || repo.name}
</p>
{/* Repository stats - matching schema data */}
<div className="flex items-center flex-wrap gap-1.5 sm:gap-2">
{/* Active status */}
<span
className={`inline-flex items-center px-1.5 sm:px-2 py-0.5 sm:py-1 rounded-full text-xs ${repo.is_active ? "bg-green-100 text-green-700" : "bg-gray-100 text-gray-600"}`}
>
<span
className={`inline-block w-1.5 sm:w-2 h-1.5 sm:h-2 rounded-full ${repo.is_active ? "bg-green-500" : "bg-gray-400"} mr-1 sm:mr-1.5`}
></span>
<span>{repo.is_active ? "Active" : "Inactive"}</span>
</span>
{/* GitHub Action */}
{repo.has_github_action && (
<span className="inline-flex items-center px-1.5 sm:px-2 py-0.5 sm:py-1 rounded-full bg-blue-50 text-xs text-blue-700">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
width="10"
height="10"
className="mr-1 fill-current sm:w-3 sm:h-3"
>
<path d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0016 8c0-4.42-3.58-8-8-8z"></path>
</svg>
Action
</span>
)}
{/* Members count if available */}
{repo.membersCount !== undefined && repo.membersCount > 0 && (
<span className="inline-flex items-center px-1.5 sm:px-2 py-0.5 sm:py-1 rounded-full bg-indigo-50 text-xs text-indigo-700">
<svg
xmlns="http://www.w3.org/2000/svg"
width="10"
height="10"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
className="mr-1 sm:w-3 sm:h-3"
>
<path d="M17 21v-2a4 4 0 0 0-4-4H5a4 4 0 0 0-4 4v2"></path>
<circle cx="9" cy="7" r="4"></circle>
<path d="M23 21v-2a4 4 0 0 0-3-3.87"></path>
<path d="M16 3.13a4 4 0 0 1 0 7.75"></path>
</svg>
{repo.membersCount}
</span>
)}
</div>
</div>
</div>
{/* Last optimized date */}
{repo.last_optimized && (
<div className="mt-3 sm:mt-4 text-xs text-muted-foreground flex items-center">
<Clock size={10} className="mr-1 sm:w-3 sm:h-3" />
Last optimized: {new Date(repo.last_optimized).toLocaleDateString()}
</div>
)}
</div>
</Card>
</Link>
)
// Import skeleton loaders
import {
RepositoriesSkeleton,
RepositoriesRefreshingSkeleton,
} from "@/components/repositories/RepositoriesSkeleton"
// Loading State Component (now using skeleton loaders)
const RepositoriesLoading = ({ isRefreshing = false }: { isRefreshing?: boolean }) =>
isRefreshing ? (
<RepositoriesRefreshingSkeleton />
) : (
<RepositoriesSkeleton message="Loading repositories..." />
)
// Page Header Component
const PageHeader = ({ totalCount }: { totalCount: number }) => (
<div className="mb-6 sm:mb-8">
<div className="flex items-center gap-3 mb-2">
<h1 className="text-xl sm:text-2xl font-bold">Repositories</h1>
<div className="px-2 py-0.5 sm:px-2.5 sm:py-1 bg-primary/10 text-primary rounded-full text-xs sm:text-sm font-medium">
{totalCount} total
</div>
</div>
</div>
)
// Main component for repository list with filters
const RepositoryList = ({ repositories }: { repositories: RepositoryWithUsage[] }) => {
const [searchQuery, setSearchQuery] = useState("")
const [filter, setFilter] = useState<"all" | "active" | "public" | "private">("all")
const [isFilterDropdownOpen, setIsFilterDropdownOpen] = useState(false)
const [sortBy, setSortBy] = useState<"name">("name")
const [isSortDropdownOpen, setIsSortDropdownOpen] = useState(false)
const filterDropdownRef = useOutsideClick(() => setIsFilterDropdownOpen(false))
const sortDropdownRef = useOutsideClick(() => setIsSortDropdownOpen(false))
const getSortLabel = (sortType: string) => {
switch (sortType) {
case "name":
return "Name"
default:
return "Last Optimized"
}
}
const filteredRepositories = useMemo(() => {
// Add safety check for repositories array
if (!repositories || !Array.isArray(repositories)) {
return []
}
let repos = repositories.filter(repo => {
// Add safety checks for repo properties
if (!repo) return false
// Search in name and full_name with safety checks
const matchesSearch =
searchQuery === "" ||
(repo.name && repo.name.toLowerCase().includes(searchQuery.toLowerCase())) ||
(repo.full_name && repo.full_name.toLowerCase().includes(searchQuery.toLowerCase())) ||
(repo.organization && repo.organization.toLowerCase().includes(searchQuery.toLowerCase()))
if (!matchesSearch) return false
switch (filter) {
case "active":
return repo.is_active
case "public":
return !repo.is_private
case "private":
return repo.is_private
default:
return true
}
})
// Sort repositories with safety check
switch (sortBy) {
case "name":
repos = repos.sort((a, b) => {
const nameA = a?.name || ""
const nameB = b?.name || ""
return nameA.localeCompare(nameB)
})
break
}
return repos
}, [repositories, searchQuery, filter, sortBy])
return (
<>
<div className="flex flex-col md:flex-row gap-3 sm:gap-4 mb-5 sm:mb-6">
<SearchBar searchQuery={searchQuery} setSearchQuery={setSearchQuery} />
<div className="flex gap-2">
<div className="relative" ref={filterDropdownRef}>
<button
onClick={() => setIsFilterDropdownOpen(!isFilterDropdownOpen)}
className="flex items-center gap-1 sm:gap-2 px-3 sm:px-4 py-2.5 sm:py-3 text-xs sm:text-sm bg-background border border-border rounded-xl hover:border-primary/50 transition-colors focus:outline-none focus:ring-1 focus:ring-primary"
>
<Filter size={14} className="text-muted-foreground sm:w-4 sm:h-4" />
<span>
{filter === "all"
? "All"
: filter.charAt(0).toUpperCase() + filter.slice(1).replace(/-/g, " ")}
</span>
<ChevronDown
size={14}
className={`transition-transform text-muted-foreground sm:w-4 sm:h-4 ${isFilterDropdownOpen ? "rotate-180" : ""}`}
/>
</button>
{isFilterDropdownOpen && (
<div className="absolute z-10 mt-2 w-48 sm:w-52 bg-card rounded-xl shadow-lg overflow-hidden border border-border animate-in fade-in-50 slide-in-from-top-5">
<div className="py-1">
<button
onClick={() => {
setFilter("all")
setIsFilterDropdownOpen(false)
}}
className={`w-full px-3 sm:px-4 py-2 sm:py-2.5 text-left hover:bg-muted flex items-center ${filter === "all" ? "bg-primary/10 text-primary font-medium" : ""}`}
>
<span className="w-4 sm:w-5 h-4 sm:h-5 mr-1.5 sm:mr-2 flex items-center justify-center">
{filter === "all" && (
<span className="w-1.5 h-1.5 rounded-full bg-primary"></span>
)}
</span>
All repositories
</button>
</div>
<div className="border-t border-border py-1">
<button
onClick={() => {
setFilter("active")
setIsFilterDropdownOpen(false)
}}
className={`w-full px-3 sm:px-4 py-2 sm:py-2.5 text-left hover:bg-muted flex items-center ${filter === "active" ? "bg-primary/10 text-primary font-medium" : ""}`}
>
<span className="w-4 sm:w-5 h-4 sm:h-5 mr-1.5 sm:mr-2 flex items-center justify-center">
{filter === "active" && (
<span className="w-1.5 h-1.5 rounded-full bg-primary"></span>
)}
</span>
Active
</button>
</div>
<div className="border-t border-border py-1">
<button
onClick={() => {
setFilter("public")
setIsFilterDropdownOpen(false)
}}
className={`w-full px-3 sm:px-4 py-2 sm:py-2.5 text-left hover:bg-muted flex items-center ${filter === "public" ? "bg-primary/10 text-primary font-medium" : ""}`}
>
<span className="w-4 sm:w-5 h-4 sm:h-5 mr-1.5 sm:mr-2 flex items-center justify-center">
{filter === "public" && (
<span className="w-1.5 h-1.5 rounded-full bg-primary"></span>
)}
</span>
Public
</button>
<button
onClick={() => {
setFilter("private")
setIsFilterDropdownOpen(false)
}}
className={`w-full px-3 sm:px-4 py-2 sm:py-2.5 text-left hover:bg-muted flex items-center ${filter === "private" ? "bg-primary/10 text-primary font-medium" : ""}`}
>
<span className="w-4 sm:w-5 h-4 sm:h-5 mr-1.5 sm:mr-2 flex items-center justify-center">
{filter === "private" && (
<span className="w-1.5 h-1.5 rounded-full bg-primary"></span>
)}
</span>
Private
</button>
</div>
</div>
)}
</div>
<div className="relative" ref={sortDropdownRef}>
<button
onClick={() => setIsSortDropdownOpen(!isSortDropdownOpen)}
className="flex items-center gap-1 sm:gap-2 px-3 sm:px-4 py-2.5 sm:py-3 text-xs sm:text-sm bg-background border border-border rounded-xl hover:border-primary/50 transition-colors focus:outline-none focus:ring-1 focus:ring-primary"
>
<ArrowUpDown size={14} className="text-muted-foreground sm:w-4 sm:h-4" />
<span>Sort: {getSortLabel(sortBy)}</span>
<ChevronDown
size={14}
className={`transition-transform text-muted-foreground sm:w-4 sm:h-4 ${isSortDropdownOpen ? "rotate-180" : ""}`}
/>
</button>
{isSortDropdownOpen && (
<div className="absolute right-0 z-10 mt-2 w-48 sm:w-52 bg-card rounded-xl shadow-lg overflow-hidden border border-border animate-in fade-in-50 slide-in-from-top-5">
<div className="py-1">
<button
onClick={() => {
setSortBy("name")
setIsSortDropdownOpen(false)
}}
className={`w-full px-3 sm:px-4 py-2 sm:py-2.5 text-left hover:bg-muted flex items-center ${sortBy === "name" ? "bg-primary/10 text-primary font-medium" : ""}`}
>
<span className="w-4 sm:w-5 h-4 sm:h-5 mr-1.5 sm:mr-2 flex items-center justify-center">
{sortBy === "name" && (
<span className="w-1.5 h-1.5 rounded-full bg-primary"></span>
)}
</span>
Name
</button>
</div>
</div>
)}
</div>
</div>
</div>
{/* Search indicator */}
{searchQuery && (
<div className="flex items-center mb-4 sm:mb-5 ml-1">
<span className="text-xs sm:text-sm text-muted-foreground mr-1 sm:mr-2">
Searching for:
</span>
<div className="bg-primary/10 text-primary px-2 sm:px-3 py-1 sm:py-1.5 rounded-full text-xs sm:text-sm flex items-center gap-1 sm:gap-1.5">
<span>{searchQuery}</span>
<button
onClick={() => setSearchQuery("")}
className="text-primary hover:text-primary/80"
>
<X size={12} className="sm:w-4 sm:h-4" />
</button>
</div>
</div>
)}
{filteredRepositories.length === 0 ? (
<div className="text-center py-16 sm:py-20 bg-card/50 rounded-xl border border-dashed border-border">
<div className="inline-flex items-center justify-center w-16 h-16 sm:w-20 sm:h-20 rounded-full bg-muted/40 mb-3 sm:mb-4">
<Search size={24} className="text-muted-foreground sm:w-7 sm:h-7" />
</div>
<h3 className="text-base sm:text-lg font-medium mb-1 sm:mb-2">No repositories found</h3>
<p className="text-xs sm:text-sm text-muted-foreground max-w-md mx-auto">
{
"We couldn't find any repositories matching your search criteria. Try adjusting your filters or search term."
}
</p>
<button
onClick={() => {
setSearchQuery("")
setFilter("all")
}}
className="mt-4 sm:mt-5 px-4 sm:px-5 py-2 sm:py-2.5 bg-primary text-primary-foreground rounded-lg sm:rounded-xl text-xs sm:text-sm hover:bg-primary/90 transition-colors"
>
Clear filters
</button>
</div>
) : (
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-3 sm:gap-5">
{filteredRepositories.map(repo => (
<RepositoryCard key={repo.id} repo={repo} />
))}
</div>
)}
</>
)
}
// Main page component
const maxRetries = 3
function RepositoriesPage() {
const [repositories, setRepositories] = useState<RepositoryWithUsage[]>([])
const [loading, setLoading] = useState(true)
const [error, setError] = useState<string | null>(null)
const [isRefreshing, setIsRefreshing] = useState(false)
const [retryCount, setRetryCount] = useState(0)
const { currentOrg } = useViewMode()
const fetchRepositories = useCallback(
async (attempt = 0) => {
try {
setLoading(attempt === 0)
setError(null)
// Add a small delay for rapid refreshes and retries
if (attempt > 0) {
await new Promise(resolve => setTimeout(resolve, Math.pow(2, attempt) * 1000))
}
const data = await getUserIdAndUsername()
if (!data || !data.userId || !data.username) {
throw new Error("User authentication data not found")
}
const repos = await getAllRepositories(
currentOrg ? { orgId: currentOrg.id } : { userId: data.userId, username: data.username },
)
if (Array.isArray(repos)) {
setRepositories(repos)
setRetryCount(0) // Reset retry count on success
} else {
console.warn("Received non-array repositories data:", repos)
setRepositories([])
}
} catch (err) {
console.error(`Failed to fetch repositories (attempt ${attempt + 1}):`, err)
// If it's an auth error and we haven't exceeded retries, try again
if (
attempt < maxRetries &&
err instanceof Error &&
(err.message.includes("authentication") ||
err.message.includes("User authentication data not found") ||
err.message.includes("Unauthorized") ||
err.message.includes("No valid session found"))
) {
setRetryCount(attempt + 1)
return fetchRepositories(attempt + 1)
}
setError("Failed to load repositories. Please try again later.")
setRepositories([])
} finally {
setLoading(false)
setIsRefreshing(false)
}
},
[currentOrg],
)
// Debounced refresh to prevent rapid successive calls
const debouncedRefresh = useDebounce(() => {
setIsRefreshing(true)
fetchRepositories()
}, 300)
const handleRefresh = () => {
if (!isRefreshing && !loading) {
debouncedRefresh()
}
}
// Handle browser refresh with beforeunload
useEffect(() => {
const handleBeforeUnload = () => {
// Clear any pending timeouts
return null
}
window.addEventListener("beforeunload", handleBeforeUnload)
return () => window.removeEventListener("beforeunload", handleBeforeUnload)
}, [])
useEffect(() => {
// Check if user was recently authenticated
const lastAuthCheck = localStorage.getItem("lastAuthCheck")
const now = Date.now()
// If last auth check was less than 2 seconds ago, wait a bit
if (lastAuthCheck && now - parseInt(lastAuthCheck) < 2000) {
const delay = 2000 - (now - parseInt(lastAuthCheck))
setTimeout(() => {
fetchRepositories()
}, delay)
} else {
// Add a small delay to prevent race conditions on rapid refreshes
const timeoutId = setTimeout(() => {
fetchRepositories()
}, 100)
const cleanup = () => clearTimeout(timeoutId)
return cleanup
}
// Update last auth check time
localStorage.setItem("lastAuthCheck", now.toString())
}, [fetchRepositories])
// Refresh Button Component
const RefreshButton = () => (
<button
onClick={handleRefresh}
disabled={isRefreshing || loading}
className={`flex items-center gap-1 sm:gap-2 px-3 sm:px-4 py-1.5 sm:py-2 rounded-lg text-xs sm:text-sm border border-border bg-background hover:bg-muted/50 transition-colors ${
isRefreshing || loading ? "opacity-50 cursor-not-allowed" : ""
}`}
>
<RefreshCw
size={12}
className={`text-muted-foreground sm:w-4 sm:h-4 ${isRefreshing || loading ? "animate-spin" : ""}`}
/>
{isRefreshing ? "Refreshing..." : "Refresh"}
</button>
)
if (loading) {
return <RepositoriesLoading isRefreshing={isRefreshing} />
}
if (error) {
return (
<div className="flex justify-center items-center h-[70vh]">
<div className="bg-red-50 text-red-800 p-6 sm:p-8 rounded-xl max-w-md border border-red-200">
<h3 className="text-base sm:text-lg font-medium mb-2 sm:mb-3">
Unable to Load Repositories
</h3>
<p className="mb-3 sm:mb-4 text-sm sm:text-base">{error}</p>
{retryCount > 0 && (
<p className="mb-3 text-xs text-red-600">
Retry attempt: {retryCount}/{maxRetries}
</p>
)}
<button
onClick={() => fetchRepositories()}
className="flex items-center gap-1 sm:gap-2 w-full justify-center px-3 sm:px-4 py-2 sm:py-2.5 bg-red-100 hover:bg-red-200 text-red-800 rounded-lg text-xs sm:text-sm font-medium transition-colors"
>
<RefreshCw size={14} className="sm:w-4 sm:h-4" /> Try Again
</button>
</div>
</div>
)
}
export default async function RepositoriesPage() {
const accountPayload = await getAccountContext()
const repos = await getAllRepositories(accountPayload)
// Serialize Date objects for client component boundary
const repositories = (Array.isArray(repos) ? repos : []).map(repo => ({
...repo,
created_at: repo.created_at instanceof Date ? repo.created_at.toISOString() : repo.created_at,
last_optimized:
repo.last_optimized instanceof Date ? repo.last_optimized.toISOString() : repo.last_optimized,
}))
return (
<div className="flex-1 bg-background">
<div className="h-screen py-6 sm:py-8 px-4 sm:px-6 max-w-[1400px] mx-auto">
<PageHeader totalCount={repositories?.length || 0} />
<div className="min-h-screen py-6 sm:py-8 px-4 sm:px-6 max-w-[1400px] mx-auto">
<div className="mb-6 sm:mb-8">
<div className="flex items-center gap-3 mb-2">
<h1 className="text-xl sm:text-2xl font-bold">Repositories</h1>
<div className="px-2 py-0.5 sm:px-2.5 sm:py-1 bg-primary/10 text-primary rounded-full text-xs sm:text-sm font-medium">
{repositories.length} total
</div>
</div>
</div>
<div className="bg-card p-4 sm:p-6 rounded-xl border border-border">
<div className="flex justify-between items-center mb-4 sm:mb-6">
<h2 className="text-base sm:text-lg font-semibold flex items-center">
<BookOpen size={18} className="mr-2 text-primary" />
Repository List
</h2>
<RefreshButton />
</div>
{!repositories || repositories.length === 0 ? (
{repositories.length === 0 ? (
<div className="flex justify-center items-center min-h-[300px] sm:min-h-[400px] w-full">
<div className="text-center py-12 sm:py-16 bg-muted/10 rounded-xl border border-dashed border-border max-w-lg w-full px-5 sm:px-8">
<div className="inline-flex items-center justify-center w-12 h-12 sm:w-16 sm:h-16 rounded-full bg-muted/20 mb-3 sm:mb-4">
@ -727,12 +68,3 @@ function RepositoriesPage() {
</div>
)
}
// Main export with error boundary
export default function RepositoriesPageWrapper() {
return (
<RepositoryErrorBoundary>
<RepositoriesPage />
</RepositoryErrorBoundary>
)
}

View file

@ -3,9 +3,10 @@
import { CF_API } from "@/app/api/const"
import { ActionResponse, createErrorResponse, createSuccessResponse } from "@/lib/action-response"
import { getRepositoriesForAccountCached } from "@/lib/services/repository-utils"
import { getAccessToken } from "@auth0/nextjs-auth0"
import { auth0 } from "@/lib/auth0"
import { AccountPayload, buildOptimizationOrCondition, prisma } from "@codeflash-ai/common"
import * as Sentry from "@sentry/nextjs"
import { trackOptimizationReviewed } from "@/lib/analytics/tracking"
export interface DiffContent {
oldContent: string
@ -32,9 +33,9 @@ export interface GetStagingCodeParams {
export async function getStagingCodeFromApi(params: GetStagingCodeParams): Promise<ActionResponse<StagingCodeResponse>> {
const cfapiUrl = process.env.CODEFLASH_CFAPI_URL
const session = await getAccessToken({ refresh: true })
const session = await auth0.getAccessToken()
if (!cfapiUrl || !session?.accessToken) {
if (!cfapiUrl || !session?.token) {
return createErrorResponse("Please sign in to continue")
}
@ -43,7 +44,7 @@ export async function getStagingCodeFromApi(params: GetStagingCodeParams): Promi
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${session.accessToken}`,
Authorization: `Bearer ${session.token}`,
"X-CodeFlash-Source": "webapp",
},
body: JSON.stringify(params),
@ -92,9 +93,9 @@ export async function commitStagingCode(
commitMessage?: string,
): Promise<ActionResponse<CommitStagingCodeResponse>> {
const cfapiUrl = process.env.CODEFLASH_CFAPI_URL
const session = await getAccessToken({ refresh: true })
const session = await auth0.getAccessToken()
if (!cfapiUrl || !session?.accessToken) {
if (!cfapiUrl || !session?.token) {
return createErrorResponse("Please sign in to continue")
}
@ -103,7 +104,7 @@ export async function commitStagingCode(
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${session.accessToken}`,
Authorization: `Bearer ${session.token}`,
"X-CodeFlash-Source": "webapp",
},
body: JSON.stringify({
@ -156,31 +157,44 @@ export async function getOptimizationEventById({
trace_id,
...buildOptimizationOrCondition(payload, repoIds),
}
const event = await prisma.optimization_events.findFirst({
where,
include: {
repository: true,
},
})
if (event) {
// Fetch review_quality and review_explanation from optimization_features
const features = await prisma.optimization_features.findUnique({
where: { trace_id: event.trace_id },
// Fire both queries in parallel — features only needs trace_id, not the event result
const [event, features] = await Promise.all([
prisma.optimization_events.findFirst({
where,
include: {
repository: true,
},
}),
prisma.optimization_features.findUnique({
where: { trace_id },
select: {
review_quality: true,
review_explanation: true,
},
})
}),
])
return {
...event,
review_quality: features?.review_quality || null,
review_explanation: features?.review_explanation || null,
}
if (!event) {
return null
}
return event
// Track that this optimization was reviewed
const userId = "userId" in payload ? payload.userId : undefined
if (userId) {
trackOptimizationReviewed(userId, {
traceId: event.trace_id,
functionName: event.function_name,
repositoryName: event.repository?.full_name ?? null,
status: event.status,
})
}
return {
...event,
review_quality: features?.review_quality || null,
review_explanation: features?.review_explanation || null,
}
}
export async function saveOptimizationChanges({
eventId,
@ -275,9 +289,9 @@ export async function createPullRequest({
optimizedLineProfiler?: string
}): Promise<ActionResponse> {
const cfapiUrl = process.env.CODEFLASH_CFAPI_URL
const session = await getAccessToken({ refresh: true })
const session = await auth0.getAccessToken()
if (!cfapiUrl || !session?.accessToken) {
if (!cfapiUrl || !session?.token) {
return createErrorResponse("Please sign in to continue")
}
@ -291,7 +305,7 @@ export async function createPullRequest({
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${session.accessToken}`,
Authorization: `Bearer ${session.token}`,
"X-CodeFlash-Source": "webapp",
},
body: JSON.stringify({

View file

@ -3,7 +3,15 @@
import { useEffect, useState, useCallback, useRef } from "react"
import { useParams, useRouter } from "next/navigation"
import Image from "next/image"
import { Zap, CheckCircle, XCircle, MessageSquare, Loader2, GitCommit, BarChart3 } from "lucide-react"
import {
Zap,
CheckCircle,
XCircle,
MessageSquare,
Loader2,
GitCommit,
BarChart3,
} from "lucide-react"
import {
createPullRequest,
getOptimizationEventById,
@ -15,7 +23,12 @@ import {
commitStagingCode,
} from "./action"
import { getUserIdAndUsername } from "@/app/utils/auth"
import MonacoDiffEditorGithub from "@/components/Editor/monaco-diff-editor-github"
import dynamic from "next/dynamic"
const MonacoDiffEditorGithub = dynamic(
() => import("@/components/Editor/monaco-diff-editor-github"),
{ ssr: false },
)
import { toast } from "sonner"
import { MarkdownEditor } from "@/components/markdwon/markdown-editor"
import { MarkdownViewer } from "@/components/markdwon/markdown-viewer"
@ -653,8 +666,7 @@ export default function OptimizationReviewPage() {
// Check if we have empty diffContents for git_branch storage type (merged PR in privacy mode)
const isPrivacyModeWithNoDiff =
event.staging_storage_type === "git_branch" &&
Object.keys(diffContents).length === 0
event.staging_storage_type === "git_branch" && Object.keys(diffContents).length === 0
return (
<div className="min-h-screen bg-background">

View file

@ -5,7 +5,16 @@ import { useParams, useRouter } from "next/navigation"
import { ArrowLeft, Zap, Loader2, AlertTriangle } from "lucide-react"
import { getOptimizationEventById } from "../action"
import { getUserIdAndUsername } from "@/app/utils/auth"
import { LineProfilerView } from "@/components/LineProfiler"
import dynamic from "next/dynamic"
import { Skeleton } from "@/components/ui/skeleton"
const LineProfilerView = dynamic(
() => import("@/components/LineProfiler").then(mod => mod.LineProfilerView),
{
ssr: false,
loading: () => <Skeleton className="h-full w-full" />,
},
)
import { useViewMode } from "@/app/app/ViewModeContext"
import { toast } from "sonner"

View file

@ -0,0 +1,312 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import { prisma, buildOptimizationOrCondition } from "@codeflash-ai/common"
import { getRepositoriesForAccountCached } from "@/lib/services/repository-utils"
vi.mock("@/lib/server-action-timing", () => ({
withTiming: vi.fn((_name: string, fn: Function) => fn),
}))
vi.mock("@/lib/services/repository-utils", () => ({
getRepositoriesForAccountCached: vi.fn(),
}))
const mockPayload = { userId: "user-1", username: "testuser" }
const mockRepoIds = ["repo-1", "repo-2"]
const mockEvents = [
{
id: "evt-1",
trace_id: "trace-1",
function_name: "calculate",
file_path: "src/utils.py",
repository_id: "repo-1",
status: "approved",
is_staging: true,
created_at: new Date("2024-06-01"),
repository: { id: "repo-1", full_name: "org/repo", name: "repo" },
},
{
id: "evt-2",
trace_id: "trace-2",
function_name: "process",
file_path: "src/main.py",
repository_id: "repo-2",
status: "pending",
is_staging: true,
created_at: new Date("2024-06-02"),
repository: { id: "repo-2", full_name: "org/repo2", name: "repo2" },
},
]
const mockFeatures = [
{
trace_id: "trace-1",
review_quality: "high",
review_explanation: "Great optimization",
},
]
describe("getAllOptimizationEvents", () => {
let getAllOptimizationEvents: typeof import("../action").getAllOptimizationEvents
beforeEach(async () => {
vi.mocked(getRepositoriesForAccountCached).mockResolvedValue({
repoIds: mockRepoIds,
repos: [],
} as any)
vi.mocked(buildOptimizationOrCondition).mockReturnValue({})
const mod = await import("../action")
getAllOptimizationEvents = mod.getAllOptimizationEvents
})
describe("Path B: standard Prisma query", () => {
it("calls findMany and count in parallel", async () => {
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue(mockEvents as any)
vi.mocked(prisma.optimization_events.count).mockResolvedValue(2)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue([])
await getAllOptimizationEvents({ payload: mockPayload as any })
expect(prisma.optimization_events.findMany).toHaveBeenCalledTimes(1)
expect(prisma.optimization_events.count).toHaveBeenCalledTimes(1)
})
it("batch-fetches optimization_features by trace_id array (not N+1)", async () => {
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue(mockEvents as any)
vi.mocked(prisma.optimization_events.count).mockResolvedValue(2)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue(mockFeatures as any)
await getAllOptimizationEvents({ payload: mockPayload as any })
// Single batch query with all trace IDs — NOT one per event
expect(prisma.optimization_features.findMany).toHaveBeenCalledTimes(1)
expect(prisma.optimization_features.findMany).toHaveBeenCalledWith({
where: { trace_id: { in: ["trace-1", "trace-2"] } },
select: {
trace_id: true,
review_quality: true,
review_explanation: true,
},
})
})
it("merges review_quality into events", async () => {
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue(mockEvents as any)
vi.mocked(prisma.optimization_events.count).mockResolvedValue(2)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue(mockFeatures as any)
const result = await getAllOptimizationEvents({ payload: mockPayload as any })
expect(result.events[0].review_quality).toBe("high")
expect(result.events[0].review_explanation).toBe("Great optimization")
expect(result.events[1].review_quality).toBeNull()
})
it("returns totalCount from count query", async () => {
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue([])
vi.mocked(prisma.optimization_events.count).mockResolvedValue(42)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue([])
const result = await getAllOptimizationEvents({ payload: mockPayload as any })
expect(result.totalCount).toBe(42)
})
it("applies pagination with skip and take", async () => {
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue([])
vi.mocked(prisma.optimization_events.count).mockResolvedValue(0)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue([])
await getAllOptimizationEvents({
payload: mockPayload as any,
page: 3,
pageSize: 25,
})
expect(prisma.optimization_events.findMany).toHaveBeenCalledWith(
expect.objectContaining({
skip: 50, // (3 - 1) * 25
take: 25,
}),
)
})
it("uses default sort (created_at desc) when no sort provided", async () => {
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue([])
vi.mocked(prisma.optimization_events.count).mockResolvedValue(0)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue([])
await getAllOptimizationEvents({ payload: mockPayload as any })
expect(prisma.optimization_events.findMany).toHaveBeenCalledWith(
expect.objectContaining({
orderBy: { created_at: "desc" },
}),
)
})
it("applies search filter", async () => {
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue([])
vi.mocked(prisma.optimization_events.count).mockResolvedValue(0)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue([])
await getAllOptimizationEvents({
payload: mockPayload as any,
search: "calc",
})
const callArgs = vi.mocked(prisma.optimization_events.findMany).mock.calls[0][0] as any
const andClause = callArgs.where.AND
expect(andClause).toBeDefined()
expect(andClause.length).toBeGreaterThan(0)
// Search should include OR across function_name, file_path, repository.full_name
const orClause = andClause.find((c: any) => c.OR)?.OR
expect(orClause).toHaveLength(3)
expect(orClause[0]).toEqual({
function_name: { contains: "calc", mode: "insensitive" },
})
})
it("applies repository_id filter", async () => {
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue([])
vi.mocked(prisma.optimization_events.count).mockResolvedValue(0)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue([])
await getAllOptimizationEvents({
payload: mockPayload as any,
filter: { repository_id: "repo-1" },
})
const callArgs = vi.mocked(prisma.optimization_events.findMany).mock.calls[0][0] as any
const andClause = callArgs.where.AND
expect(andClause).toBeDefined()
expect(andClause).toContainEqual({ repository_id: "repo-1" })
})
})
describe("Path A: raw SQL query (review_quality sort/filter)", () => {
it("triggers when sort includes review_quality", async () => {
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce([]) // events
.mockResolvedValueOnce([{ count: BigInt(0) }]) // count
await getAllOptimizationEvents({
payload: mockPayload as any,
sort: { review_quality: "desc" },
})
expect(prisma.$queryRawUnsafe).toHaveBeenCalledTimes(2)
// Should NOT use standard Prisma findMany
expect(prisma.optimization_events.findMany).not.toHaveBeenCalled()
})
it("triggers when filter includes review_quality", async () => {
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ count: BigInt(0) }])
await getAllOptimizationEvents({
payload: mockPayload as any,
filter: { review_quality: "high" },
})
expect(prisma.$queryRawUnsafe).toHaveBeenCalledTimes(2)
})
it("returns correct totalCount from BigInt conversion", async () => {
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ count: BigInt(99) }])
const result = await getAllOptimizationEvents({
payload: mockPayload as any,
sort: { review_quality: "asc" },
})
expect(result.totalCount).toBe(99)
})
it("maps JOIN results to include repository object", async () => {
const rawEvents = [
{
id: "evt-1",
trace_id: "trace-1",
review_quality: "high",
review_explanation: "Good",
repo_full_name: "org/repo",
repo_name: "repo",
repo_id: "repo-1",
},
]
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce(rawEvents)
.mockResolvedValueOnce([{ count: BigInt(1) }])
const result = await getAllOptimizationEvents({
payload: mockPayload as any,
sort: { review_quality: "desc" },
})
expect(result.events[0].repository).toEqual({
id: "repo-1",
full_name: "org/repo",
name: "repo",
})
})
it("sets repository to null when repo_id is missing", async () => {
const rawEvents = [
{
id: "evt-1",
trace_id: "trace-1",
review_quality: null,
review_explanation: null,
repo_full_name: null,
repo_name: null,
repo_id: null,
},
]
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce(rawEvents)
.mockResolvedValueOnce([{ count: BigInt(1) }])
const result = await getAllOptimizationEvents({
payload: mockPayload as any,
sort: { review_quality: "desc" },
})
expect(result.events[0].repository).toBeNull()
})
it("includes LEFT JOIN in raw SQL queries", async () => {
vi.mocked(prisma.$queryRawUnsafe)
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ count: BigInt(0) }])
await getAllOptimizationEvents({
payload: mockPayload as any,
sort: { review_quality: "desc" },
})
const sql = vi.mocked(prisma.$queryRawUnsafe).mock.calls[0][0] as string
expect(sql).toContain("LEFT JOIN optimization_features")
expect(sql).toContain("LEFT JOIN repositories")
})
})
describe("edge cases", () => {
it("handles empty repoIds", async () => {
vi.mocked(getRepositoriesForAccountCached).mockResolvedValue({
repoIds: [],
repos: [],
} as any)
vi.mocked(prisma.optimization_events.findMany).mockResolvedValue([])
vi.mocked(prisma.optimization_events.count).mockResolvedValue(0)
vi.mocked(prisma.optimization_features.findMany).mockResolvedValue([])
const result = await getAllOptimizationEvents({ payload: mockPayload as any })
expect(result.events).toEqual([])
})
})
})

View file

@ -0,0 +1,958 @@
"use client"
import { useState, useCallback, useRef, useEffect } from "react"
import { Input } from "@/components/ui/input"
import { Badge } from "@/components/ui/badge"
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table"
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select"
import {
Search,
FileCode2,
Zap,
Clock,
ChevronLeft,
ChevronRight,
Filter,
ArrowUpDown,
ArrowUp,
ArrowDown,
} from "lucide-react"
import { formatDistanceToNow } from "date-fns"
import { useRouter } from "next/navigation"
import { Button } from "@/components/ui/button"
import { getAllOptimizationEvents } from "../action"
import Image from "next/image"
import { ReviewQualityBadge } from "@/components/ui/quality_badge"
import type { AccountPayload } from "@codeflash-ai/common"
interface Repository {
id: string
full_name?: string
}
interface DiffContent {
oldContent: string
newContent: string
}
interface EventMetadata {
diffContents?: Record<string, DiffContent>
[key: string]: unknown
}
interface OptimizationEvent {
id: string
function_name?: string
file_path?: string
repository?: Repository | null | undefined
speedup_x?: number
speedup_pct?: number
metadata?: EventMetadata | null | undefined
created_at: string
status?: string
event_type?: string
trace_id: string
review_quality: string
}
interface FilterState {
search: string
repositoryId: string | null
status: string
eventType: string
reviewQuality: string
sortBy: string
page: number
}
interface OptimizationsTableProps {
initialEvents: OptimizationEvent[]
initialTotalCount: number
availableRepositories: Array<{ id: string; full_name: string }>
accountPayload: AccountPayload
}
function TableSkeleton() {
return (
<>
{Array.from({ length: 5 }).map((_, index) => (
<TableRow key={index}>
<TableCell>
<div className="flex items-start gap-3">
<div className="h-4 w-4 bg-muted animate-pulse rounded mt-1 flex-shrink-0" />
<div className="flex-1 min-w-0 space-y-2">
<div className="h-4 bg-muted animate-pulse rounded w-3/4" />
<div className="h-3 bg-muted animate-pulse rounded w-1/2" />
</div>
</div>
</TableCell>
<TableCell>
<div className="flex items-center gap-3">
<div className="h-8 w-8 bg-muted animate-pulse rounded-full flex-shrink-0" />
<div className="h-4 bg-muted animate-pulse rounded w-32" />
</div>
</TableCell>
<TableCell className="text-center">
<div className="h-6 bg-muted animate-pulse rounded-full w-20 mx-auto" />
</TableCell>
<TableCell className="text-center">
<div className="h-6 bg-muted animate-pulse rounded-full w-24 mx-auto" />
</TableCell>
<TableCell className="text-center">
<div className="h-6 bg-muted animate-pulse rounded-full w-16 mx-auto" />
</TableCell>
<TableCell className="text-center">
<div className="h-6 bg-muted animate-pulse rounded-full w-24 mx-auto" />
</TableCell>
<TableCell className="text-center">
<div className="flex items-center justify-center gap-2">
<div className="h-6 bg-muted animate-pulse rounded-full w-12" />
<div className="h-6 bg-muted animate-pulse rounded-full w-12" />
</div>
</TableCell>
<TableCell className="text-right">
<div className="flex items-center justify-end gap-2">
<div className="h-3 w-3 bg-muted animate-pulse rounded flex-shrink-0" />
<div className="h-4 bg-muted animate-pulse rounded w-24" />
</div>
</TableCell>
</TableRow>
))}
</>
)
}
function calculateDiffStats(
diffContents: Record<string, { oldContent: string; newContent: string }>,
) {
let totalAdditions = 0
let totalDeletions = 0
Object.entries(diffContents).forEach(([, { oldContent, newContent }]) => {
const oldLines = oldContent.split("\n").filter(line => line.trim() !== "")
const newLines = newContent.split("\n").filter(line => line.trim() !== "")
const oldLineMap = new Map<string, number>()
const newLineMap = new Map<string, number>()
oldLines.forEach(line => {
const trimmed = line.trim()
oldLineMap.set(trimmed, (oldLineMap.get(trimmed) || 0) + 1)
})
newLines.forEach(line => {
const trimmed = line.trim()
newLineMap.set(trimmed, (newLineMap.get(trimmed) || 0) + 1)
})
for (const [line, oldCount] of Array.from(oldLineMap)) {
const newCount = newLineMap.get(line) || 0
if (oldCount > newCount) {
totalDeletions += oldCount - newCount
}
}
for (const [line, newCount] of Array.from(newLineMap)) {
const oldCount = oldLineMap.get(line) || 0
if (newCount > oldCount) {
totalAdditions += newCount - oldCount
}
}
})
return { totalAdditions, totalDeletions }
}
function ClickableTableRow({
event,
children,
onRowClick,
}: {
event: OptimizationEvent
children: React.ReactNode
onRowClick: (eventId: string) => void
}) {
const handleRowClick = useCallback(
(e: React.MouseEvent) => {
if ((e.target as HTMLElement).closest('a[href^="http"]')) {
return
}
onRowClick(event.trace_id)
},
[event.trace_id, onRowClick],
)
return (
<TableRow
key={event.id}
className="group cursor-pointer hover:bg-muted"
onClick={handleRowClick}
>
{children}
</TableRow>
)
}
export function OptimizationsTable({
initialEvents,
initialTotalCount,
availableRepositories,
accountPayload,
}: OptimizationsTableProps) {
const router = useRouter()
const [events, setEvents] = useState<OptimizationEvent[]>(initialEvents)
const [totalCount, setTotalCount] = useState(initialTotalCount)
const [isLoading, setIsLoading] = useState(false)
const [error, setError] = useState<string | null>(null)
const [filters, setFilters] = useState<FilterState>({
search: "",
repositoryId: null,
status: "all",
eventType: "all",
reviewQuality: "all",
sortBy: "created_at_desc",
page: 1,
})
const pageSize = 10
const isInitialMount = useRef(true)
const debounceTimer = useRef<NodeJS.Timeout>(undefined)
const loadEvents = useCallback(async () => {
setIsLoading(true)
setError(null)
try {
const filter: Record<string, string | null | { not: null }> = {}
if (filters.repositoryId === "none") {
filter.repository_id = null
} else if (filters.repositoryId) {
filter.repository_id = filters.repositoryId
}
if (filters.status !== "all") {
filter.status = filters.status
}
if (filters.eventType !== "all") {
filter.event_type = filters.eventType
}
if (filters.reviewQuality !== "all") {
filter.review_quality = filters.reviewQuality
}
const [sortField, sortDirection] = filters.sortBy.split("_").reduce(
(acc, part, index, arr) => {
if (index === arr.length - 1 && (part === "asc" || part === "desc")) {
return [acc[0], part]
}
return [acc[0] ? `${acc[0]}_${part}` : part, acc[1]]
},
["", "desc"] as [string, string],
)
const sort: Record<string, "asc" | "desc"> = {
[sortField]: sortDirection as "asc" | "desc",
}
const data = await getAllOptimizationEvents({
payload: accountPayload,
search: filters.search,
filter,
sort,
page: filters.page,
pageSize,
})
type RawEvent = OptimizationEvent & {
repository?: { id: string; full_name?: string; name?: string } | null
}
const transformedEvents: OptimizationEvent[] = (data?.events || []).map(
(event: RawEvent) => ({
...event,
metadata: event.metadata as EventMetadata | null | undefined,
repository: event.repository
? {
id: event.repository.id,
full_name: event.repository.full_name || event.repository.name,
}
: null,
}),
)
setEvents(transformedEvents)
setTotalCount(data?.totalCount || 0)
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to load events")
} finally {
setIsLoading(false)
}
}, [filters, accountPayload, pageSize])
// Load events when filters change (skip initial mount — server provided that data)
useEffect(() => {
if (isInitialMount.current) {
isInitialMount.current = false
return
}
if (debounceTimer.current) {
clearTimeout(debounceTimer.current)
}
const hasSearchChanged = filters.search !== ""
if (hasSearchChanged) {
debounceTimer.current = setTimeout(() => {
loadEvents()
}, 300)
} else {
loadEvents()
}
return () => {
if (debounceTimer.current) {
clearTimeout(debounceTimer.current)
}
}
}, [filters, loadEvents])
const handleRowClick = useCallback(
(traceId: string) => {
router.push(`/review-optimizations/${traceId}`)
},
[router],
)
const updateFilter = useCallback((key: keyof FilterState, value: string | number | null) => {
setFilters(prev => ({
...prev,
[key]: value,
...(key !== "page" && { page: 1 }),
}))
}, [])
const clearFilters = useCallback(() => {
setFilters({
search: "",
repositoryId: null,
status: "all",
eventType: "all",
reviewQuality: "all",
sortBy: "created_at_desc",
page: 1,
})
}, [])
const hasActiveFilters =
filters.search ||
filters.repositoryId !== null ||
filters.status !== "all" ||
filters.eventType !== "all" ||
filters.reviewQuality !== "all" ||
filters.sortBy !== "created_at_desc"
const totalPages = Math.ceil(totalCount / pageSize)
const handlePageChange = useCallback(
(newPage: number) => {
if (newPage >= 1 && newPage <= totalPages) {
updateFilter("page", newPage)
}
},
[totalPages, updateFilter],
)
const getSortIcon = useCallback(
(field: string) => {
if (filters.sortBy.startsWith(field)) {
return filters.sortBy.endsWith("_asc") ? (
<ArrowUp className="h-4 w-4" />
) : (
<ArrowDown className="h-4 w-4" />
)
}
return <ArrowUpDown className="h-4 w-4 opacity-50" />
},
[filters.sortBy],
)
const toggleSort = useCallback(
(field: string) => {
const newSort = filters.sortBy.startsWith(field)
? filters.sortBy === `${field}_desc`
? `${field}_asc`
: `${field}_desc`
: `${field}_desc`
updateFilter("sortBy", newSort)
},
[filters.sortBy, updateFilter],
)
const getSpeedupBadge = useCallback((speedup?: number, speedupPct?: number) => {
if (typeof speedup !== "number" || typeof speedupPct !== "number") return null
const clamp = (v: number, min: number, max: number) => Math.min(Math.max(v, min), max)
const x = clamp(speedup, 1, 300)
const t = (x - 1) / 299
const hue = 158
const lightness = 95 - t * (95 - 35)
const saturation = 45 + t * (75 - 45)
const textColor = lightness < 60 ? "#fff" : "#047857"
const borderLightness = lightness > 60 ? lightness - 8 : lightness + 8
const borderSaturation = saturation > 70 ? saturation - 10 : saturation + 10
const bgColor = `hsl(${hue}, ${saturation}%, ${lightness}%)`
const borderColor = `hsl(${hue}, ${borderSaturation}%, ${borderLightness}%)`
return (
<Badge
variant="default"
className="font-mono text-[11px] px-2 py-0.5 whitespace-nowrap font-medium"
style={{
backgroundColor: bgColor,
color: textColor,
border: `1px solid ${borderColor}`,
}}
>
{speedup.toFixed(2)}x ({speedupPct.toFixed(0).replace(/\B(?=(\d{3})+(?!\d))/g, ",")}%)
</Badge>
)
}, [])
const getStatusBadge = useCallback((status?: string) => {
if (!status) return null
const variants: Record<string, { className: string; label: string }> = {
approved: {
className:
"bg-green-100 text-green-800 border-green-300 dark:bg-green-900/30 dark:text-green-100 dark:border-green-700",
label: "Approved",
},
rejected: {
className:
"bg-red-100 text-red-800 border-red-300 dark:bg-red-900/30 dark:text-red-100 dark:border-red-700",
label: "Rejected",
},
}
const variant = variants[status] || {
className: "bg-gray-100 text-gray-800 dark:bg-gray-800 dark:text-gray-100",
label: status,
}
return (
<Badge variant="secondary" className={variant.className}>
{variant.label}
</Badge>
)
}, [])
const getEventTypeBadge = useCallback((eventType?: string) => {
if (!eventType) return null
const variants: Record<string, { className: string; label: string }> = {
pr_created: {
className:
"bg-blue-100 text-blue-800 border-blue-300 dark:bg-blue-900/30 dark:text-blue-100 dark:border-blue-700",
label: "PR Created",
},
pr_merged: {
className:
"bg-purple-100 text-purple-800 border-purple-300 dark:bg-purple-900/30 dark:text-purple-100 dark:border-purple-700",
label: "PR Merged",
},
pr_closed: {
className:
"bg-orange-100 text-orange-800 border-orange-300 dark:bg-orange-900/30 dark:text-orange-100 dark:border-orange-700",
label: "PR Closed",
},
"no-pr": {
className:
"bg-gray-100 text-gray-800 border-gray-300 dark:bg-gray-800 dark:text-gray-100 dark:border-gray-600",
label: "Staged Changes",
},
}
const variant = variants[eventType] || {
className: "bg-gray-100 text-gray-800 dark:bg-gray-800 dark:text-gray-100",
label: eventType,
}
return (
<Badge variant="secondary" className={variant.className}>
{variant.label}
</Badge>
)
}, [])
return (
<div className="py-8 px-4">
<div className="mb-8">
<h1 className="text-3xl font-bold mb-2">Review Optimizations</h1>
</div>
{/* Search and Filters */}
<div className="mb-6">
<div className="flex flex-wrap items-center gap-3">
<div className="relative flex-1 min-w-[200px] max-w-md">
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 text-muted-foreground h-4 w-4" />
<Input
placeholder="Search by function name, file path, or repository name..."
value={filters.search}
onChange={e => updateFilter("search", e.target.value)}
className="pl-10 w-full"
/>
</div>
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<Filter className="h-4 w-4" />
<span className="hidden sm:inline">Filters:</span>
</div>
<Select
value={filters.repositoryId || "all"}
onValueChange={value =>
updateFilter(
"repositoryId",
value === "all" ? null : value === "none" ? "none" : value,
)
}
>
<SelectTrigger
className={`w-[180px] sm:w-[220px] ${
filters.repositoryId === null
? "text-muted-foreground [&>span]:text-muted-foreground"
: ""
}`}
>
<SelectValue placeholder="All Repositories" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all">All Repositories</SelectItem>
<SelectItem value="none">Without Repository</SelectItem>
{availableRepositories.map(repo => (
<SelectItem key={repo.id} value={repo.id}>
{repo.full_name}
</SelectItem>
))}
</SelectContent>
</Select>
<Select value={filters.status} onValueChange={value => updateFilter("status", value)}>
<SelectTrigger
className={`w-[120px] sm:w-[150px] ${
filters.status === "all"
? "text-muted-foreground [&>span]:text-muted-foreground"
: ""
}`}
>
<SelectValue placeholder="Status" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all">Reviews</SelectItem>
<SelectItem value="approved">Approved</SelectItem>
<SelectItem value="rejected">Rejected</SelectItem>
</SelectContent>
</Select>
<Select
value={filters.eventType}
onValueChange={value => updateFilter("eventType", value)}
>
<SelectTrigger
className={`w-[120px] sm:w-[150px] ${
filters.eventType === "all"
? "text-muted-foreground [&>span]:text-muted-foreground"
: ""
}`}
>
<SelectValue placeholder="Event Type" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all">Status</SelectItem>
<SelectItem value="pr_created">PR Created</SelectItem>
<SelectItem value="pr_merged">PR Merged</SelectItem>
<SelectItem value="pr_closed">PR Closed</SelectItem>
<SelectItem value="no-pr">Staged Changes</SelectItem>
</SelectContent>
</Select>
<Select
value={filters.reviewQuality}
onValueChange={value => updateFilter("reviewQuality", value)}
>
<SelectTrigger
className={`w-[120px] sm:w-[150px] ${
filters.reviewQuality === "all"
? "text-muted-foreground [&>span]:text-muted-foreground"
: ""
}`}
>
<SelectValue placeholder="Quality" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all">All Quality</SelectItem>
<SelectItem value="high">High</SelectItem>
<SelectItem value="medium">Medium</SelectItem>
<SelectItem value="low">Low</SelectItem>
</SelectContent>
</Select>
<Select value={filters.sortBy} onValueChange={value => updateFilter("sortBy", value)}>
<SelectTrigger
className={`w-[140px] sm:w-[200px] ${
filters.sortBy === "created_at_desc"
? "text-muted-foreground [&>span]:text-muted-foreground"
: ""
}`}
>
<SelectValue placeholder="Sort by" />
</SelectTrigger>
<SelectContent>
<SelectItem value="created_at_desc">Newest</SelectItem>
<SelectItem value="created_at_asc">Oldest</SelectItem>
<SelectItem value="speedup_x_desc">Speedup (Highest)</SelectItem>
<SelectItem value="speedup_x_asc">Speedup (Lowest)</SelectItem>
<SelectItem value="review_quality_desc">Quality (High to Low)</SelectItem>
<SelectItem value="review_quality_asc">Quality (Low to High)</SelectItem>
</SelectContent>
</Select>
{hasActiveFilters && (
<Button
variant="ghost"
size="sm"
onClick={clearFilters}
className="text-muted-foreground hover:text-foreground"
>
Clear
</Button>
)}
</div>
</div>
{error && (
<div className="mb-6 p-4 rounded-lg bg-destructive/10 border border-destructive/20">
<p className="text-destructive text-sm">Error: {error}</p>
<Button
variant="outline"
size="sm"
onClick={loadEvents}
className="mt-2"
disabled={isLoading}
>
Retry
</Button>
</div>
)}
<div className="rounded-lg border bg-card">
<Table>
<TableHeader>
<TableRow>
<TableHead className="w-[25%]">FUNCTION / FILE</TableHead>
<TableHead className="w-[18%]">REPOSITORY</TableHead>
<TableHead className="text-center">REVIEW</TableHead>
<TableHead className="text-center">STATUS</TableHead>
<TableHead
className="text-center cursor-pointer hover:bg-muted/50"
onClick={() => toggleSort("review_quality")}
>
<div className="flex items-center justify-center gap-1">
<span>QUALITY</span>
{getSortIcon("review_quality")}
</div>
</TableHead>
<TableHead
className="text-center cursor-pointer hover:bg-muted/50"
onClick={() => toggleSort("speedup_x")}
>
<div className="flex items-center justify-center gap-1">
<span>SPEEDUP</span>
{getSortIcon("speedup_x")}
</div>
</TableHead>
<TableHead className="text-center">CHANGES</TableHead>
<TableHead
className="text-right cursor-pointer hover:bg-muted/50"
onClick={() => toggleSort("created_at")}
>
<div className="flex items-center justify-end gap-1">
<span>CREATED</span>
{getSortIcon("created_at")}
</div>
</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{isLoading ? (
<TableSkeleton />
) : events.length === 0 ? (
<TableRow>
<TableCell colSpan={8} className="py-12">
<div className="mx-auto max-w-lg text-center bg-muted/5 border border-dashed border-border rounded-xl px-6 py-10">
<div className="inline-flex items-center justify-center w-14 h-14 rounded-full bg-muted/30 mb-4">
<Zap className="h-8 w-8 text-muted-foreground" />
</div>
<h3 className="text-base sm:text-lg font-semibold mb-2">
No optimizations to review yet
</h3>
<p className="text-xs sm:text-sm text-muted-foreground mb-3">
Run `codeflash --all with --staging-review` in a repository or install the VS
Code extension to trigger your first optimization review.
</p>
{hasActiveFilters ? (
<p className="text-xs sm:text-sm text-muted-foreground">
Filters are currently hiding resultsclear them to see everything.
</p>
) : (
<div className="flex flex-col sm:flex-row items-center justify-center gap-2 sm:gap-3 mt-4">
<a
href="/onboarding"
className="w-full sm:w-auto inline-flex items-center justify-center rounded-lg bg-primary px-4 py-2 text-xs sm:text-sm font-medium text-primary-foreground hover:bg-primary/90 transition-colors"
>
View setup steps
</a>
<a
href="https://docs.codeflash.ai/editor-plugins/vscode"
target="_blank"
rel="noopener noreferrer"
className="w-full sm:w-auto inline-flex items-center justify-center rounded-lg border border-border px-4 py-2 text-xs sm:text-sm font-medium text-foreground hover:bg-muted/40 transition-colors"
>
Install VS Code extension
</a>
</div>
)}
</div>
</TableCell>
</TableRow>
) : (
events.map((event: OptimizationEvent) => {
let diffStats: { totalAdditions: number; totalDeletions: number } = {
totalAdditions: 0,
totalDeletions: 0,
}
if (
event.metadata &&
typeof event.metadata === "object" &&
event.metadata !== null &&
typeof event.metadata.diffContents === "object" &&
event.metadata.diffContents !== null
) {
const diffContentsRaw = event.metadata.diffContents
if (diffContentsRaw && typeof diffContentsRaw === "object") {
let valid = true
for (const value of Object.values(diffContentsRaw as object)) {
if (
!value ||
typeof value !== "object" ||
typeof (value as DiffContent).oldContent !== "string" ||
typeof (value as DiffContent).newContent !== "string"
) {
valid = false
break
}
}
if (valid) {
const diffContents = diffContentsRaw as Record<
string,
{ oldContent: string; newContent: string }
>
diffStats = calculateDiffStats(diffContents)
}
}
}
return (
<ClickableTableRow key={event.id} event={event} onRowClick={handleRowClick}>
<TableCell className="w-auto min-w-0">
<div className="flex items-start gap-3">
<FileCode2 className="h-4 w-4 text-muted-foreground mt-1 flex-shrink-0" />
<div className="flex-1 min-w-0 overflow-hidden">
<div className="font-mono text-sm font-medium truncate">
{event.function_name || "Unknown Function"}
</div>
<div className="text-xs text-muted-foreground truncate">
{event.file_path || "No file path"}
</div>
</div>
</div>
</TableCell>
<TableCell className="w-auto min-w-0">
<div className="flex items-center gap-3">
{event.repository ? (
<>
<div className="relative h-8 w-8 flex-shrink-0">
{event.repository.full_name && (
<Image
src={`https://github.com/${event.repository.full_name.split("/")[0]}.png`}
alt={event.repository.full_name}
fill
className="rounded-full object-cover"
onError={e => {
e.currentTarget.style.display = "none"
}}
/>
)}
</div>
<div className="flex-1 min-w-0 overflow-hidden">
<div className="flex items-center gap-1">
<span className="text-sm font-medium truncate">
{event.repository.full_name || "Unknown Repository"}
</span>
</div>
</div>
</>
) : (
<>
<div className="h-8 w-8 rounded-full bg-muted flex items-center justify-center flex-shrink-0">
<Zap className="h-4 w-4 text-muted-foreground" />
</div>
<span className="text-sm text-muted-foreground">
Untracked repository
</span>
</>
)}
</div>
</TableCell>
<TableCell className="text-center">{getStatusBadge(event.status)}</TableCell>
<TableCell className="text-center">
{getEventTypeBadge(event.event_type)}
</TableCell>
<TableCell className="text-center">
<ReviewQualityBadge quality={event.review_quality} />
</TableCell>
<TableCell className="text-center">
{getSpeedupBadge(event.speedup_x, event.speedup_pct)}
</TableCell>
<TableCell className="text-center">
<div className="flex items-center justify-center gap-2 flex-wrap">
{diffStats.totalAdditions > 0 && (
<Badge
variant="secondary"
className="bg-emerald-50 text-emerald-700 border-emerald-200 dark:bg-emerald-950/50 dark:text-emerald-400 dark:border-emerald-800 whitespace-nowrap font-medium"
>
+{diffStats.totalAdditions}
</Badge>
)}
{diffStats.totalDeletions > 0 && (
<Badge
variant="secondary"
className="bg-rose-50 text-rose-700 border-rose-200 dark:bg-rose-950/50 dark:text-rose-400 dark:border-rose-800 whitespace-nowrap font-medium"
>
-{diffStats.totalDeletions}
</Badge>
)}
{diffStats.totalAdditions === 0 && diffStats.totalDeletions === 0 && (
<span className="text-muted-foreground text-xs"></span>
)}
</div>
</TableCell>
<TableCell className="text-right">
<div className="flex items-center justify-end gap-2">
<Clock className="h-3 w-3 text-muted-foreground flex-shrink-0" />
<span className="text-sm text-muted-foreground whitespace-nowrap">
{formatDistanceToNow(new Date(event.created_at), {
addSuffix: true,
})}
</span>
</div>
</TableCell>
</ClickableTableRow>
)
})
)}
</TableBody>
</Table>
</div>
{!isLoading && totalPages > 1 && (
<div className="flex items-center justify-between mt-6">
<p className="text-sm text-muted-foreground">
Showing {(filters.page - 1) * pageSize + 1} to{" "}
{Math.min(filters.page * pageSize, totalCount)} of {totalCount} events
</p>
<div className="flex items-center gap-2">
<Button
variant="outline"
size="sm"
disabled={filters.page === 1}
onClick={() => handlePageChange(filters.page - 1)}
>
<ChevronLeft className="h-4 w-4" />
Previous
</Button>
<div className="flex items-center gap-1">
{Array.from({ length: Math.min(5, totalPages) }, (_, i) => {
let pageNum: number
if (totalPages <= 5) {
pageNum = i + 1
} else if (filters.page <= 3) {
pageNum = i + 1
} else if (filters.page >= totalPages - 2) {
pageNum = totalPages - 4 + i
} else {
pageNum = filters.page - 2 + i
}
return (
<Button
key={i}
variant={filters.page === pageNum ? "default" : "outline"}
size="sm"
className="w-8 h-8 p-0"
onClick={() => handlePageChange(pageNum)}
>
{pageNum}
</Button>
)
})}
{totalPages > 5 && filters.page < totalPages - 2 && <span className="px-2">...</span>}
{totalPages > 5 && filters.page < totalPages - 2 && (
<Button
variant="outline"
size="sm"
className="w-8 h-8 p-0"
onClick={() => handlePageChange(totalPages)}
>
{totalPages}
</Button>
)}
</div>
<Button
variant="outline"
size="sm"
disabled={filters.page === totalPages}
onClick={() => handlePageChange(filters.page + 1)}
>
Next
<ChevronRight className="h-4 w-4" />
</Button>
</div>
</div>
)}
</div>
)
}

View file

@ -1,11 +1,12 @@
"use server"
import { getRepositoriesForAccountCached } from "@/lib/services/repository-utils"
import { withTiming } from "@/lib/server-action-timing"
import { AccountPayload, buildOptimizationOrCondition, prisma } from "@codeflash-ai/common"
export async function getRepositoriesWithStagingEvents(
payload: AccountPayload,
): Promise<Array<{ id: string; full_name: string }>> {
const { repoIds, repos: allRepos } = await getRepositoriesForAccountCached(payload)
export const getRepositoriesWithStagingEvents = withTiming(
"getRepositoriesWithStagingEvents",
async (payload: AccountPayload): Promise<Array<{ id: string; full_name: string }>> => {
const { repoIds, repos: allRepos } = await getRepositoriesForAccountCached(payload)
if (repoIds.length === 0) {
return []
@ -29,23 +30,26 @@ export async function getRepositoriesWithStagingEvents(
full_name: repo.full_name,
}))
.sort((a, b) => a.full_name.localeCompare(b.full_name))
}
},
)
export async function getAllOptimizationEvents({
payload,
search,
filter,
sort,
page = 1,
pageSize = 10,
}: {
payload: AccountPayload
search?: string
filter?: Record<string, any>
sort?: { [key: string]: "asc" | "desc" }
page?: number
pageSize?: number
}) {
export const getAllOptimizationEvents = withTiming(
"getAllOptimizationEvents",
async ({
payload,
search,
filter,
sort,
page = 1,
pageSize = 10,
}: {
payload: AccountPayload
search?: string
filter?: Record<string, any>
sort?: { [key: string]: "asc" | "desc" }
page?: number
pageSize?: number
}) => {
const repoIds = (await getRepositoriesForAccountCached(payload)).repoIds
const where: any = {
@ -168,83 +172,84 @@ export async function getAllOptimizationEvents({
orderByClauses.push("oe.created_at DESC")
}
const orderByClause = orderByClauses.join(", ")
const events = await prisma.$queryRawUnsafe<any[]>(
`
SELECT
oe.*,
of.review_quality,
of.review_explanation
FROM optimization_events oe
LEFT JOIN optimization_features of ON oe.trace_id = of.trace_id
LEFT JOIN repositories r ON oe.repository_id = r.id
WHERE ${whereClause}
ORDER BY ${orderByClause}
LIMIT $${paramIndex} OFFSET $${paramIndex + 1}
`,
...params,
pageSize,
(page - 1) * pageSize,
)
// Get total count
const countResult = await prisma.$queryRawUnsafe<[{ count: bigint }]>(
`
SELECT COUNT(*) as count
FROM optimization_events oe
LEFT JOIN optimization_features of ON oe.trace_id = of.trace_id
LEFT JOIN repositories r ON oe.repository_id = r.id
WHERE ${whereClause}
`,
...params,
)
const [events, countResult] = await Promise.all([
prisma.$queryRawUnsafe<any[]>(
`
SELECT
oe.*,
of.review_quality,
of.review_explanation,
r.full_name as repo_full_name,
r.name as repo_name,
r.id as repo_id
FROM optimization_events oe
LEFT JOIN optimization_features of ON oe.trace_id = of.trace_id
LEFT JOIN repositories r ON oe.repository_id = r.id
WHERE ${whereClause}
ORDER BY ${orderByClause}
LIMIT $${paramIndex} OFFSET $${paramIndex + 1}
`,
...params,
pageSize,
(page - 1) * pageSize,
),
prisma.$queryRawUnsafe<[{ count: bigint }]>(
`
SELECT COUNT(*) as count
FROM optimization_events oe
LEFT JOIN optimization_features of ON oe.trace_id = of.trace_id
LEFT JOIN repositories r ON oe.repository_id = r.id
WHERE ${whereClause}
`,
...params,
),
])
const totalCount = Number(countResult[0].count)
// Fetch repository data for the events
const eventsWithRepo = await Promise.all(
events.map(async event => {
if (event.repository_id) {
const repository = await prisma.repositories.findUnique({
where: { id: event.repository_id },
})
return { ...event, repository }
}
return { ...event, repository: null }
}),
)
// Repository data is already included from the JOIN
const eventsWithRepo = events.map(event => ({
...event,
repository: event.repo_id ? { id: event.repo_id, full_name: event.repo_full_name, name: event.repo_name } : null,
}))
return { events: eventsWithRepo, totalCount }
} else {
// Standard Prisma query with native orderBy
const orderBy = sort || { created_at: "desc" }
const events = await prisma.optimization_events.findMany({
where,
orderBy,
skip: (page - 1) * pageSize,
take: pageSize,
include: {
repository: true,
const [events, totalCount] = await Promise.all([
prisma.optimization_events.findMany({
where,
orderBy,
skip: (page - 1) * pageSize,
take: pageSize,
include: {
repository: true,
},
}),
prisma.optimization_events.count({ where }),
])
// Batch-fetch review data for all events in a single query
const traceIds = events.map(e => e.trace_id)
const features = await prisma.optimization_features.findMany({
where: { trace_id: { in: traceIds } },
select: {
trace_id: true,
review_quality: true,
review_explanation: true,
},
})
const featuresMap = new Map(features.map(f => [f.trace_id, f]))
// Fetch review_quality and review_explanation for each event
const eventsWithReviewData = await Promise.all(
events.map(async event => {
const features = await prisma.optimization_features.findUnique({
where: { trace_id: event.trace_id },
select: {
review_quality: true,
review_explanation: true,
},
})
return {
...event,
review_quality: features?.review_quality || null,
review_explanation: features?.review_explanation || null,
}
}),
)
const totalCount = await prisma.optimization_events.count({ where })
const eventsWithReviewData = events.map(event => {
const f = featuresMap.get(event.trace_id)
return {
...event,
review_quality: f?.review_quality || null,
review_explanation: f?.review_explanation || null,
}
})
return { events: eventsWithReviewData, totalCount }
}
}
},
)

View file

@ -0,0 +1,22 @@
"use client"
import { RefreshCw } from "lucide-react"
export default function ReviewOptimizationsError({ reset }: { error: Error; reset: () => void }) {
return (
<div className="flex justify-center items-center h-[70vh]">
<div className="bg-red-50 text-red-800 p-6 sm:p-8 rounded-xl max-w-md border border-red-200">
<h3 className="text-base sm:text-lg font-medium mb-2 sm:mb-3">Something went wrong</h3>
<p className="mb-3 sm:mb-4 text-sm sm:text-base">
There was an error loading the review optimizations page.
</p>
<button
onClick={reset}
className="flex items-center gap-1 sm:gap-2 w-full justify-center px-3 sm:px-4 py-2 sm:py-2.5 bg-red-100 hover:bg-red-200 text-red-800 rounded-lg text-xs sm:text-sm font-medium transition-colors"
>
<RefreshCw size={14} className="sm:w-4 sm:h-4" /> Try Again
</button>
</div>
</div>
)
}

View file

@ -0,0 +1,59 @@
import { Skeleton } from "@/components/ui/skeleton"
export default function ReviewOptimizationsLoading() {
return (
<div className="py-8 px-4">
<div className="mb-8">
<Skeleton className="h-9 w-64 mb-2" />
</div>
{/* Filter bar skeleton */}
<div className="mb-6">
<div className="flex flex-wrap items-center gap-3">
<Skeleton className="h-10 flex-1 min-w-[200px] max-w-md rounded-md" />
<Skeleton className="h-10 w-[180px] rounded-md" />
<Skeleton className="h-10 w-[120px] rounded-md" />
<Skeleton className="h-10 w-[120px] rounded-md" />
<Skeleton className="h-10 w-[120px] rounded-md" />
<Skeleton className="h-10 w-[140px] rounded-md" />
</div>
</div>
{/* Table skeleton */}
<div className="rounded-lg border bg-card">
<div className="p-4 border-b">
<div className="flex gap-4">
<Skeleton className="h-4 w-[25%]" />
<Skeleton className="h-4 w-[18%]" />
<Skeleton className="h-4 w-[10%]" />
<Skeleton className="h-4 w-[10%]" />
<Skeleton className="h-4 w-[8%]" />
<Skeleton className="h-4 w-[10%]" />
<Skeleton className="h-4 w-[8%]" />
<Skeleton className="h-4 w-[11%]" />
</div>
</div>
{Array.from({ length: 5 }).map((_, i) => (
<div key={i} className="p-4 border-b last:border-0">
<div className="flex items-center gap-4">
<div className="w-[25%] space-y-2">
<Skeleton className="h-4 w-3/4" />
<Skeleton className="h-3 w-1/2" />
</div>
<div className="w-[18%] flex items-center gap-2">
<Skeleton className="h-8 w-8 rounded-full" />
<Skeleton className="h-4 w-24" />
</div>
<Skeleton className="h-6 w-[10%] rounded-full" />
<Skeleton className="h-6 w-[10%] rounded-full" />
<Skeleton className="h-6 w-[8%] rounded-full" />
<Skeleton className="h-6 w-[10%] rounded-full" />
<Skeleton className="h-6 w-[8%] rounded-full" />
<Skeleton className="h-4 w-[11%]" />
</div>
</div>
))}
</div>
</div>
)
}

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,7 @@
"use client"
import { usePathname, useSearchParams } from "next/navigation"
import { useEffect } from "react"
import { type JSX, useEffect } from "react"
import { usePostHog } from "posthog-js/react"
export default function PostHogPageView(): JSX.Element | null {

View file

@ -1,175 +0,0 @@
import {
type AfterCallbackAppRoute,
type AppRouteHandlerFnContext,
getSession,
handleAuth,
handleCallback,
handleLogin,
handleLogout,
type Session,
} from "@auth0/nextjs-auth0"
import { type NextRequest, NextResponse } from "next/server"
import { createOrUpdateUser, hasCompletedOnboarding } from "@codeflash-ai/common"
import { trackUserLogin } from "@/lib/analytics/tracking"
import { cookies } from "next/headers"
import { APP_ROUTES } from "@/lib/types"
//In case we want to change some future variables to set redirect to marketing campaign
const LOGOUT_REDIRECT_URL =
process.env.CODEFLASH_LOGOUT_REDIRECT_URL ??
process.env.CODEFLASH_MARKETING_URL ??
"https://codeflash.ai"
// THIS IS THE KEY CHANGE - Your afterCallback was empty!
const afterCallback: AfterCallbackAppRoute = async (req: NextRequest, session: Session) => {
if (!session.user) {
return session
}
const user = session.user
console.log(`[Auth] Processing login for user: ${user.sub}`)
if (!user.sub || !user.nickname) {
console.error("[Auth] Missing required user fields")
return session
}
try {
// 1. SAVE TO DATABASE (moved from login page!)
console.log("[Auth] Saving user to database...")
await createOrUpdateUser(user.sub, user.nickname, user.email ?? null, user.name ?? null)
console.log("[Auth] User saved successfully")
// 2. TRACK LOGIN (moved from login page!)
await trackUserLogin({
userId: user.sub,
username: user.nickname,
email: user.email,
name: user.name,
})
// 3. CHECK ONBOARDING (moved from login page!)
const completedOnboarding = await hasCompletedOnboarding(user.sub)
console.log(`[Auth] Onboarding completed: ${completedOnboarding}`)
// 4. Decide where to redirect - Auth0 preserves returnTo from login
let intendedDestination = APP_ROUTES.BASE
// Try to get returnTo from multiple sources
const url = new URL(req.url)
// Method 1: From URL search params (direct from Auth0 redirect)
const returnToParam = url.searchParams.get("returnTo")
if (returnToParam) {
intendedDestination = returnToParam
console.log(`[Auth] Found returnTo in URL params: ${intendedDestination}`)
}
// Method 2: From state parameter (fallback)
const stateParam = url.searchParams.get("state")
if (stateParam && !returnToParam) {
try {
const state = JSON.parse(Buffer.from(stateParam, "base64").toString("utf-8"))
if (state.returnTo) {
intendedDestination = state.returnTo
console.log(`[Auth] Found returnTo in state: ${intendedDestination}`)
}
} catch (e) {
console.warn("[Auth] Failed to parse state:", e)
}
}
// check if the path is codeflash/auth/[token]
const isAuthPath =
intendedDestination.startsWith("/codeflash/auth") ||
intendedDestination.includes("/codeflash/auth")
console.log(`[Auth] isAuthPath: ${isAuthPath}`)
// Handle onboarding redirect
if (!completedOnboarding && !isAuthPath) {
session.returnTo = "/onboarding"
} else {
session.returnTo = intendedDestination
}
} catch (error) {
console.error("[Auth] Error in afterCallback:", error)
// Don't fail login even if our processing fails
session.returnTo = APP_ROUTES.BASE
}
return session
}
// Rest of your file stays mostly the same...
export const GET = handleAuth({
// Fixed login handler to preserve returnTo parameter
login: async (request: any, response: any) => {
console.log("Logging in")
try {
const req = request as NextRequest
const url = new URL(req.url)
const returnTo = url.searchParams.get("returnTo") || APP_ROUTES.BASE
console.log(`[Auth] Login with returnTo: ${returnTo}`)
return await handleLogin(req, response as AppRouteHandlerFnContext, {
returnTo,
authorizationParams: {
scope: "openid profile email offline_access",
},
})
} catch (error) {
console.error("Error logging in:", error)
return NextResponse.json({ error: "Failed to initiate login" }, { status: 500 })
}
},
// Your existing logout handler...
logout: async (request: any, response: any) => {
console.log("Logging out")
try {
return await handleLogout(request as NextRequest, response as AppRouteHandlerFnContext, {
returnTo: LOGOUT_REDIRECT_URL,
})
} catch (error) {
console.error("Error logging out:", error)
return NextResponse.redirect(LOGOUT_REDIRECT_URL)
}
},
// Updated callback handler
callback: async (req: any, res: any) => {
try {
const response = (await handleCallback(req as NextRequest, res as AppRouteHandlerFnContext, {
afterCallback, // NOW THIS DOES SOMETHING!
})) as NextResponse
const session = await getSession(req as NextRequest, response)
if (session != null) {
// Use the returnTo set by afterCallback
const returnTo = session.returnTo || APP_ROUTES.BASE
const isAbsolute = returnTo.includes(process.env.AUTH0_BASE_URL ?? "")
const redirectUrl = isAbsolute ? returnTo : `${process.env.AUTH0_BASE_URL}${returnTo}`
return NextResponse.redirect(redirectUrl, response)
} else {
return NextResponse.redirect(`${process.env.AUTH0_BASE_URL}/waitlist`, response)
}
} catch (error: any) {
console.error("Error in callback:", error)
// Your existing error handling...
if (error.status === 400 && error.message.search("allowlist-fail") !== -1) {
const re = /allowlist-fail\s(.*)\s(.*)\)/
const match = error.message.match(re)
if (match != null) {
const userId = match[1]
const userNickname = match[2]
return NextResponse.redirect(
`${process.env.AUTH0_BASE_URL}/waitlist?username=${userNickname}&userid=${userId}`,
)
}
}
// If error doesn't match any specific case, return error page
return NextResponse.redirect(`${process.env.AUTH0_BASE_URL}/login?error=callback_failed`)
}
},
})

View file

@ -57,7 +57,7 @@ function summarizeToolResult(toolName: string, result: string): string {
}
case "get_errors": {
if (result === "No errors in this trace.") return "No errors"
const count = lines.filter((l) => l.startsWith("[")).length
const count = lines.filter(l => l.startsWith("[")).length
return `Found ${count} errors`
}
case "get_llm_call_detail":
@ -92,7 +92,7 @@ async function processToolCalls(
}
const results = await Promise.all(
toolUseBlocks.map(async (block) => {
toolUseBlocks.map(async block => {
const result = await resolveToolCall(
block.name,
(block.input as Record<string, unknown>) ?? {},
@ -120,10 +120,7 @@ async function processToolCalls(
return results
}
function baseParams(
systemPrompt: string,
conversationMessages: Anthropic.MessageParam[],
) {
function baseParams(systemPrompt: string, conversationMessages: Anthropic.MessageParam[]) {
return {
model: "claude-opus-4-6" as const,
max_tokens: 32_000,
@ -154,10 +151,7 @@ export async function POST(request: NextRequest): Promise<Response> {
const { traceId, messages } = body
if (!traceId || !messages?.length) {
return Response.json(
{ error: "traceId and messages are required" },
{ status: 400 },
)
return Response.json({ error: "traceId and messages are required" }, { status: 400 })
}
const tracePrefix = traceId.substring(0, 33)
@ -168,7 +162,7 @@ export async function POST(request: NextRequest): Promise<Response> {
const indexed = indexTraceData(traceData)
const systemPrompt = buildSummaryPrompt(indexed)
const conversationMessages: Anthropic.MessageParam[] = messages.map((m) => ({
const conversationMessages: Anthropic.MessageParam[] = messages.map(m => ({
role: m.role,
content: m.content,
}))
@ -183,21 +177,25 @@ export async function POST(request: NextRequest): Promise<Response> {
try {
let toolRounds = 0
let emittedText = false
// eslint-disable-next-line no-constant-condition
while (true) {
enqueue(`data: ${JSON.stringify({ type: "status", message: toolRounds === 0 ? "Thinking…" : "Analyzing…" })}\n\n`)
enqueue(
`data: ${JSON.stringify({ type: "status", message: toolRounds === 0 ? "Thinking…" : "Analyzing…" })}\n\n`,
)
// Redact thinking blocks from prior rounds (each can be 10-50KB)
for (const msg of conversationMessages) {
if (msg.role !== "assistant" || !Array.isArray(msg.content)) continue
for (const block of msg.content) {
if ((block as { type: string }).type === "thinking") {
(block as { thinking: string }).thinking = ""
;(block as { thinking: string }).thinking = ""
}
}
}
const messageStream = client.messages.stream(baseParams(systemPrompt, conversationMessages))
const messageStream = client.messages.stream(
baseParams(systemPrompt, conversationMessages),
)
const timeout = setTimeout(() => messageStream.abort(), ROUND_TIMEOUT_MS)
let response: Anthropic.Message
try {
@ -230,7 +228,7 @@ export async function POST(request: NextRequest): Promise<Response> {
if (msg.role !== "assistant" || !Array.isArray(msg.content)) continue
for (const block of msg.content) {
if ((block as { type: string }).type === "thinking") {
(block as { thinking: string }).thinking = ""
;(block as { thinking: string }).thinking = ""
}
}
}
@ -252,9 +250,12 @@ export async function POST(request: NextRequest): Promise<Response> {
enqueue("data: [DONE]\n\n")
} catch (err) {
const message = err instanceof Anthropic.APIError
? `API error: ${err.status} ${err.message}`
: err instanceof Error ? err.message : "Stream error"
const message =
err instanceof Anthropic.APIError
? `API error: ${err.status} ${err.message}`
: err instanceof Error
? err.message
: "Stream error"
enqueue(`data: ${JSON.stringify({ error: message })}\n\n`)
} finally {
clearInterval(keepalive)

View file

@ -1,9 +1,8 @@
import { NextRequest, NextResponse } from "next/server"
import { PrismaClient } from "@prisma/client"
import { prisma } from "@/lib/prisma"
const prisma = new PrismaClient()
export async function POST(request: NextRequest, { params }: { params: { trace_id: string } }) {
export async function POST(request: NextRequest, props: { params: Promise<{ trace_id: string }> }) {
const params = await props.params
try {
const { trace_id } = params
const body = await request.json()

View file

@ -1,7 +1,9 @@
"use client"
import { getUserOrganizations } from "@/components/dashboard/action"
import { UserProfile } from "@auth0/nextjs-auth0/client"
import { setOrgCookie } from "./org-cookie-action"
import type { User as UserProfile } from "@auth0/nextjs-auth0/types"
import { useRouter } from "next/navigation"
import React, {
createContext,
useContext,
@ -44,6 +46,7 @@ export function ViewModeProvider({
children: React.ReactNode
user?: UserProfile
}) {
const router = useRouter()
const [mode, setMode] = useState<ViewMode>("personal")
const [loading, setIsLoading] = useState<boolean>(true)
const [orgs, setOrgs] = useState<Organization[]>([])
@ -67,6 +70,10 @@ export function ViewModeProvider({
const finalOrgs = fetchedOrgs || orgsRef.current
setLocalStorageMode(newMode, orgId)
// Sync org cookie so server components can read it
const cookieOrgId = newMode === "organization" && orgId ? orgId : null
await setOrgCookie(cookieOrgId)
if (newMode === "organization" && orgId) {
const org = finalOrgs.find(o => o.id === orgId)
if (org) {
@ -80,8 +87,11 @@ export function ViewModeProvider({
setMode("personal")
setCurrentOrg(null)
}
// Trigger server re-render so server components pick up the new cookie
router.refresh()
},
[user?.sub, setLocalStorageMode],
[user?.sub, setLocalStorageMode, router],
)
useEffect(() => {

View file

@ -4,7 +4,8 @@ import { redirect } from "next/navigation"
* Catch-all route for legacy /app/* URLs
* Redirects to the corresponding route without the /app prefix
*/
export default function LegacyAppCatchAll({ params }: { params: { slug: string[] } }) {
const newPath = `/${params.slug.join("/")}`
redirect(newPath)
export default async function LegacyAppCatchAll(props: { params: Promise<{ slug: string[] }> }) {
const params = await props.params;
const newPath = `/${params.slug.join("/")}`
redirect(newPath)
}

View file

@ -0,0 +1,17 @@
"use server"
import { cookies } from "next/headers"
export async function setOrgCookie(orgId: string | null) {
const cookieStore = await cookies()
if (orgId) {
cookieStore.set("currentOrganizationId", orgId, {
path: "/",
httpOnly: true,
sameSite: "lax",
maxAge: 60 * 60 * 24 * 365, // 1 year
})
} else {
cookieStore.delete("currentOrganizationId")
}
}

View file

@ -0,0 +1,77 @@
"use client"
import { useCallback, useMemo, useRef, useState } from "react"
import { CalendarDays, ChevronDown } from "lucide-react"
import { useRouter, useSearchParams } from "next/navigation"
import { useOutsideClick } from "@/components/hooks/useOutsideClick"
export function YearSelector({ selectedYear }: { selectedYear: number }) {
const router = useRouter()
const searchParams = useSearchParams()
const [isOpen, setIsOpen] = useState(false)
const dropdownRef = useOutsideClick(() => setIsOpen(false))
const currentYear = new Date().getFullYear()
const availableYears = useMemo(() => {
const baseYear = 2025
return Array.from(
{ length: Math.max(1, currentYear - baseYear + 1) },
(_, i) => baseYear + i,
).filter(year => year <= currentYear)
}, [currentYear])
const handleYearChange = useCallback(
(year: number) => {
setIsOpen(false)
const params = new URLSearchParams(searchParams.toString())
if (year === currentYear) {
params.delete("year")
} else {
params.set("year", String(year))
}
const query = params.toString()
router.push(query ? `?${query}` : "/dashboard", { scroll: false })
},
[router, searchParams, currentYear],
)
return (
<div className="relative" ref={dropdownRef}>
<button
onClick={() => setIsOpen(!isOpen)}
className="flex items-center gap-1 px-2 py-1 text-xs bg-background border border-border rounded-md hover:border-primary/50 transition-colors"
disabled={availableYears.length <= 1}
>
<CalendarDays size={12} className="text-muted-foreground" />
<span>{selectedYear}</span>
{availableYears.length > 1 && (
<ChevronDown
size={12}
className={`transition-transform text-muted-foreground ${isOpen ? "rotate-180" : ""}`}
/>
)}
</button>
{isOpen && availableYears.length > 1 && (
<div className="absolute right-0 z-10 mt-1 w-32 bg-card rounded-md shadow-lg overflow-hidden border border-border animate-in fade-in-50 slide-in-from-top-5">
<div className="py-1">
{availableYears.map(year => (
<button
key={year}
onClick={() => handleYearChange(year)}
className={`w-full px-3 py-1.5 text-left hover:bg-muted flex items-center ${selectedYear === year ? "bg-primary/10 text-primary font-medium" : ""}`}
>
<span className="w-4 h-4 mr-1.5 flex items-center justify-center">
{selectedYear === year && (
<span className="w-1.5 h-1.5 rounded-full bg-primary"></span>
)}
</span>
{year}
</button>
))}
</div>
</div>
)}
</div>
)
}

View file

@ -1,10 +1,10 @@
import { getSession } from "@auth0/nextjs-auth0"
import { auth0 } from "@/lib/auth0"
import { redirect } from "next/navigation"
import { ReactNode } from "react"
import { hasCompletedOnboarding } from "@codeflash-ai/common"
export default async function DashboardLayout({ children }: { children: ReactNode }) {
const session = await getSession()
const session = await auth0.getSession()
if (!session) return null
const completedOnboarding = await hasCompletedOnboarding(session.user.sub)

View file

@ -0,0 +1,5 @@
import { DashboardSkeleton } from "@/components/dashboard/DashboardSkeleton"
export default function DashboardLoading() {
return <DashboardSkeleton />
}

View file

@ -1,270 +1,72 @@
"use client"
import React, { useState, useMemo, useEffect, useCallback, memo, useRef } from "react"
import {
Lock,
Globe,
RefreshCw,
Zap,
Gauge,
FolderGit2,
BookOpen,
CalendarDays,
ChevronDown,
} from "lucide-react"
import { getDashboardData, RepositoryWithUsage } from "./action"
import { getUserIdAndUsername } from "@/app/utils/auth"
import { format, subDays } from "date-fns"
import { Suspense } from "react"
import { Lock, Globe, Zap, Gauge, FolderGit2, BookOpen } from "lucide-react"
import { getDashboardData } from "./action"
import { getAccountContext } from "@/lib/server/get-account-context"
import { ActiveUsersLeaderboard } from "@/components/dashboard/ActiveUsersLeaderboard"
import { CompactPullRequestActivityCard } from "@/components/dashboard/CompactPullRequestActivityCard"
import { DashboardErrorBoundary } from "@/components/dashboard/DashboardErrorBoundary"
import { MetricCard } from "@/components/dashboard/MetricCard"
import { OptimizationPRsTable } from "@/components/dashboard/OptimizationPRsTable"
import { DashboardSkeleton } from "@/components/dashboard/DashboardSkeleton"
import { useViewMode } from "../app/ViewModeContext"
import { useOutsideClick } from "@/components/hooks/useOutsideClick"
import { AccountPayload } from "@codeflash-ai/common"
import { YearSelector } from "./_components/YearSelector"
import { format, subDays } from "date-fns"
const ErrorDisplay = memo(({ error, onRetry }: { error: string; onRetry: () => void }) => (
<div className="flex justify-center items-center h-[70vh]">
<div className="bg-red-50 text-red-800 p-6 sm:p-8 rounded-xl max-w-md border border-red-200">
<h3 className="text-base sm:text-lg font-medium mb-2 sm:mb-3">Unable to Load Dashboard</h3>
<p className="mb-3 sm:mb-4 text-sm sm:text-base">{error}</p>
<button
onClick={onRetry}
className="flex items-center gap-1 sm:gap-2 w-full justify-center px-3 sm:px-4 py-2 sm:py-2.5 bg-red-100 hover:bg-red-200 text-red-800 rounded-lg text-xs sm:text-sm font-medium transition-colors"
>
<RefreshCw size={14} className="sm:w-4 sm:h-4" /> Try Again
</button>
</div>
</div>
))
ErrorDisplay.displayName = "ErrorDisplay"
function getDateRangeDisplay(): string {
const now = new Date()
const last30DaysStart = subDays(now, 30)
const startMonth = format(last30DaysStart, "MMMM")
const endMonth = format(now, "MMMM")
const startYear = format(last30DaysStart, "yyyy")
const endYear = format(now, "yyyy")
interface OptimizationStats {
totalAttempts: number
successfulAttempts: number
activeReposLast30Days: number
if (startMonth === endMonth && startYear === endYear) {
return `${startMonth} ${format(last30DaysStart, "d")}-${format(now, "d")}, ${startYear}`
}
if (startYear === endYear) {
return `${format(last30DaysStart, "MMMM d")} - ${format(now, "MMMM d")}, ${startYear}`
}
return `${format(last30DaysStart, "MMMM d, yyyy")} - ${format(now, "MMMM d, yyyy")}`
}
interface PrActivityData {
month: string
pr_created: number
pr_merged: number
pr_closed: number
}
interface ActiveUserData {
username: string
eventCount: number
avatarUrl: string
}
function Dashboard() {
const { currentOrg } = useViewMode()
export default async function DashboardPage({
searchParams,
}: {
searchParams: Promise<{ year?: string }>
}) {
const params = await searchParams
const currentYear = new Date().getFullYear()
const parsedYear = params.year ? parseInt(params.year, 10) : currentYear
const selectedYear = Number.isNaN(parsedYear) ? currentYear : parsedYear
const [repositories, setRepositories] = useState<RepositoryWithUsage[]>([])
const [loading, setLoading] = useState(true)
const [error, setError] = useState<string | null>(null)
const accountPayload = await getAccountContext()
const { stats, repos } = await getDashboardData(accountPayload, selectedYear)
const [optimizationStats, setOptimizationStats] = useState<OptimizationStats>({
totalAttempts: 0,
successfulAttempts: 0,
activeReposLast30Days: 0,
})
const repositories = Array.isArray(repos) ? repos : []
const privateRepos = repositories.filter(repo => repo?.is_private).length
const publicRepos = repositories.length - privateRepos
const totalRepos = repositories.length
const [prActivityData, setPrActivityData] = useState<PrActivityData[]>([])
const [selectedYear, setSelectedYear] = useState<number>(currentYear)
const [isYearDropdownOpen, setIsYearDropdownOpen] = useState(false)
const yearDropdownRef = useOutsideClick(() => setIsYearDropdownOpen(false))
const dateRangeDisplay = getDateRangeDisplay()
const [activeUsersData, setActiveUsersData] = useState<ActiveUserData[]>([])
const [optimizationsTrend, setOptimizationsTrend] = useState<number[]>([])
const [optimizationsTrendDates, setOptimizationsTrendDates] = useState<string[]>([])
const [successfulOptimizationsTrend, setSuccessfulOptimizationsTrend] = useState<number[]>([])
const [successfulOptimizationsTrendDates, setSuccessfulOptimizationsTrendDates] = useState<
string[]
>([])
const [accountPayload, setAccountPayload] = useState<AccountPayload | null>(null)
const [isMobile, setIsMobile] = useState<boolean>(false)
const dateValues = useMemo(() => {
const now = new Date()
const last30DaysStart = subDays(now, 30)
const startMonth = format(last30DaysStart, "MMMM")
const endMonth = format(now, "MMMM")
const startYear = format(last30DaysStart, "yyyy")
const endYear = format(now, "yyyy")
function getDateRangeDisplay(): string {
if (startMonth === endMonth && startYear === endYear) {
return `${startMonth} ${format(last30DaysStart, "d")}-${format(now, "d")}, ${startYear}`
}
if (startYear === endYear) {
return `${format(last30DaysStart, "MMMM d")} - ${format(now, "MMMM d")}, ${startYear}`
}
return `${format(last30DaysStart, "MMMM d, yyyy")} - ${format(now, "MMMM d, yyyy")}`
}
return { now, last30DaysStart, dateRangeDisplay: getDateRangeDisplay() }
}, [])
const repoCounts = useMemo(() => {
if (!Array.isArray(repositories) || repositories.length === 0) {
return { privateRepos: 0, publicRepos: 0, totalRepos: 0 }
}
const privateRepos = repositories.filter(repo => repo?.is_private).length
const publicRepos = repositories.length - privateRepos
return { privateRepos, publicRepos, totalRepos: repositories.length }
}, [repositories])
const availableYears = useMemo(() => {
const baseYear = 2025
return Array.from(
{ length: Math.max(1, currentYear - baseYear + 1) },
(_, i) => baseYear + i,
).filter(year => year <= currentYear)
}, [currentYear])
useEffect(() => {
let timeoutId: NodeJS.Timeout
const handleResize = () => {
clearTimeout(timeoutId)
timeoutId = setTimeout(() => setIsMobile(window.innerWidth < 640), 150)
}
if (typeof window !== "undefined") {
setIsMobile(window.innerWidth < 640)
window.addEventListener("resize", handleResize)
return () => {
clearTimeout(timeoutId)
window.removeEventListener("resize", handleResize)
}
}
}, [])
const currentOrgId = currentOrg?.id
const fetchingRef = useRef(false)
const fetchDashboardData = useCallback(async () => {
if (fetchingRef.current) return
fetchingRef.current = true
try {
setLoading(true)
setError(null)
const currentUser = await getUserIdAndUsername()
if (!currentUser?.userId || !currentUser?.username) {
throw new Error("User authentication data not found")
}
const payload: AccountPayload = currentOrgId
? { orgId: currentOrgId }
: { userId: currentUser.userId, username: currentUser.username }
// Store payload for the PR table component
setAccountPayload(payload)
const { stats, repos } = await getDashboardData(payload, selectedYear)
setRepositories(Array.isArray(repos) ? repos : [])
setOptimizationStats({
totalAttempts: stats.optimizations.total,
successfulAttempts: stats.optimizations.successful,
activeReposLast30Days: stats.activeReposLast30Days.length,
})
const optimizationValues = stats.optimizations.timeSeries.map(item => item.count)
const optimizationDates = stats.optimizations.timeSeries.map(item => item.date)
setOptimizationsTrend(optimizationValues)
setOptimizationsTrendDates(optimizationDates)
const successfulValues = stats.optimizations.successfulTimeSeries.map(item => item.count)
const successfulDates = stats.optimizations.successfulTimeSeries.map(item => item.date)
setSuccessfulOptimizationsTrend(successfulValues)
setSuccessfulOptimizationsTrendDates(successfulDates)
setPrActivityData(stats.pullRequests)
setActiveUsersData(stats.activeUsersLast30Days)
} catch (err) {
console.error("Dashboard data fetch error:", err)
setError("Failed to load dashboard data. Please try again later.")
setRepositories([])
setPrActivityData([])
setActiveUsersData([])
setOptimizationsTrend([])
setOptimizationsTrendDates([])
setSuccessfulOptimizationsTrend([])
setSuccessfulOptimizationsTrendDates([])
} finally {
setLoading(false)
fetchingRef.current = false
}
}, [selectedYear, currentOrgId])
useEffect(() => {
fetchDashboardData()
}, [fetchDashboardData])
const handleYearChange = useCallback((year: number) => {
setSelectedYear(year)
setIsYearDropdownOpen(false)
}, [])
if (loading) return <DashboardSkeleton />
if (error) return <ErrorDisplay error={error} onRetry={fetchDashboardData} />
const optimizationsTrend = stats.optimizations.timeSeries.map(item => item.count)
const optimizationsTrendDates = stats.optimizations.timeSeries.map(item => item.date)
const successfulOptimizationsTrend = stats.optimizations.successfulTimeSeries.map(
item => item.count,
)
const successfulOptimizationsTrendDates = stats.optimizations.successfulTimeSeries.map(
item => item.date,
)
return (
<div className="min-h-screen pb-8 py-6 sm:py-8 px-4 sm:px-6 max-w-[1400px] mx-auto">
<div className="mb-6 sm:mb-8">
<div className="flex items-center justify-between mb-2">
<h1 className="text-xl sm:text-2xl font-bold">Dashboard</h1>
<div className="relative" ref={yearDropdownRef}>
<button
onClick={() => setIsYearDropdownOpen(!isYearDropdownOpen)}
className="flex items-center gap-1 px-2 py-1 text-xs bg-background border border-border rounded-md hover:border-primary/50 transition-colors"
disabled={availableYears.length <= 1}
>
<CalendarDays size={12} className="text-muted-foreground" />
<span>{selectedYear}</span>
{availableYears.length > 1 && (
<ChevronDown
size={12}
className={`transition-transform text-muted-foreground ${isYearDropdownOpen ? "rotate-180" : ""}`}
/>
)}
</button>
{isYearDropdownOpen && availableYears.length > 1 && (
<div className="absolute right-0 z-10 mt-1 w-32 bg-card rounded-md shadow-lg overflow-hidden border border-border animate-in fade-in-50 slide-in-from-top-5">
<div className="py-1">
{availableYears.map(year => (
<button
key={year}
onClick={() => handleYearChange(year)}
className={`w-full px-3 py-1.5 text-left hover:bg-muted flex items-center ${selectedYear === year ? "bg-primary/10 text-primary font-medium" : ""}`}
>
<span className="w-4 h-4 mr-1.5 flex items-center justify-center">
{selectedYear === year && (
<span className="w-1.5 h-1.5 rounded-full bg-primary"></span>
)}
</span>
{year}
</button>
))}
</div>
</div>
)}
</div>
<Suspense>
<YearSelector selectedYear={selectedYear} />
</Suspense>
</div>
</div>
{repoCounts.totalRepos === 0 && (
{totalRepos === 0 && (
<div className="mb-6 sm:mb-8">
<div className="rounded-xl border border-dashed border-border bg-muted/10 px-5 py-4 sm:px-6 sm:py-5 flex flex-col sm:flex-row sm:items-center sm:justify-between gap-3">
<div>
@ -298,19 +100,17 @@ function Dashboard() {
</div>
)}
{/* Optimization PRs Table - Positioned at the top */}
{accountPayload && (
<div className="mb-6 sm:mb-8">
<OptimizationPRsTable payload={accountPayload} />
</div>
)}
{/* Optimization PRs Table */}
<div className="mb-6 sm:mb-8">
<OptimizationPRsTable payload={accountPayload} />
</div>
<div className="grid grid-cols-1 gap-3 sm:gap-5 mb-6 sm:mb-8">
<div className="grid grid-cols-1 sm:grid-cols-2 gap-3 sm:gap-5">
<MetricCard
title="Optimization Attempts"
value={optimizationStats.totalAttempts}
icon={<Zap size={isMobile ? 16 : 20} />}
value={stats.optimizations.total}
icon={<Zap />}
gradientFrom="bg-gradient-to-br from-blue-500/20"
gradientTo="to-blue-600/20"
iconColor="text-blue-500"
@ -324,9 +124,9 @@ function Dashboard() {
/>
<MetricCard
title="Optimizations Found"
value={optimizationStats.successfulAttempts}
value={stats.optimizations.successful}
subtitle=""
icon={<Gauge size={isMobile ? 16 : 20} />}
icon={<Gauge />}
gradientFrom="bg-gradient-to-br from-emerald-500/20"
gradientTo="to-emerald-600/20"
iconColor="text-emerald-500"
@ -344,8 +144,8 @@ function Dashboard() {
<div className="grid grid-cols-1 sm:grid-cols-2 md:grid-cols-4 gap-3 sm:gap-5">
<MetricCard
title="Total Repositories"
value={repoCounts.totalRepos}
icon={<BookOpen size={isMobile ? 16 : 20} />}
value={totalRepos}
icon={<BookOpen />}
gradientFrom="bg-gradient-to-br from-blue-500/20"
gradientTo="to-blue-600/20"
iconColor="text-blue-500"
@ -355,20 +155,20 @@ function Dashboard() {
<MetricCard
title="Active Repositories"
value={optimizationStats.activeReposLast30Days}
value={stats.activeReposLast30Days.length}
subtitle="last 30 days"
icon={<FolderGit2 size={isMobile ? 16 : 20} />}
icon={<FolderGit2 />}
gradientFrom="bg-gradient-to-br from-purple-500/20"
gradientTo="to-purple-600/20"
iconColor="text-purple-500"
timeText={dateValues.dateRangeDisplay}
timeText={dateRangeDisplay}
showChart={false}
/>
<MetricCard
title="Private Repositories"
value={repoCounts.privateRepos}
icon={<Lock size={isMobile ? 16 : 20} />}
value={privateRepos}
icon={<Lock />}
gradientFrom="bg-gradient-to-br from-amber-500/20"
gradientTo="to-amber-600/20"
iconColor="text-amber-500"
@ -378,8 +178,8 @@ function Dashboard() {
<MetricCard
title="Public Repositories"
value={repoCounts.publicRepos}
icon={<Globe size={isMobile ? 16 : 20} />}
value={publicRepos}
icon={<Globe />}
gradientFrom="bg-gradient-to-br from-violet-500/20"
gradientTo="to-violet-600/20"
iconColor="text-violet-500"
@ -391,27 +191,15 @@ function Dashboard() {
<div className="grid grid-cols-1 md:grid-cols-2 gap-3 sm:gap-5 mb-6 sm:mb-8 h-96 md:h-[500px]">
<CompactPullRequestActivityCard
prData={prActivityData}
prData={stats.pullRequests}
selectedYear={selectedYear}
onYearChange={handleYearChange}
className="h-full"
/>
<div className="h-full">
<ActiveUsersLeaderboard leaderboardData={activeUsersData} />
<ActiveUsersLeaderboard leaderboardData={stats.activeUsersLast30Days} />
</div>
</div>
</div>
)
}
const MemoizedDashboard = memo(Dashboard)
MemoizedDashboard.displayName = "Dashboard"
export default function DashboardWrapper() {
return (
<DashboardErrorBoundary>
<MemoizedDashboard />
</DashboardErrorBoundary>
)
}

View file

@ -1,10 +1,10 @@
/* Scoped observability theme - only affects pages wrapped with .obs-v2 */
@import "../styles/obs-theme.css";
@tailwind base;
@tailwind components;
@tailwind utilities;
/* Scoped observability theme - only affects pages wrapped with .obs-v2 */
@import "../styles/obs-theme.css";
@layer base {
:root {
/* Background and foreground */

View file

@ -1,23 +1,20 @@
import { type JSX } from "react"
import type { Metadata } from "next"
import { Inter as FontSans, JetBrains_Mono } from "next/font/google"
import "./globals.css"
import { cn } from "@/lib/utils"
import { ThemeProvider } from "@/components/theme-provider"
import { UserProvider } from "@auth0/nextjs-auth0/client"
import { Auth0Provider } from "@auth0/nextjs-auth0"
import { Toaster } from "@/components/ui/toaster"
import { Toaster as SonnerToaster } from "sonner"
import { getSession } from "@auth0/nextjs-auth0"
import { auth0 } from "@/lib/auth0"
import Script from "next/script"
import { PHProvider } from "./providers"
import dynamic from "next/dynamic"
import PostHogPageView from "./PostHogPageView"
import { ViewModeProvider } from "./app/ViewModeContext"
import { PrivacyModeProvider } from "./app/PrivacyModeContext"
import { ConditionalLayout } from "@/components/conditional-layout"
const PostHogPageView = dynamic(async () => await import("./PostHogPageView"), {
ssr: false,
})
const fontSans = FontSans({
subsets: ["latin"],
variable: "--font-sans",
@ -40,7 +37,7 @@ export default async function RootLayout({
}: {
children: React.ReactNode
}): Promise<JSX.Element> {
const session = await getSession()
const session = await auth0.getSession()
let intercomSnippet: string = `var APP_ID = "ljxo1nzr";
(function(){var w=window;var ic=w.Intercom;if(typeof ic==="function"){ic('reattach_activator');ic('update',w.intercomSettings);}else{var d=document;var i=function(){i.c(arguments);};i.q=[];i.c=function(args){i.q.push(args);};w.Intercom=i;var l=function(){var s=d.createElement('script');s.type='text/javascript';s.async=true;s.src='https://widget.intercom.io/widget/' + APP_ID;var x=d.getElementsByTagName('script')[0];x.parentNode.insertBefore(s, x);};if(document.readyState==='complete'){l();}else if(w.attachEvent){w.attachEvent('onload',l);}else{w.addEventListener('load',l,false);}}})();
`
@ -100,7 +97,7 @@ export default async function RootLayout({
</head>
<body className={cn("min-h-screen bg-background font-sans antialiased", fontSans.variable, jetbrainsMono.variable)}>
<PostHogPageView />
<UserProvider>
<Auth0Provider>
<ThemeProvider
attribute="class"
defaultTheme="system"
@ -115,7 +112,7 @@ export default async function RootLayout({
<Toaster />
<SonnerToaster position="top-right" richColors />
</ThemeProvider>
</UserProvider>
</Auth0Provider>
</body>
</PHProvider>
</html>

View file

@ -0,0 +1,757 @@
import { canAccessMembench } from "@/app/utils/auth"
import { redirect } from "next/navigation"
import { MembenchToggle } from "@/components/membench/membench-toggle"
import {
PeakMemoryChart,
AllocatorChart,
HeadroomChart,
MaxAllocChart,
} from "@/components/membench/membench-charts"
export const metadata = { title: "Memory Benchmark — Unstructured" }
/* ── static data ────────────────────────────────────────────────────── */
const SUITE = {
baseline: {
peak_gb: 1.66,
total_gb: 16.398,
allocs: 5_585_979,
wall_s: 76.0,
max_alloc_mb: 268,
tests: 18,
passed: 13,
failed: 5,
},
current: {
peak_gb: 1.473,
total_gb: 20.239,
allocs: 6_210_809,
wall_s: 86.0,
max_alloc_mb: 134,
tests: 18,
passed: 13,
failed: 5,
},
}
const PEAK_DELTA_PCT =
((SUITE.current.peak_gb - SUITE.baseline.peak_gb) / SUITE.baseline.peak_gb) * 100
const PEAK_DELTA_MB = Math.abs((SUITE.current.peak_gb - SUITE.baseline.peak_gb) * 1024)
const MAX_ALLOC_DELTA_PCT =
((SUITE.current.max_alloc_mb - SUITE.baseline.max_alloc_mb) / SUITE.baseline.max_alloc_mb) * 100
const POD_RAM_LIMIT_GB = 32
const TOP_ALLOC_BASELINE: [string, number][] = [
["_create_inference_session", 1.386],
["PIL Image.tobytes", 1.188],
["PIL Image.new", 1.001],
["load_prepare", 0.751],
["render", 0.649],
]
const TOP_ALLOC_CURRENT: [string, number][] = [
["PIL Image.tobytes", 1.802],
["PIL Image.new", 1.556],
["_create_inference_session", 1.328],
["load_prepare", 1.172],
["PIL Image.tobytes (2)", 0.889],
]
const SCENARIO_TABLE = [
{
scenario: "Full Suite (18 common tests)",
bl_peak: "1.660 GB",
cu_peak: "1.473 GB",
delta: "-11.3%",
bl_time: "76.0s",
cu_time: "86.0s",
},
{
scenario: "API hi_res (layout-parser-paper, 16p)",
bl_peak: "1.515 GB",
cu_peak: "1.419 GB",
delta: "-6.3%",
bl_time: "53.6s",
cu_time: "60.2s",
},
{
scenario: "od_only (Seeda Case Study)",
bl_peak: "1.127 GB",
cu_peak: "1.046 GB",
delta: "-7.2%",
bl_time: "1.96s",
cu_time: "2.13s",
},
]
const ENV_TABLE = [
["VM", "Azure Standard_D8s_v5 (8 vCPU, 32 GB RAM)"],
["OS", "Ubuntu 20.04"],
["Python", "3.12"],
["Profiler", "memray --native (captures C/C++ malloc, mmap)"],
["Test Runner", "memray run --native -o {out}.bin --force -m pytest -v"],
["Baseline Env", "/home/krrt7/bench/baseline-core + baseline-env (pre-Feb 2026)"],
["Current Env", "/home/krrt7/bench/current-core + current-env (main)"],
["Pre-run Protocol", "VM reboot + 5-min idle wait (clean Azure telemetry window)"],
["Production Target", "Knative pods, 1 CPU / 32 GB RAM, Standard_D48s_v5 nodes"],
["Test Scope", "18 common partition tests (od_only, hi_res, pptx, docx)"],
]
/* ── page ───────────────────────────────────────────────────────────── */
export default async function MembenchPage() {
const allowed = await canAccessMembench()
if (!allowed) redirect("/")
const b = SUITE.baseline
const c = SUITE.current
return (
<div className="min-h-screen bg-white dark:bg-zinc-950 font-sans text-zinc-900 dark:text-zinc-200">
{/* ── Hero ── */}
<div className="border-b border-zinc-200 dark:border-zinc-800 bg-gradient-to-br from-zinc-100 via-green-50/40 to-zinc-100 dark:from-zinc-950 dark:via-[#0c1a0f] dark:to-zinc-950 px-6 pb-14 pt-16 text-center">
<div
className="text-[13px] font-bold text-green-600 dark:text-green-400 mb-3"
style={{ letterSpacing: "0.15em" }}
>
UNSTRUCTURED
</div>
<h1
className="text-[36px] font-extrabold tracking-tight text-zinc-900 dark:text-zinc-50"
style={{ letterSpacing: "-0.02em" }}
>
Core Product Memory Benchmark
</h1>
<p className="mx-auto mt-3 max-w-[640px] text-[17px] text-zinc-500 dark:text-zinc-400">
Peak RAM reduction measured with memray --native across the partition test suite
</p>
<div className="mt-6 flex flex-wrap items-center justify-center gap-6 text-[13px] text-zinc-400 dark:text-zinc-500">
<span>April 2026</span>
<span>|</span>
<span>Baseline: pre-Feb 2026</span>
<span>|</span>
<span>18 common partition tests</span>
<span>|</span>
<span>Azure Standard_D8s_v5 VM</span>
</div>
</div>
{/* ── Hero Metrics ── */}
<div
className="mx-auto -mt-10 grid max-w-4xl grid-cols-2 gap-5 px-6 lg:grid-cols-4"
style={{ position: "relative", zIndex: 1 }}
>
{[
{
value: `${PEAK_DELTA_PCT.toFixed(1)}%`,
label: "Peak RAM",
detail: `${b.peak_gb.toFixed(2)} GB → ${c.peak_gb.toFixed(2)} GB`,
},
{
value: `${Math.round(PEAK_DELTA_MB)} MB`,
label: "Absolute Reduction",
detail: "Peak high-water mark savings",
},
{
value: `${MAX_ALLOC_DELTA_PCT.toFixed(0)}%`,
label: "Max Single Allocation",
detail: `${b.max_alloc_mb} MB → ${c.max_alloc_mb} MB`,
},
{
value: "0",
label: "New Regressions",
detail: `Same ${c.passed}/${c.tests} pass rate on both`,
},
].map(m => (
<div
key={m.label}
className="rounded-xl border border-zinc-200 dark:border-zinc-800 bg-white dark:bg-zinc-900 px-6 py-8 text-center shadow-sm dark:shadow-none"
>
<div
className="font-mono leading-none text-green-600 dark:text-green-400"
style={{ fontSize: "42px", fontWeight: 800, letterSpacing: "-0.02em" }}
>
{m.value}
</div>
<div className="mt-2 text-[15px] font-semibold text-zinc-800 dark:text-zinc-200">
{m.label}
</div>
<div className="mt-1 text-[13px] text-zinc-500 dark:text-zinc-400">{m.detail}</div>
</div>
))}
</div>
{/* ── Toggle + Views ── */}
<div className="mx-auto max-w-4xl px-6 pb-20">
<MembenchToggle execView={<ExecView />} engView={<EngView />} />
</div>
{/* ── Footer ── */}
<div className="border-t border-zinc-200 dark:border-zinc-800 py-8 text-center">
<div
className="text-[11px] font-bold text-zinc-400 dark:text-zinc-500 mb-1"
style={{ letterSpacing: "0.15em" }}
>
UNSTRUCTURED
</div>
<p className="text-xs text-zinc-400 dark:text-zinc-500">
Core Product Memory Benchmark April 2026
</p>
</div>
</div>
)
}
/*
EXECUTIVE VIEW
*/
function ExecView() {
const b = SUITE.baseline
const c = SUITE.current
return (
<div className="space-y-14">
{/* ── Peak Memory by Scenario ── */}
<Section
title="Peak Memory by Scenario"
subtitle="High-water mark during document processing — the metric that determines OOM risk."
>
<Card>
<PeakMemoryChart />
</Card>
</Section>
{/* ── What Does This Mean? ── */}
<Section title="What Does This Mean?" subtitle="How these numbers affect production pods.">
<div className="flex flex-wrap gap-5">
<Card className="flex-1 min-w-[260px]">
<h4 className="text-base font-bold text-zinc-900 dark:text-zinc-200">Lower OOM risk</h4>
<p className="mt-2 text-sm leading-relaxed text-zinc-600 dark:text-zinc-400">
Peak memory during the full partition suite dropped from {b.peak_gb.toFixed(2)} GB to{" "}
{c.peak_gb.toFixed(2)} GB a {Math.round(PEAK_DELTA_MB)} MB reduction. For Knative
pods with a 32 GB RAM limit, this means more headroom before the OOM killer terminates
the container.
</p>
</Card>
<Card className="flex-1 min-w-[260px]">
<h4 className="text-base font-bold text-zinc-900 dark:text-zinc-200">
Halved largest allocation
</h4>
<p className="mt-2 text-sm leading-relaxed text-zinc-600 dark:text-zinc-400">
The single largest memory allocation dropped from {b.max_alloc_mb} MB to{" "}
{c.max_alloc_mb} MB a 50% reduction. Large contiguous allocations are the primary
cause of memory fragmentation and allocation failures even when total free memory
appears sufficient.
</p>
</Card>
<Card className="flex-1 min-w-[260px]">
<h4 className="text-base font-bold text-zinc-900 dark:text-zinc-200">
Zero regressions
</h4>
<p className="mt-2 text-sm leading-relaxed text-zinc-600 dark:text-zinc-400">
Both environments pass the same {c.passed} of {c.tests} partition tests. The{" "}
{c.failed} failures are pre-existing docx edge cases present in the baseline not
regressions from the optimization work.
</p>
</Card>
</div>
</Section>
{/* ── Pod Headroom ── */}
<Section
title="Production Pod Headroom"
subtitle="Current peak usage vs. the Knative pod RAM limit of 32 GB."
>
<Card>
<HeadroomChart />
<div className="flex flex-wrap gap-6 mt-4">
<div>
<div
className="font-mono leading-none text-green-600 dark:text-green-400"
style={{ fontSize: "32px", fontWeight: 800 }}
>
{((c.peak_gb / POD_RAM_LIMIT_GB) * 100).toFixed(1)}%
</div>
<div className="mt-1 text-[13px] text-zinc-500 dark:text-zinc-400">
of pod limit used (current)
</div>
</div>
<div>
<div
className="font-mono leading-none text-zinc-800 dark:text-zinc-200"
style={{ fontSize: "32px", fontWeight: 800 }}
>
{(POD_RAM_LIMIT_GB - c.peak_gb).toFixed(1)} GB
</div>
<div className="mt-1 text-[13px] text-zinc-500 dark:text-zinc-400">
headroom remaining
</div>
</div>
</div>
<p className="mt-4 rounded-lg bg-zinc-50 dark:bg-white/[0.03] p-3 text-xs text-zinc-400 dark:text-zinc-500">
Note: Peak memory is measured per-process during document processing. Actual pod usage
includes OS overhead, model weights in shared memory, and other sidecar containers.
These figures represent the process-level high-water mark.
</p>
</Card>
</Section>
{/* ── Largest Single Allocation ── */}
<Section
title="Largest Single Allocation"
subtitle="The biggest contiguous block requested in a single malloc/mmap call."
>
<Card>
<MaxAllocChart />
</Card>
</Section>
{/* ── Suite-Level Comparison ── */}
<Section
title="Suite-Level Comparison"
subtitle="Aggregate metrics from 18 common partition tests."
>
<Card>
<div className="overflow-x-auto">
{/* header */}
<div className="flex gap-4 border-b-2 border-zinc-200 dark:border-zinc-800 pb-2.5 text-[11px] font-bold uppercase tracking-wider text-zinc-600 dark:text-zinc-300">
<div className="flex-1">Metric</div>
<div className="w-36 text-right">Baseline</div>
<div className="w-36 text-right">Current</div>
<div className="w-20 text-center">Delta</div>
</div>
<StatRow
label="Peak Memory"
baseline={b.peak_gb}
current={c.peak_gb}
unit="GB"
format={v => v.toFixed(3)}
better="lower"
/>
<StatRow
label="Total Allocated"
baseline={b.total_gb}
current={c.total_gb}
unit="GB"
format={v => v.toFixed(1)}
better="lower"
/>
<StatRow
label="Allocation Count"
baseline={b.allocs}
current={c.allocs}
unit=""
format={v => v.toLocaleString()}
better="lower"
/>
<StatRow
label="Max Single Alloc"
baseline={b.max_alloc_mb}
current={c.max_alloc_mb}
unit="MB"
format={v => v.toFixed(0)}
better="lower"
/>
<StatRow
label="Wall Time"
baseline={b.wall_s}
current={c.wall_s}
unit="s"
format={v => v.toFixed(1)}
better="lower"
/>
<StatRow
label="Tests Passed"
baseline={b.passed}
current={c.passed}
unit={`/ ${b.tests}`}
format={v => v.toFixed(0)}
better="higher"
/>
</div>
<p className="mt-4 text-xs text-zinc-400 dark:text-zinc-500">
Total allocated increased because current uses more frequent smaller allocations peak
(the OOM-risk metric) still decreased. This pattern indicates better memory recycling.
</p>
</Card>
</Section>
{/* ── Implications & Next Steps ── */}
<Section title="Implications & Next Steps">
<Card>
<ActionItem
text="Peak RAM reduced 11.3% — same workload fits in a smaller memory footprint"
done
/>
<ActionItem
text="Max single allocation halved — lower fragmentation risk under memory pressure"
done
/>
<ActionItem text="Zero test regressions — safe to deploy without functional risk" done />
<ActionItem text="Run full E2E suite benchmarks (all tests, not just common set) for comprehensive coverage" />
<ActionItem text="Profile top allocators (PIL Image ops, ONNX sessions) for further reduction opportunities" />
<ActionItem text="Evaluate reducing pod memory request from 32 GB based on production telemetry" />
</Card>
</Section>
</div>
)
}
/*
ENGINEERING VIEW
*/
function EngView() {
return (
<div className="space-y-14">
{/* ── Per-Scenario Results ── */}
<Section
title="Per-Scenario Results"
subtitle="Individual test scenarios measured with memray --native. VM rebooted + 5-min idle before each run."
>
<Card>
<div className="overflow-x-auto">
<table className="w-full text-sm">
<thead>
<tr className="border-b-2 border-zinc-200 dark:border-zinc-800 text-left text-[11px] font-bold uppercase tracking-wider text-green-600 dark:text-green-400">
<th className="py-2.5 pr-4">Scenario</th>
<th className="py-2.5 w-24 text-right">Baseline Peak</th>
<th className="py-2.5 w-24 text-right">Current Peak</th>
<th className="py-2.5 w-20 text-center">Delta</th>
<th className="py-2.5 w-20 text-right">BL Time</th>
<th className="py-2.5 w-20 text-right">CU Time</th>
</tr>
</thead>
<tbody>
{SCENARIO_TABLE.map((r, i) => (
<tr
key={r.scenario}
className={`border-b border-zinc-100 dark:border-zinc-800 ${i % 2 ? "bg-zinc-50/50 dark:bg-white/[0.02]" : ""}`}
>
<td className="py-2.5 pr-4 text-zinc-800 dark:text-zinc-200">{r.scenario}</td>
<td className="py-2.5 w-24 text-right font-mono text-zinc-500 dark:text-zinc-400">
{r.bl_peak}
</td>
<td className="py-2.5 w-24 text-right font-mono font-semibold text-zinc-800 dark:text-zinc-200">
{r.cu_peak}
</td>
<td className="py-2.5 w-20 text-center font-bold text-green-600 dark:text-green-400">
{r.delta}
</td>
<td className="py-2.5 w-20 text-right font-mono text-zinc-500 dark:text-zinc-400">
{r.bl_time}
</td>
<td className="py-2.5 w-20 text-right font-mono text-zinc-500 dark:text-zinc-400">
{r.cu_time}
</td>
</tr>
))}
</tbody>
</table>
</div>
</Card>
</Section>
{/* ── Top Memory Allocators ── */}
<Section
title="Top Memory Allocators"
subtitle="Functions with highest total allocated bytes over the 18-test suite lifetime (memray --native)."
>
<Card>
<AllocatorChart />
</Card>
<div className="flex flex-wrap gap-5 mt-5">
<Card className="flex-1 min-w-[300px]">
<div
className="text-[11px] font-bold text-amber-600 dark:text-amber-400 mb-4"
style={{ letterSpacing: "0.1em" }}
>
BASELINE TOP 5
</div>
{TOP_ALLOC_BASELINE.map(([name, size], i) => (
<div
key={name}
className="flex items-center gap-2 border-b border-zinc-100 dark:border-zinc-800 py-2 last:border-b-0"
>
<span className="font-mono text-[13px] text-zinc-400 dark:text-zinc-500">
{i + 1}.
</span>
<span className="text-sm font-semibold text-zinc-800 dark:text-zinc-200">
{name}
</span>
<span className="font-mono text-[13px] text-zinc-500 dark:text-zinc-400 ml-auto">
{size.toFixed(3)} GB
</span>
</div>
))}
</Card>
<Card className="flex-1 min-w-[300px]">
<div
className="text-[11px] font-bold text-green-600 dark:text-green-400 mb-4"
style={{ letterSpacing: "0.1em" }}
>
CURRENT TOP 5
</div>
{TOP_ALLOC_CURRENT.map(([name, size], i) => (
<div
key={name}
className="flex items-center gap-2 border-b border-zinc-100 dark:border-zinc-800 py-2 last:border-b-0"
>
<span className="font-mono text-[13px] text-zinc-400 dark:text-zinc-500">
{i + 1}.
</span>
<span className="text-sm font-semibold text-zinc-800 dark:text-zinc-200">
{name}
</span>
<span className="font-mono text-[13px] text-zinc-500 dark:text-zinc-400 ml-auto">
{size.toFixed(3)} GB
</span>
</div>
))}
</Card>
</div>
</Section>
{/* ── Key Observations ── */}
<Section title="Key Observations">
<ObservationCard
title="ONNX Session Overhead Down"
badge="Improved"
badgeColor="bg-green-500 text-zinc-950"
borderColor="border-l-green-400"
body="_create_inference_session dropped from #1 allocator (1.386 GB) in baseline to #3 (1.328 GB) in current. The ONNX Runtime session creation path allocates less overall, contributing to the peak reduction."
/>
<ObservationCard
title="PIL Image Operations Increased"
badge="Expected"
badgeColor="bg-amber-400 text-zinc-950"
borderColor="border-l-amber-400"
body="PIL Image.tobytes and Image.new increased in total allocation (e.g. tobytes: 1.188 → 1.802 GB). This reflects more frequent smaller image operations rather than fewer large ones — the pattern that reduces peak memory while increasing total throughput."
/>
<ObservationCard
title="Pre-Existing Test Failures (5 docx)"
badge="Not a regression"
badgeColor="bg-zinc-400 text-zinc-950"
borderColor="border-l-zinc-400"
body="5 of 18 tests fail on both baseline and current — all docx/pptx edge cases (but_not_when_the_partitioning_strategy_is_fast, PIL cannot recognize, Pillow can_only_read_the_image_on_Windows). These were excluded from the common test set count but still run; they are not related to the memory optimization work."
/>
</Section>
{/* ── Benchmark Environment ── */}
<Section title="Benchmark Environment">
<Card>
<DataTable columns={["Parameter", "Value"]} rows={ENV_TABLE} />
</Card>
</Section>
{/* ── Methodology ── */}
<Section title="Methodology">
<Card>
<ol className="list-decimal pl-5 space-y-3 text-sm leading-relaxed text-zinc-700 dark:text-zinc-200">
<li>
VM rebooted before each environment&apos;s run to ensure clean memory state and enable
Azure telemetry correlation
</li>
<li>
5-minute idle wait after reboot for OS caches, Azure agents, and background processes
to stabilize
</li>
<li>
Each test suite runs under memray run --native, which instruments both Python
allocations and native C/C++ allocations (malloc, calloc, realloc, mmap) via
LD_PRELOAD
</li>
<li>
memray stats extracts peak memory (high-water mark), total allocated, allocation
count, wall time, and top allocating functions from the binary trace
</li>
<li>
Common test set: 18 partition tests that exist in both baseline and current codebases,
ensuring apples-to-apples comparison
</li>
<li>
Identical test deselection applied to both: 6 baseline-only tests (docx/pptx edge
cases not present in current) excluded via pytest -k filters
</li>
</ol>
</Card>
</Section>
{/* ── Engineering Action Items ── */}
<Section title="Engineering Action Items">
<Card>
<ActionItem text="Run full E2E benchmarks with all tests (not just common set) for both environments" />
<ActionItem text="Add API pipeline tests (test_api_hi_res, test_api_od_only) to the standard benchmark suite" />
<ActionItem text="Profile PIL Image.tobytes hot path — largest allocator in current (1.802 GB total)" />
<ActionItem text="Investigate load_prepare growth (0.751 → 1.172 GB) for optimization opportunities" />
<ActionItem text="Measure with production-scale documents (100+ page PDFs) to validate scaling behavior" />
<ActionItem text="Correlate memray peaks with Azure pod-level metrics for production memory modeling" />
<ActionItem text="Evaluate reducing Knative pod memory request from 32 GB based on observed peaks" />
</Card>
</Section>
</div>
)
}
/* ── shared components ──────────────────────────────────────────────── */
function Section({
title,
subtitle,
children,
}: {
title: string
subtitle?: string
children: React.ReactNode
}) {
return (
<div>
<h2
className="text-[22px] font-bold text-zinc-900 dark:text-zinc-200"
style={{ letterSpacing: "-0.01em" }}
>
{title}
</h2>
{subtitle && (
<p className="mt-1.5 text-sm leading-relaxed text-zinc-500 dark:text-zinc-400">
{subtitle}
</p>
)}
<div className="mt-6">{children}</div>
</div>
)
}
function Card({ children, className = "" }: { children: React.ReactNode; className?: string }) {
return (
<div
className={`rounded-xl border border-zinc-200 dark:border-zinc-800 bg-white dark:bg-zinc-900 shadow-sm dark:shadow-none ${className}`}
style={{ padding: "28px 32px" }}
>
{children}
</div>
)
}
function ActionItem({ text, done = false }: { text: string; done?: boolean }) {
return (
<div className="flex items-center gap-3 border-b border-zinc-100 dark:border-zinc-800 py-2.5 last:border-b-0">
<span
className={`text-sm ${done ? "text-green-500 dark:text-green-400" : "text-zinc-400 dark:text-zinc-500"}`}
>
{done ? "●" : "○"}
</span>
<span className="text-sm text-zinc-800 dark:text-zinc-200">{text}</span>
</div>
)
}
function StatRow({
label,
baseline,
current,
unit,
format,
better = "lower",
}: {
label: string
baseline: number
current: number
unit: string
format: (v: number) => string
better?: "lower" | "higher"
}) {
const delta = ((current - baseline) / baseline) * 100
const improved = better === "lower" ? delta < 0 : delta > 0
const deltaText = `${delta > 0 ? "+" : ""}${delta.toFixed(1)}%`
return (
<div className="flex items-center gap-4 border-b border-zinc-100 dark:border-zinc-800 py-3 last:border-b-0">
<div className="flex-1 text-sm font-semibold text-zinc-800 dark:text-zinc-200">{label}</div>
<div className="w-36 text-right font-mono text-sm text-zinc-500 dark:text-zinc-400">
{format(baseline)} {unit}
</div>
<div className="w-36 text-right font-mono text-sm font-semibold text-zinc-800 dark:text-zinc-200">
{format(current)} {unit}
</div>
<div className="w-20 text-center">
<span
className={`inline-block rounded-md px-2 py-0.5 text-xs font-bold ${
delta === 0
? "bg-zinc-100 text-zinc-500 dark:bg-zinc-800 dark:text-zinc-400"
: improved
? "bg-green-100 text-green-700 dark:bg-green-400/10 dark:text-green-400"
: "bg-red-100 text-red-700 dark:bg-red-400/10 dark:text-red-400"
}`}
>
{deltaText}
</span>
</div>
</div>
)
}
function ObservationCard({
title,
badge,
badgeColor,
borderColor,
body,
}: {
title: string
badge: string
badgeColor: string
borderColor: string
body: string
}) {
return (
<div
className={`mb-4 rounded-xl border border-zinc-200 dark:border-zinc-800 border-l-4 ${borderColor} bg-white dark:bg-zinc-900 p-5 shadow-sm dark:shadow-none`}
>
<div className="mb-3 flex flex-wrap items-center gap-3">
<span className="text-base font-bold text-zinc-900 dark:text-zinc-200">{title}</span>
<span
className={`inline-block rounded-full ${badgeColor} px-2.5 py-0.5 text-xs font-semibold`}
>
{badge}
</span>
</div>
<p className="text-sm leading-relaxed text-zinc-600 dark:text-zinc-400">{body}</p>
</div>
)
}
function DataTable({ columns, rows }: { columns: string[]; rows: string[][] }) {
return (
<div className="overflow-x-auto">
<table className="w-full text-sm">
<thead>
<tr className="border-b-2 border-zinc-200 dark:border-zinc-800 text-left text-[11px] font-bold uppercase tracking-wider text-green-600 dark:text-green-400">
{columns.map(c => (
<th key={c} className="py-2.5 pr-4">
{c}
</th>
))}
</tr>
</thead>
<tbody>
{rows.map((row, i) => (
<tr
key={i}
className={`border-b border-zinc-100 dark:border-zinc-800 ${i % 2 ? "bg-zinc-50/50 dark:bg-white/[0.02]" : ""}`}
>
{row.map((cell, j) => (
<td key={j} className="py-2.5 pr-4 text-zinc-700 dark:text-zinc-200">
{cell}
</td>
))}
</tr>
))}
</tbody>
</table>
</div>
)
}

View file

@ -1,98 +1,84 @@
"use client"
import dynamic from "next/dynamic"
import { SyntaxHighlighter } from "@/lib/syntax-highlighter"
import { memo } from "react"
const SyntaxHighlighter = dynamic(
() => import("react-syntax-highlighter").then(m => m.Prism),
{
ssr: false,
loading: () => (
<div className="animate-pulse bg-zinc-800 rounded p-4 min-h-[100px]">
<div className="h-4 bg-zinc-700 rounded w-3/4 mb-2" />
<div className="h-4 bg-zinc-700 rounded w-1/2 mb-2" />
<div className="h-4 bg-zinc-700 rounded w-2/3" />
</div>
),
}
)
export const zincDarkTheme = {
'code[class*="language-"]': {
color: 'rgb(250, 250, 250)',
background: 'none',
fontFamily: 'var(--font-mono)',
fontSize: '1em',
textAlign: 'left',
whiteSpace: 'pre',
wordSpacing: 'normal',
wordBreak: 'normal',
wordWrap: 'normal',
lineHeight: '1.5',
color: "rgb(250, 250, 250)",
background: "none",
fontFamily: "var(--font-mono)",
fontSize: "1em",
textAlign: "left",
whiteSpace: "pre",
wordSpacing: "normal",
wordBreak: "normal",
wordWrap: "normal",
lineHeight: "1.5",
tabSize: 4,
hyphens: 'none',
hyphens: "none",
},
'pre[class*="language-"]': {
color: 'rgb(250, 250, 250)',
background: 'rgb(24, 24, 27)',
fontFamily: 'var(--font-mono)',
fontSize: '1em',
textAlign: 'left',
whiteSpace: 'pre',
wordSpacing: 'normal',
wordBreak: 'normal',
wordWrap: 'normal',
lineHeight: '1.5',
color: "rgb(250, 250, 250)",
background: "rgb(24, 24, 27)",
fontFamily: "var(--font-mono)",
fontSize: "1em",
textAlign: "left",
whiteSpace: "pre",
wordSpacing: "normal",
wordBreak: "normal",
wordWrap: "normal",
lineHeight: "1.5",
tabSize: 4,
hyphens: 'none',
padding: '1em',
margin: '0',
overflow: 'auto',
hyphens: "none",
padding: "1em",
margin: "0",
overflow: "auto",
},
comment: {
color: 'rgb(113, 113, 122)',
fontStyle: 'italic',
color: "rgb(113, 113, 122)",
fontStyle: "italic",
},
prolog: { color: 'rgb(113, 113, 122)' },
doctype: { color: 'rgb(113, 113, 122)' },
cdata: { color: 'rgb(113, 113, 122)' },
keyword: { color: 'rgb(96, 165, 250)' },
'control-flow': { color: 'rgb(96, 165, 250)' },
string: { color: 'rgb(134, 239, 172)' },
'attr-value': { color: 'rgb(134, 239, 172)' },
function: { color: 'rgb(253, 224, 71)' },
'class-name': { color: 'rgb(253, 224, 71)' },
number: { color: 'rgb(251, 146, 60)' },
boolean: { color: 'rgb(251, 146, 60)' },
operator: { color: 'rgb(161, 161, 170)' },
punctuation: { color: 'rgb(161, 161, 170)' },
variable: { color: 'rgb(250, 250, 250)' },
property: { color: 'rgb(250, 250, 250)' },
tag: { color: 'rgb(96, 165, 250)' },
'attr-name': { color: 'rgb(250, 250, 250)' },
prolog: { color: "rgb(113, 113, 122)" },
doctype: { color: "rgb(113, 113, 122)" },
cdata: { color: "rgb(113, 113, 122)" },
keyword: { color: "rgb(96, 165, 250)" },
"control-flow": { color: "rgb(96, 165, 250)" },
string: { color: "rgb(134, 239, 172)" },
"attr-value": { color: "rgb(134, 239, 172)" },
function: { color: "rgb(253, 224, 71)" },
"class-name": { color: "rgb(253, 224, 71)" },
number: { color: "rgb(251, 146, 60)" },
boolean: { color: "rgb(251, 146, 60)" },
operator: { color: "rgb(161, 161, 170)" },
punctuation: { color: "rgb(161, 161, 170)" },
variable: { color: "rgb(250, 250, 250)" },
property: { color: "rgb(250, 250, 250)" },
tag: { color: "rgb(96, 165, 250)" },
"attr-name": { color: "rgb(250, 250, 250)" },
namespace: { opacity: 0.7 },
selector: { color: 'rgb(253, 224, 71)' },
selector: { color: "rgb(253, 224, 71)" },
important: {
color: 'rgb(251, 146, 60)',
fontWeight: 'bold',
color: "rgb(251, 146, 60)",
fontWeight: "bold",
},
atrule: { color: 'rgb(96, 165, 250)' },
builtin: { color: 'rgb(253, 224, 71)' },
atrule: { color: "rgb(96, 165, 250)" },
builtin: { color: "rgb(253, 224, 71)" },
entity: {
color: 'rgb(250, 250, 250)',
cursor: 'help',
color: "rgb(250, 250, 250)",
cursor: "help",
},
url: {
color: 'rgb(96, 165, 250)',
textDecoration: 'underline',
color: "rgb(96, 165, 250)",
textDecoration: "underline",
},
inserted: {
color: 'rgb(134, 239, 172)',
background: 'rgba(134, 239, 172, 0.1)',
color: "rgb(134, 239, 172)",
background: "rgba(134, 239, 172, 0.1)",
},
deleted: {
color: 'rgb(248, 113, 113)',
background: 'rgba(248, 113, 113, 0.1)',
color: "rgb(248, 113, 113)",
background: "rgba(248, 113, 113, 0.1)",
},
} as const
@ -101,7 +87,7 @@ export const CODE_STYLE = {
padding: "1rem",
fontSize: "0.875rem",
lineHeight: 1.5,
background: 'rgb(24, 24, 27)',
background: "rgb(24, 24, 27)",
} as const
export const CODE_STYLE_RELAXED = {
@ -109,7 +95,7 @@ export const CODE_STYLE_RELAXED = {
padding: "1rem",
fontSize: "0.875rem",
lineHeight: 1.6,
background: 'rgb(24, 24, 27)',
background: "rgb(24, 24, 27)",
} as const
export const CODE_STYLE_SMALL = {
@ -117,7 +103,7 @@ export const CODE_STYLE_SMALL = {
padding: "1rem",
fontSize: "0.8125rem",
lineHeight: 1.5,
background: 'rgb(24, 24, 27)',
background: "rgb(24, 24, 27)",
} as const
interface CodeHighlighterProps {
@ -129,13 +115,13 @@ interface CodeHighlighterProps {
}
const highlightStyle = {
backgroundColor: 'rgba(250, 204, 21, 0.15)',
display: 'block',
marginLeft: '-1rem',
marginRight: '-1rem',
paddingLeft: '1rem',
paddingRight: '1rem',
borderLeft: '3px solid rgb(250, 204, 21)',
backgroundColor: "rgba(250, 204, 21, 0.15)",
display: "block",
marginLeft: "-1rem",
marginRight: "-1rem",
paddingLeft: "1rem",
paddingRight: "1rem",
borderLeft: "3px solid rgb(250, 204, 21)",
}
export const CodeHighlighter = memo(function CodeHighlighter({
@ -152,8 +138,8 @@ export const CodeHighlighter = memo(function CodeHighlighter({
return (lineNumber: number) => {
const isHighlighted = highlightSet.has(lineNumber)
return {
style: isHighlighted ? highlightStyle : { display: 'block' },
'data-highlighted': isHighlighted ? 'true' : undefined,
style: isHighlighted ? highlightStyle : { display: "block" },
"data-highlighted": isHighlighted ? "true" : undefined,
}
}
}
@ -173,4 +159,4 @@ export const CodeHighlighter = memo(function CodeHighlighter({
{code}
</SyntaxHighlighter>
)
})
})

View file

@ -1,6 +1,6 @@
"use client"
import { memo } from "react"
import React, { memo } from "react"
import {
Clock,
CheckCircle2,
@ -14,7 +14,7 @@ import { CandidateContent } from "./candidate-content"
import { RankingContent, SummaryContent } from "./ranking-content"
import type { TimelineSection } from "./timeline-types"
function getStatusIcon(status: string): JSX.Element {
function getStatusIcon(status: string): React.JSX.Element {
switch (status) {
case "success":
return <CheckCircle2 className="h-4 w-4 text-green-500" />

View file

@ -16,34 +16,36 @@ import { CopyButton } from "@/components/observability/copy-button"
import { ParsedResponseView } from "@/components/observability/parsed-response-view"
interface LLMCallDetailPageProps {
params: {
params: Promise<{
id: string
}
}>
}
export async function generateMetadata({ params }: LLMCallDetailPageProps): Promise<Metadata> {
export async function generateMetadata(props: LLMCallDetailPageProps): Promise<Metadata> {
const params = await props.params;
return {
title: `LLM Call ${params.id.substring(0, 8)} - Observability`,
description: "View LLM call details for prompt engineering analysis",
}
}
export default async function LLMCallDetailPage({ params }: LLMCallDetailPageProps) {
// Fetch LLM call details
const llmCall = await prisma.llm_calls.findUnique({
where: { id: params.id },
})
export default async function LLMCallDetailPage(props: LLMCallDetailPageProps) {
const params = await props.params;
// Fetch LLM call details and related errors in parallel
const [llmCall, relatedErrors] = await Promise.all([
prisma.llm_calls.findUnique({
where: { id: params.id },
}),
prisma.optimization_errors.findMany({
where: { llm_call_id: params.id },
orderBy: { created_at: "desc" },
}),
])
if (!llmCall) {
notFound()
}
// Fetch related errors
const relatedErrors = await prisma.optimization_errors.findMany({
where: { llm_call_id: params.id },
orderBy: { created_at: "desc" },
})
return (
<div className="container mx-auto px-4 py-8">
{/* Header */}

View file

@ -69,7 +69,8 @@ const getModels = unstable_cache(
{ revalidate: 300 }, // 5 minutes
)
export default async function LLMCallsPage({ searchParams }: { searchParams: SearchParams }) {
export default async function LLMCallsPage(props: { searchParams: Promise<SearchParams> }) {
const searchParams = await props.searchParams;
try {
const page = parseInt(searchParams.page || "1")
const pageSize = 50

View file

@ -20,12 +20,13 @@ import { CopyButton } from "@/components/observability/copy-button"
export const revalidate = 60
interface TracePageProps {
params: {
params: Promise<{
trace_id: string
}
}>
}
export default async function TracePage({ params }: TracePageProps) {
export default async function TracePage(props: TracePageProps) {
const params = await props.params
const { trace_id } = params
// Use prefix matching (first 33 chars) to group multi-model calls that share the same base trace_id

Some files were not shown because too many files have changed in this diff Show more