codeflash-internal/django/aiservice/optimizer/postprocess.py
2025-02-25 18:30:06 -08:00

280 lines
11 KiB
Python

from __future__ import annotations
import ast
import logging
import re
from difflib import Differ
from typing import Dict, List, Optional
import isort
import libcst as cst
import sentry_sdk
from libcst import CSTTransformer, CSTVisitor, Expr, FunctionDef, IndentedBlock, SimpleStatementLine, SimpleString
from optimizer.code_utils.postprocess_constants import profanity_regex
from optimizer.code_utils.postprocess_filter import GibberishFilter
from optimizer.models import CodeExplanationAndID
from optimizer.optimizer_utils import compare_unparsed_ast_to_source, unparse_parse_source
from testgen.instrumentation.edit_generated_test import parse_module_to_cst
def deduplicate_optimizations(
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
) -> list[CodeExplanationAndID]:
"""Remove optimizations that have equivalent code.
This function assumes that the 'cst_module' attribute of CodeExplanationAndID objects is a
libcst.Module instance of the optimized code, and that two CodeExplanationAndID objects are
considered duplicates if their 'cst_module' attributes produce the same code when converted
to string.
Args:
----
original_source_code (str): The original source code that was optimized.
optimized_code_and_explanations (List[CodeExplanationAndID]): A list of CodeExplanationAndID
objects representing the optimized code and their explanations.
Returns:
-------
List[CodeExplanationAndID]: A list of CodeExplanationAndID objects with duplicates removed.
"""
seen_asts = set()
unique_optimizations = []
for optimization in optimized_code_and_explanations:
code_ast = ast.parse(optimization.cst_module.code)
code_ast_tuple = ast.dump(code_ast, annotate_fields=False)
if code_ast_tuple not in seen_asts:
seen_asts.add(code_ast_tuple)
unique_optimizations.append(optimization)
return unique_optimizations
def equality_check(
original_source_code: str, optimized_code_and_explanations: List[CodeExplanationAndID]
) -> List[CodeExplanationAndID]:
"""Remove optimizations that are equivalent to the original source code.
To not have client check the original code as candidate.
"""
try:
original_source_ast = unparse_parse_source(original_source_code)
except Exception:
return [
CodeExplanationAndID(cst_module=ce.cst_module, explanation=ce.explanation, id=ce.id)
for ce in optimized_code_and_explanations
if ce.cst_module.code != original_source_code
]
filtered_optimizations = []
for ce in optimized_code_and_explanations:
try:
if not compare_unparsed_ast_to_source(original_source_ast, ce.cst_module.code):
filtered_optimizations.append(ce)
except Exception:
if ce.cst_module.code != original_source_code:
filtered_optimizations.append(ce)
return filtered_optimizations
explanation_sub_patterns = [
(
re.compile(
r"\sHere (is|are) (the )?((code|optimization)|\S+\scode|(optimized|improved) versions? of (the code|these functions))(:|.)",
re.IGNORECASE,
),
"",
),
(re.compile(r"^```(.*?)```", re.MULTILINE | re.DOTALL), ""),
(re.compile(r", as follows:"), "."),
(re.compile(r":\n"), ".\n"),
]
def cleanup_explanations(original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]):
new_optimized_code_and_explanations = []
for ce in optimized_code_and_explanations:
cleaned_up_explanation = ce.explanation
for pattern, repl in explanation_sub_patterns:
cleaned_up_explanation = pattern.sub(repl, cleaned_up_explanation)
new_optimized_code_and_explanations.append(
CodeExplanationAndID(cst_module=ce.cst_module, explanation=cleaned_up_explanation, id=ce.id)
)
return new_optimized_code_and_explanations
class DocstringVisitor(CSTVisitor):
def __init__(self):
self.original_docstrings = {}
self.class_name = None
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self.class_name = node.name.value
docstring = node.get_docstring(clean=False)
if docstring:
self.original_docstrings[self.class_name] = docstring
return True
def leave_ClassDef(self, node: cst.ClassDef) -> None:
self.class_name = None
def visit_FunctionDef(self, node: FunctionDef) -> bool:
function_name = node.name.value
qualified_name = f"{self.class_name}.{function_name}" if self.class_name else function_name
docstring = node.get_docstring(clean=False)
if docstring:
self.original_docstrings[qualified_name] = docstring
class DocstringTransformer(CSTTransformer):
def __init__(self, original_docstrings: Dict[str, Optional[str]]):
self.original_docstrings = original_docstrings
self.class_name = None
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self.class_name = node.name.value
return True
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
original_docstring = self.original_docstrings.get(self.class_name)
if original_docstring and not updated_node.get_docstring(clean=False):
new_body = [SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))])] + list(
updated_node.body.body
)
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
self.class_name = None
return updated_node
def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> FunctionDef:
function_name = updated_node.name.value
qualified_name = f"{self.class_name}.{function_name}" if self.class_name else function_name
original_docstring = self.original_docstrings.get(qualified_name)
if original_docstring and not updated_node.get_docstring(clean=False):
new_body = [SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))])] + list(
updated_node.body.body
)
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
return updated_node
def fix_missing_docstring(
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
) -> list[CodeExplanationAndID]:
visitor = DocstringVisitor()
try:
original_tree = cst.parse_module(original_source_code)
except Exception:
return optimized_code_and_explanations
original_tree.visit(visitor)
original_docstrings = visitor.original_docstrings
transformer = DocstringTransformer(original_docstrings)
return [
CodeExplanationAndID(cst_module=ce.cst_module.visit(transformer), explanation=ce.explanation, id=ce.id)
for ce in optimized_code_and_explanations
]
def dedup_and_sort_imports(
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
) -> list[CodeExplanationAndID]:
new_optimized_code_and_explanations = []
for ce in optimized_code_and_explanations:
try:
# Use isort to sort and deduplicate the imports
sorted_code = isort.code(ce.cst_module.code, disregard_skip=True)
except Exception as e:
sentry_sdk.capture_exception(e)
sorted_code = ce.cst_module.code
new_optimized_code_and_explanations.append(
CodeExplanationAndID(cst_module=parse_module_to_cst(sorted_code), explanation=ce.explanation, id=ce.id)
)
return new_optimized_code_and_explanations
class EllipsisContainingCodeVisitor(CSTVisitor):
def __init__(self):
self.ellipsis_containing_code = False
def visit_Ellipsis(self, node: cst.Ellipsis) -> bool:
self.ellipsis_containing_code = True
return False
def filter_ellipsis_containing_code(
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
) -> list[CodeExplanationAndID]:
new_optimized_code_and_explanations = []
original_visitor = EllipsisContainingCodeVisitor()
original_tree = cst.parse_module(original_source_code)
original_tree.visit(original_visitor)
if original_visitor.ellipsis_containing_code:
# don't check for ellipsis containing optimized code if the original code contains ellipsis
return optimized_code_and_explanations
for ce in optimized_code_and_explanations:
visitor = EllipsisContainingCodeVisitor()
ce.cst_module.visit(visitor)
if not visitor.ellipsis_containing_code:
new_optimized_code_and_explanations.append(ce)
return new_optimized_code_and_explanations
def remove_profanity_from_explanation(
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
):
new_optimized_code_and_explanations: list[CodeExplanationAndID] = []
for ce in optimized_code_and_explanations:
if profanity_regex.search(ce.explanation):
logging.warning(f"Profanity detected in explanation for optimization {ce.id}. Skipping this optimization.")
continue
new_optimized_code_and_explanations.append(
CodeExplanationAndID(cst_module=ce.cst_module, explanation=ce.explanation, id=ce.id)
)
return new_optimized_code_and_explanations
def degibberization_pipeline(
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
) -> list[CodeExplanationAndID]:
new_optimized_code_and_explanations = []
for code_explanation in optimized_code_and_explanations:
gibberish_filter = GibberishFilter(code_explanation)
if gibberish_filter.is_gibberish():
cleaned_explanation = gibberish_filter.cleanup_explanation()
if cleaned_explanation:
diff = Differ().compare(code_explanation.explanation.splitlines(), cleaned_explanation.splitlines())
diff_text = "\n".join(diff)
logging.info("Explanation cleaned for optimization %s:\n%s", code_explanation.id, diff_text)
logging.info("removed words: %s", "\n".join(gibberish_filter.removed_words))
new_optimized_code_and_explanations.append(
CodeExplanationAndID(
cst_module=code_explanation.cst_module, explanation=cleaned_explanation, id=code_explanation.id
)
)
else:
new_optimized_code_and_explanations.append(code_explanation)
return new_optimized_code_and_explanations
def optimizations_postprocessing_pipeline(
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
) -> list[CodeExplanationAndID]:
pipeline = [
remove_profanity_from_explanation,
fix_missing_docstring, # We want to deduplicate with the fixed docstrings included
deduplicate_optimizations,
equality_check,
dedup_and_sort_imports,
cleanup_explanations,
# degibberization_pipeline, # with newer models, the context window is large enough to not need this, restore if needed
filter_ellipsis_containing_code,
]
for pipeline_fn in pipeline:
optimized_code_and_explanations = pipeline_fn(original_source_code, optimized_code_and_explanations)
return optimized_code_and_explanations