codeflash-agent/packages/codeflash-api/tests/test_language_python.py
Kevin Turcios 3a07579bb0 Raise codeflash-api test coverage from 81% to 92%
Add 182 new tests across optimize, V4A diff, CST utils, and postprocess
modules. Key coverage improvements:
- optimize/_pipeline.py: 29% → 97%
- optimize/_router.py: 40% → 93%
- diff/_v4a.py: 40% → 97%
- languages/python/_cst_utils.py: 67% → 96%
- languages/python/_postprocess.py: 67% → 87%

Also apply ruff format to 5 files that had formatting drift.
2026-04-22 23:39:54 -05:00

1512 lines
50 KiB
Python

from __future__ import annotations
import libcst as cst
import pytest
from codeflash_api.languages.python._cst_utils import (
AnyEllipsisVisitor,
DefinitionRemover,
DepthTrackingMixin,
ImportTrackingVisitor,
InvalidEllipsisVisitor,
any_ellipsis_in_cst,
build_module_path,
collect_imported_names_from_import,
collect_imported_names_from_import_from,
compare_unparsed_ast_to_source,
ellipsis_in_cst_not_types,
evaluate_expression,
extract_import_info,
extract_imports_from_import,
extract_imports_from_import_from,
file_path_to_module_path,
find_init,
get_base_class_name,
get_dotted_name,
has_decorator,
make_number_node,
parse_module_to_cst,
unparse_parse_source,
)
from codeflash_api.languages.python._postprocess import (
CSTAnnotationNameCollector,
DocstringTransformer,
DocstringVisitor,
EllipsisContainingCodeVisitor,
OptimizationCandidate,
_strip_comments_from_code,
add_future_annotations_import,
clean_extraneous_comments,
clean_extraneous_comments_pipeline,
cleanup_explanations,
dedup_and_sort_imports,
deduplicate_optimizations,
equality_check,
extract_names_from_cst_annotation,
filter_ellipsis_containing_code,
fix_forward_references,
fix_missing_docstring,
has_future_annotations_import,
optimizations_postprocessing_pipeline,
safe_isort,
)
from codeflash_api.languages.python._validator import (
parse_python_or_none,
validate_python_syntax,
)
class TestDepthTrackingMixin:
"""Tests for DepthTrackingMixin."""
def test_initial_state(self) -> None:
"""
Starts at top level.
"""
mixin = DepthTrackingMixin()
assert mixin._is_top_level()
def test_function_depth(self) -> None:
"""
Entering a function increases depth.
"""
mixin = DepthTrackingMixin()
mixin._visit_function()
assert not mixin._is_top_level()
assert mixin._is_inside_function()
mixin._leave_function()
assert mixin._is_top_level()
class TestFilePathToModulePath:
"""Tests for file_path_to_module_path."""
def test_unix_path(self) -> None:
"""
Forward slashes become dots.
"""
assert "path.to.module" == file_path_to_module_path(
"path/to/module.py"
)
def test_windows_path(self) -> None:
"""
Backslashes become dots.
"""
assert "path.to.module" == file_path_to_module_path(
"path\\to\\module.py"
)
class TestGetDottedName:
"""Tests for get_dotted_name."""
def test_simple_name(self) -> None:
"""
Simple Name node returns the value.
"""
node = cst.Name("foo")
assert "foo" == get_dotted_name(node)
def test_attribute(self) -> None:
"""
Dotted attribute returns full path.
"""
node = cst.Attribute(value=cst.Name("foo"), attr=cst.Name("bar"))
assert "foo.bar" == get_dotted_name(node)
def test_none(self) -> None:
"""
None returns empty string.
"""
assert "" == get_dotted_name(None)
class TestBuildModulePath:
"""Tests for build_module_path."""
def test_single_part(self) -> None:
"""
Single name becomes cst.Name.
"""
result = build_module_path("foo")
assert isinstance(result, cst.Name)
assert "foo" == result.value
def test_dotted(self) -> None:
"""
Dotted path becomes cst.Attribute chain.
"""
result = build_module_path("foo.bar.baz")
assert isinstance(result, cst.Attribute)
assert "baz" == result.attr.value
class TestEvaluateExpression:
"""Tests for evaluate_expression."""
def test_integer(self) -> None:
"""
Integer node evaluates to its value.
"""
assert 42 == evaluate_expression(cst.Integer("42"))
def test_hex(self) -> None:
"""
Hex integer evaluates correctly.
"""
assert 255 == evaluate_expression(cst.Integer("0xff"))
def test_negative(self) -> None:
"""
Unary minus evaluates to negative.
"""
node = cst.UnaryOperation(
operator=cst.Minus(),
expression=cst.Integer("5"),
)
assert -5 == evaluate_expression(node)
def test_binary_add(self) -> None:
"""
Addition evaluates correctly.
"""
node = cst.BinaryOperation(
left=cst.Integer("3"),
operator=cst.Add(),
right=cst.Integer("4"),
)
assert 7 == evaluate_expression(node)
def test_unevaluable(self) -> None:
"""
Name node returns None.
"""
assert evaluate_expression(cst.Name("x")) is None
class TestMakeNumberNode:
"""Tests for make_number_node."""
def test_positive(self) -> None:
"""
Positive value becomes Integer.
"""
node = make_number_node(5)
assert isinstance(node, cst.Integer)
assert "5" == node.value
def test_negative(self) -> None:
"""
Negative value becomes UnaryOperation.
"""
node = make_number_node(-3)
assert isinstance(node, cst.UnaryOperation)
class TestParseModuleToCst:
"""Tests for parse_module_to_cst."""
def test_caches(self) -> None:
"""
Same input returns same object.
"""
code = "x = 1"
a = parse_module_to_cst(code)
b = parse_module_to_cst(code)
assert a is b
class TestDefinitionRemover:
"""Tests for DefinitionRemover."""
def test_removes_function(self) -> None:
"""
Top-level function is removed by name.
"""
code = "def foo():\n pass\ndef bar():\n pass\n"
module = cst.parse_module(code)
remover = DefinitionRemover({"foo"})
result = module.visit(remover)
assert "foo" not in result.code
assert "bar" in result.code
assert "foo" in remover.removed_names
def test_protected_name(self) -> None:
"""
Protected name is not removed.
"""
code = "def foo():\n pass\n"
module = cst.parse_module(code)
remover = DefinitionRemover({"foo"}, protected_names={"foo"})
result = module.visit(remover)
assert "foo" in result.code
class TestImportTrackingVisitor:
"""Tests for ImportTrackingVisitor."""
def test_tracks_imports(self) -> None:
"""
Imported names are collected.
"""
import ast
tree = ast.parse("import os\nfrom sys import path")
visitor = ImportTrackingVisitor()
visitor.visit(tree)
assert "os" in visitor.imported_names
assert "path" in visitor.imported_names
class TestFindInit:
"""Tests for find_init."""
def test_found(self) -> None:
"""
__init__ in a class is found.
"""
import ast
tree = ast.parse("class Foo:\n def __init__(self):\n pass\n")
assert find_init(tree)
def test_not_found(self) -> None:
"""
No __init__ returns False.
"""
import ast
tree = ast.parse("class Foo:\n def bar(self):\n pass\n")
assert not find_init(tree)
class TestValidatePythonSyntax:
"""Tests for validate_python_syntax."""
def test_valid(self) -> None:
"""
Valid Python returns True.
"""
assert validate_python_syntax("x = 1\n")
def test_invalid(self) -> None:
"""
Invalid Python returns False.
"""
assert not validate_python_syntax("def (:\n")
def test_empty_body(self) -> None:
"""
Empty content (only comments) returns False.
"""
assert not validate_python_syntax("# just a comment\n")
class TestParsePythonOrNone:
"""Tests for parse_python_or_none."""
def test_valid(self) -> None:
"""
Valid Python returns a Module.
"""
result = parse_python_or_none("x = 1\n")
assert result is not None
assert isinstance(result, cst.Module)
def test_invalid(self) -> None:
"""
Invalid Python returns None.
"""
assert parse_python_or_none("def (:\n") is None
class TestOptimizationCandidate:
"""Tests for OptimizationCandidate."""
def test_frozen(self) -> None:
"""
OptimizationCandidate is immutable.
"""
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 1"),
explanation="faster",
id="test",
)
with pytest.raises(AttributeError):
c.id = "changed"
class TestSafeIsort:
"""Tests for safe_isort."""
def test_sorts(self) -> None:
"""
Imports are sorted.
"""
code = "import sys\nimport os\n"
result = safe_isort(code)
assert result.index("os") < result.index("sys")
def test_invalid_returns_original(self) -> None:
"""
Invalid code returns unchanged.
"""
code = "not valid python {{{{"
assert code == safe_isort(code)
class TestDeduplicateOptimizations:
"""Tests for deduplicate_optimizations."""
def test_removes_duplicates(self) -> None:
"""
Candidates with same AST are deduplicated.
"""
module = cst.parse_module("x = 1\n")
c1 = OptimizationCandidate(
cst_module=cst.parse_module("x = 1\n"),
explanation="a",
id="1",
)
c2 = OptimizationCandidate(
cst_module=cst.parse_module("x=1\n"),
explanation="b",
id="2",
)
result = deduplicate_optimizations(module, [c1, c2])
assert 1 == len(result)
class TestEqualityCheck:
"""Tests for equality_check."""
def test_filters_identical(self) -> None:
"""
Candidate identical to original is filtered.
"""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 1\n"),
explanation="same",
id="1",
)
result = equality_check(module, [c])
assert 0 == len(result)
def test_keeps_different(self) -> None:
"""
Different candidate is kept.
"""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 2\n"),
explanation="different",
id="1",
)
result = equality_check(module, [c])
assert 1 == len(result)
class TestFilterEllipsis:
"""Tests for filter_ellipsis_containing_code."""
def test_filters_introduced_ellipsis(self) -> None:
"""
Candidate that introduces ellipsis is filtered.
"""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = ...\n"),
explanation="bad",
id="1",
)
result = filter_ellipsis_containing_code(module, [c])
assert 0 == len(result)
def test_keeps_when_original_has_ellipsis(self) -> None:
"""
If original has ellipsis, candidates are not filtered.
"""
module = cst.parse_module("x = ...\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = ...\n"),
explanation="ok",
id="1",
)
result = filter_ellipsis_containing_code(module, [c])
assert 1 == len(result)
class TestFixMissingDocstring:
"""Tests for fix_missing_docstring."""
def test_restores_docstring(self) -> None:
"""
Removed docstring is restored.
"""
original = cst.parse_module(
'def foo():\n """My docstring."""\n pass\n'
)
optimized = cst.parse_module("def foo():\n pass\n")
c = OptimizationCandidate(
cst_module=optimized,
explanation="faster",
id="1",
)
result = fix_missing_docstring(original, [c])
assert "My docstring" in result[0].cst_module.code
class TestHasFutureAnnotationsImport:
"""Tests for has_future_annotations_import."""
def test_present(self) -> None:
"""
Module with the import returns True.
"""
module = cst.parse_module("from __future__ import annotations\n")
assert has_future_annotations_import(module)
def test_absent(self) -> None:
"""
Module without the import returns False.
"""
module = cst.parse_module("x = 1\n")
assert not has_future_annotations_import(module)
class TestAddFutureAnnotationsImport:
"""Tests for add_future_annotations_import."""
def test_adds_when_needed(self) -> None:
"""
Import is added when undefined annotations exist.
"""
code = "def foo(x: MyType) -> None:\n pass\n"
module = cst.parse_module(code)
result = add_future_annotations_import(module)
assert "from __future__ import annotations" in result.code
def test_skips_when_present(self) -> None:
"""
No change when import already exists.
"""
code = "from __future__ import annotations\ndef foo(x: MyType) -> None:\n pass\n"
module = cst.parse_module(code)
result = add_future_annotations_import(module)
assert result is module
def test_skips_when_no_undefined(self) -> None:
"""
No change when all annotations are defined.
"""
code = "def foo(x: int) -> None:\n pass\n"
module = cst.parse_module(code)
result = add_future_annotations_import(module)
assert result is module
class TestCleanupExplanations:
"""Tests for cleanup_explanations."""
def test_removes_code_block(self) -> None:
"""
Markdown code blocks are removed from explanations.
"""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=module,
explanation="```python\nx = 1\n```",
id="1",
)
result = cleanup_explanations(module, [c])
assert "```" not in result[0].explanation
class TestPostprocessingPipeline:
"""Tests for optimizations_postprocessing_pipeline."""
def test_empty_candidates(self) -> None:
"""
Empty list returns empty list.
"""
module = cst.parse_module("x = 1\n")
result = optimizations_postprocessing_pipeline(module, [])
assert [] == result
def test_filters_identical(self) -> None:
"""
Identical candidate is filtered by the pipeline.
"""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 1\n"),
explanation="same",
id="1",
)
result = optimizations_postprocessing_pipeline(module, [c])
assert 0 == len(result)
def test_keeps_valid_optimization(self) -> None:
"""
Valid different candidate passes through.
"""
module = cst.parse_module("x = 1 + 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 2\n"),
explanation="precomputed",
id="1",
)
result = optimizations_postprocessing_pipeline(module, [c])
assert 1 == len(result)
# ---------------------------------------------------------------------------
# _cst_utils.py — additional coverage
# ---------------------------------------------------------------------------
class TestDepthTrackingMixinClassDepth:
"""Tests for class depth tracking in DepthTrackingMixin."""
def test_class_depth(self) -> None:
"""Entering a class increases class depth."""
mixin = DepthTrackingMixin()
assert not mixin._is_inside_class()
mixin._visit_class()
assert mixin._is_inside_class()
assert not mixin._is_top_level_class()
mixin._leave_class()
assert mixin._is_top_level_class()
def test_top_level_function_check(self) -> None:
"""_is_top_level_function returns False inside a function."""
mixin = DepthTrackingMixin()
assert mixin._is_top_level_function()
mixin._visit_function()
assert not mixin._is_top_level_function()
mixin._leave_function()
assert mixin._is_top_level_function()
def test_top_level_class_with_function(self) -> None:
"""_is_top_level_class returns False inside a function."""
mixin = DepthTrackingMixin()
mixin._visit_function()
assert not mixin._is_top_level_class()
class TestExtractImportInfo:
"""Tests for extract_import_info."""
def test_simple_name(self) -> None:
"""Simple import alias extracts name correctly."""
alias = cst.ImportAlias(name=cst.Name("os"))
available, module, original = extract_import_info(alias)
assert "os" == available
assert "" == module
assert "os" == original
def test_dotted_name_without_module(self) -> None:
"""Dotted import without module returns first component as available."""
alias = cst.ImportAlias(
name=cst.Attribute(value=cst.Name("os"), attr=cst.Name("path"))
)
available, module, original = extract_import_info(alias)
assert "os" == available
assert "" == module
assert "os.path" == original
def test_with_asname(self) -> None:
"""Import alias with asname uses the alias."""
alias = cst.ImportAlias(
name=cst.Name("os"),
asname=cst.AsName(
whitespace_before_as=cst.SimpleWhitespace(" "),
whitespace_after_as=cst.SimpleWhitespace(" "),
name=cst.Name("operating_system"),
),
)
available, module, original = extract_import_info(alias)
assert "operating_system" == available
assert "os" == original
def test_with_module_name(self) -> None:
"""When module_name is given, available_name is the original name."""
alias = cst.ImportAlias(name=cst.Name("path"))
available, module, original = extract_import_info(alias, "os")
assert "path" == available
assert "os" == module
assert "path" == original
class TestExtractImportsFromImport:
"""Tests for extract_imports_from_import."""
def test_simple_import(self) -> None:
"""Extracts names from a plain import statement."""
node = cst.parse_statement("import os, sys\n")
assert isinstance(node, cst.SimpleStatementLine)
import_node = node.body[0]
assert isinstance(import_node, cst.Import)
result = extract_imports_from_import(import_node)
assert "os" in result
assert "sys" in result
class TestExtractImportsFromImportFrom:
"""Tests for extract_imports_from_import_from."""
def test_from_import(self) -> None:
"""Extracts names from 'from X import Y'."""
node = cst.parse_statement("from os import path, getcwd\n")
assert isinstance(node, cst.SimpleStatementLine)
import_from = node.body[0]
assert isinstance(import_from, cst.ImportFrom)
result = extract_imports_from_import_from(import_from)
assert "path" in result
assert "getcwd" in result
assert ("os", "path") == result["path"]
def test_import_star_returns_empty(self) -> None:
"""ImportStar returns empty dict."""
node = cst.parse_statement("from os import *\n")
assert isinstance(node, cst.SimpleStatementLine)
import_from = node.body[0]
assert isinstance(import_from, cst.ImportFrom)
result = extract_imports_from_import_from(import_from)
assert {} == result
class TestCollectImportedNames:
"""Tests for collect_imported_names_from_import and _from_import_from."""
def test_collect_from_import(self) -> None:
"""Collects available names from import statement."""
node = cst.parse_statement("import os, sys\n")
import_node = node.body[0]
assert isinstance(import_node, cst.Import)
result = collect_imported_names_from_import(import_node)
assert {"os", "sys"} == result
def test_collect_from_import_from(self) -> None:
"""Collects available names from from-import statement."""
node = cst.parse_statement("from os import path, getcwd\n")
import_from = node.body[0]
assert isinstance(import_from, cst.ImportFrom)
result = collect_imported_names_from_import_from(import_from)
assert {"path", "getcwd"} == result
def test_collect_from_import_from_star(self) -> None:
"""ImportStar from-import returns empty set."""
node = cst.parse_statement("from os import *\n")
import_from = node.body[0]
assert isinstance(import_from, cst.ImportFrom)
result = collect_imported_names_from_import_from(import_from)
assert set() == result
class TestDefinitionRemoverAdvanced:
"""Advanced tests for DefinitionRemover."""
def test_removes_class(self) -> None:
"""Top-level class is removed by name."""
code = "class Foo:\n pass\nclass Bar:\n pass\n"
module = cst.parse_module(code)
remover = DefinitionRemover({"Foo"})
result = module.visit(remover)
assert "Foo" not in result.code
assert "Bar" in result.code
assert "Foo" in remover.removed_names
def test_removes_method_causes_class_removal(self) -> None:
"""Removing a qualified method removes the enclosing class."""
code = "class Foo:\n def bar(self):\n pass\n"
module = cst.parse_module(code)
remover = DefinitionRemover({"Foo.bar"})
result = module.visit(remover)
assert "Foo" not in result.code
assert "Foo" in remover.removed_names
def test_nested_function_not_removed(self) -> None:
"""A function nested inside another is not removed as top-level."""
code = "def outer():\n def foo():\n pass\n return foo()\n"
module = cst.parse_module(code)
remover = DefinitionRemover({"foo"})
result = module.visit(remover)
assert "foo" in result.code
def test_protected_class_not_removed(self) -> None:
"""Protected class is not removed."""
code = "class Foo:\n pass\n"
module = cst.parse_module(code)
remover = DefinitionRemover({"Foo"}, protected_names={"Foo"})
result = module.visit(remover)
assert "Foo" in result.code
class TestImportTrackingVisitorAdvanced:
"""Additional tests for ImportTrackingVisitor."""
def test_tracks_import_with_alias(self) -> None:
"""Aliased imports use the alias name."""
import ast
tree = ast.parse("import os as operating_system")
visitor = ImportTrackingVisitor()
visitor.visit(tree)
assert "operating_system" in visitor.imported_names
assert "os" not in visitor.imported_names
def test_tracks_from_import_with_alias(self) -> None:
"""Aliased from-imports use the alias name."""
import ast
tree = ast.parse("from os import path as p")
visitor = ImportTrackingVisitor()
visitor.visit(tree)
assert "p" in visitor.imported_names
assert "path" not in visitor.imported_names
def test_skips_star_import(self) -> None:
"""Star imports are skipped."""
import ast
tree = ast.parse("from os import *")
visitor = ImportTrackingVisitor()
visitor.visit(tree)
assert 0 == len(visitor.imported_names)
def test_dotted_import_uses_first_component(self) -> None:
"""'import os.path' tracks 'os' as the available name."""
import ast
tree = ast.parse("import os.path")
visitor = ImportTrackingVisitor()
visitor.visit(tree)
assert "os" in visitor.imported_names
class TestGetBaseClassName:
"""Tests for get_base_class_name."""
def test_simple_name(self) -> None:
"""Simple Name base class returns the name."""
arg = cst.Arg(value=cst.Name("Base"))
assert "Base" == get_base_class_name(arg)
def test_attribute_base(self) -> None:
"""Attribute base class returns the attr name."""
arg = cst.Arg(
value=cst.Attribute(value=cst.Name("mod"), attr=cst.Name("Base"))
)
assert "Base" == get_base_class_name(arg)
def test_non_name_returns_none(self) -> None:
"""Non-Name/Attribute base returns None."""
arg = cst.Arg(value=cst.Integer("42"))
assert get_base_class_name(arg) is None
class TestHasDecorator:
"""Tests for has_decorator."""
def test_simple_decorator(self) -> None:
"""Class with a simple decorator is detected."""
code = "@dataclass\nclass Foo:\n pass\n"
module = cst.parse_module(code)
class_node = module.body[0]
assert isinstance(class_node, cst.ClassDef)
assert has_decorator(class_node, "dataclass")
def test_call_decorator(self) -> None:
"""Class with a call-style decorator is detected."""
code = "@dataclass(frozen=True)\nclass Foo:\n pass\n"
module = cst.parse_module(code)
class_node = module.body[0]
assert isinstance(class_node, cst.ClassDef)
assert has_decorator(class_node, "dataclass")
def test_missing_decorator(self) -> None:
"""Class without the decorator returns False."""
code = "class Foo:\n pass\n"
module = cst.parse_module(code)
class_node = module.body[0]
assert isinstance(class_node, cst.ClassDef)
assert not has_decorator(class_node, "dataclass")
def test_wrong_decorator(self) -> None:
"""Class with a different decorator returns False."""
code = "@staticmethod\nclass Foo:\n pass\n"
module = cst.parse_module(code)
class_node = module.body[0]
assert isinstance(class_node, cst.ClassDef)
assert not has_decorator(class_node, "dataclass")
class TestEvaluateExpressionAdvanced:
"""Additional tests for evaluate_expression."""
def test_float(self) -> None:
"""Float node evaluates to truncated int."""
assert 3 == evaluate_expression(cst.Float("3.9"))
def test_binary_multiply(self) -> None:
"""Multiplication evaluates correctly."""
node = cst.BinaryOperation(
left=cst.Integer("3"),
operator=cst.Multiply(),
right=cst.Integer("4"),
)
assert 12 == evaluate_expression(node)
def test_binary_subtract(self) -> None:
"""Subtraction evaluates correctly."""
node = cst.BinaryOperation(
left=cst.Integer("10"),
operator=cst.Subtract(),
right=cst.Integer("3"),
)
assert 7 == evaluate_expression(node)
def test_binary_power(self) -> None:
"""Power evaluates correctly."""
node = cst.BinaryOperation(
left=cst.Integer("2"),
operator=cst.Power(),
right=cst.Integer("3"),
)
assert 8 == evaluate_expression(node)
def test_unsupported_binary_op(self) -> None:
"""Unsupported binary operator returns None."""
node = cst.BinaryOperation(
left=cst.Integer("10"),
operator=cst.FloorDivide(),
right=cst.Integer("3"),
)
assert evaluate_expression(node) is None
def test_nested_negative_unevaluable(self) -> None:
"""Unary minus on unevaluable returns None."""
node = cst.UnaryOperation(
operator=cst.Minus(),
expression=cst.Name("x"),
)
assert evaluate_expression(node) is None
class TestEllipsisInCstNotTypes:
"""Tests for ellipsis_in_cst_not_types."""
def test_no_ellipsis(self) -> None:
"""Module without ellipsis returns False."""
module = cst.parse_module("x = 1\n")
assert not ellipsis_in_cst_not_types(module)
def test_ellipsis_in_assignment(self) -> None:
"""Ellipsis in assignment is invalid."""
module = cst.parse_module("x = ...\n")
assert ellipsis_in_cst_not_types(module)
def test_ellipsis_in_type_annotation(self) -> None:
"""Ellipsis inside a type subscript is valid."""
module = cst.parse_module("x: tuple[int, ...] = (1,)\n")
assert not ellipsis_in_cst_not_types(module)
def test_ellipsis_as_function_body(self) -> None:
"""Ellipsis as function body placeholder is valid."""
module = cst.parse_module("def foo():\n ...\n")
assert not ellipsis_in_cst_not_types(module)
class TestAnyEllipsisInCst:
"""Tests for any_ellipsis_in_cst."""
def test_no_ellipsis(self) -> None:
"""Module without ellipsis returns False."""
module = cst.parse_module("x = 1\n")
assert not any_ellipsis_in_cst(module)
def test_has_ellipsis(self) -> None:
"""Module with any ellipsis returns True."""
module = cst.parse_module("x = ...\n")
assert any_ellipsis_in_cst(module)
class TestCompareUnparsedAstToSource:
"""Tests for compare_unparsed_ast_to_source."""
def test_equivalent(self) -> None:
"""Equivalent code compares True."""
assert compare_unparsed_ast_to_source("x = 1", "x = 1")
def test_different(self) -> None:
"""Different code compares False."""
assert not compare_unparsed_ast_to_source("x = 1", "x = 2")
class TestUnparseParseSource:
"""Tests for unparse_parse_source."""
def test_normalizes(self) -> None:
"""Normalizes formatting via ast round-trip."""
assert "x = 1" == unparse_parse_source("x = 1")
# ---------------------------------------------------------------------------
# _postprocess.py — additional coverage
# ---------------------------------------------------------------------------
class TestDocstringVisitor:
"""Tests for DocstringVisitor."""
def test_collects_function_docstring(self) -> None:
"""Collects docstring from a top-level function."""
module = cst.parse_module('def foo():\n """Hello."""\n pass\n')
visitor = DocstringVisitor()
module.visit(visitor)
assert "Hello." == visitor.original_docstrings["foo"]
def test_collects_class_docstring(self) -> None:
"""Collects docstring from a class."""
module = cst.parse_module(
'class Foo:\n """Class doc."""\n pass\n'
)
visitor = DocstringVisitor()
module.visit(visitor)
assert "Class doc." == visitor.original_docstrings["Foo"]
def test_collects_method_docstring(self) -> None:
"""Collects docstring from a method with qualified name."""
module = cst.parse_module(
'class Foo:\n def bar(self):\n """Method doc."""\n pass\n'
)
visitor = DocstringVisitor()
module.visit(visitor)
assert "Method doc." == visitor.original_docstrings["Foo.bar"]
def test_no_docstring(self) -> None:
"""Function without docstring is not collected."""
module = cst.parse_module("def foo():\n pass\n")
visitor = DocstringVisitor()
module.visit(visitor)
assert {} == visitor.original_docstrings
def test_class_name_reset_after_leave(self) -> None:
"""class_name is reset to None after leaving a class."""
module = cst.parse_module(
"class Foo:\n pass\ndef bar():\n pass\n"
)
visitor = DocstringVisitor()
module.visit(visitor)
assert visitor.class_name is None
class TestDocstringTransformer:
"""Tests for DocstringTransformer."""
def test_restores_missing_function_docstring(self) -> None:
"""Restores a removed function docstring."""
docstrings = {"foo": "My docstring."}
transformer = DocstringTransformer(docstrings)
optimized = cst.parse_module("def foo():\n pass\n")
result = optimized.visit(transformer)
assert "My docstring." in result.code
def test_replaces_changed_function_docstring(self) -> None:
"""Replaces a modified function docstring with original."""
docstrings = {"foo": "Original doc."}
transformer = DocstringTransformer(docstrings)
optimized = cst.parse_module(
'def foo():\n """Wrong doc."""\n pass\n'
)
result = optimized.visit(transformer)
assert "Original doc." in result.code
def test_restores_missing_class_docstring(self) -> None:
"""Restores a removed class docstring."""
docstrings = {"Foo": "Class docstring."}
transformer = DocstringTransformer(docstrings)
optimized = cst.parse_module("class Foo:\n pass\n")
result = optimized.visit(transformer)
assert "Class docstring." in result.code
def test_replaces_changed_class_docstring(self) -> None:
"""Replaces a modified class docstring with original."""
docstrings = {"Foo": "Original class doc."}
transformer = DocstringTransformer(docstrings)
optimized = cst.parse_module(
'class Foo:\n """Wrong class doc."""\n pass\n'
)
result = optimized.visit(transformer)
assert "Original class doc." in result.code
def test_no_change_without_original(self) -> None:
"""No change when no original docstring exists."""
transformer = DocstringTransformer({})
optimized = cst.parse_module("def foo():\n pass\n")
result = optimized.visit(transformer)
assert result.code == optimized.code
class TestFixMissingDocstringAdvanced:
"""Additional tests for fix_missing_docstring."""
def test_restores_class_docstring(self) -> None:
"""Restores a removed class docstring."""
original = cst.parse_module(
'class Foo:\n """Class doc."""\n pass\n'
)
optimized = cst.parse_module("class Foo:\n pass\n")
c = OptimizationCandidate(
cst_module=optimized,
explanation="faster",
id="1",
)
result = fix_missing_docstring(original, [c])
assert "Class doc." in result[0].cst_module.code
def test_restores_method_docstring(self) -> None:
"""Restores a removed method docstring."""
original = cst.parse_module(
'class Foo:\n def bar(self):\n """Method doc."""\n pass\n'
)
optimized = cst.parse_module(
"class Foo:\n def bar(self):\n pass\n"
)
c = OptimizationCandidate(
cst_module=optimized,
explanation="faster",
id="1",
)
result = fix_missing_docstring(original, [c])
assert "Method doc." in result[0].cst_module.code
class TestDedupAndSortImports:
"""Tests for dedup_and_sort_imports."""
def test_sorts_imports(self) -> None:
"""Unsorted imports get sorted."""
code = "import sys\nimport os\n\nx = 1\n"
module = cst.parse_module(code)
c = OptimizationCandidate(
cst_module=module,
explanation="test",
id="1",
)
result = dedup_and_sort_imports(module, [c])
assert 1 == len(result)
result_code = result[0].cst_module.code
assert result_code.index("os") < result_code.index("sys")
def test_no_change_when_sorted(self) -> None:
"""Already sorted imports are unchanged."""
code = "import os\nimport sys\n"
module = cst.parse_module(code)
c = OptimizationCandidate(
cst_module=module,
explanation="test",
id="1",
)
result = dedup_and_sort_imports(module, [c])
assert 1 == len(result)
class TestEllipsisContainingCodeVisitor:
"""Tests for EllipsisContainingCodeVisitor."""
def test_no_ellipsis(self) -> None:
"""Module without ellipsis sets flag to False."""
module = cst.parse_module("x = 1\n")
visitor = EllipsisContainingCodeVisitor()
module.visit(visitor)
assert not visitor.ellipsis_containing_code
def test_has_ellipsis(self) -> None:
"""Module with ellipsis sets flag to True."""
module = cst.parse_module("x = ...\n")
visitor = EllipsisContainingCodeVisitor()
module.visit(visitor)
assert visitor.ellipsis_containing_code
class TestFilterEllipsisAdvanced:
"""Additional tests for filter_ellipsis_containing_code."""
def test_keeps_non_ellipsis_candidate(self) -> None:
"""Candidate without ellipsis is kept when original has none."""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 2\n"),
explanation="ok",
id="1",
)
result = filter_ellipsis_containing_code(module, [c])
assert 1 == len(result)
def test_mixed_candidates(self) -> None:
"""Only candidates with ellipsis are filtered when original has none."""
module = cst.parse_module("x = 1\n")
c1 = OptimizationCandidate(
cst_module=cst.parse_module("x = 2\n"),
explanation="good",
id="1",
)
c2 = OptimizationCandidate(
cst_module=cst.parse_module("x = ...\n"),
explanation="bad",
id="2",
)
result = filter_ellipsis_containing_code(module, [c1, c2])
assert 1 == len(result)
assert "1" == result[0].id
class TestStripCommentsFromCode:
"""Tests for _strip_comments_from_code."""
def test_removes_inline_comment(self) -> None:
"""Inline comments are removed."""
code = "x = 1 # a comment\n"
result = _strip_comments_from_code(code)
assert "#" not in result
assert "x = 1" in result
def test_removes_full_line_comment(self) -> None:
"""Full-line comments are removed."""
code = "# full line comment\nx = 1\n"
result = _strip_comments_from_code(code)
assert "full line comment" not in result
assert "x = 1" in result
def test_preserves_strings(self) -> None:
"""Hash inside strings is preserved."""
code = 'x = "# not a comment"\n'
result = _strip_comments_from_code(code)
assert "# not a comment" in result
def test_invalid_code_returns_original(self) -> None:
"""Invalid code returns the original string."""
code = "def (:\n"
assert code == _strip_comments_from_code(code)
class TestCleanExtraneousComments:
"""Tests for clean_extraneous_comments."""
def test_preserves_original_comments(self) -> None:
"""Comments from original code are preserved in output."""
original_code = "x = 1 # original comment\n"
optimized_code = "x = 1 # original comment\n"
original_module = cst.parse_module(original_code)
optimized_module = cst.parse_module(optimized_code)
result = clean_extraneous_comments(original_module, optimized_module)
assert "original comment" in result.code
def test_removes_new_comments_on_unchanged_lines(self) -> None:
"""New comments added to unchanged code lines are stripped."""
original_code = "x = 1\ny = 2\n"
optimized_code = "x = 1 # new comment\ny = 2\n"
original_module = cst.parse_module(original_code)
optimized_module = cst.parse_module(optimized_code)
result = clean_extraneous_comments(original_module, optimized_module)
assert "new comment" not in result.code
class TestCleanExtraneousCommentsPipeline:
"""Tests for clean_extraneous_comments_pipeline."""
def test_processes_candidates(self) -> None:
"""Pipeline processes all candidates."""
original = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 2 # optimized\n"),
explanation="test",
id="1",
)
result = clean_extraneous_comments_pipeline(original, [c])
assert 1 == len(result)
class TestEqualityCheckAdvanced:
"""Additional edge-case tests for equality_check."""
def test_with_explicit_original_code(self) -> None:
"""Original code can be passed explicitly."""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 2\n"),
explanation="different",
id="1",
)
result = equality_check(module, [c], original_code="x = 1\n")
assert 1 == len(result)
def test_filters_with_explicit_code(self) -> None:
"""Identical candidate is filtered with explicit code."""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 1\n"),
explanation="same",
id="1",
)
result = equality_check(module, [c], original_code="x = 1\n")
assert 0 == len(result)
def test_unparseable_original_falls_back_to_string(self) -> None:
"""When original is unparseable, falls back to string comparison."""
module = cst.parse_module("x = 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 2\n"),
explanation="different",
id="1",
)
result = equality_check(module, [c], original_code="def (:\n")
assert 1 == len(result)
class TestExtractNamesFromCstAnnotation:
"""Tests for extract_names_from_cst_annotation."""
def test_simple_name(self) -> None:
"""Extracts simple name from annotation."""
names: set[str] = set()
extract_names_from_cst_annotation(cst.Name("int"), names)
assert {"int"} == names
def test_attribute(self) -> None:
"""Extracts base name from attribute annotation."""
names: set[str] = set()
node = cst.Attribute(value=cst.Name("typing"), attr=cst.Name("List"))
extract_names_from_cst_annotation(node, names)
assert "typing" in names
def test_subscript(self) -> None:
"""Extracts names from subscript annotation like List[int]."""
names: set[str] = set()
node = cst.Subscript(
value=cst.Name("List"),
slice=[
cst.SubscriptElement(slice=cst.Index(value=cst.Name("int")))
],
)
extract_names_from_cst_annotation(node, names)
assert "List" in names
assert "int" in names
def test_binary_operation(self) -> None:
"""Extracts names from union-style binary annotation."""
names: set[str] = set()
node = cst.BinaryOperation(
left=cst.Name("str"),
operator=cst.BitOr(),
right=cst.Name("int"),
)
extract_names_from_cst_annotation(node, names)
assert {"str", "int"} == names
def test_tuple_annotation(self) -> None:
"""Extracts names from tuple annotation."""
names: set[str] = set()
node = cst.Tuple(
elements=[
cst.Element(value=cst.Name("int")),
cst.Element(value=cst.Name("str")),
]
)
extract_names_from_cst_annotation(node, names)
assert {"int", "str"} == names
class TestCSTAnnotationNameCollector:
"""Tests for CSTAnnotationNameCollector."""
def test_collects_function_annotations(self) -> None:
"""Collects annotation names from function parameters and return."""
module = cst.parse_module(
"def foo(x: MyType) -> ReturnType:\n pass\n"
)
collector = CSTAnnotationNameCollector()
module.visit(collector)
assert "MyType" in collector.annotation_names
assert "ReturnType" in collector.annotation_names
def test_collects_imports(self) -> None:
"""Tracks imported names."""
module = cst.parse_module("import os\nfrom sys import path\n")
collector = CSTAnnotationNameCollector()
module.visit(collector)
assert "os" in collector.imported_names
assert "path" in collector.imported_names
def test_collects_class_names(self) -> None:
"""Tracks defined class names."""
module = cst.parse_module("class Foo:\n pass\n")
collector = CSTAnnotationNameCollector()
module.visit(collector)
assert "Foo" in collector.defined_names
def test_collects_function_names(self) -> None:
"""Tracks defined function names."""
module = cst.parse_module("def foo():\n pass\n")
collector = CSTAnnotationNameCollector()
module.visit(collector)
assert "foo" in collector.defined_names
def test_undefined_annotations(self) -> None:
"""Returns annotation names that are not defined or imported."""
module = cst.parse_module("def foo(x: MyType) -> None:\n pass\n")
collector = CSTAnnotationNameCollector()
module.visit(collector)
undefined = collector.get_undefined_annotation_names()
assert "MyType" in undefined
assert "None" not in undefined
def test_no_undefined_with_import(self) -> None:
"""No undefined names when all annotations are imported."""
module = cst.parse_module(
"from mymod import MyType\ndef foo(x: MyType) -> None:\n pass\n"
)
collector = CSTAnnotationNameCollector()
module.visit(collector)
undefined = collector.get_undefined_annotation_names()
assert "MyType" not in undefined
def test_collects_kwonly_params(self) -> None:
"""Collects annotations from keyword-only parameters."""
module = cst.parse_module("def foo(*, x: KWType) -> None:\n pass\n")
collector = CSTAnnotationNameCollector()
module.visit(collector)
assert "KWType" in collector.annotation_names
def test_collects_star_arg_annotation(self) -> None:
"""Collects annotations from *args parameter."""
module = cst.parse_module(
"def foo(*args: ArgsType) -> None:\n pass\n"
)
collector = CSTAnnotationNameCollector()
module.visit(collector)
assert "ArgsType" in collector.annotation_names
def test_collects_star_kwarg_annotation(self) -> None:
"""Collects annotations from **kwargs parameter."""
module = cst.parse_module(
"def foo(**kwargs: KwargsType) -> None:\n pass\n"
)
collector = CSTAnnotationNameCollector()
module.visit(collector)
assert "KwargsType" in collector.annotation_names
def test_collects_ann_assign(self) -> None:
"""Collects annotation names from annotated assignments."""
module = cst.parse_module("x: MyType = 1\n")
collector = CSTAnnotationNameCollector()
module.visit(collector)
assert "MyType" in collector.annotation_names
class TestFixForwardReferences:
"""Tests for fix_forward_references."""
def test_adds_import_when_needed(self) -> None:
"""Adds future annotations import for undefined annotation names."""
module = cst.parse_module("x = 1\n")
code = "def foo(x: MyType) -> None:\n pass\n"
c = OptimizationCandidate(
cst_module=cst.parse_module(code),
explanation="test",
id="1",
)
result = fix_forward_references(module, [c])
assert 1 == len(result)
assert (
"from __future__ import annotations" in result[0].cst_module.code
)
def test_no_change_when_all_defined(self) -> None:
"""No change when all annotations are builtins."""
module = cst.parse_module("x = 1\n")
code = "def foo(x: int) -> None:\n pass\n"
c = OptimizationCandidate(
cst_module=cst.parse_module(code),
explanation="test",
id="1",
)
result = fix_forward_references(module, [c])
assert 1 == len(result)
assert (
"from __future__ import annotations"
not in result[0].cst_module.code
)
def test_no_change_when_already_present(self) -> None:
"""No change when future annotations import already exists."""
module = cst.parse_module("x = 1\n")
code = "from __future__ import annotations\ndef foo(x: MyType) -> None:\n pass\n"
c = OptimizationCandidate(
cst_module=cst.parse_module(code),
explanation="test",
id="1",
)
result = fix_forward_references(module, [c])
assert 1 == len(result)
class TestPostprocessingPipelineAdvanced:
"""Additional tests for optimizations_postprocessing_pipeline."""
def test_pipeline_restores_docstring(self) -> None:
"""Pipeline restores removed docstrings."""
original = cst.parse_module(
'def foo():\n """My doc."""\n return 1\n'
)
c = OptimizationCandidate(
cst_module=cst.parse_module("def foo():\n return 2\n"),
explanation="faster",
id="1",
)
result = optimizations_postprocessing_pipeline(original, [c])
assert 1 == len(result)
assert "My doc." in result[0].cst_module.code
def test_pipeline_adds_future_annotations(self) -> None:
"""Pipeline adds future annotations when needed."""
original = cst.parse_module("def foo():\n return 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module(
"def foo(x: MyCustomType) -> None:\n return 2\n"
),
explanation="typed",
id="1",
)
result = optimizations_postprocessing_pipeline(original, [c])
assert 1 == len(result)
assert (
"from __future__ import annotations" in result[0].cst_module.code
)
def test_pipeline_cleans_explanations(self) -> None:
"""Pipeline cleans up LLM explanation artifacts."""
original = cst.parse_module("x = 1 + 1\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = 2\n"),
explanation="```python\nx = 2\n```",
id="1",
)
result = optimizations_postprocessing_pipeline(original, [c])
assert 1 == len(result)
assert "```" not in result[0].explanation
def test_pipeline_filters_ellipsis(self) -> None:
"""Pipeline filters candidates that introduce ellipsis."""
original = cst.parse_module("x = 1\ny = 2\n")
c = OptimizationCandidate(
cst_module=cst.parse_module("x = ...\ny = ...\n"),
explanation="bad",
id="1",
)
result = optimizations_postprocessing_pipeline(original, [c])
assert 0 == len(result)