perf: optimize postprocessing pipeline — eliminate redundant CST codegen (#2526)

## Summary

- Replace Pydantic frozen dataclass with stdlib
`@dataclass(frozen=True)` for `CodeExplanationAndID` and
`CodeAndExplanation`, removing `field_validator` that ran `.code` +
`compile()` ~280 times per pipeline run
- Pre-compute `original_module.code` once and pass to pipeline steps
(`clean_extraneous_comments`, `equality_check`) that previously called
it independently
- Replace `ast.dump(annotate_fields=False)` with `ast.unparse` in
`deduplicate_optimizations` (70% faster)
- Skip re-parse in `dedup_and_sort_imports` when isort returns unchanged
code
- Cache comment-stripped original code across candidates in
`clean_extraneous_comments`

**Pipeline median per-run: ~1.5s → 184ms** (4 candidates, controlled
measurement). Saves ~4-5s of CPU per optimization request in production.

## Test plan

- [x] All 558 unit tests pass
- [x] mypy clean
- [x] ruff clean (no new warnings)
- [ ] Verify optimizer endpoints return correct results in staging

---------

Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
This commit is contained in:
Kevin Turcios 2026-04-02 19:37:15 -05:00 committed by GitHub
parent d0e97992d6
commit 0029a0e76e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 78 additions and 71 deletions

View file

@ -1,8 +1,8 @@
from __future__ import annotations
from dataclasses import dataclass
import libcst
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from core.shared.optimizer_models import OptimizeSchema
@ -12,19 +12,6 @@ class CodeAndExplanation:
cst_module: libcst.Module | None
explanation: str
@field_validator("cst_module")
def validate_cst_module(cls, v):
if not isinstance(v, libcst.Module):
raise ValueError("cst_module must be an instance of libcst.Module")
try:
# Unparse the CST module to get the source code
source_code = v.code
# Compile the source code to check for syntax errors
compile(source_code, "<string>", "exec")
except Exception as e:
raise ValueError(f"Invalid cst_module, compilation error: {e}")
return v
@dataclass(frozen=True)
class CodeExplanationAndID:
@ -32,19 +19,6 @@ class CodeExplanationAndID:
explanation: str
id: str
@field_validator("cst_module")
def validate_cst_module(cls, v):
if not isinstance(v, libcst.Module):
raise ValueError("cst_module must be an instance of libcst.Module")
try:
# Unparse the CST module to get the source code
source_code = v.code
# Compile the source code to check for syntax errors
compile(source_code, "<string>", "exec")
except Exception as e:
raise ValueError(f"Invalid cst_module, compilation error: {e}")
return v
class JitRewriteOptimizeSchema(OptimizeSchema):
n_candidates: int = 1 # default value for backward compatibility

View file

@ -12,8 +12,8 @@ import libcst as cst
import sentry_sdk
from libcst import BaseStatement, CSTTransformer, CSTVisitor, Expr, IndentedBlock, SimpleStatementLine, SimpleString
from core.languages.python.cst_utils import compare_unparsed_ast_to_source, parse_module_to_cst, unparse_parse_source
from aiservice.common_utils import safe_isort
from core.languages.python.cst_utils import compare_unparsed_ast_to_source, parse_module_to_cst, unparse_parse_source
from core.languages.python.optimizer.models import CodeExplanationAndID
from core.languages.python.testgen.postprocessing.add_missing_imports import add_future_annotations_import
@ -43,29 +43,27 @@ def deduplicate_optimizations(
List[CodeExplanationAndID]: A list of CodeExplanationAndID objects with duplicates removed.
"""
seen_asts = set()
seen_asts: set[str] = set()
unique_optimizations = []
for optimization in optimized_code_and_explanations:
code_ast = ast.parse(optimization.cst_module.code)
code_ast_tuple = ast.dump(code_ast, annotate_fields=False)
if code_ast_tuple not in seen_asts:
seen_asts.add(code_ast_tuple)
normalized = ast.unparse(ast.parse(optimization.cst_module.code))
if normalized not in seen_asts:
seen_asts.add(normalized)
unique_optimizations.append(optimization)
return unique_optimizations
def equality_check(
original_module: cst.Module, optimized_code_and_explanations: list[CodeExplanationAndID]
original_module: cst.Module,
optimized_code_and_explanations: list[CodeExplanationAndID],
*,
original_code: str | None = None,
) -> list[CodeExplanationAndID]:
original_source_code = original_module.code
original_source_code = original_code if original_code is not None else original_module.code
try:
original_source_ast = unparse_parse_source(original_source_code)
except Exception:
return [
CodeExplanationAndID(cst_module=ce.cst_module, explanation=ce.explanation, id=ce.id)
for ce in optimized_code_and_explanations
if ce.cst_module.code != original_source_code
]
return [ce for ce in optimized_code_and_explanations if ce.cst_module.code != original_source_code]
filtered_optimizations = []
for ce in optimized_code_and_explanations:
try:
@ -147,13 +145,13 @@ class DocstringTransformer(CSTTransformer):
if not updated_node.get_docstring(clean=False):
new_body: list[BaseStatement] = [
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
*cast(list[BaseStatement], list(updated_node.body.body)),
*cast("list[BaseStatement]", list(updated_node.body.body)),
]
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
else:
new_body = [
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
*cast(list[BaseStatement], list(updated_node.body.body[1:])),
*cast("list[BaseStatement]", list(updated_node.body.body[1:])),
]
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
self.class_name = None
@ -167,13 +165,13 @@ class DocstringTransformer(CSTTransformer):
if not updated_node.get_docstring(clean=False):
new_body: list[BaseStatement] = [
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
*cast(list[BaseStatement], list(updated_node.body.body)),
*cast("list[BaseStatement]", list(updated_node.body.body)),
]
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
else:
new_body = [
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
*cast(list[BaseStatement], list(updated_node.body.body[1:])),
*cast("list[BaseStatement]", list(updated_node.body.body[1:])),
]
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
return updated_node
@ -207,13 +205,18 @@ def dedup_and_sort_imports(
new_optimized_code_and_explanations = []
for ce in optimized_code_and_explanations:
try:
# Use isort to sort and deduplicate the imports
sorted_code = safe_isort(ce.cst_module.code, disregard_skip=True)
original_code = ce.cst_module.code
sorted_code = safe_isort(original_code, disregard_skip=True)
except Exception:
sorted_code = ce.cst_module.code
new_optimized_code_and_explanations.append(
CodeExplanationAndID(cst_module=parse_module_to_cst(sorted_code), explanation=ce.explanation, id=ce.id)
)
new_optimized_code_and_explanations.append(ce)
continue
# Skip re-parse if isort didn't change anything
if sorted_code == original_code:
new_optimized_code_and_explanations.append(ce)
else:
new_optimized_code_and_explanations.append(
CodeExplanationAndID(cst_module=parse_module_to_cst(sorted_code), explanation=ce.explanation, id=ce.id)
)
return new_optimized_code_and_explanations
@ -299,7 +302,13 @@ def _strip_comments_from_code(code: str) -> str:
return code
def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst.Module) -> cst.Module:
def clean_extraneous_comments(
original_module: cst.Module,
optimized_module: cst.Module,
*,
orig_code: str | None = None,
orig_code_stripped: str | None = None,
) -> cst.Module:
"""Clean extraneous comments from optimized code using difflib.
Uses diff-based approach on code (without comments) to identify which lines
@ -309,6 +318,8 @@ def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst
----
original_module: The original CST module.
optimized_module: The optimized CST module with potential extra comments.
orig_code: Pre-computed original module code (avoids redundant codegen).
orig_code_stripped: Pre-computed comment-stripped original code.
Returns:
-------
@ -316,14 +327,19 @@ def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst
"""
try:
# Get line-by-line representation
orig_lines = original_module.code.splitlines(keepends=True)
opt_lines = optimized_module.code.splitlines(keepends=True)
# Get line-by-line representation, reusing pre-computed strings when available
if orig_code is None:
orig_code = original_module.code
opt_code_full = optimized_module.code
orig_lines = orig_code.splitlines(keepends=True)
opt_lines = opt_code_full.splitlines(keepends=True)
# Strip comments from entire code to identify code changes
# This properly handles # symbols inside strings
orig_code_stripped = _strip_comments_from_code(original_module.code)
opt_code_stripped = _strip_comments_from_code(optimized_module.code)
if orig_code_stripped is None:
orig_code_stripped = _strip_comments_from_code(orig_code)
opt_code_stripped = _strip_comments_from_code(opt_code_full)
# Split stripped versions into lines
orig_code_only = orig_code_stripped.splitlines(keepends=True)
@ -523,18 +539,29 @@ def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst
def clean_extraneous_comments_pipeline(
original_module: cst.Module, optimized_code_and_explanations: list[CodeExplanationAndID]
original_module: cst.Module,
optimized_code_and_explanations: list[CodeExplanationAndID],
*,
orig_code: str | None = None,
orig_code_stripped: str | None = None,
) -> list[CodeExplanationAndID]:
"""Pipeline wrapper for comment cleaning.
Cleans extraneous comments from all optimized code variants.
Pre-computes original code and stripped version once for all candidates.
"""
try:
cleaned_results = []
if orig_code is None:
orig_code = original_module.code
if orig_code_stripped is None:
orig_code_stripped = _strip_comments_from_code(orig_code)
for ce in optimized_code_and_explanations:
try:
cleaned_module = clean_extraneous_comments(original_module, ce.cst_module)
cleaned_module = clean_extraneous_comments(
original_module, ce.cst_module, orig_code=orig_code, orig_code_stripped=orig_code_stripped
)
cleaned_results.append(
CodeExplanationAndID(cst_module=cleaned_module, explanation=ce.explanation, id=ce.id)
)
@ -581,17 +608,23 @@ def fix_forward_references(
def optimizations_postprocessing_pipeline(
original_module: cst.Module, optimized_code_and_explanations: list[CodeExplanationAndID]
) -> list[CodeExplanationAndID]:
pipeline = [
fix_missing_docstring, # We want to deduplicate with the fixed docstrings included
clean_extraneous_comments_pipeline, # Clean comments added to unchanged code
fix_forward_references, # Add future annotations for forward references
deduplicate_optimizations,
equality_check,
dedup_and_sort_imports,
cleanup_explanations,
filter_ellipsis_containing_code,
]
# Pre-compute original code string once — avoids redundant CST codegen across steps
original_code = original_module.code
original_code_stripped = _strip_comments_from_code(original_code)
for pipeline_fn in pipeline:
optimized_code_and_explanations = pipeline_fn(original_module, optimized_code_and_explanations)
optimized_code_and_explanations = fix_missing_docstring(original_module, optimized_code_and_explanations)
optimized_code_and_explanations = clean_extraneous_comments_pipeline(
original_module,
optimized_code_and_explanations,
orig_code=original_code,
orig_code_stripped=original_code_stripped,
)
optimized_code_and_explanations = fix_forward_references(original_module, optimized_code_and_explanations)
optimized_code_and_explanations = deduplicate_optimizations(original_module, optimized_code_and_explanations)
optimized_code_and_explanations = equality_check(
original_module, optimized_code_and_explanations, original_code=original_code
)
optimized_code_and_explanations = dedup_and_sort_imports(original_module, optimized_code_and_explanations)
optimized_code_and_explanations = cleanup_explanations(original_module, optimized_code_and_explanations)
optimized_code_and_explanations = filter_ellipsis_containing_code(original_module, optimized_code_and_explanations)
return optimized_code_and_explanations