Merge pull request #2064 from codeflash-ai/fix/tracer-subprocess-exit-codes
fix: check subprocess exit codes in Java tracer
This commit is contained in:
commit
67cf123929
2 changed files with 177 additions and 8 deletions
|
|
@ -18,17 +18,19 @@ logger = logging.getLogger(__name__)
|
|||
GRACEFUL_SHUTDOWN_WAIT = 5 # seconds to wait after SIGTERM before SIGKILL
|
||||
|
||||
|
||||
def _run_java_with_graceful_timeout(
|
||||
java_command: list[str], env: dict[str, str], timeout: int, stage_name: str
|
||||
) -> None:
|
||||
def _run_java_with_graceful_timeout(java_command: list[str], env: dict[str, str], timeout: int, stage_name: str) -> int:
|
||||
"""Run a Java command with graceful timeout handling.
|
||||
|
||||
Sends SIGTERM first (allowing JFR dump and shutdown hooks to run),
|
||||
then SIGKILL if the process doesn't exit within GRACEFUL_SHUTDOWN_WAIT seconds.
|
||||
|
||||
Returns the process exit code, or -1 if the process was killed due to timeout.
|
||||
"""
|
||||
if not timeout:
|
||||
subprocess.run(java_command, env=env, check=False)
|
||||
return
|
||||
result = subprocess.run(java_command, env=env, check=False)
|
||||
if result.returncode != 0:
|
||||
logger.warning("%s exited with code %d", stage_name, result.returncode)
|
||||
return result.returncode
|
||||
|
||||
import signal
|
||||
|
||||
|
|
@ -46,6 +48,11 @@ def _run_java_with_graceful_timeout(
|
|||
logger.warning("%s stage did not exit after SIGTERM, sending SIGKILL", stage_name)
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
return -1
|
||||
|
||||
if proc.returncode != 0:
|
||||
logger.warning("%s exited with code %d", stage_name, proc.returncode)
|
||||
return proc.returncode
|
||||
|
||||
|
||||
# --add-opens flags needed for Kryo serialization on Java 16+
|
||||
|
|
@ -85,12 +92,23 @@ class JavaTracer:
|
|||
combined_env = self.build_combined_env(jfr_file, config_path)
|
||||
|
||||
logger.info("Running combined JFR profiling + argument capture...")
|
||||
_run_java_with_graceful_timeout(java_command, combined_env, timeout, "Combined tracing")
|
||||
exit_code = _run_java_with_graceful_timeout(java_command, combined_env, timeout, "Combined tracing")
|
||||
|
||||
if not trace_db_path.exists():
|
||||
msg = (
|
||||
f"Combined tracing failed with exit code {exit_code} — trace database was not created at "
|
||||
f"{trace_db_path}. Cannot proceed without trace data."
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if exit_code != 0:
|
||||
logger.warning(
|
||||
"Combined tracing exited with code %d but trace database was created — proceeding with partial data",
|
||||
exit_code,
|
||||
)
|
||||
|
||||
if not jfr_file.exists():
|
||||
logger.warning("JFR file was not created at %s", jfr_file)
|
||||
if not trace_db_path.exists():
|
||||
logger.error("Trace database was not created at %s", trace_db_path)
|
||||
|
||||
return trace_db_path, jfr_file
|
||||
|
||||
|
|
|
|||
151
tests/test_languages/test_java/test_tracer_exit_codes.py
Normal file
151
tests/test_languages/test_java/test_tracer_exit_codes.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.tracer import JavaTracer, _run_java_with_graceful_timeout
|
||||
|
||||
|
||||
class TestRunJavaWithGracefulTimeout:
|
||||
def test_returns_zero_on_success(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
with patch("codeflash.languages.java.tracer.subprocess.run", return_value=mock_result):
|
||||
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 0, "test")
|
||||
assert rc == 0
|
||||
|
||||
def test_returns_nonzero_on_failure(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 1
|
||||
with patch("codeflash.languages.java.tracer.subprocess.run", return_value=mock_result):
|
||||
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 0, "test")
|
||||
assert rc == 1
|
||||
|
||||
def test_returns_exit_code_137_oom_kill(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 137
|
||||
with patch("codeflash.languages.java.tracer.subprocess.run", return_value=mock_result):
|
||||
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 0, "test")
|
||||
assert rc == 137
|
||||
|
||||
def test_timeout_path_returns_zero_on_success(self) -> None:
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.returncode = 0
|
||||
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
|
||||
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
|
||||
assert rc == 0
|
||||
|
||||
def test_timeout_path_returns_nonzero_on_failure(self) -> None:
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.returncode = 1
|
||||
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
|
||||
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
|
||||
assert rc == 1
|
||||
|
||||
def test_timeout_returns_negative_one(self) -> None:
|
||||
import subprocess
|
||||
|
||||
mock_proc = MagicMock()
|
||||
# First wait() times out, SIGTERM wait succeeds
|
||||
mock_proc.wait.side_effect = [
|
||||
subprocess.TimeoutExpired(cmd="java", timeout=60),
|
||||
None, # SIGTERM wait succeeds
|
||||
]
|
||||
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
|
||||
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
|
||||
assert rc == -1
|
||||
|
||||
def test_timeout_sends_sigterm_then_sigkill(self) -> None:
|
||||
import signal
|
||||
import subprocess
|
||||
|
||||
mock_proc = MagicMock()
|
||||
# First wait() times out, SIGTERM wait also times out
|
||||
mock_proc.wait.side_effect = [
|
||||
subprocess.TimeoutExpired(cmd="java", timeout=60),
|
||||
subprocess.TimeoutExpired(cmd="java", timeout=5),
|
||||
None,
|
||||
]
|
||||
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
|
||||
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
|
||||
|
||||
assert rc == -1
|
||||
mock_proc.send_signal.assert_called_once_with(signal.SIGTERM)
|
||||
mock_proc.kill.assert_called_once()
|
||||
|
||||
|
||||
class TestJavaTracerExitCodeHandling:
|
||||
def test_success_with_trace_db_created(self, tmp_path: Path) -> None:
|
||||
trace_db_path = (tmp_path / "trace.db").resolve()
|
||||
tracer = JavaTracer()
|
||||
|
||||
def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
|
||||
trace_db_path.write_bytes(b"fake-db")
|
||||
return 0
|
||||
|
||||
with (
|
||||
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
|
||||
patch.object(tracer, "build_combined_env", return_value={}),
|
||||
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
|
||||
):
|
||||
trace_db, _jfr_file = tracer.trace(
|
||||
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
|
||||
)
|
||||
assert trace_db == trace_db_path
|
||||
|
||||
def test_failure_without_trace_db_raises(self, tmp_path: Path) -> None:
|
||||
trace_db_path = (tmp_path / "trace.db").resolve()
|
||||
tracer = JavaTracer()
|
||||
|
||||
def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
|
||||
return 1
|
||||
|
||||
with (
|
||||
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
|
||||
patch.object(tracer, "build_combined_env", return_value={}),
|
||||
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
|
||||
pytest.raises(RuntimeError, match="Combined tracing failed with exit code 1"),
|
||||
):
|
||||
tracer.trace(
|
||||
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
|
||||
)
|
||||
|
||||
def test_nonzero_exit_with_trace_db_continues(self, tmp_path: Path) -> None:
|
||||
trace_db_path = (tmp_path / "trace.db").resolve()
|
||||
tracer = JavaTracer()
|
||||
|
||||
def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
|
||||
trace_db_path.write_bytes(b"fake-db")
|
||||
return 1
|
||||
|
||||
with (
|
||||
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
|
||||
patch.object(tracer, "build_combined_env", return_value={}),
|
||||
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
|
||||
):
|
||||
trace_db, _jfr_file = tracer.trace(
|
||||
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
|
||||
)
|
||||
assert trace_db == trace_db_path
|
||||
|
||||
def test_timeout_without_trace_db_raises(self, tmp_path: Path) -> None:
|
||||
trace_db_path = (tmp_path / "trace.db").resolve()
|
||||
tracer = JavaTracer()
|
||||
|
||||
def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
|
||||
return -1
|
||||
|
||||
with (
|
||||
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
|
||||
patch.object(tracer, "build_combined_env", return_value={}),
|
||||
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
|
||||
pytest.raises(RuntimeError, match="Combined tracing failed with exit code -1"),
|
||||
):
|
||||
tracer.trace(
|
||||
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
|
||||
)
|
||||
Loading…
Reference in a new issue