diff --git a/docs/architecture.md b/docs/architecture.md index ea2601a..296c29f 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -75,7 +75,7 @@ In `utils/tool_dispatch.py`, tool results are classified through `_parse_tool_re When adding a new tool renderer: -1. Add a `(predicate, builder)` pair to `_TOOL_RESULT_DISPATCH` in `utils/tool_dispatch.py`, preserving existing predicate order unless you also update fixtures and ordering tests (`tests/test_jsonl_parser.py`, `tests/test_real_session_fixtures.py`). Order is **not** “specific before generic” in general — the first match wins. `_tool_result_pred_task_message` is the intentional broad-before-narrow exception (`task_id` or `message` before retrieval/completed/async). +1. Add a `(predicate, builder)` pair to `_TOOL_RESULT_DISPATCH` in `utils/tool_dispatch.py`, preserving existing predicate order unless you also update fixtures and ordering tests (`tests/test_jsonl_parser.py`, `tests/test_real_session_fixtures.py`). Order is **not** “specific before generic” in general — the first match wins. `is_task_message_tool_result` is the intentional broad-before-narrow exception (`task_id` or `message` before retrieval/completed/async). 2. Add or extend a JSONL fixture under `tests/fixtures/` (especially for overlaps with existing predicates). 3. Run `pytest tests/test_jsonl_parser.py tests/test_real_session_fixtures.py -v`. diff --git a/models/__init__.py b/models/__init__.py index 5f5b21c..089a5d5 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -3,14 +3,17 @@ from models.errors import ErrorResponse from models.export import ExportStateDict from models.project import ProjectDict, ProjectSessionRowDict, SessionListItemDict +from models.record_data import RecordDataUnion from models.search import SearchHitDict from models.session import ( MessageDict, QuickSessionInfoDict, SessionDict, SessionMetadataDict, + ToolUseDict, ) from models.stats import FilesTouchedDict, SessionStatsDict +from models.tool_results import ToolResultUnion __all__ = [ "ErrorResponse", @@ -25,4 +28,7 @@ "SessionListItemDict", "SessionMetadataDict", "SessionStatsDict", + "RecordDataUnion", + "ToolResultUnion", + "ToolUseDict", ] diff --git a/models/record_data.py b/models/record_data.py new file mode 100644 index 0000000..04b1512 --- /dev/null +++ b/models/record_data.py @@ -0,0 +1,43 @@ +"""TypedDict shapes for record-level ``data`` payloads on progress messages.""" + +from typing import Literal, TypedDict + + +class BashProgressDataDict(TypedDict, total=False): + type: Literal["bash_progress"] + output: str + + +class HookProgressDataDict(TypedDict, total=False): + type: Literal["hook_progress"] + output: str + + +class AgentProgressDataDict(TypedDict, total=False): + type: Literal["agent_progress"] + message: str + + +class SummaryDataDict(TypedDict, total=False): + """Summary-style progress payloads (when present on progress entries).""" + + type: Literal["summary"] + summary: str + + +class CompactBoundaryDataDict(TypedDict, total=False): + """Compact-boundary metadata when carried on a data blob.""" + + type: Literal["compact_boundary"] + trigger: str + pre_tokens: int + + +RecordDataUnion = ( + BashProgressDataDict + | HookProgressDataDict + | AgentProgressDataDict + | SummaryDataDict + | CompactBoundaryDataDict + | dict[str, object] +) diff --git a/models/session.py b/models/session.py index 85a0791..2fd150c 100644 --- a/models/session.py +++ b/models/session.py @@ -1,6 +1,27 @@ """Parsed session shapes from jsonl_parser.""" -from typing import Any, NotRequired, TypedDict +from typing import Any, Literal, NotRequired, TypedDict + +from models.record_data import RecordDataUnion +from models.tool_results import ToolNameLiteral, ToolResultUnion + + +class ToolUseDict(TypedDict, total=False): + id: str + # Literal | str is just str for mypy — documents known tool names, not exhaustiveness. + name: ToolNameLiteral | str + input: dict[str, object] + + +class MessageUsageDict(TypedDict, total=False): + input_tokens: int + output_tokens: int + cache_read: int + cache_creation: int + service_tier: str | None + + +SystemSubtypeLiteral = Literal["compact_boundary", "init"] class MessageDict(TypedDict): @@ -12,18 +33,18 @@ class MessageDict(TypedDict): content: NotRequired[str] images: NotRequired[list[Any] | None] is_sidechain: NotRequired[bool] - tool_result: NotRequired[Any] - tool_result_parsed: NotRequired[dict[str, Any] | None] + tool_result: NotRequired[ToolResultUnion | None] + tool_result_parsed: NotRequired[dict[str, object] | None] slug: NotRequired[str | None] model: NotRequired[str] stop_reason: NotRequired[str] thinking: NotRequired[str | None] - tool_uses: NotRequired[list[dict[str, Any]] | None] + tool_uses: NotRequired[list[ToolUseDict] | None] is_api_error: NotRequired[bool] - usage: NotRequired[dict[str, Any]] + usage: NotRequired[MessageUsageDict] subtype: NotRequired[str] level: NotRequired[str] - data: NotRequired[Any] + data: NotRequired[RecordDataUnion] progress_type: NotRequired[str] tool_use_id: NotRequired[str | None] parent_tool_use_id: NotRequired[str | None] diff --git a/models/tool_results.py b/models/tool_results.py new file mode 100644 index 0000000..ca881b7 --- /dev/null +++ b/models/tool_results.py @@ -0,0 +1,233 @@ +"""TypedDict shapes for Claude Code toolUseResult blobs at the JSONL parse boundary. + +Ground truth: tests/test_jsonl_parser.py, tests/test_real_session_fixtures.py, +and utils/tool_dispatch.py predicate order (first match wins). +""" + +from typing import Literal, TypedDict, TypeGuard + + +class BashToolResultDict(TypedDict, total=False): + stdout: str + stderr: str + exitCode: int + interrupted: bool + is_error: bool + returnCodeInterpretation: str + + +class FileEditToolResultDict(TypedDict, total=False): + structuredPatch: str + filePath: str + newString: str + replaceAll: bool + + +class PlanToolResultDict(TypedDict, total=False): + plan: list[object] + filePath: str + content: str + + +class FileWriteToolResultDict(TypedDict, total=False): + filePath: str + content: str + + +class GlobToolResultDict(TypedDict, total=False): + filenames: list[str] + numFiles: int + truncated: bool + durationMs: int + + +class GrepToolResultDict(TypedDict, total=False): + mode: str + numFiles: int + numLines: int + content: str + durationMs: int + + +class ReadFileObjDict(TypedDict, total=False): + filePath: str + numLines: int + content: str + + +class ReadToolResultDict(TypedDict, total=False): + file: ReadFileObjDict + content: list[object] + + +class WebSearchToolResultDict(TypedDict, total=False): + query: str + results: list[object] | None + durationSeconds: float + + +class WebFetchToolResultDict(TypedDict, total=False): + url: str + code: int + durationMs: int + + +class TaskMessageToolResultDict(TypedDict, total=False): + task_id: str + task_type: str + message: str + agentId: str + + +class TaskRetrievalToolResultDict(TypedDict, total=False): + retrieval_status: str + task: dict[str, object] + + +class TaskCompletedToolResultDict(TypedDict, total=False): + agentId: str + totalDurationMs: int + status: str + totalTokens: int + totalToolUseCount: int + + +class TaskAsyncToolResultDict(TypedDict, total=False): + agentId: str + isAsync: bool + status: str + description: str + + +class TodoItemDict(TypedDict, total=False): + id: str + content: str + + +class TodoWriteToolResultDict(TypedDict, total=False): + newTodos: list[TodoItemDict] + oldTodos: list[TodoItemDict] + + +class UserInputToolResultDict(TypedDict, total=False): + questions: list[dict[str, object]] + answers: dict[str, object] + + +class ToolResultContentBlockDict(TypedDict, total=False): + type: str + source: dict[str, object] + + +class ToolResultWithContentDict(TypedDict, total=False): + """Read-on-image and similar payloads that embed content blocks.""" + + content: list[ToolResultContentBlockDict] + + +# Dict passed into dispatch predicates (structural superset of all tool blobs). +ToolResultDict = dict[str, object] + +ToolResultUnion = ( + str + | BashToolResultDict + | FileEditToolResultDict + | PlanToolResultDict + | FileWriteToolResultDict + | GlobToolResultDict + | GrepToolResultDict + | ReadToolResultDict + | WebSearchToolResultDict + | WebFetchToolResultDict + | TaskMessageToolResultDict + | TaskRetrievalToolResultDict + | TaskCompletedToolResultDict + | TaskAsyncToolResultDict + | TodoWriteToolResultDict + | UserInputToolResultDict + | ToolResultWithContentDict + | dict[str, object] +) + + +def is_tool_result_dict(tr: ToolResultUnion | None) -> TypeGuard[ToolResultDict]: + return isinstance(tr, dict) + + +def is_bash_tool_result(tr: ToolResultDict) -> TypeGuard[BashToolResultDict]: + return "stdout" in tr or "stderr" in tr + + +def is_file_edit_tool_result(tr: ToolResultDict) -> TypeGuard[FileEditToolResultDict]: + return "structuredPatch" in tr or ("filePath" in tr and "newString" in tr) + + +def is_plan_tool_result(tr: ToolResultDict) -> TypeGuard[PlanToolResultDict]: + return "plan" in tr and "filePath" in tr + + +def is_file_write_tool_result(tr: ToolResultDict) -> TypeGuard[FileWriteToolResultDict]: + return "filePath" in tr and "content" in tr + + +def is_glob_tool_result(tr: ToolResultDict) -> TypeGuard[GlobToolResultDict]: + filenames = tr.get("filenames") + return "filenames" in tr and isinstance(filenames, list) + + +def is_grep_tool_result(tr: ToolResultDict) -> TypeGuard[GrepToolResultDict]: + return "mode" in tr and "numFiles" in tr + + +def is_read_tool_result(tr: ToolResultDict) -> TypeGuard[ReadToolResultDict]: + file_obj = tr.get("file") + return "file" in tr and isinstance(file_obj, dict) + + +def is_web_search_tool_result(tr: ToolResultDict) -> TypeGuard[WebSearchToolResultDict]: + return "query" in tr and "results" in tr + + +def is_web_fetch_tool_result(tr: ToolResultDict) -> TypeGuard[WebFetchToolResultDict]: + return "url" in tr and "code" in tr + + +def is_task_message_tool_result(tr: ToolResultDict) -> TypeGuard[TaskMessageToolResultDict]: + # Broad: matches ``task_id`` OR ``message``. Runs before retrieval/completed/async + # arms in tool_dispatch — same short-circuit order as the historical if/elif chain. + return "task_id" in tr or "message" in tr + + +def is_task_retrieval_tool_result(tr: ToolResultDict) -> TypeGuard[TaskRetrievalToolResultDict]: + return "retrieval_status" in tr and "task" in tr + + +def is_task_completed_tool_result(tr: ToolResultDict) -> TypeGuard[TaskCompletedToolResultDict]: + return "agentId" in tr and "totalDurationMs" in tr + + +def is_task_async_tool_result(tr: ToolResultDict) -> TypeGuard[TaskAsyncToolResultDict]: + return "agentId" in tr and "isAsync" in tr + + +def is_todo_write_tool_result(tr: ToolResultDict) -> TypeGuard[TodoWriteToolResultDict]: + return "newTodos" in tr or "oldTodos" in tr + + +def is_user_input_tool_result(tr: ToolResultDict) -> TypeGuard[UserInputToolResultDict]: + return "questions" in tr and "answers" in tr + + +# Tool names on assistant tool_use blocks — pairs with slug on user tool_result rows. +ToolNameLiteral = Literal[ + "Bash", + "Read", + "Write", + "Edit", + "Glob", + "Grep", + "Task", + "TodoWrite", + "WebFetch", + "WebSearch", +] diff --git a/tests/test_real_session_fixtures.py b/tests/test_real_session_fixtures.py index 9696dbc..b6f820e 100644 --- a/tests/test_real_session_fixtures.py +++ b/tests/test_real_session_fixtures.py @@ -151,7 +151,7 @@ def test_task_retrieval_not_misclassified_as_task_message() -> None: def test_task_completed_with_message_key_matches_task_message_first() -> None: """Legacy dispatch: broad task_message runs before task_completed when ``message`` present. - ``_tool_result_pred_task_message`` matches any dict with a ``message`` or ``task_id`` + ``is_task_message_tool_result`` matches any dict with a ``message`` or ``task_id`` key. Future tool shapes that add ``message`` for status text (e.g. web-fetch) would be misclassified as task until dispatch order is refined — this test locks that known false-positive surface. diff --git a/utils/jsonl_parser.py b/utils/jsonl_parser.py index 69e216a..b3a4539 100644 --- a/utils/jsonl_parser.py +++ b/utils/jsonl_parser.py @@ -6,7 +6,9 @@ from datetime import datetime from typing import Any -from models.session import MessageDict, SessionDict +from models.record_data import RecordDataUnion +from models.session import MessageDict, SessionDict, ToolUseDict +from models.tool_results import ToolResultUnion, is_tool_result_dict from utils.jsonl_helpers import ( entry_message as _entry_message, extract_images as _extract_images, @@ -172,11 +174,12 @@ def _process_user( text = _extract_text(content) images = _extract_images(content) - tool_result = entry.get("toolUseResult") + raw_tool_result = entry.get("toolUseResult") + tool_result: ToolResultUnion | None = raw_tool_result if raw_tool_result is not None else None tool_result_parsed = _parse_tool_result(tool_result, entry.get("slug")) # Also extract images from toolUseResult content (e.g., Read tool on image files) - if isinstance(tool_result, dict) and "content" in tool_result: + if is_tool_result_dict(tool_result) and "content" in tool_result: tr_content = tool_result["content"] if isinstance(tr_content, list): tr_images = _extract_images(tr_content) @@ -244,7 +247,7 @@ def _process_assistant( content_parts = _normalize_content(msg.get("content", [])) text_parts = [] thinking_parts = [] - tool_uses = [] + tool_uses: list[ToolUseDict] = [] for part in content_parts: ptype = part.get("type") @@ -254,20 +257,20 @@ def _process_assistant( thinking_parts.append(part.get("thinking", "")) elif ptype == "tool_use": tool_name = part.get("name", "unknown") - tool_input = part.get("input", {}) + raw_input = part.get("input", {}) + safe_input = raw_input if isinstance(raw_input, dict) else {} metadata["total_tool_calls"] += 1 metadata["tool_call_counts"][tool_name] = ( metadata["tool_call_counts"].get(tool_name, 0) + 1 ) - tool_uses.append( - { - "id": part.get("id"), - "name": tool_name, - "input": tool_input, - } - ) - # Track file activity from tool inputs - safe_input = tool_input if isinstance(tool_input, dict) else {} + tool_use: ToolUseDict = { + "name": tool_name, + "input": safe_input, + } + tool_id = part.get("id") + if isinstance(tool_id, str): + tool_use["id"] = tool_id + tool_uses.append(tool_use) _track_file_activity(tool_name, safe_input, metadata) messages.append( @@ -328,8 +331,9 @@ def _process_system( def _process_progress(entry: dict[str, Any], messages: list[MessageDict]) -> None: """Capture progress entries -- streaming bash output, hook results, etc. These are noisy so we mostly just store them for the JSON export.""" - data = entry.get("data", {}) - progress_type = data.get("type", "") + raw_data = entry.get("data", {}) + data: RecordDataUnion = raw_data if isinstance(raw_data, dict) else {} + progress_type = str(data.get("type", "")) messages.append( { diff --git a/utils/md_exporter.py b/utils/md_exporter.py index 0ba3605..ef9bdf3 100644 --- a/utils/md_exporter.py +++ b/utils/md_exporter.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Any -from models.session import MessageDict, SessionDict +from models.session import MessageDict, SessionDict, ToolUseDict from models.stats import SessionStatsDict from utils.session_stats import format_duration @@ -266,9 +266,10 @@ def _render_assistant(msg: MessageDict) -> str: return "\n".join(lines) -def _render_tool_use(tool: dict[str, Any]) -> str: +def _render_tool_use(tool: ToolUseDict) -> str: name = tool.get("name", "unknown") - inp = tool.get("input", {}) + raw_inp = tool.get("input", {}) + inp: dict[str, object] = raw_inp if isinstance(raw_inp, dict) else {} lines = [] lines.append(f"\n> **Tool: {name}**") @@ -307,15 +308,23 @@ def _render_tool_use(tool: dict[str, Any]) -> str: if inp.get("prompt"): lines.append(f">\n> **Prompt:**\n> ```\n> {inp['prompt']}\n> ```") elif name == "TodoWrite": - todos = inp.get("todos", []) - for t in todos: - status = t.get("status", "") - icon = {"completed": "[x]", "in_progress": "[~]", "pending": "[ ]"}.get(status, "[ ]") - lines.append(f"> - {icon} {t.get('content', '')}") + raw_todos = inp.get("todos", []) + if isinstance(raw_todos, list): + for t in raw_todos: + if not isinstance(t, dict): + continue + status = t.get("status", "") + icon = {"completed": "[x]", "in_progress": "[~]", "pending": "[ ]"}.get( + str(status), "[ ]" + ) + lines.append(f"> - {icon} {t.get('content', '')}") elif name == "AskUserQuestion": - questions = inp.get("questions", []) - for q in questions: - lines.append(f">\n> Q: {q.get('question', '')}") + raw_questions = inp.get("questions", []) + if isinstance(raw_questions, list): + for q in raw_questions: + if not isinstance(q, dict): + continue + lines.append(f">\n> Q: {q.get('question', '')}") else: lines.append(f">\n> Input: `{str(inp)}`") diff --git a/utils/tool_dispatch.py b/utils/tool_dispatch.py index de1bcdc..03081c2 100644 --- a/utils/tool_dispatch.py +++ b/utils/tool_dispatch.py @@ -9,16 +9,35 @@ To add a shape: append ``(pred, build)`` at the end, or insert only after verifying predicates above would not steal intended matches. -""" - -from typing import Any +Predicates live in ``models.tool_results`` (single source of truth for narrowing). +""" -def _tool_result_pred_bash(tr: dict[str, Any]) -> bool: - return "stdout" in tr or "stderr" in tr +from typing import cast + +from models.tool_results import ( + ToolResultDict, + ToolResultUnion, + is_bash_tool_result, + is_file_edit_tool_result, + is_file_write_tool_result, + is_glob_tool_result, + is_grep_tool_result, + is_plan_tool_result, + is_read_tool_result, + is_task_async_tool_result, + is_task_completed_tool_result, + is_task_message_tool_result, + is_task_retrieval_tool_result, + is_todo_write_tool_result, + is_tool_result_dict, + is_user_input_tool_result, + is_web_fetch_tool_result, + is_web_search_tool_result, +) -def _tool_result_build_bash(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_bash(tr: ToolResultDict, base: dict[str, object]) -> dict[str, object]: result = dict(base) result["result_type"] = "bash" result["stdout"] = tr.get("stdout", "") @@ -30,11 +49,7 @@ def _tool_result_build_bash(tr: dict[str, Any], base: dict[str, Any]) -> dict[st return result -def _tool_result_pred_file_edit(tr: dict[str, Any]) -> bool: - return "structuredPatch" in tr or ("filePath" in tr and "newString" in tr) - - -def _tool_result_build_file_edit(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_file_edit(tr: ToolResultDict, base: dict[str, object]) -> dict[str, object]: # Summary fields only; full blob (e.g. structuredPatch) stays on message tool_result. result = dict(base) result["result_type"] = "file_edit" @@ -43,48 +58,34 @@ def _tool_result_build_file_edit(tr: dict[str, Any], base: dict[str, Any]) -> di return result -def _tool_result_pred_plan(tr: dict[str, Any]) -> bool: - return "plan" in tr and "filePath" in tr - - -def _tool_result_build_plan(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_plan(tr: ToolResultDict, base: dict[str, object]) -> dict[str, object]: result = dict(base) result["result_type"] = "plan" result["file_path"] = tr.get("filePath", "") return result -def _tool_result_pred_file_write(tr: dict[str, Any]) -> bool: - return "filePath" in tr and "content" in tr - - -def _tool_result_build_file_write(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_file_write(tr: ToolResultDict, base: dict[str, object]) -> dict[str, object]: result = dict(base) result["result_type"] = "file_write" result["file_path"] = tr.get("filePath", "") return result -def _tool_result_pred_glob(tr: dict[str, Any]) -> bool: - return "filenames" in tr and isinstance(tr.get("filenames"), list) - - -def _tool_result_build_glob(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_glob(tr: ToolResultDict, base: dict[str, object]) -> dict[str, object]: result = dict(base) - filenames = tr["filenames"] + raw_filenames = tr.get("filenames") + filenames = raw_filenames if isinstance(raw_filenames, list) else [] result["result_type"] = "glob" - result["num_files"] = tr.get("numFiles", len(filenames)) + num_files = tr.get("numFiles") + result["num_files"] = num_files if isinstance(num_files, int) else len(filenames) result["truncated"] = tr.get("truncated", False) result["duration_ms"] = tr.get("durationMs") result["filenames"] = filenames return result -def _tool_result_pred_grep(tr: dict[str, Any]) -> bool: - return "mode" in tr and "numFiles" in tr - - -def _tool_result_build_grep(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_grep(tr: ToolResultDict, base: dict[str, object]) -> dict[str, object]: result = dict(base) result["result_type"] = "grep" result["mode"] = tr.get("mode") @@ -97,13 +98,10 @@ def _tool_result_build_grep(tr: dict[str, Any], base: dict[str, Any]) -> dict[st return result -def _tool_result_pred_file_read(tr: dict[str, Any]) -> bool: - return "file" in tr and isinstance(tr["file"], dict) - - -def _tool_result_build_file_read(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_file_read(tr: ToolResultDict, base: dict[str, object]) -> dict[str, object]: result = dict(base) - file_obj = tr["file"] + raw_file = tr.get("file") + file_obj = raw_file if isinstance(raw_file, dict) else {} result["result_type"] = "file_read" result["file_path"] = file_obj.get("filePath", "") result["num_lines"] = file_obj.get("numLines") @@ -113,11 +111,7 @@ def _tool_result_build_file_read(tr: dict[str, Any], base: dict[str, Any]) -> di return result -def _tool_result_pred_web_search(tr: dict[str, Any]) -> bool: - return "query" in tr and "results" in tr - - -def _tool_result_build_web_search(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_web_search(tr: ToolResultDict, base: dict[str, object]) -> dict[str, object]: result = dict(base) result["result_type"] = "web_search" result["query"] = tr.get("query", "") @@ -132,11 +126,7 @@ def _tool_result_build_web_search(tr: dict[str, Any], base: dict[str, Any]) -> d return result -def _tool_result_pred_web_fetch(tr: dict[str, Any]) -> bool: - return "url" in tr and "code" in tr - - -def _tool_result_build_web_fetch(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_web_fetch(tr: ToolResultDict, base: dict[str, object]) -> dict[str, object]: result = dict(base) result["result_type"] = "web_fetch" result["url"] = tr.get("url", "") @@ -145,15 +135,9 @@ def _tool_result_build_web_fetch(tr: dict[str, Any], base: dict[str, Any]) -> di return result -def _tool_result_pred_task_message(tr: dict[str, Any]) -> bool: - # Broad: matches ``task_id`` OR ``message``. Runs before retrieval/completed/async - # arms below — same short-circuit order as the original if/elif chain. Payloads - # that also carry e.g. ``agentId`` still classify here if they have ``message``. - # Refining order needs golden fixtures; track as follow-up if real collisions appear. - return "task_id" in tr or "message" in tr - - -def _tool_result_build_task_message(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_task_message( + tr: ToolResultDict, base: dict[str, object] +) -> dict[str, object]: result = dict(base) result["result_type"] = "task" result["task_id"] = tr.get("task_id") @@ -161,11 +145,9 @@ def _tool_result_build_task_message(tr: dict[str, Any], base: dict[str, Any]) -> return result -def _tool_result_pred_task_retrieval(tr: dict[str, Any]) -> bool: - return "retrieval_status" in tr and "task" in tr - - -def _tool_result_build_task_retrieval(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_task_retrieval( + tr: ToolResultDict, base: dict[str, object] +) -> dict[str, object]: result = dict(base) task_obj = tr["task"] if isinstance(tr["task"], dict) else {} result["result_type"] = "task" @@ -174,11 +156,9 @@ def _tool_result_build_task_retrieval(tr: dict[str, Any], base: dict[str, Any]) return result -def _tool_result_pred_task_completed(tr: dict[str, Any]) -> bool: - return "agentId" in tr and "totalDurationMs" in tr - - -def _tool_result_build_task_completed(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_task_completed( + tr: ToolResultDict, base: dict[str, object] +) -> dict[str, object]: result = dict(base) result["result_type"] = "task" result["agent_id"] = tr.get("agentId") @@ -189,11 +169,7 @@ def _tool_result_build_task_completed(tr: dict[str, Any], base: dict[str, Any]) return result -def _tool_result_pred_task_async(tr: dict[str, Any]) -> bool: - return "agentId" in tr and "isAsync" in tr - - -def _tool_result_build_task_async(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_task_async(tr: ToolResultDict, base: dict[str, object]) -> dict[str, object]: result = dict(base) result["result_type"] = "task" result["agent_id"] = tr.get("agentId") @@ -202,11 +178,7 @@ def _tool_result_build_task_async(tr: dict[str, Any], base: dict[str, Any]) -> d return result -def _tool_result_pred_todo_write(tr: dict[str, Any]) -> bool: - return "newTodos" in tr or "oldTodos" in tr - - -def _tool_result_build_todo_write(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_todo_write(tr: ToolResultDict, base: dict[str, object]) -> dict[str, object]: result = dict(base) new_todos = tr.get("newTodos", []) result["result_type"] = "todo_write" @@ -215,11 +187,7 @@ def _tool_result_build_todo_write(tr: dict[str, Any], base: dict[str, Any]) -> d return result -def _tool_result_pred_user_input(tr: dict[str, Any]) -> bool: - return "questions" in tr and "answers" in tr - - -def _tool_result_build_user_input(tr: dict[str, Any], base: dict[str, Any]) -> dict[str, Any]: +def _tool_result_build_user_input(tr: ToolResultDict, base: dict[str, object]) -> dict[str, object]: result = dict(base) result["result_type"] = "user_input" result["questions"] = tr.get("questions", []) @@ -230,25 +198,27 @@ def _tool_result_build_user_input(tr: dict[str, Any], base: dict[str, Any]) -> d # Registry order is load-bearing (see module docstring). # ``plan`` before ``file_write``: plan blobs may carry ``filePath`` + ``content``. _TOOL_RESULT_DISPATCH = ( - (_tool_result_pred_bash, _tool_result_build_bash), - (_tool_result_pred_file_edit, _tool_result_build_file_edit), - (_tool_result_pred_plan, _tool_result_build_plan), - (_tool_result_pred_file_write, _tool_result_build_file_write), - (_tool_result_pred_glob, _tool_result_build_glob), - (_tool_result_pred_grep, _tool_result_build_grep), - (_tool_result_pred_file_read, _tool_result_build_file_read), - (_tool_result_pred_web_search, _tool_result_build_web_search), - (_tool_result_pred_web_fetch, _tool_result_build_web_fetch), - (_tool_result_pred_task_message, _tool_result_build_task_message), - (_tool_result_pred_task_retrieval, _tool_result_build_task_retrieval), - (_tool_result_pred_task_completed, _tool_result_build_task_completed), - (_tool_result_pred_task_async, _tool_result_build_task_async), - (_tool_result_pred_todo_write, _tool_result_build_todo_write), - (_tool_result_pred_user_input, _tool_result_build_user_input), + (is_bash_tool_result, _tool_result_build_bash), + (is_file_edit_tool_result, _tool_result_build_file_edit), + (is_plan_tool_result, _tool_result_build_plan), + (is_file_write_tool_result, _tool_result_build_file_write), + (is_glob_tool_result, _tool_result_build_glob), + (is_grep_tool_result, _tool_result_build_grep), + (is_read_tool_result, _tool_result_build_file_read), + (is_web_search_tool_result, _tool_result_build_web_search), + (is_web_fetch_tool_result, _tool_result_build_web_fetch), + (is_task_message_tool_result, _tool_result_build_task_message), + (is_task_retrieval_tool_result, _tool_result_build_task_retrieval), + (is_task_completed_tool_result, _tool_result_build_task_completed), + (is_task_async_tool_result, _tool_result_build_task_async), + (is_todo_write_tool_result, _tool_result_build_todo_write), + (is_user_input_tool_result, _tool_result_build_user_input), ) -def _parse_tool_result(tool_result: Any, slug: str | None = None) -> dict[str, Any] | None: +def _parse_tool_result( + tool_result: ToolResultUnion | None, slug: str | None = None +) -> dict[str, object] | None: """Figure out what kind of tool result this is (bash, file edit, glob, etc.) by looking at which keys are present, since the JSONL doesn't always tag them. @@ -259,13 +229,14 @@ def _parse_tool_result(tool_result: Any, slug: str | None = None) -> dict[str, A Append a new pair at the end to register a shape, or insert mid-table only after checking interactions with broader predicates above (see notes on the tuple).""" - if not isinstance(tool_result, dict): + if not is_tool_result_dict(tool_result): return None - base = {"slug": slug} + base: dict[str, object] = {"slug": slug} for pred, build in _TOOL_RESULT_DISPATCH: if pred(tool_result): - return build(tool_result, base) + # Builders take ToolResultDict; cast after pred (heterogeneous tuple, no union narrow). + return build(cast(ToolResultDict, tool_result), base) result = dict(base) result["result_type"] = "unknown"