From a5aa75d717ab6ffacdd2bce693e36233fb563063 Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 28 Apr 2026 14:46:49 +0300 Subject: [PATCH] fix typing issues --- codeflash/languages/golang/context.py | 3 +- .../languages/golang/function_optimizer.py | 2 +- codeflash/languages/golang/parse.py | 30 ++++++++----------- codeflash/languages/golang/replacement.py | 19 +++++++----- codeflash/languages/golang/support.py | 6 ++-- 5 files changed, 30 insertions(+), 30 deletions(-) diff --git a/codeflash/languages/golang/context.py b/codeflash/languages/golang/context.py index a2a608e2b..eec372e5d 100644 --- a/codeflash/languages/golang/context.py +++ b/codeflash/languages/golang/context.py @@ -3,8 +3,9 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING -from codeflash.languages.base import CodeContext, HelperFunction, Language +from codeflash.languages.base import CodeContext, HelperFunction from codeflash.languages.golang.parser import GoAnalyzer +from codeflash.languages.language_enum import Language if TYPE_CHECKING: from pathlib import Path diff --git a/codeflash/languages/golang/function_optimizer.py b/codeflash/languages/golang/function_optimizer.py index 184fc0070..0b679da3b 100644 --- a/codeflash/languages/golang/function_optimizer.py +++ b/codeflash/languages/golang/function_optimizer.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: class GoFunctionOptimizer(FunctionOptimizer): def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: from codeflash.languages import get_language_support - from codeflash.languages.base import Language + from codeflash.languages.language_enum import Language language = Language(self.function_to_optimize.language) lang_support = get_language_support(language) diff --git a/codeflash/languages/golang/parse.py b/codeflash/languages/golang/parse.py index 42da7a59c..e6c0f09e7 100644 --- a/codeflash/languages/golang/parse.py +++ b/codeflash/languages/golang/parse.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import logging import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults @@ -12,6 +12,7 @@ if TYPE_CHECKING: from pathlib import Path from codeflash.models.models import TestFiles + from codeflash.models.test_type import TestType from codeflash.verification.verification_utils import TestConfig logger = logging.getLogger(__name__) @@ -29,7 +30,7 @@ def parse_go_test_output( test_json_path: Path, test_files: TestFiles, test_config: TestConfig, - run_result: subprocess.CompletedProcess | None = None, + run_result: subprocess.CompletedProcess[str] | None = None, ) -> TestResults: test_results = TestResults() @@ -71,10 +72,11 @@ def parse_go_test_output( active[test_name] = _TestIteration(test_name=test_name, package=package) continue - it = active.get(test_name) - if it is None: - it = _TestIteration(test_name=test_name, package=package) - active[test_name] = it + maybe_it = active.get(test_name) + if maybe_it is None: + maybe_it = _TestIteration(test_name=test_name, package=package) + active[test_name] = maybe_it + it = maybe_it if action == "output": output_text = event.get("Output", "") @@ -109,9 +111,6 @@ def parse_go_test_output( test_file_path = _resolve_test_file(it.test_name, it.package, test_files, base_dir) test_type = _resolve_test_type(test_file_path, test_files) - if test_type is None: - logger.debug("Skipping test %s: could not resolve test type", it.test_name) - continue test_results.add( FunctionTestInvocation( @@ -157,7 +156,7 @@ class _TestIteration: self.stdout: str = "" -def _read_json_output(path: Path, run_result: subprocess.CompletedProcess | None) -> str: +def _read_json_output(path: Path, run_result: subprocess.CompletedProcess[str] | None) -> str: try: content = path.read_text(encoding="utf-8") if content.strip(): @@ -165,15 +164,12 @@ def _read_json_output(path: Path, run_result: subprocess.CompletedProcess | None except Exception: pass if run_result is not None: - stdout = run_result.stdout - if isinstance(stdout, bytes): - stdout = stdout.decode("utf-8", errors="replace") - return stdout or "" + return run_result.stdout or "" return "" -def _parse_json_lines(content: str) -> list[dict]: - events: list[dict] = [] +def _parse_json_lines(content: str) -> list[dict[str, Any]]: + events: list[dict[str, Any]] = [] for line in content.splitlines(): line = line.strip() if not line: @@ -199,7 +195,7 @@ def _resolve_test_file(test_name: str, package: str, test_files: TestFiles, base return base_dir / f"{test_name}.go" -def _resolve_test_type(test_file_path: Path, test_files: TestFiles): +def _resolve_test_type(test_file_path: Path, test_files: TestFiles) -> TestType: from codeflash.models.test_type import TestType test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path) diff --git a/codeflash/languages/golang/replacement.py b/codeflash/languages/golang/replacement.py index 68b1c811e..c6301312f 100644 --- a/codeflash/languages/golang/replacement.py +++ b/codeflash/languages/golang/replacement.py @@ -6,7 +6,10 @@ from typing import TYPE_CHECKING from codeflash.languages.golang.parser import GoAnalyzer if TYPE_CHECKING: + import tree_sitter + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.languages.golang.parser import GoGlobalDeclaration logger = logging.getLogger(__name__) @@ -102,7 +105,7 @@ def _merge_global_var_const(optimized_code: str, original_source: str, analyzer: return original_source orig_decls = analyzer.find_global_declarations(original_source) - orig_names_to_decl: dict[str, object] = {} + orig_names_to_decl: dict[str, GoGlobalDeclaration] = {} for decl in orig_decls: for name in decl.names: orig_names_to_decl[name] = decl @@ -131,7 +134,7 @@ def _merge_global_var_const(optimized_code: str, original_source: str, analyzer: return original_source -def _replace_declaration_block(source: str, orig_decl: object, new_source_code: str) -> str: +def _replace_declaration_block(source: str, orig_decl: GoGlobalDeclaration, new_source_code: str) -> str: lines = source.splitlines(keepends=True) start = orig_decl.starting_line - 1 end = orig_decl.ending_line @@ -186,19 +189,19 @@ def remove_test_functions(test_source: str, functions_to_remove: list[str], anal return "".join(lines) -def _find_doc_comment_start(node: object) -> int | None: - prev = getattr(node, "prev_named_sibling", None) +def _find_doc_comment_start(node: tree_sitter.Node) -> int | None: + prev = node.prev_named_sibling if prev is None: return None - if getattr(prev, "type", None) != "comment": + if prev.type != "comment": return None if prev.end_point.row + 1 != node.start_point.row: return None - comment_start = prev.start_point.row + 1 + comment_start: int = prev.start_point.row + 1 current = prev while True: - earlier = getattr(current, "prev_named_sibling", None) - if earlier is None or getattr(earlier, "type", None) != "comment": + earlier = current.prev_named_sibling + if earlier is None or earlier.type != "comment": break if earlier.end_point.row + 1 != current.start_point.row: break diff --git a/codeflash/languages/golang/support.py b/codeflash/languages/golang/support.py index 0552735c8..aaa7c8e3f 100644 --- a/codeflash/languages/golang/support.py +++ b/codeflash/languages/golang/support.py @@ -4,6 +4,7 @@ import logging from pathlib import Path from typing import TYPE_CHECKING, Any +from codeflash.languages.base import LanguageSupport from codeflash.languages.golang.comparator import compare_test_results as _compare_results from codeflash.languages.golang.config import detect_go_project, detect_go_version from codeflash.languages.golang.context import extract_code_context as _extract_context @@ -29,18 +30,17 @@ if TYPE_CHECKING: DependencyResolver, FunctionFilterCriteria, HelperFunction, - InvocationId, ReferenceInfo, TestInfo, ) from codeflash.models.function_types import FunctionToOptimize - from codeflash.models.models import GeneratedTestsList + from codeflash.models.models import GeneratedTestsList, InvocationId logger = logging.getLogger(__name__) @register_language -class GoSupport: +class GoSupport(LanguageSupport): def __init__(self) -> None: self._analyzer = GoAnalyzer() self._go_version: str | None = None