From a6e921548bc4b343a88d9f3f359b03840fa8a6b3 Mon Sep 17 00:00:00 2001 From: Esmeralda Gallardo Date: Thu, 27 Oct 2022 15:05:41 -0300 Subject: [PATCH] Modified recv_task(): added functionality to restart ws after timeout, modified match msg and added new case to match in case of receiving an error. --- piker/data/_web_bs.py | 66 ++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/piker/data/_web_bs.py b/piker/data/_web_bs.py index 7a6bd5df..a597f7be 100644 --- a/piker/data/_web_bs.py +++ b/piker/data/_web_bs.py @@ -23,6 +23,7 @@ from itertools import count from types import ModuleType from typing import Any, Optional, Callable, AsyncGenerator import json +import sys import trio import trio_websocket @@ -139,7 +140,7 @@ class NoBsWs: async def open_autorecon_ws( url: str, - # TODO: proper type annot smh + # TODO: proper type cannot smh fixture: Optional[Callable] = None, ) -> AsyncGenerator[tuple[...], NoBsWs]: @@ -169,14 +170,17 @@ class JSONRPCResult(Struct): result: Optional[dict] = None error: Optional[dict] = None - @asynccontextmanager async def open_jsonrpc_session( url: str, start_id: int = 0, response_type: type = JSONRPCResult, request_type: Optional[type] = None, - request_hook: Optional[Callable] = None + request_hook: Optional[Callable] = None, + error_hook: Optional[Callable] = None, + timeout: int = 5, + timeout_hook: Optional[Callable] = None, + timeout_args: list = [], ) -> Callable[[str, dict], dict]: async with ( @@ -221,33 +225,49 @@ async def open_jsonrpc_session( ''' receives every ws message and stores it in its corresponding result field, then sets the event to wakeup original sender tasks. - also, recieves responses to requests originated from the server side. + also recieves responses to requests originated from the server side. + reconnects the tasks after timeout. ''' - async for msg in ws: - match msg: - case { - 'result': _ - }: - msg = response_type(**msg) + with trio.move_on_after(timeout) as cancel_scope: + async for msg in ws: + match msg: + case { + 'result': result, + 'id': mid, + } if res_entry := rpc_results.get(mid): - if msg.id not in rpc_results: + res_entry['result'] = response_type(**msg) + res_entry['event'].set() + + case { + 'result': _, + 'id': mid, + } if not rpc_results.get(mid): log.warning(f'Wasn\'t expecting ws msg: {json.dumps(msg, indent=4)}') - res = rpc_results.setdefault( - msg.id, - {'result': None, 'event': trio.Event()} - ) + case { + 'method': _, + 'params': _, + }: + log.debug(f'Recieved\n{msg}') + if request_hook: + await request_hook(request_type(**msg)) - res['result'] = msg - res['event'].set() + case { + 'error': error + }: + log.warning(f'Recieved\n{error}') + if error_hook: + await error_hook(response_type(**msg)) - case { - 'method': _, - 'params': _ - }: + case _: + log.warning(f'Unhandled JSON-RPC msg!?\n{msg}') - if request_hook: - await request_hook(request_type(**msg)) + if cancel_scope.cancelled_caught: + await ws._connect() + n.start_soon(recv_task) + if timeout_hook: + n.start_soon(timeout_hook, json_rpc, *timeout_args) n.start_soon(recv_task)