From 265515e63ed2656ffa514d3c328564679600eab6 Mon Sep 17 00:00:00 2001 From: Daniel Miller Date: Fri, 29 May 2026 10:09:00 -0400 Subject: [PATCH] feat(openai-agents): stream tool-call argument deltas to the UI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The TemporalStreamingModel already parsed the Responses API function_call_arguments deltas but only accumulated them — tool calls reached the UI all-at-once via the hooks layer. Now, as the model generates a tool call, the model layer opens a ToolRequestContent streaming context per call (keyed by output_index) and pushes ToolRequestDelta chunks as the arguments arrive, closing on item-done (mirrors the existing text-delta path). on_tool_start no longer emits a duplicate ToolRequestContent (the model layer streams it live now); on_tool_end still emits the ToolResponseContent result. The returned ModelResponse/output_items/usage assembly is unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../plugins/openai_agents/hooks/hooks.py | 41 +++-------- .../models/temporal_streaming_model.py | 70 ++++++++++++++++++- 2 files changed, 79 insertions(+), 32 deletions(-) diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py b/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py index 758b0db27..40e17f4cb 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py @@ -13,7 +13,7 @@ from agents.tool_context import ToolContext from agentex.types.text_content import TextContent -from agentex.types.task_message_content import ToolRequestContent, ToolResponseContent +from agentex.types.task_message_content import ToolResponseContent from agentex.lib.core.observability.llm_metrics_hooks import LLMMetricsHooks from agentex.lib.core.temporal.plugins.openai_agents.hooks.activities import stream_lifecycle_content @@ -106,44 +106,25 @@ async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: A @override async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None: # noqa: ARG002 - """Stream tool request when a tool starts execution. + """Called when a tool starts execution. - Extracts the tool_call_id and tool_arguments from the context and streams a - ToolRequestContent message to the UI showing that the tool is about to execute. + The tool request (ToolRequestContent) is now streamed live by the model + layer (TemporalStreamingModel) as the function-call arguments arrive over + the Responses API stream. Emitting it again here would double-render the + tool request in the UI, so this hook no longer streams a ToolRequestContent. + The tool *result* is still streamed by on_tool_end (ToolResponseContent), + which the model stream does not produce. Args: context: The run context wrapper (will be a ToolContext with tool_call_id and tool_arguments) agent: The agent executing the tool tool: The tool being executed """ - import json - tool_context = context if isinstance(context, ToolContext) else None tool_call_id = tool_context.tool_call_id if tool_context else f"call_{id(tool)}" - - # Extract tool arguments from context - tool_arguments = {} - if tool_context and hasattr(tool_context, 'tool_arguments'): - try: - # tool_arguments is a JSON string, parse it - tool_arguments = json.loads(tool_context.tool_arguments) - except (json.JSONDecodeError, TypeError): - # If parsing fails, log and use empty dict - logger.warning(f"Failed to parse tool arguments: {tool_context.tool_arguments}") - tool_arguments = {} - - await workflow.execute_activity( - stream_lifecycle_content, - args=[ - self.task_id, - ToolRequestContent( - author="agent", - tool_call_id=tool_call_id, - name=tool.name, - arguments=tool_arguments, - ).model_dump(), - ], - start_to_close_timeout=self.timeout, + logger.debug( + f"[TemporalStreamingHooks] Tool '{tool.name}' started (tool_call_id={tool_call_id}); " + "tool request is streamed live by the model layer, not re-emitted here." ) @override diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py index 7ccc6627a..2ea14dfa8 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py @@ -64,8 +64,10 @@ from agentex.lib.utils.logging import make_logger from agentex.lib.core.tracing.tracer import AsyncTracer from agentex.types.task_message_delta import TextDelta, ReasoningContentDelta, ReasoningSummaryDelta +from agentex.types.tool_request_delta import ToolRequestDelta from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta from agentex.types.task_message_content import TextContent, ReasoningContent +from agentex.types.tool_request_content import ToolRequestContent from agentex.lib.adk.utils._modules.client import create_async_agentex_client from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import ( streaming_task_id, @@ -671,6 +673,7 @@ async def get_response( # Process events from the Responses API stream function_calls_in_progress = {} # Track function calls being streamed + tool_call_contexts: dict[int, Any] = {} # Open streaming contexts per function call async for event in stream: event_count += 1 @@ -723,14 +726,29 @@ async def get_response( ).__aenter__() elif item and getattr(item, 'type', None) == 'function_call': # Track the function call being streamed + call_id = getattr(item, 'call_id', '') + name = getattr(item, 'name', '') function_calls_in_progress[output_index] = { 'id': getattr(item, 'id', ''), - 'call_id': getattr(item, 'call_id', ''), - 'name': getattr(item, 'name', ''), + 'call_id': call_id, + 'name': name, 'arguments': getattr(item, 'arguments', ''), } logger.debug(f"[TemporalStreamingModel] Starting function call: {item.name}") + # Open a streaming context so tool-call args stream live to the UI + tool_ctx = await adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=ToolRequestContent( + author="agent", + tool_call_id=call_id, + name=name, + arguments={}, + ), + streaming_mode=self.streaming_mode, + ).__aenter__() + tool_call_contexts[output_index] = tool_ctx + elif item and getattr(item, 'type', None) == 'message': # Track the message being streamed streaming_context = await adk.streaming.streaming_task_message_context( @@ -752,6 +770,25 @@ async def get_response( function_calls_in_progress[output_index]['arguments'] += delta logger.debug(f"[TemporalStreamingModel] Function call args delta: {delta[:50]}...") + # Stream the args delta live to the UI. The delta event carries + # no call_id/name, so pull them from function_calls_in_progress + # (populated at announce time, keyed by output_index). + ctx = tool_call_contexts.get(output_index) + if ctx is not None and delta: + try: + await ctx.stream_update(StreamTaskMessageDelta( + parent_task_message=ctx.task_message, + delta=ToolRequestDelta( + type="tool_request", + tool_call_id=function_calls_in_progress[output_index]['call_id'], + name=function_calls_in_progress[output_index]['name'], + arguments_delta=delta, + ), + type="delta", + )) + except Exception as e: + logger.warning(f"Failed to stream tool-call args delta: {e}") + elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent): # Function call arguments complete output_index = getattr(event, 'output_index', 0) @@ -874,6 +911,14 @@ async def get_response( ) output_items.append(tool_call) + # Close + pop the live-streaming context for this tool call + ctx = tool_call_contexts.pop(output_index, None) + if ctx is not None: + try: + await ctx.close() + except Exception as e: + logger.warning(f"Failed to close tool-call stream context: {e}") + elif isinstance(event, ResponseReasoningSummaryPartAddedEvent): # New reasoning part/summary started - reset accumulator part = getattr(event, 'part', None) @@ -907,6 +952,16 @@ async def get_response( await streaming_context.close() streaming_context = None + # Close any tool-call contexts still open (e.g. stream ended without + # a per-item done event). A partial / invalid-JSON args close can + # raise, so guard each one so it never crashes the activity. + for ctx in tool_call_contexts.values(): + try: + await ctx.close() + except Exception as e: + logger.warning(f"Failed to close tool-call stream context: {e}") + tool_call_contexts.clear() + # Build the response from output items collected during streaming # Create output from the items we collected response_output = [] @@ -1061,6 +1116,17 @@ async def get_response( except Exception as e: logger.error(f"Error using Responses API: {e}") + # Close any tool-call streaming contexts still open so the error + # path doesn't leak open contexts. A partial / invalid-JSON args + # close can raise, so guard each one. tool_call_contexts may be + # unbound if the error fired before the stream loop started. + for ctx in locals().get("tool_call_contexts", {}).values(): + try: + await ctx.close() + except Exception as close_err: + logger.warning(f"Failed to close tool-call stream context: {close_err}") + if "tool_call_contexts" in locals(): + tool_call_contexts.clear() # LLMMetricsHooks.on_llm_end doesn't fire on error, so emit the # failure counter here. Best-effort so the typed LLM exception # always propagates intact for retry / circuit-breaker logic.