diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0cea454a7..d5c01d14f 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -324,8 +324,11 @@ async def call_tool( async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None: """Validate the structured content of a tool result against its output schema.""" if name not in self._tool_output_schemas: - # refresh output schema cache - await self.list_tools() + # refresh output schema cache — paginate through all pages so tools + # beyond the first page are also considered before giving up. + list_result = await self.list_tools() + while list_result.next_cursor is not None and name not in self._tool_output_schemas: + list_result = await self.list_tools(params=types.PaginatedRequestParams(cursor=list_result.next_cursor)) output_schema = None if name in self._tool_output_schemas: @@ -476,5 +479,9 @@ async def _received_notification(self, notification: types.ServerNotification) - # Clients MAY use this to retry requests or update UI # The notification contains the elicitationId of the completed elicitation pass + case types.ToolListChangedNotification(): + # The server's tool list has changed; invalidate the cached output schemas + # so the next call_tool fetches fresh schemas before validating. + self._tool_output_schemas.clear() case _: pass diff --git a/tests/client/test_output_schema_validation.py b/tests/client/test_output_schema_validation.py index d78197b5c..948aebb8f 100644 --- a/tests/client/test_output_schema_validation.py +++ b/tests/client/test_output_schema_validation.py @@ -12,6 +12,7 @@ PaginatedRequestParams, TextContent, Tool, + ToolListChangedNotification, ) @@ -163,3 +164,114 @@ async def on_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) assert result.is_error is False assert "Tool mystery_tool not listed" in caplog.text + + +@pytest.mark.anyio +async def test_tool_list_changed_notification_clears_schema_cache(): + """ToolListChangedNotification must invalidate the cached output schemas. + + Flow: + Call 1 — schema v1 (integer). Client caches v1. Result validates OK. + Call 2 — server switches to v2 (string), sends ToolListChangedNotification + *before* returning the result, then returns a string value. + + Without the fix the client keeps the stale v1 schema and validates the + string result against it → RuntimeError (false negative). + With the fix the notification clears the cache, list_tools() re-fetches v2, + and the string result validates correctly → no error. + """ + schema_v1 = { + "type": "object", + "properties": {"result": {"type": "integer"}}, + "required": ["result"], + } + schema_v2 = { + "type": "object", + "properties": {"result": {"type": "string"}}, + "required": ["result"], + } + + use_v2: list[bool] = [False] # mutable container so nested functions can write to it + + async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + schema = schema_v2 if use_v2[0] else schema_v1 + return ListToolsResult( + tools=[Tool(name="dynamic_tool", description="d", input_schema={"type": "object"}, output_schema=schema)] + ) + + call_count: list[int] = [0] + + async def on_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + call_count[0] += 1 + if call_count[0] == 1: + # First call: v1 schema, no notification, integer result. + return CallToolResult( + content=[TextContent(type="text", text="r")], + structured_content={"result": 42}, # valid for v1 (integer) + ) + # Second call: switch schema to v2, notify BEFORE returning the result, + # then return a string value that is valid only under v2. + use_v2[0] = True + await ctx.session.send_notification(ToolListChangedNotification()) + return CallToolResult( + content=[TextContent(type="text", text="r")], + structured_content={"result": "hello"}, # valid for v2 (string), invalid for v1 + ) + + server = Server("test-server", on_list_tools=on_list_tools, on_call_tool=on_call_tool) + + async with Client(server) as client: + # Call 1: populates the cache with v1 schema and succeeds. + result1 = await client.call_tool("dynamic_tool", {}) + assert result1.structured_content == {"result": 42} + + # Call 2: notification arrives first → (with fix) cache cleared → list_tools() + # fetches v2 → string "hello" is valid → no error. + # Without the fix: stale v1 still in cache → "hello" fails integer check → RuntimeError. + result2 = await client.call_tool("dynamic_tool", {}) + assert result2.structured_content == {"result": "hello"} + + +@pytest.mark.anyio +async def test_validate_tool_result_paginates_all_pages(): + """_validate_tool_result must paginate through all tool pages when refreshing. + + Without the fix, only the first page of list_tools() is fetched. A tool that + sits on a later page is never found in the cache, so its output schema is + silently skipped — invalid structured_content is accepted without error. + """ + output_schema = { + "type": "object", + "properties": {"result": {"type": "integer"}}, + "required": ["result"], + } + + page1_tools = [Tool(name=f"tool_{i}", description="d", input_schema={"type": "object"}) for i in range(3)] + page2_tools = [ + Tool( + name="paginated_tool", + description="d", + input_schema={"type": "object"}, + output_schema=output_schema, + ) + ] + + async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + if params is not None and params.cursor == "page2": + return ListToolsResult(tools=page2_tools, next_cursor=None) + return ListToolsResult(tools=page1_tools, next_cursor="page2") + + async def on_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + # Returns a string for "result" — invalid per the integer schema. + return CallToolResult( + content=[TextContent(type="text", text="r")], + structured_content={"result": "not_an_integer"}, + ) + + server = Server("test-server", on_list_tools=on_list_tools, on_call_tool=on_call_tool) + + async with Client(server) as client: + # With the fix: both pages are fetched, schema is found, invalid content raises. + # Without the fix: only page 1 is fetched, tool not found, validation silently skipped. + with pytest.raises(RuntimeError, match="Invalid structured content returned by tool paginated_tool"): + await client.call_tool("paginated_tool", {})