various improvements

- Refactor cfapi.py to handle error responses from the API and return appropriate messages.
- make the CF-API clearer when returning responses
formatting & ruff

make ruff happy
This commit is contained in:
KRRT7 2024-10-28 08:12:53 +00:00 committed by Kevin Turcios
parent 108aa2ea0d
commit 84b42aa6cf
4 changed files with 42 additions and 18 deletions

View file

@ -4,17 +4,19 @@ import json
import os
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional
import requests
from pydantic.json import pydantic_encoder
from requests import Response
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
from codeflash.code_utils.git_utils import get_repo_owner_and_name
from codeflash.github.PrComment import FileDiffContent, PrComment
if TYPE_CHECKING:
from requests import Response
if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local":
CFAPI_BASE_URL = "http://localhost:3001"
logger.info(f"Using local CF API at {CFAPI_BASE_URL}.")
@ -22,8 +24,9 @@ else:
CFAPI_BASE_URL = "https://app.codeflash.ai"
def make_cfapi_request(endpoint: str, method: str, payload: Optional[Dict[str, Any]] = None) -> requests.Response:
def make_cfapi_request(endpoint: str, method: str, payload: dict[str, Any] | None = None) -> Response:
"""Make an HTTP request using the specified method, URL, headers, and JSON payload.
:param endpoint: The endpoint URL to send the request to.
:param method: The HTTP method to use ('GET', 'POST', etc.).
:param payload: Optional JSON payload to include in the POST request body.
@ -34,15 +37,16 @@ def make_cfapi_request(endpoint: str, method: str, payload: Optional[Dict[str, A
if method.upper() == "POST":
json_payload = json.dumps(payload, indent=None, default=pydantic_encoder)
cfapi_headers["Content-Type"] = "application/json"
response = requests.post(url, data=json_payload, headers=cfapi_headers)
response = requests.post(url, data=json_payload, headers=cfapi_headers, timeout=60)
else:
response = requests.get(url, headers=cfapi_headers)
response = requests.get(url, headers=cfapi_headers, timeout=60)
return response
@lru_cache(maxsize=1)
def get_user_id() -> Optional[str]:
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
:return: The userid or None if the request fails.
"""
if not ensure_codeflash_api_key():
@ -66,6 +70,7 @@ def suggest_changes(
trace_id: str,
) -> Response:
"""Suggest changes to a pull request.
Will make a review suggestion when possible;
or create a new dependent pull request with the suggested changes.
:param owner: The owner of the repository.
@ -86,8 +91,7 @@ def suggest_changes(
"generatedTests": generated_tests,
"traceId": trace_id,
}
response = make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload)
return response
return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload)
def create_pr(
@ -100,7 +104,8 @@ def create_pr(
generated_tests: str,
trace_id: str,
) -> Response:
"""Create a pull request, targeting the specified branch. (usually 'main')
"""Create a pull request, targeting the specified branch. (usually 'main').
:param owner: The owner of the repository.
:param repo: The name of the repository.
:param base_branch: The base branch to target.
@ -126,6 +131,7 @@ def create_pr(
def is_github_app_installed_on_repo(owner: str, repo: str) -> bool:
"""Check if the Codeflash GitHub App is installed on the specified repository.
:param owner: The owner of the repository.
:param repo: The name of the repository.
:return: The response object.
@ -137,17 +143,31 @@ def is_github_app_installed_on_repo(owner: str, repo: str) -> bool:
return True
def get_blocklisted_functions() -> dict[str, str]:
def get_blocklisted_functions() -> dict[str, set[str]]:
"""Retrieve blocklisted functions for the current pull request.
Returns A dictionary mapping filenames to sets of blocklisted function names.
"""
pr_number = get_pr_number()
if pr_number is None:
return {}
not_found = 404
internal_server_error = 500
owner, repo = get_repo_owner_and_name()
information = {"pr_number": pr_number, "repo_owner": owner, "repo_name": repo}
try:
req = make_cfapi_request(endpoint="/verify-existing-optimizations", method="POST", payload=information)
if req.status_code == not_found:
logger.debug(req.json()["message"])
return {}
if req.status_code == internal_server_error:
logger.error(req.json()["message"])
return {}
req.raise_for_status()
content: dict[str, list[str]] = req.json()
except Exception as e:
except requests.RequestException as e:
logger.error(f"Error getting blocklisted functions: {e}")
return {}
return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}

View file

@ -1,7 +1,6 @@
SPINNER_TYPES = {
"point",
"simpleDots",
"earth",
"pong",
"boxBounce2",
"bouncingBar",
@ -9,7 +8,6 @@ SPINNER_TYPES = {
"dots7",
"aesthetic",
"toggle5",
"weather",
"dots6",
"dots8",
"star2",
@ -21,7 +19,6 @@ SPINNER_TYPES = {
"toggle12",
"circle",
"bouncingBall",
"clock",
"toggle9",
"shark",
"circleHalves",
@ -64,6 +61,5 @@ SPINNER_TYPES = {
"toggle13",
"star",
"boxBounce",
"runner",
"toggle7",
}

View file

@ -357,7 +357,7 @@ async def testgen(
print("/testgen: Generating tests...")
debug_log_sensitive_data(f"Generating tests for function {data.function_to_optimize.function_name}")
debug_log_sensitive_data(f"Source code being tested: {data.source_code_being_tested}")
max_tries = 3
max_tries = 2
count = 0
for _ in range(max_tries):
if count >= max_tries:

View file

@ -35,8 +35,8 @@ export async function verifyExistingOptimizations(req, res) {
}
console.error("Error getting PR:", error)
return res
.status(500)
.send({ error: `Error getting PR ${pr_number} for ${repo_owner}/${repo_name}` })
.status(404)
.send({ message: `Error getting PR ${pr_number} for ${repo_owner}/${repo_name}` })
}
const optimizations_dict: { [key: string]: [string] } = {}
@ -52,6 +52,9 @@ export async function verifyExistingOptimizations(req, res) {
}
}
}
else {
return res.status(404).send({ message: "No optimizations found for this PR" })
}
let pr_messages
try {
@ -63,7 +66,7 @@ export async function verifyExistingOptimizations(req, res) {
} catch (error: any) {
console.error("Error getting PR messages:", error)
return res
.status(500)
.status(404)
.send({ error: `Error getting PR messages for ${repo_owner}/${repo_name}` })
}
@ -80,5 +83,10 @@ export async function verifyExistingOptimizations(req, res) {
}
}
}
if (Object.keys(optimizations_dict).length === 0) {
return res.status(404).send({ message: "No optimizations found for this PR" })
}
return res.status(200).send(optimizations_dict)
}