diff --git a/piker/brokers/questrade.py b/piker/brokers/questrade.py index 4192c2de..6c0427ed 100644 --- a/piker/brokers/questrade.py +++ b/piker/brokers/questrade.py @@ -5,7 +5,8 @@ import time from datetime import datetime from functools import partial import configparser -from typing import List, Tuple, Dict, Any +from operator import itemgetter +from typing import List, Tuple, Dict, Any, Iterator, NamedTuple import trio from async_generator import asynccontextmanager @@ -24,13 +25,19 @@ log = get_logger(__name__) _refresh_token_ep = 'https://login.questrade.com/oauth2/' _version = 'v1' -_rate_limit = 3 # queries/sec +_rate_limit = 4 # queries/sec class QuestradeError(Exception): "Non-200 OK response code" +class ContractsKey(NamedTuple): + symbol: str + id: int + expiry: datetime + + class _API: """Questrade API endpoints exposed as methods and wrapped with an http session. @@ -61,7 +68,11 @@ class _API: 'symbols', params={'ids': ids, 'names': names}) async def quotes(self, ids: str) -> dict: - return await self._request('markets/quotes', params={'ids': ids}) + quotes = (await self._request( + 'markets/quotes', params={'ids': ids}))['quotes'] + for quote in quotes: + quote['key'] = quote['symbol'] + return quotes async def candles(self, id: str, start: str, end, interval) -> dict: return await self._request(f'markets/candles/{id}', params={}) @@ -79,20 +90,19 @@ class _API: async def option_quotes( self, - contracts: Dict[int, Dict[str, dict]], + contracts: Dict[ContractsKey, Dict[int, dict]], option_ids: List[int] = [], # if you don't want them all ) -> dict: - "Retrieve option chain quotes for all option ids or by filter(s)." + """Retrieve option chain quotes for all option ids or by filter(s). + """ filters = [ { "underlyingId": int(symbol_id), "expiryDate": str(expiry), } # every expiry per symbol id - for symbol_id, expiries in contracts.items() - for expiry in expiries + for (symbol, symbol_id, expiry), bystrike in contracts.items() ] - resp = await self._sess.post( path=f'/markets/quotes/options', json={'filters': filters, 'optionIds': option_ids} @@ -111,9 +121,9 @@ class Client: self.api = _API(self._sess) self._conf = config self.access_data = {} - self.user_data = {} self._reload_config(config) - self._symbol_cache = {} + self._symbol_cache: Dict[str, int] = {} + self._contracts2expiries = {} def _reload_config(self, config=None, **kwargs): log.warn("Reloading access config data") @@ -252,8 +262,7 @@ class Client: """ t2ids = await self.tickers2ids(tickers) ids = ','.join(t2ids.values()) - results = (await self.api.quotes(ids=ids))['quotes'] - quotes = {quote['symbol']: quote for quote in results} + quotes = (await self.api.quotes(ids=ids)) # set None for all symbols not found if len(t2ids) < len(tickers): @@ -266,7 +275,7 @@ class Client: async def symbol2contracts( self, symbol: str - ) -> Tuple[int, Dict[datetime, dict]]: + ) -> Dict[Tuple[str, int, datetime], dict]: """Return option contract for the given symbol. The most useful part is the expiries which can be passed to the option @@ -274,15 +283,18 @@ class Client: """ id = int((await self.tickers2ids([symbol]))[symbol]) contracts = await self.api.option_contracts(id) - return id, { - # convert to native datetime objs for sorting - datetime.fromisoformat(item['expiryDate']): - item for item in contracts + return { + ContractsKey( + symbol=symbol, + id=id, + # convert to native datetime objs for sorting + expiry=datetime.fromisoformat(item['expiryDate'])): + item for item in contracts } async def get_all_contracts( self, - symbols: List[str], + symbols: Iterator[str], # {symbol_id: {dt_iso_contract: {strike_price: {contract_id: id}}}} ) -> Dict[int, Dict[str, Dict[int, Any]]]: """Look up all contracts for each symbol in ``symbols`` and return the @@ -293,21 +305,29 @@ class Client: per symbol) and thus the return values should be cached for use with ``option_chains()``. """ - by_id = {} + by_key = {} for symbol in symbols: - id, contracts = await self.symbol2contracts(symbol) - by_id[id] = { - dt.isoformat(timespec='microseconds'): { + contracts = await self.symbol2contracts(symbol) + # FIXME: chainPerRoot here is probably why in some UIs + # you see a second chain with a (1) suffixed; should + # probably handle this eventually. + for key, byroot in sorted( + # sort by datetime + contracts.items(), + key=lambda item: item[0].expiry + ): + by_key[ + ContractsKey( + key.symbol, + key.id, + # converting back - maybe just do this initially? + key.expiry.isoformat(timespec='microseconds'), + ) + ] = { item['strikePrice']: item for item in byroot['chainPerRoot'][0]['chainPerStrikePrice'] } - for dt, byroot in sorted( - # sort by datetime - contracts.items(), - key=lambda item: item[0] - ) - } - return by_id + return by_key async def option_chains( self, @@ -316,12 +336,14 @@ class Client: ) -> Dict[str, Dict[str, Dict[str, Any]]]: """Return option chain snap quote for each ticker in ``symbols``. """ - quotes = await self.api.option_quotes(contracts) - batch = {} - for quote in quotes: - batch.setdefault( - quote['underlying'], {} - )[quote['symbol']] = quote + batch = [] + for key, bystrike in contracts.items(): + quotes = await self.api.option_quotes({key: bystrike}) + for quote in quotes: + # index by .symbol, .expiry since that's what + # a subscriber (currently) sends initially + quote['key'] = (key[0], key[2]) + batch.extend(quotes) return batch @@ -391,15 +413,14 @@ async def get_client() -> Client: write_conf(client) -async def quoter(client: Client, tickers: List[str]): - """Stock Quoter context. +async def stock_quoter(client: Client, tickers: List[str]): + """Stock quoter context. Yeah so fun times..QT has this symbol to ``int`` id lookup system that you have to use to get any quotes. That means we try to be smart and maintain a cache of this map lazily as requests from in for new tickers/symbols. Most of the closure variables here are to deal with that. """ - @async_lifo_cache(maxsize=128) async def get_symbol_id_seq(symbols: Tuple[str]): """For each tuple ``(symbol_1, symbol_2, ... , symbol_n)`` @@ -411,6 +432,7 @@ async def quoter(client: Client, tickers: List[str]): """Query for quotes using cached symbol ids. """ if not tickers: + # don't hit the network return {} ids = await get_symbol_id_seq(tuple(tickers)) @@ -418,41 +440,88 @@ async def quoter(client: Client, tickers: List[str]): try: quotes_resp = await client.api.quotes(ids=ids) except (QuestradeError, BrokerError) as qterr: - if "Access token is invalid" in str(qterr.args[0]): - # out-of-process piker actor may have - # renewed already.. - client._reload_config() - try: - quotes_resp = await client.api.quotes(ids=ids) - except BrokerError as qterr: - if "Access token is invalid" in str(qterr.args[0]): - # TODO: this will crash when run from a sub-actor since - # STDIN can't be acquired. The right way to handle this - # is to make a request to the parent actor (i.e. - # spawner of this) to call this - # `client.ensure_access()` locally thus blocking until - # the user provides an API key on the "client side" - await client.ensure_access(force_refresh=True) - quotes_resp = await client.api.quotes(ids=ids) - else: + if "Access token is invalid" not in str(qterr.args[0]): raise + # out-of-process piker actor may have + # renewed already.. + client._reload_config() + try: + quotes_resp = await client.api.quotes(ids=ids) + except BrokerError as qterr: + if "Access token is invalid" in str(qterr.args[0]): + # TODO: this will crash when run from a sub-actor since + # STDIN can't be acquired. The right way to handle this + # is to make a request to the parent actor (i.e. + # spawner of this) to call this + # `client.ensure_access()` locally thus blocking until + # the user provides an API key on the "client side" + await client.ensure_access(force_refresh=True) + quotes_resp = await client.api.quotes(ids=ids) - # dict packing and post-processing - quotes = {} - for quote in quotes_resp['quotes']: - quotes[quote['symbol']] = quote - + # post-processing + for quote in quotes_resp: if quote.get('delay', 0) > 0: log.warn(f"Delayed quote:\n{quote}") - return quotes + return quotes_resp - # strip out unknown/invalid symbols - first_quotes_dict = await get_quote(tickers) - for symbol, quote in first_quotes_dict.items(): - if quote['low52w'] is None: - log.warn( - f"{symbol} seems to be defunct") + return get_quote + + +async def option_quoter(client: Client, tickers: List[str]): + """Option quoter context. + """ + # sanity + if isinstance(tickers[0], tuple): + datetime.fromisoformat(tickers[0][1]) + else: + log.warn(f"Ignoring option quoter call with {tickers}") + # TODO make caller always check that a quoter has been set + return + + @async_lifo_cache(maxsize=128) + async def get_contract_by_date(sym_date_pairs: Tuple[Tuple[str, str]]): + """For each tuple, + ``(symbol_date_1, symbol_date_2, ... , symbol_date_n)`` + return a contract dict. + """ + symbols = map(itemgetter(0), sym_date_pairs) + dates = map(itemgetter(1), sym_date_pairs) + contracts = await client.get_all_contracts(symbols) + selected = {} + for key, val in contracts.items(): + if key.expiry in dates: + selected[key] = val + + return selected + + async def get_quote(symbol_date_pairs): + """Query for quotes using cached symbol ids. + """ + contracts = await get_contract_by_date( + tuple(symbol_date_pairs)) + try: + quotes = await client.option_chains(contracts) + except (QuestradeError, BrokerError) as qterr: + if "Access token is invalid" not in str(qterr.args[0]): + raise + # out-of-process piker actor may have + # renewed already.. + client._reload_config() + try: + quotes = await client.option_chains(contracts) + except BrokerError as qterr: + if "Access token is invalid" in str(qterr.args[0]): + # TODO: this will crash when run from a sub-actor since + # STDIN can't be acquired. The right way to handle this + # is to make a request to the parent actor (i.e. + # spawner of this) to call this + # `client.ensure_access()` locally thus blocking until + # the user provides an API key on the "client side" + await client.ensure_access(force_refresh=True) + quotes = await client.option_chains(contracts) + + return quotes return get_quote