research: replace prune_cst boolean params with PruneConfig dataclass
This commit is contained in:
parent
282f2ba713
commit
c2c21da0c0
1 changed files with 39 additions and 113 deletions
|
|
@ -4,6 +4,7 @@ import ast
|
|||
import hashlib
|
||||
import os
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
|
@ -1650,6 +1651,18 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode
|
|||
return indented_block
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PruneConfig:
|
||||
defs_with_usages: dict[str, UsageInfo] | None = None
|
||||
helpers: set[str] | None = None
|
||||
remove_docstrings: bool = False
|
||||
include_target_in_output: bool = True
|
||||
exclude_init_from_targets: bool = False
|
||||
keep_class_init: bool = False
|
||||
include_dunder_methods: bool = False
|
||||
include_init_dunder: bool = False
|
||||
|
||||
|
||||
def parse_code_and_prune_cst(
|
||||
code: str,
|
||||
code_context_type: CodeContextType,
|
||||
|
|
@ -1662,34 +1675,28 @@ def parse_code_and_prune_cst(
|
|||
defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions)
|
||||
|
||||
if code_context_type == CodeContextType.READ_WRITABLE:
|
||||
filtered_node, found_target = prune_cst(
|
||||
module, target_functions, defs_with_usages=defs_with_usages, keep_class_init=True
|
||||
)
|
||||
cfg = PruneConfig(defs_with_usages=defs_with_usages, keep_class_init=True)
|
||||
elif code_context_type == CodeContextType.READ_ONLY:
|
||||
filtered_node, found_target = prune_cst(
|
||||
module,
|
||||
target_functions,
|
||||
cfg = PruneConfig(
|
||||
helpers=helpers_of_helper_functions,
|
||||
remove_docstrings=remove_docstrings,
|
||||
include_target_in_output=False,
|
||||
include_dunder_methods=True,
|
||||
)
|
||||
elif code_context_type == CodeContextType.TESTGEN:
|
||||
filtered_node, found_target = prune_cst(
|
||||
module,
|
||||
target_functions,
|
||||
cfg = PruneConfig(
|
||||
helpers=helpers_of_helper_functions,
|
||||
remove_docstrings=remove_docstrings,
|
||||
include_dunder_methods=True,
|
||||
include_init_dunder=True,
|
||||
)
|
||||
elif code_context_type == CodeContextType.HASHING:
|
||||
filtered_node, found_target = prune_cst(
|
||||
module, target_functions, remove_docstrings=True, exclude_init_from_targets=True
|
||||
)
|
||||
cfg = PruneConfig(remove_docstrings=True, exclude_init_from_targets=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown code_context_type: {code_context_type}") # noqa: EM102
|
||||
|
||||
filtered_node, found_target = prune_cst(module, target_functions, cfg)
|
||||
|
||||
if not found_target:
|
||||
raise ValueError("No target functions found in the provided code")
|
||||
if filtered_node and isinstance(filtered_node, cst.Module):
|
||||
|
|
@ -1700,38 +1707,9 @@ def parse_code_and_prune_cst(
|
|||
def prune_cst(
|
||||
node: cst.CSTNode,
|
||||
target_functions: set[str],
|
||||
cfg: PruneConfig,
|
||||
prefix: str = "",
|
||||
*,
|
||||
defs_with_usages: dict[str, UsageInfo] | None = None,
|
||||
helpers: set[str] | None = None,
|
||||
remove_docstrings: bool = False,
|
||||
include_target_in_output: bool = True,
|
||||
exclude_init_from_targets: bool = False,
|
||||
keep_class_init: bool = False,
|
||||
include_dunder_methods: bool = False,
|
||||
include_init_dunder: bool = False,
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Unified function to prune CST nodes based on various filtering criteria.
|
||||
|
||||
Args:
|
||||
node: The CST node to filter
|
||||
target_functions: Set of qualified function names that are targets
|
||||
prefix: Current qualified name prefix (for class methods)
|
||||
defs_with_usages: Dict of definitions with usage info (for READ_WRITABLE mode)
|
||||
helpers: Set of helper function qualified names (for READ_ONLY/TESTGEN modes)
|
||||
remove_docstrings: Whether to remove docstrings from output
|
||||
include_target_in_output: Whether to include target functions in output
|
||||
exclude_init_from_targets: Whether to exclude __init__ from targets (HASHING mode)
|
||||
keep_class_init: Whether to keep __init__ methods in classes (READ_WRITABLE mode)
|
||||
include_dunder_methods: Whether to include dunder methods (READ_ONLY/TESTGEN modes)
|
||||
include_init_dunder: Whether to include __init__ in dunder methods
|
||||
|
||||
Returns:
|
||||
(filtered_node, found_target):
|
||||
filtered_node: The modified CST node or None if it should be removed.
|
||||
found_target: True if a target function was found in this node's subtree.
|
||||
|
||||
"""
|
||||
if isinstance(node, (cst.Import, cst.ImportFrom)):
|
||||
return None, False
|
||||
|
||||
|
|
@ -1739,37 +1717,33 @@ def prune_cst(
|
|||
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
|
||||
# Check if it's a helper function (higher priority than target)
|
||||
if helpers and qualified_name in helpers:
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
if cfg.helpers and qualified_name in cfg.helpers:
|
||||
if cfg.remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
return node.with_changes(body=remove_docstring_from_body(node.body)), True
|
||||
return node, True
|
||||
|
||||
# Check if it's a target function
|
||||
if qualified_name in target_functions:
|
||||
# Handle exclude_init_from_targets for HASHING mode
|
||||
if exclude_init_from_targets and node.name.value == "__init__":
|
||||
if cfg.exclude_init_from_targets and node.name.value == "__init__":
|
||||
return None, False
|
||||
|
||||
if include_target_in_output:
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
if cfg.include_target_in_output:
|
||||
if cfg.remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
return node.with_changes(body=remove_docstring_from_body(node.body)), True
|
||||
return node, True
|
||||
return None, True
|
||||
|
||||
# Handle class __init__ for READ_WRITABLE mode
|
||||
if keep_class_init and node.name.value == "__init__":
|
||||
if cfg.keep_class_init and node.name.value == "__init__":
|
||||
return node, False
|
||||
|
||||
# Handle dunder methods for READ_ONLY/TESTGEN modes
|
||||
if (
|
||||
include_dunder_methods
|
||||
cfg.include_dunder_methods
|
||||
and len(node.name.value) > 4
|
||||
and node.name.value.startswith("__")
|
||||
and node.name.value.endswith("__")
|
||||
):
|
||||
if not include_init_dunder and node.name.value == "__init__":
|
||||
if not cfg.include_init_dunder and node.name.value == "__init__":
|
||||
return None, False
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
if cfg.remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
return node.with_changes(body=remove_docstring_from_body(node.body)), False
|
||||
return node, False
|
||||
|
||||
|
|
@ -1780,43 +1754,26 @@ def prune_cst(
|
|||
return None, False
|
||||
if not isinstance(node.body, cst.IndentedBlock):
|
||||
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
|
||||
class_prefix = node.name.value
|
||||
class_name = node.name.value
|
||||
|
||||
# Handle dependency classes for READ_WRITABLE mode
|
||||
if defs_with_usages:
|
||||
# Check if this class contains any target functions
|
||||
if cfg.defs_with_usages:
|
||||
has_target_functions = any(
|
||||
isinstance(stmt, cst.FunctionDef) and f"{class_prefix}.{stmt.name.value}" in target_functions
|
||||
isinstance(stmt, cst.FunctionDef) and f"{class_name}.{stmt.name.value}" in target_functions
|
||||
for stmt in node.body.body
|
||||
)
|
||||
|
||||
# If the class is used as a dependency (not containing target functions), keep it entirely
|
||||
if (
|
||||
not has_target_functions
|
||||
and class_name in defs_with_usages
|
||||
and defs_with_usages[class_name].used_by_qualified_function
|
||||
and class_name in cfg.defs_with_usages
|
||||
and cfg.defs_with_usages[class_name].used_by_qualified_function
|
||||
):
|
||||
return node, True
|
||||
|
||||
# Recursively filter each statement in the class body
|
||||
new_class_body: list[cst.CSTNode] = []
|
||||
found_in_class = False
|
||||
|
||||
for stmt in node.body.body:
|
||||
filtered, found_target = prune_cst(
|
||||
stmt,
|
||||
target_functions,
|
||||
class_prefix,
|
||||
defs_with_usages=defs_with_usages,
|
||||
helpers=helpers,
|
||||
remove_docstrings=remove_docstrings,
|
||||
include_target_in_output=include_target_in_output,
|
||||
exclude_init_from_targets=exclude_init_from_targets,
|
||||
keep_class_init=keep_class_init,
|
||||
include_dunder_methods=include_dunder_methods,
|
||||
include_init_dunder=include_init_dunder,
|
||||
)
|
||||
filtered, found_target = prune_cst(stmt, target_functions, cfg, class_name)
|
||||
found_in_class |= found_target
|
||||
if filtered:
|
||||
new_class_body.append(filtered)
|
||||
|
|
@ -1824,8 +1781,7 @@ def prune_cst(
|
|||
if not found_in_class:
|
||||
return None, False
|
||||
|
||||
# Apply docstring removal to class if needed
|
||||
if remove_docstrings and new_class_body:
|
||||
if cfg.remove_docstrings and new_class_body:
|
||||
updated_body = node.body.with_changes(body=new_class_body)
|
||||
assert isinstance(updated_body, cst.IndentedBlock)
|
||||
return node.with_changes(body=remove_docstring_from_body(updated_body)), True
|
||||
|
|
@ -1833,9 +1789,9 @@ def prune_cst(
|
|||
return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True
|
||||
|
||||
# Handle assignments for READ_WRITABLE mode
|
||||
if defs_with_usages is not None:
|
||||
if cfg.defs_with_usages is not None:
|
||||
if isinstance(node, (cst.Assign, cst.AnnAssign, cst.AugAssign)):
|
||||
if is_assignment_used(node, defs_with_usages):
|
||||
if is_assignment_used(node, cfg.defs_with_usages):
|
||||
return node, True
|
||||
return None, False
|
||||
|
||||
|
|
@ -1844,41 +1800,11 @@ def prune_cst(
|
|||
if not section_names:
|
||||
return node, False
|
||||
|
||||
if helpers is not None:
|
||||
return recurse_sections(
|
||||
node,
|
||||
section_names,
|
||||
lambda child: prune_cst(
|
||||
child,
|
||||
target_functions,
|
||||
prefix,
|
||||
defs_with_usages=defs_with_usages,
|
||||
helpers=helpers,
|
||||
remove_docstrings=remove_docstrings,
|
||||
include_target_in_output=include_target_in_output,
|
||||
exclude_init_from_targets=exclude_init_from_targets,
|
||||
keep_class_init=keep_class_init,
|
||||
include_dunder_methods=include_dunder_methods,
|
||||
include_init_dunder=include_init_dunder,
|
||||
),
|
||||
keep_non_target_children=True,
|
||||
)
|
||||
return recurse_sections(
|
||||
node,
|
||||
section_names,
|
||||
lambda child: prune_cst(
|
||||
child,
|
||||
target_functions,
|
||||
prefix,
|
||||
defs_with_usages=defs_with_usages,
|
||||
helpers=helpers,
|
||||
remove_docstrings=remove_docstrings,
|
||||
include_target_in_output=include_target_in_output,
|
||||
exclude_init_from_targets=exclude_init_from_targets,
|
||||
keep_class_init=keep_class_init,
|
||||
include_dunder_methods=include_dunder_methods,
|
||||
include_init_dunder=include_init_dunder,
|
||||
),
|
||||
lambda child: prune_cst(child, target_functions, cfg, prefix),
|
||||
keep_non_target_children=cfg.helpers is not None,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue