fix: harden error handling and add missing future annotations

Error handling:
- Protect ast.parse() in _normalizer.py (returns original on SyntaxError)
- Protect cst.parse_module() in _replacement.py (raises ValueError)
- Narrow except Exception to OSError/SyntaxError in _discovery.py (2 sites)
- Narrow except Exception to sqlite3.Error/OSError in _data_parsers.py
- Narrow pickle except to specific unpickling errors in _data_parsers.py

Missing future annotations:
- Add from __future__ import annotations to 12 __init__.py files
This commit is contained in:
Kevin Turcios 2026-04-24 01:36:04 -05:00
parent 6b73b07d15
commit 90a46d732c
16 changed files with 68 additions and 17 deletions

View file

@ -1,5 +1,7 @@
"""Functional programming utilities: Result, Stream, compose, and more.""" """Functional programming utilities: Result, Stream, compose, and more."""
from __future__ import annotations
from .new_type import new_type from .new_type import new_type
from .result import Err, Ok, Result from .result import Err, Ok, Result
from .safe import safe, safe_method from .safe import safe, safe_method

View file

@ -1,5 +1,7 @@
"""AI service wrappers for refinement and repair.""" """AI service wrappers for refinement and repair."""
from __future__ import annotations
from ._refinement import ( from ._refinement import (
AdaptiveCandidate, AdaptiveCandidate,
AdaptiveOptimizeRequest, AdaptiveOptimizeRequest,

View file

@ -1,5 +1,7 @@
"""Code analysis, discovery, and function ranking.""" """Code analysis, discovery, and function ranking."""
from __future__ import annotations
from ._code_utils import find_preexisting_objects from ._code_utils import find_preexisting_objects
from ._discovery import discover_functions from ._discovery import discover_functions
from ._extraction import extract_function_source from ._extraction import extract_function_source

View file

@ -253,11 +253,16 @@ def inspect_top_level_functions_or_methods(
line_no: int | None = None, line_no: int | None = None,
) -> FunctionProperties | None: ) -> FunctionProperties | None:
"""Inspect whether a function/method is top-level in *file_name*.""" """Inspect whether a function/method is top-level in *file_name*."""
with file_name.open(encoding="utf8") as file: try:
try: source = file_name.read_text(encoding="utf-8")
ast_module = ast.parse(file.read()) except OSError:
except Exception: # noqa: BLE001 log.warning("Failed to read %s", file_name)
return None return None
try:
ast_module = ast.parse(source)
except SyntaxError:
log.debug("Failed to parse %s", file_name)
return None
visitor = TopLevelFunctionOrMethodVisitor( visitor = TopLevelFunctionOrMethodVisitor(
file_name=file_name, file_name=file_name,
function_or_method_name=function_or_method_name, function_or_method_name=function_or_method_name,
@ -317,8 +322,12 @@ def find_all_functions_in_file(
"""Find all optimizable functions in a Python file.""" """Find all optimizable functions in a Python file."""
try: try:
source = file_path.read_text(encoding="utf-8") source = file_path.read_text(encoding="utf-8")
except OSError:
log.warning("Failed to read %s", file_path)
return {}
try:
ast.parse(source, filename=str(file_path)) ast.parse(source, filename=str(file_path))
except Exception: # noqa: BLE001 except SyntaxError:
log.debug("Failed to parse %s", file_path) log.debug("Failed to parse %s", file_path)
return {} return {}
fns = discover_functions(source, file_path) fns = discover_functions(source, file_path)

View file

@ -190,7 +190,10 @@ def normalize_python_code(code: str, remove_docstrings: bool = True) -> str: #
Replaces local variable names with canonical forms (var_0, var_1, etc.) Replaces local variable names with canonical forms (var_0, var_1, etc.)
while preserving function names, class names, parameters, and imports. while preserving function names, class names, parameters, and imports.
""" """
tree = ast.parse(code) try:
tree = ast.parse(code)
except SyntaxError:
return code
if remove_docstrings: if remove_docstrings:
_remove_docstrings_from_ast(tree) _remove_docstrings_from_ast(tree)

View file

@ -1,5 +1,7 @@
"""Public programmatic API for codeflash-python.""" """Public programmatic API for codeflash-python."""
from __future__ import annotations
from ._config import OptimizationConfig from ._config import OptimizationConfig
from ._session import OptimizationSession, optimize_function from ._session import OptimizationSession, optimize_function

View file

@ -1,5 +1,7 @@
"""Benchmark tracing and profiling models.""" """Benchmark tracing and profiling models."""
from __future__ import annotations
from .models import BenchmarkKey, ConcurrencyMetrics, ProcessedBenchmarkInfo from .models import BenchmarkKey, ConcurrencyMetrics, ProcessedBenchmarkInfo
__all__ = [ __all__ = [

View file

@ -1,5 +1,7 @@
"""Code generation and source replacement.""" """Code generation and source replacement."""
from __future__ import annotations
from ..analysis._code_utils import find_preexisting_objects from ..analysis._code_utils import find_preexisting_objects
from ._replacement import ( from ._replacement import (
replace_function_source, replace_function_source,

View file

@ -120,16 +120,24 @@ def replace_function_source(
class_name = function.class_name class_name = function.class_name
func_name = function.function_name func_name = function.function_name
optimized_func = _find_function( try:
cst.parse_module(new_source), optimized_func = _find_function(
class_name, cst.parse_module(new_source),
func_name, class_name,
) func_name,
)
except cst.ParserSyntaxError:
msg = f"Failed to parse new_source for {function.qualified_name!r}"
raise ValueError(msg) from None
if optimized_func is None: if optimized_func is None:
msg = f"Function {function.qualified_name!r} not found in new_source" msg = f"Function {function.qualified_name!r} not found in new_source"
raise ValueError(msg) raise ValueError(msg)
original = cst.parse_module(source) try:
original = cst.parse_module(source)
except cst.ParserSyntaxError:
msg = f"Failed to parse source for {function.qualified_name!r}"
raise ValueError(msg) from None
new_body: list[cst.BaseStatement] = [] new_body: list[cst.BaseStatement] = []
for node in original.body: for node in original.body:

View file

@ -1,5 +1,7 @@
"""Context extraction for function optimization.""" """Context extraction for function optimization."""
from __future__ import annotations
from .helpers import discover_helpers from .helpers import discover_helpers
from .models import ( from .models import (
CodeContextType, CodeContextType,

View file

@ -1 +1,3 @@
"""Runtime decorators and utilities.""" """Runtime decorators and utilities."""
from __future__ import annotations

View file

@ -1 +1,3 @@
"""Pickle patching utilities for handling unpicklable objects.""" """Pickle patching utilities for handling unpicklable objects."""
from __future__ import annotations

View file

@ -1,5 +1,7 @@
"""Test discovery, file-level filtering, and Jedi-based linking.""" """Test discovery, file-level filtering, and Jedi-based linking."""
from __future__ import annotations
from .discovery import ( from .discovery import (
discover_tests_pytest, discover_tests_pytest,
discover_unit_tests, discover_unit_tests,

View file

@ -1,5 +1,7 @@
"""Test execution infrastructure.""" """Test execution infrastructure."""
from __future__ import annotations
from ._parse_results import parse_test_results from ._parse_results import parse_test_results
from ._test_runner import ( from ._test_runner import (
async_run_behavioral_tests, async_run_behavioral_tests,

View file

@ -45,15 +45,22 @@ def parse_sqlite_test_results(
" return_value, verification_type" " return_value, verification_type"
" FROM test_results" " FROM test_results"
).fetchall() ).fetchall()
except Exception: # noqa: BLE001 except sqlite3.Error:
log.warning( log.warning(
"Failed to parse test results from %s.", "Failed to read test results from %s.",
sqlite_file_path, sqlite_file_path,
exc_info=True, exc_info=True,
) )
if db is not None: if db is not None:
db.close() db.close()
return test_results return test_results
except OSError:
log.warning(
"Failed to open %s.",
sqlite_file_path,
exc_info=True,
)
return test_results
finally: finally:
if db is not None: if db is not None:
db.close() db.close()
@ -125,8 +132,8 @@ def _process_sqlite_row_inner(
try: try:
ret_val = (pickle.loads(val[7]),) # noqa: S301 ret_val = (pickle.loads(val[7]),) # noqa: S301
except Exception: # noqa: BLE001 except (pickle.UnpicklingError, EOFError, ValueError, TypeError):
log.debug( log.warning(
"Failed to deserialize return value for %s", "Failed to deserialize return value for %s",
test_function_name, test_function_name,
exc_info=True, exc_info=True,

View file

@ -1,5 +1,7 @@
"""Behavioral verification and optimization results.""" """Behavioral verification and optimization results."""
from __future__ import annotations
from ._baseline import establish_original_code_baseline from ._baseline import establish_original_code_baseline
from ._verification import compare_test_results from ._verification import compare_test_results
from .models import ( from .models import (