mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
383 lines
13 KiB
Python
383 lines
13 KiB
Python
"""JavaScript/TypeScript test generation module.
|
|
|
|
This module generates Jest tests for JavaScript/TypeScript functions.
|
|
Instrumentation is handled by the codeflash CLI client, not here.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import re
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
import sentry_sdk
|
|
import stamina
|
|
from ninja.errors import HttpError
|
|
from openai import OpenAIError
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
|
|
from aiservice.analytics.posthog import ph
|
|
from aiservice.common_utils import validate_trace_id
|
|
from aiservice.env_specific import debug_log_sensitive_data
|
|
from aiservice.llm import EXECUTE_MODEL, HAIKU_MODEL, OPENAI_MODEL, calculate_llm_cost, call_llm
|
|
from aiservice.validators.javascript_validator import validate_javascript_syntax, validate_typescript_syntax
|
|
from authapp.auth import AuthenticatedRequest
|
|
from log_features.log_event import update_optimization_cost
|
|
from log_features.log_features import log_features
|
|
from testgen.models import TestGenerationFailedError, TestGenErrorResponseSchema, TestGenResponseSchema, TestGenSchema
|
|
|
|
if TYPE_CHECKING:
|
|
from aiservice.llm import LLM
|
|
|
|
_TEST_FUNC_RE = re.compile(r"(?:test|it)\s*\(\s*['\"]")
|
|
|
|
# Get the directory of the current file - prompts are now in languages/js_ts/prompts/testgen/
|
|
current_dir = Path(__file__).parent
|
|
JS_PROMPTS_DIR = current_dir / "prompts" / "testgen"
|
|
|
|
# Fallback to original location if prompts haven't been moved yet
|
|
if not JS_PROMPTS_DIR.exists():
|
|
# Use original location for backward compatibility during migration
|
|
JS_PROMPTS_DIR = Path(__file__).parent.parent.parent / "testgen" / "prompts" / "javascript"
|
|
|
|
# Load JavaScript prompts
|
|
JS_EXECUTE_SYSTEM_PROMPT = (JS_PROMPTS_DIR / "execute_system_prompt.md").read_text()
|
|
JS_EXECUTE_USER_PROMPT = (JS_PROMPTS_DIR / "execute_user_prompt.md").read_text()
|
|
JS_EXECUTE_ASYNC_SYSTEM_PROMPT = (JS_PROMPTS_DIR / "execute_async_system_prompt.md").read_text()
|
|
JS_EXECUTE_ASYNC_USER_PROMPT = (JS_PROMPTS_DIR / "execute_async_user_prompt.md").read_text()
|
|
|
|
# Pattern to extract JavaScript code blocks
|
|
JS_PATTERN = re.compile(r"^```(?:javascript|js|typescript|ts)?\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
|
|
|
|
|
|
def build_javascript_prompt(
|
|
function_name: str, function_code: str, module_path: str, test_framework: str, is_async: bool
|
|
) -> tuple[list[ChatCompletionMessageParam], str]:
|
|
"""Build the prompt messages for JavaScript test generation.
|
|
|
|
Args:
|
|
function_name: Name of the function to test
|
|
function_code: Source code of the function
|
|
module_path: Import path for the module
|
|
test_framework: Testing framework (jest, mocha)
|
|
is_async: Whether the function is async
|
|
|
|
Returns:
|
|
Tuple of (messages, posthog_event_suffix)
|
|
|
|
"""
|
|
if is_async:
|
|
system_prompt = JS_EXECUTE_ASYNC_SYSTEM_PROMPT
|
|
user_prompt = JS_EXECUTE_ASYNC_USER_PROMPT
|
|
posthog_event_suffix = "async-"
|
|
else:
|
|
system_prompt = JS_EXECUTE_SYSTEM_PROMPT
|
|
user_prompt = JS_EXECUTE_USER_PROMPT
|
|
posthog_event_suffix = ""
|
|
|
|
# Format prompts
|
|
system_message: ChatCompletionMessageParam = {
|
|
"role": "system",
|
|
"content": system_prompt.format(function_name=function_name),
|
|
}
|
|
|
|
user_message: ChatCompletionMessageParam = {
|
|
"role": "user",
|
|
"content": user_prompt.format(
|
|
test_framework=test_framework,
|
|
function_name=function_name,
|
|
function_code=function_code,
|
|
module_path=module_path,
|
|
package_comment="",
|
|
),
|
|
}
|
|
|
|
messages: list[ChatCompletionMessageParam] = [system_message, user_message]
|
|
return messages, posthog_event_suffix
|
|
|
|
|
|
def parse_and_validate_js_output(response_content: str) -> str:
|
|
"""Parse and validate the LLM response for JavaScript code.
|
|
|
|
Args:
|
|
response_content: Raw LLM response
|
|
|
|
Returns:
|
|
Validated JavaScript code
|
|
|
|
Raises:
|
|
ValueError: If no valid code block found
|
|
SyntaxError: If code has syntax errors
|
|
|
|
"""
|
|
# Check for code block
|
|
if "```" not in response_content:
|
|
sentry_sdk.capture_message("LLM response did not contain a code block:\n" + response_content[:500])
|
|
raise ValueError("LLM response did not contain a code block.")
|
|
|
|
pattern_res = JS_PATTERN.search(response_content)
|
|
if not pattern_res:
|
|
raise ValueError("No JavaScript code block found in the LLM response.")
|
|
|
|
code = pattern_res.group(1).strip()
|
|
|
|
# Validate syntax
|
|
is_valid, error = validate_javascript_syntax(code)
|
|
if not is_valid:
|
|
raise SyntaxError(f"Invalid JavaScript code: {error}")
|
|
|
|
# Check for test functions
|
|
if not _has_test_functions(code):
|
|
raise ValueError("Generated code does not contain any test functions.")
|
|
|
|
return code
|
|
|
|
|
|
def _has_test_functions(code: str) -> bool:
|
|
"""Check if the code contains Jest test functions."""
|
|
# Look for test() or it() calls
|
|
return _TEST_FUNC_RE.search(code) is not None
|
|
|
|
|
|
@stamina.retry(on=(SyntaxError, ValueError, OpenAIError), attempts=2)
|
|
async def generate_and_validate_js_test_code(
|
|
messages: list[ChatCompletionMessageParam],
|
|
model: LLM,
|
|
cost_tracker: list[float],
|
|
user_id: str,
|
|
posthog_event_suffix: str,
|
|
trace_id: str = "",
|
|
call_sequence: int | None = None,
|
|
) -> str:
|
|
"""Generate and validate JavaScript test code using LLM.
|
|
|
|
Args:
|
|
messages: Prompt messages
|
|
model: LLM model to use
|
|
cost_tracker: List to track costs
|
|
user_id: User ID
|
|
posthog_event_suffix: Suffix for PostHog events
|
|
trace_id: Trace ID for logging
|
|
call_sequence: Call sequence number
|
|
|
|
Returns:
|
|
Validated JavaScript test code
|
|
|
|
Raises:
|
|
SyntaxError: If code is invalid
|
|
ValueError: If no valid code found
|
|
|
|
"""
|
|
obs_context: dict | None = {"call_sequence": call_sequence} if call_sequence is not None else None
|
|
|
|
response = await call_llm(
|
|
llm=model,
|
|
messages=messages,
|
|
call_type="test_generation",
|
|
trace_id=trace_id,
|
|
user_id=user_id,
|
|
python_version="javascript", # Reusing field for language
|
|
context=obs_context,
|
|
)
|
|
|
|
cost = calculate_llm_cost(response.raw_response, model)
|
|
cost_tracker.append(cost)
|
|
|
|
debug_log_sensitive_data(
|
|
f"JavaScript {posthog_event_suffix}execute response:\n{response.raw_response.model_dump_json(indent=2)}"
|
|
)
|
|
|
|
if response.raw_response.usage:
|
|
ph(
|
|
user_id,
|
|
f"aiservice-testgen-js-{posthog_event_suffix}execute-openai-usage",
|
|
properties={"model": model.name, "usage": response.raw_response.usage.model_dump_json()},
|
|
)
|
|
|
|
# Parse and validate
|
|
validated_code = parse_and_validate_js_output(response.content)
|
|
return validated_code
|
|
|
|
|
|
@stamina.retry(on=TestGenerationFailedError, attempts=2)
|
|
async def generate_javascript_tests_from_function(
|
|
user_id: str,
|
|
function_name: str,
|
|
function_code: str,
|
|
module_path: str,
|
|
test_framework: str = "jest",
|
|
execute_model: LLM = EXECUTE_MODEL,
|
|
is_async: bool = False,
|
|
trace_id: str = "",
|
|
call_sequence: int | None = None,
|
|
language: str = "javascript",
|
|
) -> tuple[str, str, str]:
|
|
"""Generate JavaScript tests for a function.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
function_name: Name of function to test
|
|
function_code: Source code of function
|
|
module_path: Import path for module
|
|
test_framework: Testing framework (jest, mocha)
|
|
execute_model: LLM model to use
|
|
is_async: Whether function is async
|
|
trace_id: Trace ID for logging
|
|
call_sequence: Call sequence number
|
|
language: Language of the function to test (javascript, typescript)
|
|
|
|
Returns:
|
|
Tuple of (generated_tests, instrumented_behavior_tests, instrumented_perf_tests)
|
|
|
|
Raises:
|
|
TestGenerationFailedError: If test generation fails
|
|
|
|
"""
|
|
messages, posthog_event_suffix = build_javascript_prompt(
|
|
function_name=function_name,
|
|
function_code=function_code,
|
|
module_path=module_path,
|
|
test_framework=test_framework,
|
|
is_async=is_async,
|
|
)
|
|
|
|
cost_tracker = []
|
|
|
|
try:
|
|
validated_code = await generate_and_validate_js_test_code(
|
|
messages=messages,
|
|
model=execute_model,
|
|
cost_tracker=cost_tracker,
|
|
user_id=user_id,
|
|
posthog_event_suffix=posthog_event_suffix,
|
|
trace_id=trace_id,
|
|
call_sequence=call_sequence,
|
|
)
|
|
|
|
total_llm_cost = sum(cost_tracker)
|
|
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=user_id)
|
|
|
|
# Return uninstrumented code - instrumentation is handled by the codeflash CLI client
|
|
# This ensures consistent instrumentation using the codeflash npm package
|
|
return validated_code, validated_code, validated_code
|
|
|
|
except (SyntaxError, ValueError) as e:
|
|
total_llm_cost = sum(cost_tracker)
|
|
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=user_id)
|
|
msg = f"Failed to generate valid JavaScript test code after {len(cost_tracker)} tries. trace_id={trace_id}"
|
|
logging.exception(msg)
|
|
raise TestGenerationFailedError(msg) from e
|
|
|
|
|
|
def validate_javascript_testgen_request_data(data: TestGenSchema) -> None:
|
|
"""Validate JavaScript/TypeScript test generation request data.
|
|
|
|
Args:
|
|
data: Request data
|
|
|
|
Raises:
|
|
HttpError: If validation fails
|
|
|
|
"""
|
|
if data.test_framework not in ["jest"]:
|
|
raise HttpError(400, "Invalid test framework for JavaScript/TypeScript. We only support jest.")
|
|
if not data.function_to_optimize:
|
|
raise HttpError(400, "Invalid function to optimize. It is empty.")
|
|
if not validate_trace_id(data.trace_id):
|
|
raise HttpError(400, "Invalid trace ID. Please provide a valid UUIDv4.")
|
|
|
|
# Validate syntax based on language
|
|
if data.language == "typescript":
|
|
is_valid, error = validate_typescript_syntax(data.source_code_being_tested)
|
|
lang_name = "TypeScript"
|
|
else:
|
|
is_valid, error = validate_javascript_syntax(data.source_code_being_tested)
|
|
lang_name = "JavaScript"
|
|
|
|
if not is_valid:
|
|
raise HttpError(400, f"Invalid source code. It is not valid {lang_name}: {error}")
|
|
|
|
|
|
async def testgen_javascript(
|
|
request: AuthenticatedRequest, data: TestGenSchema
|
|
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
|
|
"""Main endpoint handler for JavaScript test generation.
|
|
|
|
Args:
|
|
request: Authenticated request
|
|
data: Test generation request data
|
|
|
|
Returns:
|
|
Tuple of (status_code, response)
|
|
|
|
"""
|
|
language = data.language
|
|
ph(request.user, "aiservice-testgen-called", properties={"language": language})
|
|
|
|
try:
|
|
validate_javascript_testgen_request_data(data)
|
|
except HttpError as e:
|
|
e.add_note(f"JavaScript testgen request validation error: {e.status_code} {e.message}")
|
|
sentry_sdk.capture_exception(e)
|
|
return e.status_code, TestGenErrorResponseSchema(error=e.message)
|
|
|
|
logging.info("/testgen: Generating JavaScript tests...")
|
|
|
|
try:
|
|
debug_log_sensitive_data(f"Generating JavaScript tests for function {data.function_to_optimize.function_name}")
|
|
|
|
# Using different LLMs for different test_index values
|
|
test_index = data.test_index if data.test_index is not None else 0
|
|
if test_index % 2 == 0:
|
|
execute_model = OPENAI_MODEL
|
|
model_source = "OpenAI"
|
|
else:
|
|
execute_model = HAIKU_MODEL
|
|
model_source = "Anthropic"
|
|
|
|
logging.info(f"Using {model_source} model ({execute_model.name}) for JavaScript test_index {test_index}")
|
|
|
|
(
|
|
generated_test_source,
|
|
instrumented_behavior_tests,
|
|
instrumented_perf_tests,
|
|
) = await generate_javascript_tests_from_function(
|
|
user_id=request.user,
|
|
function_name=data.function_to_optimize.qualified_name,
|
|
function_code=data.source_code_being_tested,
|
|
module_path=data.module_path,
|
|
test_framework=data.test_framework,
|
|
is_async=data.is_async or False,
|
|
trace_id=data.trace_id,
|
|
call_sequence=data.call_sequence,
|
|
execute_model=execute_model,
|
|
language=data.language,
|
|
)
|
|
|
|
ph(request.user, "aiservice-testgen-tests-generated", properties={"language": language})
|
|
|
|
if hasattr(request, "should_log_features") and request.should_log_features:
|
|
await log_features(
|
|
trace_id=data.trace_id,
|
|
user_id=request.user,
|
|
generated_tests=[generated_test_source],
|
|
instrumented_generated_tests=[instrumented_behavior_tests],
|
|
test_framework=data.test_framework,
|
|
metadata={
|
|
"test_timeout": data.test_timeout,
|
|
"function_to_optimize": data.function_to_optimize.function_name,
|
|
"language": language,
|
|
},
|
|
)
|
|
|
|
return 200, TestGenResponseSchema(
|
|
generated_tests=generated_test_source,
|
|
instrumented_behavior_tests=instrumented_behavior_tests,
|
|
instrumented_perf_tests=instrumented_perf_tests,
|
|
)
|
|
|
|
except Exception as e:
|
|
logging.exception(f"JavaScript test generation failed. trace_id={data.trace_id}")
|
|
sentry_sdk.capture_exception(e)
|
|
return 500, TestGenErrorResponseSchema(error="Error generating JavaScript tests. Internal server error.")
|