From 43114941577c405fab4b1b9c28cfb27f3ac2ab67 Mon Sep 17 00:00:00 2001 From: Mateusz Poliwczak Date: Thu, 16 Apr 2026 10:56:14 +0200 Subject: [PATCH] Make id required (non-nullable) for tool/subagent/output calls Also while here, move the id field to be first. --- splunklib/ai/engines/langchain.py | 43 ++++++++++++++++++++++++++++--- splunklib/ai/messages.py | 8 +++--- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/splunklib/ai/engines/langchain.py b/splunklib/ai/engines/langchain.py index bd103724..354c6ebe 100644 --- a/splunklib/ai/engines/langchain.py +++ b/splunklib/ai/engines/langchain.py @@ -481,10 +481,45 @@ def unpack_tool_call(self, call: LC_ToolCall) -> LC_ToolCall: return call + class _CheckCallIDMiddleware(LC_AgentMiddleware): + def _check_has_call_id(self, msg: LC_AIMessage) -> None: + for call in msg.tool_calls: + if not call["id"]: + # If we ever hit this with real model, just generate a random call_id here. + raise Exception("LLM returned a Tool Call without a call_id") + + @override + async def awrap_model_call( + self, + request: LC_ModelRequest, + handler: Callable[[LC_ModelRequest], Awaitable[LC_ModelCallResult]], + ) -> LC_ModelCallResult: + try: + resp = await handler(request) + ai_message = resp + if isinstance(ai_message, LC_ExtendedModelResponse): + ai_message = ai_message.model_response + if isinstance(ai_message, LC_ModelResponse): + ai_message = next( + ( + m + for m in ai_message.result + if isinstance(m, LC_AIMessage) + ), + None, + ) + assert ai_message, "AIMessage not found found in response" + self._check_has_call_id(ai_message) + return resp + except LC_StructuredOutputError as e: + self._check_has_call_id(e.ai_message) + raise + lc_middleware.append(_ToolFailureArtifact()) if len(conversational_subagents) > 0: lc_middleware.append(_ThreadIDMiddleware()) lc_middleware.append(_SubagentArgumentPacker()) + lc_middleware.append(_CheckCallIDMiddleware()) class _DEBUGMiddleware(LC_AgentMiddleware): @override @@ -1254,7 +1289,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe StructuredOutputCall( name=tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX), args=tc["args"], - id=tc["id"], + id=tc["id"] or "", ) for tc in ai_message.tool_calls if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX) @@ -1529,7 +1564,7 @@ def _map_tool_call_from_langchain(tool_call: LC_ToolCall) -> ToolCall | Subagent name=_denormalize_agent_name(name), args=SubagentLCArgs(**tool_call["args"]).args, thread_id=SubagentLCArgs(**tool_call["args"]).thread_id, - id=tool_call["id"], + id=tool_call["id"] or "", ) tool_type: ToolType = ( @@ -1538,7 +1573,7 @@ def _map_tool_call_from_langchain(tool_call: LC_ToolCall) -> ToolCall | Subagent return ToolCall( name=_denormalize_tool_name(name), args=tool_call["args"], - id=tool_call["id"], + id=tool_call["id"] or "", type=tool_type, ) @@ -1567,9 +1602,9 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage: ], structured_output_calls=[ StructuredOutputCall( + tc["id"] or "", tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX), tc["args"], - tc["id"], ) for tc in message.tool_calls if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX) diff --git a/splunklib/ai/messages.py b/splunklib/ai/messages.py index 04db32b6..0771250e 100644 --- a/splunklib/ai/messages.py +++ b/splunklib/ai/messages.py @@ -23,25 +23,25 @@ @dataclass(frozen=True) class ToolCall: + id: str name: str - args: dict[str, Any] - id: str | None # TODO: can be None? type: ToolType + args: dict[str, Any] @dataclass(frozen=True) class SubagentCall: + id: str name: str args: str | dict[str, Any] - id: str | None # TODO: can be None? thread_id: str | None @dataclass(frozen=True) class StructuredOutputCall: + id: str name: str args: dict[str, Any] - id: str | None # TODO: can be None? @dataclass(frozen=True)