mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
try to repro the code replacement bug
This commit is contained in:
parent
02fd821eb1
commit
e669598f1e
1 changed files with 168 additions and 0 deletions
|
|
@ -791,3 +791,171 @@ class MainClass:
|
|||
original_source_code=original_code,
|
||||
).unwrap()
|
||||
assert code_context.code_to_optimize_with_helpers == get_code_output
|
||||
|
||||
|
||||
def test_code_replacement11() -> None:
|
||||
optim_code = '''from abc import abstractmethod
|
||||
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple
|
||||
|
||||
import requests
|
||||
from requests.auth import AuthBase
|
||||
|
||||
|
||||
class AbstractOauth2Authenticator(AuthBase):
|
||||
def __init__(
|
||||
self,
|
||||
refresh_token_error_status_codes: Tuple[int, ...] = (),
|
||||
refresh_token_error_key: str = "",
|
||||
refresh_token_error_values: Tuple[str, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
|
||||
then http errors with such params will be wrapped in AirbyteTracedException.
|
||||
"""
|
||||
self._refresh_token_error_status_codes = refresh_token_error_status_codes
|
||||
self._refresh_token_error_key = refresh_token_error_key
|
||||
self._refresh_token_error_values = refresh_token_error_values
|
||||
|
||||
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
|
||||
"""Attach the HTTP headers required to authenticate on the HTTP request"""
|
||||
request.headers.update(self.get_auth_header())
|
||||
return request
|
||||
|
||||
def build_refresh_request_body(self) -> Mapping[str, Any]:
|
||||
"""
|
||||
Returns the request body to set on the refresh request
|
||||
|
||||
Override to define additional parameters
|
||||
"""
|
||||
payload: MutableMapping[str, Any] = {
|
||||
"grant_type": self.get_grant_type(),
|
||||
"client_id": self.get_client_id(),
|
||||
"client_secret": self.get_client_secret(),
|
||||
"refresh_token": self.get_refresh_token(),
|
||||
}
|
||||
|
||||
scopes = self.get_scopes()
|
||||
if scopes:
|
||||
payload["scopes"] = scopes
|
||||
|
||||
refresh_request_body = self.get_refresh_request_body()
|
||||
if refresh_request_body:
|
||||
for key, val in refresh_request_body.items():
|
||||
if key not in payload:
|
||||
payload[key] = val
|
||||
|
||||
return payload
|
||||
|
||||
@abstractmethod
|
||||
def get_grant_type(self) -> str:
|
||||
"""Returns grant_type specified for requesting access_token"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_client_id(self) -> str:
|
||||
"""The client id to authenticate"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_client_secret(self) -> str:
|
||||
"""The client secret to authenticate"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_refresh_token(self) -> Optional[str]:
|
||||
"""The token used to refresh the access token when it expires"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_scopes(self) -> List[str]:
|
||||
"""List of requested scopes"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_refresh_request_body(self) -> Mapping[str, Any]:
|
||||
"""Returns the request body to set on the refresh request"""
|
||||
pass
|
||||
'''
|
||||
original_code = '''import requests
|
||||
from abc import abstractmethod
|
||||
from requests.auth import AuthBase
|
||||
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple
|
||||
|
||||
class AbstractOauth2Authenticator(AuthBase):
|
||||
def __init__(
|
||||
self,
|
||||
refresh_token_error_status_codes: Tuple[int, ...] = (),
|
||||
refresh_token_error_key: str = "",
|
||||
refresh_token_error_values: Tuple[str, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
|
||||
then http errors with such params will be wrapped in AirbyteTracedException.
|
||||
"""
|
||||
self._refresh_token_error_status_codes = refresh_token_error_status_codes
|
||||
self._refresh_token_error_key = refresh_token_error_key
|
||||
self._refresh_token_error_values = refresh_token_error_values
|
||||
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
|
||||
"""Attach the HTTP headers required to authenticate on the HTTP request"""
|
||||
request.headers.update(self.get_auth_header())
|
||||
return request
|
||||
def build_refresh_request_body(self) -> Mapping[str, Any]:
|
||||
"""
|
||||
Returns the request body to set on the refresh request
|
||||
|
||||
Override to define additional parameters
|
||||
"""
|
||||
payload: MutableMapping[str, Any] = {
|
||||
"grant_type": self.get_grant_type(),
|
||||
"client_id": self.get_client_id(),
|
||||
"client_secret": self.get_client_secret(),
|
||||
"refresh_token": self.get_refresh_token(),
|
||||
}
|
||||
|
||||
if self.get_scopes():
|
||||
payload["scopes"] = self.get_scopes()
|
||||
|
||||
if self.get_refresh_request_body():
|
||||
for key, val in self.get_refresh_request_body().items():
|
||||
# We defer to existing oauth constructs over custom configured fields
|
||||
if key not in payload:
|
||||
payload[key] = val
|
||||
|
||||
return payload
|
||||
@abstractmethod
|
||||
def get_grant_type(self) -> str:
|
||||
"""Returns grant_type specified for requesting access_token"""
|
||||
@abstractmethod
|
||||
def get_client_id(self) -> str:
|
||||
"""The client id to authenticate"""
|
||||
@abstractmethod
|
||||
def get_client_secret(self) -> str:
|
||||
"""The client secret to authenticate"""
|
||||
@abstractmethod
|
||||
def get_refresh_token(self) -> Optional[str]:
|
||||
"""The token used to refresh the access token when it expires"""
|
||||
@abstractmethod
|
||||
def get_scopes(self) -> List[str]:
|
||||
"""List of requested scopes"""
|
||||
@abstractmethod
|
||||
def get_refresh_request_body(self) -> Mapping[str, Any]:
|
||||
"""Returns the request body to set on the refresh request"""
|
||||
'''
|
||||
|
||||
function_name: str = "AbstractOauth2Authenticator.build_refresh_request_body"
|
||||
# TODO : Fill the right values here
|
||||
preexisting_functions: list[str] = ["__init__", "__call__"]
|
||||
contextual_functions: set[tuple[str, str]] = {
|
||||
("AbstractOauth2Authenticator", "__init__"),
|
||||
("AbstractOauth2Authenticator", "__call__"),
|
||||
}
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue