diff --git a/agents/mcp/src/hyperforge_mcp/agent.py b/agents/mcp/src/hyperforge_mcp/agent.py index 1ead713..4c5f239 100644 --- a/agents/mcp/src/hyperforge_mcp/agent.py +++ b/agents/mcp/src/hyperforge_mcp/agent.py @@ -421,6 +421,14 @@ async def process_tool( "; ".join(error_texts) if error_texts else str(tool_result.meta) ) logger.error(f"Tool {tool_name} encountered an error: {error_message}") + context.chunks.append( + Chunk( + chunk_id=f"mcp_{self.config.id}_{tool_name}_error", + title=f"MCP tool error: {tool_name}", + text=(f"Tool: {tool_name}\nError: {error_message}"), + origin_agent=self.config.module, + ) + ) await memory.add_step( step_module=self.config.module, step_title=self.step_title("Tool error"), @@ -429,10 +437,32 @@ async def process_tool( timeit=time() - t0, ) return - # TODO: extract better information from the tool result - context.structured.append( - json.dumps(tool_result.structuredContent, indent=2, default=str) + text_blocks = sum( + 1 for block in tool_result.content if isinstance(block, types.TextContent) + ) + image_blocks = sum( + 1 for block in tool_result.content if isinstance(block, types.ImageContent) ) + resource_blocks = sum( + 1 for block in tool_result.content if isinstance(block, types.ResourceLink) + ) + + trace_lines = [ + f"Tool: {tool_name}", + f"is_error: {tool_result.isError}", + f"text_blocks: {text_blocks}", + f"image_blocks: {image_blocks}", + f"resource_links: {resource_blocks}", + ] + if tool_result.structuredContent is not None: + structured = json.dumps( + tool_result.structuredContent, indent=2, default=str + ) + trace_lines.append("Structured content (truncated):") + trace_lines.append( + structured[:2000] + ("...(truncated)" if len(structured) > 2000 else "") + ) + context.structured.append(structured) messages.append( Message(author=Author.NUCLIA, text=f"Tool {tool_name} executed") ) @@ -527,8 +557,10 @@ async def process_tool( if block.resource.mimeType is not None else "application/octet-stream", b64encoded=block.resource.blob, + ) ) - ) + + messages.append(Message(author=Author.NUCLIA, text="\n".join(trace_lines))) step_value = ( f"Used tool: {tool_name} with arguments: {tool_arguments}" @@ -1127,6 +1159,7 @@ async def _get_question_context( loaded_tools = False max_retries = 2 for attempt in range(max_retries): + interaction_completed = False try: async with self.http_streaming_session_ctx( manager=manager, memory=memory @@ -1142,6 +1175,19 @@ async def _get_question_context( f"Failed to preload tools from MCP server: {e}" ) self.tools = [] + if self.tools and loaded_tools and attempt == 0: + tools_text = "\n".join( + f"- {t.name}: {t.description or '(no description)'}" + for t in self.tools + ) + context.chunks.append( + Chunk( + chunk_id=f"mcp_{self.config.id}_tools_list", + title="Available MCP tools", + text=f"The following tools are available:\n{tools_text}", + origin_agent=self.config.module, + ) + ) if attempt == 0: # Only preload prompts and resources on the first attempt try: @@ -1166,14 +1212,24 @@ async def _get_question_context( ) = await self.mcp_interaction( memory, manager, question, context ) + interaction_completed = True break # Success, exit retry loop except Exception as e: + if interaction_completed: + logger.warning( + "Ignoring MCP HTTP teardown error after successful interaction: %s", + repr(e), + ) + break + logger.exception( f"Error during MCP HTTP interaction (attempt {attempt + 1}/{max_retries})" ) if attempt + 1 == max_retries: raise e + finally: + self.session = None elif self.driver_context_manager is not None: async with self.driver_context_manager as (read_stream, write_stream): # type: ignore if read_stream is None or write_stream is None: diff --git a/agents/mcp/src/hyperforge_mcp/config_driver.py b/agents/mcp/src/hyperforge_mcp/config_driver.py index 6d5d859..6299fc3 100644 --- a/agents/mcp/src/hyperforge_mcp/config_driver.py +++ b/agents/mcp/src/hyperforge_mcp/config_driver.py @@ -1,18 +1,29 @@ from enum import Enum -from typing import ClassVar, Dict, Literal, Optional +from typing import Any, ClassVar, Dict, Literal, Optional from hyperforge.driver import DriverConfig, EncryptedPayload +from hyperforge.settings import OAuthSettings from hyperforge.utils import WidgetType -from pydantic import Field +from pydantic import Field, model_validator from pydantic.config import ConfigDict +def _redirect_uris_schema_default(schema: Dict[str, Any]) -> None: + try: + callback_url = OAuthSettings().mcp_callback_url + if callback_url: + schema["default"] = [callback_url] + except Exception: + pass + + class MCPHTTPInnerConfig(EncryptedPayload): - encrypted_fields: ClassVar[list[str]] = [] + encrypted_fields: ClassVar[list[str]] = ["client_secret"] uri: str timeout: float = 60 * 5 headers: Dict[str, str] = Field(default_factory=dict) + sse_read_timeout: float = Field(default=300, title="SSE read timeout in seconds") ca_certificate: Optional[str] = Field( default=None, title="CA certificate for HTTPS", @@ -27,7 +38,12 @@ class MCPHTTPInnerConfig(EncryptedPayload): server_url: Optional[str] = Field( default=None, title="OAuth Authorization Server URL" ) - redirect_uris: list[str] = Field(default_factory=list, title="OAuth Redirect URIs") + redirect_uris: list[str] = Field( + default_factory=list, + title="OAuth Redirect URI", + description="The callback URL registered in your OAuth Connected App. Auto-filled from the server configuration - do not change.", + json_schema_extra=_redirect_uris_schema_default, + ) auth_server_url: Optional[str] = Field( default=None, title="OAuth Authorization Server URL" ) @@ -44,6 +60,53 @@ class MCPHTTPInnerConfig(EncryptedPayload): scope: str = Field( default="user", title="OAuth Scopes", description="Default: 'user'" ) + client_id: Optional[str] = Field( + default=None, + title="OAuth Client ID", + description="Pre-registered client ID. If set, skips Dynamic Client Registration.", + ) + client_secret: Optional[str] = Field( + default=None, + title="OAuth Client Secret", + description="Pre-registered client secret. Required when the AS is not a public client.", + ) + authorization_endpoint: Optional[str] = Field( + default=None, + title="OAuth Authorization Endpoint Override", + description=( + "Override the authorization endpoint discovered via RFC 8414 metadata. " + "Use when the AS advertises a non-functional /authorize path." + ), + ) + token_endpoint: Optional[str] = Field( + default=None, + title="OAuth Token Endpoint Override", + description=( + "Override the token endpoint discovered via RFC 8414 metadata. " + "Use when the AS uses a non-standard token path (e.g. uses " + "/services/oauth2/token instead of /token)." + ), + ) + pkce: bool = Field( + default=True, + title="Enable PKCE", + description=( + "Whether to use PKCE (Proof Key for Code Exchange) in the OAuth 2.0 flow. " + "Set to false for Authorization Servers that do not support PKCE " + "(e.g. Connected Apps without PKCE enabled)." + ), + ) + + @model_validator(mode="after") + def _force_redirect_uris(self) -> "MCPHTTPInnerConfig": + """Always override redirect_uris with the zone callback URL.""" + try: + callback_url = OAuthSettings().mcp_callback_url + if callback_url: + self.redirect_uris = [callback_url] + except Exception: + pass + return self class MCPHTTPDriverConfig(DriverConfig[MCPHTTPInnerConfig]): diff --git a/agents/mcp/src/hyperforge_mcp/http.py b/agents/mcp/src/hyperforge_mcp/http.py index 2fba909..2fa875b 100644 --- a/agents/mcp/src/hyperforge_mcp/http.py +++ b/agents/mcp/src/hyperforge_mcp/http.py @@ -1,27 +1,407 @@ +import asyncio +import base64 +import hashlib +import secrets import ssl import tempfile -from functools import partial +from collections.abc import Awaitable, Callable +from functools import cache, partial from typing import Any, Optional -from urllib.parse import parse_qs, urlparse +from urllib.parse import parse_qs, urlencode, urlparse +from uuid import uuid4 import httpx +from cryptography.fernet import Fernet from httpx import Auth, Timeout +from mcp.client.auth import OAuthClientProvider, PKCEParameters, TokenStorage +from mcp.client.auth.exceptions import OAuthFlowError +from mcp.client.streamable_http import streamable_http_client +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthToken, +) +from pydantic import AnyUrl, BaseModel + +from hyperforge import logger from hyperforge.configure import driver from hyperforge.driver import Driver -from hyperforge.interaction import Feedback +from hyperforge.interaction import ( + Feedback, + OAuthAuthenticateURL, + OAuthFeedbackReturnSchema, + Provider, +) from hyperforge.memory import QuestionMemory - -# from mcp.shared.auth import OAuthClientMetadata from hyperforge.utils.http import SafeTransport +from hyperforge_mcp.config_driver import MCPHTTPDriverConfig, MCPHTTPInnerConfig -# from httpx import BasicAuth, DigestAuth -from mcp.client.auth import OAuthClientProvider, TokenStorage -from mcp.client.streamable_http import streamable_http_client -from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken -from pydantic import AnyUrl +# --------------------------------------------------------------------------- +# Stateless MCP OAuth state - Fernet-encrypted routing context +# --------------------------------------------------------------------------- +# The PKCE ``state`` parameter generated by the SDK is replaced with a +# Fernet token that contains the routing context (account_id, session_id, +# etc.) encrypted and signed. +# --------------------------------------------------------------------------- + +_MCP_OAUTH_STATE_TTL = 600 # seconds - max age for the Fernet token +_MCP_OAUTH_SINGLE_FLIGHT_LOCKS: dict[str, asyncio.Lock] = {} +_PKCE_VERIFIER_ALPHABET = ( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" +) -from hyperforge import logger -from hyperforge_mcp.config_driver import MCPHTTPDriverConfig, MCPHTTPInnerConfig + +def _get_single_flight_lock(key: str) -> asyncio.Lock: + lock = _MCP_OAUTH_SINGLE_FLIGHT_LOCKS.get(key) + if lock is None: + lock = asyncio.Lock() + _MCP_OAUTH_SINGLE_FLIGHT_LOCKS[key] = lock + return lock + + +def _fingerprint(value: str) -> str: + return hashlib.sha256(value.encode()).hexdigest()[:12] + + +def _pkce_challenge_for_verifier(code_verifier: str) -> str: + digest = hashlib.sha256(code_verifier.encode()).digest() + return base64.urlsafe_b64encode(digest).decode().rstrip("=") + + +def _generate_pkce_parameters() -> PKCEParameters: + code_verifier = "".join(secrets.choice(_PKCE_VERIFIER_ALPHABET) for _ in range(128)) + return PKCEParameters( + code_verifier=code_verifier, + code_challenge=_pkce_challenge_for_verifier(code_verifier), + ) + + +class MCPOAuthRoutingParams(BaseModel): + """Routing context embedded in the OAuth state parameter. + + Carries the identifiers needed by the callback endpoint to reconstruct + the NATS subject and route the token back to the originating request. + """ + + account_id: str + agent_id: str + workflow_id: str + session_id: str + question_id: str + oauth_uuid: str + sdk_state: str = "" + + @classmethod + def from_memory(cls, memory: QuestionMemory) -> "MCPOAuthRoutingParams": + return cls( + account_id=memory.get_account_id(), + agent_id=memory.get_agent_id(), + workflow_id=memory.get_workflow_id(), + session_id=memory.get_session_id(), + question_id=memory.original_question_uuid, + oauth_uuid=memory.original_question_uuid, + ) + + +@cache +def _get_mcp_fernet() -> Fernet: + """Return a Fernet instance using the shared ENCRYPTION_SECRET_KEY env var.""" + from hyperforge.db.settings import EncryptionSettings + + settings = EncryptionSettings() + return Fernet(settings.encryption_secret_key) + + +def encrypt_mcp_oauth_state(routing: MCPOAuthRoutingParams) -> str: + """Encrypt routing context into a URL-safe Fernet token. + + The token embeds a TTL so that ``decrypt_mcp_oauth_state`` can reject + tokens that arrive after ``_MCP_OAUTH_STATE_TTL`` seconds. + + Args: + routing: MCPOAuthRoutingParams with all routing identifiers and sdk_state. + + Returns: + A URL-safe string suitable for use as the OAuth ``state`` parameter. + """ + f = _get_mcp_fernet() + payload = routing.model_dump_json().encode() + return f.encrypt(payload).decode() + + +def decrypt_mcp_oauth_state(token: str) -> MCPOAuthRoutingParams: + """Decrypt and validate a Fernet-encrypted MCP OAuth state token. + + Args: + token: The value of the ``state`` query parameter from the callback. + + Returns: + The MCPOAuthRoutingParams originally passed to ``encrypt_mcp_oauth_state``. + + Raises: + cryptography.fernet.InvalidToken: If the token is invalid, tampered, + or older than ``_MCP_OAUTH_STATE_TTL`` seconds. + """ + f = _get_mcp_fernet() + payload = f.decrypt(token.encode(), ttl=_MCP_OAUTH_STATE_TTL) + return MCPOAuthRoutingParams.model_validate_json(payload) + + +class _AcceptJsonOAuthClientProvider(OAuthClientProvider): + """OAuthClientProvider that requests JSON from the token endpoint. + + Some OAuth servers default to returning tokens in + URL-encoded form data format. Adding ``Accept: application/json`` makes them + return standard JSON, which the MCP SDK expects. + """ + + async def _exchange_token_authorization_code( + self, + auth_code: str, + code_verifier: str, + *, + token_data: dict[str, Any] | None = None, + ) -> httpx.Request: + req = await super()._exchange_token_authorization_code( + auth_code, code_verifier, token_data=token_data + ) + req.headers["Accept"] = "application/json" + return req + + +class _RoutedOAuthClientProvider(_AcceptJsonOAuthClientProvider): + """OAuthClientProvider that embeds routing context in the OAuth state parameter. + + Replaces the SDK-generated random ``state`` token with a Fernet-encrypted + payload containing routing metadata (account_id, session_id, ...). This + allows the callback endpoint to reconstruct the NATS subject. + Forked from MCP SDK's built-in OAuthClientProvider with the same overrides as the MCP SDK's AuthorizationCodeGrantProvider. + """ + + def __init__( + self, + *args: Any, + routing: MCPOAuthRoutingParams, + authorization_endpoint: str | None = None, + token_endpoint: str | None = None, + pkce: bool = True, + single_flight_key: str | None = None, + callback_handler_for_oauth_uuid: Callable[ + [str], Awaitable[tuple[str, str | None]] + ] + | None = None, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._routing = routing + self._authorization_endpoint = authorization_endpoint + # Keep token_endpoint separate - OASM discovery overwrites oauth_metadata on 401. + self._token_endpoint = token_endpoint + self._pkce = pkce + self._single_flight_key = single_flight_key + self._callback_handler_for_oauth_uuid = callback_handler_for_oauth_uuid + self._pkce_grants: dict[str, tuple[str, str]] = {} + if token_endpoint: + # Seed oauth_metadata so refresh flows work before OASM discovery runs. + existing = self.context.oauth_metadata + self.context.oauth_metadata = OAuthMetadata( + issuer=existing.issuer if existing else "https://placeholder.invalid", + authorization_endpoint=( + existing.authorization_endpoint + if existing and existing.authorization_endpoint + else "https://placeholder.invalid/authorize" + ), + token_endpoint=token_endpoint, # type: ignore[arg-type] + ) + + def _get_token_endpoint(self) -> str: + """Use the manually configured endpoint, ignoring OASM discovery.""" + if self._token_endpoint: + return self._token_endpoint + return super()._get_token_endpoint() + + async def async_auth_flow(self, request: httpx.Request): + if self._single_flight_key is None: + auth_flow = super().async_auth_flow(request) + try: + auth_request = await auth_flow.__anext__() + while True: + response = yield auth_request + auth_request = await auth_flow.asend(response) + except StopAsyncIteration: + return + return + + lock = _get_single_flight_lock(self._single_flight_key) + logger.info( + "mcp_oauth single_flight_wait: key=%s", + self._single_flight_key, + ) + async with lock: + logger.info( + "mcp_oauth single_flight_enter: key=%s", + self._single_flight_key, + ) + auth_flow = super().async_auth_flow(request) + try: + auth_request = await auth_flow.__anext__() + while True: + response = yield auth_request + auth_request = await auth_flow.asend(response) + except StopAsyncIteration: + return + finally: + logger.info( + "mcp_oauth single_flight_exit: key=%s", + self._single_flight_key, + ) + + async def _perform_authorization_code_grant(self) -> tuple[str, str]: + """Override to use a Fernet-encrypted state instead of a random token.""" + if self.context.client_metadata.redirect_uris is None: + raise OAuthFlowError( + "No redirect URIs provided for authorization code grant" + ) + if not self.context.redirect_handler: + raise OAuthFlowError( + "No redirect handler provided for authorization code grant" + ) + if self._callback_handler_for_oauth_uuid is None: + raise OAuthFlowError( + "No callback handler provided for authorization code grant" + ) + + if self._authorization_endpoint: + auth_endpoint = self._authorization_endpoint + elif ( + self.context.oauth_metadata + and self.context.oauth_metadata.authorization_endpoint + ): + auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) + else: + auth_base_url = self.context.get_authorization_base_url( + self.context.server_url + ) + auth_endpoint = auth_base_url + "/authorize" + + if not self.context.client_info: + raise OAuthFlowError("No client info available for authorization") + + # Generate PKCE parameters and a random nonce that the SDK will verify. + pkce_params = _generate_pkce_parameters() + sdk_state = secrets.token_urlsafe(32) + oauth_uuid = uuid4().hex + verifier_fp = _fingerprint(pkce_params.code_verifier) + challenge_fp = _fingerprint(pkce_params.code_challenge) + + # Encrypt routing context + grant identifiers into the OAuth state parameter. + encrypted_state = encrypt_mcp_oauth_state( + self._routing.model_copy( + update={"oauth_uuid": oauth_uuid, "sdk_state": sdk_state} + ) + ) + + auth_params: dict[str, str] = { + "response_type": "code", + "client_id": self.context.client_info.client_id or "", + "redirect_uri": str((self.context.client_metadata.redirect_uris or [])[0]), + "state": encrypted_state, + } + if self._pkce: + auth_params["code_challenge"] = pkce_params.code_challenge + auth_params["code_challenge_method"] = "S256" + self._pkce_grants[verifier_fp] = (oauth_uuid, challenge_fp) + logger.info( + "mcp_oauth pkce_grant_created: oauth_uuid=%s question_id=%s verifier_fp=%s challenge_fp=%s", + oauth_uuid, + self._routing.question_id, + verifier_fp, + challenge_fp, + ) + + if self.context.should_include_resource_param(self.context.protocol_version): + auth_params["resource"] = self.context.get_resource_url() + + if self.context.client_metadata.scope: + auth_params["scope"] = self.context.client_metadata.scope + if "offline_access" in self.context.client_metadata.scope.split(): + auth_params["prompt"] = "consent" + + await self.context.redirect_handler(f"{auth_endpoint}?{urlencode(auth_params)}") + + # Wait for callback - the callback handler returns (code, sdk_state). + auth_code, returned_state = await self._callback_handler_for_oauth_uuid( + oauth_uuid + ) + + if returned_state is None or not secrets.compare_digest( + returned_state, sdk_state + ): + raise OAuthFlowError( + f"State parameter mismatch: {returned_state} != {sdk_state}" + ) + + if not auth_code: + raise OAuthFlowError("No authorization code received") + + logger.info( + "mcp_oauth authorization_grant_ready: oauth_uuid=%s question_id=%s pkce=%s verifier_fp=%s", + oauth_uuid, + self._routing.question_id, + self._pkce, + verifier_fp if self._pkce else "disabled", + ) + return auth_code, pkce_params.code_verifier if self._pkce else "" + + async def _exchange_token_authorization_code( + self, + auth_code: str, + code_verifier: str, + *, + token_data: dict[str, Any] | None = None, + ) -> Any: + """Override to skip code_verifier when PKCE is disabled.""" + if self._pkce: + verifier_fp = _fingerprint(code_verifier) + expected_oauth_uuid, expected_challenge_fp = self._pkce_grants.pop( + verifier_fp, ("unknown", "unknown") + ) + actual_challenge = _pkce_challenge_for_verifier(code_verifier) + actual_challenge_fp = _fingerprint(actual_challenge) + logger.info( + "mcp_oauth token_exchange_pkce: oauth_uuid=%s question_id=%s verifier_fp=%s challenge_fp=%s expected_challenge_fp=%s match=%s", + expected_oauth_uuid, + self._routing.question_id, + verifier_fp, + actual_challenge_fp, + expected_challenge_fp, + expected_challenge_fp == actual_challenge_fp, + ) + req = await super()._exchange_token_authorization_code( + auth_code, code_verifier, token_data=token_data + ) + logger.info( + "mcp_oauth token_exchange_request_built: oauth_uuid=%s question_id=%s token_url=%s has_code_verifier=%s", + expected_oauth_uuid, + self._routing.question_id, + req.url, + "code_verifier=" in req.content.decode(errors="ignore"), + ) + return req + # PKCE disabled: build the token request without code_verifier. + if not self.context.client_info: + raise OAuthFlowError("Missing client info") + token_url = self._get_token_endpoint() + data: dict[str, str] = { + "grant_type": "authorization_code", + "code": auth_code, + "redirect_uri": str((self.context.client_metadata.redirect_uris or [])[0]), + "client_id": self.context.client_info.client_id or "", + } + headers = {"Content-Type": "application/x-www-form-urlencoded"} + data, headers = self.context.prepare_token_auth(data, headers) + return httpx.Request("POST", token_url, data=data, headers=headers) def create_mcp_http_client( @@ -110,64 +490,231 @@ def create_mcp_http_client( return httpx.AsyncClient(transport=safe_transport, **kwargs) -class InMemoryTokenStorage(TokenStorage): - """Demo In-memory token storage implementation.""" +class FeedbackTokenStorage(TokenStorage): + """Token storage that negotiates MCP credentials with the user via the Feedback mechanism. - def __init__(self): - self.tokens: OAuthToken | None = None - self.client_info: OAuthClientInformationFull | None = None + On ``get_tokens()``, asks the user for an existing Bearer token (15 s timeout). + If none, returns ``None`` so the SDK starts the full OAuth flow. + On ``set_tokens()``, notifies the user so they can store it client-side. + Nothing is persisted server-side. + """ + + def __init__( + self, + memory: QuestionMemory, + agent_id: str, + uri: str, + client_id: Optional[str], + client_secret: Optional[str], + redirect_uris: list[str], + grant_types: list[str], + response_types: list[str], + scope: Optional[str], + ) -> None: + self._memory = memory + self._agent_id = agent_id + self._uri = uri + self._client_info: OAuthClientInformationFull | None = None + if client_id is not None: + self._client_info = OAuthClientInformationFull( + client_id=client_id, + client_secret=client_secret, + token_endpoint_auth_method="client_secret_post" + if client_secret + else "none", + client_name="Hyperforge MCP Client", + redirect_uris=[AnyUrl(x) for x in redirect_uris], + grant_types=grant_types, + response_types=response_types, + scope=scope, + ) async def get_tokens(self) -> OAuthToken | None: - """Get stored tokens.""" - return self.tokens + """Ask the user for an existing Bearer token (15 s timeout). + + Returns the token if provided, or None to let the SDK start the OAuth flow. + """ + logger.info( + "mcp_oauth get_tokens: asking client for cached credentials; uri=%s agent=%s session=%s", + self._uri, + self._agent_id, + self._memory.get_session_id(), + ) + feedback = Feedback( + request_id=self._memory.get_session_id(), + question="Get credentials", + data=None, + get_credentials={self._uri: Provider.MCP_OAUTH}, + module="oauth", + agent_id=self._agent_id, + response_schema=OAuthFeedbackReturnSchema.model_json_schema(), + # Give the client 60 s to reply before assuming "no credentials" + # and starting OAuth automatically. + timeout_ms=60_000, + ) + try: + answer = await asyncio.wait_for( + self._memory.send_feedback(feedback), timeout=61.0 + ) + except (asyncio.TimeoutError, asyncio.CancelledError, Exception) as exc: + logger.info( + "mcp_oauth get_tokens: no answer from client (timeout/cancel); uri=%s reason=%s", + self._uri, + type(exc).__name__, + ) + answer = None + if answer is not None: + try: + schema = OAuthFeedbackReturnSchema.model_validate_json(answer.response) + if schema.existing_credentials: + creds = schema.existing_credentials.get(self._uri, {}) + access_token = creds.get("access_token") + if access_token: + logger.info( + "mcp_oauth get_tokens: cached token found; uri=%s agent=%s", + self._uri, + self._agent_id, + ) + return OAuthToken( + access_token=access_token, token_type="Bearer" + ) + except Exception: + pass + logger.info( + "mcp_oauth get_tokens: no cached token - starting full OAuth flow; uri=%s agent=%s", + self._uri, + self._agent_id, + ) + return None async def set_tokens(self, tokens: OAuthToken) -> None: - """Store tokens.""" - self.tokens = tokens + """Notify the WS client of the new token before the SDK retries. + + We await directly so the feedback is delivered before the SDK + continues - a fire-and-forget task gets cancelled during MCP + transport teardown and never reaches the WebSocket client. + """ + logger.info( + "mcp_oauth set_tokens: sending new token to WS client; uri=%s agent=%s session=%s", + self._uri, + self._agent_id, + self._memory.get_session_id(), + ) + feedback = Feedback( + request_id=self._memory.get_session_id(), + question="Send credentials", + data=None, + credentials={ + self._uri: { + "access_token": tokens.access_token, + "token_type": tokens.token_type, + } + }, + module="oauth", + agent_id=self._agent_id, + response_schema=OAuthFeedbackReturnSchema.model_json_schema(), + # Give the client time to ACK. The SDK waits for set_tokens + # to return before retrying the MCP request, so up to 60 s + # is acceptable; the client can ACK immediately. + timeout_ms=60_000, + ) + try: + await asyncio.wait_for(self._memory.send_feedback(feedback), timeout=61.0) + logger.info( + "mcp_oauth set_tokens: token delivered to WS client; uri=%s agent=%s", + self._uri, + self._agent_id, + ) + except (asyncio.TimeoutError, asyncio.CancelledError, Exception) as exc: + logger.warning( + "mcp_oauth set_tokens: failed to deliver token (ACK optional); uri=%s reason=%s", + self._uri, + type(exc).__name__, + ) async def get_client_info(self) -> OAuthClientInformationFull | None: - """Get stored client information.""" - return self.client_info + return self._client_info async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: - """Store client information.""" - self.client_info = client_info - - -async def handle_redirect( - memory: QuestionMemory, module: str, agent_id: str, request_id: str, auth_url: str -) -> None: - resp = await memory.send_feedback( - Feedback( - request_id=request_id, - question=f"Visit: {auth_url}", - module=module, - agent_id=agent_id, - data={"auth_url": auth_url}, - response_schema=None, - ) + self._client_info = client_info + + +async def handle_redirect(memory: QuestionMemory, auth_url: str) -> None: + """Send the OAuth authorization URL to the user. + + The routing context is already embedded in the ``state`` parameter of + ``auth_url`` as a Fernet-encrypted token - no in-process registration + is needed here. + """ + logger.info( + "mcp_oauth handle_redirect: sending OAuth URL to user; agent=%s session=%s", + memory.get_agent_id(), + memory.get_session_id(), + ) + await memory.send_oauth(oauth=OAuthAuthenticateURL(oauth_url=auth_url)) + logger.info( + "mcp_oauth handle_redirect: OAuth URL delivered, waiting for user to authenticate; agent=%s session=%s", + memory.get_agent_id(), + memory.get_session_id(), ) - logger.info("Redirect feedback sent:", resp) async def handle_callback( - memory: QuestionMemory, module: str, agent_id: str, request_id: str + memory: QuestionMemory, module: str, agent_id: str, oauth_uuid: str ) -> tuple[str, str | None]: - callback_url = await memory.send_feedback( - Feedback( - request_id=request_id, - question="Callback", - module=module, - agent_id=agent_id, - data=None, - response_schema=None, - ) + has_fn = getattr(memory, "oauth_callback_fn", None) is not None + logger.info( + "mcp_oauth handle_callback: waiting for OAuth callback; agent=%s session=%s question_id=%s oauth_uuid=%s has_oauth_fn=%s", + agent_id, + memory.get_session_id(), + memory.original_question_uuid, + oauth_uuid, + has_fn, + ) + callback_payload = await memory.recv_oauth_callback( + question_id=memory.original_question_uuid, + oauth_uuid=oauth_uuid, ) - if callback_url: - params = parse_qs(urlparse(callback_url.response).query) + if callback_payload: + if "://" in callback_payload: + params = parse_qs(urlparse(callback_payload).query) + else: + params = parse_qs(callback_payload.lstrip("?")) + + if "error" in params: + err = params["error"][0] + desc = params.get("error_description", [""])[0] + logger.error( + "mcp_oauth handle_callback: OAuth error from provider; agent=%s error=%s desc=%s", + agent_id, + err, + desc, + ) + raise ValueError(f"OAuth authorization failed: {err} {desc}".strip()) + + if "code" not in params: + logger.error( + "mcp_oauth handle_callback: callback payload missing code; agent=%s payload_keys=%s", + agent_id, + list(params.keys()), + ) + raise ValueError( + "OAuth callback payload does not include authorization code" + ) + + logger.info( + "mcp_oauth handle_callback: authorization code received ok; agent=%s session=%s", + agent_id, + memory.get_session_id(), + ) return params["code"][0], params.get("state", [None])[0] else: - raise ValueError("No callback URL received") + logger.error( + "mcp_oauth handle_callback: no callback payload received (timeout?); agent=%s session=%s", + agent_id, + memory.get_session_id(), + ) + raise ValueError("No OAuth callback payload received") @driver( @@ -207,21 +754,43 @@ def client( new_headers.update(headers) if self.config.auth_server_url is not None: - auth: Auth | None = OAuthClientProvider( + routing = MCPOAuthRoutingParams.from_memory(memory) + storage = FeedbackTokenStorage( + memory=memory, + agent_id=agent_id, + uri=self.config.uri, + client_id=self.config.client_id, + client_secret=self.config.client_secret, + redirect_uris=self.config.redirect_uris, + grant_types=self.config.grant_types, + response_types=self.config.response_types, + scope=self.config.scope, + ) + auth: Auth | None = _RoutedOAuthClientProvider( server_url=self.config.auth_server_url, client_metadata=OAuthClientMetadata( - client_name="ARAG MCP Client", + client_name="Hyperforge MCP Client", redirect_uris=[AnyUrl(x) for x in self.config.redirect_uris], grant_types=self.config.grant_types, response_types=self.config.response_types, scope=self.config.scope, ), - storage=InMemoryTokenStorage(), - redirect_handler=partial( - handle_redirect, memory, module, agent_id, request_id + storage=storage, + routing=routing, + authorization_endpoint=self.config.authorization_endpoint, + token_endpoint=self.config.token_endpoint, + pkce=self.config.pkce, + single_flight_key=( + f"{routing.account_id}.{routing.agent_id}." + f"{routing.workflow_id}.{routing.session_id}." + f"{routing.question_id}.{self.config.uri}" ), - callback_handler=partial( - handle_callback, memory, module, agent_id, request_id + redirect_handler=partial(handle_redirect, memory), + callback_handler_for_oauth_uuid=partial( + handle_callback, + memory, + module, + agent_id, ), ) else: diff --git a/agents/mcp/tests/test_mcp_oauth.py b/agents/mcp/tests/test_mcp_oauth.py new file mode 100644 index 0000000..4d4b699 --- /dev/null +++ b/agents/mcp/tests/test_mcp_oauth.py @@ -0,0 +1,244 @@ +import asyncio +import base64 +import hashlib +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch +from urllib.parse import parse_qs, urlparse + +import httpx +import pytest +from mcp.client.auth import PKCEParameters +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata +from pydantic import AnyUrl + +from hyperforge_mcp.http import ( + _MCP_OAUTH_SINGLE_FLIGHT_LOCKS, + MCPOAuthRoutingParams, + _generate_pkce_parameters, + _RoutedOAuthClientProvider, + handle_callback, +) + + +def _make_provider(**kwargs) -> _RoutedOAuthClientProvider: + redirect_handler = kwargs.pop("redirect_handler", AsyncMock()) + callback_handler = kwargs.pop("callback_handler_for_oauth_uuid", AsyncMock()) + provider = _RoutedOAuthClientProvider( + server_url="https://auth.example.com", + client_metadata=OAuthClientMetadata( + client_name="client", + redirect_uris=[AnyUrl("https://app.example.com/callback")], + grant_types=["authorization_code"], + response_types=["code"], + ), + storage=SimpleNamespace(), + routing=MCPOAuthRoutingParams( + account_id="account", + agent_id="agent", + workflow_id="workflow", + session_id="session", + question_id="question", + oauth_uuid="base-oauth", + ), + redirect_handler=redirect_handler, + callback_handler_for_oauth_uuid=callback_handler, + **kwargs, + ) + provider.context.client_info = OAuthClientInformationFull( + client_id="client-id", + redirect_uris=[AnyUrl("https://app.example.com/callback")], + ) + return provider + + +def _make_pkce(verifier: str) -> PKCEParameters: + digest = hashlib.sha256(verifier.encode()).digest() + return PKCEParameters( + code_verifier=verifier, + code_challenge=base64.urlsafe_b64encode(digest).decode().rstrip("="), + ) + + +def test_generate_pkce_parameters_uses_salesforce_safe_verifier(): + pkce = _generate_pkce_parameters() + + assert len(pkce.code_verifier) == 128 + assert pkce.code_verifier.isalnum() + assert pkce.code_challenge == _make_pkce(pkce.code_verifier).code_challenge + + +@pytest.mark.asyncio +async def test_handle_callback_uses_separate_oauth_uuid(): + memory = SimpleNamespace( + original_question_uuid="question-id", + oauth_callback_fn=object(), + recv_oauth_callback=AsyncMock(return_value="code=auth-code&state=sdk-state"), + get_session_id=lambda: "session-id", + ) + + code, state = await handle_callback( + memory=memory, + module="module", + agent_id="agent-id", + oauth_uuid="oauth-attempt-id", + ) + + assert code == "auth-code" + assert state == "sdk-state" + memory.recv_oauth_callback.assert_awaited_once_with( + question_id="question-id", + oauth_uuid="oauth-attempt-id", + ) + + +@pytest.mark.asyncio +async def test_oauth_uuid_is_unique_per_authorization_grant(): + sdk_states_by_oauth_uuid: dict[str, str] = {} + callback_oauth_uuids: list[str] = [] + + def encrypt_state(routing: MCPOAuthRoutingParams) -> str: + sdk_states_by_oauth_uuid[routing.oauth_uuid] = routing.sdk_state + return f"encrypted-{routing.oauth_uuid}" + + async def callback_handler(oauth_uuid: str) -> tuple[str, str | None]: + callback_oauth_uuids.append(oauth_uuid) + return "auth-code", sdk_states_by_oauth_uuid[oauth_uuid] + + provider = _make_provider( + authorization_endpoint="https://auth.example.com/authorize", + callback_handler_for_oauth_uuid=callback_handler, + ) + + with ( + patch( + "hyperforge_mcp.http.uuid4", + side_effect=[ + SimpleNamespace(hex="oauth-grant-1"), + SimpleNamespace(hex="oauth-grant-2"), + ], + ), + patch( + "hyperforge_mcp.http.encrypt_mcp_oauth_state", + side_effect=encrypt_state, + ), + ): + await provider._perform_authorization_code_grant() + await provider._perform_authorization_code_grant() + + assert callback_oauth_uuids == ["oauth-grant-1", "oauth-grant-2"] + + +@pytest.mark.asyncio +async def test_authorization_grant_keeps_matching_pkce_pair(): + sdk_states_by_oauth_uuid: dict[str, str] = {} + + def encrypt_state(routing: MCPOAuthRoutingParams) -> str: + sdk_states_by_oauth_uuid[routing.oauth_uuid] = routing.sdk_state + return f"encrypted-{routing.oauth_uuid}" + + async def callback_handler(oauth_uuid: str) -> tuple[str, str | None]: + return f"auth-code-{oauth_uuid}", sdk_states_by_oauth_uuid[oauth_uuid] + + redirect_handler = AsyncMock() + provider = _make_provider( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + redirect_handler=redirect_handler, + callback_handler_for_oauth_uuid=callback_handler, + ) + + verifier = "A" * 128 + pkce = _make_pkce(verifier) + with ( + patch( + "hyperforge_mcp.http.uuid4", + return_value=SimpleNamespace(hex="grant-id"), + ), + patch( + "hyperforge_mcp.http._generate_pkce_parameters", + return_value=pkce, + ), + patch( + "hyperforge_mcp.http.encrypt_mcp_oauth_state", + side_effect=encrypt_state, + ), + ): + auth_code, code_verifier = await provider._perform_authorization_code_grant() + + redirect_url = redirect_handler.await_args.args[0] + auth_params = parse_qs(urlparse(redirect_url).query) + assert auth_params["code_challenge"] == [pkce.code_challenge] + assert auth_params["code_challenge_method"] == ["S256"] + assert auth_code == "auth-code-grant-id" + assert code_verifier == verifier + + token_request = await provider._exchange_token_authorization_code( + auth_code, code_verifier + ) + token_params = parse_qs(token_request.content.decode()) + assert token_params["code"] == ["auth-code-grant-id"] + assert token_params["code_verifier"] == [verifier] + + +@pytest.mark.asyncio +async def test_token_exchange_uses_fresh_token_data_per_grant(): + provider = _make_provider( + token_endpoint="https://auth.example.com/token", + ) + + first_request = await provider._exchange_token_authorization_code( + "auth-code-1", "verifier-1" + ) + second_request = await provider._exchange_token_authorization_code( + "auth-code-2", "verifier-2" + ) + + first_params = parse_qs(first_request.content.decode()) + second_params = parse_qs(second_request.content.decode()) + + assert first_params["code"] == ["auth-code-1"] + assert first_params["code_verifier"] == ["verifier-1"] + assert second_params["code"] == ["auth-code-2"] + assert second_params["code_verifier"] == ["verifier-2"] + + +@pytest.mark.asyncio +async def test_oauth_auth_flow_is_single_flight_per_question_and_uri(): + _MCP_OAUTH_SINGLE_FLIGHT_LOCKS.clear() + started: list[int] = [] + release_first_flow = asyncio.Event() + + async def fake_auth_flow(self, request): + started.append(id(self)) + auth_request = httpx.Request("GET", "https://auth.example.com/authorize") + yield auth_request + await release_first_flow.wait() + + async def drive_auth_flow(provider): + flow = provider.async_auth_flow(httpx.Request("GET", "https://mcp.example.com")) + auth_request = await flow.__anext__() + with pytest.raises(StopAsyncIteration): + await flow.asend(httpx.Response(200, request=auth_request)) + + provider_1 = _make_provider(single_flight_key="same-question-uri") + provider_2 = _make_provider(single_flight_key="same-question-uri") + + try: + with patch( + "mcp.client.auth.OAuthClientProvider.async_auth_flow", + fake_auth_flow, + ): + task_1 = asyncio.create_task(drive_auth_flow(provider_1)) + while len(started) < 1: + await asyncio.sleep(0) + + task_2 = asyncio.create_task(drive_auth_flow(provider_2)) + await asyncio.sleep(0) + assert len(started) == 1 + + release_first_flow.set() + await asyncio.gather(task_1, task_2) + finally: + _MCP_OAUTH_SINGLE_FLIGHT_LOCKS.clear() + + assert len(started) == 2 diff --git a/hyperforge/src/hyperforge/api/v1/oauth.py b/hyperforge/src/hyperforge/api/v1/oauth.py index 5722cd9..5b44a64 100644 --- a/hyperforge/src/hyperforge/api/v1/oauth.py +++ b/hyperforge/src/hyperforge/api/v1/oauth.py @@ -1,5 +1,7 @@ import logging +from urllib.parse import urlencode +from cryptography.fernet import InvalidToken from fastapi import Query from starlette.requests import Request from starlette.responses import HTMLResponse @@ -7,12 +9,32 @@ from hyperforge.api.settings import Settings from hyperforge.api.v1.router import router from hyperforge.api.v1.utils import tracer +from hyperforge_mcp.http import _fingerprint, decrypt_mcp_oauth_state logger = logging.getLogger(__name__) RENDER = "
You can close this window and return to the application.
" +def _build_oauth_subject( + settings: Settings, + account_id: str, + agent_id: str, + workflow_id: str, + session: str, + question_id: str, + oauth_uuid: str, +) -> str: + return settings.oauth_subject.format( + account=account_id, + agent_id=agent_id, + session=session, + question=question_id, + oauth_uuid=oauth_uuid, + workflow_id=workflow_id, + ) + + @router.get( "/api/auth/agent/{agent_id}/workflow/{workflow_id}/session/{session}/oauth/{oauth_uuid}/callback", status_code=200, @@ -34,13 +56,14 @@ async def oauth_callback( Callback from oauth flow on RAO that requires to send creds to websocket """ settings: Settings = request.app.settings - subject = settings.oauth_subject.format( - account=account_id, - agent_id=agent_id, - session=session, - question=question_id, - oauth_uuid=oauth_uuid, - workflow_id=workflow_id, + subject = _build_oauth_subject( + settings, + account_id, + agent_id, + workflow_id, + session, + question_id, + oauth_uuid, ) # Request a question with tracer().start_as_current_span("Request activation"): @@ -58,3 +81,83 @@ async def oauth_callback( ) return HTMLResponse(content=RENDER) + + +@router.get( + "/api/auth/mcp/callback", + status_code=200, + description="Generic MCP OAuth callback (fixed redirect URI, state-routed)", + tags=["Retrieval Agent"], + include_in_schema=False, +) +async def mcp_oauth_callback_generic( + request: Request, + code: str | None = Query(None, include_in_schema=False), + state: str | None = Query(None, include_in_schema=False), + error: str | None = Query(None, include_in_schema=False), + error_description: str | None = Query(None, include_in_schema=False), +): + settings: Settings = request.app.settings + + if state is None: + logger.warning("MCP generic OAuth callback received without state parameter") + return HTMLResponse(content="Missing OAuth state parameter", status_code=400) + + try: + routing = decrypt_mcp_oauth_state(state) + except InvalidToken: + logger.warning( + "MCP generic OAuth callback: invalid or expired state (Fernet decryption failed)" + ) + return HTMLResponse(content="Invalid or expired OAuth state", status_code=400) + + sdk_state = routing.sdk_state or None + if sdk_state is None: + logger.warning( + "MCP generic OAuth callback: decrypted state missing sdk_state field" + ) + return HTMLResponse(content="Malformed OAuth state", status_code=400) + + subject = _build_oauth_subject( + settings, + routing.account_id, + routing.agent_id, + routing.workflow_id, + routing.session_id, + routing.question_id, + routing.oauth_uuid, + ) + + payload_data: dict[str, str] = {} + if code is not None: + payload_data["code"] = code + if sdk_state is not None: + payload_data["state"] = sdk_state + if error is not None: + payload_data["error"] = error + if error_description is not None: + payload_data["error_description"] = error_description + + payload = urlencode(payload_data) + + with tracer().start_as_current_span("MCP generic OAuth callback"): + logger.info( + "MCP generic OAuth callback: agent=%s workflow=%s session=%s question_id=%s oauth_uuid=%s sdk_state_fp=%s has_code=%s has_error=%s", + routing.agent_id, + routing.workflow_id, + routing.session_id, + routing.question_id, + routing.oauth_uuid, + _fingerprint(sdk_state), + code is not None, + error is not None, + ) + await request.app.broker.send_reply(subject, payload) + logger.info("mcp_oauth send_reply: published to stream %s", subject) + + if error is not None: + desc = f": {error_description}" if error_description else "" + return HTMLResponse( + content=f"OAuth authorization failed ({error}{desc})", status_code=400 + ) + return HTMLResponse(content=RENDER) diff --git a/hyperforge/src/hyperforge/interaction.py b/hyperforge/src/hyperforge/interaction.py index 80235ea..e03bf6c 100644 --- a/hyperforge/src/hyperforge/interaction.py +++ b/hyperforge/src/hyperforge/interaction.py @@ -60,6 +60,7 @@ class Provider(Enum): AZURE_CERTIFICATE_CREDENTIALS = "azure_certificate_credentials" AWS_S3_ACCESS_KEYS = "aws_s3_access_keys" SHAREFILE_OAUTH = "sharefile_oauth" + MCP_OAUTH = "mcp_oauth" class OAuthAuthenticateURL(BaseModel): diff --git a/hyperforge/src/hyperforge/memory/memory.py b/hyperforge/src/hyperforge/memory/memory.py index 4d22855..9729b7e 100644 --- a/hyperforge/src/hyperforge/memory/memory.py +++ b/hyperforge/src/hyperforge/memory/memory.py @@ -65,6 +65,7 @@ class BaseSessionMemory: agent_id: str = "" workflow_id: str = "" + account_id: str = "" kbid: Optional[str] = None # User information dictionary @@ -622,6 +623,10 @@ def get_workflow_id(self) -> str: """Returns the workflow ID for the current question. The workflow ID is a unique identifier that is shared across all questions and interactions that belong to the same workflow. This can be used to group related interactions together, and to keep track of the conversation history in a coherent way.""" return self.session.workflow_id + def get_account_id(self) -> str: + """Returns the account ID for the current question.""" + return self.session.account_id + def context_user_info(self) -> str: """Returns a string with user information that can be used in the context of the agent. This can include information such as user preferences, user history, or any other relevant information about the user that can help the agent to generate a more personalized and accurate response.""" return self.session.context_user_info() diff --git a/hyperforge/src/hyperforge/server/session.py b/hyperforge/src/hyperforge/server/session.py index 761ed73..02bb0e8 100644 --- a/hyperforge/src/hyperforge/server/session.py +++ b/hyperforge/src/hyperforge/server/session.py @@ -181,8 +181,12 @@ async def activate(self, message: StartInteraction): config=config.memory, agent=message.agent_id, workflow_id=message.workflow_id, + account_id=message.account, ) - self.memory[message.session] = memory + # Ephemeral sessions are short-lived and should not be persisted in + # the shared LRU cache across activations. + if message.session != "ephemeral": + self.memory[message.session] = memory else: memory = self.memory[message.session] diff --git a/hyperforge/src/hyperforge/server/utils.py b/hyperforge/src/hyperforge/server/utils.py index f5b4f8b..c185079 100644 --- a/hyperforge/src/hyperforge/server/utils.py +++ b/hyperforge/src/hyperforge/server/utils.py @@ -16,6 +16,7 @@ async def get_memory( config: MemoryConfig, agent: str, workflow_id: str, + account_id: str = "", ) -> BaseSessionMemory: memory: BaseSessionMemory @@ -53,5 +54,6 @@ async def get_memory( ) memory.init(session=session) + memory.account_id = account_id return memory diff --git a/hyperforge/src/hyperforge/settings.py b/hyperforge/src/hyperforge/settings.py index 69f8e2a..62db0b0 100644 --- a/hyperforge/src/hyperforge/settings.py +++ b/hyperforge/src/hyperforge/settings.py @@ -2,17 +2,23 @@ from pydantic_settings import BaseSettings _REDIRECT_PATH = "/api/auth/agent/{agent_id}/workflow/{workflow_id}/session/{session_id}/oauth/{oauth_uuid}/callback" +_MCP_CALLBACK_PATH = "/api/auth/mcp/callback" class OAuthSettings(BaseSettings): nuclia_public_url: str = "https://{zone}.nuclia.com" nuclia_zone: str = "arag" rao_redirect_url: str = "" + mcp_callback_url: str = "" @model_validator(mode="after") def _resolve_urls(self) -> "OAuthSettings": self.nuclia_public_url = self.nuclia_public_url.format(zone=self.nuclia_zone) if not self.rao_redirect_url: self.rao_redirect_url = self.nuclia_public_url.rstrip("/") + _REDIRECT_PATH + if not self.mcp_callback_url: + self.mcp_callback_url = ( + self.nuclia_public_url.rstrip("/") + _MCP_CALLBACK_PATH + ) return self diff --git a/hyperforge/tests/server/test_session_manager.py b/hyperforge/tests/server/test_session_manager.py new file mode 100644 index 0000000..8ab6e26 --- /dev/null +++ b/hyperforge/tests/server/test_session_manager.py @@ -0,0 +1,59 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from hyperforge.pubsub import StartInteraction +from hyperforge.server.session import SessionManager + + +@pytest.mark.asyncio +async def test_ephemeral_session_is_not_cached(): + manager = SessionManager( + settings=SimpleNamespace( + answers_subject="arag.{account}.{agent_id}.{workflow_id}.{session}.{question}.answer", + internal_nucliadb_url="", + internal_nua_api="", + internal_nua=False, + local_openai=None, + external_nua_api_key=None, + standalone=False, + ), + broker=None, # type: ignore[arg-type] + agent_manager=SimpleNamespace(get_agent_config=AsyncMock()), + cache=None, # type: ignore[arg-type] + ) + manager.agent_manager.get_agent_config.return_value = SimpleNamespace( + memory=SimpleNamespace(), + rules=SimpleNamespace(rules=[]), + ) + + memory = SimpleNamespace(rules=None) + memory.start_question = MagicMock(return_value=SimpleNamespace()) + task = MagicMock() + message = StartInteraction( + account="account", + agent_id="agent", + session="ephemeral", + question_id="question-id", + question="question", + ) + + with ( + patch( + "hyperforge.server.session.get_memory", + new_callable=AsyncMock, + return_value=memory, + ) as get_memory, + patch( + "hyperforge.server.session.get_state", + new_callable=AsyncMock, + return_value=SimpleNamespace(), + ), + patch.object(manager, "answer", new=MagicMock(return_value=object())), + patch("hyperforge.server.session.asyncio.create_task", return_value=task), + ): + await manager.activate(message) + + get_memory.assert_awaited_once() + assert "ephemeral" not in manager.memory diff --git a/hyperforge/tests/unit/arag/test_mcp_oauth_callback.py b/hyperforge/tests/unit/arag/test_mcp_oauth_callback.py new file mode 100644 index 0000000..cf97972 --- /dev/null +++ b/hyperforge/tests/unit/arag/test_mcp_oauth_callback.py @@ -0,0 +1,174 @@ +"""Unit tests for the MCP OAuth generic callback endpoint.""" + +from unittest.mock import AsyncMock +from urllib.parse import parse_qs + +import pytest +from cryptography.fernet import Fernet +from starlette.testclient import TestClient + +import hyperforge_mcp.http as http_module +from hyperforge_mcp.http import ( + MCPOAuthRoutingParams, + decrypt_mcp_oauth_state, + encrypt_mcp_oauth_state, +) + +# --------------------------------------------------------------------------- +# Test Fernet key - injected via env var so EncryptionSettings picks it up. +# --------------------------------------------------------------------------- +_TEST_KEY = Fernet.generate_key().decode() + +_ROUTING = MCPOAuthRoutingParams( + account_id="acct-1", + agent_id="agent-1", + workflow_id="wf-1", + session_id="sess-1", + question_id="q-1", + oauth_uuid="q-1", +) + +_ROUTING_WITH_SDK_STATE = _ROUTING.model_copy( + update={"sdk_state": "random-sdk-nonce-abc123"} +) + + +@pytest.fixture(autouse=True) +def _inject_key(monkeypatch): + """Inject the test Fernet key and clear the @cache so tests use it.""" + monkeypatch.setenv("ENCRYPTION_SECRET_KEY", _TEST_KEY) + # Clear @cache so _get_mcp_fernet() returns a Fernet with the test key. + + http_module._get_mcp_fernet.cache_clear() + yield + http_module._get_mcp_fernet.cache_clear() + + +def _make_app(): + """Return a minimal FastAPI app that mounts only the oauth router.""" + from fastapi import FastAPI + + from hyperforge.api.settings import Settings + from hyperforge.api.v1.router import router + + app = FastAPI() + app.settings = Settings() + app.broker = AsyncMock() + app.include_router(router) + return app + + +@pytest.fixture() +def client(): + app = _make_app() + with TestClient(app, raise_server_exceptions=True) as c: + yield c, app + + +# --------------------------------------------------------------------------- +# Unit tests for encrypt_mcp_oauth_state / decrypt_mcp_oauth_state +# --------------------------------------------------------------------------- + + +def test_encrypt_decrypt_roundtrip(): + """Verify the full model_dump_json -> Fernet -> model_validate_json round-trip works.""" + token = encrypt_mcp_oauth_state(_ROUTING_WITH_SDK_STATE) + result = decrypt_mcp_oauth_state(token) + assert result == _ROUTING_WITH_SDK_STATE + + +# --------------------------------------------------------------------------- +# Tests for the HTTP endpoint +# --------------------------------------------------------------------------- + + +def _make_state(sdk_state: str = "sdk-nonce-abc") -> str: + return encrypt_mcp_oauth_state(_ROUTING.model_copy(update={"sdk_state": sdk_state})) + + +def test_missing_state_returns_400(client): + tc, _ = client + resp = tc.get("/api/auth/mcp/callback?code=abc123") + assert resp.status_code == 400 + assert "Missing OAuth state" in resp.text + + +def test_invalid_state_returns_400(client): + tc, _ = client + resp = tc.get("/api/auth/mcp/callback?state=garbage-token&code=abc123") + assert resp.status_code == 400 + assert "Invalid or expired" in resp.text + + +def test_valid_state_publishes_sdk_state_and_returns_200(client): + tc, app = client + sdk_nonce = "my-sdk-nonce" + state = _make_state(sdk_nonce) + + resp = tc.get(f"/api/auth/mcp/callback?state={state}&code=mycode") + + assert resp.status_code == 200 + assert app.broker.send_reply.called + _, published_payload = app.broker.send_reply.call_args[0] + params = parse_qs(published_payload) + assert params["code"] == ["mycode"] + # The published state must be sdk_state (not the full Fernet token). + assert params["state"] == [sdk_nonce] + + +def test_valid_state_can_be_used_twice(client): + """Fernet tokens are stateless - the same token can theoretically be reused + within its TTL window. This tests that the endpoint is idempotent.""" + tc, app = client + state = _make_state() + + resp1 = tc.get(f"/api/auth/mcp/callback?state={state}&code=mycode") + resp2 = tc.get(f"/api/auth/mcp/callback?state={state}&code=mycode") + assert resp1.status_code == 200 + assert resp2.status_code == 200 + + +def test_oauth_error_publishes_and_returns_400(client): + tc, app = client + state = _make_state() + + resp = tc.get( + f"/api/auth/mcp/callback?state={state}&error=access_denied&error_description=User+denied" + ) + + assert resp.status_code == 400 + assert "access_denied" in resp.text + assert app.broker.send_reply.called + _, published_payload = app.broker.send_reply.call_args[0] + params = parse_qs(published_payload) + assert params["error"] == ["access_denied"] + assert params["error_description"] == ["User denied"] + + +def test_oauth_error_without_description(client): + tc, _ = client + state = _make_state() + + resp = tc.get(f"/api/auth/mcp/callback?state={state}&error=server_error") + assert resp.status_code == 400 + assert "server_error" in resp.text + + +def test_state_encrypted_with_different_key_returns_400(client, monkeypatch): + """A token encrypted with a different key must be rejected.""" + + other_key = Fernet.generate_key().decode() + monkeypatch.setenv("ENCRYPTION_SECRET_KEY", other_key) + http_module._get_mcp_fernet.cache_clear() + + # encrypt_mcp_oauth_state now uses the *other* key + state_with_other_key = _make_state() + + # Reset to original test key for the endpoint + monkeypatch.setenv("ENCRYPTION_SECRET_KEY", _TEST_KEY) + http_module._get_mcp_fernet.cache_clear() + + tc, _ = client + resp = tc.get(f"/api/auth/mcp/callback?state={state_with_other_key}&code=abc") + assert resp.status_code == 400 + assert "Invalid or expired" in resp.text