From f1521d7a2d8adcdacf5f5d98c29861640263c152 Mon Sep 17 00:00:00 2001 From: Mohamed Ashraf Date: Tue, 28 Apr 2026 15:40:02 +0000 Subject: [PATCH] fix: resolve pre-existing mypy errors on files touched by this PR The prek mypy hook runs on changed files and bypasses the pyproject.toml tests/ exclude, surfacing pre-existing errors in both context.py and test_context.py that block CI for this PR. Fixes applied: - Import Language from language_enum instead of base (base re-exports are not explicit; strict mypy flags attr-defined) - Annotate _extract_class_declaration, _import_to_statement, get_java_imported_type_skeletons, and resolved_imports - Guard None start/end_line in _extract_function_source_by_lines and find_helper_functions; guard None file_path in the import skeleton loop - Drop unreachable `if not node: continue` in _extract_public_method_signatures (JavaMethodNode.node is non-nullable) - Add -> None to every test method and fix an `int | None` comparison in test_context.py All 880 Java tests pass after the change. --- codeflash/languages/java/context.py | 26 ++- .../test_languages/test_java/test_context.py | 195 +++++++++--------- 2 files changed, 116 insertions(+), 105 deletions(-) diff --git a/codeflash/languages/java/context.py b/codeflash/languages/java/context.py index 1b5ffa74a..e2f6ada5c 100644 --- a/codeflash/languages/java/context.py +++ b/codeflash/languages/java/context.py @@ -11,10 +11,11 @@ import logging from typing import TYPE_CHECKING from codeflash.code_utils.code_utils import encoded_tokens_len -from codeflash.languages.base import CodeContext, HelperFunction, Language +from codeflash.languages.base import CodeContext, HelperFunction from codeflash.languages.java.discovery import discover_functions_from_source from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files from codeflash.languages.java.parser import get_java_analyzer +from codeflash.languages.language_enum import Language if TYPE_CHECKING: from pathlib import Path @@ -22,7 +23,8 @@ if TYPE_CHECKING: from tree_sitter import Node from codeflash.discovery.functions_to_optimize import FunctionToOptimize - from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode + from codeflash.languages.java.import_resolver import ResolvedImport + from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo, JavaMethodNode logger = logging.getLogger(__name__) @@ -360,7 +362,7 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s # Keep old function name for backwards compatibility -def _extract_class_declaration(node, source_bytes): +def _extract_class_declaration(node: Node, source_bytes: bytes) -> str: return _extract_type_declaration(node, source_bytes, "class") @@ -629,6 +631,8 @@ def _extract_function_source_by_lines(source: str, function: FunctionToOptimize) start_line = function.doc_start_line or function.starting_line end_line = function.ending_line + if start_line is None or end_line is None: + return "" # Convert from 1-indexed to 0-indexed start_idx = start_line - 1 @@ -672,6 +676,8 @@ def find_helper_functions( func_id = f"{file_path}:{func.qualified_name}" if func_id not in visited_functions: visited_functions.add(func_id) + if func.starting_line is None or func.ending_line is None: + continue # Extract the function source using tree-sitter for resilient lookup func_source = extract_function_source(source, func, analyzer=analyzer) @@ -795,7 +801,7 @@ def extract_read_only_context(source: str, function: FunctionToOptimize, analyze return "\n".join(context_parts) -def _import_to_statement(import_info) -> str: +def _import_to_statement(import_info: JavaImportInfo) -> str: """Convert a JavaImportInfo to an import statement string. Args: @@ -898,7 +904,11 @@ def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str] def get_java_imported_type_skeletons( - imports: list, project_root: Path, module_root: Path | None, analyzer: JavaAnalyzer, target_code: str = "" + imports: list[JavaImportInfo], + project_root: Path, + module_root: Path | None, + analyzer: JavaAnalyzer, + target_code: str = "", ) -> str: """Extract type skeletons for project-internal imported types. @@ -933,7 +943,7 @@ def get_java_imported_type_skeletons( priority_types = _extract_type_names_from_code(target_code, analyzer) # Pre-resolve all imports, expanding wildcards into individual types - resolved_imports: list = [] + resolved_imports: list[ResolvedImport] = [] for imp in imports: if imp.is_wildcard: # First try unfiltered expansion with a cap. If the package is small enough, take all types. @@ -978,7 +988,7 @@ def get_java_imported_type_skeletons( for resolved in resolved_imports: class_name = resolved.class_name - if not class_name: + if not class_name or resolved.file_path is None: continue dedup_key = (str(resolved.file_path), class_name) @@ -1100,8 +1110,6 @@ def _extract_public_method_signatures(source: str, class_name: str, analyzer: Ja continue node = method.node - if not node: - continue # Check if the method is public is_public = False diff --git a/tests/test_languages/test_java/test_context.py b/tests/test_languages/test_java/test_context.py index d9d771c01..b8316de15 100644 --- a/tests/test_languages/test_java/test_context.py +++ b/tests/test_languages/test_java/test_context.py @@ -2,7 +2,8 @@ from pathlib import Path -from codeflash.languages.base import FunctionFilterCriteria, Language +from codeflash.languages.base import FunctionFilterCriteria +from codeflash.languages.language_enum import Language from codeflash.languages.java.context import ( TypeSkeleton, _extract_public_method_signatures, @@ -23,7 +24,7 @@ NO_RETURN_FILTER = FunctionFilterCriteria(require_return=False) class TestExtractCodeContextBasic: """Tests for basic extract_code_context functionality.""" - def test_simple_method(self, tmp_path: Path): + def test_simple_method(self, tmp_path: Path) -> None: """Test extracting context for a simple method.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { @@ -53,7 +54,7 @@ class TestExtractCodeContextBasic: assert context.helper_functions == [] assert context.read_only_context == "" - def test_method_with_javadoc(self, tmp_path: Path): + def test_method_with_javadoc(self, tmp_path: Path) -> None: """Test extracting context for method with Javadoc.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { @@ -94,7 +95,7 @@ class TestExtractCodeContextBasic: assert context.helper_functions == [] assert context.read_only_context == "" - def test_static_method(self, tmp_path: Path): + def test_static_method(self, tmp_path: Path) -> None: """Test extracting context for a static method.""" java_file = tmp_path / "MathUtils.java" java_file.write_text("""public class MathUtils { @@ -123,7 +124,7 @@ class TestExtractCodeContextBasic: assert context.helper_functions == [] assert context.read_only_context == "" - def test_private_method(self, tmp_path: Path): + def test_private_method(self, tmp_path: Path) -> None: """Test extracting context for a private method.""" java_file = tmp_path / "Helper.java" java_file.write_text("""public class Helper { @@ -149,7 +150,7 @@ class TestExtractCodeContextBasic: """ ) - def test_protected_method(self, tmp_path: Path): + def test_protected_method(self, tmp_path: Path) -> None: """Test extracting context for a protected method.""" java_file = tmp_path / "Base.java" java_file.write_text("""public class Base { @@ -175,7 +176,7 @@ class TestExtractCodeContextBasic: """ ) - def test_synchronized_method(self, tmp_path: Path): + def test_synchronized_method(self, tmp_path: Path) -> None: """Test extracting context for a synchronized method.""" java_file = tmp_path / "Counter.java" java_file.write_text("""public class Counter { @@ -200,7 +201,7 @@ class TestExtractCodeContextBasic: """ ) - def test_method_with_throws(self, tmp_path: Path): + def test_method_with_throws(self, tmp_path: Path) -> None: """Test extracting context for a method with throws clause.""" java_file = tmp_path / "FileHandler.java" java_file.write_text("""public class FileHandler { @@ -225,7 +226,7 @@ class TestExtractCodeContextBasic: """ ) - def test_method_with_varargs(self, tmp_path: Path): + def test_method_with_varargs(self, tmp_path: Path) -> None: """Test extracting context for a method with varargs.""" java_file = tmp_path / "Logger.java" java_file.write_text("""public class Logger { @@ -250,7 +251,7 @@ class TestExtractCodeContextBasic: """ ) - def test_void_method(self, tmp_path: Path): + def test_void_method(self, tmp_path: Path) -> None: """Test extracting context for a void method.""" java_file = tmp_path / "Printer.java" java_file.write_text("""public class Printer { @@ -277,7 +278,7 @@ class TestExtractCodeContextBasic: """ ) - def test_generic_return_type(self, tmp_path: Path): + def test_generic_return_type(self, tmp_path: Path) -> None: """Test extracting context for a method with generic return type.""" java_file = tmp_path / "Container.java" java_file.write_text("""public class Container { @@ -306,7 +307,7 @@ class TestExtractCodeContextBasic: class TestExtractCodeContextWithImports: """Tests for extract_code_context with various import types.""" - def test_with_package_and_imports(self, tmp_path: Path): + def test_with_package_and_imports(self, tmp_path: Path) -> None: """Test context extraction with package and imports.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""package com.example; @@ -344,7 +345,7 @@ public class Calculator { # Fields are in skeleton, so read_only_context is empty assert context.read_only_context == "" - def test_with_static_imports(self, tmp_path: Path): + def test_with_static_imports(self, tmp_path: Path) -> None: """Test context extraction with static imports.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""package com.example; @@ -380,7 +381,7 @@ public class Calculator { "import static java.lang.Math.sqrt;", ] - def test_with_wildcard_imports(self, tmp_path: Path): + def test_with_wildcard_imports(self, tmp_path: Path) -> None: """Test context extraction with wildcard imports.""" java_file = tmp_path / "Processor.java" java_file.write_text("""package com.example; @@ -402,7 +403,7 @@ public class Processor { assert context.language == Language.JAVA assert context.imports == ["import java.util.*;", "import java.io.*;"] - def test_with_multiple_import_types(self, tmp_path: Path): + def test_with_multiple_import_types(self, tmp_path: Path) -> None: """Test context extraction with various import types.""" java_file = tmp_path / "Handler.java" java_file.write_text("""package com.example; @@ -453,7 +454,7 @@ class TestExtractCodeContextWithFields: read_only_context should be empty to avoid duplication. """ - def test_with_instance_fields(self, tmp_path: Path): + def test_with_instance_fields(self, tmp_path: Path) -> None: """Test context extraction with instance fields.""" java_file = tmp_path / "Person.java" java_file.write_text("""public class Person { @@ -488,7 +489,7 @@ class TestExtractCodeContextWithFields: assert context.imports == [] assert context.helper_functions == [] - def test_with_static_fields(self, tmp_path: Path): + def test_with_static_fields(self, tmp_path: Path) -> None: """Test context extraction with static fields.""" java_file = tmp_path / "Counter.java" java_file.write_text("""public class Counter { @@ -519,7 +520,7 @@ class TestExtractCodeContextWithFields: # Fields are in skeleton, so read_only_context is empty assert context.read_only_context == "" - def test_with_final_fields(self, tmp_path: Path): + def test_with_final_fields(self, tmp_path: Path) -> None: """Test context extraction with final fields.""" java_file = tmp_path / "Config.java" java_file.write_text("""public class Config { @@ -549,7 +550,7 @@ class TestExtractCodeContextWithFields: ) assert context.read_only_context == "" - def test_with_static_final_constants(self, tmp_path: Path): + def test_with_static_final_constants(self, tmp_path: Path) -> None: """Test context extraction with static final constants.""" java_file = tmp_path / "Constants.java" java_file.write_text("""public class Constants { @@ -581,7 +582,7 @@ class TestExtractCodeContextWithFields: ) assert context.read_only_context == "" - def test_with_volatile_fields(self, tmp_path: Path): + def test_with_volatile_fields(self, tmp_path: Path) -> None: """Test context extraction with volatile fields.""" java_file = tmp_path / "ThreadSafe.java" java_file.write_text("""public class ThreadSafe { @@ -611,7 +612,7 @@ class TestExtractCodeContextWithFields: ) assert context.read_only_context == "" - def test_with_generic_fields(self, tmp_path: Path): + def test_with_generic_fields(self, tmp_path: Path) -> None: """Test context extraction with generic type fields.""" java_file = tmp_path / "Container.java" java_file.write_text("""public class Container { @@ -643,7 +644,7 @@ class TestExtractCodeContextWithFields: ) assert context.read_only_context == "" - def test_with_array_fields(self, tmp_path: Path): + def test_with_array_fields(self, tmp_path: Path) -> None: """Test context extraction with array fields.""" java_file = tmp_path / "ArrayHolder.java" java_file.write_text("""public class ArrayHolder { @@ -679,7 +680,7 @@ class TestExtractCodeContextWithFields: class TestExtractCodeContextWithHelpers: """Tests for extract_code_context with helper functions.""" - def test_single_helper_method(self, tmp_path: Path): + def test_single_helper_method(self, tmp_path: Path) -> None: """Test context extraction with a single helper method.""" java_file = tmp_path / "Processor.java" java_file.write_text("""public class Processor { @@ -715,7 +716,7 @@ class TestExtractCodeContextWithHelpers: == "private String normalize(String s) {\n return s.trim().toLowerCase();\n }" ) - def test_multiple_helper_methods(self, tmp_path: Path): + def test_multiple_helper_methods(self, tmp_path: Path) -> None: """Test context extraction with multiple helper methods.""" java_file = tmp_path / "Processor.java" java_file.write_text("""public class Processor { @@ -758,7 +759,7 @@ class TestExtractCodeContextWithHelpers: helper_names = sorted([h.name for h in context.helper_functions]) assert helper_names == ["trim", "upper"] - def test_chained_helper_calls(self, tmp_path: Path): + def test_chained_helper_calls(self, tmp_path: Path) -> None: """Test context extraction with chained helper calls.""" java_file = tmp_path / "Processor.java" java_file.write_text("""public class Processor { @@ -784,7 +785,7 @@ class TestExtractCodeContextWithHelpers: helper_names = [h.name for h in context.helper_functions] assert helper_names == ["normalize"] - def test_no_helpers_when_none_called(self, tmp_path: Path): + def test_no_helpers_when_none_called(self, tmp_path: Path) -> None: """Test context extraction when no helpers are called.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { @@ -814,7 +815,7 @@ class TestExtractCodeContextWithHelpers: ) assert context.helper_functions == [] - def test_static_helper_from_instance_method(self, tmp_path: Path): + def test_static_helper_from_instance_method(self, tmp_path: Path) -> None: """Test context extraction with static helper called from instance method.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { @@ -840,7 +841,7 @@ class TestExtractCodeContextWithHelpers: class TestExtractCodeContextWithJavadoc: """Tests for extract_code_context with various Javadoc patterns.""" - def test_simple_javadoc(self, tmp_path: Path): + def test_simple_javadoc(self, tmp_path: Path) -> None: """Test context extraction with simple Javadoc.""" java_file = tmp_path / "Example.java" java_file.write_text("""public class Example { @@ -866,7 +867,7 @@ class TestExtractCodeContextWithJavadoc: """ ) - def test_javadoc_with_params(self, tmp_path: Path): + def test_javadoc_with_params(self, tmp_path: Path) -> None: """Test context extraction with Javadoc @param tags.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { @@ -900,7 +901,7 @@ class TestExtractCodeContextWithJavadoc: """ ) - def test_javadoc_with_return(self, tmp_path: Path): + def test_javadoc_with_return(self, tmp_path: Path) -> None: """Test context extraction with Javadoc @return tag.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { @@ -932,7 +933,7 @@ class TestExtractCodeContextWithJavadoc: """ ) - def test_javadoc_with_throws(self, tmp_path: Path): + def test_javadoc_with_throws(self, tmp_path: Path) -> None: """Test context extraction with Javadoc @throws tag.""" java_file = tmp_path / "Divider.java" java_file.write_text("""public class Divider { @@ -968,7 +969,7 @@ class TestExtractCodeContextWithJavadoc: """ ) - def test_javadoc_multiline(self, tmp_path: Path): + def test_javadoc_multiline(self, tmp_path: Path) -> None: """Test context extraction with multi-paragraph Javadoc.""" java_file = tmp_path / "Complex.java" java_file.write_text("""public class Complex { @@ -1020,7 +1021,7 @@ class TestExtractCodeContextWithJavadoc: class TestExtractCodeContextWithGenerics: """Tests for extract_code_context with generic types.""" - def test_generic_method_type_parameter(self, tmp_path: Path): + def test_generic_method_type_parameter(self, tmp_path: Path) -> None: """Test context extraction with generic type parameter.""" java_file = tmp_path / "Utils.java" java_file.write_text("""public class Utils { @@ -1044,7 +1045,7 @@ class TestExtractCodeContextWithGenerics: """ ) - def test_bounded_type_parameter(self, tmp_path: Path): + def test_bounded_type_parameter(self, tmp_path: Path) -> None: """Test context extraction with bounded type parameter.""" java_file = tmp_path / "Statistics.java" java_file.write_text("""public class Statistics { @@ -1076,7 +1077,7 @@ class TestExtractCodeContextWithGenerics: """ ) - def test_wildcard_type(self, tmp_path: Path): + def test_wildcard_type(self, tmp_path: Path) -> None: """Test context extraction with wildcard type.""" java_file = tmp_path / "Printer.java" java_file.write_text("""public class Printer { @@ -1100,7 +1101,7 @@ class TestExtractCodeContextWithGenerics: """ ) - def test_bounded_wildcard_extends(self, tmp_path: Path): + def test_bounded_wildcard_extends(self, tmp_path: Path) -> None: """Test context extraction with upper bounded wildcard.""" java_file = tmp_path / "Aggregator.java" java_file.write_text("""public class Aggregator { @@ -1132,7 +1133,7 @@ class TestExtractCodeContextWithGenerics: """ ) - def test_bounded_wildcard_super(self, tmp_path: Path): + def test_bounded_wildcard_super(self, tmp_path: Path) -> None: """Test context extraction with lower bounded wildcard.""" java_file = tmp_path / "Filler.java" java_file.write_text("""public class Filler { @@ -1158,7 +1159,7 @@ class TestExtractCodeContextWithGenerics: """ ) - def test_multiple_type_parameters(self, tmp_path: Path): + def test_multiple_type_parameters(self, tmp_path: Path) -> None: """Test context extraction with multiple type parameters.""" java_file = tmp_path / "Mapper.java" java_file.write_text("""public class Mapper { @@ -1190,7 +1191,7 @@ class TestExtractCodeContextWithGenerics: """ ) - def test_recursive_type_bound(self, tmp_path: Path): + def test_recursive_type_bound(self, tmp_path: Path) -> None: """Test context extraction with recursive type bound.""" java_file = tmp_path / "Sorter.java" java_file.write_text("""public class Sorter { @@ -1218,7 +1219,7 @@ class TestExtractCodeContextWithGenerics: class TestExtractCodeContextWithAnnotations: """Tests for extract_code_context with annotations.""" - def test_override_annotation(self, tmp_path: Path): + def test_override_annotation(self, tmp_path: Path) -> None: """Test context extraction with @Override annotation.""" java_file = tmp_path / "Child.java" java_file.write_text("""public class Child extends Parent { @@ -1244,7 +1245,7 @@ class TestExtractCodeContextWithAnnotations: """ ) - def test_deprecated_annotation(self, tmp_path: Path): + def test_deprecated_annotation(self, tmp_path: Path) -> None: """Test context extraction with @Deprecated annotation.""" java_file = tmp_path / "Legacy.java" java_file.write_text("""public class Legacy { @@ -1270,7 +1271,7 @@ class TestExtractCodeContextWithAnnotations: """ ) - def test_suppress_warnings_annotation(self, tmp_path: Path): + def test_suppress_warnings_annotation(self, tmp_path: Path) -> None: """Test context extraction with @SuppressWarnings annotation.""" java_file = tmp_path / "Processor.java" java_file.write_text("""public class Processor { @@ -1296,7 +1297,7 @@ class TestExtractCodeContextWithAnnotations: """ ) - def test_multiple_annotations(self, tmp_path: Path): + def test_multiple_annotations(self, tmp_path: Path) -> None: """Test context extraction with multiple annotations.""" java_file = tmp_path / "Service.java" java_file.write_text("""public class Service { @@ -1326,7 +1327,7 @@ class TestExtractCodeContextWithAnnotations: """ ) - def test_annotation_with_array_value(self, tmp_path: Path): + def test_annotation_with_array_value(self, tmp_path: Path) -> None: """Test context extraction with annotation array value.""" java_file = tmp_path / "Handler.java" java_file.write_text("""public class Handler { @@ -1356,7 +1357,7 @@ class TestExtractCodeContextWithAnnotations: class TestExtractCodeContextWithInheritance: """Tests for extract_code_context with inheritance scenarios.""" - def test_method_in_subclass(self, tmp_path: Path): + def test_method_in_subclass(self, tmp_path: Path) -> None: """Test context extraction for method in subclass.""" java_file = tmp_path / "AdvancedCalc.java" java_file.write_text("""public class AdvancedCalc extends Calculator { @@ -1382,7 +1383,7 @@ class TestExtractCodeContextWithInheritance: """ ) - def test_interface_implementation(self, tmp_path: Path): + def test_interface_implementation(self, tmp_path: Path) -> None: """Test context extraction for interface implementation.""" java_file = tmp_path / "MyComparable.java" java_file.write_text("""public class MyComparable implements Comparable { @@ -1414,7 +1415,7 @@ class TestExtractCodeContextWithInheritance: # Fields are in skeleton, so read_only_context is empty (no duplication) assert context.read_only_context == "" - def test_multiple_interfaces(self, tmp_path: Path): + def test_multiple_interfaces(self, tmp_path: Path) -> None: """Test context extraction for multiple interface implementations.""" java_file = tmp_path / "MultiImpl.java" java_file.write_text("""public class MultiImpl implements Runnable, Comparable { @@ -1446,7 +1447,7 @@ class TestExtractCodeContextWithInheritance: """ ) - def test_default_interface_method(self, tmp_path: Path): + def test_default_interface_method(self, tmp_path: Path) -> None: """Test context extraction for default interface method.""" java_file = tmp_path / "MyInterface.java" java_file.write_text("""public interface MyInterface { @@ -1479,7 +1480,7 @@ class TestExtractCodeContextWithInheritance: class TestExtractCodeContextWithInnerClasses: """Tests for extract_code_context with inner/nested classes.""" - def test_static_nested_class_method(self, tmp_path: Path): + def test_static_nested_class_method(self, tmp_path: Path) -> None: """Inner class methods are excluded from discovery and cannot be context-extracted. Methods of static nested classes are skipped in discovery because they @@ -1499,7 +1500,7 @@ class TestExtractCodeContextWithInnerClasses: # Inner class method must NOT be discovered assert compute_func is None - def test_inner_class_method(self, tmp_path: Path): + def test_inner_class_method(self, tmp_path: Path) -> None: """Inner class methods are excluded from discovery and cannot be context-extracted. Methods of non-static inner classes are skipped in discovery because they @@ -1525,7 +1526,7 @@ class TestExtractCodeContextWithInnerClasses: class TestExtractCodeContextWithEnumAndInterface: """Tests for extract_code_context with enums and interfaces.""" - def test_enum_method(self, tmp_path: Path): + def test_enum_method(self, tmp_path: Path) -> None: """Test context extraction for enum method.""" java_file = tmp_path / "Operation.java" java_file.write_text("""public enum Operation { @@ -1568,7 +1569,7 @@ class TestExtractCodeContextWithEnumAndInterface: ) assert context.read_only_context == "" - def test_interface_default_method(self, tmp_path: Path): + def test_interface_default_method(self, tmp_path: Path) -> None: """Test context extraction for interface default method.""" java_file = tmp_path / "Greeting.java" java_file.write_text("""public interface Greeting { @@ -1595,7 +1596,7 @@ class TestExtractCodeContextWithEnumAndInterface: ) assert context.read_only_context == "" - def test_interface_static_method(self, tmp_path: Path): + def test_interface_static_method(self, tmp_path: Path) -> None: """Test context extraction for interface static method.""" java_file = tmp_path / "Factory.java" java_file.write_text("""public interface Factory { @@ -1626,7 +1627,7 @@ class TestExtractCodeContextWithEnumAndInterface: class TestExtractCodeContextEdgeCases: """Tests for extract_code_context edge cases.""" - def test_empty_method(self, tmp_path: Path): + def test_empty_method(self, tmp_path: Path) -> None: """Test context extraction for empty method.""" java_file = tmp_path / "Empty.java" java_file.write_text("""public class Empty { @@ -1650,7 +1651,7 @@ class TestExtractCodeContextEdgeCases: """ ) - def test_single_line_method(self, tmp_path: Path): + def test_single_line_method(self, tmp_path: Path) -> None: """Test context extraction for single-line method.""" java_file = tmp_path / "OneLiner.java" java_file.write_text("""public class OneLiner { @@ -1670,7 +1671,7 @@ class TestExtractCodeContextEdgeCases: """ ) - def test_method_with_lambda(self, tmp_path: Path): + def test_method_with_lambda(self, tmp_path: Path) -> None: """Test context extraction for method with lambda.""" java_file = tmp_path / "Functional.java" java_file.write_text("""public class Functional { @@ -1698,7 +1699,7 @@ class TestExtractCodeContextEdgeCases: """ ) - def test_method_with_method_reference(self, tmp_path: Path): + def test_method_with_method_reference(self, tmp_path: Path) -> None: """Test context extraction for method with method reference.""" java_file = tmp_path / "Printer.java" java_file.write_text("""public class Printer { @@ -1722,7 +1723,7 @@ class TestExtractCodeContextEdgeCases: """ ) - def test_deeply_nested_blocks(self, tmp_path: Path): + def test_deeply_nested_blocks(self, tmp_path: Path) -> None: """Test context extraction for method with deeply nested blocks.""" java_file = tmp_path / "Nested.java" java_file.write_text("""public class Nested { @@ -1776,7 +1777,7 @@ class TestExtractCodeContextEdgeCases: """ ) - def test_unicode_in_source(self, tmp_path: Path): + def test_unicode_in_source(self, tmp_path: Path) -> None: """Test context extraction for method with unicode characters.""" java_file = tmp_path / "Unicode.java" java_file.write_text( @@ -1803,7 +1804,7 @@ class TestExtractCodeContextEdgeCases: """ ) - def test_file_not_found(self, tmp_path: Path): + def test_file_not_found(self, tmp_path: Path) -> None: """Test context extraction for missing file.""" from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.function_types import FunctionParent @@ -1824,7 +1825,7 @@ class TestExtractCodeContextEdgeCases: assert context.language == Language.JAVA assert context.target_file == missing_file - def test_max_helper_depth_zero(self, tmp_path: Path): + def test_max_helper_depth_zero(self, tmp_path: Path) -> None: """Test context extraction with max_helper_depth=0.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { @@ -1858,7 +1859,7 @@ class TestExtractCodeContextEdgeCases: class TestExtractCodeContextWithConstructor: """Tests for extract_code_context with constructors in class skeleton.""" - def test_class_with_constructor(self, tmp_path: Path): + def test_class_with_constructor(self, tmp_path: Path) -> None: """Test context extraction includes constructor in skeleton.""" java_file = tmp_path / "Person.java" java_file.write_text("""public class Person { @@ -1898,7 +1899,7 @@ class TestExtractCodeContextWithConstructor: """ ) - def test_class_with_multiple_constructors(self, tmp_path: Path): + def test_class_with_multiple_constructors(self, tmp_path: Path) -> None: """Test context extraction includes all constructors in skeleton.""" java_file = tmp_path / "Config.java" java_file.write_text("""public class Config { @@ -1956,7 +1957,7 @@ class TestExtractCodeContextWithConstructor: class TestExtractCodeContextFullIntegration: """Integration tests for extract_code_context with all components.""" - def test_full_context_with_all_components(self, tmp_path: Path): + def test_full_context_with_all_components(self, tmp_path: Path) -> None: """Test context extraction with imports, fields, and helpers.""" java_file = tmp_path / "Service.java" java_file.write_text("""package com.example; @@ -2007,7 +2008,7 @@ public class Service { assert len(context.helper_functions) == 1 assert context.helper_functions[0].name == "transform" - def test_complex_class_with_javadoc_and_annotations(self, tmp_path: Path): + def test_complex_class_with_javadoc_and_annotations(self, tmp_path: Path) -> None: """Test context extraction for complex class with javadoc and annotations.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""package com.example.math; @@ -2073,7 +2074,7 @@ public class Calculator { class TestExtractClassContext: """Tests for extract_class_context.""" - def test_extract_class_with_imports(self, tmp_path: Path): + def test_extract_class_with_imports(self, tmp_path: Path) -> None: """Test extracting full class context with imports.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""package com.example; @@ -2112,7 +2113,7 @@ public class Calculator { }""" ) - def test_extract_class_not_found(self, tmp_path: Path): + def test_extract_class_not_found(self, tmp_path: Path) -> None: """Test extracting non-existent class returns empty string.""" java_file = tmp_path / "Test.java" java_file.write_text("""public class Test { @@ -2124,7 +2125,7 @@ public class Calculator { assert context == "" - def test_extract_class_missing_file(self, tmp_path: Path): + def test_extract_class_missing_file(self, tmp_path: Path) -> None: """Test extracting from missing file returns empty string.""" missing_file = tmp_path / "Missing.java" @@ -2141,7 +2142,7 @@ class TestExtractFunctionSourceStaleLineNumbers: extraction should still find the correct function by name. """ - def test_extraction_with_stale_line_numbers(self): + def test_extraction_with_stale_line_numbers(self) -> None: """Verify extraction works when pre-computed line numbers no longer match the source.""" # Original source: functionA at lines 2-4, functionB at lines 5-7 original_source = """public class Utils { @@ -2177,7 +2178,7 @@ class TestExtractFunctionSourceStaleLineNumbers: assert "functionB" in result assert "return 2;" in result - def test_extraction_without_analyzer_uses_line_numbers(self): + def test_extraction_without_analyzer_uses_line_numbers(self) -> None: """Without analyzer, extraction falls back to pre-computed line numbers.""" source = """public class Utils { public int functionA() { @@ -2196,7 +2197,7 @@ class TestExtractFunctionSourceStaleLineNumbers: assert "functionB" in result assert "return 2;" in result - def test_extraction_with_javadoc_after_file_modification(self): + def test_extraction_with_javadoc_after_file_modification(self) -> None: """Verify Javadoc is included when using tree-sitter extraction on modified files.""" original_source = """public class Utils { /** Adds two numbers. */ @@ -2233,7 +2234,7 @@ class TestExtractFunctionSourceStaleLineNumbers: assert "public int subtract" in result assert "return a - b;" in result - def test_extraction_with_overloaded_methods(self): + def test_extraction_with_overloaded_methods(self) -> None: """Verify correct overload is selected using line proximity.""" source = """public class Utils { public int process(int x) { @@ -2247,13 +2248,15 @@ class TestExtractFunctionSourceStaleLineNumbers: analyzer = get_java_analyzer() functions = discover_functions_from_source(source, file_path=Path("Utils.java")) # Get the second overload (process(int, int)) - func_two_args = [f for f in functions if f.function_name == "process" and f.ending_line > 4][0] + func_two_args = [ + f for f in functions if f.function_name == "process" and f.ending_line is not None and f.ending_line > 4 + ][0] result = extract_function_source(source, func_two_args, analyzer=analyzer) assert "int x, int y" in result assert "return x + y;" in result - def test_extraction_function_not_found_falls_back(self): + def test_extraction_function_not_found_falls_back(self) -> None: """If tree-sitter can't find the method, fall back to line numbers.""" source = """public class Utils { public int functionA() { @@ -2282,7 +2285,7 @@ FIXTURE_DIR = Path(__file__).parent.parent / "fixtures" / "java_maven" class TestGetJavaImportedTypeSkeletons: """Tests for get_java_imported_type_skeletons().""" - def test_resolves_internal_imports(self): + def test_resolves_internal_imports(self) -> None: """Verify that project-internal imports are resolved and skeletons extracted.""" project_root = FIXTURE_DIR module_root = FIXTURE_DIR / "src" / "main" / "java" @@ -2297,7 +2300,7 @@ class TestGetJavaImportedTypeSkeletons: assert "MathHelper" in result assert "Formatter" in result - def test_skeletons_contain_method_signatures(self): + def test_skeletons_contain_method_signatures(self) -> None: """Verify extracted skeletons include public method signatures.""" project_root = FIXTURE_DIR module_root = FIXTURE_DIR / "src" / "main" / "java" @@ -2313,7 +2316,7 @@ class TestGetJavaImportedTypeSkeletons: assert "multiply" in result assert "factorial" in result - def test_skips_external_imports(self): + def test_skips_external_imports(self) -> None: """Verify that standard library and external imports are skipped.""" project_root = FIXTURE_DIR module_root = FIXTURE_DIR / "src" / "main" / "java" @@ -2328,7 +2331,7 @@ class TestGetJavaImportedTypeSkeletons: # No internal imports → empty result assert result == "" - def test_deduplicates_imports(self): + def test_deduplicates_imports(self) -> None: """Verify that the same type imported twice is only included once.""" project_root = FIXTURE_DIR module_root = FIXTURE_DIR / "src" / "main" / "java" @@ -2344,7 +2347,7 @@ class TestGetJavaImportedTypeSkeletons: # Count occurrences of MathHelper — should appear exactly once assert result.count("class MathHelper") == 1 - def test_empty_imports_returns_empty(self): + def test_empty_imports_returns_empty(self) -> None: """Verify that empty import list returns empty string.""" project_root = FIXTURE_DIR analyzer = get_java_analyzer() @@ -2353,7 +2356,7 @@ class TestGetJavaImportedTypeSkeletons: assert result == "" - def test_respects_token_budget(self): + def test_respects_token_budget(self) -> None: """Verify that the function stops when token budget is exceeded.""" project_root = FIXTURE_DIR module_root = FIXTURE_DIR / "src" / "main" / "java" @@ -2378,7 +2381,7 @@ class TestGetJavaImportedTypeSkeletons: class TestExtractPublicMethodSignatures: """Tests for _extract_public_method_signatures().""" - def test_extracts_public_methods(self): + def test_extracts_public_methods(self) -> None: """Verify public method signatures are extracted.""" source = """public class Foo { public int add(int a, int b) { @@ -2398,7 +2401,7 @@ class TestExtractPublicMethodSignatures: # private method should not be included assert not any("secret" in s for s in sigs) - def test_excludes_constructors(self): + def test_excludes_constructors(self) -> None: """Verify constructors are excluded from method signatures.""" source = """public class Bar { public Bar(int x) { this.x = x; } @@ -2411,7 +2414,7 @@ class TestExtractPublicMethodSignatures: assert "getX" in sigs[0] assert not any("Bar(" in s for s in sigs) - def test_empty_class_returns_empty(self): + def test_empty_class_returns_empty(self) -> None: """Verify empty class returns no signatures.""" source = """public class Empty {}""" analyzer = get_java_analyzer() @@ -2419,7 +2422,7 @@ class TestExtractPublicMethodSignatures: assert sigs == [] - def test_filters_by_class_name(self): + def test_filters_by_class_name(self) -> None: """Verify only methods from the specified class are returned.""" source = """public class A { public int aMethod() { return 1; } @@ -2440,7 +2443,7 @@ class B { class TestFormatSkeletonForContext: """Tests for _format_skeleton_for_context().""" - def test_formats_basic_skeleton(self): + def test_formats_basic_skeleton(self) -> None: """Verify basic skeleton formatting with fields and constructors.""" source = """public class Widget { private int size; @@ -2467,7 +2470,7 @@ class TestFormatSkeletonForContext: assert "getSize" in result assert result.endswith("}") - def test_formats_enum_skeleton(self): + def test_formats_enum_skeleton(self) -> None: """Verify enum formatting includes constants.""" source = """public enum Color { RED, GREEN, BLUE; @@ -2490,7 +2493,7 @@ class TestFormatSkeletonForContext: assert "RED, GREEN, BLUE;" in result assert "lower" in result - def test_formats_empty_class(self): + def test_formats_empty_class(self) -> None: """Verify formatting of a class with no fields or methods.""" source = """public class Empty {}""" analyzer = get_java_analyzer() @@ -2512,7 +2515,7 @@ class TestFormatSkeletonForContext: class TestGetJavaImportedTypeSkeletonsEdgeCases: """Additional edge case tests for get_java_imported_type_skeletons().""" - def test_wildcard_imports_are_expanded(self): + def test_wildcard_imports_are_expanded(self) -> None: """Wildcard imports (e.g., import com.example.helpers.*) are expanded to individual types.""" project_root = FIXTURE_DIR module_root = FIXTURE_DIR / "src" / "main" / "java" @@ -2530,7 +2533,7 @@ class TestGetJavaImportedTypeSkeletonsEdgeCases: # Wildcard imports should now be expanded to individual classes found in the package directory assert "MathHelper" in result - def test_large_wildcard_is_filtered_to_referenced_types(self, tmp_path: Path): + def test_large_wildcard_is_filtered_to_referenced_types(self, tmp_path: Path) -> None: """When wildcard expands to >50 types, only types referenced in target code are included.""" from codeflash.languages.java.context import MAX_WILDCARD_TYPES_UNFILTERED @@ -2560,7 +2563,7 @@ class TestGetJavaImportedTypeSkeletonsEdgeCases: # Types not referenced in target code should be excluded assert "Type050" not in result - def test_import_to_nonexistent_class_in_file(self): + def test_import_to_nonexistent_class_in_file(self) -> None: """When an import resolves to a file but the class doesn't exist in it, skeleton extraction returns None.""" analyzer = get_java_analyzer() @@ -2570,7 +2573,7 @@ class TestGetJavaImportedTypeSkeletonsEdgeCases: assert skeleton is None - def test_skeleton_output_is_well_formed(self): + def test_skeleton_output_is_well_formed(self) -> None: """Verify the skeleton string has proper Java-like structure with braces.""" project_root = FIXTURE_DIR module_root = FIXTURE_DIR / "src" / "main" / "java" @@ -2593,7 +2596,7 @@ class TestGetJavaImportedTypeSkeletonsEdgeCases: class TestExtractPublicMethodSignaturesEdgeCases: """Additional edge case tests for _extract_public_method_signatures().""" - def test_excludes_protected_and_package_private(self): + def test_excludes_protected_and_package_private(self) -> None: """Verify protected and package-private methods are excluded.""" source = """public class Visibility { public int publicMethod() { return 1; } @@ -2610,7 +2613,7 @@ class TestExtractPublicMethodSignaturesEdgeCases: assert not any("packagePrivateMethod" in s for s in sigs) assert not any("privateMethod" in s for s in sigs) - def test_handles_overloaded_methods(self): + def test_handles_overloaded_methods(self) -> None: """Verify all public overloads are extracted.""" source = """public class Overloaded { public int process(int x) { return x; } @@ -2624,7 +2627,7 @@ class TestExtractPublicMethodSignaturesEdgeCases: # All should contain "process" assert all("process" in s for s in sigs) - def test_handles_generic_methods(self): + def test_handles_generic_methods(self) -> None: """Verify generic method signatures are extracted correctly.""" source = """public class Generic { public T identity(T value) { return value; } @@ -2641,7 +2644,7 @@ class TestExtractPublicMethodSignaturesEdgeCases: class TestFormatSkeletonRoundTrip: """Tests that verify _extract_type_skeleton → _format_skeleton_for_context produces valid output.""" - def test_round_trip_produces_valid_skeleton(self): + def test_round_trip_produces_valid_skeleton(self) -> None: """Extract a real skeleton and format it — verify the output is sensible.""" source = """public class Service { private final String name; @@ -2689,7 +2692,7 @@ class TestFormatSkeletonRoundTrip: # Should end properly assert result.strip().endswith("}") - def test_round_trip_with_fixture_mathhelper(self): + def test_round_trip_with_fixture_mathhelper(self) -> None: """Round-trip test using the real MathHelper fixture file.""" source = (FIXTURE_DIR / "src" / "main" / "java" / "com" / "example" / "helpers" / "MathHelper.java").read_text() analyzer = get_java_analyzer()