tests and fix global assignments imports
This commit is contained in:
parent
2acda6a411
commit
2e394f6b8f
6 changed files with 73 additions and 122 deletions
|
|
@ -1,77 +1,25 @@
|
|||
from os import getenv
|
||||
from typing import Optional
|
||||
|
||||
from attrs import define, evolve, field
|
||||
from attrs import define, evolve
|
||||
|
||||
from code_to_optimize.code_directories.circular_deps.constants import DEFAULT_API_URL, DEFAULT_APP_URL
|
||||
from constants import DEFAULT_API_URL, DEFAULT_APP_URL
|
||||
|
||||
|
||||
@define
|
||||
class GalileoApiClient():
|
||||
"""A Client which has been authenticated for use on secured endpoints
|
||||
|
||||
The following are accepted as keyword arguments and will be used to construct httpx Clients internally:
|
||||
|
||||
``base_url``: The base URL for the API, all requests are made to a relative path to this URL
|
||||
This can also be set via the GALILEO_CONSOLE_URL environment variable
|
||||
|
||||
``api_key``: The API key to be sent with every request
|
||||
This can also be set via the GALILEO_API_KEY environment variable
|
||||
|
||||
``cookies``: A dictionary of cookies to be sent with every request
|
||||
|
||||
``headers``: A dictionary of headers to be sent with every request
|
||||
|
||||
``timeout``: The maximum amount of a time a request can take. API functions will raise
|
||||
httpx.TimeoutException if this is exceeded.
|
||||
|
||||
``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production,
|
||||
but can be set to False for testing purposes.
|
||||
|
||||
``follow_redirects``: Whether or not to follow redirects. Default value is False.
|
||||
|
||||
``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor.
|
||||
|
||||
Attributes:
|
||||
raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a
|
||||
status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
|
||||
argument to the constructor.
|
||||
token: The token to use for authentication
|
||||
prefix: The prefix to use for the Authorization header
|
||||
auth_header_name: The name of the Authorization header
|
||||
"""
|
||||
|
||||
_base_url: Optional[str] = field(factory=lambda: GalileoApiClient.get_api_url(), kw_only=True, alias="base_url")
|
||||
_api_key: Optional[str] = field(factory=lambda: getenv("GALILEO_API_KEY", None), kw_only=True, alias="api_key")
|
||||
token: Optional[str] = None
|
||||
|
||||
api_key_header_name: str = "Galileo-API-Key"
|
||||
class ApiClient():
|
||||
api_key_header_name: str = "API-Key"
|
||||
client_type_header_name: str = "client-type"
|
||||
client_type_header_value: str = "sdk-python"
|
||||
|
||||
@staticmethod
|
||||
def get_console_url() -> str:
|
||||
console_url = getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL)
|
||||
console_url = getenv("CONSOLE_URL", DEFAULT_API_URL)
|
||||
if DEFAULT_API_URL == console_url:
|
||||
return DEFAULT_APP_URL
|
||||
|
||||
return console_url
|
||||
|
||||
def with_api_key(self, api_key: str) -> "GalileoApiClient":
|
||||
def with_api_key(self, api_key: str) -> "ApiClient": # ---> here is the problem with circular dependency, this makes libcst thinks that ApiClient needs an import despite it's already in the same file.
|
||||
"""Get a new client matching this one with a new API key"""
|
||||
if self._client is not None:
|
||||
self._client.headers.update({self.api_key_header_name: api_key})
|
||||
if self._async_client is not None:
|
||||
self._async_client.headers.update({self.api_key_header_name: api_key})
|
||||
return evolve(self, api_key=api_key)
|
||||
|
||||
@staticmethod
|
||||
def get_api_url(base_url: Optional[str] = None) -> str:
|
||||
api_url = base_url or getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL)
|
||||
if api_url is None:
|
||||
raise ValueError("base_url or GALILEO_CONSOLE_URL must be set")
|
||||
if any(map(api_url.__contains__, ["localhost", "127.0.0.1"])):
|
||||
api_url = "http://localhost:8088"
|
||||
else:
|
||||
api_url = api_url.replace("app.galileo.ai", "api.galileo.ai").replace("console", "api")
|
||||
return api_url
|
||||
|
|
|
|||
37
code_to_optimize/code_directories/circular_deps/optimized.py
Normal file
37
code_to_optimize/code_directories/circular_deps/optimized.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from os import getenv
|
||||
|
||||
from attrs import define
|
||||
from api_client import ApiClient
|
||||
from constants import DEFAULT_API_URL, DEFAULT_APP_URL
|
||||
|
||||
|
||||
@define
|
||||
class ApiClient():
|
||||
|
||||
@staticmethod
|
||||
def get_console_url() -> str:
|
||||
# Cache env lookup for speed
|
||||
console_url = getenv("CONSOLE_URL")
|
||||
if not console_url or console_url == DEFAULT_API_URL:
|
||||
return DEFAULT_APP_URL
|
||||
return console_url
|
||||
|
||||
# Pre-parse netlocs that are checked frequently to avoid parsing repeatedly
|
||||
_DEFAULT_APP_URL_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc
|
||||
_DEFAULT_API_URL_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc
|
||||
|
||||
def get_dest_url(url: str) -> str:
|
||||
destination = url if url else ApiClient.get_console_url()
|
||||
# Replace only if 'console.' is at the beginning to avoid partial matches
|
||||
if destination.startswith("console."):
|
||||
destination = "api." + destination[len("console."):]
|
||||
else:
|
||||
destination = destination.replace("console.", "api.", 1)
|
||||
|
||||
parsed_url = urllib.parse.urlparse(destination)
|
||||
if parsed_url.netloc == _DEFAULT_APP_URL_NETLOC or parsed_url.netloc == _DEFAULT_API_URL_NETLOC:
|
||||
return f"{DEFAULT_APP_URL}api/traces"
|
||||
return f"{parsed_url.scheme}://{parsed_url.netloc}/traces"
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[tool.codeflash]
|
||||
# All paths are relative to this pyproject.toml's directory.
|
||||
module-root = "."
|
||||
tests-root = "tests"
|
||||
test-framework = "pytest"
|
||||
ignore-paths = []
|
||||
formatter-cmds = ["black $file"]
|
||||
|
|
@ -325,36 +325,30 @@ def add_needed_imports_from_module(
|
|||
)
|
||||
)
|
||||
cst.parse_module(src_module_code).visit(gatherer)
|
||||
scheduled_unused_imports = []
|
||||
try:
|
||||
for mod in gatherer.module_imports:
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod)
|
||||
scheduled_unused_imports.append((mod, "", ""))
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
|
||||
for mod, obj_seq in gatherer.object_mapping.items():
|
||||
logger.debug(f"dst_context.full_module_name: {dst_context.full_module_name}")
|
||||
logger.debug(f"mod: {mod}")
|
||||
logger.debug(f"obj_seq: {obj_seq}")
|
||||
logger.debug(f"helper_functions_fqn: {helper_functions_fqn}")
|
||||
for obj in obj_seq:
|
||||
if (
|
||||
f"{mod}.{obj}" in helper_functions_fqn
|
||||
or dst_context.full_module_name == mod # avoid circular imports
|
||||
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
|
||||
):
|
||||
continue # Skip adding imports for helper functions already in the context
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
|
||||
scheduled_unused_imports.append((mod, obj, ""))
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding imports to destination module code: {e}")
|
||||
return dst_module_code
|
||||
for mod, asname in gatherer.module_aliases.items():
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
|
||||
scheduled_unused_imports.append((mod, "", asname))
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
|
||||
for mod, alias_pairs in gatherer.alias_mapping.items():
|
||||
for alias_pair in alias_pairs:
|
||||
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
|
||||
continue
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
|
||||
scheduled_unused_imports.append((mod, alias_pair[0], alias_pair[1]))
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
|
||||
|
||||
try:
|
||||
parsed_module = cst.parse_module(dst_module_code)
|
||||
|
|
@ -363,9 +357,6 @@ def add_needed_imports_from_module(
|
|||
return dst_module_code # Return the original code if there's a syntax error
|
||||
try:
|
||||
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
|
||||
for _import in scheduled_unused_imports:
|
||||
(_module, _obj, _alias) = _import
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, module=_module, obj=_obj, asname=_alias)
|
||||
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
|
||||
return transformed_module.code.lstrip("\n")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -397,13 +397,6 @@ def replace_functions_and_add_imports(
|
|||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
|
||||
project_root_path: Path,
|
||||
) -> str:
|
||||
logger.debug("start from here,...")
|
||||
logger.debug(f"source_code: {source_code}")
|
||||
logger.debug(f"function_names: {function_names}")
|
||||
logger.debug(f"optimized_code: {optimized_code}")
|
||||
logger.debug(f"module_abspath: {module_abspath}")
|
||||
logger.debug(f"preexisting_objects: {preexisting_objects}")
|
||||
logger.debug(f"project_root_path: {project_root_path}")
|
||||
return add_needed_imports_from_module(
|
||||
optimized_code,
|
||||
replace_functions_in_file(source_code, function_names, optimized_code, preexisting_objects),
|
||||
|
|
@ -422,12 +415,16 @@ def replace_function_definitions_in_module(
|
|||
) -> bool:
|
||||
source_code: str = module_abspath.read_text(encoding="utf8")
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code, function_names, optimized_code, module_abspath, preexisting_objects, project_root_path
|
||||
add_global_assignments(optimized_code, source_code),
|
||||
function_names,
|
||||
optimized_code,
|
||||
module_abspath,
|
||||
preexisting_objects,
|
||||
project_root_path,
|
||||
)
|
||||
if is_zero_diff(source_code, new_code):
|
||||
return False
|
||||
code_with_global_assignments = add_global_assignments(optimized_code, new_code)
|
||||
module_abspath.write_text(code_with_global_assignments, encoding="utf8")
|
||||
module_abspath.write_text(new_code, encoding="utf8")
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from codeflash.context.code_context_extractor import get_code_optimization_conte
|
|||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.code_utils.code_extractor import add_global_assignments
|
||||
|
||||
|
||||
class HelperClass:
|
||||
|
|
@ -2436,51 +2438,20 @@ class SimpleClass:
|
|||
assert "return 42" in code_content
|
||||
|
||||
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
|
||||
def test_replace_functions_and_add_imports():
|
||||
path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps"
|
||||
optimized_code = '''from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from os import getenv
|
||||
|
||||
from attrs import define
|
||||
from code_to_optimize.code_directories.circular_deps.constants import DEFAULT_API_URL, DEFAULT_APP_URL
|
||||
|
||||
# Precompute constant netlocs for set membership test
|
||||
_DEFAULT_APP_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc
|
||||
_DEFAULT_API_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc
|
||||
_NETLOC_SET = {_DEFAULT_APP_NETLOC, _DEFAULT_API_NETLOC}
|
||||
|
||||
@define
|
||||
class GalileoApiClient():
|
||||
|
||||
@staticmethod
|
||||
def get_console_url() -> str:
|
||||
# Return DEFAULT_APP_URL if the env var is not set or set to DEFAULT_API_URL
|
||||
console_url = getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL)
|
||||
if console_url == DEFAULT_API_URL:
|
||||
return DEFAULT_APP_URL
|
||||
return console_url
|
||||
|
||||
def _set_destination(console_url: str) -> str:
|
||||
"""
|
||||
Parse the console_url and return the destination for the OpenTelemetry traces.
|
||||
"""
|
||||
destination = (console_url or GalileoApiClient.get_console_url()).replace("console.", "api.")
|
||||
parsed_url = urllib.parse.urlparse(destination)
|
||||
if parsed_url.netloc in _NETLOC_SET:
|
||||
return f"{DEFAULT_APP_URL}api/galileo/otel/traces"
|
||||
return f"{parsed_url.scheme}://{parsed_url.netloc}/otel/traces"'''
|
||||
file_abs_path = path_to_root / "api_client.py"
|
||||
optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8")
|
||||
content = Path(file_abs_path).read_text(encoding="utf-8")
|
||||
new_code = replace_functions_and_add_imports(
|
||||
source_code= content,
|
||||
function_names= ["GalileoApiClient.get_console_url"],
|
||||
source_code= add_global_assignments(optimized_code, content),
|
||||
function_names= ["ApiClient.get_console_url"],
|
||||
optimized_code= optimized_code,
|
||||
module_abspath= file_abs_path,
|
||||
preexisting_objects= {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))},
|
||||
module_abspath= Path(file_abs_path),
|
||||
preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))},
|
||||
project_root_path= Path(path_to_root),
|
||||
)
|
||||
print(new_code)
|
||||
assert 1 == 1
|
||||
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
|
||||
|
||||
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
|
||||
|
|
|
|||
Loading…
Reference in a new issue