From dd0167b9a5f77d3620411b849c4e55c2ad038b22 Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Wed, 6 Dec 2023 17:53:35 -0500 Subject: [PATCH] Make `fsp.cascade()` expect src/dst `Flume`s Been meaning to this for a while, and there's still a few design / interface kinks (like `.mkt: MktPair` which should be better generalized?) but this flips over all of the fsp chaining engine to operate on the higher level `Flume` APIs via the newly cobbled `Cascade` thinger.. --- piker/fsp/__init__.py | 13 +++- piker/fsp/_engine.py | 172 +++++++++++++++++++++++------------------- piker/ui/_fsp.py | 24 +++--- 3 files changed, 116 insertions(+), 93 deletions(-) diff --git a/piker/fsp/__init__.py b/piker/fsp/__init__.py index e463ac26..0651069e 100644 --- a/piker/fsp/__init__.py +++ b/piker/fsp/__init__.py @@ -26,7 +26,10 @@ from ._api import ( maybe_mk_fsp_shm, Fsp, ) -from ._engine import cascade +from ._engine import ( + cascade, + Cascade, +) from ._volume import ( dolla_vlm, flow_rates, @@ -35,6 +38,7 @@ from ._volume import ( __all__: list[str] = [ 'cascade', + 'Cascade', 'maybe_mk_fsp_shm', 'Fsp', 'dolla_vlm', @@ -46,9 +50,12 @@ __all__: list[str] = [ async def latency( source: 'TickStream[Dict[str, float]]', # noqa ohlcv: np.ndarray + ) -> AsyncIterator[np.ndarray]: - """Latency measurements, broker to piker. - """ + ''' + Latency measurements, broker to piker. + + ''' # TODO: do we want to offer yielding this async # before the rt data connection comes up? diff --git a/piker/fsp/_engine.py b/piker/fsp/_engine.py index 29b93631..acc7309e 100644 --- a/piker/fsp/_engine.py +++ b/piker/fsp/_engine.py @@ -24,8 +24,6 @@ from functools import partial from typing import ( AsyncIterator, Callable, - Optional, - Union, ) import numpy as np @@ -37,7 +35,6 @@ from tractor.msg import NamespacePath from piker.types import Struct from ..log import get_logger, get_console_log from .. import data -from ..data import attach_shm_array from ..data.feed import ( Flume, Feed, @@ -117,8 +114,8 @@ class Cascade(Struct): ''' # TODO: make these `Flume`s - src: ShmArray - dst: ShmArray + src: Flume + dst: Flume tn: trio.Nursery fsp: Fsp # UI-side middleware ctl API @@ -139,11 +136,12 @@ class Cascade(Struct): # always trigger UI refresh after history update, # see ``piker.ui._fsp.FspAdmin.open_chain()`` and # ``piker.ui._display.trigger_update()``. + dst_shm: ShmArray = self.dst.rt_shm await self.client_stream.send({ 'fsp_update': { - 'key': self.dst.token, - 'first': self.dst._first.value, - 'last': self.dst._last.value, + 'key': dst_shm.token, + 'first': dst_shm._first.value, + 'last': dst_shm._last.value, } }) return index @@ -154,10 +152,10 @@ class Cascade(Struct): output array is aligned to its source array. ''' - src: ShmArray = self.src - dst: ShmArray = self.dst - step_diff = src.index - dst.index - len_diff = abs(len(src.array) - len(dst.array)) + src_shm: ShmArray = self.src.rt_shm + dst_shm: ShmArray = self.dst.rt_shm + step_diff = src_shm.index - dst_shm.index + len_diff = abs(len(src_shm.array) - len(dst_shm.array)) synced: bool = not ( # the source is likely backfilling and we must # sync history calculations @@ -172,7 +170,7 @@ class Cascade(Struct): fsp: Fsp = self.fsp log.warning( '***DESYNCED FSP***\n' - f'{fsp.ns_path}@{src.token}\n' + f'{fsp.ns_path}@{src_shm.token}\n' f'step_diff: {step_diff}\n' f'len_diff: {len_diff}\n' ) @@ -183,10 +181,10 @@ class Cascade(Struct): ) async def poll_and_sync_to_step(self) -> int: - synced, step_diff, _ = self.is_synced() #src, dst) + synced, step_diff, _ = self.is_synced() while not synced: await self.resync() - synced, step_diff, _ = self.is_synced() #src, dst) + synced, step_diff, _ = self.is_synced() return step_diff @@ -203,16 +201,13 @@ class Cascade(Struct): async def connect_streams( - casc: Cascade, mkt: MktPair, - flume: Flume, quote_stream: trio.abc.ReceiveChannel, + src: Flume, + dst: Flume, - src: ShmArray, - dst: ShmArray, - - func: Callable, + edge_func: Callable, # attach_stream: bool = False, task_status: TaskStatus[None] = trio.TASK_STATUS_IGNORED, @@ -226,7 +221,7 @@ async def connect_streams( Not literally, but something like: - func(Flume_in) -> Flume_out + edge_func(Flume_in) -> Flume_out ''' profiler = Profiler( @@ -234,12 +229,14 @@ async def connect_streams( disabled=True ) - fqme: str = mkt.fqme + # TODO: just pull it from src.mkt.fqme no? + # fqme: str = mkt.fqme + fqme: str = src.mkt.fqme # TODO: dynamic introspection of what the underlying (vertex) # function actually requires from input node (flumes) then # deliver those inputs as part of a graph "compilation" step? - out_stream = func( + out_stream = edge_func( # TODO: do we even need this if we do the feed api right? # shouldn't a local stream do this before we get a handle @@ -249,19 +246,19 @@ async def connect_streams( # XXX: currently the ``ohlcv`` arg, but we should allow # (dynamic) requests for src flume (node) streams? - flume.rt_shm, + src.rt_shm, ) # HISTORY COMPUTE PHASE # conduct a single iteration of fsp with historical bars input # and get historical output. - history_output: Union[ - dict[str, np.ndarray], # multi-output case - np.ndarray, # single output case - ] + history_output: ( + dict[str, np.ndarray] # multi-output case + | np.ndarray, # single output case + ) history_output = await anext(out_stream) - func_name = func.__name__ + func_name = edge_func.__name__ profiler(f'{func_name} generated history') # build struct array with an 'index' field to push as history @@ -269,10 +266,12 @@ async def connect_streams( # TODO: push using a[['f0', 'f1', .., 'fn']] = .. syntax no? # if the output array is multi-field then push # each respective field. - fields = getattr(dst.array.dtype, 'fields', None).copy() + dst_shm: ShmArray = dst.rt_shm + fields = getattr(dst_shm.array.dtype, 'fields', None).copy() fields.pop('index') - history_by_field: Optional[np.ndarray] = None - src_time = src.array['time'] + history_by_field: np.ndarray | None = None + src_shm: ShmArray = src.rt_shm + src_time = src_shm.array['time'] if ( fields and @@ -291,7 +290,7 @@ async def connect_streams( if history_by_field is None: if output is None: - length = len(src.array) + length = len(src_shm.array) else: length = len(output) @@ -300,7 +299,7 @@ async def connect_streams( # will be pushed to shm. history_by_field = np.zeros( length, - dtype=dst.array.dtype + dtype=dst_shm.array.dtype ) if output is None: @@ -317,13 +316,13 @@ async def connect_streams( ) history_by_field = np.zeros( len(history_output), - dtype=dst.array.dtype + dtype=dst_shm.array.dtype ) history_by_field[func_name] = history_output history_by_field['time'] = src_time[-len(history_by_field):] - history_output['time'] = src.array['time'] + history_output['time'] = src_shm.array['time'] # TODO: XXX: # THERE'S A BIG BUG HERE WITH THE `index` field since we're @@ -336,11 +335,11 @@ async def connect_streams( # is `index` aware such that historical data can be indexed # relative to the true first datum? Not sure if this is sane # for incremental compuations. - first = dst._first.value = src._first.value + first = dst_shm._first.value = src_shm._first.value # TODO: can we use this `start` flag instead of the manual # setting above? - index = dst.push( + index = dst_shm.push( history_by_field, start=first, ) @@ -367,12 +366,12 @@ async def connect_streams( log.debug(f"{func_name}: {processed}") key, output = processed # dst.array[-1][key] = output - dst.array[[key, 'time']][-1] = ( + dst_shm.array[[key, 'time']][-1] = ( output, # TODO: what about pushing ``time.time_ns()`` # in which case we'll need to round at the graphics # processing / sampling layer? - src.array[-1]['time'] + src_shm.array[-1]['time'] ) # NOTE: for now we aren't streaming this to the consumer @@ -384,7 +383,7 @@ async def connect_streams( # N-consumers who subscribe for the real-time output, # which we'll likely want to implement using local-mem # chans for the fan out? - # index = src.index + # index = src_shm.index # if attach_stream: # await client_stream.send(index) @@ -405,16 +404,15 @@ async def cascade( # data feed key fqme: str, - # TODO: expect and attach from `Flume.to_msg()`s! - src_shm_token: dict, - dst_shm_token: tuple[str, np.dtype], - + # flume pair cascaded using an "edge function" + src_flume_addr: dict, + dst_flume_addr: dict, ns_path: NamespacePath, shm_registry: dict[str, _Token], zero_on_step: bool = False, - loglevel: Optional[str] = None, + loglevel: str | None = None, ) -> None: ''' @@ -430,8 +428,14 @@ async def cascade( if loglevel: get_console_log(loglevel) - src: ShmArray = attach_shm_array(token=src_shm_token) - dst: ShmArray = attach_shm_array(readonly=False, token=dst_shm_token) + src: Flume = Flume.from_msg(src_flume_addr) + dst: Flume = Flume.from_msg( + dst_flume_addr, + readonly=False, + ) + + # src: ShmArray = attach_shm_array(token=src_shm_token) + # dst: ShmArray = attach_shm_array(readonly=False, token=dst_shm_token) reg = _load_builtins() lines = '\n'.join([f'{key.rpartition(":")[2]} => {key}' for key in reg]) @@ -439,11 +443,11 @@ async def cascade( f'Registered FSP set:\n{lines}' ) - # update actorlocal flows table which registers - # readonly "instances" of this fsp for symbol/source - # so that consumer fsps can look it up by source + fsp. - # TODO: ugh i hate this wind/unwind to list over the wire - # but not sure how else to do it. + # NOTE XXX: update actorlocal flows table which registers + # readonly "instances" of this fsp for symbol/source so that + # consumer fsps can look it up by source + fsp. + # TODO: ugh i hate this wind/unwind to list over the wire but + # not sure how else to do it. for (token, fsp_name, dst_token) in shm_registry: Fsp._flow_registry[( _Token.from_msg(token), @@ -459,6 +463,9 @@ async def cascade( # TODO: assume it's a func target path raise ValueError(f'Unknown fsp target: {ns_path}') + _fqme: str = src.mkt.fqme + assert _fqme == fqme + # open a data feed stream with requested broker feed: Feed async with data.feed.maybe_open_feed( @@ -472,12 +479,21 @@ async def cascade( ) as feed: - flume = feed.flumes[fqme] - mkt = flume.mkt + flume: Flume = feed.flumes[fqme] + # XXX: can't do this since flume.feed will be set XD + # assert flume == src + assert flume.mkt == src.mkt + mkt: MktPair = flume.mkt + + # NOTE: FOR NOW, sanity checks around the feed as being + # always the src flume (until we get to fancier/lengthier + # chains/graphs. + assert src.rt_shm.token == flume.rt_shm.token + + # XXX: won't work bc the _hist_shm_token value will be + # list[list] after IPC.. + # assert flume.to_msg() == src_flume_addr - # TODO: make an equivalent `Flume` around the Fsp output - # streams and chain them using a `Cascade` Bo - assert src.token == flume.rt_shm.token profiler(f'{func}: feed up') func_name: str = func.__name__ @@ -497,34 +513,34 @@ async def cascade( # TODO: this seems like it should be wrapped somewhere? fsp_target = partial( - connect_streams, casc=casc, mkt=mkt, - flume=flume, quote_stream=flume.stream, - # shm + # flumes and shm passthrough src=src, dst=dst, # chain function which takes src flume input(s) # and renders dst flume output(s) - func=func + edge_func=func ) async with casc.open_edge( bind_func=fsp_target, ) as index: # casc.bind_func = fsp_target # index = await tn.start(fsp_target) + dst_shm: ShmArray = dst.rt_shm + src_shm: ShmArray = src.rt_shm if zero_on_step: - last = dst.array[-1:] + last = dst.rt_shm.array[-1:] zeroed = np.zeros(last.shape, dtype=last.dtype) profiler(f'{func_name}: fsp up') - # sync client + # sync to client-side actor await ctx.started(index) # XXX: rt stream with client which we MUST @@ -532,24 +548,26 @@ async def cascade( # incremental "updates" as history prepends take # place. async with ctx.open_stream() as client_stream: - casc.client_stream = client_stream + casc.client_stream: tractor.MsgStream = client_stream - s, step, ld = casc.is_synced() #src, dst) + s, step, ld = casc.is_synced() # detect sample period step for subscription to increment # signal - times = src.array['time'] + times = src.rt_shm.array['time'] if len(times) > 1: last_ts = times[-1] - delay_s = float(last_ts - times[times != last_ts][-1]) + delay_s: float = float(last_ts - times[times != last_ts][-1]) else: # our default "HFT" sample rate. - delay_s = _default_delay_s + delay_s: float = _default_delay_s # sub and increment the underlying shared memory buffer # on every step msg received from the global `samplerd` # service. - async with open_sample_stream(float(delay_s)) as istream: + async with open_sample_stream( + float(delay_s) + ) as istream: profiler(f'{func_name}: sample stream up') profiler.finish() @@ -560,7 +578,7 @@ async def cascade( # respawn the compute task if the source # array has been updated such that we compute # new history from the (prepended) source. - synced, step_diff, _ = casc.is_synced() #src, dst) + synced, step_diff, _ = casc.is_synced() if not synced: step_diff: int = await casc.poll_and_sync_to_step() @@ -570,7 +588,7 @@ async def cascade( continue # read out last shm row, copy and write new row - array = dst.array + array = dst_shm.array # some metrics like vlm should be reset # to zero every step. @@ -579,14 +597,14 @@ async def cascade( else: last = array[-1:].copy() - dst.push(last) + dst.rt_shm.push(last) # sync with source buffer's time step - src_l2 = src.array[-2:] + src_l2 = src_shm.array[-2:] src_li, src_lt = src_l2[-1][['index', 'time']] src_2li, src_2lt = src_l2[-2][['index', 'time']] - dst._array['time'][src_li] = src_lt - dst._array['time'][src_2li] = src_2lt + dst_shm._array['time'][src_li] = src_lt + dst_shm._array['time'][src_2li] = src_2lt # last2 = dst.array[-2:] # if ( diff --git a/piker/ui/_fsp.py b/piker/ui/_fsp.py index 23cec162..58842a73 100644 --- a/piker/ui/_fsp.py +++ b/piker/ui/_fsp.py @@ -390,7 +390,7 @@ class FspAdmin: complete: trio.Event, started: trio.Event, fqme: str, - dst_fsp_flume: Flume, + dst_flume: Flume, conf: dict, target: Fsp, loglevel: str, @@ -408,16 +408,14 @@ class FspAdmin: # chaining entrypoint cascade, + # TODO: can't we just drop this and expect + # far end to read the src flume's .mkt.fqme? # data feed key fqme=fqme, - # TODO: pass `Flume.to_msg()`s here? - # mems - src_shm_token=self.flume.rt_shm.token, - dst_shm_token=dst_fsp_flume.rt_shm.token, - - # target - ns_path=ns_path, + src_flume_addr=self.flume.to_msg(), + dst_flume_addr=dst_flume.to_msg(), + ns_path=ns_path, # edge-bind-func loglevel=loglevel, zero_on_step=conf.get('zero_on_step', False), @@ -431,14 +429,14 @@ class FspAdmin: ctx.open_stream() as stream, ): - dst_fsp_flume.stream: tractor.MsgStream = stream + dst_flume.stream: tractor.MsgStream = stream # register output data self._registry[ (fqme, ns_path) ] = ( stream, - dst_fsp_flume.rt_shm, + dst_flume.rt_shm, complete ) @@ -515,7 +513,7 @@ class FspAdmin: broker='piker', _atype='fsp', ) - dst_fsp_flume = Flume( + dst_flume = Flume( mkt=mkt, _rt_shm_token=dst_shm.token, first_quote={}, @@ -543,13 +541,13 @@ class FspAdmin: complete, started, fqme, - dst_fsp_flume, + dst_flume, conf, target, loglevel, ) - return dst_fsp_flume, started + return dst_flume, started async def open_fsp_chart( self,