refactor: simplify code_context_extractor by extracting helper and removing dead code

- Extract build_testgen_context helper to reduce duplication in testgen
  token limit handling (~50 lines to ~20 lines)
- Remove unused extract_code_string_context_from_files function (~100 lines)
- Import get_section_names from unused_definition_remover instead of duplicating
This commit is contained in:
Kevin Turcios 2026-01-24 09:25:41 -05:00
parent 47b5235978
commit 69740f0340

View file

@ -17,6 +17,7 @@ from codeflash.code_utils.config_consts import OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
from codeflash.context.unused_definition_remover import (
collect_top_level_defs_with_usages,
extract_names_from_targets,
get_section_names,
remove_unused_definitions_by_function_names,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
@ -36,6 +37,38 @@ if TYPE_CHECKING:
from codeflash.context.unused_definition_remover import UsageInfo
def build_testgen_context(
helpers_of_fto_dict: dict[Path, set[FunctionSource]],
helpers_of_helpers_dict: dict[Path, set[FunctionSource]],
project_root_path: Path,
remove_docstrings: bool, # noqa: FBT001
include_imported_classes: bool, # noqa: FBT001
) -> CodeStringsMarkdown:
"""Build testgen context with optional imported class definitions and external base inits."""
testgen_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=remove_docstrings,
code_context_type=CodeContextType.TESTGEN,
)
if include_imported_classes:
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
if imported_class_context.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + imported_class_context.code_strings
)
external_base_inits = get_external_base_class_inits(testgen_context, project_root_path)
if external_base_inits.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + external_base_inits.code_strings
)
return testgen_context
def get_code_optimization_context(
function_to_optimize: FunctionToOptimize,
project_root_path: Path,
@ -119,69 +152,37 @@ def get_code_optimization_context(
logger.debug("Code context has exceeded token limit, removing read-only code")
read_only_context_code = ""
# Extract code context for testgen
testgen_context = extract_code_markdown_context_from_files(
# Extract code context for testgen with progressive fallback for token limits
# Try in order: full context -> remove docstrings -> remove imported classes
testgen_context = build_testgen_context(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=False,
code_context_type=CodeContextType.TESTGEN,
include_imported_classes=True,
)
# Extract class definitions for imported types from project modules
# This helps the LLM understand class constructors and structure
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
if imported_class_context.code_strings:
# Merge imported class definitions into testgen context
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + imported_class_context.code_strings
)
# Extract __init__ methods from external library base classes
# This helps the LLM understand how to mock/test classes that inherit from external libraries
external_base_inits = get_external_base_class_inits(testgen_context, project_root_path)
if external_base_inits.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + external_base_inits.code_strings
)
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
# First try removing docstrings
testgen_context = extract_code_markdown_context_from_files(
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
logger.debug("Testgen context exceeded token limit, removing docstrings")
testgen_context = build_testgen_context(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
code_context_type=CodeContextType.TESTGEN,
include_imported_classes=True,
)
# Re-extract imported classes (they may still fit)
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
if imported_class_context.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + imported_class_context.code_strings
)
# Re-extract external base class inits
external_base_inits = get_external_base_class_inits(testgen_context, project_root_path)
if external_base_inits.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + external_base_inits.code_strings
)
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
# If still over limit, try without imported class definitions
testgen_context = extract_code_markdown_context_from_files(
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
logger.debug("Testgen context still exceeded token limit, removing imported class definitions")
testgen_context = build_testgen_context(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
code_context_type=CodeContextType.TESTGEN,
include_imported_classes=False,
)
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
code_hash_context = hashing_code_context.markdown
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
@ -197,114 +198,6 @@ def get_code_optimization_context(
)
def extract_code_string_context_from_files(
helpers_of_fto: dict[Path, set[FunctionSource]],
helpers_of_helpers: dict[Path, set[FunctionSource]],
project_root_path: Path,
remove_docstrings: bool = False, # noqa: FBT001, FBT002
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
) -> CodeString:
"""Extract code context from files containing target functions and their helpers.
This function processes two sets of files:
1. Files containing the function to optimize (fto) and their first-degree helpers
2. Files containing only helpers of helpers (with no overlap with the first set).
For each file, it extracts relevant code based on the specified context type, adds necessary
imports, and combines them.
Args:
----
helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers
helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions
project_root_path: Root path of the project
remove_docstrings: Whether to remove docstrings from the extracted code
code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN)
Returns:
-------
CodeString containing the extracted code context with necessary imports
""" # noqa: D205
# Rearrange to remove overlaps, so we only access each file path once
helpers_of_helpers_no_overlap = defaultdict(set)
for file_path, function_sources in helpers_of_helpers.items():
if file_path in helpers_of_fto:
# Remove duplicates within the same file path, in case a helper of helper is also a helper of fto
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
else:
helpers_of_helpers_no_overlap[file_path] = function_sources
final_code_string_context = ""
# Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files
for file_path, function_sources in helpers_of_fto.items():
try:
original_code = file_path.read_text("utf8")
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
qualified_function_names = {func.qualified_name for func in function_sources}
helpers_of_helpers_qualified_names = {
func.qualified_name for func in helpers_of_helpers.get(file_path, set())
}
code_without_unused_defs = remove_unused_definitions_by_function_names(
original_code, qualified_function_names | helpers_of_helpers_qualified_names
)
code_context = parse_code_and_prune_cst(
code_without_unused_defs,
code_context_type,
qualified_function_names,
helpers_of_helpers_qualified_names,
remove_docstrings,
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
if code_context.strip():
final_code_string_context += f"\n{code_context}"
final_code_string_context = add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=final_code_string_context,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions=list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())),
)
if code_context_type == CodeContextType.READ_WRITABLE:
return CodeString(code=final_code_string_context)
# Extract code from file paths containing helpers of helpers
for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items():
try:
original_code = file_path.read_text("utf8")
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
qualified_helper_function_names = {func.qualified_name for func in helper_function_sources}
code_without_unused_defs = remove_unused_definitions_by_function_names(
original_code, qualified_helper_function_names
)
code_context = parse_code_and_prune_cst(
code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
if code_context.strip():
final_code_string_context += f"\n{code_context}"
final_code_string_context = add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=final_code_string_context,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())),
)
return CodeString(code=final_code_string_context)
def extract_code_markdown_context_from_files(
helpers_of_fto: dict[Path, set[FunctionSource]],
helpers_of_helpers: dict[Path, set[FunctionSource]],
@ -939,12 +832,6 @@ def get_imported_names(import_node: cst.Import | cst.ImportFrom) -> set[str]:
return names
def get_section_names(node: cst.CSTNode) -> list[str]:
"""Returns the section attribute names (e.g., body, orelse) for a given node if they exist.""" # noqa: D401
possible_sections = ["body", "orelse", "finalbody", "handlers"]
return [sec for sec in possible_sections if hasattr(node, sec)]
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
"""Removes the docstring from an indented block if it exists.""" # noqa: D401
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):