Stnadardize json parsing to use pydantic json encoder in all cases (in pydantic 1 RootModel does not exist, we use pydantic's json encoder instead.)

This commit is contained in:
afik.cohen 2024-02-05 18:47:44 -08:00
parent 96d2a7cad0
commit 48dca8ff50
7 changed files with 33 additions and 34 deletions

View file

@ -1,9 +1,11 @@
import json
import logging
import os
import requests
from pydantic import RootModel
from typing import Any, Dict, List, Tuple, Optional
import requests
from pydantic.json import pydantic_encoder
from codeflash.analytics.posthog import ph
from codeflash.code_utils.env_utils import get_codeflash_api_key
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@ -23,19 +25,19 @@ def make_ai_service_request(
) -> requests.Response:
"""
Make an API request to the given endpoint on the AI service.
Parameters:
- endpoint (str): The endpoint to call, e.g., "/optimize".
- method (str): The HTTP method to use, e.g., "POST".
- data (Dict[str, Any]): The data to send in the request.
Returns:
- requests.Response: The response from the API.
:param endpoint: The endpoint to call, e.g., "/optimize".
:param method: The HTTP method to use ('GET' or 'POST').
:param payload: Optional JSON payload to include in the POST request body.
:param timeout: The timeout for the request.
:return: The response object from the API.
"""
url = f"{AI_SERVICE_BASE_URL}/ai{endpoint}"
ai_service_headers = {"Authorization": f"Bearer {get_codeflash_api_key()}"}
if method.upper() == "POST":
response = requests.post(url, json=payload, headers=ai_service_headers, timeout=timeout)
json_payload = json.dumps(payload, indent=4, default=pydantic_encoder)
response = requests.post(
url, data=json_payload, headers=ai_service_headers, timeout=timeout
)
else:
response = requests.get(url, headers=ai_service_headers, timeout=timeout)
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
@ -55,10 +57,10 @@ def optimize_python_code(
Returns:
- List[Tuple[str, str]]: A list of tuples where the first element is the optimized code and the second is the explanation.
"""
data = {"source_code": source_code, "num_variants": num_variants}
payload = {"source_code": source_code, "num_variants": num_variants}
logging.info(f"Generating optimized candidates ...")
try:
response = make_ai_service_request("/optimize", payload=data, timeout=600)
response = make_ai_service_request("/optimize", payload=payload, timeout=600)
except requests.exceptions.RequestException as e:
logging.error(f"Error generating optimized candidates: {e}")
ph("cli-optimize-error-caught", {"error": str(e)})
@ -108,11 +110,9 @@ def generate_regression_tests(
"pytest",
"unittest",
], f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
data = {
payload = {
"source_code_being_tested": source_code_being_tested,
"function_to_optimize": RootModel[FunctionToOptimize](function_to_optimize).model_dump(
mode="json"
),
"function_to_optimize": function_to_optimize,
"dependent_function_names": dependent_function_names,
"module_path": module_path,
"test_module_path": test_module_path,
@ -120,7 +120,7 @@ def generate_regression_tests(
"test_timeout": test_timeout,
}
try:
response = make_ai_service_request("/testgen", payload=data, timeout=600)
response = make_ai_service_request("/testgen", payload=payload, timeout=600)
except requests.exceptions.RequestException as e:
logging.error(f"Error generating tests: {e}")
ph("cli-testgen-error-caught", {"error": str(e)})

View file

@ -1,9 +1,11 @@
import json
import logging
import os
from functools import lru_cache
from typing import Optional, Dict, Any
import requests
from pydantic.json import pydantic_encoder
from requests import Response
from codeflash.code_utils.env_utils import get_codeflash_api_key
@ -21,15 +23,16 @@ def make_cfapi_request(
) -> requests.Response:
"""
Make an HTTP request using the specified method, URL, headers, and JSON payload.
:param endpoint: The URL to send the request to.
:param endpoint: The endpoint URL to send the request to.
:param method: The HTTP method to use ('GET', 'POST', etc.).
:param payload: Optional JSON payload to include in the request body.
:return: The response object.
:param payload: Optional JSON payload to include in the POST request body.
:return: The response object from the API.
"""
url = f"{CFAPI_BASE_URL}/cfapi{endpoint}"
cfapi_headers = {"Authorization": f"Bearer {get_codeflash_api_key()}"}
if method.upper() == "POST":
response = requests.post(url, json=payload, headers=cfapi_headers)
json_payload = json.dumps(payload, indent=4, default=pydantic_encoder)
response = requests.post(url, json=json_payload, headers=cfapi_headers)
else:
response = requests.get(url, headers=cfapi_headers)
return response

View file

@ -1,8 +1,10 @@
import logging
import os
from functools import lru_cache
from typing import Optional
@lru_cache(maxsize=1)
def get_codeflash_api_key() -> Optional[str]:
api_key = os.environ.get("CODEFLASH_API_KEY")
if not api_key:
@ -33,11 +35,13 @@ def ensure_codeflash_api_key() -> bool:
return True
@lru_cache(maxsize=1)
def get_codeflash_org_key() -> Optional[str]:
api_key = os.environ.get("CODEFLASH_ORG_KEY")
return api_key
@lru_cache(maxsize=1)
def get_pr_number() -> Optional[int]:
pr_number = os.environ.get("CODEFLASH_PR_NUMBER")
if not pr_number:

View file

@ -6,7 +6,6 @@ from typing import List
import jedi
import tiktoken
from jedi.api.classes import Name
from pydantic import RootModel
from pydantic.dataclasses import dataclass
from codeflash.code_utils.code_extractor import get_code_no_skeleton, get_code
@ -44,11 +43,6 @@ class Source:
source_code: str
@dataclass(frozen=True)
class SourceList(RootModel):
root: list[Source]
def get_type_annotation_context(
function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: str
) -> List[Source]:

View file

@ -64,9 +64,7 @@ def check_create_pr(
repo=repo,
base_branch=base_branch,
file_changes={
relative_path: FileDiffContent(
oldContent=original_code, newContent=new_code
).model_dump(mode="json")
relative_path: FileDiffContent(oldContent=original_code, newContent=new_code)
},
pr_comment=PrComment(
optimization_explanation=explanation.explanation_message(),

View file

@ -1,3 +1,3 @@
# These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`.
__version__ = "0.3.1"
__version_tuple__ = (0, 3, 1)
__version__ = "0.3.3.post16.dev0+8f05c9c"
__version_tuple__ = (0, 3, 3, "post16", "dev0", "8f05c9c")

View file

@ -22,7 +22,7 @@ pytest-timeout = ">=2.1.0"
tomlkit = ">=0.11.7"
unittest-xml-reporting = ">=3.2.0"
junitparser = ">=3.1.0"
pydantic = "^2.5.2"
pydantic = "==1.10.14"
black = ">=22.3.0"
humanize = ">=4.0.0"
posthog = ">=3.0.0"