some tests failing now

This commit is contained in:
aseembits93 2025-04-29 18:34:40 -07:00
parent 43cf1d7067
commit 08c8067630
5 changed files with 210 additions and 971 deletions

View file

@ -1,653 +0,0 @@
import base64
import json
from functools import partial
from typing import Any, Dict, List, Literal, Optional, Type, Union
import requests
from openai import OpenAI
from openai._types import NOT_GIVEN
from pydantic import ConfigDict, Field, model_validator
from inference.core.env import WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, API_BASE_URL
from inference.core.managers.base import ModelManager
from inference.core.utils.image_utils import encode_image_to_jpeg_bytes, load_image
from inference.core.workflows.core_steps.common.utils import run_in_parallel
from inference.core.workflows.core_steps.common.vlms import VLM_TASKS_METADATA
from inference.core.workflows.execution_engine.entities.base import (
Batch,
OutputDefinition,
WorkflowImageData,
)
from inference.core.workflows.execution_engine.entities.types import (
FLOAT_KIND,
IMAGE_KIND,
LANGUAGE_MODEL_OUTPUT_KIND,
LIST_OF_VALUES_KIND,
SECRET_KIND,
STRING_KIND,
ImageInputField,
Selector,
)
from inference.core.workflows.prototypes.block import (
BlockResult,
WorkflowBlock,
WorkflowBlockManifest,
)
SUPPORTED_TASK_TYPES_LIST = [
"unconstrained",
"ocr",
"structured-answering",
"classification",
"multi-label-classification",
"visual-question-answering",
"caption",
"detailed-caption",
]
SUPPORTED_TASK_TYPES = set(SUPPORTED_TASK_TYPES_LIST)
RELEVANT_TASKS_METADATA = {
k: v for k, v in VLM_TASKS_METADATA.items() if k in SUPPORTED_TASK_TYPES
}
RELEVANT_TASKS_DOCS_DESCRIPTION = "\n\n".join(
f"* **{v['name']}** (`{k}`) - {v['description']}"
for k, v in RELEVANT_TASKS_METADATA.items()
)
LONG_DESCRIPTION = f"""
Ask a question to OpenAI's GPT-4 with Vision model.
You can specify arbitrary text prompts or predefined ones, the block supports the following types of prompt:
{RELEVANT_TASKS_DOCS_DESCRIPTION}
You need to provide your OpenAI API key to use the GPT-4 with Vision model.
"""
TaskType = Literal[tuple(SUPPORTED_TASK_TYPES_LIST)]
TASKS_REQUIRING_PROMPT = {
"unconstrained",
"visual-question-answering",
}
TASKS_REQUIRING_CLASSES = {
"classification",
"multi-label-classification",
}
TASKS_REQUIRING_OUTPUT_STRUCTURE = {
"structured-answering",
}
class BlockManifest(WorkflowBlockManifest):
model_config = ConfigDict(
json_schema_extra={
"name": "OpenAI",
"version": "v3",
"short_description": "Run OpenAI's GPT-4 with vision capabilities.",
"long_description": LONG_DESCRIPTION,
"license": "Apache-2.0",
"block_type": "model",
"search_keywords": ["LMM", "VLM", "ChatGPT", "GPT", "OpenAI"],
"is_vlm_block": True,
"task_type_property": "task_type",
"ui_manifest": {
"section": "model",
"icon": "fal fa-atom",
"blockPriority": 5,
"popular": True,
},
},
protected_namespaces=(),
)
type: Literal["roboflow_core/open_ai@v3"]
images: Selector(kind=[IMAGE_KIND]) = ImageInputField
task_type: TaskType = Field(
default="unconstrained",
description="Task type to be performed by model. Value determines required parameters and output response.",
json_schema_extra={
"values_metadata": RELEVANT_TASKS_METADATA,
"recommended_parsers": {
"structured-answering": "roboflow_core/json_parser@v1",
"classification": "roboflow_core/vlm_as_classifier@v1",
"multi-label-classification": "roboflow_core/vlm_as_classifier@v1",
},
"always_visible": True,
},
)
prompt: Optional[Union[Selector(kind=[STRING_KIND]), str]] = Field(
default=None,
description="Text prompt to the OpenAI model",
examples=["my prompt", "$inputs.prompt"],
json_schema_extra={
"relevant_for": {
"task_type": {"values": TASKS_REQUIRING_PROMPT, "required": True},
},
"multiline": True,
},
)
output_structure: Optional[Dict[str, str]] = Field(
default=None,
description="Dictionary with structure of expected JSON response",
examples=[{"my_key": "description"}, "$inputs.output_structure"],
json_schema_extra={
"relevant_for": {
"task_type": {
"values": TASKS_REQUIRING_OUTPUT_STRUCTURE,
"required": True,
},
},
},
)
classes: Optional[Union[Selector(kind=[LIST_OF_VALUES_KIND]), List[str]]] = Field(
default=None,
description="List of classes to be used",
examples=[["class-a", "class-b"], "$inputs.classes"],
json_schema_extra={
"relevant_for": {
"task_type": {
"values": TASKS_REQUIRING_CLASSES,
"required": True,
},
},
},
)
api_key: Union[Selector(kind=[STRING_KIND, SECRET_KIND]), str] = Field(
description="Your OpenAI API key",
examples=["xxx-xxx", "$inputs.openai_api_key"],
private=True,
)
model_version: Union[
Selector(kind=[STRING_KIND]), Literal["gpt-4o", "gpt-4o-mini"]
] = Field(
default="gpt-4o",
description="Model to be used",
examples=["gpt-4o", "$inputs.openai_model"],
)
image_detail: Union[
Selector(kind=[STRING_KIND]), Literal["auto", "high", "low"]
] = Field(
default="auto",
description="Indicates the image's quality, with 'high' suggesting it is of high resolution and should be processed or displayed with high fidelity.",
examples=["auto", "high", "low"],
)
max_tokens: int = Field(
default=450,
description="Maximum number of tokens the model can generate in it's response.",
)
temperature: Optional[Union[float, Selector(kind=[FLOAT_KIND])]] = Field(
default=None,
description="Temperature to sample from the model - value in range 0.0-2.0, the higher - the more "
'random / "creative" the generations are.',
ge=0.0,
le=2.0,
)
max_concurrent_requests: Optional[int] = Field(
default=None,
description="Number of concurrent requests that can be executed by block when batch of input images provided. "
"If not given - block defaults to value configured globally in Workflows Execution Engine. "
"Please restrict if you hit OpenAI limits.",
)
@model_validator(mode="after")
def validate(self) -> "BlockManifest":
if self.task_type in TASKS_REQUIRING_PROMPT and self.prompt is None:
raise ValueError(
f"`prompt` parameter required to be set for task `{self.task_type}`"
)
if self.task_type in TASKS_REQUIRING_CLASSES and self.classes is None:
raise ValueError(
f"`classes` parameter required to be set for task `{self.task_type}`"
)
if (
self.task_type in TASKS_REQUIRING_OUTPUT_STRUCTURE
and self.output_structure is None
):
raise ValueError(
f"`output_structure` parameter required to be set for task `{self.task_type}`"
)
return self
@classmethod
def get_parameters_accepting_batches(cls) -> List[str]:
return ["images"]
@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return [
OutputDefinition(
name="output", kind=[STRING_KIND, LANGUAGE_MODEL_OUTPUT_KIND]
),
OutputDefinition(name="classes", kind=[LIST_OF_VALUES_KIND]),
]
@classmethod
def get_execution_engine_compatibility(cls) -> Optional[str]:
return ">=1.4.0,<2.0.0"
class OpenAIBlockV3(WorkflowBlock):
def __init__(
self,
model_manager: ModelManager,
api_key: Optional[str],
):
self._model_manager = model_manager
self._api_key = api_key
@classmethod
def get_init_parameters(cls) -> List[str]:
return ["model_manager", "api_key"]
@classmethod
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
return BlockManifest
@classmethod
def get_execution_engine_compatibility(cls) -> Optional[str]:
return ">=1.3.0,<2.0.0"
def run(
self,
images: Batch[WorkflowImageData],
task_type: TaskType,
prompt: Optional[str],
output_structure: Optional[Dict[str, str]],
classes: Optional[List[str]],
api_key: str,
model_version: str,
image_detail: Literal["low", "high", "auto"],
max_tokens: int,
temperature: Optional[float],
max_concurrent_requests: Optional[int],
) -> BlockResult:
inference_images = [i.to_inference_format() for i in images]
raw_outputs = run_gpt_4v_llm_prompting(
roboflow_api_key=self._api_key,
images=inference_images,
task_type=task_type,
prompt=prompt,
output_structure=output_structure,
classes=classes,
openai_api_key=api_key,
gpt_model_version=model_version,
gpt_image_detail=image_detail,
max_tokens=max_tokens,
temperature=temperature,
max_concurrent_requests=max_concurrent_requests,
)
return [
{"output": raw_output, "classes": classes} for raw_output in raw_outputs
]
def run_gpt_4v_llm_prompting(
images: List[Dict[str, Any]],
task_type: TaskType,
prompt: Optional[str],
output_structure: Optional[Dict[str, str]],
classes: Optional[List[str]],
roboflow_api_key: Optional[str],
openai_api_key: Optional[str],
gpt_model_version: str,
gpt_image_detail: Literal["auto", "high", "low"],
max_tokens: int,
temperature: Optional[int],
max_concurrent_requests: Optional[int],
) -> List[str]:
if task_type not in PROMPT_BUILDERS:
raise ValueError(f"Task type: {task_type} not supported.")
gpt4_prompts = []
for image in images:
loaded_image, _ = load_image(image)
base64_image = base64.b64encode(
encode_image_to_jpeg_bytes(loaded_image)
).decode("ascii")
generated_prompt = PROMPT_BUILDERS[task_type](
base64_image=base64_image,
prompt=prompt,
output_structure=output_structure,
classes=classes,
gpt_image_detail=gpt_image_detail,
)
gpt4_prompts.append(generated_prompt)
return execute_gpt_4v_requests(
roboflow_api_key=roboflow_api_key,
openai_api_key=openai_api_key,
gpt4_prompts=gpt4_prompts,
gpt_model_version=gpt_model_version,
max_tokens=max_tokens,
temperature=temperature,
max_concurrent_requests=max_concurrent_requests,
)
def execute_gpt_4v_requests(
roboflow_api_key:str,
openai_api_key: str,
gpt4_prompts: List[List[dict]],
gpt_model_version: str,
max_tokens: int,
temperature: Optional[float],
max_concurrent_requests: Optional[int],
) -> List[str]:
tasks = [
partial(
execute_gpt_4v_request,
roboflow_api_key=roboflow_api_key,
openai_api_key=openai_api_key,
prompt=prompt,
gpt_model_version=gpt_model_version,
max_tokens=max_tokens,
temperature=temperature,
)
for prompt in gpt4_prompts
]
max_workers = (
max_concurrent_requests
or WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS
)
return run_in_parallel(
tasks=tasks,
max_workers=max_workers,
)
def _execute_proxied_openai_request(
roboflow_api_key: str,
openai_api_key: str,
prompt: List[dict],
gpt_model_version: str,
max_tokens: int,
temperature: Optional[float],
) -> str:
"""Executes OpenAI request via Roboflow proxy."""
payload = {
"model": gpt_model_version,
"messages": prompt,
"max_tokens": max_tokens,
"openai_api_key": openai_api_key,
}
if temperature is not None:
payload["temperature"] = temperature
try:
endpoint = f"{API_BASE_URL}/apiproxy/openai?api_key={roboflow_api_key}"
response = requests.post(endpoint, json=payload)
response.raise_for_status()
response_data = response.json()
return response_data["choices"][0]["message"]["content"]
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Failed to connect to Roboflow proxy: {e}") from e
except (KeyError, IndexError) as e:
raise RuntimeError(
f"Invalid response structure from Roboflow proxy: {e} - Response: {response.text}"
) from e
def _execute_openai_request(
openai_api_key: str,
prompt: List[dict],
gpt_model_version: str,
max_tokens: int,
temperature: Optional[float],
) -> str:
"""Executes OpenAI request directly."""
temp_value = temperature if temperature is not None else NOT_GIVEN
try:
client = OpenAI(api_key=openai_api_key)
response = client.chat.completions.create(
model=gpt_model_version,
messages=prompt,
max_tokens=max_tokens,
temperature=temp_value,
)
return response.choices[0].message.content
except Exception as e:
raise RuntimeError(f"OpenAI API request failed: {e}") from e
def execute_gpt_4v_request(
roboflow_api_key: str,
openai_api_key: str,
prompt: List[dict],
gpt_model_version: str,
max_tokens: int,
temperature: Optional[float],
) -> str:
if openai_api_key.startswith("rf_key:account") or openai_api_key.startswith(
"rf_key:user:"
):
return _execute_proxied_openai_request(
roboflow_api_key=roboflow_api_key,
openai_api_key=openai_api_key,
prompt=prompt,
gpt_model_version=gpt_model_version,
max_tokens=max_tokens,
temperature=temperature,
)
else:
return _execute_openai_request(
openai_api_key=openai_api_key,
prompt=prompt,
gpt_model_version=gpt_model_version,
max_tokens=max_tokens,
temperature=temperature,
)
def prepare_unconstrained_prompt(
base64_image: str,
prompt: str,
gpt_image_detail: str,
**kwargs,
) -> List[dict]:
return [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": gpt_image_detail,
},
},
],
}
]
def prepare_classification_prompt(
base64_image: str, classes: List[str], gpt_image_detail: str, **kwargs
) -> List[dict]:
serialised_classes = ", ".join(classes)
return [
{
"role": "system",
"content": "You act as single-class classification model. You must provide reasonable predictions. "
"You are only allowed to produce JSON document in Markdown ```json [...]``` markers. "
'Expected structure of json: {"class_name": "class-name", "confidence": 0.4}. '
"`class-name` must be one of the class names defined by user. You are only allowed to return "
"single JSON document, even if there are potentially multiple classes. You are not allowed to return list.",
},
{
"role": "user",
"content": [
{
"type": "text",
"text": f"List of all classes to be recognised by model: {serialised_classes}",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": gpt_image_detail,
},
},
],
},
]
def prepare_multi_label_classification_prompt(
base64_image: str, classes: List[str], gpt_image_detail: str, **kwargs
) -> List[dict]:
serialised_classes = ", ".join(classes)
return [
{
"role": "system",
"content": "You act as multi-label classification model. You must provide reasonable predictions. "
"You are only allowed to produce JSON document in Markdown ```json``` markers. "
'Expected structure of json: {"predicted_classes": [{"class": "class-name-1", "confidence": 0.9}, '
'{"class": "class-name-2", "confidence": 0.7}]}. '
"`class-name-X` must be one of the class names defined by user and `confidence` is a float value in range "
"0.0-1.0 that represent how sure you are that the class is present in the image. Only return class names "
"that are visible.",
},
{
"role": "user",
"content": [
{
"type": "text",
"text": f"List of all classes to be recognised by model: {serialised_classes}",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": gpt_image_detail,
},
},
],
},
]
def prepare_vqa_prompt(
base64_image: str, prompt: str, gpt_image_detail: str, **kwargs
) -> List[dict]:
return [
{
"role": "system",
"content": "You act as Visual Question Answering model. Your task is to provide answer to question"
"submitted by user. If this is open-question - answer with few sentences, for ABCD question, "
"return only the indicator of the answer.",
},
{
"role": "user",
"content": [
{"type": "text", "text": f"Question: {prompt}"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": gpt_image_detail,
},
},
],
},
]
def prepare_ocr_prompt(
base64_image: str, gpt_image_detail: str, **kwargs
) -> List[dict]:
return [
{
"role": "system",
"content": "You act as OCR model. Your task is to read text from the image and return it in "
"paragraphs representing the structure of texts in the image. You should only return "
"recognised text, nothing else.",
},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": gpt_image_detail,
},
},
],
},
]
def prepare_caption_prompt(
base64_image: str, gpt_image_detail: str, short_description: bool, **kwargs
) -> List[dict]:
caption_detail_level = "Caption should be short."
if not short_description:
caption_detail_level = "Caption should be extensive."
return [
{
"role": "system",
"content": f"You act as image caption model. Your task is to provide description of the image. "
f"{caption_detail_level}",
},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": gpt_image_detail,
},
},
],
},
]
def prepare_structured_answering_prompt(
base64_image: str, output_structure: Dict[str, str], gpt_image_detail: str, **kwargs
) -> List[dict]:
output_structure_serialised = json.dumps(output_structure, indent=4)
return [
{
"role": "system",
"content": "You are supposed to produce responses in JSON wrapped in Markdown markers: "
"```json\nyour-response\n```. User is to provide you dictionary with keys and values. "
"Each key must be present in your response. Values in user dictionary represent "
"descriptions for JSON fields to be generated. Provide only JSON Markdown in response.",
},
{
"role": "user",
"content": [
{
"type": "text",
"text": f"Specification of requirements regarding output fields: \n"
f"{output_structure_serialised}",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": gpt_image_detail,
},
},
],
},
]
PROMPT_BUILDERS = {
"unconstrained": prepare_unconstrained_prompt,
"ocr": prepare_ocr_prompt,
"visual-question-answering": prepare_vqa_prompt,
"caption": partial(prepare_caption_prompt, short_description=True),
"detailed-caption": partial(prepare_caption_prompt, short_description=False),
"classification": prepare_classification_prompt,
"multi-label-classification": prepare_multi_label_classification_prompt,
"structured-answering": prepare_structured_answering_prompt,
}

View file

@ -18,6 +18,139 @@ if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from typing import List, Union
class ImportCollector(cst.CSTVisitor):
"""Visitor that collects all import statements in a module."""
def __init__(self):
super().__init__()
self.imports = []
def visit_Import(self, node: cst.Import) -> None:
self.imports.append(node)
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
self.imports.append(node)
class GlobalStatementCollector(cst.CSTVisitor):
"""Visitor that collects all global statements (excluding imports and functions/classes)."""
def __init__(self):
super().__init__()
self.global_statements = []
self.in_function_or_class = False
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
# Don't visit inside classes
self.in_function_or_class = True
return False
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.in_function_or_class = False
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
# Don't visit inside functions
self.in_function_or_class = True
return False
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.in_function_or_class = False
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
if not self.in_function_or_class:
for statement in node.body:
# Skip imports
if not isinstance(statement, (cst.Import, cst.ImportFrom)):
self.global_statements.append(node)
break
class LastImportFinder(cst.CSTVisitor):
"""Finds the position of the last import statement in the module."""
def __init__(self):
super().__init__()
self.last_import_line = 0
self.current_line = 0
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
self.current_line += 1
for statement in node.body:
if isinstance(statement, (cst.Import, cst.ImportFrom)):
self.last_import_line = self.current_line
class ImportInserter(cst.CSTTransformer):
"""Transformer that inserts global statements after the last import."""
def __init__(self, global_statements: List[cst.SimpleStatementLine], last_import_line: int):
super().__init__()
self.global_statements = global_statements
self.last_import_line = last_import_line
self.current_line = 0
self.inserted = False
def leave_SimpleStatementLine(
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
) -> cst.Module:
self.current_line += 1
# If we're right after the last import and haven't inserted yet
if self.current_line == self.last_import_line and not self.inserted:
self.inserted = True
return cst.Module(body=[updated_node] + self.global_statements)
return cst.Module(body=[updated_node])
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# If there were no imports, add at the beginning of the module
if self.last_import_line == 0 and not self.inserted:
updated_body = list(updated_node.body)
for stmt in reversed(self.global_statements):
updated_body.insert(0, stmt)
return updated_node.with_changes(body=updated_body)
return updated_node
def extract_global_statements(source_code: str) -> List[cst.SimpleStatementLine]:
"""Extract global statements from source code."""
module = cst.parse_module(source_code)
collector = GlobalStatementCollector()
module.visit(collector)
return collector.global_statements
def find_last_import_line(target_code: str) -> int:
"""Find the line number of the last import statement."""
module = cst.parse_module(target_code)
finder = LastImportFinder()
module.visit(finder)
return finder.last_import_line
def merge_globals(source_code: str, target_code: str) -> str:
"""Merge global statements from source into target just after imports."""
# Extract global statements from source
global_statements = extract_global_statements(source_code)
# Find the last import line in target
last_import_line = find_last_import_line(target_code)
# Parse the target code
target_module = cst.parse_module(target_code)
# Create transformer to insert global statements
transformer = ImportInserter(global_statements, last_import_line)
# Apply transformation
modified_module = target_module.visit(transformer)
# Return the modified code
return modified_module.code
class FutureAliasedImportTransformer(cst.CSTTransformer): class FutureAliasedImportTransformer(cst.CSTTransformer):
def leave_ImportFrom( def leave_ImportFrom(
@ -47,6 +180,20 @@ def add_needed_imports_from_module(
helper_functions: list[FunctionSource] | None = None, helper_functions: list[FunctionSource] | None = None,
helper_functions_fqn: set[str] | None = None, helper_functions_fqn: set[str] | None = None,
) -> str: ) -> str:
global_statements = extract_global_statements(src_module_code)
# Find the last import line in target
last_import_line = find_last_import_line(dst_module_code)
# Parse the target code
target_module = cst.parse_module(dst_module_code)
# Create transformer to insert global statements
transformer = ImportInserter(global_statements, last_import_line)
#
# # Apply transformation
modified_module = target_module.visit(transformer)
dst_module_code = modified_module.code
"""Add all needed and used source module code imports to the destination module code, and return it.""" """Add all needed and used source module code imports to the destination module code, and return it."""
src_module_code = delete___future___aliased_imports(src_module_code) src_module_code = delete___future___aliased_imports(src_module_code)
if not helper_functions_fqn: if not helper_functions_fqn:

View file

@ -12,7 +12,7 @@ from codeflash.code_utils.code_replacer import (
replace_functions_in_file, replace_functions_in_file,
) )
from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent from codeflash.models.models import CodeOptimizationContext, FunctionParent
from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig from codeflash.verification.verification_utils import TestConfig
@ -34,6 +34,48 @@ class FakeFunctionSource:
jedi_definition: JediDefinition jedi_definition: JediDefinition
class Args:
disable_imports_sorting = True
formatter_cmds = ["disabled"]
def test_code_replacement_global_statements():
optimized_code = """import numpy as np
inconsequential_var = '123'
def sorter(arr):
return arr.sort()"""
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_optimized.py").resolve()
original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").read_text(
encoding="utf-8"
)
code_path.write_text(original_code_str, encoding="utf-8")
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
project_root_path = (Path(__file__).parent / "..").resolve()
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
test_config = TestConfig(
tests_root=tests_root,
tests_project_rootdir=project_root_path,
project_root_path=project_root_path,
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
)
final_output = code_path.read_text(encoding="utf-8")
assert "inconsequential_var = '123'" in final_output
code_path.unlink(missing_ok=True)
def test_test_libcst_code_replacement() -> None: def test_test_libcst_code_replacement() -> None:
optim_code = """import libcst as cst optim_code = """import libcst as cst
from typing import Optional from typing import Optional
@ -74,7 +116,7 @@ print("Hello world")
""" """
function_name: str = "NewClass.new_function" function_name: str = "NewClass.new_function"
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports( new_code: str = replace_functions_and_add_imports(
source_code=original_code, source_code=original_code,
function_names=[function_name], function_names=[function_name],
@ -135,7 +177,7 @@ print("Hello world")
""" """
function_name: str = "NewClass.new_function" function_name: str = "NewClass.new_function"
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports( new_code: str = replace_functions_and_add_imports(
source_code=original_code, source_code=original_code,
function_names=[function_name], function_names=[function_name],
@ -196,7 +238,7 @@ print("Salut monde")
""" """
function_names: list[str] = ["other_function"] function_names: list[str] = ["other_function"]
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports( new_code: str = replace_functions_and_add_imports(
source_code=original_code, source_code=original_code,
function_names=function_names, function_names=function_names,
@ -260,7 +302,7 @@ print("Salut monde")
""" """
function_names: list[str] = ["yet_another_function", "other_function"] function_names: list[str] = ["yet_another_function", "other_function"]
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports( new_code: str = replace_functions_and_add_imports(
source_code=original_code, source_code=original_code,
function_names=function_names, function_names=function_names,
@ -313,7 +355,7 @@ def supersort(doink):
""" """
function_names: list[str] = ["sorter_deps"] function_names: list[str] = ["sorter_deps"]
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports( new_code: str = replace_functions_and_add_imports(
source_code=original_code, source_code=original_code,
function_names=function_names, function_names=function_names,
@ -591,7 +633,7 @@ class CacheConfig(BaseConfig):
) )
""" """
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"] function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports( new_code: str = replace_functions_and_add_imports(
source_code=original_code, source_code=original_code,
@ -662,7 +704,7 @@ def test_test_libcst_code_replacement8() -> None:
return np.sum(a != b) / a.size return np.sum(a != b) / a.size
''' '''
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"] function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports( new_code: str = replace_functions_and_add_imports(
source_code=original_code, source_code=original_code,
function_names=function_names, function_names=function_names,
@ -715,7 +757,7 @@ def totally_new_function(value: Optional[str]):
print("Hello world") print("Hello world")
""" """
function_name: str = "NewClass.__init__" function_name: str = "NewClass.__init__"
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports( new_code: str = replace_functions_and_add_imports(
source_code=original_code, source_code=original_code,
function_names=[function_name], function_names=[function_name],
@ -811,7 +853,7 @@ def test_code_replacement11() -> None:
function_name: str = "Fu.foo" function_name: str = "Fu.foo"
parents = (FunctionParent("Fu", "ClassDef"),) parents = (FunctionParent("Fu", "ClassDef"),)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = {("foo", parents), ("real_bar", parents)} preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = {("foo", parents), ("real_bar", parents)}
new_code: str = replace_functions_in_file( new_code: str = replace_functions_in_file(
source_code=original_code, source_code=original_code,
original_function_names=[function_name], original_function_names=[function_name],
@ -850,7 +892,7 @@ def test_code_replacement12() -> None:
pass pass
''' '''
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = [] preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = []
new_code: str = replace_functions_in_file( new_code: str = replace_functions_in_file(
source_code=original_code, source_code=original_code,
original_function_names=["Fu.real_bar"], original_function_names=["Fu.real_bar"],
@ -887,7 +929,7 @@ def test_test_libcst_code_replacement13() -> None:
""" """
function_names: list[str] = ["yet_another_function", "other_function"] function_names: list[str] = ["yet_another_function", "other_function"]
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = [] preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = []
new_code: str = replace_functions_and_add_imports( new_code: str = replace_functions_and_add_imports(
source_code=original_code, source_code=original_code,
function_names=function_names, function_names=function_names,
@ -1098,8 +1140,8 @@ class TestResults(BaseModel):
) )
assert ( assert (
new_code new_code
== """from __future__ import annotations == """from __future__ import annotations
import sys import sys
from codeflash.verification.comparator import comparator from codeflash.verification.comparator import comparator
from enum import Enum from enum import Enum
@ -1274,7 +1316,7 @@ def cosine_similarity_top_k(
return ret_idxs, scores return ret_idxs, scores
''' '''
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
helper_functions = [ helper_functions = [
FakeFunctionSource( FakeFunctionSource(
@ -1304,8 +1346,8 @@ def cosine_similarity_top_k(
project_root_path=Path(__file__).parent.parent.resolve(), project_root_path=Path(__file__).parent.parent.resolve(),
) )
assert ( assert (
new_code new_code
== '''import numpy as np == '''import numpy as np
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@dataclass(config=dict(arbitrary_types_allowed=True)) @dataclass(config=dict(arbitrary_types_allowed=True))
@ -1363,8 +1405,8 @@ def cosine_similarity_top_k(
) )
assert ( assert (
new_helper_code new_helper_code
== '''import numpy as np == '''import numpy as np
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@dataclass(config=dict(arbitrary_types_allowed=True)) @dataclass(config=dict(arbitrary_types_allowed=True))
@ -1575,7 +1617,7 @@ print("Hello world")
"NewClass.new_function2", "NewClass.new_function2",
"NestedClass.nested_function", "NestedClass.nested_function",
] # Nested classes should be ignored, even if provided as target ] # Nested classes should be ignored, even if provided as target
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports( new_code: str = replace_functions_and_add_imports(
source_code=original_code, source_code=original_code,
function_names=function_names, function_names=function_names,
@ -1611,7 +1653,7 @@ print("Hello world")
""" """
function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"] function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports( new_code: str = replace_functions_and_add_imports(
source_code=original_code, source_code=original_code,
function_names=function_names, function_names=function_names,

View file

@ -1,147 +0,0 @@
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
class Args:
disable_imports_sorting = True
formatter_cmds = ["disabled"]
def test_code_replacement_pr():
optimized_code = """from typing import List, Optional
import requests
from inference.core.env import API_BASE_URL
from openai import OpenAI
from openai._types import NOT_GIVEN
# Create a global requests session to reuse connections
sess = requests.Session()
# Create a cache for OpenAI clients to avoid recreating them frequently
openai_clients = {}
def _execute_proxied_openai_request(
roboflow_api_key: str,
openai_api_key: str,
prompt: List[dict],
gpt_model_version: str,
max_tokens: int,
temperature: Optional[float],
) -> str:
\"\"\"Executes OpenAI request via Roboflow proxy.\"\"\"
payload = {
\"model\": gpt_model_version,
\"messages\": prompt,
\"max_tokens\": max_tokens,
\"openai_api_key\": openai_api_key,
}
if temperature is not None:
payload[\"temperature\"] = temperature
try:
endpoint = f\"{API_BASE_URL}/apiproxy/openai?api_key={roboflow_api_key}\"
# Use global session for requests
response = sess.post(endpoint, json=payload)
response.raise_for_status()
response_data = response.json()
return response_data[\"choices\"][0][\"message\"][\"content\"]
except requests.exceptions.RequestException as e:
raise RuntimeError(f\"Failed to connect to Roboflow proxy: {e}\") from e
except (KeyError, IndexError) as e:
raise RuntimeError(
f\"Invalid response structure from Roboflow proxy: {e} - Response: {response.text}\"
) from e
def _execute_openai_request(
openai_api_key: str,
prompt: List[dict],
gpt_model_version: str,
max_tokens: int,
temperature: Optional[float],
) -> str:
\"\"\"Executes OpenAI request directly.\"\"\"
temp_value = temperature if temperature is not None else NOT_GIVEN
try:
# Cache OpenAI client to avoid creating a new one each time
if openai_api_key not in openai_clients:
openai_clients[openai_api_key] = OpenAI(api_key=openai_api_key)
client = openai_clients[openai_api_key]
response = client.chat.completions.create(
model=gpt_model_version,
messages=prompt,
max_tokens=max_tokens,
temperature=temp_value,
)
return response.choices[0].message.content
except Exception as e:
raise RuntimeError(f\"OpenAI API request failed: {e}\") from e
def execute_gpt_4v_request(
roboflow_api_key: str,
openai_api_key: str,
prompt: List[dict],
gpt_model_version: str,
max_tokens: int,
temperature: Optional[float],
) -> str:
if openai_api_key.startswith(\"rf_key:account\") or openai_api_key.startswith(
\"rf_key:user:\"
):
return _execute_proxied_openai_request(
roboflow_api_key=roboflow_api_key,
openai_api_key=openai_api_key,
prompt=prompt,
gpt_model_version=gpt_model_version,
max_tokens=max_tokens,
temperature=temperature,
)
else:
return _execute_openai_request(
openai_api_key=openai_api_key,
prompt=prompt,
gpt_model_version=gpt_model_version,
max_tokens=max_tokens,
temperature=temperature,
)
"""
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/roboflow.py").resolve()
original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/roboflow_original.py").read_text(encoding="utf-8")
code_path.write_text(original_code_str, encoding="utf-8")
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
project_root_path = (Path(__file__).parent / "..").resolve()
func = FunctionToOptimize(function_name="execute_gpt_4v_request", parents=[], file_path=code_path)
test_config = TestConfig(
tests_root=tests_root,
tests_project_rootdir=project_root_path,
project_root_path=project_root_path,
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
)
new_code, new_helper_code = func_optimizer.reformat_code_and_helpers(
code_context.helper_functions, func.file_path, func_optimizer.function_to_optimize_source_code
)
original_code_combined = original_helper_code.copy()
original_code_combined[func.file_path] = func_optimizer.function_to_optimize_source_code
new_code_combined = new_helper_code.copy()
new_code_combined[func.file_path] = new_code
final_output = code_path.read_text(encoding="utf-8")
assert "openai_clients = {}" in final_output
code_path.unlink(missing_ok=True)

View file

@ -1,150 +0,0 @@
from pathlib import Path
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
class Args:
disable_imports_sorting = True
formatter_cmds = ["disabled"]
def test_code_replacement_pr():
optimized_code = """from typing import List, Optional
import requests
from inference.core.env import API_BASE_URL
from openai import OpenAI
from openai._types import NOT_GIVEN
# Create a global requests session to reuse connections
sess = requests.Session()
# Create a cache for OpenAI clients to avoid recreating them frequently
openai_clients = {}
def _execute_proxied_openai_request(
roboflow_api_key: str,
openai_api_key: str,
prompt: List[dict],
gpt_model_version: str,
max_tokens: int,
temperature: Optional[float],
) -> str:
\"\"\"Executes OpenAI request via Roboflow proxy.\"\"\"
payload = {
\"model\": gpt_model_version,
\"messages\": prompt,
\"max_tokens\": max_tokens,
\"openai_api_key\": openai_api_key,
}
if temperature is not None:
payload[\"temperature\"] = temperature
try:
endpoint = f\"{API_BASE_URL}/apiproxy/openai?api_key={roboflow_api_key}\"
# Use global session for requests
response = sess.post(endpoint, json=payload)
response.raise_for_status()
response_data = response.json()
return response_data[\"choices\"][0][\"message\"][\"content\"]
except requests.exceptions.RequestException as e:
raise RuntimeError(f\"Failed to connect to Roboflow proxy: {e}\") from e
except (KeyError, IndexError) as e:
raise RuntimeError(
f\"Invalid response structure from Roboflow proxy: {e} - Response: {response.text}\"
) from e
def _execute_openai_request(
openai_api_key: str,
prompt: List[dict],
gpt_model_version: str,
max_tokens: int,
temperature: Optional[float],
) -> str:
\"\"\"Executes OpenAI request directly.\"\"\"
temp_value = temperature if temperature is not None else NOT_GIVEN
try:
# Cache OpenAI client to avoid creating a new one each time
if openai_api_key not in openai_clients:
openai_clients[openai_api_key] = OpenAI(api_key=openai_api_key)
client = openai_clients[openai_api_key]
response = client.chat.completions.create(
model=gpt_model_version,
messages=prompt,
max_tokens=max_tokens,
temperature=temp_value,
)
return response.choices[0].message.content
except Exception as e:
raise RuntimeError(f\"OpenAI API request failed: {e}\") from e
def execute_gpt_4v_request(
roboflow_api_key: str,
openai_api_key: str,
prompt: List[dict],
gpt_model_version: str,
max_tokens: int,
temperature: Optional[float],
) -> str:
if openai_api_key.startswith(\"rf_key:account\") or openai_api_key.startswith(
\"rf_key:user:\"
):
return _execute_proxied_openai_request(
roboflow_api_key=roboflow_api_key,
openai_api_key=openai_api_key,
prompt=prompt,
gpt_model_version=gpt_model_version,
max_tokens=max_tokens,
temperature=temperature,
)
else:
return _execute_openai_request(
openai_api_key=openai_api_key,
prompt=prompt,
gpt_model_version=gpt_model_version,
max_tokens=max_tokens,
temperature=temperature,
)
"""
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/roboflow.py").resolve()
original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/roboflow_original.py").read_text(encoding="utf-8")
code_path.write_text(original_code_str, encoding="utf-8")
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/roboflow_tests/tests/workflows/unit_tests/core_steps/models/foundation")
# tests/workflows/unit_tests/core_steps/models/foundation/test_openai.py
project_root_path = (Path(__file__).parent / "..").resolve()
func = FunctionToOptimize(function_name="execute_gpt_4v_request", parents=[], file_path=code_path)
test_config = TestConfig(
tests_root=tests_root,
tests_project_rootdir=project_root_path,
project_root_path=project_root_path,
test_framework="pytest",
pytest_cmd="pytest",
)
discovered_tests = discover_unit_tests(test_config)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
)
# new_code, new_helper_code = func_optimizer.reformat_code_and_helpers(
# code_context.helper_functions, func.file_path, func_optimizer.function_to_optimize_source_code
# )
# original_code_combined = original_helper_code.copy()
# original_code_combined[func.file_path] = func_optimizer.function_to_optimize_source_code
# new_code_combined = new_helper_code.copy()
# new_code_combined[func.file_path] = new_code
final_output = code_path.read_text(encoding="utf-8")
assert "openai_clients = {}" in final_output
code_path.unlink(missing_ok=True)