Add language layer: CST utils, validator, postprocessing pipeline

Faithful port of Python language utilities from Django aiservice:
- _cst_utils.py: depth tracking, import extraction, definition removal,
  ellipsis detection, expression evaluation, module path helpers
- _validator.py: dual ast+libcst syntax validation, parse-or-none
- _postprocess.py: full optimization postprocessing pipeline including
  dedup, equality check, docstring restoration, comment cleaning,
  forward reference fixing, ellipsis filtering, isort
This commit is contained in:
Kevin Turcios 2026-04-21 22:04:39 -05:00
parent 5c6b82050a
commit 3e62f502e7
5 changed files with 2033 additions and 0 deletions

View file

@ -0,0 +1,528 @@
from __future__ import annotations
import ast
import operator
from collections import deque
from functools import lru_cache
from typing import TYPE_CHECKING
import libcst as cst
from libcst.metadata import (
MetadataWrapper,
ParentNodeProvider,
PositionProvider,
)
if TYPE_CHECKING:
from collections.abc import Iterator
class DepthTrackingMixin:
"""
Tracks function and class nesting depth for CST transformers.
"""
def __init__(self) -> None:
self._function_depth: int = 0
self._class_depth: int = 0
def _visit_function(self) -> None:
self._function_depth += 1
def _leave_function(self) -> None:
self._function_depth -= 1
def _visit_class(self) -> None:
self._class_depth += 1
def _leave_class(self) -> None:
self._class_depth -= 1
def _is_top_level(self) -> bool:
return self._function_depth == 0 and self._class_depth == 0
def _is_top_level_function(self) -> bool:
return self._function_depth == 0 and self._class_depth == 0
def _is_top_level_class(self) -> bool:
return self._class_depth == 0 and self._function_depth == 0
def _is_inside_class(self) -> bool:
return self._class_depth > 0
def _is_inside_function(self) -> bool:
return self._function_depth > 0
def extract_import_info(
alias: cst.ImportAlias, module_name: str = ""
) -> tuple[str, str, str]:
"""
Extract (available_name, module_path, original_name) from an import alias.
"""
original_name = (
alias.name.value
if isinstance(alias.name, cst.Name)
else get_dotted_name(alias.name)
)
if alias.asname and isinstance(alias.asname.name, cst.Name):
available_name = alias.asname.name.value
elif module_name:
available_name = original_name
else:
available_name = original_name.split(".")[0]
return available_name, module_name, original_name
def extract_imports_from_import(
node: cst.Import,
) -> dict[str, tuple[str, str]]:
"""
Extract all imported names from an Import statement.
"""
if isinstance(node.names, cst.ImportStar):
return {}
result: dict[str, tuple[str, str]] = {}
for alias in node.names:
available_name, _, original_name = extract_import_info(alias)
if available_name:
result[available_name] = (
original_name,
original_name,
)
return result
def extract_imports_from_import_from(
node: cst.ImportFrom,
) -> dict[str, tuple[str, str]]:
"""
Extract all imported names from an ImportFrom statement.
"""
if isinstance(node.names, cst.ImportStar):
return {}
module_name = get_dotted_name(node.module) if node.module else ""
result: dict[str, tuple[str, str]] = {}
for alias in node.names:
available_name, _, original_name = extract_import_info(
alias, module_name
)
if available_name:
result[available_name] = (
module_name,
original_name,
)
return result
def collect_imported_names_from_import(
node: cst.Import,
) -> set[str]:
"""
Collect available names from an Import statement.
"""
if isinstance(node.names, cst.ImportStar):
return set()
return set(extract_imports_from_import(node).keys())
def collect_imported_names_from_import_from(
node: cst.ImportFrom,
) -> set[str]:
"""
Collect available names from an ImportFrom statement.
"""
if isinstance(node.names, cst.ImportStar):
return set()
return set(extract_imports_from_import_from(node).keys())
class DefinitionRemover(DepthTrackingMixin, cst.CSTTransformer):
"""
Remove top-level class and function definitions by name.
"""
def __init__(
self,
names_to_remove: set[str] | list[str],
protected_names: set[str] | None = None,
) -> None:
DepthTrackingMixin.__init__(self)
cst.CSTTransformer.__init__(self)
names_list = list(names_to_remove)
self._qualified_names: list[list[str]] = [
s.split(".")[-2:] for s in names_list
]
self._simple_names: set[str] = {
pair[-1] for pair in self._qualified_names if pair
}
self.protected_names: set[str] = protected_names or set()
self.removed_names: set[str] = set()
self._is_killable_class: bool = False
self._potential_methods: set[str] = set()
def _should_remove(self, name: str) -> bool:
"""
Check if *name* should be removed.
"""
return name in self._simple_names and name not in self.protected_names
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
is_top_level = self._is_top_level_class()
self._visit_class()
if not is_top_level:
return True
class_name = node.name.value
if self._should_remove(class_name):
self._is_killable_class = True
return False
potential_methods = {
pair[-1]
for pair in self._qualified_names
if len(pair) > 1
and pair[-2] == class_name
and pair[-1] not in self.protected_names
}
if potential_methods:
self._potential_methods = potential_methods
return True
def leave_ClassDef(
self,
original_node: cst.ClassDef,
updated_node: cst.ClassDef,
) -> cst.ClassDef | cst.RemovalSentinel:
self._leave_class()
if self._is_top_level() and self._is_killable_class:
self._is_killable_class = False
self._potential_methods = set()
self.removed_names.add(original_node.name.value)
return cst.RemovalSentinel.REMOVE
self._potential_methods = set()
return updated_node
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
self._visit_function()
return True
def leave_FunctionDef(
self,
original_node: cst.FunctionDef,
updated_node: cst.FunctionDef,
) -> cst.FunctionDef | cst.RemovalSentinel:
self._leave_function()
func_name = original_node.name.value
if self._is_top_level() and self._should_remove(func_name):
self.removed_names.add(func_name)
return cst.RemovalSentinel.REMOVE
if func_name in self._potential_methods:
self._is_killable_class = True
self._potential_methods = set()
return updated_node
class ImportTrackingVisitor(ast.NodeVisitor):
"""
Base AST visitor that tracks imported names.
"""
def __init__(self) -> None:
self.imported_names: set[str] = set()
def visit_Import(self, node: ast.Import) -> None:
for alias in node.names:
name = alias.asname or alias.name.split(".")[0]
self.imported_names.add(name)
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
for alias in node.names:
if alias.name == "*":
continue
name = alias.asname or alias.name
self.imported_names.add(name)
self.generic_visit(node)
@lru_cache(maxsize=128)
def parse_module_to_cst(
module_str: str,
) -> cst.Module:
"""
Parse a module string into its libCST representation.
"""
return cst.parse_module(module_str)
def file_path_to_module_path(file_path: str) -> str:
"""
Convert a file path to a dotted module path.
"""
file_path = file_path.removesuffix(".py")
return file_path.replace("/", ".").replace("\\", ".")
def get_dotted_name(
node: cst.BaseExpression | None,
) -> str:
"""
Extract dotted module path from a CST node.
"""
if node is None:
return ""
if isinstance(node, cst.Name):
return node.value
if isinstance(node, cst.Attribute):
base = get_dotted_name(node.value)
if base:
return f"{base}.{node.attr.value}"
return node.attr.value
return ""
def get_base_class_name(
base: cst.Arg,
) -> str | None:
"""
Get the name of a base class from a CST Arg node.
"""
expr = base.value
if isinstance(expr, cst.Name):
return expr.value
if isinstance(expr, cst.Attribute):
return expr.attr.value
return None
def has_decorator(class_node: cst.ClassDef, decorator_name: str) -> bool:
"""
Check if a class has a specific decorator.
"""
for decorator in class_node.decorators:
dec = decorator.decorator
if isinstance(dec, cst.Name) and dec.value == decorator_name:
return True
if (
isinstance(dec, cst.Call)
and isinstance(dec.func, cst.Name)
and dec.func.value == decorator_name
):
return True
return False
def get_node_source_text(
node: cst.CSTNode,
source_lines: list[str],
wrapper: MetadataWrapper,
fallback: str = "...",
) -> str:
"""
Get the source text of a CST node using position metadata.
"""
pos = wrapper.resolve(PositionProvider).get(node)
if pos is None:
return fallback
return source_lines[pos.start.line - 1][pos.start.column : pos.end.column]
def build_module_path(
module_path: str,
) -> cst.Attribute | cst.Name:
"""
Build a libcst node from a dotted module path string.
"""
parts = module_path.split(".")
result: cst.Attribute | cst.Name = cst.Name(parts[0])
for part in parts[1:]:
result = cst.Attribute(value=result, attr=cst.Name(part))
return result
_BINARY_OPS = {
cst.Power: operator.pow,
cst.Multiply: operator.mul,
cst.Add: operator.add,
cst.Subtract: operator.sub,
}
def make_number_node(
value: int,
) -> cst.Integer | cst.UnaryOperation:
"""
Create a CST node representing an integer value.
"""
if value >= 0:
return cst.Integer(str(value))
return cst.UnaryOperation(
operator=cst.Minus(),
expression=cst.Integer(str(abs(value))),
)
def evaluate_expression(
node: cst.BaseExpression,
) -> int | None:
"""
Evaluate a CST expression node to an integer, or None.
"""
if isinstance(node, cst.Integer):
return int(node.value, 0)
if isinstance(node, cst.Float):
return int(float(node.value))
if isinstance(node, cst.BinaryOperation):
left = evaluate_expression(node.left)
right = evaluate_expression(node.right)
op = _BINARY_OPS.get(type(node.operator))
if left is not None and right is not None and op is not None:
return op(left, right)
if isinstance(node, cst.UnaryOperation) and isinstance(
node.operator, cst.Minus
):
val = evaluate_expression(node.expression)
if val is not None:
return -val
return None
class AnyEllipsisVisitor(cst.CSTVisitor):
"""
Detect if any ellipsis exists in the code.
"""
def __init__(self) -> None:
self.found = False
def visit_Ellipsis(self, node: cst.Ellipsis) -> bool:
self.found = True
return False
class InvalidEllipsisVisitor(cst.CSTVisitor):
"""
Detect ellipsis in invalid locations.
"""
METADATA_DEPENDENCIES = (ParentNodeProvider,)
def __init__(self) -> None:
self.invalid_found = False
def visit_Ellipsis(self, node: cst.Ellipsis) -> bool:
if self._is_in_type_annotation_context(node):
return False
if self._is_function_body_ellipsis(node):
return False
self.invalid_found = True
return False
def _iter_ancestors(
self,
node: cst.CSTNode,
max_depth: int = 10,
) -> Iterator[cst.CSTNode]:
"""
Yield ancestors walking up the tree.
"""
current = node
for _ in range(max_depth):
parent = self.get_metadata(ParentNodeProvider, current, None)
if parent is None:
break
yield parent
current = parent
def _is_in_type_annotation_context(self, node: cst.CSTNode) -> bool:
"""
Check if *node* is within a type annotation.
"""
return any(
isinstance(ancestor, (cst.Subscript, cst.Tuple))
for ancestor in self._iter_ancestors(node)
)
def _is_function_body_ellipsis(self, node: cst.Ellipsis) -> bool:
"""
Check if *node* is a function body ellipsis.
"""
ancestors = list(self._iter_ancestors(node, max_depth=4))
if len(ancestors) < 4:
return False
expected_types = (
cst.Expr,
cst.SimpleStatementLine,
cst.IndentedBlock,
cst.FunctionDef,
)
return all(
isinstance(ancestor, expected)
for ancestor, expected in zip(
ancestors, expected_types, strict=True
)
)
def ellipsis_in_cst_not_types(
module: cst.Module,
) -> bool:
"""
Check if *module* contains ellipsis in invalid locations.
"""
visitor = InvalidEllipsisVisitor()
MetadataWrapper(module).visit(visitor)
return visitor.invalid_found
def any_ellipsis_in_cst(module: cst.Module) -> bool:
"""
Check if *module* contains any ellipsis.
"""
visitor = AnyEllipsisVisitor()
module.visit(visitor)
return visitor.found
def unparse_parse_source(source: str) -> str:
"""
Parse source and unparse it back, normalizing formatting.
"""
return ast.unparse(ast.parse(source))
def compare_unparsed_ast_to_source(unparsed_ast: str, source: str) -> bool:
"""
Compare an unparsed AST string to a source string after normalizing.
"""
return unparsed_ast == unparse_parse_source(source)
def find_init(node: ast.AST) -> bool:
"""
Search an AST for a FunctionDef node named '__init__'.
"""
stack = deque([node])
while stack:
current = stack.pop()
if isinstance(current, ast.FunctionDef) and current.name == "__init__":
return True
stack.extend(ast.iter_child_nodes(current))
return False

View file

@ -0,0 +1,906 @@
from __future__ import annotations
import ast
import builtins
import io
import logging
import re
import tokenize
from difflib import SequenceMatcher
from typing import TYPE_CHECKING, cast
import attrs
import isort
import libcst as cst
from libcst import (
BaseStatement,
CSTTransformer,
CSTVisitor,
Expr,
IndentedBlock,
SimpleStatementLine,
SimpleString,
)
from libcst.codemod import CodemodContext
from libcst.codemod.visitors import AddImportsVisitor
from codeflash_api.languages.python._cst_utils import (
collect_imported_names_from_import,
collect_imported_names_from_import_from,
compare_unparsed_ast_to_source,
get_dotted_name,
parse_module_to_cst,
unparse_parse_source,
)
if TYPE_CHECKING:
from libcst import FunctionDef
log = logging.getLogger(__name__)
BUILTIN_NAMES = frozenset(dir(builtins))
@attrs.frozen
class OptimizationCandidate:
"""
An optimized code variant with its CST module and explanation.
"""
cst_module: cst.Module
explanation: str
id: str
def safe_isort(code: str, **kwargs: object) -> str:
"""
Run isort on *code*, returning original on failure.
"""
try:
return isort.code(code, **kwargs)
except Exception:
return code
def deduplicate_optimizations(
_original_module: cst.Module,
candidates: list[OptimizationCandidate],
) -> list[OptimizationCandidate]:
"""
Remove candidates with equivalent ASTs.
"""
seen_asts: set[str] = set()
unique: list[OptimizationCandidate] = []
for candidate in candidates:
normalized = ast.unparse(ast.parse(candidate.cst_module.code))
if normalized not in seen_asts:
seen_asts.add(normalized)
unique.append(candidate)
return unique
def equality_check(
original_module: cst.Module,
candidates: list[OptimizationCandidate],
*,
original_code: str | None = None,
) -> list[OptimizationCandidate]:
"""
Filter out candidates identical to the original.
"""
source = (
original_code if original_code is not None else original_module.code
)
try:
original_ast = unparse_parse_source(source)
except Exception:
return [c for c in candidates if c.cst_module.code != source]
filtered: list[OptimizationCandidate] = []
for c in candidates:
try:
if not compare_unparsed_ast_to_source(
original_ast, c.cst_module.code
):
filtered.append(c)
except Exception:
if c.cst_module.code != source:
filtered.append(c)
return filtered
_EXPLANATION_PATTERNS = [
(
re.compile(
r"\sHere (is|are) (the )?((code|optimization)"
r"|\S+\scode|(optimized|improved) versions?"
r" of (the code|these functions))(:|.)",
re.IGNORECASE,
),
"",
),
(
re.compile(r"^```(.*?)```", re.MULTILINE | re.DOTALL),
"",
),
(re.compile(r", as follows:"), "."),
(re.compile(r":\n"), ".\n"),
]
def cleanup_explanations(
_original_module: cst.Module,
candidates: list[OptimizationCandidate],
) -> list[OptimizationCandidate]:
"""
Clean up explanation text from LLM artifacts.
"""
result: list[OptimizationCandidate] = []
for c in candidates:
cleaned = c.explanation
for pattern, repl in _EXPLANATION_PATTERNS:
cleaned = pattern.sub(repl, cleaned)
result.append(
OptimizationCandidate(
cst_module=c.cst_module,
explanation=cleaned,
id=c.id,
)
)
return result
class DocstringVisitor(CSTVisitor):
"""
Collect original docstrings from functions and classes.
"""
def __init__(self) -> None:
self.original_docstrings: dict[str, str] = {}
self.class_name: str | None = None
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self.class_name = node.name.value
docstring = node.get_docstring(clean=False)
if docstring:
self.original_docstrings[self.class_name] = docstring
return True
def leave_ClassDef(
self,
original_node: cst.ClassDef,
) -> None:
self.class_name = None
def visit_FunctionDef(self, node: FunctionDef) -> bool:
function_name = node.name.value
qualified_name = (
f"{self.class_name}.{function_name}"
if self.class_name
else function_name
)
docstring = node.get_docstring(clean=False)
if docstring:
self.original_docstrings[qualified_name] = docstring
return True
class DocstringTransformer(CSTTransformer):
"""
Restore original docstrings in optimized code.
"""
def __init__(self, original_docstrings: dict[str, str]) -> None:
self.original_docstrings = original_docstrings
self.class_name: str | None = None
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self.class_name = node.name.value
return True
def leave_ClassDef(
self,
original_node: cst.ClassDef,
updated_node: cst.ClassDef,
) -> cst.ClassDef:
original_docstring = (
self.original_docstrings.get(self.class_name)
if self.class_name
else None
)
if original_docstring:
if not updated_node.get_docstring(clean=False):
new_body: list[BaseStatement] = [
SimpleStatementLine(
body=[
Expr(
value=SimpleString(
f'"""{original_docstring}"""'
)
)
]
),
*cast(
"list[BaseStatement]",
list(updated_node.body.body),
),
]
updated_node = updated_node.with_changes(
body=IndentedBlock(body=new_body)
)
else:
new_body = [
SimpleStatementLine(
body=[
Expr(
value=SimpleString(
f'"""{original_docstring}"""'
)
)
]
),
*cast(
"list[BaseStatement]",
list(updated_node.body.body[1:]),
),
]
updated_node = updated_node.with_changes(
body=IndentedBlock(body=new_body)
)
self.class_name = None
return updated_node
def leave_FunctionDef(
self,
original_node: FunctionDef,
updated_node: FunctionDef,
) -> FunctionDef:
function_name = updated_node.name.value
qualified_name = (
f"{self.class_name}.{function_name}"
if self.class_name
else function_name
)
original_docstring = self.original_docstrings.get(qualified_name)
if original_docstring:
if not updated_node.get_docstring(clean=False):
new_body: list[BaseStatement] = [
SimpleStatementLine(
body=[
Expr(
value=SimpleString(
f'"""{original_docstring}"""'
)
)
]
),
*cast(
"list[BaseStatement]",
list(updated_node.body.body),
),
]
updated_node = updated_node.with_changes(
body=IndentedBlock(body=new_body)
)
else:
new_body = [
SimpleStatementLine(
body=[
Expr(
value=SimpleString(
f'"""{original_docstring}"""'
)
)
]
),
*cast(
"list[BaseStatement]",
list(updated_node.body.body[1:]),
),
]
updated_node = updated_node.with_changes(
body=IndentedBlock(body=new_body)
)
return updated_node
def fix_missing_docstring(
original_module: cst.Module,
candidates: list[OptimizationCandidate],
) -> list[OptimizationCandidate]:
"""
Restore docstrings that the LLM removed.
"""
visitor = DocstringVisitor()
try:
original_module.visit(visitor)
except Exception:
return candidates
transformer = DocstringTransformer(visitor.original_docstrings)
result: list[OptimizationCandidate] = []
for c in candidates:
try:
result.append(
OptimizationCandidate(
cst_module=c.cst_module.visit(transformer),
explanation=c.explanation,
id=c.id,
)
)
except Exception:
log.warning("Failed to restore docstring for %s", c.id)
result.append(c)
return result
def dedup_and_sort_imports(
_original_module: cst.Module,
candidates: list[OptimizationCandidate],
) -> list[OptimizationCandidate]:
"""
Run isort on all candidates.
"""
result: list[OptimizationCandidate] = []
for c in candidates:
try:
original_code = c.cst_module.code
sorted_code = safe_isort(original_code, disregard_skip=True)
except Exception:
result.append(c)
continue
if sorted_code == original_code:
result.append(c)
else:
result.append(
OptimizationCandidate(
cst_module=parse_module_to_cst(sorted_code),
explanation=c.explanation,
id=c.id,
)
)
return result
class EllipsisContainingCodeVisitor(CSTVisitor):
"""
Detect any ellipsis in code.
"""
def __init__(self) -> None:
self.ellipsis_containing_code = False
def visit_Ellipsis(
self,
node: cst.Ellipsis,
) -> bool:
self.ellipsis_containing_code = True
return False
def filter_ellipsis_containing_code(
original_module: cst.Module,
candidates: list[OptimizationCandidate],
) -> list[OptimizationCandidate]:
"""
Filter out candidates that introduce ellipsis not present in original.
"""
original_visitor = EllipsisContainingCodeVisitor()
original_module.visit(original_visitor)
if original_visitor.ellipsis_containing_code:
return candidates
result: list[OptimizationCandidate] = []
for c in candidates:
visitor = EllipsisContainingCodeVisitor()
c.cst_module.visit(visitor)
if not visitor.ellipsis_containing_code:
result.append(c)
return result
def _strip_comments_from_code(code: str) -> str:
"""
Remove all comments from Python code, preserving strings.
"""
try:
lines = code.splitlines(keepends=True)
tokens = tokenize.generate_tokens(io.StringIO(code).readline)
comments_by_line: list[list[tuple[int, int]]] = [
[] for _ in range(len(lines))
]
for token in tokens:
if token.type == tokenize.COMMENT:
line_idx = token.start[0] - 1
if 0 <= line_idx < len(lines):
comments_by_line[line_idx].append(
(token.start[1], token.end[1])
)
result_lines: list[str] = []
for line_num, line in enumerate(lines, start=1):
current = line
line_comments = comments_by_line[line_num - 1]
if line_comments:
line_comments.sort(reverse=True)
for start_col, end_col in line_comments:
current = current[:start_col].rstrip() + current[end_col:]
if not current.endswith("\n") and lines[line_num - 1].endswith(
"\n"
):
current += "\n"
result_lines.append(current)
return "".join(result_lines)
except (tokenize.TokenError, IndentationError):
return code
def clean_extraneous_comments(
original_module: cst.Module,
optimized_module: cst.Module,
*,
orig_code: str | None = None,
orig_code_stripped: str | None = None,
) -> cst.Module:
"""
Remove extraneous comments from optimized code using diff-based approach.
"""
try:
if orig_code is None:
orig_code = original_module.code
opt_code_full = optimized_module.code
orig_lines = orig_code.splitlines(keepends=True)
opt_lines = opt_code_full.splitlines(keepends=True)
if orig_code_stripped is None:
orig_code_stripped = _strip_comments_from_code(orig_code)
opt_code_stripped = _strip_comments_from_code(opt_code_full)
orig_code_only = orig_code_stripped.splitlines(keepends=True)
opt_code_only = opt_code_stripped.splitlines(keepends=True)
orig_code_line_indices: list[int] = []
orig_code_lines_filtered: list[str] = []
for i, line in enumerate(orig_code_only):
if line.strip():
orig_code_line_indices.append(i)
orig_code_lines_filtered.append(line)
opt_code_line_indices: list[int] = []
opt_code_lines_filtered: list[str] = []
for i, line in enumerate(opt_code_only):
if line.strip():
opt_code_line_indices.append(i)
opt_code_lines_filtered.append(line)
code_matcher = SequenceMatcher(
None,
orig_code_lines_filtered,
opt_code_lines_filtered,
)
code_changed_line_indices: set[int] = set()
for tag, _i1, _i2, j1, j2 in code_matcher.get_opcodes():
if tag != "equal":
for j in range(j1, j2):
code_changed_line_indices.add(j)
code_changed_lines: set[int] = set()
for filtered_idx in code_changed_line_indices:
if filtered_idx < len(opt_code_line_indices):
code_changed_lines.add(opt_code_line_indices[filtered_idx])
orig_to_opt_mapping: dict[int, int] = {}
for tag, i1, i2, j1, j2 in code_matcher.get_opcodes():
if tag == "equal":
for offset in range(min(i2 - i1, j2 - j1)):
orig_idx = orig_code_line_indices[i1 + offset]
opt_idx = opt_code_line_indices[j1 + offset]
orig_to_opt_mapping[orig_idx] = opt_idx
result_lines: list[str] = []
orig_idx = 0
restored_orig_indices: set[int] = set()
for opt_idx, opt_line in enumerate(opt_lines):
if opt_idx in code_changed_lines:
result_lines.append(opt_line)
else:
opt_code = (
opt_code_only[opt_idx]
if opt_idx < len(opt_code_only)
else ""
)
is_comment_only = not opt_code.strip()
is_near_change = False
if is_comment_only:
for check_idx in range(opt_idx + 1, len(opt_lines)):
if check_idx in code_changed_lines:
is_near_change = True
break
check_code = (
opt_code_only[check_idx]
if check_idx < len(opt_code_only)
else ""
)
if check_code.strip():
break
if not is_near_change:
for check_idx in range(opt_idx - 1, -1, -1):
if check_idx in code_changed_lines:
is_near_change = True
break
check_code = (
opt_code_only[check_idx]
if check_idx < len(opt_code_only)
else ""
)
if check_code.strip():
break
if is_comment_only and is_near_change:
result_lines.append(opt_line)
elif is_comment_only and not is_near_change:
pass
elif opt_idx in code_changed_lines:
result_lines.append(opt_line)
else:
found_orig = None
orig_line_idx = None
for orig_idx_search in range(orig_idx, len(orig_lines)):
orig_code_line = (
orig_code_only[orig_idx_search]
if orig_idx_search < len(orig_code_only)
else ""
)
if orig_code_line == opt_code:
found_orig = orig_lines[orig_idx_search]
orig_line_idx = orig_idx_search
orig_idx = orig_idx_search + 1
break
if found_orig:
if (
orig_line_idx is not None
and orig_line_idx > 0
and opt_idx not in code_changed_lines
):
preceding_lines: list[str] = []
check_idx = orig_line_idx - 1
while check_idx >= 0:
check_code = (
orig_code_only[check_idx]
if check_idx < len(orig_code_only)
else ""
)
if not check_code.strip():
if (
check_idx not in orig_to_opt_mapping
and check_idx
not in restored_orig_indices
):
preceding_lines.insert(
0,
orig_lines[check_idx],
)
restored_orig_indices.add(check_idx)
check_idx -= 1
else:
break
result_lines.extend(preceding_lines)
result_lines.append(found_orig)
if orig_line_idx is not None:
restored_orig_indices.add(orig_line_idx)
if (
orig_line_idx is not None
and orig_line_idx < len(orig_lines) - 1
and opt_idx not in code_changed_lines
):
trailing_lines: list[str] = []
check_idx = orig_line_idx + 1
found_changed_line = False
while check_idx < len(orig_lines):
check_code = (
orig_code_only[check_idx]
if check_idx < len(orig_code_only)
else ""
)
if not check_code.strip():
if (
check_idx not in orig_to_opt_mapping
and check_idx
not in restored_orig_indices
):
trailing_lines.append(
orig_lines[check_idx]
)
restored_orig_indices.add(check_idx)
check_idx += 1
else:
if check_idx in orig_to_opt_mapping:
next_opt_idx = orig_to_opt_mapping[
check_idx
]
if next_opt_idx in code_changed_lines:
found_changed_line = True
else:
found_changed_line = True
break
if not found_changed_line:
result_lines.extend(trailing_lines)
else:
result_lines.append(opt_line)
cleaned_code = "".join(result_lines)
return cst.parse_module(cleaned_code)
except Exception:
log.warning("Error cleaning comments")
return optimized_module
def clean_extraneous_comments_pipeline(
original_module: cst.Module,
candidates: list[OptimizationCandidate],
*,
orig_code: str | None = None,
orig_code_stripped: str | None = None,
) -> list[OptimizationCandidate]:
"""
Pipeline wrapper for comment cleaning.
"""
try:
if orig_code is None:
orig_code = original_module.code
if orig_code_stripped is None:
orig_code_stripped = _strip_comments_from_code(orig_code)
result: list[OptimizationCandidate] = []
for c in candidates:
try:
cleaned = clean_extraneous_comments(
original_module,
c.cst_module,
orig_code=orig_code,
orig_code_stripped=orig_code_stripped,
)
result.append(
OptimizationCandidate(
cst_module=cleaned,
explanation=c.explanation,
id=c.id,
)
)
except Exception:
log.warning(
"Error cleaning comments for %s",
c.id,
)
result.append(c)
return result
except Exception:
log.warning("Error in comment cleaning pipeline")
return candidates
def extract_names_from_cst_annotation(
annotation: cst.BaseExpression,
names: set[str],
) -> None:
"""
Walk a CST annotation expression and collect Name nodes.
"""
if isinstance(annotation, cst.Name):
names.add(annotation.value)
elif isinstance(annotation, cst.Attribute):
extract_names_from_cst_annotation(annotation.value, names)
elif isinstance(annotation, cst.Subscript):
extract_names_from_cst_annotation(annotation.value, names)
for element in annotation.slice:
if isinstance(element, cst.SubscriptElement):
slice_node = element.slice
if isinstance(slice_node, cst.Index):
extract_names_from_cst_annotation(slice_node.value, names)
elif isinstance(annotation, cst.BinaryOperation):
extract_names_from_cst_annotation(annotation.left, names)
extract_names_from_cst_annotation(annotation.right, names)
elif isinstance(annotation, (cst.Tuple, cst.List)):
for element in annotation.elements:
if isinstance(
element,
(cst.Element, cst.StarredElement),
):
extract_names_from_cst_annotation(element.value, names)
class CSTAnnotationNameCollector(cst.CSTVisitor):
"""
Collect names used in type annotations.
"""
def __init__(self) -> None:
self.annotation_names: set[str] = set()
self.defined_names: set[str] = set()
self.imported_names: set[str] = set()
def visit_Import(self, node: cst.Import) -> bool | None:
self.imported_names.update(collect_imported_names_from_import(node))
return False
def visit_ImportFrom(self, node: cst.ImportFrom) -> bool | None:
self.imported_names.update(
collect_imported_names_from_import_from(node)
)
return False
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
self.defined_names.add(node.name.value)
self._collect_from_function(node)
return True
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
self.defined_names.add(node.name.value)
return True
def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None:
extract_names_from_cst_annotation(
node.annotation.annotation,
self.annotation_names,
)
return True
def _collect_from_function(self, node: cst.FunctionDef) -> None:
"""
Collect annotation names from function parameters and return type.
"""
params = node.params
for param in [
*params.params,
*params.kwonly_params,
*params.posonly_params,
]:
if param.annotation:
extract_names_from_cst_annotation(
param.annotation.annotation,
self.annotation_names,
)
if (
params.star_arg
and isinstance(params.star_arg, cst.Param)
and params.star_arg.annotation
):
extract_names_from_cst_annotation(
params.star_arg.annotation.annotation,
self.annotation_names,
)
if params.star_kwarg and params.star_kwarg.annotation:
extract_names_from_cst_annotation(
params.star_kwarg.annotation.annotation,
self.annotation_names,
)
if node.returns:
extract_names_from_cst_annotation(
node.returns.annotation,
self.annotation_names,
)
def get_undefined_annotation_names(self) -> set[str]:
"""
Return annotation names not defined, imported, or builtin.
"""
return (
self.annotation_names
- self.defined_names
- self.imported_names
- BUILTIN_NAMES
)
def has_future_annotations_import(
module: cst.Module,
) -> bool:
"""
Check if *module* has 'from __future__ import annotations'.
"""
for stmt in module.body:
if isinstance(stmt, cst.SimpleStatementLine):
for small_stmt in stmt.body:
if (
isinstance(small_stmt, cst.ImportFrom)
and small_stmt.module
and get_dotted_name(small_stmt.module) == "__future__"
):
if isinstance(small_stmt.names, cst.ImportStar):
continue
for alias in small_stmt.names:
if alias.name.value == "annotations":
return True
return False
def add_future_annotations_import(
module: cst.Module,
) -> cst.Module:
"""
Add 'from __future__ import annotations' if needed for forward refs.
"""
if has_future_annotations_import(module):
return module
collector = CSTAnnotationNameCollector()
module.visit(collector)
if not collector.get_undefined_annotation_names():
return module
context = CodemodContext()
AddImportsVisitor.add_needed_import(context, "__future__", "annotations")
return module.visit(AddImportsVisitor(context))
def fix_forward_references(
_original_module: cst.Module,
candidates: list[OptimizationCandidate],
) -> list[OptimizationCandidate]:
"""
Add future annotations import where needed.
"""
result: list[OptimizationCandidate] = []
for c in candidates:
try:
new_module = add_future_annotations_import(c.cst_module)
if new_module is not c.cst_module:
result.append(
OptimizationCandidate(
cst_module=new_module,
explanation=c.explanation,
id=c.id,
)
)
else:
result.append(c)
except Exception:
log.warning("Error fixing forward refs for %s", c.id)
result.append(c)
return result
def optimizations_postprocessing_pipeline(
original_module: cst.Module,
candidates: list[OptimizationCandidate],
) -> list[OptimizationCandidate]:
"""
Run the full postprocessing pipeline on optimization candidates.
"""
original_code = original_module.code
original_code_stripped = _strip_comments_from_code(original_code)
candidates = fix_missing_docstring(original_module, candidates)
candidates = clean_extraneous_comments_pipeline(
original_module,
candidates,
orig_code=original_code,
orig_code_stripped=original_code_stripped,
)
candidates = fix_forward_references(original_module, candidates)
candidates = deduplicate_optimizations(original_module, candidates)
candidates = equality_check(
original_module,
candidates,
original_code=original_code,
)
candidates = dedup_and_sort_imports(original_module, candidates)
candidates = cleanup_explanations(original_module, candidates)
candidates = filter_ellipsis_containing_code(original_module, candidates)
return candidates

View file

@ -0,0 +1,43 @@
from __future__ import annotations
import ast
import logging
import libcst as cst
log = logging.getLogger(__name__)
def validate_python_syntax(source: str) -> bool:
"""
Check if *source* is valid Python using both ast and libcst.
"""
try:
ast.parse(source)
except SyntaxError:
return False
try:
module = cst.parse_module(source)
if not module.body:
return False
except cst.ParserSyntaxError:
return False
return True
def parse_python_or_none(
source: str,
) -> cst.Module | None:
"""
Parse *source* into a CST module, or None on failure.
"""
try:
module = cst.parse_module(source)
except cst.ParserSyntaxError:
log.warning("Failed to parse Python source")
return None
if not module.body:
return None
return module

View file

@ -0,0 +1,543 @@
from __future__ import annotations
import libcst as cst
import pytest
from codeflash_api.languages.python._cst_utils import (
AnyEllipsisVisitor,
DefinitionRemover,
DepthTrackingMixin,
ImportTrackingVisitor,
build_module_path,
evaluate_expression,
file_path_to_module_path,
find_init,
get_base_class_name,
get_dotted_name,
has_decorator,
make_number_node,
parse_module_to_cst,
unparse_parse_source,
)
from codeflash_api.languages.python._postprocess import (
OptimizationCandidate,
add_future_annotations_import,
cleanup_explanations,
deduplicate_optimizations,
equality_check,
filter_ellipsis_containing_code,
fix_missing_docstring,
has_future_annotations_import,
optimizations_postprocessing_pipeline,
safe_isort,
)
from codeflash_api.languages.python._validator import (
parse_python_or_none,
validate_python_syntax,
)
class TestDepthTrackingMixin:
"""Tests for DepthTrackingMixin."""
def test_initial_state(self) -> None:
"""
Starts at top level.
"""
mixin = DepthTrackingMixin()
assert mixin._is_top_level()
def test_function_depth(self) -> None:
"""
Entering a function increases depth.
"""
mixin = DepthTrackingMixin()
mixin._visit_function()
assert not mixin._is_top_level()
assert mixin._is_inside_function()
mixin._leave_function()
assert mixin._is_top_level()
class TestFilePathToModulePath:
"""Tests for file_path_to_module_path."""
def test_unix_path(self) -> None:
"""
Forward slashes become dots.
"""
assert "path.to.module" == file_path_to_module_path(
"path/to/module.py"
)
def test_windows_path(self) -> None:
"""
Backslashes become dots.
"""
assert "path.to.module" == file_path_to_module_path(
"path\\to\\module.py"
)
class TestGetDottedName:
"""Tests for get_dotted_name."""
def test_simple_name(self) -> None:
"""
Simple Name node returns the value.
"""
node = cst.Name("foo")
assert "foo" == get_dotted_name(node)
def test_attribute(self) -> None:
"""
Dotted attribute returns full path.
"""
node = cst.Attribute(value=cst.Name("foo"), attr=cst.Name("bar"))
assert "foo.bar" == get_dotted_name(node)
def test_none(self) -> None:
"""
None returns empty string.
"""
assert "" == get_dotted_name(None)
class TestBuildModulePath:
"""Tests for build_module_path."""
def test_single_part(self) -> None:
"""
Single name becomes cst.Name.
"""
result = build_module_path("foo")
assert isinstance(result, cst.Name)
assert "foo" == result.value
def test_dotted(self) -> None:
"""
Dotted path becomes cst.Attribute chain.
"""
result = build_module_path("foo.bar.baz")
assert isinstance(result, cst.Attribute)
assert "baz" == result.attr.value
class TestEvaluateExpression:
"""Tests for evaluate_expression."""
def test_integer(self) -> None:
"""
Integer node evaluates to its value.
"""
assert 42 == evaluate_expression(cst.Integer("42"))
def test_hex(self) -> None:
"""
Hex integer evaluates correctly.
"""
assert 255 == evaluate_expression(cst.Integer("0xff"))
def test_negative(self) -> None:
"""
Unary minus evaluates to negative.
"""
node = cst.UnaryOperation(
operator=cst.Minus(),
expression=cst.Integer("5"),
)
assert -5 == evaluate_expression(node)
def test_binary_add(self) -> None:
"""
Addition evaluates correctly.
"""
node = cst.BinaryOperation(
left=cst.Integer("3"),
operator=cst.Add(),
right=cst.Integer("4"),
)
assert 7 == evaluate_expression(node)
def test_unevaluable(self) -> None:
"""
Name node returns None.
"""
assert evaluate_expression(cst.Name("x")) is None
class TestMakeNumberNode:
"""Tests for make_number_node."""
def test_positive(self) -> None:
"""
Positive value becomes Integer.
"""
node = make_number_node(5)
assert isinstance(node, cst.Integer)
assert "5" == node.value
def test_negative(self) -> None:
"""
Negative value becomes UnaryOperation.
"""
node = make_number_node(-3)
assert isinstance(node, cst.UnaryOperation)
class TestParseModuleToCst:
"""Tests for parse_module_to_cst."""
def test_caches(self) -> None:
"""
Same input returns same object.
"""
code = "x = 1"
a = parse_module_to_cst(code)
b = parse_module_to_cst(code)
assert a is b
class TestDefinitionRemover:
"""Tests for DefinitionRemover."""
def test_removes_function(self) -> None:
"""
Top-level function is removed by name.
"""
code = "def foo():\n pass\ndef bar():\n pass\n"
module = cst.parse_module(code)
remover = DefinitionRemover({"foo"})
result = module.visit(remover)
assert "foo" not in result.code
assert "bar" in result.code
assert "foo" in remover.removed_names
def test_protected_name(self) -> None:
"""
Protected name is not removed.
"""
code = "def foo():\n pass\n"
module = cst.parse_module(code)
remover = DefinitionRemover({"foo"}, protected_names={"foo"})
result = module.visit(remover)
assert "foo" in result.code
class TestImportTrackingVisitor:
"""Tests for ImportTrackingVisitor."""
def test_tracks_imports(self) -> None:
"""
Imported names are collected.
"""
import ast
tree = ast.parse("import os\nfrom sys import path")
visitor = ImportTrackingVisitor()
visitor.visit(tree)
assert "os" in visitor.imported_names
assert "path" in visitor.imported_names
class TestFindInit:
"""Tests for find_init."""
def test_found(self) -> None:
"""
__init__ in a class is found.
"""
import ast
tree = ast.parse("class Foo:\n def __init__(self):\n pass\n")
assert find_init(tree)
def test_not_found(self) -> None:
"""
No __init__ returns False.
"""
import ast
tree = ast.parse("class Foo:\n def bar(self):\n pass\n")
assert not find_init(tree)
class TestValidatePythonSyntax:
"""Tests for validate_python_syntax."""
def test_valid(self) -> None:
"""
Valid Python returns True.
"""
assert validate_python_syntax("x = 1\n")
def test_invalid(self) -> None:
"""
Invalid Python returns False.
"""
assert not validate_python_syntax("def (:\n")
def test_empty_body(self) -> None:
"""
Empty content (only comments) returns False.
"""
assert not validate_python_syntax("# just a comment\n")
class TestParsePythonOrNone:
"""Tests for parse_python_or_none."""
def test_valid(self) -> None:
"""
Valid Python returns a Module.
"""
result = parse_python_or_none("x = 1\n")
assert result is not None
assert isinstance(result, cst.Module)
def test_invalid(self) -> None:
"""
Invalid Python returns None.
"""
assert parse_python_or_none("def (:\n") is None
class TestOptimizationCandidate:
"""Tests for OptimizationCandidate."""
def test_frozen(self) -> None:
"""
OptimizationCandidate is immutable.
"""
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 1"),
explanation="faster",
id="test",
)
with pytest.raises(AttributeError):
c.id = "changed"
class TestSafeIsort:
"""Tests for safe_isort."""
def test_sorts(self) -> None:
"""
Imports are sorted.
"""
code = "import sys\nimport os\n"
result = safe_isort(code)
assert result.index("os") < result.index("sys")
def test_invalid_returns_original(self) -> None:
"""
Invalid code returns unchanged.
"""
code = "not valid python {{{{"
assert code == safe_isort(code)
class TestDeduplicateOptimizations:
"""Tests for deduplicate_optimizations."""
def test_removes_duplicates(self) -> None:
"""
Candidates with same AST are deduplicated.
"""
module = cst.parse_module("x = 1\n")
c1 = OptimizationCandidate(
cst_module=cst.parse_module("x = 1\n"),
explanation="a",
id="1",
)
c2 = OptimizationCandidate(
cst_module=cst.parse_module("x=1\n"),
explanation="b",
id="2",
)
result = deduplicate_optimizations(module, [c1, c2])
assert 1 == len(result)
class TestEqualityCheck:
"""Tests for equality_check."""
def test_filters_identical(self) -> None:
"""
Candidate identical to original is filtered.
"""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 1\n"),
explanation="same",
id="1",
)
result = equality_check(module, [c])
assert 0 == len(result)
def test_keeps_different(self) -> None:
"""
Different candidate is kept.
"""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 2\n"),
explanation="different",
id="1",
)
result = equality_check(module, [c])
assert 1 == len(result)
class TestFilterEllipsis:
"""Tests for filter_ellipsis_containing_code."""
def test_filters_introduced_ellipsis(self) -> None:
"""
Candidate that introduces ellipsis is filtered.
"""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = ...\n"),
explanation="bad",
id="1",
)
result = filter_ellipsis_containing_code(module, [c])
assert 0 == len(result)
def test_keeps_when_original_has_ellipsis(self) -> None:
"""
If original has ellipsis, candidates are not filtered.
"""
module = cst.parse_module("x = ...\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = ...\n"),
explanation="ok",
id="1",
)
result = filter_ellipsis_containing_code(module, [c])
assert 1 == len(result)
class TestFixMissingDocstring:
"""Tests for fix_missing_docstring."""
def test_restores_docstring(self) -> None:
"""
Removed docstring is restored.
"""
original = cst.parse_module(
'def foo():\n """My docstring."""\n pass\n'
)
optimized = cst.parse_module("def foo():\n pass\n")
c = OptimizationCandidate(
cst_module=optimized,
explanation="faster",
id="1",
)
result = fix_missing_docstring(original, [c])
assert "My docstring" in result[0].cst_module.code
class TestHasFutureAnnotationsImport:
"""Tests for has_future_annotations_import."""
def test_present(self) -> None:
"""
Module with the import returns True.
"""
module = cst.parse_module("from __future__ import annotations\n")
assert has_future_annotations_import(module)
def test_absent(self) -> None:
"""
Module without the import returns False.
"""
module = cst.parse_module("x = 1\n")
assert not has_future_annotations_import(module)
class TestAddFutureAnnotationsImport:
"""Tests for add_future_annotations_import."""
def test_adds_when_needed(self) -> None:
"""
Import is added when undefined annotations exist.
"""
code = "def foo(x: MyType) -> None:\n pass\n"
module = cst.parse_module(code)
result = add_future_annotations_import(module)
assert "from __future__ import annotations" in result.code
def test_skips_when_present(self) -> None:
"""
No change when import already exists.
"""
code = "from __future__ import annotations\ndef foo(x: MyType) -> None:\n pass\n"
module = cst.parse_module(code)
result = add_future_annotations_import(module)
assert result is module
def test_skips_when_no_undefined(self) -> None:
"""
No change when all annotations are defined.
"""
code = "def foo(x: int) -> None:\n pass\n"
module = cst.parse_module(code)
result = add_future_annotations_import(module)
assert result is module
class TestCleanupExplanations:
"""Tests for cleanup_explanations."""
def test_removes_code_block(self) -> None:
"""
Markdown code blocks are removed from explanations.
"""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=module,
explanation="```python\nx = 1\n```",
id="1",
)
result = cleanup_explanations(module, [c])
assert "```" not in result[0].explanation
class TestPostprocessingPipeline:
"""Tests for optimizations_postprocessing_pipeline."""
def test_empty_candidates(self) -> None:
"""
Empty list returns empty list.
"""
module = cst.parse_module("x = 1\n")
result = optimizations_postprocessing_pipeline(module, [])
assert [] == result
def test_filters_identical(self) -> None:
"""
Identical candidate is filtered by the pipeline.
"""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 1\n"),
explanation="same",
id="1",
)
result = optimizations_postprocessing_pipeline(module, [c])
assert 0 == len(result)
def test_keeps_valid_optimization(self) -> None:
"""
Valid different candidate passes through.
"""
module = cst.parse_module("x = 1 + 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 2\n"),
explanation="precomputed",
id="1",
)
result = optimizations_postprocessing_pipeline(module, [c])
assert 1 == len(result)

View file

@ -81,6 +81,19 @@ ignore = [
"S104", # binding to 0.0.0.0 is the dev default, overridden in production
"S105", # dev placeholder secret_key, overridden via env var in production
]
"packages/codeflash-api/src/codeflash_api/languages/python/_cst_utils.py" = [
"N802", # libcst visitor methods must match visit_NodeName / leave_NodeName
"PLR2004", # magic values in faithfully ported ellipsis depth checks
]
"packages/codeflash-api/src/codeflash_api/languages/python/_postprocess.py" = [
"BLE001", # broad except is intentional — postprocessing must not crash the pipeline
"C901", # faithfully ported recursive annotation walker and comment cleaner
"N802", # libcst visitor methods must match visit_NodeName / leave_NodeName
"PERF203", # try/except in loop is intentional — each candidate processed independently
"PLR0912", # faithfully ported comment cleaning logic
"PLR0915", # faithfully ported comment cleaning logic
"TRY300", # faithfully ported control flow
]
"packages/codeflash-api/src/codeflash_api/diff/_v4a.py" = [
"C901", # peek_next_section and _parse_update_file_sections faithfully ported from Django
"PLR0912", # too many branches in faithfully ported _parse_update_file_sections