From 5671562da23edb5c229395164e354777332b6445 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 16 Mar 2026 10:11:58 -0600 Subject: [PATCH] perf: eliminate redundant CST parsing in get_code_optimization_context Parse each file once instead of up to 16 times by: - Making remove_unused_definitions_by_function_names accept/return cst.Module - Making parse_code_and_prune_cst and add_needed_imports_from_module accept cst.Module - Threading the parsed Module through process_file_context - Adding extract_all_contexts_from_files that processes all 4 context types (READ_WRITABLE, READ_ONLY, HASHING, TESTGEN) in a single per-file pass --- .../python/context/code_context_extractor.py | 346 +++++++++++++----- .../context/unused_definition_remover.py | 26 +- .../python/static_analysis/code_extractor.py | 8 +- tests/test_remove_unused_definitions.py | 33 +- 4 files changed, 294 insertions(+), 119 deletions(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 002c6e32f..61d806bfd 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -4,7 +4,7 @@ import ast import hashlib import os from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from itertools import chain from pathlib import Path from typing import TYPE_CHECKING @@ -46,22 +46,22 @@ if TYPE_CHECKING: from codeflash.languages.python.context.unused_definition_remover import UsageInfo +@dataclass +class AllContextResults: + read_writable: CodeStringsMarkdown + read_only: CodeStringsMarkdown + hashing: CodeStringsMarkdown + testgen: CodeStringsMarkdown + + def build_testgen_context( - helpers_of_fto_dict: dict[Path, set[FunctionSource]], - helpers_of_helpers_dict: dict[Path, set[FunctionSource]], + testgen_base: CodeStringsMarkdown, project_root_path: Path, *, - remove_docstrings: bool = False, include_enrichment: bool = True, function_to_optimize: FunctionToOptimize | None = None, ) -> CodeStringsMarkdown: - testgen_context = extract_code_markdown_context_from_files( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=remove_docstrings, - code_context_type=CodeContextType.TESTGEN, - ) + testgen_context = testgen_base if include_enrichment: enrichment = enrich_testgen_context(testgen_context, project_root_path) @@ -114,14 +114,10 @@ def get_code_optimization_context( helpers_of_fto_qualified_names_dict, project_root_path ) - # Extract code context for optimization - final_read_writable_code = extract_code_markdown_context_from_files( - helpers_of_fto_dict, - {}, - project_root_path, - remove_docstrings=False, - code_context_type=CodeContextType.READ_WRITABLE, - ) + # Extract all code contexts in a single pass (one CST parse per file) + all_ctx = extract_all_contexts_from_files(helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path) + + final_read_writable_code = all_ctx.read_writable # Ensure the target file is first in the code blocks so the LLM knows which file to optimize target_relative = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve()) @@ -130,20 +126,7 @@ def get_code_optimization_context( if target_blocks: final_read_writable_code.code_strings = target_blocks + other_blocks - read_only_code_markdown = extract_code_markdown_context_from_files( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=False, - code_context_type=CodeContextType.READ_ONLY, - ) - hashing_code_context = extract_code_markdown_context_from_files( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=True, - code_context_type=CodeContextType.HASHING, - ) + read_only_code_markdown = all_ctx.read_only # Handle token limits final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.markdown) @@ -173,32 +156,29 @@ def get_code_optimization_context( # Progressive fallback for testgen context token limits testgen_context = build_testgen_context( - helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, function_to_optimize=function_to_optimize + all_ctx.testgen, project_root_path, function_to_optimize=function_to_optimize ) if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: logger.debug("Testgen context exceeded token limit, removing docstrings") - testgen_context = build_testgen_context( + testgen_base_no_docs = extract_code_markdown_context_from_files( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True, - function_to_optimize=function_to_optimize, + code_context_type=CodeContextType.TESTGEN, + ) + testgen_context = build_testgen_context( + testgen_base_no_docs, project_root_path, function_to_optimize=function_to_optimize ) if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: logger.debug("Testgen context still exceeded token limit, removing enrichment") - testgen_context = build_testgen_context( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=True, - include_enrichment=False, - ) + testgen_context = build_testgen_context(testgen_base_no_docs, project_root_path, include_enrichment=False) if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: raise ValueError(TESTGEN_LIMIT_ERROR) - code_hash_context = hashing_code_context.markdown + code_hash_context = all_ctx.hashing.markdown code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest() all_helper_fqns = list({fs.fully_qualified_name for fs in helpers_of_fto_list + helpers_of_helpers_list}) @@ -230,15 +210,17 @@ def process_file_context( logger.exception(f"Error while parsing {file_path}: {e}") return None + try: + original_module = cst.parse_module(original_code) + except Exception as e: + logger.debug(f"Failed to parse {file_path} with libcst: {type(e).__name__}: {e}") + return None + try: all_names = primary_qualified_names | secondary_qualified_names - code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, all_names) + cleaned_module = remove_unused_definitions_by_function_names(original_module, all_names) pruned_module = parse_code_and_prune_cst( - code_without_unused_defs, - code_context_type, - primary_qualified_names, - secondary_qualified_names, - remove_docstrings, + cleaned_module, code_context_type, primary_qualified_names, secondary_qualified_names, remove_docstrings ) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") @@ -249,7 +231,7 @@ def process_file_context( code_context = ast.unparse(ast.parse(pruned_module.code)) else: code_context = add_needed_imports_from_module( - src_module_code=original_code, + src_module_code=original_module, dst_module_code=pruned_module, src_path=file_path, dst_path=file_path, @@ -264,6 +246,197 @@ def process_file_context( return None +def extract_all_contexts_from_files( + helpers_of_fto: dict[Path, set[FunctionSource]], + helpers_of_helpers: dict[Path, set[FunctionSource]], + project_root_path: Path, +) -> AllContextResults: + """Extract all 4 code context types from files in a single pass, parsing each file only once.""" + # Deduplicate: remove HoH entries that overlap with FTO + helpers_of_helpers_no_overlap: dict[Path, set[FunctionSource]] = {} + for file_path, function_sources in helpers_of_helpers.items(): + if file_path in helpers_of_fto: + helpers_of_helpers[file_path] -= helpers_of_fto[file_path] + else: + helpers_of_helpers_no_overlap[file_path] = function_sources + + rw = CodeStringsMarkdown() + ro = CodeStringsMarkdown() + hashing = CodeStringsMarkdown() + testgen = CodeStringsMarkdown() + + # Process files containing FTO helpers (all 4 context types) + for file_path, function_sources in helpers_of_fto.items(): + fto_names = {func.qualified_name for func in function_sources} + hoh_funcs = helpers_of_helpers.get(file_path, set()) + hoh_names = {func.qualified_name for func in hoh_funcs} + rw_helper_functions = list(function_sources) + all_helper_functions = list(function_sources | hoh_funcs) + + try: + original_code = file_path.read_text("utf8") + except Exception as e: + logger.exception(f"Error while parsing {file_path}: {e}") + continue + + try: + original_module = cst.parse_module(original_code) + except Exception as e: + logger.debug(f"Failed to parse {file_path} with libcst: {type(e).__name__}: {e}") + continue + + try: + relative_path = file_path.resolve().relative_to(project_root_path.resolve()) + except ValueError: + relative_path = file_path + + # Clean by fto_names only (for RW) + rw_cleaned = remove_unused_definitions_by_function_names(original_module, fto_names) + # Clean by all names (for RO/HASH/TESTGEN) — reuse rw_cleaned if no extra HoH names + all_names = fto_names | hoh_names + all_cleaned = ( + remove_unused_definitions_by_function_names(original_module, all_names) if hoh_names else rw_cleaned + ) + + # READ_WRITABLE + try: + rw_pruned = parse_code_and_prune_cst( + rw_cleaned, CodeContextType.READ_WRITABLE, fto_names, set(), remove_docstrings=False + ) + if rw_pruned.code.strip(): + rw_code = add_needed_imports_from_module( + src_module_code=original_module, + dst_module_code=rw_pruned, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=rw_helper_functions, + ) + rw.code_strings.append(CodeString(code=rw_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting read-writable code: {e}") + + # READ_ONLY + try: + ro_pruned = parse_code_and_prune_cst( + all_cleaned, CodeContextType.READ_ONLY, fto_names, hoh_names, remove_docstrings=False + ) + if ro_pruned.code.strip(): + ro_code = add_needed_imports_from_module( + src_module_code=original_module, + dst_module_code=ro_pruned, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=all_helper_functions, + ) + ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting read-only code: {e}") + + # HASHING + try: + hash_pruned = parse_code_and_prune_cst( + all_cleaned, CodeContextType.HASHING, fto_names, hoh_names, remove_docstrings=True + ) + if hash_pruned.code.strip(): + hash_code = ast.unparse(ast.parse(hash_pruned.code)) + hashing.code_strings.append(CodeString(code=hash_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting hashing code: {e}") + + # TESTGEN + try: + testgen_pruned = parse_code_and_prune_cst( + all_cleaned, CodeContextType.TESTGEN, fto_names, hoh_names, remove_docstrings=False + ) + if testgen_pruned.code.strip(): + testgen_code = add_needed_imports_from_module( + src_module_code=original_module, + dst_module_code=testgen_pruned, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=all_helper_functions, + ) + testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting testgen code: {e}") + + # Process files containing only helpers of helpers (RO/HASH/TESTGEN only) + for file_path, function_sources in helpers_of_helpers_no_overlap.items(): + hoh_names = {func.qualified_name for func in function_sources} + helper_functions = list(function_sources) + + try: + original_code = file_path.read_text("utf8") + except Exception as e: + logger.exception(f"Error while parsing {file_path}: {e}") + continue + + try: + original_module = cst.parse_module(original_code) + except Exception as e: + logger.debug(f"Failed to parse {file_path} with libcst: {type(e).__name__}: {e}") + continue + + try: + relative_path = file_path.resolve().relative_to(project_root_path.resolve()) + except ValueError: + relative_path = file_path + + cleaned = remove_unused_definitions_by_function_names(original_module, hoh_names) + + # READ_ONLY + try: + ro_pruned = parse_code_and_prune_cst( + cleaned, CodeContextType.READ_ONLY, set(), hoh_names, remove_docstrings=False + ) + if ro_pruned.code.strip(): + ro_code = add_needed_imports_from_module( + src_module_code=original_module, + dst_module_code=ro_pruned, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=helper_functions, + ) + ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting read-only code: {e}") + + # HASHING + try: + hash_pruned = parse_code_and_prune_cst( + cleaned, CodeContextType.HASHING, set(), hoh_names, remove_docstrings=True + ) + if hash_pruned.code.strip(): + hash_code = ast.unparse(ast.parse(hash_pruned.code)) + hashing.code_strings.append(CodeString(code=hash_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting hashing code: {e}") + + # TESTGEN + try: + testgen_pruned = parse_code_and_prune_cst( + cleaned, CodeContextType.TESTGEN, set(), hoh_names, remove_docstrings=False + ) + if testgen_pruned.code.strip(): + testgen_code = add_needed_imports_from_module( + src_module_code=original_module, + dst_module_code=testgen_pruned, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=helper_functions, + ) + testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting testgen code: {e}") + + return AllContextResults(read_writable=rw, read_only=ro, hashing=hashing, testgen=testgen) + + def extract_code_markdown_context_from_files( helpers_of_fto: dict[Path, set[FunctionSource]], helpers_of_helpers: dict[Path, set[FunctionSource]], @@ -641,12 +814,13 @@ def _get_class_start_line(class_node: ast.ClassDef) -> int: def _class_has_explicit_init(class_node: ast.ClassDef) -> bool: - return any(isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__" for item in class_node.body) + return any( + isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__" + for item in class_node.body + ) -def _collect_synthetic_constructor_type_names( - class_node: ast.ClassDef, import_aliases: dict[str, str] -) -> set[str]: +def _collect_synthetic_constructor_type_names(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> set[str]: is_dataclass, dataclass_init_enabled, _ = _get_dataclass_config(class_node, import_aliases) if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass: return set() @@ -712,7 +886,9 @@ def _extract_synthetic_init_parameters( if not include_in_init: continue - parameters.append((item.target.id, _get_node_source(item.annotation, module_source, "Any"), default_value, kw_only)) + parameters.append( + (item.target.id, _get_node_source(item.annotation, module_source, "Any"), default_value, kw_only) + ) return parameters @@ -727,10 +903,7 @@ def _build_synthetic_init_stub( return None parameters = _extract_synthetic_init_parameters( - class_node, - module_source, - import_aliases, - kw_only_by_default=dataclass_kw_only, + class_node, module_source, import_aliases, kw_only_by_default=dataclass_kw_only ) if not parameters: return None @@ -750,9 +923,7 @@ def _build_synthetic_init_stub( return f" def __init__({signature}):\n ..." -def _extract_function_stub_snippet( - fn_node: ast.FunctionDef | ast.AsyncFunctionDef, module_lines: list[str] -) -> str: +def _extract_function_stub_snippet(fn_node: ast.FunctionDef | ast.AsyncFunctionDef, module_lines: list[str]) -> str: start_line = min(d.lineno for d in fn_node.decorator_list) if fn_node.decorator_list else fn_node.lineno return "\n".join(module_lines[start_line - 1 : fn_node.end_lineno]) @@ -779,13 +950,17 @@ def _has_non_property_method_decorator( def _has_descriptor_like_class_fields(class_node: ast.ClassDef) -> bool: - return any(isinstance(item, (ast.Assign, ast.AnnAssign)) and isinstance(item.value, ast.Call) for item in class_node.body) + return any( + isinstance(item, (ast.Assign, ast.AnnAssign)) and isinstance(item.value, ast.Call) for item in class_node.body + ) def _should_use_raw_project_class_context(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> bool: start_line = _get_class_start_line(class_node) class_line_count = class_node.end_lineno - start_line + 1 - is_small = class_line_count <= MAX_RAW_PROJECT_CLASS_LINES and len(class_node.body) <= MAX_RAW_PROJECT_CLASS_BODY_ITEMS + is_small = ( + class_line_count <= MAX_RAW_PROJECT_CLASS_LINES and len(class_node.body) <= MAX_RAW_PROJECT_CLASS_BODY_ITEMS + ) if is_small and _class_has_explicit_init(class_node): return True @@ -933,11 +1108,7 @@ def _append_project_class_context( if base_expr_name is None: continue resolved = _resolve_imported_class_reference( - base_expr_name, - module_tree, - module_path, - project_root_path, - module_cache, + base_expr_name, module_tree, module_path, project_root_path, module_cache ) if resolved is None: continue @@ -955,16 +1126,16 @@ def _append_project_class_context( code_strings, ) - code_strings.append(CodeString(code=_extract_raw_class_context(class_node, module_source, module_tree), file_path=module_path)) + code_strings.append( + CodeString(code=_extract_raw_class_context(class_node, module_source, module_tree), file_path=module_path) + ) emitted_classes.add(class_key) emitted_class_names.add(class_name) return True def _collect_type_names_from_function( - func_node: ast.FunctionDef | ast.AsyncFunctionDef, - tree: ast.Module, - class_name: str | None, + func_node: ast.FunctionDef | ast.AsyncFunctionDef, tree: ast.Module, class_name: str | None ) -> set[str]: type_names: set[str] = set() for arg in func_node.args.args + func_node.args.posonlyargs + func_node.args.kwonlyargs: @@ -974,7 +1145,11 @@ def _collect_type_names_from_function( if func_node.args.kwarg: type_names |= collect_type_names_from_annotation(func_node.args.kwarg.annotation) for body_node in ast.walk(func_node): - if isinstance(body_node, ast.Call) and isinstance(body_node.func, ast.Name) and body_node.func.id == "isinstance": + if ( + isinstance(body_node, ast.Call) + and isinstance(body_node.func, ast.Name) + and body_node.func.id == "isinstance" + ): if len(body_node.args) >= 2: second_arg = body_node.args[1] if isinstance(second_arg, ast.Name): @@ -1305,7 +1480,11 @@ def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, if not isinstance(node, (ast.Import, ast.ImportFrom)) or node.lineno in added_imports: continue for alias in node.names: - name = alias.asname if alias.asname else (alias.name.split(".")[0] if isinstance(node, ast.Import) else alias.name) + name = ( + alias.asname + if alias.asname + else (alias.name.split(".")[0] if isinstance(node, ast.Import) else alias.name) + ) if name in needed_names: import_lines.append(source_lines[node.lineno - 1]) added_imports.add(node.lineno) @@ -1340,9 +1519,7 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode return indented_block -def _maybe_strip_docstring( - node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig -) -> cst.FunctionDef | cst.ClassDef: +def _maybe_strip_docstring(node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig) -> cst.FunctionDef | cst.ClassDef: if cfg.remove_docstrings and isinstance(node.body, cst.IndentedBlock): return node.with_changes(body=remove_docstring_from_body(node.body)) return node @@ -1361,17 +1538,17 @@ class PruneConfig: def parse_code_and_prune_cst( - code: str, + code: str | cst.Module, code_context_type: CodeContextType, target_functions: set[str], helpers_of_helper_functions: set[str] = set(), # noqa: B006 remove_docstrings: bool = False, ) -> cst.Module: """Parse and filter the code CST, returning the pruned Module.""" - module = cst.parse_module(code) - defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions) + module = code if isinstance(code, cst.Module) else cst.parse_module(code) if code_context_type == CodeContextType.READ_WRITABLE: + defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions) cfg = PruneConfig(defs_with_usages=defs_with_usages, keep_class_init=True) elif code_context_type == CodeContextType.READ_ONLY: cfg = PruneConfig( @@ -1402,10 +1579,7 @@ def parse_code_and_prune_cst( def prune_cst( - node: cst.CSTNode, - target_functions: set[str], - cfg: PruneConfig, - prefix: str = "", + node: cst.CSTNode, target_functions: set[str], cfg: PruneConfig, prefix: str = "" ) -> tuple[cst.CSTNode | None, bool]: if isinstance(node, (cst.Import, cst.ImportFrom)): return None, False diff --git a/codeflash/languages/python/context/unused_definition_remover.py b/codeflash/languages/python/context/unused_definition_remover.py index ec1123573..1a6bcc8fc 100644 --- a/codeflash/languages/python/context/unused_definition_remover.py +++ b/codeflash/languages/python/context/unused_definition_remover.py @@ -383,7 +383,9 @@ def remove_unused_definitions_recursively( if isinstance(statement, cst.FunctionDef): new_statements.append(statement) elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)): - if class_has_dependencies or is_assignment_used(statement, definitions, name_prefix=f"{class_name}."): + if class_has_dependencies or is_assignment_used( + statement, definitions, name_prefix=f"{class_name}." + ): new_statements.append(statement) else: new_statements.append(statement) @@ -425,13 +427,15 @@ def collect_top_level_defs_with_usages( return definitions -def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str: +def remove_unused_definitions_by_function_names( + code: Union[str, cst.Module], qualified_function_names: set[str] +) -> cst.Module: """Remove top-level definitions (classes, variables, functions) not used by the specified qualified function names.""" try: - module = cst.parse_module(code) + module = code if isinstance(code, cst.Module) else cst.parse_module(code) except Exception as e: logger.debug(f"Failed to parse code with libcst: {type(e).__name__}: {e}") - return code + return code if isinstance(code, cst.Module) else cst.parse_module("") try: defs_with_usages = collect_top_level_defs_with_usages(module, qualified_function_names) @@ -439,11 +443,11 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na # Apply the recursive removal transformation modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages) - return modified_module.code if modified_module else "" + return modified_module if modified_module else cst.parse_module("") except Exception as e: # If any other error occurs during processing, return the original code logger.debug(f"Error processing code to remove unused definitions: {type(e).__name__}: {e}") - return code + return module def revert_unused_helper_functions( @@ -569,11 +573,7 @@ def find_target_node( def _collect_attr_names( - value_id: str, - attr_name: str, - class_name: str | None, - names: set[str], - imported_names_map: dict[str, set[str]], + value_id: str, attr_name: str, class_name: str | None, names: set[str], imported_names_map: dict[str, set[str]] ) -> None: if value_id == "self": names.add(attr_name) @@ -605,9 +605,7 @@ def _collect_called_names( called.update(mapped_names) elif isinstance(node.func, ast.Attribute): if isinstance(node.func.value, ast.Name): - _collect_attr_names( - node.func.value.id, node.func.attr, class_name, called, imported_names_map - ) + _collect_attr_names(node.func.value.id, node.func.attr, class_name, called, imported_names_map) else: called.add(node.func.attr) elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name): diff --git a/codeflash/languages/python/static_analysis/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py index 01117cd73..c469f6bf4 100644 --- a/codeflash/languages/python/static_analysis/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -540,7 +540,7 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]: def add_needed_imports_from_module( - src_module_code: str, + src_module_code: str | cst.Module, dst_module_code: str | cst.Module, src_path: Path, dst_path: Path, @@ -549,7 +549,6 @@ def add_needed_imports_from_module( helper_functions_fqn: set[str] | None = None, ) -> str: """Add all needed and used source module code imports to the destination module code, and return it.""" - src_module_code = delete___future___aliased_imports(src_module_code) if not helper_functions_fqn: helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])} @@ -571,7 +570,10 @@ def add_needed_imports_from_module( ) ) try: - src_module = cst.parse_module(src_module_code) + if isinstance(src_module_code, cst.Module): + src_module = src_module_code.visit(FutureAliasedImportTransformer()) + else: + src_module = cst.parse_module(src_module_code).visit(FutureAliasedImportTransformer()) # Exclude function/class bodies so GatherImportsVisitor only sees module-level imports. # Nested imports (inside functions) are part of function logic and must not be # scheduled for add/remove — RemoveImportsVisitor would strip them as "unused". diff --git a/tests/test_remove_unused_definitions.py b/tests/test_remove_unused_definitions.py index 032942f29..3bc237ba4 100644 --- a/tests/test_remove_unused_definitions.py +++ b/tests/test_remove_unused_definitions.py @@ -33,7 +33,7 @@ def another_function(): qualified_functions = {"main_function"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # Normalize whitespace for comparison - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_class_variable_removal() -> None: @@ -84,7 +84,7 @@ def helper_function(): qualified_functions = {"helper_function"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # Normalize whitespace for comparison - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_complex_variable_dependencies() -> None: @@ -122,7 +122,7 @@ def tuple_user(): qualified_functions = {"main_function"} result = remove_unused_definitions_by_function_names(code, qualified_functions) - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_type_annotation_usage() -> None: @@ -156,7 +156,7 @@ def unused_function(param: UnusedType) -> UnusedType: qualified_functions = {"main_function"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # Normalize whitespace for comparison - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_class_method_with_dunder_methods() -> None: @@ -215,7 +215,7 @@ def helper_function(): qualified_functions = {"MyClass.target_method"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # Normalize whitespace for comparison - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_complex_type_annotations() -> None: @@ -263,7 +263,7 @@ def unused_function(param: UnusedType) -> None: qualified_functions = {"process_data"} result = remove_unused_definitions_by_function_names(code, qualified_functions) - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_try_except_finally_variables() -> None: @@ -325,7 +325,7 @@ def unused_function(): qualified_functions = {"use_constants", "use_cleanup"} result = remove_unused_definitions_by_function_names(code, qualified_functions) - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_base_class_inheritance() -> None: @@ -383,8 +383,9 @@ def test_function(): qualified_functions = {"test_function"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # LayoutDumper should be preserved because ObjectDetectionLayoutDumper inherits from it - assert "class LayoutDumper" in result - assert "class ObjectDetectionLayoutDumper" in result + assert "class LayoutDumper" in result.code + assert "class ObjectDetectionLayoutDumper" in result.code + assert result.code.strip() == expected.strip() def test_conditional_and_loop_variables() -> None: @@ -471,7 +472,7 @@ def unused_function(): qualified_functions = {"get_platform_info", "get_loop_result"} result = remove_unused_definitions_by_function_names(code, qualified_functions) - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_enum_attribute_access_dependency() -> None: @@ -519,10 +520,10 @@ def process_message(kind): qualified_functions = {"process_message"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # MessageKind should be preserved because process_message uses MessageKind.VALUE - assert "class MessageKind" in result + assert "class MessageKind" in result.code # UNUSED_VAR should be removed - assert "UNUSED_VAR" not in result - assert result.strip() == expected.strip() + assert "UNUSED_VAR" not in result.code + assert result.code.strip() == expected.strip() def test_attribute_access_does_not_track_attr_name() -> None: @@ -551,7 +552,7 @@ class MyClass: qualified_functions = {"MyClass.get_x", "MyClass.__init__"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # Module-level x should NOT be kept (self.x doesn't reference it) - assert 'x = "module_level_x"' not in result + assert 'x = "module_level_x"' not in result.code # UNUSED_VAR should also be removed - assert "UNUSED_VAR" not in result - assert result.strip() == expected.strip() + assert "UNUSED_VAR" not in result.code + assert result.code.strip() == expected.strip()