mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
various improvements
- Refactor cfapi.py to handle error responses from the API and return appropriate messages. - make the CF-API clearer when returning responses formatting & ruff make ruff happy
This commit is contained in:
parent
108aa2ea0d
commit
84b42aa6cf
4 changed files with 42 additions and 18 deletions
|
|
@ -4,17 +4,19 @@ import json
|
|||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from pydantic.json import pydantic_encoder
|
||||
from requests import Response
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
|
||||
from codeflash.code_utils.git_utils import get_repo_owner_and_name
|
||||
from codeflash.github.PrComment import FileDiffContent, PrComment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from requests import Response
|
||||
|
||||
if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local":
|
||||
CFAPI_BASE_URL = "http://localhost:3001"
|
||||
logger.info(f"Using local CF API at {CFAPI_BASE_URL}.")
|
||||
|
|
@ -22,8 +24,9 @@ else:
|
|||
CFAPI_BASE_URL = "https://app.codeflash.ai"
|
||||
|
||||
|
||||
def make_cfapi_request(endpoint: str, method: str, payload: Optional[Dict[str, Any]] = None) -> requests.Response:
|
||||
def make_cfapi_request(endpoint: str, method: str, payload: dict[str, Any] | None = None) -> Response:
|
||||
"""Make an HTTP request using the specified method, URL, headers, and JSON payload.
|
||||
|
||||
:param endpoint: The endpoint URL to send the request to.
|
||||
:param method: The HTTP method to use ('GET', 'POST', etc.).
|
||||
:param payload: Optional JSON payload to include in the POST request body.
|
||||
|
|
@ -34,15 +37,16 @@ def make_cfapi_request(endpoint: str, method: str, payload: Optional[Dict[str, A
|
|||
if method.upper() == "POST":
|
||||
json_payload = json.dumps(payload, indent=None, default=pydantic_encoder)
|
||||
cfapi_headers["Content-Type"] = "application/json"
|
||||
response = requests.post(url, data=json_payload, headers=cfapi_headers)
|
||||
response = requests.post(url, data=json_payload, headers=cfapi_headers, timeout=60)
|
||||
else:
|
||||
response = requests.get(url, headers=cfapi_headers)
|
||||
response = requests.get(url, headers=cfapi_headers, timeout=60)
|
||||
return response
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_user_id() -> Optional[str]:
|
||||
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
|
||||
|
||||
:return: The userid or None if the request fails.
|
||||
"""
|
||||
if not ensure_codeflash_api_key():
|
||||
|
|
@ -66,6 +70,7 @@ def suggest_changes(
|
|||
trace_id: str,
|
||||
) -> Response:
|
||||
"""Suggest changes to a pull request.
|
||||
|
||||
Will make a review suggestion when possible;
|
||||
or create a new dependent pull request with the suggested changes.
|
||||
:param owner: The owner of the repository.
|
||||
|
|
@ -86,8 +91,7 @@ def suggest_changes(
|
|||
"generatedTests": generated_tests,
|
||||
"traceId": trace_id,
|
||||
}
|
||||
response = make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload)
|
||||
return response
|
||||
return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload)
|
||||
|
||||
|
||||
def create_pr(
|
||||
|
|
@ -100,7 +104,8 @@ def create_pr(
|
|||
generated_tests: str,
|
||||
trace_id: str,
|
||||
) -> Response:
|
||||
"""Create a pull request, targeting the specified branch. (usually 'main')
|
||||
"""Create a pull request, targeting the specified branch. (usually 'main').
|
||||
|
||||
:param owner: The owner of the repository.
|
||||
:param repo: The name of the repository.
|
||||
:param base_branch: The base branch to target.
|
||||
|
|
@ -126,6 +131,7 @@ def create_pr(
|
|||
|
||||
def is_github_app_installed_on_repo(owner: str, repo: str) -> bool:
|
||||
"""Check if the Codeflash GitHub App is installed on the specified repository.
|
||||
|
||||
:param owner: The owner of the repository.
|
||||
:param repo: The name of the repository.
|
||||
:return: The response object.
|
||||
|
|
@ -137,17 +143,31 @@ def is_github_app_installed_on_repo(owner: str, repo: str) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def get_blocklisted_functions() -> dict[str, str]:
|
||||
def get_blocklisted_functions() -> dict[str, set[str]]:
|
||||
"""Retrieve blocklisted functions for the current pull request.
|
||||
|
||||
Returns A dictionary mapping filenames to sets of blocklisted function names.
|
||||
"""
|
||||
pr_number = get_pr_number()
|
||||
if pr_number is None:
|
||||
return {}
|
||||
|
||||
not_found = 404
|
||||
internal_server_error = 500
|
||||
|
||||
owner, repo = get_repo_owner_and_name()
|
||||
information = {"pr_number": pr_number, "repo_owner": owner, "repo_name": repo}
|
||||
try:
|
||||
req = make_cfapi_request(endpoint="/verify-existing-optimizations", method="POST", payload=information)
|
||||
if req.status_code == not_found:
|
||||
logger.debug(req.json()["message"])
|
||||
return {}
|
||||
if req.status_code == internal_server_error:
|
||||
logger.error(req.json()["message"])
|
||||
return {}
|
||||
req.raise_for_status()
|
||||
content: dict[str, list[str]] = req.json()
|
||||
except Exception as e:
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Error getting blocklisted functions: {e}")
|
||||
return {}
|
||||
return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
SPINNER_TYPES = {
|
||||
"point",
|
||||
"simpleDots",
|
||||
"earth",
|
||||
"pong",
|
||||
"boxBounce2",
|
||||
"bouncingBar",
|
||||
|
|
@ -9,7 +8,6 @@ SPINNER_TYPES = {
|
|||
"dots7",
|
||||
"aesthetic",
|
||||
"toggle5",
|
||||
"weather",
|
||||
"dots6",
|
||||
"dots8",
|
||||
"star2",
|
||||
|
|
@ -21,7 +19,6 @@ SPINNER_TYPES = {
|
|||
"toggle12",
|
||||
"circle",
|
||||
"bouncingBall",
|
||||
"clock",
|
||||
"toggle9",
|
||||
"shark",
|
||||
"circleHalves",
|
||||
|
|
@ -64,6 +61,5 @@ SPINNER_TYPES = {
|
|||
"toggle13",
|
||||
"star",
|
||||
"boxBounce",
|
||||
"runner",
|
||||
"toggle7",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -357,7 +357,7 @@ async def testgen(
|
|||
print("/testgen: Generating tests...")
|
||||
debug_log_sensitive_data(f"Generating tests for function {data.function_to_optimize.function_name}")
|
||||
debug_log_sensitive_data(f"Source code being tested: {data.source_code_being_tested}")
|
||||
max_tries = 3
|
||||
max_tries = 2
|
||||
count = 0
|
||||
for _ in range(max_tries):
|
||||
if count >= max_tries:
|
||||
|
|
|
|||
|
|
@ -35,8 +35,8 @@ export async function verifyExistingOptimizations(req, res) {
|
|||
}
|
||||
console.error("Error getting PR:", error)
|
||||
return res
|
||||
.status(500)
|
||||
.send({ error: `Error getting PR ${pr_number} for ${repo_owner}/${repo_name}` })
|
||||
.status(404)
|
||||
.send({ message: `Error getting PR ${pr_number} for ${repo_owner}/${repo_name}` })
|
||||
}
|
||||
|
||||
const optimizations_dict: { [key: string]: [string] } = {}
|
||||
|
|
@ -52,6 +52,9 @@ export async function verifyExistingOptimizations(req, res) {
|
|||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
return res.status(404).send({ message: "No optimizations found for this PR" })
|
||||
}
|
||||
|
||||
let pr_messages
|
||||
try {
|
||||
|
|
@ -63,7 +66,7 @@ export async function verifyExistingOptimizations(req, res) {
|
|||
} catch (error: any) {
|
||||
console.error("Error getting PR messages:", error)
|
||||
return res
|
||||
.status(500)
|
||||
.status(404)
|
||||
.send({ error: `Error getting PR messages for ${repo_owner}/${repo_name}` })
|
||||
}
|
||||
|
||||
|
|
@ -80,5 +83,10 @@ export async function verifyExistingOptimizations(req, res) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (Object.keys(optimizations_dict).length === 0) {
|
||||
return res.status(404).send({ message: "No optimizations found for this PR" })
|
||||
}
|
||||
|
||||
return res.status(200).send(optimizations_dict)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue