Merge pull request #2064 from codeflash-ai/fix/tracer-subprocess-exit-codes

fix: check subprocess exit codes in Java tracer
This commit is contained in:
mashraf-222 2026-04-21 15:35:46 +02:00 committed by GitHub
commit 67cf123929
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 177 additions and 8 deletions

View file

@ -18,17 +18,19 @@ logger = logging.getLogger(__name__)
GRACEFUL_SHUTDOWN_WAIT = 5 # seconds to wait after SIGTERM before SIGKILL GRACEFUL_SHUTDOWN_WAIT = 5 # seconds to wait after SIGTERM before SIGKILL
def _run_java_with_graceful_timeout( def _run_java_with_graceful_timeout(java_command: list[str], env: dict[str, str], timeout: int, stage_name: str) -> int:
java_command: list[str], env: dict[str, str], timeout: int, stage_name: str
) -> None:
"""Run a Java command with graceful timeout handling. """Run a Java command with graceful timeout handling.
Sends SIGTERM first (allowing JFR dump and shutdown hooks to run), Sends SIGTERM first (allowing JFR dump and shutdown hooks to run),
then SIGKILL if the process doesn't exit within GRACEFUL_SHUTDOWN_WAIT seconds. then SIGKILL if the process doesn't exit within GRACEFUL_SHUTDOWN_WAIT seconds.
Returns the process exit code, or -1 if the process was killed due to timeout.
""" """
if not timeout: if not timeout:
subprocess.run(java_command, env=env, check=False) result = subprocess.run(java_command, env=env, check=False)
return if result.returncode != 0:
logger.warning("%s exited with code %d", stage_name, result.returncode)
return result.returncode
import signal import signal
@ -46,6 +48,11 @@ def _run_java_with_graceful_timeout(
logger.warning("%s stage did not exit after SIGTERM, sending SIGKILL", stage_name) logger.warning("%s stage did not exit after SIGTERM, sending SIGKILL", stage_name)
proc.kill() proc.kill()
proc.wait() proc.wait()
return -1
if proc.returncode != 0:
logger.warning("%s exited with code %d", stage_name, proc.returncode)
return proc.returncode
# --add-opens flags needed for Kryo serialization on Java 16+ # --add-opens flags needed for Kryo serialization on Java 16+
@ -85,12 +92,23 @@ class JavaTracer:
combined_env = self.build_combined_env(jfr_file, config_path) combined_env = self.build_combined_env(jfr_file, config_path)
logger.info("Running combined JFR profiling + argument capture...") logger.info("Running combined JFR profiling + argument capture...")
_run_java_with_graceful_timeout(java_command, combined_env, timeout, "Combined tracing") exit_code = _run_java_with_graceful_timeout(java_command, combined_env, timeout, "Combined tracing")
if not trace_db_path.exists():
msg = (
f"Combined tracing failed with exit code {exit_code} — trace database was not created at "
f"{trace_db_path}. Cannot proceed without trace data."
)
raise RuntimeError(msg)
if exit_code != 0:
logger.warning(
"Combined tracing exited with code %d but trace database was created — proceeding with partial data",
exit_code,
)
if not jfr_file.exists(): if not jfr_file.exists():
logger.warning("JFR file was not created at %s", jfr_file) logger.warning("JFR file was not created at %s", jfr_file)
if not trace_db_path.exists():
logger.error("Trace database was not created at %s", trace_db_path)
return trace_db_path, jfr_file return trace_db_path, jfr_file

View file

@ -0,0 +1,151 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
if TYPE_CHECKING:
from pathlib import Path
import pytest
from codeflash.languages.java.tracer import JavaTracer, _run_java_with_graceful_timeout
class TestRunJavaWithGracefulTimeout:
def test_returns_zero_on_success(self) -> None:
mock_result = MagicMock()
mock_result.returncode = 0
with patch("codeflash.languages.java.tracer.subprocess.run", return_value=mock_result):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 0, "test")
assert rc == 0
def test_returns_nonzero_on_failure(self) -> None:
mock_result = MagicMock()
mock_result.returncode = 1
with patch("codeflash.languages.java.tracer.subprocess.run", return_value=mock_result):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 0, "test")
assert rc == 1
def test_returns_exit_code_137_oom_kill(self) -> None:
mock_result = MagicMock()
mock_result.returncode = 137
with patch("codeflash.languages.java.tracer.subprocess.run", return_value=mock_result):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 0, "test")
assert rc == 137
def test_timeout_path_returns_zero_on_success(self) -> None:
mock_proc = MagicMock()
mock_proc.returncode = 0
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
assert rc == 0
def test_timeout_path_returns_nonzero_on_failure(self) -> None:
mock_proc = MagicMock()
mock_proc.returncode = 1
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
assert rc == 1
def test_timeout_returns_negative_one(self) -> None:
import subprocess
mock_proc = MagicMock()
# First wait() times out, SIGTERM wait succeeds
mock_proc.wait.side_effect = [
subprocess.TimeoutExpired(cmd="java", timeout=60),
None, # SIGTERM wait succeeds
]
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
assert rc == -1
def test_timeout_sends_sigterm_then_sigkill(self) -> None:
import signal
import subprocess
mock_proc = MagicMock()
# First wait() times out, SIGTERM wait also times out
mock_proc.wait.side_effect = [
subprocess.TimeoutExpired(cmd="java", timeout=60),
subprocess.TimeoutExpired(cmd="java", timeout=5),
None,
]
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
assert rc == -1
mock_proc.send_signal.assert_called_once_with(signal.SIGTERM)
mock_proc.kill.assert_called_once()
class TestJavaTracerExitCodeHandling:
def test_success_with_trace_db_created(self, tmp_path: Path) -> None:
trace_db_path = (tmp_path / "trace.db").resolve()
tracer = JavaTracer()
def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
trace_db_path.write_bytes(b"fake-db")
return 0
with (
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
patch.object(tracer, "build_combined_env", return_value={}),
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
):
trace_db, _jfr_file = tracer.trace(
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
)
assert trace_db == trace_db_path
def test_failure_without_trace_db_raises(self, tmp_path: Path) -> None:
trace_db_path = (tmp_path / "trace.db").resolve()
tracer = JavaTracer()
def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
return 1
with (
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
patch.object(tracer, "build_combined_env", return_value={}),
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
pytest.raises(RuntimeError, match="Combined tracing failed with exit code 1"),
):
tracer.trace(
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
)
def test_nonzero_exit_with_trace_db_continues(self, tmp_path: Path) -> None:
trace_db_path = (tmp_path / "trace.db").resolve()
tracer = JavaTracer()
def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
trace_db_path.write_bytes(b"fake-db")
return 1
with (
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
patch.object(tracer, "build_combined_env", return_value={}),
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
):
trace_db, _jfr_file = tracer.trace(
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
)
assert trace_db == trace_db_path
def test_timeout_without_trace_db_raises(self, tmp_path: Path) -> None:
trace_db_path = (tmp_path / "trace.db").resolve()
tracer = JavaTracer()
def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
return -1
with (
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
patch.object(tracer, "build_combined_env", return_value={}),
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
pytest.raises(RuntimeError, match="Combined tracing failed with exit code -1"),
):
tracer.trace(
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
)