add retrying to our aiservice requests

This commit is contained in:
Kevin Turcios 2024-10-27 12:52:36 -05:00
parent cbe5fcc989
commit 9bda7bf470
3 changed files with 58 additions and 51 deletions

View file

@ -6,6 +6,7 @@ import platform
from typing import TYPE_CHECKING, Any
import requests
import stamina
from pydantic.dataclasses import dataclass
from pydantic.json import pydantic_encoder
@ -25,22 +26,60 @@ if TYPE_CHECKING:
@dataclass(frozen=True)
class OptimizedCandidate:
"""Optimized candidate, containing the optimized source code, explanation, and optimization ID."""
source_code: str
explanation: str
optimization_id: str
ph_events = {
"cli-optimize-error-caught": ("/optimize", "Error generating optimized candidates"),
"cli-optimize-error-response": ("/optimize", "Error generating optimized candidates"),
"cli-testgen-error-caught": ("/testgen", "Error generating tests"),
"cli-testgen-error-response": ("/testgen", "Error generating tests"),
None: ("/log_features", "Error logging features"),
}
def stamina_on_error_ph_event(exc: Exception) -> bool:
"""Handle errors by sending events to PostHog before retrying."""
if isinstance(exc, requests.HTTPError):
try:
if exc.request and exc.request.url:
endpoint = exc.request.url.split("/ai")[-1]
for event_key, (event_endpoint, event_message) in ph_events.items():
if endpoint.startswith(event_endpoint):
if event_key:
ph(
event_key,
{"response_status_code": exc.response.status_code, "error": exc.response.text},
)
# logger.exception(f"{event_message}: {exc.response.status_code} - {exc.response.text}")
logger.info(f"{event_message}: {exc.response.status_code} - {exc.response.text}")
break
except (AttributeError, KeyError) as e:
logger.error(f"Error reporting to ph: {e}")
return True
return False
class AiServiceClient:
"""Client for interacting with the AI service."""
def __init__(self) -> None:
"""Initialize the AI service client with base URL and headers."""
self.base_url = self.get_aiservice_base_url()
self.headers = {"Authorization": f"Bearer {get_codeflash_api_key()}", "Connection": "close"}
def get_aiservice_base_url(self) -> str:
"""Get the base URL for the AI service based on the environment."""
if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local":
logger.info("Using local AI Service at http://localhost:8000")
return "http://localhost:8000"
return "https://app.codeflash.ai"
@stamina.retry(on=stamina_on_error_ph_event, wait_initial=0.5, attempts=5)
def make_ai_service_request(
self, endpoint: str, method: str = "POST", payload: dict[str, Any] | None = None, timeout: float | None = None
) -> requests.Response:
@ -59,7 +98,7 @@ class AiServiceClient:
response = requests.post(url, data=json_payload, headers=headers, timeout=timeout)
else:
response = requests.get(url, headers=self.headers, timeout=timeout)
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
return response
def optimize_python_code(
@ -91,33 +130,18 @@ class AiServiceClient:
}
logger.info("Generating optimized candidates ...")
console.rule()
try:
response = self.make_ai_service_request("/optimize", payload=payload, timeout=600)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating optimized candidates: {e}")
ph("cli-optimize-error-caught", {"error": str(e)})
return []
if response.status_code == 200:
optimizations_json = response.json()["optimizations"]
logger.info(f"Generated {len(optimizations_json)} candidates.")
console.rule()
return [
OptimizedCandidate(
source_code=opt["source_code"],
explanation=opt["explanation"],
optimization_id=opt["optimization_id"],
)
for opt in optimizations_json
]
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
response = self.make_ai_service_request("/optimize", payload=payload, timeout=600)
optimizations_json = response.json()["optimizations"]
logger.info(f"Generated {len(optimizations_json)} candidates.")
console.rule()
return []
return [
OptimizedCandidate(
source_code=opt["source_code"], explanation=opt["explanation"], optimization_id=opt["optimization_id"]
)
for opt in optimizations_json
]
def log_results(
self,
@ -146,10 +170,8 @@ class AiServiceClient:
"is_correct": is_correct,
"codeflash_version": codeflash_version,
}
try:
self.make_ai_service_request("/log_features", payload=payload, timeout=5)
except requests.exceptions.RequestException as e:
logger.exception(f"Error logging features: {e}")
self.make_ai_service_request("/log_features", payload=payload, timeout=5)
def generate_regression_tests(
self,
@ -185,6 +207,7 @@ class AiServiceClient:
"pytest",
"unittest",
], f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
payload = {
"source_code_being_tested": source_code_being_tested,
"function_to_optimize": function_to_optimize,
@ -198,28 +221,12 @@ class AiServiceClient:
"python_version": platform.python_version(),
"codeflash_version": codeflash_version,
}
try:
response = self.make_ai_service_request("/testgen", payload=payload, timeout=600)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating tests: {e}")
ph("cli-testgen-error-caught", {"error": str(e)})
return None
# the timeout should be the same as the timeout for the AI service backend
response = self.make_ai_service_request("/testgen", payload=payload, timeout=600)
response_json = response.json()
logger.debug(f"Generated tests for function {function_to_optimize.function_name}")
if response.status_code == 200:
response_json = response.json()
logger.debug(f"Generated tests for function {function_to_optimize.function_name}")
return response_json["generated_tests"], response_json["instrumented_tests"]
try:
error = response.json()["error"]
logger.error(f"Error generating tests: {response.status_code} - {error}")
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
return None
except Exception:
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})
return None
return response_json["generated_tests"], response_json["instrumented_tests"]
class LocalAiServiceClient(AiServiceClient):

View file

@ -61,7 +61,6 @@ SPINNER_TYPES = {
"hamburger",
"dots",
"squish",
"christmas",
"toggle13",
"star",
"boxBounce",

View file

@ -46,6 +46,7 @@ types-requests = "^2.32.0.20241016"
types-six = "^1.16.21.20241009"
types-cffi = "^1.16.0.20240331"
types-openpyxl = "^3.1.5.20241020"
stamina = "^24.3.0"
[tool.poetry.group.dev]
optional = true