perf: optimize postprocessing pipeline — eliminate redundant CST codegen (#2526)
## Summary - Replace Pydantic frozen dataclass with stdlib `@dataclass(frozen=True)` for `CodeExplanationAndID` and `CodeAndExplanation`, removing `field_validator` that ran `.code` + `compile()` ~280 times per pipeline run - Pre-compute `original_module.code` once and pass to pipeline steps (`clean_extraneous_comments`, `equality_check`) that previously called it independently - Replace `ast.dump(annotate_fields=False)` with `ast.unparse` in `deduplicate_optimizations` (70% faster) - Skip re-parse in `dedup_and_sort_imports` when isort returns unchanged code - Cache comment-stripped original code across candidates in `clean_extraneous_comments` **Pipeline median per-run: ~1.5s → 184ms** (4 candidates, controlled measurement). Saves ~4-5s of CPU per optimization request in production. ## Test plan - [x] All 558 unit tests pass - [x] mypy clean - [x] ruff clean (no new warnings) - [ ] Verify optimizer endpoints return correct results in staging --------- Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
This commit is contained in:
parent
d0e97992d6
commit
0029a0e76e
2 changed files with 78 additions and 71 deletions
|
|
@ -1,8 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import libcst
|
||||
from pydantic import field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from core.shared.optimizer_models import OptimizeSchema
|
||||
|
||||
|
|
@ -12,19 +12,6 @@ class CodeAndExplanation:
|
|||
cst_module: libcst.Module | None
|
||||
explanation: str
|
||||
|
||||
@field_validator("cst_module")
|
||||
def validate_cst_module(cls, v):
|
||||
if not isinstance(v, libcst.Module):
|
||||
raise ValueError("cst_module must be an instance of libcst.Module")
|
||||
try:
|
||||
# Unparse the CST module to get the source code
|
||||
source_code = v.code
|
||||
# Compile the source code to check for syntax errors
|
||||
compile(source_code, "<string>", "exec")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid cst_module, compilation error: {e}")
|
||||
return v
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CodeExplanationAndID:
|
||||
|
|
@ -32,19 +19,6 @@ class CodeExplanationAndID:
|
|||
explanation: str
|
||||
id: str
|
||||
|
||||
@field_validator("cst_module")
|
||||
def validate_cst_module(cls, v):
|
||||
if not isinstance(v, libcst.Module):
|
||||
raise ValueError("cst_module must be an instance of libcst.Module")
|
||||
try:
|
||||
# Unparse the CST module to get the source code
|
||||
source_code = v.code
|
||||
# Compile the source code to check for syntax errors
|
||||
compile(source_code, "<string>", "exec")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid cst_module, compilation error: {e}")
|
||||
return v
|
||||
|
||||
|
||||
class JitRewriteOptimizeSchema(OptimizeSchema):
|
||||
n_candidates: int = 1 # default value for backward compatibility
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ import libcst as cst
|
|||
import sentry_sdk
|
||||
from libcst import BaseStatement, CSTTransformer, CSTVisitor, Expr, IndentedBlock, SimpleStatementLine, SimpleString
|
||||
|
||||
from core.languages.python.cst_utils import compare_unparsed_ast_to_source, parse_module_to_cst, unparse_parse_source
|
||||
from aiservice.common_utils import safe_isort
|
||||
from core.languages.python.cst_utils import compare_unparsed_ast_to_source, parse_module_to_cst, unparse_parse_source
|
||||
from core.languages.python.optimizer.models import CodeExplanationAndID
|
||||
from core.languages.python.testgen.postprocessing.add_missing_imports import add_future_annotations_import
|
||||
|
||||
|
|
@ -43,29 +43,27 @@ def deduplicate_optimizations(
|
|||
List[CodeExplanationAndID]: A list of CodeExplanationAndID objects with duplicates removed.
|
||||
|
||||
"""
|
||||
seen_asts = set()
|
||||
seen_asts: set[str] = set()
|
||||
unique_optimizations = []
|
||||
for optimization in optimized_code_and_explanations:
|
||||
code_ast = ast.parse(optimization.cst_module.code)
|
||||
code_ast_tuple = ast.dump(code_ast, annotate_fields=False)
|
||||
if code_ast_tuple not in seen_asts:
|
||||
seen_asts.add(code_ast_tuple)
|
||||
normalized = ast.unparse(ast.parse(optimization.cst_module.code))
|
||||
if normalized not in seen_asts:
|
||||
seen_asts.add(normalized)
|
||||
unique_optimizations.append(optimization)
|
||||
return unique_optimizations
|
||||
|
||||
|
||||
def equality_check(
|
||||
original_module: cst.Module, optimized_code_and_explanations: list[CodeExplanationAndID]
|
||||
original_module: cst.Module,
|
||||
optimized_code_and_explanations: list[CodeExplanationAndID],
|
||||
*,
|
||||
original_code: str | None = None,
|
||||
) -> list[CodeExplanationAndID]:
|
||||
original_source_code = original_module.code
|
||||
original_source_code = original_code if original_code is not None else original_module.code
|
||||
try:
|
||||
original_source_ast = unparse_parse_source(original_source_code)
|
||||
except Exception:
|
||||
return [
|
||||
CodeExplanationAndID(cst_module=ce.cst_module, explanation=ce.explanation, id=ce.id)
|
||||
for ce in optimized_code_and_explanations
|
||||
if ce.cst_module.code != original_source_code
|
||||
]
|
||||
return [ce for ce in optimized_code_and_explanations if ce.cst_module.code != original_source_code]
|
||||
filtered_optimizations = []
|
||||
for ce in optimized_code_and_explanations:
|
||||
try:
|
||||
|
|
@ -147,13 +145,13 @@ class DocstringTransformer(CSTTransformer):
|
|||
if not updated_node.get_docstring(clean=False):
|
||||
new_body: list[BaseStatement] = [
|
||||
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
|
||||
*cast(list[BaseStatement], list(updated_node.body.body)),
|
||||
*cast("list[BaseStatement]", list(updated_node.body.body)),
|
||||
]
|
||||
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
|
||||
else:
|
||||
new_body = [
|
||||
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
|
||||
*cast(list[BaseStatement], list(updated_node.body.body[1:])),
|
||||
*cast("list[BaseStatement]", list(updated_node.body.body[1:])),
|
||||
]
|
||||
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
|
||||
self.class_name = None
|
||||
|
|
@ -167,13 +165,13 @@ class DocstringTransformer(CSTTransformer):
|
|||
if not updated_node.get_docstring(clean=False):
|
||||
new_body: list[BaseStatement] = [
|
||||
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
|
||||
*cast(list[BaseStatement], list(updated_node.body.body)),
|
||||
*cast("list[BaseStatement]", list(updated_node.body.body)),
|
||||
]
|
||||
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
|
||||
else:
|
||||
new_body = [
|
||||
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
|
||||
*cast(list[BaseStatement], list(updated_node.body.body[1:])),
|
||||
*cast("list[BaseStatement]", list(updated_node.body.body[1:])),
|
||||
]
|
||||
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
|
||||
return updated_node
|
||||
|
|
@ -207,13 +205,18 @@ def dedup_and_sort_imports(
|
|||
new_optimized_code_and_explanations = []
|
||||
for ce in optimized_code_and_explanations:
|
||||
try:
|
||||
# Use isort to sort and deduplicate the imports
|
||||
sorted_code = safe_isort(ce.cst_module.code, disregard_skip=True)
|
||||
original_code = ce.cst_module.code
|
||||
sorted_code = safe_isort(original_code, disregard_skip=True)
|
||||
except Exception:
|
||||
sorted_code = ce.cst_module.code
|
||||
new_optimized_code_and_explanations.append(
|
||||
CodeExplanationAndID(cst_module=parse_module_to_cst(sorted_code), explanation=ce.explanation, id=ce.id)
|
||||
)
|
||||
new_optimized_code_and_explanations.append(ce)
|
||||
continue
|
||||
# Skip re-parse if isort didn't change anything
|
||||
if sorted_code == original_code:
|
||||
new_optimized_code_and_explanations.append(ce)
|
||||
else:
|
||||
new_optimized_code_and_explanations.append(
|
||||
CodeExplanationAndID(cst_module=parse_module_to_cst(sorted_code), explanation=ce.explanation, id=ce.id)
|
||||
)
|
||||
|
||||
return new_optimized_code_and_explanations
|
||||
|
||||
|
|
@ -299,7 +302,13 @@ def _strip_comments_from_code(code: str) -> str:
|
|||
return code
|
||||
|
||||
|
||||
def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst.Module) -> cst.Module:
|
||||
def clean_extraneous_comments(
|
||||
original_module: cst.Module,
|
||||
optimized_module: cst.Module,
|
||||
*,
|
||||
orig_code: str | None = None,
|
||||
orig_code_stripped: str | None = None,
|
||||
) -> cst.Module:
|
||||
"""Clean extraneous comments from optimized code using difflib.
|
||||
|
||||
Uses diff-based approach on code (without comments) to identify which lines
|
||||
|
|
@ -309,6 +318,8 @@ def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst
|
|||
----
|
||||
original_module: The original CST module.
|
||||
optimized_module: The optimized CST module with potential extra comments.
|
||||
orig_code: Pre-computed original module code (avoids redundant codegen).
|
||||
orig_code_stripped: Pre-computed comment-stripped original code.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
|
|
@ -316,14 +327,19 @@ def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst
|
|||
|
||||
"""
|
||||
try:
|
||||
# Get line-by-line representation
|
||||
orig_lines = original_module.code.splitlines(keepends=True)
|
||||
opt_lines = optimized_module.code.splitlines(keepends=True)
|
||||
# Get line-by-line representation, reusing pre-computed strings when available
|
||||
if orig_code is None:
|
||||
orig_code = original_module.code
|
||||
opt_code_full = optimized_module.code
|
||||
|
||||
orig_lines = orig_code.splitlines(keepends=True)
|
||||
opt_lines = opt_code_full.splitlines(keepends=True)
|
||||
|
||||
# Strip comments from entire code to identify code changes
|
||||
# This properly handles # symbols inside strings
|
||||
orig_code_stripped = _strip_comments_from_code(original_module.code)
|
||||
opt_code_stripped = _strip_comments_from_code(optimized_module.code)
|
||||
if orig_code_stripped is None:
|
||||
orig_code_stripped = _strip_comments_from_code(orig_code)
|
||||
opt_code_stripped = _strip_comments_from_code(opt_code_full)
|
||||
|
||||
# Split stripped versions into lines
|
||||
orig_code_only = orig_code_stripped.splitlines(keepends=True)
|
||||
|
|
@ -523,18 +539,29 @@ def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst
|
|||
|
||||
|
||||
def clean_extraneous_comments_pipeline(
|
||||
original_module: cst.Module, optimized_code_and_explanations: list[CodeExplanationAndID]
|
||||
original_module: cst.Module,
|
||||
optimized_code_and_explanations: list[CodeExplanationAndID],
|
||||
*,
|
||||
orig_code: str | None = None,
|
||||
orig_code_stripped: str | None = None,
|
||||
) -> list[CodeExplanationAndID]:
|
||||
"""Pipeline wrapper for comment cleaning.
|
||||
|
||||
Cleans extraneous comments from all optimized code variants.
|
||||
Pre-computes original code and stripped version once for all candidates.
|
||||
"""
|
||||
try:
|
||||
cleaned_results = []
|
||||
if orig_code is None:
|
||||
orig_code = original_module.code
|
||||
if orig_code_stripped is None:
|
||||
orig_code_stripped = _strip_comments_from_code(orig_code)
|
||||
|
||||
for ce in optimized_code_and_explanations:
|
||||
try:
|
||||
cleaned_module = clean_extraneous_comments(original_module, ce.cst_module)
|
||||
cleaned_module = clean_extraneous_comments(
|
||||
original_module, ce.cst_module, orig_code=orig_code, orig_code_stripped=orig_code_stripped
|
||||
)
|
||||
cleaned_results.append(
|
||||
CodeExplanationAndID(cst_module=cleaned_module, explanation=ce.explanation, id=ce.id)
|
||||
)
|
||||
|
|
@ -581,17 +608,23 @@ def fix_forward_references(
|
|||
def optimizations_postprocessing_pipeline(
|
||||
original_module: cst.Module, optimized_code_and_explanations: list[CodeExplanationAndID]
|
||||
) -> list[CodeExplanationAndID]:
|
||||
pipeline = [
|
||||
fix_missing_docstring, # We want to deduplicate with the fixed docstrings included
|
||||
clean_extraneous_comments_pipeline, # Clean comments added to unchanged code
|
||||
fix_forward_references, # Add future annotations for forward references
|
||||
deduplicate_optimizations,
|
||||
equality_check,
|
||||
dedup_and_sort_imports,
|
||||
cleanup_explanations,
|
||||
filter_ellipsis_containing_code,
|
||||
]
|
||||
# Pre-compute original code string once — avoids redundant CST codegen across steps
|
||||
original_code = original_module.code
|
||||
original_code_stripped = _strip_comments_from_code(original_code)
|
||||
|
||||
for pipeline_fn in pipeline:
|
||||
optimized_code_and_explanations = pipeline_fn(original_module, optimized_code_and_explanations)
|
||||
optimized_code_and_explanations = fix_missing_docstring(original_module, optimized_code_and_explanations)
|
||||
optimized_code_and_explanations = clean_extraneous_comments_pipeline(
|
||||
original_module,
|
||||
optimized_code_and_explanations,
|
||||
orig_code=original_code,
|
||||
orig_code_stripped=original_code_stripped,
|
||||
)
|
||||
optimized_code_and_explanations = fix_forward_references(original_module, optimized_code_and_explanations)
|
||||
optimized_code_and_explanations = deduplicate_optimizations(original_module, optimized_code_and_explanations)
|
||||
optimized_code_and_explanations = equality_check(
|
||||
original_module, optimized_code_and_explanations, original_code=original_code
|
||||
)
|
||||
optimized_code_and_explanations = dedup_and_sort_imports(original_module, optimized_code_and_explanations)
|
||||
optimized_code_and_explanations = cleanup_explanations(original_module, optimized_code_and_explanations)
|
||||
optimized_code_and_explanations = filter_ellipsis_containing_code(original_module, optimized_code_and_explanations)
|
||||
return optimized_code_and_explanations
|
||||
|
|
|
|||
Loading…
Reference in a new issue