From 97efac24cf52f4a0318f89c39dafc8cf353d1b17 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 14 Apr 2026 11:17:11 -0600 Subject: [PATCH 1/3] Add history retrieval and terminal instance ID paging APIs --- CHANGELOG.md | 6 + durabletask/client.py | 87 +++- durabletask/history.py | 535 ++++++++++++++++++++ durabletask/internal/history_helpers.py | 68 +++ durabletask/testing/in_memory_backend.py | 63 ++- tests/durabletask/test_batch_actions.py | 39 ++ tests/durabletask/test_client.py | 144 +++++- tests/durabletask/test_orchestration_e2e.py | 27 + 8 files changed, 964 insertions(+), 5 deletions(-) create mode 100644 durabletask/history.py create mode 100644 durabletask/internal/history_helpers.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e51ec3c..b405b513 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +ADDED + +- Added `get_orchestration_history()` and `list_instance_ids()` to the sync and async gRPC clients. +- Added in-memory backend support for `StreamInstanceHistory` and `ListInstanceIds` to enable history retrieval and completion-window pagination in tests. +- Added internal history helpers for aggregating streamed history events, de-externalizing payload-backed values, and converting history events to dictionaries. + ## v1.4.0 ADDED diff --git a/durabletask/client.py b/durabletask/client.py index a73fc343..8d2831e0 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -6,14 +6,16 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, List, Optional, Sequence, TypeVar, Union +from typing import Any, Generic, List, Optional, Sequence, TypeVar, Union import grpc import grpc.aio +import durabletask.history as history from durabletask.entities import EntityInstanceId from durabletask.entities.entity_metadata import EntityMetadata import durabletask.internal.helpers as helpers +import durabletask.internal.history_helpers as history_helpers import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared @@ -37,6 +39,7 @@ TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') +TItem = TypeVar('TItem') class OrchestrationStatus(Enum): @@ -99,6 +102,12 @@ class PurgeInstancesResult: is_complete: bool +@dataclass +class Page(Generic[TItem]): + items: List[TItem] + continuation_token: Optional[str] + + @dataclass class CleanEntityStorageResult: empty_entities_removed: int @@ -218,6 +227,44 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr payload_helpers.deexternalize_payloads(res, self._payload_store) return new_orchestration_state(req.instanceId, res) + def get_orchestration_history(self, + instance_id: str, *, + execution_id: Optional[str] = None, + for_work_item_processing: bool = False) -> List[history.HistoryEvent]: + req = pb.StreamInstanceHistoryRequest( + instanceId=instance_id, + executionId=helpers.get_string_value(execution_id), + forWorkItemProcessing=for_work_item_processing, + ) + self._logger.info(f"Retrieving history for instance '{instance_id}'.") + stream = self._stub.StreamInstanceHistory(req) + return history_helpers.collect_history_events(stream, self._payload_store) + + def list_instance_ids(self, + runtime_status: Optional[List[OrchestrationStatus]] = None, + completed_time_from: Optional[datetime] = None, + completed_time_to: Optional[datetime] = None, + page_size: Optional[int] = None, + continuation_token: Optional[str] = None) -> Page[str]: + req = pb.ListInstanceIdsRequest( + runtimeStatus=[status.value for status in runtime_status] if runtime_status else None, + completedTimeFrom=helpers.new_timestamp(completed_time_from) if completed_time_from else None, + completedTimeTo=helpers.new_timestamp(completed_time_to) if completed_time_to else None, + pageSize=page_size or 0, + lastInstanceKey=helpers.get_string_value(continuation_token), + ) + self._logger.info( + "Listing terminal instance IDs with filters: " + f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, " + f"completed_time_from={completed_time_from}, " + f"completed_time_to={completed_time_to}, " + f"page_size={page_size}, " + f"continuation_token={continuation_token}" + ) + resp: pb.ListInstanceIdsResponse = self._stub.ListInstanceIds(req) + next_token = resp.lastInstanceKey.value if resp.HasField("lastInstanceKey") else None + return Page(items=list(resp.instanceIds), continuation_token=next_token) + def get_all_orchestration_states(self, orchestration_query: Optional[OrchestrationQuery] = None ) -> List[OrchestrationState]: @@ -502,6 +549,44 @@ async def get_orchestration_state(self, instance_id: str, *, await payload_helpers.deexternalize_payloads_async(res, self._payload_store) return new_orchestration_state(req.instanceId, res) + async def get_orchestration_history(self, + instance_id: str, *, + execution_id: Optional[str] = None, + for_work_item_processing: bool = False) -> List[history.HistoryEvent]: + req = pb.StreamInstanceHistoryRequest( + instanceId=instance_id, + executionId=helpers.get_string_value(execution_id), + forWorkItemProcessing=for_work_item_processing, + ) + self._logger.info(f"Retrieving history for instance '{instance_id}'.") + stream = self._stub.StreamInstanceHistory(req) + return await history_helpers.collect_history_events_async(stream, self._payload_store) + + async def list_instance_ids(self, + runtime_status: Optional[List[OrchestrationStatus]] = None, + completed_time_from: Optional[datetime] = None, + completed_time_to: Optional[datetime] = None, + page_size: Optional[int] = None, + continuation_token: Optional[str] = None) -> Page[str]: + req = pb.ListInstanceIdsRequest( + runtimeStatus=[status.value for status in runtime_status] if runtime_status else None, + completedTimeFrom=helpers.new_timestamp(completed_time_from) if completed_time_from else None, + completedTimeTo=helpers.new_timestamp(completed_time_to) if completed_time_to else None, + pageSize=page_size or 0, + lastInstanceKey=helpers.get_string_value(continuation_token), + ) + self._logger.info( + "Listing terminal instance IDs with filters: " + f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, " + f"completed_time_from={completed_time_from}, " + f"completed_time_to={completed_time_to}, " + f"page_size={page_size}, " + f"continuation_token={continuation_token}" + ) + resp: pb.ListInstanceIdsResponse = await self._stub.ListInstanceIds(req) + next_token = resp.lastInstanceKey.value if resp.HasField("lastInstanceKey") else None + return Page(items=list(resp.instanceIds), continuation_token=next_token) + async def get_all_orchestration_states(self, orchestration_query: Optional[OrchestrationQuery] = None ) -> List[OrchestrationState]: diff --git a/durabletask/history.py b/durabletask/history.py new file mode 100644 index 00000000..34046c62 --- /dev/null +++ b/durabletask/history.py @@ -0,0 +1,535 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from dataclasses import asdict, dataclass +from datetime import datetime +from typing import Any, Optional + +from google.protobuf import json_format +from google.protobuf.message import Message + +from durabletask import task +import durabletask.internal.orchestrator_service_pb2 as pb + + +@dataclass(slots=True) +class OrchestrationInstance: + instance_id: str + execution_id: Optional[str] = None + + +@dataclass(slots=True) +class ParentInstanceInfo: + task_scheduled_id: int + name: Optional[str] = None + version: Optional[str] = None + orchestration_instance: Optional[OrchestrationInstance] = None + + +@dataclass(slots=True) +class TraceContext: + trace_parent: str + span_id: str + trace_state: Optional[str] = None + + +@dataclass(slots=True) +class HistoryEvent: + event_id: int + timestamp: datetime + + def to_dict(self) -> dict[str, Any]: + return _to_serializable(asdict(self)) + + +@dataclass(slots=True) +class ExecutionStartedEvent(HistoryEvent): + name: str + version: Optional[str] = None + input: Optional[str] = None + orchestration_instance: Optional[OrchestrationInstance] = None + parent_instance: Optional[ParentInstanceInfo] = None + scheduled_start_timestamp: Optional[datetime] = None + parent_trace_context: Optional[TraceContext] = None + orchestration_span_id: Optional[str] = None + tags: Optional[dict[str, str]] = None + + +@dataclass(slots=True) +class ExecutionCompletedEvent(HistoryEvent): + orchestration_status: int + result: Optional[str] = None + failure_details: Optional[task.FailureDetails] = None + + +@dataclass(slots=True) +class ExecutionTerminatedEvent(HistoryEvent): + input: Optional[str] = None + recurse: bool = False + + +@dataclass(slots=True) +class TaskScheduledEvent(HistoryEvent): + name: str + version: Optional[str] = None + input: Optional[str] = None + parent_trace_context: Optional[TraceContext] = None + tags: Optional[dict[str, str]] = None + + +@dataclass(slots=True) +class TaskCompletedEvent(HistoryEvent): + task_scheduled_id: int + result: Optional[str] = None + + +@dataclass(slots=True) +class TaskFailedEvent(HistoryEvent): + task_scheduled_id: int + failure_details: Optional[task.FailureDetails] = None + + +@dataclass(slots=True) +class SubOrchestrationInstanceCreatedEvent(HistoryEvent): + instance_id: str + name: str + version: Optional[str] = None + input: Optional[str] = None + parent_trace_context: Optional[TraceContext] = None + tags: Optional[dict[str, str]] = None + + +@dataclass(slots=True) +class SubOrchestrationInstanceCompletedEvent(HistoryEvent): + task_scheduled_id: int + result: Optional[str] = None + + +@dataclass(slots=True) +class SubOrchestrationInstanceFailedEvent(HistoryEvent): + task_scheduled_id: int + failure_details: Optional[task.FailureDetails] = None + + +@dataclass(slots=True) +class TimerCreatedEvent(HistoryEvent): + fire_at: datetime + + +@dataclass(slots=True) +class TimerFiredEvent(HistoryEvent): + fire_at: datetime + timer_id: int + + +@dataclass(slots=True) +class OrchestratorStartedEvent(HistoryEvent): + pass + + +@dataclass(slots=True) +class OrchestratorCompletedEvent(HistoryEvent): + pass + + +@dataclass(slots=True) +class EventSentEvent(HistoryEvent): + instance_id: str + name: str + input: Optional[str] = None + + +@dataclass(slots=True) +class EventRaisedEvent(HistoryEvent): + name: str + input: Optional[str] = None + + +@dataclass(slots=True) +class GenericEvent(HistoryEvent): + data: Optional[str] = None + + +@dataclass(slots=True) +class HistoryStateEvent(HistoryEvent): + orchestration_state: dict[str, Any] + + +@dataclass(slots=True) +class ContinueAsNewEvent(HistoryEvent): + input: Optional[str] = None + + +@dataclass(slots=True) +class ExecutionSuspendedEvent(HistoryEvent): + input: Optional[str] = None + + +@dataclass(slots=True) +class ExecutionResumedEvent(HistoryEvent): + input: Optional[str] = None + + +@dataclass(slots=True) +class EntityOperationSignaledEvent(HistoryEvent): + request_id: str + operation: str + scheduled_time: Optional[datetime] = None + input: Optional[str] = None + target_instance_id: Optional[str] = None + + +@dataclass(slots=True) +class EntityOperationCalledEvent(HistoryEvent): + request_id: str + operation: str + scheduled_time: Optional[datetime] = None + input: Optional[str] = None + parent_instance_id: Optional[str] = None + parent_execution_id: Optional[str] = None + target_instance_id: Optional[str] = None + + +@dataclass(slots=True) +class EntityOperationCompletedEvent(HistoryEvent): + request_id: str + output: Optional[str] = None + + +@dataclass(slots=True) +class EntityOperationFailedEvent(HistoryEvent): + request_id: str + failure_details: Optional[task.FailureDetails] = None + + +@dataclass(slots=True) +class EntityLockRequestedEvent(HistoryEvent): + critical_section_id: str + lock_set: list[str] + position: int + parent_instance_id: Optional[str] = None + + +@dataclass(slots=True) +class EntityLockGrantedEvent(HistoryEvent): + critical_section_id: str + + +@dataclass(slots=True) +class EntityUnlockSentEvent(HistoryEvent): + critical_section_id: str + parent_instance_id: Optional[str] = None + target_instance_id: Optional[str] = None + + +@dataclass(slots=True) +class ExecutionRewoundEvent(HistoryEvent): + reason: Optional[str] = None + parent_execution_id: Optional[str] = None + instance_id: Optional[str] = None + parent_trace_context: Optional[TraceContext] = None + name: Optional[str] = None + version: Optional[str] = None + input: Optional[str] = None + parent_instance: Optional[ParentInstanceInfo] = None + tags: Optional[dict[str, str]] = None + + +def _from_protobuf(event: pb.HistoryEvent) -> HistoryEvent: + event_type = event.WhichOneof('eventType') + if event_type is None: + raise ValueError('History event does not have an eventType set') + converter = _EVENT_CONVERTERS.get(event_type) + if converter is None: + raise ValueError(f'Unsupported history event type: {event_type}') + return converter(event) + + +def to_dict(event: HistoryEvent) -> dict[str, Any]: + return event.to_dict() + + +def _base_kwargs(event: pb.HistoryEvent) -> dict[str, Any]: + return { + 'event_id': event.eventId, + 'timestamp': event.timestamp.ToDatetime(), + } + + +def _string_value(msg: Message, field_name: str) -> Optional[str]: + if msg.HasField(field_name): + return getattr(msg, field_name).value + return None + + +def _timestamp_value(msg: Message, field_name: str) -> Optional[datetime]: + if msg.HasField(field_name): + return getattr(msg, field_name).ToDatetime() + return None + + +def _failure_details(msg: Message, field_name: str) -> Optional[task.FailureDetails]: + if not msg.HasField(field_name): + return None + details = getattr(msg, field_name) + return task.FailureDetails( + details.errorMessage, + details.errorType, + details.stackTrace.value if details.HasField('stackTrace') else None, + ) + + +def _trace_context(msg: Message, field_name: str) -> Optional[TraceContext]: + if not msg.HasField(field_name): + return None + value = getattr(msg, field_name) + return TraceContext( + trace_parent=value.traceParent, + span_id=value.spanID, + trace_state=value.traceState.value if value.HasField('traceState') else None, + ) + + +def _orchestration_instance(msg: Message, field_name: str) -> Optional[OrchestrationInstance]: + if not msg.HasField(field_name): + return None + value = getattr(msg, field_name) + return OrchestrationInstance( + instance_id=value.instanceId, + execution_id=value.executionId.value if value.HasField('executionId') else None, + ) + + +def _parent_instance(msg: Message, field_name: str) -> Optional[ParentInstanceInfo]: + if not msg.HasField(field_name): + return None + value = getattr(msg, field_name) + orchestration_instance = None + if value.HasField('orchestrationInstance'): + orchestration_instance = OrchestrationInstance( + instance_id=value.orchestrationInstance.instanceId, + execution_id=value.orchestrationInstance.executionId.value + if value.orchestrationInstance.HasField('executionId') else None, + ) + return ParentInstanceInfo( + task_scheduled_id=value.taskScheduledId, + name=value.name.value if value.HasField('name') else None, + version=value.version.value if value.HasField('version') else None, + orchestration_instance=orchestration_instance, + ) + + +def _message_to_dict(msg: Message) -> dict[str, Any]: + return json_format.MessageToDict(msg, preserving_proto_field_name=True) + + +def _to_serializable(value: Any) -> Any: + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, list): + return [_to_serializable(item) for item in value] + if isinstance(value, dict): + return {key: _to_serializable(item) for key, item in value.items()} + return value + + +_EVENT_CONVERTERS: dict[str, Any] = { + 'executionStarted': lambda event: ExecutionStartedEvent( + **_base_kwargs(event), + name=event.executionStarted.name, + version=_string_value(event.executionStarted, 'version'), + input=_string_value(event.executionStarted, 'input'), + orchestration_instance=_orchestration_instance(event.executionStarted, 'orchestrationInstance'), + parent_instance=_parent_instance(event.executionStarted, 'parentInstance'), + scheduled_start_timestamp=_timestamp_value(event.executionStarted, 'scheduledStartTimestamp'), + parent_trace_context=_trace_context(event.executionStarted, 'parentTraceContext'), + orchestration_span_id=_string_value(event.executionStarted, 'orchestrationSpanID'), + tags=dict(event.executionStarted.tags) if event.executionStarted.tags else None, + ), + 'executionCompleted': lambda event: ExecutionCompletedEvent( + **_base_kwargs(event), + orchestration_status=event.executionCompleted.orchestrationStatus, + result=_string_value(event.executionCompleted, 'result'), + failure_details=_failure_details(event.executionCompleted, 'failureDetails'), + ), + 'executionTerminated': lambda event: ExecutionTerminatedEvent( + **_base_kwargs(event), + input=_string_value(event.executionTerminated, 'input'), + recurse=event.executionTerminated.recurse, + ), + 'taskScheduled': lambda event: TaskScheduledEvent( + **_base_kwargs(event), + name=event.taskScheduled.name, + version=_string_value(event.taskScheduled, 'version'), + input=_string_value(event.taskScheduled, 'input'), + parent_trace_context=_trace_context(event.taskScheduled, 'parentTraceContext'), + tags=dict(event.taskScheduled.tags) if event.taskScheduled.tags else None, + ), + 'taskCompleted': lambda event: TaskCompletedEvent( + **_base_kwargs(event), + task_scheduled_id=event.taskCompleted.taskScheduledId, + result=_string_value(event.taskCompleted, 'result'), + ), + 'taskFailed': lambda event: TaskFailedEvent( + **_base_kwargs(event), + task_scheduled_id=event.taskFailed.taskScheduledId, + failure_details=_failure_details(event.taskFailed, 'failureDetails'), + ), + 'subOrchestrationInstanceCreated': lambda event: SubOrchestrationInstanceCreatedEvent( + **_base_kwargs(event), + instance_id=event.subOrchestrationInstanceCreated.instanceId, + name=event.subOrchestrationInstanceCreated.name, + version=_string_value(event.subOrchestrationInstanceCreated, 'version'), + input=_string_value(event.subOrchestrationInstanceCreated, 'input'), + parent_trace_context=_trace_context(event.subOrchestrationInstanceCreated, 'parentTraceContext'), + tags=dict(event.subOrchestrationInstanceCreated.tags) if event.subOrchestrationInstanceCreated.tags else None, + ), + 'subOrchestrationInstanceCompleted': lambda event: SubOrchestrationInstanceCompletedEvent( + **_base_kwargs(event), + task_scheduled_id=event.subOrchestrationInstanceCompleted.taskScheduledId, + result=_string_value(event.subOrchestrationInstanceCompleted, 'result'), + ), + 'subOrchestrationInstanceFailed': lambda event: SubOrchestrationInstanceFailedEvent( + **_base_kwargs(event), + task_scheduled_id=event.subOrchestrationInstanceFailed.taskScheduledId, + failure_details=_failure_details(event.subOrchestrationInstanceFailed, 'failureDetails'), + ), + 'timerCreated': lambda event: TimerCreatedEvent( + **_base_kwargs(event), + fire_at=event.timerCreated.fireAt.ToDatetime(), + ), + 'timerFired': lambda event: TimerFiredEvent( + **_base_kwargs(event), + fire_at=event.timerFired.fireAt.ToDatetime(), + timer_id=event.timerFired.timerId, + ), + 'orchestratorStarted': lambda event: OrchestratorStartedEvent(**_base_kwargs(event)), + 'orchestratorCompleted': lambda event: OrchestratorCompletedEvent(**_base_kwargs(event)), + 'eventSent': lambda event: EventSentEvent( + **_base_kwargs(event), + instance_id=event.eventSent.instanceId, + name=event.eventSent.name, + input=_string_value(event.eventSent, 'input'), + ), + 'eventRaised': lambda event: EventRaisedEvent( + **_base_kwargs(event), + name=event.eventRaised.name, + input=_string_value(event.eventRaised, 'input'), + ), + 'genericEvent': lambda event: GenericEvent( + **_base_kwargs(event), + data=_string_value(event.genericEvent, 'data'), + ), + 'historyState': lambda event: HistoryStateEvent( + **_base_kwargs(event), + orchestration_state=_message_to_dict(event.historyState.orchestrationState), + ), + 'continueAsNew': lambda event: ContinueAsNewEvent( + **_base_kwargs(event), + input=_string_value(event.continueAsNew, 'input'), + ), + 'executionSuspended': lambda event: ExecutionSuspendedEvent( + **_base_kwargs(event), + input=_string_value(event.executionSuspended, 'input'), + ), + 'executionResumed': lambda event: ExecutionResumedEvent( + **_base_kwargs(event), + input=_string_value(event.executionResumed, 'input'), + ), + 'entityOperationSignaled': lambda event: EntityOperationSignaledEvent( + **_base_kwargs(event), + request_id=event.entityOperationSignaled.requestId, + operation=event.entityOperationSignaled.operation, + scheduled_time=_timestamp_value(event.entityOperationSignaled, 'scheduledTime'), + input=_string_value(event.entityOperationSignaled, 'input'), + target_instance_id=_string_value(event.entityOperationSignaled, 'targetInstanceId'), + ), + 'entityOperationCalled': lambda event: EntityOperationCalledEvent( + **_base_kwargs(event), + request_id=event.entityOperationCalled.requestId, + operation=event.entityOperationCalled.operation, + scheduled_time=_timestamp_value(event.entityOperationCalled, 'scheduledTime'), + input=_string_value(event.entityOperationCalled, 'input'), + parent_instance_id=_string_value(event.entityOperationCalled, 'parentInstanceId'), + parent_execution_id=_string_value(event.entityOperationCalled, 'parentExecutionId'), + target_instance_id=_string_value(event.entityOperationCalled, 'targetInstanceId'), + ), + 'entityOperationCompleted': lambda event: EntityOperationCompletedEvent( + **_base_kwargs(event), + request_id=event.entityOperationCompleted.requestId, + output=_string_value(event.entityOperationCompleted, 'output'), + ), + 'entityOperationFailed': lambda event: EntityOperationFailedEvent( + **_base_kwargs(event), + request_id=event.entityOperationFailed.requestId, + failure_details=_failure_details(event.entityOperationFailed, 'failureDetails'), + ), + 'entityLockRequested': lambda event: EntityLockRequestedEvent( + **_base_kwargs(event), + critical_section_id=event.entityLockRequested.criticalSectionId, + lock_set=list(event.entityLockRequested.lockSet), + position=event.entityLockRequested.position, + parent_instance_id=_string_value(event.entityLockRequested, 'parentInstanceId'), + ), + 'entityLockGranted': lambda event: EntityLockGrantedEvent( + **_base_kwargs(event), + critical_section_id=event.entityLockGranted.criticalSectionId, + ), + 'entityUnlockSent': lambda event: EntityUnlockSentEvent( + **_base_kwargs(event), + critical_section_id=event.entityUnlockSent.criticalSectionId, + parent_instance_id=_string_value(event.entityUnlockSent, 'parentInstanceId'), + target_instance_id=_string_value(event.entityUnlockSent, 'targetInstanceId'), + ), + 'executionRewound': lambda event: ExecutionRewoundEvent( + **_base_kwargs(event), + reason=_string_value(event.executionRewound, 'reason'), + parent_execution_id=_string_value(event.executionRewound, 'parentExecutionId'), + instance_id=_string_value(event.executionRewound, 'instanceId'), + parent_trace_context=_trace_context(event.executionRewound, 'parentTraceContext'), + name=_string_value(event.executionRewound, 'name'), + version=_string_value(event.executionRewound, 'version'), + input=_string_value(event.executionRewound, 'input'), + parent_instance=_parent_instance(event.executionRewound, 'parentInstance'), + tags=dict(event.executionRewound.tags) if event.executionRewound.tags else None, + ), +} + + +__all__ = [ + 'ContinueAsNewEvent', + 'EntityLockGrantedEvent', + 'EntityLockRequestedEvent', + 'EntityOperationCalledEvent', + 'EntityOperationCompletedEvent', + 'EntityOperationFailedEvent', + 'EntityOperationSignaledEvent', + 'EntityUnlockSentEvent', + 'EventRaisedEvent', + 'EventSentEvent', + 'ExecutionCompletedEvent', + 'ExecutionResumedEvent', + 'ExecutionRewoundEvent', + 'ExecutionStartedEvent', + 'ExecutionSuspendedEvent', + 'ExecutionTerminatedEvent', + 'GenericEvent', + 'HistoryEvent', + 'HistoryStateEvent', + 'OrchestrationInstance', + 'OrchestratorCompletedEvent', + 'OrchestratorStartedEvent', + 'ParentInstanceInfo', + 'SubOrchestrationInstanceCompletedEvent', + 'SubOrchestrationInstanceCreatedEvent', + 'SubOrchestrationInstanceFailedEvent', + 'TaskCompletedEvent', + 'TaskFailedEvent', + 'TaskScheduledEvent', + 'TimerCreatedEvent', + 'TimerFiredEvent', + 'TraceContext', + 'to_dict', +] diff --git a/durabletask/internal/history_helpers.py b/durabletask/internal/history_helpers.py new file mode 100644 index 00000000..35ff946a --- /dev/null +++ b/durabletask/internal/history_helpers.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import AsyncIterable, Iterable, Optional + +import durabletask.history as history +import durabletask.internal.orchestrator_service_pb2 as pb +from durabletask.payload import helpers as payload_helpers +from durabletask.payload.store import PayloadStore + + +def collect_history_events( + chunks: Iterable[pb.HistoryChunk], + payload_store: Optional[PayloadStore] = None, +) -> list[history.HistoryEvent]: + events: list[history.HistoryEvent] = [] + for chunk in chunks: + events.extend(_clone_and_convert_events(chunk.events, payload_store)) + return events + + +async def collect_history_events_async( + chunks: AsyncIterable[pb.HistoryChunk], + payload_store: Optional[PayloadStore] = None, +) -> list[history.HistoryEvent]: + events: list[history.HistoryEvent] = [] + async for chunk in chunks: + events.extend(await _clone_and_convert_events_async(chunk.events, payload_store)) + return events + + +def history_event_to_dict(event: history.HistoryEvent) -> dict: + return history.to_dict(event) + + +def _clone_and_convert_events( + source_events: Iterable[pb.HistoryEvent], + payload_store: Optional[PayloadStore], +) -> list[history.HistoryEvent]: + events: list[history.HistoryEvent] = [] + for source_event in source_events: + event = source_event + if payload_store is not None: + # deexternalize_payloads mutates messages in-place, so clone to avoid + # mutating protobuf instances owned by gRPC/deserializer internals. + event = pb.HistoryEvent() + event.CopyFrom(source_event) + payload_helpers.deexternalize_payloads(event, payload_store) + events.append(history._from_protobuf(event)) + return events + + +async def _clone_and_convert_events_async( + source_events: Iterable[pb.HistoryEvent], + payload_store: Optional[PayloadStore], +) -> list[history.HistoryEvent]: + events: list[history.HistoryEvent] = [] + for source_event in source_events: + event = source_event + if payload_store is not None: + # Async deexternalization mutates messages in-place, so clone first. + event = pb.HistoryEvent() + event.CopyFrom(source_event) + await payload_helpers.deexternalize_payloads_async(event, payload_store) + events.append(history._from_protobuf(event)) + return events diff --git a/durabletask/testing/in_memory_backend.py b/durabletask/testing/in_memory_backend.py index d4a22a99..19630bd8 100644 --- a/durabletask/testing/in_memory_backend.py +++ b/durabletask/testing/in_memory_backend.py @@ -41,6 +41,7 @@ class OrchestrationInstance: custom_status: Optional[str] = None created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) last_updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + completed_at: Optional[datetime] = None failure_details: Optional[pb.TaskFailureDetails] = None history: list[pb.HistoryEvent] = field(default_factory=list) pending_events: list[pb.HistoryEvent] = field(default_factory=list) @@ -238,6 +239,7 @@ def StartInstance(self, request: pb.CreateInstanceRequest, context): input=request.input.value if request.input else None, created_at=now, last_updated_at=now, + completed_at=None, completion_token=self._next_completion_token, tags=dict(request.tags) if request.tags else None, ) @@ -453,6 +455,43 @@ def RestartInstance(self, request: pb.RestartInstanceRequest, context): f"Restarted instance '{request.instanceId}' as '{new_instance_id}'") return pb.RestartInstanceResponse(instanceId=new_instance_id) + def ListInstanceIds(self, request: pb.ListInstanceIdsRequest, context): + """Lists terminal orchestration instance IDs with completion-time pagination.""" + with self._lock: + matching = [] + for instance in self._instances.values(): + if not self._is_terminal_status(instance.status): + continue + if request.runtimeStatus and instance.status not in request.runtimeStatus: + continue + if instance.completed_at is None: + continue + if request.HasField("completedTimeFrom") and instance.completed_at < request.completedTimeFrom.ToDatetime(timezone.utc): + continue + if request.HasField("completedTimeTo") and instance.completed_at >= request.completedTimeTo.ToDatetime(timezone.utc): + continue + matching.append(instance) + + matching.sort(key=lambda i: (i.completed_at, i.instance_id)) + + start_index = 0 + if request.HasField("lastInstanceKey") and request.lastInstanceKey.value: + for idx, instance in enumerate(matching): + if instance.instance_id == request.lastInstanceKey.value: + start_index = idx + 1 + break + + page_size = request.pageSize if request.pageSize > 0 else len(matching) + page = matching[start_index:start_index + page_size] + next_token = None + if start_index + page_size < len(matching) and page: + next_token = wrappers_pb2.StringValue(value=page[-1].instance_id) + + return pb.ListInstanceIdsResponse( + instanceIds=[instance.instance_id for instance in page], + lastInstanceKey=next_token, + ) + @staticmethod def _parse_work_item_filters(request: pb.GetWorkItemsRequest): """Extract filters from the request. @@ -1084,8 +1123,18 @@ def CleanEntityStorage(self, request: pb.CleanEntityStorageRequest, context): ) def StreamInstanceHistory(self, request: pb.StreamInstanceHistoryRequest, context): - """Streams instance history (not implemented).""" - context.abort(grpc.StatusCode.UNIMPLEMENTED, "StreamInstanceHistory not implemented") + """Streams orchestration history for an instance.""" + with self._lock: + instance = self._instances.get(request.instanceId) + if instance is None: + context.abort(grpc.StatusCode.NOT_FOUND, + f"Orchestration instance '{request.instanceId}' not found") + return + history = [self._clone_history_event(event) for event in instance.history] + + chunk_size = 100 + for offset in range(0, len(history), chunk_size): + yield pb.HistoryChunk(events=history[offset:offset + chunk_size]) def CreateTaskHub(self, request: pb.CreateTaskHubRequest, context): """Creates task hub resources (no-op for in-memory).""" @@ -1178,6 +1227,7 @@ def _create_instance_internal(self, instance_id: str, name: str, input=encoded_input, created_at=now, last_updated_at=now, + completed_at=None, completion_token=self._next_completion_token, ) self._next_completion_token += 1 @@ -1239,6 +1289,7 @@ def _build_instance_response(self, instance: OrchestrationInstance, orchestrationStatus=instance.status, createdTimestamp=created_ts, lastUpdatedTimestamp=updated_ts, + completedTimestamp=helpers.new_timestamp(instance.completed_at) if instance.completed_at else None, input=wrappers_pb2.StringValue(value=instance.input) if include_payloads and instance.input else None, output=wrappers_pb2.StringValue(value=instance.output) if include_payloads and instance.output else None, customStatus=wrappers_pb2.StringValue(value=instance.custom_status) if instance.custom_status else None, @@ -1319,6 +1370,7 @@ def _process_complete_orchestration_action(self, instance: OrchestrationInstance instance.status = status instance.output = complete_action.result.value if complete_action.result else None instance.failure_details = complete_action.failureDetails if complete_action.failureDetails else None + instance.completed_at = datetime.now(timezone.utc) if self._is_terminal_status(status) else None if status == pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW: # Handle continue-as-new @@ -1336,6 +1388,7 @@ def _process_complete_orchestration_action(self, instance: OrchestrationInstance instance.output = None instance.failure_details = None instance.status = pb.ORCHESTRATION_STATUS_PENDING + instance.completed_at = None # Save any events that arrived during the in-flight dispatch so # they can be appended AFTER the new execution started events. @@ -1357,6 +1410,12 @@ def _process_complete_orchestration_action(self, instance: OrchestrationInstance self._enqueue_orchestration(instance.instance_id) + @staticmethod + def _clone_history_event(event: pb.HistoryEvent) -> pb.HistoryEvent: + cloned_event = pb.HistoryEvent() + cloned_event.CopyFrom(event) + return cloned_event + def _process_schedule_task_action(self, instance: OrchestrationInstance, action: pb.OrchestratorAction): """Processes a schedule task action.""" diff --git a/tests/durabletask/test_batch_actions.py b/tests/durabletask/test_batch_actions.py index f25805b5..86f4c3b1 100644 --- a/tests/durabletask/test_batch_actions.py +++ b/tests/durabletask/test_batch_actions.py @@ -305,6 +305,45 @@ def test_purge_orchestrations_by_time_range(backend): worker.stop() +def test_list_instance_ids_paginates_terminal_instances(backend): + worker = TaskHubGrpcWorker(host_address=HOST) + c = TaskHubGrpcClient(host_address=HOST) + + worker.add_orchestrator(empty_orchestrator) + worker.add_orchestrator(failing_orchestrator) + worker.start() + + try: + completed_id = c.schedule_new_orchestration(empty_orchestrator, input='done') + c.wait_for_orchestration_completion(completed_id, timeout=30) + + failed_id = c.schedule_new_orchestration(failing_orchestrator) + failed_state = c.wait_for_orchestration_completion(failed_id, timeout=30) + + window_start = datetime.now(timezone.utc) - timedelta(minutes=1) + first_page = c.list_instance_ids( + runtime_status=[client.OrchestrationStatus.COMPLETED, client.OrchestrationStatus.FAILED], + completed_time_from=window_start, + page_size=1, + ) + second_page = c.list_instance_ids( + runtime_status=[client.OrchestrationStatus.COMPLETED, client.OrchestrationStatus.FAILED], + completed_time_from=window_start, + page_size=1, + continuation_token=first_page.continuation_token, + ) + finally: + worker.stop() + + assert len(first_page.items) == 1 + assert len(second_page.items) == 1 + assert set(first_page.items + second_page.items) == {completed_id, failed_id} + assert failed_state is not None + assert failed_state.runtime_status == client.OrchestrationStatus.FAILED + assert first_page.continuation_token in {completed_id, failed_id} + assert second_page.continuation_token is None + + def test_get_all_entities(backend): counter_value = 0 diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index 006c0987..0216d9cf 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -1,6 +1,13 @@ -from unittest.mock import ANY, MagicMock, patch +from datetime import datetime, timezone +from unittest.mock import ANY, AsyncMock, MagicMock, patch -from durabletask.client import AsyncTaskHubGrpcClient +import pytest +from google.protobuf import wrappers_pb2 + +import durabletask.history as history +import durabletask.internal.orchestrator_service_pb2 as pb +from durabletask.client import AsyncTaskHubGrpcClient, OrchestrationStatus, TaskHubGrpcClient +from durabletask.payload.store import LargePayloadStorageOptions, PayloadStore from durabletask.internal.grpc_interceptor import ( DefaultAsyncClientInterceptorImpl, @@ -17,6 +24,37 @@ INTERCEPTORS = [DefaultClientInterceptorImpl(METADATA)] +class FakePayloadStore(PayloadStore): + TOKEN_PREFIX = 'fake://' + + def __init__(self): + self._options = LargePayloadStorageOptions(threshold_bytes=1, max_stored_payload_bytes=1024 * 1024) + self._blobs: dict[str, bytes] = {} + self._counter = 0 + + @property + def options(self) -> LargePayloadStorageOptions: + return self._options + + def upload(self, payload: bytes, *, instance_id=None) -> str: + self._counter += 1 + token = f'{self.TOKEN_PREFIX}{self._counter}' + self._blobs[token] = payload + return token + + def download(self, token: str) -> bytes: + return self._blobs[token] + + def is_known_token(self, value: str) -> bool: + return value.startswith(self.TOKEN_PREFIX) + + async def upload_async(self, payload: bytes, *, instance_id=None) -> str: + return self.upload(payload, instance_id=instance_id) + + async def download_async(self, token: str) -> bytes: + return self.download(token) + + # ==== Sync channel tests ==== @@ -185,3 +223,105 @@ def test_async_client_creates_with_metadata(): assert interceptors is not None assert len(interceptors) == 1 assert isinstance(interceptors[0], DefaultAsyncClientInterceptorImpl) + + +def test_get_orchestration_history_aggregates_chunks_and_deexternalizes_payloads(): + store = FakePayloadStore() + token = store.upload(b'history payload') + stream = [ + pb.HistoryChunk(events=[ + pb.HistoryEvent( + eventId=1, + taskCompleted=pb.TaskCompletedEvent( + taskScheduledId=42, + result=wrappers_pb2.StringValue(value=token), + ), + ) + ]), + pb.HistoryChunk(events=[pb.HistoryEvent(eventId=2, executionCompleted=pb.ExecutionCompletedEvent())]), + ] + + stub = MagicMock() + stub.StreamInstanceHistory.return_value = stream + + with patch('durabletask.client.shared.get_grpc_channel', return_value=MagicMock()), patch( + 'durabletask.client.stubs.TaskHubSidecarServiceStub', return_value=stub): + history_client = TaskHubGrpcClient(payload_store=store) + events = history_client.get_orchestration_history('abc') + + assert [event.event_id for event in events] == [1, 2] + assert isinstance(events[0], history.TaskCompletedEvent) + assert events[0].result == 'history payload' + req = stub.StreamInstanceHistory.call_args.args[0] + assert req.instanceId == 'abc' + + +def test_list_instance_ids_returns_page(): + stub = MagicMock() + stub.ListInstanceIds.return_value = pb.ListInstanceIdsResponse( + instanceIds=['a', 'b'], + lastInstanceKey=wrappers_pb2.StringValue(value='b'), + ) + + with patch('durabletask.client.shared.get_grpc_channel', return_value=MagicMock()), patch( + 'durabletask.client.stubs.TaskHubSidecarServiceStub', return_value=stub): + history_client = TaskHubGrpcClient() + page = history_client.list_instance_ids( + runtime_status=[OrchestrationStatus.COMPLETED], + completed_time_from=datetime(2025, 1, 1, tzinfo=timezone.utc), + page_size=2, + continuation_token='prev', + ) + + assert page.items == ['a', 'b'] + assert page.continuation_token == 'b' + req = stub.ListInstanceIds.call_args.args[0] + assert list(req.runtimeStatus) == [pb.ORCHESTRATION_STATUS_COMPLETED] + assert req.pageSize == 2 + assert req.lastInstanceKey.value == 'prev' + + +@pytest.mark.asyncio +async def test_async_get_orchestration_history_aggregates_chunks_and_deexternalizes_payloads(): + store = FakePayloadStore() + token = store.upload(b'async history payload') + + async def stream(): + yield pb.HistoryChunk(events=[ + pb.HistoryEvent( + eventId=3, + taskCompleted=pb.TaskCompletedEvent( + taskScheduledId=43, + result=wrappers_pb2.StringValue(value=token), + ), + ) + ]) + + stub = MagicMock() + stub.StreamInstanceHistory.return_value = stream() + + with patch('durabletask.client.shared.get_async_grpc_channel', return_value=MagicMock()), patch( + 'durabletask.client.stubs.TaskHubSidecarServiceStub', return_value=stub): + history_client = AsyncTaskHubGrpcClient(payload_store=store) + events = await history_client.get_orchestration_history('async-abc') + + assert [event.event_id for event in events] == [3] + assert isinstance(events[0], history.TaskCompletedEvent) + assert events[0].result == 'async history payload' + + +@pytest.mark.asyncio +async def test_async_list_instance_ids_returns_page(): + stub = MagicMock() + stub.ListInstanceIds = AsyncMock(return_value=pb.ListInstanceIdsResponse( + instanceIds=['one'], + lastInstanceKey=wrappers_pb2.StringValue(value='one'), + )) + + with patch('durabletask.client.shared.get_async_grpc_channel', return_value=MagicMock()), patch( + 'durabletask.client.stubs.TaskHubSidecarServiceStub', return_value=stub): + history_client = AsyncTaskHubGrpcClient() + page = await history_client.list_instance_ids(page_size=1) + + assert page.items == ['one'] + assert page.continuation_token == 'one' diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 103f14aa..aea110b1 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -10,6 +10,7 @@ import pytest from durabletask import client, task, worker +import durabletask.history as history from durabletask.testing import create_test_backend HOST = "localhost:50054" @@ -128,6 +129,32 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): assert compensation_counter == 2 +def test_get_orchestration_history(): + + def plus_one(_: task.ActivityContext, input: int) -> int: + return input + 1 + + def simple(ctx: task.OrchestrationContext, value: int): + result = yield ctx.call_activity(plus_one, input=value) + return result + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(simple) + w.add_activity(plus_one) + w.start() + + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) + instance_id = task_hub_client.schedule_new_orchestration(simple, input=1) + state = task_hub_client.wait_for_orchestration_completion(instance_id, timeout=30) + events = task_hub_client.get_orchestration_history(instance_id) + + assert state is not None + assert len(events) > 0 + assert any(isinstance(event, history.ExecutionStartedEvent) for event in events) + assert any(isinstance(event, history.TaskScheduledEvent) for event in events) + assert any(isinstance(event, history.TaskCompletedEvent) for event in events) + + def test_sub_orchestration_fan_out(): threadLock = threading.Lock() activity_counter = 0 From 765cf5e6b66c5cd2fcebec4d1faff6841e838387 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 14 Apr 2026 11:26:34 -0600 Subject: [PATCH 2/3] CHANGELOG clarity, instructions --- .github/copilot-instructions.md | 10 ++++++++++ CHANGELOG.md | 3 +-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 23914c0d..df9fb40e 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -15,9 +15,19 @@ building durable orchestrations. The repo contains two packages: - Update `CHANGELOG.md` for core SDK changes and `durabletask-azuremanaged/CHANGELOG.md` for provider changes. - If a change affects both packages, update both changelogs. +- Include changelog entries for externally observable outcomes only, such as + new public APIs, behavior changes, bug fixes users can notice, breaking + changes, and new configuration capabilities. - Do NOT document internal-only changes in changelogs, including CI/workflow updates, test-only changes, refactors with no user-visible behavior change, and implementation details that do not affect public behavior or API. +- When in doubt, write the changelog entry in terms of user impact (what users + can now do or what behavior changed), not implementation mechanism (how it + was implemented internally). + +Examples: +- Include: "Added `get_orchestration_history()` to retrieve orchestration history from the client." +- Exclude: "Added internal helper functions to aggregate streamed history chunks." ## Language and Style diff --git a/CHANGELOG.md b/CHANGELOG.md index b405b513..4979a514 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ADDED - Added `get_orchestration_history()` and `list_instance_ids()` to the sync and async gRPC clients. -- Added in-memory backend support for `StreamInstanceHistory` and `ListInstanceIds` to enable history retrieval and completion-window pagination in tests. -- Added internal history helpers for aggregating streamed history events, de-externalizing payload-backed values, and converting history events to dictionaries. +- Added in-memory backend support for `StreamInstanceHistory` and `ListInstanceIds` so local orchestration tests can retrieve history and page terminal instance IDs by completion window. ## v1.4.0 From 342eb249a17b77be04d039382ce8d8afcad57c89 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 14 Apr 2026 11:57:04 -0600 Subject: [PATCH 3/3] Address PR review comments --- durabletask/client.py | 4 ++-- durabletask/history.py | 10 ++++----- durabletask/testing/in_memory_backend.py | 24 +++++++++++++++------ tests/durabletask/test_batch_actions.py | 3 ++- tests/durabletask/test_orchestration_e2e.py | 9 +++++--- 5 files changed, 32 insertions(+), 18 deletions(-) diff --git a/durabletask/client.py b/durabletask/client.py index 8d2831e0..a143503a 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -247,7 +247,7 @@ def list_instance_ids(self, page_size: Optional[int] = None, continuation_token: Optional[str] = None) -> Page[str]: req = pb.ListInstanceIdsRequest( - runtimeStatus=[status.value for status in runtime_status] if runtime_status else None, + runtimeStatus=[status.value for status in runtime_status] if runtime_status else [], completedTimeFrom=helpers.new_timestamp(completed_time_from) if completed_time_from else None, completedTimeTo=helpers.new_timestamp(completed_time_to) if completed_time_to else None, pageSize=page_size or 0, @@ -569,7 +569,7 @@ async def list_instance_ids(self, page_size: Optional[int] = None, continuation_token: Optional[str] = None) -> Page[str]: req = pb.ListInstanceIdsRequest( - runtimeStatus=[status.value for status in runtime_status] if runtime_status else None, + runtimeStatus=[status.value for status in runtime_status] if runtime_status else [], completedTimeFrom=helpers.new_timestamp(completed_time_from) if completed_time_from else None, completedTimeTo=helpers.new_timestamp(completed_time_to) if completed_time_to else None, pageSize=page_size or 0, diff --git a/durabletask/history.py b/durabletask/history.py index 34046c62..43fc30b8 100644 --- a/durabletask/history.py +++ b/durabletask/history.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import asdict, dataclass -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Optional from google.protobuf import json_format @@ -254,7 +254,7 @@ def to_dict(event: HistoryEvent) -> dict[str, Any]: def _base_kwargs(event: pb.HistoryEvent) -> dict[str, Any]: return { 'event_id': event.eventId, - 'timestamp': event.timestamp.ToDatetime(), + 'timestamp': event.timestamp.ToDatetime(timezone.utc), } @@ -266,7 +266,7 @@ def _string_value(msg: Message, field_name: str) -> Optional[str]: def _timestamp_value(msg: Message, field_name: str) -> Optional[datetime]: if msg.HasField(field_name): - return getattr(msg, field_name).ToDatetime() + return getattr(msg, field_name).ToDatetime(timezone.utc) return None @@ -398,11 +398,11 @@ def _to_serializable(value: Any) -> Any: ), 'timerCreated': lambda event: TimerCreatedEvent( **_base_kwargs(event), - fire_at=event.timerCreated.fireAt.ToDatetime(), + fire_at=event.timerCreated.fireAt.ToDatetime(timezone.utc), ), 'timerFired': lambda event: TimerFiredEvent( **_base_kwargs(event), - fire_at=event.timerFired.fireAt.ToDatetime(), + fire_at=event.timerFired.fireAt.ToDatetime(timezone.utc), timer_id=event.timerFired.timerId, ), 'orchestratorStarted': lambda event: OrchestratorStartedEvent(**_base_kwargs(event)), diff --git a/durabletask/testing/in_memory_backend.py b/durabletask/testing/in_memory_backend.py index 19630bd8..744c4c24 100644 --- a/durabletask/testing/in_memory_backend.py +++ b/durabletask/testing/in_memory_backend.py @@ -10,6 +10,7 @@ or external storage is not desired. """ +import bisect import logging import threading import time @@ -98,6 +99,10 @@ class StateWaiter: result: Optional[OrchestrationInstance] = None +_DEFAULT_PAGE_SIZE = 100 +_TOKEN_SEP = '|' + + class InMemoryOrchestrationBackend(stubs.TaskHubSidecarServiceServicer): """ In-memory backend for durable orchestrations suitable for testing. @@ -473,19 +478,24 @@ def ListInstanceIds(self, request: pb.ListInstanceIdsRequest, context): matching.append(instance) matching.sort(key=lambda i: (i.completed_at, i.instance_id)) + sort_keys = [(i.completed_at, i.instance_id) for i in matching] start_index = 0 if request.HasField("lastInstanceKey") and request.lastInstanceKey.value: - for idx, instance in enumerate(matching): - if instance.instance_id == request.lastInstanceKey.value: - start_index = idx + 1 - break - - page_size = request.pageSize if request.pageSize > 0 else len(matching) + token = request.lastInstanceKey.value + sep_idx = token.index(_TOKEN_SEP) + token_ts = datetime.fromisoformat(token[:sep_idx]).replace(tzinfo=timezone.utc) + token_id = token[sep_idx + 1:] + # bisect_right positions us just after the cursor entry + start_index = bisect.bisect_right(sort_keys, (token_ts, token_id)) + + page_size = request.pageSize if request.pageSize > 0 else _DEFAULT_PAGE_SIZE page = matching[start_index:start_index + page_size] next_token = None if start_index + page_size < len(matching) and page: - next_token = wrappers_pb2.StringValue(value=page[-1].instance_id) + last = page[-1] + encoded = f"{last.completed_at.isoformat()}{_TOKEN_SEP}{last.instance_id}" + next_token = wrappers_pb2.StringValue(value=encoded) return pb.ListInstanceIdsResponse( instanceIds=[instance.instance_id for instance in page], diff --git a/tests/durabletask/test_batch_actions.py b/tests/durabletask/test_batch_actions.py index 86f4c3b1..b0778503 100644 --- a/tests/durabletask/test_batch_actions.py +++ b/tests/durabletask/test_batch_actions.py @@ -340,7 +340,8 @@ def test_list_instance_ids_paginates_terminal_instances(backend): assert set(first_page.items + second_page.items) == {completed_id, failed_id} assert failed_state is not None assert failed_state.runtime_status == client.OrchestrationStatus.FAILED - assert first_page.continuation_token in {completed_id, failed_id} + assert first_page.continuation_token is not None + assert any(instance_id in first_page.continuation_token for instance_id in {completed_id, failed_id}) assert second_page.continuation_token is None diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index aea110b1..065a8b11 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -144,9 +144,12 @@ def simple(ctx: task.OrchestrationContext, value: int): w.start() task_hub_client = client.TaskHubGrpcClient(host_address=HOST) - instance_id = task_hub_client.schedule_new_orchestration(simple, input=1) - state = task_hub_client.wait_for_orchestration_completion(instance_id, timeout=30) - events = task_hub_client.get_orchestration_history(instance_id) + try: + instance_id = task_hub_client.schedule_new_orchestration(simple, input=1) + state = task_hub_client.wait_for_orchestration_completion(instance_id, timeout=30) + events = task_hub_client.get_orchestration_history(instance_id) + finally: + task_hub_client.close() assert state is not None assert len(events) > 0