diff --git a/piker/data/_sampling.py b/piker/data/_sampling.py index 61b2bd2f..f8230bd7 100644 --- a/piker/data/_sampling.py +++ b/piker/data/_sampling.py @@ -24,7 +24,6 @@ from collections import Counter import time from typing import ( TYPE_CHECKING, - Union, ) import tractor @@ -319,11 +318,10 @@ async def sample_and_broadcast( sub_key: str = broker_symbol.lower() subs: list[ tuple[ - Union[tractor.MsgStream, trio.MemorySendChannel], - tractor.Context, + tractor.MsgStream | trio.MemorySendChannel, float | None, # tick throttle in Hz ] - ] = bus._subscribers[sub_key] + ] = bus.get_subs(sub_key) # NOTE: by default the broker backend doesn't append # it's own "name" into the fqsn schema (but maybe it @@ -332,7 +330,7 @@ async def sample_and_broadcast( fqsn = f'{broker_symbol}.{brokername}' lags: int = 0 - for (stream, ctx, tick_throttle) in subs: + for (stream, tick_throttle) in subs.copy(): try: with trio.move_on_after(0.2) as cs: if tick_throttle: @@ -344,6 +342,7 @@ async def sample_and_broadcast( ) except trio.WouldBlock: overruns[sub_key] += 1 + ctx = stream._ctx chan = ctx.chan log.warning( @@ -399,9 +398,9 @@ async def sample_and_broadcast( # so far seems like no since this should all # be single-threaded. Doing it anyway though # since there seems to be some kinda race.. - bus.remove_sub( + bus.remove_subs( sub_key, - (stream, ctx, tick_throttle), + {(stream, tick_throttle)}, ) diff --git a/piker/data/feed.py b/piker/data/feed.py index 6cb25bdc..93630a13 100644 --- a/piker/data/feed.py +++ b/piker/data/feed.py @@ -21,6 +21,7 @@ This module is enabled for ``brokerd`` daemons. """ from __future__ import annotations +from collections import defaultdict from contextlib import asynccontextmanager as acm from datetime import datetime from functools import partial @@ -111,16 +112,16 @@ class _FeedsBus(Struct): task_lock: trio.StrictFIFOLock = trio.StrictFIFOLock() - _subscribers: dict[ + _subscribers: defaultdict[ str, - list[ + set[ tuple[ - Union[tractor.MsgStream, trio.MemorySendChannel], - tractor.Context, - Optional[float], # tick throttle in Hz + tractor.MsgStream | trio.MemorySendChannel, + # tractor.Context, + float | None, # tick throttle in Hz ] ] - ] = {} + ] = defaultdict(set) async def start_task( self, @@ -147,38 +148,53 @@ class _FeedsBus(Struct): # task: trio.lowlevel.Task, # ) -> bool: # ... + def get_subs( self, key: str, - ) -> list[ + ) -> set[ tuple[ Union[tractor.MsgStream, trio.MemorySendChannel], - tractor.Context, + # tractor.Context, float | None, # tick throttle in Hz ] ]: + ''' + Get the ``set`` of consumer subscription entries for the given key. + + ''' return self._subscribers[key] - def remove_sub( + def add_subs( self, key: str, - sub: tuple, - ) -> bool: + subs: set[tuple[ + tractor.MsgStream | trio.MemorySendChannel, + # tractor.Context, + float | None, # tick throttle in Hz + ]], + ) -> set[tuple]: ''' - Remove a consumer's subscription entry for the given key. + Add a ``set`` of consumer subscription entries for the given key. ''' - stream, ctx, tick_throttle = sub - subs = self.get_subs(key) - try: - subs.remove(sub) - except ValueError: - chan = ctx.chan - log.error( - f'Stream was already removed from subs!?\n' - f'{key}:' - f'{ctx.cid}@{chan.uid}' - ) + _subs = self._subscribers[key] + _subs.update(subs) + return _subs + + def remove_subs( + self, + key: str, + subs: set[tuple], + + ) -> set[tuple]: + ''' + Remove a ``set`` of consumer subscription entries for key. + + ''' + _subs = self.get_subs(key) + _subs.difference_update(subs) + return _subs _bus: _FeedsBus = None @@ -969,12 +985,6 @@ class Flume(Struct): else: yield istream - async def pause(self) -> None: - await self.stream.send('pause') - - async def resume(self) -> None: - await self.stream.send('resume') - def get_ds_info( self, ) -> tuple[float, float, float]: @@ -1308,7 +1318,7 @@ async def open_feed_bus( # the sampler subscription since the backend isn't (yet) # expected to append it's own name to the fqsn, so we filter # on keys which *do not* include that name (e.g .ib) . - bus._subscribers.setdefault(bfqsn, []) + bus._subscribers.setdefault(bfqsn, set()) # sync feed subscribers with flume handles await ctx.started( @@ -1324,7 +1334,7 @@ async def open_feed_bus( ctx.open_stream() as stream, ): - local_subs: list = [] + local_subs: dict[str, set[tuple]] = {} for fqsn, flume in flumes.items(): # re-send to trigger display loop cycle (necessary especially # when the mkt is closed and no real-time messages are @@ -1361,43 +1371,42 @@ async def open_feed_bus( # stream it's the throttle task does the work of # incrementally forwarding to the IPC stream at the throttle # rate. - sub = (send, ctx, tick_throttle) + send._ctx = ctx # mock internal ``tractor.MsgStream`` ref + sub = (send, tick_throttle) else: - sub = (stream, ctx, tick_throttle) + sub = (stream, tick_throttle) # TODO: add an api for this on the bus? # maybe use the current task-id to key the sub list that's # added / removed? Or maybe we can add a general # pause-resume by sub-key api? bfqsn = fqsn.removesuffix(f'.{brokername}') - bus_subs = bus._subscribers[bfqsn] - bus_subs.append(sub) - local_subs.append(sub) + local_subs.setdefault(bfqsn, set()).add(sub) + bus.add_subs(bfqsn, {sub}) + # sync caller with all subs registered state sub_registered.set() + uid = ctx.chan.uid try: - uid = ctx.chan.uid - # ctrl protocol for start/stop of quote streams based on UI # state (eg. don't need a stream when a symbol isn't being # displayed). async for msg in stream: if msg == 'pause': - for sub in local_subs: - if sub in bus_subs: - log.info( - f'Pausing {fqsn} feed for {uid}') - bus_subs.remove(sub) + for bfqsn, subs in local_subs.items(): + log.info( + f'Pausing {bfqsn} feed for {uid}') + bus.remove_subs(bfqsn, subs) elif msg == 'resume': - for sub in local_subs: - if sub not in bus_subs: - log.info( - f'Resuming {fqsn} feed for {uid}') - bus_subs.append(sub) + for bfqsn, subs in local_subs.items(): + log.info( + f'Resuming {bfqsn} feed for {uid}') + bus.add_subs(bfqsn, subs) + else: raise ValueError(msg) finally: @@ -1410,11 +1419,8 @@ async def open_feed_bus( cs.cancel() # drop all subs for this task from the bus - for sub in local_subs: - try: - bus._subscribers[bfqsn].remove(sub) - except ValueError: - log.warning(f'{sub} for {symbol} was already removed?') + for bfqsn, subs in local_subs.items(): + bus.remove_subs(bfqsn, subs) class Feed(Struct): @@ -1492,6 +1498,14 @@ class Feed(Struct): # def name(self) -> str: # return self.mod.name + async def pause(self) -> None: + for stream in set(self.streams.values()): + await stream.send('pause') + + async def resume(self) -> None: + for stream in set(self.streams.values()): + await stream.send('resume') + @acm async def install_brokerd_search(