try to repro the code replacement bug

This commit is contained in:
Saurabh Misra 2024-06-17 18:27:13 -07:00
parent 02fd821eb1
commit e669598f1e

View file

@ -791,3 +791,171 @@ class MainClass:
original_source_code=original_code,
).unwrap()
assert code_context.code_to_optimize_with_helpers == get_code_output
def test_code_replacement11() -> None:
optim_code = '''from abc import abstractmethod
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple
import requests
from requests.auth import AuthBase
class AbstractOauth2Authenticator(AuthBase):
def __init__(
self,
refresh_token_error_status_codes: Tuple[int, ...] = (),
refresh_token_error_key: str = "",
refresh_token_error_values: Tuple[str, ...] = (),
) -> None:
"""
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
then http errors with such params will be wrapped in AirbyteTracedException.
"""
self._refresh_token_error_status_codes = refresh_token_error_status_codes
self._refresh_token_error_key = refresh_token_error_key
self._refresh_token_error_values = refresh_token_error_values
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
"""Attach the HTTP headers required to authenticate on the HTTP request"""
request.headers.update(self.get_auth_header())
return request
def build_refresh_request_body(self) -> Mapping[str, Any]:
"""
Returns the request body to set on the refresh request
Override to define additional parameters
"""
payload: MutableMapping[str, Any] = {
"grant_type": self.get_grant_type(),
"client_id": self.get_client_id(),
"client_secret": self.get_client_secret(),
"refresh_token": self.get_refresh_token(),
}
scopes = self.get_scopes()
if scopes:
payload["scopes"] = scopes
refresh_request_body = self.get_refresh_request_body()
if refresh_request_body:
for key, val in refresh_request_body.items():
if key not in payload:
payload[key] = val
return payload
@abstractmethod
def get_grant_type(self) -> str:
"""Returns grant_type specified for requesting access_token"""
pass
@abstractmethod
def get_client_id(self) -> str:
"""The client id to authenticate"""
pass
@abstractmethod
def get_client_secret(self) -> str:
"""The client secret to authenticate"""
pass
@abstractmethod
def get_refresh_token(self) -> Optional[str]:
"""The token used to refresh the access token when it expires"""
pass
@abstractmethod
def get_scopes(self) -> List[str]:
"""List of requested scopes"""
pass
@abstractmethod
def get_refresh_request_body(self) -> Mapping[str, Any]:
"""Returns the request body to set on the refresh request"""
pass
'''
original_code = '''import requests
from abc import abstractmethod
from requests.auth import AuthBase
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple
class AbstractOauth2Authenticator(AuthBase):
def __init__(
self,
refresh_token_error_status_codes: Tuple[int, ...] = (),
refresh_token_error_key: str = "",
refresh_token_error_values: Tuple[str, ...] = (),
) -> None:
"""
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
then http errors with such params will be wrapped in AirbyteTracedException.
"""
self._refresh_token_error_status_codes = refresh_token_error_status_codes
self._refresh_token_error_key = refresh_token_error_key
self._refresh_token_error_values = refresh_token_error_values
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
"""Attach the HTTP headers required to authenticate on the HTTP request"""
request.headers.update(self.get_auth_header())
return request
def build_refresh_request_body(self) -> Mapping[str, Any]:
"""
Returns the request body to set on the refresh request
Override to define additional parameters
"""
payload: MutableMapping[str, Any] = {
"grant_type": self.get_grant_type(),
"client_id": self.get_client_id(),
"client_secret": self.get_client_secret(),
"refresh_token": self.get_refresh_token(),
}
if self.get_scopes():
payload["scopes"] = self.get_scopes()
if self.get_refresh_request_body():
for key, val in self.get_refresh_request_body().items():
# We defer to existing oauth constructs over custom configured fields
if key not in payload:
payload[key] = val
return payload
@abstractmethod
def get_grant_type(self) -> str:
"""Returns grant_type specified for requesting access_token"""
@abstractmethod
def get_client_id(self) -> str:
"""The client id to authenticate"""
@abstractmethod
def get_client_secret(self) -> str:
"""The client secret to authenticate"""
@abstractmethod
def get_refresh_token(self) -> Optional[str]:
"""The token used to refresh the access token when it expires"""
@abstractmethod
def get_scopes(self) -> List[str]:
"""List of requested scopes"""
@abstractmethod
def get_refresh_request_body(self) -> Mapping[str, Any]:
"""Returns the request body to set on the refresh request"""
'''
function_name: str = "AbstractOauth2Authenticator.build_refresh_request_body"
# TODO : Fill the right values here
preexisting_functions: list[str] = ["__init__", "__call__"]
contextual_functions: set[tuple[str, str]] = {
("AbstractOauth2Authenticator", "__init__"),
("AbstractOauth2Authenticator", "__call__"),
}
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)