57 lines
1.8 KiB
Python
57 lines
1.8 KiB
Python
import libcst
|
|
from ninja import Schema
|
|
from pydantic import field_validator
|
|
from pydantic.dataclasses import dataclass
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CodeAndExplanation:
|
|
cst_module: libcst.Module | None
|
|
explanation: str
|
|
|
|
@field_validator("cst_module")
|
|
def validate_cst_module(cls, v):
|
|
if not isinstance(v, libcst.Module):
|
|
raise ValueError("cst_module must be an instance of libcst.Module")
|
|
try:
|
|
# Unparse the CST module to get the source code
|
|
source_code = v.code
|
|
# Compile the source code to check for syntax errors
|
|
compile(source_code, "<string>", "exec")
|
|
except Exception as e:
|
|
raise ValueError(f"Invalid cst_module, compilation error: {e}")
|
|
return v
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CodeExplanationAndID:
|
|
cst_module: libcst.Module
|
|
explanation: str
|
|
id: str
|
|
|
|
@field_validator("cst_module")
|
|
def validate_cst_module(cls, v):
|
|
if not isinstance(v, libcst.Module):
|
|
raise ValueError("cst_module must be an instance of libcst.Module")
|
|
try:
|
|
# Unparse the CST module to get the source code
|
|
source_code = v.code
|
|
# Compile the source code to check for syntax errors
|
|
compile(source_code, "<string>", "exec")
|
|
except Exception as e:
|
|
raise ValueError(f"Invalid cst_module, compilation error: {e}")
|
|
return v
|
|
|
|
|
|
class OptimizeSchema(Schema):
|
|
source_code: str
|
|
dependency_code: str | None
|
|
trace_id: str
|
|
python_version: str
|
|
experiment_metadata: dict[str, str] | None = None
|
|
codeflash_version: str | None = None
|
|
current_username: str | None = None
|
|
repo_owner: str | None = None
|
|
repo_name: str | None = None
|
|
is_async: bool | None = False
|
|
n_candidates: int | None = 5
|