Add language layer: CST utils, validator, postprocessing pipeline
Faithful port of Python language utilities from Django aiservice: - _cst_utils.py: depth tracking, import extraction, definition removal, ellipsis detection, expression evaluation, module path helpers - _validator.py: dual ast+libcst syntax validation, parse-or-none - _postprocess.py: full optimization postprocessing pipeline including dedup, equality check, docstring restoration, comment cleaning, forward reference fixing, ellipsis filtering, isort
This commit is contained in:
parent
5c6b82050a
commit
3e62f502e7
5 changed files with 2033 additions and 0 deletions
|
|
@ -0,0 +1,528 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import operator
|
||||
from collections import deque
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import libcst as cst
|
||||
from libcst.metadata import (
|
||||
MetadataWrapper,
|
||||
ParentNodeProvider,
|
||||
PositionProvider,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
|
||||
class DepthTrackingMixin:
|
||||
"""
|
||||
Tracks function and class nesting depth for CST transformers.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._function_depth: int = 0
|
||||
self._class_depth: int = 0
|
||||
|
||||
def _visit_function(self) -> None:
|
||||
self._function_depth += 1
|
||||
|
||||
def _leave_function(self) -> None:
|
||||
self._function_depth -= 1
|
||||
|
||||
def _visit_class(self) -> None:
|
||||
self._class_depth += 1
|
||||
|
||||
def _leave_class(self) -> None:
|
||||
self._class_depth -= 1
|
||||
|
||||
def _is_top_level(self) -> bool:
|
||||
return self._function_depth == 0 and self._class_depth == 0
|
||||
|
||||
def _is_top_level_function(self) -> bool:
|
||||
return self._function_depth == 0 and self._class_depth == 0
|
||||
|
||||
def _is_top_level_class(self) -> bool:
|
||||
return self._class_depth == 0 and self._function_depth == 0
|
||||
|
||||
def _is_inside_class(self) -> bool:
|
||||
return self._class_depth > 0
|
||||
|
||||
def _is_inside_function(self) -> bool:
|
||||
return self._function_depth > 0
|
||||
|
||||
|
||||
def extract_import_info(
|
||||
alias: cst.ImportAlias, module_name: str = ""
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
Extract (available_name, module_path, original_name) from an import alias.
|
||||
"""
|
||||
original_name = (
|
||||
alias.name.value
|
||||
if isinstance(alias.name, cst.Name)
|
||||
else get_dotted_name(alias.name)
|
||||
)
|
||||
|
||||
if alias.asname and isinstance(alias.asname.name, cst.Name):
|
||||
available_name = alias.asname.name.value
|
||||
elif module_name:
|
||||
available_name = original_name
|
||||
else:
|
||||
available_name = original_name.split(".")[0]
|
||||
|
||||
return available_name, module_name, original_name
|
||||
|
||||
|
||||
def extract_imports_from_import(
|
||||
node: cst.Import,
|
||||
) -> dict[str, tuple[str, str]]:
|
||||
"""
|
||||
Extract all imported names from an Import statement.
|
||||
"""
|
||||
if isinstance(node.names, cst.ImportStar):
|
||||
return {}
|
||||
|
||||
result: dict[str, tuple[str, str]] = {}
|
||||
for alias in node.names:
|
||||
available_name, _, original_name = extract_import_info(alias)
|
||||
if available_name:
|
||||
result[available_name] = (
|
||||
original_name,
|
||||
original_name,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def extract_imports_from_import_from(
|
||||
node: cst.ImportFrom,
|
||||
) -> dict[str, tuple[str, str]]:
|
||||
"""
|
||||
Extract all imported names from an ImportFrom statement.
|
||||
"""
|
||||
if isinstance(node.names, cst.ImportStar):
|
||||
return {}
|
||||
|
||||
module_name = get_dotted_name(node.module) if node.module else ""
|
||||
result: dict[str, tuple[str, str]] = {}
|
||||
for alias in node.names:
|
||||
available_name, _, original_name = extract_import_info(
|
||||
alias, module_name
|
||||
)
|
||||
if available_name:
|
||||
result[available_name] = (
|
||||
module_name,
|
||||
original_name,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def collect_imported_names_from_import(
|
||||
node: cst.Import,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Collect available names from an Import statement.
|
||||
"""
|
||||
if isinstance(node.names, cst.ImportStar):
|
||||
return set()
|
||||
return set(extract_imports_from_import(node).keys())
|
||||
|
||||
|
||||
def collect_imported_names_from_import_from(
|
||||
node: cst.ImportFrom,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Collect available names from an ImportFrom statement.
|
||||
"""
|
||||
if isinstance(node.names, cst.ImportStar):
|
||||
return set()
|
||||
return set(extract_imports_from_import_from(node).keys())
|
||||
|
||||
|
||||
class DefinitionRemover(DepthTrackingMixin, cst.CSTTransformer):
|
||||
"""
|
||||
Remove top-level class and function definitions by name.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
names_to_remove: set[str] | list[str],
|
||||
protected_names: set[str] | None = None,
|
||||
) -> None:
|
||||
DepthTrackingMixin.__init__(self)
|
||||
cst.CSTTransformer.__init__(self)
|
||||
|
||||
names_list = list(names_to_remove)
|
||||
self._qualified_names: list[list[str]] = [
|
||||
s.split(".")[-2:] for s in names_list
|
||||
]
|
||||
self._simple_names: set[str] = {
|
||||
pair[-1] for pair in self._qualified_names if pair
|
||||
}
|
||||
|
||||
self.protected_names: set[str] = protected_names or set()
|
||||
self.removed_names: set[str] = set()
|
||||
|
||||
self._is_killable_class: bool = False
|
||||
self._potential_methods: set[str] = set()
|
||||
|
||||
def _should_remove(self, name: str) -> bool:
|
||||
"""
|
||||
Check if *name* should be removed.
|
||||
"""
|
||||
return name in self._simple_names and name not in self.protected_names
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
is_top_level = self._is_top_level_class()
|
||||
self._visit_class()
|
||||
|
||||
if not is_top_level:
|
||||
return True
|
||||
|
||||
class_name = node.name.value
|
||||
|
||||
if self._should_remove(class_name):
|
||||
self._is_killable_class = True
|
||||
return False
|
||||
|
||||
potential_methods = {
|
||||
pair[-1]
|
||||
for pair in self._qualified_names
|
||||
if len(pair) > 1
|
||||
and pair[-2] == class_name
|
||||
and pair[-1] not in self.protected_names
|
||||
}
|
||||
if potential_methods:
|
||||
self._potential_methods = potential_methods
|
||||
|
||||
return True
|
||||
|
||||
def leave_ClassDef(
|
||||
self,
|
||||
original_node: cst.ClassDef,
|
||||
updated_node: cst.ClassDef,
|
||||
) -> cst.ClassDef | cst.RemovalSentinel:
|
||||
self._leave_class()
|
||||
|
||||
if self._is_top_level() and self._is_killable_class:
|
||||
self._is_killable_class = False
|
||||
self._potential_methods = set()
|
||||
self.removed_names.add(original_node.name.value)
|
||||
return cst.RemovalSentinel.REMOVE
|
||||
|
||||
self._potential_methods = set()
|
||||
return updated_node
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
||||
self._visit_function()
|
||||
return True
|
||||
|
||||
def leave_FunctionDef(
|
||||
self,
|
||||
original_node: cst.FunctionDef,
|
||||
updated_node: cst.FunctionDef,
|
||||
) -> cst.FunctionDef | cst.RemovalSentinel:
|
||||
self._leave_function()
|
||||
|
||||
func_name = original_node.name.value
|
||||
|
||||
if self._is_top_level() and self._should_remove(func_name):
|
||||
self.removed_names.add(func_name)
|
||||
return cst.RemovalSentinel.REMOVE
|
||||
|
||||
if func_name in self._potential_methods:
|
||||
self._is_killable_class = True
|
||||
self._potential_methods = set()
|
||||
|
||||
return updated_node
|
||||
|
||||
|
||||
class ImportTrackingVisitor(ast.NodeVisitor):
|
||||
"""
|
||||
Base AST visitor that tracks imported names.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.imported_names: set[str] = set()
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> None:
|
||||
for alias in node.names:
|
||||
name = alias.asname or alias.name.split(".")[0]
|
||||
self.imported_names.add(name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
for alias in node.names:
|
||||
if alias.name == "*":
|
||||
continue
|
||||
name = alias.asname or alias.name
|
||||
self.imported_names.add(name)
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def parse_module_to_cst(
|
||||
module_str: str,
|
||||
) -> cst.Module:
|
||||
"""
|
||||
Parse a module string into its libCST representation.
|
||||
"""
|
||||
return cst.parse_module(module_str)
|
||||
|
||||
|
||||
def file_path_to_module_path(file_path: str) -> str:
|
||||
"""
|
||||
Convert a file path to a dotted module path.
|
||||
"""
|
||||
file_path = file_path.removesuffix(".py")
|
||||
return file_path.replace("/", ".").replace("\\", ".")
|
||||
|
||||
|
||||
def get_dotted_name(
|
||||
node: cst.BaseExpression | None,
|
||||
) -> str:
|
||||
"""
|
||||
Extract dotted module path from a CST node.
|
||||
"""
|
||||
if node is None:
|
||||
return ""
|
||||
if isinstance(node, cst.Name):
|
||||
return node.value
|
||||
if isinstance(node, cst.Attribute):
|
||||
base = get_dotted_name(node.value)
|
||||
if base:
|
||||
return f"{base}.{node.attr.value}"
|
||||
return node.attr.value
|
||||
return ""
|
||||
|
||||
|
||||
def get_base_class_name(
|
||||
base: cst.Arg,
|
||||
) -> str | None:
|
||||
"""
|
||||
Get the name of a base class from a CST Arg node.
|
||||
"""
|
||||
expr = base.value
|
||||
if isinstance(expr, cst.Name):
|
||||
return expr.value
|
||||
if isinstance(expr, cst.Attribute):
|
||||
return expr.attr.value
|
||||
return None
|
||||
|
||||
|
||||
def has_decorator(class_node: cst.ClassDef, decorator_name: str) -> bool:
|
||||
"""
|
||||
Check if a class has a specific decorator.
|
||||
"""
|
||||
for decorator in class_node.decorators:
|
||||
dec = decorator.decorator
|
||||
if isinstance(dec, cst.Name) and dec.value == decorator_name:
|
||||
return True
|
||||
if (
|
||||
isinstance(dec, cst.Call)
|
||||
and isinstance(dec.func, cst.Name)
|
||||
and dec.func.value == decorator_name
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_node_source_text(
|
||||
node: cst.CSTNode,
|
||||
source_lines: list[str],
|
||||
wrapper: MetadataWrapper,
|
||||
fallback: str = "...",
|
||||
) -> str:
|
||||
"""
|
||||
Get the source text of a CST node using position metadata.
|
||||
"""
|
||||
pos = wrapper.resolve(PositionProvider).get(node)
|
||||
if pos is None:
|
||||
return fallback
|
||||
return source_lines[pos.start.line - 1][pos.start.column : pos.end.column]
|
||||
|
||||
|
||||
def build_module_path(
|
||||
module_path: str,
|
||||
) -> cst.Attribute | cst.Name:
|
||||
"""
|
||||
Build a libcst node from a dotted module path string.
|
||||
"""
|
||||
parts = module_path.split(".")
|
||||
result: cst.Attribute | cst.Name = cst.Name(parts[0])
|
||||
for part in parts[1:]:
|
||||
result = cst.Attribute(value=result, attr=cst.Name(part))
|
||||
return result
|
||||
|
||||
|
||||
_BINARY_OPS = {
|
||||
cst.Power: operator.pow,
|
||||
cst.Multiply: operator.mul,
|
||||
cst.Add: operator.add,
|
||||
cst.Subtract: operator.sub,
|
||||
}
|
||||
|
||||
|
||||
def make_number_node(
|
||||
value: int,
|
||||
) -> cst.Integer | cst.UnaryOperation:
|
||||
"""
|
||||
Create a CST node representing an integer value.
|
||||
"""
|
||||
if value >= 0:
|
||||
return cst.Integer(str(value))
|
||||
return cst.UnaryOperation(
|
||||
operator=cst.Minus(),
|
||||
expression=cst.Integer(str(abs(value))),
|
||||
)
|
||||
|
||||
|
||||
def evaluate_expression(
|
||||
node: cst.BaseExpression,
|
||||
) -> int | None:
|
||||
"""
|
||||
Evaluate a CST expression node to an integer, or None.
|
||||
"""
|
||||
if isinstance(node, cst.Integer):
|
||||
return int(node.value, 0)
|
||||
if isinstance(node, cst.Float):
|
||||
return int(float(node.value))
|
||||
if isinstance(node, cst.BinaryOperation):
|
||||
left = evaluate_expression(node.left)
|
||||
right = evaluate_expression(node.right)
|
||||
op = _BINARY_OPS.get(type(node.operator))
|
||||
if left is not None and right is not None and op is not None:
|
||||
return op(left, right)
|
||||
if isinstance(node, cst.UnaryOperation) and isinstance(
|
||||
node.operator, cst.Minus
|
||||
):
|
||||
val = evaluate_expression(node.expression)
|
||||
if val is not None:
|
||||
return -val
|
||||
return None
|
||||
|
||||
|
||||
class AnyEllipsisVisitor(cst.CSTVisitor):
|
||||
"""
|
||||
Detect if any ellipsis exists in the code.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.found = False
|
||||
|
||||
def visit_Ellipsis(self, node: cst.Ellipsis) -> bool:
|
||||
self.found = True
|
||||
return False
|
||||
|
||||
|
||||
class InvalidEllipsisVisitor(cst.CSTVisitor):
|
||||
"""
|
||||
Detect ellipsis in invalid locations.
|
||||
"""
|
||||
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.invalid_found = False
|
||||
|
||||
def visit_Ellipsis(self, node: cst.Ellipsis) -> bool:
|
||||
if self._is_in_type_annotation_context(node):
|
||||
return False
|
||||
if self._is_function_body_ellipsis(node):
|
||||
return False
|
||||
self.invalid_found = True
|
||||
return False
|
||||
|
||||
def _iter_ancestors(
|
||||
self,
|
||||
node: cst.CSTNode,
|
||||
max_depth: int = 10,
|
||||
) -> Iterator[cst.CSTNode]:
|
||||
"""
|
||||
Yield ancestors walking up the tree.
|
||||
"""
|
||||
current = node
|
||||
for _ in range(max_depth):
|
||||
parent = self.get_metadata(ParentNodeProvider, current, None)
|
||||
if parent is None:
|
||||
break
|
||||
yield parent
|
||||
current = parent
|
||||
|
||||
def _is_in_type_annotation_context(self, node: cst.CSTNode) -> bool:
|
||||
"""
|
||||
Check if *node* is within a type annotation.
|
||||
"""
|
||||
return any(
|
||||
isinstance(ancestor, (cst.Subscript, cst.Tuple))
|
||||
for ancestor in self._iter_ancestors(node)
|
||||
)
|
||||
|
||||
def _is_function_body_ellipsis(self, node: cst.Ellipsis) -> bool:
|
||||
"""
|
||||
Check if *node* is a function body ellipsis.
|
||||
"""
|
||||
ancestors = list(self._iter_ancestors(node, max_depth=4))
|
||||
if len(ancestors) < 4:
|
||||
return False
|
||||
expected_types = (
|
||||
cst.Expr,
|
||||
cst.SimpleStatementLine,
|
||||
cst.IndentedBlock,
|
||||
cst.FunctionDef,
|
||||
)
|
||||
return all(
|
||||
isinstance(ancestor, expected)
|
||||
for ancestor, expected in zip(
|
||||
ancestors, expected_types, strict=True
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def ellipsis_in_cst_not_types(
|
||||
module: cst.Module,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if *module* contains ellipsis in invalid locations.
|
||||
"""
|
||||
visitor = InvalidEllipsisVisitor()
|
||||
MetadataWrapper(module).visit(visitor)
|
||||
return visitor.invalid_found
|
||||
|
||||
|
||||
def any_ellipsis_in_cst(module: cst.Module) -> bool:
|
||||
"""
|
||||
Check if *module* contains any ellipsis.
|
||||
"""
|
||||
visitor = AnyEllipsisVisitor()
|
||||
module.visit(visitor)
|
||||
return visitor.found
|
||||
|
||||
|
||||
def unparse_parse_source(source: str) -> str:
|
||||
"""
|
||||
Parse source and unparse it back, normalizing formatting.
|
||||
"""
|
||||
return ast.unparse(ast.parse(source))
|
||||
|
||||
|
||||
def compare_unparsed_ast_to_source(unparsed_ast: str, source: str) -> bool:
|
||||
"""
|
||||
Compare an unparsed AST string to a source string after normalizing.
|
||||
"""
|
||||
return unparsed_ast == unparse_parse_source(source)
|
||||
|
||||
|
||||
def find_init(node: ast.AST) -> bool:
|
||||
"""
|
||||
Search an AST for a FunctionDef node named '__init__'.
|
||||
"""
|
||||
stack = deque([node])
|
||||
while stack:
|
||||
current = stack.pop()
|
||||
if isinstance(current, ast.FunctionDef) and current.name == "__init__":
|
||||
return True
|
||||
stack.extend(ast.iter_child_nodes(current))
|
||||
return False
|
||||
|
|
@ -0,0 +1,906 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import builtins
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
import tokenize
|
||||
from difflib import SequenceMatcher
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import attrs
|
||||
import isort
|
||||
import libcst as cst
|
||||
from libcst import (
|
||||
BaseStatement,
|
||||
CSTTransformer,
|
||||
CSTVisitor,
|
||||
Expr,
|
||||
IndentedBlock,
|
||||
SimpleStatementLine,
|
||||
SimpleString,
|
||||
)
|
||||
from libcst.codemod import CodemodContext
|
||||
from libcst.codemod.visitors import AddImportsVisitor
|
||||
|
||||
from codeflash_api.languages.python._cst_utils import (
|
||||
collect_imported_names_from_import,
|
||||
collect_imported_names_from_import_from,
|
||||
compare_unparsed_ast_to_source,
|
||||
get_dotted_name,
|
||||
parse_module_to_cst,
|
||||
unparse_parse_source,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libcst import FunctionDef
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
BUILTIN_NAMES = frozenset(dir(builtins))
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class OptimizationCandidate:
|
||||
"""
|
||||
An optimized code variant with its CST module and explanation.
|
||||
"""
|
||||
|
||||
cst_module: cst.Module
|
||||
explanation: str
|
||||
id: str
|
||||
|
||||
|
||||
def safe_isort(code: str, **kwargs: object) -> str:
|
||||
"""
|
||||
Run isort on *code*, returning original on failure.
|
||||
"""
|
||||
try:
|
||||
return isort.code(code, **kwargs)
|
||||
except Exception:
|
||||
return code
|
||||
|
||||
|
||||
def deduplicate_optimizations(
|
||||
_original_module: cst.Module,
|
||||
candidates: list[OptimizationCandidate],
|
||||
) -> list[OptimizationCandidate]:
|
||||
"""
|
||||
Remove candidates with equivalent ASTs.
|
||||
"""
|
||||
seen_asts: set[str] = set()
|
||||
unique: list[OptimizationCandidate] = []
|
||||
for candidate in candidates:
|
||||
normalized = ast.unparse(ast.parse(candidate.cst_module.code))
|
||||
if normalized not in seen_asts:
|
||||
seen_asts.add(normalized)
|
||||
unique.append(candidate)
|
||||
return unique
|
||||
|
||||
|
||||
def equality_check(
|
||||
original_module: cst.Module,
|
||||
candidates: list[OptimizationCandidate],
|
||||
*,
|
||||
original_code: str | None = None,
|
||||
) -> list[OptimizationCandidate]:
|
||||
"""
|
||||
Filter out candidates identical to the original.
|
||||
"""
|
||||
source = (
|
||||
original_code if original_code is not None else original_module.code
|
||||
)
|
||||
try:
|
||||
original_ast = unparse_parse_source(source)
|
||||
except Exception:
|
||||
return [c for c in candidates if c.cst_module.code != source]
|
||||
filtered: list[OptimizationCandidate] = []
|
||||
for c in candidates:
|
||||
try:
|
||||
if not compare_unparsed_ast_to_source(
|
||||
original_ast, c.cst_module.code
|
||||
):
|
||||
filtered.append(c)
|
||||
except Exception:
|
||||
if c.cst_module.code != source:
|
||||
filtered.append(c)
|
||||
return filtered
|
||||
|
||||
|
||||
_EXPLANATION_PATTERNS = [
|
||||
(
|
||||
re.compile(
|
||||
r"\sHere (is|are) (the )?((code|optimization)"
|
||||
r"|\S+\scode|(optimized|improved) versions?"
|
||||
r" of (the code|these functions))(:|.)",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
"",
|
||||
),
|
||||
(
|
||||
re.compile(r"^```(.*?)```", re.MULTILINE | re.DOTALL),
|
||||
"",
|
||||
),
|
||||
(re.compile(r", as follows:"), "."),
|
||||
(re.compile(r":\n"), ".\n"),
|
||||
]
|
||||
|
||||
|
||||
def cleanup_explanations(
|
||||
_original_module: cst.Module,
|
||||
candidates: list[OptimizationCandidate],
|
||||
) -> list[OptimizationCandidate]:
|
||||
"""
|
||||
Clean up explanation text from LLM artifacts.
|
||||
"""
|
||||
result: list[OptimizationCandidate] = []
|
||||
for c in candidates:
|
||||
cleaned = c.explanation
|
||||
for pattern, repl in _EXPLANATION_PATTERNS:
|
||||
cleaned = pattern.sub(repl, cleaned)
|
||||
result.append(
|
||||
OptimizationCandidate(
|
||||
cst_module=c.cst_module,
|
||||
explanation=cleaned,
|
||||
id=c.id,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class DocstringVisitor(CSTVisitor):
|
||||
"""
|
||||
Collect original docstrings from functions and classes.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.original_docstrings: dict[str, str] = {}
|
||||
self.class_name: str | None = None
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
self.class_name = node.name.value
|
||||
docstring = node.get_docstring(clean=False)
|
||||
if docstring:
|
||||
self.original_docstrings[self.class_name] = docstring
|
||||
return True
|
||||
|
||||
def leave_ClassDef(
|
||||
self,
|
||||
original_node: cst.ClassDef,
|
||||
) -> None:
|
||||
self.class_name = None
|
||||
|
||||
def visit_FunctionDef(self, node: FunctionDef) -> bool:
|
||||
function_name = node.name.value
|
||||
qualified_name = (
|
||||
f"{self.class_name}.{function_name}"
|
||||
if self.class_name
|
||||
else function_name
|
||||
)
|
||||
docstring = node.get_docstring(clean=False)
|
||||
if docstring:
|
||||
self.original_docstrings[qualified_name] = docstring
|
||||
return True
|
||||
|
||||
|
||||
class DocstringTransformer(CSTTransformer):
|
||||
"""
|
||||
Restore original docstrings in optimized code.
|
||||
"""
|
||||
|
||||
def __init__(self, original_docstrings: dict[str, str]) -> None:
|
||||
self.original_docstrings = original_docstrings
|
||||
self.class_name: str | None = None
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
self.class_name = node.name.value
|
||||
return True
|
||||
|
||||
def leave_ClassDef(
|
||||
self,
|
||||
original_node: cst.ClassDef,
|
||||
updated_node: cst.ClassDef,
|
||||
) -> cst.ClassDef:
|
||||
original_docstring = (
|
||||
self.original_docstrings.get(self.class_name)
|
||||
if self.class_name
|
||||
else None
|
||||
)
|
||||
if original_docstring:
|
||||
if not updated_node.get_docstring(clean=False):
|
||||
new_body: list[BaseStatement] = [
|
||||
SimpleStatementLine(
|
||||
body=[
|
||||
Expr(
|
||||
value=SimpleString(
|
||||
f'"""{original_docstring}"""'
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
*cast(
|
||||
"list[BaseStatement]",
|
||||
list(updated_node.body.body),
|
||||
),
|
||||
]
|
||||
updated_node = updated_node.with_changes(
|
||||
body=IndentedBlock(body=new_body)
|
||||
)
|
||||
else:
|
||||
new_body = [
|
||||
SimpleStatementLine(
|
||||
body=[
|
||||
Expr(
|
||||
value=SimpleString(
|
||||
f'"""{original_docstring}"""'
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
*cast(
|
||||
"list[BaseStatement]",
|
||||
list(updated_node.body.body[1:]),
|
||||
),
|
||||
]
|
||||
updated_node = updated_node.with_changes(
|
||||
body=IndentedBlock(body=new_body)
|
||||
)
|
||||
self.class_name = None
|
||||
return updated_node
|
||||
|
||||
def leave_FunctionDef(
|
||||
self,
|
||||
original_node: FunctionDef,
|
||||
updated_node: FunctionDef,
|
||||
) -> FunctionDef:
|
||||
function_name = updated_node.name.value
|
||||
qualified_name = (
|
||||
f"{self.class_name}.{function_name}"
|
||||
if self.class_name
|
||||
else function_name
|
||||
)
|
||||
original_docstring = self.original_docstrings.get(qualified_name)
|
||||
if original_docstring:
|
||||
if not updated_node.get_docstring(clean=False):
|
||||
new_body: list[BaseStatement] = [
|
||||
SimpleStatementLine(
|
||||
body=[
|
||||
Expr(
|
||||
value=SimpleString(
|
||||
f'"""{original_docstring}"""'
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
*cast(
|
||||
"list[BaseStatement]",
|
||||
list(updated_node.body.body),
|
||||
),
|
||||
]
|
||||
updated_node = updated_node.with_changes(
|
||||
body=IndentedBlock(body=new_body)
|
||||
)
|
||||
else:
|
||||
new_body = [
|
||||
SimpleStatementLine(
|
||||
body=[
|
||||
Expr(
|
||||
value=SimpleString(
|
||||
f'"""{original_docstring}"""'
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
*cast(
|
||||
"list[BaseStatement]",
|
||||
list(updated_node.body.body[1:]),
|
||||
),
|
||||
]
|
||||
updated_node = updated_node.with_changes(
|
||||
body=IndentedBlock(body=new_body)
|
||||
)
|
||||
return updated_node
|
||||
|
||||
|
||||
def fix_missing_docstring(
|
||||
original_module: cst.Module,
|
||||
candidates: list[OptimizationCandidate],
|
||||
) -> list[OptimizationCandidate]:
|
||||
"""
|
||||
Restore docstrings that the LLM removed.
|
||||
"""
|
||||
visitor = DocstringVisitor()
|
||||
try:
|
||||
original_module.visit(visitor)
|
||||
except Exception:
|
||||
return candidates
|
||||
transformer = DocstringTransformer(visitor.original_docstrings)
|
||||
result: list[OptimizationCandidate] = []
|
||||
for c in candidates:
|
||||
try:
|
||||
result.append(
|
||||
OptimizationCandidate(
|
||||
cst_module=c.cst_module.visit(transformer),
|
||||
explanation=c.explanation,
|
||||
id=c.id,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
log.warning("Failed to restore docstring for %s", c.id)
|
||||
result.append(c)
|
||||
return result
|
||||
|
||||
|
||||
def dedup_and_sort_imports(
|
||||
_original_module: cst.Module,
|
||||
candidates: list[OptimizationCandidate],
|
||||
) -> list[OptimizationCandidate]:
|
||||
"""
|
||||
Run isort on all candidates.
|
||||
"""
|
||||
result: list[OptimizationCandidate] = []
|
||||
for c in candidates:
|
||||
try:
|
||||
original_code = c.cst_module.code
|
||||
sorted_code = safe_isort(original_code, disregard_skip=True)
|
||||
except Exception:
|
||||
result.append(c)
|
||||
continue
|
||||
if sorted_code == original_code:
|
||||
result.append(c)
|
||||
else:
|
||||
result.append(
|
||||
OptimizationCandidate(
|
||||
cst_module=parse_module_to_cst(sorted_code),
|
||||
explanation=c.explanation,
|
||||
id=c.id,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class EllipsisContainingCodeVisitor(CSTVisitor):
|
||||
"""
|
||||
Detect any ellipsis in code.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.ellipsis_containing_code = False
|
||||
|
||||
def visit_Ellipsis(
|
||||
self,
|
||||
node: cst.Ellipsis,
|
||||
) -> bool:
|
||||
self.ellipsis_containing_code = True
|
||||
return False
|
||||
|
||||
|
||||
def filter_ellipsis_containing_code(
|
||||
original_module: cst.Module,
|
||||
candidates: list[OptimizationCandidate],
|
||||
) -> list[OptimizationCandidate]:
|
||||
"""
|
||||
Filter out candidates that introduce ellipsis not present in original.
|
||||
"""
|
||||
original_visitor = EllipsisContainingCodeVisitor()
|
||||
original_module.visit(original_visitor)
|
||||
if original_visitor.ellipsis_containing_code:
|
||||
return candidates
|
||||
result: list[OptimizationCandidate] = []
|
||||
for c in candidates:
|
||||
visitor = EllipsisContainingCodeVisitor()
|
||||
c.cst_module.visit(visitor)
|
||||
if not visitor.ellipsis_containing_code:
|
||||
result.append(c)
|
||||
return result
|
||||
|
||||
|
||||
def _strip_comments_from_code(code: str) -> str:
|
||||
"""
|
||||
Remove all comments from Python code, preserving strings.
|
||||
"""
|
||||
try:
|
||||
lines = code.splitlines(keepends=True)
|
||||
tokens = tokenize.generate_tokens(io.StringIO(code).readline)
|
||||
|
||||
comments_by_line: list[list[tuple[int, int]]] = [
|
||||
[] for _ in range(len(lines))
|
||||
]
|
||||
for token in tokens:
|
||||
if token.type == tokenize.COMMENT:
|
||||
line_idx = token.start[0] - 1
|
||||
if 0 <= line_idx < len(lines):
|
||||
comments_by_line[line_idx].append(
|
||||
(token.start[1], token.end[1])
|
||||
)
|
||||
|
||||
result_lines: list[str] = []
|
||||
for line_num, line in enumerate(lines, start=1):
|
||||
current = line
|
||||
line_comments = comments_by_line[line_num - 1]
|
||||
|
||||
if line_comments:
|
||||
line_comments.sort(reverse=True)
|
||||
for start_col, end_col in line_comments:
|
||||
current = current[:start_col].rstrip() + current[end_col:]
|
||||
|
||||
if not current.endswith("\n") and lines[line_num - 1].endswith(
|
||||
"\n"
|
||||
):
|
||||
current += "\n"
|
||||
|
||||
result_lines.append(current)
|
||||
|
||||
return "".join(result_lines)
|
||||
except (tokenize.TokenError, IndentationError):
|
||||
return code
|
||||
|
||||
|
||||
def clean_extraneous_comments(
|
||||
original_module: cst.Module,
|
||||
optimized_module: cst.Module,
|
||||
*,
|
||||
orig_code: str | None = None,
|
||||
orig_code_stripped: str | None = None,
|
||||
) -> cst.Module:
|
||||
"""
|
||||
Remove extraneous comments from optimized code using diff-based approach.
|
||||
"""
|
||||
try:
|
||||
if orig_code is None:
|
||||
orig_code = original_module.code
|
||||
opt_code_full = optimized_module.code
|
||||
|
||||
orig_lines = orig_code.splitlines(keepends=True)
|
||||
opt_lines = opt_code_full.splitlines(keepends=True)
|
||||
|
||||
if orig_code_stripped is None:
|
||||
orig_code_stripped = _strip_comments_from_code(orig_code)
|
||||
opt_code_stripped = _strip_comments_from_code(opt_code_full)
|
||||
|
||||
orig_code_only = orig_code_stripped.splitlines(keepends=True)
|
||||
opt_code_only = opt_code_stripped.splitlines(keepends=True)
|
||||
|
||||
orig_code_line_indices: list[int] = []
|
||||
orig_code_lines_filtered: list[str] = []
|
||||
for i, line in enumerate(orig_code_only):
|
||||
if line.strip():
|
||||
orig_code_line_indices.append(i)
|
||||
orig_code_lines_filtered.append(line)
|
||||
|
||||
opt_code_line_indices: list[int] = []
|
||||
opt_code_lines_filtered: list[str] = []
|
||||
for i, line in enumerate(opt_code_only):
|
||||
if line.strip():
|
||||
opt_code_line_indices.append(i)
|
||||
opt_code_lines_filtered.append(line)
|
||||
|
||||
code_matcher = SequenceMatcher(
|
||||
None,
|
||||
orig_code_lines_filtered,
|
||||
opt_code_lines_filtered,
|
||||
)
|
||||
code_changed_line_indices: set[int] = set()
|
||||
|
||||
for tag, _i1, _i2, j1, j2 in code_matcher.get_opcodes():
|
||||
if tag != "equal":
|
||||
for j in range(j1, j2):
|
||||
code_changed_line_indices.add(j)
|
||||
|
||||
code_changed_lines: set[int] = set()
|
||||
for filtered_idx in code_changed_line_indices:
|
||||
if filtered_idx < len(opt_code_line_indices):
|
||||
code_changed_lines.add(opt_code_line_indices[filtered_idx])
|
||||
|
||||
orig_to_opt_mapping: dict[int, int] = {}
|
||||
|
||||
for tag, i1, i2, j1, j2 in code_matcher.get_opcodes():
|
||||
if tag == "equal":
|
||||
for offset in range(min(i2 - i1, j2 - j1)):
|
||||
orig_idx = orig_code_line_indices[i1 + offset]
|
||||
opt_idx = opt_code_line_indices[j1 + offset]
|
||||
orig_to_opt_mapping[orig_idx] = opt_idx
|
||||
|
||||
result_lines: list[str] = []
|
||||
orig_idx = 0
|
||||
restored_orig_indices: set[int] = set()
|
||||
|
||||
for opt_idx, opt_line in enumerate(opt_lines):
|
||||
if opt_idx in code_changed_lines:
|
||||
result_lines.append(opt_line)
|
||||
else:
|
||||
opt_code = (
|
||||
opt_code_only[opt_idx]
|
||||
if opt_idx < len(opt_code_only)
|
||||
else ""
|
||||
)
|
||||
|
||||
is_comment_only = not opt_code.strip()
|
||||
is_near_change = False
|
||||
if is_comment_only:
|
||||
for check_idx in range(opt_idx + 1, len(opt_lines)):
|
||||
if check_idx in code_changed_lines:
|
||||
is_near_change = True
|
||||
break
|
||||
check_code = (
|
||||
opt_code_only[check_idx]
|
||||
if check_idx < len(opt_code_only)
|
||||
else ""
|
||||
)
|
||||
if check_code.strip():
|
||||
break
|
||||
|
||||
if not is_near_change:
|
||||
for check_idx in range(opt_idx - 1, -1, -1):
|
||||
if check_idx in code_changed_lines:
|
||||
is_near_change = True
|
||||
break
|
||||
check_code = (
|
||||
opt_code_only[check_idx]
|
||||
if check_idx < len(opt_code_only)
|
||||
else ""
|
||||
)
|
||||
if check_code.strip():
|
||||
break
|
||||
|
||||
if is_comment_only and is_near_change:
|
||||
result_lines.append(opt_line)
|
||||
elif is_comment_only and not is_near_change:
|
||||
pass
|
||||
elif opt_idx in code_changed_lines:
|
||||
result_lines.append(opt_line)
|
||||
else:
|
||||
found_orig = None
|
||||
orig_line_idx = None
|
||||
for orig_idx_search in range(orig_idx, len(orig_lines)):
|
||||
orig_code_line = (
|
||||
orig_code_only[orig_idx_search]
|
||||
if orig_idx_search < len(orig_code_only)
|
||||
else ""
|
||||
)
|
||||
if orig_code_line == opt_code:
|
||||
found_orig = orig_lines[orig_idx_search]
|
||||
orig_line_idx = orig_idx_search
|
||||
orig_idx = orig_idx_search + 1
|
||||
break
|
||||
|
||||
if found_orig:
|
||||
if (
|
||||
orig_line_idx is not None
|
||||
and orig_line_idx > 0
|
||||
and opt_idx not in code_changed_lines
|
||||
):
|
||||
preceding_lines: list[str] = []
|
||||
check_idx = orig_line_idx - 1
|
||||
|
||||
while check_idx >= 0:
|
||||
check_code = (
|
||||
orig_code_only[check_idx]
|
||||
if check_idx < len(orig_code_only)
|
||||
else ""
|
||||
)
|
||||
if not check_code.strip():
|
||||
if (
|
||||
check_idx not in orig_to_opt_mapping
|
||||
and check_idx
|
||||
not in restored_orig_indices
|
||||
):
|
||||
preceding_lines.insert(
|
||||
0,
|
||||
orig_lines[check_idx],
|
||||
)
|
||||
restored_orig_indices.add(check_idx)
|
||||
check_idx -= 1
|
||||
else:
|
||||
break
|
||||
|
||||
result_lines.extend(preceding_lines)
|
||||
|
||||
result_lines.append(found_orig)
|
||||
if orig_line_idx is not None:
|
||||
restored_orig_indices.add(orig_line_idx)
|
||||
|
||||
if (
|
||||
orig_line_idx is not None
|
||||
and orig_line_idx < len(orig_lines) - 1
|
||||
and opt_idx not in code_changed_lines
|
||||
):
|
||||
trailing_lines: list[str] = []
|
||||
check_idx = orig_line_idx + 1
|
||||
found_changed_line = False
|
||||
|
||||
while check_idx < len(orig_lines):
|
||||
check_code = (
|
||||
orig_code_only[check_idx]
|
||||
if check_idx < len(orig_code_only)
|
||||
else ""
|
||||
)
|
||||
if not check_code.strip():
|
||||
if (
|
||||
check_idx not in orig_to_opt_mapping
|
||||
and check_idx
|
||||
not in restored_orig_indices
|
||||
):
|
||||
trailing_lines.append(
|
||||
orig_lines[check_idx]
|
||||
)
|
||||
restored_orig_indices.add(check_idx)
|
||||
check_idx += 1
|
||||
else:
|
||||
if check_idx in orig_to_opt_mapping:
|
||||
next_opt_idx = orig_to_opt_mapping[
|
||||
check_idx
|
||||
]
|
||||
if next_opt_idx in code_changed_lines:
|
||||
found_changed_line = True
|
||||
else:
|
||||
found_changed_line = True
|
||||
break
|
||||
|
||||
if not found_changed_line:
|
||||
result_lines.extend(trailing_lines)
|
||||
else:
|
||||
result_lines.append(opt_line)
|
||||
|
||||
cleaned_code = "".join(result_lines)
|
||||
return cst.parse_module(cleaned_code)
|
||||
|
||||
except Exception:
|
||||
log.warning("Error cleaning comments")
|
||||
return optimized_module
|
||||
|
||||
|
||||
def clean_extraneous_comments_pipeline(
|
||||
original_module: cst.Module,
|
||||
candidates: list[OptimizationCandidate],
|
||||
*,
|
||||
orig_code: str | None = None,
|
||||
orig_code_stripped: str | None = None,
|
||||
) -> list[OptimizationCandidate]:
|
||||
"""
|
||||
Pipeline wrapper for comment cleaning.
|
||||
"""
|
||||
try:
|
||||
if orig_code is None:
|
||||
orig_code = original_module.code
|
||||
if orig_code_stripped is None:
|
||||
orig_code_stripped = _strip_comments_from_code(orig_code)
|
||||
|
||||
result: list[OptimizationCandidate] = []
|
||||
for c in candidates:
|
||||
try:
|
||||
cleaned = clean_extraneous_comments(
|
||||
original_module,
|
||||
c.cst_module,
|
||||
orig_code=orig_code,
|
||||
orig_code_stripped=orig_code_stripped,
|
||||
)
|
||||
result.append(
|
||||
OptimizationCandidate(
|
||||
cst_module=cleaned,
|
||||
explanation=c.explanation,
|
||||
id=c.id,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
log.warning(
|
||||
"Error cleaning comments for %s",
|
||||
c.id,
|
||||
)
|
||||
result.append(c)
|
||||
return result
|
||||
except Exception:
|
||||
log.warning("Error in comment cleaning pipeline")
|
||||
return candidates
|
||||
|
||||
|
||||
def extract_names_from_cst_annotation(
|
||||
annotation: cst.BaseExpression,
|
||||
names: set[str],
|
||||
) -> None:
|
||||
"""
|
||||
Walk a CST annotation expression and collect Name nodes.
|
||||
"""
|
||||
if isinstance(annotation, cst.Name):
|
||||
names.add(annotation.value)
|
||||
elif isinstance(annotation, cst.Attribute):
|
||||
extract_names_from_cst_annotation(annotation.value, names)
|
||||
elif isinstance(annotation, cst.Subscript):
|
||||
extract_names_from_cst_annotation(annotation.value, names)
|
||||
for element in annotation.slice:
|
||||
if isinstance(element, cst.SubscriptElement):
|
||||
slice_node = element.slice
|
||||
if isinstance(slice_node, cst.Index):
|
||||
extract_names_from_cst_annotation(slice_node.value, names)
|
||||
elif isinstance(annotation, cst.BinaryOperation):
|
||||
extract_names_from_cst_annotation(annotation.left, names)
|
||||
extract_names_from_cst_annotation(annotation.right, names)
|
||||
elif isinstance(annotation, (cst.Tuple, cst.List)):
|
||||
for element in annotation.elements:
|
||||
if isinstance(
|
||||
element,
|
||||
(cst.Element, cst.StarredElement),
|
||||
):
|
||||
extract_names_from_cst_annotation(element.value, names)
|
||||
|
||||
|
||||
class CSTAnnotationNameCollector(cst.CSTVisitor):
|
||||
"""
|
||||
Collect names used in type annotations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.annotation_names: set[str] = set()
|
||||
self.defined_names: set[str] = set()
|
||||
self.imported_names: set[str] = set()
|
||||
|
||||
def visit_Import(self, node: cst.Import) -> bool | None:
|
||||
self.imported_names.update(collect_imported_names_from_import(node))
|
||||
return False
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> bool | None:
|
||||
self.imported_names.update(
|
||||
collect_imported_names_from_import_from(node)
|
||||
)
|
||||
return False
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
|
||||
self.defined_names.add(node.name.value)
|
||||
self._collect_from_function(node)
|
||||
return True
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
|
||||
self.defined_names.add(node.name.value)
|
||||
return True
|
||||
|
||||
def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None:
|
||||
extract_names_from_cst_annotation(
|
||||
node.annotation.annotation,
|
||||
self.annotation_names,
|
||||
)
|
||||
return True
|
||||
|
||||
def _collect_from_function(self, node: cst.FunctionDef) -> None:
|
||||
"""
|
||||
Collect annotation names from function parameters and return type.
|
||||
"""
|
||||
params = node.params
|
||||
for param in [
|
||||
*params.params,
|
||||
*params.kwonly_params,
|
||||
*params.posonly_params,
|
||||
]:
|
||||
if param.annotation:
|
||||
extract_names_from_cst_annotation(
|
||||
param.annotation.annotation,
|
||||
self.annotation_names,
|
||||
)
|
||||
if (
|
||||
params.star_arg
|
||||
and isinstance(params.star_arg, cst.Param)
|
||||
and params.star_arg.annotation
|
||||
):
|
||||
extract_names_from_cst_annotation(
|
||||
params.star_arg.annotation.annotation,
|
||||
self.annotation_names,
|
||||
)
|
||||
if params.star_kwarg and params.star_kwarg.annotation:
|
||||
extract_names_from_cst_annotation(
|
||||
params.star_kwarg.annotation.annotation,
|
||||
self.annotation_names,
|
||||
)
|
||||
if node.returns:
|
||||
extract_names_from_cst_annotation(
|
||||
node.returns.annotation,
|
||||
self.annotation_names,
|
||||
)
|
||||
|
||||
def get_undefined_annotation_names(self) -> set[str]:
|
||||
"""
|
||||
Return annotation names not defined, imported, or builtin.
|
||||
"""
|
||||
return (
|
||||
self.annotation_names
|
||||
- self.defined_names
|
||||
- self.imported_names
|
||||
- BUILTIN_NAMES
|
||||
)
|
||||
|
||||
|
||||
def has_future_annotations_import(
|
||||
module: cst.Module,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if *module* has 'from __future__ import annotations'.
|
||||
"""
|
||||
for stmt in module.body:
|
||||
if isinstance(stmt, cst.SimpleStatementLine):
|
||||
for small_stmt in stmt.body:
|
||||
if (
|
||||
isinstance(small_stmt, cst.ImportFrom)
|
||||
and small_stmt.module
|
||||
and get_dotted_name(small_stmt.module) == "__future__"
|
||||
):
|
||||
if isinstance(small_stmt.names, cst.ImportStar):
|
||||
continue
|
||||
for alias in small_stmt.names:
|
||||
if alias.name.value == "annotations":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def add_future_annotations_import(
|
||||
module: cst.Module,
|
||||
) -> cst.Module:
|
||||
"""
|
||||
Add 'from __future__ import annotations' if needed for forward refs.
|
||||
"""
|
||||
if has_future_annotations_import(module):
|
||||
return module
|
||||
|
||||
collector = CSTAnnotationNameCollector()
|
||||
module.visit(collector)
|
||||
if not collector.get_undefined_annotation_names():
|
||||
return module
|
||||
|
||||
context = CodemodContext()
|
||||
AddImportsVisitor.add_needed_import(context, "__future__", "annotations")
|
||||
return module.visit(AddImportsVisitor(context))
|
||||
|
||||
|
||||
def fix_forward_references(
|
||||
_original_module: cst.Module,
|
||||
candidates: list[OptimizationCandidate],
|
||||
) -> list[OptimizationCandidate]:
|
||||
"""
|
||||
Add future annotations import where needed.
|
||||
"""
|
||||
result: list[OptimizationCandidate] = []
|
||||
for c in candidates:
|
||||
try:
|
||||
new_module = add_future_annotations_import(c.cst_module)
|
||||
if new_module is not c.cst_module:
|
||||
result.append(
|
||||
OptimizationCandidate(
|
||||
cst_module=new_module,
|
||||
explanation=c.explanation,
|
||||
id=c.id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(c)
|
||||
except Exception:
|
||||
log.warning("Error fixing forward refs for %s", c.id)
|
||||
result.append(c)
|
||||
return result
|
||||
|
||||
|
||||
def optimizations_postprocessing_pipeline(
|
||||
original_module: cst.Module,
|
||||
candidates: list[OptimizationCandidate],
|
||||
) -> list[OptimizationCandidate]:
|
||||
"""
|
||||
Run the full postprocessing pipeline on optimization candidates.
|
||||
"""
|
||||
original_code = original_module.code
|
||||
original_code_stripped = _strip_comments_from_code(original_code)
|
||||
|
||||
candidates = fix_missing_docstring(original_module, candidates)
|
||||
candidates = clean_extraneous_comments_pipeline(
|
||||
original_module,
|
||||
candidates,
|
||||
orig_code=original_code,
|
||||
orig_code_stripped=original_code_stripped,
|
||||
)
|
||||
candidates = fix_forward_references(original_module, candidates)
|
||||
candidates = deduplicate_optimizations(original_module, candidates)
|
||||
candidates = equality_check(
|
||||
original_module,
|
||||
candidates,
|
||||
original_code=original_code,
|
||||
)
|
||||
candidates = dedup_and_sort_imports(original_module, candidates)
|
||||
candidates = cleanup_explanations(original_module, candidates)
|
||||
candidates = filter_ellipsis_containing_code(original_module, candidates)
|
||||
return candidates
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import logging
|
||||
|
||||
import libcst as cst
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_python_syntax(source: str) -> bool:
|
||||
"""
|
||||
Check if *source* is valid Python using both ast and libcst.
|
||||
"""
|
||||
try:
|
||||
ast.parse(source)
|
||||
except SyntaxError:
|
||||
return False
|
||||
|
||||
try:
|
||||
module = cst.parse_module(source)
|
||||
if not module.body:
|
||||
return False
|
||||
except cst.ParserSyntaxError:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def parse_python_or_none(
|
||||
source: str,
|
||||
) -> cst.Module | None:
|
||||
"""
|
||||
Parse *source* into a CST module, or None on failure.
|
||||
"""
|
||||
try:
|
||||
module = cst.parse_module(source)
|
||||
except cst.ParserSyntaxError:
|
||||
log.warning("Failed to parse Python source")
|
||||
return None
|
||||
if not module.body:
|
||||
return None
|
||||
return module
|
||||
543
packages/codeflash-api/tests/test_language_python.py
Normal file
543
packages/codeflash-api/tests/test_language_python.py
Normal file
|
|
@ -0,0 +1,543 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import libcst as cst
|
||||
import pytest
|
||||
|
||||
from codeflash_api.languages.python._cst_utils import (
|
||||
AnyEllipsisVisitor,
|
||||
DefinitionRemover,
|
||||
DepthTrackingMixin,
|
||||
ImportTrackingVisitor,
|
||||
build_module_path,
|
||||
evaluate_expression,
|
||||
file_path_to_module_path,
|
||||
find_init,
|
||||
get_base_class_name,
|
||||
get_dotted_name,
|
||||
has_decorator,
|
||||
make_number_node,
|
||||
parse_module_to_cst,
|
||||
unparse_parse_source,
|
||||
)
|
||||
from codeflash_api.languages.python._postprocess import (
|
||||
OptimizationCandidate,
|
||||
add_future_annotations_import,
|
||||
cleanup_explanations,
|
||||
deduplicate_optimizations,
|
||||
equality_check,
|
||||
filter_ellipsis_containing_code,
|
||||
fix_missing_docstring,
|
||||
has_future_annotations_import,
|
||||
optimizations_postprocessing_pipeline,
|
||||
safe_isort,
|
||||
)
|
||||
from codeflash_api.languages.python._validator import (
|
||||
parse_python_or_none,
|
||||
validate_python_syntax,
|
||||
)
|
||||
|
||||
|
||||
class TestDepthTrackingMixin:
|
||||
"""Tests for DepthTrackingMixin."""
|
||||
|
||||
def test_initial_state(self) -> None:
|
||||
"""
|
||||
Starts at top level.
|
||||
"""
|
||||
mixin = DepthTrackingMixin()
|
||||
assert mixin._is_top_level()
|
||||
|
||||
def test_function_depth(self) -> None:
|
||||
"""
|
||||
Entering a function increases depth.
|
||||
"""
|
||||
mixin = DepthTrackingMixin()
|
||||
mixin._visit_function()
|
||||
assert not mixin._is_top_level()
|
||||
assert mixin._is_inside_function()
|
||||
mixin._leave_function()
|
||||
assert mixin._is_top_level()
|
||||
|
||||
|
||||
class TestFilePathToModulePath:
|
||||
"""Tests for file_path_to_module_path."""
|
||||
|
||||
def test_unix_path(self) -> None:
|
||||
"""
|
||||
Forward slashes become dots.
|
||||
"""
|
||||
assert "path.to.module" == file_path_to_module_path(
|
||||
"path/to/module.py"
|
||||
)
|
||||
|
||||
def test_windows_path(self) -> None:
|
||||
"""
|
||||
Backslashes become dots.
|
||||
"""
|
||||
assert "path.to.module" == file_path_to_module_path(
|
||||
"path\\to\\module.py"
|
||||
)
|
||||
|
||||
|
||||
class TestGetDottedName:
|
||||
"""Tests for get_dotted_name."""
|
||||
|
||||
def test_simple_name(self) -> None:
|
||||
"""
|
||||
Simple Name node returns the value.
|
||||
"""
|
||||
node = cst.Name("foo")
|
||||
assert "foo" == get_dotted_name(node)
|
||||
|
||||
def test_attribute(self) -> None:
|
||||
"""
|
||||
Dotted attribute returns full path.
|
||||
"""
|
||||
node = cst.Attribute(value=cst.Name("foo"), attr=cst.Name("bar"))
|
||||
assert "foo.bar" == get_dotted_name(node)
|
||||
|
||||
def test_none(self) -> None:
|
||||
"""
|
||||
None returns empty string.
|
||||
"""
|
||||
assert "" == get_dotted_name(None)
|
||||
|
||||
|
||||
class TestBuildModulePath:
|
||||
"""Tests for build_module_path."""
|
||||
|
||||
def test_single_part(self) -> None:
|
||||
"""
|
||||
Single name becomes cst.Name.
|
||||
"""
|
||||
result = build_module_path("foo")
|
||||
assert isinstance(result, cst.Name)
|
||||
assert "foo" == result.value
|
||||
|
||||
def test_dotted(self) -> None:
|
||||
"""
|
||||
Dotted path becomes cst.Attribute chain.
|
||||
"""
|
||||
result = build_module_path("foo.bar.baz")
|
||||
assert isinstance(result, cst.Attribute)
|
||||
assert "baz" == result.attr.value
|
||||
|
||||
|
||||
class TestEvaluateExpression:
|
||||
"""Tests for evaluate_expression."""
|
||||
|
||||
def test_integer(self) -> None:
|
||||
"""
|
||||
Integer node evaluates to its value.
|
||||
"""
|
||||
assert 42 == evaluate_expression(cst.Integer("42"))
|
||||
|
||||
def test_hex(self) -> None:
|
||||
"""
|
||||
Hex integer evaluates correctly.
|
||||
"""
|
||||
assert 255 == evaluate_expression(cst.Integer("0xff"))
|
||||
|
||||
def test_negative(self) -> None:
|
||||
"""
|
||||
Unary minus evaluates to negative.
|
||||
"""
|
||||
node = cst.UnaryOperation(
|
||||
operator=cst.Minus(),
|
||||
expression=cst.Integer("5"),
|
||||
)
|
||||
assert -5 == evaluate_expression(node)
|
||||
|
||||
def test_binary_add(self) -> None:
|
||||
"""
|
||||
Addition evaluates correctly.
|
||||
"""
|
||||
node = cst.BinaryOperation(
|
||||
left=cst.Integer("3"),
|
||||
operator=cst.Add(),
|
||||
right=cst.Integer("4"),
|
||||
)
|
||||
assert 7 == evaluate_expression(node)
|
||||
|
||||
def test_unevaluable(self) -> None:
|
||||
"""
|
||||
Name node returns None.
|
||||
"""
|
||||
assert evaluate_expression(cst.Name("x")) is None
|
||||
|
||||
|
||||
class TestMakeNumberNode:
|
||||
"""Tests for make_number_node."""
|
||||
|
||||
def test_positive(self) -> None:
|
||||
"""
|
||||
Positive value becomes Integer.
|
||||
"""
|
||||
node = make_number_node(5)
|
||||
assert isinstance(node, cst.Integer)
|
||||
assert "5" == node.value
|
||||
|
||||
def test_negative(self) -> None:
|
||||
"""
|
||||
Negative value becomes UnaryOperation.
|
||||
"""
|
||||
node = make_number_node(-3)
|
||||
assert isinstance(node, cst.UnaryOperation)
|
||||
|
||||
|
||||
class TestParseModuleToCst:
|
||||
"""Tests for parse_module_to_cst."""
|
||||
|
||||
def test_caches(self) -> None:
|
||||
"""
|
||||
Same input returns same object.
|
||||
"""
|
||||
code = "x = 1"
|
||||
a = parse_module_to_cst(code)
|
||||
b = parse_module_to_cst(code)
|
||||
assert a is b
|
||||
|
||||
|
||||
class TestDefinitionRemover:
|
||||
"""Tests for DefinitionRemover."""
|
||||
|
||||
def test_removes_function(self) -> None:
|
||||
"""
|
||||
Top-level function is removed by name.
|
||||
"""
|
||||
code = "def foo():\n pass\ndef bar():\n pass\n"
|
||||
module = cst.parse_module(code)
|
||||
remover = DefinitionRemover({"foo"})
|
||||
result = module.visit(remover)
|
||||
assert "foo" not in result.code
|
||||
assert "bar" in result.code
|
||||
assert "foo" in remover.removed_names
|
||||
|
||||
def test_protected_name(self) -> None:
|
||||
"""
|
||||
Protected name is not removed.
|
||||
"""
|
||||
code = "def foo():\n pass\n"
|
||||
module = cst.parse_module(code)
|
||||
remover = DefinitionRemover({"foo"}, protected_names={"foo"})
|
||||
result = module.visit(remover)
|
||||
assert "foo" in result.code
|
||||
|
||||
|
||||
class TestImportTrackingVisitor:
|
||||
"""Tests for ImportTrackingVisitor."""
|
||||
|
||||
def test_tracks_imports(self) -> None:
|
||||
"""
|
||||
Imported names are collected.
|
||||
"""
|
||||
import ast
|
||||
|
||||
tree = ast.parse("import os\nfrom sys import path")
|
||||
visitor = ImportTrackingVisitor()
|
||||
visitor.visit(tree)
|
||||
assert "os" in visitor.imported_names
|
||||
assert "path" in visitor.imported_names
|
||||
|
||||
|
||||
class TestFindInit:
|
||||
"""Tests for find_init."""
|
||||
|
||||
def test_found(self) -> None:
|
||||
"""
|
||||
__init__ in a class is found.
|
||||
"""
|
||||
import ast
|
||||
|
||||
tree = ast.parse("class Foo:\n def __init__(self):\n pass\n")
|
||||
assert find_init(tree)
|
||||
|
||||
def test_not_found(self) -> None:
|
||||
"""
|
||||
No __init__ returns False.
|
||||
"""
|
||||
import ast
|
||||
|
||||
tree = ast.parse("class Foo:\n def bar(self):\n pass\n")
|
||||
assert not find_init(tree)
|
||||
|
||||
|
||||
class TestValidatePythonSyntax:
|
||||
"""Tests for validate_python_syntax."""
|
||||
|
||||
def test_valid(self) -> None:
|
||||
"""
|
||||
Valid Python returns True.
|
||||
"""
|
||||
assert validate_python_syntax("x = 1\n")
|
||||
|
||||
def test_invalid(self) -> None:
|
||||
"""
|
||||
Invalid Python returns False.
|
||||
"""
|
||||
assert not validate_python_syntax("def (:\n")
|
||||
|
||||
def test_empty_body(self) -> None:
|
||||
"""
|
||||
Empty content (only comments) returns False.
|
||||
"""
|
||||
assert not validate_python_syntax("# just a comment\n")
|
||||
|
||||
|
||||
class TestParsePythonOrNone:
|
||||
"""Tests for parse_python_or_none."""
|
||||
|
||||
def test_valid(self) -> None:
|
||||
"""
|
||||
Valid Python returns a Module.
|
||||
"""
|
||||
result = parse_python_or_none("x = 1\n")
|
||||
assert result is not None
|
||||
assert isinstance(result, cst.Module)
|
||||
|
||||
def test_invalid(self) -> None:
|
||||
"""
|
||||
Invalid Python returns None.
|
||||
"""
|
||||
assert parse_python_or_none("def (:\n") is None
|
||||
|
||||
|
||||
class TestOptimizationCandidate:
|
||||
"""Tests for OptimizationCandidate."""
|
||||
|
||||
def test_frozen(self) -> None:
|
||||
"""
|
||||
OptimizationCandidate is immutable.
|
||||
"""
|
||||
c = OptimizationCandidate(
|
||||
cst_module=cst.parse_module("x = 1"),
|
||||
explanation="faster",
|
||||
id="test",
|
||||
)
|
||||
with pytest.raises(AttributeError):
|
||||
c.id = "changed"
|
||||
|
||||
|
||||
class TestSafeIsort:
|
||||
"""Tests for safe_isort."""
|
||||
|
||||
def test_sorts(self) -> None:
|
||||
"""
|
||||
Imports are sorted.
|
||||
"""
|
||||
code = "import sys\nimport os\n"
|
||||
result = safe_isort(code)
|
||||
assert result.index("os") < result.index("sys")
|
||||
|
||||
def test_invalid_returns_original(self) -> None:
|
||||
"""
|
||||
Invalid code returns unchanged.
|
||||
"""
|
||||
code = "not valid python {{{{"
|
||||
assert code == safe_isort(code)
|
||||
|
||||
|
||||
class TestDeduplicateOptimizations:
|
||||
"""Tests for deduplicate_optimizations."""
|
||||
|
||||
def test_removes_duplicates(self) -> None:
|
||||
"""
|
||||
Candidates with same AST are deduplicated.
|
||||
"""
|
||||
module = cst.parse_module("x = 1\n")
|
||||
c1 = OptimizationCandidate(
|
||||
cst_module=cst.parse_module("x = 1\n"),
|
||||
explanation="a",
|
||||
id="1",
|
||||
)
|
||||
c2 = OptimizationCandidate(
|
||||
cst_module=cst.parse_module("x=1\n"),
|
||||
explanation="b",
|
||||
id="2",
|
||||
)
|
||||
result = deduplicate_optimizations(module, [c1, c2])
|
||||
assert 1 == len(result)
|
||||
|
||||
|
||||
class TestEqualityCheck:
|
||||
"""Tests for equality_check."""
|
||||
|
||||
def test_filters_identical(self) -> None:
|
||||
"""
|
||||
Candidate identical to original is filtered.
|
||||
"""
|
||||
module = cst.parse_module("x = 1\n")
|
||||
c = OptimizationCandidate(
|
||||
cst_module=cst.parse_module("x = 1\n"),
|
||||
explanation="same",
|
||||
id="1",
|
||||
)
|
||||
result = equality_check(module, [c])
|
||||
assert 0 == len(result)
|
||||
|
||||
def test_keeps_different(self) -> None:
|
||||
"""
|
||||
Different candidate is kept.
|
||||
"""
|
||||
module = cst.parse_module("x = 1\n")
|
||||
c = OptimizationCandidate(
|
||||
cst_module=cst.parse_module("x = 2\n"),
|
||||
explanation="different",
|
||||
id="1",
|
||||
)
|
||||
result = equality_check(module, [c])
|
||||
assert 1 == len(result)
|
||||
|
||||
|
||||
class TestFilterEllipsis:
|
||||
"""Tests for filter_ellipsis_containing_code."""
|
||||
|
||||
def test_filters_introduced_ellipsis(self) -> None:
|
||||
"""
|
||||
Candidate that introduces ellipsis is filtered.
|
||||
"""
|
||||
module = cst.parse_module("x = 1\n")
|
||||
c = OptimizationCandidate(
|
||||
cst_module=cst.parse_module("x = ...\n"),
|
||||
explanation="bad",
|
||||
id="1",
|
||||
)
|
||||
result = filter_ellipsis_containing_code(module, [c])
|
||||
assert 0 == len(result)
|
||||
|
||||
def test_keeps_when_original_has_ellipsis(self) -> None:
|
||||
"""
|
||||
If original has ellipsis, candidates are not filtered.
|
||||
"""
|
||||
module = cst.parse_module("x = ...\n")
|
||||
c = OptimizationCandidate(
|
||||
cst_module=cst.parse_module("x = ...\n"),
|
||||
explanation="ok",
|
||||
id="1",
|
||||
)
|
||||
result = filter_ellipsis_containing_code(module, [c])
|
||||
assert 1 == len(result)
|
||||
|
||||
|
||||
class TestFixMissingDocstring:
|
||||
"""Tests for fix_missing_docstring."""
|
||||
|
||||
def test_restores_docstring(self) -> None:
|
||||
"""
|
||||
Removed docstring is restored.
|
||||
"""
|
||||
original = cst.parse_module(
|
||||
'def foo():\n """My docstring."""\n pass\n'
|
||||
)
|
||||
optimized = cst.parse_module("def foo():\n pass\n")
|
||||
c = OptimizationCandidate(
|
||||
cst_module=optimized,
|
||||
explanation="faster",
|
||||
id="1",
|
||||
)
|
||||
result = fix_missing_docstring(original, [c])
|
||||
assert "My docstring" in result[0].cst_module.code
|
||||
|
||||
|
||||
class TestHasFutureAnnotationsImport:
|
||||
"""Tests for has_future_annotations_import."""
|
||||
|
||||
def test_present(self) -> None:
|
||||
"""
|
||||
Module with the import returns True.
|
||||
"""
|
||||
module = cst.parse_module("from __future__ import annotations\n")
|
||||
assert has_future_annotations_import(module)
|
||||
|
||||
def test_absent(self) -> None:
|
||||
"""
|
||||
Module without the import returns False.
|
||||
"""
|
||||
module = cst.parse_module("x = 1\n")
|
||||
assert not has_future_annotations_import(module)
|
||||
|
||||
|
||||
class TestAddFutureAnnotationsImport:
|
||||
"""Tests for add_future_annotations_import."""
|
||||
|
||||
def test_adds_when_needed(self) -> None:
|
||||
"""
|
||||
Import is added when undefined annotations exist.
|
||||
"""
|
||||
code = "def foo(x: MyType) -> None:\n pass\n"
|
||||
module = cst.parse_module(code)
|
||||
result = add_future_annotations_import(module)
|
||||
assert "from __future__ import annotations" in result.code
|
||||
|
||||
def test_skips_when_present(self) -> None:
|
||||
"""
|
||||
No change when import already exists.
|
||||
"""
|
||||
code = "from __future__ import annotations\ndef foo(x: MyType) -> None:\n pass\n"
|
||||
module = cst.parse_module(code)
|
||||
result = add_future_annotations_import(module)
|
||||
assert result is module
|
||||
|
||||
def test_skips_when_no_undefined(self) -> None:
|
||||
"""
|
||||
No change when all annotations are defined.
|
||||
"""
|
||||
code = "def foo(x: int) -> None:\n pass\n"
|
||||
module = cst.parse_module(code)
|
||||
result = add_future_annotations_import(module)
|
||||
assert result is module
|
||||
|
||||
|
||||
class TestCleanupExplanations:
|
||||
"""Tests for cleanup_explanations."""
|
||||
|
||||
def test_removes_code_block(self) -> None:
|
||||
"""
|
||||
Markdown code blocks are removed from explanations.
|
||||
"""
|
||||
module = cst.parse_module("x = 1\n")
|
||||
c = OptimizationCandidate(
|
||||
cst_module=module,
|
||||
explanation="```python\nx = 1\n```",
|
||||
id="1",
|
||||
)
|
||||
result = cleanup_explanations(module, [c])
|
||||
assert "```" not in result[0].explanation
|
||||
|
||||
|
||||
class TestPostprocessingPipeline:
|
||||
"""Tests for optimizations_postprocessing_pipeline."""
|
||||
|
||||
def test_empty_candidates(self) -> None:
|
||||
"""
|
||||
Empty list returns empty list.
|
||||
"""
|
||||
module = cst.parse_module("x = 1\n")
|
||||
result = optimizations_postprocessing_pipeline(module, [])
|
||||
assert [] == result
|
||||
|
||||
def test_filters_identical(self) -> None:
|
||||
"""
|
||||
Identical candidate is filtered by the pipeline.
|
||||
"""
|
||||
module = cst.parse_module("x = 1\n")
|
||||
c = OptimizationCandidate(
|
||||
cst_module=cst.parse_module("x = 1\n"),
|
||||
explanation="same",
|
||||
id="1",
|
||||
)
|
||||
result = optimizations_postprocessing_pipeline(module, [c])
|
||||
assert 0 == len(result)
|
||||
|
||||
def test_keeps_valid_optimization(self) -> None:
|
||||
"""
|
||||
Valid different candidate passes through.
|
||||
"""
|
||||
module = cst.parse_module("x = 1 + 1\n")
|
||||
c = OptimizationCandidate(
|
||||
cst_module=cst.parse_module("x = 2\n"),
|
||||
explanation="precomputed",
|
||||
id="1",
|
||||
)
|
||||
result = optimizations_postprocessing_pipeline(module, [c])
|
||||
assert 1 == len(result)
|
||||
|
|
@ -81,6 +81,19 @@ ignore = [
|
|||
"S104", # binding to 0.0.0.0 is the dev default, overridden in production
|
||||
"S105", # dev placeholder secret_key, overridden via env var in production
|
||||
]
|
||||
"packages/codeflash-api/src/codeflash_api/languages/python/_cst_utils.py" = [
|
||||
"N802", # libcst visitor methods must match visit_NodeName / leave_NodeName
|
||||
"PLR2004", # magic values in faithfully ported ellipsis depth checks
|
||||
]
|
||||
"packages/codeflash-api/src/codeflash_api/languages/python/_postprocess.py" = [
|
||||
"BLE001", # broad except is intentional — postprocessing must not crash the pipeline
|
||||
"C901", # faithfully ported recursive annotation walker and comment cleaner
|
||||
"N802", # libcst visitor methods must match visit_NodeName / leave_NodeName
|
||||
"PERF203", # try/except in loop is intentional — each candidate processed independently
|
||||
"PLR0912", # faithfully ported comment cleaning logic
|
||||
"PLR0915", # faithfully ported comment cleaning logic
|
||||
"TRY300", # faithfully ported control flow
|
||||
]
|
||||
"packages/codeflash-api/src/codeflash_api/diff/_v4a.py" = [
|
||||
"C901", # peek_next_section and _parse_update_file_sections faithfully ported from Django
|
||||
"PLR0912", # too many branches in faithfully ported _parse_update_file_sections
|
||||
|
|
|
|||
Loading…
Reference in a new issue