Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions cloud_pipelines_backend/event_listeners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import dataclasses
import threading
import typing
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Google style guide for typing would be to not import it as a module, I believe we do that in other places in this repo as well.

https://google.github.io/styleguide/pyguide.html#31912-imports-for-typing

image.png


_CallbackEntry: typing.TypeAlias = tuple[typing.Callable[..., None], bool]


@dataclasses.dataclass(frozen=True, kw_only=True)
class Event:
"""Marker base class for all event types."""


_EventType = typing.TypeVar("_EventType", bound=Event)

_listeners: dict[type, list[_CallbackEntry]] = {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI - there is typing.Final which will have linting to make sure this variable can't be assigned. Not necessary needed here, but just an FYI when I create global variables.



def subscribe(
*,
event_type: type[_EventType],
callback: typing.Callable[[_EventType], None],
asynchronous: bool = True,
) -> None:
"""Subscribe callback to event_type. Called once at startup per consumer.

Args:
event_type: The event class to subscribe to.
callback: Called with the event instance when an event of that type is emitted.
asynchronous: When True (default), the callback is dispatched on a
separate daemon thread so emit() returns immediately. When False,
the callback is invoked synchronously on the calling thread.
"""
_listeners.setdefault(event_type, []).append((callback, asynchronous))
Comment thread
yuechao-qin marked this conversation as resolved.


def emit(
*,
event: _EventType,
) -> None:
"""Dispatch event to all callbacks subscribed to its type."""
for callback, asynchronous in _listeners.get(type(event), []):
if asynchronous:
threading.Thread(target=callback, args=(event,), daemon=True).start()
else:
callback(event)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe wrapping a try/catch here would be important to not break the emit for loop?

Add a new test for this too.

Copy link
Copy Markdown
Collaborator Author

@morgan-wowk morgan-wowk Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm...I guess that would help ensure at least the "async" listeners are guarenteed. But a synchronous listener should prevent any further action / subsequent synchronous listners ffrom executing.

Does that sound good? We add a try catch to ensure the async listeners are always called?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also just dispatch the asychronous listeners first instead of having a try catch

189 changes: 189 additions & 0 deletions tests/test_event_listeners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""Tests for cloud_pipelines_backend.event_listeners."""

import threading
import typing

import pytest

from cloud_pipelines_backend import event_listeners


@pytest.fixture(autouse=True)
def reset_listeners() -> typing.Generator[None, None, None]:
"""Clear _listeners before and after every test."""
event_listeners._listeners.clear()
yield
event_listeners._listeners.clear()


class _SampleEvent(event_listeners.Event):
"""Minimal Event subclass for testing."""


class _AnotherEvent(event_listeners.Event):
"""Second Event subclass to verify type-keyed dispatch."""


class TestSubscribeAndEmitSync:
def test_callback_is_called_with_event(self) -> None:
received: list[_SampleEvent] = []
event = _SampleEvent()

event_listeners.subscribe(
event_type=_SampleEvent,
callback=received.append,
asynchronous=False,
)
event_listeners.emit(event=event)

assert received == [event]

def test_multiple_callbacks_all_called_in_order(self) -> None:
calls: list[int] = []

event_listeners.subscribe(
event_type=_SampleEvent,
callback=lambda _e: calls.append(1),
asynchronous=False,
)
event_listeners.subscribe(
event_type=_SampleEvent,
callback=lambda _e: calls.append(2),
asynchronous=False,
)
event_listeners.emit(event=_SampleEvent())

assert calls == [1, 2]

def test_emit_without_subscribers_is_noop(self) -> None:
event_listeners.emit(event=_SampleEvent())

def test_callbacks_only_fired_for_matching_event_type(self) -> None:
calls: list[str] = []

event_listeners.subscribe(
event_type=_SampleEvent,
callback=lambda _e: calls.append("sample"),
asynchronous=False,
)
event_listeners.subscribe(
event_type=_AnotherEvent,
callback=lambda _e: calls.append("another"),
asynchronous=False,
)

event_listeners.emit(event=_SampleEvent())

assert calls == ["sample"]


class TestAsynchronousDispatch:
def test_async_callback_runs_on_separate_thread(self) -> None:
callback_thread_ident: list[int] = []
done = threading.Event()

def callback(_e: _SampleEvent) -> None:
callback_thread_ident.append(threading.current_thread().ident)
done.set()

event_listeners.subscribe(
event_type=_SampleEvent,
callback=callback,
asynchronous=True,
)
event_listeners.emit(event=_SampleEvent())

assert done.wait(timeout=2.0), "async callback did not fire within 2 s"
assert callback_thread_ident[0] != threading.main_thread().ident

def test_asynchronous_defaults_to_true(self) -> None:
callback_thread_ident: list[int] = []
done = threading.Event()

def callback(_e: _SampleEvent) -> None:
callback_thread_ident.append(threading.current_thread().ident)
done.set()

event_listeners.subscribe(event_type=_SampleEvent, callback=callback)
event_listeners.emit(event=_SampleEvent())

assert done.wait(timeout=2.0), "default async callback did not fire within 2 s"
assert callback_thread_ident[0] != threading.main_thread().ident

def test_sync_callback_runs_on_calling_thread(self) -> None:
callback_thread_ident: list[int] = []

event_listeners.subscribe(
event_type=_SampleEvent,
callback=lambda _e: callback_thread_ident.append(
threading.current_thread().ident
),
asynchronous=False,
)
event_listeners.emit(event=_SampleEvent())

assert callback_thread_ident == [threading.main_thread().ident]

def test_mixed_sync_and_async_subscribers(self) -> None:
sync_thread_ident: list[int] = []
async_thread_ident: list[int] = []
async_done = threading.Event()

def sync_callback(_e: _SampleEvent) -> None:
sync_thread_ident.append(threading.current_thread().ident)

def async_callback(_e: _SampleEvent) -> None:
async_thread_ident.append(threading.current_thread().ident)
async_done.set()

event_listeners.subscribe(
event_type=_SampleEvent,
callback=sync_callback,
asynchronous=False,
)
event_listeners.subscribe(
event_type=_SampleEvent,
callback=async_callback,
asynchronous=True,
)

event_listeners.emit(event=_SampleEvent())

assert sync_thread_ident == [threading.main_thread().ident]
assert async_done.wait(timeout=2.0), "async callback did not fire"
assert async_thread_ident[0] != threading.main_thread().ident

def test_exception_in_async_callback_does_not_propagate_to_caller(self) -> None:
"""A runtime exception raised inside an async callback must not reach emit()'s caller."""
callback_ran = threading.Event()
exception_raised = threading.Event()
caught_exceptions: list[BaseException] = []

original_excepthook = threading.excepthook

def _capture(args: threading.ExceptHookArgs) -> None:
caught_exceptions.append(args.exc_value)
exception_raised.set()

threading.excepthook = _capture
try:
def raising_callback(_event: _SampleEvent) -> None:
callback_ran.set()
raise RuntimeError("boom")

event_listeners.subscribe(
event_type=_SampleEvent,
callback=raising_callback,
asynchronous=True,
)

# emit() must return normally even though the callback will raise.
event_listeners.emit(event=_SampleEvent())
main_thread_continued = True

assert callback_ran.wait(timeout=2.0), "async callback did not run"
assert exception_raised.wait(timeout=2.0), "thread exception hook did not fire"
assert main_thread_continued
assert isinstance(caught_exceptions[0], RuntimeError)
finally:
threading.excepthook = original_excepthook