ready to review
This commit is contained in:
parent
08a3d2507c
commit
66ba5cefc2
6 changed files with 40 additions and 27 deletions
|
|
@ -2,6 +2,25 @@ from __future__ import annotations
|
|||
|
||||
import uuid
|
||||
|
||||
import isort
|
||||
|
||||
|
||||
def safe_isort_code(code: str, **kwargs) -> str: # noqa: ANN003
|
||||
"""Wrap isort.code to returns the original code if isort fails.
|
||||
|
||||
Args:
|
||||
code (str): The Python source code to sort imports for.
|
||||
**kwargs: Any additional keyword arguments accepted by isort.code.
|
||||
|
||||
Returns:
|
||||
str: The sorted code, or the original code if an exception is raised.
|
||||
|
||||
"""
|
||||
try:
|
||||
return isort.code(code, **kwargs)
|
||||
except Exception: # noqa: BLE001
|
||||
return code
|
||||
|
||||
|
||||
def parse_python_version(version: str) -> tuple[int, int, int]:
|
||||
assert len(version) < 30, "Invalid version format"
|
||||
|
|
|
|||
|
|
@ -5,15 +5,14 @@ import logging
|
|||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import isort
|
||||
import libcst as cst
|
||||
import sentry_sdk
|
||||
from libcst import CSTTransformer, CSTVisitor, Expr, IndentedBlock, SimpleStatementLine, SimpleString
|
||||
from testgen.instrumentation.edit_generated_test import parse_module_to_cst
|
||||
|
||||
from optimizer.code_utils.postprocess_constants import profanity_regex
|
||||
from optimizer.models import CodeExplanationAndID
|
||||
from optimizer.optimizer_utils import compare_unparsed_ast_to_source, unparse_parse_source
|
||||
from testgen.instrumentation.edit_generated_test import parse_module_to_cst
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libcst import FunctionDef
|
||||
|
|
@ -205,7 +204,7 @@ def dedup_and_sort_imports(
|
|||
for ce in optimized_code_and_explanations:
|
||||
try:
|
||||
# Use isort to sort and deduplicate the imports
|
||||
sorted_code = isort.code(ce.cst_module.code, disregard_skip=True)
|
||||
sorted_code = safe_isort(ce.cst_module.code, disregard_skip=True)
|
||||
except Exception: # noqa: BLE001
|
||||
sorted_code = ce.cst_module.code
|
||||
new_optimized_code_and_explanations.append(
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from dataclasses import dataclass
|
|||
import black
|
||||
import isort
|
||||
import sentry_sdk
|
||||
|
||||
from aiservice.models.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
|
||||
from testgen.models import TestingMode
|
||||
|
||||
plat_str = platform.python_version_tuple()
|
||||
|
|
@ -38,7 +38,7 @@ def format_and_float_to_top(code: str) -> str:
|
|||
original_level = logger.level
|
||||
logger.setLevel(logging.INFO) # Suppress debug logs from black, which spams the aiservice console
|
||||
try:
|
||||
formatted_code = black.format_str(isort.code(code, config=isort.Config(float_to_top=True)), mode=black.Mode())
|
||||
formatted_code = black.format_str(safe_isort(code, config=isort.Config(float_to_top=True)), mode=black.Mode())
|
||||
except Exception:
|
||||
formatted_code = code
|
||||
logger.setLevel(original_level)
|
||||
|
|
@ -172,7 +172,7 @@ class InjectPerfAndLogging(ast.NodeTransformer):
|
|||
isinstance(node.func, ast.Attribute) and node.func.attr == self.only_function_name
|
||||
)
|
||||
# Check for await expressions (for async functions)
|
||||
elif isinstance(node, ast.Await) and isinstance(node.value, ast.Call):
|
||||
if isinstance(node, ast.Await) and isinstance(node.value, ast.Call):
|
||||
return (isinstance(node.value.func, ast.Name) and node.value.func.id == self.only_function_name) or (
|
||||
isinstance(node.value.func, ast.Attribute) and node.value.func.attr == self.only_function_name
|
||||
)
|
||||
|
|
@ -340,9 +340,7 @@ class InjectPerfAndLogging(ast.NodeTransformer):
|
|||
def find_target_function_call(self, node: ast.AST) -> ast.Call | ast.Await | None:
|
||||
for child in ast.walk(node):
|
||||
if self.is_target_function_node(child):
|
||||
if isinstance(child, ast.Call):
|
||||
return child
|
||||
elif isinstance(child, ast.Await):
|
||||
if isinstance(child, ast.Call) or isinstance(child, ast.Await):
|
||||
return child
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
import ast
|
||||
import logging
|
||||
|
||||
import isort
|
||||
import sentry_sdk
|
||||
|
||||
|
||||
|
|
@ -22,7 +21,7 @@ def validate_testgen_code(code: str, python_version: tuple[int, int, int], max_l
|
|||
ValueError: If no valid code is found after checking lines
|
||||
|
||||
"""
|
||||
fixed_code = isort.code(code, float_to_top=True)
|
||||
fixed_code = safe_isort(code, float_to_top=True)
|
||||
# Split code into lines and calculate end positions
|
||||
lines = fixed_code.splitlines(keepends=True)
|
||||
end_positions = []
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import os
|
|||
from pathlib import Path
|
||||
from typing import SupportsIndex
|
||||
|
||||
import isort
|
||||
from aiservice.common_utils import parse_python_version
|
||||
from aiservice.env_specific import create_openai_client, debug_log_sensitive_data
|
||||
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL, calculate_llm_cost
|
||||
|
|
@ -341,7 +340,7 @@ To help unit test the function above, list diverse scenarios that the function s
|
|||
if tries == 0:
|
||||
raise TestGenerationFailedException("Failed to generate test code after 2 tries.")
|
||||
|
||||
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
||||
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
||||
# return the unit test as a string
|
||||
return code
|
||||
|
||||
|
|
@ -380,8 +379,8 @@ class TestGenErrorResponseSchema(Schema):
|
|||
error: str
|
||||
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
import sentry_sdk
|
||||
from aiservice.analytics.posthog import ph
|
||||
|
||||
|
||||
@testgen_api.post(
|
||||
|
|
@ -425,7 +424,7 @@ async def testgen(request, data: TestGenSchema) -> tuple[int, TestGenResponseSch
|
|||
trace_id=data.trace_id,
|
||||
)
|
||||
print("/testgen: Instrumenting tests...")
|
||||
instrumented_test_source = isort.code(
|
||||
instrumented_test_source = safe_isort(
|
||||
instrument_test_source(
|
||||
test_source=generated_test_source,
|
||||
function_to_optimize=data.function_to_optimize,
|
||||
|
|
@ -442,7 +441,7 @@ async def testgen(request, data: TestGenSchema) -> tuple[int, TestGenResponseSch
|
|||
generated_test_source, data.function_to_optimize, data.module_path
|
||||
)
|
||||
# Use isort to sort and deduplicate the imports in the generated test code
|
||||
sorted_imports_test_source = isort.code(generated_test_source)
|
||||
sorted_imports_test_source = safe_isort(generated_test_source)
|
||||
try:
|
||||
parse_module_to_cst(sorted_imports_test_source)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -8,20 +8,19 @@ import re
|
|||
from pathlib import Path
|
||||
from typing import SupportsIndex
|
||||
|
||||
import isort
|
||||
import sentry_sdk
|
||||
import stamina
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
|
||||
from aiservice.env_specific import IS_PRODUCTION, debug_log_sensitive_data, open_ai_client
|
||||
from aiservice.models.aimodels import EXECUTE_MODEL, LLM, calculate_llm_cost
|
||||
from authapp.auth import AuthBearer
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features_optimized
|
||||
from ninja import NinjaAPI
|
||||
from ninja.errors import HttpError
|
||||
from openai import OpenAIError
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
|
||||
from aiservice.env_specific import IS_PRODUCTION, debug_log_sensitive_data, open_ai_client
|
||||
from aiservice.models.aimodels import EXECUTE_MODEL, LLM, calculate_llm_cost # noqa: TC001
|
||||
from authapp.auth import AuthBearer # noqa: TC001
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features_optimized
|
||||
from testgen.instrumentation.edit_generated_test import parse_module_to_cst, replace_definition_with_import
|
||||
from testgen.instrumentation.instrument_new_tests import instrument_test_source
|
||||
from testgen.models import (
|
||||
|
|
@ -153,10 +152,10 @@ def instrument_tests(
|
|||
}
|
||||
|
||||
behavior_result = instrument_test_source(**common_args, mode=TestingMode.BEHAVIOR)
|
||||
instrumented_behavior = isort.code(behavior_result, float_to_top=True) if behavior_result is not None else None
|
||||
instrumented_behavior = safe_isort(behavior_result, float_to_top=True) if behavior_result is not None else None
|
||||
|
||||
perf_result = instrument_test_source(**common_args, mode=TestingMode.PERFORMANCE)
|
||||
instrumented_perf = isort.code(perf_result, float_to_top=True) if perf_result is not None else None
|
||||
instrumented_perf = safe_isort(perf_result, float_to_top=True) if perf_result is not None else None
|
||||
|
||||
return instrumented_behavior, instrumented_perf
|
||||
|
||||
|
|
@ -271,7 +270,7 @@ async def generate_regression_tests_from_function(
|
|||
msg = "There was an error detected in the function to optimize, is it valid Python code?"
|
||||
raise TestGenerationFailedError(msg)
|
||||
|
||||
sorted_generated_tests = isort.code(generated_test_source)
|
||||
sorted_generated_tests = safe_isort(generated_test_source)
|
||||
|
||||
try:
|
||||
parse_module_to_cst(sorted_generated_tests)
|
||||
|
|
|
|||
Loading…
Reference in a new issue