From f6d805d6b1374880ee72892acfe23b02bcc71d06 Mon Sep 17 00:00:00 2001 From: Abigail Emery Date: Tue, 24 Feb 2026 17:20:02 +0000 Subject: [PATCH 01/21] wip --- src/blueapi/service/interface.py | 15 +++++++++++++ src/blueapi/service/main.py | 38 +++++++++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index c8b1e07245..d0c3dc6d0d 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -1,5 +1,6 @@ from collections.abc import Mapping from functools import cache +from multiprocessing.connection import Connection from typing import Any from bluesky.callbacks.tiled_writer import TiledWriter @@ -278,3 +279,17 @@ def get_python_env( """Retrieve information about the Python environment""" scratch = config().scratch return get_python_environment(config=scratch, name=name, source=source) + + +def pipe_events(tx: Connection) -> int: + + def handler( + worker_event: WorkerEvent, + cor_id: str | None, + ) -> None: + tx.send(worker_event) + + task_worker = worker() + sub_id = task_worker.worker_events.subscribe(handler) + + return sub_id diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 81f9f849ee..dda91482a7 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -2,6 +2,7 @@ import urllib.parse from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager +from multiprocessing import Pipe from typing import Annotated, Any import jwt @@ -14,8 +15,10 @@ HTTPException, Request, Response, + WebSocket, status, ) +from fastapi.concurrency import run_in_threadpool from fastapi.datastructures import Address from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse, StreamingResponse @@ -38,7 +41,7 @@ from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.service import interface from blueapi.worker import TrackableTask, WorkerState -from blueapi.worker.event import TaskStatusEnum +from blueapi.worker.event import TaskStatusEnum, WorkerEvent from .model import ( DeviceModel, @@ -541,6 +544,39 @@ def logout(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> Response: ) +@secure_router.websocket("/run_plan") +async def run_plan( + ws: WebSocket, + runner: Annotated[WorkerDispatcher, Depends(_runner)], +): + user = "alice" + + # ack ws + await ws.accept() + # accept task request through socket + rq = await ws.receive_json() + # submit task to runner + task_request: TaskRequest = TaskRequest.model_validate(rq) + task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) + # add listener to runner + tx, rx = Pipe() + h = runner.run(interface.pipe_events, tx=tx) + # start task + task = WorkerTask(task_id=task_id) + runner.run( + interface.begin_task, + task=task, + ) + # pipe events to ws + while True: + event: WorkerEvent = await run_in_threadpool(rx.recv) + await ws.send_json(event.model_dump(mode="json")) + if event.is_complete(): + break + # ??? + # profit + + @start_as_current_span(TRACER, "config") def start(config: ApplicationConfig): import uvicorn From 6e1e9dccae001dfe66f0a55cc34ec89548d5b37d Mon Sep 17 00:00:00 2001 From: Abigail Emery Date: Tue, 24 Feb 2026 17:37:08 +0000 Subject: [PATCH 02/21] client wip --- src/blueapi/cli/cli.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index bd2154f0a1..f04d98d126 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -382,6 +382,29 @@ def on_event(event: AnyEvent) -> None: raise ClickException(f"task could not run: {ve}") from ve +@controller.command(name="ws") +@click.argument("name", type=str) +@click.argument("parameters", type=ParametersType(), default={}, required=False) +def run_blocking( + name: str, + parameters: TaskParameters, +): + instrument_session = "cm33-3" + + from websockets.sync.client import connect + + task_req = TaskRequest( + name=name, + params=parameters, + instrument_session=instrument_session, + ) + + with connect("ws://localhost:8007/run_plan") as ws: + ws.send(task_req.model_dump_json()) + while message := ws.recv(): + print(message) + + @controller.command(name="state") @click.pass_obj @check_connection From 04cb661d8b26e0d1de28f8294041bdaf8084408f Mon Sep 17 00:00:00 2001 From: Abigail Emery Date: Tue, 24 Feb 2026 17:48:31 +0000 Subject: [PATCH 03/21] use normal iter --- src/blueapi/cli/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index f04d98d126..66c1c59a75 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -401,7 +401,7 @@ def run_blocking( with connect("ws://localhost:8007/run_plan") as ws: ws.send(task_req.model_dump_json()) - while message := ws.recv(): + for message in ws: print(message) From 9ad35bb024a9811f4e6e88908120f8f29b650085 Mon Sep 17 00:00:00 2001 From: Abigail Emery Date: Tue, 24 Feb 2026 17:48:48 +0000 Subject: [PATCH 04/21] close ws --- src/blueapi/service/main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index dda91482a7..24f4b29396 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -573,8 +573,7 @@ async def run_plan( await ws.send_json(event.model_dump(mode="json")) if event.is_complete(): break - # ??? - # profit + await ws.close() @start_as_current_span(TRACER, "config") From 56a51eb079235c314da753e111ec1cb82b3f9dc1 Mon Sep 17 00:00:00 2001 From: Abigail Emery Date: Tue, 24 Feb 2026 18:41:02 +0000 Subject: [PATCH 05/21] add some trys --- src/blueapi/service/main.py | 39 ++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 24f4b29396..74531bb332 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -42,6 +42,7 @@ from blueapi.service import interface from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum, WorkerEvent +from blueapi.worker.worker_errors import WorkerBusyError from .model import ( DeviceModel, @@ -556,24 +557,36 @@ async def run_plan( # accept task request through socket rq = await ws.receive_json() # submit task to runner - task_request: TaskRequest = TaskRequest.model_validate(rq) - task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) + try: + task_request: TaskRequest = TaskRequest.model_validate(rq) + task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) + except ValidationError: + await ws.close(code=1003, reason="invalid args") + return + # add listener to runner tx, rx = Pipe() h = runner.run(interface.pipe_events, tx=tx) # start task - task = WorkerTask(task_id=task_id) - runner.run( - interface.begin_task, - task=task, - ) + try: + task = WorkerTask(task_id=task_id) + runner.run( + interface.begin_task, + task=task, + ) + except WorkerBusyError: + await ws.close(code=1013, reason="Worker busy") + return # pipe events to ws - while True: - event: WorkerEvent = await run_in_threadpool(rx.recv) - await ws.send_json(event.model_dump(mode="json")) - if event.is_complete(): - break - await ws.close() + try: + while True: + event: WorkerEvent = await run_in_threadpool(rx.recv) + await ws.send_json(event.model_dump(mode="json")) + if event.is_complete(): + break + finally: + await ws.close() + runner.run(interface.unpipe_events, h=h) @start_as_current_span(TRACER, "config") From 78e076bc57b5654c9144cbd039b00bda80a32a5d Mon Sep 17 00:00:00 2001 From: Abigail Emery Date: Tue, 24 Feb 2026 18:42:16 +0000 Subject: [PATCH 06/21] unpipe --- src/blueapi/service/interface.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index d0c3dc6d0d..7e35833698 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Mapping from functools import cache from multiprocessing.connection import Connection @@ -23,6 +24,7 @@ WorkerTask, ) from blueapi.utils.serialization import access_blob +from blueapi.worker import task_worker from blueapi.worker.event import TaskStatusEnum, WorkerEvent, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TaskWorker, TrackableTask @@ -30,7 +32,7 @@ """This module provides interface between web application and underlying Bluesky context and worker""" - +LOGGER = logging.getLogger(__name__) _CONFIG: ApplicationConfig = ApplicationConfig() @@ -287,9 +289,14 @@ def handler( worker_event: WorkerEvent, cor_id: str | None, ) -> None: + LOGGER.info("Sending event") tx.send(worker_event) task_worker = worker() sub_id = task_worker.worker_events.subscribe(handler) - return sub_id + + +def unpipe_events(h: int) -> None: + task_worker = worker() + task_worker.worker_events.unsubscribe(h) From 77ec78db3ca24788f7f7a54f48c7bdc8e2da4209 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 3 Mar 2026 11:52:21 +0000 Subject: [PATCH 07/21] Move websocket handling into BlueapiRestClient --- src/blueapi/cli/cli.py | 25 ++++++++++++++----------- src/blueapi/client/client.py | 4 ++++ src/blueapi/client/rest.py | 17 ++++++++++++++++- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 66c1c59a75..6c2fc84ee4 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -383,26 +383,29 @@ def on_event(event: AnyEvent) -> None: @controller.command(name="ws") +@click.pass_obj @click.argument("name", type=str) @click.argument("parameters", type=ParametersType(), default={}, required=False) +@click.option( + "-i", + "--instrument-session", + type=str, + help=textwrap.dedent(""" + Instrument session associated with running the plan, + used to tell blueapi where to store any data and as a security check: + the session must be valid and active and you must be a member of it."""), + required=True, +) def run_blocking( - name: str, - parameters: TaskParameters, + obj: dict, name: str, parameters: TaskParameters, instrument_session: str ): - instrument_session = "cm33-3" - - from websockets.sync.client import connect - + client = cast(BlueapiClient, obj["client"]) task_req = TaskRequest( name=name, params=parameters, instrument_session=instrument_session, ) - - with connect("ws://localhost:8007/run_plan") as ws: - ws.send(task_req.model_dump_json()) - for message in ws: - print(message) + client.run_blocking(task_req) @controller.command(name="state") diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index c5b41ff45e..0638f911f6 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -439,6 +439,10 @@ def get_active_task(self) -> WorkerTask: return self.active_task + @start_as_current_span(TRACER, "request") + def run_blocking(self, request: TaskRequest): + self._rest.run_blocking(request) + @start_as_current_span(TRACER, "task", "timeout") def run_task( self, diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 6caa0add20..b758be3ecb 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -10,7 +10,8 @@ get_tracer, start_as_current_span, ) -from pydantic import BaseModel, TypeAdapter, ValidationError +from pydantic import BaseModel, TypeAdapter, ValidationError, WebsocketUrl +from websockets.sync.client import connect from blueapi import __version__ from blueapi.config import RestConfig @@ -308,6 +309,20 @@ def _request_and_deserialize( ) return deserialized + def run_blocking(self, req: TaskRequest): + url = self._ws_address().unicode_string().removesuffix("/") + "/run_plan" + print(url) + with connect(url) as ws: + ws.send(req.model_dump_json()) + for message in ws: + print(message) + + def _ws_address(self) -> WebsocketUrl: + # url = WebsocketUrl.build( + # scheme="ws", host=api.host, port=api.port, path=api.path + # ) + return WebsocketUrl("ws://localhost:8000/") + # https://github.com/DiamondLightSource/blueapi/issues/1256 - remove before 2.0 def __getattr__(name: str): From a3f341c3807e5845b1c40dd9e9470e748559ddbe Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 3 Mar 2026 12:16:10 +0000 Subject: [PATCH 08/21] Send all events through websocket --- src/blueapi/service/interface.py | 13 ++++++++----- src/blueapi/service/main.py | 10 +++++++--- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 7e35833698..f30f5559c3 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -11,6 +11,7 @@ from blueapi.cli.scratch import get_python_environment from blueapi.config import ApplicationConfig, OIDCConfig, ServiceAccount, StompConfig +from blueapi.core.bluesky_types import DataEvent from blueapi.core.context import BlueskyContext from blueapi.core.event import EventStream from blueapi.log import set_up_logging @@ -24,8 +25,7 @@ WorkerTask, ) from blueapi.utils.serialization import access_blob -from blueapi.worker import task_worker -from blueapi.worker.event import TaskStatusEnum, WorkerEvent, WorkerState +from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TaskWorker, TrackableTask @@ -286,17 +286,20 @@ def get_python_env( def pipe_events(tx: Connection) -> int: def handler( - worker_event: WorkerEvent, - cor_id: str | None, + worker_event: WorkerEvent | DataEvent | ProgressEvent, + _cor_id: str | None, ) -> None: - LOGGER.info("Sending event") tx.send(worker_event) task_worker = worker() sub_id = task_worker.worker_events.subscribe(handler) + sub_id = task_worker.data_events.subscribe(handler) + sub_id = task_worker.progress_events.subscribe(handler) return sub_id def unpipe_events(h: int) -> None: task_worker = worker() task_worker.worker_events.unsubscribe(h) + task_worker.data_events.unsubscribe(h) + task_worker.progress_events.unsubscribe(h) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 74531bb332..82bc8eeb4d 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -39,9 +39,10 @@ from blueapi import __version__ from blueapi.config import ApplicationConfig, OIDCConfig, Tag +from blueapi.core.bluesky_types import DataEvent from blueapi.service import interface from blueapi.worker import TrackableTask, WorkerState -from blueapi.worker.event import TaskStatusEnum, WorkerEvent +from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent from blueapi.worker.worker_errors import WorkerBusyError from .model import ( @@ -67,6 +68,9 @@ LOGGER = logging.getLogger(__name__) +AnyEvent = WorkerEvent | DataEvent | ProgressEvent + + def _runner() -> WorkerDispatcher: """Intended to be used only with FastAPI Depends""" if RUNNER is None: @@ -580,9 +584,9 @@ async def run_plan( # pipe events to ws try: while True: - event: WorkerEvent = await run_in_threadpool(rx.recv) + event: AnyEvent = await run_in_threadpool(rx.recv) await ws.send_json(event.model_dump(mode="json")) - if event.is_complete(): + if isinstance(event, WorkerEvent) and event.is_complete(): break finally: await ws.close() From 65632538db595b48322a62e559173343222b1869 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 4 Mar 2026 14:50:32 +0000 Subject: [PATCH 09/21] Split pipe subscribe handles --- src/blueapi/service/interface.py | 22 +++++++++++++--------- src/blueapi/service/main.py | 2 +- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index f30f5559c3..796fcfea23 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -283,7 +283,10 @@ def get_python_env( return get_python_environment(config=scratch, name=name, source=source) -def pipe_events(tx: Connection) -> int: +SubHandle = tuple[int, int, int] + + +def pipe_events(tx: Connection) -> SubHandle: def handler( worker_event: WorkerEvent | DataEvent | ProgressEvent, @@ -292,14 +295,15 @@ def handler( tx.send(worker_event) task_worker = worker() - sub_id = task_worker.worker_events.subscribe(handler) - sub_id = task_worker.data_events.subscribe(handler) - sub_id = task_worker.progress_events.subscribe(handler) - return sub_id + w_id = task_worker.worker_events.subscribe(handler) + d_id = task_worker.data_events.subscribe(handler) + p_id = task_worker.progress_events.subscribe(handler) + return (w_id, d_id, p_id) -def unpipe_events(h: int) -> None: +def unpipe_events(hnd: SubHandle) -> None: task_worker = worker() - task_worker.worker_events.unsubscribe(h) - task_worker.data_events.unsubscribe(h) - task_worker.progress_events.unsubscribe(h) + w, d, p = hnd + task_worker.worker_events.unsubscribe(w) + task_worker.data_events.unsubscribe(d) + task_worker.progress_events.unsubscribe(p) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 82bc8eeb4d..265d33e342 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -590,7 +590,7 @@ async def run_plan( break finally: await ws.close() - runner.run(interface.unpipe_events, h=h) + runner.run(interface.unpipe_events, hnd=h) @start_as_current_span(TRACER, "config") From b53c7695983bcf208c16bf19c6b7d3fd4eba0066 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 4 Mar 2026 16:10:43 +0000 Subject: [PATCH 10/21] Re-use run subcommand for websockets --- src/blueapi/cli/cli.py | 10 +++++++++- src/blueapi/client/client.py | 19 +++++++++++++++++-- src/blueapi/client/rest.py | 5 +++-- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 6c2fc84ee4..c045095f8f 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -307,6 +307,7 @@ def on_event( @controller.command(name="run") @click.argument("name", type=str) @click.argument("parameters", type=ParametersType(), default={}, required=False) +@click.option("--ws", type=bool, is_flag=True, default=False) @click.option( "--foreground/--background", "--fg/--bg", type=bool, is_flag=True, default=True ) @@ -334,6 +335,7 @@ def run_plan( name: str, timeout: float | None, foreground: bool, + ws: bool, instrument_session: str, parameters: TaskParameters, ) -> None: @@ -355,7 +357,13 @@ def on_event(event: AnyEvent) -> None: elif isinstance(event, DataEvent): callback(event.name, event.doc) - resp = client.run_task(task, on_event=on_event) + client.add_callback(on_event) + + if ws: + resp = client.run_blocking(task) + else: + resp = client.run_task(task) + match resp.result: case TaskResult(result=None, type="NoneType"): print("Plan succeeded") diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 0638f911f6..8c41ced7ab 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -440,8 +440,23 @@ def get_active_task(self) -> WorkerTask: return self.active_task @start_as_current_span(TRACER, "request") - def run_blocking(self, request: TaskRequest): - self._rest.run_blocking(request) + def run_blocking( + self, request: TaskRequest, on_event: OnAnyEvent | None = None + ) -> TaskStatus: + for event in self._rest.run_blocking(request): + if on_event is not None: + on_event(event) + for cb in self._callbacks.values(): + try: + cb(event) + except Exception as e: + log.error(f"Callback ({cb}) failed for event: {event}", exc_info=e) + if isinstance(event, WorkerEvent) and event.is_complete(): + if event.task_status is None: + raise BlueskyRemoteControlError( + "Server completed without task status" + ) + return event.task_status @start_as_current_span(TRACER, "task", "timeout") def run_task( diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index b758be3ecb..8f3670b43d 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -14,6 +14,7 @@ from websockets.sync.client import connect from blueapi import __version__ +from blueapi.client.event_bus import AnyEvent from blueapi.config import RestConfig from blueapi.service.authentication import JWTAuth, SessionManager from blueapi.service.model import ( @@ -311,11 +312,11 @@ def _request_and_deserialize( def run_blocking(self, req: TaskRequest): url = self._ws_address().unicode_string().removesuffix("/") + "/run_plan" - print(url) with connect(url) as ws: ws.send(req.model_dump_json()) for message in ws: - print(message) + event = TypeAdapter(AnyEvent).validate_json(message) + yield event def _ws_address(self) -> WebsocketUrl: # url = WebsocketUrl.build( From 3ba83cb161de450eb6f0d036ee779bd132d573ec Mon Sep 17 00:00:00 2001 From: Abigail Emery Date: Wed, 4 Mar 2026 16:40:04 +0000 Subject: [PATCH 11/21] Raise for connection closing pre plan completed --- src/blueapi/client/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 8c41ced7ab..fd2c93f53f 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -457,6 +457,7 @@ def run_blocking( "Server completed without task status" ) return event.task_status + raise BlueskyRemoteControlError("Connection closed before plan completed.") @start_as_current_span(TRACER, "task", "timeout") def run_task( From 4333c65bea011730b78d36f559c2f36e98ab41ac Mon Sep 17 00:00:00 2001 From: Abigail Emery Date: Wed, 4 Mar 2026 16:41:45 +0000 Subject: [PATCH 12/21] Remove run blocking from cli --- src/blueapi/cli/cli.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index c045095f8f..60eef5b3c3 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -390,32 +390,6 @@ def on_event(event: AnyEvent) -> None: raise ClickException(f"task could not run: {ve}") from ve -@controller.command(name="ws") -@click.pass_obj -@click.argument("name", type=str) -@click.argument("parameters", type=ParametersType(), default={}, required=False) -@click.option( - "-i", - "--instrument-session", - type=str, - help=textwrap.dedent(""" - Instrument session associated with running the plan, - used to tell blueapi where to store any data and as a security check: - the session must be valid and active and you must be a member of it."""), - required=True, -) -def run_blocking( - obj: dict, name: str, parameters: TaskParameters, instrument_session: str -): - client = cast(BlueapiClient, obj["client"]) - task_req = TaskRequest( - name=name, - params=parameters, - instrument_session=instrument_session, - ) - client.run_blocking(task_req) - - @controller.command(name="state") @click.pass_obj @check_connection From 420b57ec7f70cf2e707359aea58484e0abff7c75 Mon Sep 17 00:00:00 2001 From: Abigail Emery Date: Wed, 4 Mar 2026 17:24:11 +0000 Subject: [PATCH 13/21] Catch plan key error in run_plan --- src/blueapi/service/main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 265d33e342..02af1d33b5 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -567,6 +567,9 @@ async def run_plan( except ValidationError: await ws.close(code=1003, reason="invalid args") return + except KeyError: + await ws.close(code=1003, reason="unknown plan") + return # add listener to runner tx, rx = Pipe() From 5e2af5c23d4fb747252678d9552cf9d830faecc6 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 6 Mar 2026 15:45:13 +0000 Subject: [PATCH 14/21] Refactor event pipe handling into context manager and iterable --- src/blueapi/service/interface.py | 37 +++++++++++------- src/blueapi/service/main.py | 48 +++++++++++------------ src/blueapi/service/runner.py | 66 ++++++++++++++++++++++++++------ 3 files changed, 101 insertions(+), 50 deletions(-) diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 796fcfea23..274ff738e7 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -1,5 +1,6 @@ import logging from collections.abc import Mapping +from dataclasses import dataclass from functools import cache from multiprocessing.connection import Connection from typing import Any @@ -283,27 +284,35 @@ def get_python_env( return get_python_environment(config=scratch, name=name, source=source) -SubHandle = tuple[int, int, int] +@dataclass +class SubHandles: + worker: int + progress: int + data: int -def pipe_events(tx: Connection) -> SubHandle: +def pipe_events(tx: Connection) -> SubHandles: + tw = worker() def handler( worker_event: WorkerEvent | DataEvent | ProgressEvent, _cor_id: str | None, ) -> None: - tx.send(worker_event) - task_worker = worker() - w_id = task_worker.worker_events.subscribe(handler) - d_id = task_worker.data_events.subscribe(handler) - p_id = task_worker.progress_events.subscribe(handler) - return (w_id, d_id, p_id) + try: + tx.send(worker_event) + except BrokenPipeError: + LOGGER.warning("Sending event to broken pipe") + pass + w = tw.worker_events.subscribe(handler) + d = tw.data_events.subscribe(handler) + p = tw.progress_events.subscribe(handler) + return SubHandles(worker=w, data=d, progress=p) -def unpipe_events(hnd: SubHandle) -> None: - task_worker = worker() - w, d, p = hnd - task_worker.worker_events.unsubscribe(w) - task_worker.data_events.unsubscribe(d) - task_worker.progress_events.unsubscribe(p) + +def unpipe_events(hnd: SubHandles): + tw = worker() + tw.worker_events.unsubscribe(hnd.worker) + tw.data_events.unsubscribe(hnd.data) + tw.progress_events.unsubscribe(hnd.progress) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 02af1d33b5..c90cdf4c95 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -2,7 +2,6 @@ import urllib.parse from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager -from multiprocessing import Pipe from typing import Annotated, Any import jwt @@ -16,9 +15,9 @@ Request, Response, WebSocket, + WebSocketDisconnect, status, ) -from fastapi.concurrency import run_in_threadpool from fastapi.datastructures import Address from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse, StreamingResponse @@ -556,44 +555,45 @@ async def run_plan( ): user = "alice" - # ack ws + LOGGER.info("Starting WS plan") await ws.accept() - # accept task request through socket rq = await ws.receive_json() - # submit task to runner + LOGGER.info("Raw request: %s", rq) try: task_request: TaskRequest = TaskRequest.model_validate(rq) + LOGGER.info("Plan request: %s", task_request) task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) + LOGGER.info("Task ID: %s", task_id) except ValidationError: + LOGGER.error("Args not valid", exc_info=True) await ws.close(code=1003, reason="invalid args") return except KeyError: + LOGGER.error("Plan not found", exc_info=True) await ws.close(code=1003, reason="unknown plan") return - # add listener to runner - tx, rx = Pipe() - h = runner.run(interface.pipe_events, tx=tx) - # start task try: - task = WorkerTask(task_id=task_id) - runner.run( - interface.begin_task, - task=task, - ) + with runner.event_pipe() as events: + LOGGER.info("Created event pipe") + runner.run(interface.begin_task, task=WorkerTask(task_id=task_id)) + async for evt in events: + LOGGER.debug("Event: %s", evt) + await ws.send_json(evt.model_dump(mode="json")) + if isinstance(evt, WorkerEvent) and evt.is_complete(): + LOGGER.info("End of stream") + break except WorkerBusyError: + LOGGER.error("Worker was busy") await ws.close(code=1013, reason="Worker busy") - return - # pipe events to ws - try: - while True: - event: AnyEvent = await run_in_threadpool(rx.recv) - await ws.send_json(event.model_dump(mode="json")) - if isinstance(event, WorkerEvent) and event.is_complete(): - break - finally: + except WebSocketDisconnect: + LOGGER.info("Client disconnected") + runner.run( + interface.cancel_active_task, failure=True, reason="Client disconnected" + ) + else: + LOGGER.info("Plan complete") await ws.close() - runner.run(interface.unpipe_events, hnd=h) @start_as_current_span(TRACER, "config") diff --git a/src/blueapi/service/runner.py b/src/blueapi/service/runner.py index 2b5a5f37f5..6b34c6ad17 100644 --- a/src/blueapi/service/runner.py +++ b/src/blueapi/service/runner.py @@ -1,10 +1,12 @@ +import asyncio import inspect import logging import signal import uuid -from collections.abc import Callable +from collections.abc import AsyncIterator, Callable from importlib import import_module from multiprocessing import Pool, set_start_method +from multiprocessing.connection import Connection, Pipe from multiprocessing.pool import Pool as PoolClass from typing import Any, ParamSpec, TypeVar @@ -15,11 +17,13 @@ ) from opentelemetry.context import attach from opentelemetry.propagate import get_global_textmap -from pydantic import TypeAdapter from blueapi.config import ApplicationConfig -from blueapi.service.interface import setup, teardown +from blueapi.core.bluesky_types import DataEvent +from blueapi.service import interface +from blueapi.service.interface import SubHandles, setup, teardown from blueapi.service.model import EnvironmentResponse +from blueapi.worker.event import ProgressEvent, WorkerEvent # The default multiprocessing start method is fork set_start_method("spawn", force=True) @@ -145,11 +149,57 @@ def run( kwargs, ) + def event_pipe(self): + return EventPipe(self) + @property def state(self) -> EnvironmentResponse: return self._state +class EventStream: + def __init__(self, rx: Connection): + self._rx = rx + + def __aiter__(self) -> AsyncIterator[WorkerEvent | DataEvent | ProgressEvent]: + return self + + async def __anext__(self) -> WorkerEvent | DataEvent | ProgressEvent: + data_available = asyncio.Event() + asyncio.get_event_loop().add_reader(self._rx.fileno(), data_available.set) + try: + while not self._rx.poll(): + await data_available.wait() + data_available.clear() + return self._rx.recv() + except BrokenPipeError: + raise StopAsyncIteration() from None + finally: + asyncio.get_event_loop().remove_reader(self._rx.fileno()) + + +class EventPipe: + runner: WorkerDispatcher + handles: list[tuple[SubHandles, Connection]] + + def __init__(self, runner: WorkerDispatcher): + self.runner = runner + self.handles = [] + + def __enter__(self) -> EventStream: + tx, rx = Pipe() + hnd = self.runner.run(interface.pipe_events, tx) + LOGGER.debug("Subscribing new event pipe: %s", hnd) + self.handles.append((hnd, tx)) + return EventStream(rx) + + def __exit__(self, *exc): + hnd, conn = self.handles.pop() + LOGGER.debug("Unsubscribing event pipe: %s", hnd) + conn.close() + self.runner.run(interface.unpipe_events, hnd) + + class InvalidRunnerStateError(Exception): def __init__(self, message): super().__init__(message) @@ -173,15 +223,7 @@ def import_and_run_function( func: Callable[..., T] = _validate_function( mod.__dict__.get(function_name, None), function_name ) - value = func(*args, **kwargs) - return _valid_return(value, expected_type) - - -def _valid_return(value: Any, expected_type: type[T] | None = None) -> T: - if expected_type is None: - return value - else: - return TypeAdapter(expected_type).validate_python(value) + return func(*args, **kwargs) def _validate_function(func: Any, function_name: str) -> Callable: From 347595fb9b1f2ef47686adc3b6d5955f732d6da1 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 12 Mar 2026 10:37:12 +0000 Subject: [PATCH 15/21] Testing auth tokens --- src/blueapi/client/rest.py | 8 +++++++- src/blueapi/service/main.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 8f3670b43d..8e592cfe7b 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -312,7 +312,13 @@ def _request_and_deserialize( def run_blocking(self, req: TaskRequest): url = self._ws_address().unicode_string().removesuffix("/") + "/run_plan" - with connect(url) as ws: + with connect( + url, + additional_headers={ + "Cookie": "Authorization=Bearer cook", + "Authorization": "Bearer head", + }, + ) as ws: ws.send(req.model_dump_json()) for message in ws: event = TypeAdapter(AnyEvent).validate_json(message) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index c90cdf4c95..c54affdfc4 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -9,8 +9,10 @@ APIRouter, BackgroundTasks, Body, + Cookie, Depends, FastAPI, + Header, HTTPException, Request, Response, @@ -107,6 +109,7 @@ async def inner(app: FastAPI): secure_router = APIRouter() +ws_router = APIRouter() open_router = APIRouter() @@ -122,12 +125,15 @@ def get_app(config: ApplicationConfig): openapi_tags=ApplicationConfig.TAG_METADATA, ) dependencies = [] + ws_dependencies = [] if config.oidc: dependencies.append(Depends(decode_access_token(config.oidc))) + ws_dependencies.append(Depends(init_ws_auth(config.oidc))) app.swagger_ui_init_oauth = { "clientId": "NOT_SUPPORTED", } app.include_router(open_router) + app.include_router(ws_router, dependencies=ws_dependencies) app.include_router(secure_router, dependencies=dependencies) app.add_exception_handler(KeyError, on_key_error_404) app.add_exception_handler(jwt.PyJWTError, on_token_error_401) @@ -168,6 +174,24 @@ def inner(request: Request, access_token: str = Depends(oauth_scheme)): return inner +def init_ws_auth(oidc_config: OIDCConfig): + LOGGER.info("Creating ws auth dependency") + + async def inner( + ws: WebSocket, + auth_header: str | None = Header(alias="authorization", default=None), + auth_cookie: str | None = Cookie(default=None, alias="Authorization"), + ): + print(auth_header) + print(auth_cookie) + print(ws.headers) + print(ws.cookies) + await ws.accept() + LOGGER.info("Authenticating websocket") + + return inner + + TRACER = get_tracer("interface") @@ -548,7 +572,7 @@ def logout(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> Response: ) -@secure_router.websocket("/run_plan") +@ws_router.websocket("/run_plan") async def run_plan( ws: WebSocket, runner: Annotated[WorkerDispatcher, Depends(_runner)], @@ -556,7 +580,7 @@ async def run_plan( user = "alice" LOGGER.info("Starting WS plan") - await ws.accept() + # await ws.accept() rq = await ws.receive_json() LOGGER.info("Raw request: %s", rq) try: From cd88dd9e66cdfe4c8406812c386bac5d55ff7664 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 12 Mar 2026 14:03:49 +0000 Subject: [PATCH 16/21] Re-use existing auth dependency for websocket endpoint --- src/blueapi/service/authentication.py | 17 ++++++++++++ src/blueapi/service/main.py | 39 ++++++--------------------- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index c0a32b26f4..4452591777 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -14,6 +14,9 @@ import httpx import jwt import requests +from fastapi.requests import HTTPConnection +from fastapi.security import OAuth2AuthorizationCodeBearer +from fastapi.security.utils import get_authorization_scheme_param from pydantic import TypeAdapter from requests.auth import AuthBase @@ -266,3 +269,17 @@ def get_access_token(self): def sync_auth_flow(self, request): request.headers["Authorization"] = f"Bearer {self.get_access_token()}" yield request + + +class CommonHttpOAuth(OAuth2AuthorizationCodeBearer): + """Extended version of OAuth2 Auth to work with both WebSockets and HTTP Requests""" + + async def __call__(self, request: HTTPConnection) -> str | None: + authorization = request.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() != "bearer": + if self.auto_error: + raise self.make_not_authenticated_error() + else: + return None # pragma: nocover + return param diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index c54affdfc4..736e13d4be 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -9,10 +9,8 @@ APIRouter, BackgroundTasks, Body, - Cookie, Depends, FastAPI, - Header, HTTPException, Request, Response, @@ -22,8 +20,8 @@ ) from fastapi.datastructures import Address from fastapi.middleware.cors import CORSMiddleware +from fastapi.requests import HTTPConnection from fastapi.responses import RedirectResponse, StreamingResponse -from fastapi.security import OAuth2AuthorizationCodeBearer from observability_utils.tracing import ( add_span_attributes, get_tracer, @@ -42,6 +40,7 @@ from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.core.bluesky_types import DataEvent from blueapi.service import interface +from blueapi.service.authentication import CommonHttpOAuth from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent from blueapi.worker.worker_errors import WorkerBusyError @@ -109,7 +108,6 @@ async def inner(app: FastAPI): secure_router = APIRouter() -ws_router = APIRouter() open_router = APIRouter() @@ -125,15 +123,12 @@ def get_app(config: ApplicationConfig): openapi_tags=ApplicationConfig.TAG_METADATA, ) dependencies = [] - ws_dependencies = [] if config.oidc: dependencies.append(Depends(decode_access_token(config.oidc))) - ws_dependencies.append(Depends(init_ws_auth(config.oidc))) app.swagger_ui_init_oauth = { "clientId": "NOT_SUPPORTED", } app.include_router(open_router) - app.include_router(ws_router, dependencies=ws_dependencies) app.include_router(secure_router, dependencies=dependencies) app.add_exception_handler(KeyError, on_key_error_404) app.add_exception_handler(jwt.PyJWTError, on_token_error_401) @@ -153,13 +148,13 @@ def get_app(config: ApplicationConfig): def decode_access_token(config: OIDCConfig): jwkclient = jwt.PyJWKClient(config.jwks_uri) - oauth_scheme = OAuth2AuthorizationCodeBearer( + oauth_scheme = CommonHttpOAuth( authorizationUrl=config.authorization_endpoint, tokenUrl=config.token_endpoint, refreshUrl=config.token_endpoint, ) - def inner(request: Request, access_token: str = Depends(oauth_scheme)): + def inner(request: HTTPConnection, access_token: str = Depends(oauth_scheme)): signing_key = jwkclient.get_signing_key_from_jwt(access_token) decoded: dict[str, Any] = jwt.decode( access_token, @@ -174,24 +169,6 @@ def inner(request: Request, access_token: str = Depends(oauth_scheme)): return inner -def init_ws_auth(oidc_config: OIDCConfig): - LOGGER.info("Creating ws auth dependency") - - async def inner( - ws: WebSocket, - auth_header: str | None = Header(alias="authorization", default=None), - auth_cookie: str | None = Cookie(default=None, alias="Authorization"), - ): - print(auth_header) - print(auth_cookie) - print(ws.headers) - print(ws.cookies) - await ws.accept() - LOGGER.info("Authenticating websocket") - - return inner - - TRACER = get_tracer("interface") @@ -572,15 +549,15 @@ def logout(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> Response: ) -@ws_router.websocket("/run_plan") +@secure_router.websocket("/run_plan") async def run_plan( ws: WebSocket, runner: Annotated[WorkerDispatcher, Depends(_runner)], ): - user = "alice" + user = ws.state.decoded_access_token["fedid"] - LOGGER.info("Starting WS plan") - # await ws.accept() + LOGGER.info("Starting WS plan as %s", user) + await ws.accept() rq = await ws.receive_json() LOGGER.info("Raw request: %s", rq) try: From 8c3576b506d9e1f1e0dd35b3af12d3939ab5bea6 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 12 Mar 2026 14:04:34 +0000 Subject: [PATCH 17/21] Add user auth token in websocket client --- src/blueapi/client/rest.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 8e592cfe7b..5fcc5a0bb5 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -312,13 +312,11 @@ def _request_and_deserialize( def run_blocking(self, req: TaskRequest): url = self._ws_address().unicode_string().removesuffix("/") + "/run_plan" - with connect( - url, - additional_headers={ - "Cookie": "Authorization=Bearer cook", - "Authorization": "Bearer head", - }, - ) as ws: + headers = {} + if self._session_manager: + auth = self._session_manager.get_valid_access_token() + headers["Authorization"] = f"Bearer {auth}" + with connect(url, additional_headers=headers) as ws: ws.send(req.model_dump_json()) for message in ws: event = TypeAdapter(AnyEvent).validate_json(message) From 3032d22967d5238934989a871a68e0dffafb31d4 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 12 Mar 2026 16:39:56 +0000 Subject: [PATCH 18/21] Read authorization from cookie as well as header --- src/blueapi/service/authentication.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index 4452591777..ee1ed63146 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -14,6 +14,7 @@ import httpx import jwt import requests +from fastapi import Cookie, Header from fastapi.requests import HTTPConnection from fastapi.security import OAuth2AuthorizationCodeBearer from fastapi.security.utils import get_authorization_scheme_param @@ -274,8 +275,13 @@ def sync_auth_flow(self, request): class CommonHttpOAuth(OAuth2AuthorizationCodeBearer): """Extended version of OAuth2 Auth to work with both WebSockets and HTTP Requests""" - async def __call__(self, request: HTTPConnection) -> str | None: - authorization = request.headers.get("Authorization") + async def __call__( + self, + request: HTTPConnection, + auth_header: str | None = Header(alias="Authorization", default=None), + auth_cookie: str | None = Cookie(alias="Authorization", default=None), + ) -> str | None: + authorization = auth_header or auth_cookie scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "bearer": if self.auto_error: From be070c4b929fdfd15fa0d322962313c4d98dabc1 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 12 Mar 2026 16:41:32 +0000 Subject: [PATCH 19/21] Add user agent to websocket request --- src/blueapi/client/rest.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 5fcc5a0bb5..bfa7af6701 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -316,7 +316,11 @@ def run_blocking(self, req: TaskRequest): if self._session_manager: auth = self._session_manager.get_valid_access_token() headers["Authorization"] = f"Bearer {auth}" - with connect(url, additional_headers=headers) as ws: + with connect( + url, + additional_headers=headers, + user_agent_header="blueapi cli", + ) as ws: ws.send(req.model_dump_json()) for message in ws: event = TypeAdapter(AnyEvent).validate_json(message) From 420e73f4e165046e24e5e11b27293144342f61ac Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 1 Apr 2026 18:37:02 +0100 Subject: [PATCH 20/21] Add websocket support to middleware --- src/blueapi/service/main.py | 19 ++++----- src/blueapi/service/middleware.py | 60 +++++++++++++++++++++++++++ tests/unit_tests/service/test_main.py | 4 +- 3 files changed, 69 insertions(+), 14 deletions(-) create mode 100644 src/blueapi/service/middleware.py diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 736e13d4be..12ba66786b 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -36,11 +36,14 @@ from starlette.responses import JSONResponse from super_state_machine.errors import TransitionError -from blueapi import __version__ from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.core.bluesky_types import DataEvent from blueapi.service import interface from blueapi.service.authentication import CommonHttpOAuth +from blueapi.service.middleware import ( + ObservabilityContextPropagator, + VersionHeaders, +) from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent from blueapi.worker.worker_errors import WorkerBusyError @@ -132,8 +135,9 @@ def get_app(config: ApplicationConfig): app.include_router(secure_router, dependencies=dependencies) app.add_exception_handler(KeyError, on_key_error_404) app.add_exception_handler(jwt.PyJWTError, on_token_error_401) - app.middleware("http")(add_version_headers) - app.middleware("http")(inject_propagated_observability_context) + + app.add_middleware(ObservabilityContextPropagator) + app.add_middleware(VersionHeaders) app.middleware("http")(log_request_details) if config.api.cors: app.add_middleware( @@ -625,15 +629,6 @@ def start(config: ApplicationConfig): ) -async def add_version_headers( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -): - response = await call_next(request) - response.headers["X-API-Version"] = ApplicationConfig.REST_API_VERSION - response.headers["X-BlueAPI-Version"] = __version__ - return response - - async def log_request_details( request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]] ) -> Response: diff --git a/src/blueapi/service/middleware.py b/src/blueapi/service/middleware.py new file mode 100644 index 0000000000..45ef033613 --- /dev/null +++ b/src/blueapi/service/middleware.py @@ -0,0 +1,60 @@ +import logging + +from opentelemetry.context import attach +from opentelemetry.propagate import get_global_textmap +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +from blueapi import __version__ +from blueapi.config import ApplicationConfig + +OBS_LOGGER = logging.getLogger("blueapi.service.middleware.obs") +VER_LOGGER = logging.getLogger("blueapi.service.middleware.version") + +CONTEXT_HEADER = ApplicationConfig.CONTEXT_HEADER.encode() +VENDOR_CONTEXT_HEADER = ApplicationConfig.VENDOR_CONTEXT_HEADER.encode() + +API_VERSION = (b"x-api-version", ApplicationConfig.REST_API_VERSION.encode("utf-8")) +VERSION = (b"x-blueapi-version", __version__.encode("utf-8")) + + +class VersionHeaders: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope.get("type") not in ("websocket", "http"): + return await self.app(scope, receive, send) + + async def local_send(message: Message): + VER_LOGGER.info("message: %s", message) + if message["type"] in ("websocket.accept", "http.response.start"): + message["headers"].append(VERSION) + message["headers"].append(API_VERSION) + await send(message) + + await self.app(scope, receive, local_send) + + +class ObservabilityContextPropagator: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] not in ("http", "websocket"): + return await self.app(scope, receive, send) + + ctx = None + v_ctx = None + for key, val in scope.get("headers", ()): + if key == CONTEXT_HEADER: + ctx = val.decode() + elif key == VENDOR_CONTEXT_HEADER: + v_ctx = val.decode() + if ctx: + OBS_LOGGER.debug("Propagating observability context: %s, %s", ctx, v_ctx) + carrier = {ApplicationConfig.CONTEXT_HEADER: ctx} + if v_ctx: + carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = v_ctx + attach(get_global_textmap().extract(carrier)) + + await self.app(scope, receive, send) diff --git a/tests/unit_tests/service/test_main.py b/tests/unit_tests/service/test_main.py index 4a4bcca634..fb689985e6 100644 --- a/tests/unit_tests/service/test_main.py +++ b/tests/unit_tests/service/test_main.py @@ -8,15 +8,15 @@ from blueapi import __version__ from blueapi.config import ApplicationConfig from blueapi.service.main import ( - add_version_headers, get_passthrough_headers, log_request_details, ) +from blueapi.service.middleware import VersionHeaders async def test_add_version_header(): app = FastAPI() - app.middleware("http")(add_version_headers) + app.add_middleware(VersionHeaders) @app.get("/") async def root(): From 2371239e3574dbe672830ebb9649beaec5f3d540 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 7 Apr 2026 11:19:49 +0100 Subject: [PATCH 21/21] Add user agent to all requests --- src/blueapi/client/rest.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index bfa7af6701..81721eee30 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -39,6 +39,8 @@ LOGGER = logging.getLogger(__name__) +USER_AGENT = f"blueapi cli {__version__}" + class UnauthorisedAccessError(Exception): pass @@ -277,14 +279,15 @@ def _request_and_deserialize( ) -> T: url = self._config.url.unicode_string().removesuffix("/") + suffix # Get the trace context to propagate to the REST API - carr = get_context_propagator() + headers = get_context_propagator() + headers["User-Agent"] = USER_AGENT try: response = self._pool.request( method, url, json=data, params=params, - headers=carr, + headers=headers, auth=JWTAuth(self._session_manager), ) except requests.exceptions.ConnectionError as ce: @@ -312,14 +315,14 @@ def _request_and_deserialize( def run_blocking(self, req: TaskRequest): url = self._ws_address().unicode_string().removesuffix("/") + "/run_plan" - headers = {} + headers = get_context_propagator() if self._session_manager: auth = self._session_manager.get_valid_access_token() headers["Authorization"] = f"Bearer {auth}" with connect( url, additional_headers=headers, - user_agent_header="blueapi cli", + user_agent_header=USER_AGENT, ) as ws: ws.send(req.model_dump_json()) for message in ws: