mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Another replacment mole bites the dust.
This commit is contained in:
parent
9f29126123
commit
0fbb09ee1c
6 changed files with 186 additions and 237 deletions
|
|
@ -6,17 +6,18 @@ import libcst as cst
|
|||
from libcst import FunctionDef
|
||||
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent
|
||||
|
||||
|
||||
class OptimFunctionCollector(cst.CSTVisitor):
|
||||
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function_name: str,
|
||||
class_name: str | None,
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
preexisting_functions: list[str] | None = None,
|
||||
self,
|
||||
function_name: str,
|
||||
class_name: str | None,
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if preexisting_functions is None:
|
||||
|
|
@ -42,36 +43,32 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
if node.name.value == self.function_name:
|
||||
self.optim_body = node
|
||||
elif (
|
||||
self.preexisting_functions
|
||||
and node.name.value not in self.preexisting_functions
|
||||
and (
|
||||
isinstance(parent, cst.Module)
|
||||
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
|
||||
)
|
||||
self.preexisting_functions
|
||||
and (node.name.value, []) not in self.preexisting_functions
|
||||
and (
|
||||
isinstance(parent, cst.Module)
|
||||
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
|
||||
)
|
||||
):
|
||||
self.optim_new_functions.append(node)
|
||||
|
||||
def visit_ClassDef_body(self, node: cst.ClassDef) -> None:
|
||||
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
|
||||
for child_node in node.body.body:
|
||||
if (
|
||||
isinstance(child_node, cst.FunctionDef)
|
||||
and (
|
||||
node.name.value,
|
||||
child_node.name.value,
|
||||
)
|
||||
not in self.contextual_functions
|
||||
):
|
||||
if isinstance(child_node, cst.FunctionDef) and (
|
||||
node.name.value, child_node.name.value) not in self.contextual_functions and (
|
||||
child_node.name.value, parents) not in self.preexisting_functions:
|
||||
self.optim_new_class_functions.append(child_node)
|
||||
|
||||
|
||||
class OptimFunctionReplacer(cst.CSTTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
function_name: str,
|
||||
optim_body: cst.FunctionDef,
|
||||
optim_new_class_functions: list[cst.FunctionDef],
|
||||
optim_new_functions: list[cst.FunctionDef],
|
||||
class_name: str | None = None,
|
||||
self,
|
||||
function_name: str,
|
||||
optim_body: cst.FunctionDef,
|
||||
optim_new_class_functions: list[cst.FunctionDef],
|
||||
optim_new_functions: list[cst.FunctionDef],
|
||||
class_name: str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.function_name = function_name
|
||||
|
|
@ -86,12 +83,12 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
return False
|
||||
|
||||
def leave_FunctionDef(
|
||||
self,
|
||||
original_node: cst.FunctionDef,
|
||||
updated_node: cst.FunctionDef,
|
||||
self,
|
||||
original_node: cst.FunctionDef,
|
||||
updated_node: cst.FunctionDef,
|
||||
) -> cst.FunctionDef:
|
||||
if original_node.name.value == self.function_name and (
|
||||
self.depth == 0 or (self.depth == 1 and self.in_class)
|
||||
self.depth == 0 or (self.depth == 1 and self.in_class)
|
||||
):
|
||||
return self.optim_body
|
||||
return updated_node
|
||||
|
|
@ -104,9 +101,9 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
return self.in_class
|
||||
|
||||
def leave_ClassDef(
|
||||
self,
|
||||
original_node: cst.ClassDef,
|
||||
updated_node: cst.ClassDef,
|
||||
self,
|
||||
original_node: cst.ClassDef,
|
||||
updated_node: cst.ClassDef,
|
||||
) -> cst.ClassDef:
|
||||
self.depth -= 1
|
||||
if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name):
|
||||
|
|
@ -132,7 +129,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
body=(
|
||||
*node.body[: max_function_index + 1],
|
||||
*self.optim_new_functions,
|
||||
*node.body[max_function_index + 1 :],
|
||||
*node.body[max_function_index + 1:],
|
||||
),
|
||||
)
|
||||
elif class_index is not None:
|
||||
|
|
@ -140,7 +137,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
body=(
|
||||
*node.body[: class_index + 1],
|
||||
*self.optim_new_functions,
|
||||
*node.body[class_index + 1 :],
|
||||
*node.body[class_index + 1:],
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -149,11 +146,11 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
|
||||
|
||||
def replace_functions_in_file(
|
||||
source_code: str,
|
||||
original_function_names: list[str],
|
||||
optimized_code: str,
|
||||
preexisting_functions: list[str],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
source_code: str,
|
||||
original_function_names: list[str],
|
||||
optimized_code: str,
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
) -> str:
|
||||
parsed_function_names = []
|
||||
for original_function_name in original_function_names:
|
||||
|
|
@ -196,14 +193,14 @@ def replace_functions_in_file(
|
|||
|
||||
|
||||
def replace_functions_and_add_imports(
|
||||
source_code: str,
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
file_path_of_module_with_function_to_optimize: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[str],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: str,
|
||||
source_code: str,
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
file_path_of_module_with_function_to_optimize: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: str,
|
||||
) -> str:
|
||||
return add_needed_imports_from_module(
|
||||
optimized_code,
|
||||
|
|
@ -221,13 +218,13 @@ def replace_functions_and_add_imports(
|
|||
|
||||
|
||||
def replace_function_definitions_in_module(
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
file_path_of_module_with_function_to_optimize: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[str],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: str,
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
file_path_of_module_with_function_to_optimize: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: str,
|
||||
) -> None:
|
||||
file: IO[str]
|
||||
with open(module_abspath, encoding="utf8") as file:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import logging
|
||||
import os
|
||||
import site
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
||||
def module_name_from_file_path(file_path: str, project_root_path: str) -> str:
|
||||
|
|
@ -21,16 +22,16 @@ def file_path_from_module_name(module_name: str, project_root_path: str) -> str:
|
|||
|
||||
def ellipsis_in_ast(module: ast.AST) -> bool:
|
||||
for node in ast.walk(module):
|
||||
if isinstance(node, ast.Constant) and node.value == ...:
|
||||
if isinstance(node, ast.Constant) and node.value is ...:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_imports_from_file(
|
||||
file_path: Optional[str] = None,
|
||||
file_string: Optional[str] = None,
|
||||
file_ast: Optional[ast.AST] = None,
|
||||
) -> List[Union[ast.Import, ast.ImportFrom]]:
|
||||
file_path: str | None = None,
|
||||
file_string: str | None = None,
|
||||
file_ast: ast.AST | None = None,
|
||||
) -> list[ast.Import | ast.ImportFrom]:
|
||||
assert (
|
||||
sum([file_path is not None, file_string is not None, file_ast is not None]) == 1
|
||||
), "Must provide exactly one of file_path, file_string, or file_ast"
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Optional
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent
|
||||
from codeflash.api.aiservice import OptimizedCandidate
|
||||
from codeflash.optimization.function_context import Source
|
||||
from codeflash.verification.test_results import TestResults
|
||||
|
|
@ -20,7 +21,7 @@ class CodeOptimizationContext(BaseModel):
|
|||
code_to_optimize_with_helpers: str
|
||||
contextual_dunder_methods: set[tuple[str, str]]
|
||||
helper_functions: list[tuple[Source, str, str]]
|
||||
preexisting_functions: list[str]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]]
|
||||
|
||||
|
||||
class OptimizedCandidateResult(BaseModel):
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ from codeflash.code_utils import env_utils
|
|||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code
|
||||
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
||||
from codeflash.code_utils.code_utils import (
|
||||
get_all_function_names,
|
||||
get_run_tmp_file,
|
||||
module_name_from_file_path,
|
||||
)
|
||||
|
|
@ -623,9 +622,12 @@ class Optimizer:
|
|||
)
|
||||
if code_to_optimize is None:
|
||||
return Failure("Could not find function to optimize.")
|
||||
success, preexisting_functions = get_all_function_names(code_to_optimize)
|
||||
if not success:
|
||||
return Failure("Error in parsing the code, skipping optimization.")
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
(name, [
|
||||
FunctionParent(
|
||||
name=class_name, type="ClassDef")]) for class_name, name in contextual_dunder_methods]
|
||||
preexisting_functions.append(
|
||||
(function_to_optimize.function_name, function_to_optimize.parents))
|
||||
(
|
||||
helper_code,
|
||||
helper_functions,
|
||||
|
|
@ -667,9 +669,9 @@ class Optimizer:
|
|||
function_to_optimize.file_path,
|
||||
project_root,
|
||||
)
|
||||
preexisting_functions.extend(
|
||||
[fn[0].full_name.split(".")[-1] for fn in helper_functions],
|
||||
)
|
||||
preexisting_functions.extend([(qualified_name_list[-1], ([FunctionParent(name=qualified_name_list[-2], type="ClassDef")])) if len(
|
||||
qualified_name_list := fn[0].full_name.split(".")) > 1 else (
|
||||
qualified_name_list[-1], []) for fn in helper_functions])
|
||||
contextual_dunder_methods.update(helper_dunder_methods)
|
||||
return Success(
|
||||
CodeOptimizationContext(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple
|
||||
|
||||
import requests
|
||||
from requests.auth import AuthBase
|
||||
|
||||
|
||||
class AbstractOauth2Authenticator(AuthBase):
|
||||
def __init__(
|
||||
self,
|
||||
refresh_token_error_status_codes: Tuple[int, ...] = (),
|
||||
refresh_token_error_key: str = "",
|
||||
refresh_token_error_values: Tuple[str, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
|
||||
then http errors with such params will be wrapped in AirbyteTracedException.
|
||||
"""
|
||||
self._refresh_token_error_status_codes = refresh_token_error_status_codes
|
||||
self._refresh_token_error_key = refresh_token_error_key
|
||||
self._refresh_token_error_values = refresh_token_error_values
|
||||
|
||||
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
|
||||
"""Attach the HTTP headers required to authenticate on the HTTP request"""
|
||||
request.headers.update(self.get_auth_header())
|
||||
return request
|
||||
|
||||
def build_refresh_request_body(self) -> Mapping[str, Any]:
|
||||
"""
|
||||
Returns the request body to set on the refresh request
|
||||
|
||||
Override to define additional parameters
|
||||
"""
|
||||
payload: MutableMapping[str, Any] = {
|
||||
"grant_type": self.get_grant_type(),
|
||||
"client_id": self.get_client_id(),
|
||||
"client_secret": self.get_client_secret(),
|
||||
"refresh_token": self.get_refresh_token(),
|
||||
}
|
||||
|
||||
if self.get_scopes():
|
||||
payload["scopes"] = self.get_scopes()
|
||||
|
||||
if self.get_refresh_request_body():
|
||||
for key, val in self.get_refresh_request_body().items():
|
||||
# We defer to existing oauth constructs over custom configured fields
|
||||
if key not in payload:
|
||||
payload[key] = val
|
||||
|
||||
return payload
|
||||
|
||||
@abstractmethod
|
||||
def get_grant_type(self) -> str:
|
||||
"""Returns grant_type specified for requesting access_token"""
|
||||
|
||||
@abstractmethod
|
||||
def get_client_id(self) -> str:
|
||||
"""The client id to authenticate"""
|
||||
|
||||
@abstractmethod
|
||||
def get_client_secret(self) -> str:
|
||||
"""The client secret to authenticate"""
|
||||
|
||||
@abstractmethod
|
||||
def get_refresh_token(self) -> Optional[str]:
|
||||
"""The token used to refresh the access token when it expires"""
|
||||
|
||||
@abstractmethod
|
||||
def get_scopes(self) -> List[str]:
|
||||
"""List of requested scopes"""
|
||||
|
||||
@abstractmethod
|
||||
def get_refresh_request_body(self) -> Mapping[str, Any]:
|
||||
"""Returns the request body to set on the refresh request"""
|
||||
|
|
@ -4,7 +4,7 @@ import os
|
|||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
|
@ -50,7 +50,8 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[str] = ["new_function"]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
("new_function", [FunctionParent(name="NewClass", type="ClassDef")])]
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -111,7 +112,7 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[str] = ["new_function", "other_function"]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("new_function", []), ("other_function", [])]
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -172,7 +173,7 @@ print("Salut monde")
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["module.other_function"]
|
||||
preexisting_functions: list[str] = []
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -236,7 +237,7 @@ print("Salut monde")
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
|
||||
preexisting_functions: list[str] = []
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -289,7 +290,7 @@ def supersort(doink):
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["sorter_deps"]
|
||||
preexisting_functions: list[str] = ["sorter_deps"]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("sorter_deps", [])]
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -373,7 +374,7 @@ print("Not cool")
|
|||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=["other_function", "yet_another_function", "blob"],
|
||||
preexisting_functions=[("other_function", []), ("yet_another_function", []), ("blob", [])],
|
||||
contextual_functions=set(),
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
|
|
@ -573,10 +574,9 @@ class CacheConfig(BaseConfig):
|
|||
)
|
||||
"""
|
||||
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
|
||||
preexisting_functions: list[str] = [
|
||||
"__init__",
|
||||
"from_config",
|
||||
]
|
||||
parents = [FunctionParent(name="CacheConfig", type="ClassDef")]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("from_config", parents)]
|
||||
|
||||
contextual_functions: set[tuple[str, str]] = {
|
||||
("CacheSimilarityEvalConfig", "__init__"),
|
||||
("CacheConfig", "__init__"),
|
||||
|
|
@ -652,9 +652,8 @@ def test_test_libcst_code_replacement8() -> None:
|
|||
return np.sum(a != b) / a.size
|
||||
'''
|
||||
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
|
||||
preexisting_functions: list[str] = [
|
||||
"_hamming_distance",
|
||||
]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
("_hamming_distance", [FunctionParent("_EmbeddingDistanceChainMixin", "ClassDef")])]
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -709,9 +708,9 @@ def totally_new_function(value: Optional[str]):
|
|||
|
||||
print("Hello world")
|
||||
"""
|
||||
|
||||
parents = [FunctionParent(name="NewClass", type="ClassDef")]
|
||||
function_name: str = "NewClass.__init__"
|
||||
preexisting_functions: list[str] = ["__init__", "__call__"]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("__call__", parents)]
|
||||
contextual_functions: set[tuple[str, str]] = {
|
||||
("NewClass", "__init__"),
|
||||
("NewClass", "__call__"),
|
||||
|
|
@ -794,168 +793,43 @@ class MainClass:
|
|||
|
||||
|
||||
def test_code_replacement11() -> None:
|
||||
optim_code = '''from abc import abstractmethod
|
||||
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple
|
||||
|
||||
import requests
|
||||
from requests.auth import AuthBase
|
||||
|
||||
|
||||
class AbstractOauth2Authenticator(AuthBase):
|
||||
def __init__(
|
||||
self,
|
||||
refresh_token_error_status_codes: Tuple[int, ...] = (),
|
||||
refresh_token_error_key: str = "",
|
||||
refresh_token_error_values: Tuple[str, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
|
||||
then http errors with such params will be wrapped in AirbyteTracedException.
|
||||
"""
|
||||
self._refresh_token_error_status_codes = refresh_token_error_status_codes
|
||||
self._refresh_token_error_key = refresh_token_error_key
|
||||
self._refresh_token_error_values = refresh_token_error_values
|
||||
|
||||
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
|
||||
"""Attach the HTTP headers required to authenticate on the HTTP request"""
|
||||
request.headers.update(self.get_auth_header())
|
||||
return request
|
||||
|
||||
def build_refresh_request_body(self) -> Mapping[str, Any]:
|
||||
"""
|
||||
Returns the request body to set on the refresh request
|
||||
|
||||
Override to define additional parameters
|
||||
"""
|
||||
payload: MutableMapping[str, Any] = {
|
||||
"grant_type": self.get_grant_type(),
|
||||
"client_id": self.get_client_id(),
|
||||
"client_secret": self.get_client_secret(),
|
||||
"refresh_token": self.get_refresh_token(),
|
||||
}
|
||||
|
||||
scopes = self.get_scopes()
|
||||
if scopes:
|
||||
payload["scopes"] = scopes
|
||||
|
||||
refresh_request_body = self.get_refresh_request_body()
|
||||
if refresh_request_body:
|
||||
for key, val in refresh_request_body.items():
|
||||
if key not in payload:
|
||||
payload[key] = val
|
||||
|
||||
optim_code = '''class Fu():
|
||||
def foo(self) -> dict[str, str]:
|
||||
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar() + 1)}
|
||||
return payload
|
||||
|
||||
@abstractmethod
|
||||
def get_grant_type(self) -> str:
|
||||
"""Returns grant_type specified for requesting access_token"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_client_id(self) -> str:
|
||||
"""The client id to authenticate"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_client_secret(self) -> str:
|
||||
"""The client secret to authenticate"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_refresh_token(self) -> Optional[str]:
|
||||
"""The token used to refresh the access token when it expires"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_scopes(self) -> List[str]:
|
||||
"""List of requested scopes"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_refresh_request_body(self) -> Mapping[str, Any]:
|
||||
"""Returns the request body to set on the refresh request"""
|
||||
def real_bar(self) -> int:
|
||||
"""No abstract nonsense"""
|
||||
pass
|
||||
'''
|
||||
original_code = '''import requests
|
||||
from abc import abstractmethod
|
||||
from requests.auth import AuthBase
|
||||
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple
|
||||
|
||||
class AbstractOauth2Authenticator(AuthBase):
|
||||
def __init__(
|
||||
self,
|
||||
refresh_token_error_status_codes: Tuple[int, ...] = (),
|
||||
refresh_token_error_key: str = "",
|
||||
refresh_token_error_values: Tuple[str, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
|
||||
then http errors with such params will be wrapped in AirbyteTracedException.
|
||||
"""
|
||||
self._refresh_token_error_status_codes = refresh_token_error_status_codes
|
||||
self._refresh_token_error_key = refresh_token_error_key
|
||||
self._refresh_token_error_values = refresh_token_error_values
|
||||
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
|
||||
"""Attach the HTTP headers required to authenticate on the HTTP request"""
|
||||
request.headers.update(self.get_auth_header())
|
||||
return request
|
||||
def build_refresh_request_body(self) -> Mapping[str, Any]:
|
||||
"""
|
||||
Returns the request body to set on the refresh request
|
||||
|
||||
Override to define additional parameters
|
||||
"""
|
||||
payload: MutableMapping[str, Any] = {
|
||||
"grant_type": self.get_grant_type(),
|
||||
"client_id": self.get_client_id(),
|
||||
"client_secret": self.get_client_secret(),
|
||||
"refresh_token": self.get_refresh_token(),
|
||||
}
|
||||
|
||||
if self.get_scopes():
|
||||
payload["scopes"] = self.get_scopes()
|
||||
|
||||
if self.get_refresh_request_body():
|
||||
for key, val in self.get_refresh_request_body().items():
|
||||
# We defer to existing oauth constructs over custom configured fields
|
||||
if key not in payload:
|
||||
payload[key] = val
|
||||
|
||||
original_code = '''class Fu():
|
||||
def foo(self) -> dict[str, str]:
|
||||
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar())}
|
||||
return payload
|
||||
@abstractmethod
|
||||
def get_grant_type(self) -> str:
|
||||
"""Returns grant_type specified for requesting access_token"""
|
||||
@abstractmethod
|
||||
def get_client_id(self) -> str:
|
||||
"""The client id to authenticate"""
|
||||
@abstractmethod
|
||||
def get_client_secret(self) -> str:
|
||||
"""The client secret to authenticate"""
|
||||
@abstractmethod
|
||||
def get_refresh_token(self) -> Optional[str]:
|
||||
"""The token used to refresh the access token when it expires"""
|
||||
@abstractmethod
|
||||
def get_scopes(self) -> List[str]:
|
||||
"""List of requested scopes"""
|
||||
@abstractmethod
|
||||
def get_refresh_request_body(self) -> Mapping[str, Any]:
|
||||
"""Returns the request body to set on the refresh request"""
|
||||
|
||||
def real_bar(self) -> int:
|
||||
"""No abstract nonsense"""
|
||||
return 0
|
||||
'''
|
||||
expected_code = '''class Fu():
|
||||
def foo(self) -> dict[str, str]:
|
||||
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar() + 1)}
|
||||
return payload
|
||||
|
||||
def real_bar(self) -> int:
|
||||
"""No abstract nonsense"""
|
||||
return 0
|
||||
'''
|
||||
|
||||
function_name: str = "AbstractOauth2Authenticator.build_refresh_request_body"
|
||||
# TODO : Fill the right values here
|
||||
preexisting_functions: list[str] = ["__init__", "__call__"]
|
||||
contextual_functions: set[tuple[str, str]] = {
|
||||
("AbstractOauth2Authenticator", "__init__"),
|
||||
("AbstractOauth2Authenticator", "__call__"),
|
||||
}
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
function_name: str = "Fu.foo"
|
||||
parents = [FunctionParent("Fu", "ClassDef")]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("foo", parents), ("real_bar", parents)]
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
original_function_names=[function_name],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
assert new_code == expected_code
|
||||
|
|
|
|||
Loading…
Reference in a new issue