refactor: extract shared create_prompt_env Jinja2 factory

Deduplicate the identical Environment(FileSystemLoader, StrictUndefined,
keep_trailing_newline=True) setup across JS testgen, Python testgen, and
Python explanations into core/shared/jinja_utils.py.

Also fix tests/testgen/test_testgen_javascript.py which had a stale
copy of build_javascript_prompt and loaded the now-deleted .md files.
This commit is contained in:
Kevin Turcios 2026-03-02 18:42:57 -05:00
parent 7820fb15e1
commit 5d2ad27d3f
5 changed files with 26 additions and 73 deletions

View file

@ -12,7 +12,6 @@ from pathlib import Path
from typing import TYPE_CHECKING
import sentry_sdk
from jinja2 import Environment, FileSystemLoader, StrictUndefined
import stamina
from ninja.errors import HttpError
from openai import OpenAIError
@ -26,6 +25,7 @@ from aiservice.validators.javascript_validator import validate_javascript_syntax
from authapp.auth import AuthenticatedRequest
from core.log_features.log_event import update_optimization_cost
from core.log_features.log_features import log_features
from core.shared.jinja_utils import create_prompt_env
from core.shared.testgen_models import (
TestGenerationFailedError,
TestGenErrorResponseSchema,
@ -42,7 +42,7 @@ _TEST_FUNC_RE = re.compile(r"(?:test|it)\s*\(\s*['\"]")
current_dir = Path(__file__).parent
JS_PROMPTS_DIR = current_dir / "prompts" / "testgen"
_jinja_env = Environment(loader=FileSystemLoader(JS_PROMPTS_DIR), keep_trailing_newline=True, undefined=StrictUndefined)
_jinja_env = create_prompt_env(JS_PROMPTS_DIR)
# Pattern to extract JavaScript code blocks
JS_PATTERN = re.compile(r"^```(?:javascript|js|typescript|ts)?\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)

View file

@ -5,7 +5,6 @@ import logging
from pathlib import Path
import sentry_sdk
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from ninja import NinjaAPI
from openai.types.chat import (
ChatCompletionMessageParam,
@ -27,17 +26,12 @@ from core.languages.python.explanations.models import (
)
from core.log_features.log_event import update_optimization_cost
from core.log_features.log_features import log_features
from core.shared.jinja_utils import create_prompt_env
explanations_api = NinjaAPI(urls_namespace="explanations")
_PROMPT_DIR = Path(__file__).parent / "prompts"
_jinja_env = Environment( # noqa: S701 — rendering LLM prompts, not HTML
loader=FileSystemLoader(_PROMPT_DIR),
undefined=StrictUndefined,
keep_trailing_newline=True,
trim_blocks=True,
lstrip_blocks=True,
)
_jinja_env = create_prompt_env(_PROMPT_DIR, trim_blocks=True, lstrip_blocks=True)
SYSTEM_PROMPT_TEMPLATE = _jinja_env.get_template("system_prompt.md.j2")
USER_PROMPT_TEMPLATE = _jinja_env.get_template("user_prompt.md.j2")

View file

@ -10,7 +10,6 @@ from typing import TYPE_CHECKING
import libcst as cst
import sentry_sdk
import stamina
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from ninja.errors import HttpError
from openai import OpenAIError
@ -34,6 +33,7 @@ from core.languages.python.testgen.preprocessing.preprocess_pipeline import gene
from core.languages.python.testgen.validate import instrument_tests, validate_request_data
from core.log_features.log_event import update_optimization_cost
from core.log_features.log_features import log_features
from core.shared.jinja_utils import create_prompt_env
from core.shared.testgen_models import (
TestGenDebugInfo,
TestGenerationFailedError,
@ -58,7 +58,7 @@ def select_model_for_test(test_index: int) -> tuple[LLM, str]:
_current_dir = Path(__file__).parent
_prompts_dir = _current_dir / "prompts"
_jinja_env = Environment(loader=FileSystemLoader(_prompts_dir), keep_trailing_newline=True, undefined=StrictUndefined) # noqa: S701 - rendering LLM prompts, not HTML
_jinja_env = create_prompt_env(_prompts_dir)
def build_prompt(

View file

@ -0,0 +1,18 @@
from pathlib import Path
from jinja2 import Environment, FileSystemLoader, StrictUndefined
def create_prompt_env(
prompt_dir: Path,
*,
trim_blocks: bool = False,
lstrip_blocks: bool = False,
) -> Environment:
return Environment( # noqa: S701 — rendering LLM prompts, not HTML
loader=FileSystemLoader(prompt_dir),
keep_trailing_newline=True,
undefined=StrictUndefined,
trim_blocks=trim_blocks,
lstrip_blocks=lstrip_blocks,
)

View file

@ -1,23 +1,13 @@
"""Tests for JavaScript test generation module.
Tests the prompt building and validation functions without importing
the full testgen module (to avoid LLM dependencies in tests).
Tests the prompt building and validation functions.
"""
import re
from pathlib import Path
import pytest
# Load prompts directly to avoid importing testgen_javascript.py
JS_PROMPTS_DIR = Path(__file__).parent.parent.parent / "core" / "languages" / "js_ts" / "prompts" / "testgen"
JS_EXECUTE_SYSTEM_PROMPT = (JS_PROMPTS_DIR / "execute_system_prompt.md").read_text()
JS_EXECUTE_USER_PROMPT = (JS_PROMPTS_DIR / "execute_user_prompt.md").read_text()
JS_EXECUTE_ASYNC_SYSTEM_PROMPT = (JS_PROMPTS_DIR / "execute_async_system_prompt.md").read_text()
JS_EXECUTE_ASYNC_USER_PROMPT = (JS_PROMPTS_DIR / "execute_async_user_prompt.md").read_text()
JS_PATTERN = re.compile(r"^```(?:javascript|js|typescript|ts)?\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
from core.languages.js_ts.testgen import build_javascript_prompt, parse_and_validate_js_output
def _has_test_functions(code: str) -> bool:
@ -26,55 +16,6 @@ def _has_test_functions(code: str) -> bool:
return bool(re.search(test_pattern, code))
def build_javascript_prompt(
function_name: str, function_code: str, module_path: str, test_framework: str, is_async: bool
) -> tuple[list[dict[str, str]], str]:
"""Build the prompt messages for JavaScript test generation."""
if is_async:
system_prompt = JS_EXECUTE_ASYNC_SYSTEM_PROMPT
user_prompt = JS_EXECUTE_ASYNC_USER_PROMPT
posthog_event_suffix = "async-"
else:
system_prompt = JS_EXECUTE_SYSTEM_PROMPT
user_prompt = JS_EXECUTE_USER_PROMPT
posthog_event_suffix = ""
import_statement = f"import {{ {function_name} }} from '{module_path}';"
system_message = {"role": "system", "content": system_prompt.format(function_name=function_name)}
user_message = {
"role": "user",
"content": user_prompt.format(
test_framework=test_framework,
function_name=function_name,
function_code=function_code,
import_statement=import_statement,
package_comment="",
),
}
messages = [system_message, user_message]
return messages, posthog_event_suffix
def parse_and_validate_js_output(response_content: str) -> str:
"""Parse and validate the LLM response for JavaScript code."""
if "```" not in response_content:
raise ValueError("LLM response did not contain a code block.")
pattern_res = JS_PATTERN.search(response_content)
if not pattern_res:
raise ValueError("No JavaScript code block found in the LLM response.")
code = pattern_res.group(1).strip()
if not _has_test_functions(code):
raise ValueError("Generated code does not contain any test functions.")
return code
class TestHasTestFunctions:
"""Tests for detecting Jest test functions in code."""