perf: eliminate redundant CST parsing in get_code_optimization_context

Parse each file once instead of up to 16 times by:
- Making remove_unused_definitions_by_function_names accept/return cst.Module
- Making parse_code_and_prune_cst and add_needed_imports_from_module accept cst.Module
- Threading the parsed Module through process_file_context
- Adding extract_all_contexts_from_files that processes all 4 context types
  (READ_WRITABLE, READ_ONLY, HASHING, TESTGEN) in a single per-file pass
This commit is contained in:
Kevin Turcios 2026-03-16 10:11:58 -06:00
parent 17f4bbd6f9
commit 5671562da2
4 changed files with 294 additions and 119 deletions

View file

@ -4,7 +4,7 @@ import ast
import hashlib
import os
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING
@ -46,22 +46,22 @@ if TYPE_CHECKING:
from codeflash.languages.python.context.unused_definition_remover import UsageInfo
@dataclass
class AllContextResults:
read_writable: CodeStringsMarkdown
read_only: CodeStringsMarkdown
hashing: CodeStringsMarkdown
testgen: CodeStringsMarkdown
def build_testgen_context(
helpers_of_fto_dict: dict[Path, set[FunctionSource]],
helpers_of_helpers_dict: dict[Path, set[FunctionSource]],
testgen_base: CodeStringsMarkdown,
project_root_path: Path,
*,
remove_docstrings: bool = False,
include_enrichment: bool = True,
function_to_optimize: FunctionToOptimize | None = None,
) -> CodeStringsMarkdown:
testgen_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=remove_docstrings,
code_context_type=CodeContextType.TESTGEN,
)
testgen_context = testgen_base
if include_enrichment:
enrichment = enrich_testgen_context(testgen_context, project_root_path)
@ -114,14 +114,10 @@ def get_code_optimization_context(
helpers_of_fto_qualified_names_dict, project_root_path
)
# Extract code context for optimization
final_read_writable_code = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
{},
project_root_path,
remove_docstrings=False,
code_context_type=CodeContextType.READ_WRITABLE,
)
# Extract all code contexts in a single pass (one CST parse per file)
all_ctx = extract_all_contexts_from_files(helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path)
final_read_writable_code = all_ctx.read_writable
# Ensure the target file is first in the code blocks so the LLM knows which file to optimize
target_relative = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve())
@ -130,20 +126,7 @@ def get_code_optimization_context(
if target_blocks:
final_read_writable_code.code_strings = target_blocks + other_blocks
read_only_code_markdown = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=False,
code_context_type=CodeContextType.READ_ONLY,
)
hashing_code_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
code_context_type=CodeContextType.HASHING,
)
read_only_code_markdown = all_ctx.read_only
# Handle token limits
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.markdown)
@ -173,32 +156,29 @@ def get_code_optimization_context(
# Progressive fallback for testgen context token limits
testgen_context = build_testgen_context(
helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, function_to_optimize=function_to_optimize
all_ctx.testgen, project_root_path, function_to_optimize=function_to_optimize
)
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
logger.debug("Testgen context exceeded token limit, removing docstrings")
testgen_context = build_testgen_context(
testgen_base_no_docs = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
function_to_optimize=function_to_optimize,
code_context_type=CodeContextType.TESTGEN,
)
testgen_context = build_testgen_context(
testgen_base_no_docs, project_root_path, function_to_optimize=function_to_optimize
)
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
logger.debug("Testgen context still exceeded token limit, removing enrichment")
testgen_context = build_testgen_context(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
include_enrichment=False,
)
testgen_context = build_testgen_context(testgen_base_no_docs, project_root_path, include_enrichment=False)
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
raise ValueError(TESTGEN_LIMIT_ERROR)
code_hash_context = hashing_code_context.markdown
code_hash_context = all_ctx.hashing.markdown
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
all_helper_fqns = list({fs.fully_qualified_name for fs in helpers_of_fto_list + helpers_of_helpers_list})
@ -230,15 +210,17 @@ def process_file_context(
logger.exception(f"Error while parsing {file_path}: {e}")
return None
try:
original_module = cst.parse_module(original_code)
except Exception as e:
logger.debug(f"Failed to parse {file_path} with libcst: {type(e).__name__}: {e}")
return None
try:
all_names = primary_qualified_names | secondary_qualified_names
code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, all_names)
cleaned_module = remove_unused_definitions_by_function_names(original_module, all_names)
pruned_module = parse_code_and_prune_cst(
code_without_unused_defs,
code_context_type,
primary_qualified_names,
secondary_qualified_names,
remove_docstrings,
cleaned_module, code_context_type, primary_qualified_names, secondary_qualified_names, remove_docstrings
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
@ -249,7 +231,7 @@ def process_file_context(
code_context = ast.unparse(ast.parse(pruned_module.code))
else:
code_context = add_needed_imports_from_module(
src_module_code=original_code,
src_module_code=original_module,
dst_module_code=pruned_module,
src_path=file_path,
dst_path=file_path,
@ -264,6 +246,197 @@ def process_file_context(
return None
def extract_all_contexts_from_files(
helpers_of_fto: dict[Path, set[FunctionSource]],
helpers_of_helpers: dict[Path, set[FunctionSource]],
project_root_path: Path,
) -> AllContextResults:
"""Extract all 4 code context types from files in a single pass, parsing each file only once."""
# Deduplicate: remove HoH entries that overlap with FTO
helpers_of_helpers_no_overlap: dict[Path, set[FunctionSource]] = {}
for file_path, function_sources in helpers_of_helpers.items():
if file_path in helpers_of_fto:
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
else:
helpers_of_helpers_no_overlap[file_path] = function_sources
rw = CodeStringsMarkdown()
ro = CodeStringsMarkdown()
hashing = CodeStringsMarkdown()
testgen = CodeStringsMarkdown()
# Process files containing FTO helpers (all 4 context types)
for file_path, function_sources in helpers_of_fto.items():
fto_names = {func.qualified_name for func in function_sources}
hoh_funcs = helpers_of_helpers.get(file_path, set())
hoh_names = {func.qualified_name for func in hoh_funcs}
rw_helper_functions = list(function_sources)
all_helper_functions = list(function_sources | hoh_funcs)
try:
original_code = file_path.read_text("utf8")
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
original_module = cst.parse_module(original_code)
except Exception as e:
logger.debug(f"Failed to parse {file_path} with libcst: {type(e).__name__}: {e}")
continue
try:
relative_path = file_path.resolve().relative_to(project_root_path.resolve())
except ValueError:
relative_path = file_path
# Clean by fto_names only (for RW)
rw_cleaned = remove_unused_definitions_by_function_names(original_module, fto_names)
# Clean by all names (for RO/HASH/TESTGEN) — reuse rw_cleaned if no extra HoH names
all_names = fto_names | hoh_names
all_cleaned = (
remove_unused_definitions_by_function_names(original_module, all_names) if hoh_names else rw_cleaned
)
# READ_WRITABLE
try:
rw_pruned = parse_code_and_prune_cst(
rw_cleaned, CodeContextType.READ_WRITABLE, fto_names, set(), remove_docstrings=False
)
if rw_pruned.code.strip():
rw_code = add_needed_imports_from_module(
src_module_code=original_module,
dst_module_code=rw_pruned,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions=rw_helper_functions,
)
rw.code_strings.append(CodeString(code=rw_code, file_path=relative_path))
except ValueError as e:
logger.debug(f"Error while getting read-writable code: {e}")
# READ_ONLY
try:
ro_pruned = parse_code_and_prune_cst(
all_cleaned, CodeContextType.READ_ONLY, fto_names, hoh_names, remove_docstrings=False
)
if ro_pruned.code.strip():
ro_code = add_needed_imports_from_module(
src_module_code=original_module,
dst_module_code=ro_pruned,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions=all_helper_functions,
)
ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path))
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
# HASHING
try:
hash_pruned = parse_code_and_prune_cst(
all_cleaned, CodeContextType.HASHING, fto_names, hoh_names, remove_docstrings=True
)
if hash_pruned.code.strip():
hash_code = ast.unparse(ast.parse(hash_pruned.code))
hashing.code_strings.append(CodeString(code=hash_code, file_path=relative_path))
except ValueError as e:
logger.debug(f"Error while getting hashing code: {e}")
# TESTGEN
try:
testgen_pruned = parse_code_and_prune_cst(
all_cleaned, CodeContextType.TESTGEN, fto_names, hoh_names, remove_docstrings=False
)
if testgen_pruned.code.strip():
testgen_code = add_needed_imports_from_module(
src_module_code=original_module,
dst_module_code=testgen_pruned,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions=all_helper_functions,
)
testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path))
except ValueError as e:
logger.debug(f"Error while getting testgen code: {e}")
# Process files containing only helpers of helpers (RO/HASH/TESTGEN only)
for file_path, function_sources in helpers_of_helpers_no_overlap.items():
hoh_names = {func.qualified_name for func in function_sources}
helper_functions = list(function_sources)
try:
original_code = file_path.read_text("utf8")
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
original_module = cst.parse_module(original_code)
except Exception as e:
logger.debug(f"Failed to parse {file_path} with libcst: {type(e).__name__}: {e}")
continue
try:
relative_path = file_path.resolve().relative_to(project_root_path.resolve())
except ValueError:
relative_path = file_path
cleaned = remove_unused_definitions_by_function_names(original_module, hoh_names)
# READ_ONLY
try:
ro_pruned = parse_code_and_prune_cst(
cleaned, CodeContextType.READ_ONLY, set(), hoh_names, remove_docstrings=False
)
if ro_pruned.code.strip():
ro_code = add_needed_imports_from_module(
src_module_code=original_module,
dst_module_code=ro_pruned,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions=helper_functions,
)
ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path))
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
# HASHING
try:
hash_pruned = parse_code_and_prune_cst(
cleaned, CodeContextType.HASHING, set(), hoh_names, remove_docstrings=True
)
if hash_pruned.code.strip():
hash_code = ast.unparse(ast.parse(hash_pruned.code))
hashing.code_strings.append(CodeString(code=hash_code, file_path=relative_path))
except ValueError as e:
logger.debug(f"Error while getting hashing code: {e}")
# TESTGEN
try:
testgen_pruned = parse_code_and_prune_cst(
cleaned, CodeContextType.TESTGEN, set(), hoh_names, remove_docstrings=False
)
if testgen_pruned.code.strip():
testgen_code = add_needed_imports_from_module(
src_module_code=original_module,
dst_module_code=testgen_pruned,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions=helper_functions,
)
testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path))
except ValueError as e:
logger.debug(f"Error while getting testgen code: {e}")
return AllContextResults(read_writable=rw, read_only=ro, hashing=hashing, testgen=testgen)
def extract_code_markdown_context_from_files(
helpers_of_fto: dict[Path, set[FunctionSource]],
helpers_of_helpers: dict[Path, set[FunctionSource]],
@ -641,12 +814,13 @@ def _get_class_start_line(class_node: ast.ClassDef) -> int:
def _class_has_explicit_init(class_node: ast.ClassDef) -> bool:
return any(isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__" for item in class_node.body)
return any(
isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__"
for item in class_node.body
)
def _collect_synthetic_constructor_type_names(
class_node: ast.ClassDef, import_aliases: dict[str, str]
) -> set[str]:
def _collect_synthetic_constructor_type_names(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> set[str]:
is_dataclass, dataclass_init_enabled, _ = _get_dataclass_config(class_node, import_aliases)
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass:
return set()
@ -712,7 +886,9 @@ def _extract_synthetic_init_parameters(
if not include_in_init:
continue
parameters.append((item.target.id, _get_node_source(item.annotation, module_source, "Any"), default_value, kw_only))
parameters.append(
(item.target.id, _get_node_source(item.annotation, module_source, "Any"), default_value, kw_only)
)
return parameters
@ -727,10 +903,7 @@ def _build_synthetic_init_stub(
return None
parameters = _extract_synthetic_init_parameters(
class_node,
module_source,
import_aliases,
kw_only_by_default=dataclass_kw_only,
class_node, module_source, import_aliases, kw_only_by_default=dataclass_kw_only
)
if not parameters:
return None
@ -750,9 +923,7 @@ def _build_synthetic_init_stub(
return f" def __init__({signature}):\n ..."
def _extract_function_stub_snippet(
fn_node: ast.FunctionDef | ast.AsyncFunctionDef, module_lines: list[str]
) -> str:
def _extract_function_stub_snippet(fn_node: ast.FunctionDef | ast.AsyncFunctionDef, module_lines: list[str]) -> str:
start_line = min(d.lineno for d in fn_node.decorator_list) if fn_node.decorator_list else fn_node.lineno
return "\n".join(module_lines[start_line - 1 : fn_node.end_lineno])
@ -779,13 +950,17 @@ def _has_non_property_method_decorator(
def _has_descriptor_like_class_fields(class_node: ast.ClassDef) -> bool:
return any(isinstance(item, (ast.Assign, ast.AnnAssign)) and isinstance(item.value, ast.Call) for item in class_node.body)
return any(
isinstance(item, (ast.Assign, ast.AnnAssign)) and isinstance(item.value, ast.Call) for item in class_node.body
)
def _should_use_raw_project_class_context(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> bool:
start_line = _get_class_start_line(class_node)
class_line_count = class_node.end_lineno - start_line + 1
is_small = class_line_count <= MAX_RAW_PROJECT_CLASS_LINES and len(class_node.body) <= MAX_RAW_PROJECT_CLASS_BODY_ITEMS
is_small = (
class_line_count <= MAX_RAW_PROJECT_CLASS_LINES and len(class_node.body) <= MAX_RAW_PROJECT_CLASS_BODY_ITEMS
)
if is_small and _class_has_explicit_init(class_node):
return True
@ -933,11 +1108,7 @@ def _append_project_class_context(
if base_expr_name is None:
continue
resolved = _resolve_imported_class_reference(
base_expr_name,
module_tree,
module_path,
project_root_path,
module_cache,
base_expr_name, module_tree, module_path, project_root_path, module_cache
)
if resolved is None:
continue
@ -955,16 +1126,16 @@ def _append_project_class_context(
code_strings,
)
code_strings.append(CodeString(code=_extract_raw_class_context(class_node, module_source, module_tree), file_path=module_path))
code_strings.append(
CodeString(code=_extract_raw_class_context(class_node, module_source, module_tree), file_path=module_path)
)
emitted_classes.add(class_key)
emitted_class_names.add(class_name)
return True
def _collect_type_names_from_function(
func_node: ast.FunctionDef | ast.AsyncFunctionDef,
tree: ast.Module,
class_name: str | None,
func_node: ast.FunctionDef | ast.AsyncFunctionDef, tree: ast.Module, class_name: str | None
) -> set[str]:
type_names: set[str] = set()
for arg in func_node.args.args + func_node.args.posonlyargs + func_node.args.kwonlyargs:
@ -974,7 +1145,11 @@ def _collect_type_names_from_function(
if func_node.args.kwarg:
type_names |= collect_type_names_from_annotation(func_node.args.kwarg.annotation)
for body_node in ast.walk(func_node):
if isinstance(body_node, ast.Call) and isinstance(body_node.func, ast.Name) and body_node.func.id == "isinstance":
if (
isinstance(body_node, ast.Call)
and isinstance(body_node.func, ast.Name)
and body_node.func.id == "isinstance"
):
if len(body_node.args) >= 2:
second_arg = body_node.args[1]
if isinstance(second_arg, ast.Name):
@ -1305,7 +1480,11 @@ def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef,
if not isinstance(node, (ast.Import, ast.ImportFrom)) or node.lineno in added_imports:
continue
for alias in node.names:
name = alias.asname if alias.asname else (alias.name.split(".")[0] if isinstance(node, ast.Import) else alias.name)
name = (
alias.asname
if alias.asname
else (alias.name.split(".")[0] if isinstance(node, ast.Import) else alias.name)
)
if name in needed_names:
import_lines.append(source_lines[node.lineno - 1])
added_imports.add(node.lineno)
@ -1340,9 +1519,7 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode
return indented_block
def _maybe_strip_docstring(
node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig
) -> cst.FunctionDef | cst.ClassDef:
def _maybe_strip_docstring(node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig) -> cst.FunctionDef | cst.ClassDef:
if cfg.remove_docstrings and isinstance(node.body, cst.IndentedBlock):
return node.with_changes(body=remove_docstring_from_body(node.body))
return node
@ -1361,17 +1538,17 @@ class PruneConfig:
def parse_code_and_prune_cst(
code: str,
code: str | cst.Module,
code_context_type: CodeContextType,
target_functions: set[str],
helpers_of_helper_functions: set[str] = set(), # noqa: B006
remove_docstrings: bool = False,
) -> cst.Module:
"""Parse and filter the code CST, returning the pruned Module."""
module = cst.parse_module(code)
defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions)
module = code if isinstance(code, cst.Module) else cst.parse_module(code)
if code_context_type == CodeContextType.READ_WRITABLE:
defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions)
cfg = PruneConfig(defs_with_usages=defs_with_usages, keep_class_init=True)
elif code_context_type == CodeContextType.READ_ONLY:
cfg = PruneConfig(
@ -1402,10 +1579,7 @@ def parse_code_and_prune_cst(
def prune_cst(
node: cst.CSTNode,
target_functions: set[str],
cfg: PruneConfig,
prefix: str = "",
node: cst.CSTNode, target_functions: set[str], cfg: PruneConfig, prefix: str = ""
) -> tuple[cst.CSTNode | None, bool]:
if isinstance(node, (cst.Import, cst.ImportFrom)):
return None, False

View file

@ -383,7 +383,9 @@ def remove_unused_definitions_recursively(
if isinstance(statement, cst.FunctionDef):
new_statements.append(statement)
elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)):
if class_has_dependencies or is_assignment_used(statement, definitions, name_prefix=f"{class_name}."):
if class_has_dependencies or is_assignment_used(
statement, definitions, name_prefix=f"{class_name}."
):
new_statements.append(statement)
else:
new_statements.append(statement)
@ -425,13 +427,15 @@ def collect_top_level_defs_with_usages(
return definitions
def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str:
def remove_unused_definitions_by_function_names(
code: Union[str, cst.Module], qualified_function_names: set[str]
) -> cst.Module:
"""Remove top-level definitions (classes, variables, functions) not used by the specified qualified function names."""
try:
module = cst.parse_module(code)
module = code if isinstance(code, cst.Module) else cst.parse_module(code)
except Exception as e:
logger.debug(f"Failed to parse code with libcst: {type(e).__name__}: {e}")
return code
return code if isinstance(code, cst.Module) else cst.parse_module("")
try:
defs_with_usages = collect_top_level_defs_with_usages(module, qualified_function_names)
@ -439,11 +443,11 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
# Apply the recursive removal transformation
modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages)
return modified_module.code if modified_module else ""
return modified_module if modified_module else cst.parse_module("")
except Exception as e:
# If any other error occurs during processing, return the original code
logger.debug(f"Error processing code to remove unused definitions: {type(e).__name__}: {e}")
return code
return module
def revert_unused_helper_functions(
@ -569,11 +573,7 @@ def find_target_node(
def _collect_attr_names(
value_id: str,
attr_name: str,
class_name: str | None,
names: set[str],
imported_names_map: dict[str, set[str]],
value_id: str, attr_name: str, class_name: str | None, names: set[str], imported_names_map: dict[str, set[str]]
) -> None:
if value_id == "self":
names.add(attr_name)
@ -605,9 +605,7 @@ def _collect_called_names(
called.update(mapped_names)
elif isinstance(node.func, ast.Attribute):
if isinstance(node.func.value, ast.Name):
_collect_attr_names(
node.func.value.id, node.func.attr, class_name, called, imported_names_map
)
_collect_attr_names(node.func.value.id, node.func.attr, class_name, called, imported_names_map)
else:
called.add(node.func.attr)
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):

View file

@ -540,7 +540,7 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
def add_needed_imports_from_module(
src_module_code: str,
src_module_code: str | cst.Module,
dst_module_code: str | cst.Module,
src_path: Path,
dst_path: Path,
@ -549,7 +549,6 @@ def add_needed_imports_from_module(
helper_functions_fqn: set[str] | None = None,
) -> str:
"""Add all needed and used source module code imports to the destination module code, and return it."""
src_module_code = delete___future___aliased_imports(src_module_code)
if not helper_functions_fqn:
helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])}
@ -571,7 +570,10 @@ def add_needed_imports_from_module(
)
)
try:
src_module = cst.parse_module(src_module_code)
if isinstance(src_module_code, cst.Module):
src_module = src_module_code.visit(FutureAliasedImportTransformer())
else:
src_module = cst.parse_module(src_module_code).visit(FutureAliasedImportTransformer())
# Exclude function/class bodies so GatherImportsVisitor only sees module-level imports.
# Nested imports (inside functions) are part of function logic and must not be
# scheduled for add/remove — RemoveImportsVisitor would strip them as "unused".

View file

@ -33,7 +33,7 @@ def another_function():
qualified_functions = {"main_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Normalize whitespace for comparison
assert result.strip() == expected.strip()
assert result.code.strip() == expected.strip()
def test_class_variable_removal() -> None:
@ -84,7 +84,7 @@ def helper_function():
qualified_functions = {"helper_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Normalize whitespace for comparison
assert result.strip() == expected.strip()
assert result.code.strip() == expected.strip()
def test_complex_variable_dependencies() -> None:
@ -122,7 +122,7 @@ def tuple_user():
qualified_functions = {"main_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.strip() == expected.strip()
assert result.code.strip() == expected.strip()
def test_type_annotation_usage() -> None:
@ -156,7 +156,7 @@ def unused_function(param: UnusedType) -> UnusedType:
qualified_functions = {"main_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Normalize whitespace for comparison
assert result.strip() == expected.strip()
assert result.code.strip() == expected.strip()
def test_class_method_with_dunder_methods() -> None:
@ -215,7 +215,7 @@ def helper_function():
qualified_functions = {"MyClass.target_method"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Normalize whitespace for comparison
assert result.strip() == expected.strip()
assert result.code.strip() == expected.strip()
def test_complex_type_annotations() -> None:
@ -263,7 +263,7 @@ def unused_function(param: UnusedType) -> None:
qualified_functions = {"process_data"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.strip() == expected.strip()
assert result.code.strip() == expected.strip()
def test_try_except_finally_variables() -> None:
@ -325,7 +325,7 @@ def unused_function():
qualified_functions = {"use_constants", "use_cleanup"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.strip() == expected.strip()
assert result.code.strip() == expected.strip()
def test_base_class_inheritance() -> None:
@ -383,8 +383,9 @@ def test_function():
qualified_functions = {"test_function"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# LayoutDumper should be preserved because ObjectDetectionLayoutDumper inherits from it
assert "class LayoutDumper" in result
assert "class ObjectDetectionLayoutDumper" in result
assert "class LayoutDumper" in result.code
assert "class ObjectDetectionLayoutDumper" in result.code
assert result.code.strip() == expected.strip()
def test_conditional_and_loop_variables() -> None:
@ -471,7 +472,7 @@ def unused_function():
qualified_functions = {"get_platform_info", "get_loop_result"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.strip() == expected.strip()
assert result.code.strip() == expected.strip()
def test_enum_attribute_access_dependency() -> None:
@ -519,10 +520,10 @@ def process_message(kind):
qualified_functions = {"process_message"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# MessageKind should be preserved because process_message uses MessageKind.VALUE
assert "class MessageKind" in result
assert "class MessageKind" in result.code
# UNUSED_VAR should be removed
assert "UNUSED_VAR" not in result
assert result.strip() == expected.strip()
assert "UNUSED_VAR" not in result.code
assert result.code.strip() == expected.strip()
def test_attribute_access_does_not_track_attr_name() -> None:
@ -551,7 +552,7 @@ class MyClass:
qualified_functions = {"MyClass.get_x", "MyClass.__init__"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Module-level x should NOT be kept (self.x doesn't reference it)
assert 'x = "module_level_x"' not in result
assert 'x = "module_level_x"' not in result.code
# UNUSED_VAR should also be removed
assert "UNUSED_VAR" not in result
assert result.strip() == expected.strip()
assert "UNUSED_VAR" not in result.code
assert result.code.strip() == expected.strip()