codeflash-internal/cli/codeflash/api/cfapi.py
KRRT7 84b42aa6cf various improvements
- Refactor cfapi.py to handle error responses from the API and return appropriate messages.
- make the CF-API clearer when returning responses
formatting & ruff

make ruff happy
2024-10-28 18:17:19 -05:00

173 lines
6.2 KiB
Python

from __future__ import annotations
import json
import os
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional
import requests
from pydantic.json import pydantic_encoder
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
from codeflash.code_utils.git_utils import get_repo_owner_and_name
from codeflash.github.PrComment import FileDiffContent, PrComment
if TYPE_CHECKING:
from requests import Response
if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local":
CFAPI_BASE_URL = "http://localhost:3001"
logger.info(f"Using local CF API at {CFAPI_BASE_URL}.")
else:
CFAPI_BASE_URL = "https://app.codeflash.ai"
def make_cfapi_request(endpoint: str, method: str, payload: dict[str, Any] | None = None) -> Response:
"""Make an HTTP request using the specified method, URL, headers, and JSON payload.
:param endpoint: The endpoint URL to send the request to.
:param method: The HTTP method to use ('GET', 'POST', etc.).
:param payload: Optional JSON payload to include in the POST request body.
:return: The response object from the API.
"""
url = f"{CFAPI_BASE_URL}/cfapi{endpoint}"
cfapi_headers = {"Authorization": f"Bearer {get_codeflash_api_key()}"}
if method.upper() == "POST":
json_payload = json.dumps(payload, indent=None, default=pydantic_encoder)
cfapi_headers["Content-Type"] = "application/json"
response = requests.post(url, data=json_payload, headers=cfapi_headers, timeout=60)
else:
response = requests.get(url, headers=cfapi_headers, timeout=60)
return response
@lru_cache(maxsize=1)
def get_user_id() -> Optional[str]:
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
:return: The userid or None if the request fails.
"""
if not ensure_codeflash_api_key():
return None
response = make_cfapi_request(endpoint="/cli-get-user", method="GET")
if response.status_code == 200:
return response.text
logger.error(f"Failed to look up your userid; is your CF API key valid? ({response.reason})")
return None
def suggest_changes(
owner: str,
repo: str,
pr_number: int,
file_changes: dict[str, FileDiffContent],
pr_comment: PrComment,
existing_tests: str,
generated_tests: str,
trace_id: str,
) -> Response:
"""Suggest changes to a pull request.
Will make a review suggestion when possible;
or create a new dependent pull request with the suggested changes.
:param owner: The owner of the repository.
:param repo: The name of the repository.
:param pr_number: The number of the pull request.
:param file_changes: A dictionary of file changes.
:param pr_comment: The pull request comment object, containing the optimization explanation, best runtime, etc.
:param generated_tests: The generated tests.
:return: The response object.
"""
payload = {
"owner": owner,
"repo": repo,
"pullNumber": pr_number,
"diffContents": file_changes,
"prCommentFields": pr_comment.to_json(),
"existingTests": existing_tests,
"generatedTests": generated_tests,
"traceId": trace_id,
}
return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload)
def create_pr(
owner: str,
repo: str,
base_branch: str,
file_changes: dict[str, FileDiffContent],
pr_comment: PrComment,
existing_tests: str,
generated_tests: str,
trace_id: str,
) -> Response:
"""Create a pull request, targeting the specified branch. (usually 'main').
:param owner: The owner of the repository.
:param repo: The name of the repository.
:param base_branch: The base branch to target.
:param file_changes: A dictionary of file changes.
:param pr_comment: The pull request comment object, containing the optimization explanation, best runtime, etc.
:param generated_tests: The generated tests.
:return: The response object.
"""
# convert Path objects to strings
payload = {
"owner": owner,
"repo": repo,
"baseBranch": base_branch,
"diffContents": file_changes,
"prCommentFields": pr_comment.to_json(),
"existingTests": existing_tests,
"generatedTests": generated_tests,
"traceId": trace_id,
}
response = make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload)
return response
def is_github_app_installed_on_repo(owner: str, repo: str) -> bool:
"""Check if the Codeflash GitHub App is installed on the specified repository.
:param owner: The owner of the repository.
:param repo: The name of the repository.
:return: The response object.
"""
response = make_cfapi_request(endpoint=f"/is-github-app-installed?repo={repo}&owner={owner}", method="GET")
if not response.ok or response.text != "true":
logger.error(f"Error: {response.text}")
return False
return True
def get_blocklisted_functions() -> dict[str, set[str]]:
"""Retrieve blocklisted functions for the current pull request.
Returns A dictionary mapping filenames to sets of blocklisted function names.
"""
pr_number = get_pr_number()
if pr_number is None:
return {}
not_found = 404
internal_server_error = 500
owner, repo = get_repo_owner_and_name()
information = {"pr_number": pr_number, "repo_owner": owner, "repo_name": repo}
try:
req = make_cfapi_request(endpoint="/verify-existing-optimizations", method="POST", payload=information)
if req.status_code == not_found:
logger.debug(req.json()["message"])
return {}
if req.status_code == internal_server_error:
logger.error(req.json()["message"])
return {}
req.raise_for_status()
content: dict[str, list[str]] = req.json()
except requests.RequestException as e:
logger.error(f"Error getting blocklisted functions: {e}")
return {}
return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}