diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index bccaa46cb..34fc2fe9e 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -22,11 +22,10 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v6 - with: - version: "0.5.30" - name: sync uv run: | + uv venv --seed uv sync diff --git a/.gitignore b/.gitignore index 899bbf88f..fe9d37dc9 100644 --- a/.gitignore +++ b/.gitignore @@ -262,3 +262,6 @@ tessl.json **/node_modules/** /dist-nuitka/main.dist/* packages/codeflash/.npmrc + +# Tessl auto-generates AGENTS.md on install; ignore to avoid cluttering git status +AGENTS.md diff --git a/CLAUDE.md b/CLAUDE.md index 7499c45ff..c77ece2b0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -88,6 +88,10 @@ else: - Commit message body should be concise (1-2 sentences max) - PR titles should also use conventional format + + # Agent Rules @.tessl/RULES.md follow the [instructions](.tessl/RULES.md) + +@AGENTS.md diff --git a/codeflash/LICENSE b/codeflash/LICENSE index 285f15162..19fb335ef 100644 --- a/codeflash/LICENSE +++ b/codeflash/LICENSE @@ -3,7 +3,7 @@ Business Source License 1.1 Parameters Licensor: CodeFlash Inc. -Licensed Work: Codeflash Client version 0.19.x +Licensed Work: Codeflash Client version 0.20.x The Licensed Work is (c) 2024 CodeFlash Inc. Additional Use Grant: None. Production use of the Licensed Work is only permitted @@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte Platform. Please visit codeflash.ai for further information. -Change Date: 2029-12-21 +Change Date: 2030-01-26 Change License: MIT diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 1d7df25a9..1071dbc51 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -268,7 +268,7 @@ class AiServiceClient: logger.info("!lsp|Rewriting as a JIT function…") console.rule() try: - response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=60) + response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=self.timeout) except requests.exceptions.RequestException as e: logger.exception(f"Error generating jit rewritten candidate: {e}") ph("cli-jit-rewrite-error-caught", {"error": str(e)}) diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index d5eab953b..9dca009fd 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -84,6 +84,9 @@ def parse_args() -> Namespace: parser.add_argument( "--no-gen-tests", action="store_true", help="Do not generate tests, use only existing tests for optimization." ) + parser.add_argument( + "--no-jit-opts", action="store_true", help="Do not generate JIT-compiled optimizations for numerical code." + ) parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review") parser.add_argument( "--verify-setup", diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 66281b0b5..4bb42a96b 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -26,12 +26,117 @@ if TYPE_CHECKING: from codeflash.models.models import FunctionSource +class GlobalFunctionCollector(cst.CSTVisitor): + """Collects all module-level function definitions (not inside classes or other functions).""" + + def __init__(self) -> None: + super().__init__() + self.functions: dict[str, cst.FunctionDef] = {} + self.function_order: list[str] = [] + self.scope_depth = 0 + + def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: + if self.scope_depth == 0: + # Module-level function + name = node.name.value + self.functions[name] = node + if name not in self.function_order: + self.function_order.append(name) + self.scope_depth += 1 + return True + + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002 + self.scope_depth -= 1 + + def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002 + self.scope_depth += 1 + return True + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002 + self.scope_depth -= 1 + + +class GlobalFunctionTransformer(cst.CSTTransformer): + """Transforms/adds module-level functions from the new file to the original file.""" + + def __init__(self, new_functions: dict[str, cst.FunctionDef], new_function_order: list[str]) -> None: + super().__init__() + self.new_functions = new_functions + self.new_function_order = new_function_order + self.processed_functions: set[str] = set() + self.scope_depth = 0 + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002 + self.scope_depth += 1 + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + self.scope_depth -= 1 + if self.scope_depth > 0: + return updated_node + + # Check if this is a module-level function we need to replace + name = original_node.name.value + if name in self.new_functions: + self.processed_functions.add(name) + return self.new_functions[name] + return updated_node + + def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002 + self.scope_depth += 1 + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002 + self.scope_depth -= 1 + return updated_node + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 + # Add any new functions that weren't in the original file + new_statements = list(updated_node.body) + + functions_to_append = [ + self.new_functions[name] + for name in self.new_function_order + if name not in self.processed_functions and name in self.new_functions + ] + + if functions_to_append: + # Find the position of the last function or class definition + insert_index = find_insertion_index_after_imports(updated_node) + for i, stmt in enumerate(new_statements): + if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): + insert_index = i + 1 + + # Add empty line before each new function + function_nodes = [] + for func in functions_to_append: + func_with_empty_line = func.with_changes(leading_lines=[cst.EmptyLine(), *func.leading_lines]) + function_nodes.append(func_with_empty_line) + + new_statements = list(chain(new_statements[:insert_index], function_nodes, new_statements[insert_index:])) + + return updated_node.with_changes(body=new_statements) + + +def collect_referenced_names(node: cst.CSTNode) -> set[str]: + """Collect all names referenced in a CST node using recursive traversal.""" + names: set[str] = set() + + def _collect(n: cst.CSTNode) -> None: + if isinstance(n, cst.Name): + names.add(n.value) + # Recursively process all children + for child in n.children: + _collect(child) + + _collect(node) + return names + + class GlobalAssignmentCollector(cst.CSTVisitor): """Collects all global assignment statements.""" def __init__(self) -> None: super().__init__() - self.assignments: dict[str, cst.Assign] = {} + self.assignments: dict[str, cst.Assign | cst.AnnAssign] = {} self.assignment_order: list[str] = [] # Track scope depth to identify global assignments self.scope_depth = 0 @@ -73,6 +178,21 @@ class GlobalAssignmentCollector(cst.CSTVisitor): self.assignment_order.append(name) return True + def visit_AnnAssign(self, node: cst.AnnAssign) -> Optional[bool]: + # Handle annotated assignments like: _CACHE: Dict[str, int] = {} + # Only process module-level annotated assignments with a value + if ( + self.scope_depth == 0 + and self.if_else_depth == 0 + and isinstance(node.target, cst.Name) + and node.value is not None + ): + name = node.target.value + self.assignments[name] = node + if name not in self.assignment_order: + self.assignment_order.append(name) + return True + def find_insertion_index_after_imports(node: cst.Module) -> int: """Find the position of the last import statement in the top-level of the module.""" @@ -104,7 +224,7 @@ def find_insertion_index_after_imports(node: cst.Module) -> int: class GlobalAssignmentTransformer(cst.CSTTransformer): """Transforms global assignments in the original file with those from the new file.""" - def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: list[str]) -> None: + def __init__(self, new_assignments: dict[str, cst.Assign | cst.AnnAssign], new_assignment_order: list[str]) -> None: super().__init__() self.new_assignments = new_assignments self.new_assignment_order = new_assignment_order @@ -151,38 +271,120 @@ class GlobalAssignmentTransformer(cst.CSTTransformer): return updated_node + def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.CSTNode: + if self.scope_depth > 0 or self.if_else_depth > 0: + return updated_node + + # Check if this is a global annotated assignment we need to replace + if isinstance(original_node.target, cst.Name): + name = original_node.target.value + if name in self.new_assignments: + self.processed_assignments.add(name) + return self.new_assignments[name] + + return updated_node + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 # Add any new assignments that weren't in the original file new_statements = list(updated_node.body) # Find assignments to append assignments_to_append = [ - self.new_assignments[name] + (name, self.new_assignments[name]) for name in self.new_assignment_order if name not in self.processed_assignments and name in self.new_assignments ] - if assignments_to_append: - # after last top-level imports + if not assignments_to_append: + return updated_node.with_changes(body=new_statements) + + # Collect all class and function names defined in the module + # These are the names that assignments might reference + module_defined_names: set[str] = set() + for stmt in new_statements: + if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)): + module_defined_names.add(stmt.name.value) + + # Partition assignments: those that reference module definitions go at the end, + # those that don't can go right after imports + assignments_after_imports: list[tuple[str, cst.Assign | cst.AnnAssign]] = [] + assignments_after_definitions: list[tuple[str, cst.Assign | cst.AnnAssign]] = [] + + for name, assignment in assignments_to_append: + # Get the value being assigned + if isinstance(assignment, (cst.Assign, cst.AnnAssign)) and assignment.value is not None: + value_node = assignment.value + else: + # No value to analyze, safe to place after imports + assignments_after_imports.append((name, assignment)) + continue + + # Collect names referenced in the assignment value + referenced_names = collect_referenced_names(value_node) + + # Check if any referenced names are module-level definitions + if referenced_names & module_defined_names: + # This assignment references a class/function, place it after definitions + assignments_after_definitions.append((name, assignment)) + else: + # Safe to place right after imports + assignments_after_imports.append((name, assignment)) + + # Insert assignments that don't depend on module definitions right after imports + if assignments_after_imports: insert_index = find_insertion_index_after_imports(updated_node) + assignment_lines = [ + cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) + for _, assignment in assignments_after_imports + ] + new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:])) + + # Insert assignments that depend on module definitions after all class/function definitions + if assignments_after_definitions: + # Find the position after the last function or class definition + insert_index = find_insertion_index_after_imports(cst.Module(body=new_statements)) + for i, stmt in enumerate(new_statements): + if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): + insert_index = i + 1 assignment_lines = [ cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) - for assignment in assignments_to_append + for _, assignment in assignments_after_definitions ] - new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:])) - # Add a blank line after the last assignment if needed - after_index = insert_index + len(assignment_lines) - if after_index < len(new_statements): - next_stmt = new_statements[after_index] - # If there's no empty line, add one - has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines) - if not has_empty: - new_statements[after_index] = next_stmt.with_changes( - leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines] - ) + return updated_node.with_changes(body=new_statements) + + +class GlobalStatementTransformer(cst.CSTTransformer): + """Transformer that appends global statements at the end of the module. + + This ensures that global statements (like function calls at module level) are placed + after all functions, classes, and assignments they might reference, preventing NameError + at module load time. + + This transformer should be run LAST after GlobalFunctionTransformer and + GlobalAssignmentTransformer have already added their content. + """ + + def __init__(self, global_statements: list[cst.SimpleStatementLine]) -> None: + super().__init__() + self.global_statements = global_statements + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 + if not self.global_statements: + return updated_node + + new_statements = list(updated_node.body) + + # Add empty line before each statement for readability + statement_lines = [ + stmt.with_changes(leading_lines=[cst.EmptyLine(), *stmt.leading_lines]) for stmt in self.global_statements + ] + + # Append statements at the end of the module + # This ensures they come after all functions, classes, and assignments + new_statements.extend(statement_lines) return updated_node.with_changes(body=new_statements) @@ -214,8 +416,8 @@ class GlobalStatementCollector(cst.CSTVisitor): def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: if not self.in_function_or_class: for statement in node.body: - # Skip imports - if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)): + # Skip imports and assignments (both regular and annotated) + if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign, cst.AnnAssign)): self.global_statements.append(node) break @@ -310,40 +512,6 @@ class DottedImportCollector(cst.CSTVisitor): self._collect_imports_from_block(node.body) -class ImportInserter(cst.CSTTransformer): - """Transformer that inserts global statements after the last import.""" - - def __init__(self, global_statements: list[cst.SimpleStatementLine], last_import_line: int) -> None: - super().__init__() - self.global_statements = global_statements - self.last_import_line = last_import_line - self.current_line = 0 - self.inserted = False - - def leave_SimpleStatementLine( - self, - original_node: cst.SimpleStatementLine, # noqa: ARG002 - updated_node: cst.SimpleStatementLine, - ) -> cst.Module: - self.current_line += 1 - - # If we're right after the last import and haven't inserted yet - if self.current_line == self.last_import_line and not self.inserted: - self.inserted = True - return cst.Module(body=[updated_node, *self.global_statements]) - - return cst.Module(body=[updated_node]) - - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 - # If there were no imports, add at the beginning of the module - if self.last_import_line == 0 and not self.inserted: - updated_body = list(updated_node.body) - for stmt in reversed(self.global_statements): - updated_body.insert(0, stmt) - return updated_node.with_changes(body=updated_body) - return updated_node - - def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]: """Extract global statements from source code.""" module = cst.parse_module(source_code) @@ -395,34 +563,58 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: continue unique_global_statements.append(stmt) - mod_dst_code = dst_module_code - # Insert unique global statements if any - if unique_global_statements: - last_import_line = find_last_import_line(dst_module_code) - # Reuse already-parsed dst_module - transformer = ImportInserter(unique_global_statements, last_import_line) - # Use visit inplace, don't parse again - modified_module = dst_module.visit(transformer) - mod_dst_code = modified_module.code - # Parse the code after insertion - original_module = cst.parse_module(mod_dst_code) - else: - # No new statements to insert, reuse already-parsed dst_module - original_module = dst_module + # Reuse already-parsed dst_module + original_module = dst_module # Parse the src_module_code once only (already done above: src_module) # Collect assignments from the new file - new_collector = GlobalAssignmentCollector() - src_module.visit(new_collector) - # Only create transformer if there are assignments to insert/transform - if not new_collector.assignments: # nothing to transform - return mod_dst_code + new_assignment_collector = GlobalAssignmentCollector() + src_module.visit(new_assignment_collector) - # Transform the original destination module - transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order) - transformed_module = original_module.visit(transformer) + # Collect module-level functions from both source and destination + src_function_collector = GlobalFunctionCollector() + src_module.visit(src_function_collector) - return transformed_module.code + dst_function_collector = GlobalFunctionCollector() + original_module.visit(dst_function_collector) + + # Filter out functions that already exist in the destination (only add truly new functions) + new_functions = { + name: func + for name, func in src_function_collector.functions.items() + if name not in dst_function_collector.functions + } + new_function_order = [name for name in src_function_collector.function_order if name in new_functions] + + # If there are no assignments, no new functions, and no global statements, return unchanged + if not new_assignment_collector.assignments and not new_functions and not unique_global_statements: + return dst_module_code + + # The order of transformations matters: + # 1. Functions first - so assignments and statements can reference them + # 2. Assignments second - so they come after functions but before statements + # 3. Global statements last - so they can reference both functions and assignments + + # Transform functions if any + if new_functions: + function_transformer = GlobalFunctionTransformer(new_functions, new_function_order) + original_module = original_module.visit(function_transformer) + + # Transform assignments if any + if new_assignment_collector.assignments: + transformer = GlobalAssignmentTransformer( + new_assignment_collector.assignments, new_assignment_collector.assignment_order + ) + original_module = original_module.visit(transformer) + + # Insert global statements (like function calls at module level) LAST, + # after all functions and assignments are added, to ensure they can reference any + # functions or variables defined in the module + if unique_global_statements: + statement_transformer = GlobalStatementTransformer(unique_global_statements) + original_module = original_module.visit(statement_transformer) + + return original_module.code def resolve_star_import(module_name: str, project_root: Path) -> set[str]: diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 4590c1106..4b7b49015 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -5,6 +5,7 @@ import hashlib import os from collections import defaultdict from itertools import chain +from pathlib import Path from typing import TYPE_CHECKING, cast import libcst as cst @@ -16,6 +17,7 @@ from codeflash.code_utils.config_consts import OPTIMIZATION_CONTEXT_TOKEN_LIMIT, from codeflash.context.unused_definition_remover import ( collect_top_level_defs_with_usages, extract_names_from_targets, + get_section_names, remove_unused_definitions_by_function_names, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001 @@ -32,14 +34,44 @@ from codeflash.models.models import ( from codeflash.optimization.function_context import belongs_to_function_qualified if TYPE_CHECKING: - from pathlib import Path - from jedi.api.classes import Name from libcst import CSTNode from codeflash.context.unused_definition_remover import UsageInfo +def build_testgen_context( + helpers_of_fto_dict: dict[Path, set[FunctionSource]], + helpers_of_helpers_dict: dict[Path, set[FunctionSource]], + project_root_path: Path, + remove_docstrings: bool, # noqa: FBT001 + include_imported_classes: bool, # noqa: FBT001 +) -> CodeStringsMarkdown: + """Build testgen context with optional imported class definitions and external base inits.""" + testgen_context = extract_code_markdown_context_from_files( + helpers_of_fto_dict, + helpers_of_helpers_dict, + project_root_path, + remove_docstrings=remove_docstrings, + code_context_type=CodeContextType.TESTGEN, + ) + + if include_imported_classes: + imported_class_context = get_imported_class_definitions(testgen_context, project_root_path) + if imported_class_context.code_strings: + testgen_context = CodeStringsMarkdown( + code_strings=testgen_context.code_strings + imported_class_context.code_strings + ) + + external_base_inits = get_external_base_class_inits(testgen_context, project_root_path) + if external_base_inits.code_strings: + testgen_context = CodeStringsMarkdown( + code_strings=testgen_context.code_strings + external_base_inits.code_strings + ) + + return testgen_context + + def get_code_optimization_context( function_to_optimize: FunctionToOptimize, project_root_path: Path, @@ -129,55 +161,37 @@ def get_code_optimization_context( logger.debug("Code context has exceeded token limit, removing read-only code") read_only_context_code = "" - # Extract code context for testgen - testgen_context = extract_code_markdown_context_from_files( + # Extract code context for testgen with progressive fallback for token limits + # Try in order: full context -> remove docstrings -> remove imported classes + testgen_context = build_testgen_context( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=False, - code_context_type=CodeContextType.TESTGEN, + include_imported_classes=True, ) - # Extract class definitions for imported types from project modules - # This helps the LLM understand class constructors and structure - imported_class_context = get_imported_class_definitions(testgen_context, project_root_path) - if imported_class_context.code_strings: - # Merge imported class definitions into testgen context - testgen_context = CodeStringsMarkdown( - code_strings=testgen_context.code_strings + imported_class_context.code_strings - ) - - testgen_markdown_code = testgen_context.markdown - testgen_code_token_length = encoded_tokens_len(testgen_markdown_code) - if testgen_code_token_length > testgen_token_limit: - # First try removing docstrings - testgen_context = extract_code_markdown_context_from_files( + if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: + logger.debug("Testgen context exceeded token limit, removing docstrings") + testgen_context = build_testgen_context( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True, - code_context_type=CodeContextType.TESTGEN, + include_imported_classes=True, ) - # Re-extract imported classes (they may still fit) - imported_class_context = get_imported_class_definitions(testgen_context, project_root_path) - if imported_class_context.code_strings: - testgen_context = CodeStringsMarkdown( - code_strings=testgen_context.code_strings + imported_class_context.code_strings - ) - testgen_markdown_code = testgen_context.markdown - testgen_code_token_length = encoded_tokens_len(testgen_markdown_code) - if testgen_code_token_length > testgen_token_limit: - # If still over limit, try without imported class definitions - testgen_context = extract_code_markdown_context_from_files( + + if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: + logger.debug("Testgen context still exceeded token limit, removing imported class definitions") + testgen_context = build_testgen_context( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True, - code_context_type=CodeContextType.TESTGEN, + include_imported_classes=False, ) - testgen_markdown_code = testgen_context.markdown - testgen_code_token_length = encoded_tokens_len(testgen_markdown_code) - if testgen_code_token_length > testgen_token_limit: + + if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: raise ValueError("Testgen code context has exceeded token limit, cannot proceed") code_hash_context = hashing_code_context.markdown code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest() @@ -193,264 +207,6 @@ def get_code_optimization_context( ) -def get_code_optimization_context_for_language( - function_to_optimize: FunctionToOptimize, - project_root_path: Path, - optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT, - testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT, -) -> CodeOptimizationContext: - """Extract code optimization context for non-Python languages. - - Uses the language support abstraction to extract code context and converts - it to the CodeOptimizationContext format expected by the pipeline. - - This function supports multi-file context extraction, grouping helpers by file - and creating proper CodeStringsMarkdown with file paths for multi-file replacement. - - Args: - function_to_optimize: The function to extract context for. - project_root_path: Root of the project. - optim_token_limit: Token limit for optimization context. - testgen_token_limit: Token limit for testgen context. - - Returns: - CodeOptimizationContext with target code and dependencies. - - """ - from codeflash.languages import get_language_support - from codeflash.languages.base import FunctionInfo, ParentInfo - - # Get language support for this function - language = Language(function_to_optimize.language) - lang_support = get_language_support(language) - - # Convert FunctionToOptimize to FunctionInfo for language support - parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents) - func_info = FunctionInfo( - name=function_to_optimize.function_name, - file_path=function_to_optimize.file_path, - start_line=function_to_optimize.starting_line or 1, - end_line=function_to_optimize.ending_line or 1, - parents=parents, - is_async=function_to_optimize.is_async, - is_method=len(function_to_optimize.parents) > 0, - language=language, - ) - - # Extract code context using language support - code_context = lang_support.extract_code_context(func_info, project_root_path, project_root_path) - - # Build imports string if available - imports_code = "\n".join(code_context.imports) if code_context.imports else "" - - # Get relative path for target file - try: - target_relative_path = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve()) - except ValueError: - target_relative_path = function_to_optimize.file_path - - # Group helpers by file path - helpers_by_file: dict[Path, list] = defaultdict(list) - helper_function_sources = [] - - for helper in code_context.helper_functions: - helpers_by_file[helper.file_path].append(helper) - - # Convert to FunctionSource for pipeline compatibility - helper_function_sources.append( - FunctionSource( - file_path=helper.file_path, - qualified_name=helper.qualified_name, - fully_qualified_name=helper.qualified_name, - only_function_name=helper.name, - source_code=helper.source_code, - jedi_definition=None, - ) - ) - - # Build read-writable code (target file + same-file helpers + global variables) - read_writable_code_strings = [] - - # Combine target code with same-file helpers - target_file_code = code_context.target_code - same_file_helpers = helpers_by_file.get(function_to_optimize.file_path, []) - if same_file_helpers: - helper_code = "\n\n".join(h.source_code for h in same_file_helpers) - target_file_code = target_file_code + "\n\n" + helper_code - - # Add global variables (module-level declarations) referenced by the function and helpers - # These should be included in read-writable context so AI can modify them if needed - if code_context.read_only_context: - target_file_code = code_context.read_only_context + "\n\n" + target_file_code - - # Add imports to target file code - if imports_code: - target_file_code = imports_code + "\n\n" + target_file_code - - read_writable_code_strings.append( - CodeString(code=target_file_code, file_path=target_relative_path, language=function_to_optimize.language) - ) - - # Add helper files (cross-file helpers) - for file_path, file_helpers in helpers_by_file.items(): - if file_path == function_to_optimize.file_path: - continue # Already included in target file - - try: - helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve()) - except ValueError: - helper_relative_path = file_path - - # Combine all helpers from this file - combined_helper_code = "\n\n".join(h.source_code for h in file_helpers) - - read_writable_code_strings.append( - CodeString( - code=combined_helper_code, file_path=helper_relative_path, language=function_to_optimize.language - ) - ) - - read_writable_code = CodeStringsMarkdown( - code_strings=read_writable_code_strings, language=function_to_optimize.language - ) - - # Build testgen context (same as read_writable for non-Python) - testgen_context = CodeStringsMarkdown( - code_strings=read_writable_code_strings.copy(), language=function_to_optimize.language - ) - - # Check token limits - read_writable_tokens = encoded_tokens_len(read_writable_code.markdown) - if read_writable_tokens > optim_token_limit: - raise ValueError("Read-writable code has exceeded token limit, cannot proceed") - - testgen_tokens = encoded_tokens_len(testgen_context.markdown) - if testgen_tokens > testgen_token_limit: - raise ValueError("Testgen code context has exceeded token limit, cannot proceed") - - # Generate code hash from all read-writable code - code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest() - - return CodeOptimizationContext( - testgen_context=testgen_context, - read_writable_code=read_writable_code, - # Global variables are now included in read-writable code, so don't duplicate in read-only - read_only_context_code="", - hashing_code_context=read_writable_code.flat, - hashing_code_context_hash=code_hash, - helper_functions=helper_function_sources, - preexisting_objects=set(), # Not implemented for non-Python yet - ) - - -def extract_code_string_context_from_files( - helpers_of_fto: dict[Path, set[FunctionSource]], - helpers_of_helpers: dict[Path, set[FunctionSource]], - project_root_path: Path, - remove_docstrings: bool = False, # noqa: FBT001, FBT002 - code_context_type: CodeContextType = CodeContextType.READ_ONLY, -) -> CodeString: - """Extract code context from files containing target functions and their helpers. - This function processes two sets of files: - 1. Files containing the function to optimize (fto) and their first-degree helpers - 2. Files containing only helpers of helpers (with no overlap with the first set). - - For each file, it extracts relevant code based on the specified context type, adds necessary - imports, and combines them. - - Args: - ---- - helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers - helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions - project_root_path: Root path of the project - remove_docstrings: Whether to remove docstrings from the extracted code - code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) - - Returns: - ------- - CodeString containing the extracted code context with necessary imports - - """ # noqa: D205 - # Rearrange to remove overlaps, so we only access each file path once - helpers_of_helpers_no_overlap = defaultdict(set) - for file_path, function_sources in helpers_of_helpers.items(): - if file_path in helpers_of_fto: - # Remove duplicates within the same file path, in case a helper of helper is also a helper of fto - helpers_of_helpers[file_path] -= helpers_of_fto[file_path] - else: - helpers_of_helpers_no_overlap[file_path] = function_sources - - final_code_string_context = "" - - # Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files - for file_path, function_sources in helpers_of_fto.items(): - try: - original_code = file_path.read_text("utf8") - except Exception as e: - logger.exception(f"Error while parsing {file_path}: {e}") - continue - try: - qualified_function_names = {func.qualified_name for func in function_sources} - helpers_of_helpers_qualified_names = { - func.qualified_name for func in helpers_of_helpers.get(file_path, set()) - } - code_without_unused_defs = remove_unused_definitions_by_function_names( - original_code, qualified_function_names | helpers_of_helpers_qualified_names - ) - code_context = parse_code_and_prune_cst( - code_without_unused_defs, - code_context_type, - qualified_function_names, - helpers_of_helpers_qualified_names, - remove_docstrings, - ) - except ValueError as e: - logger.debug(f"Error while getting read-only code: {e}") - continue - if code_context.strip(): - final_code_string_context += f"\n{code_context}" - final_code_string_context = add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=final_code_string_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())), - ) - if code_context_type == CodeContextType.READ_WRITABLE: - return CodeString(code=final_code_string_context) - # Extract code from file paths containing helpers of helpers - for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items(): - try: - original_code = file_path.read_text("utf8") - except Exception as e: - logger.exception(f"Error while parsing {file_path}: {e}") - continue - try: - qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} - code_without_unused_defs = remove_unused_definitions_by_function_names( - original_code, qualified_helper_function_names - ) - code_context = parse_code_and_prune_cst( - code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings - ) - except ValueError as e: - logger.debug(f"Error while getting read-only code: {e}") - continue - - if code_context.strip(): - final_code_string_context += f"\n{code_context}" - final_code_string_context = add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=final_code_string_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), - ) - return CodeString(code=final_code_string_context) - - def extract_code_markdown_context_from_files( helpers_of_fto: dict[Path, set[FunctionSource]], helpers_of_helpers: dict[Path, set[FunctionSource]], @@ -685,6 +441,10 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro the LLM understand the actual class structure (constructors, methods, inheritance) rather than just seeing import statements. + Also recursively extracts base classes when a class inherits from another class + in the same module, ensuring the full inheritance chain is available for + understanding constructor signatures. + Args: code_context: The already extracted code context containing imports project_root_path: Root path of the project @@ -727,6 +487,68 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro class_code_strings: list[CodeString] = [] + module_cache: dict[Path, tuple[str, ast.Module]] = {} + + def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | None: + if module_path in module_cache: + return module_cache[module_path] + try: + module_source = module_path.read_text(encoding="utf-8") + module_tree = ast.parse(module_source) + except Exception: + return None + else: + module_cache[module_path] = (module_source, module_tree) + return module_source, module_tree + + def extract_class_and_bases( + class_name: str, module_path: Path, module_source: str, module_tree: ast.Module + ) -> None: + """Extract a class and its base classes recursively from the same module.""" + # Skip if already extracted + if (module_path, class_name) in extracted_classes: + return + + # Find the class definition in the module + class_node = None + for node in ast.walk(module_tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + class_node = node + break + + if class_node is None: + return + + # First, recursively extract base classes from the same module + for base in class_node.bases: + base_name = None + if isinstance(base, ast.Name): + base_name = base.id + elif isinstance(base, ast.Attribute): + # For module.ClassName, we skip (cross-module inheritance) + continue + + if base_name and base_name not in existing_definitions: + # Check if base class is defined in the same module + extract_class_and_bases(base_name, module_path, module_source, module_tree) + + # Now extract this class (after its bases, so base classes appear first) + if (module_path, class_name) in extracted_classes: + return # Already added by another path + + lines = module_source.split("\n") + start_line = class_node.lineno + if class_node.decorator_list: + start_line = min(d.lineno for d in class_node.decorator_list) + class_source = "\n".join(lines[start_line - 1 : class_node.end_lineno]) + + # Extract imports for the class + class_imports = extract_imports_for_class(module_tree, class_node, module_source) + full_source = class_imports + "\n\n" + class_source if class_imports else class_source + + class_code_strings.append(CodeString(code=full_source, file_path=module_path)) + extracted_classes.add((module_path, class_name)) + for name, module_name in imported_names.items(): # Skip if already defined in context if name in existing_definitions: @@ -752,28 +574,14 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro if path_belongs_to_site_packages(module_path): continue - # Skip if we've already extracted this class - if (module_path, name) in extracted_classes: + # Get module source and tree + result = get_module_source_and_tree(module_path) + if result is None: continue + module_source, module_tree = result - # Parse the module to find the class definition - module_source = module_path.read_text(encoding="utf-8") - module_tree = ast.parse(module_source) - - for node in ast.walk(module_tree): - if isinstance(node, ast.ClassDef) and node.name == name: - # Extract the class source code - lines = module_source.split("\n") - class_source = "\n".join(lines[node.lineno - 1 : node.end_lineno]) - - # Also extract any necessary imports for the class (base classes, type hints) - class_imports = _extract_imports_for_class(module_tree, node, module_source) - - full_source = class_imports + "\n\n" + class_source if class_imports else class_source - - class_code_strings.append(CodeString(code=full_source, file_path=module_path)) - extracted_classes.add((module_path, name)) - break + # Extract the class and its base classes + extract_class_and_bases(name, module_path, module_source, module_tree) except Exception: logger.debug(f"Error extracting class definition for {name} from {module_name}") @@ -782,10 +590,111 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro return CodeStringsMarkdown(code_strings=class_code_strings) -def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str: +def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown: + """Extract __init__ methods from external library base classes. + + Scans the code context for classes that inherit from external libraries and extracts + just their __init__ methods. This helps the LLM understand constructor signatures + for mocking or instantiation. + """ + import importlib + import inspect + import textwrap + + all_code = "\n".join(cs.code for cs in code_context.code_strings) + + try: + tree = ast.parse(all_code) + except SyntaxError: + return CodeStringsMarkdown(code_strings=[]) + + imported_names: dict[str, str] = {} + external_bases: list[tuple[str, str]] = [] + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom) and node.module: + for alias in node.names: + if alias.name != "*": + imported_name = alias.asname if alias.asname else alias.name + imported_names[imported_name] = node.module + elif isinstance(node, ast.ClassDef): + for base in node.bases: + base_name = None + if isinstance(base, ast.Name): + base_name = base.id + elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name): + base_name = base.attr + + if base_name and base_name in imported_names: + module_name = imported_names[base_name] + if not _is_project_module(module_name, project_root_path): + external_bases.append((base_name, module_name)) + + if not external_bases: + return CodeStringsMarkdown(code_strings=[]) + + code_strings: list[CodeString] = [] + extracted: set[tuple[str, str]] = set() + + for base_name, module_name in external_bases: + if (module_name, base_name) in extracted: + continue + + try: + module = importlib.import_module(module_name) + base_class = getattr(module, base_name, None) + if base_class is None: + continue + + init_method = getattr(base_class, "__init__", None) + if init_method is None: + continue + + try: + init_source = inspect.getsource(init_method) + init_source = textwrap.dedent(init_source) + class_file = Path(inspect.getfile(base_class)) + parts = class_file.parts + if "site-packages" in parts: + idx = parts.index("site-packages") + class_file = Path(*parts[idx + 1 :]) + except (OSError, TypeError): + continue + + class_source = f"class {base_name}:\n" + textwrap.indent(init_source, " ") + code_strings.append(CodeString(code=class_source, file_path=class_file)) + extracted.add((module_name, base_name)) + + except (ImportError, ModuleNotFoundError, AttributeError): + logger.debug(f"Failed to extract __init__ for {module_name}.{base_name}") + continue + + return CodeStringsMarkdown(code_strings=code_strings) + + +def _is_project_module(module_name: str, project_root_path: Path) -> bool: + """Check if a module is part of the project (not external/stdlib).""" + import importlib.util + + try: + spec = importlib.util.find_spec(module_name) + except (ImportError, ModuleNotFoundError, ValueError): + return False + else: + if spec is None or spec.origin is None: + return False + module_path = Path(spec.origin) + # Check if the module is in site-packages (external dependency) + # This must be checked first because .venv/site-packages is under project root + if path_belongs_to_site_packages(module_path): + return False + # Check if the module is within the project root + return str(module_path).startswith(str(project_root_path) + os.sep) + + +def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str: """Extract import statements needed for a class definition. - This extracts imports for base classes and commonly used type annotations. + This extracts imports for base classes, decorators, and type annotations. """ needed_names: set[str] = set() @@ -797,35 +706,139 @@ def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef # For things like abc.ABC, we need the module name needed_names.add(base.value.id) + # Get decorator names (e.g., dataclass, field) + for decorator in class_node.decorator_list: + if isinstance(decorator, ast.Name): + needed_names.add(decorator.id) + elif isinstance(decorator, ast.Call): + if isinstance(decorator.func, ast.Name): + needed_names.add(decorator.func.id) + elif isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name): + needed_names.add(decorator.func.value.id) + + # Get type annotation names from class body (for dataclass fields) + for item in ast.walk(class_node): + if isinstance(item, ast.AnnAssign) and item.annotation: + collect_names_from_annotation(item.annotation, needed_names) + # Also check for field() calls which are common in dataclasses + if isinstance(item, ast.Call) and isinstance(item.func, ast.Name): + needed_names.add(item.func.id) + # Find imports that provide these names import_lines: list[str] = [] source_lines = module_source.split("\n") + added_imports: set[int] = set() # Track line numbers to avoid duplicates for node in module_tree.body: if isinstance(node, ast.Import): for alias in node.names: name = alias.asname if alias.asname else alias.name.split(".")[0] - if name in needed_names: + if name in needed_names and node.lineno not in added_imports: import_lines.append(source_lines[node.lineno - 1]) + added_imports.add(node.lineno) break elif isinstance(node, ast.ImportFrom): for alias in node.names: name = alias.asname if alias.asname else alias.name - if name in needed_names: + if name in needed_names and node.lineno not in added_imports: import_lines.append(source_lines[node.lineno - 1]) + added_imports.add(node.lineno) break return "\n".join(import_lines) +def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None: + """Recursively collect type annotation names from an AST node.""" + if isinstance(node, ast.Name): + names.add(node.id) + elif isinstance(node, ast.Subscript): + collect_names_from_annotation(node.value, names) + collect_names_from_annotation(node.slice, names) + elif isinstance(node, ast.Tuple): + for elt in node.elts: + collect_names_from_annotation(elt, names) + elif isinstance(node, ast.BinOp): # For Union types with | syntax + collect_names_from_annotation(node.left, names) + collect_names_from_annotation(node.right, names) + elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name): + names.add(node.value.id) + + def is_dunder_method(name: str) -> bool: return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__") -def get_section_names(node: cst.CSTNode) -> list[str]: - """Returns the section attribute names (e.g., body, orelse) for a given node if they exist.""" # noqa: D401 - possible_sections = ["body", "orelse", "finalbody", "handlers"] - return [sec for sec in possible_sections if hasattr(node, sec)] +class UsedNameCollector(cst.CSTVisitor): + """Collects all base names referenced in code (for import preservation).""" + + def __init__(self) -> None: + self.used_names: set[str] = set() + self.defined_names: set[str] = set() + + def visit_Name(self, node: cst.Name) -> None: + self.used_names.add(node.value) + + def visit_Attribute(self, node: cst.Attribute) -> bool | None: + base = node.value + while isinstance(base, cst.Attribute): + base = base.value + if isinstance(base, cst.Name): + self.used_names.add(base.value) + return True + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: + self.defined_names.add(node.name.value) + return True + + def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: + self.defined_names.add(node.name.value) + return True + + def visit_Assign(self, node: cst.Assign) -> bool | None: + for target in node.targets: + names = extract_names_from_targets(target.target) + self.defined_names.update(names) + return True + + def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: + names = extract_names_from_targets(node.target) + self.defined_names.update(names) + return True + + def get_external_names(self) -> set[str]: + return self.used_names - self.defined_names - {"self", "cls"} + + +def get_imported_names(import_node: cst.Import | cst.ImportFrom) -> set[str]: + """Extract the names made available by an import statement.""" + names: set[str] = set() + if isinstance(import_node, cst.Import): + if isinstance(import_node.names, cst.ImportStar): + return {"*"} + for alias in import_node.names: + if isinstance(alias, cst.ImportAlias): + if alias.asname and isinstance(alias.asname.name, cst.Name): + names.add(alias.asname.name.value) + elif isinstance(alias.name, cst.Name): + names.add(alias.name.value) + elif isinstance(alias.name, cst.Attribute): + # import foo.bar -> accessible as "foo" + base: cst.BaseExpression = alias.name + while isinstance(base, cst.Attribute): + base = base.value + if isinstance(base, cst.Name): + names.add(base.value) + elif isinstance(import_node, cst.ImportFrom): + if isinstance(import_node.names, cst.ImportStar): + return {"*"} + for alias in import_node.names: + if isinstance(alias, cst.ImportAlias): + if alias.asname and isinstance(alias.asname.name, cst.Name): + names.add(alias.asname.name.value) + elif isinstance(alias.name, cst.Name): + names.add(alias.name.value) + return names def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode: @@ -852,12 +865,22 @@ def parse_code_and_prune_cst( if code_context_type == CodeContextType.READ_WRITABLE: filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions, defs_with_usages) elif code_context_type == CodeContextType.READ_ONLY: - filtered_node, found_target = prune_cst_for_read_only_code( - module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings + filtered_node, found_target = prune_cst_for_context( + module, + target_functions, + helpers_of_helper_functions, + remove_docstrings=remove_docstrings, + include_target_in_output=False, + include_init_dunder=False, ) elif code_context_type == CodeContextType.TESTGEN: - filtered_node, found_target = prune_cst_for_testgen_code( - module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings + filtered_node, found_target = prune_cst_for_context( + module, + target_functions, + helpers_of_helper_functions, + remove_docstrings=remove_docstrings, + include_target_in_output=True, + include_init_dunder=True, ) elif code_context_type == CodeContextType.HASHING: filtered_node, found_target = prune_cst_for_code_hashing(module, target_functions) @@ -899,10 +922,29 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911 # Do not recurse into nested classes if prefix: return None, False + + class_name = node.name.value + # Assuming always an IndentedBlock if not isinstance(node.body, cst.IndentedBlock): raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 - class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value + class_prefix = f"{prefix}.{class_name}" if prefix else class_name + + # Check if this class contains any target functions + has_target_functions = any( + isinstance(stmt, cst.FunctionDef) and f"{class_prefix}.{stmt.name.value}" in target_functions + for stmt in node.body.body + ) + + # If the class is used as a dependency (not containing target functions), keep it entirely + # This handles cases like enums, dataclasses, and other types used by the target function + if ( + not has_target_functions + and class_name in defs_with_usages + and defs_with_usages[class_name].used_by_qualified_function + ): + return node, True + new_body = [] found_target = False @@ -1062,17 +1104,29 @@ def prune_cst_for_code_hashing( # noqa: PLR0911 return (node.with_changes(**updates) if updates else node), True -def prune_cst_for_read_only_code( # noqa: PLR0911 +def prune_cst_for_context( # noqa: PLR0911 node: cst.CSTNode, target_functions: set[str], helpers_of_helper_functions: set[str], prefix: str = "", remove_docstrings: bool = False, # noqa: FBT001, FBT002 + include_target_in_output: bool = False, # noqa: FBT001, FBT002 + include_init_dunder: bool = False, # noqa: FBT001, FBT002 ) -> tuple[cst.CSTNode | None, bool]: - """Recursively filter the node for read-only context. + """Recursively filter the node for code context extraction. - Returns - ------- + Args: + node: The CST node to filter + target_functions: Set of qualified function names that are targets + helpers_of_helper_functions: Set of helper function qualified names + prefix: Current qualified name prefix (for class methods) + remove_docstrings: Whether to remove docstrings from output + include_target_in_output: If True, include target functions in output (testgen mode) + If False, exclude target functions (read-only mode) + include_init_dunder: If True, include __init__ in dunder methods (testgen mode) + If False, exclude __init__ from dunder methods (read-only mode) + + Returns: (filtered_node, found_target): filtered_node: The modified CST node or None if it should be removed. found_target: True if a target function was found in this node's subtree. @@ -1083,17 +1137,28 @@ def prune_cst_for_read_only_code( # noqa: PLR0911 if isinstance(node, cst.FunctionDef): qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value - # If it's a target function, remove it but mark found_target = True + + # Check if it's a helper of helper function if qualified_name in helpers_of_helper_functions: + if remove_docstrings and isinstance(node.body, cst.IndentedBlock): + return node.with_changes(body=remove_docstring_from_body(node.body)), True return node, True + + # Check if it's a target function if qualified_name in target_functions: + if include_target_in_output: + if remove_docstrings and isinstance(node.body, cst.IndentedBlock): + return node.with_changes(body=remove_docstring_from_body(node.body)), True + return node, True return None, True - # Keep only dunder methods - if is_dunder_method(node.name.value) and node.name.value != "__init__": + + # Check dunder methods + # For read-only mode, exclude __init__; for testgen mode, include all dunders + if is_dunder_method(node.name.value) and (include_init_dunder or node.name.value != "__init__"): if remove_docstrings and isinstance(node.body, cst.IndentedBlock): - new_body = remove_docstring_from_body(node.body) - return node.with_changes(body=new_body), False + return node.with_changes(body=remove_docstring_from_body(node.body)), False return node, False + return None, False if isinstance(node, cst.ClassDef): @@ -1110,8 +1175,14 @@ def prune_cst_for_read_only_code( # noqa: PLR0911 found_in_class = False new_class_body: list[CSTNode] = [] for stmt in node.body.body: - filtered, found_target = prune_cst_for_read_only_code( - stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings + filtered, found_target = prune_cst_for_context( + stmt, + target_functions, + helpers_of_helper_functions, + class_prefix, + remove_docstrings=remove_docstrings, + include_target_in_output=include_target_in_output, + include_init_dunder=include_init_dunder, ) found_in_class |= found_target if filtered: @@ -1140,8 +1211,14 @@ def prune_cst_for_read_only_code( # noqa: PLR0911 new_children = [] section_found_target = False for child in original_content: - filtered, found_target = prune_cst_for_read_only_code( - child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings + filtered, found_target = prune_cst_for_context( + child, + target_functions, + helpers_of_helper_functions, + prefix, + remove_docstrings=remove_docstrings, + include_target_in_output=include_target_in_output, + include_init_dunder=include_init_dunder, ) if filtered: new_children.append(filtered) @@ -1151,122 +1228,19 @@ def prune_cst_for_read_only_code( # noqa: PLR0911 found_any_target |= section_found_target updates[section] = new_children elif original_content is not None: - filtered, found_target = prune_cst_for_read_only_code( - original_content, - target_functions, - helpers_of_helper_functions, - prefix, - remove_docstrings=remove_docstrings, - ) - found_any_target |= found_target - if filtered: - updates[section] = filtered - if updates: - return (node.with_changes(**updates), found_any_target) - - return None, False - - -def prune_cst_for_testgen_code( # noqa: PLR0911 - node: cst.CSTNode, - target_functions: set[str], - helpers_of_helper_functions: set[str], - prefix: str = "", - remove_docstrings: bool = False, # noqa: FBT001, FBT002 -) -> tuple[cst.CSTNode | None, bool]: - """Recursively filter the node for testgen context. - - Returns - ------- - (filtered_node, found_target): - filtered_node: The modified CST node or None if it should be removed. - found_target: True if a target function was found in this node's subtree. - - """ - if isinstance(node, (cst.Import, cst.ImportFrom)): - return None, False - - if isinstance(node, cst.FunctionDef): - qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value - # If it's a target function, remove it but mark found_target = True - if qualified_name in helpers_of_helper_functions or qualified_name in target_functions: - if remove_docstrings and isinstance(node.body, cst.IndentedBlock): - new_body = remove_docstring_from_body(node.body) - return node.with_changes(body=new_body), True - return node, True - # Keep all dunder methods - if is_dunder_method(node.name.value): - if remove_docstrings and isinstance(node.body, cst.IndentedBlock): - new_body = remove_docstring_from_body(node.body) - return node.with_changes(body=new_body), False - return node, False - return None, False - - if isinstance(node, cst.ClassDef): - # Do not recurse into nested classes - if prefix: - return None, False - # Assuming always an IndentedBlock - if not isinstance(node.body, cst.IndentedBlock): - raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 - - class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value - - # First pass: detect if there is a target function in the class - found_in_class = False - new_class_body: list[CSTNode] = [] - for stmt in node.body.body: - filtered, found_target = prune_cst_for_testgen_code( - stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings - ) - found_in_class |= found_target - if filtered: - new_class_body.append(filtered) - - if not found_in_class: - return None, False - - if remove_docstrings: - return node.with_changes( - body=remove_docstring_from_body(node.body.with_changes(body=new_class_body)) - ) if new_class_body else None, True - return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True - - # For other nodes, keep the node and recursively filter children - section_names = get_section_names(node) - if not section_names: - return node, False - - updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} - found_any_target = False - - for section in section_names: - original_content = getattr(node, section, None) - if isinstance(original_content, (list, tuple)): - new_children = [] - section_found_target = False - for child in original_content: - filtered, found_target = prune_cst_for_testgen_code( - child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings - ) - if filtered: - new_children.append(filtered) - section_found_target |= found_target - - if section_found_target or new_children: - found_any_target |= section_found_target - updates[section] = new_children - elif original_content is not None: - filtered, found_target = prune_cst_for_testgen_code( + filtered, found_target = prune_cst_for_context( original_content, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings, + include_target_in_output=include_target_in_output, + include_init_dunder=include_init_dunder, ) found_any_target |= found_target if filtered: updates[section] = filtered + if updates: return (node.with_changes(**updates), found_any_target) diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 64b52cba3..51baa90e0 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -295,11 +295,18 @@ class DependencyCollector(cst.CSTVisitor): return if name in self.definitions and name != self.current_top_level_name: - # skip if we are refrencing a class attribute and not a top-level definition + # Skip if this Name is the .attr part of an Attribute (e.g., 'x' in 'self.x') + # We only want to track the base/value of attribute access, not the attribute name itself if self.class_depth > 0: parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) if parent is not None and isinstance(parent, cst.Attribute): - return + # Check if this Name is the .attr (property name), not the .value (base) + # If it's the .attr, skip it - attribute names aren't references to definitions + if parent.attr is node: + return + # If it's the .value (base), only skip if it's self/cls + if name in ("self", "cls"): + return self.definitions[self.current_top_level_name].dependencies.add(name) @@ -553,16 +560,6 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na return code -def print_definitions(definitions: dict[str, UsageInfo]) -> None: - """Print information about each definition without the complex node object, used for debugging.""" - print(f"Found {len(definitions)} definitions:") - for name, info in sorted(definitions.items()): - print(f" - Name: {name}") - print(f" Used by qualified function: {info.used_by_qualified_function}") - print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}") - print() - - def revert_unused_helper_functions( project_root: Path, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str] ) -> None: @@ -637,43 +634,40 @@ def _analyze_imports_in_optimized_code( func_name = helper.only_function_name module_name = helper.file_path.stem # Cache function lookup for this (module, func) - file_entry = helpers_by_file_and_func[module_name] - if func_name in file_entry: - file_entry[func_name].append(helper) - else: - file_entry[func_name] = [helper] + helpers_by_file_and_func[module_name].setdefault(func_name, []).append(helper) helpers_by_file[module_name].append(helper) - # Optimize attribute lookups and method binding outside the loop - helpers_by_file_and_func_get = helpers_by_file_and_func.get - helpers_by_file_get = helpers_by_file.get - for node in ast.walk(optimized_ast): if isinstance(node, ast.ImportFrom): # Handle "from module import function" statements module_name = node.module if module_name: - file_entry = helpers_by_file_and_func_get(module_name, None) + file_entry = helpers_by_file_and_func.get(module_name) if file_entry: for alias in node.names: imported_name = alias.asname if alias.asname else alias.name original_name = alias.name - helpers = file_entry.get(original_name, None) + helpers = file_entry.get(original_name) if helpers: + imported_set = imported_names_map[imported_name] for helper in helpers: - imported_names_map[imported_name].add(helper.qualified_name) - imported_names_map[imported_name].add(helper.fully_qualified_name) + imported_set.add(helper.qualified_name) + imported_set.add(helper.fully_qualified_name) elif isinstance(node, ast.Import): # Handle "import module" statements for alias in node.names: imported_name = alias.asname if alias.asname else alias.name module_name = alias.name - for helper in helpers_by_file_get(module_name, []): - # For "import module" statements, functions would be called as module.function - full_call = f"{imported_name}.{helper.only_function_name}" - imported_names_map[full_call].add(helper.qualified_name) - imported_names_map[full_call].add(helper.fully_qualified_name) + helpers = helpers_by_file.get(module_name) + if helpers: + imported_set = imported_names_map[f"{imported_name}.{{func}}"] + for helper in helpers: + # For "import module" statements, functions would be called as module.function + full_call = f"{imported_name}.{helper.only_function_name}" + full_call_set = imported_names_map[full_call] + full_call_set.add(helper.qualified_name) + full_call_set.add(helper.fully_qualified_name) return dict(imported_names_map) @@ -758,27 +752,31 @@ def detect_unused_helper_functions( called_name = node.func.id called_function_names.add(called_name) # Also add the qualified name if this is an imported function - if called_name in imported_names_map: - called_function_names.update(imported_names_map[called_name]) + mapped_names = imported_names_map.get(called_name) + if mapped_names: + called_function_names.update(mapped_names) elif isinstance(node.func, ast.Attribute): # Method call: obj.method() or self.method() or module.function() if isinstance(node.func.value, ast.Name): - if node.func.value.id == "self": + attr_name = node.func.attr + value_id = node.func.value.id + if value_id == "self": # self.method_name() -> add both method_name and ClassName.method_name - called_function_names.add(node.func.attr) + called_function_names.add(attr_name) + # For class methods, also add the qualified name # For class methods, also add the qualified name if hasattr(function_to_optimize, "parents") and function_to_optimize.parents: class_name = function_to_optimize.parents[0].name - called_function_names.add(f"{class_name}.{node.func.attr}") + called_function_names.add(f"{class_name}.{attr_name}") else: - # obj.method() or module.function() - attr_name = node.func.attr called_function_names.add(attr_name) - called_function_names.add(f"{node.func.value.id}.{attr_name}") + full_call = f"{value_id}.{attr_name}" + called_function_names.add(full_call) # Check if this is a module.function call that maps to a helper - full_call = f"{node.func.value.id}.{attr_name}" - if full_call in imported_names_map: - called_function_names.update(imported_names_map[full_call]) + mapped_names = imported_names_map.get(full_call) + if mapped_names: + called_function_names.update(mapped_names) + # Handle nested attribute access like obj.attr.method() # Handle nested attribute access like obj.attr.method() else: called_function_names.add(node.func.attr) @@ -788,6 +786,7 @@ def detect_unused_helper_functions( # Find helper functions that are no longer called unused_helpers = [] + entrypoint_file_path = function_to_optimize.file_path for helper_function in code_context.helper_functions: jedi_type = helper_function.jedi_definition.type if helper_function.jedi_definition else None if jedi_type != "class": # Include when jedi_definition is None (non-Python) @@ -796,29 +795,30 @@ def detect_unused_helper_functions( helper_simple_name = helper_function.only_function_name helper_fully_qualified_name = helper_function.fully_qualified_name - # Create a set of all possible names this helper might be called by - possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name} - + # Check membership efficiently - exit early on first match + if ( + helper_qualified_name in called_function_names + or helper_simple_name in called_function_names + or helper_fully_qualified_name in called_function_names + ): + is_called = True # For cross-file helpers, also consider module-based calls - if helper_function.file_path != function_to_optimize.file_path: + elif helper_function.file_path != entrypoint_file_path: # Add potential module.function combinations module_name = helper_function.file_path.stem - possible_call_names.add(f"{module_name}.{helper_simple_name}") - - # Check if any of the possible names are in the called functions - is_called = bool(possible_call_names.intersection(called_function_names)) + module_call = f"{module_name}.{helper_simple_name}" + is_called = module_call in called_function_names + else: + is_called = False if not is_called: unused_helpers.append(helper_function) logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code") - logger.debug(f" Checked names: {possible_call_names}") else: logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code") - logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}") - - ret_val = unused_helpers except Exception as e: logger.debug(f"Error detecting unused helper functions: {e}") - ret_val = [] - return ret_val + return [] + else: + return unused_helpers diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index ab4c32a37..2fb062eb4 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -700,7 +700,9 @@ class FunctionOptimizer: ): console.rule() new_code_context = code_context - if self.is_numerical_code: # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax) + if ( + self.is_numerical_code and not self.args.no_jit_opts + ): # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax) jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code( code_context.read_writable_code.markdown, self.function_trace_id ) @@ -729,7 +731,7 @@ class FunctionOptimizer: read_writable_code=code_context.read_writable_code, read_only_context_code=code_context.read_only_context_code, run_experiment=should_run_experiment, - is_numerical_code=self.is_numerical_code, + is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts, ) concurrent.futures.wait([future_tests, future_optimizations]) @@ -1251,7 +1253,7 @@ class FunctionOptimizer: ) if self.experiment_id else None, - is_numerical_code=self.is_numerical_code, + is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts, language=self.function_to_optimize.language, ) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 74c1593a7..8c7a7621f 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -138,6 +138,13 @@ def main(args: Namespace | None = None) -> ArgumentParser: env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}" else: env["PYTHONPATH"] = project_root_str + # Disable JIT compilation to ensure tracing captures all function calls + env["NUMBA_DISABLE_JIT"] = str(1) + env["TORCHDYNAMO_DISABLE"] = str(1) + env["PYTORCH_JIT"] = str(0) + env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0" + env["TF_ENABLE_ONEDNN_OPTS"] = str(0) + env["JAX_DISABLE_JIT"] = str(1) processes.append( subprocess.Popen( [ @@ -175,6 +182,13 @@ def main(args: Namespace | None = None) -> ArgumentParser: env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}" else: env["PYTHONPATH"] = project_root_str + # Disable JIT compilation to ensure tracing captures all function calls + env["NUMBA_DISABLE_JIT"] = str(1) + env["TORCHDYNAMO_DISABLE"] = str(1) + env["PYTORCH_JIT"] = str(0) + env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0" + env["TF_ENABLE_ONEDNN_OPTS"] = str(0) + env["JAX_DISABLE_JIT"] = str(1) subprocess.run( [ diff --git a/codeflash/verification/codeflash_capture.py b/codeflash/verification/codeflash_capture.py index 991f4d624..5c2bf4b6f 100644 --- a/codeflash/verification/codeflash_capture.py +++ b/codeflash/verification/codeflash_capture.py @@ -15,6 +15,8 @@ from typing import Callable import dill as pickle from dill import PicklingWarning +from codeflash.picklepatch.pickle_patcher import PicklePatcher + warnings.filterwarnings("ignore", category=PicklingWarning) @@ -148,18 +150,29 @@ def codeflash_capture(function_name: str, tmp_dir_path: str, tests_root: str, is print(f"!######{test_stdout_tag}######!") # Capture instance state after initialization - if hasattr(args[0], "__dict__"): - instance_state = args[ - 0 - ].__dict__ # self is always the first argument, this is ensured during instrumentation + # self is always the first argument, this is ensured during instrumentation + instance = args[0] + if hasattr(instance, "__dict__"): + instance_state = instance.__dict__ + elif hasattr(instance, "__slots__"): + # For classes using __slots__, capture slot values + instance_state = { + slot: getattr(instance, slot, None) for slot in instance.__slots__ if hasattr(instance, slot) + } else: - raise ValueError("Instance state could not be captured.") + # For C extension types or other special classes (e.g., Playwright's Page), + # capture all non-private, non-callable attributes + instance_state = { + attr: getattr(instance, attr) + for attr in dir(instance) + if not attr.startswith("_") and not callable(getattr(instance, attr, None)) + } codeflash_cur.execute( "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)" ) # Write to sqlite - pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(instance_state) + pickled_return_value = pickle.dumps(exception) if exception else PicklePatcher.dumps(instance_state) codeflash_cur.execute( "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", ( diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index e370d35ad..8d256d0da 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -257,6 +257,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False return True + # Handle mappingproxy (read-only dict view, commonly seen as class.__dict__) + if isinstance(orig, types.MappingProxyType): + return comparator(dict(orig), dict(new), superset_obj) + # Handle dict view types (dict_keys, dict_values, dict_items) # Use type name checking since these are not directly importable types type_name = type(orig).__name__ diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 2966fec14..4fa167071 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -18,6 +18,12 @@ reprlib_repr = reprlib.Repr() reprlib_repr.maxstring = 1500 test_diff_repr = reprlib_repr.repr +def safe_repr(obj: object) -> str: + """Safely get repr of an object, handling Mock objects with corrupted state.""" + try: + return repr(obj) + except (AttributeError, TypeError, RecursionError) as e: + return f"" def compare_test_results( original_results: TestResults, @@ -102,8 +108,8 @@ def compare_test_results( test_diffs.append( TestDiff( scope=TestDiffScope.RETURN_VALUE, - original_value=test_diff_repr(repr(original_test_result.return_value)), - candidate_value=test_diff_repr(repr(cdd_test_result.return_value)), + original_value=test_diff_repr(safe_repr(original_test_result.return_value)), + candidate_value=test_diff_repr(safe_repr(cdd_test_result.return_value)), test_src_code=original_test_result.id.get_src_code(original_test_result.file_name), candidate_pytest_error=cdd_pytest_error, original_pass=original_test_result.did_pass, diff --git a/codeflash/version.py b/codeflash/version.py index 2da327166..6225467e3 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.19.1.post300.dev0+ad95a41b" +__version__ = "0.20.0" diff --git a/docs/cli-reference/flags.mdx b/docs/cli-reference/flags.mdx deleted file mode 100644 index ad9fc5f31..000000000 --- a/docs/cli-reference/flags.mdx +++ /dev/null @@ -1,191 +0,0 @@ ---- -title: "Flags Reference" -description: "Complete reference for all Codeflash CLI flags and options" -icon: "list" -sidebarTitle: "Flags Reference" -keywords: ["flags", "options", "arguments", "command line"] ---- - -# Flags Reference - -Complete reference for all Codeflash CLI flags and command-line options. - ---- - -## Main Command Flags - -| Flag | Type | Description | -|------|------|-------------| -| `--file` | `PATH` | Optimize only this file | -| `--function` | `NAME` | Optimize only this function (requires `--file`) | -| `--all` | `[PATH]` | Optimize all functions. Optional path to start from | -| `--replay-test` | `PATH` | Path to replay test file(s) | -| `--benchmark` | flag | Enable benchmark mode | -| `--no-pr` | flag | Don't create PR, update locally | -| `--no-gen-tests` | flag | Don't generate tests | -| `--no-draft` | flag | Skip draft PRs | -| `--worktree` | flag | Use git worktree | -| `--staging-review` | flag | Upload to staging | -| `--verbose` / `-v` | flag | Verbose debug output | -| `--verify-setup` | flag | Run setup verification | -| `--version` | flag | Show version | - ---- - -## Configuration Override Flags - -Override settings from `pyproject.toml` via command line. - -| Flag | Type | Description | -|------|------|-------------| -| `--config-file` | `PATH` | Path to pyproject.toml | -| `--module-root` | `PATH` | Python module root directory | -| `--tests-root` | `PATH` | Tests directory | -| `--benchmarks-root` | `PATH` | Benchmarks directory | - - - - - ```bash - # Override config file location - codeflash --file src/app.py --function main --config-file configs/pyproject.toml --no-pr - - # Override module root - codeflash --file src/app.py --function main --module-root src --no-pr - - # Override tests root - codeflash --file src/app.py --function main --tests-root tests/unit --no-pr - - # Combine multiple overrides - codeflash --file src/app.py --function main \ - --module-root src \ - --tests-root tests \ - --no-pr - ``` - - - ```powershell - # Override config file location - codeflash --file src\app.py --function main --config-file configs\pyproject.toml --no-pr - - # Override module root - codeflash --file src\app.py --function main --module-root src --no-pr - - # Override tests root - codeflash --file src\app.py --function main --tests-root tests\unit --no-pr - ``` - - - - ---- - -## Optimize Subcommand Flags - -Flags specific to the `codeflash optimize` command. - -| Flag | Type | Description | -|------|------|-------------| -| `--output` | `PATH` | Trace file output path (default: `codeflash.trace`) | -| `--timeout` | `INT` | Maximum trace time in seconds | -| `--max-function-count` | `INT` | Max times to trace a function (default: 100) | -| `--config-file-path` | `PATH` | Path to pyproject.toml | -| `--trace-only` | flag | Only trace, don't optimize | - - - The `--output` flag specifies where to save the trace file. If not specified, it defaults to `codeflash.trace` in the current directory. - - ---- - -## Behavior Flags - -Control how Codeflash behaves during optimization. - -| Flag | Description | -|------|-------------| -| `--no-pr` | Run locally without creating a pull request | -| `--no-gen-tests` | Use only existing tests, skip test generation | -| `--no-draft` | Skip optimization for draft PRs (CI mode) | -| `--worktree` | Use git worktree for isolated optimization | -| `--staging-review` | Upload optimizations to staging for review | -| `--verbose` / `-v` | Enable verbose debug logging | - - -```bash -# Local optimization only -codeflash --file src/app.py --function main --no-pr - -# Use only existing tests -codeflash --file src/app.py --function main --no-gen-tests --no-pr - -# Enable verbose logging -codeflash --file src/app.py --function main --verbose --no-pr - -# Use worktree for isolation -codeflash --file src/app.py --function main --worktree --no-pr - -# Upload to staging -codeflash --all --staging-review --no-pr -``` - - ---- - -## Flag Combinations - -Common flag combinations for different use cases: - -### Local Development - -```bash -# Optimize locally with verbose output -codeflash --file src/app.py --function main --no-pr --verbose -``` - -### CI/CD Pipeline - -```bash -# Skip draft PRs and use existing tests only -codeflash --all --no-draft --no-gen-tests -``` - -### Debugging - -```bash -# Trace only with custom output and timeout -codeflash optimize app.py --trace-only --output debug.trace --timeout 60 -``` - -### Custom Configuration - -```bash -# Override multiple config settings -codeflash --file src/app.py --function main \ - --module-root src \ - --tests-root tests/unit \ - --benchmarks-root tests/benchmarks \ - --no-pr -``` - ---- - -## Next Steps - - - - Learn how to use optimization commands - - - Fix common issues - - - diff --git a/docs/cli-reference/index.mdx b/docs/cli-reference/index.mdx deleted file mode 100644 index f91fb2e1b..000000000 --- a/docs/cli-reference/index.mdx +++ /dev/null @@ -1,208 +0,0 @@ ---- -title: "CLI Reference" -description: "Complete command-line reference for Codeflash CLI commands, flags, and options" -icon: "terminal" -sidebarTitle: "Overview" -keywords: - [ - "CLI", - "command line", - "commands", - "flags", - "options", - "reference", - "terminal", - ] ---- - -# Codeflash CLI Reference - -Complete command-line reference for all Codeflash commands, flags, and options with practical examples you can run directly in your terminal. - - - **Prerequisites** - Ensure Codeflash is installed in your Python environment - and you have a configured `pyproject.toml` in your project. - - ---- - -## Quick Start - - - - ```bash - # Activate virtual environment (if using one) - source .venv/bin/activate - - # Verify installation - codeflash --version - ``` - - - ```powershell - # Activate virtual environment (if using one) - .venv\Scripts\activate - - # Verify installation - codeflash --version - ``` - - - ---- - -## Common Workflows - -### 1. First-Time Setup - - - - ```bash - pip install codeflash - ``` - - - ```bash - codeflash init - ``` - - - ```bash - codeflash --verify-setup - ``` - - - ```bash - codeflash --file src/main.py --function my_function --no-pr - ``` - - - ---- - -### 2. Optimize a Workflow - - - - ```bash - codeflash optimize my_script.py --arg1 value1 - ``` - - - Check the generated PR or local changes for optimization suggestions. - - - ---- - -### 3. CI/CD Integration - - - - ```bash - codeflash init-actions - ``` - - - Review and merge the generated GitHub Actions workflow. - - - Codeflash will now optimize code in every PR automatically! - - - ---- - -## Help & Version - -```bash -# Display version -codeflash --version - -# Main help -codeflash --help - -# Subcommand help -codeflash optimize --help -codeflash init --help -``` - ---- - -## Documentation Structure - -This CLI reference is organized into the following sections: - - - - Initialize projects, set up GitHub Actions, and verify installation - - - Optimize single functions or entire codebases - - - Trace script execution and optimize based on real usage - - - Complete reference for all command-line flags - - - Solutions for common CLI issues - - - ---- - -## Next Steps - - - - Learn how to optimize individual functions - - - Optimize entire workflows with tracing - - - Set up continuous optimization - - - Advanced configuration options - - - diff --git a/docs/cli-reference/optimization.mdx b/docs/cli-reference/optimization.mdx deleted file mode 100644 index 34284e4b0..000000000 --- a/docs/cli-reference/optimization.mdx +++ /dev/null @@ -1,172 +0,0 @@ ---- -title: "Optimization Commands" -description: "Optimize single functions or entire codebases with Codeflash CLI" -icon: "bullseye" -sidebarTitle: "Optimization Commands" -keywords: ["optimization", "function", "file", "all", "commands"] ---- - -# Optimization Commands - -Commands for optimizing individual functions or entire codebases. - ---- - -## Optimize a Single Function - -Target a specific function in a file for optimization. - -```bash -codeflash --file --function -``` - - - - - ```bash - # Basic optimization (creates PR) - codeflash --file src/utils.py --function calculate_metrics - - # Local optimization only (no PR) - codeflash --file src/utils.py --function calculate_metrics --no-pr - - # With verbose output - codeflash --file src/utils.py --function calculate_metrics --no-pr --verbose - ``` - - - ```powershell - # Basic optimization (creates PR) - codeflash --file src\utils.py --function calculate_metrics - - # Local optimization only (no PR) - codeflash --file src\utils.py --function calculate_metrics --no-pr - - # With verbose output - codeflash --file src\utils.py --function calculate_metrics --no-pr --verbose - ``` - - - - - - **Important**: The file must be within your configured `module-root` - directory. Files outside `module-root` will be ignored with "Functions outside - module-root" message. - - ---- - -## Optimize All Functions - -Optimize all functions in your entire codebase or a specific directory. - -```bash -# Optimize entire codebase -codeflash --all - -# Optimize specific directory -codeflash --all src/core/ -``` - - - - - ```bash - # Optimize all (creates PRs) - codeflash --all - - # Optimize all locally (no PRs) - codeflash --all --no-pr - - # Optimize specific directory - codeflash --all src/algorithms/ --no-pr - - # Skip draft PRs in CI - codeflash --all --no-draft - ``` - - - ```powershell - # Optimize all (creates PRs) - codeflash --all - - # Optimize all locally (no PRs) - codeflash --all --no-pr - - # Optimize specific directory - codeflash --all src\algorithms\ --no-pr - - # Skip draft PRs in CI - codeflash --all --no-draft - ``` - - - - - - When using `--all`, Codeflash will: - - Discover all optimizable functions in your codebase - - Create separate PRs for each function (or update locally with `--no-pr`) - - Process functions in batches to avoid overwhelming your repository - - ---- - -## Benchmark Mode - -Optimize code based on performance benchmarks using pytest-benchmark format. - -```bash -codeflash --file --benchmark --benchmarks-root -``` - - - - - ```bash - # With benchmarks-root flag - codeflash --file src/core.py --benchmark --benchmarks-root tests/benchmarks --no-pr - - # If benchmarks-root is in pyproject.toml - codeflash --file src/core.py --benchmark --no-pr - ``` - - - ```powershell - # With benchmarks-root flag - codeflash --file src\core.py --benchmark --benchmarks-root tests\benchmarks --no-pr - - # If benchmarks-root is in pyproject.toml - codeflash --file src\core.py --benchmark --no-pr - ``` - - - - - - The `--benchmarks-root` directory must exist and be configured either via - `pyproject.toml` or the command-line flag. - - ---- - -## Next Steps - - - - Learn about trace-based optimization - - - Complete flag reference - - - diff --git a/docs/cli-reference/setup.mdx b/docs/cli-reference/setup.mdx deleted file mode 100644 index c9a3be441..000000000 --- a/docs/cli-reference/setup.mdx +++ /dev/null @@ -1,125 +0,0 @@ ---- -title: "Setup Commands" -description: "Initialize projects, set up GitHub Actions, and verify Codeflash installation" -icon: "wrench" -sidebarTitle: "Setup Commands" -keywords: ["setup", "init", "installation", "github actions", "verify"] ---- - -# Setup Commands - -Commands for initializing Codeflash in your project, setting up continuous optimization, and verifying your installation. - ---- - -## `codeflash init` - -Initialize Codeflash for your Python project. This creates the configuration in `pyproject.toml`. - - -```bash Basic -codeflash init -``` - -```bash With Formatter Override -codeflash init --override-formatter-check -``` - - - - The `init` command will guide you through an interactive setup process, - including API key configuration, module selection, and GitHub App - installation. - - -**What it does:** - -- Prompts for your Python module directory (`module-root`) -- Prompts for your test directory (`tests-root`) -- Configures code formatter preferences -- Sets up telemetry preferences -- Optionally installs the Codeflash VS Code extension -- Optionally sets up GitHub Actions workflow - ---- - -## `codeflash init-actions` - -Set up GitHub Actions workflow for continuous optimization on every pull request. - -```bash -codeflash init-actions -``` - -**What it does:** - -- Creates a workflow file in `.github/workflows/` -- Opens a PR with the workflow configuration -- Requires the Codeflash GitHub App to be installed - - - This command requires the Codeflash GitHub App to be installed on your repository. If you haven't installed it, you'll be prompted with a link to do so. - - ---- - -## `codeflash vscode-install` - -Install the Codeflash extension for VS Code, Cursor, or Windsurf. - -```bash -codeflash vscode-install -``` - -**What it does:** - -- Detects which editor you're using (VS Code, Cursor, or Windsurf) -- Downloads and installs the appropriate extension -- Works with both Marketplace and Open VSX sources - - - This command is also run automatically during `codeflash init` if you choose to install the extension. - - ---- - -## `codeflash --verify-setup` - -Verify your Codeflash installation by running a sample optimization. - -```bash -codeflash --verify-setup -``` - -**What it does:** - -- Creates a temporary demo file -- Runs a sample optimization -- Verifies all components are working correctly -- Cleans up the demo file afterward - - - This command takes about 3 minutes to complete. It's a great way to ensure everything is set up correctly before optimizing your actual code. - - ---- - -## Next Steps - - - - Learn how to optimize functions - - - Complete flag reference - - - diff --git a/docs/cli-reference/tracing.mdx b/docs/cli-reference/tracing.mdx deleted file mode 100644 index c2394a6d5..000000000 --- a/docs/cli-reference/tracing.mdx +++ /dev/null @@ -1,213 +0,0 @@ ---- -title: "Tracing & Workflows" -description: "Trace script execution and optimize functions based on real-world usage" -icon: "route" -sidebarTitle: "Tracing & Workflows" -keywords: ["tracing", "optimize", "workflow", "replay test", "pytest"] ---- - -# Tracing & Workflows - -Trace Python script execution and optimize functions based on real-world usage patterns. - ---- - -## `codeflash optimize` - -Trace a Python script's execution and optimize functions based on real-world usage. - -```bash -codeflash optimize [script_args] -``` - - - - - ```bash - # Basic trace and optimize - codeflash optimize app.py - - # With script arguments - codeflash optimize process.py --input data.csv --output results.json - - # Custom trace output file - codeflash optimize app.py --output custom_trace.trace - - # With timeout (30 seconds) - codeflash optimize long_running_script.py --timeout 30 - - # Limit function trace count - codeflash optimize app.py --max-function-count 50 - - # Specify config file - codeflash optimize app.py --config-file-path pyproject.toml - - # Local only (no PR) - codeflash optimize app.py --no-pr - ``` - - - ```powershell - # Basic trace and optimize - codeflash optimize app.py - - # With script arguments - codeflash optimize process.py --input data.csv --output results.json - - # Custom trace output file - codeflash optimize app.py --output custom_trace.trace - - # With timeout (30 seconds) - codeflash optimize long_running_script.py --timeout 30 - - # Limit function trace count - codeflash optimize app.py --max-function-count 50 - - # Specify config file - codeflash optimize app.py --config-file-path pyproject.toml - - # Local only (no PR) - codeflash optimize app.py --no-pr - ``` - - - - -**How it works:** - -1. Runs your script with the provided arguments -2. Traces all function calls during execution -3. Identifies which functions are called and how often -4. Generates replay tests based on actual usage -5. Optimizes the traced functions - ---- - -## Trace with pytest - -Optimize functions called during pytest test execution. - - - - ```bash - # Trace pytest tests - codeflash optimize -m pytest tests/ - - # Trace specific test file - codeflash optimize -m pytest tests/test_core.py - - # With pytest arguments - codeflash optimize -m pytest tests/ -v --tb=short - ``` - - - ```powershell - # Trace pytest tests - codeflash optimize -m pytest tests\ - - # Trace specific test file - codeflash optimize -m pytest tests\test_core.py - - # With pytest arguments - codeflash optimize -m pytest tests\ -v --tb=short - ``` - - - - - Tracing pytest tests is great for optimizing functions that are heavily used in your test suite, ensuring optimizations work correctly with your existing tests. - - ---- - -## Trace Only (Generate Replay Tests) - -Create trace files and replay tests without running optimization. - - - - ```bash - # Trace only - generates replay test - codeflash optimize app.py --output trace_file.trace --trace-only - - # Then optimize using the replay test - codeflash --replay-test tests/test_app_py__replay_test_0.py --no-pr - ``` - - - ```powershell - # Trace only - generates replay test - codeflash optimize app.py --output trace_file.trace --trace-only - - # Then optimize using the replay test - codeflash --replay-test tests\test_app_py__replay_test_0.py --no-pr - ``` - - - - - **Replay test naming**: Files are named based on the traced script path. For - `src/app.py`, the replay test will be named like - `test_srcapp_py__replay_test_0.py`. - - -**Use cases for trace-only:** - -- Generate replay tests for later optimization -- Debug tracing issues without running full optimization -- Create reusable test cases from script execution - ---- - -## Replay Test Optimization - -Optimize functions using previously generated replay tests. - -```bash -codeflash --replay-test -``` - - - - - ```bash - # Optimize using replay test - codeflash --replay-test tests/test_app_py__replay_test_0.py --no-pr - - # Multiple replay tests - codeflash --replay-test tests/test_*.py --no-pr - ``` - - - ```powershell - # Optimize using replay test - codeflash --replay-test tests\test_app_py__replay_test_0.py --no-pr - - # Multiple replay tests (use Get-ChildItem for globbing) - codeflash --replay-test (Get-ChildItem tests\test_*.py) --no-pr - ``` - - - - ---- - -## Next Steps - - - - Learn about function optimization - - - Complete flag reference - - - diff --git a/docs/cli-reference/troubleshooting.mdx b/docs/cli-reference/troubleshooting.mdx deleted file mode 100644 index 7c8134242..000000000 --- a/docs/cli-reference/troubleshooting.mdx +++ /dev/null @@ -1,157 +0,0 @@ ---- -title: "CLI Troubleshooting" -description: "Solutions for common Codeflash CLI issues and errors" -icon: "wrench" -sidebarTitle: "Troubleshooting" -keywords: ["troubleshooting", "errors", "issues", "problems", "debugging"] ---- - -# CLI Troubleshooting - -Solutions for common issues when using the Codeflash CLI. - ---- - -## Common Issues - - - - **Problem**: Function not found because file is outside `module-root`. - - **Solution**: Ensure your file is within the `module-root` directory specified in `pyproject.toml`. - - ```bash - # Check your module-root - grep "module-root" pyproject.toml - - # Use the correct path (e.g., if module-root is "src") - codeflash --file src/myfile.py --function my_function --no-pr - ``` - - - - **Problem**: Using `--benchmark` without specifying benchmarks directory. - - **Solution**: Either add `benchmarks-root` to `pyproject.toml` or use the flag: - - ```bash - codeflash --file src/app.py --benchmark --benchmarks-root tests/benchmarks --no-pr - ``` - - - - **Problem**: Replay test filename doesn't match expected path. - - **Solution**: Replay tests include the module path in their name. Check the actual filename: - - ```bash - # Linux/macOS - ls tests/test_*replay*.py - - # Windows - dir tests\test_*replay*.py - ``` - - - Replay test files are named based on the traced script path. For `src/app.py`, - the replay test will be named like `test_srcapp_py__replay_test_0.py`. - - - - - **Problem**: PR creation fails due to missing GitHub App. - - **Solution**: Install the Codeflash GitHub App or use `--no-pr` for local optimization: - - ```bash - # Local optimization - codeflash --file src/app.py --function main --no-pr - - # Or install the GitHub App - # https://github.com/apps/codeflash-ai/installations/select_target - ``` - - - - **Problem**: Codeflash can't find your Python modules. - - **Solution**: - - 1. Verify `module-root` is correctly set in `pyproject.toml` - 2. Ensure you're running from the project root - 3. Check that your Python environment has all dependencies installed - - ```bash - # Verify module-root - cat pyproject.toml | grep module-root - - # Check Python path - python -c "import sys; print(sys.path)" - ``` - - - - **Problem**: Codeflash can't generate tests for your function. - - **Solution**: - - 1. Ensure your function has a return statement - 2. Check that the function is not a property or class method with special decorators - 3. Use `--no-gen-tests` to skip test generation and use existing tests only - - ```bash - codeflash --file src/app.py --function main --no-gen-tests --no-pr - ``` - - - - **Problem**: Optimization takes too long or times out. - - **Solution**: - - 1. Use `--verbose` to see what's happening - 2. For tracing, use `--timeout` to limit trace duration - 3. For large functions, consider breaking them down - - ```bash - # Limit trace time - codeflash optimize app.py --timeout 30 - - # See detailed progress - codeflash --file src/app.py --function main --verbose --no-pr - ``` - - - ---- - -## Getting Help - -If you're still experiencing issues: - -1. **Check the logs**: Use `--verbose` flag to see detailed output -2. **Verify setup**: Run `codeflash --verify-setup` to check your installation -3. **Check configuration**: Ensure `pyproject.toml` is correctly configured -4. **View help**: Run `codeflash --help` or `codeflash --help` - ---- - -## Next Steps - - - - Review setup and initialization - - - Complete flag reference - - - diff --git a/docs/docs.json b/docs/docs.json index 32d979600..579a8355c 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -18,35 +18,15 @@ { "tab": "Documentation", "groups": [ - { - "group": "🚀 Quickstart", - "pages": [ - "getting-started/local-installation", - "getting-started/javascript-installation" - ] - }, { "group": "🏠 Overview", "pages": ["index"] }, { - "group": "📖 Codeflash CLI", + "group": "🚀 Quickstart", "pages": [ - "cli-reference/index", - "cli-reference/setup", - "cli-reference/optimization", - "cli-reference/tracing", - "cli-reference/flags", - "cli-reference/troubleshooting" - ] - }, - { - "group": "🛠 IDE Extension", - "pages": [ - "editor-plugins/vscode/index", - "editor-plugins/vscode/features", - "editor-plugins/vscode/configuration", - "editor-plugins/vscode/troubleshooting" + "getting-started/local-installation", + "getting-started/javascript-installation" ] }, { @@ -65,6 +45,15 @@ "optimizing-with-codeflash/review-optimizations" ] }, + { + "group": "🛠 IDE Extension", + "pages": [ + "editor-plugins/vscode/index", + "editor-plugins/vscode/features", + "editor-plugins/vscode/configuration", + "editor-plugins/vscode/troubleshooting" + ] + }, { "group": "🧠 Core Concepts", "pages": [ diff --git a/docs/editor-plugins/vscode/configuration.mdx b/docs/editor-plugins/vscode/configuration.mdx index a19400d37..d8a113a2d 100644 --- a/docs/editor-plugins/vscode/configuration.mdx +++ b/docs/editor-plugins/vscode/configuration.mdx @@ -146,7 +146,4 @@ When configuration issues are detected, the extension displays clear error messa Complete pyproject.toml reference - - Command-line options - diff --git a/docs/editor-plugins/vscode/index.mdx b/docs/editor-plugins/vscode/index.mdx index f598fe594..8cf0f3b10 100644 --- a/docs/editor-plugins/vscode/index.mdx +++ b/docs/editor-plugins/vscode/index.mdx @@ -204,7 +204,6 @@ The extension works alongside the Codeflash CLI. You can: - **Use extension for interactive work** — Optimize individual functions as you code - **Mix both** — The extension picks up CLI results when you return to the editor -For CLI documentation, see the [Codeflash CLI](/cli-reference/index). --- @@ -220,9 +219,6 @@ For CLI documentation, see the [Codeflash CLI](/cli-reference/index). Fix common issues - - Command-line interface docs - --- diff --git a/docs/editor-plugins/vscode/troubleshooting.mdx b/docs/editor-plugins/vscode/troubleshooting.mdx index 7af38446c..09a408ec1 100644 --- a/docs/editor-plugins/vscode/troubleshooting.mdx +++ b/docs/editor-plugins/vscode/troubleshooting.mdx @@ -208,8 +208,5 @@ If you're still experiencing issues: Customize extension settings - - Command-line interface docs - diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 0593c37bc..71db216e4 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -7,13 +7,19 @@ from collections import defaultdict from pathlib import Path import pytest -from codeflash.context.code_context_extractor import get_code_optimization_context, get_imported_class_definitions -from codeflash.models.models import CodeString, CodeStringsMarkdown -from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent -from codeflash.optimization.optimizer import Optimizer + +from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments from codeflash.code_utils.code_replacer import replace_functions_and_add_imports -from codeflash.code_utils.code_extractor import add_global_assignments, GlobalAssignmentCollector +from codeflash.context.code_context_extractor import ( + collect_names_from_annotation, + extract_imports_for_class, + get_code_optimization_context, + get_external_base_class_inits, + get_imported_class_definitions, +) +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent +from codeflash.optimization.optimizer import Optimizer class HelperClass: @@ -86,7 +92,10 @@ def test_code_replacement10() -> None: code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent) qualified_names = {func.qualified_name for func in code_ctx.helper_functions} # HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class - assert qualified_names == {"HelperClass.helper_method", "HelperClass.__init__"} # Nested method should not be in here + assert qualified_names == { + "HelperClass.helper_method", + "HelperClass.__init__", + } # Nested method should not be in here read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context @@ -229,7 +238,7 @@ def test_bubble_sort_helper() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = f""" + expected_read_write_context = """ ```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py import math @@ -1103,7 +1112,9 @@ class HelperClass: code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) # the global x variable shouldn't be included in any context type - assert code_ctx.read_writable_code.flat == '''# file: test_code.py + assert ( + code_ctx.read_writable_code.flat + == '''# file: test_code.py class MyClass: def __init__(self): self.x = 1 @@ -1118,7 +1129,10 @@ class HelperClass: def helper_method(self): return self.x ''' - assert code_ctx.testgen_context.flat == '''# file: test_code.py + ) + assert ( + code_ctx.testgen_context.flat + == '''# file: test_code.py class MyClass: """A class with a helper method. """ def __init__(self): @@ -1138,6 +1152,7 @@ class HelperClass: def helper_method(self): return self.x ''' + ) def test_repo_helper() -> None: @@ -2348,9 +2363,7 @@ def standalone_function(): assert '"""Helper method with docstring."""' not in hashing_context, ( "Docstrings should be removed from helper functions" ) - assert '"""Process data method."""' not in hashing_context, ( - "Docstrings should be removed from helper class methods" - ) + assert '"""Process data method."""' not in hashing_context, "Docstrings should be removed from helper class methods" def test_hashing_code_context_with_nested_classes(tmp_path: Path) -> None: @@ -2588,16 +2601,21 @@ def test_circular_deps(): optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8") content = Path(file_abs_path).read_text(encoding="utf-8") new_code = replace_functions_and_add_imports( - source_code= add_global_assignments(optimized_code, content), - function_names= ["ApiClient.get_console_url"], - optimized_code= optimized_code, - module_abspath= Path(file_abs_path), - preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))}, - project_root_path= Path(path_to_root), + source_code=add_global_assignments(optimized_code, content), + function_names=["ApiClient.get_console_url"], + optimized_code=optimized_code, + module_abspath=Path(file_abs_path), + preexisting_objects={ + ("ApiClient", ()), + ("get_console_url", (FunctionParent(name="ApiClient", type="ClassDef"),)), + }, + project_root_path=Path(path_to_root), ) assert "import ApiClient" not in new_code, "Error: Circular dependency found" assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist" + + def test_global_assignment_collector_with_async_function(): """Test GlobalAssignmentCollector correctly identifies global assignments outside async functions.""" import libcst as cst @@ -2745,6 +2763,380 @@ FINAL_ASSIGNMENT = {"data": "value"} assert collector.assignment_order == expected_order +def test_global_assignment_collector_annotated_assignments(): + """Test GlobalAssignmentCollector correctly handles annotated assignments (AnnAssign).""" + import libcst as cst + + source_code = """ +# Regular global assignment +REGULAR_VAR = "regular" + +# Annotated global assignments +TYPED_VAR: str = "typed" +CACHE: dict[str, int] = {} +SENTINEL: object = object() + +# Annotated without value (type declaration only) - should NOT be collected +DECLARED_ONLY: int + +def some_function(): + # Annotated assignment inside function - should not be collected + local_typed: str = "local" + return local_typed + +class SomeClass: + # Class-level annotated assignment - should not be collected + class_attr: str = "class" + +# Another regular assignment +FINAL_VAR = 123 +""" + + tree = cst.parse_module(source_code) + collector = GlobalAssignmentCollector() + tree.visit(collector) + + # Should collect both regular and annotated global assignments with values + assert len(collector.assignments) == 5 + assert "REGULAR_VAR" in collector.assignments + assert "TYPED_VAR" in collector.assignments + assert "CACHE" in collector.assignments + assert "SENTINEL" in collector.assignments + assert "FINAL_VAR" in collector.assignments + + # Should not collect type declarations without values + assert "DECLARED_ONLY" not in collector.assignments + + # Should not collect assignments from inside functions or classes + assert "local_typed" not in collector.assignments + assert "class_attr" not in collector.assignments + + # Verify correct order + expected_order = ["REGULAR_VAR", "TYPED_VAR", "CACHE", "SENTINEL", "FINAL_VAR"] + assert collector.assignment_order == expected_order + + +def test_global_function_collector(): + """Test GlobalFunctionCollector correctly collects module-level function definitions.""" + import libcst as cst + + from codeflash.code_utils.code_extractor import GlobalFunctionCollector + + source_code = """ +# Module-level functions +def helper_function(): + return "helper" + +def another_helper(x: int) -> str: + return str(x) + +class SomeClass: + def method(self): + # This is a method, not a module-level function + return "method" + + def another_method(self): + # Also a method + def nested_function(): + # Nested function inside method + return "nested" + return nested_function() + +def final_function(): + def inner_function(): + # This is a nested function, not module-level + return "inner" + return inner_function() +""" + + tree = cst.parse_module(source_code) + collector = GlobalFunctionCollector() + tree.visit(collector) + + # Should collect only module-level functions + assert len(collector.functions) == 3 + assert "helper_function" in collector.functions + assert "another_helper" in collector.functions + assert "final_function" in collector.functions + + # Should not collect methods or nested functions + assert "method" not in collector.functions + assert "another_method" not in collector.functions + assert "nested_function" not in collector.functions + assert "inner_function" not in collector.functions + + # Verify correct order + expected_order = ["helper_function", "another_helper", "final_function"] + assert collector.function_order == expected_order + + +def test_add_global_assignments_with_new_functions(): + """Test add_global_assignments correctly adds new module-level functions.""" + source_code = """\ +from functools import lru_cache + +class SkyvernPage: + @staticmethod + def action_wrap(action): + return _get_decorator_for_action(action) + +@lru_cache(maxsize=None) +def _get_decorator_for_action(action): + def decorator(fn): + return fn + return decorator +""" + + destination_code = """\ +from functools import lru_cache + +class SkyvernPage: + @staticmethod + def action_wrap(action): + # Original implementation + return action +""" + + expected = """\ +from functools import lru_cache + +class SkyvernPage: + @staticmethod + def action_wrap(action): + # Original implementation + return action + + +@lru_cache(maxsize=None) +def _get_decorator_for_action(action): + def decorator(fn): + return fn + return decorator +""" + + result = add_global_assignments(source_code, destination_code) + assert result == expected + + +def test_add_global_assignments_does_not_duplicate_existing_functions(): + """Test add_global_assignments does not duplicate functions that already exist in destination.""" + source_code = """\ +def helper(): + return "source_helper" + +def existing_function(): + return "source_existing" +""" + + destination_code = """\ +def existing_function(): + return "dest_existing" + +class MyClass: + pass +""" + + expected = """\ +def existing_function(): + return "dest_existing" + +class MyClass: + pass + +def helper(): + return "source_helper" +""" + + result = add_global_assignments(source_code, destination_code) + assert result == expected + + +def test_add_global_assignments_with_decorated_functions(): + """Test add_global_assignments correctly adds decorated functions.""" + source_code = """\ +from functools import lru_cache +from typing import Callable + +_LOCAL_CACHE: dict[str, int] = {} + +@lru_cache(maxsize=128) +def cached_helper(x: int) -> int: + return x * 2 + +def regular_helper(): + return "regular" +""" + + destination_code = """\ +from typing import Any + +class MyClass: + def method(self): + return cached_helper(5) +""" + + expected = """\ +from typing import Any + +_LOCAL_CACHE: dict[str, int] = {} + +class MyClass: + def method(self): + return cached_helper(5) + + +@lru_cache(maxsize=128) +def cached_helper(x: int) -> int: + return x * 2 + + +def regular_helper(): + return "regular" +""" + + result = add_global_assignments(source_code, destination_code) + assert result == expected + + +def test_add_global_assignments_references_class_defined_in_module(): + """Test that global assignments referencing classes are placed after those class definitions. + + This test verifies the fix for a bug where LLM-generated optimization code like: + _REIFIERS = {MessageKind.XXX: lambda d: ...} + was placed BEFORE the MessageKind class definition, causing NameError at module load. + + The fix ensures that new global assignments are inserted AFTER all class/function + definitions in the module, so they can safely reference any class defined in the module. + """ + source_code = """\ +import enum + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +_MESSAGE_HANDLERS = { + MessageKind.ASK: lambda: "ask handler", + MessageKind.REPLY: lambda: "reply handler", +} + +def handle_message(kind): + return _MESSAGE_HANDLERS[kind]() +""" + + destination_code = """\ +import enum + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +def handle_message(kind): + if kind == MessageKind.ASK: + return "ask" + return "reply" +""" + + # Global assignments are now inserted AFTER class/function definitions + # to ensure they can reference classes defined in the module + expected = """\ +import enum + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +def handle_message(kind): + if kind == MessageKind.ASK: + return "ask" + return "reply" + +_MESSAGE_HANDLERS = { + MessageKind.ASK: lambda: "ask handler", + MessageKind.REPLY: lambda: "reply handler", +} +""" + + result = add_global_assignments(source_code, destination_code) + assert result == expected + + +def test_add_global_assignments_function_calls_after_function_definitions(): + """Test that global function calls are placed after the functions they reference. + + This test verifies the fix for a bug where LLM-generated optimization code like: + def _register(kind, factory): + _factories[kind] = factory + + _register(MessageKind.ASK, lambda: "ask") + + would have the _register(...) calls placed BEFORE the _register function definition, + causing NameError at module load time. + + The fix ensures that new global statements (like function calls) are inserted AFTER + all class/function definitions, so they can safely reference any function defined in + the module. + """ + source_code = """\ +import enum + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +_factories = {} + +def _register(kind, factory): + _factories[kind] = factory + +_register(MessageKind.ASK, lambda: "ask handler") +_register(MessageKind.REPLY, lambda: "reply handler") + +def handle_message(kind): + return _factories[kind]() +""" + + destination_code = """\ +import enum + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +def handle_message(kind): + if kind == MessageKind.ASK: + return "ask" + return "reply" +""" + + expected = """\ +import enum + +_factories = {} + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +def handle_message(kind): + if kind == MessageKind.ASK: + return "ask" + return "reply" + + +def _register(kind, factory): + _factories[kind] = factory + + +_register(MessageKind.ASK, lambda: "ask handler") + +_register(MessageKind.REPLY, lambda: "reply handler") +""" + + result = add_global_assignments(source_code, destination_code) + assert result == expected + + def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None: """Test that when a class is instantiated, its __init__ method is tracked as a helper. @@ -2785,11 +3177,7 @@ def target_function(): ) ) function_to_optimize = FunctionToOptimize( - function_name="target_function", - file_path=file_path, - parents=[], - starting_line=None, - ending_line=None, + function_name="target_function", file_path=file_path, parents=[], starting_line=None, ending_line=None ) code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) @@ -2803,15 +3191,11 @@ def target_function(): # The testgen context should contain the class with __init__ (critical for LLM to know constructor) testgen_context = code_ctx.testgen_context.markdown assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context" - assert "def __init__(self, data):" in testgen_context, ( - "__init__ method should be included in testgen context" - ) + assert "def __init__(self, data):" in testgen_context, "__init__ method should be included in testgen context" # The hashing context should NOT contain __init__ (excluded for stability) hashing_context = code_ctx.hashing_code_context - assert "__init__" not in hashing_context, ( - "__init__ should NOT be in hashing context (excluded for hash stability)" - ) + assert "__init__" not in hashing_context, "__init__ should NOT be in hashing context (excluded for hash stability)" def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None: @@ -2865,11 +3249,7 @@ def dump_layout(layout_type, layout): ) ) function_to_optimize = FunctionToOptimize( - function_name="dump_layout", - file_path=file_path, - parents=[], - starting_line=None, - ending_line=None, + function_name="dump_layout", file_path=file_path, parents=[], starting_line=None, ending_line=None ) code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) @@ -2879,9 +3259,7 @@ def dump_layout(layout_type, layout): assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, ( "ObjectDetectionLayoutDumper.__init__ should be tracked" ) - assert "LayoutDumper.__init__" in qualified_names, ( - "LayoutDumper.__init__ should be tracked" - ) + assert "LayoutDumper.__init__" in qualified_names, "LayoutDumper.__init__ should be tracked" # The testgen context should include both classes with their __init__ methods testgen_context = code_ctx.testgen_context.markdown @@ -2891,9 +3269,7 @@ def dump_layout(layout_type, layout): ) # Both __init__ methods should be in the testgen context (so LLM knows constructor signatures) - assert testgen_context.count("def __init__") >= 2, ( - "Both __init__ methods should be in testgen context" - ) + assert testgen_context.count("def __init__") >= 2, "Both __init__ methods should be in testgen context" def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None: @@ -2929,7 +3305,7 @@ class Text(Element): elements_path.write_text(elements_code, encoding="utf-8") # Create another module that imports from elements - chunking_code = ''' + chunking_code = """ from mypackage.elements import Element class PreChunk: @@ -2939,14 +3315,12 @@ class PreChunk: class Accumulator: def will_fit(self, chunk: PreChunk) -> bool: return True -''' +""" chunking_path = package_dir / "chunking.py" chunking_path.write_text(chunking_code, encoding="utf-8") # Create CodeStringsMarkdown from the chunking module (simulating testgen context) - context = CodeStringsMarkdown( - code_strings=[CodeString(code=chunking_code, file_path=chunking_path)] - ) + context = CodeStringsMarkdown(code_strings=[CodeString(code=chunking_code, file_path=chunking_path)]) # Call get_imported_class_definitions result = get_imported_class_definitions(context, tmp_path) @@ -2970,16 +3344,16 @@ def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Pat (package_dir / "__init__.py").write_text("", encoding="utf-8") # Create a module with a class definition - elements_code = ''' + elements_code = """ class Element: def __init__(self, text: str): self.text = text -''' +""" elements_path = package_dir / "elements.py" elements_path.write_text(elements_code, encoding="utf-8") # Create code that imports Element but also redefines it locally - code_with_local_def = ''' + code_with_local_def = """ from mypackage.elements import Element # Local redefinition (this happens when LLM redefines classes) @@ -2990,13 +3364,11 @@ class Element: class User: def process(self, elem: Element): pass -''' +""" code_path = package_dir / "user.py" code_path.write_text(code_with_local_def, encoding="utf-8") - context = CodeStringsMarkdown( - code_strings=[CodeString(code=code_with_local_def, file_path=code_path)] - ) + context = CodeStringsMarkdown(code_strings=[CodeString(code=code_with_local_def, file_path=code_path)]) # Call get_imported_class_definitions result = get_imported_class_definitions(context, tmp_path) @@ -3013,7 +3385,7 @@ def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> Non (package_dir / "__init__.py").write_text("", encoding="utf-8") # Code with stdlib/third-party imports - code = ''' + code = """ from pathlib import Path from typing import Optional from dataclasses import dataclass @@ -3021,13 +3393,11 @@ from dataclasses import dataclass class MyClass: def __init__(self, path: Path): self.path = path -''' +""" code_path = package_dir / "main.py" code_path.write_text(code, encoding="utf-8") - context = CodeStringsMarkdown( - code_strings=[CodeString(code=code, file_path=code_path)] - ) + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) # Call get_imported_class_definitions result = get_imported_class_definitions(context, tmp_path) @@ -3044,7 +3414,7 @@ def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path) (package_dir / "__init__.py").write_text("", encoding="utf-8") # Create a module with multiple class definitions - types_code = ''' + types_code = """ class TypeA: def __init__(self, value: int): self.value = value @@ -3056,24 +3426,22 @@ class TypeB: class TypeC: def __init__(self): pass -''' +""" types_path = package_dir / "types.py" types_path.write_text(types_code, encoding="utf-8") # Create code that imports multiple classes - code = ''' + code = """ from mypackage.types import TypeA, TypeB class Processor: def process(self, a: TypeA, b: TypeB): pass -''' +""" code_path = package_dir / "processor.py" code_path.write_text(code, encoding="utf-8") - context = CodeStringsMarkdown( - code_strings=[CodeString(code=code, file_path=code_path)] - ) + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) # Call get_imported_class_definitions result = get_imported_class_definitions(context, tmp_path) @@ -3085,3 +3453,1200 @@ class Processor: assert "class TypeA" in all_extracted_code, "Should contain TypeA class" assert "class TypeB" in all_extracted_code, "Should contain TypeB class" assert "class TypeC" not in all_extracted_code, "Should NOT contain TypeC (not imported)" + + +def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path: Path) -> None: + """Test that get_imported_class_definitions includes decorators when extracting dataclasses.""" + # Create a package structure + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + (package_dir / "__init__.py").write_text("", encoding="utf-8") + + # Create a module with dataclass definitions (like LLMConfig in skyvern) + models_code = """from dataclasses import dataclass, field +from typing import Optional + +@dataclass(frozen=True) +class LLMConfigBase: + model_name: str + required_env_vars: list[str] + supports_vision: bool + add_assistant_prefix: bool + +@dataclass(frozen=True) +class LLMConfig(LLMConfigBase): + litellm_params: Optional[dict] = field(default=None) + max_tokens: int | None = None +""" + models_path = package_dir / "models.py" + models_path.write_text(models_code, encoding="utf-8") + + # Create code that imports the dataclass + code = """from mypackage.models import LLMConfig + +class ConfigRegistry: + def get_config(self) -> LLMConfig: + pass +""" + code_path = package_dir / "registry.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + + # Call get_imported_class_definitions + result = get_imported_class_definitions(context, tmp_path) + + # Should extract both LLMConfigBase (base class) and LLMConfig + assert len(result.code_strings) == 2, "Should extract both LLMConfig and its base class LLMConfigBase" + + # Combine extracted code to check for all required elements + all_extracted_code = "\n".join(cs.code for cs in result.code_strings) + + # Verify the base class is extracted first (for proper inheritance understanding) + base_class_idx = all_extracted_code.find("class LLMConfigBase") + derived_class_idx = all_extracted_code.find("class LLMConfig(") + assert base_class_idx < derived_class_idx, "Base class should appear before derived class" + + # Verify both classes include @dataclass decorators + assert all_extracted_code.count("@dataclass(frozen=True)") == 2, ( + "Should include @dataclass decorator for both classes" + ) + assert "class LLMConfig" in all_extracted_code, "Should contain LLMConfig class definition" + assert "class LLMConfigBase" in all_extracted_code, "Should contain LLMConfigBase class definition" + + # Verify imports are included for dataclass-related items + assert "from dataclasses import" in all_extracted_code, "Should include dataclasses import" + + +def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(tmp_path: Path) -> None: + """Test that extract_imports_for_class includes decorator and type annotation imports.""" + # Create a package structure + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + (package_dir / "__init__.py").write_text("", encoding="utf-8") + + # Create a module with decorated class that uses field() and various type annotations + models_code = """from dataclasses import dataclass, field +from typing import Optional, List + +@dataclass +class Config: + name: str + values: List[int] = field(default_factory=list) + description: Optional[str] = None +""" + models_path = package_dir / "models.py" + models_path.write_text(models_code, encoding="utf-8") + + # Create code that imports the class + code = """from mypackage.models import Config + +def create_config() -> Config: + return Config(name="test") +""" + code_path = package_dir / "main.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + + result = get_imported_class_definitions(context, tmp_path) + + assert len(result.code_strings) == 1, "Should extract Config class" + extracted_code = result.code_strings[0].code + + # The extracted code should include the decorator + assert "@dataclass" in extracted_code, "Should include @dataclass decorator" + # The imports should include dataclass and field + assert "from dataclasses import" in extracted_code, "Should include dataclasses import for decorator" + + +class TestCollectNamesFromAnnotation: + """Tests for the collect_names_from_annotation helper function.""" + + def test_simple_name(self): + """Test extracting a simple type name.""" + import ast + + code = "def f(x: MyClass): pass" + annotation = ast.parse(code).body[0].args.args[0].annotation + names: set[str] = set() + collect_names_from_annotation(annotation, names) + assert "MyClass" in names + + def test_subscript_type(self): + """Test extracting names from generic types like List[int].""" + import ast + + code = "def f(x: List[int]): pass" + annotation = ast.parse(code).body[0].args.args[0].annotation + names: set[str] = set() + collect_names_from_annotation(annotation, names) + assert "List" in names + assert "int" in names + + def test_optional_type(self): + """Test extracting names from Optional[MyClass].""" + import ast + + code = "def f(x: Optional[MyClass]): pass" + annotation = ast.parse(code).body[0].args.args[0].annotation + names: set[str] = set() + collect_names_from_annotation(annotation, names) + assert "Optional" in names + assert "MyClass" in names + + def test_union_type_with_pipe(self): + """Test extracting names from union types with | syntax.""" + import ast + + code = "def f(x: int | str | None): pass" + annotation = ast.parse(code).body[0].args.args[0].annotation + names: set[str] = set() + collect_names_from_annotation(annotation, names) + # int | str | None becomes BinOp nodes + assert "int" in names + assert "str" in names + + def test_nested_generic_types(self): + """Test extracting names from nested generics like Dict[str, List[MyClass]].""" + import ast + + code = "def f(x: Dict[str, List[MyClass]]): pass" + annotation = ast.parse(code).body[0].args.args[0].annotation + names: set[str] = set() + collect_names_from_annotation(annotation, names) + assert "Dict" in names + assert "str" in names + assert "List" in names + assert "MyClass" in names + + def test_tuple_annotation(self): + """Test extracting names from tuple type hints.""" + import ast + + code = "def f(x: tuple[int, str, MyClass]): pass" + annotation = ast.parse(code).body[0].args.args[0].annotation + names: set[str] = set() + collect_names_from_annotation(annotation, names) + assert "tuple" in names + assert "int" in names + assert "str" in names + assert "MyClass" in names + + +class TestExtractImportsForClass: + """Tests for the extract_imports_for_class helper function.""" + + def test_extracts_base_class_imports(self): + """Test that base class imports are extracted.""" + import ast + + module_source = """from abc import ABC +from mypackage import BaseClass + +class MyClass(BaseClass, ABC): + pass +""" + tree = ast.parse(module_source) + class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) + result = extract_imports_for_class(tree, class_node, module_source) + assert "from abc import ABC" in result + assert "from mypackage import BaseClass" in result + + def test_extracts_decorator_imports(self): + """Test that decorator imports are extracted.""" + import ast + + module_source = """from dataclasses import dataclass +from functools import lru_cache + +@dataclass +class MyClass: + name: str +""" + tree = ast.parse(module_source) + class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) + result = extract_imports_for_class(tree, class_node, module_source) + assert "from dataclasses import dataclass" in result + + def test_extracts_type_annotation_imports(self): + """Test that type annotation imports are extracted.""" + import ast + + module_source = """from typing import Optional, List +from mypackage.models import Config + +@dataclass +class MyClass: + config: Optional[Config] + items: List[str] +""" + tree = ast.parse(module_source) + class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) + result = extract_imports_for_class(tree, class_node, module_source) + assert "from typing import Optional, List" in result + assert "from mypackage.models import Config" in result + + def test_extracts_field_function_imports(self): + """Test that field() function imports are extracted for dataclasses.""" + import ast + + module_source = """from dataclasses import dataclass, field +from typing import List + +@dataclass +class MyClass: + items: List[str] = field(default_factory=list) +""" + tree = ast.parse(module_source) + class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) + result = extract_imports_for_class(tree, class_node, module_source) + assert "from dataclasses import dataclass, field" in result + + def test_no_duplicate_imports(self): + """Test that duplicate imports are not included.""" + import ast + + module_source = """from typing import Optional + +@dataclass +class MyClass: + field1: Optional[str] + field2: Optional[int] +""" + tree = ast.parse(module_source) + class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) + result = extract_imports_for_class(tree, class_node, module_source) + # Should only have one import line even though Optional is used twice + assert result.count("from typing import Optional") == 1 + + +def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> None: + """Test that classes with multiple decorators are extracted correctly.""" + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + (package_dir / "__init__.py").write_text("", encoding="utf-8") + + models_code = """from dataclasses import dataclass +from functools import total_ordering + +@total_ordering +@dataclass +class OrderedConfig: + name: str + priority: int + + def __lt__(self, other): + return self.priority < other.priority +""" + models_path = package_dir / "models.py" + models_path.write_text(models_code, encoding="utf-8") + + code = """from mypackage.models import OrderedConfig + +def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]: + return sorted(configs) +""" + code_path = package_dir / "main.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + + result = get_imported_class_definitions(context, tmp_path) + + assert len(result.code_strings) == 1 + extracted_code = result.code_strings[0].code + + # Both decorators should be included + assert "@total_ordering" in extracted_code, "Should include @total_ordering decorator" + assert "@dataclass" in extracted_code, "Should include @dataclass decorator" + assert "class OrderedConfig" in extracted_code + + +def test_get_imported_class_definitions_extracts_multilevel_inheritance(tmp_path: Path) -> None: + """Test that base classes are recursively extracted for multi-level inheritance. + + This is critical for understanding dataclass constructor signatures, as fields + from parent classes become required positional arguments in child classes. + """ + # Create a package structure + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + (package_dir / "__init__.py").write_text("", encoding="utf-8") + + # Create a module with multi-level inheritance like skyvern's LLM models: + # GrandParent -> Parent -> Child + models_code = '''from dataclasses import dataclass, field +from typing import Optional, Literal + +@dataclass(frozen=True) +class GrandParentConfig: + """Base config with common fields.""" + model_name: str + required_env_vars: list[str] + +@dataclass(frozen=True) +class ParentConfig(GrandParentConfig): + """Intermediate config adding vision support.""" + supports_vision: bool + add_assistant_prefix: bool + +@dataclass(frozen=True) +class ChildConfig(ParentConfig): + """Full config with optional parameters.""" + litellm_params: Optional[dict] = field(default=None) + max_tokens: int | None = None + temperature: float | None = 0.7 + +@dataclass(frozen=True) +class RouterConfig(ParentConfig): + """Router config branching from ParentConfig.""" + model_list: list + main_model_group: str + routing_strategy: Literal["simple", "least-busy"] = "simple" +''' + models_path = package_dir / "models.py" + models_path.write_text(models_code, encoding="utf-8") + + # Create code that imports only the child classes (not the base classes) + code = """from mypackage.models import ChildConfig, RouterConfig + +class ConfigRegistry: + def get_child_config(self) -> ChildConfig: + pass + + def get_router_config(self) -> RouterConfig: + pass +""" + code_path = package_dir / "registry.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + + # Call get_imported_class_definitions + result = get_imported_class_definitions(context, tmp_path) + + # Should extract 4 classes: GrandParentConfig, ParentConfig, ChildConfig, RouterConfig + # (all classes needed to understand the full inheritance hierarchy) + assert len(result.code_strings) == 4, ( + f"Should extract 4 classes (GrandParent, Parent, Child, Router), got {len(result.code_strings)}" + ) + + # Combine extracted code + all_extracted_code = "\n".join(cs.code for cs in result.code_strings) + + # Verify all classes are extracted + assert "class GrandParentConfig" in all_extracted_code, "Should extract GrandParentConfig base class" + assert "class ParentConfig(GrandParentConfig)" in all_extracted_code, "Should extract ParentConfig" + assert "class ChildConfig(ParentConfig)" in all_extracted_code, "Should extract ChildConfig" + assert "class RouterConfig(ParentConfig)" in all_extracted_code, "Should extract RouterConfig" + + # Verify classes are ordered correctly (base classes before derived) + grandparent_idx = all_extracted_code.find("class GrandParentConfig") + parent_idx = all_extracted_code.find("class ParentConfig(") + child_idx = all_extracted_code.find("class ChildConfig(") + router_idx = all_extracted_code.find("class RouterConfig(") + + assert grandparent_idx < parent_idx, "GrandParentConfig should appear before ParentConfig" + assert parent_idx < child_idx, "ParentConfig should appear before ChildConfig" + assert parent_idx < router_idx, "ParentConfig should appear before RouterConfig" + + # Verify the critical fields are visible for constructor understanding + assert "model_name: str" in all_extracted_code, "Should include model_name field from GrandParent" + assert "required_env_vars: list[str]" in all_extracted_code, "Should include required_env_vars field" + assert "supports_vision: bool" in all_extracted_code, "Should include supports_vision field from Parent" + assert "litellm_params:" in all_extracted_code, "Should include litellm_params field from Child" + assert "model_list: list" in all_extracted_code, "Should include model_list field from Router" + + +def test_get_external_base_class_inits_extracts_userdict(tmp_path: Path) -> None: + """Extracts __init__ from collections.UserDict when a class inherits from it.""" + code = """from collections import UserDict + +class MyCustomDict(UserDict): + pass +""" + code_path = tmp_path / "mydict.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_base_class_inits(context, tmp_path) + + assert len(result.code_strings) == 1 + code_string = result.code_strings[0] + + expected_code = """\ +class UserDict: + def __init__(self, dict=None, /, **kwargs): + self.data = {} + if dict is not None: + self.update(dict) + if kwargs: + self.update(kwargs) +""" + assert code_string.code == expected_code + assert code_string.file_path.as_posix().endswith("collections/__init__.py") + + +def test_get_external_base_class_inits_skips_project_classes(tmp_path: Path) -> None: + """Returns empty when base class is from the project, not external.""" + child_code = """from base import ProjectBase + +class Child(ProjectBase): + pass +""" + child_path = tmp_path / "child.py" + child_path.write_text(child_code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=child_code, file_path=child_path)]) + result = get_external_base_class_inits(context, tmp_path) + + assert result.code_strings == [] + + +def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None: + """Returns empty for builtin classes like list that have no inspectable source.""" + code = """class MyList(list): + pass +""" + code_path = tmp_path / "mylist.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_base_class_inits(context, tmp_path) + + assert result.code_strings == [] + + +def test_get_external_base_class_inits_deduplicates(tmp_path: Path) -> None: + """Extracts the same external base class only once even when inherited multiple times.""" + code = """from collections import UserDict + +class MyDict1(UserDict): + pass + +class MyDict2(UserDict): + pass +""" + code_path = tmp_path / "mydicts.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_base_class_inits(context, tmp_path) + + assert len(result.code_strings) == 1 + expected_code = """\ +class UserDict: + def __init__(self, dict=None, /, **kwargs): + self.data = {} + if dict is not None: + self.update(dict) + if kwargs: + self.update(kwargs) +""" + assert result.code_strings[0].code == expected_code + + +def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path) -> None: + """Returns empty when there are no external base classes.""" + code = """class SimpleClass: + pass +""" + code_path = tmp_path / "simple.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_base_class_inits(context, tmp_path) + + assert result.code_strings == [] + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="enum.StrEnum requires Python 3.11+") +def test_dependency_classes_kept_in_read_writable_context(tmp_path: Path) -> None: + """Tests that classes used as dependencies (enums, dataclasses) are kept in read-writable context. + + This test verifies that when a function uses classes like enums or dataclasses + as types or in match statements, those classes are included in the optimization + context, even though they don't contain any target functions. + """ + code = ''' +import dataclasses +import enum +import typing as t + + +class MessageKind(enum.StrEnum): + ASK_FOR_CLIPBOARD_RESPONSE = "ask-for-clipboard-response" + BEGIN_EXFILTRATION = "begin-exfiltration" + + +@dataclasses.dataclass +class Message: + kind: str + + +@dataclasses.dataclass +class MessageInAskForClipboardResponse(Message): + kind: t.Literal[MessageKind.ASK_FOR_CLIPBOARD_RESPONSE] = MessageKind.ASK_FOR_CLIPBOARD_RESPONSE + text: str = "" + + +@dataclasses.dataclass +class MessageInBeginExfiltration(Message): + kind: t.Literal[MessageKind.BEGIN_EXFILTRATION] = MessageKind.BEGIN_EXFILTRATION + + +MessageIn = ( + MessageInAskForClipboardResponse + | MessageInBeginExfiltration +) + + +def reify_channel_message(data: dict) -> MessageIn: + kind = data.get("kind", None) + + match kind: + case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE: + text = data.get("text") or "" + return MessageInAskForClipboardResponse(text=text) + case MessageKind.BEGIN_EXFILTRATION: + return MessageInBeginExfiltration() + case _: + raise ValueError(f"Unknown message kind: '{kind}'") +''' + code_path = tmp_path / "message.py" + code_path.write_text(code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="reify_channel_message", + file_path=code_path, + parents=[], + ) + + code_ctx = get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + ) + + expected_read_writable = """ +```python:message.py +import dataclasses +import enum +import typing as t + +class MessageKind(enum.StrEnum): + ASK_FOR_CLIPBOARD_RESPONSE = "ask-for-clipboard-response" + BEGIN_EXFILTRATION = "begin-exfiltration" + + +@dataclasses.dataclass +class Message: + kind: str + + +@dataclasses.dataclass +class MessageInAskForClipboardResponse(Message): + kind: t.Literal[MessageKind.ASK_FOR_CLIPBOARD_RESPONSE] = MessageKind.ASK_FOR_CLIPBOARD_RESPONSE + text: str = "" + + +@dataclasses.dataclass +class MessageInBeginExfiltration(Message): + kind: t.Literal[MessageKind.BEGIN_EXFILTRATION] = MessageKind.BEGIN_EXFILTRATION + + +MessageIn = ( + MessageInAskForClipboardResponse + | MessageInBeginExfiltration +) + + +def reify_channel_message(data: dict) -> MessageIn: + kind = data.get("kind", None) + + match kind: + case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE: + text = data.get("text") or "" + return MessageInAskForClipboardResponse(text=text) + case MessageKind.BEGIN_EXFILTRATION: + return MessageInBeginExfiltration() + case _: + raise ValueError(f"Unknown message kind: '{kind}'") +``` +""" + assert code_ctx.read_writable_code.markdown.strip() == expected_read_writable.strip() + + +def test_testgen_context_includes_external_base_inits(tmp_path: Path) -> None: + """Test that external base class __init__ methods are included in testgen context. + + This covers line 65 in code_context_extractor.py where external_base_inits.code_strings + are appended to the testgen context when a class inherits from an external library. + """ + code = """from collections import UserDict + +class MyCustomDict(UserDict): + def target_method(self): + return self.data +""" + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyCustomDict", type="ClassDef")], + ) + + code_ctx = get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + ) + + # The testgen context should include the UserDict __init__ method + testgen_context = code_ctx.testgen_context.markdown + assert "class UserDict:" in testgen_context, "UserDict class should be in testgen context" + assert "def __init__" in testgen_context, "UserDict __init__ should be in testgen context" + assert "self.data = {}" in testgen_context, "UserDict __init__ body should be included" + + +def test_read_only_code_removed_when_exceeds_limit(tmp_path: Path) -> None: + """Test read-only code is completely removed when it exceeds token limit even without docstrings. + + This covers lines 152-153 in code_context_extractor.py where read_only_context_code is set + to empty string when it still exceeds the token limit after docstring removal. + """ + # Create a second-degree helper with large implementation that has no docstrings + # Second-degree helpers go into read-only context + long_lines = [" x = 0"] + for i in range(150): + long_lines.append(f" x = x + {i}") + long_lines.append(" return x") + long_body = "\n".join(long_lines) + + code = f""" +class MyClass: + def __init__(self): + self.x = 1 + + def target_method(self): + return first_helper() + + +def first_helper(): + # First degree helper - calls second degree + return second_helper() + + +def second_helper(): + # Second degree helper - goes into read-only context +{long_body} +""" + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + ) + + # Use a small optim_token_limit that allows read-writable but not read-only + # Read-writable is ~48 tokens, read-only is ~600 tokens + code_ctx = get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + optim_token_limit=100, # Small limit to trigger read-only removal + ) + + # The read-only context should be empty because it exceeded the limit + assert code_ctx.read_only_context_code == "", "Read-only code should be removed when exceeding token limit" + + +def test_testgen_removes_imported_classes_on_overflow(tmp_path: Path) -> None: + """Test testgen context removes imported class definitions when exceeding token limit. + + This covers lines 176-186 in code_context_extractor.py where: + - Testgen context exceeds limit (line 175) + - Removing docstrings still exceeds (line 175 again) + - Removing imported classes succeeds (line 177-183) + """ + # Create a package structure with a large type class used only in type annotations + # This ensures get_imported_class_definitions extracts the full class + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + (package_dir / "__init__.py").write_text("", encoding="utf-8") + + # Create a large class with methods that will be extracted via get_imported_class_definitions + # Use methods WITHOUT docstrings so removing docstrings won't help much + many_methods = "\n".join([f" def method_{i}(self):\n return {i}" for i in range(100)]) + type_class_code = f''' +class TypeClass: + """A type class for annotations.""" + + def __init__(self, value: int): + self.value = value + +{many_methods} +''' + type_class_path = package_dir / "types.py" + type_class_path.write_text(type_class_code, encoding="utf-8") + + # Main module uses TypeClass only in annotation (not instantiated) + # This triggers get_imported_class_definitions to extract the full class + main_code = """ +from mypackage.types import TypeClass + +def target_function(obj: TypeClass) -> int: + return obj.value +""" + main_path = package_dir / "main.py" + main_path.write_text(main_code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="target_function", + file_path=main_path, + parents=[], + ) + + # Use a testgen_token_limit that: + # - Is exceeded by full context with imported class (~1500 tokens) + # - Is exceeded even after removing docstrings + # - But fits when imported class is removed (~40 tokens) + code_ctx = get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + testgen_token_limit=200, # Small limit to trigger imported class removal + ) + + # The testgen context should exist (didn't raise ValueError) + testgen_context = code_ctx.testgen_context.markdown + assert testgen_context, "Testgen context should not be empty" + + # The target function should still be there + assert "def target_function" in testgen_context, "Target function should be in testgen context" + + # The large imported class should NOT be included (removed due to token limit) + assert "class TypeClass" not in testgen_context, ( + "TypeClass should be removed from testgen context when exceeding token limit" + ) + + +def test_testgen_raises_when_all_fallbacks_fail(tmp_path: Path) -> None: + """Test that ValueError is raised when testgen context exceeds limit even after all fallbacks. + + This covers line 186 in code_context_extractor.py. + """ + # Create a function with a very long body that exceeds limits even without imports/docstrings + long_lines = [" x = 0"] + for i in range(200): + long_lines.append(f" x = x + {i}") + long_lines.append(" return x") + long_body = "\n".join(long_lines) + + code = f""" +def target_function(): +{long_body} +""" + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="target_function", + file_path=file_path, + parents=[], + ) + + # Use a very small testgen_token_limit that cannot fit even the base function + with pytest.raises(ValueError, match="Testgen code context has exceeded token limit"): + get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + testgen_token_limit=50, # Very small limit + ) + + +def test_get_external_base_class_inits_attribute_base(tmp_path: Path) -> None: + """Test handling of base class accessed as module.ClassName (ast.Attribute). + + This covers line 616 in code_context_extractor.py. + """ + # Use the standard import style which the code actually handles + code = """from collections import UserDict + +class MyDict(UserDict): + def custom_method(self): + return self.data +""" + code_path = tmp_path / "mydict.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_base_class_inits(context, tmp_path) + + # Should extract UserDict __init__ + assert len(result.code_strings) == 1 + assert "class UserDict:" in result.code_strings[0].code + assert "def __init__" in result.code_strings[0].code + + +def test_get_external_base_class_inits_no_init_method(tmp_path: Path) -> None: + """Test handling when base class has no __init__ method. + + This covers line 641 in code_context_extractor.py. + """ + # Create a class inheriting from a class that doesn't have inspectable __init__ + code = """from typing import Protocol + +class MyProtocol(Protocol): + pass +""" + code_path = tmp_path / "myproto.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_base_class_inits(context, tmp_path) + + # Protocol's __init__ can't be easily inspected, should handle gracefully + # Result may be empty or contain Protocol based on implementation + assert isinstance(result.code_strings, list) + + +def test_collect_names_from_annotation_attribute(tmp_path: Path) -> None: + """Test collect_names_from_annotation handles ast.Attribute annotations. + + This covers line 756 in code_context_extractor.py. + """ + # Use __import__ to avoid polluting the test file's detected imports + ast_mod = __import__("ast") + + # Parse code with type annotation using attribute access + code = "x: typing.List[int] = []" + tree = ast_mod.parse(code) + names: set[str] = set() + + # Find the annotation node + for node in ast_mod.walk(tree): + if isinstance(node, ast_mod.AnnAssign) and node.annotation: + collect_names_from_annotation(node.annotation, names) + break + + assert "typing" in names + + +def test_extract_imports_for_class_decorator_call_attribute(tmp_path: Path) -> None: + """Test extract_imports_for_class handles decorator calls with attribute access. + + This covers lines 707-708 in code_context_extractor.py. + """ + ast_mod = __import__("ast") + + code = """ +import functools + +@functools.lru_cache(maxsize=128) +class CachedClass: + pass +""" + tree = ast_mod.parse(code) + + # Find the class node + class_node = None + for node in ast_mod.walk(tree): + if isinstance(node, ast_mod.ClassDef): + class_node = node + break + + assert class_node is not None + result = extract_imports_for_class(tree, class_node, code) + + # Should include the functools import + assert "functools" in result + + +def test_annotated_assignment_in_read_writable(tmp_path: Path) -> None: + """Test that annotated assignments used by target function are in read-writable context. + + This covers lines 965-969 in code_context_extractor.py. + """ + code = """ +CONFIG_VALUE: int = 42 + +class MyClass: + def __init__(self): + self.x = CONFIG_VALUE + + def target_method(self): + return self.x +""" + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + ) + + code_ctx = get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + ) + + # CONFIG_VALUE should be in read-writable context since it's used by __init__ + read_writable = code_ctx.read_writable_code.markdown + assert "CONFIG_VALUE" in read_writable + + +def test_imported_class_definitions_module_path_none(tmp_path: Path) -> None: + """Test handling when module_path is None in get_imported_class_definitions. + + This covers line 560 in code_context_extractor.py. + """ + # Create code that imports from a non-existent or unresolvable module + code = """ +from nonexistent_module_xyz import SomeClass + +class MyClass: + def method(self, obj: SomeClass): + pass +""" + code_path = tmp_path / "test.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_imported_class_definitions(context, tmp_path) + + # Should handle gracefully and return empty or partial results + assert isinstance(result.code_strings, list) + + +def test_get_imported_names_import_star(tmp_path: Path) -> None: + """Test get_imported_names handles import * correctly. + + This covers lines 808-809 and 824-825 in code_context_extractor.py. + """ + import libcst as cst + + # Test regular import * + # Note: "import *" is not valid Python, but "from x import *" is + from_import_star = cst.parse_statement("from os import *") + assert isinstance(from_import_star, cst.SimpleStatementLine) + import_node = from_import_star.body[0] + assert isinstance(import_node, cst.ImportFrom) + + from codeflash.context.code_context_extractor import get_imported_names + + result = get_imported_names(import_node) + assert result == {"*"} + + +def test_get_imported_names_aliased_import(tmp_path: Path) -> None: + """Test get_imported_names handles aliased imports correctly. + + This covers lines 812-813 and 828-829 in code_context_extractor.py. + """ + import libcst as cst + + from codeflash.context.code_context_extractor import get_imported_names + + # Test import with alias + import_stmt = cst.parse_statement("import numpy as np") + assert isinstance(import_stmt, cst.SimpleStatementLine) + import_node = import_stmt.body[0] + assert isinstance(import_node, cst.Import) + + result = get_imported_names(import_node) + assert "np" in result + + # Test from import with alias + from_import_stmt = cst.parse_statement("from os import path as ospath") + assert isinstance(from_import_stmt, cst.SimpleStatementLine) + from_import_node = from_import_stmt.body[0] + assert isinstance(from_import_node, cst.ImportFrom) + + result2 = get_imported_names(from_import_node) + assert "ospath" in result2 + + +def test_get_imported_names_dotted_import(tmp_path: Path) -> None: + """Test get_imported_names handles dotted imports correctly. + + This covers lines 816-822 in code_context_extractor.py. + """ + import libcst as cst + + from codeflash.context.code_context_extractor import get_imported_names + + # Test dotted import like "import os.path" + import_stmt = cst.parse_statement("import os.path") + assert isinstance(import_stmt, cst.SimpleStatementLine) + import_node = import_stmt.body[0] + assert isinstance(import_node, cst.Import) + + result = get_imported_names(import_node) + assert "os" in result + + +def test_used_name_collector_comprehensive(tmp_path: Path) -> None: + """Test UsedNameCollector handles various node types. + + This covers lines 767-801 in code_context_extractor.py. + """ + import libcst as cst + + from codeflash.context.code_context_extractor import UsedNameCollector + + code = """ +import os +from typing import List + +x: int = 1 +y = os.path.join("a", "b") + +class MyClass: + z = 10 + +def my_func(): + pass +""" + module = cst.parse_module(code) + collector = UsedNameCollector() + # In libcst, the walker traverses the module + cst.MetadataWrapper(module).visit(collector) + + # Check used names + assert "os" in collector.used_names + assert "int" in collector.used_names + assert "List" in collector.used_names + + # Check defined names + assert "x" in collector.defined_names + assert "y" in collector.defined_names + assert "MyClass" in collector.defined_names + assert "my_func" in collector.defined_names + + # Check external names (used but not defined) + external = collector.get_external_names() + assert "os" in external + assert "x" not in external # x is defined + + +def test_imported_class_with_base_in_same_module(tmp_path: Path) -> None: + """Test that imported classes with bases in the same module are extracted correctly. + + This covers line 528 in code_context_extractor.py - early return for already extracted. + """ + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + (package_dir / "__init__.py").write_text("", encoding="utf-8") + + # Create a module with inheritance chain + module_code = """ +class BaseClass: + def __init__(self): + self.base = True + +class MiddleClass(BaseClass): + def __init__(self): + super().__init__() + self.middle = True + +class DerivedClass(MiddleClass): + def __init__(self): + super().__init__() + self.derived = True +""" + module_path = package_dir / "classes.py" + module_path.write_text(module_code, encoding="utf-8") + + # Main module imports and uses the derived class + main_code = """ +from mypackage.classes import DerivedClass + +def target_function(obj: DerivedClass) -> bool: + return obj.derived +""" + main_path = package_dir / "main.py" + main_path.write_text(main_code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=main_code, file_path=main_path)]) + result = get_imported_class_definitions(context, tmp_path) + + # Should extract the inheritance chain + all_code = "\n".join(cs.code for cs in result.code_strings) + assert "class BaseClass" in all_code or "class DerivedClass" in all_code + + +def test_get_imported_names_from_import_without_alias(tmp_path: Path) -> None: + """Test get_imported_names handles from imports without aliases. + + This covers lines 830-831 in code_context_extractor.py. + """ + import libcst as cst + + from codeflash.context.code_context_extractor import get_imported_names + + # Test from import without alias + from_import_stmt = cst.parse_statement("from os import path, getcwd") + assert isinstance(from_import_stmt, cst.SimpleStatementLine) + from_import_node = from_import_stmt.body[0] + assert isinstance(from_import_node, cst.ImportFrom) + + result = get_imported_names(from_import_node) + assert "path" in result + assert "getcwd" in result + + +def test_get_imported_names_regular_import(tmp_path: Path) -> None: + """Test get_imported_names handles regular imports. + + This covers lines 814-815 in code_context_extractor.py. + """ + import libcst as cst + + from codeflash.context.code_context_extractor import get_imported_names + + # Test regular import without alias + import_stmt = cst.parse_statement("import json") + assert isinstance(import_stmt, cst.SimpleStatementLine) + import_node = import_stmt.body[0] + assert isinstance(import_node, cst.Import) + + result = get_imported_names(import_node) + assert "json" in result + + +def test_augmented_assignment_not_in_context(tmp_path: Path) -> None: + """Test that augmented assignments are handled but not included unless used. + + This covers line 962-969 in code_context_extractor.py. + """ + code = """ +counter = 0 + +class MyClass: + def __init__(self): + global counter + counter += 1 + + def target_method(self): + return 42 +""" + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + + func_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + ) + + code_ctx = get_code_optimization_context( + function_to_optimize=func_to_optimize, + project_root_path=tmp_path, + ) + + # counter should be in context since __init__ uses it + read_writable = code_ctx.read_writable_code.markdown + assert "counter" in read_writable diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 04d83f13f..da83146a8 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -2119,7 +2119,6 @@ print("Hello world") expected_code = """import numpy as np a = 6 - if 2<3: a=4 else: diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index 6d76e2bf6..b9112f047 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -1602,7 +1602,94 @@ def calculate_portfolio_metrics( # now the test should match and no diffs should be found assert len(diffs) == 0 assert matched - + finally: test_path.unlink(missing_ok=True) - fto_file_path.unlink(missing_ok=True) \ No newline at end of file + fto_file_path.unlink(missing_ok=True) + + +def test_codeflash_capture_with_slots_class() -> None: + """Test that codeflash_capture works with classes that use __slots__ instead of __dict__.""" + test_code = """ +from code_to_optimize.tests.pytest.sample_code import SlotsClass +import unittest + +def test_slots_class(): + obj = SlotsClass(10, "test") + assert obj.x == 10 + assert obj.y == "test" +""" + test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve() + tmp_dir_path = get_run_tmp_file(Path("test_return_values")) + sample_code = f""" +from codeflash.verification.codeflash_capture import codeflash_capture + +class SlotsClass: + __slots__ = ('x', 'y') + + @codeflash_capture(function_name="SlotsClass.__init__", tmp_dir_path="{tmp_dir_path.as_posix()}", tests_root="{test_dir.as_posix()}") + def __init__(self, x, y): + self.x = x + self.y = y +""" + test_file_name = "test_slots_class_temp.py" + test_path = test_dir / test_file_name + test_path_perf = test_dir / "test_slots_class_temp_perf.py" + + tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/" + project_root_path = (Path(__file__).parent / "..").resolve() + sample_code_path = test_dir / "sample_code.py" + + try: + with test_path.open("w") as f: + f.write(test_code) + with sample_code_path.open("w") as f: + f.write(sample_code) + + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_type = TestType.EXISTING_UNIT_TEST + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + fto = FunctionToOptimize( + function_name="__init__", + file_path=sample_code_path, + parents=[FunctionParent(name="SlotsClass", type="ClassDef")], + ) + func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + # Test should pass and capture the slots values + assert len(test_results) == 1 + assert test_results[0].did_pass + # The return value should contain the slot values + assert test_results[0].return_value[0]["x"] == 10 + assert test_results[0].return_value[0]["y"] == "test" + + finally: + test_path.unlink(missing_ok=True) + sample_code_path.unlink(missing_ok=True) \ No newline at end of file diff --git a/tests/test_comparator.py b/tests/test_comparator.py index a62d61d80..7ce23febb 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -2316,6 +2316,77 @@ def test_dict_views() -> None: assert not comparator(d.items(), [("a", 1), ("b", 2)]) +def test_mappingproxy() -> None: + """Test comparator support for types.MappingProxyType (read-only dict view).""" + import types + + # Basic equality + mp1 = types.MappingProxyType({"a": 1, "b": 2, "c": 3}) + mp2 = types.MappingProxyType({"a": 1, "b": 2, "c": 3}) + assert comparator(mp1, mp2) + + # Different values + mp3 = types.MappingProxyType({"a": 1, "b": 2, "c": 4}) + assert not comparator(mp1, mp3) + + # Different keys + mp4 = types.MappingProxyType({"a": 1, "b": 2, "d": 3}) + assert not comparator(mp1, mp4) + + # Different length + mp5 = types.MappingProxyType({"a": 1, "b": 2}) + assert not comparator(mp1, mp5) + + # Order doesn't matter (like dict) + mp6 = types.MappingProxyType({"c": 3, "a": 1, "b": 2}) + assert comparator(mp1, mp6) + + # Empty mappingproxy + empty1 = types.MappingProxyType({}) + empty2 = types.MappingProxyType({}) + assert comparator(empty1, empty2) + + # Nested values + nested1 = types.MappingProxyType({"a": [1, 2, 3], "b": {"x": 1}}) + nested2 = types.MappingProxyType({"a": [1, 2, 3], "b": {"x": 1}}) + nested3 = types.MappingProxyType({"a": [1, 2, 4], "b": {"x": 1}}) + assert comparator(nested1, nested2) + assert not comparator(nested1, nested3) + + # mappingproxy is not equal to dict (different types) + d = {"a": 1, "b": 2} + mp = types.MappingProxyType({"a": 1, "b": 2}) + assert not comparator(mp, d) + assert not comparator(d, mp) + + # Verify class __dict__ is indeed a mappingproxy + class MyClass: + x = 1 + y = 2 + + assert isinstance(MyClass.__dict__, types.MappingProxyType) + + +def test_mappingproxy_superset() -> None: + """Test comparator superset_obj support for mappingproxy.""" + import types + + mp1 = types.MappingProxyType({"a": 1, "b": 2}) + mp2 = types.MappingProxyType({"a": 1, "b": 2, "c": 3}) + + # mp2 is a superset of mp1 + assert comparator(mp1, mp2, superset_obj=True) + # mp1 is not a superset of mp2 + assert not comparator(mp2, mp1, superset_obj=True) + + # Same mappingproxy with superset_obj=True + assert comparator(mp1, mp1, superset_obj=True) + + # Different values even with superset + mp3 = types.MappingProxyType({"a": 1, "b": 99, "c": 3}) + assert not comparator(mp1, mp3, superset_obj=True) + + def test_tensorflow_tensor() -> None: """Test comparator support for TensorFlow Tensor objects.""" try: diff --git a/tests/test_get_read_writable_code.py b/tests/test_get_read_writable_code.py index d1eeb6e99..952479d3a 100644 --- a/tests/test_get_read_writable_code.py +++ b/tests/test_get_read_writable_code.py @@ -218,8 +218,28 @@ def test_no_targets_found() -> None: def target(self): pass """ + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}) + expected = dedent(""" + class MyClass: + def method(self): + pass + + class Inner: + def target(self): + pass + """) + assert result.strip() == expected.strip() + + +def test_no_targets_found_raises_for_nonexistent() -> None: + """Test that ValueError is raised when the target function doesn't exist at all.""" + code = """ + class MyClass: + def method(self): + pass + """ with pytest.raises(ValueError, match="No target functions found in the provided code"): - parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}) + parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"NonExistent.target"}) def test_module_var() -> None: diff --git a/tests/test_multi_file_code_replacement.py b/tests/test_multi_file_code_replacement.py index e33e98d24..2d1f22509 100644 --- a/tests/test_multi_file_code_replacement.py +++ b/tests/test_multi_file_code_replacement.py @@ -124,7 +124,7 @@ def _get_string_usage(text: str) -> Usage: helper_file.unlink(missing_ok=True) main_file.unlink(missing_ok=True) - + expected_helper = """import re from collections.abc import Sequence diff --git a/tests/test_remove_unused_definitions.py b/tests/test_remove_unused_definitions.py index 8d09a95e1..edf11e7c5 100644 --- a/tests/test_remove_unused_definitions.py +++ b/tests/test_remove_unused_definitions.py @@ -481,3 +481,86 @@ def unused_function(): qualified_functions = {"get_platform_info", "get_loop_result"} result = remove_unused_definitions_by_function_names(code, qualified_functions) assert result.strip() == expected.strip() + + +def test_enum_attribute_access_dependency() -> None: + """Test that enum/class attribute access like MessageKind.VALUE is tracked as a dependency.""" + code = """ +from enum import Enum + +class MessageKind(Enum): + VALUE = "value" + OTHER = "other" + +class UnusedEnum(Enum): + UNUSED = "unused" + +UNUSED_VAR = 123 + +def process_message(kind): + match kind: + case MessageKind.VALUE: + return "got value" + case MessageKind.OTHER: + return "got other" + return "unknown" +""" + + expected = """ +from enum import Enum + +class MessageKind(Enum): + VALUE = "value" + OTHER = "other" + +class UnusedEnum(Enum): + UNUSED = "unused" + +def process_message(kind): + match kind: + case MessageKind.VALUE: + return "got value" + case MessageKind.OTHER: + return "got other" + return "unknown" +""" + + qualified_functions = {"process_message"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + # MessageKind should be preserved because process_message uses MessageKind.VALUE + assert "class MessageKind" in result + # UNUSED_VAR should be removed + assert "UNUSED_VAR" not in result + assert result.strip() == expected.strip() + + +def test_attribute_access_does_not_track_attr_name() -> None: + """Test that self.x attribute access doesn't track 'x' as a dependency on module-level x.""" + code = """ +x = "module_level_x" +UNUSED_VAR = "unused" + +class MyClass: + def __init__(self): + self.x = 1 # This 'x' is an attribute, not a reference to module-level 'x' + + def get_x(self): + return self.x # This 'x' is also an attribute access +""" + + expected = """ +class MyClass: + def __init__(self): + self.x = 1 # This 'x' is an attribute, not a reference to module-level 'x' + + def get_x(self): + return self.x # This 'x' is also an attribute access +""" + + qualified_functions = {"MyClass.get_x", "MyClass.__init__"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + # Module-level x should NOT be kept (self.x doesn't reference it) + assert 'x = "module_level_x"' not in result + # UNUSED_VAR should also be removed + assert "UNUSED_VAR" not in result + assert result.strip() == expected.strip()