ruff check fixes

This commit is contained in:
Kevin Turcios 2025-08-07 14:35:21 -07:00
parent f72ed92ce4
commit 61b030e47b
16 changed files with 61 additions and 64 deletions

View file

@ -3,36 +3,39 @@ from __future__ import annotations
import logging
import os
import sys
from collections.abc import Callable
from typing import TYPE_CHECKING
from dotenv import load_dotenv
from httpx import AsyncClient
from openai import AsyncOpenAI
from openai.lib.azure import AsyncAzureOpenAI
if TYPE_CHECKING:
from collections.abc import Callable
IS_PRODUCTION = os.environ.get("ENVIRONMENT", default="") == "PRODUCTION"
LOGGING_FORMAT = "[%(levelname)s] %(message)s"
def load_env():
def load_env() -> None:
if not IS_PRODUCTION:
load_dotenv()
def set_logging_level():
def set_logging_level() -> None:
if IS_PRODUCTION:
logging.basicConfig(level=logging.INFO, format=LOGGING_FORMAT, stream=sys.stdout)
else:
logging.basicConfig(level=logging.DEBUG, format=LOGGING_FORMAT, stream=sys.stdout)
def debug_log_sensitive_data(message: str):
def debug_log_sensitive_data(message: str) -> None:
if not IS_PRODUCTION:
logging.debug(message)
def debug_log_sensitive_data_from_callable(message: Callable[[], str | None]):
def debug_log_sensitive_data_from_callable(message: Callable[[], str | None]) -> None:
if not IS_PRODUCTION:
logging.debug(message())

View file

@ -7,7 +7,7 @@ from authapp.auth import AuthBearer
@async_only_middleware
class AuthMiddleware:
def __init__(self, get_response):
def __init__(self, get_response) -> None:
self.get_response = get_response
self.auth_bearer = AuthBearer()

View file

@ -5,7 +5,7 @@ from django.utils.decorators import async_only_middleware
@async_only_middleware
class HealthCheckMiddleware:
def __init__(self, get_response):
def __init__(self, get_response) -> None:
self.get_response = get_response
if iscoroutinefunction(self.get_response):
print("HealthcheckMiddleware is async.")

View file

@ -15,7 +15,7 @@ RATE_LIMIT_MAX = int(os.getenv("RATE_LIMIT_MAX", "40"))
# TODO: Implement a distributed caching solution (e.g., Redis) for multi-server environments.
@async_only_middleware
class RateLimitMiddleware:
def __init__(self, get_response):
def __init__(self, get_response) -> None:
self.get_response = get_response
if iscoroutinefunction(self.get_response):
print("RateLimitMiddleware is async.")

View file

@ -11,7 +11,7 @@ class FunctionParent:
type: str
@dataclass(frozen=True, config=dict(arbitrary_types_allowed=True))
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
class FunctionToOptimize:
function_name: str
file_path: str

View file

@ -3,7 +3,7 @@ import pytest
from aiservice.common_utils import parse_python_version, validate_trace_id
def test_python_version():
def test_python_version() -> None:
assert parse_python_version("3.9.0") == (3, 9, 0)
assert parse_python_version("3.12.2") == (3, 12, 2)
with pytest.raises(ValueError):
@ -16,7 +16,7 @@ def test_python_version():
parse_python_version("3.1231231129.23432412312312313212412131212122112")
def test_validate_trace_id():
def test_validate_trace_id() -> None:
assert validate_trace_id("f47ac10b-58cc-4372-a567-0e02b2c3d479")
assert validate_trace_id("f47ac10b-58cc-4372-a567-0e02b2c3EXP0")
assert validate_trace_id("f47ac10b-58cc-4372-a567-0e02b2c3EXP1")

View file

@ -7,7 +7,7 @@ from authapp.auth_utils import hash_api_key, instance_for_api_key
from authapp.models import CFAPIKeys, Subscriptions
async def check_subscription_status(user_id, tier):
async def check_subscription_status(user_id, tier) -> bool:
"""Check if a user has a premium subscription that doesn't require feature logging.
Args:

View file

@ -6,7 +6,7 @@ from .models import CFAPIKeys
def hash_api_key(api_key: str) -> str:
"""Hash the API key using SHA-384 and encode the result using base64url.
The hashing source of truth implementation is the hashApiKey function in js/common/token-functions.ts
The hashing source of truth implementation is the hashApiKey function in js/common/token-functions.ts.
Parameters
----------
@ -35,7 +35,6 @@ async def instance_for_api_key(hashed_api_key: str) -> CFAPIKeys | None:
"""
try:
api_key_instance = await CFAPIKeys.objects.aget(key=hashed_api_key)
return api_key_instance
return await CFAPIKeys.objects.aget(key=hashed_api_key)
except CFAPIKeys.DoesNotExist:
return None

View file

@ -1,9 +1,8 @@
from __future__ import annotations
import asyncio
import datetime as dt
import logging
from typing import Any
from typing import TYPE_CHECKING, Any
from asgiref.sync import sync_to_async
from django.db import IntegrityError, transaction
@ -12,6 +11,9 @@ from ninja import NinjaAPI, Schema
from aiservice.common_utils import validate_trace_id
from log_features.models import OptimizationFeatures
if TYPE_CHECKING:
import datetime as dt
features_api = NinjaAPI(urls_namespace="log_features")
@ -73,7 +75,7 @@ async def log_features(
@sync_to_async
@transaction.atomic
def db_operation():
def db_operation() -> None:
# Try to get existing record with a lock
f, created = OptimizationFeatures.objects.select_for_update().get_or_create(
trace_id=trace_id,
@ -187,7 +189,7 @@ async def log_features(
except Exception as e:
logging.exception(f"Error logging features: {e}")
raise e
raise
class LoggingSchema(Schema):

View file

@ -9,7 +9,7 @@ from aiservice.analytics.posthog import ph
from aiservice.env_specific import load_env, set_logging_level
def main():
def main() -> None:
"""Run administrative tasks."""
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "aiservice.settings")
# Get the network name of the machine

View file

@ -1,8 +1,8 @@
import ast
def find_init(node):
"""Recursively search an AST for a FunctionDef node named '__init__'
def find_init(node) -> bool:
"""Recursively search an AST for a FunctionDef node named '__init__'.
Args:
node: An AST node to search
@ -16,8 +16,4 @@ def find_init(node):
return True
# Recursively check all child nodes
for child in ast.iter_child_nodes(node):
if find_init(child):
return True
return False
return any(find_init(child) for child in ast.iter_child_nodes(node))

View file

@ -184,7 +184,7 @@ class SingleOptimizerContext(BaseOptimizerContext):
# MultiOptimizerContext #
##########################################################################################
class MultiOptimizerContext(BaseOptimizerContext):
def __init__(self, base_system_prompt: str, base_user_prompt: str, source_code: str):
def __init__(self, base_system_prompt: str, base_user_prompt: str, source_code: str) -> None:
super().__init__(base_system_prompt, base_user_prompt, source_code)
self.original_file_to_code = split_markdown_code(source_code)
@ -193,12 +193,11 @@ class MultiOptimizerContext(BaseOptimizerContext):
def get_user_prompt(self, dependency_code: str, line_profiler_results: str | None) -> str:
has_init = any(find_init(ast.parse(code)) for code in self.original_file_to_code.values())
user_prompt = (
return (
f"{DEPS_CONTEXT_PROMPT.format(dependency_code=dependency_code) if dependency_code else ''}"
f"{self.base_user_prompt.format(source_code=self.source_code, init_optimization_prompt=INIT_OPTIMIZATION_PROMPT if has_init else '')}"
f"{LINE_PROF_CONTEXT_PROMPT.format(line_profiler_results=line_profiler_results) if line_profiler_results else ''}"
)
return user_prompt
def extract_code_and_explanation_from_llm_res(self, content: str) -> CodeStrAndExplanation:
code_blocks = MARKDOWN_BLOCK_PATTERN.findall(content)

View file

@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Never
import libcst as cst
from pydantic import ValidationError
@ -37,7 +38,7 @@ class RefinementContextData:
# BaseRefinerContext #
##########################################################################################
class BaseRefinerContext:
def __init__(self, ctx_data: RefinementContextData, base_system_prompt: str, base_user_prompt: str):
def __init__(self, ctx_data: RefinementContextData, base_system_prompt: str, base_user_prompt: str) -> None:
self.data = ctx_data
self.base_system_prompt = base_system_prompt
self.base_user_prompt = base_user_prompt
@ -89,7 +90,7 @@ class BaseRefinerContext:
except cst.ParserSyntaxError:
return False
def validate_python_module(self, code: str):
def validate_python_module(self, code: str) -> Never:
raise NotImplementedError
@ -97,7 +98,7 @@ class BaseRefinerContext:
# SingleRefinerContext #
##########################################################################################
class SingleRefinerContext(BaseRefinerContext):
def __init__(self, ctx_data: RefinementContextData, base_system_prompt: str, base_user_prompt: str):
def __init__(self, ctx_data: RefinementContextData, base_system_prompt: str, base_user_prompt: str) -> None:
super().__init__(ctx_data, base_system_prompt, base_user_prompt)
def get_system_prompt(self) -> str:
@ -134,19 +135,19 @@ class SingleRefinerContext(BaseRefinerContext):
def is_valid_refinement(self, new_refined_code: str) -> bool:
return super().is_valid_refinement(new_refined_code)
def validate_python_module(self, code: str):
def validate_python_module(self, code: str) -> None:
try:
cst_module = parse_module_to_cst(code)
CodeAndExplanation(cst_module, "")
except (ValueError, ValidationError, cst.ParserSyntaxError) as exc:
raise exc
except (ValueError, ValidationError, cst.ParserSyntaxError):
raise
##########################################################################################
# MultiRefinerContext #
##########################################################################################
class MultiRefinerContext(BaseRefinerContext):
def __init__(self, ctx_data: RefinementContextData, base_system_prompt: str, base_user_prompt: str):
def __init__(self, ctx_data: RefinementContextData, base_system_prompt: str, base_user_prompt: str) -> None:
super().__init__(ctx_data, base_system_prompt, base_user_prompt)
def get_system_prompt(self) -> str:
@ -190,10 +191,10 @@ class MultiRefinerContext(BaseRefinerContext):
break
return valid
def validate_python_module(self, code: str):
def validate_python_module(self, code: str) -> None:
for code in split_markdown_code(self.data.optimized_source_code).values():
try:
cst_module = parse_module_to_cst(code)
CodeAndExplanation(cst_module, "")
except (ValueError, ValidationError, cst.ParserSyntaxError) as exc:
raise exc
except (ValueError, ValidationError, cst.ParserSyntaxError):
raise

View file

@ -210,7 +210,7 @@ async def optimize(request, data: OptimizeSchema) -> tuple[int, OptimizeResponse
response = OptimizeResponseSchema(optimizations=optimization_response_items)
def log_response():
def log_response() -> None:
debug_log_sensitive_data(f"Response:\n{response.json()}")
for opt in response.optimizations:
debug_log_sensitive_data(f"Optimized source:\n{opt.source_code}")

View file

@ -27,12 +27,11 @@ class RefinerTestCase:
def create_optimizer_context(code: str) -> BaseOptimizerContext:
ctx = BaseOptimizerContext.get_dynamic_context("", "", code)
return ctx
return BaseOptimizerContext.get_dynamic_context("", "", code)
def create_refiner_context(optimized_code: str) -> BaseRefinerContext:
ctx = BaseRefinerContext.get_dynamic_context(
return BaseRefinerContext.get_dynamic_context(
RefinementContextData(
original_source_code="",
original_line_profiler_results="",
@ -47,7 +46,6 @@ def create_refiner_context(optimized_code: str) -> BaseRefinerContext:
"",
"",
)
return ctx
single_optimizer_test_cases: list[OptimizerTestCase] = [
@ -566,7 +564,7 @@ def sorter_test_final10(arr):
]
def test_single_optimizer_context():
def test_single_optimizer_context() -> None:
for t in single_optimizer_test_cases:
ctx = create_optimizer_context("")
ctx.extract_code_and_explanation_from_llm_res(t.llm_response)
@ -577,7 +575,7 @@ def test_single_optimizer_context():
assert ctx.is_valid_code()
def test_multi_optimizer_context():
def test_multi_optimizer_context() -> None:
for t in multi_optimizer_test_cases:
ctx = create_optimizer_context(t.original_code)
ctx.extract_code_and_explanation_from_llm_res(t.llm_response)
@ -588,7 +586,7 @@ def test_multi_optimizer_context():
assert ctx.is_valid_code()
def test_single_refiner_context():
def test_single_refiner_context() -> None:
for t in single_refiner_test_cases:
ctx = create_refiner_context(t.optimized_code)
patches = ctx.extract_diff_patches_from_llm_res(t.llm_response)
@ -597,7 +595,7 @@ def test_single_refiner_context():
assert ctx.is_valid_refinement(refined_code)
def test_multi_refiner_context():
def test_multi_refiner_context() -> None:
for t in multi_refiner_test_cases:
ctx = create_refiner_context(t.optimized_code)
patches = ctx.extract_diff_patches_from_llm_res(t.llm_response)
@ -607,7 +605,7 @@ def test_multi_refiner_context():
assert ctx.is_valid_refinement(refined_code)
def test_split_and_group_markdown_code():
def test_split_and_group_markdown_code() -> None:
code = """```python:path/to/file1.py
file1=__name__
```
@ -643,7 +641,7 @@ def sorter2(arr):
assert group_code(expected) == code
def test_system_prompt_for_multi_optimizer():
def test_system_prompt_for_multi_optimizer() -> None:
ctx = create_optimizer_context(
"```python:some_file.py\nprint(\"Hello world\")\n```\n```python:some_other_file.py\nprint('hi')\n````\n"
)

View file

@ -5,7 +5,7 @@ from optimizer.models import CodeExplanationAndID
from optimizer.postprocess import DocstringTransformer, DocstringVisitor
def test_function_docstring_preservation():
def test_function_docstring_preservation() -> None:
# Original code with docstring
original_code = """
def example_function():
@ -31,7 +31,7 @@ def example_function():
assert "This is a docstring for the example function" in transformed_tree.code
def test_function_docstring_in_both_functions():
def test_function_docstring_in_both_functions() -> None:
# Original code with docstring
original_code = """
def example_function():
@ -58,7 +58,7 @@ def example_function():
assert "This is a docstring for the example function" in transformed_tree.code
def test_class_docstring_preservation():
def test_class_docstring_preservation() -> None:
# Original code with class docstring
original_code = """
class ExampleClass:
@ -86,7 +86,7 @@ class ExampleClass:
assert "This is a docstring for the example class" in transformed_tree.code
def test_class_docstring_in_both_functions():
def test_class_docstring_in_both_functions() -> None:
# Original code with class docstring
original_code = """
class ExampleClass:
@ -115,7 +115,7 @@ class ExampleClass:
assert "This is a docstring for the example class" in transformed_tree.code
def test_method_docstring_preservation():
def test_method_docstring_preservation() -> None:
# Original code with method docstring
original_code = """
class ExampleClass:
@ -143,7 +143,7 @@ class ExampleClass:
assert "This is a docstring for the method" in transformed_tree.code
def test_fix_missing_docstring_pipeline_function():
def test_fix_missing_docstring_pipeline_function() -> None:
# Test the integration with the fix_missing_docstring pipeline function
from optimizer.postprocess import fix_missing_docstring
@ -173,7 +173,7 @@ def example_function():
@pytest.mark.skip(
reason="This currently results in an exception and should be fixed. This test case reproduces the error"
)
def test_docstring_exception():
def test_docstring_exception() -> None:
# derived from optimizing https://github.com/pydantic/pydantic-ai/blob/39e28771e538a3a4af98222ca565ecfa402d9c08/pydantic_ai_slim/pydantic_ai/agent.py#L1717
original_code = """import dataclasses
import warnings
@ -414,11 +414,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
]
# Apply the pipeline function
result = fix_missing_docstring(original_code, code_explanations)
fix_missing_docstring(original_code, code_explanations)
# TODO : This test case fails
def test_benjamin_button():
def test_benjamin_button() -> None:
original_code = """
def test_1():
\"\"\"useful docstring
@ -443,14 +443,13 @@ def test_2():
# Apply transformer to optimized code
transformer = DocstringTransformer(original_visitor.original_docstrings)
optimized_tree = cst.parse_module(optimized_code)
transformed_tree = optimized_tree.visit(transformer)
optimized_tree.visit(transformer)
# Check if all docstrings were preserved
transformed_code = transformed_tree.code
assert True
def test_multiple_functions_and_classes():
def test_multiple_functions_and_classes() -> None:
# Test with multiple functions and classes
original_code = """
def function1():
@ -548,7 +547,7 @@ class ExampleClass:
}
def test_docstring_preservation_with_fixtures(code_with_docstrings, code_without_docstrings):
def test_docstring_preservation_with_fixtures(code_with_docstrings, code_without_docstrings) -> None:
# Test all three types of docstrings using fixtures
for code_type in ["function", "class", "method"]:
original_code = code_with_docstrings[code_type]