Another replacment mole bites the dust.

This commit is contained in:
RD 2024-06-21 16:43:43 -07:00
parent 9f29126123
commit 0fbb09ee1c
6 changed files with 186 additions and 237 deletions

View file

@ -6,17 +6,18 @@ import libcst as cst
from libcst import FunctionDef
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.discovery.functions_to_optimize import FunctionParent
class OptimFunctionCollector(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
def __init__(
self,
function_name: str,
class_name: str | None,
contextual_functions: set[tuple[str, str]],
preexisting_functions: list[str] | None = None,
self,
function_name: str,
class_name: str | None,
contextual_functions: set[tuple[str, str]],
preexisting_functions: list[tuple[str, list[FunctionParent]]] | None = None,
) -> None:
super().__init__()
if preexisting_functions is None:
@ -42,36 +43,32 @@ class OptimFunctionCollector(cst.CSTVisitor):
if node.name.value == self.function_name:
self.optim_body = node
elif (
self.preexisting_functions
and node.name.value not in self.preexisting_functions
and (
isinstance(parent, cst.Module)
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
)
self.preexisting_functions
and (node.name.value, []) not in self.preexisting_functions
and (
isinstance(parent, cst.Module)
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
)
):
self.optim_new_functions.append(node)
def visit_ClassDef_body(self, node: cst.ClassDef) -> None:
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
for child_node in node.body.body:
if (
isinstance(child_node, cst.FunctionDef)
and (
node.name.value,
child_node.name.value,
)
not in self.contextual_functions
):
if isinstance(child_node, cst.FunctionDef) and (
node.name.value, child_node.name.value) not in self.contextual_functions and (
child_node.name.value, parents) not in self.preexisting_functions:
self.optim_new_class_functions.append(child_node)
class OptimFunctionReplacer(cst.CSTTransformer):
def __init__(
self,
function_name: str,
optim_body: cst.FunctionDef,
optim_new_class_functions: list[cst.FunctionDef],
optim_new_functions: list[cst.FunctionDef],
class_name: str | None = None,
self,
function_name: str,
optim_body: cst.FunctionDef,
optim_new_class_functions: list[cst.FunctionDef],
optim_new_functions: list[cst.FunctionDef],
class_name: str | None = None,
) -> None:
super().__init__()
self.function_name = function_name
@ -86,12 +83,12 @@ class OptimFunctionReplacer(cst.CSTTransformer):
return False
def leave_FunctionDef(
self,
original_node: cst.FunctionDef,
updated_node: cst.FunctionDef,
self,
original_node: cst.FunctionDef,
updated_node: cst.FunctionDef,
) -> cst.FunctionDef:
if original_node.name.value == self.function_name and (
self.depth == 0 or (self.depth == 1 and self.in_class)
self.depth == 0 or (self.depth == 1 and self.in_class)
):
return self.optim_body
return updated_node
@ -104,9 +101,9 @@ class OptimFunctionReplacer(cst.CSTTransformer):
return self.in_class
def leave_ClassDef(
self,
original_node: cst.ClassDef,
updated_node: cst.ClassDef,
self,
original_node: cst.ClassDef,
updated_node: cst.ClassDef,
) -> cst.ClassDef:
self.depth -= 1
if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name):
@ -132,7 +129,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
body=(
*node.body[: max_function_index + 1],
*self.optim_new_functions,
*node.body[max_function_index + 1 :],
*node.body[max_function_index + 1:],
),
)
elif class_index is not None:
@ -140,7 +137,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
body=(
*node.body[: class_index + 1],
*self.optim_new_functions,
*node.body[class_index + 1 :],
*node.body[class_index + 1:],
),
)
else:
@ -149,11 +146,11 @@ class OptimFunctionReplacer(cst.CSTTransformer):
def replace_functions_in_file(
source_code: str,
original_function_names: list[str],
optimized_code: str,
preexisting_functions: list[str],
contextual_functions: set[tuple[str, str]],
source_code: str,
original_function_names: list[str],
optimized_code: str,
preexisting_functions: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
) -> str:
parsed_function_names = []
for original_function_name in original_function_names:
@ -196,14 +193,14 @@ def replace_functions_in_file(
def replace_functions_and_add_imports(
source_code: str,
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: str,
module_abspath: str,
preexisting_functions: list[str],
contextual_functions: set[tuple[str, str]],
project_root_path: str,
source_code: str,
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: str,
module_abspath: str,
preexisting_functions: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
project_root_path: str,
) -> str:
return add_needed_imports_from_module(
optimized_code,
@ -221,13 +218,13 @@ def replace_functions_and_add_imports(
def replace_function_definitions_in_module(
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: str,
module_abspath: str,
preexisting_functions: list[str],
contextual_functions: set[tuple[str, str]],
project_root_path: str,
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: str,
module_abspath: str,
preexisting_functions: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
project_root_path: str,
) -> None:
file: IO[str]
with open(module_abspath, encoding="utf8") as file:

View file

@ -1,9 +1,10 @@
from __future__ import annotations
import ast
import logging
import os
import site
from tempfile import TemporaryDirectory
from typing import List, Optional, Tuple, Union
def module_name_from_file_path(file_path: str, project_root_path: str) -> str:
@ -21,16 +22,16 @@ def file_path_from_module_name(module_name: str, project_root_path: str) -> str:
def ellipsis_in_ast(module: ast.AST) -> bool:
for node in ast.walk(module):
if isinstance(node, ast.Constant) and node.value == ...:
if isinstance(node, ast.Constant) and node.value is ...:
return True
return False
def get_imports_from_file(
file_path: Optional[str] = None,
file_string: Optional[str] = None,
file_ast: Optional[ast.AST] = None,
) -> List[Union[ast.Import, ast.ImportFrom]]:
file_path: str | None = None,
file_string: str | None = None,
file_ast: ast.AST | None = None,
) -> list[ast.Import | ast.ImportFrom]:
assert (
sum([file_path is not None, file_string is not None, file_ast is not None]) == 1
), "Must provide exactly one of file_path, file_string, or file_ast"

View file

@ -4,6 +4,7 @@ from typing import Optional
from pydantic import BaseModel
from codeflash.discovery.functions_to_optimize import FunctionParent
from codeflash.api.aiservice import OptimizedCandidate
from codeflash.optimization.function_context import Source
from codeflash.verification.test_results import TestResults
@ -20,7 +21,7 @@ class CodeOptimizationContext(BaseModel):
code_to_optimize_with_helpers: str
contextual_dunder_methods: set[tuple[str, str]]
helper_functions: list[tuple[Source, str, str]]
preexisting_functions: list[str]
preexisting_functions: list[tuple[str, list[FunctionParent]]]
class OptimizedCandidateResult(BaseModel):

View file

@ -23,7 +23,6 @@ from codeflash.code_utils import env_utils
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.code_utils.code_utils import (
get_all_function_names,
get_run_tmp_file,
module_name_from_file_path,
)
@ -623,9 +622,12 @@ class Optimizer:
)
if code_to_optimize is None:
return Failure("Could not find function to optimize.")
success, preexisting_functions = get_all_function_names(code_to_optimize)
if not success:
return Failure("Error in parsing the code, skipping optimization.")
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
(name, [
FunctionParent(
name=class_name, type="ClassDef")]) for class_name, name in contextual_dunder_methods]
preexisting_functions.append(
(function_to_optimize.function_name, function_to_optimize.parents))
(
helper_code,
helper_functions,
@ -667,9 +669,9 @@ class Optimizer:
function_to_optimize.file_path,
project_root,
)
preexisting_functions.extend(
[fn[0].full_name.split(".")[-1] for fn in helper_functions],
)
preexisting_functions.extend([(qualified_name_list[-1], ([FunctionParent(name=qualified_name_list[-2], type="ClassDef")])) if len(
qualified_name_list := fn[0].full_name.split(".")) > 1 else (
qualified_name_list[-1], []) for fn in helper_functions])
contextual_dunder_methods.update(helper_dunder_methods)
return Success(
CodeOptimizationContext(

View file

@ -0,0 +1,74 @@
from abc import abstractmethod
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple
import requests
from requests.auth import AuthBase
class AbstractOauth2Authenticator(AuthBase):
def __init__(
self,
refresh_token_error_status_codes: Tuple[int, ...] = (),
refresh_token_error_key: str = "",
refresh_token_error_values: Tuple[str, ...] = (),
) -> None:
"""
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
then http errors with such params will be wrapped in AirbyteTracedException.
"""
self._refresh_token_error_status_codes = refresh_token_error_status_codes
self._refresh_token_error_key = refresh_token_error_key
self._refresh_token_error_values = refresh_token_error_values
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
"""Attach the HTTP headers required to authenticate on the HTTP request"""
request.headers.update(self.get_auth_header())
return request
def build_refresh_request_body(self) -> Mapping[str, Any]:
"""
Returns the request body to set on the refresh request
Override to define additional parameters
"""
payload: MutableMapping[str, Any] = {
"grant_type": self.get_grant_type(),
"client_id": self.get_client_id(),
"client_secret": self.get_client_secret(),
"refresh_token": self.get_refresh_token(),
}
if self.get_scopes():
payload["scopes"] = self.get_scopes()
if self.get_refresh_request_body():
for key, val in self.get_refresh_request_body().items():
# We defer to existing oauth constructs over custom configured fields
if key not in payload:
payload[key] = val
return payload
@abstractmethod
def get_grant_type(self) -> str:
"""Returns grant_type specified for requesting access_token"""
@abstractmethod
def get_client_id(self) -> str:
"""The client id to authenticate"""
@abstractmethod
def get_client_secret(self) -> str:
"""The client secret to authenticate"""
@abstractmethod
def get_refresh_token(self) -> Optional[str]:
"""The token used to refresh the access token when it expires"""
@abstractmethod
def get_scopes(self) -> List[str]:
"""List of requested scopes"""
@abstractmethod
def get_refresh_request_body(self) -> Mapping[str, Any]:
"""Returns the request body to set on the refresh request"""

View file

@ -4,7 +4,7 @@ import os
from argparse import Namespace
from pathlib import Path
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
from codeflash.optimization.optimizer import Optimizer
@ -50,7 +50,8 @@ print("Hello world")
"""
function_name: str = "NewClass.new_function"
preexisting_functions: list[str] = ["new_function"]
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
("new_function", [FunctionParent(name="NewClass", type="ClassDef")])]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
@ -111,7 +112,7 @@ print("Hello world")
"""
function_name: str = "NewClass.new_function"
preexisting_functions: list[str] = ["new_function", "other_function"]
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("new_function", []), ("other_function", [])]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
@ -172,7 +173,7 @@ print("Salut monde")
"""
function_names: list[str] = ["module.other_function"]
preexisting_functions: list[str] = []
preexisting_functions: list[tuple[str, list[FunctionParent]]] = []
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
@ -236,7 +237,7 @@ print("Salut monde")
"""
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
preexisting_functions: list[str] = []
preexisting_functions: list[tuple[str, list[FunctionParent]]] = []
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
@ -289,7 +290,7 @@ def supersort(doink):
"""
function_names: list[str] = ["sorter_deps"]
preexisting_functions: list[str] = ["sorter_deps"]
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("sorter_deps", [])]
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
@ -373,7 +374,7 @@ print("Not cool")
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=["other_function", "yet_another_function", "blob"],
preexisting_functions=[("other_function", []), ("yet_another_function", []), ("blob", [])],
contextual_functions=set(),
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)
@ -573,10 +574,9 @@ class CacheConfig(BaseConfig):
)
"""
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
preexisting_functions: list[str] = [
"__init__",
"from_config",
]
parents = [FunctionParent(name="CacheConfig", type="ClassDef")]
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("from_config", parents)]
contextual_functions: set[tuple[str, str]] = {
("CacheSimilarityEvalConfig", "__init__"),
("CacheConfig", "__init__"),
@ -652,9 +652,8 @@ def test_test_libcst_code_replacement8() -> None:
return np.sum(a != b) / a.size
'''
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
preexisting_functions: list[str] = [
"_hamming_distance",
]
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
("_hamming_distance", [FunctionParent("_EmbeddingDistanceChainMixin", "ClassDef")])]
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
@ -709,9 +708,9 @@ def totally_new_function(value: Optional[str]):
print("Hello world")
"""
parents = [FunctionParent(name="NewClass", type="ClassDef")]
function_name: str = "NewClass.__init__"
preexisting_functions: list[str] = ["__init__", "__call__"]
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("__call__", parents)]
contextual_functions: set[tuple[str, str]] = {
("NewClass", "__init__"),
("NewClass", "__call__"),
@ -794,168 +793,43 @@ class MainClass:
def test_code_replacement11() -> None:
optim_code = '''from abc import abstractmethod
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple
import requests
from requests.auth import AuthBase
class AbstractOauth2Authenticator(AuthBase):
def __init__(
self,
refresh_token_error_status_codes: Tuple[int, ...] = (),
refresh_token_error_key: str = "",
refresh_token_error_values: Tuple[str, ...] = (),
) -> None:
"""
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
then http errors with such params will be wrapped in AirbyteTracedException.
"""
self._refresh_token_error_status_codes = refresh_token_error_status_codes
self._refresh_token_error_key = refresh_token_error_key
self._refresh_token_error_values = refresh_token_error_values
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
"""Attach the HTTP headers required to authenticate on the HTTP request"""
request.headers.update(self.get_auth_header())
return request
def build_refresh_request_body(self) -> Mapping[str, Any]:
"""
Returns the request body to set on the refresh request
Override to define additional parameters
"""
payload: MutableMapping[str, Any] = {
"grant_type": self.get_grant_type(),
"client_id": self.get_client_id(),
"client_secret": self.get_client_secret(),
"refresh_token": self.get_refresh_token(),
}
scopes = self.get_scopes()
if scopes:
payload["scopes"] = scopes
refresh_request_body = self.get_refresh_request_body()
if refresh_request_body:
for key, val in refresh_request_body.items():
if key not in payload:
payload[key] = val
optim_code = '''class Fu():
def foo(self) -> dict[str, str]:
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar() + 1)}
return payload
@abstractmethod
def get_grant_type(self) -> str:
"""Returns grant_type specified for requesting access_token"""
pass
@abstractmethod
def get_client_id(self) -> str:
"""The client id to authenticate"""
pass
@abstractmethod
def get_client_secret(self) -> str:
"""The client secret to authenticate"""
pass
@abstractmethod
def get_refresh_token(self) -> Optional[str]:
"""The token used to refresh the access token when it expires"""
pass
@abstractmethod
def get_scopes(self) -> List[str]:
"""List of requested scopes"""
pass
@abstractmethod
def get_refresh_request_body(self) -> Mapping[str, Any]:
"""Returns the request body to set on the refresh request"""
def real_bar(self) -> int:
"""No abstract nonsense"""
pass
'''
original_code = '''import requests
from abc import abstractmethod
from requests.auth import AuthBase
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple
class AbstractOauth2Authenticator(AuthBase):
def __init__(
self,
refresh_token_error_status_codes: Tuple[int, ...] = (),
refresh_token_error_key: str = "",
refresh_token_error_values: Tuple[str, ...] = (),
) -> None:
"""
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
then http errors with such params will be wrapped in AirbyteTracedException.
"""
self._refresh_token_error_status_codes = refresh_token_error_status_codes
self._refresh_token_error_key = refresh_token_error_key
self._refresh_token_error_values = refresh_token_error_values
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
"""Attach the HTTP headers required to authenticate on the HTTP request"""
request.headers.update(self.get_auth_header())
return request
def build_refresh_request_body(self) -> Mapping[str, Any]:
"""
Returns the request body to set on the refresh request
Override to define additional parameters
"""
payload: MutableMapping[str, Any] = {
"grant_type": self.get_grant_type(),
"client_id": self.get_client_id(),
"client_secret": self.get_client_secret(),
"refresh_token": self.get_refresh_token(),
}
if self.get_scopes():
payload["scopes"] = self.get_scopes()
if self.get_refresh_request_body():
for key, val in self.get_refresh_request_body().items():
# We defer to existing oauth constructs over custom configured fields
if key not in payload:
payload[key] = val
original_code = '''class Fu():
def foo(self) -> dict[str, str]:
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar())}
return payload
@abstractmethod
def get_grant_type(self) -> str:
"""Returns grant_type specified for requesting access_token"""
@abstractmethod
def get_client_id(self) -> str:
"""The client id to authenticate"""
@abstractmethod
def get_client_secret(self) -> str:
"""The client secret to authenticate"""
@abstractmethod
def get_refresh_token(self) -> Optional[str]:
"""The token used to refresh the access token when it expires"""
@abstractmethod
def get_scopes(self) -> List[str]:
"""List of requested scopes"""
@abstractmethod
def get_refresh_request_body(self) -> Mapping[str, Any]:
"""Returns the request body to set on the refresh request"""
def real_bar(self) -> int:
"""No abstract nonsense"""
return 0
'''
expected_code = '''class Fu():
def foo(self) -> dict[str, str]:
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar() + 1)}
return payload
def real_bar(self) -> int:
"""No abstract nonsense"""
return 0
'''
function_name: str = "AbstractOauth2Authenticator.build_refresh_request_body"
# TODO : Fill the right values here
preexisting_functions: list[str] = ["__init__", "__call__"]
contextual_functions: set[tuple[str, str]] = {
("AbstractOauth2Authenticator", "__init__"),
("AbstractOauth2Authenticator", "__call__"),
}
new_code: str = replace_functions_and_add_imports(
function_name: str = "Fu.foo"
parents = [FunctionParent("Fu", "ClassDef")]
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("foo", parents), ("real_bar", parents)]
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
source_code=original_code,
function_names=[function_name],
original_function_names=[function_name],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)
assert new_code == expected_code