mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
ruff check fixes
This commit is contained in:
parent
f72ed92ce4
commit
61b030e47b
16 changed files with 61 additions and 64 deletions
|
|
@ -3,36 +3,39 @@ from __future__ import annotations
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from httpx import AsyncClient
|
||||
from openai import AsyncOpenAI
|
||||
from openai.lib.azure import AsyncAzureOpenAI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
IS_PRODUCTION = os.environ.get("ENVIRONMENT", default="") == "PRODUCTION"
|
||||
|
||||
LOGGING_FORMAT = "[%(levelname)s] %(message)s"
|
||||
|
||||
|
||||
def load_env():
|
||||
def load_env() -> None:
|
||||
if not IS_PRODUCTION:
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def set_logging_level():
|
||||
def set_logging_level() -> None:
|
||||
if IS_PRODUCTION:
|
||||
logging.basicConfig(level=logging.INFO, format=LOGGING_FORMAT, stream=sys.stdout)
|
||||
else:
|
||||
logging.basicConfig(level=logging.DEBUG, format=LOGGING_FORMAT, stream=sys.stdout)
|
||||
|
||||
|
||||
def debug_log_sensitive_data(message: str):
|
||||
def debug_log_sensitive_data(message: str) -> None:
|
||||
if not IS_PRODUCTION:
|
||||
logging.debug(message)
|
||||
|
||||
|
||||
def debug_log_sensitive_data_from_callable(message: Callable[[], str | None]):
|
||||
def debug_log_sensitive_data_from_callable(message: Callable[[], str | None]) -> None:
|
||||
if not IS_PRODUCTION:
|
||||
logging.debug(message())
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from authapp.auth import AuthBearer
|
|||
|
||||
@async_only_middleware
|
||||
class AuthMiddleware:
|
||||
def __init__(self, get_response):
|
||||
def __init__(self, get_response) -> None:
|
||||
self.get_response = get_response
|
||||
self.auth_bearer = AuthBearer()
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from django.utils.decorators import async_only_middleware
|
|||
|
||||
@async_only_middleware
|
||||
class HealthCheckMiddleware:
|
||||
def __init__(self, get_response):
|
||||
def __init__(self, get_response) -> None:
|
||||
self.get_response = get_response
|
||||
if iscoroutinefunction(self.get_response):
|
||||
print("HealthcheckMiddleware is async.")
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ RATE_LIMIT_MAX = int(os.getenv("RATE_LIMIT_MAX", "40"))
|
|||
# TODO: Implement a distributed caching solution (e.g., Redis) for multi-server environments.
|
||||
@async_only_middleware
|
||||
class RateLimitMiddleware:
|
||||
def __init__(self, get_response):
|
||||
def __init__(self, get_response) -> None:
|
||||
self.get_response = get_response
|
||||
if iscoroutinefunction(self.get_response):
|
||||
print("RateLimitMiddleware is async.")
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ class FunctionParent:
|
|||
type: str
|
||||
|
||||
|
||||
@dataclass(frozen=True, config=dict(arbitrary_types_allowed=True))
|
||||
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
||||
class FunctionToOptimize:
|
||||
function_name: str
|
||||
file_path: str
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
from aiservice.common_utils import parse_python_version, validate_trace_id
|
||||
|
||||
|
||||
def test_python_version():
|
||||
def test_python_version() -> None:
|
||||
assert parse_python_version("3.9.0") == (3, 9, 0)
|
||||
assert parse_python_version("3.12.2") == (3, 12, 2)
|
||||
with pytest.raises(ValueError):
|
||||
|
|
@ -16,7 +16,7 @@ def test_python_version():
|
|||
parse_python_version("3.1231231129.23432412312312313212412131212122112")
|
||||
|
||||
|
||||
def test_validate_trace_id():
|
||||
def test_validate_trace_id() -> None:
|
||||
assert validate_trace_id("f47ac10b-58cc-4372-a567-0e02b2c3d479")
|
||||
assert validate_trace_id("f47ac10b-58cc-4372-a567-0e02b2c3EXP0")
|
||||
assert validate_trace_id("f47ac10b-58cc-4372-a567-0e02b2c3EXP1")
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from authapp.auth_utils import hash_api_key, instance_for_api_key
|
|||
from authapp.models import CFAPIKeys, Subscriptions
|
||||
|
||||
|
||||
async def check_subscription_status(user_id, tier):
|
||||
async def check_subscription_status(user_id, tier) -> bool:
|
||||
"""Check if a user has a premium subscription that doesn't require feature logging.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from .models import CFAPIKeys
|
|||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
"""Hash the API key using SHA-384 and encode the result using base64url.
|
||||
The hashing source of truth implementation is the hashApiKey function in js/common/token-functions.ts
|
||||
The hashing source of truth implementation is the hashApiKey function in js/common/token-functions.ts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -35,7 +35,6 @@ async def instance_for_api_key(hashed_api_key: str) -> CFAPIKeys | None:
|
|||
|
||||
"""
|
||||
try:
|
||||
api_key_instance = await CFAPIKeys.objects.aget(key=hashed_api_key)
|
||||
return api_key_instance
|
||||
return await CFAPIKeys.objects.aget(key=hashed_api_key)
|
||||
except CFAPIKeys.DoesNotExist:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.db import IntegrityError, transaction
|
||||
|
|
@ -12,6 +11,9 @@ from ninja import NinjaAPI, Schema
|
|||
from aiservice.common_utils import validate_trace_id
|
||||
from log_features.models import OptimizationFeatures
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import datetime as dt
|
||||
|
||||
features_api = NinjaAPI(urls_namespace="log_features")
|
||||
|
||||
|
||||
|
|
@ -73,7 +75,7 @@ async def log_features(
|
|||
|
||||
@sync_to_async
|
||||
@transaction.atomic
|
||||
def db_operation():
|
||||
def db_operation() -> None:
|
||||
# Try to get existing record with a lock
|
||||
f, created = OptimizationFeatures.objects.select_for_update().get_or_create(
|
||||
trace_id=trace_id,
|
||||
|
|
@ -187,7 +189,7 @@ async def log_features(
|
|||
|
||||
except Exception as e:
|
||||
logging.exception(f"Error logging features: {e}")
|
||||
raise e
|
||||
raise
|
||||
|
||||
|
||||
class LoggingSchema(Schema):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from aiservice.analytics.posthog import ph
|
|||
from aiservice.env_specific import load_env, set_logging_level
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
"""Run administrative tasks."""
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "aiservice.settings")
|
||||
# Get the network name of the machine
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import ast
|
||||
|
||||
|
||||
def find_init(node):
|
||||
"""Recursively search an AST for a FunctionDef node named '__init__'
|
||||
def find_init(node) -> bool:
|
||||
"""Recursively search an AST for a FunctionDef node named '__init__'.
|
||||
|
||||
Args:
|
||||
node: An AST node to search
|
||||
|
|
@ -16,8 +16,4 @@ def find_init(node):
|
|||
return True
|
||||
|
||||
# Recursively check all child nodes
|
||||
for child in ast.iter_child_nodes(node):
|
||||
if find_init(child):
|
||||
return True
|
||||
|
||||
return False
|
||||
return any(find_init(child) for child in ast.iter_child_nodes(node))
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ class SingleOptimizerContext(BaseOptimizerContext):
|
|||
# MultiOptimizerContext #
|
||||
##########################################################################################
|
||||
class MultiOptimizerContext(BaseOptimizerContext):
|
||||
def __init__(self, base_system_prompt: str, base_user_prompt: str, source_code: str):
|
||||
def __init__(self, base_system_prompt: str, base_user_prompt: str, source_code: str) -> None:
|
||||
super().__init__(base_system_prompt, base_user_prompt, source_code)
|
||||
self.original_file_to_code = split_markdown_code(source_code)
|
||||
|
||||
|
|
@ -193,12 +193,11 @@ class MultiOptimizerContext(BaseOptimizerContext):
|
|||
|
||||
def get_user_prompt(self, dependency_code: str, line_profiler_results: str | None) -> str:
|
||||
has_init = any(find_init(ast.parse(code)) for code in self.original_file_to_code.values())
|
||||
user_prompt = (
|
||||
return (
|
||||
f"{DEPS_CONTEXT_PROMPT.format(dependency_code=dependency_code) if dependency_code else ''}"
|
||||
f"{self.base_user_prompt.format(source_code=self.source_code, init_optimization_prompt=INIT_OPTIMIZATION_PROMPT if has_init else '')}"
|
||||
f"{LINE_PROF_CONTEXT_PROMPT.format(line_profiler_results=line_profiler_results) if line_profiler_results else ''}"
|
||||
)
|
||||
return user_prompt
|
||||
|
||||
def extract_code_and_explanation_from_llm_res(self, content: str) -> CodeStrAndExplanation:
|
||||
code_blocks = MARKDOWN_BLOCK_PATTERN.findall(content)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Never
|
||||
|
||||
import libcst as cst
|
||||
from pydantic import ValidationError
|
||||
|
|
@ -37,7 +38,7 @@ class RefinementContextData:
|
|||
# BaseRefinerContext #
|
||||
##########################################################################################
|
||||
class BaseRefinerContext:
|
||||
def __init__(self, ctx_data: RefinementContextData, base_system_prompt: str, base_user_prompt: str):
|
||||
def __init__(self, ctx_data: RefinementContextData, base_system_prompt: str, base_user_prompt: str) -> None:
|
||||
self.data = ctx_data
|
||||
self.base_system_prompt = base_system_prompt
|
||||
self.base_user_prompt = base_user_prompt
|
||||
|
|
@ -89,7 +90,7 @@ class BaseRefinerContext:
|
|||
except cst.ParserSyntaxError:
|
||||
return False
|
||||
|
||||
def validate_python_module(self, code: str):
|
||||
def validate_python_module(self, code: str) -> Never:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
|
@ -97,7 +98,7 @@ class BaseRefinerContext:
|
|||
# SingleRefinerContext #
|
||||
##########################################################################################
|
||||
class SingleRefinerContext(BaseRefinerContext):
|
||||
def __init__(self, ctx_data: RefinementContextData, base_system_prompt: str, base_user_prompt: str):
|
||||
def __init__(self, ctx_data: RefinementContextData, base_system_prompt: str, base_user_prompt: str) -> None:
|
||||
super().__init__(ctx_data, base_system_prompt, base_user_prompt)
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
|
|
@ -134,19 +135,19 @@ class SingleRefinerContext(BaseRefinerContext):
|
|||
def is_valid_refinement(self, new_refined_code: str) -> bool:
|
||||
return super().is_valid_refinement(new_refined_code)
|
||||
|
||||
def validate_python_module(self, code: str):
|
||||
def validate_python_module(self, code: str) -> None:
|
||||
try:
|
||||
cst_module = parse_module_to_cst(code)
|
||||
CodeAndExplanation(cst_module, "")
|
||||
except (ValueError, ValidationError, cst.ParserSyntaxError) as exc:
|
||||
raise exc
|
||||
except (ValueError, ValidationError, cst.ParserSyntaxError):
|
||||
raise
|
||||
|
||||
|
||||
##########################################################################################
|
||||
# MultiRefinerContext #
|
||||
##########################################################################################
|
||||
class MultiRefinerContext(BaseRefinerContext):
|
||||
def __init__(self, ctx_data: RefinementContextData, base_system_prompt: str, base_user_prompt: str):
|
||||
def __init__(self, ctx_data: RefinementContextData, base_system_prompt: str, base_user_prompt: str) -> None:
|
||||
super().__init__(ctx_data, base_system_prompt, base_user_prompt)
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
|
|
@ -190,10 +191,10 @@ class MultiRefinerContext(BaseRefinerContext):
|
|||
break
|
||||
return valid
|
||||
|
||||
def validate_python_module(self, code: str):
|
||||
def validate_python_module(self, code: str) -> None:
|
||||
for code in split_markdown_code(self.data.optimized_source_code).values():
|
||||
try:
|
||||
cst_module = parse_module_to_cst(code)
|
||||
CodeAndExplanation(cst_module, "")
|
||||
except (ValueError, ValidationError, cst.ParserSyntaxError) as exc:
|
||||
raise exc
|
||||
except (ValueError, ValidationError, cst.ParserSyntaxError):
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ async def optimize(request, data: OptimizeSchema) -> tuple[int, OptimizeResponse
|
|||
|
||||
response = OptimizeResponseSchema(optimizations=optimization_response_items)
|
||||
|
||||
def log_response():
|
||||
def log_response() -> None:
|
||||
debug_log_sensitive_data(f"Response:\n{response.json()}")
|
||||
for opt in response.optimizations:
|
||||
debug_log_sensitive_data(f"Optimized source:\n{opt.source_code}")
|
||||
|
|
|
|||
|
|
@ -27,12 +27,11 @@ class RefinerTestCase:
|
|||
|
||||
|
||||
def create_optimizer_context(code: str) -> BaseOptimizerContext:
|
||||
ctx = BaseOptimizerContext.get_dynamic_context("", "", code)
|
||||
return ctx
|
||||
return BaseOptimizerContext.get_dynamic_context("", "", code)
|
||||
|
||||
|
||||
def create_refiner_context(optimized_code: str) -> BaseRefinerContext:
|
||||
ctx = BaseRefinerContext.get_dynamic_context(
|
||||
return BaseRefinerContext.get_dynamic_context(
|
||||
RefinementContextData(
|
||||
original_source_code="",
|
||||
original_line_profiler_results="",
|
||||
|
|
@ -47,7 +46,6 @@ def create_refiner_context(optimized_code: str) -> BaseRefinerContext:
|
|||
"",
|
||||
"",
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
single_optimizer_test_cases: list[OptimizerTestCase] = [
|
||||
|
|
@ -566,7 +564,7 @@ def sorter_test_final10(arr):
|
|||
]
|
||||
|
||||
|
||||
def test_single_optimizer_context():
|
||||
def test_single_optimizer_context() -> None:
|
||||
for t in single_optimizer_test_cases:
|
||||
ctx = create_optimizer_context("")
|
||||
ctx.extract_code_and_explanation_from_llm_res(t.llm_response)
|
||||
|
|
@ -577,7 +575,7 @@ def test_single_optimizer_context():
|
|||
assert ctx.is_valid_code()
|
||||
|
||||
|
||||
def test_multi_optimizer_context():
|
||||
def test_multi_optimizer_context() -> None:
|
||||
for t in multi_optimizer_test_cases:
|
||||
ctx = create_optimizer_context(t.original_code)
|
||||
ctx.extract_code_and_explanation_from_llm_res(t.llm_response)
|
||||
|
|
@ -588,7 +586,7 @@ def test_multi_optimizer_context():
|
|||
assert ctx.is_valid_code()
|
||||
|
||||
|
||||
def test_single_refiner_context():
|
||||
def test_single_refiner_context() -> None:
|
||||
for t in single_refiner_test_cases:
|
||||
ctx = create_refiner_context(t.optimized_code)
|
||||
patches = ctx.extract_diff_patches_from_llm_res(t.llm_response)
|
||||
|
|
@ -597,7 +595,7 @@ def test_single_refiner_context():
|
|||
assert ctx.is_valid_refinement(refined_code)
|
||||
|
||||
|
||||
def test_multi_refiner_context():
|
||||
def test_multi_refiner_context() -> None:
|
||||
for t in multi_refiner_test_cases:
|
||||
ctx = create_refiner_context(t.optimized_code)
|
||||
patches = ctx.extract_diff_patches_from_llm_res(t.llm_response)
|
||||
|
|
@ -607,7 +605,7 @@ def test_multi_refiner_context():
|
|||
assert ctx.is_valid_refinement(refined_code)
|
||||
|
||||
|
||||
def test_split_and_group_markdown_code():
|
||||
def test_split_and_group_markdown_code() -> None:
|
||||
code = """```python:path/to/file1.py
|
||||
file1=__name__
|
||||
```
|
||||
|
|
@ -643,7 +641,7 @@ def sorter2(arr):
|
|||
assert group_code(expected) == code
|
||||
|
||||
|
||||
def test_system_prompt_for_multi_optimizer():
|
||||
def test_system_prompt_for_multi_optimizer() -> None:
|
||||
ctx = create_optimizer_context(
|
||||
"```python:some_file.py\nprint(\"Hello world\")\n```\n```python:some_other_file.py\nprint('hi')\n````\n"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from optimizer.models import CodeExplanationAndID
|
|||
from optimizer.postprocess import DocstringTransformer, DocstringVisitor
|
||||
|
||||
|
||||
def test_function_docstring_preservation():
|
||||
def test_function_docstring_preservation() -> None:
|
||||
# Original code with docstring
|
||||
original_code = """
|
||||
def example_function():
|
||||
|
|
@ -31,7 +31,7 @@ def example_function():
|
|||
assert "This is a docstring for the example function" in transformed_tree.code
|
||||
|
||||
|
||||
def test_function_docstring_in_both_functions():
|
||||
def test_function_docstring_in_both_functions() -> None:
|
||||
# Original code with docstring
|
||||
original_code = """
|
||||
def example_function():
|
||||
|
|
@ -58,7 +58,7 @@ def example_function():
|
|||
assert "This is a docstring for the example function" in transformed_tree.code
|
||||
|
||||
|
||||
def test_class_docstring_preservation():
|
||||
def test_class_docstring_preservation() -> None:
|
||||
# Original code with class docstring
|
||||
original_code = """
|
||||
class ExampleClass:
|
||||
|
|
@ -86,7 +86,7 @@ class ExampleClass:
|
|||
assert "This is a docstring for the example class" in transformed_tree.code
|
||||
|
||||
|
||||
def test_class_docstring_in_both_functions():
|
||||
def test_class_docstring_in_both_functions() -> None:
|
||||
# Original code with class docstring
|
||||
original_code = """
|
||||
class ExampleClass:
|
||||
|
|
@ -115,7 +115,7 @@ class ExampleClass:
|
|||
assert "This is a docstring for the example class" in transformed_tree.code
|
||||
|
||||
|
||||
def test_method_docstring_preservation():
|
||||
def test_method_docstring_preservation() -> None:
|
||||
# Original code with method docstring
|
||||
original_code = """
|
||||
class ExampleClass:
|
||||
|
|
@ -143,7 +143,7 @@ class ExampleClass:
|
|||
assert "This is a docstring for the method" in transformed_tree.code
|
||||
|
||||
|
||||
def test_fix_missing_docstring_pipeline_function():
|
||||
def test_fix_missing_docstring_pipeline_function() -> None:
|
||||
# Test the integration with the fix_missing_docstring pipeline function
|
||||
from optimizer.postprocess import fix_missing_docstring
|
||||
|
||||
|
|
@ -173,7 +173,7 @@ def example_function():
|
|||
@pytest.mark.skip(
|
||||
reason="This currently results in an exception and should be fixed. This test case reproduces the error"
|
||||
)
|
||||
def test_docstring_exception():
|
||||
def test_docstring_exception() -> None:
|
||||
# derived from optimizing https://github.com/pydantic/pydantic-ai/blob/39e28771e538a3a4af98222ca565ecfa402d9c08/pydantic_ai_slim/pydantic_ai/agent.py#L1717
|
||||
original_code = """import dataclasses
|
||||
import warnings
|
||||
|
|
@ -414,11 +414,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|||
]
|
||||
|
||||
# Apply the pipeline function
|
||||
result = fix_missing_docstring(original_code, code_explanations)
|
||||
fix_missing_docstring(original_code, code_explanations)
|
||||
# TODO : This test case fails
|
||||
|
||||
|
||||
def test_benjamin_button():
|
||||
def test_benjamin_button() -> None:
|
||||
original_code = """
|
||||
def test_1():
|
||||
\"\"\"useful docstring
|
||||
|
|
@ -443,14 +443,13 @@ def test_2():
|
|||
# Apply transformer to optimized code
|
||||
transformer = DocstringTransformer(original_visitor.original_docstrings)
|
||||
optimized_tree = cst.parse_module(optimized_code)
|
||||
transformed_tree = optimized_tree.visit(transformer)
|
||||
optimized_tree.visit(transformer)
|
||||
|
||||
# Check if all docstrings were preserved
|
||||
transformed_code = transformed_tree.code
|
||||
assert True
|
||||
|
||||
|
||||
def test_multiple_functions_and_classes():
|
||||
def test_multiple_functions_and_classes() -> None:
|
||||
# Test with multiple functions and classes
|
||||
original_code = """
|
||||
def function1():
|
||||
|
|
@ -548,7 +547,7 @@ class ExampleClass:
|
|||
}
|
||||
|
||||
|
||||
def test_docstring_preservation_with_fixtures(code_with_docstrings, code_without_docstrings):
|
||||
def test_docstring_preservation_with_fixtures(code_with_docstrings, code_without_docstrings) -> None:
|
||||
# Test all three types of docstrings using fixtures
|
||||
for code_type in ["function", "class", "method"]:
|
||||
original_code = code_with_docstrings[code_type]
|
||||
|
|
|
|||
Loading…
Reference in a new issue