Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 11 additions & 30 deletions src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from agents.tool_context import ToolContext

from agentex.types.text_content import TextContent
from agentex.types.task_message_content import ToolRequestContent, ToolResponseContent
from agentex.types.task_message_content import ToolResponseContent
from agentex.lib.core.observability.llm_metrics_hooks import LLMMetricsHooks
from agentex.lib.core.temporal.plugins.openai_agents.hooks.activities import stream_lifecycle_content

Expand Down Expand Up @@ -106,44 +106,25 @@ async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: A

@override
async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None: # noqa: ARG002
"""Stream tool request when a tool starts execution.
"""Called when a tool starts execution.

Extracts the tool_call_id and tool_arguments from the context and streams a
ToolRequestContent message to the UI showing that the tool is about to execute.
The tool request (ToolRequestContent) is now streamed live by the model
layer (TemporalStreamingModel) as the function-call arguments arrive over
the Responses API stream. Emitting it again here would double-render the
tool request in the UI, so this hook no longer streams a ToolRequestContent.
The tool *result* is still streamed by on_tool_end (ToolResponseContent),
which the model stream does not produce.

Args:
context: The run context wrapper (will be a ToolContext with tool_call_id and tool_arguments)
agent: The agent executing the tool
tool: The tool being executed
"""
import json

tool_context = context if isinstance(context, ToolContext) else None
tool_call_id = tool_context.tool_call_id if tool_context else f"call_{id(tool)}"

# Extract tool arguments from context
tool_arguments = {}
if tool_context and hasattr(tool_context, 'tool_arguments'):
try:
# tool_arguments is a JSON string, parse it
tool_arguments = json.loads(tool_context.tool_arguments)
except (json.JSONDecodeError, TypeError):
# If parsing fails, log and use empty dict
logger.warning(f"Failed to parse tool arguments: {tool_context.tool_arguments}")
tool_arguments = {}

await workflow.execute_activity(
stream_lifecycle_content,
args=[
self.task_id,
ToolRequestContent(
author="agent",
tool_call_id=tool_call_id,
name=tool.name,
arguments=tool_arguments,
).model_dump(),
],
start_to_close_timeout=self.timeout,
logger.debug(
f"[TemporalStreamingHooks] Tool '{tool.name}' started (tool_call_id={tool_call_id}); "
"tool request is streamed live by the model layer, not re-emitted here."
)

@override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@
from agentex.lib.utils.logging import make_logger
from agentex.lib.core.tracing.tracer import AsyncTracer
from agentex.types.task_message_delta import TextDelta, ReasoningContentDelta, ReasoningSummaryDelta
from agentex.types.tool_request_delta import ToolRequestDelta
from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta
from agentex.types.task_message_content import TextContent, ReasoningContent
from agentex.types.tool_request_content import ToolRequestContent
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import (
streaming_task_id,
Expand Down Expand Up @@ -671,6 +673,7 @@ async def get_response(

# Process events from the Responses API stream
function_calls_in_progress = {} # Track function calls being streamed
tool_call_contexts: dict[int, Any] = {} # Open streaming contexts per function call

async for event in stream:
event_count += 1
Expand Down Expand Up @@ -723,14 +726,29 @@ async def get_response(
).__aenter__()
elif item and getattr(item, 'type', None) == 'function_call':
# Track the function call being streamed
call_id = getattr(item, 'call_id', '')
name = getattr(item, 'name', '')
function_calls_in_progress[output_index] = {
'id': getattr(item, 'id', ''),
'call_id': getattr(item, 'call_id', ''),
'name': getattr(item, 'name', ''),
'call_id': call_id,
'name': name,
'arguments': getattr(item, 'arguments', ''),
}
logger.debug(f"[TemporalStreamingModel] Starting function call: {item.name}")

# Open a streaming context so tool-call args stream live to the UI
tool_ctx = await adk.streaming.streaming_task_message_context(
task_id=task_id,
initial_content=ToolRequestContent(
author="agent",
tool_call_id=call_id,
name=name,
arguments={},
),
streaming_mode=self.streaming_mode,
).__aenter__()
tool_call_contexts[output_index] = tool_ctx

elif item and getattr(item, 'type', None) == 'message':
# Track the message being streamed
streaming_context = await adk.streaming.streaming_task_message_context(
Expand All @@ -752,6 +770,25 @@ async def get_response(
function_calls_in_progress[output_index]['arguments'] += delta
logger.debug(f"[TemporalStreamingModel] Function call args delta: {delta[:50]}...")

# Stream the args delta live to the UI. The delta event carries
# no call_id/name, so pull them from function_calls_in_progress
# (populated at announce time, keyed by output_index).
ctx = tool_call_contexts.get(output_index)
if ctx is not None and delta:
try:
await ctx.stream_update(StreamTaskMessageDelta(
parent_task_message=ctx.task_message,
delta=ToolRequestDelta(
type="tool_request",
tool_call_id=function_calls_in_progress[output_index]['call_id'],
name=function_calls_in_progress[output_index]['name'],
arguments_delta=delta,
),
type="delta",
))
except Exception as e:
logger.warning(f"Failed to stream tool-call args delta: {e}")

elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
# Function call arguments complete
output_index = getattr(event, 'output_index', 0)
Expand Down Expand Up @@ -874,6 +911,14 @@ async def get_response(
)
output_items.append(tool_call)

# Close + pop the live-streaming context for this tool call
ctx = tool_call_contexts.pop(output_index, None)
if ctx is not None:
try:
await ctx.close()
except Exception as e:
logger.warning(f"Failed to close tool-call stream context: {e}")

elif isinstance(event, ResponseReasoningSummaryPartAddedEvent):
# New reasoning part/summary started - reset accumulator
part = getattr(event, 'part', None)
Expand Down Expand Up @@ -907,6 +952,16 @@ async def get_response(
await streaming_context.close()
streaming_context = None

# Close any tool-call contexts still open (e.g. stream ended without
# a per-item done event). A partial / invalid-JSON args close can
# raise, so guard each one so it never crashes the activity.
for ctx in tool_call_contexts.values():
try:
await ctx.close()
except Exception as e:
logger.warning(f"Failed to close tool-call stream context: {e}")
tool_call_contexts.clear()

# Build the response from output items collected during streaming
# Create output from the items we collected
response_output = []
Expand Down Expand Up @@ -1061,6 +1116,17 @@ async def get_response(

except Exception as e:
logger.error(f"Error using Responses API: {e}")
# Close any tool-call streaming contexts still open so the error
# path doesn't leak open contexts. A partial / invalid-JSON args
# close can raise, so guard each one. tool_call_contexts may be
# unbound if the error fired before the stream loop started.
for ctx in locals().get("tool_call_contexts", {}).values():
try:
await ctx.close()
except Exception as close_err:
logger.warning(f"Failed to close tool-call stream context: {close_err}")
if "tool_call_contexts" in locals():
tool_call_contexts.clear()
# LLMMetricsHooks.on_llm_end doesn't fire on error, so emit the
# failure counter here. Best-effort so the typed LLM exception
# always propagates intact for retry / circuit-breaker logic.
Expand Down
Loading