fix typing issues

This commit is contained in:
ali 2026-04-28 14:46:49 +03:00
parent ea51f780a3
commit a5aa75d717
5 changed files with 30 additions and 30 deletions

View file

@ -3,8 +3,9 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from codeflash.languages.base import CodeContext, HelperFunction, Language
from codeflash.languages.base import CodeContext, HelperFunction
from codeflash.languages.golang.parser import GoAnalyzer
from codeflash.languages.language_enum import Language
if TYPE_CHECKING:
from pathlib import Path

View file

@ -27,7 +27,7 @@ if TYPE_CHECKING:
class GoFunctionOptimizer(FunctionOptimizer):
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
from codeflash.languages import get_language_support
from codeflash.languages.base import Language
from codeflash.languages.language_enum import Language
language = Language(self.function_to_optimize.language)
lang_support = get_language_support(language)

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import json
import logging
import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults
@ -12,6 +12,7 @@ if TYPE_CHECKING:
from pathlib import Path
from codeflash.models.models import TestFiles
from codeflash.models.test_type import TestType
from codeflash.verification.verification_utils import TestConfig
logger = logging.getLogger(__name__)
@ -29,7 +30,7 @@ def parse_go_test_output(
test_json_path: Path,
test_files: TestFiles,
test_config: TestConfig,
run_result: subprocess.CompletedProcess | None = None,
run_result: subprocess.CompletedProcess[str] | None = None,
) -> TestResults:
test_results = TestResults()
@ -71,10 +72,11 @@ def parse_go_test_output(
active[test_name] = _TestIteration(test_name=test_name, package=package)
continue
it = active.get(test_name)
if it is None:
it = _TestIteration(test_name=test_name, package=package)
active[test_name] = it
maybe_it = active.get(test_name)
if maybe_it is None:
maybe_it = _TestIteration(test_name=test_name, package=package)
active[test_name] = maybe_it
it = maybe_it
if action == "output":
output_text = event.get("Output", "")
@ -109,9 +111,6 @@ def parse_go_test_output(
test_file_path = _resolve_test_file(it.test_name, it.package, test_files, base_dir)
test_type = _resolve_test_type(test_file_path, test_files)
if test_type is None:
logger.debug("Skipping test %s: could not resolve test type", it.test_name)
continue
test_results.add(
FunctionTestInvocation(
@ -157,7 +156,7 @@ class _TestIteration:
self.stdout: str = ""
def _read_json_output(path: Path, run_result: subprocess.CompletedProcess | None) -> str:
def _read_json_output(path: Path, run_result: subprocess.CompletedProcess[str] | None) -> str:
try:
content = path.read_text(encoding="utf-8")
if content.strip():
@ -165,15 +164,12 @@ def _read_json_output(path: Path, run_result: subprocess.CompletedProcess | None
except Exception:
pass
if run_result is not None:
stdout = run_result.stdout
if isinstance(stdout, bytes):
stdout = stdout.decode("utf-8", errors="replace")
return stdout or ""
return run_result.stdout or ""
return ""
def _parse_json_lines(content: str) -> list[dict]:
events: list[dict] = []
def _parse_json_lines(content: str) -> list[dict[str, Any]]:
events: list[dict[str, Any]] = []
for line in content.splitlines():
line = line.strip()
if not line:
@ -199,7 +195,7 @@ def _resolve_test_file(test_name: str, package: str, test_files: TestFiles, base
return base_dir / f"{test_name}.go"
def _resolve_test_type(test_file_path: Path, test_files: TestFiles):
def _resolve_test_type(test_file_path: Path, test_files: TestFiles) -> TestType:
from codeflash.models.test_type import TestType
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)

View file

@ -6,7 +6,10 @@ from typing import TYPE_CHECKING
from codeflash.languages.golang.parser import GoAnalyzer
if TYPE_CHECKING:
import tree_sitter
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.golang.parser import GoGlobalDeclaration
logger = logging.getLogger(__name__)
@ -102,7 +105,7 @@ def _merge_global_var_const(optimized_code: str, original_source: str, analyzer:
return original_source
orig_decls = analyzer.find_global_declarations(original_source)
orig_names_to_decl: dict[str, object] = {}
orig_names_to_decl: dict[str, GoGlobalDeclaration] = {}
for decl in orig_decls:
for name in decl.names:
orig_names_to_decl[name] = decl
@ -131,7 +134,7 @@ def _merge_global_var_const(optimized_code: str, original_source: str, analyzer:
return original_source
def _replace_declaration_block(source: str, orig_decl: object, new_source_code: str) -> str:
def _replace_declaration_block(source: str, orig_decl: GoGlobalDeclaration, new_source_code: str) -> str:
lines = source.splitlines(keepends=True)
start = orig_decl.starting_line - 1
end = orig_decl.ending_line
@ -186,19 +189,19 @@ def remove_test_functions(test_source: str, functions_to_remove: list[str], anal
return "".join(lines)
def _find_doc_comment_start(node: object) -> int | None:
prev = getattr(node, "prev_named_sibling", None)
def _find_doc_comment_start(node: tree_sitter.Node) -> int | None:
prev = node.prev_named_sibling
if prev is None:
return None
if getattr(prev, "type", None) != "comment":
if prev.type != "comment":
return None
if prev.end_point.row + 1 != node.start_point.row:
return None
comment_start = prev.start_point.row + 1
comment_start: int = prev.start_point.row + 1
current = prev
while True:
earlier = getattr(current, "prev_named_sibling", None)
if earlier is None or getattr(earlier, "type", None) != "comment":
earlier = current.prev_named_sibling
if earlier is None or earlier.type != "comment":
break
if earlier.end_point.row + 1 != current.start_point.row:
break

View file

@ -4,6 +4,7 @@ import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
from codeflash.languages.base import LanguageSupport
from codeflash.languages.golang.comparator import compare_test_results as _compare_results
from codeflash.languages.golang.config import detect_go_project, detect_go_version
from codeflash.languages.golang.context import extract_code_context as _extract_context
@ -29,18 +30,17 @@ if TYPE_CHECKING:
DependencyResolver,
FunctionFilterCriteria,
HelperFunction,
InvocationId,
ReferenceInfo,
TestInfo,
)
from codeflash.models.function_types import FunctionToOptimize
from codeflash.models.models import GeneratedTestsList
from codeflash.models.models import GeneratedTestsList, InvocationId
logger = logging.getLogger(__name__)
@register_language
class GoSupport:
class GoSupport(LanguageSupport):
def __init__(self) -> None:
self._analyzer = GoAnalyzer()
self._go_version: str | None = None