diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index aa3e50e07e..cc174e982c 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -240,7 +240,11 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") + response_complete = False async for sse in event_source.aiter_sse(): # pragma: no branch + if response_complete: + continue + is_complete = await self._handle_sse_event( sse, ctx.read_stream_writer, @@ -248,8 +252,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: ctx.metadata.on_resumption_token_update if ctx.metadata else None, ) if is_complete: - await event_source.response.aclose() - break + response_complete = True async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" @@ -342,7 +345,11 @@ async def _handle_sse_response( try: event_source = EventSource(response) + response_complete = False async for sse in event_source.aiter_sse(): # pragma: no branch + if response_complete: + continue + # Track last event ID for potential reconnection if sse.id: last_event_id = sse.id @@ -359,13 +366,15 @@ async def _handle_sse_response( is_initialization=is_initialization, ) # If the SSE event indicates completion, like returning response/error - # break the loop + # keep draining the response to EOF so the HTTP connection can be reused. if is_complete: - await response.aclose() - return # Normal completion, no reconnect needed + response_complete = True except Exception: logger.debug("SSE stream ended", exc_info=True) # pragma: no cover + if response_complete: + return # Normal completion, no reconnect needed + # Stream ended without response - reconnect if we received an event with ID if last_event_id is not None: # pragma: no branch logger.info("SSE stream disconnected, reconnecting...") @@ -405,7 +414,11 @@ async def _handle_reconnection( reconnect_last_event_id: str = last_event_id reconnect_retry_ms = retry_interval_ms + response_complete = False async for sse in event_source.aiter_sse(): + if response_complete: + continue + if sse.id: # pragma: no branch reconnect_last_event_id = sse.id if sse.retry is not None: @@ -418,8 +431,10 @@ async def _handle_reconnection( ctx.metadata.on_resumption_token_update if ctx.metadata else None, ) if is_complete: - await event_source.response.aclose() - return + response_complete = True + + if response_complete: + return # Stream ended again without response - reconnect again (reset attempt counter) logger.info("SSE stream disconnected, reconnecting...") diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb61..a27b8e796d 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -27,6 +27,7 @@ from starlette.requests import Request from starlette.routing import Mount +import mcp.client.streamable_http as streamable_http from mcp import MCPError, types from mcp.client.session import ClientSession from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client @@ -139,6 +140,38 @@ async def replay_events_after( # pragma: no cover return target_stream_id +class FakeStreamResponse: + def __init__(self) -> None: + self.close_count = 0 + + def raise_for_status(self) -> None: + pass + + async def aclose(self) -> None: + self.close_count += 1 + + +class FakeEventSource: + def __init__(self, events: list[ServerSentEvent]) -> None: + self.response = FakeStreamResponse() + self.events = events + self.seen = 0 + + async def aiter_sse(self) -> AsyncIterator[ServerSentEvent]: + for event in self.events: + self.seen += 1 + yield event + + +def jsonrpc_response_event(request_id: str, event_id: str) -> ServerSentEvent: + return ServerSentEvent( + event="message", + data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}), + id=event_id, + retry=None, + ) + + @dataclass class ServerState: lock: anyio.Event = field(default_factory=anyio.Event) @@ -1803,6 +1836,88 @@ async def test_handle_sse_event_skips_empty_data(): await read_stream.aclose() +@pytest.mark.anyio +async def test_handle_sse_response_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch): + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") + response = FakeStreamResponse() + event_source = FakeEventSource( + [ + jsonrpc_response_event("request-1", "event-1"), + ServerSentEvent(event="message", data="", id="event-2", retry=None), + ] + ) + monkeypatch.setattr(streamable_http, "EventSource", lambda _response: event_source) + + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) + try: + async with httpx.AsyncClient() as client: + ctx = streamable_http.RequestContext( + client=client, + session_id=None, + session_message=SessionMessage( + JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={}) + ), + metadata=None, + read_stream_writer=write_stream, + ) + + await transport._handle_sse_response(response, ctx) + + received = await read_stream.receive() + assert isinstance(received.message, types.JSONRPCResponse) + assert received.message.id == "request-1" + assert event_source.seen == 2 + assert response.close_count == 0 + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_reconnection_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch): + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") + event_source = FakeEventSource( + [ + jsonrpc_response_event("request-1", "event-2"), + ServerSentEvent(event="message", data="", id="event-3", retry=None), + ] + ) + + async def sleep_noop(_delay: float) -> None: + pass + + @asynccontextmanager + async def connect_sse(*args: Any, **kwargs: Any) -> AsyncIterator[FakeEventSource]: + yield event_source + + monkeypatch.setattr(streamable_http.anyio, "sleep", sleep_noop) + monkeypatch.setattr(streamable_http, "aconnect_sse", connect_sse) + + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) + try: + async with httpx.AsyncClient() as client: + ctx = streamable_http.RequestContext( + client=client, + session_id=None, + session_message=SessionMessage( + JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={}) + ), + metadata=None, + read_stream_writer=write_stream, + ) + + await transport._handle_reconnection(ctx, last_event_id="event-1") + + received = await read_stream.receive() + assert isinstance(received.message, types.JSONRPCResponse) + assert received.message.id == "request-1" + assert event_source.seen == 2 + assert event_source.response.close_count == 0 + finally: + await write_stream.aclose() + await read_stream.aclose() + + @pytest.mark.anyio async def test_priming_event_not_sent_for_old_protocol_version(): """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat)."""