From 347f4ba56dff2199356d99a26c4302c0c95c171f Mon Sep 17 00:00:00 2001 From: Albert Callarisa Date: Fri, 17 Apr 2026 09:08:51 +0200 Subject: [PATCH 1/4] Adds support for pydantic workflow and activity inputs/outputs Signed-off-by: Albert Callarisa --- examples/workflow/README.md | 28 ++++ examples/workflow/pydantic_models.py | 100 +++++++++++ examples/workflow/requirements.txt | 1 + .../workflow/_durabletask/internal/shared.py | 12 ++ .../dapr/ext/workflow/_model_protocol.py | 148 ++++++++++++++++ .../dapr/ext/workflow/workflow_runtime.py | 13 ++ .../tests/durabletask/test_serialization.py | 47 ++++++ .../tests/test_model_protocol.py | 158 ++++++++++++++++++ .../tests/test_workflow_runtime.py | 126 +++++++++++++- 9 files changed, 632 insertions(+), 1 deletion(-) create mode 100644 examples/workflow/pydantic_models.py create mode 100644 ext/dapr-ext-workflow/dapr/ext/workflow/_model_protocol.py create mode 100644 ext/dapr-ext-workflow/tests/test_model_protocol.py diff --git a/examples/workflow/README.md b/examples/workflow/README.md index cf3cce610..30210e667 100644 --- a/examples/workflow/README.md +++ b/examples/workflow/README.md @@ -508,3 +508,31 @@ dapr run --app-id wf-versioning-example -- python3 versioning.py part1 dapr run --app-id wf-versioning-example --log-level debug -- python3 versioning.py part2 ``` + +### Pydantic models as workflow/activity inputs + +This example shows how to pass [Pydantic](https://docs.pydantic.dev/) `BaseModel` +instances directly as workflow and activity inputs. When a workflow or activity +annotates its input parameter with a `BaseModel` subclass, the runtime +reconstructs the model from the decoded JSON payload automatically — no manual +`model_validate` call is needed at the receiving side. + +The wire format remains plain JSON, so workflows and activities stay +interop-friendly with non-Python Dapr apps. Outputs coming back from activities +arrive as dicts; reconstructing them into a typed instance is a one-liner +(`OrderResult.model_validate(...)`). + + + +```sh +dapr run --app-id wf-pydantic-example -- python3 pydantic_models.py +``` + diff --git a/examples/workflow/pydantic_models.py b/examples/workflow/pydantic_models.py new file mode 100644 index 000000000..c4aefd79f --- /dev/null +++ b/examples/workflow/pydantic_models.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2026 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Native Pydantic model support in Dapr workflows and activities. + +Inputs annotated with a Pydantic BaseModel are reconstructed automatically on +the receiving side — no manual serialization is needed. Outputs are emitted +as plain JSON so the wire format stays interop-friendly with non-Python Dapr +apps. +""" + +from time import sleep + +from dapr.ext.workflow import ( + DaprWorkflowClient, + DaprWorkflowContext, + WorkflowActivityContext, + WorkflowRuntime, +) +from pydantic import BaseModel + + +class OrderRequest(BaseModel): + order_id: str + customer: str + amount: float + + +class OrderResult(BaseModel): + order_id: str + approved: bool + message: str + + +wfr = WorkflowRuntime() +instance_id = 'pydantic-demo' + + +@wfr.workflow(name='order_workflow') +def order_workflow(ctx: DaprWorkflowContext, order: OrderRequest): + # `order` arrives as a real OrderRequest instance — the runtime reads the + # annotation and reconstructs the model from the decoded JSON automatically. + if not ctx.is_replaying: + print( + f'[workflow] received order {order.order_id} ' + f'for {order.customer} amount={order.amount}', + flush=True, + ) + raw = yield ctx.call_activity(approve_order, input=order) + # Activity results come back as a plain dict. One line turns them into a + # typed instance. + result = OrderResult.model_validate(raw) + if not ctx.is_replaying: + print( + f'[workflow] activity returned approved={result.approved}', + flush=True, + ) + return result + + +@wfr.activity(name='approve_order') +def approve_order(ctx: WorkflowActivityContext, order: OrderRequest) -> OrderResult: + # Same story: `order` is already an OrderRequest instance here. + print(f'[activity] approving order {order.order_id}', flush=True) + if order.amount <= 100.0: + return OrderResult(order_id=order.order_id, approved=True, message='auto-approved') + return OrderResult(order_id=order.order_id, approved=False, message='needs review') + + +def main(): + wfr.start() + sleep(5) + client = DaprWorkflowClient() + + order = OrderRequest(order_id='O-100', customer='Acme', amount=42.0) + client.schedule_new_workflow(workflow=order_workflow, input=order, instance_id=instance_id) + state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=30) + + # state.serialized_output is a JSON string — reconstruct a typed instance. + output = OrderResult.model_validate_json(state.serialized_output) + print( + f'[client] workflow output: order_id={output.order_id} ' + f'approved={output.approved} message={output.message}', + flush=True, + ) + + client.purge_workflow(instance_id) + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/requirements.txt b/examples/workflow/requirements.txt index faabd0063..023e2c2b4 100644 --- a/examples/workflow/requirements.txt +++ b/examples/workflow/requirements.txt @@ -1,2 +1,3 @@ dapr-ext-workflow>=1.17.0.dev dapr>=1.17.0.dev +pydantic>=2.0 diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/internal/shared.py b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/internal/shared.py index 540887657..ffe144dfb 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/internal/shared.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/internal/shared.py @@ -156,6 +156,18 @@ def encode(self, obj: Any) -> str: return super().encode(obj) def default(self, obj): + # Dapr-specific: objects implementing the duck-typed model protocol + # (model_dump + model_validate) are emitted as plain JSON objects with + # no AUTO_SERIALIZED marker, so the payload stays readable by + # non-Python SDKs and by workflows/activities that don't import the + # same class. Type-directed reconstruction happens at the + # activity/workflow input boundary in + # dapr.ext.workflow.workflow_runtime. No pydantic dependency — any + # class matching the protocol works (Pydantic v2, SQLModel, custom). + from dapr.ext.workflow import _model_protocol + + if _model_protocol.is_model(obj): + return _model_protocol.dump_model(obj) if dataclasses.is_dataclass(obj): # Dataclasses are not serializable by default, so we convert them to a dict and mark them for # automatic deserialization by the receiver diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/_model_protocol.py b/ext/dapr-ext-workflow/dapr/ext/workflow/_model_protocol.py new file mode 100644 index 000000000..fa1f29438 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/_model_protocol.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2026 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import inspect +import typing +from functools import lru_cache +from types import SimpleNamespace +from typing import Any, Callable, Optional + +# A "model" here is anything that implements the Pydantic v2 shape: +# - model_dump(self, ...) -> dict +# - cls.model_validate(value) -> instance +# We duck-type on these names rather than importing pydantic so the SDK has no +# hard dependency on pydantic (or any specific version of it). SQLModel, +# FastAPI response models, and custom classes mirroring the protocol all work. + + +def is_model(obj: Any) -> bool: + """Whether obj implements the model protocol (model_dump + model_validate).""" + return is_model_class(type(obj)) + + +def is_model_class(cls: Any) -> bool: + """Whether cls is a class implementing the model protocol.""" + return ( + inspect.isclass(cls) + and callable(getattr(cls, 'model_dump', None)) + and callable(getattr(cls, 'model_validate', None)) + ) + + +@lru_cache(maxsize=None) +def _supports_mode_kwarg(cls: type) -> bool: + """Whether cls.model_dump accepts a `mode` keyword (Pydantic v2 signature).""" + try: + sig = inspect.signature(cls.model_dump) + except (TypeError, ValueError): + return False + params = sig.parameters + if 'mode' in params: + return True + return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + + +def dump_model(model: Any) -> Any: + """Serialize a model instance to a JSON-compatible primitive graph. + + Prefers model_dump(mode='json') when supported so nested datetimes, enums, + and UUIDs render into JSON-safe primitives. Falls back to bare model_dump() + for protocol-compatible classes that don't accept the mode kwarg — those + classes are responsible for returning JSON-safe values themselves. + """ + if not is_model(model): + raise TypeError( + f'Expected a model-like object with model_dump/model_validate, ' + f'got {type(model).__name__}' + ) + cls = type(model) + if _supports_mode_kwarg(cls): + return model.model_dump(mode='json') + return model.model_dump() + + +def coerce_to_model(value: Any, cls: type) -> Any: + """Reconstruct a model instance from a decoded JSON payload. + + Accepts dicts, SimpleNamespace (from the InternalJSONDecoder's + AUTO_SERIALIZED path), or already-instantiated models. Any other shape + raises TypeError so the failure surfaces at the activity/workflow + boundary rather than later as an attribute access error. + """ + if not is_model_class(cls): + raise TypeError(f'{cls!r} is not a model class (no model_dump/model_validate)') + if isinstance(value, cls): + return value + if isinstance(value, SimpleNamespace): + value = vars(value) + if isinstance(value, dict): + return cls.model_validate(value) + raise TypeError( + f'Cannot coerce value of type {type(value).__name__} into {cls.__name__}; ' + 'expected a dict, SimpleNamespace, or existing model instance.' + ) + + +def resolve_input_model(fn: Callable[..., Any]) -> Optional[type]: + """Return the model class annotated on fn's input parameter, if any. + + Workflow and activity functions take (ctx, input) — we look at the second + positional parameter's annotation. Returns None when no annotation is + present or the annotation is not a model class. Optional[Model] and + Model | None are unwrapped to Model. + """ + try: + sig = inspect.signature(fn) + except (TypeError, ValueError): + return None + + params = list(sig.parameters.values()) + if len(params) < 2: + return None + + annotation = params[1].annotation + if annotation is inspect.Parameter.empty: + return None + + if isinstance(annotation, str): + try: + hints = typing.get_type_hints(fn) + annotation = hints.get(params[1].name, annotation) + except Exception: + return None + + annotation = _unwrap_optional(annotation) + return annotation if is_model_class(annotation) else None + + +def _unwrap_optional(annotation: Any) -> Any: + """Unwrap Optional[X] / X | None to X. Leaves other annotations unchanged.""" + origin = typing.get_origin(annotation) + if origin is typing.Union or _is_pep604_union(origin): + args = [a for a in typing.get_args(annotation) if a is not type(None)] + if len(args) == 1: + return args[0] + return annotation + + +def _is_pep604_union(origin: Any) -> bool: + try: + from types import UnionType # type: ignore[attr-defined] + + return origin is UnionType + except ImportError: + return False diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index 197254a89..2a62ab096 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -19,6 +19,7 @@ from typing import Optional, Sequence, TypeVar, Union import grpc +from dapr.ext.workflow import _model_protocol from dapr.ext.workflow._durabletask import task, worker from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext from dapr.ext.workflow.logger import Logger, LoggerOptions @@ -89,6 +90,8 @@ def register_workflow(self, fn: Workflow, *, name: Optional[str] = None): effective_name = name or fn.__name__ self._logger.info(f"Registering workflow '{effective_name}' with runtime") + input_model = _model_protocol.resolve_input_model(fn) + def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None): """Responsible to call Workflow function in orchestrationWrapper""" instance_id = getattr(ctx, 'instance_id', 'unknown') @@ -98,6 +101,8 @@ def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = if inp is None: result = fn(daprWfContext) else: + if input_model is not None and not isinstance(inp, input_model): + inp = _model_protocol.coerce_to_model(inp, input_model) result = fn(daprWfContext, inp) return result except Exception as e: @@ -131,11 +136,15 @@ def register_versioned_workflow( f"Registering version {version_name} of workflow '{effective_name}' with runtime" ) + input_model = _model_protocol.resolve_input_model(fn) + def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None): """Responsible to call Workflow function in orchestrationWrapper""" daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) if inp is None: return fn(daprWfContext) + if input_model is not None and not isinstance(inp, input_model): + inp = _model_protocol.coerce_to_model(inp, input_model) return fn(daprWfContext, inp) if hasattr(fn, '_workflow_registered'): @@ -167,6 +176,8 @@ def register_activity(self, fn: Activity, *, name: Optional[str] = None): effective_name = name or fn.__name__ self._logger.info(f"Registering activity '{effective_name}' with runtime") + input_model = _model_protocol.resolve_input_model(fn) + def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): """Responsible to call Activity function in activityWrapper""" activity_id = getattr(ctx, 'task_id', 'unknown') @@ -176,6 +187,8 @@ def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): if inp is None: result = fn(wfActivityContext) else: + if input_model is not None and not isinstance(inp, input_model): + inp = _model_protocol.coerce_to_model(inp, input_model) result = fn(wfActivityContext, inp) return result except Exception as e: diff --git a/ext/dapr-ext-workflow/tests/durabletask/test_serialization.py b/ext/dapr-ext-workflow/tests/durabletask/test_serialization.py index 334bed174..2b9a79dfb 100644 --- a/ext/dapr-ext-workflow/tests/durabletask/test_serialization.py +++ b/ext/dapr-ext-workflow/tests/durabletask/test_serialization.py @@ -16,6 +16,7 @@ from types import SimpleNamespace from dapr.ext.workflow._durabletask.internal.shared import AUTO_SERIALIZED, from_json, to_json +from pydantic import BaseModel @dataclass @@ -85,3 +86,49 @@ def test_to_json_nested_dataclass_collection(): assert isinstance(decoded, list) assert [item.count for item in decoded] == [1, 2] assert [item.name for item in decoded] == ['first', 'second'] + + +class Order(BaseModel): + order_id: str + amount: float + + +class Item(BaseModel): + sku: str + qty: int + + +def test_to_json_pydantic_model_emits_plain_dict(): + encoded = to_json(Order(order_id='o1', amount=9.99)) + + # Pydantic payloads must not carry the AUTO_SERIALIZED marker so that + # cross-language and marker-unaware receivers can read them as plain JSON. + assert AUTO_SERIALIZED not in encoded + + decoded = from_json(encoded) + assert decoded == {'order_id': 'o1', 'amount': 9.99} + + +def test_to_json_pydantic_model_in_list_emits_plain_dicts(): + encoded = to_json([Item(sku='A', qty=1), Item(sku='B', qty=2)]) + + assert AUTO_SERIALIZED not in encoded + + decoded = from_json(encoded) + assert decoded == [{'sku': 'A', 'qty': 1}, {'sku': 'B', 'qty': 2}] + + +def test_to_json_pydantic_and_dataclass_coexist(): + payload = { + 'order': Order(order_id='o1', amount=1.0), + 'detail': SamplePayload(count=3, name='x'), + } + encoded = to_json(payload) + + # Dataclass still carries AUTO_SERIALIZED; Pydantic does not. + assert encoded.count(AUTO_SERIALIZED) == 1 + + decoded = from_json(encoded) + assert decoded['order'] == {'order_id': 'o1', 'amount': 1.0} + assert isinstance(decoded['detail'], SimpleNamespace) + assert decoded['detail'].count == 3 diff --git a/ext/dapr-ext-workflow/tests/test_model_protocol.py b/ext/dapr-ext-workflow/tests/test_model_protocol.py new file mode 100644 index 000000000..f815e6830 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_model_protocol.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2026 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from types import SimpleNamespace +from typing import Optional + +from dapr.ext.workflow import _model_protocol +from pydantic import BaseModel, ValidationError + + +class Order(BaseModel): + order_id: str + amount: float + + +class ModelProtocolTest(unittest.TestCase): + """Model-protocol helpers exercised against real Pydantic models.""" + + def test_is_model_recognizes_pydantic_instance(self): + self.assertTrue(_model_protocol.is_model(Order(order_id='o1', amount=1.0))) + + def test_is_model_class_recognizes_pydantic_class(self): + self.assertTrue(_model_protocol.is_model_class(Order)) + + def test_is_model_rejects_plain_objects(self): + self.assertFalse(_model_protocol.is_model(None)) + self.assertFalse(_model_protocol.is_model({'a': 1})) + self.assertFalse(_model_protocol.is_model(object())) + self.assertFalse(_model_protocol.is_model_class(dict)) + self.assertFalse(_model_protocol.is_model_class(None)) + + def test_dump_model_uses_json_mode(self): + dumped = _model_protocol.dump_model(Order(order_id='o1', amount=2.5)) + self.assertEqual(dumped, {'order_id': 'o1', 'amount': 2.5}) + + def test_dump_model_rejects_non_model(self): + with self.assertRaises(TypeError): + _model_protocol.dump_model({'order_id': 'o1', 'amount': 1.0}) + + def test_coerce_to_model_from_dict(self): + order = _model_protocol.coerce_to_model({'order_id': 'o1', 'amount': 3.0}, Order) + self.assertIsInstance(order, Order) + self.assertEqual(order.order_id, 'o1') + self.assertEqual(order.amount, 3.0) + + def test_coerce_to_model_from_simplenamespace(self): + ns = SimpleNamespace(order_id='o2', amount=4.0) + order = _model_protocol.coerce_to_model(ns, Order) + self.assertIsInstance(order, Order) + self.assertEqual(order.order_id, 'o2') + self.assertEqual(order.amount, 4.0) + + def test_coerce_to_model_passthrough_when_already_instance(self): + original = Order(order_id='o3', amount=5.0) + self.assertIs(_model_protocol.coerce_to_model(original, Order), original) + + def test_coerce_to_model_rejects_unsupported_shape(self): + with self.assertRaises(TypeError): + _model_protocol.coerce_to_model(42, Order) + with self.assertRaises(TypeError): + _model_protocol.coerce_to_model([1, 2, 3], Order) + + def test_coerce_to_model_rejects_non_model_class(self): + with self.assertRaises(TypeError): + _model_protocol.coerce_to_model({'x': 1}, dict) + + def test_coerce_to_model_raises_validation_error_on_invalid_payload(self): + with self.assertRaises(ValidationError): + _model_protocol.coerce_to_model({'order_id': 'o1'}, Order) # missing amount + + +class ResolveInputModelTest(unittest.TestCase): + def test_resolves_pydantic_annotation(self): + def my_activity(ctx, order: Order): + return order + + self.assertIs(_model_protocol.resolve_input_model(my_activity), Order) + + def test_unwraps_optional(self): + def my_activity(ctx, order: Optional[Order] = None): + return order + + self.assertIs(_model_protocol.resolve_input_model(my_activity), Order) + + def test_returns_none_when_no_annotation(self): + def my_activity(ctx, order): + return order + + self.assertIsNone(_model_protocol.resolve_input_model(my_activity)) + + def test_returns_none_for_non_model_annotation(self): + def my_activity(ctx, order: dict): + return order + + self.assertIsNone(_model_protocol.resolve_input_model(my_activity)) + + def test_returns_none_for_ctx_only(self): + def my_activity(ctx): + return None + + self.assertIsNone(_model_protocol.resolve_input_model(my_activity)) + + +class _DuckModelNoModeKwarg: + """Non-Pydantic class matching the model protocol without a mode kwarg. + + Exercises the _supports_mode_kwarg fallback path — real Pydantic v2 always + accepts `mode`, so this behavior needs a non-Pydantic class to hit. + """ + + def __init__(self, name: str, value: int): + self.name = name + self.value = value + + def model_dump(self) -> dict: + return {'name': self.name, 'value': self.value} + + @classmethod + def model_validate(cls, data: dict) -> '_DuckModelNoModeKwarg': + return cls(name=data['name'], value=data['value']) + + +class ProtocolOpennessTest(unittest.TestCase): + """The protocol is open to any class implementing model_dump/model_validate.""" + + def test_dump_falls_back_when_model_dump_has_no_mode_kwarg(self): + dumped = _model_protocol.dump_model(_DuckModelNoModeKwarg('x', 7)) + self.assertEqual(dumped, {'name': 'x', 'value': 7}) + + def test_is_model_class_rejects_partial_implementations(self): + class DumpOnly: + def model_dump(self): + return {} + + class ValidateOnly: + @classmethod + def model_validate(cls, data): + return cls() + + self.assertFalse(_model_protocol.is_model_class(DumpOnly)) + self.assertFalse(_model_protocol.is_model_class(ValidateOnly)) + + +if __name__ == '__main__': + unittest.main() diff --git a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py index 0c6c2afd3..340ed9ac6 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py @@ -14,12 +14,19 @@ """ import unittest -from typing import List +from typing import List, Optional from unittest import mock from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, alternate_name +from pydantic import BaseModel, ValidationError + + +class Order(BaseModel): + order_id: str + amount: float + listOrchestrators: List[str] = [] listActivities: List[str] = [] @@ -630,3 +637,120 @@ def my_fn(ctx): with self.assertRaises(ValueError) as ctx: alternate_name(name='second')(my_fn) self.assertIn('already has an alternate name', str(ctx.exception)) + + +class PydanticInputCoercionTest(unittest.TestCase): + """Signature-directed Pydantic input coercion in workflow/activity wrappers.""" + + def setUp(self): + listActivities.clear() + listOrchestrators.clear() + self._registry_patch = mock.patch( + 'dapr.ext.workflow._durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker() + ) + self._registry_patch.start() + self.runtime = WorkflowRuntime() + self.fake_registry = self.runtime._WorkflowRuntime__worker._registry + + def tearDown(self): + mock.patch.stopall() + + def test_activity_wrapper_coerces_dict_to_pydantic_model(self): + received = {} + + def my_act(ctx, order: Order): + received['order'] = order + return order.amount * 2 + + self.runtime.register_activity(my_act, name='pydantic_act') + wrapper = self.fake_registry._activity_fns['pydantic_act'] + + result = wrapper(mock.MagicMock(), {'order_id': 'o1', 'amount': 5.0}) + self.assertIsInstance(received['order'], Order) + self.assertEqual(received['order'].order_id, 'o1') + self.assertEqual(result, 10.0) + + def test_workflow_wrapper_coerces_dict_to_pydantic_model(self): + received = {} + + def my_wf(ctx, order: Order): + received['order'] = order + return order.order_id + + self.runtime.register_workflow(my_wf, name='pydantic_wf') + wrapper = self.fake_registry._orchestrator_fns['pydantic_wf'] + + result = wrapper(mock.MagicMock(), {'order_id': 'o2', 'amount': 3.0}) + self.assertIsInstance(received['order'], Order) + self.assertEqual(result, 'o2') + + def test_activity_wrapper_passthrough_when_not_annotated(self): + def my_act(ctx, inp): + return inp + + self.runtime.register_activity(my_act, name='plain_act') + wrapper = self.fake_registry._activity_fns['plain_act'] + + payload = {'order_id': 'o3', 'amount': 1.0} + result = wrapper(mock.MagicMock(), payload) + self.assertIs(result, payload) + + def test_workflow_wrapper_passthrough_for_primitive_annotation(self): + def my_wf(ctx, n: int): + return n + 1 + + self.runtime.register_workflow(my_wf, name='int_wf') + wrapper = self.fake_registry._orchestrator_fns['int_wf'] + + result = wrapper(mock.MagicMock(), 41) + self.assertEqual(result, 42) + + def test_activity_wrapper_handles_optional_annotation(self): + def my_act(ctx, order: Optional[Order] = None): + return order + + self.runtime.register_activity(my_act, name='optional_act') + wrapper = self.fake_registry._activity_fns['optional_act'] + + self.assertIsNone(wrapper(mock.MagicMock(), None)) + result = wrapper(mock.MagicMock(), {'order_id': 'o4', 'amount': 7.0}) + self.assertIsInstance(result, Order) + self.assertEqual(result.amount, 7.0) + + def test_activity_wrapper_passes_through_existing_model_instance(self): + instance = Order(order_id='o5', amount=9.0) + + def my_act(ctx, order: Order): + return order + + self.runtime.register_activity(my_act, name='reuse_act') + wrapper = self.fake_registry._activity_fns['reuse_act'] + + result = wrapper(mock.MagicMock(), instance) + self.assertIs(result, instance) + + def test_activity_wrapper_raises_validation_error_for_invalid_payload(self): + def my_act(ctx, order: Order): + return order + + self.runtime.register_activity(my_act, name='invalid_act') + wrapper = self.fake_registry._activity_fns['invalid_act'] + + with self.assertRaises(ValidationError): + wrapper(mock.MagicMock(), {'order_id': 'o6'}) # missing amount + + def test_versioned_workflow_wrapper_coerces_input(self): + received = {} + + def my_wf(ctx, order: Order): + received['order'] = order + return order.order_id + + self.runtime.register_versioned_workflow( + my_wf, name='versioned_pydantic', version_name='v1', is_latest=True + ) + wrapper = self.fake_registry._orchestrator_fns['versioned_pydantic'] + + result = wrapper(mock.MagicMock(), {'order_id': 'v1', 'amount': 2.0}) + self.assertIsInstance(received['order'], Order) + self.assertEqual(result, 'v1') From 672a0ca07153c43d6f265d12e49e0556424fca09 Mon Sep 17 00:00:00 2001 From: Albert Callarisa Date: Fri, 17 Apr 2026 12:22:43 +0200 Subject: [PATCH 2/4] Address copilot comments Signed-off-by: Albert Callarisa --- examples/workflow/README.md | 2 +- examples/workflow/pydantic_models.py | 3 --- .../dapr/ext/workflow/_durabletask/internal/shared.py | 3 +-- ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py | 3 ++- ext/dapr-ext-workflow/tests/test_model_protocol.py | 4 ---- 5 files changed, 4 insertions(+), 11 deletions(-) diff --git a/examples/workflow/README.md b/examples/workflow/README.md index 30210e667..ecc28193d 100644 --- a/examples/workflow/README.md +++ b/examples/workflow/README.md @@ -529,7 +529,7 @@ expected_stdout_lines: - "[activity] approving order O-100" - "[workflow] activity returned approved=True" - "[client] workflow output: order_id=O-100 approved=True message=auto-approved" -timeout_seconds: 30 +timeout_seconds: 60 --> ```sh diff --git a/examples/workflow/pydantic_models.py b/examples/workflow/pydantic_models.py index c4aefd79f..e9b1471ea 100644 --- a/examples/workflow/pydantic_models.py +++ b/examples/workflow/pydantic_models.py @@ -17,8 +17,6 @@ apps. """ -from time import sleep - from dapr.ext.workflow import ( DaprWorkflowClient, DaprWorkflowContext, @@ -77,7 +75,6 @@ def approve_order(ctx: WorkflowActivityContext, order: OrderRequest) -> OrderRes def main(): wfr.start() - sleep(5) client = DaprWorkflowClient() order = OrderRequest(order_id='O-100', customer='Acme', amount=42.0) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/internal/shared.py b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/internal/shared.py index ffe144dfb..5c9bd9f9b 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/internal/shared.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/internal/shared.py @@ -17,6 +17,7 @@ from typing import Any, Optional, Sequence, Union import grpc +from dapr.ext.workflow import _model_protocol ClientInterceptor = Union[ grpc.UnaryUnaryClientInterceptor, @@ -164,8 +165,6 @@ def default(self, obj): # activity/workflow input boundary in # dapr.ext.workflow.workflow_runtime. No pydantic dependency — any # class matching the protocol works (Pydantic v2, SQLModel, custom). - from dapr.ext.workflow import _model_protocol - if _model_protocol.is_model(obj): return _model_protocol.dump_model(obj) if dataclasses.is_dataclass(obj): diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index 2a62ab096..0a0c300ef 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -19,7 +19,6 @@ from typing import Optional, Sequence, TypeVar, Union import grpc -from dapr.ext.workflow import _model_protocol from dapr.ext.workflow._durabletask import task, worker from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext from dapr.ext.workflow.logger import Logger, LoggerOptions @@ -32,6 +31,8 @@ from dapr.conf import settings from dapr.conf.helpers import GrpcEndpoint +from . import _model_protocol + T = TypeVar('T') TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') diff --git a/ext/dapr-ext-workflow/tests/test_model_protocol.py b/ext/dapr-ext-workflow/tests/test_model_protocol.py index f815e6830..074c6f8eb 100644 --- a/ext/dapr-ext-workflow/tests/test_model_protocol.py +++ b/ext/dapr-ext-workflow/tests/test_model_protocol.py @@ -152,7 +152,3 @@ def model_validate(cls, data): self.assertFalse(_model_protocol.is_model_class(DumpOnly)) self.assertFalse(_model_protocol.is_model_class(ValidateOnly)) - - -if __name__ == '__main__': - unittest.main() From b0a841d1f18eb5cd5b21e82ba53b66ba5229451a Mon Sep 17 00:00:00 2001 From: Albert Callarisa Date: Fri, 17 Apr 2026 12:42:48 +0200 Subject: [PATCH 3/4] Address copilot comments Signed-off-by: Albert Callarisa --- .../dapr/ext/workflow/_model_protocol.py | 27 ++++++++++--------- .../dapr/ext/workflow/workflow_runtime.py | 26 +++++++++++------- .../tests/test_model_protocol.py | 18 ++++++------- .../tests/test_workflow_runtime.py | 11 ++++++++ 4 files changed, 52 insertions(+), 30 deletions(-) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/_model_protocol.py b/ext/dapr-ext-workflow/dapr/ext/workflow/_model_protocol.py index fa1f29438..d11624c2f 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/_model_protocol.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/_model_protocol.py @@ -97,36 +97,39 @@ def coerce_to_model(value: Any, cls: type) -> Any: ) -def resolve_input_model(fn: Callable[..., Any]) -> Optional[type]: - """Return the model class annotated on fn's input parameter, if any. - - Workflow and activity functions take (ctx, input) — we look at the second - positional parameter's annotation. Returns None when no annotation is - present or the annotation is not a model class. Optional[Model] and - Model | None are unwrapped to Model. +def resolve_input(fn: Callable[..., Any]) -> tuple[bool, Optional[type]]: + """Inspect fn's input parameter. + + Returns (accepts_input, model_class): + - accepts_input is True when fn declares a second positional parameter + (beyond the context) — the runtime must then pass the input through + even when it is None, so `Optional[Model]` works without a default. + - model_class is the model class annotated on that parameter, or None + when there is no annotation or the annotation is not a model. + Optional[Model] and Model | None are unwrapped to Model. """ try: sig = inspect.signature(fn) except (TypeError, ValueError): - return None + return False, None params = list(sig.parameters.values()) if len(params) < 2: - return None + return False, None annotation = params[1].annotation if annotation is inspect.Parameter.empty: - return None + return True, None if isinstance(annotation, str): try: hints = typing.get_type_hints(fn) annotation = hints.get(params[1].name, annotation) except Exception: - return None + return True, None annotation = _unwrap_optional(annotation) - return annotation if is_model_class(annotation) else None + return True, (annotation if is_model_class(annotation) else None) def _unwrap_optional(annotation: Any) -> Any: diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index 0a0c300ef..e0627377d 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -91,7 +91,7 @@ def register_workflow(self, fn: Workflow, *, name: Optional[str] = None): effective_name = name or fn.__name__ self._logger.info(f"Registering workflow '{effective_name}' with runtime") - input_model = _model_protocol.resolve_input_model(fn) + accepts_input, input_model = _model_protocol.resolve_input(fn) def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None): """Responsible to call Workflow function in orchestrationWrapper""" @@ -99,10 +99,14 @@ def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = try: daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) - if inp is None: + if not accepts_input: result = fn(daprWfContext) else: - if input_model is not None and not isinstance(inp, input_model): + if ( + inp is not None + and input_model is not None + and not isinstance(inp, input_model) + ): inp = _model_protocol.coerce_to_model(inp, input_model) result = fn(daprWfContext, inp) return result @@ -137,14 +141,14 @@ def register_versioned_workflow( f"Registering version {version_name} of workflow '{effective_name}' with runtime" ) - input_model = _model_protocol.resolve_input_model(fn) + accepts_input, input_model = _model_protocol.resolve_input(fn) def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None): """Responsible to call Workflow function in orchestrationWrapper""" daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) - if inp is None: + if not accepts_input: return fn(daprWfContext) - if input_model is not None and not isinstance(inp, input_model): + if inp is not None and input_model is not None and not isinstance(inp, input_model): inp = _model_protocol.coerce_to_model(inp, input_model) return fn(daprWfContext, inp) @@ -177,7 +181,7 @@ def register_activity(self, fn: Activity, *, name: Optional[str] = None): effective_name = name or fn.__name__ self._logger.info(f"Registering activity '{effective_name}' with runtime") - input_model = _model_protocol.resolve_input_model(fn) + accepts_input, input_model = _model_protocol.resolve_input(fn) def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): """Responsible to call Activity function in activityWrapper""" @@ -185,10 +189,14 @@ def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): try: wfActivityContext = WorkflowActivityContext(ctx) - if inp is None: + if not accepts_input: result = fn(wfActivityContext) else: - if input_model is not None and not isinstance(inp, input_model): + if ( + inp is not None + and input_model is not None + and not isinstance(inp, input_model) + ): inp = _model_protocol.coerce_to_model(inp, input_model) result = fn(wfActivityContext, inp) return result diff --git a/ext/dapr-ext-workflow/tests/test_model_protocol.py b/ext/dapr-ext-workflow/tests/test_model_protocol.py index 074c6f8eb..bcfa6b1d6 100644 --- a/ext/dapr-ext-workflow/tests/test_model_protocol.py +++ b/ext/dapr-ext-workflow/tests/test_model_protocol.py @@ -82,36 +82,36 @@ def test_coerce_to_model_raises_validation_error_on_invalid_payload(self): _model_protocol.coerce_to_model({'order_id': 'o1'}, Order) # missing amount -class ResolveInputModelTest(unittest.TestCase): +class ResolveInputTest(unittest.TestCase): def test_resolves_pydantic_annotation(self): def my_activity(ctx, order: Order): return order - self.assertIs(_model_protocol.resolve_input_model(my_activity), Order) + self.assertEqual(_model_protocol.resolve_input(my_activity), (True, Order)) def test_unwraps_optional(self): def my_activity(ctx, order: Optional[Order] = None): return order - self.assertIs(_model_protocol.resolve_input_model(my_activity), Order) + self.assertEqual(_model_protocol.resolve_input(my_activity), (True, Order)) - def test_returns_none_when_no_annotation(self): + def test_accepts_input_without_annotation(self): def my_activity(ctx, order): return order - self.assertIsNone(_model_protocol.resolve_input_model(my_activity)) + self.assertEqual(_model_protocol.resolve_input(my_activity), (True, None)) - def test_returns_none_for_non_model_annotation(self): + def test_accepts_input_with_non_model_annotation(self): def my_activity(ctx, order: dict): return order - self.assertIsNone(_model_protocol.resolve_input_model(my_activity)) + self.assertEqual(_model_protocol.resolve_input(my_activity), (True, None)) - def test_returns_none_for_ctx_only(self): + def test_ctx_only_does_not_accept_input(self): def my_activity(ctx): return None - self.assertIsNone(_model_protocol.resolve_input_model(my_activity)) + self.assertEqual(_model_protocol.resolve_input(my_activity), (False, None)) class _DuckModelNoModeKwarg: diff --git a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py index 340ed9ac6..233bd032f 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py @@ -754,3 +754,14 @@ def my_wf(ctx, order: Order): result = wrapper(mock.MagicMock(), {'order_id': 'v1', 'amount': 2.0}) self.assertIsInstance(received['order'], Order) self.assertEqual(result, 'v1') + + def test_activity_wrapper_passes_none_to_fn_that_expects_input(self): + """Regression: Optional[Model] without a default must receive None, not be dropped.""" + + def my_act(ctx, order: Optional[Order]): + return order + + self.runtime.register_activity(my_act, name='optional_no_default_act') + wrapper = self.fake_registry._activity_fns['optional_no_default_act'] + + self.assertIsNone(wrapper(mock.MagicMock(), None)) From e77c2b12147dfa164f48674bb500965d6d8e2a87 Mon Sep 17 00:00:00 2001 From: Albert Callarisa Date: Mon, 20 Apr 2026 11:08:20 +0200 Subject: [PATCH 4/4] cleaner conditions Signed-off-by: Albert Callarisa --- .../dapr/ext/workflow/workflow_runtime.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index e0627377d..11bae78ac 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -103,8 +103,8 @@ def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = result = fn(daprWfContext) else: if ( - inp is not None - and input_model is not None + (inp is not None) + and (input_model is not None) and not isinstance(inp, input_model) ): inp = _model_protocol.coerce_to_model(inp, input_model) @@ -148,7 +148,7 @@ def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) if not accepts_input: return fn(daprWfContext) - if inp is not None and input_model is not None and not isinstance(inp, input_model): + if (inp is not None) and (input_model is not None) and not isinstance(inp, input_model): inp = _model_protocol.coerce_to_model(inp, input_model) return fn(daprWfContext, inp) @@ -193,8 +193,8 @@ def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): result = fn(wfActivityContext) else: if ( - inp is not None - and input_model is not None + (inp is not None) + and (input_model is not None) and not isinstance(inp, input_model) ): inp = _model_protocol.coerce_to_model(inp, input_model)