mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
86 lines
2.9 KiB
Python
86 lines
2.9 KiB
Python
from __future__ import annotations
|
|
|
|
import enum
|
|
|
|
import libcst
|
|
from ninja import Schema
|
|
from pydantic import field_validator
|
|
from pydantic.dataclasses import dataclass
|
|
|
|
|
|
class OptimizedCandidateSource(str, enum.Enum):
|
|
OPTIMIZE = "OPTIMIZE"
|
|
OPTIMIZE_LP = "OPTIMIZE_LP"
|
|
REFINE = "REFINE"
|
|
REPAIR = "REPAIR"
|
|
ADAPTIVE = "ADAPTIVE"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CodeAndExplanation:
|
|
cst_module: libcst.Module | None
|
|
explanation: str
|
|
|
|
@field_validator("cst_module")
|
|
def validate_cst_module(cls, v):
|
|
if not isinstance(v, libcst.Module):
|
|
raise ValueError("cst_module must be an instance of libcst.Module")
|
|
try:
|
|
# Unparse the CST module to get the source code
|
|
source_code = v.code
|
|
# Compile the source code to check for syntax errors
|
|
compile(source_code, "<string>", "exec")
|
|
except Exception as e:
|
|
raise ValueError(f"Invalid cst_module, compilation error: {e}")
|
|
return v
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CodeExplanationAndID:
|
|
cst_module: libcst.Module
|
|
explanation: str
|
|
id: str
|
|
|
|
@field_validator("cst_module")
|
|
def validate_cst_module(cls, v):
|
|
if not isinstance(v, libcst.Module):
|
|
raise ValueError("cst_module must be an instance of libcst.Module")
|
|
try:
|
|
# Unparse the CST module to get the source code
|
|
source_code = v.code
|
|
# Compile the source code to check for syntax errors
|
|
compile(source_code, "<string>", "exec")
|
|
except Exception as e:
|
|
raise ValueError(f"Invalid cst_module, compilation error: {e}")
|
|
return v
|
|
|
|
|
|
class OptimizeSchema(Schema):
|
|
source_code: str
|
|
dependency_code: str | None
|
|
trace_id: str
|
|
python_version: str | None = None # Made optional for multi-language support
|
|
language: str = "python" # NEW: language identifier (python, javascript, typescript)
|
|
language_version: str | None = None # NEW: e.g., "ES2022", "Node 20", or Python version
|
|
experiment_metadata: dict[str, str] | None = None
|
|
codeflash_version: str | None = None
|
|
current_username: str | None = None
|
|
repo_owner: str | None = None
|
|
repo_name: str | None = None
|
|
is_async: bool | None = False
|
|
model: str | None = None # Deprecated: multi-model is now handled by get_model_distribution
|
|
call_sequence: int | None = None # Deprecated: call_sequence is now auto-assigned
|
|
n_candidates: int = 5 # default value for backward compatibility
|
|
|
|
|
|
class OptimizeSchemaLP(Schema):
|
|
source_code: str
|
|
dependency_code: str | None
|
|
line_profiler_results: str | None
|
|
trace_id: str
|
|
python_version: str
|
|
experiment_metadata: dict[str, str] | None = None
|
|
codeflash_version: str | None = None
|
|
n_candidates: int = 6 # default value for backward compatibility
|
|
model: str | None = None
|
|
call_sequence: int | None = None
|