mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix typing issues
This commit is contained in:
parent
ea51f780a3
commit
a5aa75d717
5 changed files with 30 additions and 30 deletions
|
|
@ -3,8 +3,9 @@ from __future__ import annotations
|
|||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import CodeContext, HelperFunction, Language
|
||||
from codeflash.languages.base import CodeContext, HelperFunction
|
||||
from codeflash.languages.golang.parser import GoAnalyzer
|
||||
from codeflash.languages.language_enum import Language
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
|||
class GoFunctionOptimizer(FunctionOptimizer):
|
||||
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.language_enum import Language
|
||||
|
||||
language = Language(self.function_to_optimize.language)
|
||||
lang_support = get_language_support(language)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults
|
||||
|
||||
|
|
@ -12,6 +12,7 @@ if TYPE_CHECKING:
|
|||
from pathlib import Path
|
||||
|
||||
from codeflash.models.models import TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -29,7 +30,7 @@ def parse_go_test_output(
|
|||
test_json_path: Path,
|
||||
test_files: TestFiles,
|
||||
test_config: TestConfig,
|
||||
run_result: subprocess.CompletedProcess | None = None,
|
||||
run_result: subprocess.CompletedProcess[str] | None = None,
|
||||
) -> TestResults:
|
||||
test_results = TestResults()
|
||||
|
||||
|
|
@ -71,10 +72,11 @@ def parse_go_test_output(
|
|||
active[test_name] = _TestIteration(test_name=test_name, package=package)
|
||||
continue
|
||||
|
||||
it = active.get(test_name)
|
||||
if it is None:
|
||||
it = _TestIteration(test_name=test_name, package=package)
|
||||
active[test_name] = it
|
||||
maybe_it = active.get(test_name)
|
||||
if maybe_it is None:
|
||||
maybe_it = _TestIteration(test_name=test_name, package=package)
|
||||
active[test_name] = maybe_it
|
||||
it = maybe_it
|
||||
|
||||
if action == "output":
|
||||
output_text = event.get("Output", "")
|
||||
|
|
@ -109,9 +111,6 @@ def parse_go_test_output(
|
|||
|
||||
test_file_path = _resolve_test_file(it.test_name, it.package, test_files, base_dir)
|
||||
test_type = _resolve_test_type(test_file_path, test_files)
|
||||
if test_type is None:
|
||||
logger.debug("Skipping test %s: could not resolve test type", it.test_name)
|
||||
continue
|
||||
|
||||
test_results.add(
|
||||
FunctionTestInvocation(
|
||||
|
|
@ -157,7 +156,7 @@ class _TestIteration:
|
|||
self.stdout: str = ""
|
||||
|
||||
|
||||
def _read_json_output(path: Path, run_result: subprocess.CompletedProcess | None) -> str:
|
||||
def _read_json_output(path: Path, run_result: subprocess.CompletedProcess[str] | None) -> str:
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
if content.strip():
|
||||
|
|
@ -165,15 +164,12 @@ def _read_json_output(path: Path, run_result: subprocess.CompletedProcess | None
|
|||
except Exception:
|
||||
pass
|
||||
if run_result is not None:
|
||||
stdout = run_result.stdout
|
||||
if isinstance(stdout, bytes):
|
||||
stdout = stdout.decode("utf-8", errors="replace")
|
||||
return stdout or ""
|
||||
return run_result.stdout or ""
|
||||
return ""
|
||||
|
||||
|
||||
def _parse_json_lines(content: str) -> list[dict]:
|
||||
events: list[dict] = []
|
||||
def _parse_json_lines(content: str) -> list[dict[str, Any]]:
|
||||
events: list[dict[str, Any]] = []
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
|
|
@ -199,7 +195,7 @@ def _resolve_test_file(test_name: str, package: str, test_files: TestFiles, base
|
|||
return base_dir / f"{test_name}.go"
|
||||
|
||||
|
||||
def _resolve_test_type(test_file_path: Path, test_files: TestFiles):
|
||||
def _resolve_test_type(test_file_path: Path, test_files: TestFiles) -> TestType:
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,10 @@ from typing import TYPE_CHECKING
|
|||
from codeflash.languages.golang.parser import GoAnalyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import tree_sitter
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.golang.parser import GoGlobalDeclaration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -102,7 +105,7 @@ def _merge_global_var_const(optimized_code: str, original_source: str, analyzer:
|
|||
return original_source
|
||||
|
||||
orig_decls = analyzer.find_global_declarations(original_source)
|
||||
orig_names_to_decl: dict[str, object] = {}
|
||||
orig_names_to_decl: dict[str, GoGlobalDeclaration] = {}
|
||||
for decl in orig_decls:
|
||||
for name in decl.names:
|
||||
orig_names_to_decl[name] = decl
|
||||
|
|
@ -131,7 +134,7 @@ def _merge_global_var_const(optimized_code: str, original_source: str, analyzer:
|
|||
return original_source
|
||||
|
||||
|
||||
def _replace_declaration_block(source: str, orig_decl: object, new_source_code: str) -> str:
|
||||
def _replace_declaration_block(source: str, orig_decl: GoGlobalDeclaration, new_source_code: str) -> str:
|
||||
lines = source.splitlines(keepends=True)
|
||||
start = orig_decl.starting_line - 1
|
||||
end = orig_decl.ending_line
|
||||
|
|
@ -186,19 +189,19 @@ def remove_test_functions(test_source: str, functions_to_remove: list[str], anal
|
|||
return "".join(lines)
|
||||
|
||||
|
||||
def _find_doc_comment_start(node: object) -> int | None:
|
||||
prev = getattr(node, "prev_named_sibling", None)
|
||||
def _find_doc_comment_start(node: tree_sitter.Node) -> int | None:
|
||||
prev = node.prev_named_sibling
|
||||
if prev is None:
|
||||
return None
|
||||
if getattr(prev, "type", None) != "comment":
|
||||
if prev.type != "comment":
|
||||
return None
|
||||
if prev.end_point.row + 1 != node.start_point.row:
|
||||
return None
|
||||
comment_start = prev.start_point.row + 1
|
||||
comment_start: int = prev.start_point.row + 1
|
||||
current = prev
|
||||
while True:
|
||||
earlier = getattr(current, "prev_named_sibling", None)
|
||||
if earlier is None or getattr(earlier, "type", None) != "comment":
|
||||
earlier = current.prev_named_sibling
|
||||
if earlier is None or earlier.type != "comment":
|
||||
break
|
||||
if earlier.end_point.row + 1 != current.start_point.row:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import logging
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.languages.base import LanguageSupport
|
||||
from codeflash.languages.golang.comparator import compare_test_results as _compare_results
|
||||
from codeflash.languages.golang.config import detect_go_project, detect_go_version
|
||||
from codeflash.languages.golang.context import extract_code_context as _extract_context
|
||||
|
|
@ -29,18 +30,17 @@ if TYPE_CHECKING:
|
|||
DependencyResolver,
|
||||
FunctionFilterCriteria,
|
||||
HelperFunction,
|
||||
InvocationId,
|
||||
ReferenceInfo,
|
||||
TestInfo,
|
||||
)
|
||||
from codeflash.models.function_types import FunctionToOptimize
|
||||
from codeflash.models.models import GeneratedTestsList
|
||||
from codeflash.models.models import GeneratedTestsList, InvocationId
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_language
|
||||
class GoSupport:
|
||||
class GoSupport(LanguageSupport):
|
||||
def __init__(self) -> None:
|
||||
self._analyzer = GoAnalyzer()
|
||||
self._go_version: str | None = None
|
||||
|
|
|
|||
Loading…
Reference in a new issue