refactor: simplify code_context_extractor by extracting helper and removing dead code
- Extract build_testgen_context helper to reduce duplication in testgen token limit handling (~50 lines to ~20 lines) - Remove unused extract_code_string_context_from_files function (~100 lines) - Import get_section_names from unused_definition_remover instead of duplicating
This commit is contained in:
parent
47b5235978
commit
69740f0340
1 changed files with 48 additions and 161 deletions
|
|
@ -17,6 +17,7 @@ from codeflash.code_utils.config_consts import OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
|
|||
from codeflash.context.unused_definition_remover import (
|
||||
collect_top_level_defs_with_usages,
|
||||
extract_names_from_targets,
|
||||
get_section_names,
|
||||
remove_unused_definitions_by_function_names,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
|
||||
|
|
@ -36,6 +37,38 @@ if TYPE_CHECKING:
|
|||
from codeflash.context.unused_definition_remover import UsageInfo
|
||||
|
||||
|
||||
def build_testgen_context(
|
||||
helpers_of_fto_dict: dict[Path, set[FunctionSource]],
|
||||
helpers_of_helpers_dict: dict[Path, set[FunctionSource]],
|
||||
project_root_path: Path,
|
||||
remove_docstrings: bool, # noqa: FBT001
|
||||
include_imported_classes: bool, # noqa: FBT001
|
||||
) -> CodeStringsMarkdown:
|
||||
"""Build testgen context with optional imported class definitions and external base inits."""
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=remove_docstrings,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
)
|
||||
|
||||
if include_imported_classes:
|
||||
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
|
||||
if imported_class_context.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + imported_class_context.code_strings
|
||||
)
|
||||
|
||||
external_base_inits = get_external_base_class_inits(testgen_context, project_root_path)
|
||||
if external_base_inits.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + external_base_inits.code_strings
|
||||
)
|
||||
|
||||
return testgen_context
|
||||
|
||||
|
||||
def get_code_optimization_context(
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root_path: Path,
|
||||
|
|
@ -119,69 +152,37 @@ def get_code_optimization_context(
|
|||
logger.debug("Code context has exceeded token limit, removing read-only code")
|
||||
read_only_context_code = ""
|
||||
|
||||
# Extract code context for testgen
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
# Extract code context for testgen with progressive fallback for token limits
|
||||
# Try in order: full context -> remove docstrings -> remove imported classes
|
||||
testgen_context = build_testgen_context(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=False,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
include_imported_classes=True,
|
||||
)
|
||||
|
||||
# Extract class definitions for imported types from project modules
|
||||
# This helps the LLM understand class constructors and structure
|
||||
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
|
||||
if imported_class_context.code_strings:
|
||||
# Merge imported class definitions into testgen context
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + imported_class_context.code_strings
|
||||
)
|
||||
|
||||
# Extract __init__ methods from external library base classes
|
||||
# This helps the LLM understand how to mock/test classes that inherit from external libraries
|
||||
external_base_inits = get_external_base_class_inits(testgen_context, project_root_path)
|
||||
if external_base_inits.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + external_base_inits.code_strings
|
||||
)
|
||||
|
||||
testgen_markdown_code = testgen_context.markdown
|
||||
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
|
||||
if testgen_code_token_length > testgen_token_limit:
|
||||
# First try removing docstrings
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
|
||||
logger.debug("Testgen context exceeded token limit, removing docstrings")
|
||||
testgen_context = build_testgen_context(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
include_imported_classes=True,
|
||||
)
|
||||
# Re-extract imported classes (they may still fit)
|
||||
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
|
||||
if imported_class_context.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + imported_class_context.code_strings
|
||||
)
|
||||
# Re-extract external base class inits
|
||||
external_base_inits = get_external_base_class_inits(testgen_context, project_root_path)
|
||||
if external_base_inits.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + external_base_inits.code_strings
|
||||
)
|
||||
testgen_markdown_code = testgen_context.markdown
|
||||
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
|
||||
if testgen_code_token_length > testgen_token_limit:
|
||||
# If still over limit, try without imported class definitions
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
|
||||
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
|
||||
logger.debug("Testgen context still exceeded token limit, removing imported class definitions")
|
||||
testgen_context = build_testgen_context(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
include_imported_classes=False,
|
||||
)
|
||||
testgen_markdown_code = testgen_context.markdown
|
||||
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
|
||||
if testgen_code_token_length > testgen_token_limit:
|
||||
|
||||
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
|
||||
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
|
||||
code_hash_context = hashing_code_context.markdown
|
||||
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
|
||||
|
|
@ -197,114 +198,6 @@ def get_code_optimization_context(
|
|||
)
|
||||
|
||||
|
||||
def extract_code_string_context_from_files(
|
||||
helpers_of_fto: dict[Path, set[FunctionSource]],
|
||||
helpers_of_helpers: dict[Path, set[FunctionSource]],
|
||||
project_root_path: Path,
|
||||
remove_docstrings: bool = False, # noqa: FBT001, FBT002
|
||||
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
|
||||
) -> CodeString:
|
||||
"""Extract code context from files containing target functions and their helpers.
|
||||
This function processes two sets of files:
|
||||
1. Files containing the function to optimize (fto) and their first-degree helpers
|
||||
2. Files containing only helpers of helpers (with no overlap with the first set).
|
||||
|
||||
For each file, it extracts relevant code based on the specified context type, adds necessary
|
||||
imports, and combines them.
|
||||
|
||||
Args:
|
||||
----
|
||||
helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers
|
||||
helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions
|
||||
project_root_path: Root path of the project
|
||||
remove_docstrings: Whether to remove docstrings from the extracted code
|
||||
code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN)
|
||||
|
||||
Returns:
|
||||
-------
|
||||
CodeString containing the extracted code context with necessary imports
|
||||
|
||||
""" # noqa: D205
|
||||
# Rearrange to remove overlaps, so we only access each file path once
|
||||
helpers_of_helpers_no_overlap = defaultdict(set)
|
||||
for file_path, function_sources in helpers_of_helpers.items():
|
||||
if file_path in helpers_of_fto:
|
||||
# Remove duplicates within the same file path, in case a helper of helper is also a helper of fto
|
||||
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
|
||||
else:
|
||||
helpers_of_helpers_no_overlap[file_path] = function_sources
|
||||
|
||||
final_code_string_context = ""
|
||||
|
||||
# Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files
|
||||
for file_path, function_sources in helpers_of_fto.items():
|
||||
try:
|
||||
original_code = file_path.read_text("utf8")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
qualified_function_names = {func.qualified_name for func in function_sources}
|
||||
helpers_of_helpers_qualified_names = {
|
||||
func.qualified_name for func in helpers_of_helpers.get(file_path, set())
|
||||
}
|
||||
code_without_unused_defs = remove_unused_definitions_by_function_names(
|
||||
original_code, qualified_function_names | helpers_of_helpers_qualified_names
|
||||
)
|
||||
code_context = parse_code_and_prune_cst(
|
||||
code_without_unused_defs,
|
||||
code_context_type,
|
||||
qualified_function_names,
|
||||
helpers_of_helpers_qualified_names,
|
||||
remove_docstrings,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
if code_context.strip():
|
||||
final_code_string_context += f"\n{code_context}"
|
||||
final_code_string_context = add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=final_code_string_context,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions=list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())),
|
||||
)
|
||||
if code_context_type == CodeContextType.READ_WRITABLE:
|
||||
return CodeString(code=final_code_string_context)
|
||||
# Extract code from file paths containing helpers of helpers
|
||||
for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items():
|
||||
try:
|
||||
original_code = file_path.read_text("utf8")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
qualified_helper_function_names = {func.qualified_name for func in helper_function_sources}
|
||||
code_without_unused_defs = remove_unused_definitions_by_function_names(
|
||||
original_code, qualified_helper_function_names
|
||||
)
|
||||
code_context = parse_code_and_prune_cst(
|
||||
code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
||||
if code_context.strip():
|
||||
final_code_string_context += f"\n{code_context}"
|
||||
final_code_string_context = add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=final_code_string_context,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())),
|
||||
)
|
||||
return CodeString(code=final_code_string_context)
|
||||
|
||||
|
||||
def extract_code_markdown_context_from_files(
|
||||
helpers_of_fto: dict[Path, set[FunctionSource]],
|
||||
helpers_of_helpers: dict[Path, set[FunctionSource]],
|
||||
|
|
@ -939,12 +832,6 @@ def get_imported_names(import_node: cst.Import | cst.ImportFrom) -> set[str]:
|
|||
return names
|
||||
|
||||
|
||||
def get_section_names(node: cst.CSTNode) -> list[str]:
|
||||
"""Returns the section attribute names (e.g., body, orelse) for a given node if they exist.""" # noqa: D401
|
||||
possible_sections = ["body", "orelse", "finalbody", "handlers"]
|
||||
return [sec for sec in possible_sections if hasattr(node, sec)]
|
||||
|
||||
|
||||
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
|
||||
"""Removes the docstring from an indented block if it exists.""" # noqa: D401
|
||||
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
|
||||
|
|
|
|||
Loading…
Reference in a new issue