add retrying to our aiservice requests
This commit is contained in:
parent
cbe5fcc989
commit
9bda7bf470
3 changed files with 58 additions and 51 deletions
|
|
@ -6,6 +6,7 @@ import platform
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import requests
|
||||
import stamina
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
|
|
@ -25,22 +26,60 @@ if TYPE_CHECKING:
|
|||
|
||||
@dataclass(frozen=True)
|
||||
class OptimizedCandidate:
|
||||
"""Optimized candidate, containing the optimized source code, explanation, and optimization ID."""
|
||||
|
||||
source_code: str
|
||||
explanation: str
|
||||
optimization_id: str
|
||||
|
||||
|
||||
ph_events = {
|
||||
"cli-optimize-error-caught": ("/optimize", "Error generating optimized candidates"),
|
||||
"cli-optimize-error-response": ("/optimize", "Error generating optimized candidates"),
|
||||
"cli-testgen-error-caught": ("/testgen", "Error generating tests"),
|
||||
"cli-testgen-error-response": ("/testgen", "Error generating tests"),
|
||||
None: ("/log_features", "Error logging features"),
|
||||
}
|
||||
|
||||
|
||||
def stamina_on_error_ph_event(exc: Exception) -> bool:
|
||||
"""Handle errors by sending events to PostHog before retrying."""
|
||||
if isinstance(exc, requests.HTTPError):
|
||||
try:
|
||||
if exc.request and exc.request.url:
|
||||
endpoint = exc.request.url.split("/ai")[-1]
|
||||
for event_key, (event_endpoint, event_message) in ph_events.items():
|
||||
if endpoint.startswith(event_endpoint):
|
||||
if event_key:
|
||||
ph(
|
||||
event_key,
|
||||
{"response_status_code": exc.response.status_code, "error": exc.response.text},
|
||||
)
|
||||
# logger.exception(f"{event_message}: {exc.response.status_code} - {exc.response.text}")
|
||||
logger.info(f"{event_message}: {exc.response.status_code} - {exc.response.text}")
|
||||
break
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.error(f"Error reporting to ph: {e}")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AiServiceClient:
|
||||
"""Client for interacting with the AI service."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the AI service client with base URL and headers."""
|
||||
self.base_url = self.get_aiservice_base_url()
|
||||
self.headers = {"Authorization": f"Bearer {get_codeflash_api_key()}", "Connection": "close"}
|
||||
|
||||
def get_aiservice_base_url(self) -> str:
|
||||
"""Get the base URL for the AI service based on the environment."""
|
||||
if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local":
|
||||
logger.info("Using local AI Service at http://localhost:8000")
|
||||
return "http://localhost:8000"
|
||||
return "https://app.codeflash.ai"
|
||||
|
||||
@stamina.retry(on=stamina_on_error_ph_event, wait_initial=0.5, attempts=5)
|
||||
def make_ai_service_request(
|
||||
self, endpoint: str, method: str = "POST", payload: dict[str, Any] | None = None, timeout: float | None = None
|
||||
) -> requests.Response:
|
||||
|
|
@ -59,7 +98,7 @@ class AiServiceClient:
|
|||
response = requests.post(url, data=json_payload, headers=headers, timeout=timeout)
|
||||
else:
|
||||
response = requests.get(url, headers=self.headers, timeout=timeout)
|
||||
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
|
||||
response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
|
||||
return response
|
||||
|
||||
def optimize_python_code(
|
||||
|
|
@ -91,33 +130,18 @@ class AiServiceClient:
|
|||
}
|
||||
logger.info("Generating optimized candidates ...")
|
||||
console.rule()
|
||||
try:
|
||||
response = self.make_ai_service_request("/optimize", payload=payload, timeout=600)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.exception(f"Error generating optimized candidates: {e}")
|
||||
ph("cli-optimize-error-caught", {"error": str(e)})
|
||||
return []
|
||||
|
||||
if response.status_code == 200:
|
||||
optimizations_json = response.json()["optimizations"]
|
||||
logger.info(f"Generated {len(optimizations_json)} candidates.")
|
||||
console.rule()
|
||||
return [
|
||||
OptimizedCandidate(
|
||||
source_code=opt["source_code"],
|
||||
explanation=opt["explanation"],
|
||||
optimization_id=opt["optimization_id"],
|
||||
)
|
||||
for opt in optimizations_json
|
||||
]
|
||||
try:
|
||||
error = response.json()["error"]
|
||||
except Exception:
|
||||
error = response.text
|
||||
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
|
||||
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
|
||||
response = self.make_ai_service_request("/optimize", payload=payload, timeout=600)
|
||||
|
||||
optimizations_json = response.json()["optimizations"]
|
||||
logger.info(f"Generated {len(optimizations_json)} candidates.")
|
||||
console.rule()
|
||||
return []
|
||||
return [
|
||||
OptimizedCandidate(
|
||||
source_code=opt["source_code"], explanation=opt["explanation"], optimization_id=opt["optimization_id"]
|
||||
)
|
||||
for opt in optimizations_json
|
||||
]
|
||||
|
||||
def log_results(
|
||||
self,
|
||||
|
|
@ -146,10 +170,8 @@ class AiServiceClient:
|
|||
"is_correct": is_correct,
|
||||
"codeflash_version": codeflash_version,
|
||||
}
|
||||
try:
|
||||
self.make_ai_service_request("/log_features", payload=payload, timeout=5)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.exception(f"Error logging features: {e}")
|
||||
|
||||
self.make_ai_service_request("/log_features", payload=payload, timeout=5)
|
||||
|
||||
def generate_regression_tests(
|
||||
self,
|
||||
|
|
@ -185,6 +207,7 @@ class AiServiceClient:
|
|||
"pytest",
|
||||
"unittest",
|
||||
], f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
|
||||
|
||||
payload = {
|
||||
"source_code_being_tested": source_code_being_tested,
|
||||
"function_to_optimize": function_to_optimize,
|
||||
|
|
@ -198,28 +221,12 @@ class AiServiceClient:
|
|||
"python_version": platform.python_version(),
|
||||
"codeflash_version": codeflash_version,
|
||||
}
|
||||
try:
|
||||
response = self.make_ai_service_request("/testgen", payload=payload, timeout=600)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.exception(f"Error generating tests: {e}")
|
||||
ph("cli-testgen-error-caught", {"error": str(e)})
|
||||
return None
|
||||
|
||||
# the timeout should be the same as the timeout for the AI service backend
|
||||
response = self.make_ai_service_request("/testgen", payload=payload, timeout=600)
|
||||
response_json = response.json()
|
||||
logger.debug(f"Generated tests for function {function_to_optimize.function_name}")
|
||||
|
||||
if response.status_code == 200:
|
||||
response_json = response.json()
|
||||
logger.debug(f"Generated tests for function {function_to_optimize.function_name}")
|
||||
return response_json["generated_tests"], response_json["instrumented_tests"]
|
||||
try:
|
||||
error = response.json()["error"]
|
||||
logger.error(f"Error generating tests: {response.status_code} - {error}")
|
||||
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
|
||||
return None
|
||||
except Exception:
|
||||
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
|
||||
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})
|
||||
return None
|
||||
return response_json["generated_tests"], response_json["instrumented_tests"]
|
||||
|
||||
|
||||
class LocalAiServiceClient(AiServiceClient):
|
||||
|
|
|
|||
|
|
@ -61,7 +61,6 @@ SPINNER_TYPES = {
|
|||
"hamburger",
|
||||
"dots",
|
||||
"squish",
|
||||
"christmas",
|
||||
"toggle13",
|
||||
"star",
|
||||
"boxBounce",
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ types-requests = "^2.32.0.20241016"
|
|||
types-six = "^1.16.21.20241009"
|
||||
types-cffi = "^1.16.0.20240331"
|
||||
types-openpyxl = "^3.1.5.20241020"
|
||||
stamina = "^24.3.0"
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
|
|
|||
Loading…
Reference in a new issue