diff --git a/piker/data/feed.py b/piker/data/feed.py index cdd19070..bde8fe72 100644 --- a/piker/data/feed.py +++ b/piker/data/feed.py @@ -34,7 +34,6 @@ import trio from trio_typing import TaskStatus import tractor from pydantic import BaseModel -from fuzzywuzzy import process as fuzzy from ..brokers import get_brokermod from ..log import get_logger, get_console_log @@ -376,38 +375,6 @@ class Feed: yield self._trade_stream - @asynccontextmanager - async def open_symbol_search(self) -> AsyncIterator[dict]: - - open_search = getattr(self.mod, 'open_symbol_search', None) - if open_search is None: - - # just return a pure pass through searcher - async def passthru(text: str) -> Dict[str, Any]: - return text - - self.search = passthru - yield self.search - self.search = None - return - - async with self._brokerd_portal.open_context( - open_search, - ) as (ctx, cache): - - # shield here since we expect the search rpc to be - # cancellable by the user as they see fit. - async with ctx.open_stream(shield=True) as stream: - - async def search(text: str) -> Dict[str, Any]: - await stream.send(text) - return await stream.receive() - - # deliver search func to consumer - self.search = search - yield search - self.search = None - def sym_to_shm_key( broker: str, @@ -417,7 +384,7 @@ def sym_to_shm_key( # cache of brokernames to feeds -_cache: Dict[str, Feed] = {} +_cache: Dict[str, Callable] = {} _cache_lock: trio.Lock = trio.Lock() @@ -434,21 +401,60 @@ def get_multi_search() -> Callable[..., Awaitable]: async def pack_matches( brokername: str, pattern: str, + search: Callable[..., Awaitable[dict]], ) -> None: - matches[brokername] = await feed.search(pattern) + log.debug(f'Searching {brokername} for "{pattern}"') + matches[brokername] = await search(pattern) # TODO: make this an async stream? async with trio.open_nursery() as n: - for (brokername, startup_sym), feed in _cache.items(): - if feed.search: - n.start_soon(pack_matches, brokername, pattern) + for brokername, search in _cache.items(): + n.start_soon(pack_matches, brokername, pattern, search) return matches return multisearcher +@asynccontextmanager +async def open_symbol_search( + brokermod: ModuleType, + brokerd_portal: tractor._portal.Portal, +) -> AsyncIterator[dict]: + + global _cache + + open_search = getattr(brokermod, 'open_symbol_search', None) + if open_search is None: + + # just return a pure pass through searcher + async def passthru(text: str) -> Dict[str, Any]: + return text + + yield passthru + return + + async with brokerd_portal.open_context( + open_search, + ) as (ctx, cache): + + # shield here since we expect the search rpc to be + # cancellable by the user as they see fit. + async with ctx.open_stream() as stream: + + async def search(text: str) -> Dict[str, Any]: + await stream.send(text) + return await stream.receive() + + # deliver search func to consumer + try: + _cache[brokermod.name] = search + yield search + finally: + _cache.pop(brokermod.name) + + @asynccontextmanager async def open_feed( brokername: str, @@ -536,8 +542,8 @@ async def open_feed( feed._max_sample_rate = max(ohlc_sample_rates) - _cache[(brokername, sym)] = feed - - async with feed.open_symbol_search(): + if brokername in _cache: yield feed - + else: + async with open_symbol_search(mod, feed._brokerd_portal): + yield feed