import io import json import asyncio import logging from abc import ABC, abstractmethod, abstractproperty from PIL import Image, UnidentifiedImageError from random import randint from decimal import Decimal from hashlib import sha256 from datetime import datetime, timedelta import msgspec from leap import CLEOS from leap.hyperion import HyperionAPI from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file from skynet.types import BodyV0, BodyV0Params from skynet.config import FrontendConfig from skynet.constants import ( MODELS, GPU_CONTRACT_ABI, HELP_TEXT, HELP_TOPICS, HELP_UNKWNOWN_PARAM, COOL_WORDS, DONATION_INFO, UNKNOWN_CMD_TEXT ) from skynet.frontend import validate_user_config_request from skynet.frontend.chatbot.db import FrontendUserDB from skynet.frontend.chatbot.types import ( BaseUser, BaseChatRoom, BaseCommands, BaseFileInput, BaseMessage ) def perform_auto_conf(config: dict) -> dict: model = MODELS[config['model']] maybe_step = model.attrs.get('step', None) if maybe_step: config['step'] = maybe_step maybe_width = model.attrs.get('width', None) if maybe_width: config['width'] = maybe_step maybe_height = model.attrs.get('height', None) if maybe_height: config['height'] = maybe_step return config def sanitize_params(params: dict) -> dict: if ( 'seed' not in params or params['seed'] is None ): params['seed'] = randint(0, 0xffffffff) s_params = {} for key, val in params.items(): if isinstance(val, Decimal): val = str(val) s_params[key] = val return s_params class RequestTimeoutError(BaseException): ... class BaseChatbot(ABC): def __init__( self, config: FrontendConfig, db: FrontendUserDB ): self.db = db self.config = config self.ipfs = AsyncIPFSHTTP(config.ipfs_url) self.cleos = CLEOS(endpoint=config.node_url) self.cleos.load_abi(config.receiver, GPU_CONTRACT_ABI) self.cleos.import_key(config.account, config.key) self.hyperion = HyperionAPI(config.hyperion_url) async def init(self): ... @abstractmethod async def run(self): ... @abstractproperty def main_group(self) -> BaseChatRoom: ... @abstractmethod async def new_msg(self, chat: BaseChatRoom, text: str, **kwargs) -> BaseMessage: ''' Send text to a chat/channel. ''' ... @abstractmethod async def reply_to(self, msg: BaseMessage, text: str, **kwargs) -> BaseMessage: ''' Reply to existing message by sending new message. ''' ... @abstractmethod async def edit_msg(self, msg: BaseMessage, text: str, **kwargs): ''' Edit an existing message. ''' ... async def create_status_msg(self, msg: BaseMessage, init_text: str, force_user: BaseUser | None = None) -> tuple[BaseUser, BaseMessage, dict]: # maybe init user user = msg.author if force_user: user = force_user user_row = await self.db.get_or_create_user(user.id) # create status msg status_msg = await self.reply_to(msg, init_text) # start tracking of request in db await self.db.new_user_request(user.id, msg.id, status_msg.id, status=init_text) return [user, status_msg, user_row] async def update_status_msg(self, msg: BaseMessage, text: str): ''' Update an existing status message, also mirrors changes on db ''' await self.db.update_user_request_by_sid(msg.id, text) await self.edit_msg(msg, text) async def append_status_msg(self, msg: BaseMessage, text: str): ''' Append text to an existing status message ''' request = await self.db.get_user_request_by_sid(msg.id) await self.update_status_msg(msg, request['status'] + text) @abstractmethod async def update_request_status_timeout(self, status_msg: BaseMessage): ''' Notify users when we timedout trying to find a matching submit ''' ... @abstractmethod async def update_request_status_step_0(self, status_msg: BaseMessage, user_msg: BaseMessage): ''' First step in request status message lifecycle, should notify which user sent the request and that we are about to broadcast the request to chain ''' ... @abstractmethod async def update_request_status_step_1(self, status_msg: BaseMessage, tx_result: dict): ''' Second step in request status message lifecycle, should notify enqueue transaction was processed by chain, and provide a link to the tx in the chain explorer ''' ... @abstractmethod async def update_request_status_step_2(self, status_msg: BaseMessage, submit_tx_hash: str): ''' Third step in request status message lifecycle, should notify matching submit transaction was found, and provide a link to the tx in the chain explorer ''' ... @abstractmethod async def update_request_status_final( self, og_msg: BaseMessage, status_msg: BaseMessage, user: BaseUser, params: BodyV0Params, inputs: list[BaseFileInput], submit_tx_hash: str, worker: str, result_url: str, result_img: bytes | None ): ''' Last step in request status message lifecycle, should delete status message and send a new message replying to the original user's message, generate the appropiate reply caption and if provided also sent the found result img ''' ... async def handle_request( self, msg: BaseMessage, force_user: BaseUser | None = None ): if msg.chat.is_private: return if ( len(msg.text) == 0 and msg.command != BaseCommands.REDO ): await self.reply_to(msg, 'empty prompt ignored.') return # maybe initialize user db row and send a new msg thats gonna # be updated throughout the request lifecycle user, status_msg, user_row = await self.create_status_msg( msg, f'started processing a {msg.command} request...', force_user=force_user) # if this is a redo msg, we attempt to get the input params from db # else use msg properties match msg.command: case BaseCommands.TXT2IMG | BaseCommands.IMG2IMG: prompt = msg.text command = msg.command inputs = msg.inputs case BaseCommands.REDO: prompt = await self.db.get_last_prompt_of(user.id) command = await self.db.get_last_method_of(user.id) inputs = await self.db.get_last_inputs_of(user.id) if not prompt: await self.reply_to(msg, 'no last prompt found, try doing a non-redo request first') return case _: await self.reply_to(msg, f'unknown request of type {msg.command}') return if ( msg.command == BaseCommands.IMG2IMG and len(inputs) == 0 ): await self.edit_msg(status_msg, 'seems you tried to do an img2img command without sending image') return # maybe apply recomended settings to this request del user_row['id'] if user_row['autoconf']: user_row = perform_auto_conf(user_row) user_row = sanitize_params(user_row) body = BodyV0( method=command, params=BodyV0Params( prompt=prompt, **user_row ) ) # publish inputs to ipfs input_cids = [] for i in inputs: await i.publish(self.ipfs, user_row) input_cids.append(i.cid) inputs_str = ','.join((i for i in input_cids)) # unless its a redo request, update db user data if command != BaseCommands.REDO: await self.db.update_user_stats( user.id, command, last_prompt=prompt, last_inputs=inputs ) await self.update_request_status_step_0(status_msg, msg) # prepare and send enqueue request request_time = datetime.now().isoformat() str_body = msgspec.json.encode(body).decode('utf-8') enqueue_receipt = await self.cleos.a_push_action( self.config.receiver, 'enqueue', [ self.config.account, str_body, inputs_str, self.config.reward, 1 ], self.config.account, key=self.cleos.private_keys[self.config.account], permission=self.config.permission ) await self.update_request_status_step_1(status_msg, enqueue_receipt) # wait and search submit request using hyperion endpoint console = enqueue_receipt['processed']['action_traces'][0]['console'] console_lines = console.split('\n') request_id = None request_hash = None if self.config.proto_version == 0: ''' v0 has req_id:nonce printed in enqueue console output to search for a result request_hash arg on submit has to match the sha256 of nonce + body + input_str ''' request_id, nonce = console_lines[-1].rstrip().split(':') request_hash = sha256( (nonce + str_body + inputs_str).encode('utf-8')).hexdigest().upper() request_id = int(request_id) elif self.config.proto_version == 1: ''' v1 uses a global unique nonce and prints it on enqueue console output to search for a result request_id arg on submit has to match the printed req_id ''' request_id = int(console_lines[-1].rstrip()) else: raise NotImplementedError worker = None submit_tx_hash = None result_cid = None for i in range(1, self.config.request_timeout + 1): try: submits = await self.hyperion.aget_actions( account=self.config.account, filter=f'{self.config.receiver}:submit', sort='desc', after=request_time ) if self.config.proto_version == 0: actions = [ action for action in submits['actions'] if action['act']['data']['request_hash'] == request_hash ] elif self.config.proto_version == 1: actions = [ action for action in submits['actions'] if action['act']['data']['request_id'] == request_id ] else: raise NotImplementedError if len(actions) > 0: action = actions[0] submit_tx_hash = action['trx_id'] data = action['act']['data'] result_cid = data['ipfs_hash'] worker = data['worker'] logging.info(f'found matching submit! tx: {submit_tx_hash} cid: {result_cid}') break except json.JSONDecodeError: if i < self.config.request_timeout: logging.error('network error while searching for submit, retry...') await asyncio.sleep(1) # if we found matching submit submit_tx_hash, worker, and result_cid will not be None if not result_cid: await self.update_request_status_timeout(status_msg) raise RequestTimeoutError await self.update_request_status_step_2(status_msg, submit_tx_hash) # attempt to get the image and send it result_link = f'https://{self.config.ipfs_domain}/ipfs/{result_cid}' get_img_response = await get_ipfs_file(result_link) result_img = None if get_img_response and get_img_response.status_code == 200: try: with Image.open(io.BytesIO(get_img_response.read())) as img: w, h = img.size if ( w > self.config.result_max_width or h > self.config.result_max_height ): max_size = (self.config.result_max_width, self.config.result_max_height) logging.warning( f'raw result is of size {img.size}, resizing to {max_size}') img.thumbnail(max_size) tmp_buf = io.BytesIO() img.save(tmp_buf, format='PNG') result_img = tmp_buf.getvalue() except UnidentifiedImageError: logging.warning(f'couldn\'t get ipfs result at {result_link}!') await self.update_request_status_final( msg, status_msg, user, body.params, inputs, submit_tx_hash, worker, result_link, result_img) await self.db.increment_generated(user.id) async def send_help(self, msg: BaseMessage): if len(msg.text) == 0: await self.reply_to(msg, HELP_TEXT) else: if msg.text in HELP_TOPICS: await self.reply_to(msg, HELP_TOPICS[msg.text]) else: await self.reply_to(msg, HELP_UNKWNOWN_PARAM) async def send_cool_words(self, msg: BaseMessage): await self.reply_to(msg, '\n'.join(COOL_WORDS)) async def get_queue(self, msg: BaseMessage): an_hour_ago = datetime.now() - timedelta(hours=1) queue = await self.cleos.aget_table( self.config.receiver, self.config.receiver, 'queue', index_position=2, key_type='i64', sort='desc', lower_bound=int(an_hour_ago.timestamp()) ) await self.reply_to( msg, f'Requests on skynet queue: {len(queue)}') async def set_config(self, msg: BaseMessage): try: attr, val, reply_txt = validate_user_config_request(msg.text) await self.db.update_user_config(msg.author.id, attr, val) except BaseException as e: reply_txt = str(e) finally: await self.reply_to(msg, reply_txt) async def user_stats(self, msg: BaseMessage): await self.db.get_or_create_user(msg.author.id) generated, joined, role = await self.db.get_user_stats(msg.author.id) stats_str = f'generated: {generated}\n' stats_str += f'joined: {joined}\n' stats_str += f'role: {role}\n' await self.reply_to(msg, stats_str) async def donation_info(self, msg: BaseMessage): await self.reply_to(msg, DONATION_INFO) async def say(self, msg: BaseMessage): if ( msg.chat.is_private or not msg.author.is_admin ): return await self.new_msg(self.main_group, msg.text) async def echo_unknown(self, msg: BaseMessage): await self.reply_to(msg, UNKNOWN_CMD_TEXT)