mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
280 lines
11 KiB
Python
280 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import ast
|
|
import logging
|
|
import re
|
|
from difflib import Differ
|
|
from typing import Dict, List, Optional
|
|
|
|
import isort
|
|
import libcst as cst
|
|
import sentry_sdk
|
|
from libcst import CSTTransformer, CSTVisitor, Expr, FunctionDef, IndentedBlock, SimpleStatementLine, SimpleString
|
|
|
|
from optimizer.code_utils.postprocess_constants import profanity_regex
|
|
from optimizer.code_utils.postprocess_filter import GibberishFilter
|
|
from optimizer.models import CodeExplanationAndID
|
|
from optimizer.optimizer_utils import compare_unparsed_ast_to_source, unparse_parse_source
|
|
from testgen.instrumentation.edit_generated_test import parse_module_to_cst
|
|
|
|
|
|
def deduplicate_optimizations(
|
|
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
|
|
) -> list[CodeExplanationAndID]:
|
|
"""Remove optimizations that have equivalent code.
|
|
|
|
This function assumes that the 'cst_module' attribute of CodeExplanationAndID objects is a
|
|
libcst.Module instance of the optimized code, and that two CodeExplanationAndID objects are
|
|
considered duplicates if their 'cst_module' attributes produce the same code when converted
|
|
to string.
|
|
|
|
|
|
Args:
|
|
----
|
|
original_source_code (str): The original source code that was optimized.
|
|
optimized_code_and_explanations (List[CodeExplanationAndID]): A list of CodeExplanationAndID
|
|
objects representing the optimized code and their explanations.
|
|
|
|
Returns:
|
|
-------
|
|
List[CodeExplanationAndID]: A list of CodeExplanationAndID objects with duplicates removed.
|
|
|
|
"""
|
|
seen_asts = set()
|
|
unique_optimizations = []
|
|
for optimization in optimized_code_and_explanations:
|
|
code_ast = ast.parse(optimization.cst_module.code)
|
|
code_ast_tuple = ast.dump(code_ast, annotate_fields=False)
|
|
if code_ast_tuple not in seen_asts:
|
|
seen_asts.add(code_ast_tuple)
|
|
unique_optimizations.append(optimization)
|
|
return unique_optimizations
|
|
|
|
|
|
def equality_check(
|
|
original_source_code: str, optimized_code_and_explanations: List[CodeExplanationAndID]
|
|
) -> List[CodeExplanationAndID]:
|
|
"""Remove optimizations that are equivalent to the original source code.
|
|
|
|
To not have client check the original code as candidate.
|
|
"""
|
|
try:
|
|
original_source_ast = unparse_parse_source(original_source_code)
|
|
except Exception:
|
|
return [
|
|
CodeExplanationAndID(cst_module=ce.cst_module, explanation=ce.explanation, id=ce.id)
|
|
for ce in optimized_code_and_explanations
|
|
if ce.cst_module.code != original_source_code
|
|
]
|
|
filtered_optimizations = []
|
|
for ce in optimized_code_and_explanations:
|
|
try:
|
|
if not compare_unparsed_ast_to_source(original_source_ast, ce.cst_module.code):
|
|
filtered_optimizations.append(ce)
|
|
except Exception:
|
|
if ce.cst_module.code != original_source_code:
|
|
filtered_optimizations.append(ce)
|
|
return filtered_optimizations
|
|
|
|
|
|
explanation_sub_patterns = [
|
|
(
|
|
re.compile(
|
|
r"\sHere (is|are) (the )?((code|optimization)|\S+\scode|(optimized|improved) versions? of (the code|these functions))(:|.)",
|
|
re.IGNORECASE,
|
|
),
|
|
"",
|
|
),
|
|
(re.compile(r"^```(.*?)```", re.MULTILINE | re.DOTALL), ""),
|
|
(re.compile(r", as follows:"), "."),
|
|
(re.compile(r":\n"), ".\n"),
|
|
]
|
|
|
|
|
|
def cleanup_explanations(original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]):
|
|
new_optimized_code_and_explanations = []
|
|
|
|
for ce in optimized_code_and_explanations:
|
|
cleaned_up_explanation = ce.explanation
|
|
for pattern, repl in explanation_sub_patterns:
|
|
cleaned_up_explanation = pattern.sub(repl, cleaned_up_explanation)
|
|
|
|
new_optimized_code_and_explanations.append(
|
|
CodeExplanationAndID(cst_module=ce.cst_module, explanation=cleaned_up_explanation, id=ce.id)
|
|
)
|
|
|
|
return new_optimized_code_and_explanations
|
|
|
|
|
|
class DocstringVisitor(CSTVisitor):
|
|
def __init__(self):
|
|
self.original_docstrings = {}
|
|
self.class_name = None
|
|
|
|
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
|
self.class_name = node.name.value
|
|
docstring = node.get_docstring(clean=False)
|
|
if docstring:
|
|
self.original_docstrings[self.class_name] = docstring
|
|
return True
|
|
|
|
def leave_ClassDef(self, node: cst.ClassDef) -> None:
|
|
self.class_name = None
|
|
|
|
def visit_FunctionDef(self, node: FunctionDef) -> bool:
|
|
function_name = node.name.value
|
|
qualified_name = f"{self.class_name}.{function_name}" if self.class_name else function_name
|
|
docstring = node.get_docstring(clean=False)
|
|
if docstring:
|
|
self.original_docstrings[qualified_name] = docstring
|
|
|
|
|
|
class DocstringTransformer(CSTTransformer):
|
|
def __init__(self, original_docstrings: Dict[str, Optional[str]]):
|
|
self.original_docstrings = original_docstrings
|
|
self.class_name = None
|
|
|
|
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
|
self.class_name = node.name.value
|
|
return True
|
|
|
|
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
|
original_docstring = self.original_docstrings.get(self.class_name)
|
|
if original_docstring and not updated_node.get_docstring(clean=False):
|
|
new_body = [SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))])] + list(
|
|
updated_node.body.body
|
|
)
|
|
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
|
|
self.class_name = None
|
|
return updated_node
|
|
|
|
def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> FunctionDef:
|
|
function_name = updated_node.name.value
|
|
qualified_name = f"{self.class_name}.{function_name}" if self.class_name else function_name
|
|
original_docstring = self.original_docstrings.get(qualified_name)
|
|
if original_docstring and not updated_node.get_docstring(clean=False):
|
|
new_body = [SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))])] + list(
|
|
updated_node.body.body
|
|
)
|
|
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
|
|
return updated_node
|
|
|
|
|
|
def fix_missing_docstring(
|
|
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
|
|
) -> list[CodeExplanationAndID]:
|
|
visitor = DocstringVisitor()
|
|
try:
|
|
original_tree = cst.parse_module(original_source_code)
|
|
except Exception:
|
|
return optimized_code_and_explanations
|
|
original_tree.visit(visitor)
|
|
original_docstrings = visitor.original_docstrings
|
|
transformer = DocstringTransformer(original_docstrings)
|
|
return [
|
|
CodeExplanationAndID(cst_module=ce.cst_module.visit(transformer), explanation=ce.explanation, id=ce.id)
|
|
for ce in optimized_code_and_explanations
|
|
]
|
|
|
|
|
|
def dedup_and_sort_imports(
|
|
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
|
|
) -> list[CodeExplanationAndID]:
|
|
new_optimized_code_and_explanations = []
|
|
for ce in optimized_code_and_explanations:
|
|
try:
|
|
# Use isort to sort and deduplicate the imports
|
|
sorted_code = isort.code(ce.cst_module.code, disregard_skip=True)
|
|
except Exception as e:
|
|
sentry_sdk.capture_exception(e)
|
|
sorted_code = ce.cst_module.code
|
|
new_optimized_code_and_explanations.append(
|
|
CodeExplanationAndID(cst_module=parse_module_to_cst(sorted_code), explanation=ce.explanation, id=ce.id)
|
|
)
|
|
|
|
return new_optimized_code_and_explanations
|
|
|
|
|
|
class EllipsisContainingCodeVisitor(CSTVisitor):
|
|
def __init__(self):
|
|
self.ellipsis_containing_code = False
|
|
|
|
def visit_Ellipsis(self, node: cst.Ellipsis) -> bool:
|
|
self.ellipsis_containing_code = True
|
|
return False
|
|
|
|
|
|
def filter_ellipsis_containing_code(
|
|
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
|
|
) -> list[CodeExplanationAndID]:
|
|
new_optimized_code_and_explanations = []
|
|
original_visitor = EllipsisContainingCodeVisitor()
|
|
original_tree = cst.parse_module(original_source_code)
|
|
original_tree.visit(original_visitor)
|
|
if original_visitor.ellipsis_containing_code:
|
|
# don't check for ellipsis containing optimized code if the original code contains ellipsis
|
|
return optimized_code_and_explanations
|
|
for ce in optimized_code_and_explanations:
|
|
visitor = EllipsisContainingCodeVisitor()
|
|
ce.cst_module.visit(visitor)
|
|
if not visitor.ellipsis_containing_code:
|
|
new_optimized_code_and_explanations.append(ce)
|
|
return new_optimized_code_and_explanations
|
|
|
|
|
|
def remove_profanity_from_explanation(
|
|
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
|
|
):
|
|
new_optimized_code_and_explanations: list[CodeExplanationAndID] = []
|
|
for ce in optimized_code_and_explanations:
|
|
if profanity_regex.search(ce.explanation):
|
|
logging.warning(f"Profanity detected in explanation for optimization {ce.id}. Skipping this optimization.")
|
|
continue
|
|
new_optimized_code_and_explanations.append(
|
|
CodeExplanationAndID(cst_module=ce.cst_module, explanation=ce.explanation, id=ce.id)
|
|
)
|
|
return new_optimized_code_and_explanations
|
|
|
|
|
|
def degibberization_pipeline(
|
|
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
|
|
) -> list[CodeExplanationAndID]:
|
|
new_optimized_code_and_explanations = []
|
|
|
|
for code_explanation in optimized_code_and_explanations:
|
|
gibberish_filter = GibberishFilter(code_explanation)
|
|
if gibberish_filter.is_gibberish():
|
|
cleaned_explanation = gibberish_filter.cleanup_explanation()
|
|
if cleaned_explanation:
|
|
diff = Differ().compare(code_explanation.explanation.splitlines(), cleaned_explanation.splitlines())
|
|
diff_text = "\n".join(diff)
|
|
logging.info("Explanation cleaned for optimization %s:\n%s", code_explanation.id, diff_text)
|
|
logging.info("removed words: %s", "\n".join(gibberish_filter.removed_words))
|
|
new_optimized_code_and_explanations.append(
|
|
CodeExplanationAndID(
|
|
cst_module=code_explanation.cst_module, explanation=cleaned_explanation, id=code_explanation.id
|
|
)
|
|
)
|
|
else:
|
|
new_optimized_code_and_explanations.append(code_explanation)
|
|
|
|
return new_optimized_code_and_explanations
|
|
|
|
|
|
def optimizations_postprocessing_pipeline(
|
|
original_source_code: str, optimized_code_and_explanations: list[CodeExplanationAndID]
|
|
) -> list[CodeExplanationAndID]:
|
|
pipeline = [
|
|
remove_profanity_from_explanation,
|
|
fix_missing_docstring, # We want to deduplicate with the fixed docstrings included
|
|
deduplicate_optimizations,
|
|
equality_check,
|
|
dedup_and_sort_imports,
|
|
cleanup_explanations,
|
|
# degibberization_pipeline, # with newer models, the context window is large enough to not need this, restore if needed
|
|
filter_ellipsis_containing_code,
|
|
]
|
|
|
|
for pipeline_fn in pipeline:
|
|
optimized_code_and_explanations = pipeline_fn(original_source_code, optimized_code_and_explanations)
|
|
return optimized_code_and_explanations
|