ready to review

This commit is contained in:
Aseem Saxena 2025-10-22 21:08:06 -07:00
parent 08a3d2507c
commit 66ba5cefc2
6 changed files with 40 additions and 27 deletions

View file

@ -2,6 +2,25 @@ from __future__ import annotations
import uuid
import isort
def safe_isort_code(code: str, **kwargs) -> str: # noqa: ANN003
"""Wrap isort.code to returns the original code if isort fails.
Args:
code (str): The Python source code to sort imports for.
**kwargs: Any additional keyword arguments accepted by isort.code.
Returns:
str: The sorted code, or the original code if an exception is raised.
"""
try:
return isort.code(code, **kwargs)
except Exception: # noqa: BLE001
return code
def parse_python_version(version: str) -> tuple[int, int, int]:
assert len(version) < 30, "Invalid version format"

View file

@ -5,15 +5,14 @@ import logging
import re
from typing import TYPE_CHECKING
import isort
import libcst as cst
import sentry_sdk
from libcst import CSTTransformer, CSTVisitor, Expr, IndentedBlock, SimpleStatementLine, SimpleString
from testgen.instrumentation.edit_generated_test import parse_module_to_cst
from optimizer.code_utils.postprocess_constants import profanity_regex
from optimizer.models import CodeExplanationAndID
from optimizer.optimizer_utils import compare_unparsed_ast_to_source, unparse_parse_source
from testgen.instrumentation.edit_generated_test import parse_module_to_cst
if TYPE_CHECKING:
from libcst import FunctionDef
@ -205,7 +204,7 @@ def dedup_and_sort_imports(
for ce in optimized_code_and_explanations:
try:
# Use isort to sort and deduplicate the imports
sorted_code = isort.code(ce.cst_module.code, disregard_skip=True)
sorted_code = safe_isort(ce.cst_module.code, disregard_skip=True)
except Exception: # noqa: BLE001
sorted_code = ce.cst_module.code
new_optimized_code_and_explanations.append(

View file

@ -8,8 +8,8 @@ from dataclasses import dataclass
import black
import isort
import sentry_sdk
from aiservice.models.functions_to_optimize import FunctionParent, FunctionToOptimize
from testgen.models import TestingMode
plat_str = platform.python_version_tuple()
@ -38,7 +38,7 @@ def format_and_float_to_top(code: str) -> str:
original_level = logger.level
logger.setLevel(logging.INFO) # Suppress debug logs from black, which spams the aiservice console
try:
formatted_code = black.format_str(isort.code(code, config=isort.Config(float_to_top=True)), mode=black.Mode())
formatted_code = black.format_str(safe_isort(code, config=isort.Config(float_to_top=True)), mode=black.Mode())
except Exception:
formatted_code = code
logger.setLevel(original_level)
@ -172,7 +172,7 @@ class InjectPerfAndLogging(ast.NodeTransformer):
isinstance(node.func, ast.Attribute) and node.func.attr == self.only_function_name
)
# Check for await expressions (for async functions)
elif isinstance(node, ast.Await) and isinstance(node.value, ast.Call):
if isinstance(node, ast.Await) and isinstance(node.value, ast.Call):
return (isinstance(node.value.func, ast.Name) and node.value.func.id == self.only_function_name) or (
isinstance(node.value.func, ast.Attribute) and node.value.func.attr == self.only_function_name
)
@ -340,9 +340,7 @@ class InjectPerfAndLogging(ast.NodeTransformer):
def find_target_function_call(self, node: ast.AST) -> ast.Call | ast.Await | None:
for child in ast.walk(node):
if self.is_target_function_node(child):
if isinstance(child, ast.Call):
return child
elif isinstance(child, ast.Await):
if isinstance(child, ast.Call) or isinstance(child, ast.Await):
return child
return None

View file

@ -3,7 +3,6 @@ from __future__ import annotations
import ast
import logging
import isort
import sentry_sdk
@ -22,7 +21,7 @@ def validate_testgen_code(code: str, python_version: tuple[int, int, int], max_l
ValueError: If no valid code is found after checking lines
"""
fixed_code = isort.code(code, float_to_top=True)
fixed_code = safe_isort(code, float_to_top=True)
# Split code into lines and calculate end positions
lines = fixed_code.splitlines(keepends=True)
end_positions = []

View file

@ -7,7 +7,6 @@ import os
from pathlib import Path
from typing import SupportsIndex
import isort
from aiservice.common_utils import parse_python_version
from aiservice.env_specific import create_openai_client, debug_log_sensitive_data
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL, calculate_llm_cost
@ -341,7 +340,7 @@ To help unit test the function above, list diverse scenarios that the function s
if tries == 0:
raise TestGenerationFailedException("Failed to generate test code after 2 tries.")
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
# return the unit test as a string
return code
@ -380,8 +379,8 @@ class TestGenErrorResponseSchema(Schema):
error: str
from aiservice.analytics.posthog import ph
import sentry_sdk
from aiservice.analytics.posthog import ph
@testgen_api.post(
@ -425,7 +424,7 @@ async def testgen(request, data: TestGenSchema) -> tuple[int, TestGenResponseSch
trace_id=data.trace_id,
)
print("/testgen: Instrumenting tests...")
instrumented_test_source = isort.code(
instrumented_test_source = safe_isort(
instrument_test_source(
test_source=generated_test_source,
function_to_optimize=data.function_to_optimize,
@ -442,7 +441,7 @@ async def testgen(request, data: TestGenSchema) -> tuple[int, TestGenResponseSch
generated_test_source, data.function_to_optimize, data.module_path
)
# Use isort to sort and deduplicate the imports in the generated test code
sorted_imports_test_source = isort.code(generated_test_source)
sorted_imports_test_source = safe_isort(generated_test_source)
try:
parse_module_to_cst(sorted_imports_test_source)
except Exception as e:

View file

@ -8,20 +8,19 @@ import re
from pathlib import Path
from typing import SupportsIndex
import isort
import sentry_sdk
import stamina
from aiservice.analytics.posthog import ph
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
from aiservice.env_specific import IS_PRODUCTION, debug_log_sensitive_data, open_ai_client
from aiservice.models.aimodels import EXECUTE_MODEL, LLM, calculate_llm_cost
from authapp.auth import AuthBearer
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features_optimized
from ninja import NinjaAPI
from ninja.errors import HttpError
from openai import OpenAIError
from aiservice.analytics.posthog import ph
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
from aiservice.env_specific import IS_PRODUCTION, debug_log_sensitive_data, open_ai_client
from aiservice.models.aimodels import EXECUTE_MODEL, LLM, calculate_llm_cost # noqa: TC001
from authapp.auth import AuthBearer # noqa: TC001
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features_optimized
from testgen.instrumentation.edit_generated_test import parse_module_to_cst, replace_definition_with_import
from testgen.instrumentation.instrument_new_tests import instrument_test_source
from testgen.models import (
@ -153,10 +152,10 @@ def instrument_tests(
}
behavior_result = instrument_test_source(**common_args, mode=TestingMode.BEHAVIOR)
instrumented_behavior = isort.code(behavior_result, float_to_top=True) if behavior_result is not None else None
instrumented_behavior = safe_isort(behavior_result, float_to_top=True) if behavior_result is not None else None
perf_result = instrument_test_source(**common_args, mode=TestingMode.PERFORMANCE)
instrumented_perf = isort.code(perf_result, float_to_top=True) if perf_result is not None else None
instrumented_perf = safe_isort(perf_result, float_to_top=True) if perf_result is not None else None
return instrumented_behavior, instrumented_perf
@ -271,7 +270,7 @@ async def generate_regression_tests_from_function(
msg = "There was an error detected in the function to optimize, is it valid Python code?"
raise TestGenerationFailedError(msg)
sorted_generated_tests = isort.code(generated_test_source)
sorted_generated_tests = safe_isort(generated_test_source)
try:
parse_module_to_cst(sorted_generated_tests)