diff --git a/src/a2a/client/transports/__init__.py b/src/a2a/client/transports/__init__.py index af7c60f62..001ca67b5 100644 --- a/src/a2a/client/transports/__init__.py +++ b/src/a2a/client/transports/__init__.py @@ -3,6 +3,7 @@ from a2a.client.transports.base import ClientTransport from a2a.client.transports.jsonrpc import JsonRpcTransport from a2a.client.transports.rest import RestTransport +from a2a.client.transports.tenant_decorator import TenantTransportDecorator try: @@ -16,4 +17,5 @@ 'GrpcTransport', 'JsonRpcTransport', 'RestTransport', + 'TenantTransportDecorator', ] diff --git a/src/a2a/client/transports/tenant_decorator.py b/src/a2a/client/transports/tenant_decorator.py index d1059d757..80f596d2e 100644 --- a/src/a2a/client/transports/tenant_decorator.py +++ b/src/a2a/client/transports/tenant_decorator.py @@ -43,7 +43,7 @@ async def send_message( *, context: ClientCallContext | None = None, ) -> SendMessageResponse: - """Sends a streaming message request to the agent and yields responses as they arrive.""" + """Sends a non-streaming message request to the agent.""" request.tenant = self._resolve_tenant(request.tenant) return await self._base.send_message(request, context=context) diff --git a/tests/client/transports/test_tenant_decorator.py b/tests/client/transports/test_tenant_decorator.py index b08406bad..1e560d2ac 100644 --- a/tests/client/transports/test_tenant_decorator.py +++ b/tests/client/transports/test_tenant_decorator.py @@ -127,3 +127,22 @@ async def mock_stream(*args, **kwargs): async for _ in decorator.send_message_streaming(request_msg): pass assert request_msg.tenant == tenant_id + + @pytest.mark.asyncio + async def test_close_delegates_to_base( + self, mock_transport: AsyncMock + ) -> None: + """Test that close() is delegated to the underlying transport.""" + decorator = TenantTransportDecorator(mock_transport, 'test-tenant') + await decorator.close() + mock_transport.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_async_context_manager( + self, mock_transport: AsyncMock + ) -> None: + """Test that the decorator works as an async context manager.""" + decorator = TenantTransportDecorator(mock_transport, 'test-tenant') + async with decorator as transport: + assert transport is decorator + mock_transport.close.assert_awaited_once()