mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
perf: eliminate redundant CST parsing in get_code_optimization_context
Parse each file once instead of up to 16 times by: - Making remove_unused_definitions_by_function_names accept/return cst.Module - Making parse_code_and_prune_cst and add_needed_imports_from_module accept cst.Module - Threading the parsed Module through process_file_context - Adding extract_all_contexts_from_files that processes all 4 context types (READ_WRITABLE, READ_ONLY, HASHING, TESTGEN) in a single per-file pass
This commit is contained in:
parent
17f4bbd6f9
commit
5671562da2
4 changed files with 294 additions and 119 deletions
|
|
@ -4,7 +4,7 @@ import ast
|
|||
import hashlib
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
|
@ -46,22 +46,22 @@ if TYPE_CHECKING:
|
|||
from codeflash.languages.python.context.unused_definition_remover import UsageInfo
|
||||
|
||||
|
||||
@dataclass
|
||||
class AllContextResults:
|
||||
read_writable: CodeStringsMarkdown
|
||||
read_only: CodeStringsMarkdown
|
||||
hashing: CodeStringsMarkdown
|
||||
testgen: CodeStringsMarkdown
|
||||
|
||||
|
||||
def build_testgen_context(
|
||||
helpers_of_fto_dict: dict[Path, set[FunctionSource]],
|
||||
helpers_of_helpers_dict: dict[Path, set[FunctionSource]],
|
||||
testgen_base: CodeStringsMarkdown,
|
||||
project_root_path: Path,
|
||||
*,
|
||||
remove_docstrings: bool = False,
|
||||
include_enrichment: bool = True,
|
||||
function_to_optimize: FunctionToOptimize | None = None,
|
||||
) -> CodeStringsMarkdown:
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=remove_docstrings,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
)
|
||||
testgen_context = testgen_base
|
||||
|
||||
if include_enrichment:
|
||||
enrichment = enrich_testgen_context(testgen_context, project_root_path)
|
||||
|
|
@ -114,14 +114,10 @@ def get_code_optimization_context(
|
|||
helpers_of_fto_qualified_names_dict, project_root_path
|
||||
)
|
||||
|
||||
# Extract code context for optimization
|
||||
final_read_writable_code = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
{},
|
||||
project_root_path,
|
||||
remove_docstrings=False,
|
||||
code_context_type=CodeContextType.READ_WRITABLE,
|
||||
)
|
||||
# Extract all code contexts in a single pass (one CST parse per file)
|
||||
all_ctx = extract_all_contexts_from_files(helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path)
|
||||
|
||||
final_read_writable_code = all_ctx.read_writable
|
||||
|
||||
# Ensure the target file is first in the code blocks so the LLM knows which file to optimize
|
||||
target_relative = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve())
|
||||
|
|
@ -130,20 +126,7 @@ def get_code_optimization_context(
|
|||
if target_blocks:
|
||||
final_read_writable_code.code_strings = target_blocks + other_blocks
|
||||
|
||||
read_only_code_markdown = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=False,
|
||||
code_context_type=CodeContextType.READ_ONLY,
|
||||
)
|
||||
hashing_code_context = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
code_context_type=CodeContextType.HASHING,
|
||||
)
|
||||
read_only_code_markdown = all_ctx.read_only
|
||||
|
||||
# Handle token limits
|
||||
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.markdown)
|
||||
|
|
@ -173,32 +156,29 @@ def get_code_optimization_context(
|
|||
|
||||
# Progressive fallback for testgen context token limits
|
||||
testgen_context = build_testgen_context(
|
||||
helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, function_to_optimize=function_to_optimize
|
||||
all_ctx.testgen, project_root_path, function_to_optimize=function_to_optimize
|
||||
)
|
||||
|
||||
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
|
||||
logger.debug("Testgen context exceeded token limit, removing docstrings")
|
||||
testgen_context = build_testgen_context(
|
||||
testgen_base_no_docs = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
function_to_optimize=function_to_optimize,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
)
|
||||
testgen_context = build_testgen_context(
|
||||
testgen_base_no_docs, project_root_path, function_to_optimize=function_to_optimize
|
||||
)
|
||||
|
||||
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
|
||||
logger.debug("Testgen context still exceeded token limit, removing enrichment")
|
||||
testgen_context = build_testgen_context(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
include_enrichment=False,
|
||||
)
|
||||
testgen_context = build_testgen_context(testgen_base_no_docs, project_root_path, include_enrichment=False)
|
||||
|
||||
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
|
||||
raise ValueError(TESTGEN_LIMIT_ERROR)
|
||||
code_hash_context = hashing_code_context.markdown
|
||||
code_hash_context = all_ctx.hashing.markdown
|
||||
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
|
||||
|
||||
all_helper_fqns = list({fs.fully_qualified_name for fs in helpers_of_fto_list + helpers_of_helpers_list})
|
||||
|
|
@ -230,15 +210,17 @@ def process_file_context(
|
|||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
original_module = cst.parse_module(original_code)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to parse {file_path} with libcst: {type(e).__name__}: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
all_names = primary_qualified_names | secondary_qualified_names
|
||||
code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, all_names)
|
||||
cleaned_module = remove_unused_definitions_by_function_names(original_module, all_names)
|
||||
pruned_module = parse_code_and_prune_cst(
|
||||
code_without_unused_defs,
|
||||
code_context_type,
|
||||
primary_qualified_names,
|
||||
secondary_qualified_names,
|
||||
remove_docstrings,
|
||||
cleaned_module, code_context_type, primary_qualified_names, secondary_qualified_names, remove_docstrings
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
|
|
@ -249,7 +231,7 @@ def process_file_context(
|
|||
code_context = ast.unparse(ast.parse(pruned_module.code))
|
||||
else:
|
||||
code_context = add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
src_module_code=original_module,
|
||||
dst_module_code=pruned_module,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
|
|
@ -264,6 +246,197 @@ def process_file_context(
|
|||
return None
|
||||
|
||||
|
||||
def extract_all_contexts_from_files(
|
||||
helpers_of_fto: dict[Path, set[FunctionSource]],
|
||||
helpers_of_helpers: dict[Path, set[FunctionSource]],
|
||||
project_root_path: Path,
|
||||
) -> AllContextResults:
|
||||
"""Extract all 4 code context types from files in a single pass, parsing each file only once."""
|
||||
# Deduplicate: remove HoH entries that overlap with FTO
|
||||
helpers_of_helpers_no_overlap: dict[Path, set[FunctionSource]] = {}
|
||||
for file_path, function_sources in helpers_of_helpers.items():
|
||||
if file_path in helpers_of_fto:
|
||||
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
|
||||
else:
|
||||
helpers_of_helpers_no_overlap[file_path] = function_sources
|
||||
|
||||
rw = CodeStringsMarkdown()
|
||||
ro = CodeStringsMarkdown()
|
||||
hashing = CodeStringsMarkdown()
|
||||
testgen = CodeStringsMarkdown()
|
||||
|
||||
# Process files containing FTO helpers (all 4 context types)
|
||||
for file_path, function_sources in helpers_of_fto.items():
|
||||
fto_names = {func.qualified_name for func in function_sources}
|
||||
hoh_funcs = helpers_of_helpers.get(file_path, set())
|
||||
hoh_names = {func.qualified_name for func in hoh_funcs}
|
||||
rw_helper_functions = list(function_sources)
|
||||
all_helper_functions = list(function_sources | hoh_funcs)
|
||||
|
||||
try:
|
||||
original_code = file_path.read_text("utf8")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
original_module = cst.parse_module(original_code)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to parse {file_path} with libcst: {type(e).__name__}: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
relative_path = file_path.resolve().relative_to(project_root_path.resolve())
|
||||
except ValueError:
|
||||
relative_path = file_path
|
||||
|
||||
# Clean by fto_names only (for RW)
|
||||
rw_cleaned = remove_unused_definitions_by_function_names(original_module, fto_names)
|
||||
# Clean by all names (for RO/HASH/TESTGEN) — reuse rw_cleaned if no extra HoH names
|
||||
all_names = fto_names | hoh_names
|
||||
all_cleaned = (
|
||||
remove_unused_definitions_by_function_names(original_module, all_names) if hoh_names else rw_cleaned
|
||||
)
|
||||
|
||||
# READ_WRITABLE
|
||||
try:
|
||||
rw_pruned = parse_code_and_prune_cst(
|
||||
rw_cleaned, CodeContextType.READ_WRITABLE, fto_names, set(), remove_docstrings=False
|
||||
)
|
||||
if rw_pruned.code.strip():
|
||||
rw_code = add_needed_imports_from_module(
|
||||
src_module_code=original_module,
|
||||
dst_module_code=rw_pruned,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions=rw_helper_functions,
|
||||
)
|
||||
rw.code_strings.append(CodeString(code=rw_code, file_path=relative_path))
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-writable code: {e}")
|
||||
|
||||
# READ_ONLY
|
||||
try:
|
||||
ro_pruned = parse_code_and_prune_cst(
|
||||
all_cleaned, CodeContextType.READ_ONLY, fto_names, hoh_names, remove_docstrings=False
|
||||
)
|
||||
if ro_pruned.code.strip():
|
||||
ro_code = add_needed_imports_from_module(
|
||||
src_module_code=original_module,
|
||||
dst_module_code=ro_pruned,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions=all_helper_functions,
|
||||
)
|
||||
ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path))
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
|
||||
# HASHING
|
||||
try:
|
||||
hash_pruned = parse_code_and_prune_cst(
|
||||
all_cleaned, CodeContextType.HASHING, fto_names, hoh_names, remove_docstrings=True
|
||||
)
|
||||
if hash_pruned.code.strip():
|
||||
hash_code = ast.unparse(ast.parse(hash_pruned.code))
|
||||
hashing.code_strings.append(CodeString(code=hash_code, file_path=relative_path))
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting hashing code: {e}")
|
||||
|
||||
# TESTGEN
|
||||
try:
|
||||
testgen_pruned = parse_code_and_prune_cst(
|
||||
all_cleaned, CodeContextType.TESTGEN, fto_names, hoh_names, remove_docstrings=False
|
||||
)
|
||||
if testgen_pruned.code.strip():
|
||||
testgen_code = add_needed_imports_from_module(
|
||||
src_module_code=original_module,
|
||||
dst_module_code=testgen_pruned,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions=all_helper_functions,
|
||||
)
|
||||
testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path))
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting testgen code: {e}")
|
||||
|
||||
# Process files containing only helpers of helpers (RO/HASH/TESTGEN only)
|
||||
for file_path, function_sources in helpers_of_helpers_no_overlap.items():
|
||||
hoh_names = {func.qualified_name for func in function_sources}
|
||||
helper_functions = list(function_sources)
|
||||
|
||||
try:
|
||||
original_code = file_path.read_text("utf8")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
original_module = cst.parse_module(original_code)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to parse {file_path} with libcst: {type(e).__name__}: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
relative_path = file_path.resolve().relative_to(project_root_path.resolve())
|
||||
except ValueError:
|
||||
relative_path = file_path
|
||||
|
||||
cleaned = remove_unused_definitions_by_function_names(original_module, hoh_names)
|
||||
|
||||
# READ_ONLY
|
||||
try:
|
||||
ro_pruned = parse_code_and_prune_cst(
|
||||
cleaned, CodeContextType.READ_ONLY, set(), hoh_names, remove_docstrings=False
|
||||
)
|
||||
if ro_pruned.code.strip():
|
||||
ro_code = add_needed_imports_from_module(
|
||||
src_module_code=original_module,
|
||||
dst_module_code=ro_pruned,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions=helper_functions,
|
||||
)
|
||||
ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path))
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
|
||||
# HASHING
|
||||
try:
|
||||
hash_pruned = parse_code_and_prune_cst(
|
||||
cleaned, CodeContextType.HASHING, set(), hoh_names, remove_docstrings=True
|
||||
)
|
||||
if hash_pruned.code.strip():
|
||||
hash_code = ast.unparse(ast.parse(hash_pruned.code))
|
||||
hashing.code_strings.append(CodeString(code=hash_code, file_path=relative_path))
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting hashing code: {e}")
|
||||
|
||||
# TESTGEN
|
||||
try:
|
||||
testgen_pruned = parse_code_and_prune_cst(
|
||||
cleaned, CodeContextType.TESTGEN, set(), hoh_names, remove_docstrings=False
|
||||
)
|
||||
if testgen_pruned.code.strip():
|
||||
testgen_code = add_needed_imports_from_module(
|
||||
src_module_code=original_module,
|
||||
dst_module_code=testgen_pruned,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions=helper_functions,
|
||||
)
|
||||
testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path))
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting testgen code: {e}")
|
||||
|
||||
return AllContextResults(read_writable=rw, read_only=ro, hashing=hashing, testgen=testgen)
|
||||
|
||||
|
||||
def extract_code_markdown_context_from_files(
|
||||
helpers_of_fto: dict[Path, set[FunctionSource]],
|
||||
helpers_of_helpers: dict[Path, set[FunctionSource]],
|
||||
|
|
@ -641,12 +814,13 @@ def _get_class_start_line(class_node: ast.ClassDef) -> int:
|
|||
|
||||
|
||||
def _class_has_explicit_init(class_node: ast.ClassDef) -> bool:
|
||||
return any(isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__" for item in class_node.body)
|
||||
return any(
|
||||
isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__"
|
||||
for item in class_node.body
|
||||
)
|
||||
|
||||
|
||||
def _collect_synthetic_constructor_type_names(
|
||||
class_node: ast.ClassDef, import_aliases: dict[str, str]
|
||||
) -> set[str]:
|
||||
def _collect_synthetic_constructor_type_names(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> set[str]:
|
||||
is_dataclass, dataclass_init_enabled, _ = _get_dataclass_config(class_node, import_aliases)
|
||||
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass:
|
||||
return set()
|
||||
|
|
@ -712,7 +886,9 @@ def _extract_synthetic_init_parameters(
|
|||
if not include_in_init:
|
||||
continue
|
||||
|
||||
parameters.append((item.target.id, _get_node_source(item.annotation, module_source, "Any"), default_value, kw_only))
|
||||
parameters.append(
|
||||
(item.target.id, _get_node_source(item.annotation, module_source, "Any"), default_value, kw_only)
|
||||
)
|
||||
return parameters
|
||||
|
||||
|
||||
|
|
@ -727,10 +903,7 @@ def _build_synthetic_init_stub(
|
|||
return None
|
||||
|
||||
parameters = _extract_synthetic_init_parameters(
|
||||
class_node,
|
||||
module_source,
|
||||
import_aliases,
|
||||
kw_only_by_default=dataclass_kw_only,
|
||||
class_node, module_source, import_aliases, kw_only_by_default=dataclass_kw_only
|
||||
)
|
||||
if not parameters:
|
||||
return None
|
||||
|
|
@ -750,9 +923,7 @@ def _build_synthetic_init_stub(
|
|||
return f" def __init__({signature}):\n ..."
|
||||
|
||||
|
||||
def _extract_function_stub_snippet(
|
||||
fn_node: ast.FunctionDef | ast.AsyncFunctionDef, module_lines: list[str]
|
||||
) -> str:
|
||||
def _extract_function_stub_snippet(fn_node: ast.FunctionDef | ast.AsyncFunctionDef, module_lines: list[str]) -> str:
|
||||
start_line = min(d.lineno for d in fn_node.decorator_list) if fn_node.decorator_list else fn_node.lineno
|
||||
return "\n".join(module_lines[start_line - 1 : fn_node.end_lineno])
|
||||
|
||||
|
|
@ -779,13 +950,17 @@ def _has_non_property_method_decorator(
|
|||
|
||||
|
||||
def _has_descriptor_like_class_fields(class_node: ast.ClassDef) -> bool:
|
||||
return any(isinstance(item, (ast.Assign, ast.AnnAssign)) and isinstance(item.value, ast.Call) for item in class_node.body)
|
||||
return any(
|
||||
isinstance(item, (ast.Assign, ast.AnnAssign)) and isinstance(item.value, ast.Call) for item in class_node.body
|
||||
)
|
||||
|
||||
|
||||
def _should_use_raw_project_class_context(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> bool:
|
||||
start_line = _get_class_start_line(class_node)
|
||||
class_line_count = class_node.end_lineno - start_line + 1
|
||||
is_small = class_line_count <= MAX_RAW_PROJECT_CLASS_LINES and len(class_node.body) <= MAX_RAW_PROJECT_CLASS_BODY_ITEMS
|
||||
is_small = (
|
||||
class_line_count <= MAX_RAW_PROJECT_CLASS_LINES and len(class_node.body) <= MAX_RAW_PROJECT_CLASS_BODY_ITEMS
|
||||
)
|
||||
|
||||
if is_small and _class_has_explicit_init(class_node):
|
||||
return True
|
||||
|
|
@ -933,11 +1108,7 @@ def _append_project_class_context(
|
|||
if base_expr_name is None:
|
||||
continue
|
||||
resolved = _resolve_imported_class_reference(
|
||||
base_expr_name,
|
||||
module_tree,
|
||||
module_path,
|
||||
project_root_path,
|
||||
module_cache,
|
||||
base_expr_name, module_tree, module_path, project_root_path, module_cache
|
||||
)
|
||||
if resolved is None:
|
||||
continue
|
||||
|
|
@ -955,16 +1126,16 @@ def _append_project_class_context(
|
|||
code_strings,
|
||||
)
|
||||
|
||||
code_strings.append(CodeString(code=_extract_raw_class_context(class_node, module_source, module_tree), file_path=module_path))
|
||||
code_strings.append(
|
||||
CodeString(code=_extract_raw_class_context(class_node, module_source, module_tree), file_path=module_path)
|
||||
)
|
||||
emitted_classes.add(class_key)
|
||||
emitted_class_names.add(class_name)
|
||||
return True
|
||||
|
||||
|
||||
def _collect_type_names_from_function(
|
||||
func_node: ast.FunctionDef | ast.AsyncFunctionDef,
|
||||
tree: ast.Module,
|
||||
class_name: str | None,
|
||||
func_node: ast.FunctionDef | ast.AsyncFunctionDef, tree: ast.Module, class_name: str | None
|
||||
) -> set[str]:
|
||||
type_names: set[str] = set()
|
||||
for arg in func_node.args.args + func_node.args.posonlyargs + func_node.args.kwonlyargs:
|
||||
|
|
@ -974,7 +1145,11 @@ def _collect_type_names_from_function(
|
|||
if func_node.args.kwarg:
|
||||
type_names |= collect_type_names_from_annotation(func_node.args.kwarg.annotation)
|
||||
for body_node in ast.walk(func_node):
|
||||
if isinstance(body_node, ast.Call) and isinstance(body_node.func, ast.Name) and body_node.func.id == "isinstance":
|
||||
if (
|
||||
isinstance(body_node, ast.Call)
|
||||
and isinstance(body_node.func, ast.Name)
|
||||
and body_node.func.id == "isinstance"
|
||||
):
|
||||
if len(body_node.args) >= 2:
|
||||
second_arg = body_node.args[1]
|
||||
if isinstance(second_arg, ast.Name):
|
||||
|
|
@ -1305,7 +1480,11 @@ def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef,
|
|||
if not isinstance(node, (ast.Import, ast.ImportFrom)) or node.lineno in added_imports:
|
||||
continue
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else (alias.name.split(".")[0] if isinstance(node, ast.Import) else alias.name)
|
||||
name = (
|
||||
alias.asname
|
||||
if alias.asname
|
||||
else (alias.name.split(".")[0] if isinstance(node, ast.Import) else alias.name)
|
||||
)
|
||||
if name in needed_names:
|
||||
import_lines.append(source_lines[node.lineno - 1])
|
||||
added_imports.add(node.lineno)
|
||||
|
|
@ -1340,9 +1519,7 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode
|
|||
return indented_block
|
||||
|
||||
|
||||
def _maybe_strip_docstring(
|
||||
node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig
|
||||
) -> cst.FunctionDef | cst.ClassDef:
|
||||
def _maybe_strip_docstring(node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig) -> cst.FunctionDef | cst.ClassDef:
|
||||
if cfg.remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
return node.with_changes(body=remove_docstring_from_body(node.body))
|
||||
return node
|
||||
|
|
@ -1361,17 +1538,17 @@ class PruneConfig:
|
|||
|
||||
|
||||
def parse_code_and_prune_cst(
|
||||
code: str,
|
||||
code: str | cst.Module,
|
||||
code_context_type: CodeContextType,
|
||||
target_functions: set[str],
|
||||
helpers_of_helper_functions: set[str] = set(), # noqa: B006
|
||||
remove_docstrings: bool = False,
|
||||
) -> cst.Module:
|
||||
"""Parse and filter the code CST, returning the pruned Module."""
|
||||
module = cst.parse_module(code)
|
||||
defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions)
|
||||
module = code if isinstance(code, cst.Module) else cst.parse_module(code)
|
||||
|
||||
if code_context_type == CodeContextType.READ_WRITABLE:
|
||||
defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions)
|
||||
cfg = PruneConfig(defs_with_usages=defs_with_usages, keep_class_init=True)
|
||||
elif code_context_type == CodeContextType.READ_ONLY:
|
||||
cfg = PruneConfig(
|
||||
|
|
@ -1402,10 +1579,7 @@ def parse_code_and_prune_cst(
|
|||
|
||||
|
||||
def prune_cst(
|
||||
node: cst.CSTNode,
|
||||
target_functions: set[str],
|
||||
cfg: PruneConfig,
|
||||
prefix: str = "",
|
||||
node: cst.CSTNode, target_functions: set[str], cfg: PruneConfig, prefix: str = ""
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
if isinstance(node, (cst.Import, cst.ImportFrom)):
|
||||
return None, False
|
||||
|
|
|
|||
|
|
@ -383,7 +383,9 @@ def remove_unused_definitions_recursively(
|
|||
if isinstance(statement, cst.FunctionDef):
|
||||
new_statements.append(statement)
|
||||
elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)):
|
||||
if class_has_dependencies or is_assignment_used(statement, definitions, name_prefix=f"{class_name}."):
|
||||
if class_has_dependencies or is_assignment_used(
|
||||
statement, definitions, name_prefix=f"{class_name}."
|
||||
):
|
||||
new_statements.append(statement)
|
||||
else:
|
||||
new_statements.append(statement)
|
||||
|
|
@ -425,13 +427,15 @@ def collect_top_level_defs_with_usages(
|
|||
return definitions
|
||||
|
||||
|
||||
def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str:
|
||||
def remove_unused_definitions_by_function_names(
|
||||
code: Union[str, cst.Module], qualified_function_names: set[str]
|
||||
) -> cst.Module:
|
||||
"""Remove top-level definitions (classes, variables, functions) not used by the specified qualified function names."""
|
||||
try:
|
||||
module = cst.parse_module(code)
|
||||
module = code if isinstance(code, cst.Module) else cst.parse_module(code)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to parse code with libcst: {type(e).__name__}: {e}")
|
||||
return code
|
||||
return code if isinstance(code, cst.Module) else cst.parse_module("")
|
||||
|
||||
try:
|
||||
defs_with_usages = collect_top_level_defs_with_usages(module, qualified_function_names)
|
||||
|
|
@ -439,11 +443,11 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
|
|||
# Apply the recursive removal transformation
|
||||
modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages)
|
||||
|
||||
return modified_module.code if modified_module else ""
|
||||
return modified_module if modified_module else cst.parse_module("")
|
||||
except Exception as e:
|
||||
# If any other error occurs during processing, return the original code
|
||||
logger.debug(f"Error processing code to remove unused definitions: {type(e).__name__}: {e}")
|
||||
return code
|
||||
return module
|
||||
|
||||
|
||||
def revert_unused_helper_functions(
|
||||
|
|
@ -569,11 +573,7 @@ def find_target_node(
|
|||
|
||||
|
||||
def _collect_attr_names(
|
||||
value_id: str,
|
||||
attr_name: str,
|
||||
class_name: str | None,
|
||||
names: set[str],
|
||||
imported_names_map: dict[str, set[str]],
|
||||
value_id: str, attr_name: str, class_name: str | None, names: set[str], imported_names_map: dict[str, set[str]]
|
||||
) -> None:
|
||||
if value_id == "self":
|
||||
names.add(attr_name)
|
||||
|
|
@ -605,9 +605,7 @@ def _collect_called_names(
|
|||
called.update(mapped_names)
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
if isinstance(node.func.value, ast.Name):
|
||||
_collect_attr_names(
|
||||
node.func.value.id, node.func.attr, class_name, called, imported_names_map
|
||||
)
|
||||
_collect_attr_names(node.func.value.id, node.func.attr, class_name, called, imported_names_map)
|
||||
else:
|
||||
called.add(node.func.attr)
|
||||
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
|
||||
|
|
|
|||
|
|
@ -540,7 +540,7 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
|
|||
|
||||
|
||||
def add_needed_imports_from_module(
|
||||
src_module_code: str,
|
||||
src_module_code: str | cst.Module,
|
||||
dst_module_code: str | cst.Module,
|
||||
src_path: Path,
|
||||
dst_path: Path,
|
||||
|
|
@ -549,7 +549,6 @@ def add_needed_imports_from_module(
|
|||
helper_functions_fqn: set[str] | None = None,
|
||||
) -> str:
|
||||
"""Add all needed and used source module code imports to the destination module code, and return it."""
|
||||
src_module_code = delete___future___aliased_imports(src_module_code)
|
||||
if not helper_functions_fqn:
|
||||
helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])}
|
||||
|
||||
|
|
@ -571,7 +570,10 @@ def add_needed_imports_from_module(
|
|||
)
|
||||
)
|
||||
try:
|
||||
src_module = cst.parse_module(src_module_code)
|
||||
if isinstance(src_module_code, cst.Module):
|
||||
src_module = src_module_code.visit(FutureAliasedImportTransformer())
|
||||
else:
|
||||
src_module = cst.parse_module(src_module_code).visit(FutureAliasedImportTransformer())
|
||||
# Exclude function/class bodies so GatherImportsVisitor only sees module-level imports.
|
||||
# Nested imports (inside functions) are part of function logic and must not be
|
||||
# scheduled for add/remove — RemoveImportsVisitor would strip them as "unused".
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def another_function():
|
|||
qualified_functions = {"main_function"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
# Normalize whitespace for comparison
|
||||
assert result.strip() == expected.strip()
|
||||
assert result.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_class_variable_removal() -> None:
|
||||
|
|
@ -84,7 +84,7 @@ def helper_function():
|
|||
qualified_functions = {"helper_function"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
# Normalize whitespace for comparison
|
||||
assert result.strip() == expected.strip()
|
||||
assert result.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_complex_variable_dependencies() -> None:
|
||||
|
|
@ -122,7 +122,7 @@ def tuple_user():
|
|||
|
||||
qualified_functions = {"main_function"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
assert result.strip() == expected.strip()
|
||||
assert result.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_type_annotation_usage() -> None:
|
||||
|
|
@ -156,7 +156,7 @@ def unused_function(param: UnusedType) -> UnusedType:
|
|||
qualified_functions = {"main_function"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
# Normalize whitespace for comparison
|
||||
assert result.strip() == expected.strip()
|
||||
assert result.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_class_method_with_dunder_methods() -> None:
|
||||
|
|
@ -215,7 +215,7 @@ def helper_function():
|
|||
qualified_functions = {"MyClass.target_method"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
# Normalize whitespace for comparison
|
||||
assert result.strip() == expected.strip()
|
||||
assert result.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_complex_type_annotations() -> None:
|
||||
|
|
@ -263,7 +263,7 @@ def unused_function(param: UnusedType) -> None:
|
|||
|
||||
qualified_functions = {"process_data"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
assert result.strip() == expected.strip()
|
||||
assert result.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_try_except_finally_variables() -> None:
|
||||
|
|
@ -325,7 +325,7 @@ def unused_function():
|
|||
|
||||
qualified_functions = {"use_constants", "use_cleanup"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
assert result.strip() == expected.strip()
|
||||
assert result.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_base_class_inheritance() -> None:
|
||||
|
|
@ -383,8 +383,9 @@ def test_function():
|
|||
qualified_functions = {"test_function"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
# LayoutDumper should be preserved because ObjectDetectionLayoutDumper inherits from it
|
||||
assert "class LayoutDumper" in result
|
||||
assert "class ObjectDetectionLayoutDumper" in result
|
||||
assert "class LayoutDumper" in result.code
|
||||
assert "class ObjectDetectionLayoutDumper" in result.code
|
||||
assert result.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_conditional_and_loop_variables() -> None:
|
||||
|
|
@ -471,7 +472,7 @@ def unused_function():
|
|||
|
||||
qualified_functions = {"get_platform_info", "get_loop_result"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
assert result.strip() == expected.strip()
|
||||
assert result.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_enum_attribute_access_dependency() -> None:
|
||||
|
|
@ -519,10 +520,10 @@ def process_message(kind):
|
|||
qualified_functions = {"process_message"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
# MessageKind should be preserved because process_message uses MessageKind.VALUE
|
||||
assert "class MessageKind" in result
|
||||
assert "class MessageKind" in result.code
|
||||
# UNUSED_VAR should be removed
|
||||
assert "UNUSED_VAR" not in result
|
||||
assert result.strip() == expected.strip()
|
||||
assert "UNUSED_VAR" not in result.code
|
||||
assert result.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_attribute_access_does_not_track_attr_name() -> None:
|
||||
|
|
@ -551,7 +552,7 @@ class MyClass:
|
|||
qualified_functions = {"MyClass.get_x", "MyClass.__init__"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
# Module-level x should NOT be kept (self.x doesn't reference it)
|
||||
assert 'x = "module_level_x"' not in result
|
||||
assert 'x = "module_level_x"' not in result.code
|
||||
# UNUSED_VAR should also be removed
|
||||
assert "UNUSED_VAR" not in result
|
||||
assert result.strip() == expected.strip()
|
||||
assert "UNUSED_VAR" not in result.code
|
||||
assert result.code.strip() == expected.strip()
|
||||
|
|
|
|||
Loading…
Reference in a new issue