mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
Add 182 new tests across optimize, V4A diff, CST utils, and postprocess modules. Key coverage improvements: - optimize/_pipeline.py: 29% → 97% - optimize/_router.py: 40% → 93% - diff/_v4a.py: 40% → 97% - languages/python/_cst_utils.py: 67% → 96% - languages/python/_postprocess.py: 67% → 87% Also apply ruff format to 5 files that had formatting drift.
1512 lines
50 KiB
Python
1512 lines
50 KiB
Python
from __future__ import annotations
|
|
|
|
import libcst as cst
|
|
import pytest
|
|
|
|
from codeflash_api.languages.python._cst_utils import (
|
|
AnyEllipsisVisitor,
|
|
DefinitionRemover,
|
|
DepthTrackingMixin,
|
|
ImportTrackingVisitor,
|
|
InvalidEllipsisVisitor,
|
|
any_ellipsis_in_cst,
|
|
build_module_path,
|
|
collect_imported_names_from_import,
|
|
collect_imported_names_from_import_from,
|
|
compare_unparsed_ast_to_source,
|
|
ellipsis_in_cst_not_types,
|
|
evaluate_expression,
|
|
extract_import_info,
|
|
extract_imports_from_import,
|
|
extract_imports_from_import_from,
|
|
file_path_to_module_path,
|
|
find_init,
|
|
get_base_class_name,
|
|
get_dotted_name,
|
|
has_decorator,
|
|
make_number_node,
|
|
parse_module_to_cst,
|
|
unparse_parse_source,
|
|
)
|
|
from codeflash_api.languages.python._postprocess import (
|
|
CSTAnnotationNameCollector,
|
|
DocstringTransformer,
|
|
DocstringVisitor,
|
|
EllipsisContainingCodeVisitor,
|
|
OptimizationCandidate,
|
|
_strip_comments_from_code,
|
|
add_future_annotations_import,
|
|
clean_extraneous_comments,
|
|
clean_extraneous_comments_pipeline,
|
|
cleanup_explanations,
|
|
dedup_and_sort_imports,
|
|
deduplicate_optimizations,
|
|
equality_check,
|
|
extract_names_from_cst_annotation,
|
|
filter_ellipsis_containing_code,
|
|
fix_forward_references,
|
|
fix_missing_docstring,
|
|
has_future_annotations_import,
|
|
optimizations_postprocessing_pipeline,
|
|
safe_isort,
|
|
)
|
|
from codeflash_api.languages.python._validator import (
|
|
parse_python_or_none,
|
|
validate_python_syntax,
|
|
)
|
|
|
|
|
|
class TestDepthTrackingMixin:
|
|
"""Tests for DepthTrackingMixin."""
|
|
|
|
def test_initial_state(self) -> None:
|
|
"""
|
|
Starts at top level.
|
|
"""
|
|
mixin = DepthTrackingMixin()
|
|
assert mixin._is_top_level()
|
|
|
|
def test_function_depth(self) -> None:
|
|
"""
|
|
Entering a function increases depth.
|
|
"""
|
|
mixin = DepthTrackingMixin()
|
|
mixin._visit_function()
|
|
assert not mixin._is_top_level()
|
|
assert mixin._is_inside_function()
|
|
mixin._leave_function()
|
|
assert mixin._is_top_level()
|
|
|
|
|
|
class TestFilePathToModulePath:
|
|
"""Tests for file_path_to_module_path."""
|
|
|
|
def test_unix_path(self) -> None:
|
|
"""
|
|
Forward slashes become dots.
|
|
"""
|
|
assert "path.to.module" == file_path_to_module_path(
|
|
"path/to/module.py"
|
|
)
|
|
|
|
def test_windows_path(self) -> None:
|
|
"""
|
|
Backslashes become dots.
|
|
"""
|
|
assert "path.to.module" == file_path_to_module_path(
|
|
"path\\to\\module.py"
|
|
)
|
|
|
|
|
|
class TestGetDottedName:
|
|
"""Tests for get_dotted_name."""
|
|
|
|
def test_simple_name(self) -> None:
|
|
"""
|
|
Simple Name node returns the value.
|
|
"""
|
|
node = cst.Name("foo")
|
|
assert "foo" == get_dotted_name(node)
|
|
|
|
def test_attribute(self) -> None:
|
|
"""
|
|
Dotted attribute returns full path.
|
|
"""
|
|
node = cst.Attribute(value=cst.Name("foo"), attr=cst.Name("bar"))
|
|
assert "foo.bar" == get_dotted_name(node)
|
|
|
|
def test_none(self) -> None:
|
|
"""
|
|
None returns empty string.
|
|
"""
|
|
assert "" == get_dotted_name(None)
|
|
|
|
|
|
class TestBuildModulePath:
|
|
"""Tests for build_module_path."""
|
|
|
|
def test_single_part(self) -> None:
|
|
"""
|
|
Single name becomes cst.Name.
|
|
"""
|
|
result = build_module_path("foo")
|
|
assert isinstance(result, cst.Name)
|
|
assert "foo" == result.value
|
|
|
|
def test_dotted(self) -> None:
|
|
"""
|
|
Dotted path becomes cst.Attribute chain.
|
|
"""
|
|
result = build_module_path("foo.bar.baz")
|
|
assert isinstance(result, cst.Attribute)
|
|
assert "baz" == result.attr.value
|
|
|
|
|
|
class TestEvaluateExpression:
|
|
"""Tests for evaluate_expression."""
|
|
|
|
def test_integer(self) -> None:
|
|
"""
|
|
Integer node evaluates to its value.
|
|
"""
|
|
assert 42 == evaluate_expression(cst.Integer("42"))
|
|
|
|
def test_hex(self) -> None:
|
|
"""
|
|
Hex integer evaluates correctly.
|
|
"""
|
|
assert 255 == evaluate_expression(cst.Integer("0xff"))
|
|
|
|
def test_negative(self) -> None:
|
|
"""
|
|
Unary minus evaluates to negative.
|
|
"""
|
|
node = cst.UnaryOperation(
|
|
operator=cst.Minus(),
|
|
expression=cst.Integer("5"),
|
|
)
|
|
assert -5 == evaluate_expression(node)
|
|
|
|
def test_binary_add(self) -> None:
|
|
"""
|
|
Addition evaluates correctly.
|
|
"""
|
|
node = cst.BinaryOperation(
|
|
left=cst.Integer("3"),
|
|
operator=cst.Add(),
|
|
right=cst.Integer("4"),
|
|
)
|
|
assert 7 == evaluate_expression(node)
|
|
|
|
def test_unevaluable(self) -> None:
|
|
"""
|
|
Name node returns None.
|
|
"""
|
|
assert evaluate_expression(cst.Name("x")) is None
|
|
|
|
|
|
class TestMakeNumberNode:
|
|
"""Tests for make_number_node."""
|
|
|
|
def test_positive(self) -> None:
|
|
"""
|
|
Positive value becomes Integer.
|
|
"""
|
|
node = make_number_node(5)
|
|
assert isinstance(node, cst.Integer)
|
|
assert "5" == node.value
|
|
|
|
def test_negative(self) -> None:
|
|
"""
|
|
Negative value becomes UnaryOperation.
|
|
"""
|
|
node = make_number_node(-3)
|
|
assert isinstance(node, cst.UnaryOperation)
|
|
|
|
|
|
class TestParseModuleToCst:
|
|
"""Tests for parse_module_to_cst."""
|
|
|
|
def test_caches(self) -> None:
|
|
"""
|
|
Same input returns same object.
|
|
"""
|
|
code = "x = 1"
|
|
a = parse_module_to_cst(code)
|
|
b = parse_module_to_cst(code)
|
|
assert a is b
|
|
|
|
|
|
class TestDefinitionRemover:
|
|
"""Tests for DefinitionRemover."""
|
|
|
|
def test_removes_function(self) -> None:
|
|
"""
|
|
Top-level function is removed by name.
|
|
"""
|
|
code = "def foo():\n pass\ndef bar():\n pass\n"
|
|
module = cst.parse_module(code)
|
|
remover = DefinitionRemover({"foo"})
|
|
result = module.visit(remover)
|
|
assert "foo" not in result.code
|
|
assert "bar" in result.code
|
|
assert "foo" in remover.removed_names
|
|
|
|
def test_protected_name(self) -> None:
|
|
"""
|
|
Protected name is not removed.
|
|
"""
|
|
code = "def foo():\n pass\n"
|
|
module = cst.parse_module(code)
|
|
remover = DefinitionRemover({"foo"}, protected_names={"foo"})
|
|
result = module.visit(remover)
|
|
assert "foo" in result.code
|
|
|
|
|
|
class TestImportTrackingVisitor:
|
|
"""Tests for ImportTrackingVisitor."""
|
|
|
|
def test_tracks_imports(self) -> None:
|
|
"""
|
|
Imported names are collected.
|
|
"""
|
|
import ast
|
|
|
|
tree = ast.parse("import os\nfrom sys import path")
|
|
visitor = ImportTrackingVisitor()
|
|
visitor.visit(tree)
|
|
assert "os" in visitor.imported_names
|
|
assert "path" in visitor.imported_names
|
|
|
|
|
|
class TestFindInit:
|
|
"""Tests for find_init."""
|
|
|
|
def test_found(self) -> None:
|
|
"""
|
|
__init__ in a class is found.
|
|
"""
|
|
import ast
|
|
|
|
tree = ast.parse("class Foo:\n def __init__(self):\n pass\n")
|
|
assert find_init(tree)
|
|
|
|
def test_not_found(self) -> None:
|
|
"""
|
|
No __init__ returns False.
|
|
"""
|
|
import ast
|
|
|
|
tree = ast.parse("class Foo:\n def bar(self):\n pass\n")
|
|
assert not find_init(tree)
|
|
|
|
|
|
class TestValidatePythonSyntax:
|
|
"""Tests for validate_python_syntax."""
|
|
|
|
def test_valid(self) -> None:
|
|
"""
|
|
Valid Python returns True.
|
|
"""
|
|
assert validate_python_syntax("x = 1\n")
|
|
|
|
def test_invalid(self) -> None:
|
|
"""
|
|
Invalid Python returns False.
|
|
"""
|
|
assert not validate_python_syntax("def (:\n")
|
|
|
|
def test_empty_body(self) -> None:
|
|
"""
|
|
Empty content (only comments) returns False.
|
|
"""
|
|
assert not validate_python_syntax("# just a comment\n")
|
|
|
|
|
|
class TestParsePythonOrNone:
|
|
"""Tests for parse_python_or_none."""
|
|
|
|
def test_valid(self) -> None:
|
|
"""
|
|
Valid Python returns a Module.
|
|
"""
|
|
result = parse_python_or_none("x = 1\n")
|
|
assert result is not None
|
|
assert isinstance(result, cst.Module)
|
|
|
|
def test_invalid(self) -> None:
|
|
"""
|
|
Invalid Python returns None.
|
|
"""
|
|
assert parse_python_or_none("def (:\n") is None
|
|
|
|
|
|
class TestOptimizationCandidate:
|
|
"""Tests for OptimizationCandidate."""
|
|
|
|
def test_frozen(self) -> None:
|
|
"""
|
|
OptimizationCandidate is immutable.
|
|
"""
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = 1"),
|
|
explanation="faster",
|
|
id="test",
|
|
)
|
|
with pytest.raises(AttributeError):
|
|
c.id = "changed"
|
|
|
|
|
|
class TestSafeIsort:
|
|
"""Tests for safe_isort."""
|
|
|
|
def test_sorts(self) -> None:
|
|
"""
|
|
Imports are sorted.
|
|
"""
|
|
code = "import sys\nimport os\n"
|
|
result = safe_isort(code)
|
|
assert result.index("os") < result.index("sys")
|
|
|
|
def test_invalid_returns_original(self) -> None:
|
|
"""
|
|
Invalid code returns unchanged.
|
|
"""
|
|
code = "not valid python {{{{"
|
|
assert code == safe_isort(code)
|
|
|
|
|
|
class TestDeduplicateOptimizations:
|
|
"""Tests for deduplicate_optimizations."""
|
|
|
|
def test_removes_duplicates(self) -> None:
|
|
"""
|
|
Candidates with same AST are deduplicated.
|
|
"""
|
|
module = cst.parse_module("x = 1\n")
|
|
c1 = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = 1\n"),
|
|
explanation="a",
|
|
id="1",
|
|
)
|
|
c2 = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x=1\n"),
|
|
explanation="b",
|
|
id="2",
|
|
)
|
|
result = deduplicate_optimizations(module, [c1, c2])
|
|
assert 1 == len(result)
|
|
|
|
|
|
class TestEqualityCheck:
|
|
"""Tests for equality_check."""
|
|
|
|
def test_filters_identical(self) -> None:
|
|
"""
|
|
Candidate identical to original is filtered.
|
|
"""
|
|
module = cst.parse_module("x = 1\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = 1\n"),
|
|
explanation="same",
|
|
id="1",
|
|
)
|
|
result = equality_check(module, [c])
|
|
assert 0 == len(result)
|
|
|
|
def test_keeps_different(self) -> None:
|
|
"""
|
|
Different candidate is kept.
|
|
"""
|
|
module = cst.parse_module("x = 1\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = 2\n"),
|
|
explanation="different",
|
|
id="1",
|
|
)
|
|
result = equality_check(module, [c])
|
|
assert 1 == len(result)
|
|
|
|
|
|
class TestFilterEllipsis:
|
|
"""Tests for filter_ellipsis_containing_code."""
|
|
|
|
def test_filters_introduced_ellipsis(self) -> None:
|
|
"""
|
|
Candidate that introduces ellipsis is filtered.
|
|
"""
|
|
module = cst.parse_module("x = 1\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = ...\n"),
|
|
explanation="bad",
|
|
id="1",
|
|
)
|
|
result = filter_ellipsis_containing_code(module, [c])
|
|
assert 0 == len(result)
|
|
|
|
def test_keeps_when_original_has_ellipsis(self) -> None:
|
|
"""
|
|
If original has ellipsis, candidates are not filtered.
|
|
"""
|
|
module = cst.parse_module("x = ...\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = ...\n"),
|
|
explanation="ok",
|
|
id="1",
|
|
)
|
|
result = filter_ellipsis_containing_code(module, [c])
|
|
assert 1 == len(result)
|
|
|
|
|
|
class TestFixMissingDocstring:
|
|
"""Tests for fix_missing_docstring."""
|
|
|
|
def test_restores_docstring(self) -> None:
|
|
"""
|
|
Removed docstring is restored.
|
|
"""
|
|
original = cst.parse_module(
|
|
'def foo():\n """My docstring."""\n pass\n'
|
|
)
|
|
optimized = cst.parse_module("def foo():\n pass\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=optimized,
|
|
explanation="faster",
|
|
id="1",
|
|
)
|
|
result = fix_missing_docstring(original, [c])
|
|
assert "My docstring" in result[0].cst_module.code
|
|
|
|
|
|
class TestHasFutureAnnotationsImport:
|
|
"""Tests for has_future_annotations_import."""
|
|
|
|
def test_present(self) -> None:
|
|
"""
|
|
Module with the import returns True.
|
|
"""
|
|
module = cst.parse_module("from __future__ import annotations\n")
|
|
assert has_future_annotations_import(module)
|
|
|
|
def test_absent(self) -> None:
|
|
"""
|
|
Module without the import returns False.
|
|
"""
|
|
module = cst.parse_module("x = 1\n")
|
|
assert not has_future_annotations_import(module)
|
|
|
|
|
|
class TestAddFutureAnnotationsImport:
|
|
"""Tests for add_future_annotations_import."""
|
|
|
|
def test_adds_when_needed(self) -> None:
|
|
"""
|
|
Import is added when undefined annotations exist.
|
|
"""
|
|
code = "def foo(x: MyType) -> None:\n pass\n"
|
|
module = cst.parse_module(code)
|
|
result = add_future_annotations_import(module)
|
|
assert "from __future__ import annotations" in result.code
|
|
|
|
def test_skips_when_present(self) -> None:
|
|
"""
|
|
No change when import already exists.
|
|
"""
|
|
code = "from __future__ import annotations\ndef foo(x: MyType) -> None:\n pass\n"
|
|
module = cst.parse_module(code)
|
|
result = add_future_annotations_import(module)
|
|
assert result is module
|
|
|
|
def test_skips_when_no_undefined(self) -> None:
|
|
"""
|
|
No change when all annotations are defined.
|
|
"""
|
|
code = "def foo(x: int) -> None:\n pass\n"
|
|
module = cst.parse_module(code)
|
|
result = add_future_annotations_import(module)
|
|
assert result is module
|
|
|
|
|
|
class TestCleanupExplanations:
|
|
"""Tests for cleanup_explanations."""
|
|
|
|
def test_removes_code_block(self) -> None:
|
|
"""
|
|
Markdown code blocks are removed from explanations.
|
|
"""
|
|
module = cst.parse_module("x = 1\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=module,
|
|
explanation="```python\nx = 1\n```",
|
|
id="1",
|
|
)
|
|
result = cleanup_explanations(module, [c])
|
|
assert "```" not in result[0].explanation
|
|
|
|
|
|
class TestPostprocessingPipeline:
|
|
"""Tests for optimizations_postprocessing_pipeline."""
|
|
|
|
def test_empty_candidates(self) -> None:
|
|
"""
|
|
Empty list returns empty list.
|
|
"""
|
|
module = cst.parse_module("x = 1\n")
|
|
result = optimizations_postprocessing_pipeline(module, [])
|
|
assert [] == result
|
|
|
|
def test_filters_identical(self) -> None:
|
|
"""
|
|
Identical candidate is filtered by the pipeline.
|
|
"""
|
|
module = cst.parse_module("x = 1\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = 1\n"),
|
|
explanation="same",
|
|
id="1",
|
|
)
|
|
result = optimizations_postprocessing_pipeline(module, [c])
|
|
assert 0 == len(result)
|
|
|
|
def test_keeps_valid_optimization(self) -> None:
|
|
"""
|
|
Valid different candidate passes through.
|
|
"""
|
|
module = cst.parse_module("x = 1 + 1\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = 2\n"),
|
|
explanation="precomputed",
|
|
id="1",
|
|
)
|
|
result = optimizations_postprocessing_pipeline(module, [c])
|
|
assert 1 == len(result)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _cst_utils.py — additional coverage
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDepthTrackingMixinClassDepth:
|
|
"""Tests for class depth tracking in DepthTrackingMixin."""
|
|
|
|
def test_class_depth(self) -> None:
|
|
"""Entering a class increases class depth."""
|
|
mixin = DepthTrackingMixin()
|
|
assert not mixin._is_inside_class()
|
|
mixin._visit_class()
|
|
assert mixin._is_inside_class()
|
|
assert not mixin._is_top_level_class()
|
|
mixin._leave_class()
|
|
assert mixin._is_top_level_class()
|
|
|
|
def test_top_level_function_check(self) -> None:
|
|
"""_is_top_level_function returns False inside a function."""
|
|
mixin = DepthTrackingMixin()
|
|
assert mixin._is_top_level_function()
|
|
mixin._visit_function()
|
|
assert not mixin._is_top_level_function()
|
|
mixin._leave_function()
|
|
assert mixin._is_top_level_function()
|
|
|
|
def test_top_level_class_with_function(self) -> None:
|
|
"""_is_top_level_class returns False inside a function."""
|
|
mixin = DepthTrackingMixin()
|
|
mixin._visit_function()
|
|
assert not mixin._is_top_level_class()
|
|
|
|
|
|
class TestExtractImportInfo:
|
|
"""Tests for extract_import_info."""
|
|
|
|
def test_simple_name(self) -> None:
|
|
"""Simple import alias extracts name correctly."""
|
|
alias = cst.ImportAlias(name=cst.Name("os"))
|
|
available, module, original = extract_import_info(alias)
|
|
assert "os" == available
|
|
assert "" == module
|
|
assert "os" == original
|
|
|
|
def test_dotted_name_without_module(self) -> None:
|
|
"""Dotted import without module returns first component as available."""
|
|
alias = cst.ImportAlias(
|
|
name=cst.Attribute(value=cst.Name("os"), attr=cst.Name("path"))
|
|
)
|
|
available, module, original = extract_import_info(alias)
|
|
assert "os" == available
|
|
assert "" == module
|
|
assert "os.path" == original
|
|
|
|
def test_with_asname(self) -> None:
|
|
"""Import alias with asname uses the alias."""
|
|
alias = cst.ImportAlias(
|
|
name=cst.Name("os"),
|
|
asname=cst.AsName(
|
|
whitespace_before_as=cst.SimpleWhitespace(" "),
|
|
whitespace_after_as=cst.SimpleWhitespace(" "),
|
|
name=cst.Name("operating_system"),
|
|
),
|
|
)
|
|
available, module, original = extract_import_info(alias)
|
|
assert "operating_system" == available
|
|
assert "os" == original
|
|
|
|
def test_with_module_name(self) -> None:
|
|
"""When module_name is given, available_name is the original name."""
|
|
alias = cst.ImportAlias(name=cst.Name("path"))
|
|
available, module, original = extract_import_info(alias, "os")
|
|
assert "path" == available
|
|
assert "os" == module
|
|
assert "path" == original
|
|
|
|
|
|
class TestExtractImportsFromImport:
|
|
"""Tests for extract_imports_from_import."""
|
|
|
|
def test_simple_import(self) -> None:
|
|
"""Extracts names from a plain import statement."""
|
|
node = cst.parse_statement("import os, sys\n")
|
|
assert isinstance(node, cst.SimpleStatementLine)
|
|
import_node = node.body[0]
|
|
assert isinstance(import_node, cst.Import)
|
|
result = extract_imports_from_import(import_node)
|
|
assert "os" in result
|
|
assert "sys" in result
|
|
|
|
|
|
class TestExtractImportsFromImportFrom:
|
|
"""Tests for extract_imports_from_import_from."""
|
|
|
|
def test_from_import(self) -> None:
|
|
"""Extracts names from 'from X import Y'."""
|
|
node = cst.parse_statement("from os import path, getcwd\n")
|
|
assert isinstance(node, cst.SimpleStatementLine)
|
|
import_from = node.body[0]
|
|
assert isinstance(import_from, cst.ImportFrom)
|
|
result = extract_imports_from_import_from(import_from)
|
|
assert "path" in result
|
|
assert "getcwd" in result
|
|
assert ("os", "path") == result["path"]
|
|
|
|
def test_import_star_returns_empty(self) -> None:
|
|
"""ImportStar returns empty dict."""
|
|
node = cst.parse_statement("from os import *\n")
|
|
assert isinstance(node, cst.SimpleStatementLine)
|
|
import_from = node.body[0]
|
|
assert isinstance(import_from, cst.ImportFrom)
|
|
result = extract_imports_from_import_from(import_from)
|
|
assert {} == result
|
|
|
|
|
|
class TestCollectImportedNames:
|
|
"""Tests for collect_imported_names_from_import and _from_import_from."""
|
|
|
|
def test_collect_from_import(self) -> None:
|
|
"""Collects available names from import statement."""
|
|
node = cst.parse_statement("import os, sys\n")
|
|
import_node = node.body[0]
|
|
assert isinstance(import_node, cst.Import)
|
|
result = collect_imported_names_from_import(import_node)
|
|
assert {"os", "sys"} == result
|
|
|
|
def test_collect_from_import_from(self) -> None:
|
|
"""Collects available names from from-import statement."""
|
|
node = cst.parse_statement("from os import path, getcwd\n")
|
|
import_from = node.body[0]
|
|
assert isinstance(import_from, cst.ImportFrom)
|
|
result = collect_imported_names_from_import_from(import_from)
|
|
assert {"path", "getcwd"} == result
|
|
|
|
def test_collect_from_import_from_star(self) -> None:
|
|
"""ImportStar from-import returns empty set."""
|
|
node = cst.parse_statement("from os import *\n")
|
|
import_from = node.body[0]
|
|
assert isinstance(import_from, cst.ImportFrom)
|
|
result = collect_imported_names_from_import_from(import_from)
|
|
assert set() == result
|
|
|
|
|
|
class TestDefinitionRemoverAdvanced:
|
|
"""Advanced tests for DefinitionRemover."""
|
|
|
|
def test_removes_class(self) -> None:
|
|
"""Top-level class is removed by name."""
|
|
code = "class Foo:\n pass\nclass Bar:\n pass\n"
|
|
module = cst.parse_module(code)
|
|
remover = DefinitionRemover({"Foo"})
|
|
result = module.visit(remover)
|
|
assert "Foo" not in result.code
|
|
assert "Bar" in result.code
|
|
assert "Foo" in remover.removed_names
|
|
|
|
def test_removes_method_causes_class_removal(self) -> None:
|
|
"""Removing a qualified method removes the enclosing class."""
|
|
code = "class Foo:\n def bar(self):\n pass\n"
|
|
module = cst.parse_module(code)
|
|
remover = DefinitionRemover({"Foo.bar"})
|
|
result = module.visit(remover)
|
|
assert "Foo" not in result.code
|
|
assert "Foo" in remover.removed_names
|
|
|
|
def test_nested_function_not_removed(self) -> None:
|
|
"""A function nested inside another is not removed as top-level."""
|
|
code = "def outer():\n def foo():\n pass\n return foo()\n"
|
|
module = cst.parse_module(code)
|
|
remover = DefinitionRemover({"foo"})
|
|
result = module.visit(remover)
|
|
assert "foo" in result.code
|
|
|
|
def test_protected_class_not_removed(self) -> None:
|
|
"""Protected class is not removed."""
|
|
code = "class Foo:\n pass\n"
|
|
module = cst.parse_module(code)
|
|
remover = DefinitionRemover({"Foo"}, protected_names={"Foo"})
|
|
result = module.visit(remover)
|
|
assert "Foo" in result.code
|
|
|
|
|
|
class TestImportTrackingVisitorAdvanced:
|
|
"""Additional tests for ImportTrackingVisitor."""
|
|
|
|
def test_tracks_import_with_alias(self) -> None:
|
|
"""Aliased imports use the alias name."""
|
|
import ast
|
|
|
|
tree = ast.parse("import os as operating_system")
|
|
visitor = ImportTrackingVisitor()
|
|
visitor.visit(tree)
|
|
assert "operating_system" in visitor.imported_names
|
|
assert "os" not in visitor.imported_names
|
|
|
|
def test_tracks_from_import_with_alias(self) -> None:
|
|
"""Aliased from-imports use the alias name."""
|
|
import ast
|
|
|
|
tree = ast.parse("from os import path as p")
|
|
visitor = ImportTrackingVisitor()
|
|
visitor.visit(tree)
|
|
assert "p" in visitor.imported_names
|
|
assert "path" not in visitor.imported_names
|
|
|
|
def test_skips_star_import(self) -> None:
|
|
"""Star imports are skipped."""
|
|
import ast
|
|
|
|
tree = ast.parse("from os import *")
|
|
visitor = ImportTrackingVisitor()
|
|
visitor.visit(tree)
|
|
assert 0 == len(visitor.imported_names)
|
|
|
|
def test_dotted_import_uses_first_component(self) -> None:
|
|
"""'import os.path' tracks 'os' as the available name."""
|
|
import ast
|
|
|
|
tree = ast.parse("import os.path")
|
|
visitor = ImportTrackingVisitor()
|
|
visitor.visit(tree)
|
|
assert "os" in visitor.imported_names
|
|
|
|
|
|
class TestGetBaseClassName:
|
|
"""Tests for get_base_class_name."""
|
|
|
|
def test_simple_name(self) -> None:
|
|
"""Simple Name base class returns the name."""
|
|
arg = cst.Arg(value=cst.Name("Base"))
|
|
assert "Base" == get_base_class_name(arg)
|
|
|
|
def test_attribute_base(self) -> None:
|
|
"""Attribute base class returns the attr name."""
|
|
arg = cst.Arg(
|
|
value=cst.Attribute(value=cst.Name("mod"), attr=cst.Name("Base"))
|
|
)
|
|
assert "Base" == get_base_class_name(arg)
|
|
|
|
def test_non_name_returns_none(self) -> None:
|
|
"""Non-Name/Attribute base returns None."""
|
|
arg = cst.Arg(value=cst.Integer("42"))
|
|
assert get_base_class_name(arg) is None
|
|
|
|
|
|
class TestHasDecorator:
|
|
"""Tests for has_decorator."""
|
|
|
|
def test_simple_decorator(self) -> None:
|
|
"""Class with a simple decorator is detected."""
|
|
code = "@dataclass\nclass Foo:\n pass\n"
|
|
module = cst.parse_module(code)
|
|
class_node = module.body[0]
|
|
assert isinstance(class_node, cst.ClassDef)
|
|
assert has_decorator(class_node, "dataclass")
|
|
|
|
def test_call_decorator(self) -> None:
|
|
"""Class with a call-style decorator is detected."""
|
|
code = "@dataclass(frozen=True)\nclass Foo:\n pass\n"
|
|
module = cst.parse_module(code)
|
|
class_node = module.body[0]
|
|
assert isinstance(class_node, cst.ClassDef)
|
|
assert has_decorator(class_node, "dataclass")
|
|
|
|
def test_missing_decorator(self) -> None:
|
|
"""Class without the decorator returns False."""
|
|
code = "class Foo:\n pass\n"
|
|
module = cst.parse_module(code)
|
|
class_node = module.body[0]
|
|
assert isinstance(class_node, cst.ClassDef)
|
|
assert not has_decorator(class_node, "dataclass")
|
|
|
|
def test_wrong_decorator(self) -> None:
|
|
"""Class with a different decorator returns False."""
|
|
code = "@staticmethod\nclass Foo:\n pass\n"
|
|
module = cst.parse_module(code)
|
|
class_node = module.body[0]
|
|
assert isinstance(class_node, cst.ClassDef)
|
|
assert not has_decorator(class_node, "dataclass")
|
|
|
|
|
|
class TestEvaluateExpressionAdvanced:
|
|
"""Additional tests for evaluate_expression."""
|
|
|
|
def test_float(self) -> None:
|
|
"""Float node evaluates to truncated int."""
|
|
assert 3 == evaluate_expression(cst.Float("3.9"))
|
|
|
|
def test_binary_multiply(self) -> None:
|
|
"""Multiplication evaluates correctly."""
|
|
node = cst.BinaryOperation(
|
|
left=cst.Integer("3"),
|
|
operator=cst.Multiply(),
|
|
right=cst.Integer("4"),
|
|
)
|
|
assert 12 == evaluate_expression(node)
|
|
|
|
def test_binary_subtract(self) -> None:
|
|
"""Subtraction evaluates correctly."""
|
|
node = cst.BinaryOperation(
|
|
left=cst.Integer("10"),
|
|
operator=cst.Subtract(),
|
|
right=cst.Integer("3"),
|
|
)
|
|
assert 7 == evaluate_expression(node)
|
|
|
|
def test_binary_power(self) -> None:
|
|
"""Power evaluates correctly."""
|
|
node = cst.BinaryOperation(
|
|
left=cst.Integer("2"),
|
|
operator=cst.Power(),
|
|
right=cst.Integer("3"),
|
|
)
|
|
assert 8 == evaluate_expression(node)
|
|
|
|
def test_unsupported_binary_op(self) -> None:
|
|
"""Unsupported binary operator returns None."""
|
|
node = cst.BinaryOperation(
|
|
left=cst.Integer("10"),
|
|
operator=cst.FloorDivide(),
|
|
right=cst.Integer("3"),
|
|
)
|
|
assert evaluate_expression(node) is None
|
|
|
|
def test_nested_negative_unevaluable(self) -> None:
|
|
"""Unary minus on unevaluable returns None."""
|
|
node = cst.UnaryOperation(
|
|
operator=cst.Minus(),
|
|
expression=cst.Name("x"),
|
|
)
|
|
assert evaluate_expression(node) is None
|
|
|
|
|
|
class TestEllipsisInCstNotTypes:
|
|
"""Tests for ellipsis_in_cst_not_types."""
|
|
|
|
def test_no_ellipsis(self) -> None:
|
|
"""Module without ellipsis returns False."""
|
|
module = cst.parse_module("x = 1\n")
|
|
assert not ellipsis_in_cst_not_types(module)
|
|
|
|
def test_ellipsis_in_assignment(self) -> None:
|
|
"""Ellipsis in assignment is invalid."""
|
|
module = cst.parse_module("x = ...\n")
|
|
assert ellipsis_in_cst_not_types(module)
|
|
|
|
def test_ellipsis_in_type_annotation(self) -> None:
|
|
"""Ellipsis inside a type subscript is valid."""
|
|
module = cst.parse_module("x: tuple[int, ...] = (1,)\n")
|
|
assert not ellipsis_in_cst_not_types(module)
|
|
|
|
def test_ellipsis_as_function_body(self) -> None:
|
|
"""Ellipsis as function body placeholder is valid."""
|
|
module = cst.parse_module("def foo():\n ...\n")
|
|
assert not ellipsis_in_cst_not_types(module)
|
|
|
|
|
|
class TestAnyEllipsisInCst:
|
|
"""Tests for any_ellipsis_in_cst."""
|
|
|
|
def test_no_ellipsis(self) -> None:
|
|
"""Module without ellipsis returns False."""
|
|
module = cst.parse_module("x = 1\n")
|
|
assert not any_ellipsis_in_cst(module)
|
|
|
|
def test_has_ellipsis(self) -> None:
|
|
"""Module with any ellipsis returns True."""
|
|
module = cst.parse_module("x = ...\n")
|
|
assert any_ellipsis_in_cst(module)
|
|
|
|
|
|
class TestCompareUnparsedAstToSource:
|
|
"""Tests for compare_unparsed_ast_to_source."""
|
|
|
|
def test_equivalent(self) -> None:
|
|
"""Equivalent code compares True."""
|
|
assert compare_unparsed_ast_to_source("x = 1", "x = 1")
|
|
|
|
def test_different(self) -> None:
|
|
"""Different code compares False."""
|
|
assert not compare_unparsed_ast_to_source("x = 1", "x = 2")
|
|
|
|
|
|
class TestUnparseParseSource:
|
|
"""Tests for unparse_parse_source."""
|
|
|
|
def test_normalizes(self) -> None:
|
|
"""Normalizes formatting via ast round-trip."""
|
|
assert "x = 1" == unparse_parse_source("x = 1")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _postprocess.py — additional coverage
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDocstringVisitor:
|
|
"""Tests for DocstringVisitor."""
|
|
|
|
def test_collects_function_docstring(self) -> None:
|
|
"""Collects docstring from a top-level function."""
|
|
module = cst.parse_module('def foo():\n """Hello."""\n pass\n')
|
|
visitor = DocstringVisitor()
|
|
module.visit(visitor)
|
|
assert "Hello." == visitor.original_docstrings["foo"]
|
|
|
|
def test_collects_class_docstring(self) -> None:
|
|
"""Collects docstring from a class."""
|
|
module = cst.parse_module(
|
|
'class Foo:\n """Class doc."""\n pass\n'
|
|
)
|
|
visitor = DocstringVisitor()
|
|
module.visit(visitor)
|
|
assert "Class doc." == visitor.original_docstrings["Foo"]
|
|
|
|
def test_collects_method_docstring(self) -> None:
|
|
"""Collects docstring from a method with qualified name."""
|
|
module = cst.parse_module(
|
|
'class Foo:\n def bar(self):\n """Method doc."""\n pass\n'
|
|
)
|
|
visitor = DocstringVisitor()
|
|
module.visit(visitor)
|
|
assert "Method doc." == visitor.original_docstrings["Foo.bar"]
|
|
|
|
def test_no_docstring(self) -> None:
|
|
"""Function without docstring is not collected."""
|
|
module = cst.parse_module("def foo():\n pass\n")
|
|
visitor = DocstringVisitor()
|
|
module.visit(visitor)
|
|
assert {} == visitor.original_docstrings
|
|
|
|
def test_class_name_reset_after_leave(self) -> None:
|
|
"""class_name is reset to None after leaving a class."""
|
|
module = cst.parse_module(
|
|
"class Foo:\n pass\ndef bar():\n pass\n"
|
|
)
|
|
visitor = DocstringVisitor()
|
|
module.visit(visitor)
|
|
assert visitor.class_name is None
|
|
|
|
|
|
class TestDocstringTransformer:
|
|
"""Tests for DocstringTransformer."""
|
|
|
|
def test_restores_missing_function_docstring(self) -> None:
|
|
"""Restores a removed function docstring."""
|
|
docstrings = {"foo": "My docstring."}
|
|
transformer = DocstringTransformer(docstrings)
|
|
optimized = cst.parse_module("def foo():\n pass\n")
|
|
result = optimized.visit(transformer)
|
|
assert "My docstring." in result.code
|
|
|
|
def test_replaces_changed_function_docstring(self) -> None:
|
|
"""Replaces a modified function docstring with original."""
|
|
docstrings = {"foo": "Original doc."}
|
|
transformer = DocstringTransformer(docstrings)
|
|
optimized = cst.parse_module(
|
|
'def foo():\n """Wrong doc."""\n pass\n'
|
|
)
|
|
result = optimized.visit(transformer)
|
|
assert "Original doc." in result.code
|
|
|
|
def test_restores_missing_class_docstring(self) -> None:
|
|
"""Restores a removed class docstring."""
|
|
docstrings = {"Foo": "Class docstring."}
|
|
transformer = DocstringTransformer(docstrings)
|
|
optimized = cst.parse_module("class Foo:\n pass\n")
|
|
result = optimized.visit(transformer)
|
|
assert "Class docstring." in result.code
|
|
|
|
def test_replaces_changed_class_docstring(self) -> None:
|
|
"""Replaces a modified class docstring with original."""
|
|
docstrings = {"Foo": "Original class doc."}
|
|
transformer = DocstringTransformer(docstrings)
|
|
optimized = cst.parse_module(
|
|
'class Foo:\n """Wrong class doc."""\n pass\n'
|
|
)
|
|
result = optimized.visit(transformer)
|
|
assert "Original class doc." in result.code
|
|
|
|
def test_no_change_without_original(self) -> None:
|
|
"""No change when no original docstring exists."""
|
|
transformer = DocstringTransformer({})
|
|
optimized = cst.parse_module("def foo():\n pass\n")
|
|
result = optimized.visit(transformer)
|
|
assert result.code == optimized.code
|
|
|
|
|
|
class TestFixMissingDocstringAdvanced:
|
|
"""Additional tests for fix_missing_docstring."""
|
|
|
|
def test_restores_class_docstring(self) -> None:
|
|
"""Restores a removed class docstring."""
|
|
original = cst.parse_module(
|
|
'class Foo:\n """Class doc."""\n pass\n'
|
|
)
|
|
optimized = cst.parse_module("class Foo:\n pass\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=optimized,
|
|
explanation="faster",
|
|
id="1",
|
|
)
|
|
result = fix_missing_docstring(original, [c])
|
|
assert "Class doc." in result[0].cst_module.code
|
|
|
|
def test_restores_method_docstring(self) -> None:
|
|
"""Restores a removed method docstring."""
|
|
original = cst.parse_module(
|
|
'class Foo:\n def bar(self):\n """Method doc."""\n pass\n'
|
|
)
|
|
optimized = cst.parse_module(
|
|
"class Foo:\n def bar(self):\n pass\n"
|
|
)
|
|
c = OptimizationCandidate(
|
|
cst_module=optimized,
|
|
explanation="faster",
|
|
id="1",
|
|
)
|
|
result = fix_missing_docstring(original, [c])
|
|
assert "Method doc." in result[0].cst_module.code
|
|
|
|
|
|
class TestDedupAndSortImports:
|
|
"""Tests for dedup_and_sort_imports."""
|
|
|
|
def test_sorts_imports(self) -> None:
|
|
"""Unsorted imports get sorted."""
|
|
code = "import sys\nimport os\n\nx = 1\n"
|
|
module = cst.parse_module(code)
|
|
c = OptimizationCandidate(
|
|
cst_module=module,
|
|
explanation="test",
|
|
id="1",
|
|
)
|
|
result = dedup_and_sort_imports(module, [c])
|
|
assert 1 == len(result)
|
|
result_code = result[0].cst_module.code
|
|
assert result_code.index("os") < result_code.index("sys")
|
|
|
|
def test_no_change_when_sorted(self) -> None:
|
|
"""Already sorted imports are unchanged."""
|
|
code = "import os\nimport sys\n"
|
|
module = cst.parse_module(code)
|
|
c = OptimizationCandidate(
|
|
cst_module=module,
|
|
explanation="test",
|
|
id="1",
|
|
)
|
|
result = dedup_and_sort_imports(module, [c])
|
|
assert 1 == len(result)
|
|
|
|
|
|
class TestEllipsisContainingCodeVisitor:
|
|
"""Tests for EllipsisContainingCodeVisitor."""
|
|
|
|
def test_no_ellipsis(self) -> None:
|
|
"""Module without ellipsis sets flag to False."""
|
|
module = cst.parse_module("x = 1\n")
|
|
visitor = EllipsisContainingCodeVisitor()
|
|
module.visit(visitor)
|
|
assert not visitor.ellipsis_containing_code
|
|
|
|
def test_has_ellipsis(self) -> None:
|
|
"""Module with ellipsis sets flag to True."""
|
|
module = cst.parse_module("x = ...\n")
|
|
visitor = EllipsisContainingCodeVisitor()
|
|
module.visit(visitor)
|
|
assert visitor.ellipsis_containing_code
|
|
|
|
|
|
class TestFilterEllipsisAdvanced:
|
|
"""Additional tests for filter_ellipsis_containing_code."""
|
|
|
|
def test_keeps_non_ellipsis_candidate(self) -> None:
|
|
"""Candidate without ellipsis is kept when original has none."""
|
|
module = cst.parse_module("x = 1\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = 2\n"),
|
|
explanation="ok",
|
|
id="1",
|
|
)
|
|
result = filter_ellipsis_containing_code(module, [c])
|
|
assert 1 == len(result)
|
|
|
|
def test_mixed_candidates(self) -> None:
|
|
"""Only candidates with ellipsis are filtered when original has none."""
|
|
module = cst.parse_module("x = 1\n")
|
|
c1 = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = 2\n"),
|
|
explanation="good",
|
|
id="1",
|
|
)
|
|
c2 = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = ...\n"),
|
|
explanation="bad",
|
|
id="2",
|
|
)
|
|
result = filter_ellipsis_containing_code(module, [c1, c2])
|
|
assert 1 == len(result)
|
|
assert "1" == result[0].id
|
|
|
|
|
|
class TestStripCommentsFromCode:
|
|
"""Tests for _strip_comments_from_code."""
|
|
|
|
def test_removes_inline_comment(self) -> None:
|
|
"""Inline comments are removed."""
|
|
code = "x = 1 # a comment\n"
|
|
result = _strip_comments_from_code(code)
|
|
assert "#" not in result
|
|
assert "x = 1" in result
|
|
|
|
def test_removes_full_line_comment(self) -> None:
|
|
"""Full-line comments are removed."""
|
|
code = "# full line comment\nx = 1\n"
|
|
result = _strip_comments_from_code(code)
|
|
assert "full line comment" not in result
|
|
assert "x = 1" in result
|
|
|
|
def test_preserves_strings(self) -> None:
|
|
"""Hash inside strings is preserved."""
|
|
code = 'x = "# not a comment"\n'
|
|
result = _strip_comments_from_code(code)
|
|
assert "# not a comment" in result
|
|
|
|
def test_invalid_code_returns_original(self) -> None:
|
|
"""Invalid code returns the original string."""
|
|
code = "def (:\n"
|
|
assert code == _strip_comments_from_code(code)
|
|
|
|
|
|
class TestCleanExtraneousComments:
|
|
"""Tests for clean_extraneous_comments."""
|
|
|
|
def test_preserves_original_comments(self) -> None:
|
|
"""Comments from original code are preserved in output."""
|
|
original_code = "x = 1 # original comment\n"
|
|
optimized_code = "x = 1 # original comment\n"
|
|
original_module = cst.parse_module(original_code)
|
|
optimized_module = cst.parse_module(optimized_code)
|
|
result = clean_extraneous_comments(original_module, optimized_module)
|
|
assert "original comment" in result.code
|
|
|
|
def test_removes_new_comments_on_unchanged_lines(self) -> None:
|
|
"""New comments added to unchanged code lines are stripped."""
|
|
original_code = "x = 1\ny = 2\n"
|
|
optimized_code = "x = 1 # new comment\ny = 2\n"
|
|
original_module = cst.parse_module(original_code)
|
|
optimized_module = cst.parse_module(optimized_code)
|
|
result = clean_extraneous_comments(original_module, optimized_module)
|
|
assert "new comment" not in result.code
|
|
|
|
|
|
class TestCleanExtraneousCommentsPipeline:
|
|
"""Tests for clean_extraneous_comments_pipeline."""
|
|
|
|
def test_processes_candidates(self) -> None:
|
|
"""Pipeline processes all candidates."""
|
|
original = cst.parse_module("x = 1\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = 2 # optimized\n"),
|
|
explanation="test",
|
|
id="1",
|
|
)
|
|
result = clean_extraneous_comments_pipeline(original, [c])
|
|
assert 1 == len(result)
|
|
|
|
|
|
class TestEqualityCheckAdvanced:
|
|
"""Additional edge-case tests for equality_check."""
|
|
|
|
def test_with_explicit_original_code(self) -> None:
|
|
"""Original code can be passed explicitly."""
|
|
module = cst.parse_module("x = 1\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = 2\n"),
|
|
explanation="different",
|
|
id="1",
|
|
)
|
|
result = equality_check(module, [c], original_code="x = 1\n")
|
|
assert 1 == len(result)
|
|
|
|
def test_filters_with_explicit_code(self) -> None:
|
|
"""Identical candidate is filtered with explicit code."""
|
|
module = cst.parse_module("x = 1\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = 1\n"),
|
|
explanation="same",
|
|
id="1",
|
|
)
|
|
result = equality_check(module, [c], original_code="x = 1\n")
|
|
assert 0 == len(result)
|
|
|
|
def test_unparseable_original_falls_back_to_string(self) -> None:
|
|
"""When original is unparseable, falls back to string comparison."""
|
|
module = cst.parse_module("x = 1\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = 2\n"),
|
|
explanation="different",
|
|
id="1",
|
|
)
|
|
result = equality_check(module, [c], original_code="def (:\n")
|
|
assert 1 == len(result)
|
|
|
|
|
|
class TestExtractNamesFromCstAnnotation:
|
|
"""Tests for extract_names_from_cst_annotation."""
|
|
|
|
def test_simple_name(self) -> None:
|
|
"""Extracts simple name from annotation."""
|
|
names: set[str] = set()
|
|
extract_names_from_cst_annotation(cst.Name("int"), names)
|
|
assert {"int"} == names
|
|
|
|
def test_attribute(self) -> None:
|
|
"""Extracts base name from attribute annotation."""
|
|
names: set[str] = set()
|
|
node = cst.Attribute(value=cst.Name("typing"), attr=cst.Name("List"))
|
|
extract_names_from_cst_annotation(node, names)
|
|
assert "typing" in names
|
|
|
|
def test_subscript(self) -> None:
|
|
"""Extracts names from subscript annotation like List[int]."""
|
|
names: set[str] = set()
|
|
node = cst.Subscript(
|
|
value=cst.Name("List"),
|
|
slice=[
|
|
cst.SubscriptElement(slice=cst.Index(value=cst.Name("int")))
|
|
],
|
|
)
|
|
extract_names_from_cst_annotation(node, names)
|
|
assert "List" in names
|
|
assert "int" in names
|
|
|
|
def test_binary_operation(self) -> None:
|
|
"""Extracts names from union-style binary annotation."""
|
|
names: set[str] = set()
|
|
node = cst.BinaryOperation(
|
|
left=cst.Name("str"),
|
|
operator=cst.BitOr(),
|
|
right=cst.Name("int"),
|
|
)
|
|
extract_names_from_cst_annotation(node, names)
|
|
assert {"str", "int"} == names
|
|
|
|
def test_tuple_annotation(self) -> None:
|
|
"""Extracts names from tuple annotation."""
|
|
names: set[str] = set()
|
|
node = cst.Tuple(
|
|
elements=[
|
|
cst.Element(value=cst.Name("int")),
|
|
cst.Element(value=cst.Name("str")),
|
|
]
|
|
)
|
|
extract_names_from_cst_annotation(node, names)
|
|
assert {"int", "str"} == names
|
|
|
|
|
|
class TestCSTAnnotationNameCollector:
|
|
"""Tests for CSTAnnotationNameCollector."""
|
|
|
|
def test_collects_function_annotations(self) -> None:
|
|
"""Collects annotation names from function parameters and return."""
|
|
module = cst.parse_module(
|
|
"def foo(x: MyType) -> ReturnType:\n pass\n"
|
|
)
|
|
collector = CSTAnnotationNameCollector()
|
|
module.visit(collector)
|
|
assert "MyType" in collector.annotation_names
|
|
assert "ReturnType" in collector.annotation_names
|
|
|
|
def test_collects_imports(self) -> None:
|
|
"""Tracks imported names."""
|
|
module = cst.parse_module("import os\nfrom sys import path\n")
|
|
collector = CSTAnnotationNameCollector()
|
|
module.visit(collector)
|
|
assert "os" in collector.imported_names
|
|
assert "path" in collector.imported_names
|
|
|
|
def test_collects_class_names(self) -> None:
|
|
"""Tracks defined class names."""
|
|
module = cst.parse_module("class Foo:\n pass\n")
|
|
collector = CSTAnnotationNameCollector()
|
|
module.visit(collector)
|
|
assert "Foo" in collector.defined_names
|
|
|
|
def test_collects_function_names(self) -> None:
|
|
"""Tracks defined function names."""
|
|
module = cst.parse_module("def foo():\n pass\n")
|
|
collector = CSTAnnotationNameCollector()
|
|
module.visit(collector)
|
|
assert "foo" in collector.defined_names
|
|
|
|
def test_undefined_annotations(self) -> None:
|
|
"""Returns annotation names that are not defined or imported."""
|
|
module = cst.parse_module("def foo(x: MyType) -> None:\n pass\n")
|
|
collector = CSTAnnotationNameCollector()
|
|
module.visit(collector)
|
|
undefined = collector.get_undefined_annotation_names()
|
|
assert "MyType" in undefined
|
|
assert "None" not in undefined
|
|
|
|
def test_no_undefined_with_import(self) -> None:
|
|
"""No undefined names when all annotations are imported."""
|
|
module = cst.parse_module(
|
|
"from mymod import MyType\ndef foo(x: MyType) -> None:\n pass\n"
|
|
)
|
|
collector = CSTAnnotationNameCollector()
|
|
module.visit(collector)
|
|
undefined = collector.get_undefined_annotation_names()
|
|
assert "MyType" not in undefined
|
|
|
|
def test_collects_kwonly_params(self) -> None:
|
|
"""Collects annotations from keyword-only parameters."""
|
|
module = cst.parse_module("def foo(*, x: KWType) -> None:\n pass\n")
|
|
collector = CSTAnnotationNameCollector()
|
|
module.visit(collector)
|
|
assert "KWType" in collector.annotation_names
|
|
|
|
def test_collects_star_arg_annotation(self) -> None:
|
|
"""Collects annotations from *args parameter."""
|
|
module = cst.parse_module(
|
|
"def foo(*args: ArgsType) -> None:\n pass\n"
|
|
)
|
|
collector = CSTAnnotationNameCollector()
|
|
module.visit(collector)
|
|
assert "ArgsType" in collector.annotation_names
|
|
|
|
def test_collects_star_kwarg_annotation(self) -> None:
|
|
"""Collects annotations from **kwargs parameter."""
|
|
module = cst.parse_module(
|
|
"def foo(**kwargs: KwargsType) -> None:\n pass\n"
|
|
)
|
|
collector = CSTAnnotationNameCollector()
|
|
module.visit(collector)
|
|
assert "KwargsType" in collector.annotation_names
|
|
|
|
def test_collects_ann_assign(self) -> None:
|
|
"""Collects annotation names from annotated assignments."""
|
|
module = cst.parse_module("x: MyType = 1\n")
|
|
collector = CSTAnnotationNameCollector()
|
|
module.visit(collector)
|
|
assert "MyType" in collector.annotation_names
|
|
|
|
|
|
class TestFixForwardReferences:
|
|
"""Tests for fix_forward_references."""
|
|
|
|
def test_adds_import_when_needed(self) -> None:
|
|
"""Adds future annotations import for undefined annotation names."""
|
|
module = cst.parse_module("x = 1\n")
|
|
code = "def foo(x: MyType) -> None:\n pass\n"
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module(code),
|
|
explanation="test",
|
|
id="1",
|
|
)
|
|
result = fix_forward_references(module, [c])
|
|
assert 1 == len(result)
|
|
assert (
|
|
"from __future__ import annotations" in result[0].cst_module.code
|
|
)
|
|
|
|
def test_no_change_when_all_defined(self) -> None:
|
|
"""No change when all annotations are builtins."""
|
|
module = cst.parse_module("x = 1\n")
|
|
code = "def foo(x: int) -> None:\n pass\n"
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module(code),
|
|
explanation="test",
|
|
id="1",
|
|
)
|
|
result = fix_forward_references(module, [c])
|
|
assert 1 == len(result)
|
|
assert (
|
|
"from __future__ import annotations"
|
|
not in result[0].cst_module.code
|
|
)
|
|
|
|
def test_no_change_when_already_present(self) -> None:
|
|
"""No change when future annotations import already exists."""
|
|
module = cst.parse_module("x = 1\n")
|
|
code = "from __future__ import annotations\ndef foo(x: MyType) -> None:\n pass\n"
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module(code),
|
|
explanation="test",
|
|
id="1",
|
|
)
|
|
result = fix_forward_references(module, [c])
|
|
assert 1 == len(result)
|
|
|
|
|
|
class TestPostprocessingPipelineAdvanced:
|
|
"""Additional tests for optimizations_postprocessing_pipeline."""
|
|
|
|
def test_pipeline_restores_docstring(self) -> None:
|
|
"""Pipeline restores removed docstrings."""
|
|
original = cst.parse_module(
|
|
'def foo():\n """My doc."""\n return 1\n'
|
|
)
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("def foo():\n return 2\n"),
|
|
explanation="faster",
|
|
id="1",
|
|
)
|
|
result = optimizations_postprocessing_pipeline(original, [c])
|
|
assert 1 == len(result)
|
|
assert "My doc." in result[0].cst_module.code
|
|
|
|
def test_pipeline_adds_future_annotations(self) -> None:
|
|
"""Pipeline adds future annotations when needed."""
|
|
original = cst.parse_module("def foo():\n return 1\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module(
|
|
"def foo(x: MyCustomType) -> None:\n return 2\n"
|
|
),
|
|
explanation="typed",
|
|
id="1",
|
|
)
|
|
result = optimizations_postprocessing_pipeline(original, [c])
|
|
assert 1 == len(result)
|
|
assert (
|
|
"from __future__ import annotations" in result[0].cst_module.code
|
|
)
|
|
|
|
def test_pipeline_cleans_explanations(self) -> None:
|
|
"""Pipeline cleans up LLM explanation artifacts."""
|
|
original = cst.parse_module("x = 1 + 1\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = 2\n"),
|
|
explanation="```python\nx = 2\n```",
|
|
id="1",
|
|
)
|
|
result = optimizations_postprocessing_pipeline(original, [c])
|
|
assert 1 == len(result)
|
|
assert "```" not in result[0].explanation
|
|
|
|
def test_pipeline_filters_ellipsis(self) -> None:
|
|
"""Pipeline filters candidates that introduce ellipsis."""
|
|
original = cst.parse_module("x = 1\ny = 2\n")
|
|
c = OptimizationCandidate(
|
|
cst_module=cst.parse_module("x = ...\ny = ...\n"),
|
|
explanation="bad",
|
|
id="1",
|
|
)
|
|
result = optimizations_postprocessing_pipeline(original, [c])
|
|
assert 0 == len(result)
|