mirror of https://github.com/skygpu/skynet.git
Compare commits
1 Commits
86ed291875
...
bb1c82eb66
Author | SHA1 | Date |
---|---|---|
|
bb1c82eb66 |
|
@ -210,50 +210,22 @@ def telegram(
|
|||
db_pass: str
|
||||
):
|
||||
import asyncio
|
||||
from .frontend.telegram import SkynetTelegramFrontend
|
||||
from skynet.frontend.chatbot.telegram import TelegramChatbot
|
||||
from skynet.frontend.chatbot.db import FrontendUserDB
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml()
|
||||
tg_token = config.telegram.tg_token
|
||||
|
||||
key = config.telegram.key
|
||||
account = config.telegram.account
|
||||
permission = config.telegram.permission
|
||||
node_url = config.telegram.node_url
|
||||
hyperion_url = config.telegram.hyperion_url
|
||||
|
||||
ipfs_url = config.telegram.ipfs_url
|
||||
|
||||
try:
|
||||
explorer_domain = config.telegram.explorer_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
explorer_domain = DEFAULT_EXPLORER_DOMAIN
|
||||
|
||||
try:
|
||||
ipfs_domain = config.telegram.ipfs_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
ipfs_domain = DEFAULT_IPFS_DOMAIN
|
||||
config = load_skynet_toml().telegram
|
||||
|
||||
async def _async_main():
|
||||
frontend = SkynetTelegramFrontend(
|
||||
tg_token,
|
||||
account,
|
||||
permission,
|
||||
node_url,
|
||||
hyperion_url,
|
||||
db_host, db_user, db_pass,
|
||||
ipfs_url,
|
||||
key=key,
|
||||
explorer_domain=explorer_domain,
|
||||
ipfs_domain=ipfs_domain
|
||||
)
|
||||
|
||||
async with frontend.open():
|
||||
await frontend.bot.infinity_polling()
|
||||
|
||||
async with FrontendUserDB(
|
||||
config.db_user,
|
||||
config.db_pass,
|
||||
config.db_host,
|
||||
config.db_name
|
||||
) as db:
|
||||
bot = TelegramChatbot(config, db)
|
||||
await bot.run()
|
||||
|
||||
asyncio.run(_async_main())
|
||||
|
||||
|
|
|
@ -36,6 +36,13 @@ class FrontendConfig(msgspec.Struct):
|
|||
hyperion_url: str
|
||||
ipfs_url: str
|
||||
token: str
|
||||
ipfs_domain: str = 'ipfs.skygpu.net'
|
||||
explorer_domain: str = 'explorer.skygpu.net'
|
||||
proto_version: int = 0
|
||||
reward: str = '20.0000 GPU'
|
||||
receiver: str = 'gpu.scd'
|
||||
result_max_width: int = 1280
|
||||
result_max_height: int = 1280
|
||||
|
||||
|
||||
class PinnerConfig(msgspec.Struct):
|
||||
|
|
|
@ -14,91 +14,91 @@ MODELS: dict[str, ModelDesc] = {
|
|||
'runwayml/stable-diffusion-v1-5': ModelDesc(
|
||||
short='stable',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'stabilityai/stable-diffusion-2-1-base': ModelDesc(
|
||||
short='stable2',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'snowkidy/stable-diffusion-xl-base-0.9': ModelDesc(
|
||||
short='stablexl0.9',
|
||||
mem=8.3,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
attrs={'size': {'w': 1024, 'h': 1024}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'Linaqruf/anything-v3.0': ModelDesc(
|
||||
short='hdanime',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'hakurei/waifu-diffusion': ModelDesc(
|
||||
short='waifu',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'nitrosocke/Ghibli-Diffusion': ModelDesc(
|
||||
short='ghibli',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'dallinmackay/Van-Gogh-diffusion': ModelDesc(
|
||||
short='van-gogh',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'lambdalabs/sd-pokemon-diffusers': ModelDesc(
|
||||
short='pokemon',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'Envvi/Inkpunk-Diffusion': ModelDesc(
|
||||
short='ink',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'nousr/robo-diffusion': ModelDesc(
|
||||
short='robot',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'black-forest-labs/FLUX.1-schnell': ModelDesc(
|
||||
short='flux',
|
||||
mem=24,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
attrs={'size': {'w': 1024, 'h': 1024}, 'step': 4},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'black-forest-labs/FLUX.1-Fill-dev': ModelDesc(
|
||||
short='flux-inpaint',
|
||||
mem=24,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
attrs={'size': {'w': 1024, 'h': 1024}, 'step': 28},
|
||||
tags=['inpaint']
|
||||
),
|
||||
'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': ModelDesc(
|
||||
short='stablexl-inpaint',
|
||||
mem=8.3,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
attrs={'size': {'w': 1024, 'h': 1024}, 'step': 28},
|
||||
tags=['inpaint']
|
||||
),
|
||||
'prompthero/openjourney': ModelDesc(
|
||||
short='midj',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img', 'img2img']
|
||||
),
|
||||
'stabilityai/stable-diffusion-xl-base-1.0': ModelDesc(
|
||||
short='stablexl',
|
||||
mem=8.3,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
attrs={'size': {'w': 1024, 'h': 1024}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
}
|
||||
|
@ -225,8 +225,6 @@ Noise is added to the image you use as an init image for img2img, and then the\
|
|||
|
||||
HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.'
|
||||
|
||||
GROUP_ID = -1001541979235
|
||||
|
||||
MP_ENABLED_ROLES = ['god']
|
||||
|
||||
MIN_STEP = 1
|
||||
|
|
|
@ -0,0 +1,448 @@
|
|||
import io
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from random import randint
|
||||
from decimal import Decimal
|
||||
from hashlib import sha256
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import msgspec
|
||||
from leap import CLEOS
|
||||
from leap.hyperion import HyperionAPI
|
||||
|
||||
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
|
||||
from skynet.types import BodyV0, BodyV0Params
|
||||
from skynet.config import FrontendConfig
|
||||
from skynet.constants import (
|
||||
MODELS, GPU_CONTRACT_ABI,
|
||||
HELP_TEXT,
|
||||
HELP_TOPICS,
|
||||
HELP_UNKWNOWN_PARAM,
|
||||
COOL_WORDS,
|
||||
DONATION_INFO
|
||||
)
|
||||
from skynet.frontend import validate_user_config_request
|
||||
from skynet.frontend.chatbot.db import FrontendUserDB
|
||||
from skynet.frontend.chatbot.types import (
|
||||
BaseUser,
|
||||
BaseChatRoom,
|
||||
BaseCommands,
|
||||
BaseFileInput,
|
||||
BaseMessage
|
||||
)
|
||||
|
||||
|
||||
def perform_auto_conf(config: dict) -> dict:
|
||||
model = MODELS[config['model']]
|
||||
|
||||
maybe_step = model.attrs.get('step', None)
|
||||
if maybe_step:
|
||||
config['step'] = maybe_step
|
||||
|
||||
maybe_width = model.attrs.get('width', None)
|
||||
if maybe_width:
|
||||
config['width'] = maybe_step
|
||||
|
||||
maybe_height = model.attrs.get('height', None)
|
||||
if maybe_height:
|
||||
config['height'] = maybe_step
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def sanitize_params(params: dict) -> dict:
|
||||
if 'seed' not in params:
|
||||
params['seed'] = randint(0, 0xffffffff)
|
||||
|
||||
s_params = {}
|
||||
for key, val in params.items():
|
||||
if isinstance(val, Decimal):
|
||||
val = str(val)
|
||||
|
||||
s_params[key] = val
|
||||
|
||||
return s_params
|
||||
|
||||
|
||||
class RequestTimeoutError(BaseException):
|
||||
...
|
||||
|
||||
|
||||
class BaseChatbot(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FrontendConfig,
|
||||
db: FrontendUserDB
|
||||
):
|
||||
self.db = db
|
||||
self.config = config
|
||||
self.ipfs = AsyncIPFSHTTP(config.ipfs_url)
|
||||
self.cleos = CLEOS(endpoint=config.node_url)
|
||||
self.cleos.load_abi('gpu.scd', GPU_CONTRACT_ABI)
|
||||
self.hyperion = HyperionAPI(config.hyperion_url)
|
||||
|
||||
async def init(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def run(self):
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def main_group(self) -> BaseChatRoom:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def new_msg(self, chat: BaseChatRoom, text: str, **kwargs) -> BaseMessage:
|
||||
'''
|
||||
Send text to a chat/channel.
|
||||
'''
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reply_to(self, msg: BaseMessage, text: str, **kwargs) -> BaseMessage:
|
||||
'''
|
||||
Reply to existing message by sending new message.
|
||||
'''
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def edit_msg(self, msg: BaseMessage, text: str, **kwargs):
|
||||
'''
|
||||
Edit an existing message.
|
||||
'''
|
||||
...
|
||||
|
||||
async def create_status_msg(self, msg: BaseMessage, init_text: str) -> tuple[BaseUser, BaseMessage, dict]:
|
||||
# maybe init user
|
||||
user = msg.author
|
||||
user_row = await self.db.get_or_create_user(user.id)
|
||||
|
||||
# create status msg
|
||||
status_msg = await self.reply_to(msg, init_text)
|
||||
|
||||
# start tracking of request in db
|
||||
await self.db.new_user_request(user.id, msg.id, status_msg.id, status=init_text)
|
||||
return [user, status_msg, user_row]
|
||||
|
||||
async def update_status_msg(self, msg: BaseMessage, text: str):
|
||||
'''
|
||||
Update an existing status message, also mirrors changes on db
|
||||
'''
|
||||
await self.db.update_user_request_by_sid(msg.id, text)
|
||||
await self.edit_msg(msg, text)
|
||||
|
||||
async def append_status_msg(self, msg: BaseMessage, text: str):
|
||||
'''
|
||||
Append text to an existing status message
|
||||
'''
|
||||
request = await self.db.get_user_request_by_sid(msg.id)
|
||||
await self.update_status_msg(msg, request['status'] + text)
|
||||
|
||||
@abstractmethod
|
||||
async def update_request_status_timeout(self, status_msg: BaseMessage):
|
||||
'''
|
||||
Notify users when we timedout trying to find a matching submit
|
||||
'''
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_request_status_step_0(self, status_msg: BaseMessage, user: BaseUser):
|
||||
'''
|
||||
First step in request status message lifecycle, should notify which user sent the request
|
||||
and that we are about to broadcast the request to chain
|
||||
'''
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_request_status_step_1(self, status_msg: BaseMessage, tx_result: dict):
|
||||
'''
|
||||
Second step in request status message lifecycle, should notify enqueue transaction
|
||||
was processed by chain, and provide a link to the tx in the chain explorer
|
||||
'''
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_request_status_step_2(self, status_msg: BaseMessage, submit_tx_hash: str):
|
||||
'''
|
||||
Third step in request status message lifecycle, should notify matching submit transaction
|
||||
was found, and provide a link to the tx in the chain explorer
|
||||
'''
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_request_status_final(
|
||||
self,
|
||||
og_msg: BaseMessage,
|
||||
status_msg: BaseMessage,
|
||||
user: BaseUser,
|
||||
params: BodyV0Params,
|
||||
inputs: list[BaseFileInput],
|
||||
submit_tx_hash: str,
|
||||
worker: str,
|
||||
result_img: bytes | None
|
||||
):
|
||||
'''
|
||||
Last step in request status message lifecycle, should delete status message and send a
|
||||
new message replying to the original user's message, generate the appropiate
|
||||
reply caption and if provided also sent the found result img
|
||||
'''
|
||||
...
|
||||
|
||||
|
||||
async def handle_request(
|
||||
self,
|
||||
msg: BaseMessage
|
||||
):
|
||||
if msg.chat.is_private:
|
||||
return
|
||||
|
||||
if len(msg.text) == 0:
|
||||
await self.reply_to(msg.id, 'empty prompt ignored.')
|
||||
return
|
||||
|
||||
# maybe initialize user db row and send a new msg thats gonna
|
||||
# be updated throughout the request lifecycle
|
||||
user, status_msg, user_row = await self.create_status_msg(
|
||||
msg, 'started processing a {msg.command} request...')
|
||||
|
||||
# if this is a redo msg, we attempt to get the input params from db
|
||||
# else use msg properties
|
||||
match msg.command:
|
||||
case BaseCommands.TXT2IMG | BaseCommands.IMG2IMG:
|
||||
prompt = msg.text
|
||||
command = msg.command
|
||||
inputs = msg.inputs
|
||||
|
||||
case BaseCommands.REDO:
|
||||
prompt = await self.db.get_last_prompt_of(user.id)
|
||||
command = await self.db.get_last_method_of(user.id)
|
||||
inputs = await self.db.get_last_inputs_of(user.id)
|
||||
|
||||
if not prompt:
|
||||
await self.reply_to(msg.id, 'no last prompt found, try doing a non-redo request first')
|
||||
return
|
||||
|
||||
case _:
|
||||
await self.reply_to(msg.id, f'unknown request of type {msg.command}')
|
||||
return
|
||||
|
||||
# maybe apply recomended settings to this request
|
||||
del user_row['id']
|
||||
if user_row['autoconf']:
|
||||
user_row = perform_auto_conf(user_row)
|
||||
|
||||
body = BodyV0(
|
||||
method=command,
|
||||
params=BodyV0Params(
|
||||
prompt=prompt,
|
||||
**user_row
|
||||
)
|
||||
)
|
||||
|
||||
# publish inputs to ipfs
|
||||
input_cids = []
|
||||
for i in inputs:
|
||||
i.publish()
|
||||
input_cids.append(i.cid)
|
||||
|
||||
inputs_str = ','.join((i for i in input_cids))
|
||||
|
||||
# unless its a redo request, update db user data
|
||||
if command != BaseCommands.REDO:
|
||||
await self.db.update_user_stats(
|
||||
user.id,
|
||||
command,
|
||||
last_prompt=prompt,
|
||||
last_inputs=inputs
|
||||
)
|
||||
|
||||
await self.update_request_status_step_0(status_msg)
|
||||
|
||||
# prepare and send enqueue request
|
||||
request_time = datetime.now().isoformat()
|
||||
str_body = msgspec.json.encode(body).decode('utf-8')
|
||||
|
||||
enqueue_receipt = await self.cleos.a_push_action(
|
||||
self.receiver,
|
||||
'enqueue',
|
||||
(
|
||||
self.config.account,
|
||||
str_body,
|
||||
inputs_str,
|
||||
self.config.reward,
|
||||
1
|
||||
)
|
||||
)
|
||||
|
||||
await self.update_request_status_step_1(status_msg, enqueue_receipt)
|
||||
|
||||
# wait and search submit request using hyperion endpoint
|
||||
console = enqueue_receipt['processed']['action_traces'][0]['console']
|
||||
console_lines = console.split('\n')
|
||||
|
||||
request_id = None
|
||||
request_hash = None
|
||||
if self.config.proto_version == 0:
|
||||
'''
|
||||
v0 has req_id:nonce printed in enqueue console output
|
||||
to search for a result request_hash arg on submit has
|
||||
to match the sha256 of nonce + body + input_str
|
||||
'''
|
||||
request_id, nonce = console_lines[-1].rstrip().split(':')
|
||||
request_hash = sha256(
|
||||
(nonce + str_body + inputs_str).encode('utf-8')).hexdigest.upper()
|
||||
|
||||
request_id = int(request_id)
|
||||
|
||||
elif self.config.proto_version == 1:
|
||||
'''
|
||||
v1 uses a global unique nonce and prints it on enqueue
|
||||
console output to search for a result request_id arg
|
||||
on submit has to match the printed req_id
|
||||
'''
|
||||
request_id = int(console_lines[-1].rstrip())
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
worker = None
|
||||
submit_tx_hash = None
|
||||
result_cid = None
|
||||
for i in range(1, self.config.request_timeout + 1):
|
||||
try:
|
||||
submits = await self.hyperion.aget_actions(
|
||||
account=self.config.account,
|
||||
filter=f'{self.config.receiver}:submit',
|
||||
sort='desc',
|
||||
after=request_time
|
||||
)
|
||||
if self.config.proto_version == 0:
|
||||
actions = [
|
||||
action
|
||||
for action in submits['actions']
|
||||
if action['act']['data']['request_hash'] == request_hash
|
||||
]
|
||||
elif self.config.proto_version == 1:
|
||||
actions = [
|
||||
action
|
||||
for action in submits['actions']
|
||||
if action['act']['data']['request_id'] == request_id
|
||||
]
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if len(actions) > 0:
|
||||
action = actions[0]
|
||||
submit_tx_hash = action['trx_id']
|
||||
data = action['act']['data']
|
||||
result_cid = data['ipfs_hash']
|
||||
worker = data['worker']
|
||||
logging.info('found matching submit! tx: {submit_tx_hash} cid: {ipfs_hash}')
|
||||
break
|
||||
|
||||
except json.JSONDecodeError:
|
||||
if i < self.config.request_timeout:
|
||||
logging.error('network error while searching for submit, retry...')
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# if we found matching submit submit_tx_hash, worker, and result_cid will not be None
|
||||
if not result_cid:
|
||||
await self.update_request_status_timeout(status_msg)
|
||||
raise RequestTimeoutError
|
||||
|
||||
await self.update_request_status_step_2(status_msg, submit_tx_hash)
|
||||
|
||||
# attempt to get the image and send it
|
||||
result_link = f'https://{self.config.ipfs_domain}/ipfs/{result_cid}'
|
||||
get_img_response = await get_ipfs_file(result_link)
|
||||
|
||||
result_img = None
|
||||
if get_img_response and get_img_response.status_code == 200:
|
||||
try:
|
||||
with Image.open(io.BytesIO(get_img_response.raw)) as img:
|
||||
w, h = img.size
|
||||
|
||||
if (
|
||||
w > self.config.result_max_width
|
||||
or
|
||||
h > self.config.result_max_height
|
||||
):
|
||||
max_size = (self.config.result_max_width, self.config.result_max_height)
|
||||
logging.warning(
|
||||
f'raw result is of size {img.size}, resizing to {max_size}')
|
||||
img.thumbnail(max_size)
|
||||
|
||||
tmp_buf = io.BytesIO()
|
||||
img.save(tmp_buf, format='PNG')
|
||||
result_img = tmp_buf.getvalue()
|
||||
|
||||
except UnidentifiedImageError:
|
||||
logging.warning(f'couldn\'t get ipfs result at {result_link}!')
|
||||
|
||||
await self.update_request_status_final(
|
||||
msg, status_msg, user, body, inputs, submit_tx_hash, worker, result_img)
|
||||
|
||||
await self.db.increment_generated(user.id)
|
||||
|
||||
async def send_help(self, msg: BaseMessage):
|
||||
if len(msg.text) == 0:
|
||||
await self.reply_to(msg, HELP_TEXT)
|
||||
|
||||
else:
|
||||
if msg.text in HELP_TOPICS:
|
||||
await self.reply_to(msg, HELP_TOPICS[msg.text])
|
||||
|
||||
else:
|
||||
await self.reply_to(msg, HELP_UNKWNOWN_PARAM)
|
||||
|
||||
async def send_cool_words(self, msg: BaseMessage):
|
||||
await self.reply_to(msg, '\n'.join(COOL_WORDS))
|
||||
|
||||
async def get_queue(self, msg: BaseMessage):
|
||||
an_hour_ago = datetime.now() - timedelta(hours=1)
|
||||
queue = await self.cleos.aget_table(
|
||||
self.config.receiver, self.config.receiver, 'queue',
|
||||
index_position=2,
|
||||
key_type='i64',
|
||||
sort='desc',
|
||||
lower_bound=int(an_hour_ago.timestamp())
|
||||
)
|
||||
await self.reply_to(
|
||||
msg, f'Requests on skynet queue: {len(queue)}')
|
||||
|
||||
async def set_config(self, msg: BaseMessage):
|
||||
try:
|
||||
attr, val, reply_txt = validate_user_config_request(msg.text)
|
||||
|
||||
await self.db.update_user_config(msg.author.id, attr, val)
|
||||
|
||||
except BaseException as e:
|
||||
reply_txt = str(e)
|
||||
|
||||
finally:
|
||||
await self.reply_to(msg, reply_txt)
|
||||
|
||||
async def user_stats(self, msg: BaseMessage):
|
||||
await self.db.get_or_create_user(msg.author.id)
|
||||
generated, joined, role = await self.db.get_user_stats(msg.author.id)
|
||||
|
||||
stats_str = f'generated: {generated}\n'
|
||||
stats_str += f'joined: {joined}\n'
|
||||
stats_str += f'role: {role}\n'
|
||||
|
||||
await self.reply_to(msg, stats_str)
|
||||
|
||||
async def donation_info(self, msg: BaseMessage):
|
||||
await self.reply_to(msg, DONATION_INFO)
|
||||
|
||||
async def say(self, msg: BaseMessage):
|
||||
if not msg.chat.is_private or not msg.author.is_admin:
|
||||
return
|
||||
|
||||
await self.new_msg(self.main_group, msg.text)
|
|
@ -0,0 +1,440 @@
|
|||
import logging
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import docker
|
||||
import psycopg2
|
||||
import asyncpg
|
||||
|
||||
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
|
||||
from contextlib import contextmanager as cm
|
||||
|
||||
from skynet.constants import (
|
||||
DEFAULT_ROLE, DEFAULT_MODEL, DEFAULT_STEP,
|
||||
DEFAULT_WIDTH, DEFAULT_HEIGHT, DEFAULT_GUIDANCE,
|
||||
DEFAULT_STRENGTH, DEFAULT_UPSCALER
|
||||
)
|
||||
from skynet.frontend.chatbot.types import BaseFileInput
|
||||
|
||||
DB_INIT_SQL = """
|
||||
CREATE SCHEMA IF NOT EXISTS skynet;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS skynet.user(
|
||||
id BIGSERIAL PRIMARY KEY NOT NULL,
|
||||
generated INT NOT NULL,
|
||||
joined TIMESTAMP NOT NULL,
|
||||
last_method TEXT,
|
||||
last_prompt TEXT,
|
||||
last_inputs TEXT,
|
||||
role VARCHAR(128) NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS skynet.user_config(
|
||||
id BIGSERIAL NOT NULL,
|
||||
model VARCHAR(512) NOT NULL,
|
||||
step INT NOT NULL,
|
||||
width INT NOT NULL,
|
||||
height INT NOT NULL,
|
||||
seed NUMERIC,
|
||||
guidance DECIMAL NOT NULL,
|
||||
strength DECIMAL NOT NULL,
|
||||
upscaler VARCHAR(128),
|
||||
autoconf BOOLEAN DEFAULT TRUE,
|
||||
CONSTRAINT fk_config
|
||||
FOREIGN KEY(id)
|
||||
REFERENCES skynet.user(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS skynet.user_requests(
|
||||
id BIGSERIAL NOT NULL,
|
||||
user_id BIGSERIAL NOT NULL,
|
||||
sent TIMESTAMP NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
status_msg BIGSERIAL PRIMARY KEY NOT NULL,
|
||||
CONSTRAINT fk_user_req
|
||||
FOREIGN KEY(user_id)
|
||||
REFERENCES skynet.user(id)
|
||||
);
|
||||
"""
|
||||
|
||||
def try_decode_uid(uid: str) -> tuple[str | None, int | None]:
|
||||
"""
|
||||
Attempts to decode the user ID. The user ID can be just an integer
|
||||
or of the format 'proto+uid'. Returns (None, int) if it's just an
|
||||
integer or (proto, int) if it's 'proto+uid'. Returns (None, None)
|
||||
if neither format is valid.
|
||||
"""
|
||||
try:
|
||||
return None, int(uid)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
proto, uid_str = uid.split("+", 1)
|
||||
return proto, int(uid_str)
|
||||
except ValueError:
|
||||
logging.warning(f"Got non-chat-proto UID?: {uid}")
|
||||
return None, None
|
||||
|
||||
|
||||
@cm
|
||||
def open_new_database(cleanup: bool = True):
|
||||
"""
|
||||
Context manager that spins up a temporary Postgres Docker container,
|
||||
creates a 'skynet' user and database, and yields (container, password, host).
|
||||
Stops the container on exit if 'cleanup' is True.
|
||||
"""
|
||||
root_password = "".join(random.choice(string.ascii_lowercase) for _ in range(12))
|
||||
skynet_password = "".join(random.choice(string.ascii_lowercase) for _ in range(12))
|
||||
|
||||
dclient = docker.from_env()
|
||||
container = dclient.containers.run(
|
||||
"postgres",
|
||||
name="skynet-test-postgres",
|
||||
ports={"5432/tcp": None},
|
||||
environment={"POSTGRES_PASSWORD": root_password},
|
||||
detach=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# Wait for Postgres to be ready
|
||||
for log_line in container.logs(stream=True):
|
||||
line = log_line.decode().rstrip()
|
||||
logging.info(line)
|
||||
if (
|
||||
"database system is ready to accept connections" in line
|
||||
or "database system is shut down" in line
|
||||
):
|
||||
break
|
||||
|
||||
container.reload()
|
||||
port_info = container.ports["5432/tcp"][0]
|
||||
port = port_info["HostPort"]
|
||||
db_host = f"localhost:{port}"
|
||||
|
||||
# Let PostgreSQL settle
|
||||
time.sleep(1)
|
||||
logging.info("Creating 'skynet' database...")
|
||||
|
||||
with psycopg2.connect(
|
||||
user="postgres", password=root_password, host="localhost", port=port
|
||||
) as conn:
|
||||
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(f"CREATE USER skynet WITH PASSWORD '{skynet_password}'")
|
||||
cursor.execute("CREATE DATABASE skynet")
|
||||
cursor.execute("GRANT ALL PRIVILEGES ON DATABASE skynet TO skynet")
|
||||
|
||||
logging.info("Database setup complete.")
|
||||
yield container, skynet_password, db_host
|
||||
|
||||
finally:
|
||||
if container and cleanup:
|
||||
container.stop()
|
||||
|
||||
|
||||
class FrontendUserDB:
|
||||
"""
|
||||
A class that manages the connection pool for the 'skynet' database,
|
||||
initializes the schema if needed, and provides high-level methods
|
||||
for interacting with the 'skynet' tables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_user: str,
|
||||
db_pass: str,
|
||||
db_host: str,
|
||||
db_name: str
|
||||
):
|
||||
self.db_user = db_user
|
||||
self.db_pass = db_pass
|
||||
self.db_host = db_host
|
||||
self.db_name = db_name
|
||||
self.pool: asyncpg.Pool | None = None
|
||||
|
||||
async def __aenter__(self) -> "FrontendUserDB":
|
||||
dsn = f"postgres://{self.db_user}:{self.db_pass}@{self.db_host}/{self.db_name}"
|
||||
self.pool = await asyncpg.create_pool(dsn=dsn)
|
||||
await self._init_db()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.pool:
|
||||
await self.pool.close()
|
||||
|
||||
async def _init_db(self):
|
||||
"""
|
||||
Ensures the 'skynet' schema and tables exist. Also checks for
|
||||
missing columns and adds them if necessary.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
# Check if schema is already initialized
|
||||
result = await conn.fetch("""
|
||||
SELECT DISTINCT table_schema
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'skynet'
|
||||
""")
|
||||
if not result:
|
||||
await conn.execute(DB_INIT_SQL)
|
||||
|
||||
# Check if 'autoconf' column exists in user_config
|
||||
col_check = await conn.fetch("""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'user_config' AND column_name = 'autoconf'
|
||||
""")
|
||||
if not col_check:
|
||||
await conn.execute(
|
||||
"ALTER TABLE skynet.user_config ADD COLUMN autoconf BOOLEAN DEFAULT TRUE;"
|
||||
)
|
||||
|
||||
# -------------
|
||||
# USER METHODS
|
||||
# -------------
|
||||
|
||||
async def get_user_config(self, user_id: int):
|
||||
"""
|
||||
Fetches the user_config for the given user ID.
|
||||
Returns the record if found, otherwise None.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
records = await conn.fetch(
|
||||
"SELECT * FROM skynet.user_config WHERE id = $1", user_id
|
||||
)
|
||||
return records[0] if len(records) == 1 else None
|
||||
|
||||
async def get_user(self, user_id: int):
|
||||
"""Alias for get_user_config (same data returned)."""
|
||||
return await self.get_user_config(user_id)
|
||||
|
||||
async def new_user(self, user_id: int):
|
||||
"""
|
||||
Inserts a new user in skynet.user and its corresponding user_config record.
|
||||
Raises ValueError if the user already exists.
|
||||
"""
|
||||
existing = await self.get_user(user_id)
|
||||
if existing:
|
||||
raise ValueError("User already present in DB")
|
||||
|
||||
logging.info(f"New user! {user_id}")
|
||||
now = datetime.utcnow()
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO skynet.user(
|
||||
id, generated, joined,
|
||||
last_method, last_prompt, last_inputs, role
|
||||
)
|
||||
VALUES($1, 0, $2, 'txt2img', NULL, NULL, $3)
|
||||
""",
|
||||
user_id,
|
||||
now,
|
||||
DEFAULT_ROLE,
|
||||
)
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO skynet.user_config(
|
||||
id, model, step, width,
|
||||
height, guidance, strength, upscaler
|
||||
)
|
||||
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
""",
|
||||
user_id,
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_STEP,
|
||||
DEFAULT_WIDTH,
|
||||
DEFAULT_HEIGHT,
|
||||
DEFAULT_GUIDANCE,
|
||||
DEFAULT_STRENGTH,
|
||||
DEFAULT_UPSCALER,
|
||||
)
|
||||
|
||||
async def get_or_create_user(self, user_id: int):
|
||||
"""
|
||||
Retrieves a user_config record for the given user_id.
|
||||
If none exists, creates the user and returns the new record.
|
||||
"""
|
||||
user_cfg = await self.get_user(user_id)
|
||||
if not user_cfg:
|
||||
await self.new_user(user_id)
|
||||
user_cfg = await self.get_user(user_id)
|
||||
return user_cfg
|
||||
|
||||
async def update_user(self, user_id: int, attr: str, val):
|
||||
"""
|
||||
Generic function to update a single field in skynet.user for a given user_id.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
f"UPDATE skynet.user SET {attr} = $2 WHERE id = $1", user_id, val
|
||||
)
|
||||
|
||||
async def update_user_config(self, user_id: int, attr: str, val):
|
||||
"""
|
||||
Generic function to update a single field in skynet.user_config for a given user_id.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
f"UPDATE skynet.user_config SET {attr} = $2 WHERE id = $1", user_id, val
|
||||
)
|
||||
|
||||
async def get_user_stats(self, user_id: int):
|
||||
"""
|
||||
Returns (generated, joined, role) for the given user_id.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
records = await conn.fetch(
|
||||
"""
|
||||
SELECT generated, joined, role
|
||||
FROM skynet.user
|
||||
WHERE id = $1
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
return records[0] if records else None
|
||||
|
||||
async def increment_generated(self, user_id: int):
|
||||
"""
|
||||
Increments the 'generated' count for a given user by 1.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE skynet.user
|
||||
SET generated = generated + 1
|
||||
WHERE id = $1
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
async def update_user_stats(
|
||||
self,
|
||||
user_id: int,
|
||||
method: str,
|
||||
last_prompt: str | None = None,
|
||||
last_inputs: list | None = None
|
||||
):
|
||||
"""
|
||||
Updates various 'last_*' fields in skynet.user.
|
||||
"""
|
||||
await self.update_user(user_id, "last_method", method)
|
||||
if last_prompt is not None:
|
||||
await self.update_user(user_id, "last_prompt", last_prompt)
|
||||
|
||||
last_inputs_str = None
|
||||
if isinstance(last_inputs, list):
|
||||
last_inputs_str = ','.join((f'{f.id}:{f.cid}' for f in last_inputs))
|
||||
await self.update_user(user_id, "last_inputs", last_inputs_str)
|
||||
|
||||
logging.info("Updated user stats: %s", (method, last_prompt, last_inputs_str))
|
||||
|
||||
# ----------------------
|
||||
# USER REQUESTS METHODS
|
||||
# ----------------------
|
||||
|
||||
async def get_user_request(self, request_id: int):
|
||||
"""
|
||||
Fetches all matching rows for a given request_id.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
return await conn.fetch(
|
||||
"SELECT * FROM skynet.user_requests WHERE id = $1", request_id
|
||||
)
|
||||
|
||||
async def get_user_request_by_sid(self, status_msg_id: int):
|
||||
"""
|
||||
Fetches exactly one row (first row) by status_msg primary key.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
records = await conn.fetch(
|
||||
"SELECT * FROM skynet.user_requests WHERE status_msg = $1", status_msg_id
|
||||
)
|
||||
return records[0] if records else None
|
||||
|
||||
async def new_user_request(
|
||||
self,
|
||||
user_id: int,
|
||||
request_id: int,
|
||||
status_msg_id: int,
|
||||
status: str = "started processing request..."
|
||||
):
|
||||
"""
|
||||
Inserts a new row in skynet.user_requests.
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO skynet.user_requests(
|
||||
id, user_id, sent, status, status_msg
|
||||
)
|
||||
VALUES($1, $2, $3, $4, $5)
|
||||
""",
|
||||
request_id, user_id, now, status, status_msg_id
|
||||
)
|
||||
|
||||
async def update_user_request(self, request_id: int, status: str):
|
||||
"""
|
||||
Updates the 'status' for a user request identified by 'request_id'.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE skynet.user_requests
|
||||
SET status = $2
|
||||
WHERE id = $1
|
||||
""",
|
||||
request_id, status
|
||||
)
|
||||
|
||||
async def update_user_request_by_sid(self, sid: int, status: str):
|
||||
"""
|
||||
Updates the 'status' for a user request identified by 'status_msg'.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE skynet.user_requests
|
||||
SET status = $2
|
||||
WHERE status_msg = $1
|
||||
""",
|
||||
sid, status
|
||||
)
|
||||
|
||||
# ----------------------------
|
||||
# Convenience "Get Last" Helpers
|
||||
# ----------------------------
|
||||
|
||||
async def get_last_method_of(self, user_id: int) -> str | None:
|
||||
async with self.pool.acquire() as conn:
|
||||
return await conn.fetchval(
|
||||
"SELECT last_method FROM skynet.user WHERE id = $1", user_id
|
||||
)
|
||||
|
||||
async def get_last_prompt_of(self, user_id: int) -> str | None:
|
||||
async with self.pool.acquire() as conn:
|
||||
return await conn.fetchval(
|
||||
"SELECT last_prompt FROM skynet.user WHERE id = $1", user_id
|
||||
)
|
||||
|
||||
async def get_last_inputs_of(self, user_id: int) -> list[BaseFileInput] | None:
|
||||
async with self.pool.acquire() as conn:
|
||||
last_inputs_str = await conn.fetchval(
|
||||
"SELECT last_inputs FROM skynet.user WHERE id = $1", user_id
|
||||
)
|
||||
|
||||
if not last_inputs_str:
|
||||
return None
|
||||
|
||||
last_inputs = []
|
||||
for i in last_inputs_str.split(','):
|
||||
id, cid = i.split(':')
|
||||
last_inputs.from_values(id, cid)
|
||||
|
||||
return last_inputs
|
|
@ -0,0 +1,378 @@
|
|||
import traceback
|
||||
|
||||
from typing import Self, Awaitable
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from telebot.types import (
|
||||
AsyncTeleBot,
|
||||
User as TGUser,
|
||||
Chat as TGChat,
|
||||
PhotoSize as TGPhotoSize,
|
||||
Message as TGMessage,
|
||||
InputMediaPhoto,
|
||||
InlineKeyboardButton,
|
||||
InlineKeyboardMarkup
|
||||
)
|
||||
from telebot.async_telebot import ExceptionHandler
|
||||
from telebot.formatting import hlink
|
||||
|
||||
from skynet.types import BodyV0Params
|
||||
from skynet.config import FrontendConfig
|
||||
from skynet.constants import VERSION
|
||||
from skynet.frontend.chatbot import BaseChatbot
|
||||
from skynet.frontend.chatbot.db import FrontendUserDB
|
||||
from skynet.frontend.types import (
|
||||
BaseUser,
|
||||
BaseChatRoom,
|
||||
BaseFileInput,
|
||||
BaseCommands,
|
||||
BaseMessage
|
||||
)
|
||||
|
||||
GROUP_ID = -1001541979235
|
||||
ADMIN_USER_ID = 383385940
|
||||
|
||||
|
||||
# Chatbot types impls
|
||||
|
||||
class TelegramUser(BaseUser):
|
||||
def __init__(self, user: TGUser):
|
||||
self._user = user
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._user.id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
if self._user.username:
|
||||
return f'@{self._user.username}'
|
||||
|
||||
return f'{self._user.first_name} id: {self.id}'
|
||||
|
||||
@property
|
||||
def is_admin(self) -> bool:
|
||||
return self.id == ADMIN_USER_ID
|
||||
|
||||
|
||||
class TelegramChatRoom(BaseChatRoom):
|
||||
|
||||
def __init__(self, chat: TGChat):
|
||||
self._chat = chat
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._chat.id
|
||||
|
||||
@property
|
||||
def is_private(self) -> bool:
|
||||
return self._chat.type == 'private'
|
||||
|
||||
|
||||
class TelegramFileInput(BaseFileInput):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
photo: TGPhotoSize | None = None,
|
||||
id: int | None = None,
|
||||
cid: str | None = None
|
||||
):
|
||||
self._photo = photo
|
||||
self._id = id
|
||||
self._cid = cid
|
||||
|
||||
self._raw = None
|
||||
|
||||
def from_values(id: int, cid: str) -> Self:
|
||||
return TelegramFileInput(id=id, cid=cid)
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
if self._id:
|
||||
return self._id
|
||||
|
||||
return self._photo.file_id
|
||||
|
||||
@property
|
||||
def cid(self) -> str:
|
||||
if self._cid:
|
||||
return self._cid
|
||||
|
||||
raise ValueError
|
||||
|
||||
async def download(self, bot: AsyncTeleBot) -> bytes:
|
||||
file_path = (await bot.get_file(self.id)).file_path
|
||||
self._raw = await bot.download_file(file_path)
|
||||
return self._raw
|
||||
|
||||
|
||||
class TelegramMessage(BaseMessage):
|
||||
|
||||
def __init__(self, cmd: BaseCommands | None, msg: TGMessage):
|
||||
self._msg = msg
|
||||
self._cmd = cmd
|
||||
self._chat = TelegramChatRoom(msg.chat)
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._msg.message_id
|
||||
|
||||
@property
|
||||
def chat(self) -> TelegramChatRoom:
|
||||
return self._chat
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return self._msg.text[len(self._cmd) + 1:]
|
||||
|
||||
@property
|
||||
def author(self) -> TelegramUser:
|
||||
return TelegramUser(self._msg.from_user)
|
||||
|
||||
@property
|
||||
def command(self) -> str | None:
|
||||
return self._cmd
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[TelegramFileInput]:
|
||||
return [
|
||||
TelegramFileInput(photo=p)
|
||||
for p in self._msg.photo
|
||||
]
|
||||
|
||||
|
||||
# generic tg utils
|
||||
|
||||
def timestamp_pretty():
|
||||
return datetime.now(timezone.utc).strftime('%H:%M:%S')
|
||||
|
||||
|
||||
class TGExceptionHandler(ExceptionHandler):
|
||||
|
||||
def handle(exception):
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def build_redo_menu():
|
||||
btn_redo = InlineKeyboardButton("Redo", callback_data='{\"method\": \"redo\"}')
|
||||
inline_keyboard = InlineKeyboardMarkup()
|
||||
inline_keyboard.add(btn_redo)
|
||||
return inline_keyboard
|
||||
|
||||
|
||||
def prepare_metainfo_caption(user: TelegramUser, worker: str, reward: str, params: BodyV0Params) -> str:
|
||||
prompt = params.prompt
|
||||
if len(prompt) > 256:
|
||||
prompt = prompt[:256]
|
||||
|
||||
meta_str = f'<u>by {user.name}</u>\n'
|
||||
meta_str += f'<i>performed by {worker}</i>\n'
|
||||
meta_str += f'<b><u>reward: {reward}</u></b>\n'
|
||||
|
||||
meta_str += f'<code>prompt:</code> {prompt}\n'
|
||||
meta_str += f'<code>seed: {params.seed}</code>\n'
|
||||
meta_str += f'<code>step: {params.step}</code>\n'
|
||||
if params.guidance:
|
||||
meta_str += f'<code>guidance: {params.guidance}</code>\n'
|
||||
|
||||
if params.strength:
|
||||
meta_str += f'<code>strength: {params.strength}</code>\n'
|
||||
|
||||
meta_str += f'<code>algo: {params.model}</code>\n'
|
||||
|
||||
meta_str += f'<b><u>Made with Skynet v{VERSION}</u></b>\n'
|
||||
meta_str += '<b>JOIN THE SWARM: @skynetgpu</b>'
|
||||
return meta_str
|
||||
|
||||
|
||||
def generate_reply_caption(
|
||||
user: TelegramUser,
|
||||
params: BodyV0Params,
|
||||
tx_hash: str,
|
||||
worker: str,
|
||||
reward: str,
|
||||
explorer_domain: str
|
||||
):
|
||||
explorer_link = hlink(
|
||||
'SKYNET Transaction Explorer',
|
||||
f'https://{explorer_domain}/v2/explore/transaction/{tx_hash}'
|
||||
)
|
||||
|
||||
meta_info = prepare_metainfo_caption(user, worker, reward, params)
|
||||
|
||||
final_msg = '\n'.join([
|
||||
'Worker finished your task!',
|
||||
explorer_link,
|
||||
f'PARAMETER INFO:\n{meta_info}'
|
||||
])
|
||||
|
||||
final_msg = '\n'.join([
|
||||
f'<b><i>{explorer_link}</i></b>',
|
||||
f'{meta_info}'
|
||||
])
|
||||
|
||||
return final_msg
|
||||
|
||||
|
||||
def append_handler(bot: AsyncTeleBot, command: str, fn: Awaitable):
|
||||
@bot.message_handler(commands=[command])
|
||||
async def wrap_msg_and_handle(tg_msg: TGMessage):
|
||||
await fn(TelegramMessage(cmd=command, msg=tg_msg))
|
||||
|
||||
|
||||
class TelegramChatbot(BaseChatbot):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FrontendConfig,
|
||||
db: FrontendUserDB,
|
||||
):
|
||||
super().__init__(config, db)
|
||||
bot = AsyncTeleBot(config.token, exception_handler=TGExceptionHandler)
|
||||
|
||||
append_handler(bot, BaseCommands.HELP, self.send_help)
|
||||
append_handler(bot, BaseCommands.COOL, self.send_cool_words)
|
||||
append_handler(bot, BaseCommands.QUEUE, self.get_queue)
|
||||
append_handler(bot, BaseCommands.CONFIG, self.set_config)
|
||||
append_handler(bot, BaseCommands.STATS, self.user_stats)
|
||||
append_handler(bot, BaseCommands.DONATE, self.donation_info)
|
||||
append_handler(bot, BaseCommands.SAY, self.say)
|
||||
|
||||
append_handler(bot, BaseCommands.TXT2IMG, self.handle_request)
|
||||
append_handler(bot, BaseCommands.IMG2IMG, self.handle_request)
|
||||
append_handler(bot, BaseCommands.REDO, self.handle_request)
|
||||
|
||||
self._main_room: TelegramChatRoom | None = None
|
||||
|
||||
async def init(self):
|
||||
tg_group = await self.bot.get_chat(GROUP_ID)
|
||||
self._main_room = TelegramChatRoom(chat=tg_group)
|
||||
|
||||
async def run(self):
|
||||
await self.init()
|
||||
await self.bot.infinity_polling()
|
||||
|
||||
@property
|
||||
def main_group(self) -> TelegramChatRoom:
|
||||
return self._main_room
|
||||
|
||||
async def new_msg(self, chat: TelegramChatRoom, text: str) -> TelegramMessage:
|
||||
msg = await self.bot.send_message(chat.id, text, parse_mode='HTML')
|
||||
return TelegramMessage(cmd=None, msg=msg)
|
||||
|
||||
async def reply_to(self, msg: TelegramMessage, text: str) -> TelegramMessage:
|
||||
msg = await self.bot.reply_to(msg._msg, text, parse_mode='HTML')
|
||||
return TelegramMessage(cmd=None, msg=msg)
|
||||
|
||||
async def edit_msg(self, msg: TelegramMessage, text: str):
|
||||
await self.bot.edit_message_text(
|
||||
text,
|
||||
chat_id=msg.chat.id,
|
||||
message_id=msg.id,
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
async def update_request_status_timeout(self, status_msg: TelegramMessage):
|
||||
'''
|
||||
Notify users when we timedout trying to find a matching submit
|
||||
'''
|
||||
await self.append_status_msg(
|
||||
status_msg,
|
||||
f'\n[{timestamp_pretty()}] <b>timeout processing request</b>',
|
||||
)
|
||||
|
||||
async def update_request_status_step_0(self, status_msg: TelegramMessage):
|
||||
'''
|
||||
First step in request status message lifecycle, should notify which user sent the request
|
||||
and that we are about to broadcast the request to chain
|
||||
'''
|
||||
await self.update_status_msg(
|
||||
status_msg,
|
||||
f'processing a \'{status_msg.command}\' request by {status_msg.author.name}\n'
|
||||
f'[{timestamp_pretty()}] <i>broadcasting transaction to chain...</i>'
|
||||
)
|
||||
|
||||
async def update_request_status_step_1(self, status_msg: TelegramMessage, tx_result: dict):
|
||||
'''
|
||||
Second step in request status message lifecycle, should notify enqueue transaction
|
||||
was processed by chain, and provide a link to the tx in the chain explorer
|
||||
'''
|
||||
enqueue_tx_id = tx_result['transaction_id']
|
||||
enqueue_tx_link = hlink(
|
||||
'Your request on Skynet Explorer',
|
||||
f'https://{self.explorer_domain}/v2/explore/transaction/{enqueue_tx_id}'
|
||||
)
|
||||
await self.append_status_msg(
|
||||
status_msg,
|
||||
f' <b>broadcasted!</b>\n'
|
||||
f'<b>{enqueue_tx_link}</b>\n'
|
||||
f'[{timestamp_pretty()}] <i>workers are processing request...</i>',
|
||||
)
|
||||
|
||||
async def update_request_status_step_2(self, status_msg: TelegramMessage, submit_tx_hash: str):
|
||||
'''
|
||||
Third step in request status message lifecycle, should notify matching submit transaction
|
||||
was found, and provide a link to the tx in the chain explorer
|
||||
'''
|
||||
tx_link = hlink(
|
||||
'Your result on Skynet Explorer',
|
||||
f'https://{self.explorer_domain}/v2/explore/transaction/{submit_tx_hash}'
|
||||
)
|
||||
await self.append_status_msg(
|
||||
status_msg,
|
||||
f' <b>request processed!</b>\n'
|
||||
f'<b>{tx_link}</b>\n'
|
||||
f'[{timestamp_pretty()}] <i>trying to download image...</i>\n',
|
||||
)
|
||||
|
||||
async def update_request_status_final(
|
||||
self,
|
||||
og_msg: TelegramMessage,
|
||||
status_msg: TelegramMessage,
|
||||
user: TelegramUser,
|
||||
params: BodyV0Params,
|
||||
inputs: list[TelegramFileInput],
|
||||
submit_tx_hash: str,
|
||||
worker: str,
|
||||
result_img: bytes | None
|
||||
):
|
||||
'''
|
||||
Last step in request status message lifecycle, should delete status message and send a
|
||||
new message replying to the original user's message, generate the appropiate
|
||||
reply caption and if provided also sent the found result img
|
||||
'''
|
||||
caption = generate_reply_caption(
|
||||
user, worker, self.config.reward, params)
|
||||
|
||||
await self.bot.delete_message(
|
||||
chat_id=status_msg.chat.id,
|
||||
message_id=status_msg.id
|
||||
)
|
||||
|
||||
if not result_img:
|
||||
# result found on chain but failed to fetch img from ipfs
|
||||
await self.reply_to(og_msg, caption, reply_markup=build_redo_menu())
|
||||
return
|
||||
|
||||
match len(inputs):
|
||||
case 0:
|
||||
await self.bot.send_photo(
|
||||
status_msg.chat.id,
|
||||
caption=caption,
|
||||
photo=result_img,
|
||||
reply_markup=build_redo_menu(),
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
case 1:
|
||||
_input = inputs.pop()
|
||||
await self.bot.send_media_group(
|
||||
status_msg.chat.id,
|
||||
media=[
|
||||
InputMediaPhoto(_input.id),
|
||||
InputMediaPhoto(result_img, caption=caption, parse_mode='HTML')
|
||||
]
|
||||
)
|
||||
|
||||
case _:
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,110 @@
|
|||
import io
|
||||
|
||||
from ABC import ABC, abstractproperty, abstractmethod
|
||||
from enum import StrEnum
|
||||
from typing import Self
|
||||
from PIL import Image
|
||||
|
||||
from skynet.ipfs import AsyncIPFSHTTP
|
||||
|
||||
|
||||
class BaseUser(ABC):
|
||||
|
||||
@abstractproperty
|
||||
def id(self) -> int:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def name(self) -> str:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def is_admin(self) -> bool:
|
||||
...
|
||||
|
||||
|
||||
class BaseChatRoom(ABC):
|
||||
@abstractproperty
|
||||
def id(self) -> int:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def is_private(self) -> bool:
|
||||
...
|
||||
|
||||
|
||||
class BaseFileInput(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def from_values(id: int, cid: str) -> Self:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def id(self) -> int:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def cid(self) -> str:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def download(self, *args) -> bytes:
|
||||
...
|
||||
|
||||
async def publish(self, ipfs_api: AsyncIPFSHTTP, user_row: dict):
|
||||
with Image.open(io.BytesIO(self._raw)) as img:
|
||||
w, h = img.size
|
||||
|
||||
if (
|
||||
w > user_row['width']
|
||||
or
|
||||
h > user_row['height']
|
||||
):
|
||||
img.thumbnail((user_row['width'], user_row['height']))
|
||||
|
||||
img_path = '/tmp/ipfs-staging/img.png'
|
||||
img.save(img_path, format='PNG')
|
||||
|
||||
ipfs_info = await ipfs_api.add(img_path)
|
||||
ipfs_hash = ipfs_info['Hash']
|
||||
await ipfs_api.pin(ipfs_hash)
|
||||
|
||||
|
||||
class BaseCommands(StrEnum):
|
||||
TXT2IMG = 'txt2img'
|
||||
IMG2IMG = 'img2img'
|
||||
REDO = 'redo'
|
||||
HELP = 'help'
|
||||
COOL = 'cool'
|
||||
QUEUE = 'queue'
|
||||
CONFIG = 'config'
|
||||
STATS = 'stats'
|
||||
DONATE = 'donate'
|
||||
SAY = 'say'
|
||||
|
||||
|
||||
class BaseMessage(ABC):
|
||||
@abstractproperty
|
||||
def id(self) -> int:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def chat(self) -> BaseChatRoom:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def text(self) -> str:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def author(self) -> BaseUser:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def command(self) -> str | None:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def inputs(self) -> list[BaseFileInput]:
|
||||
...
|
|
@ -1,295 +0,0 @@
|
|||
import io
|
||||
import random
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from json import JSONDecodeError
|
||||
from decimal import Decimal
|
||||
from hashlib import sha256
|
||||
from datetime import datetime
|
||||
from contextlib import AsyncExitStack
|
||||
from contextlib import asynccontextmanager as acm
|
||||
|
||||
from leap.cleos import CLEOS
|
||||
from leap.protocol import Name, Asset
|
||||
from leap.hyperion import HyperionAPI
|
||||
|
||||
from telebot.types import InputMediaPhoto
|
||||
from telebot.async_telebot import AsyncTeleBot
|
||||
|
||||
from skynet.db import open_database_connection
|
||||
from skynet.ipfs import get_ipfs_file, AsyncIPFSHTTP
|
||||
from skynet.constants import *
|
||||
|
||||
from . import *
|
||||
|
||||
from .utils import *
|
||||
from .handlers import create_handler_context
|
||||
|
||||
|
||||
class SkynetTelegramFrontend:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token: str,
|
||||
account: str,
|
||||
permission: str,
|
||||
node_url: str,
|
||||
hyperion_url: str,
|
||||
db_host: str,
|
||||
db_user: str,
|
||||
db_pass: str,
|
||||
ipfs_node: str,
|
||||
key: str,
|
||||
explorer_domain: str,
|
||||
ipfs_domain: str
|
||||
):
|
||||
self.token = token
|
||||
self.account = account
|
||||
self.permission = permission
|
||||
self.node_url = node_url
|
||||
self.hyperion_url = hyperion_url
|
||||
self.db_host = db_host
|
||||
self.db_user = db_user
|
||||
self.db_pass = db_pass
|
||||
self.key = key
|
||||
self.explorer_domain = explorer_domain
|
||||
self.ipfs_domain = ipfs_domain
|
||||
|
||||
self.bot = AsyncTeleBot(token, exception_handler=SKYExceptionHandler)
|
||||
self.cleos = CLEOS(endpoint=node_url)
|
||||
self.cleos.load_abi('gpu.scd', GPU_CONTRACT_ABI)
|
||||
self.hyperion = HyperionAPI(hyperion_url)
|
||||
self.ipfs_node = AsyncIPFSHTTP(ipfs_node)
|
||||
|
||||
self._async_exit_stack = AsyncExitStack()
|
||||
|
||||
async def start(self):
|
||||
self.db_call = await self._async_exit_stack.enter_async_context(
|
||||
open_database_connection(
|
||||
self.db_user, self.db_pass, self.db_host))
|
||||
|
||||
create_handler_context(self)
|
||||
|
||||
async def stop(self):
|
||||
await self._async_exit_stack.aclose()
|
||||
|
||||
@acm
|
||||
async def open(self):
|
||||
await self.start()
|
||||
yield self
|
||||
await self.stop()
|
||||
|
||||
async def update_status_message(
|
||||
self, status_msg, new_text: str, **kwargs
|
||||
):
|
||||
await self.db_call(
|
||||
'update_user_request_by_sid', status_msg.id, new_text)
|
||||
return await self.bot.edit_message_text(
|
||||
new_text,
|
||||
chat_id=status_msg.chat.id,
|
||||
message_id=status_msg.id,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def append_status_message(
|
||||
self, status_msg, add_text: str, **kwargs
|
||||
):
|
||||
request = await self.db_call('get_user_request_by_sid', status_msg.id)
|
||||
await self.update_status_message(
|
||||
status_msg,
|
||||
request['status'] + add_text,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def work_request(
|
||||
self,
|
||||
user,
|
||||
status_msg,
|
||||
method: str,
|
||||
params: dict,
|
||||
file_id: str | None = None,
|
||||
inputs: list[str] = []
|
||||
) -> bool:
|
||||
if params['seed'] == None:
|
||||
params['seed'] = random.randint(0, 0xFFFFFFFF)
|
||||
|
||||
sanitized_params = {}
|
||||
for key, val in params.items():
|
||||
if isinstance(val, Decimal):
|
||||
val = str(val)
|
||||
|
||||
sanitized_params[key] = val
|
||||
|
||||
body = json.dumps({
|
||||
'method': 'diffuse',
|
||||
'params': sanitized_params
|
||||
})
|
||||
request_time = datetime.now().isoformat()
|
||||
|
||||
await self.update_status_message(
|
||||
status_msg,
|
||||
f'processing a \'{method}\' request by {tg_user_pretty(user)}\n'
|
||||
f'[{timestamp_pretty()}] <i>broadcasting transaction to chain...</i>',
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
reward = '20.0000 GPU'
|
||||
res = await self.cleos.a_push_action(
|
||||
'gpu.scd',
|
||||
'enqueue',
|
||||
list({
|
||||
'user': Name(self.account),
|
||||
'request_body': body,
|
||||
'binary_data': ','.join(inputs),
|
||||
'reward': Asset.from_str(reward),
|
||||
'min_verification': 1
|
||||
}.values()),
|
||||
self.account, self.key, permission=self.permission
|
||||
)
|
||||
|
||||
if 'code' in res or 'statusCode' in res:
|
||||
logging.error(json.dumps(res, indent=4))
|
||||
await self.update_status_message(
|
||||
status_msg,
|
||||
'skynet has suffered an internal error trying to fill this request')
|
||||
return False
|
||||
|
||||
enqueue_tx_id = res['transaction_id']
|
||||
enqueue_tx_link = hlink(
|
||||
'Your request on Skynet Explorer',
|
||||
f'https://{self.explorer_domain}/v2/explore/transaction/{enqueue_tx_id}'
|
||||
)
|
||||
|
||||
await self.append_status_message(
|
||||
status_msg,
|
||||
f' <b>broadcasted!</b>\n'
|
||||
f'<b>{enqueue_tx_link}</b>\n'
|
||||
f'[{timestamp_pretty()}] <i>workers are processing request...</i>',
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
out = res['processed']['action_traces'][0]['console']
|
||||
|
||||
request_id, nonce = out.split(':')
|
||||
|
||||
request_hash = sha256(
|
||||
(nonce + body + ','.join(inputs)).encode('utf-8')).hexdigest().upper()
|
||||
|
||||
request_id = int(request_id)
|
||||
|
||||
logging.info(f'{request_id} enqueued.')
|
||||
|
||||
tx_hash = None
|
||||
ipfs_hash = None
|
||||
for i in range(60 * 3):
|
||||
try:
|
||||
submits = await self.hyperion.aget_actions(
|
||||
account=self.account,
|
||||
filter='gpu.scd:submit',
|
||||
sort='desc',
|
||||
after=request_time
|
||||
)
|
||||
actions = [
|
||||
action
|
||||
for action in submits['actions']
|
||||
if action[
|
||||
'act']['data']['request_hash'] == request_hash
|
||||
]
|
||||
if len(actions) > 0:
|
||||
tx_hash = actions[0]['trx_id']
|
||||
data = actions[0]['act']['data']
|
||||
ipfs_hash = data['ipfs_hash']
|
||||
worker = data['worker']
|
||||
logging.info('Found matching submit!')
|
||||
break
|
||||
|
||||
except JSONDecodeError:
|
||||
logging.error(f'network error while getting actions, retry..')
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if not ipfs_hash:
|
||||
await self.update_status_message(
|
||||
status_msg,
|
||||
f'\n[{timestamp_pretty()}] <b>timeout processing request</b>',
|
||||
parse_mode='HTML'
|
||||
)
|
||||
return False
|
||||
|
||||
tx_link = hlink(
|
||||
'Your result on Skynet Explorer',
|
||||
f'https://{self.explorer_domain}/v2/explore/transaction/{tx_hash}'
|
||||
)
|
||||
|
||||
await self.append_status_message(
|
||||
status_msg,
|
||||
f' <b>request processed!</b>\n'
|
||||
f'<b>{tx_link}</b>\n'
|
||||
f'[{timestamp_pretty()}] <i>trying to download image...</i>\n',
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
caption = generate_reply_caption(
|
||||
user, params, tx_hash, worker, reward, self.explorer_domain)
|
||||
|
||||
# attempt to get the image and send it
|
||||
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
|
||||
|
||||
res = await get_ipfs_file(ipfs_link)
|
||||
logging.info(f'got response from {ipfs_link}')
|
||||
if not res or res.status_code != 200:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {ipfs_link}!')
|
||||
|
||||
else:
|
||||
try:
|
||||
with Image.open(io.BytesIO(res.raw)) as image:
|
||||
w, h = image.size
|
||||
|
||||
if w > TG_MAX_WIDTH or h > TG_MAX_HEIGHT:
|
||||
logging.warning(f'result is of size {image.size}')
|
||||
image.thumbnail((TG_MAX_WIDTH, TG_MAX_HEIGHT))
|
||||
|
||||
tmp_buf = io.BytesIO()
|
||||
image.save(tmp_buf, format='PNG')
|
||||
png_img = tmp_buf.getvalue()
|
||||
|
||||
except UnidentifiedImageError:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {ipfs_link}!')
|
||||
|
||||
if not png_img:
|
||||
await self.update_status_message(
|
||||
status_msg,
|
||||
caption,
|
||||
reply_markup=build_redo_menu(),
|
||||
parse_mode='HTML'
|
||||
)
|
||||
return True
|
||||
|
||||
logging.info(f'success! sending generated image')
|
||||
await self.bot.delete_message(
|
||||
chat_id=status_msg.chat.id, message_id=status_msg.id)
|
||||
if file_id: # img2img
|
||||
await self.bot.send_media_group(
|
||||
status_msg.chat.id,
|
||||
media=[
|
||||
InputMediaPhoto(file_id),
|
||||
InputMediaPhoto(
|
||||
png_img,
|
||||
caption=caption,
|
||||
parse_mode='HTML'
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
else: # txt2img
|
||||
await self.bot.send_photo(
|
||||
status_msg.chat.id,
|
||||
caption=caption,
|
||||
photo=png_img,
|
||||
reply_markup=build_redo_menu(),
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
return True
|
|
@ -1,365 +0,0 @@
|
|||
import io
|
||||
import json
|
||||
import logging
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from PIL import Image
|
||||
from telebot.types import CallbackQuery, Message
|
||||
|
||||
from skynet.frontend import validate_user_config_request, perform_auto_conf
|
||||
from skynet.constants import *
|
||||
|
||||
|
||||
def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
||||
|
||||
bot = frontend.bot
|
||||
cleos = frontend.cleos
|
||||
db_call = frontend.db_call
|
||||
work_request = frontend.work_request
|
||||
|
||||
ipfs_node = frontend.ipfs_node
|
||||
|
||||
# generic / simple handlers
|
||||
|
||||
@bot.message_handler(commands=['help'])
|
||||
async def send_help(message):
|
||||
splt_msg = message.text.split(' ')
|
||||
|
||||
if len(splt_msg) == 1:
|
||||
await bot.reply_to(message, HELP_TEXT)
|
||||
|
||||
else:
|
||||
param = splt_msg[1]
|
||||
if param in HELP_TOPICS:
|
||||
await bot.reply_to(message, HELP_TOPICS[param])
|
||||
|
||||
else:
|
||||
await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
|
||||
|
||||
@bot.message_handler(commands=['cool'])
|
||||
async def send_cool_words(message):
|
||||
await bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||
|
||||
@bot.message_handler(commands=['queue'])
|
||||
async def queue(message):
|
||||
an_hour_ago = datetime.now() - timedelta(hours=1)
|
||||
queue = await cleos.aget_table(
|
||||
'gpu.scd', 'gpu.scd', 'queue',
|
||||
index_position=2,
|
||||
key_type='i64',
|
||||
sort='desc',
|
||||
lower_bound=int(an_hour_ago.timestamp())
|
||||
)
|
||||
await bot.reply_to(
|
||||
message, f'Total requests on skynet queue: {len(queue)}')
|
||||
|
||||
|
||||
@bot.message_handler(commands=['config'])
|
||||
async def set_config(message):
|
||||
user = message.from_user.id
|
||||
try:
|
||||
attr, val, reply_txt = validate_user_config_request(
|
||||
message.text)
|
||||
|
||||
logging.info(f'user config update: {attr} to {val}')
|
||||
await db_call('update_user_config', user, attr, val)
|
||||
logging.info('done')
|
||||
|
||||
except BaseException as e:
|
||||
reply_txt = str(e)
|
||||
|
||||
finally:
|
||||
await bot.reply_to(message, reply_txt)
|
||||
|
||||
@bot.message_handler(commands=['stats'])
|
||||
async def user_stats(message):
|
||||
user = message.from_user.id
|
||||
|
||||
await db_call('get_or_create_user', user)
|
||||
generated, joined, role = await db_call('get_user_stats', user)
|
||||
|
||||
stats_str = f'generated: {generated}\n'
|
||||
stats_str += f'joined: {joined}\n'
|
||||
stats_str += f'role: {role}\n'
|
||||
|
||||
await bot.reply_to(
|
||||
message, stats_str)
|
||||
|
||||
@bot.message_handler(commands=['donate'])
|
||||
async def donation_info(message):
|
||||
await bot.reply_to(
|
||||
message, DONATION_INFO)
|
||||
|
||||
@bot.message_handler(commands=['say'])
|
||||
async def say(message):
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
|
||||
if (chat.type == 'group') or (user.id != 383385940):
|
||||
return
|
||||
|
||||
await bot.send_message(GROUP_ID, message.text[4:])
|
||||
|
||||
|
||||
# generic txt2img handler
|
||||
|
||||
async def _generic_txt2img(message_or_query):
|
||||
if isinstance(message_or_query, CallbackQuery):
|
||||
query = message_or_query
|
||||
message = query.message
|
||||
user = query.from_user
|
||||
chat = query.message.chat
|
||||
|
||||
else:
|
||||
message = message_or_query
|
||||
user = message.from_user
|
||||
chat = message.chat
|
||||
|
||||
if chat.type == 'private':
|
||||
return
|
||||
|
||||
reply_id = None
|
||||
if chat.type == 'group' and chat.id == GROUP_ID:
|
||||
reply_id = message.message_id
|
||||
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
|
||||
# init new msg
|
||||
init_msg = 'started processing txt2img request...'
|
||||
status_msg = await bot.reply_to(message, init_msg)
|
||||
await db_call(
|
||||
'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
||||
|
||||
prompt = ' '.join(message.text.split(' ')[1:])
|
||||
|
||||
if len(prompt) == 0:
|
||||
await bot.edit_message_text(
|
||||
'Empty text prompt ignored.',
|
||||
chat_id=status_msg.chat.id,
|
||||
message_id=status_msg.id
|
||||
)
|
||||
await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
logging.info(f'mid: {message.id}')
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
if user_config['autoconf']:
|
||||
user_config = perform_auto_conf(user_config)
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
|
||||
await db_call(
|
||||
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
|
||||
|
||||
success = await work_request(user, status_msg, 'txt2img', params)
|
||||
|
||||
if success:
|
||||
await db_call('increment_generated', user.id)
|
||||
|
||||
|
||||
# generic img2img handler
|
||||
|
||||
async def _generic_img2img(message_or_query):
|
||||
if isinstance(message_or_query, CallbackQuery):
|
||||
query = message_or_query
|
||||
message = query.message
|
||||
user = query.from_user
|
||||
chat = query.message.chat
|
||||
|
||||
else:
|
||||
message = message_or_query
|
||||
user = message.from_user
|
||||
chat = message.chat
|
||||
|
||||
if chat.type == 'private':
|
||||
return
|
||||
|
||||
reply_id = None
|
||||
if chat.type == 'group' and chat.id == GROUP_ID:
|
||||
reply_id = message.message_id
|
||||
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
|
||||
# init new msg
|
||||
init_msg = 'started processing txt2img request...'
|
||||
status_msg = await bot.reply_to(message, init_msg)
|
||||
await db_call(
|
||||
'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
||||
|
||||
if not message.caption.startswith('/img2img'):
|
||||
await bot.reply_to(
|
||||
message,
|
||||
'For image to image you need to add /img2img to the beggining of your caption'
|
||||
)
|
||||
return
|
||||
|
||||
prompt = ' '.join(message.caption.split(' ')[1:])
|
||||
|
||||
if len(prompt) == 0:
|
||||
await bot.reply_to(message, 'Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
file_id = message.photo[-1].file_id
|
||||
file_path = (await bot.get_file(file_id)).file_path
|
||||
image_raw = await bot.download_file(file_path)
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
if user_config['autoconf']:
|
||||
user_config = perform_auto_conf(user_config)
|
||||
|
||||
with Image.open(io.BytesIO(image_raw)) as image:
|
||||
w, h = image.size
|
||||
|
||||
if w > user_config['width'] or h > user_config['height']:
|
||||
logging.warning(f'user sent img of size {image.size}')
|
||||
image.thumbnail(
|
||||
(user_config['width'], user_config['height']))
|
||||
logging.warning(f'resized it to {image.size}')
|
||||
|
||||
image_loc = 'ipfs-staging/image.png'
|
||||
image.save(image_loc, format='PNG')
|
||||
|
||||
ipfs_info = await ipfs_node.add(image_loc)
|
||||
ipfs_hash = ipfs_info['Hash']
|
||||
await ipfs_node.pin(ipfs_hash)
|
||||
|
||||
logging.info(f'published input image {ipfs_hash} on ipfs')
|
||||
|
||||
logging.info(f'mid: {message.id}')
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
|
||||
await db_call(
|
||||
'update_user_stats',
|
||||
user.id,
|
||||
'img2img',
|
||||
last_file=file_id,
|
||||
last_prompt=prompt,
|
||||
last_binary=ipfs_hash
|
||||
)
|
||||
|
||||
success = await work_request(
|
||||
user, status_msg, 'img2img', params,
|
||||
file_id=file_id,
|
||||
inputs=ipfs_hash
|
||||
)
|
||||
|
||||
if success:
|
||||
await db_call('increment_generated', user.id)
|
||||
|
||||
|
||||
# generic redo handler
|
||||
|
||||
async def _redo(message_or_query):
|
||||
is_query = False
|
||||
if isinstance(message_or_query, CallbackQuery):
|
||||
is_query = True
|
||||
query = message_or_query
|
||||
message = query.message
|
||||
user = query.from_user
|
||||
chat = query.message.chat
|
||||
|
||||
elif isinstance(message_or_query, Message):
|
||||
message = message_or_query
|
||||
user = message.from_user
|
||||
chat = message.chat
|
||||
|
||||
if chat.type == 'private':
|
||||
return
|
||||
|
||||
init_msg = 'started processing redo request...'
|
||||
if is_query:
|
||||
status_msg = await bot.send_message(chat.id, init_msg)
|
||||
|
||||
else:
|
||||
status_msg = await bot.reply_to(message, init_msg)
|
||||
|
||||
method = await db_call('get_last_method_of', user.id)
|
||||
prompt = await db_call('get_last_prompt_of', user.id)
|
||||
|
||||
file_id = None
|
||||
binary = ''
|
||||
if method == 'img2img':
|
||||
file_id = await db_call('get_last_file_of', user.id)
|
||||
binary = await db_call('get_last_binary_of', user.id)
|
||||
|
||||
if not prompt:
|
||||
await bot.reply_to(
|
||||
message,
|
||||
'no last prompt found, do a txt2img cmd first!'
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
await db_call(
|
||||
'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
if user_config['autoconf']:
|
||||
user_config = perform_auto_conf(user_config)
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
|
||||
success = await work_request(
|
||||
user, status_msg, 'redo', params,
|
||||
file_id=file_id,
|
||||
inputs=binary
|
||||
)
|
||||
|
||||
if success:
|
||||
await db_call('increment_generated', user.id)
|
||||
|
||||
|
||||
# "proxy" handlers just request routers
|
||||
|
||||
@bot.message_handler(commands=['txt2img'])
|
||||
async def send_txt2img(message):
|
||||
await _generic_txt2img(message)
|
||||
|
||||
@bot.message_handler(func=lambda message: True, content_types=[
|
||||
'photo', 'document'])
|
||||
async def send_img2img(message):
|
||||
await _generic_img2img(message)
|
||||
|
||||
@bot.message_handler(commands=['img2img'])
|
||||
async def img2img_missing_image(message):
|
||||
await bot.reply_to(
|
||||
message,
|
||||
'seems you tried to do an img2img command without sending image'
|
||||
)
|
||||
|
||||
@bot.message_handler(commands=['redo'])
|
||||
async def redo(message):
|
||||
await _redo(message)
|
||||
|
||||
@bot.callback_query_handler(func=lambda call: True)
|
||||
async def callback_query(call):
|
||||
msg = json.loads(call.data)
|
||||
logging.info(call.data)
|
||||
method = msg.get('method')
|
||||
match method:
|
||||
case 'redo':
|
||||
await _redo(call)
|
||||
|
||||
|
||||
# catch all handler for things we dont support
|
||||
|
||||
@bot.message_handler(func=lambda message: True)
|
||||
async def echo_message(message):
|
||||
if message.text[0] == '/':
|
||||
await bot.reply_to(message, UNKNOWN_CMD_TEXT)
|
|
@ -1,105 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from telebot.types import InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telebot.async_telebot import ExceptionHandler
|
||||
from telebot.formatting import hlink
|
||||
|
||||
from skynet.constants import *
|
||||
|
||||
|
||||
def timestamp_pretty():
|
||||
return datetime.now(timezone.utc).strftime('%H:%M:%S')
|
||||
|
||||
|
||||
def tg_user_pretty(tguser):
|
||||
if tguser.username:
|
||||
return f'@{tguser.username}'
|
||||
else:
|
||||
return f'{tguser.first_name} id: {tguser.id}'
|
||||
|
||||
|
||||
class SKYExceptionHandler(ExceptionHandler):
|
||||
|
||||
def handle(exception):
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def build_redo_menu():
|
||||
btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'}))
|
||||
inline_keyboard = InlineKeyboardMarkup()
|
||||
inline_keyboard.add(btn_redo)
|
||||
return inline_keyboard
|
||||
|
||||
|
||||
def prepare_metainfo_caption(tguser, worker: str, reward: str, meta: dict) -> str:
|
||||
prompt = meta["prompt"]
|
||||
if len(prompt) > 256:
|
||||
prompt = prompt[:256]
|
||||
|
||||
|
||||
meta_str = f'<u>by {tg_user_pretty(tguser)}</u>\n'
|
||||
meta_str += f'<i>performed by {worker}</i>\n'
|
||||
meta_str += f'<b><u>reward: {reward}</u></b>\n'
|
||||
|
||||
meta_str += f'<code>prompt:</code> {prompt}\n'
|
||||
meta_str += f'<code>seed: {meta["seed"]}</code>\n'
|
||||
meta_str += f'<code>step: {meta["step"]}</code>\n'
|
||||
meta_str += f'<code>guidance: {meta["guidance"]}</code>\n'
|
||||
if meta['strength']:
|
||||
meta_str += f'<code>strength: {meta["strength"]}</code>\n'
|
||||
meta_str += f'<code>algo: {meta["model"]}</code>\n'
|
||||
if meta['upscaler']:
|
||||
meta_str += f'<code>upscaler: {meta["upscaler"]}</code>\n'
|
||||
|
||||
meta_str += f'<b><u>Made with Skynet v{VERSION}</u></b>\n'
|
||||
meta_str += f'<b>JOIN THE SWARM: @skynetgpu</b>'
|
||||
return meta_str
|
||||
|
||||
|
||||
def generate_reply_caption(
|
||||
tguser, # telegram user
|
||||
params: dict,
|
||||
tx_hash: str,
|
||||
worker: str,
|
||||
reward: str,
|
||||
explorer_domain: str
|
||||
):
|
||||
explorer_link = hlink(
|
||||
'SKYNET Transaction Explorer',
|
||||
f'https://{explorer_domain}/v2/explore/transaction/{tx_hash}'
|
||||
)
|
||||
|
||||
meta_info = prepare_metainfo_caption(tguser, worker, reward, params)
|
||||
|
||||
final_msg = '\n'.join([
|
||||
'Worker finished your task!',
|
||||
explorer_link,
|
||||
f'PARAMETER INFO:\n{meta_info}'
|
||||
])
|
||||
|
||||
final_msg = '\n'.join([
|
||||
f'<b><i>{explorer_link}</i></b>',
|
||||
f'{meta_info}'
|
||||
])
|
||||
|
||||
logging.info(final_msg)
|
||||
|
||||
return final_msg
|
||||
|
||||
|
||||
async def get_global_config(cleos):
|
||||
return (await cleos.aget_table(
|
||||
'gpu.scd', 'gpu.scd', 'config'))[0]
|
||||
|
||||
async def get_user_nonce(cleos, user: str):
|
||||
return (await cleos.aget_table(
|
||||
'gpu.scd', 'gpu.scd', 'users',
|
||||
index_position=1,
|
||||
key_type='name',
|
||||
lower_bound=user,
|
||||
upper_bound=user
|
||||
))[0]['nonce']
|
Loading…
Reference in New Issue