- Refactor cfapi.py to handle error responses from the API and return appropriate messages. - make the CF-API clearer when returning responses formatting & ruff make ruff happy
173 lines
6.2 KiB
Python
173 lines
6.2 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
|
|
import requests
|
|
from pydantic.json import pydantic_encoder
|
|
|
|
from codeflash.cli_cmds.console import logger
|
|
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
|
|
from codeflash.code_utils.git_utils import get_repo_owner_and_name
|
|
from codeflash.github.PrComment import FileDiffContent, PrComment
|
|
|
|
if TYPE_CHECKING:
|
|
from requests import Response
|
|
|
|
if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local":
|
|
CFAPI_BASE_URL = "http://localhost:3001"
|
|
logger.info(f"Using local CF API at {CFAPI_BASE_URL}.")
|
|
else:
|
|
CFAPI_BASE_URL = "https://app.codeflash.ai"
|
|
|
|
|
|
def make_cfapi_request(endpoint: str, method: str, payload: dict[str, Any] | None = None) -> Response:
|
|
"""Make an HTTP request using the specified method, URL, headers, and JSON payload.
|
|
|
|
:param endpoint: The endpoint URL to send the request to.
|
|
:param method: The HTTP method to use ('GET', 'POST', etc.).
|
|
:param payload: Optional JSON payload to include in the POST request body.
|
|
:return: The response object from the API.
|
|
"""
|
|
url = f"{CFAPI_BASE_URL}/cfapi{endpoint}"
|
|
cfapi_headers = {"Authorization": f"Bearer {get_codeflash_api_key()}"}
|
|
if method.upper() == "POST":
|
|
json_payload = json.dumps(payload, indent=None, default=pydantic_encoder)
|
|
cfapi_headers["Content-Type"] = "application/json"
|
|
response = requests.post(url, data=json_payload, headers=cfapi_headers, timeout=60)
|
|
else:
|
|
response = requests.get(url, headers=cfapi_headers, timeout=60)
|
|
return response
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_user_id() -> Optional[str]:
|
|
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
|
|
|
|
:return: The userid or None if the request fails.
|
|
"""
|
|
if not ensure_codeflash_api_key():
|
|
return None
|
|
|
|
response = make_cfapi_request(endpoint="/cli-get-user", method="GET")
|
|
if response.status_code == 200:
|
|
return response.text
|
|
logger.error(f"Failed to look up your userid; is your CF API key valid? ({response.reason})")
|
|
return None
|
|
|
|
|
|
def suggest_changes(
|
|
owner: str,
|
|
repo: str,
|
|
pr_number: int,
|
|
file_changes: dict[str, FileDiffContent],
|
|
pr_comment: PrComment,
|
|
existing_tests: str,
|
|
generated_tests: str,
|
|
trace_id: str,
|
|
) -> Response:
|
|
"""Suggest changes to a pull request.
|
|
|
|
Will make a review suggestion when possible;
|
|
or create a new dependent pull request with the suggested changes.
|
|
:param owner: The owner of the repository.
|
|
:param repo: The name of the repository.
|
|
:param pr_number: The number of the pull request.
|
|
:param file_changes: A dictionary of file changes.
|
|
:param pr_comment: The pull request comment object, containing the optimization explanation, best runtime, etc.
|
|
:param generated_tests: The generated tests.
|
|
:return: The response object.
|
|
"""
|
|
payload = {
|
|
"owner": owner,
|
|
"repo": repo,
|
|
"pullNumber": pr_number,
|
|
"diffContents": file_changes,
|
|
"prCommentFields": pr_comment.to_json(),
|
|
"existingTests": existing_tests,
|
|
"generatedTests": generated_tests,
|
|
"traceId": trace_id,
|
|
}
|
|
return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload)
|
|
|
|
|
|
def create_pr(
|
|
owner: str,
|
|
repo: str,
|
|
base_branch: str,
|
|
file_changes: dict[str, FileDiffContent],
|
|
pr_comment: PrComment,
|
|
existing_tests: str,
|
|
generated_tests: str,
|
|
trace_id: str,
|
|
) -> Response:
|
|
"""Create a pull request, targeting the specified branch. (usually 'main').
|
|
|
|
:param owner: The owner of the repository.
|
|
:param repo: The name of the repository.
|
|
:param base_branch: The base branch to target.
|
|
:param file_changes: A dictionary of file changes.
|
|
:param pr_comment: The pull request comment object, containing the optimization explanation, best runtime, etc.
|
|
:param generated_tests: The generated tests.
|
|
:return: The response object.
|
|
"""
|
|
# convert Path objects to strings
|
|
payload = {
|
|
"owner": owner,
|
|
"repo": repo,
|
|
"baseBranch": base_branch,
|
|
"diffContents": file_changes,
|
|
"prCommentFields": pr_comment.to_json(),
|
|
"existingTests": existing_tests,
|
|
"generatedTests": generated_tests,
|
|
"traceId": trace_id,
|
|
}
|
|
response = make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload)
|
|
return response
|
|
|
|
|
|
def is_github_app_installed_on_repo(owner: str, repo: str) -> bool:
|
|
"""Check if the Codeflash GitHub App is installed on the specified repository.
|
|
|
|
:param owner: The owner of the repository.
|
|
:param repo: The name of the repository.
|
|
:return: The response object.
|
|
"""
|
|
response = make_cfapi_request(endpoint=f"/is-github-app-installed?repo={repo}&owner={owner}", method="GET")
|
|
if not response.ok or response.text != "true":
|
|
logger.error(f"Error: {response.text}")
|
|
return False
|
|
return True
|
|
|
|
|
|
def get_blocklisted_functions() -> dict[str, set[str]]:
|
|
"""Retrieve blocklisted functions for the current pull request.
|
|
|
|
Returns A dictionary mapping filenames to sets of blocklisted function names.
|
|
"""
|
|
pr_number = get_pr_number()
|
|
if pr_number is None:
|
|
return {}
|
|
|
|
not_found = 404
|
|
internal_server_error = 500
|
|
|
|
owner, repo = get_repo_owner_and_name()
|
|
information = {"pr_number": pr_number, "repo_owner": owner, "repo_name": repo}
|
|
try:
|
|
req = make_cfapi_request(endpoint="/verify-existing-optimizations", method="POST", payload=information)
|
|
if req.status_code == not_found:
|
|
logger.debug(req.json()["message"])
|
|
return {}
|
|
if req.status_code == internal_server_error:
|
|
logger.error(req.json()["message"])
|
|
return {}
|
|
req.raise_for_status()
|
|
content: dict[str, list[str]] = req.json()
|
|
except requests.RequestException as e:
|
|
logger.error(f"Error getting blocklisted functions: {e}")
|
|
return {}
|
|
return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}
|