diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index aa3e50e07e..7589c8c3bb 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -43,6 +43,8 @@ MCP_SESSION_ID = "mcp-session-id" MCP_PROTOCOL_VERSION = "mcp-protocol-version" +MCP_METHOD = "mcp-method" +MCP_NAME = "mcp-name" LAST_EVENT_ID = "last-event-id" # Reconnection defaults @@ -82,7 +84,7 @@ def __init__(self, url: str) -> None: self.session_id: str | None = None self.protocol_version: str | None = None - def _prepare_headers(self) -> dict[str, str]: + def _prepare_headers(self, message: JSONRPCMessage | None = None) -> dict[str, str]: """Build MCP-specific request headers. These headers will be merged with the httpx.AsyncClient's default headers, @@ -97,8 +99,28 @@ def _prepare_headers(self) -> dict[str, str]: headers[MCP_SESSION_ID] = self.session_id if self.protocol_version: headers[MCP_PROTOCOL_VERSION] = self.protocol_version + if isinstance(message, JSONRPCRequest | JSONRPCNotification): + headers[MCP_METHOD] = message.method + if mcp_name := self._get_mcp_name(message): + headers[MCP_NAME] = mcp_name return headers + def _get_mcp_name(self, message: JSONRPCRequest | JSONRPCNotification) -> str | None: + params = message.params + if not isinstance(params, dict): + return None + + if message.method in {"tools/call", "prompts/get"}: + value = params.get("name") + elif message.method in {"resources/read", "resources/subscribe", "resources/unsubscribe"}: + value = params.get("uri") + else: + return None + + if value is None: + return None + return str(value) + def _is_initialization_request(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialization request.""" return isinstance(message, JSONRPCRequest) and message.method == "initialize" @@ -253,8 +275,8 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" - headers = self._prepare_headers() message = ctx.session_message.message + headers = self._prepare_headers(message) is_initialization = self._is_initialization_request(message) async with ctx.client.stream( diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb61..b4ef1191b4 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1718,6 +1718,33 @@ def test_server_validates_protocol_version_header(basic_server: None, basic_serv assert response.status_code == 200 +@pytest.mark.parametrize( + ("method", "params", "expected_name"), + [ + ("tools/call", {"name": "echo_headers"}, "echo_headers"), + ("prompts/get", {"name": "summarize"}, "summarize"), + ("resources/read", {"uri": "file:///tmp/readme.md"}, "file:///tmp/readme.md"), + ("resources/subscribe", {"uri": "file:///tmp/readme.md"}, "file:///tmp/readme.md"), + ("resources/unsubscribe", {"uri": "file:///tmp/readme.md"}, "file:///tmp/readme.md"), + ("tools/call", {}, None), + ("resources/read", {}, None), + ("tools/list", {}, None), + ], +) +def test_streamable_http_client_adds_sep_2243_headers(method: str, params: dict[str, Any], expected_name: str | None): + """POST requests include SEP-2243 method/name headers.""" + transport = StreamableHTTPTransport("https://example.com/mcp") + message = JSONRPCRequest(jsonrpc="2.0", id=1, method=method, params=params) + + headers = transport._prepare_headers(message) + + assert headers["mcp-method"] == method + if expected_name is None: + assert "mcp-name" not in headers + else: + assert headers["mcp-name"] == expected_name + + def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): """Test server accepts requests without protocol version header.""" # First initialize a session to get a valid session ID