some tests failing now
This commit is contained in:
parent
43cf1d7067
commit
08c8067630
5 changed files with 210 additions and 971 deletions
|
|
@ -1,653 +0,0 @@
|
|||
import base64
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
import requests
|
||||
|
||||
from openai import OpenAI
|
||||
from openai._types import NOT_GIVEN
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from inference.core.env import WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, API_BASE_URL
|
||||
from inference.core.managers.base import ModelManager
|
||||
from inference.core.utils.image_utils import encode_image_to_jpeg_bytes, load_image
|
||||
from inference.core.workflows.core_steps.common.utils import run_in_parallel
|
||||
from inference.core.workflows.core_steps.common.vlms import VLM_TASKS_METADATA
|
||||
from inference.core.workflows.execution_engine.entities.base import (
|
||||
Batch,
|
||||
OutputDefinition,
|
||||
WorkflowImageData,
|
||||
)
|
||||
from inference.core.workflows.execution_engine.entities.types import (
|
||||
FLOAT_KIND,
|
||||
IMAGE_KIND,
|
||||
LANGUAGE_MODEL_OUTPUT_KIND,
|
||||
LIST_OF_VALUES_KIND,
|
||||
SECRET_KIND,
|
||||
STRING_KIND,
|
||||
ImageInputField,
|
||||
Selector,
|
||||
)
|
||||
from inference.core.workflows.prototypes.block import (
|
||||
BlockResult,
|
||||
WorkflowBlock,
|
||||
WorkflowBlockManifest,
|
||||
)
|
||||
|
||||
SUPPORTED_TASK_TYPES_LIST = [
|
||||
"unconstrained",
|
||||
"ocr",
|
||||
"structured-answering",
|
||||
"classification",
|
||||
"multi-label-classification",
|
||||
"visual-question-answering",
|
||||
"caption",
|
||||
"detailed-caption",
|
||||
]
|
||||
SUPPORTED_TASK_TYPES = set(SUPPORTED_TASK_TYPES_LIST)
|
||||
|
||||
RELEVANT_TASKS_METADATA = {
|
||||
k: v for k, v in VLM_TASKS_METADATA.items() if k in SUPPORTED_TASK_TYPES
|
||||
}
|
||||
RELEVANT_TASKS_DOCS_DESCRIPTION = "\n\n".join(
|
||||
f"* **{v['name']}** (`{k}`) - {v['description']}"
|
||||
for k, v in RELEVANT_TASKS_METADATA.items()
|
||||
)
|
||||
|
||||
|
||||
LONG_DESCRIPTION = f"""
|
||||
Ask a question to OpenAI's GPT-4 with Vision model.
|
||||
|
||||
You can specify arbitrary text prompts or predefined ones, the block supports the following types of prompt:
|
||||
|
||||
{RELEVANT_TASKS_DOCS_DESCRIPTION}
|
||||
|
||||
You need to provide your OpenAI API key to use the GPT-4 with Vision model.
|
||||
"""
|
||||
|
||||
|
||||
TaskType = Literal[tuple(SUPPORTED_TASK_TYPES_LIST)]
|
||||
|
||||
TASKS_REQUIRING_PROMPT = {
|
||||
"unconstrained",
|
||||
"visual-question-answering",
|
||||
}
|
||||
|
||||
TASKS_REQUIRING_CLASSES = {
|
||||
"classification",
|
||||
"multi-label-classification",
|
||||
}
|
||||
|
||||
TASKS_REQUIRING_OUTPUT_STRUCTURE = {
|
||||
"structured-answering",
|
||||
}
|
||||
|
||||
|
||||
|
||||
class BlockManifest(WorkflowBlockManifest):
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"name": "OpenAI",
|
||||
"version": "v3",
|
||||
"short_description": "Run OpenAI's GPT-4 with vision capabilities.",
|
||||
"long_description": LONG_DESCRIPTION,
|
||||
"license": "Apache-2.0",
|
||||
"block_type": "model",
|
||||
"search_keywords": ["LMM", "VLM", "ChatGPT", "GPT", "OpenAI"],
|
||||
"is_vlm_block": True,
|
||||
"task_type_property": "task_type",
|
||||
"ui_manifest": {
|
||||
"section": "model",
|
||||
"icon": "fal fa-atom",
|
||||
"blockPriority": 5,
|
||||
"popular": True,
|
||||
},
|
||||
},
|
||||
protected_namespaces=(),
|
||||
)
|
||||
type: Literal["roboflow_core/open_ai@v3"]
|
||||
images: Selector(kind=[IMAGE_KIND]) = ImageInputField
|
||||
task_type: TaskType = Field(
|
||||
default="unconstrained",
|
||||
description="Task type to be performed by model. Value determines required parameters and output response.",
|
||||
json_schema_extra={
|
||||
"values_metadata": RELEVANT_TASKS_METADATA,
|
||||
"recommended_parsers": {
|
||||
"structured-answering": "roboflow_core/json_parser@v1",
|
||||
"classification": "roboflow_core/vlm_as_classifier@v1",
|
||||
"multi-label-classification": "roboflow_core/vlm_as_classifier@v1",
|
||||
},
|
||||
"always_visible": True,
|
||||
},
|
||||
)
|
||||
prompt: Optional[Union[Selector(kind=[STRING_KIND]), str]] = Field(
|
||||
default=None,
|
||||
description="Text prompt to the OpenAI model",
|
||||
examples=["my prompt", "$inputs.prompt"],
|
||||
json_schema_extra={
|
||||
"relevant_for": {
|
||||
"task_type": {"values": TASKS_REQUIRING_PROMPT, "required": True},
|
||||
},
|
||||
"multiline": True,
|
||||
},
|
||||
)
|
||||
output_structure: Optional[Dict[str, str]] = Field(
|
||||
default=None,
|
||||
description="Dictionary with structure of expected JSON response",
|
||||
examples=[{"my_key": "description"}, "$inputs.output_structure"],
|
||||
json_schema_extra={
|
||||
"relevant_for": {
|
||||
"task_type": {
|
||||
"values": TASKS_REQUIRING_OUTPUT_STRUCTURE,
|
||||
"required": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
classes: Optional[Union[Selector(kind=[LIST_OF_VALUES_KIND]), List[str]]] = Field(
|
||||
default=None,
|
||||
description="List of classes to be used",
|
||||
examples=[["class-a", "class-b"], "$inputs.classes"],
|
||||
json_schema_extra={
|
||||
"relevant_for": {
|
||||
"task_type": {
|
||||
"values": TASKS_REQUIRING_CLASSES,
|
||||
"required": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
api_key: Union[Selector(kind=[STRING_KIND, SECRET_KIND]), str] = Field(
|
||||
description="Your OpenAI API key",
|
||||
examples=["xxx-xxx", "$inputs.openai_api_key"],
|
||||
private=True,
|
||||
)
|
||||
model_version: Union[
|
||||
Selector(kind=[STRING_KIND]), Literal["gpt-4o", "gpt-4o-mini"]
|
||||
] = Field(
|
||||
default="gpt-4o",
|
||||
description="Model to be used",
|
||||
examples=["gpt-4o", "$inputs.openai_model"],
|
||||
)
|
||||
image_detail: Union[
|
||||
Selector(kind=[STRING_KIND]), Literal["auto", "high", "low"]
|
||||
] = Field(
|
||||
default="auto",
|
||||
description="Indicates the image's quality, with 'high' suggesting it is of high resolution and should be processed or displayed with high fidelity.",
|
||||
examples=["auto", "high", "low"],
|
||||
)
|
||||
max_tokens: int = Field(
|
||||
default=450,
|
||||
description="Maximum number of tokens the model can generate in it's response.",
|
||||
)
|
||||
temperature: Optional[Union[float, Selector(kind=[FLOAT_KIND])]] = Field(
|
||||
default=None,
|
||||
description="Temperature to sample from the model - value in range 0.0-2.0, the higher - the more "
|
||||
'random / "creative" the generations are.',
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
)
|
||||
max_concurrent_requests: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Number of concurrent requests that can be executed by block when batch of input images provided. "
|
||||
"If not given - block defaults to value configured globally in Workflows Execution Engine. "
|
||||
"Please restrict if you hit OpenAI limits.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate(self) -> "BlockManifest":
|
||||
if self.task_type in TASKS_REQUIRING_PROMPT and self.prompt is None:
|
||||
raise ValueError(
|
||||
f"`prompt` parameter required to be set for task `{self.task_type}`"
|
||||
)
|
||||
if self.task_type in TASKS_REQUIRING_CLASSES and self.classes is None:
|
||||
raise ValueError(
|
||||
f"`classes` parameter required to be set for task `{self.task_type}`"
|
||||
)
|
||||
if (
|
||||
self.task_type in TASKS_REQUIRING_OUTPUT_STRUCTURE
|
||||
and self.output_structure is None
|
||||
):
|
||||
raise ValueError(
|
||||
f"`output_structure` parameter required to be set for task `{self.task_type}`"
|
||||
)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def get_parameters_accepting_batches(cls) -> List[str]:
|
||||
return ["images"]
|
||||
|
||||
@classmethod
|
||||
def describe_outputs(cls) -> List[OutputDefinition]:
|
||||
return [
|
||||
OutputDefinition(
|
||||
name="output", kind=[STRING_KIND, LANGUAGE_MODEL_OUTPUT_KIND]
|
||||
),
|
||||
OutputDefinition(name="classes", kind=[LIST_OF_VALUES_KIND]),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_execution_engine_compatibility(cls) -> Optional[str]:
|
||||
return ">=1.4.0,<2.0.0"
|
||||
|
||||
|
||||
class OpenAIBlockV3(WorkflowBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
api_key: Optional[str],
|
||||
):
|
||||
self._model_manager = model_manager
|
||||
self._api_key = api_key
|
||||
|
||||
@classmethod
|
||||
def get_init_parameters(cls) -> List[str]:
|
||||
return ["model_manager", "api_key"]
|
||||
|
||||
@classmethod
|
||||
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
|
||||
return BlockManifest
|
||||
|
||||
@classmethod
|
||||
def get_execution_engine_compatibility(cls) -> Optional[str]:
|
||||
return ">=1.3.0,<2.0.0"
|
||||
|
||||
def run(
|
||||
self,
|
||||
images: Batch[WorkflowImageData],
|
||||
task_type: TaskType,
|
||||
prompt: Optional[str],
|
||||
output_structure: Optional[Dict[str, str]],
|
||||
classes: Optional[List[str]],
|
||||
api_key: str,
|
||||
model_version: str,
|
||||
image_detail: Literal["low", "high", "auto"],
|
||||
max_tokens: int,
|
||||
temperature: Optional[float],
|
||||
max_concurrent_requests: Optional[int],
|
||||
) -> BlockResult:
|
||||
inference_images = [i.to_inference_format() for i in images]
|
||||
raw_outputs = run_gpt_4v_llm_prompting(
|
||||
roboflow_api_key=self._api_key,
|
||||
images=inference_images,
|
||||
task_type=task_type,
|
||||
prompt=prompt,
|
||||
output_structure=output_structure,
|
||||
classes=classes,
|
||||
openai_api_key=api_key,
|
||||
gpt_model_version=model_version,
|
||||
gpt_image_detail=image_detail,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
)
|
||||
return [
|
||||
{"output": raw_output, "classes": classes} for raw_output in raw_outputs
|
||||
]
|
||||
|
||||
|
||||
def run_gpt_4v_llm_prompting(
|
||||
images: List[Dict[str, Any]],
|
||||
task_type: TaskType,
|
||||
prompt: Optional[str],
|
||||
output_structure: Optional[Dict[str, str]],
|
||||
classes: Optional[List[str]],
|
||||
roboflow_api_key: Optional[str],
|
||||
openai_api_key: Optional[str],
|
||||
gpt_model_version: str,
|
||||
gpt_image_detail: Literal["auto", "high", "low"],
|
||||
max_tokens: int,
|
||||
temperature: Optional[int],
|
||||
max_concurrent_requests: Optional[int],
|
||||
) -> List[str]:
|
||||
if task_type not in PROMPT_BUILDERS:
|
||||
raise ValueError(f"Task type: {task_type} not supported.")
|
||||
gpt4_prompts = []
|
||||
for image in images:
|
||||
loaded_image, _ = load_image(image)
|
||||
base64_image = base64.b64encode(
|
||||
encode_image_to_jpeg_bytes(loaded_image)
|
||||
).decode("ascii")
|
||||
generated_prompt = PROMPT_BUILDERS[task_type](
|
||||
base64_image=base64_image,
|
||||
prompt=prompt,
|
||||
output_structure=output_structure,
|
||||
classes=classes,
|
||||
gpt_image_detail=gpt_image_detail,
|
||||
)
|
||||
gpt4_prompts.append(generated_prompt)
|
||||
return execute_gpt_4v_requests(
|
||||
roboflow_api_key=roboflow_api_key,
|
||||
openai_api_key=openai_api_key,
|
||||
gpt4_prompts=gpt4_prompts,
|
||||
gpt_model_version=gpt_model_version,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
)
|
||||
|
||||
|
||||
def execute_gpt_4v_requests(
|
||||
roboflow_api_key:str,
|
||||
openai_api_key: str,
|
||||
gpt4_prompts: List[List[dict]],
|
||||
gpt_model_version: str,
|
||||
max_tokens: int,
|
||||
temperature: Optional[float],
|
||||
max_concurrent_requests: Optional[int],
|
||||
) -> List[str]:
|
||||
tasks = [
|
||||
partial(
|
||||
execute_gpt_4v_request,
|
||||
roboflow_api_key=roboflow_api_key,
|
||||
openai_api_key=openai_api_key,
|
||||
prompt=prompt,
|
||||
gpt_model_version=gpt_model_version,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
for prompt in gpt4_prompts
|
||||
]
|
||||
max_workers = (
|
||||
max_concurrent_requests
|
||||
or WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS
|
||||
)
|
||||
return run_in_parallel(
|
||||
tasks=tasks,
|
||||
max_workers=max_workers,
|
||||
)
|
||||
|
||||
|
||||
def _execute_proxied_openai_request(
|
||||
roboflow_api_key: str,
|
||||
openai_api_key: str,
|
||||
prompt: List[dict],
|
||||
gpt_model_version: str,
|
||||
max_tokens: int,
|
||||
temperature: Optional[float],
|
||||
) -> str:
|
||||
"""Executes OpenAI request via Roboflow proxy."""
|
||||
payload = {
|
||||
"model": gpt_model_version,
|
||||
"messages": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"openai_api_key": openai_api_key,
|
||||
}
|
||||
if temperature is not None:
|
||||
payload["temperature"] = temperature
|
||||
|
||||
try:
|
||||
endpoint = f"{API_BASE_URL}/apiproxy/openai?api_key={roboflow_api_key}"
|
||||
response = requests.post(endpoint, json=payload)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
return response_data["choices"][0]["message"]["content"]
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Failed to connect to Roboflow proxy: {e}") from e
|
||||
except (KeyError, IndexError) as e:
|
||||
raise RuntimeError(
|
||||
f"Invalid response structure from Roboflow proxy: {e} - Response: {response.text}"
|
||||
) from e
|
||||
|
||||
|
||||
def _execute_openai_request(
|
||||
openai_api_key: str,
|
||||
prompt: List[dict],
|
||||
gpt_model_version: str,
|
||||
max_tokens: int,
|
||||
temperature: Optional[float],
|
||||
) -> str:
|
||||
"""Executes OpenAI request directly."""
|
||||
temp_value = temperature if temperature is not None else NOT_GIVEN
|
||||
try:
|
||||
client = OpenAI(api_key=openai_api_key)
|
||||
response = client.chat.completions.create(
|
||||
model=gpt_model_version,
|
||||
messages=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temp_value,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"OpenAI API request failed: {e}") from e
|
||||
|
||||
|
||||
def execute_gpt_4v_request(
|
||||
roboflow_api_key: str,
|
||||
openai_api_key: str,
|
||||
prompt: List[dict],
|
||||
gpt_model_version: str,
|
||||
max_tokens: int,
|
||||
temperature: Optional[float],
|
||||
) -> str:
|
||||
if openai_api_key.startswith("rf_key:account") or openai_api_key.startswith(
|
||||
"rf_key:user:"
|
||||
):
|
||||
return _execute_proxied_openai_request(
|
||||
roboflow_api_key=roboflow_api_key,
|
||||
openai_api_key=openai_api_key,
|
||||
prompt=prompt,
|
||||
gpt_model_version=gpt_model_version,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
else:
|
||||
return _execute_openai_request(
|
||||
openai_api_key=openai_api_key,
|
||||
prompt=prompt,
|
||||
gpt_model_version=gpt_model_version,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
|
||||
def prepare_unconstrained_prompt(
|
||||
base64_image: str,
|
||||
prompt: str,
|
||||
gpt_image_detail: str,
|
||||
**kwargs,
|
||||
) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"detail": gpt_image_detail,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def prepare_classification_prompt(
|
||||
base64_image: str, classes: List[str], gpt_image_detail: str, **kwargs
|
||||
) -> List[dict]:
|
||||
serialised_classes = ", ".join(classes)
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You act as single-class classification model. You must provide reasonable predictions. "
|
||||
"You are only allowed to produce JSON document in Markdown ```json [...]``` markers. "
|
||||
'Expected structure of json: {"class_name": "class-name", "confidence": 0.4}. '
|
||||
"`class-name` must be one of the class names defined by user. You are only allowed to return "
|
||||
"single JSON document, even if there are potentially multiple classes. You are not allowed to return list.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"List of all classes to be recognised by model: {serialised_classes}",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"detail": gpt_image_detail,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def prepare_multi_label_classification_prompt(
|
||||
base64_image: str, classes: List[str], gpt_image_detail: str, **kwargs
|
||||
) -> List[dict]:
|
||||
serialised_classes = ", ".join(classes)
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You act as multi-label classification model. You must provide reasonable predictions. "
|
||||
"You are only allowed to produce JSON document in Markdown ```json``` markers. "
|
||||
'Expected structure of json: {"predicted_classes": [{"class": "class-name-1", "confidence": 0.9}, '
|
||||
'{"class": "class-name-2", "confidence": 0.7}]}. '
|
||||
"`class-name-X` must be one of the class names defined by user and `confidence` is a float value in range "
|
||||
"0.0-1.0 that represent how sure you are that the class is present in the image. Only return class names "
|
||||
"that are visible.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"List of all classes to be recognised by model: {serialised_classes}",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"detail": gpt_image_detail,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def prepare_vqa_prompt(
|
||||
base64_image: str, prompt: str, gpt_image_detail: str, **kwargs
|
||||
) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You act as Visual Question Answering model. Your task is to provide answer to question"
|
||||
"submitted by user. If this is open-question - answer with few sentences, for ABCD question, "
|
||||
"return only the indicator of the answer.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": f"Question: {prompt}"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"detail": gpt_image_detail,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def prepare_ocr_prompt(
|
||||
base64_image: str, gpt_image_detail: str, **kwargs
|
||||
) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You act as OCR model. Your task is to read text from the image and return it in "
|
||||
"paragraphs representing the structure of texts in the image. You should only return "
|
||||
"recognised text, nothing else.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"detail": gpt_image_detail,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def prepare_caption_prompt(
|
||||
base64_image: str, gpt_image_detail: str, short_description: bool, **kwargs
|
||||
) -> List[dict]:
|
||||
caption_detail_level = "Caption should be short."
|
||||
if not short_description:
|
||||
caption_detail_level = "Caption should be extensive."
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You act as image caption model. Your task is to provide description of the image. "
|
||||
f"{caption_detail_level}",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"detail": gpt_image_detail,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def prepare_structured_answering_prompt(
|
||||
base64_image: str, output_structure: Dict[str, str], gpt_image_detail: str, **kwargs
|
||||
) -> List[dict]:
|
||||
output_structure_serialised = json.dumps(output_structure, indent=4)
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are supposed to produce responses in JSON wrapped in Markdown markers: "
|
||||
"```json\nyour-response\n```. User is to provide you dictionary with keys and values. "
|
||||
"Each key must be present in your response. Values in user dictionary represent "
|
||||
"descriptions for JSON fields to be generated. Provide only JSON Markdown in response.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Specification of requirements regarding output fields: \n"
|
||||
f"{output_structure_serialised}",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"detail": gpt_image_detail,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
PROMPT_BUILDERS = {
|
||||
"unconstrained": prepare_unconstrained_prompt,
|
||||
"ocr": prepare_ocr_prompt,
|
||||
"visual-question-answering": prepare_vqa_prompt,
|
||||
"caption": partial(prepare_caption_prompt, short_description=True),
|
||||
"detailed-caption": partial(prepare_caption_prompt, short_description=False),
|
||||
"classification": prepare_classification_prompt,
|
||||
"multi-label-classification": prepare_multi_label_classification_prompt,
|
||||
"structured-answering": prepare_structured_answering_prompt,
|
||||
}
|
||||
|
|
@ -18,6 +18,139 @@ if TYPE_CHECKING:
|
|||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
class ImportCollector(cst.CSTVisitor):
|
||||
"""Visitor that collects all import statements in a module."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.imports = []
|
||||
|
||||
def visit_Import(self, node: cst.Import) -> None:
|
||||
self.imports.append(node)
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
self.imports.append(node)
|
||||
|
||||
|
||||
class GlobalStatementCollector(cst.CSTVisitor):
|
||||
"""Visitor that collects all global statements (excluding imports and functions/classes)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.global_statements = []
|
||||
self.in_function_or_class = False
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
# Don't visit inside classes
|
||||
self.in_function_or_class = True
|
||||
return False
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
self.in_function_or_class = False
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
||||
# Don't visit inside functions
|
||||
self.in_function_or_class = True
|
||||
return False
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
|
||||
self.in_function_or_class = False
|
||||
|
||||
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
|
||||
if not self.in_function_or_class:
|
||||
for statement in node.body:
|
||||
# Skip imports
|
||||
if not isinstance(statement, (cst.Import, cst.ImportFrom)):
|
||||
self.global_statements.append(node)
|
||||
break
|
||||
|
||||
|
||||
class LastImportFinder(cst.CSTVisitor):
|
||||
"""Finds the position of the last import statement in the module."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.last_import_line = 0
|
||||
self.current_line = 0
|
||||
|
||||
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
|
||||
self.current_line += 1
|
||||
for statement in node.body:
|
||||
if isinstance(statement, (cst.Import, cst.ImportFrom)):
|
||||
self.last_import_line = self.current_line
|
||||
|
||||
|
||||
class ImportInserter(cst.CSTTransformer):
|
||||
"""Transformer that inserts global statements after the last import."""
|
||||
|
||||
def __init__(self, global_statements: List[cst.SimpleStatementLine], last_import_line: int):
|
||||
super().__init__()
|
||||
self.global_statements = global_statements
|
||||
self.last_import_line = last_import_line
|
||||
self.current_line = 0
|
||||
self.inserted = False
|
||||
|
||||
def leave_SimpleStatementLine(
|
||||
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
|
||||
) -> cst.Module:
|
||||
self.current_line += 1
|
||||
|
||||
# If we're right after the last import and haven't inserted yet
|
||||
if self.current_line == self.last_import_line and not self.inserted:
|
||||
self.inserted = True
|
||||
return cst.Module(body=[updated_node] + self.global_statements)
|
||||
|
||||
return cst.Module(body=[updated_node])
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# If there were no imports, add at the beginning of the module
|
||||
if self.last_import_line == 0 and not self.inserted:
|
||||
updated_body = list(updated_node.body)
|
||||
for stmt in reversed(self.global_statements):
|
||||
updated_body.insert(0, stmt)
|
||||
return updated_node.with_changes(body=updated_body)
|
||||
return updated_node
|
||||
|
||||
|
||||
def extract_global_statements(source_code: str) -> List[cst.SimpleStatementLine]:
|
||||
"""Extract global statements from source code."""
|
||||
module = cst.parse_module(source_code)
|
||||
collector = GlobalStatementCollector()
|
||||
module.visit(collector)
|
||||
return collector.global_statements
|
||||
|
||||
|
||||
def find_last_import_line(target_code: str) -> int:
|
||||
"""Find the line number of the last import statement."""
|
||||
module = cst.parse_module(target_code)
|
||||
finder = LastImportFinder()
|
||||
module.visit(finder)
|
||||
return finder.last_import_line
|
||||
|
||||
|
||||
def merge_globals(source_code: str, target_code: str) -> str:
|
||||
"""Merge global statements from source into target just after imports."""
|
||||
# Extract global statements from source
|
||||
global_statements = extract_global_statements(source_code)
|
||||
|
||||
# Find the last import line in target
|
||||
last_import_line = find_last_import_line(target_code)
|
||||
|
||||
# Parse the target code
|
||||
target_module = cst.parse_module(target_code)
|
||||
|
||||
# Create transformer to insert global statements
|
||||
transformer = ImportInserter(global_statements, last_import_line)
|
||||
|
||||
# Apply transformation
|
||||
modified_module = target_module.visit(transformer)
|
||||
|
||||
# Return the modified code
|
||||
return modified_module.code
|
||||
|
||||
|
||||
class FutureAliasedImportTransformer(cst.CSTTransformer):
|
||||
def leave_ImportFrom(
|
||||
|
|
@ -47,6 +180,20 @@ def add_needed_imports_from_module(
|
|||
helper_functions: list[FunctionSource] | None = None,
|
||||
helper_functions_fqn: set[str] | None = None,
|
||||
) -> str:
|
||||
global_statements = extract_global_statements(src_module_code)
|
||||
|
||||
# Find the last import line in target
|
||||
last_import_line = find_last_import_line(dst_module_code)
|
||||
|
||||
# Parse the target code
|
||||
target_module = cst.parse_module(dst_module_code)
|
||||
|
||||
# Create transformer to insert global statements
|
||||
transformer = ImportInserter(global_statements, last_import_line)
|
||||
#
|
||||
# # Apply transformation
|
||||
modified_module = target_module.visit(transformer)
|
||||
dst_module_code = modified_module.code
|
||||
"""Add all needed and used source module code imports to the destination module code, and return it."""
|
||||
src_module_code = delete___future___aliased_imports(src_module_code)
|
||||
if not helper_functions_fqn:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from codeflash.code_utils.code_replacer import (
|
|||
replace_functions_in_file,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.models.models import CodeOptimizationContext, FunctionParent
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
|
@ -34,6 +34,48 @@ class FakeFunctionSource:
|
|||
jedi_definition: JediDefinition
|
||||
|
||||
|
||||
class Args:
|
||||
disable_imports_sorting = True
|
||||
formatter_cmds = ["disabled"]
|
||||
|
||||
|
||||
def test_code_replacement_global_statements():
|
||||
optimized_code = """import numpy as np
|
||||
inconsequential_var = '123'
|
||||
def sorter(arr):
|
||||
return arr.sort()"""
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_optimized.py").resolve()
|
||||
original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
code_path.write_text(original_code_str, encoding="utf-8")
|
||||
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
|
||||
project_root_path = (Path(__file__).parent / "..").resolve()
|
||||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
|
||||
test_config = TestConfig(
|
||||
tests_root=tests_root,
|
||||
tests_project_rootdir=project_root_path,
|
||||
project_root_path=project_root_path,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
for helper_function_path in helper_function_paths:
|
||||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=optimized_code
|
||||
)
|
||||
final_output = code_path.read_text(encoding="utf-8")
|
||||
assert "inconsequential_var = '123'" in final_output
|
||||
code_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement() -> None:
|
||||
optim_code = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
|
@ -74,7 +116,7 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
|
|
@ -135,7 +177,7 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
|
|
@ -196,7 +238,7 @@ print("Salut monde")
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["other_function"]
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -260,7 +302,7 @@ print("Salut monde")
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["yet_another_function", "other_function"]
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -313,7 +355,7 @@ def supersort(doink):
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["sorter_deps"]
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -591,7 +633,7 @@ class CacheConfig(BaseConfig):
|
|||
)
|
||||
"""
|
||||
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -662,7 +704,7 @@ def test_test_libcst_code_replacement8() -> None:
|
|||
return np.sum(a != b) / a.size
|
||||
'''
|
||||
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -715,7 +757,7 @@ def totally_new_function(value: Optional[str]):
|
|||
print("Hello world")
|
||||
"""
|
||||
function_name: str = "NewClass.__init__"
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
|
|
@ -811,7 +853,7 @@ def test_code_replacement11() -> None:
|
|||
|
||||
function_name: str = "Fu.foo"
|
||||
parents = (FunctionParent("Fu", "ClassDef"),)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = {("foo", parents), ("real_bar", parents)}
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = {("foo", parents), ("real_bar", parents)}
|
||||
new_code: str = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
original_function_names=[function_name],
|
||||
|
|
@ -850,7 +892,7 @@ def test_code_replacement12() -> None:
|
|||
pass
|
||||
'''
|
||||
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = []
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = []
|
||||
new_code: str = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
original_function_names=["Fu.real_bar"],
|
||||
|
|
@ -887,7 +929,7 @@ def test_test_libcst_code_replacement13() -> None:
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["yet_another_function", "other_function"]
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = []
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = []
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -1274,7 +1316,7 @@ def cosine_similarity_top_k(
|
|||
|
||||
return ret_idxs, scores
|
||||
'''
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
|
||||
helper_functions = [
|
||||
FakeFunctionSource(
|
||||
|
|
@ -1575,7 +1617,7 @@ print("Hello world")
|
|||
"NewClass.new_function2",
|
||||
"NestedClass.nested_function",
|
||||
] # Nested classes should be ignored, even if provided as target
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -1611,7 +1653,7 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
|
|||
|
|
@ -1,147 +0,0 @@
|
|||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
class Args:
|
||||
disable_imports_sorting = True
|
||||
formatter_cmds = ["disabled"]
|
||||
def test_code_replacement_pr():
|
||||
optimized_code = """from typing import List, Optional
|
||||
|
||||
import requests
|
||||
from inference.core.env import API_BASE_URL
|
||||
from openai import OpenAI
|
||||
from openai._types import NOT_GIVEN
|
||||
|
||||
# Create a global requests session to reuse connections
|
||||
sess = requests.Session()
|
||||
|
||||
# Create a cache for OpenAI clients to avoid recreating them frequently
|
||||
openai_clients = {}
|
||||
|
||||
def _execute_proxied_openai_request(
|
||||
roboflow_api_key: str,
|
||||
openai_api_key: str,
|
||||
prompt: List[dict],
|
||||
gpt_model_version: str,
|
||||
max_tokens: int,
|
||||
temperature: Optional[float],
|
||||
) -> str:
|
||||
\"\"\"Executes OpenAI request via Roboflow proxy.\"\"\"
|
||||
payload = {
|
||||
\"model\": gpt_model_version,
|
||||
\"messages\": prompt,
|
||||
\"max_tokens\": max_tokens,
|
||||
\"openai_api_key\": openai_api_key,
|
||||
}
|
||||
if temperature is not None:
|
||||
payload[\"temperature\"] = temperature
|
||||
|
||||
try:
|
||||
endpoint = f\"{API_BASE_URL}/apiproxy/openai?api_key={roboflow_api_key}\"
|
||||
# Use global session for requests
|
||||
response = sess.post(endpoint, json=payload)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
return response_data[\"choices\"][0][\"message\"][\"content\"]
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f\"Failed to connect to Roboflow proxy: {e}\") from e
|
||||
except (KeyError, IndexError) as e:
|
||||
raise RuntimeError(
|
||||
f\"Invalid response structure from Roboflow proxy: {e} - Response: {response.text}\"
|
||||
) from e
|
||||
|
||||
|
||||
def _execute_openai_request(
|
||||
openai_api_key: str,
|
||||
prompt: List[dict],
|
||||
gpt_model_version: str,
|
||||
max_tokens: int,
|
||||
temperature: Optional[float],
|
||||
) -> str:
|
||||
\"\"\"Executes OpenAI request directly.\"\"\"
|
||||
temp_value = temperature if temperature is not None else NOT_GIVEN
|
||||
try:
|
||||
# Cache OpenAI client to avoid creating a new one each time
|
||||
if openai_api_key not in openai_clients:
|
||||
openai_clients[openai_api_key] = OpenAI(api_key=openai_api_key)
|
||||
client = openai_clients[openai_api_key]
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=gpt_model_version,
|
||||
messages=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temp_value,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
raise RuntimeError(f\"OpenAI API request failed: {e}\") from e
|
||||
|
||||
|
||||
def execute_gpt_4v_request(
|
||||
roboflow_api_key: str,
|
||||
openai_api_key: str,
|
||||
prompt: List[dict],
|
||||
gpt_model_version: str,
|
||||
max_tokens: int,
|
||||
temperature: Optional[float],
|
||||
) -> str:
|
||||
if openai_api_key.startswith(\"rf_key:account\") or openai_api_key.startswith(
|
||||
\"rf_key:user:\"
|
||||
):
|
||||
return _execute_proxied_openai_request(
|
||||
roboflow_api_key=roboflow_api_key,
|
||||
openai_api_key=openai_api_key,
|
||||
prompt=prompt,
|
||||
gpt_model_version=gpt_model_version,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
else:
|
||||
return _execute_openai_request(
|
||||
openai_api_key=openai_api_key,
|
||||
prompt=prompt,
|
||||
gpt_model_version=gpt_model_version,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
"""
|
||||
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/roboflow.py").resolve()
|
||||
original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/roboflow_original.py").read_text(encoding="utf-8")
|
||||
code_path.write_text(original_code_str, encoding="utf-8")
|
||||
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
|
||||
project_root_path = (Path(__file__).parent / "..").resolve()
|
||||
func = FunctionToOptimize(function_name="execute_gpt_4v_request", parents=[], file_path=code_path)
|
||||
test_config = TestConfig(
|
||||
tests_root=tests_root,
|
||||
tests_project_rootdir=project_root_path,
|
||||
project_root_path=project_root_path,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
for helper_function_path in helper_function_paths:
|
||||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=optimized_code
|
||||
)
|
||||
new_code, new_helper_code = func_optimizer.reformat_code_and_helpers(
|
||||
code_context.helper_functions, func.file_path, func_optimizer.function_to_optimize_source_code
|
||||
)
|
||||
original_code_combined = original_helper_code.copy()
|
||||
original_code_combined[func.file_path] = func_optimizer.function_to_optimize_source_code
|
||||
new_code_combined = new_helper_code.copy()
|
||||
new_code_combined[func.file_path] = new_code
|
||||
final_output = code_path.read_text(encoding="utf-8")
|
||||
assert "openai_clients = {}" in final_output
|
||||
code_path.unlink(missing_ok=True)
|
||||
|
|
@ -1,150 +0,0 @@
|
|||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
class Args:
|
||||
disable_imports_sorting = True
|
||||
formatter_cmds = ["disabled"]
|
||||
def test_code_replacement_pr():
|
||||
optimized_code = """from typing import List, Optional
|
||||
|
||||
import requests
|
||||
from inference.core.env import API_BASE_URL
|
||||
from openai import OpenAI
|
||||
from openai._types import NOT_GIVEN
|
||||
|
||||
# Create a global requests session to reuse connections
|
||||
sess = requests.Session()
|
||||
|
||||
# Create a cache for OpenAI clients to avoid recreating them frequently
|
||||
openai_clients = {}
|
||||
|
||||
def _execute_proxied_openai_request(
|
||||
roboflow_api_key: str,
|
||||
openai_api_key: str,
|
||||
prompt: List[dict],
|
||||
gpt_model_version: str,
|
||||
max_tokens: int,
|
||||
temperature: Optional[float],
|
||||
) -> str:
|
||||
\"\"\"Executes OpenAI request via Roboflow proxy.\"\"\"
|
||||
payload = {
|
||||
\"model\": gpt_model_version,
|
||||
\"messages\": prompt,
|
||||
\"max_tokens\": max_tokens,
|
||||
\"openai_api_key\": openai_api_key,
|
||||
}
|
||||
if temperature is not None:
|
||||
payload[\"temperature\"] = temperature
|
||||
|
||||
try:
|
||||
endpoint = f\"{API_BASE_URL}/apiproxy/openai?api_key={roboflow_api_key}\"
|
||||
# Use global session for requests
|
||||
response = sess.post(endpoint, json=payload)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
return response_data[\"choices\"][0][\"message\"][\"content\"]
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f\"Failed to connect to Roboflow proxy: {e}\") from e
|
||||
except (KeyError, IndexError) as e:
|
||||
raise RuntimeError(
|
||||
f\"Invalid response structure from Roboflow proxy: {e} - Response: {response.text}\"
|
||||
) from e
|
||||
|
||||
|
||||
def _execute_openai_request(
|
||||
openai_api_key: str,
|
||||
prompt: List[dict],
|
||||
gpt_model_version: str,
|
||||
max_tokens: int,
|
||||
temperature: Optional[float],
|
||||
) -> str:
|
||||
\"\"\"Executes OpenAI request directly.\"\"\"
|
||||
temp_value = temperature if temperature is not None else NOT_GIVEN
|
||||
try:
|
||||
# Cache OpenAI client to avoid creating a new one each time
|
||||
if openai_api_key not in openai_clients:
|
||||
openai_clients[openai_api_key] = OpenAI(api_key=openai_api_key)
|
||||
client = openai_clients[openai_api_key]
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=gpt_model_version,
|
||||
messages=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temp_value,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
raise RuntimeError(f\"OpenAI API request failed: {e}\") from e
|
||||
|
||||
|
||||
def execute_gpt_4v_request(
|
||||
roboflow_api_key: str,
|
||||
openai_api_key: str,
|
||||
prompt: List[dict],
|
||||
gpt_model_version: str,
|
||||
max_tokens: int,
|
||||
temperature: Optional[float],
|
||||
) -> str:
|
||||
if openai_api_key.startswith(\"rf_key:account\") or openai_api_key.startswith(
|
||||
\"rf_key:user:\"
|
||||
):
|
||||
return _execute_proxied_openai_request(
|
||||
roboflow_api_key=roboflow_api_key,
|
||||
openai_api_key=openai_api_key,
|
||||
prompt=prompt,
|
||||
gpt_model_version=gpt_model_version,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
else:
|
||||
return _execute_openai_request(
|
||||
openai_api_key=openai_api_key,
|
||||
prompt=prompt,
|
||||
gpt_model_version=gpt_model_version,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
"""
|
||||
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/roboflow.py").resolve()
|
||||
original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/roboflow_original.py").read_text(encoding="utf-8")
|
||||
code_path.write_text(original_code_str, encoding="utf-8")
|
||||
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/roboflow_tests/tests/workflows/unit_tests/core_steps/models/foundation")
|
||||
# tests/workflows/unit_tests/core_steps/models/foundation/test_openai.py
|
||||
project_root_path = (Path(__file__).parent / "..").resolve()
|
||||
func = FunctionToOptimize(function_name="execute_gpt_4v_request", parents=[], file_path=code_path)
|
||||
test_config = TestConfig(
|
||||
tests_root=tests_root,
|
||||
tests_project_rootdir=project_root_path,
|
||||
project_root_path=project_root_path,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
discovered_tests = discover_unit_tests(test_config)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
for helper_function_path in helper_function_paths:
|
||||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=optimized_code
|
||||
)
|
||||
# new_code, new_helper_code = func_optimizer.reformat_code_and_helpers(
|
||||
# code_context.helper_functions, func.file_path, func_optimizer.function_to_optimize_source_code
|
||||
# )
|
||||
# original_code_combined = original_helper_code.copy()
|
||||
# original_code_combined[func.file_path] = func_optimizer.function_to_optimize_source_code
|
||||
# new_code_combined = new_helper_code.copy()
|
||||
# new_code_combined[func.file_path] = new_code
|
||||
final_output = code_path.read_text(encoding="utf-8")
|
||||
assert "openai_clients = {}" in final_output
|
||||
code_path.unlink(missing_ok=True)
|
||||
Loading…
Reference in a new issue