mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
75 lines
2.8 KiB
Python
75 lines
2.8 KiB
Python
import enum
|
|
from typing import Self
|
|
|
|
from ninja import Schema
|
|
from pydantic import model_validator
|
|
|
|
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
|
|
|
|
|
class TestingMode(enum.Enum):
|
|
BEHAVIOR = "behavior"
|
|
PERFORMANCE = "performance"
|
|
|
|
|
|
class TestGenSchema(Schema):
|
|
source_code_being_tested: str
|
|
function_to_optimize: FunctionToOptimize
|
|
helper_function_names: list[str] | None = None # This is the only one we should use
|
|
dependent_function_names: list[str] | None = None # Only for backwards compatibility
|
|
module_path: str
|
|
test_module_path: str
|
|
test_framework: str # "pytest", "jest"
|
|
test_timeout: int
|
|
trace_id: str
|
|
python_version: str | None = None # Made optional for multi-language support
|
|
language: str = "python" # NEW: language identifier (python, javascript, typescript)
|
|
language_version: str | None = None # NEW: e.g., "ES2022", "Node 20", or Python version
|
|
codeflash_version: str | None = None
|
|
test_index: int | None = None
|
|
is_async: bool | None = False
|
|
call_sequence: int | None = None
|
|
is_numerical_code: bool | None = None
|
|
|
|
@model_validator(mode="after")
|
|
def helper_function_names_validator(self) -> Self:
|
|
# To maintain backwards compatibility
|
|
if self.dependent_function_names is None and self.helper_function_names is None:
|
|
raise ValueError("either field 'helper_function_names' or 'dependent_function_names' is required")
|
|
if self.helper_function_names is not None:
|
|
return self
|
|
self.helper_function_names = self.dependent_function_names
|
|
self.dependent_function_names = None
|
|
return self
|
|
|
|
|
|
class TestGenResponseSchema(Schema):
|
|
generated_tests: str
|
|
instrumented_behavior_tests: str
|
|
instrumented_perf_tests: str
|
|
|
|
|
|
class TestGenDebugInfo(Schema):
|
|
"""Debug information for failed test generation."""
|
|
|
|
stage: str # "llm_generation", "code_validation", "instrumentation", "postprocessing"
|
|
raw_llm_output: str | None = None # The raw LLM response before parsing
|
|
initial_code: str | None = None # Code extracted from LLM response
|
|
fixed_code: str | None = None # Code after isort/quote fixes
|
|
final_code: str | None = None # Final code that failed validation
|
|
lines_removed: int | None = None # Number of lines truncated during validation
|
|
validation_error: str | None = None # Specific validation error message
|
|
|
|
|
|
class TestGenErrorResponseSchema(Schema):
|
|
error: str
|
|
trace_id: str | None = None
|
|
debug_info: TestGenDebugInfo | None = None
|
|
|
|
|
|
class TestGenerationFailedError(Exception):
|
|
"""Exception for test generation failures with debug context."""
|
|
|
|
def __init__(self, message: str, debug_info: dict[str, str | int | None] | None = None) -> None:
|
|
super().__init__(message)
|
|
self.debug_info = debug_info or {}
|