codeflash-internal/django/aiservice/languages/js_ts/testgen.py
2026-01-28 22:23:54 +02:00

383 lines
13 KiB
Python

"""JavaScript/TypeScript test generation module.
This module generates Jest tests for JavaScript/TypeScript functions.
Instrumentation is handled by the codeflash CLI client, not here.
"""
from __future__ import annotations
import logging
import re
from pathlib import Path
from typing import TYPE_CHECKING
import sentry_sdk
import stamina
from ninja.errors import HttpError
from openai import OpenAIError
from openai.types.chat import ChatCompletionMessageParam
from aiservice.analytics.posthog import ph
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.llm import EXECUTE_MODEL, HAIKU_MODEL, OPENAI_MODEL, calculate_llm_cost, call_llm
from aiservice.validators.javascript_validator import validate_javascript_syntax, validate_typescript_syntax
from authapp.auth import AuthenticatedRequest
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features
from testgen.models import TestGenerationFailedError, TestGenErrorResponseSchema, TestGenResponseSchema, TestGenSchema
if TYPE_CHECKING:
from aiservice.llm import LLM
_TEST_FUNC_RE = re.compile(r"(?:test|it)\s*\(\s*['\"]")
# Get the directory of the current file - prompts are now in languages/js_ts/prompts/testgen/
current_dir = Path(__file__).parent
JS_PROMPTS_DIR = current_dir / "prompts" / "testgen"
# Fallback to original location if prompts haven't been moved yet
if not JS_PROMPTS_DIR.exists():
# Use original location for backward compatibility during migration
JS_PROMPTS_DIR = Path(__file__).parent.parent.parent / "testgen" / "prompts" / "javascript"
# Load JavaScript prompts
JS_EXECUTE_SYSTEM_PROMPT = (JS_PROMPTS_DIR / "execute_system_prompt.md").read_text()
JS_EXECUTE_USER_PROMPT = (JS_PROMPTS_DIR / "execute_user_prompt.md").read_text()
JS_EXECUTE_ASYNC_SYSTEM_PROMPT = (JS_PROMPTS_DIR / "execute_async_system_prompt.md").read_text()
JS_EXECUTE_ASYNC_USER_PROMPT = (JS_PROMPTS_DIR / "execute_async_user_prompt.md").read_text()
# Pattern to extract JavaScript code blocks
JS_PATTERN = re.compile(r"^```(?:javascript|js|typescript|ts)?\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
def build_javascript_prompt(
function_name: str, function_code: str, module_path: str, test_framework: str, is_async: bool
) -> tuple[list[ChatCompletionMessageParam], str]:
"""Build the prompt messages for JavaScript test generation.
Args:
function_name: Name of the function to test
function_code: Source code of the function
module_path: Import path for the module
test_framework: Testing framework (jest, mocha)
is_async: Whether the function is async
Returns:
Tuple of (messages, posthog_event_suffix)
"""
if is_async:
system_prompt = JS_EXECUTE_ASYNC_SYSTEM_PROMPT
user_prompt = JS_EXECUTE_ASYNC_USER_PROMPT
posthog_event_suffix = "async-"
else:
system_prompt = JS_EXECUTE_SYSTEM_PROMPT
user_prompt = JS_EXECUTE_USER_PROMPT
posthog_event_suffix = ""
# Format prompts
system_message: ChatCompletionMessageParam = {
"role": "system",
"content": system_prompt.format(function_name=function_name),
}
user_message: ChatCompletionMessageParam = {
"role": "user",
"content": user_prompt.format(
test_framework=test_framework,
function_name=function_name,
function_code=function_code,
module_path=module_path,
package_comment="",
),
}
messages: list[ChatCompletionMessageParam] = [system_message, user_message]
return messages, posthog_event_suffix
def parse_and_validate_js_output(response_content: str) -> str:
"""Parse and validate the LLM response for JavaScript code.
Args:
response_content: Raw LLM response
Returns:
Validated JavaScript code
Raises:
ValueError: If no valid code block found
SyntaxError: If code has syntax errors
"""
# Check for code block
if "```" not in response_content:
sentry_sdk.capture_message("LLM response did not contain a code block:\n" + response_content[:500])
raise ValueError("LLM response did not contain a code block.")
pattern_res = JS_PATTERN.search(response_content)
if not pattern_res:
raise ValueError("No JavaScript code block found in the LLM response.")
code = pattern_res.group(1).strip()
# Validate syntax
is_valid, error = validate_javascript_syntax(code)
if not is_valid:
raise SyntaxError(f"Invalid JavaScript code: {error}")
# Check for test functions
if not _has_test_functions(code):
raise ValueError("Generated code does not contain any test functions.")
return code
def _has_test_functions(code: str) -> bool:
"""Check if the code contains Jest test functions."""
# Look for test() or it() calls
return _TEST_FUNC_RE.search(code) is not None
@stamina.retry(on=(SyntaxError, ValueError, OpenAIError), attempts=2)
async def generate_and_validate_js_test_code(
messages: list[ChatCompletionMessageParam],
model: LLM,
cost_tracker: list[float],
user_id: str,
posthog_event_suffix: str,
trace_id: str = "",
call_sequence: int | None = None,
) -> str:
"""Generate and validate JavaScript test code using LLM.
Args:
messages: Prompt messages
model: LLM model to use
cost_tracker: List to track costs
user_id: User ID
posthog_event_suffix: Suffix for PostHog events
trace_id: Trace ID for logging
call_sequence: Call sequence number
Returns:
Validated JavaScript test code
Raises:
SyntaxError: If code is invalid
ValueError: If no valid code found
"""
obs_context: dict | None = {"call_sequence": call_sequence} if call_sequence is not None else None
response = await call_llm(
llm=model,
messages=messages,
call_type="test_generation",
trace_id=trace_id,
user_id=user_id,
python_version="javascript", # Reusing field for language
context=obs_context,
)
cost = calculate_llm_cost(response.raw_response, model)
cost_tracker.append(cost)
debug_log_sensitive_data(
f"JavaScript {posthog_event_suffix}execute response:\n{response.raw_response.model_dump_json(indent=2)}"
)
if response.raw_response.usage:
ph(
user_id,
f"aiservice-testgen-js-{posthog_event_suffix}execute-openai-usage",
properties={"model": model.name, "usage": response.raw_response.usage.model_dump_json()},
)
# Parse and validate
validated_code = parse_and_validate_js_output(response.content)
return validated_code
@stamina.retry(on=TestGenerationFailedError, attempts=2)
async def generate_javascript_tests_from_function(
user_id: str,
function_name: str,
function_code: str,
module_path: str,
test_framework: str = "jest",
execute_model: LLM = EXECUTE_MODEL,
is_async: bool = False,
trace_id: str = "",
call_sequence: int | None = None,
language: str = "javascript",
) -> tuple[str, str, str]:
"""Generate JavaScript tests for a function.
Args:
user_id: User ID
function_name: Name of function to test
function_code: Source code of function
module_path: Import path for module
test_framework: Testing framework (jest, mocha)
execute_model: LLM model to use
is_async: Whether function is async
trace_id: Trace ID for logging
call_sequence: Call sequence number
language: Language of the function to test (javascript, typescript)
Returns:
Tuple of (generated_tests, instrumented_behavior_tests, instrumented_perf_tests)
Raises:
TestGenerationFailedError: If test generation fails
"""
messages, posthog_event_suffix = build_javascript_prompt(
function_name=function_name,
function_code=function_code,
module_path=module_path,
test_framework=test_framework,
is_async=is_async,
)
cost_tracker = []
try:
validated_code = await generate_and_validate_js_test_code(
messages=messages,
model=execute_model,
cost_tracker=cost_tracker,
user_id=user_id,
posthog_event_suffix=posthog_event_suffix,
trace_id=trace_id,
call_sequence=call_sequence,
)
total_llm_cost = sum(cost_tracker)
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=user_id)
# Return uninstrumented code - instrumentation is handled by the codeflash CLI client
# This ensures consistent instrumentation using the codeflash npm package
return validated_code, validated_code, validated_code
except (SyntaxError, ValueError) as e:
total_llm_cost = sum(cost_tracker)
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=user_id)
msg = f"Failed to generate valid JavaScript test code after {len(cost_tracker)} tries. trace_id={trace_id}"
logging.exception(msg)
raise TestGenerationFailedError(msg) from e
def validate_javascript_testgen_request_data(data: TestGenSchema) -> None:
"""Validate JavaScript/TypeScript test generation request data.
Args:
data: Request data
Raises:
HttpError: If validation fails
"""
if data.test_framework not in ["jest"]:
raise HttpError(400, "Invalid test framework for JavaScript/TypeScript. We only support jest.")
if not data.function_to_optimize:
raise HttpError(400, "Invalid function to optimize. It is empty.")
if not validate_trace_id(data.trace_id):
raise HttpError(400, "Invalid trace ID. Please provide a valid UUIDv4.")
# Validate syntax based on language
if data.language == "typescript":
is_valid, error = validate_typescript_syntax(data.source_code_being_tested)
lang_name = "TypeScript"
else:
is_valid, error = validate_javascript_syntax(data.source_code_being_tested)
lang_name = "JavaScript"
if not is_valid:
raise HttpError(400, f"Invalid source code. It is not valid {lang_name}: {error}")
async def testgen_javascript(
request: AuthenticatedRequest, data: TestGenSchema
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
"""Main endpoint handler for JavaScript test generation.
Args:
request: Authenticated request
data: Test generation request data
Returns:
Tuple of (status_code, response)
"""
language = data.language
ph(request.user, "aiservice-testgen-called", properties={"language": language})
try:
validate_javascript_testgen_request_data(data)
except HttpError as e:
e.add_note(f"JavaScript testgen request validation error: {e.status_code} {e.message}")
sentry_sdk.capture_exception(e)
return e.status_code, TestGenErrorResponseSchema(error=e.message)
logging.info("/testgen: Generating JavaScript tests...")
try:
debug_log_sensitive_data(f"Generating JavaScript tests for function {data.function_to_optimize.function_name}")
# Using different LLMs for different test_index values
test_index = data.test_index if data.test_index is not None else 0
if test_index % 2 == 0:
execute_model = OPENAI_MODEL
model_source = "OpenAI"
else:
execute_model = HAIKU_MODEL
model_source = "Anthropic"
logging.info(f"Using {model_source} model ({execute_model.name}) for JavaScript test_index {test_index}")
(
generated_test_source,
instrumented_behavior_tests,
instrumented_perf_tests,
) = await generate_javascript_tests_from_function(
user_id=request.user,
function_name=data.function_to_optimize.qualified_name,
function_code=data.source_code_being_tested,
module_path=data.module_path,
test_framework=data.test_framework,
is_async=data.is_async or False,
trace_id=data.trace_id,
call_sequence=data.call_sequence,
execute_model=execute_model,
language=data.language,
)
ph(request.user, "aiservice-testgen-tests-generated", properties={"language": language})
if hasattr(request, "should_log_features") and request.should_log_features:
await log_features(
trace_id=data.trace_id,
user_id=request.user,
generated_tests=[generated_test_source],
instrumented_generated_tests=[instrumented_behavior_tests],
test_framework=data.test_framework,
metadata={
"test_timeout": data.test_timeout,
"function_to_optimize": data.function_to_optimize.function_name,
"language": language,
},
)
return 200, TestGenResponseSchema(
generated_tests=generated_test_source,
instrumented_behavior_tests=instrumented_behavior_tests,
instrumented_perf_tests=instrumented_perf_tests,
)
except Exception as e:
logging.exception(f"JavaScript test generation failed. trace_id={data.trace_id}")
sentry_sdk.capture_exception(e)
return 500, TestGenErrorResponseSchema(error="Error generating JavaScript tests. Internal server error.")