tests and fix global assignments imports

This commit is contained in:
mohammed 2025-06-29 00:23:36 +03:00
parent 2acda6a411
commit 2e394f6b8f
6 changed files with 73 additions and 122 deletions

View file

@ -1,77 +1,25 @@
from os import getenv
from typing import Optional
from attrs import define, evolve, field
from attrs import define, evolve
from code_to_optimize.code_directories.circular_deps.constants import DEFAULT_API_URL, DEFAULT_APP_URL
from constants import DEFAULT_API_URL, DEFAULT_APP_URL
@define
class GalileoApiClient():
"""A Client which has been authenticated for use on secured endpoints
The following are accepted as keyword arguments and will be used to construct httpx Clients internally:
``base_url``: The base URL for the API, all requests are made to a relative path to this URL
This can also be set via the GALILEO_CONSOLE_URL environment variable
``api_key``: The API key to be sent with every request
This can also be set via the GALILEO_API_KEY environment variable
``cookies``: A dictionary of cookies to be sent with every request
``headers``: A dictionary of headers to be sent with every request
``timeout``: The maximum amount of a time a request can take. API functions will raise
httpx.TimeoutException if this is exceeded.
``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production,
but can be set to False for testing purposes.
``follow_redirects``: Whether or not to follow redirects. Default value is False.
``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor.
Attributes:
raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a
status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
argument to the constructor.
token: The token to use for authentication
prefix: The prefix to use for the Authorization header
auth_header_name: The name of the Authorization header
"""
_base_url: Optional[str] = field(factory=lambda: GalileoApiClient.get_api_url(), kw_only=True, alias="base_url")
_api_key: Optional[str] = field(factory=lambda: getenv("GALILEO_API_KEY", None), kw_only=True, alias="api_key")
token: Optional[str] = None
api_key_header_name: str = "Galileo-API-Key"
class ApiClient():
api_key_header_name: str = "API-Key"
client_type_header_name: str = "client-type"
client_type_header_value: str = "sdk-python"
@staticmethod
def get_console_url() -> str:
console_url = getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL)
console_url = getenv("CONSOLE_URL", DEFAULT_API_URL)
if DEFAULT_API_URL == console_url:
return DEFAULT_APP_URL
return console_url
def with_api_key(self, api_key: str) -> "GalileoApiClient":
def with_api_key(self, api_key: str) -> "ApiClient": # ---> here is the problem with circular dependency, this makes libcst thinks that ApiClient needs an import despite it's already in the same file.
"""Get a new client matching this one with a new API key"""
if self._client is not None:
self._client.headers.update({self.api_key_header_name: api_key})
if self._async_client is not None:
self._async_client.headers.update({self.api_key_header_name: api_key})
return evolve(self, api_key=api_key)
@staticmethod
def get_api_url(base_url: Optional[str] = None) -> str:
api_url = base_url or getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL)
if api_url is None:
raise ValueError("base_url or GALILEO_CONSOLE_URL must be set")
if any(map(api_url.__contains__, ["localhost", "127.0.0.1"])):
api_url = "http://localhost:8088"
else:
api_url = api_url.replace("app.galileo.ai", "api.galileo.ai").replace("console", "api")
return api_url

View file

@ -0,0 +1,37 @@
from __future__ import annotations
import urllib.parse
from os import getenv
from attrs import define
from api_client import ApiClient
from constants import DEFAULT_API_URL, DEFAULT_APP_URL
@define
class ApiClient():
@staticmethod
def get_console_url() -> str:
# Cache env lookup for speed
console_url = getenv("CONSOLE_URL")
if not console_url or console_url == DEFAULT_API_URL:
return DEFAULT_APP_URL
return console_url
# Pre-parse netlocs that are checked frequently to avoid parsing repeatedly
_DEFAULT_APP_URL_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc
_DEFAULT_API_URL_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc
def get_dest_url(url: str) -> str:
destination = url if url else ApiClient.get_console_url()
# Replace only if 'console.' is at the beginning to avoid partial matches
if destination.startswith("console."):
destination = "api." + destination[len("console."):]
else:
destination = destination.replace("console.", "api.", 1)
parsed_url = urllib.parse.urlparse(destination)
if parsed_url.netloc == _DEFAULT_APP_URL_NETLOC or parsed_url.netloc == _DEFAULT_API_URL_NETLOC:
return f"{DEFAULT_APP_URL}api/traces"
return f"{parsed_url.scheme}://{parsed_url.netloc}/traces"

View file

@ -0,0 +1,7 @@
[tool.codeflash]
# All paths are relative to this pyproject.toml's directory.
module-root = "."
tests-root = "tests"
test-framework = "pytest"
ignore-paths = []
formatter-cmds = ["black $file"]

View file

@ -325,36 +325,30 @@ def add_needed_imports_from_module(
)
)
cst.parse_module(src_module_code).visit(gatherer)
scheduled_unused_imports = []
try:
for mod in gatherer.module_imports:
AddImportsVisitor.add_needed_import(dst_context, mod)
scheduled_unused_imports.append((mod, "", ""))
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
for mod, obj_seq in gatherer.object_mapping.items():
logger.debug(f"dst_context.full_module_name: {dst_context.full_module_name}")
logger.debug(f"mod: {mod}")
logger.debug(f"obj_seq: {obj_seq}")
logger.debug(f"helper_functions_fqn: {helper_functions_fqn}")
for obj in obj_seq:
if (
f"{mod}.{obj}" in helper_functions_fqn
or dst_context.full_module_name == mod # avoid circular imports
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
):
continue # Skip adding imports for helper functions already in the context
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
scheduled_unused_imports.append((mod, obj, ""))
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
except Exception as e:
logger.exception(f"Error adding imports to destination module code: {e}")
return dst_module_code
for mod, asname in gatherer.module_aliases.items():
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
scheduled_unused_imports.append((mod, "", asname))
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
for mod, alias_pairs in gatherer.alias_mapping.items():
for alias_pair in alias_pairs:
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
continue
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
scheduled_unused_imports.append((mod, alias_pair[0], alias_pair[1]))
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
try:
parsed_module = cst.parse_module(dst_module_code)
@ -363,9 +357,6 @@ def add_needed_imports_from_module(
return dst_module_code # Return the original code if there's a syntax error
try:
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
for _import in scheduled_unused_imports:
(_module, _obj, _alias) = _import
RemoveImportsVisitor.remove_unused_import(dst_context, module=_module, obj=_obj, asname=_alias)
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
return transformed_module.code.lstrip("\n")
except Exception as e:

View file

@ -397,13 +397,6 @@ def replace_functions_and_add_imports(
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
project_root_path: Path,
) -> str:
logger.debug("start from here,...")
logger.debug(f"source_code: {source_code}")
logger.debug(f"function_names: {function_names}")
logger.debug(f"optimized_code: {optimized_code}")
logger.debug(f"module_abspath: {module_abspath}")
logger.debug(f"preexisting_objects: {preexisting_objects}")
logger.debug(f"project_root_path: {project_root_path}")
return add_needed_imports_from_module(
optimized_code,
replace_functions_in_file(source_code, function_names, optimized_code, preexisting_objects),
@ -422,12 +415,16 @@ def replace_function_definitions_in_module(
) -> bool:
source_code: str = module_abspath.read_text(encoding="utf8")
new_code: str = replace_functions_and_add_imports(
source_code, function_names, optimized_code, module_abspath, preexisting_objects, project_root_path
add_global_assignments(optimized_code, source_code),
function_names,
optimized_code,
module_abspath,
preexisting_objects,
project_root_path,
)
if is_zero_diff(source_code, new_code):
return False
code_with_global_assignments = add_global_assignments(optimized_code, new_code)
module_abspath.write_text(code_with_global_assignments, encoding="utf8")
module_abspath.write_text(new_code, encoding="utf8")
return True

View file

@ -11,6 +11,8 @@ from codeflash.context.code_context_extractor import get_code_optimization_conte
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.code_utils.code_extractor import add_global_assignments
class HelperClass:
@ -2436,51 +2438,20 @@ class SimpleClass:
assert "return 42" in code_content
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
def test_replace_functions_and_add_imports():
path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps"
optimized_code = '''from __future__ import annotations
import urllib.parse
from os import getenv
from attrs import define
from code_to_optimize.code_directories.circular_deps.constants import DEFAULT_API_URL, DEFAULT_APP_URL
# Precompute constant netlocs for set membership test
_DEFAULT_APP_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc
_DEFAULT_API_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc
_NETLOC_SET = {_DEFAULT_APP_NETLOC, _DEFAULT_API_NETLOC}
@define
class GalileoApiClient():
@staticmethod
def get_console_url() -> str:
# Return DEFAULT_APP_URL if the env var is not set or set to DEFAULT_API_URL
console_url = getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL)
if console_url == DEFAULT_API_URL:
return DEFAULT_APP_URL
return console_url
def _set_destination(console_url: str) -> str:
"""
Parse the console_url and return the destination for the OpenTelemetry traces.
"""
destination = (console_url or GalileoApiClient.get_console_url()).replace("console.", "api.")
parsed_url = urllib.parse.urlparse(destination)
if parsed_url.netloc in _NETLOC_SET:
return f"{DEFAULT_APP_URL}api/galileo/otel/traces"
return f"{parsed_url.scheme}://{parsed_url.netloc}/otel/traces"'''
file_abs_path = path_to_root / "api_client.py"
optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8")
content = Path(file_abs_path).read_text(encoding="utf-8")
new_code = replace_functions_and_add_imports(
source_code= content,
function_names= ["GalileoApiClient.get_console_url"],
source_code= add_global_assignments(optimized_code, content),
function_names= ["ApiClient.get_console_url"],
optimized_code= optimized_code,
module_abspath= file_abs_path,
preexisting_objects= {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))},
module_abspath= Path(file_abs_path),
preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))},
project_root_path= Path(path_to_root),
)
print(new_code)
assert 1 == 1
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"