Stnadardize json parsing to use pydantic json encoder in all cases (in pydantic 1 RootModel does not exist, we use pydantic's json encoder instead.)
This commit is contained in:
parent
96d2a7cad0
commit
48dca8ff50
7 changed files with 33 additions and 34 deletions
|
|
@ -1,9 +1,11 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
from pydantic import RootModel
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
|
||||
import requests
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from codeflash.analytics.posthog import ph
|
||||
from codeflash.code_utils.env_utils import get_codeflash_api_key
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
|
@ -23,19 +25,19 @@ def make_ai_service_request(
|
|||
) -> requests.Response:
|
||||
"""
|
||||
Make an API request to the given endpoint on the AI service.
|
||||
|
||||
Parameters:
|
||||
- endpoint (str): The endpoint to call, e.g., "/optimize".
|
||||
- method (str): The HTTP method to use, e.g., "POST".
|
||||
- data (Dict[str, Any]): The data to send in the request.
|
||||
|
||||
Returns:
|
||||
- requests.Response: The response from the API.
|
||||
:param endpoint: The endpoint to call, e.g., "/optimize".
|
||||
:param method: The HTTP method to use ('GET' or 'POST').
|
||||
:param payload: Optional JSON payload to include in the POST request body.
|
||||
:param timeout: The timeout for the request.
|
||||
:return: The response object from the API.
|
||||
"""
|
||||
url = f"{AI_SERVICE_BASE_URL}/ai{endpoint}"
|
||||
ai_service_headers = {"Authorization": f"Bearer {get_codeflash_api_key()}"}
|
||||
if method.upper() == "POST":
|
||||
response = requests.post(url, json=payload, headers=ai_service_headers, timeout=timeout)
|
||||
json_payload = json.dumps(payload, indent=4, default=pydantic_encoder)
|
||||
response = requests.post(
|
||||
url, data=json_payload, headers=ai_service_headers, timeout=timeout
|
||||
)
|
||||
else:
|
||||
response = requests.get(url, headers=ai_service_headers, timeout=timeout)
|
||||
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
|
||||
|
|
@ -55,10 +57,10 @@ def optimize_python_code(
|
|||
Returns:
|
||||
- List[Tuple[str, str]]: A list of tuples where the first element is the optimized code and the second is the explanation.
|
||||
"""
|
||||
data = {"source_code": source_code, "num_variants": num_variants}
|
||||
payload = {"source_code": source_code, "num_variants": num_variants}
|
||||
logging.info(f"Generating optimized candidates ...")
|
||||
try:
|
||||
response = make_ai_service_request("/optimize", payload=data, timeout=600)
|
||||
response = make_ai_service_request("/optimize", payload=payload, timeout=600)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"Error generating optimized candidates: {e}")
|
||||
ph("cli-optimize-error-caught", {"error": str(e)})
|
||||
|
|
@ -108,11 +110,9 @@ def generate_regression_tests(
|
|||
"pytest",
|
||||
"unittest",
|
||||
], f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
|
||||
data = {
|
||||
payload = {
|
||||
"source_code_being_tested": source_code_being_tested,
|
||||
"function_to_optimize": RootModel[FunctionToOptimize](function_to_optimize).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"function_to_optimize": function_to_optimize,
|
||||
"dependent_function_names": dependent_function_names,
|
||||
"module_path": module_path,
|
||||
"test_module_path": test_module_path,
|
||||
|
|
@ -120,7 +120,7 @@ def generate_regression_tests(
|
|||
"test_timeout": test_timeout,
|
||||
}
|
||||
try:
|
||||
response = make_ai_service_request("/testgen", payload=data, timeout=600)
|
||||
response = make_ai_service_request("/testgen", payload=payload, timeout=600)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"Error generating tests: {e}")
|
||||
ph("cli-testgen-error-caught", {"error": str(e)})
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import requests
|
||||
from pydantic.json import pydantic_encoder
|
||||
from requests import Response
|
||||
|
||||
from codeflash.code_utils.env_utils import get_codeflash_api_key
|
||||
|
|
@ -21,15 +23,16 @@ def make_cfapi_request(
|
|||
) -> requests.Response:
|
||||
"""
|
||||
Make an HTTP request using the specified method, URL, headers, and JSON payload.
|
||||
:param endpoint: The URL to send the request to.
|
||||
:param endpoint: The endpoint URL to send the request to.
|
||||
:param method: The HTTP method to use ('GET', 'POST', etc.).
|
||||
:param payload: Optional JSON payload to include in the request body.
|
||||
:return: The response object.
|
||||
:param payload: Optional JSON payload to include in the POST request body.
|
||||
:return: The response object from the API.
|
||||
"""
|
||||
url = f"{CFAPI_BASE_URL}/cfapi{endpoint}"
|
||||
cfapi_headers = {"Authorization": f"Bearer {get_codeflash_api_key()}"}
|
||||
if method.upper() == "POST":
|
||||
response = requests.post(url, json=payload, headers=cfapi_headers)
|
||||
json_payload = json.dumps(payload, indent=4, default=pydantic_encoder)
|
||||
response = requests.post(url, json=json_payload, headers=cfapi_headers)
|
||||
else:
|
||||
response = requests.get(url, headers=cfapi_headers)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_codeflash_api_key() -> Optional[str]:
|
||||
api_key = os.environ.get("CODEFLASH_API_KEY")
|
||||
if not api_key:
|
||||
|
|
@ -33,11 +35,13 @@ def ensure_codeflash_api_key() -> bool:
|
|||
return True
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_codeflash_org_key() -> Optional[str]:
|
||||
api_key = os.environ.get("CODEFLASH_ORG_KEY")
|
||||
return api_key
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_pr_number() -> Optional[int]:
|
||||
pr_number = os.environ.get("CODEFLASH_PR_NUMBER")
|
||||
if not pr_number:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from typing import List
|
|||
import jedi
|
||||
import tiktoken
|
||||
from jedi.api.classes import Name
|
||||
from pydantic import RootModel
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.code_utils.code_extractor import get_code_no_skeleton, get_code
|
||||
|
|
@ -44,11 +43,6 @@ class Source:
|
|||
source_code: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SourceList(RootModel):
|
||||
root: list[Source]
|
||||
|
||||
|
||||
def get_type_annotation_context(
|
||||
function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: str
|
||||
) -> List[Source]:
|
||||
|
|
|
|||
|
|
@ -64,9 +64,7 @@ def check_create_pr(
|
|||
repo=repo,
|
||||
base_branch=base_branch,
|
||||
file_changes={
|
||||
relative_path: FileDiffContent(
|
||||
oldContent=original_code, newContent=new_code
|
||||
).model_dump(mode="json")
|
||||
relative_path: FileDiffContent(oldContent=original_code, newContent=new_code)
|
||||
},
|
||||
pr_comment=PrComment(
|
||||
optimization_explanation=explanation.explanation_message(),
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
# These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`.
|
||||
__version__ = "0.3.1"
|
||||
__version_tuple__ = (0, 3, 1)
|
||||
__version__ = "0.3.3.post16.dev0+8f05c9c"
|
||||
__version_tuple__ = (0, 3, 3, "post16", "dev0", "8f05c9c")
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ pytest-timeout = ">=2.1.0"
|
|||
tomlkit = ">=0.11.7"
|
||||
unittest-xml-reporting = ">=3.2.0"
|
||||
junitparser = ">=3.1.0"
|
||||
pydantic = "^2.5.2"
|
||||
pydantic = "==1.10.14"
|
||||
black = ">=22.3.0"
|
||||
humanize = ">=4.0.0"
|
||||
posthog = ">=3.0.0"
|
||||
|
|
|
|||
Loading…
Reference in a new issue