From b138aff6600a970108dc69a3d9cb563d37d56c3a Mon Sep 17 00:00:00 2001 From: mtsmyassin Date: Tue, 7 Apr 2026 20:53:52 -0400 Subject: [PATCH] feat: add ArtifactStreamer for streaming artifact updates with stable ID Adds a stateful streaming helper to a2a.utils that maintains a stable artifact_id across chunks, enabling correct append=True semantics for TaskArtifactUpdateEvent. Closes #833 --- src/a2a/utils/__init__.py | 2 + src/a2a/utils/artifact.py | 109 ++++++++++++++++++++++++++++- tests/utils/test_artifact.py | 128 +++++++++++++++++++++++++++++++++++ 3 files changed, 238 insertions(+), 1 deletion(-) diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index e5b5663dd..5478377ca 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -1,6 +1,7 @@ """Utility functions for the A2A Python SDK.""" from a2a.utils.artifact import ( + ArtifactStreamer, get_artifact_text, new_artifact, new_data_artifact, @@ -39,6 +40,7 @@ 'DEFAULT_RPC_URL', 'EXTENDED_AGENT_CARD_PATH', 'PREV_AGENT_CARD_WELL_KNOWN_PATH', + 'ArtifactStreamer', 'append_artifact_to_task', 'are_modalities_compatible', 'build_text_artifact', diff --git a/src/a2a/utils/artifact.py b/src/a2a/utils/artifact.py index 5053ca421..4c1b81278 100644 --- a/src/a2a/utils/artifact.py +++ b/src/a2a/utils/artifact.py @@ -4,7 +4,13 @@ from typing import Any -from a2a.types import Artifact, DataPart, Part, TextPart +from a2a.types import ( + Artifact, + DataPart, + Part, + TaskArtifactUpdateEvent, + TextPart, +) from a2a.utils.parts import get_text_parts @@ -86,3 +92,104 @@ def get_artifact_text(artifact: Artifact, delimiter: str = '\n') -> str: A single string containing all text content, or an empty string if no text parts are found. """ return delimiter.join(get_text_parts(artifact.parts)) + + +class ArtifactStreamer: + """A stateful helper for streaming artifact updates with a stable artifact ID. + + Solves the problem where calling ``new_text_artifact`` in a loop generates + a fresh ``artifact_id`` each time, making ``append=True`` unusable. + + Example:: + + streamer = ArtifactStreamer(context_id, task_id, name='response') + + async for chunk in llm.stream(prompt): + await event_queue.enqueue_event(streamer.append(chunk)) + + await event_queue.enqueue_event(streamer.finalize()) + + Args: + context_id: The context ID associated with the task. + task_id: The ID of the task this artifact belongs to. + name: A human-readable name for the artifact. + description: An optional description of the artifact. + """ + + def __init__( + self, + context_id: str, + task_id: str, + name: str, + description: str | None = None, + ) -> None: + self._context_id = context_id + self._task_id = task_id + self._name = name + self._description = description + self._artifact_id = str(uuid.uuid4()) + self._finalized = False + + @property + def artifact_id(self) -> str: + """The stable artifact ID used across all chunks.""" + return self._artifact_id + + def append(self, text: str) -> TaskArtifactUpdateEvent: + """Create an append event for the next chunk of text. + + Args: + text: The text content to append. + + Returns: + A ``TaskArtifactUpdateEvent`` with ``append=True`` and + ``last_chunk=False``. + + Raises: + RuntimeError: If ``finalize()`` has already been called. + """ + if self._finalized: + raise RuntimeError( + 'Cannot append after finalize() has been called.' + ) + return TaskArtifactUpdateEvent( + context_id=self._context_id, + task_id=self._task_id, + append=True, + last_chunk=False, + artifact=Artifact( + artifact_id=self._artifact_id, + name=self._name, + description=self._description, + parts=[Part(root=TextPart(text=text))], + ), + ) + + def finalize(self, text: str = '') -> TaskArtifactUpdateEvent: + """Create the final chunk event, closing the stream. + + Args: + text: Optional final text content. Defaults to empty string. + + Returns: + A ``TaskArtifactUpdateEvent`` with ``append=True`` and + ``last_chunk=True``. + + Raises: + RuntimeError: If ``finalize()`` has already been called. + """ + if self._finalized: + raise RuntimeError('finalize() has already been called.') + self._finalized = True + return TaskArtifactUpdateEvent( + context_id=self._context_id, + task_id=self._task_id, + append=True, + last_chunk=True, + artifact=Artifact( + artifact_id=self._artifact_id, + name=self._name, + description=self._description, + parts=[Part(root=TextPart(text=text))], + ), + ) diff --git a/tests/utils/test_artifact.py b/tests/utils/test_artifact.py index 489c047c4..11499553e 100644 --- a/tests/utils/test_artifact.py +++ b/tests/utils/test_artifact.py @@ -7,9 +7,11 @@ Artifact, DataPart, Part, + TaskArtifactUpdateEvent, TextPart, ) from a2a.utils.artifact import ( + ArtifactStreamer, get_artifact_text, new_artifact, new_data_artifact, @@ -155,5 +157,131 @@ def test_get_artifact_text_empty_parts(self): assert result == '' +class TestArtifactStreamer(unittest.TestCase): + def setUp(self): + self.context_id = 'ctx-123' + self.task_id = 'task-456' + self.name = 'response' + + def test_stable_artifact_id_across_appends(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, name=self.name + ) + event1 = streamer.append('Hello ') + event2 = streamer.append('world') + self.assertEqual( + event1.artifact.artifact_id, event2.artifact.artifact_id + ) + + def test_append_returns_correct_event_type(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, name=self.name + ) + event = streamer.append('chunk') + self.assertIsInstance(event, TaskArtifactUpdateEvent) + + def test_append_sets_append_true_last_chunk_false(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, name=self.name + ) + event = streamer.append('chunk') + self.assertTrue(event.append) + self.assertFalse(event.last_chunk) + + def test_append_sets_context_and_task_ids(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, name=self.name + ) + event = streamer.append('chunk') + self.assertEqual(event.context_id, self.context_id) + self.assertEqual(event.task_id, self.task_id) + + def test_append_sets_text_content(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, name=self.name + ) + event = streamer.append('Hello world') + self.assertEqual(len(event.artifact.parts), 1) + self.assertEqual(event.artifact.parts[0].root.text, 'Hello world') + + def test_append_sets_artifact_name_and_description(self): + streamer = ArtifactStreamer( + self.context_id, + self.task_id, + name='my-artifact', + description='A streamed response', + ) + event = streamer.append('chunk') + self.assertEqual(event.artifact.name, 'my-artifact') + self.assertEqual(event.artifact.description, 'A streamed response') + + def test_finalize_sets_last_chunk_true(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, name=self.name + ) + event = streamer.finalize('done') + self.assertTrue(event.append) + self.assertTrue(event.last_chunk) + + def test_finalize_with_empty_text(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, name=self.name + ) + event = streamer.finalize() + self.assertEqual(event.artifact.parts[0].root.text, '') + self.assertTrue(event.last_chunk) + + def test_finalize_uses_same_artifact_id(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, name=self.name + ) + append_event = streamer.append('chunk') + finalize_event = streamer.finalize() + self.assertEqual( + append_event.artifact.artifact_id, + finalize_event.artifact.artifact_id, + ) + + def test_append_after_finalize_raises(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, name=self.name + ) + streamer.finalize() + with self.assertRaises(RuntimeError): + streamer.append('too late') + + def test_double_finalize_raises(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, name=self.name + ) + streamer.finalize() + with self.assertRaises(RuntimeError): + streamer.finalize() + + def test_artifact_id_property(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, name=self.name + ) + artifact_id = streamer.artifact_id + self.assertIsInstance(artifact_id, str) + self.assertTrue(len(artifact_id) > 0) + + @patch('uuid.uuid4') + def test_artifact_id_from_uuid(self, mock_uuid4): + mock_uuid = uuid.UUID('12345678-1234-5678-1234-567812345678') + mock_uuid4.return_value = mock_uuid + streamer = ArtifactStreamer( + self.context_id, self.task_id, name=self.name + ) + self.assertEqual(streamer.artifact_id, str(mock_uuid)) + + def test_description_defaults_to_none(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, name=self.name + ) + event = streamer.append('chunk') + self.assertIsNone(event.artifact.description) + + if __name__ == '__main__': unittest.main()