From 555e5d1a1b8274e513c455a8525638028b1cdd67 Mon Sep 17 00:00:00 2001 From: mac Date: Tue, 9 Jun 2026 22:28:57 +0800 Subject: [PATCH 1/4] Graceful shutdown --- .env.example | 4 + docker-compose.yml | 1 + src/paperscout/__main__.py | 51 ++++- src/paperscout/config.py | 4 + src/paperscout/health.py | 1 + src/paperscout/monitor.py | 58 ++++- src/paperscout/scout.py | 58 ++++- src/paperscout/shutdown.py | 63 ++++++ tests/test_message_queue.py | 414 +++++++----------------------------- tests/test_monitor.py | 95 +++++++-- tests/test_shutdown.py | 62 ++++++ 11 files changed, 437 insertions(+), 374 deletions(-) create mode 100644 src/paperscout/shutdown.py create mode 100644 tests/test_shutdown.py diff --git a/.env.example b/.env.example index 1db9e7b..a78f271 100644 --- a/.env.example +++ b/.env.example @@ -73,3 +73,7 @@ CACHE_TTL_HOURS=1 # Log level for both console and file (DEBUG|INFO|WARNING|ERROR). LOG_LEVEL=INFO LOG_RETENTION_DAYS=7 + +# --- Graceful shutdown (optional) --- +# Max time to drain queued Slack messages on SIGTERM/SIGINT shutdown. +SHUTDOWN_MQ_DRAIN_TIMEOUT_SECONDS=30 diff --git a/docker-compose.yml b/docker-compose.yml index c4c7386..a945825 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -17,4 +17,5 @@ services: options: max-size: "10m" max-file: "5" + stop_grace_period: 45s restart: unless-stopped diff --git a/src/paperscout/__main__.py b/src/paperscout/__main__.py index e466fe0..48478ba 100644 --- a/src/paperscout/__main__.py +++ b/src/paperscout/__main__.py @@ -5,6 +5,7 @@ import asyncio import logging import logging.handlers +import signal import sys import threading from datetime import datetime, timezone @@ -22,6 +23,7 @@ notify_users, register_handlers, ) +from .shutdown import shutdown_services from .sources import ISOProber, WG21Index from .storage import ProbeState, UserWatchlist @@ -128,8 +130,36 @@ def _setup_logging(data_dir: Path, console_level: str = "INFO", retention_days: logging.getLogger(lib).setLevel(logging.WARNING) +def _register_shutdown_signals( + loop: asyncio.AbstractEventLoop, + shutdown_event: asyncio.Event, + shutdown_reason: list[str | None], +) -> None: + """Register SIGTERM/SIGINT handlers that set *shutdown_event*.""" + + def _on_signal(signame: str) -> None: + if shutdown_reason[0] is None: + shutdown_reason[0] = signame + shutdown_event.set() + + for sig, name in ((signal.SIGTERM, "SIGTERM"), (signal.SIGINT, "SIGINT")): + try: + loop.add_signal_handler(sig, lambda n=name: _on_signal(n)) + except NotImplementedError: + signal.signal(sig, lambda *_a, n=name: _on_signal(n)) + + async def _async_main() -> None: """Start DB, Slack app, health server, and the polling scheduler.""" + shutdown_event = asyncio.Event() + shutdown_reason: list[str | None] = [None] + health_server = None + bolt_thread = None + mq = None + app = None + + _register_shutdown_signals(asyncio.get_running_loop(), shutdown_event, shutdown_reason) + data_dir = settings.data_dir data_dir.mkdir(parents=True, exist_ok=True) @@ -217,7 +247,7 @@ def _extra_health_fields() -> dict: register_handlers(app, user_watchlist, state, paper_count_fn, launch_time) - start_health_server( + health_server = start_health_server( settings.health_port, launch_time, state, @@ -230,12 +260,27 @@ def _extra_health_fields() -> dict: target=app.start, kwargs={"port": settings.port}, daemon=True, + name="bolt", ) bolt_thread.start() enqueue_startup_status(mq, state, paper_count_fn) - await scheduler.run_forever() + try: + await scheduler.run_forever(shutdown_event) + finally: + shutdown_services( + reason=shutdown_reason[0] or "unknown", + mq=mq, + health_server=health_server, + health_thread=( + getattr(health_server, "_paperscout_thread", None) if health_server else None + ), + app=app, + bolt_thread=bolt_thread, + mq_drain_timeout=settings.shutdown_mq_drain_timeout_seconds, + thread_join_timeout=settings.shutdown_thread_join_timeout_seconds, + ) def main() -> None: @@ -244,7 +289,7 @@ def main() -> None: asyncio.run(_async_main()) except KeyboardInterrupt: log.info("=== Paperscout shutting down (KeyboardInterrupt) ===") - sys.exit(0) + sys.exit(0) if __name__ == "__main__": diff --git a/src/paperscout/config.py b/src/paperscout/config.py index 0461c6f..bc50c63 100644 --- a/src/paperscout/config.py +++ b/src/paperscout/config.py @@ -107,6 +107,10 @@ class Settings(BaseSettings): mq_circuit_breaker_cooldown_seconds: int = Field(default=60, ge=1) mq_max_size: int = Field(default=1000, ge=1) + # -- Graceful shutdown -- + shutdown_mq_drain_timeout_seconds: float = Field(default=30.0, ge=0.1) + shutdown_thread_join_timeout_seconds: float = Field(default=5.0, ge=0.1) + @model_validator(mode="after") def _require_slack_credentials_unless_testing(self) -> Settings: """Slack tokens must be set for real runs; pytest sets ``_PAPERSCOUT_TESTING=1``.""" diff --git a/src/paperscout/health.py b/src/paperscout/health.py index 18257dc..398c87a 100644 --- a/src/paperscout/health.py +++ b/src/paperscout/health.py @@ -100,6 +100,7 @@ def start_health_server( server = HTTPServer((bind_host, port), handler) thread = threading.Thread(target=server.serve_forever, daemon=True, name="health") + server._paperscout_thread = thread # noqa: SLF001 — joined during graceful shutdown thread.start() log.info("Health endpoint listening on %s:%d", bind_host, port) return server diff --git a/src/paperscout/monitor.py b/src/paperscout/monitor.py index 4a8790f..b95478c 100644 --- a/src/paperscout/monitor.py +++ b/src/paperscout/monitor.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import contextlib import copy import logging import threading @@ -456,7 +457,33 @@ async def poll_once(self) -> PollResult: self._publish_health_snapshot() return result - async def run_forever(self) -> None: + async def _poll_once_or_cancel(self, shutdown_event: asyncio.Event) -> None: + """Run ``poll_once`` or cancel promptly when *shutdown_event* is set.""" + poll_task = asyncio.create_task(self.poll_once()) + shutdown_task = asyncio.create_task(shutdown_event.wait()) + try: + done, _pending = await asyncio.wait( + {poll_task, shutdown_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + if shutdown_task in done: + poll_task.cancel() + await poll_task + else: + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + await poll_task + except asyncio.CancelledError: + raise + finally: + for task in (poll_task, shutdown_task): + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + async def run_forever(self, shutdown_event: asyncio.Event | None = None) -> None: """Run ``poll_once`` on an interval, with overrun cooldown between cycles.""" interval = self.cfg.poll_interval_minutes * 60 cooldown = self.cfg.poll_overrun_cooldown_seconds @@ -468,10 +495,18 @@ async def run_forever(self) -> None: self.cfg.enable_bulk_wg21, ) run_started_wall = time.time() - while True: + shutdown_requested = False + while shutdown_event is None or not shutdown_event.is_set(): t0 = time.monotonic() try: - await self.poll_once() + if shutdown_event is not None: + await self._poll_once_or_cancel(shutdown_event) + else: + await self.poll_once() + except asyncio.CancelledError: + log.info("POLL-CANCELLED poll=%d reason=shutdown", self._poll_count) + shutdown_requested = True + break except ConfigurationError as exc: log.critical( "POLL-FATAL failure_category=%s poll=%d %s", @@ -514,6 +549,10 @@ async def run_forever(self) -> None: ) elapsed = time.monotonic() - t0 + if shutdown_event is not None and shutdown_event.is_set(): + shutdown_requested = True + break + if self.ops_alert_fn: alert_threshold = 2 * interval now_wall = time.time() @@ -542,4 +581,15 @@ async def run_forever(self) -> None: elapsed, interval, ) - await asyncio.sleep(sleep_for) + if shutdown_event is not None: + try: + await asyncio.wait_for(shutdown_event.wait(), timeout=sleep_for) + shutdown_requested = True + break + except asyncio.TimeoutError: + pass + else: + await asyncio.sleep(sleep_for) + + if shutdown_requested or (shutdown_event is not None and shutdown_event.is_set()): + log.info("SCHEDULER-STOP reason=shutdown_event") diff --git a/src/paperscout/scout.py b/src/paperscout/scout.py index 94f6144..c3cfbc8 100644 --- a/src/paperscout/scout.py +++ b/src/paperscout/scout.py @@ -33,6 +33,8 @@ def create_app() -> App: SLACK_MAX_TEXT = 3000 +_MQ_SENTINEL = object() + # ── Message Queue ───────────────────────────────────────────────────────────── @@ -150,9 +152,7 @@ class MessageQueue: def __init__(self, app: App): self._app = app - self._q: queue.Queue[tuple[str, str, dict]] = queue.Queue( - maxsize=settings.mq_max_size, - ) + self._q: queue.Queue[Any] = queue.Queue(maxsize=settings.mq_max_size) self._last_send: dict[str, float] = {} self._lock = threading.Lock() self._queue_lock = threading.Lock() @@ -163,13 +163,53 @@ def __init__(self, app: App): cooldown_seconds=settings.mq_circuit_breaker_cooldown_seconds, ) self._warned_high_water = False + self._stop_requested = threading.Event() + self._drain_sent_count = 0 + self._drain_sent_lock = threading.Lock() def start(self) -> None: """Start the background sender thread.""" + if self._thread is not None and self._thread.is_alive(): + return self._thread = threading.Thread(target=self._run, daemon=True, name="mq-sender") self._thread.start() log.info("MessageQueue started") + def stop(self) -> None: + """Signal the sender thread to exit after draining queued messages.""" + if self._stop_requested.is_set(): + return + with self._drain_sent_lock: + self._drain_sent_count = 0 + self._stop_requested.set() + self._put_shutdown_sentinel() + + def join(self, timeout: float | None = None) -> bool: + """Wait for the sender thread. Return True if still alive (timed out).""" + if self._thread is None: + return False + self._thread.join(timeout) + return self._thread.is_alive() + + def drain(self, timeout: float | None = None) -> int: + """Stop the queue and block until drained or *timeout* expires. + + Returns the number of messages successfully sent during drain. + """ + if not self._stop_requested.is_set(): + self.stop() + drain_timeout = ( + timeout if timeout is not None else settings.shutdown_mq_drain_timeout_seconds + ) + self.join(drain_timeout) + with self._drain_sent_lock: + return self._drain_sent_count + + def _put_shutdown_sentinel(self) -> None: + """Enqueue shutdown sentinel, bypassing circuit breaker and drop-oldest.""" + with self._queue_lock: + self._q.put(_MQ_SENTINEL) + def depth(self) -> int: """Approximate number of messages waiting to send.""" with self._queue_lock: @@ -235,10 +275,17 @@ def enqueue(self, channel: str, text: str, **kwargs) -> bool: def _run(self) -> None: while True: try: - channel, text, kwargs = self._q.get(timeout=1) + item = self._q.get(timeout=1) except queue.Empty: + if self._stop_requested.is_set(): + break continue + if item is _MQ_SENTINEL: + self._q.task_done() + break + + channel, text, kwargs = item self._throttle(channel) self._send_with_retry(channel, text, kwargs) self._q.task_done() @@ -285,6 +332,9 @@ def _send_with_retry(self, channel: str, text: str, kwargs: dict) -> None: with self._lock: self._last_send[channel] = time.monotonic() self._breaker.record_success() + if self._stop_requested.is_set(): + with self._drain_sent_lock: + self._drain_sent_count += 1 return except SlackApiError as exc: if exc.response.status_code == 429: diff --git a/src/paperscout/shutdown.py b/src/paperscout/shutdown.py new file mode 100644 index 0000000..2bb7ef2 --- /dev/null +++ b/src/paperscout/shutdown.py @@ -0,0 +1,63 @@ +"""Graceful process shutdown: drain MQ, stop HTTP servers, join worker threads.""" + +from __future__ import annotations + +import logging +import threading +from http.server import HTTPServer + +from slack_bolt import App + +from .scout import MessageQueue + +log = logging.getLogger("paperscout") + + +def stop_bolt_server(app: App) -> None: + """Stop Slack Bolt dev server (slack-bolt 1.28+ private ``_development_server`` API).""" + dev = getattr(app, "_development_server", None) + if dev is None: + return + server = getattr(dev, "_server", None) + if server is not None: + server.shutdown() + + +def _join_thread(thread: threading.Thread | None, timeout: float, label: str) -> None: + if thread is None or not thread.is_alive(): + return + thread.join(timeout) + if thread.is_alive(): + log.warning("shutdown: %s thread did not exit within %.1fs", label, timeout) + + +def shutdown_services( + *, + reason: str, + mq: MessageQueue | None, + health_server: HTTPServer | None, + health_thread: threading.Thread | None, + app: App | None, + bolt_thread: threading.Thread | None, + mq_drain_timeout: float, + thread_join_timeout: float, +) -> int: + """Ordered teardown. Returns the number of messages drained from the queue.""" + drained = 0 + if mq is not None: + drained = mq.drain(timeout=mq_drain_timeout) + + if health_server is not None: + health_server.shutdown() + _join_thread(health_thread, thread_join_timeout, "health") + + if app is not None: + stop_bolt_server(app) + _join_thread(bolt_thread, thread_join_timeout, "bolt") + + log.info( + "=== Paperscout shutting down (%s) — drained %d queued message(s) ===", + reason, + drained, + ) + return drained diff --git a/tests/test_message_queue.py b/tests/test_message_queue.py index 5d9e144..139a21b 100644 --- a/tests/test_message_queue.py +++ b/tests/test_message_queue.py @@ -1,358 +1,92 @@ -"""Tests for paperscout.scout.MessageQueue (Slack chat.postMessage worker).""" +"""Tests for MessageQueue graceful shutdown.""" from __future__ import annotations -import logging -import queue import threading import time -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock -import pytest from slack_sdk.errors import SlackApiError -import paperscout.config as cfg from paperscout.scout import CircuitState, MessageQueue -def _slack_error(status: int, headers: dict | None = None) -> SlackApiError: - resp = MagicMock() - resp.status_code = status - resp.headers = headers if headers is not None else {} - return SlackApiError("slack error", resp) +def _make_mq() -> MessageQueue: + app = MagicMock() + app.client.chat_postMessage = MagicMock() + return MessageQueue(app) -@pytest.fixture() -def mq_settings(monkeypatch): - """Fast, small queue/breaker settings for tests.""" - monkeypatch.setattr(cfg.settings, "mq_max_retries", 3) - monkeypatch.setattr(cfg.settings, "mq_circuit_breaker_threshold", 2) - monkeypatch.setattr(cfg.settings, "mq_circuit_breaker_cooldown_seconds", 10) - monkeypatch.setattr(cfg.settings, "mq_max_size", 5) - - -class TestMessageQueueDirect: - """Exercise ``_throttle`` / ``_send_with_retry`` without starting the daemon thread.""" - - def test_health_fields_reports_depth_and_utilization(self): - mq = MessageQueue(MagicMock()) - mq.enqueue("C1", "x") - fields = mq.health_fields() - assert fields["mq_depth"] == 1 - assert fields["mq_max_size"] >= 1 - assert 0.0 <= fields["mq_utilization"] <= 1.0 - assert fields["mq_circuit_state"] == "closed" - - def test_health_fields_clamps_utilization_when_depth_exceeds_max(self): - mq = MessageQueue(MagicMock()) - with patch("paperscout.scout.settings") as cfg: - cfg.mq_max_size = 2 - for i in range(5): - mq.enqueue(f"C{i}", "x") - fields = mq.health_fields() - assert fields["mq_depth"] == 5 - assert fields["mq_max_size"] == 2 - assert fields["mq_utilization"] == 1.0 - - def test_send_success_updates_last_send(self): - app = MagicMock() - mq = MessageQueue(app) - with patch.object(mq, "_throttle"): - mq._send_with_retry("C1", "hello", {}) - app.client.chat_postMessage.assert_called_once_with( - channel="C1", - text="hello", - unfurl_links=False, - unfurl_media=False, - ) - - def test_send_forwards_extra_kwargs(self): - app = MagicMock() - mq = MessageQueue(app) - with patch.object(mq, "_throttle"): - mq._send_with_retry("C1", "x", {"thread_ts": "99.9"}) - app.client.chat_postMessage.assert_called_once_with( - channel="C1", - text="x", - unfurl_links=False, - unfurl_media=False, - thread_ts="99.9", - ) - - def test_429_retries_then_success(self): - app = MagicMock() - app.client.chat_postMessage.side_effect = [ - _slack_error(429, {"Retry-After": "2"}), - None, - ] - mq = MessageQueue(app) - sleeps: list[float] = [] - - with patch.object(mq, "_throttle"): - with patch("paperscout.scout.time.sleep", side_effect=sleeps.append): - mq._send_with_retry("C1", "hi", {}) - - assert app.client.chat_postMessage.call_count == 2 - assert sleeps == [2.0] - - def test_429_default_retry_after_when_header_missing(self): - app = MagicMock() - app.client.chat_postMessage.side_effect = [ - _slack_error(429, {}), - None, - ] - mq = MessageQueue(app) - sleeps: list[float] = [] - - with patch.object(mq, "_throttle"): - with patch("paperscout.scout.time.sleep", side_effect=sleeps.append): - mq._send_with_retry("C1", "hi", {}) - - assert sleeps == [5.0] - - def test_429_retry_cap_exhaustion_dead_letters(self, mq_settings, caplog): - app = MagicMock() - app.client.chat_postMessage.side_effect = _slack_error(429, {"Retry-After": "1"}) - mq = MessageQueue(app) - - with patch.object(mq, "_throttle"): - with patch("paperscout.scout.time.sleep"): - with caplog.at_level(logging.ERROR): - mq._send_with_retry("C1", "stuck message", {}) - - assert app.client.chat_postMessage.call_count == cfg.settings.mq_max_retries + 1 - assert any("MQ-DEAD-LETTER" in r.message for r in caplog.records) - assert any("retry_exhausted" in r.message for r in caplog.records) - - def test_circuit_breaker_trips_after_consecutive_failures(self, mq_settings, caplog): - app = MagicMock() - app.client.chat_postMessage.side_effect = _slack_error(500) - mq = MessageQueue(app) - - with patch.object(mq, "_throttle"): - with caplog.at_level(logging.ERROR): - mq._send_with_retry("C1", "a", {}) - mq._send_with_retry("C1", "b", {}) - - assert mq._breaker.state == CircuitState.OPEN - assert any("MQ-CIRCUIT-OPEN" in r.message for r in caplog.records) - - with patch.object(mq, "_throttle"): - with caplog.at_level(logging.ERROR): - mq._send_with_retry("C1", "c", {}) - - assert app.client.chat_postMessage.call_count == 2 - assert any("circuit_open" in r.message for r in caplog.records) - - def test_circuit_breaker_half_open_recovery(self, mq_settings, caplog): - app = MagicMock() - mq = MessageQueue(app) - with mq._breaker._lock: - mq._breaker._state = CircuitState.OPEN - mq._breaker._opened_at = 1000.0 - mq._breaker._consecutive_failures = cfg.settings.mq_circuit_breaker_threshold - - mono = [1000.0] - - def fake_monotonic(): - return mono[0] - - with patch.object(mq, "_throttle"): - with patch("paperscout.scout.time.monotonic", side_effect=fake_monotonic): - with caplog.at_level(logging.INFO): - mq._send_with_retry("C1", "blocked", {}) - assert mq._breaker.state == CircuitState.OPEN - - mono[0] = 1011.0 - app.client.chat_postMessage.side_effect = None - mq._send_with_retry("C1", "probe ok", {}) - - assert mq._breaker.state == CircuitState.CLOSED - assert any("MQ-CIRCUIT-HALF-OPEN" in r.message for r in caplog.records) - - def test_circuit_breaker_half_open_failure_reopens(self, mq_settings): - app = MagicMock() - app.client.chat_postMessage.side_effect = _slack_error(500) - mq = MessageQueue(app) - with mq._breaker._lock: - mq._breaker._state = CircuitState.HALF_OPEN - - with patch.object(mq, "_throttle"): - mq._send_with_retry("C1", "fail probe", {}) - - assert mq._breaker.state == CircuitState.OPEN - - def test_non_429_slack_error_stops(self): - app = MagicMock() - app.client.chat_postMessage.side_effect = _slack_error(500) - mq = MessageQueue(app) - - with patch.object(mq, "_throttle"): - mq._send_with_retry("C1", "hi", {}) - - assert app.client.chat_postMessage.call_count == 1 - - def test_generic_exception_stops(self): - app = MagicMock() - app.client.chat_postMessage.side_effect = RuntimeError("network down") - mq = MessageQueue(app) - - with patch.object(mq, "_throttle"): - mq._send_with_retry("C1", "hi", {}) - - assert app.client.chat_postMessage.call_count == 1 - - def test_throttle_sleeps_when_within_one_second(self): - app = MagicMock() - mq = MessageQueue(app) - mq._last_send["C1"] = 1000.0 - - sleeps: list[float] = [] - - with patch("paperscout.scout.time.monotonic", return_value=1000.4): - with patch("paperscout.scout.time.sleep", side_effect=sleeps.append): - mq._throttle("C1") - - assert len(sleeps) == 1 - assert sleeps[0] == pytest.approx(0.6, rel=1e-3) - - def test_throttle_no_sleep_when_idle(self): - app = MagicMock() - mq = MessageQueue(app) - mq._last_send["C1"] = 0.0 - - sleeps: list[float] = [] - - with patch("paperscout.scout.time.monotonic", return_value=5000.0): - with patch("paperscout.scout.time.sleep", side_effect=sleeps.append): - mq._throttle("C1") - - assert sleeps == [] - - -class TestMessageQueueBounded: - def test_enqueue_normal_returns_true(self, mq_settings): - app = MagicMock() - mq = MessageQueue(app) - assert mq.enqueue("C1", "hello") is True - assert mq.depth() == 1 - - def test_enqueue_respects_max_size_drop_oldest(self, mq_settings, caplog): - app = MagicMock() - mq = MessageQueue(app) - for i in range(cfg.settings.mq_max_size): - assert mq.enqueue("C", f"msg-{i}") is True - - with caplog.at_level(logging.WARNING): - assert mq.enqueue("C", "newest") is True - - assert mq.depth() == cfg.settings.mq_max_size - assert any("drop-oldest" in r.message for r in caplog.records) - - with mq._queue_lock: - items = [] - while True: - try: - items.append(mq._q.get_nowait()) - except queue.Empty: - break - texts = [t for _, t, _ in items] - assert "msg-0" not in texts - assert "newest" in texts - - def test_enqueue_retries_put_when_get_nowait_empty_after_full(self, mq_settings): - """Full then Empty on drop path must retry put, not silently discard the new item.""" - mq = MessageQueue(MagicMock()) - real_put = mq._q.put_nowait - put_attempts = 0 - - def put_side_effect(item): - nonlocal put_attempts - put_attempts += 1 - if put_attempts == 1: - raise queue.Full - return real_put(item) - - with patch.object(mq._q, "put_nowait", side_effect=put_side_effect): - with patch.object(mq._q, "get_nowait", side_effect=queue.Empty): - assert mq.enqueue("C", "new-item") is True - - assert put_attempts == 2 - assert mq.depth() == 1 - with mq._queue_lock: - _, text, _ = mq._q.get_nowait() - assert text == "new-item" - - def test_enqueue_rejected_when_circuit_open(self, mq_settings, caplog): - app = MagicMock() - mq = MessageQueue(app) - with mq._breaker._lock: - mq._breaker._state = CircuitState.OPEN - mq._breaker._opened_at = time.monotonic() - - with caplog.at_level(logging.WARNING): - assert mq.enqueue("C1", "blocked") is False - - assert mq.depth() == 0 - assert any("enqueue-rejected" in r.message for r in caplog.records) - - def test_enqueue_accepts_after_cooldown_expires(self, mq_settings): - mq = MessageQueue(MagicMock()) - with mq._breaker._lock: - mq._breaker._state = CircuitState.OPEN - mq._breaker._opened_at = 1000.0 - - with patch("paperscout.scout.time.monotonic", return_value=1011.0): - assert mq.enqueue("C1", "after cooldown") is True - assert mq.depth() == 1 - - def test_health_fields_circuit_state_open_after_trip(self, mq_settings, caplog): - app = MagicMock() - app.client.chat_postMessage.side_effect = _slack_error(500) - mq = MessageQueue(app) - with patch.object(mq, "_throttle"): - with caplog.at_level(logging.ERROR): - mq._send_with_retry("C1", "a", {}) - mq._send_with_retry("C1", "b", {}) - assert mq.health_fields()["mq_circuit_state"] == "open" - - def test_health_fields_reports_depth_and_utilization(self, mq_settings): - app = MagicMock() - mq = MessageQueue(app) - mq.enqueue("C1", "a") - mq.enqueue("C1", "b") - fields = mq.health_fields() - assert fields["mq_depth"] == 2 - assert fields["mq_max_size"] == cfg.settings.mq_max_size - assert fields["mq_utilization"] == pytest.approx(2 / cfg.settings.mq_max_size, rel=1e-3) - assert fields["mq_circuit_state"] == "closed" - - def test_high_water_warning_at_80_percent(self, monkeypatch, caplog): - monkeypatch.setattr(cfg.settings, "mq_max_size", 10) - app = MagicMock() - mq = MessageQueue(app) - threshold = int(0.8 * cfg.settings.mq_max_size) - for i in range(threshold - 1): - mq.enqueue("C", f"m{i}") - - with caplog.at_level(logging.WARNING): - mq.enqueue("C", "tip-over") - - assert any("high-water" in r.message for r in caplog.records) - +class TestMessageQueueShutdown: + def test_stop_drains_pending_messages(self): + mq = _make_mq() + mq.start() + for i in range(3): + mq.enqueue("C1", f"msg-{i}") + drained = mq.drain(timeout=5.0) + assert drained == 3 + assert mq._app.client.chat_postMessage.call_count == 3 + assert not mq.join(timeout=0.1) + + def test_stop_is_idempotent(self): + mq = _make_mq() + mq.start() + mq.enqueue("C1", "hello") + mq.stop() + mq.stop() + drained = mq.drain(timeout=5.0) + assert drained == 1 + assert not mq.join(timeout=0.1) + + def test_drain_counts_only_successful_sends(self): + mq = _make_mq() + response = MagicMock() + response.status_code = 500 + err = SlackApiError("fail", response) + + def side_effect(**_kwargs): + if mq._app.client.chat_postMessage.call_count == 1: + raise err + return MagicMock() + + mq._app.client.chat_postMessage.side_effect = side_effect + mq.start() + mq.enqueue("C1", "fail") + mq.enqueue("C1", "ok") + drained = mq.drain(timeout=5.0) + assert drained == 1 -class TestMessageQueueThreaded: - def test_enqueue_processed_by_background_thread(self): - app = MagicMock() - mq = MessageQueue(app) - done = threading.Event() + def test_drain_times_out(self): + mq = _make_mq() - def side_effect(**kwargs): - done.set() + def slow_send(**_kwargs): + time.sleep(2.0) - app.client.chat_postMessage.side_effect = side_effect + mq._app.client.chat_postMessage.side_effect = slow_send + mq.start() + mq.enqueue("C1", "slow") + assert mq.join(timeout=0.1) + drained = mq.drain(timeout=0.1) + assert drained == 0 + mq.join(timeout=5.0) + assert not mq.join(timeout=0.1) + + def test_start_guard_no_double_thread(self): + mq = _make_mq() + mq.start() + first = mq._thread + mq.start() + assert mq._thread is first + assert threading.active_count() >= 1 + mq.drain(timeout=2.0) + def test_stop_bypasses_open_circuit(self): + mq = _make_mq() mq.start() - assert mq.enqueue("D123", "queued message") is True - assert done.wait(timeout=5.0), "chat_postMessage was not invoked in time" - app.client.chat_postMessage.assert_called() + for _ in range(mq._breaker._threshold): + mq._breaker.record_failure() + assert mq._breaker.state == CircuitState.OPEN + assert not mq.enqueue("C1", "blocked") + mq.stop() + assert mq.drain(timeout=2.0) == 0 + assert not mq.join(timeout=0.1) diff --git a/tests/test_monitor.py b/tests/test_monitor.py index f83edf8..6f9e12d 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import contextlib import logging from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch @@ -457,23 +458,26 @@ async def fake_run_cycle(): assert seed_result.probe_hits == [hit] assert state.is_discovered(hit.url) - async def test_run_forever_calls_poll_and_breaks_on_cancel(self, fake_pool): + async def test_run_forever_calls_poll_and_breaks_on_shutdown(self, fake_pool): scheduler, _, _, _, _ = _make_scheduler(fake_pool) + shutdown_event = asyncio.Event() call_count = 0 async def mock_poll_once(): nonlocal call_count call_count += 1 - raise asyncio.CancelledError() + shutdown_event.set() scheduler.poll_once = mock_poll_once with patch("asyncio.sleep", AsyncMock()): - with pytest.raises(asyncio.CancelledError): - await scheduler.run_forever() + await scheduler.run_forever(shutdown_event) assert call_count == 1 async def test_run_forever_continues_after_poll_exception(self, fake_pool): - scheduler, _, _, _, _ = _make_scheduler(fake_pool, poll_interval_minutes=0) + scheduler, _, _, _, _ = _make_scheduler( + fake_pool, poll_interval_minutes=0, poll_overrun_cooldown_seconds=0 + ) + shutdown_event = asyncio.Event() call_count = 0 async def mock_poll_once(): @@ -481,16 +485,18 @@ async def mock_poll_once(): call_count += 1 if call_count == 1: raise RuntimeError("poll failed") - raise asyncio.CancelledError() + shutdown_event.set() scheduler.poll_once = mock_poll_once - with patch("asyncio.sleep", AsyncMock()): - with pytest.raises(asyncio.CancelledError): - await scheduler.run_forever() + with patch("asyncio.wait_for", AsyncMock(side_effect=asyncio.TimeoutError)): + await scheduler.run_forever(shutdown_event) assert call_count == 2 async def test_run_forever_emits_timeout_failure_category(self, fake_pool, caplog): - scheduler, _, _, _, _ = _make_scheduler(fake_pool, poll_interval_minutes=0) + scheduler, _, _, _, _ = _make_scheduler( + fake_pool, poll_interval_minutes=0, poll_overrun_cooldown_seconds=0 + ) + shutdown_event = asyncio.Event() call_count = 0 async def mock_poll_once(): @@ -498,18 +504,20 @@ async def mock_poll_once(): call_count += 1 if call_count == 1: raise httpx.TimeoutException("boom", request=MagicMock()) - raise asyncio.CancelledError() + shutdown_event.set() scheduler.poll_once = mock_poll_once with caplog.at_level(logging.ERROR, logger="paperscout.monitor"): - with patch("asyncio.sleep", AsyncMock()): - with pytest.raises(asyncio.CancelledError): - await scheduler.run_forever() + with patch("asyncio.wait_for", AsyncMock(side_effect=asyncio.TimeoutError)): + await scheduler.run_forever(shutdown_event) assert "failure_category=TIMEOUT" in caplog.text assert call_count == 2 async def test_run_forever_emits_network_failure_category(self, fake_pool, caplog): - scheduler, _, _, _, _ = _make_scheduler(fake_pool, poll_interval_minutes=0) + scheduler, _, _, _, _ = _make_scheduler( + fake_pool, poll_interval_minutes=0, poll_overrun_cooldown_seconds=0 + ) + shutdown_event = asyncio.Event() call_count = 0 async def mock_poll_once(): @@ -517,16 +525,59 @@ async def mock_poll_once(): call_count += 1 if call_count == 1: raise httpx.ConnectError("no route", request=MagicMock()) - raise asyncio.CancelledError() + shutdown_event.set() scheduler.poll_once = mock_poll_once with caplog.at_level(logging.ERROR, logger="paperscout.monitor"): - with patch("asyncio.sleep", AsyncMock()): - with pytest.raises(asyncio.CancelledError): - await scheduler.run_forever() + with patch("asyncio.wait_for", AsyncMock(side_effect=asyncio.TimeoutError)): + await scheduler.run_forever(shutdown_event) assert "failure_category=NETWORK" in caplog.text assert call_count == 2 + async def test_run_forever_exits_on_shutdown_event_during_sleep(self, fake_pool): + scheduler, _, _, _, _ = _make_scheduler(fake_pool, poll_interval_minutes=30) + shutdown_event = asyncio.Event() + + async def mock_poll_once(): + shutdown_event.set() + + scheduler.poll_once = mock_poll_once + with patch("asyncio.sleep", AsyncMock()) as sleep_m: + await scheduler.run_forever(shutdown_event) + sleep_m.assert_not_called() + + async def test_run_forever_exits_when_event_set_before_first_poll(self, fake_pool): + scheduler, _, _, _, _ = _make_scheduler(fake_pool) + shutdown_event = asyncio.Event() + shutdown_event.set() + scheduler.poll_once = AsyncMock() + await scheduler.run_forever(shutdown_event) + scheduler.poll_once.assert_not_called() + + async def test_run_forever_cancels_in_flight_poll(self, fake_pool, caplog): + scheduler, _, _, _, _ = _make_scheduler(fake_pool) + shutdown_event = asyncio.Event() + poll_started = asyncio.Event() + + async def slow_poll_once(): + poll_started.set() + await asyncio.Event().wait() + + async def request_shutdown(): + await poll_started.wait() + shutdown_event.set() + + scheduler.poll_once = slow_poll_once + stopper = asyncio.create_task(request_shutdown()) + try: + with caplog.at_level(logging.INFO, logger="paperscout.monitor"): + await scheduler.run_forever(shutdown_event) + finally: + stopper.cancel() + with contextlib.suppress(asyncio.CancelledError): + await stopper + assert "POLL-CANCELLED" in caplog.text + async def test_failed_probe_cycle_does_not_advance_last_successful_poll_normal_path( self, fake_pool ): @@ -613,8 +664,7 @@ async def capture_sleep(duration: float): mock_time.monotonic.side_effect = [0.0, 360.0, 0.0] scheduler.poll_once = mock_poll_once with patch("asyncio.sleep", capture_sleep): - with pytest.raises(asyncio.CancelledError): - await scheduler.run_forever() + await scheduler.run_forever() assert len(slept) == 1 assert slept[0] == pytest.approx(1440.0) @@ -639,8 +689,7 @@ async def capture_sleep(duration: float): mock_time.monotonic.side_effect = [0.0, 2000.0, 0.0] scheduler.poll_once = mock_poll_once with patch("asyncio.sleep", capture_sleep): - with pytest.raises(asyncio.CancelledError): - await scheduler.run_forever() + await scheduler.run_forever() assert len(slept) == 1 assert slept[0] == pytest.approx(300.0) diff --git a/tests/test_shutdown.py b/tests/test_shutdown.py new file mode 100644 index 0000000..c263454 --- /dev/null +++ b/tests/test_shutdown.py @@ -0,0 +1,62 @@ +"""Tests for paperscout.shutdown.""" + +from __future__ import annotations + +import logging +from http.server import HTTPServer +from unittest.mock import MagicMock + +from paperscout.shutdown import shutdown_services, stop_bolt_server + + +class TestShutdownServices: + def test_shutdown_services_drains_mq_and_logs(self, caplog): + mq = MagicMock() + mq.drain.return_value = 2 + with caplog.at_level(logging.INFO, logger="paperscout"): + drained = shutdown_services( + reason="SIGTERM", + mq=mq, + health_server=None, + health_thread=None, + app=None, + bolt_thread=None, + mq_drain_timeout=30.0, + thread_join_timeout=5.0, + ) + assert drained == 2 + mq.drain.assert_called_once_with(timeout=30.0) + assert any("SIGTERM" in r.message and "drained 2" in r.message for r in caplog.records) + + def test_shutdown_services_skips_none_handles(self): + shutdown_services( + reason="unknown", + mq=None, + health_server=None, + health_thread=None, + app=None, + bolt_thread=None, + mq_drain_timeout=30.0, + thread_join_timeout=5.0, + ) + + def test_stop_bolt_server_calls_shutdown(self): + app = MagicMock() + server = MagicMock() + app._development_server = MagicMock(_server=server) + stop_bolt_server(app) + server.shutdown.assert_called_once() + + def test_shutdown_services_stops_health_server(self): + health_server = MagicMock(spec=HTTPServer) + shutdown_services( + reason="SIGINT", + mq=None, + health_server=health_server, + health_thread=None, + app=None, + bolt_thread=None, + mq_drain_timeout=30.0, + thread_join_timeout=5.0, + ) + health_server.shutdown.assert_called_once() From 1b29f9209350a710e7338819782eea5219dd696f Mon Sep 17 00:00:00 2001 From: mac Date: Wed, 10 Jun 2026 04:45:10 +0800 Subject: [PATCH 2/4] addressed ai reviews --- .env.example | 2 ++ src/paperscout/__main__.py | 1 + src/paperscout/monitor.py | 2 ++ src/paperscout/scout.py | 8 +++-- src/paperscout/shutdown.py | 37 ++++++++++++++++++----- tests/test_message_queue.py | 47 ++++++++++++++++++++++++++++++ tests/test_monitor.py | 58 +++++++++++++++++++++++++++---------- tests/test_shutdown.py | 19 ++++++++++++ 8 files changed, 150 insertions(+), 24 deletions(-) diff --git a/.env.example b/.env.example index a78f271..eb53db1 100644 --- a/.env.example +++ b/.env.example @@ -77,3 +77,5 @@ LOG_RETENTION_DAYS=7 # --- Graceful shutdown (optional) --- # Max time to drain queued Slack messages on SIGTERM/SIGINT shutdown. SHUTDOWN_MQ_DRAIN_TIMEOUT_SECONDS=30 +# Max time to wait for server threads (health, Bolt) to exit during shutdown. +SHUTDOWN_THREAD_JOIN_TIMEOUT_SECONDS=5 diff --git a/src/paperscout/__main__.py b/src/paperscout/__main__.py index 48478ba..75635a4 100644 --- a/src/paperscout/__main__.py +++ b/src/paperscout/__main__.py @@ -138,6 +138,7 @@ def _register_shutdown_signals( """Register SIGTERM/SIGINT handlers that set *shutdown_event*.""" def _on_signal(signame: str) -> None: + """Record the first shutdown signal and wake the scheduler.""" if shutdown_reason[0] is None: shutdown_reason[0] = signame shutdown_event.set() diff --git a/src/paperscout/monitor.py b/src/paperscout/monitor.py index b95478c..da23019 100644 --- a/src/paperscout/monitor.py +++ b/src/paperscout/monitor.py @@ -504,6 +504,8 @@ async def run_forever(self, shutdown_event: asyncio.Event | None = None) -> None else: await self.poll_once() except asyncio.CancelledError: + if shutdown_event is None or not shutdown_event.is_set(): + raise log.info("POLL-CANCELLED poll=%d reason=shutdown", self._poll_count) shutdown_requested = True break diff --git a/src/paperscout/scout.py b/src/paperscout/scout.py index c3cfbc8..d11ccdb 100644 --- a/src/paperscout/scout.py +++ b/src/paperscout/scout.py @@ -208,7 +208,10 @@ def drain(self, timeout: float | None = None) -> int: def _put_shutdown_sentinel(self) -> None: """Enqueue shutdown sentinel, bypassing circuit breaker and drop-oldest.""" with self._queue_lock: - self._q.put(_MQ_SENTINEL) + try: + self._q.put_nowait(_MQ_SENTINEL) + except queue.Full: + pass # _stop_requested still lets _run() exit via queue.Empty + flag def depth(self) -> int: """Approximate number of messages waiting to send.""" @@ -273,6 +276,7 @@ def enqueue(self, channel: str, text: str, **kwargs) -> bool: return True def _run(self) -> None: + """Background sender loop; exits on sentinel or when stop is requested.""" while True: try: item = self._q.get(timeout=1) @@ -315,7 +319,7 @@ def _dead_letter( ) def _send_with_retry(self, channel: str, text: str, kwargs: dict) -> None: - if not self._breaker.allow_send(): + if not self._stop_requested.is_set() and not self._breaker.allow_send(): self._dead_letter(channel, text, reason="circuit_open", kwargs=kwargs) return diff --git a/src/paperscout/shutdown.py b/src/paperscout/shutdown.py index 2bb7ef2..b3646b3 100644 --- a/src/paperscout/shutdown.py +++ b/src/paperscout/shutdown.py @@ -14,16 +14,24 @@ def stop_bolt_server(app: App) -> None: - """Stop Slack Bolt dev server (slack-bolt 1.28+ private ``_development_server`` API).""" + """Stop Slack Bolt HTTP dev server started via ``app.start()``. + + Bolt has no public graceful-shutdown API for the dev server; this uses the + private ``_development_server._server`` handle (slack-bolt pinned in uv.lock). + """ dev = getattr(app, "_development_server", None) if dev is None: return server = getattr(dev, "_server", None) if server is not None: - server.shutdown() + try: + server.shutdown() + except Exception: + log.exception("shutdown: bolt server shutdown failed") def _join_thread(thread: threading.Thread | None, timeout: float, label: str) -> None: + """Wait for *thread* to finish; log a warning if it exceeds *timeout*.""" if thread is None or not thread.is_alive(): return thread.join(timeout) @@ -45,15 +53,30 @@ def shutdown_services( """Ordered teardown. Returns the number of messages drained from the queue.""" drained = 0 if mq is not None: - drained = mq.drain(timeout=mq_drain_timeout) + try: + drained = mq.drain(timeout=mq_drain_timeout) + except Exception: + log.exception("shutdown: MQ drain failed") if health_server is not None: - health_server.shutdown() - _join_thread(health_thread, thread_join_timeout, "health") + try: + health_server.shutdown() + except Exception: + log.exception("shutdown: health server shutdown failed") + try: + _join_thread(health_thread, thread_join_timeout, "health") + except Exception: + log.exception("shutdown: health thread join failed") if app is not None: - stop_bolt_server(app) - _join_thread(bolt_thread, thread_join_timeout, "bolt") + try: + stop_bolt_server(app) + except Exception: + log.exception("shutdown: bolt server stop failed") + try: + _join_thread(bolt_thread, thread_join_timeout, "bolt") + except Exception: + log.exception("shutdown: bolt thread join failed") log.info( "=== Paperscout shutting down (%s) — drained %d queued message(s) ===", diff --git a/tests/test_message_queue.py b/tests/test_message_queue.py index 139a21b..3e5d0a8 100644 --- a/tests/test_message_queue.py +++ b/tests/test_message_queue.py @@ -90,3 +90,50 @@ def test_stop_bypasses_open_circuit(self): mq.stop() assert mq.drain(timeout=2.0) == 0 assert not mq.join(timeout=0.1) + + def test_drain_sends_despite_open_circuit(self): + mq = _make_mq() + gate = threading.Event() + + def gated_send(**_kwargs): + gate.wait(timeout=2.0) + + mq._app.client.chat_postMessage.side_effect = gated_send + mq.start() + mq.enqueue("C1", "queued") + time.sleep(0.05) + for _ in range(mq._breaker._threshold): + mq._breaker.record_failure() + assert mq._breaker.state == CircuitState.OPEN + mq.stop() + gate.set() + drained = mq.drain(timeout=5.0) + assert drained == 1 + + def test_stop_does_not_block_on_full_queue(self, monkeypatch): + monkeypatch.setattr("paperscout.scout.settings.mq_max_size", 1) + mq = _make_mq() + sender_busy = threading.Event() + + def slow_send(**_kwargs): + sender_busy.set() + time.sleep(10) + + mq._app.client.chat_postMessage.side_effect = slow_send + mq.start() + mq.enqueue("C1", "first") + assert sender_busy.wait(timeout=2.0) + with mq._queue_lock: + mq._q.put_nowait(("C1", "second", {})) + + stop_done = threading.Event() + + def call_stop(): + mq.stop() + stop_done.set() + + stopper = threading.Thread(target=call_stop) + stopper.start() + assert stop_done.wait(timeout=1.0) + stopper.join(timeout=1.0) + mq.drain(timeout=2.0) diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 6f9e12d..94b4991 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -24,6 +24,13 @@ from tests.conftest import make_test_settings +def _wait_for_timeout(awaitable, timeout=None): + """Sync mock for asyncio.wait_for that times out immediately without orphaning coroutines.""" + if hasattr(awaitable, "close"): + awaitable.close() + raise asyncio.TimeoutError + + def _recent_hit(**kwargs) -> ProbeHit: defaults = dict( url="https://isocpp.org/files/papers/D9999R0.pdf", @@ -458,6 +465,17 @@ async def fake_run_cycle(): assert seed_result.probe_hits == [hit] assert state.is_discovered(hit.url) + async def test_run_forever_reraises_cancelled_error_without_shutdown_event(self, fake_pool): + scheduler, _, _, _, _ = _make_scheduler(fake_pool) + + async def mock_poll_once(): + raise asyncio.CancelledError() + + scheduler.poll_once = mock_poll_once + with patch("asyncio.sleep", AsyncMock()): + with pytest.raises(asyncio.CancelledError): + await scheduler.run_forever() + async def test_run_forever_calls_poll_and_breaks_on_shutdown(self, fake_pool): scheduler, _, _, _, _ = _make_scheduler(fake_pool) shutdown_event = asyncio.Event() @@ -488,7 +506,7 @@ async def mock_poll_once(): shutdown_event.set() scheduler.poll_once = mock_poll_once - with patch("asyncio.wait_for", AsyncMock(side_effect=asyncio.TimeoutError)): + with patch("asyncio.wait_for", _wait_for_timeout): await scheduler.run_forever(shutdown_event) assert call_count == 2 @@ -508,7 +526,7 @@ async def mock_poll_once(): scheduler.poll_once = mock_poll_once with caplog.at_level(logging.ERROR, logger="paperscout.monitor"): - with patch("asyncio.wait_for", AsyncMock(side_effect=asyncio.TimeoutError)): + with patch("asyncio.wait_for", _wait_for_timeout): await scheduler.run_forever(shutdown_event) assert "failure_category=TIMEOUT" in caplog.text assert call_count == 2 @@ -529,7 +547,7 @@ async def mock_poll_once(): scheduler.poll_once = mock_poll_once with caplog.at_level(logging.ERROR, logger="paperscout.monitor"): - with patch("asyncio.wait_for", AsyncMock(side_effect=asyncio.TimeoutError)): + with patch("asyncio.wait_for", _wait_for_timeout): await scheduler.run_forever(shutdown_event) assert "failure_category=NETWORK" in caplog.text assert call_count == 2 @@ -648,6 +666,7 @@ async def test_run_forever_adaptive_sleep_normal_cycle(self, fake_pool): scheduler, _, _, _, _ = _make_scheduler( fake_pool, poll_interval_minutes=30, poll_overrun_cooldown_seconds=300 ) + shutdown_event = asyncio.Event() call_count = 0 slept: list[float] = [] @@ -655,16 +674,20 @@ async def mock_poll_once(): nonlocal call_count call_count += 1 if call_count >= 2: - raise asyncio.CancelledError() + shutdown_event.set() - async def capture_sleep(duration: float): - slept.append(duration) + def capture_wait_for(awaitable, timeout=None): + if hasattr(awaitable, "close"): + awaitable.close() + if timeout is not None: + slept.append(timeout) + raise asyncio.TimeoutError with patch("paperscout.monitor.time") as mock_time: - mock_time.monotonic.side_effect = [0.0, 360.0, 0.0] + mock_time.monotonic.side_effect = [0.0, 360.0, 0.0, 1.0] scheduler.poll_once = mock_poll_once - with patch("asyncio.sleep", capture_sleep): - await scheduler.run_forever() + with patch("asyncio.wait_for", side_effect=capture_wait_for): + await scheduler.run_forever(shutdown_event) assert len(slept) == 1 assert slept[0] == pytest.approx(1440.0) @@ -673,6 +696,7 @@ async def test_run_forever_adaptive_sleep_overrun_cycle(self, fake_pool): scheduler, _, _, _, _ = _make_scheduler( fake_pool, poll_interval_minutes=30, poll_overrun_cooldown_seconds=300 ) + shutdown_event = asyncio.Event() call_count = 0 slept: list[float] = [] @@ -680,16 +704,20 @@ async def mock_poll_once(): nonlocal call_count call_count += 1 if call_count >= 2: - raise asyncio.CancelledError() + shutdown_event.set() - async def capture_sleep(duration: float): - slept.append(duration) + def capture_wait_for(awaitable, timeout=None): + if hasattr(awaitable, "close"): + awaitable.close() + if timeout is not None: + slept.append(timeout) + raise asyncio.TimeoutError with patch("paperscout.monitor.time") as mock_time: - mock_time.monotonic.side_effect = [0.0, 2000.0, 0.0] + mock_time.monotonic.side_effect = [0.0, 2000.0, 0.0, 1.0] scheduler.poll_once = mock_poll_once - with patch("asyncio.sleep", capture_sleep): - await scheduler.run_forever() + with patch("asyncio.wait_for", side_effect=capture_wait_for): + await scheduler.run_forever(shutdown_event) assert len(slept) == 1 assert slept[0] == pytest.approx(300.0) diff --git a/tests/test_shutdown.py b/tests/test_shutdown.py index c263454..1acb092 100644 --- a/tests/test_shutdown.py +++ b/tests/test_shutdown.py @@ -60,3 +60,22 @@ def test_shutdown_services_stops_health_server(self): thread_join_timeout=5.0, ) health_server.shutdown.assert_called_once() + + def test_shutdown_services_continues_after_mq_drain_failure(self, caplog): + mq = MagicMock() + mq.drain.side_effect = RuntimeError("drain boom") + health_server = MagicMock(spec=HTTPServer) + with caplog.at_level(logging.INFO, logger="paperscout"): + drained = shutdown_services( + reason="SIGTERM", + mq=mq, + health_server=health_server, + health_thread=None, + app=None, + bolt_thread=None, + mq_drain_timeout=30.0, + thread_join_timeout=5.0, + ) + assert drained == 0 + health_server.shutdown.assert_called_once() + assert any("drained 0" in r.message for r in caplog.records) From afaaaf88da0b3d919d96766b5075c579818834ab Mon Sep 17 00:00:00 2001 From: mac Date: Wed, 10 Jun 2026 05:22:48 +0800 Subject: [PATCH 3/4] addressed second ai reviews --- docker-compose.yml | 4 ++++ src/paperscout/__main__.py | 40 ++++++++++++++++++------------------- src/paperscout/scout.py | 7 +++++++ tests/test_message_queue.py | 12 ++++++++++- 4 files changed, 42 insertions(+), 21 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index a945825..4558c98 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -17,5 +17,9 @@ services: options: max-size: "10m" max-file: "5" + # Must exceed the combined shutdown budget: + # SHUTDOWN_MQ_DRAIN_TIMEOUT_SECONDS (30 s default) + # + SHUTDOWN_THREAD_JOIN_TIMEOUT_SECONDS × 2 (5 s × 2 = 10 s default) + # = 40 s. Keep at least 5 s of headroom above that sum. stop_grace_period: 45s restart: unless-stopped diff --git a/src/paperscout/__main__.py b/src/paperscout/__main__.py index 75635a4..ed28c36 100644 --- a/src/paperscout/__main__.py +++ b/src/paperscout/__main__.py @@ -246,28 +246,28 @@ def _extra_health_fields() -> dict: _pool_status(pool), ) - register_handlers(app, user_watchlist, state, paper_count_fn, launch_time) - - health_server = start_health_server( - settings.health_port, - launch_time, - state, - paper_count_fn, - bind_host=settings.health_bind_host, - extra_fields_fn=_extra_health_fields, - ) - log.info("Starting Slack Bolt app on port %d", settings.port) - bolt_thread = threading.Thread( - target=app.start, - kwargs={"port": settings.port}, - daemon=True, - name="bolt", - ) - bolt_thread.start() + try: + register_handlers(app, user_watchlist, state, paper_count_fn, launch_time) + + health_server = start_health_server( + settings.health_port, + launch_time, + state, + paper_count_fn, + bind_host=settings.health_bind_host, + extra_fields_fn=_extra_health_fields, + ) + log.info("Starting Slack Bolt app on port %d", settings.port) + bolt_thread = threading.Thread( + target=app.start, + kwargs={"port": settings.port}, + daemon=True, + name="bolt", + ) + bolt_thread.start() - enqueue_startup_status(mq, state, paper_count_fn) + enqueue_startup_status(mq, state, paper_count_fn) - try: await scheduler.run_forever(shutdown_event) finally: shutdown_services( diff --git a/src/paperscout/scout.py b/src/paperscout/scout.py index d11ccdb..bba0861 100644 --- a/src/paperscout/scout.py +++ b/src/paperscout/scout.py @@ -233,6 +233,13 @@ def health_fields(self) -> dict[str, Any]: def enqueue(self, channel: str, text: str, **kwargs) -> bool: """Queue a ``chat.postMessage``; return False when the circuit breaker rejects.""" + if self._stop_requested.is_set(): + log.warning( + "MQ enqueue-rejected shutdown %s %s", + _redact_channel(channel), + _payload_meta(text, kwargs), + ) + return False if not self._breaker.allow_send(): log.warning( "MQ enqueue-rejected circuit=open %s %s", diff --git a/tests/test_message_queue.py b/tests/test_message_queue.py index 3e5d0a8..5a771d7 100644 --- a/tests/test_message_queue.py +++ b/tests/test_message_queue.py @@ -94,14 +94,16 @@ def test_stop_bypasses_open_circuit(self): def test_drain_sends_despite_open_circuit(self): mq = _make_mq() gate = threading.Event() + sender_started = threading.Event() def gated_send(**_kwargs): + sender_started.set() gate.wait(timeout=2.0) mq._app.client.chat_postMessage.side_effect = gated_send mq.start() mq.enqueue("C1", "queued") - time.sleep(0.05) + assert sender_started.wait(timeout=2.0), "sender did not start" for _ in range(mq._breaker._threshold): mq._breaker.record_failure() assert mq._breaker.state == CircuitState.OPEN @@ -110,6 +112,14 @@ def gated_send(**_kwargs): drained = mq.drain(timeout=5.0) assert drained == 1 + def test_enqueue_rejected_after_stop(self): + mq = _make_mq() + mq.start() + mq.stop() + assert mq.enqueue("C1", "too-late") is False + assert mq._app.client.chat_postMessage.call_count == 0 + mq.drain(timeout=2.0) + def test_stop_does_not_block_on_full_queue(self, monkeypatch): monkeypatch.setattr("paperscout.scout.settings.mq_max_size", 1) mq = _make_mq() From 78cdbb2391bc4ddba9c324e9a400e4877c6e45b3 Mon Sep 17 00:00:00 2001 From: mac Date: Wed, 10 Jun 2026 05:50:20 +0800 Subject: [PATCH 4/4] addressed third reviews --- .env.example | 6 ++++++ docker-compose.yml | 7 ++++++- src/paperscout/__main__.py | 21 ++++++++++++++++++++ src/paperscout/config.py | 4 ++++ src/paperscout/scout.py | 31 +++++++++++++++++++++++------ tests/test_message_queue.py | 39 +++++++++++++++++++++++++++++++++++++ 6 files changed, 101 insertions(+), 7 deletions(-) diff --git a/.env.example b/.env.example index eb53db1..dc21172 100644 --- a/.env.example +++ b/.env.example @@ -79,3 +79,9 @@ LOG_RETENTION_DAYS=7 SHUTDOWN_MQ_DRAIN_TIMEOUT_SECONDS=30 # Max time to wait for server threads (health, Bolt) to exit during shutdown. SHUTDOWN_THREAD_JOIN_TIMEOUT_SECONDS=5 +# Set to the container stop grace period (seconds). When non-zero, a startup +# warning is emitted if the combined shutdown budget +# (SHUTDOWN_MQ_DRAIN_TIMEOUT_SECONDS + 2 × SHUTDOWN_THREAD_JOIN_TIMEOUT_SECONDS) +# meets or exceeds this value. In docker-compose.yml this also controls +# stop_grace_period (default 45 s, which exceeds the 40 s default budget). +STOP_GRACE_PERIOD_SECONDS=45 diff --git a/docker-compose.yml b/docker-compose.yml index 4558c98..5c02fbf 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,6 +10,9 @@ services: env_file: .env environment: HEALTH_BIND_HOST: "0.0.0.0" + # Passed into the container so the startup budget check can compare it + # against the configured shutdown timeouts. Must match stop_grace_period. + STOP_GRACE_PERIOD_SECONDS: "${STOP_GRACE_PERIOD_SECONDS:-45}" extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -21,5 +24,7 @@ services: # SHUTDOWN_MQ_DRAIN_TIMEOUT_SECONDS (30 s default) # + SHUTDOWN_THREAD_JOIN_TIMEOUT_SECONDS × 2 (5 s × 2 = 10 s default) # = 40 s. Keep at least 5 s of headroom above that sum. - stop_grace_period: 45s + # Override STOP_GRACE_PERIOD_SECONDS in the environment or .env file to + # change both this period and the in-process budget check simultaneously. + stop_grace_period: "${STOP_GRACE_PERIOD_SECONDS:-45}s" restart: unless-stopped diff --git a/src/paperscout/__main__.py b/src/paperscout/__main__.py index ed28c36..923d31c 100644 --- a/src/paperscout/__main__.py +++ b/src/paperscout/__main__.py @@ -188,6 +188,27 @@ async def _async_main() -> None: settings.frontier_gap_threshold, ) + _shutdown_budget = ( + settings.shutdown_mq_drain_timeout_seconds + + 2 * settings.shutdown_thread_join_timeout_seconds + ) + log.info( + "Shutdown budget: %.0fs (mq_drain=%.0f + 2×thread_join=%.0f)", + _shutdown_budget, + settings.shutdown_mq_drain_timeout_seconds, + settings.shutdown_thread_join_timeout_seconds, + ) + if ( + settings.stop_grace_period_seconds > 0 + and _shutdown_budget >= settings.stop_grace_period_seconds + ): + log.warning( + "Shutdown budget %.0fs ≥ stop_grace_period %.0fs — increase " + "STOP_GRACE_PERIOD_SECONDS or reduce SHUTDOWN_*_TIMEOUT_SECONDS", + _shutdown_budget, + settings.stop_grace_period_seconds, + ) + if not settings.database_url: log.error("DATABASE_URL is not set — cannot start") sys.exit(1) diff --git a/src/paperscout/config.py b/src/paperscout/config.py index bc50c63..d487d53 100644 --- a/src/paperscout/config.py +++ b/src/paperscout/config.py @@ -110,6 +110,10 @@ class Settings(BaseSettings): # -- Graceful shutdown -- shutdown_mq_drain_timeout_seconds: float = Field(default=30.0, ge=0.1) shutdown_thread_join_timeout_seconds: float = Field(default=5.0, ge=0.1) + # Set to the container orchestrator's stop/grace period (seconds). + # When non-zero, a startup warning is emitted if the combined shutdown budget + # (mq_drain + 2 × thread_join) meets or exceeds this value. + stop_grace_period_seconds: float = Field(default=0.0, ge=0.0) @model_validator(mode="after") def _require_slack_credentials_unless_testing(self) -> Settings: diff --git a/src/paperscout/scout.py b/src/paperscout/scout.py index bba0861..54bb8da 100644 --- a/src/paperscout/scout.py +++ b/src/paperscout/scout.py @@ -251,21 +251,40 @@ def enqueue(self, channel: str, text: str, **kwargs) -> bool: item = (channel, text, kwargs) max_size = settings.mq_max_size with self._queue_lock: + if self._stop_requested.is_set(): + log.warning( + "MQ enqueue-rejected shutdown %s %s", + _redact_channel(channel), + _payload_meta(text, kwargs), + ) + return False while True: try: self._q.put_nowait(item) break except queue.Full: try: - dropped_ch, dropped_text, dropped_kwargs = self._q.get_nowait() - log.warning( - "MQ drop-oldest %s %s", - _redact_channel(dropped_ch), - _payload_meta(dropped_text, dropped_kwargs), - ) + dropped = self._q.get_nowait() except queue.Empty: # Consumer may have taken an item between Full and get_nowait; retry put. continue + if dropped is _MQ_SENTINEL: + try: + self._q.put_nowait(_MQ_SENTINEL) + except queue.Full: + pass + log.warning( + "MQ enqueue-rejected shutdown %s %s", + _redact_channel(channel), + _payload_meta(text, kwargs), + ) + return False + dropped_ch, dropped_text, dropped_kwargs = dropped + log.warning( + "MQ drop-oldest %s %s", + _redact_channel(dropped_ch), + _payload_meta(dropped_text, dropped_kwargs), + ) if max_size > 0: depth = self._q.qsize() high = 0.8 * max_size diff --git a/tests/test_message_queue.py b/tests/test_message_queue.py index 5a771d7..33fbfd9 100644 --- a/tests/test_message_queue.py +++ b/tests/test_message_queue.py @@ -120,6 +120,45 @@ def test_enqueue_rejected_after_stop(self): assert mq._app.client.chat_postMessage.call_count == 0 mq.drain(timeout=2.0) + def test_enqueue_rejects_when_stop_races_under_lock(self): + """Re-check under _queue_lock rejects enqueues that passed the pre-lock check.""" + mq = _make_mq() + mq.start() + + class _GateLock: + def __init__(self, real: threading.Lock): + self._real = real + self.entered = threading.Event() + self.proceed = threading.Event() + + def __enter__(self): + self._real.acquire() + self.entered.set() + self.proceed.wait(timeout=2.0) + return self + + def __exit__(self, *_args): + self._real.release() + + gate = _GateLock(mq._queue_lock) + mq._queue_lock = gate + + result: list[bool] = [] + + def try_enqueue(): + result.append(mq.enqueue("C1", "late")) + + waiter = threading.Thread(target=try_enqueue) + waiter.start() + assert gate.entered.wait(timeout=2.0), "enqueue did not reach lock" + mq.stop() + gate.proceed.set() + waiter.join(timeout=2.0) + + assert result == [False] + assert mq._app.client.chat_postMessage.call_count == 0 + mq.drain(timeout=2.0) + def test_stop_does_not_block_on_full_queue(self, monkeypatch): monkeypatch.setattr("paperscout.scout.settings.mq_max_size", 1) mq = _make_mq()