From 84b42aa6cf1d3f2587a01201ca23dfedaf3c124c Mon Sep 17 00:00:00 2001 From: KRRT7 Date: Mon, 28 Oct 2024 08:12:53 +0000 Subject: [PATCH] various improvements - Refactor cfapi.py to handle error responses from the API and return appropriate messages. - make the CF-API clearer when returning responses formatting & ruff make ruff happy --- cli/codeflash/api/cfapi.py | 40 ++++++++++++++----- cli/codeflash/cli_cmds/console_constants.py | 4 -- django/aiservice/testgen/testgen.py | 2 +- .../verify-existing-optimizations.ts | 14 +++++-- 4 files changed, 42 insertions(+), 18 deletions(-) diff --git a/cli/codeflash/api/cfapi.py b/cli/codeflash/api/cfapi.py index d2733570b..5d5827a95 100644 --- a/cli/codeflash/api/cfapi.py +++ b/cli/codeflash/api/cfapi.py @@ -4,17 +4,19 @@ import json import os from functools import lru_cache from pathlib import Path -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional import requests from pydantic.json import pydantic_encoder -from requests import Response from codeflash.cli_cmds.console import logger from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number from codeflash.code_utils.git_utils import get_repo_owner_and_name from codeflash.github.PrComment import FileDiffContent, PrComment +if TYPE_CHECKING: + from requests import Response + if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local": CFAPI_BASE_URL = "http://localhost:3001" logger.info(f"Using local CF API at {CFAPI_BASE_URL}.") @@ -22,8 +24,9 @@ else: CFAPI_BASE_URL = "https://app.codeflash.ai" -def make_cfapi_request(endpoint: str, method: str, payload: Optional[Dict[str, Any]] = None) -> requests.Response: +def make_cfapi_request(endpoint: str, method: str, payload: dict[str, Any] | None = None) -> Response: """Make an HTTP request using the specified method, URL, headers, and JSON payload. + :param endpoint: The endpoint URL to send the request to. :param method: The HTTP method to use ('GET', 'POST', etc.). :param payload: Optional JSON payload to include in the POST request body. @@ -34,15 +37,16 @@ def make_cfapi_request(endpoint: str, method: str, payload: Optional[Dict[str, A if method.upper() == "POST": json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) cfapi_headers["Content-Type"] = "application/json" - response = requests.post(url, data=json_payload, headers=cfapi_headers) + response = requests.post(url, data=json_payload, headers=cfapi_headers, timeout=60) else: - response = requests.get(url, headers=cfapi_headers) + response = requests.get(url, headers=cfapi_headers, timeout=60) return response @lru_cache(maxsize=1) def get_user_id() -> Optional[str]: """Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint. + :return: The userid or None if the request fails. """ if not ensure_codeflash_api_key(): @@ -66,6 +70,7 @@ def suggest_changes( trace_id: str, ) -> Response: """Suggest changes to a pull request. + Will make a review suggestion when possible; or create a new dependent pull request with the suggested changes. :param owner: The owner of the repository. @@ -86,8 +91,7 @@ def suggest_changes( "generatedTests": generated_tests, "traceId": trace_id, } - response = make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload) - return response + return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload) def create_pr( @@ -100,7 +104,8 @@ def create_pr( generated_tests: str, trace_id: str, ) -> Response: - """Create a pull request, targeting the specified branch. (usually 'main') + """Create a pull request, targeting the specified branch. (usually 'main'). + :param owner: The owner of the repository. :param repo: The name of the repository. :param base_branch: The base branch to target. @@ -126,6 +131,7 @@ def create_pr( def is_github_app_installed_on_repo(owner: str, repo: str) -> bool: """Check if the Codeflash GitHub App is installed on the specified repository. + :param owner: The owner of the repository. :param repo: The name of the repository. :return: The response object. @@ -137,17 +143,31 @@ def is_github_app_installed_on_repo(owner: str, repo: str) -> bool: return True -def get_blocklisted_functions() -> dict[str, str]: +def get_blocklisted_functions() -> dict[str, set[str]]: + """Retrieve blocklisted functions for the current pull request. + + Returns A dictionary mapping filenames to sets of blocklisted function names. + """ pr_number = get_pr_number() if pr_number is None: return {} + not_found = 404 + internal_server_error = 500 + owner, repo = get_repo_owner_and_name() information = {"pr_number": pr_number, "repo_owner": owner, "repo_name": repo} try: req = make_cfapi_request(endpoint="/verify-existing-optimizations", method="POST", payload=information) + if req.status_code == not_found: + logger.debug(req.json()["message"]) + return {} + if req.status_code == internal_server_error: + logger.error(req.json()["message"]) + return {} + req.raise_for_status() content: dict[str, list[str]] = req.json() - except Exception as e: + except requests.RequestException as e: logger.error(f"Error getting blocklisted functions: {e}") return {} return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()} diff --git a/cli/codeflash/cli_cmds/console_constants.py b/cli/codeflash/cli_cmds/console_constants.py index 758b94adf..13f8aa621 100644 --- a/cli/codeflash/cli_cmds/console_constants.py +++ b/cli/codeflash/cli_cmds/console_constants.py @@ -1,7 +1,6 @@ SPINNER_TYPES = { "point", "simpleDots", - "earth", "pong", "boxBounce2", "bouncingBar", @@ -9,7 +8,6 @@ SPINNER_TYPES = { "dots7", "aesthetic", "toggle5", - "weather", "dots6", "dots8", "star2", @@ -21,7 +19,6 @@ SPINNER_TYPES = { "toggle12", "circle", "bouncingBall", - "clock", "toggle9", "shark", "circleHalves", @@ -64,6 +61,5 @@ SPINNER_TYPES = { "toggle13", "star", "boxBounce", - "runner", "toggle7", } diff --git a/django/aiservice/testgen/testgen.py b/django/aiservice/testgen/testgen.py index 0b2a3a32d..6406c8c86 100644 --- a/django/aiservice/testgen/testgen.py +++ b/django/aiservice/testgen/testgen.py @@ -357,7 +357,7 @@ async def testgen( print("/testgen: Generating tests...") debug_log_sensitive_data(f"Generating tests for function {data.function_to_optimize.function_name}") debug_log_sensitive_data(f"Source code being tested: {data.source_code_being_tested}") - max_tries = 3 + max_tries = 2 count = 0 for _ in range(max_tries): if count >= max_tries: diff --git a/js/cf-api/endpoints/verify-existing-optimizations.ts b/js/cf-api/endpoints/verify-existing-optimizations.ts index 358623839..e8edbaf0c 100644 --- a/js/cf-api/endpoints/verify-existing-optimizations.ts +++ b/js/cf-api/endpoints/verify-existing-optimizations.ts @@ -35,8 +35,8 @@ export async function verifyExistingOptimizations(req, res) { } console.error("Error getting PR:", error) return res - .status(500) - .send({ error: `Error getting PR ${pr_number} for ${repo_owner}/${repo_name}` }) + .status(404) + .send({ message: `Error getting PR ${pr_number} for ${repo_owner}/${repo_name}` }) } const optimizations_dict: { [key: string]: [string] } = {} @@ -52,6 +52,9 @@ export async function verifyExistingOptimizations(req, res) { } } } + else { + return res.status(404).send({ message: "No optimizations found for this PR" }) + } let pr_messages try { @@ -63,7 +66,7 @@ export async function verifyExistingOptimizations(req, res) { } catch (error: any) { console.error("Error getting PR messages:", error) return res - .status(500) + .status(404) .send({ error: `Error getting PR messages for ${repo_owner}/${repo_name}` }) } @@ -80,5 +83,10 @@ export async function verifyExistingOptimizations(req, res) { } } } + + if (Object.keys(optimizations_dict).length === 0) { + return res.status(404).send({ message: "No optimizations found for this PR" }) + } + return res.status(200).send(optimizations_dict) }