mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
225 lines
8.7 KiB
Python
225 lines
8.7 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import platform
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import requests
|
|
from pydantic.dataclasses import dataclass
|
|
from pydantic.json import pydantic_encoder
|
|
|
|
from codeflash.code_utils.env_utils import get_codeflash_api_key
|
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
|
from codeflash.telemetry.posthog import ph
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class OptimizedCandidate:
|
|
source_code: str
|
|
explanation: str
|
|
optimization_id: str
|
|
|
|
|
|
class AiServiceClient:
|
|
def __init__(self):
|
|
self.base_url = self.get_aiservice_base_url()
|
|
self.headers = {"Authorization": f"Bearer {get_codeflash_api_key()}"}
|
|
|
|
def get_aiservice_base_url(self) -> str:
|
|
if os.environ.get("AIS_SERVER", default="prod").lower() == "local":
|
|
logging.info("Using local AI Service at http://localhost:8000/")
|
|
return "http://localhost:8000/"
|
|
return "https://app.codeflash.ai"
|
|
|
|
def make_ai_service_request(
|
|
self,
|
|
endpoint: str,
|
|
method: str = "POST",
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
timeout: float = None,
|
|
) -> requests.Response:
|
|
"""Make an API request to the given endpoint on the AI service.
|
|
:param endpoint: The endpoint to call, e.g., "/optimize".
|
|
:param method: The HTTP method to use ('GET' or 'POST').
|
|
:param payload: Optional JSON payload to include in the POST request body.
|
|
:param timeout: The timeout for the request.
|
|
:return: The response object from the API.
|
|
"""
|
|
url = f"{self.base_url}/ai{endpoint}"
|
|
if method.upper() == "POST":
|
|
json_payload = json.dumps(payload, indent=None, default=pydantic_encoder)
|
|
headers = {**self.headers, "Content-Type": "application/json"}
|
|
response = requests.post(url, data=json_payload, headers=headers, timeout=timeout)
|
|
else:
|
|
response = requests.get(url, headers=self.headers, timeout=timeout)
|
|
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
|
|
return response
|
|
|
|
def optimize_python_code(
|
|
self,
|
|
source_code: str,
|
|
trace_id: str,
|
|
num_variants: int = 10,
|
|
) -> List[OptimizedCandidate]:
|
|
"""Optimize the given python code for performance by making a request to the Django endpoint.
|
|
|
|
Parameters
|
|
----------
|
|
- source_code (str): The python code to optimize.
|
|
- num_variants (int): Number of optimization variants to generate. Default is 10.
|
|
|
|
Returns
|
|
-------
|
|
- List[Optimization]: A list of Optimization objects.
|
|
|
|
"""
|
|
payload = {
|
|
"source_code": source_code,
|
|
"num_variants": num_variants,
|
|
"trace_id": trace_id,
|
|
"python_version": platform.python_version(),
|
|
}
|
|
logging.info("Generating optimized candidates ...")
|
|
try:
|
|
response = self.make_ai_service_request(
|
|
"/optimize",
|
|
payload=payload,
|
|
timeout=600,
|
|
)
|
|
except requests.exceptions.RequestException as e:
|
|
logging.exception(f"Error generating optimized candidates: {e}")
|
|
ph("cli-optimize-error-caught", {"error": str(e)})
|
|
return []
|
|
|
|
if response.status_code == 200:
|
|
optimizations_json = response.json()["optimizations"]
|
|
logging.info(f"Generated {len(optimizations_json)} candidates.")
|
|
return [
|
|
OptimizedCandidate(
|
|
source_code=opt["source_code"],
|
|
explanation=opt["explanation"],
|
|
optimization_id=opt["optimization_id"],
|
|
)
|
|
for opt in optimizations_json
|
|
]
|
|
try:
|
|
error = response.json()["error"]
|
|
except Exception:
|
|
error = response.text
|
|
logging.error(f"Error generating optimized candidates: {response.status_code} - {error}")
|
|
ph(
|
|
"cli-optimize-error-response",
|
|
{"response_status_code": response.status_code, "error": error},
|
|
)
|
|
return []
|
|
|
|
def log_results(
|
|
self,
|
|
function_trace_id: str,
|
|
speedup_ratio: Optional[Dict[str, float]],
|
|
original_runtime: Optional[float],
|
|
optimized_runtime: Optional[Dict[str, float]],
|
|
is_correct: Optional[Dict[str, bool]],
|
|
experiment_id: Optional[str],
|
|
) -> None:
|
|
"""Log features to the database.
|
|
|
|
Parameters
|
|
----------
|
|
- function_trace_id (str): The UUID.
|
|
- speedup_ratio (Optional[Dict[str, float]]): The speedup.
|
|
- original_runtime (Optional[Dict[str, float]]): The original runtime.
|
|
- optimized_runtime (Optional[Dict[str, float]]): The optimized runtime.
|
|
- is_correct (Optional[Dict[str, bool]]): Whether the optimized code is correct.
|
|
|
|
"""
|
|
payload = {
|
|
"trace_id": function_trace_id,
|
|
"speedup_ratio": speedup_ratio,
|
|
"original_runtime": original_runtime,
|
|
"optimized_runtime": optimized_runtime,
|
|
"is_correct": is_correct,
|
|
"experiment_id": experiment_id,
|
|
}
|
|
try:
|
|
self.make_ai_service_request("/log_features", payload=payload, timeout=5)
|
|
except requests.exceptions.RequestException as e:
|
|
logging.exception(f"Error logging features: {e}")
|
|
|
|
def generate_regression_tests(
|
|
self,
|
|
source_code_being_tested: str,
|
|
function_to_optimize: FunctionToOptimize,
|
|
dependent_function_names: list[str],
|
|
module_path: str,
|
|
test_module_path: str,
|
|
test_framework: str,
|
|
test_timeout: int,
|
|
trace_id: str,
|
|
) -> Optional[Tuple[str, str]]:
|
|
"""Generate regression tests for the given function by making a request to the Django endpoint.
|
|
|
|
Parameters
|
|
----------
|
|
- source_code_being_tested (str): The source code of the function being tested.
|
|
- function_to_optimize (FunctionToOptimize): The function to optimize.
|
|
- dependent_function_names (list[Source]): List of dependent function names.
|
|
- module_path (str): The module path where the function is located.
|
|
- test_module_path (str): The module path for the test code.
|
|
- test_framework (str): The test framework to use, e.g., "pytest".
|
|
- test_timeout (int): The timeout for each test in seconds.
|
|
|
|
Returns
|
|
-------
|
|
- Dict[str, str] | None: The generated regression tests and instrumented tests, or None if an error occurred.
|
|
|
|
"""
|
|
assert test_framework in [
|
|
"pytest",
|
|
"unittest",
|
|
], f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
|
|
payload = {
|
|
"source_code_being_tested": source_code_being_tested,
|
|
"function_to_optimize": function_to_optimize,
|
|
"dependent_function_names": dependent_function_names,
|
|
"module_path": module_path,
|
|
"test_module_path": test_module_path,
|
|
"test_framework": test_framework,
|
|
"test_timeout": test_timeout,
|
|
"trace_id": trace_id,
|
|
"python_version": platform.python_version(),
|
|
}
|
|
try:
|
|
response = self.make_ai_service_request("/testgen", payload=payload, timeout=600)
|
|
except requests.exceptions.RequestException as e:
|
|
logging.exception(f"Error generating tests: {e}")
|
|
ph("cli-testgen-error-caught", {"error": str(e)})
|
|
return None
|
|
|
|
# the timeout should be the same as the timeout for the AI service backend
|
|
|
|
if response.status_code == 200:
|
|
response_json = response.json()
|
|
logging.info(f"Generated tests for function {function_to_optimize.function_name}")
|
|
return response_json["generated_tests"], response_json["instrumented_tests"]
|
|
else:
|
|
try:
|
|
error = response.json()["error"]
|
|
logging.error(f"Error generating tests: {response.status_code} - {error}")
|
|
ph(
|
|
"cli-testgen-error-response",
|
|
{"response_status_code": response.status_code, "error": error},
|
|
)
|
|
return None
|
|
except Exception:
|
|
logging.exception(f"Error generating tests: {response.status_code} - {response.text}")
|
|
ph(
|
|
"cli-testgen-error-response",
|
|
{"response_status_code": response.status_code, "error": response.text},
|
|
)
|
|
return None
|
|
|
|
|
|
class LocalAiServiceClient(AiServiceClient):
|
|
def get_aiservice_base_url(self) -> str:
|
|
return "http://localhost:8000/"
|