mirror of https://github.com/skygpu/skynet.git
Compare commits
6 Commits
eeb27d5bbf
...
480e7236d6
Author | SHA1 | Date |
---|---|---|
|
480e7236d6 | |
|
d7ee0cb660 | |
|
7a4387a064 | |
|
8746a5c75b | |
|
9da4736976 | |
|
9f4122fef5 |
|
@ -77,7 +77,8 @@ explicit = true
|
|||
torch = { index = "torch" }
|
||||
triton = { index = "torch" }
|
||||
torchvision = { index = "torch" }
|
||||
py-leap = { git = "https://github.com/guilledk/py-leap.git", rev = "v0.1a35" }
|
||||
py-leap = { git = "https://github.com/guilledk/py-leap.git", branch = "struct_unwrap" }
|
||||
# py-leap = { path = "../py-leap", editable = true }
|
||||
pytest-dockerctl = { git = "https://github.com/pikers/pytest-dockerctl.git", branch = "g_update" }
|
||||
|
||||
[build-system]
|
||||
|
|
119
skynet/cli.py
119
skynet/cli.py
|
@ -165,7 +165,7 @@ def run(*args, **kwargs):
|
|||
|
||||
@run.command()
|
||||
def db():
|
||||
from .db import open_new_database
|
||||
from skynet.frontend.chatbot.db import open_new_database
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
with open_new_database(cleanup=False) as db_params:
|
||||
|
@ -197,125 +197,52 @@ def dgpu(
|
|||
|
||||
@run.command()
|
||||
@click.option('--loglevel', '-l', default='INFO', help='logging level')
|
||||
@click.option(
|
||||
'--db-host', '-h', default='localhost:5432')
|
||||
@click.option(
|
||||
'--db-user', '-u', default='skynet')
|
||||
@click.option(
|
||||
'--db-pass', '-u', default='password')
|
||||
def telegram(
|
||||
loglevel: str,
|
||||
db_host: str,
|
||||
db_user: str,
|
||||
db_pass: str
|
||||
):
|
||||
import asyncio
|
||||
from .frontend.telegram import SkynetTelegramFrontend
|
||||
from skynet.frontend.chatbot.telegram import TelegramChatbot
|
||||
from skynet.frontend.chatbot.db import FrontendUserDB
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml()
|
||||
tg_token = config.telegram.tg_token
|
||||
|
||||
key = config.telegram.key
|
||||
account = config.telegram.account
|
||||
permission = config.telegram.permission
|
||||
node_url = config.telegram.node_url
|
||||
hyperion_url = config.telegram.hyperion_url
|
||||
|
||||
ipfs_url = config.telegram.ipfs_url
|
||||
|
||||
try:
|
||||
explorer_domain = config.telegram.explorer_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
explorer_domain = DEFAULT_EXPLORER_DOMAIN
|
||||
|
||||
try:
|
||||
ipfs_domain = config.telegram.ipfs_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
ipfs_domain = DEFAULT_IPFS_DOMAIN
|
||||
config = load_skynet_toml().telegram
|
||||
|
||||
async def _async_main():
|
||||
frontend = SkynetTelegramFrontend(
|
||||
tg_token,
|
||||
account,
|
||||
permission,
|
||||
node_url,
|
||||
hyperion_url,
|
||||
db_host, db_user, db_pass,
|
||||
ipfs_url,
|
||||
key=key,
|
||||
explorer_domain=explorer_domain,
|
||||
ipfs_domain=ipfs_domain
|
||||
)
|
||||
|
||||
async with frontend.open():
|
||||
await frontend.bot.infinity_polling()
|
||||
|
||||
async with FrontendUserDB(
|
||||
config.db_user,
|
||||
config.db_pass,
|
||||
config.db_host,
|
||||
config.db_name
|
||||
) as db:
|
||||
bot = TelegramChatbot(config, db)
|
||||
await bot.run()
|
||||
|
||||
asyncio.run(_async_main())
|
||||
|
||||
|
||||
@run.command()
|
||||
@click.option('--loglevel', '-l', default='INFO', help='logging level')
|
||||
@click.option(
|
||||
'--db-host', '-h', default='localhost:5432')
|
||||
@click.option(
|
||||
'--db-user', '-u', default='skynet')
|
||||
@click.option(
|
||||
'--db-pass', '-u', default='password')
|
||||
def discord(
|
||||
loglevel: str,
|
||||
db_host: str,
|
||||
db_user: str,
|
||||
db_pass: str
|
||||
):
|
||||
import asyncio
|
||||
from .frontend.discord import SkynetDiscordFrontend
|
||||
from skynet.frontend.chatbot.discord import DiscordChatbot
|
||||
from skynet.frontend.chatbot.db import FrontendUserDB
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml()
|
||||
dc_token = config.discord.dc_token
|
||||
|
||||
key = config.discord.key
|
||||
account = config.discord.account
|
||||
permission = config.discord.permission
|
||||
node_url = config.discord.node_url
|
||||
hyperion_url = config.discord.hyperion_url
|
||||
|
||||
ipfs_url = config.discord.ipfs_url
|
||||
|
||||
try:
|
||||
explorer_domain = config.discord.explorer_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
explorer_domain = DEFAULT_EXPLORER_DOMAIN
|
||||
|
||||
try:
|
||||
ipfs_domain = config.discord.ipfs_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
ipfs_domain = DEFAULT_IPFS_DOMAIN
|
||||
config = load_skynet_toml().discord
|
||||
|
||||
async def _async_main():
|
||||
frontend = SkynetDiscordFrontend(
|
||||
# dc_token,
|
||||
account,
|
||||
permission,
|
||||
node_url,
|
||||
hyperion_url,
|
||||
db_host, db_user, db_pass,
|
||||
ipfs_url,
|
||||
key=key,
|
||||
explorer_domain=explorer_domain,
|
||||
ipfs_domain=ipfs_domain
|
||||
)
|
||||
|
||||
async with frontend.open():
|
||||
await frontend.bot.start(dc_token)
|
||||
async with FrontendUserDB(
|
||||
config.db_user,
|
||||
config.db_pass,
|
||||
config.db_host,
|
||||
config.db_name
|
||||
) as db:
|
||||
bot = DiscordChatbot(config, db)
|
||||
await bot.run()
|
||||
|
||||
asyncio.run(_async_main())
|
||||
|
||||
|
|
|
@ -32,10 +32,22 @@ class FrontendConfig(msgspec.Struct):
|
|||
account: str
|
||||
permission: str
|
||||
key: str
|
||||
node_url: str
|
||||
hyperion_url: str
|
||||
ipfs_url: str
|
||||
token: str
|
||||
db_host: str
|
||||
db_user: str
|
||||
db_pass: str
|
||||
db_name: str = 'skynet'
|
||||
node_url: str = 'https://testnet.telos.net'
|
||||
hyperion_url: str = 'https://testnet.skygpu.net'
|
||||
ipfs_domain: str = 'ipfs.skygpu.net'
|
||||
explorer_domain: str = 'explorer.skygpu.net'
|
||||
request_timeout: int = 60 * 3
|
||||
proto_version: int = 0
|
||||
reward: str = '20.0000 GPU'
|
||||
receiver: str = 'gpu.scd'
|
||||
result_max_width: int = 1280
|
||||
result_max_height: int = 1280
|
||||
|
||||
|
||||
class PinnerConfig(msgspec.Struct):
|
||||
|
|
|
@ -14,91 +14,91 @@ MODELS: dict[str, ModelDesc] = {
|
|||
'runwayml/stable-diffusion-v1-5': ModelDesc(
|
||||
short='stable',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'stabilityai/stable-diffusion-2-1-base': ModelDesc(
|
||||
short='stable2',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'snowkidy/stable-diffusion-xl-base-0.9': ModelDesc(
|
||||
short='stablexl0.9',
|
||||
mem=8.3,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
attrs={'size': {'w': 1024, 'h': 1024}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'Linaqruf/anything-v3.0': ModelDesc(
|
||||
short='hdanime',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'hakurei/waifu-diffusion': ModelDesc(
|
||||
short='waifu',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'nitrosocke/Ghibli-Diffusion': ModelDesc(
|
||||
short='ghibli',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'dallinmackay/Van-Gogh-diffusion': ModelDesc(
|
||||
short='van-gogh',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'lambdalabs/sd-pokemon-diffusers': ModelDesc(
|
||||
short='pokemon',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'Envvi/Inkpunk-Diffusion': ModelDesc(
|
||||
short='ink',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'nousr/robo-diffusion': ModelDesc(
|
||||
short='robot',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'black-forest-labs/FLUX.1-schnell': ModelDesc(
|
||||
short='flux',
|
||||
mem=24,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
attrs={'size': {'w': 1024, 'h': 1024}, 'step': 4},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'black-forest-labs/FLUX.1-Fill-dev': ModelDesc(
|
||||
short='flux-inpaint',
|
||||
mem=24,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
attrs={'size': {'w': 1024, 'h': 1024}, 'step': 28},
|
||||
tags=['inpaint']
|
||||
),
|
||||
'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': ModelDesc(
|
||||
short='stablexl-inpaint',
|
||||
mem=8.3,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
attrs={'size': {'w': 1024, 'h': 1024}, 'step': 28},
|
||||
tags=['inpaint']
|
||||
),
|
||||
'prompthero/openjourney': ModelDesc(
|
||||
short='midj',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
|
||||
tags=['txt2img', 'img2img']
|
||||
),
|
||||
'stabilityai/stable-diffusion-xl-base-1.0': ModelDesc(
|
||||
short='stablexl',
|
||||
mem=8.3,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
attrs={'size': {'w': 1024, 'h': 1024}, 'step': 28},
|
||||
tags=['txt2img']
|
||||
),
|
||||
}
|
||||
|
@ -225,8 +225,6 @@ Noise is added to the image you use as an init image for img2img, and then the\
|
|||
|
||||
HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.'
|
||||
|
||||
GROUP_ID = -1001541979235
|
||||
|
||||
MP_ENABLED_ROLES = ['god']
|
||||
|
||||
MIN_STEP = 1
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
import random
|
||||
|
||||
from ..constants import *
|
||||
from ..constants import (
|
||||
MODELS,
|
||||
get_model_by_shortname,
|
||||
MAX_STEP, MIN_STEP, MAX_WIDTH, MAX_HEIGHT, MAX_GUIDANCE
|
||||
)
|
||||
|
||||
|
||||
class ConfigRequestFormatError(BaseException):
|
||||
|
@ -26,17 +30,17 @@ class ConfigSizeDivisionByEight(BaseException):
|
|||
def validate_user_config_request(req: str):
|
||||
params = req.split(' ')
|
||||
|
||||
if len(params) < 3:
|
||||
if len(params) < 2:
|
||||
raise ConfigRequestFormatError('config request format incorrect')
|
||||
|
||||
else:
|
||||
try:
|
||||
attr = params[1]
|
||||
attr = params[0]
|
||||
|
||||
match attr:
|
||||
case 'model' | 'algo':
|
||||
attr = 'model'
|
||||
val = params[2]
|
||||
val = params[1]
|
||||
shorts = [model_info.short for model_info in MODELS.values()]
|
||||
if val not in shorts:
|
||||
raise ConfigUnknownAlgorithm(f'no model named {val}')
|
||||
|
@ -44,38 +48,38 @@ def validate_user_config_request(req: str):
|
|||
val = get_model_by_shortname(val)
|
||||
|
||||
case 'step':
|
||||
val = int(params[2])
|
||||
val = int(params[1])
|
||||
val = max(min(val, MAX_STEP), MIN_STEP)
|
||||
|
||||
case 'width':
|
||||
val = max(min(int(params[2]), MAX_WIDTH), 16)
|
||||
val = max(min(int(params[1]), MAX_WIDTH), 16)
|
||||
if val % 8 != 0:
|
||||
raise ConfigSizeDivisionByEight(
|
||||
'size must be divisible by 8!')
|
||||
|
||||
case 'height':
|
||||
val = max(min(int(params[2]), MAX_HEIGHT), 16)
|
||||
val = max(min(int(params[1]), MAX_HEIGHT), 16)
|
||||
if val % 8 != 0:
|
||||
raise ConfigSizeDivisionByEight(
|
||||
'size must be divisible by 8!')
|
||||
|
||||
case 'seed':
|
||||
val = params[2]
|
||||
val = params[1]
|
||||
if val == 'auto':
|
||||
val = None
|
||||
else:
|
||||
val = int(params[2])
|
||||
val = int(params[1])
|
||||
|
||||
case 'guidance':
|
||||
val = float(params[2])
|
||||
val = float(params[1])
|
||||
val = max(min(val, MAX_GUIDANCE), 0)
|
||||
|
||||
case 'strength':
|
||||
val = float(params[2])
|
||||
val = float(params[1])
|
||||
val = max(min(val, 0.99), 0.01)
|
||||
|
||||
case 'upscaler':
|
||||
val = params[2]
|
||||
val = params[1]
|
||||
if val == 'off':
|
||||
val = None
|
||||
elif val != 'x4':
|
||||
|
@ -83,7 +87,7 @@ def validate_user_config_request(req: str):
|
|||
f'\"{val}\" is not a valid upscaler')
|
||||
|
||||
case 'autoconf':
|
||||
val = params[2]
|
||||
val = params[1]
|
||||
if val == 'on':
|
||||
val = True
|
||||
|
||||
|
|
|
@ -0,0 +1,484 @@
|
|||
import io
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from random import randint
|
||||
from decimal import Decimal
|
||||
from hashlib import sha256
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import msgspec
|
||||
from leap import CLEOS
|
||||
from leap.hyperion import HyperionAPI
|
||||
|
||||
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
|
||||
from skynet.types import BodyV0, BodyV0Params
|
||||
from skynet.config import FrontendConfig
|
||||
from skynet.constants import (
|
||||
MODELS, GPU_CONTRACT_ABI,
|
||||
HELP_TEXT,
|
||||
HELP_TOPICS,
|
||||
HELP_UNKWNOWN_PARAM,
|
||||
COOL_WORDS,
|
||||
DONATION_INFO,
|
||||
UNKNOWN_CMD_TEXT
|
||||
)
|
||||
from skynet.frontend import validate_user_config_request
|
||||
from skynet.frontend.chatbot.db import FrontendUserDB
|
||||
from skynet.frontend.chatbot.types import (
|
||||
BaseUser,
|
||||
BaseChatRoom,
|
||||
BaseCommands,
|
||||
BaseFileInput,
|
||||
BaseMessage
|
||||
)
|
||||
|
||||
|
||||
def perform_auto_conf(config: dict) -> dict:
|
||||
model = MODELS[config['model']]
|
||||
|
||||
maybe_step = model.attrs.get('step', None)
|
||||
if maybe_step:
|
||||
config['step'] = maybe_step
|
||||
|
||||
maybe_width = model.attrs.get('width', None)
|
||||
if maybe_width:
|
||||
config['width'] = maybe_step
|
||||
|
||||
maybe_height = model.attrs.get('height', None)
|
||||
if maybe_height:
|
||||
config['height'] = maybe_step
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def sanitize_params(params: dict) -> dict:
|
||||
if (
|
||||
'seed' not in params
|
||||
or
|
||||
params['seed'] is None
|
||||
):
|
||||
params['seed'] = randint(0, 0xffffffff)
|
||||
|
||||
s_params = {}
|
||||
for key, val in params.items():
|
||||
if isinstance(val, Decimal):
|
||||
val = str(val)
|
||||
|
||||
s_params[key] = val
|
||||
|
||||
return s_params
|
||||
|
||||
|
||||
class RequestTimeoutError(BaseException):
|
||||
...
|
||||
|
||||
|
||||
class BaseChatbot(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FrontendConfig,
|
||||
db: FrontendUserDB
|
||||
):
|
||||
self.db = db
|
||||
self.config = config
|
||||
self.ipfs = AsyncIPFSHTTP(config.ipfs_url)
|
||||
self.cleos = CLEOS(endpoint=config.node_url)
|
||||
self.cleos.load_abi(config.receiver, GPU_CONTRACT_ABI)
|
||||
self.cleos.import_key(config.account, config.key)
|
||||
self.hyperion = HyperionAPI(config.hyperion_url)
|
||||
|
||||
async def init(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def run(self):
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def main_group(self) -> BaseChatRoom:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def new_msg(self, chat: BaseChatRoom, text: str, **kwargs) -> BaseMessage:
|
||||
'''
|
||||
Send text to a chat/channel.
|
||||
'''
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reply_to(self, msg: BaseMessage, text: str, **kwargs) -> BaseMessage:
|
||||
'''
|
||||
Reply to existing message by sending new message.
|
||||
'''
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def edit_msg(self, msg: BaseMessage, text: str, **kwargs):
|
||||
'''
|
||||
Edit an existing message.
|
||||
'''
|
||||
...
|
||||
|
||||
async def create_status_msg(self, msg: BaseMessage, init_text: str, force_user: BaseUser | None = None) -> tuple[BaseUser, BaseMessage, dict]:
|
||||
# maybe init user
|
||||
user = msg.author
|
||||
if force_user:
|
||||
user = force_user
|
||||
|
||||
user_row = await self.db.get_or_create_user(user.id)
|
||||
|
||||
# create status msg
|
||||
status_msg = await self.reply_to(msg, init_text)
|
||||
|
||||
# start tracking of request in db
|
||||
await self.db.new_user_request(user.id, msg.id, status_msg.id, status=init_text)
|
||||
return [user, status_msg, user_row]
|
||||
|
||||
async def update_status_msg(self, msg: BaseMessage, text: str):
|
||||
'''
|
||||
Update an existing status message, also mirrors changes on db
|
||||
'''
|
||||
await self.db.update_user_request_by_sid(msg.id, text)
|
||||
await self.edit_msg(msg, text)
|
||||
|
||||
async def append_status_msg(self, msg: BaseMessage, text: str):
|
||||
'''
|
||||
Append text to an existing status message
|
||||
'''
|
||||
request = await self.db.get_user_request_by_sid(msg.id)
|
||||
await self.update_status_msg(msg, request['status'] + text)
|
||||
|
||||
@abstractmethod
|
||||
async def update_request_status_timeout(self, status_msg: BaseMessage):
|
||||
'''
|
||||
Notify users when we timedout trying to find a matching submit
|
||||
'''
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_request_status_step_0(self, status_msg: BaseMessage, user_msg: BaseMessage):
|
||||
'''
|
||||
First step in request status message lifecycle, should notify which user sent the request
|
||||
and that we are about to broadcast the request to chain
|
||||
'''
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_request_status_step_1(self, status_msg: BaseMessage, tx_result: dict):
|
||||
'''
|
||||
Second step in request status message lifecycle, should notify enqueue transaction
|
||||
was processed by chain, and provide a link to the tx in the chain explorer
|
||||
'''
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_request_status_step_2(self, status_msg: BaseMessage, submit_tx_hash: str):
|
||||
'''
|
||||
Third step in request status message lifecycle, should notify matching submit transaction
|
||||
was found, and provide a link to the tx in the chain explorer
|
||||
'''
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_request_status_final(
|
||||
self,
|
||||
og_msg: BaseMessage,
|
||||
status_msg: BaseMessage,
|
||||
user: BaseUser,
|
||||
params: BodyV0Params,
|
||||
inputs: list[BaseFileInput],
|
||||
submit_tx_hash: str,
|
||||
worker: str,
|
||||
result_url: str,
|
||||
result_img: bytes | None
|
||||
):
|
||||
'''
|
||||
Last step in request status message lifecycle, should delete status message and send a
|
||||
new message replying to the original user's message, generate the appropiate
|
||||
reply caption and if provided also sent the found result img
|
||||
'''
|
||||
...
|
||||
|
||||
|
||||
async def handle_request(
|
||||
self,
|
||||
msg: BaseMessage,
|
||||
force_user: BaseUser | None = None
|
||||
):
|
||||
if msg.chat.is_private:
|
||||
return
|
||||
|
||||
if (
|
||||
len(msg.text) == 0
|
||||
and
|
||||
msg.command != BaseCommands.REDO
|
||||
):
|
||||
await self.reply_to(msg, 'empty prompt ignored.')
|
||||
return
|
||||
|
||||
# maybe initialize user db row and send a new msg thats gonna
|
||||
# be updated throughout the request lifecycle
|
||||
user, status_msg, user_row = await self.create_status_msg(
|
||||
msg, f'started processing a {msg.command} request...', force_user=force_user)
|
||||
|
||||
# if this is a redo msg, we attempt to get the input params from db
|
||||
# else use msg properties
|
||||
match msg.command:
|
||||
case BaseCommands.TXT2IMG | BaseCommands.IMG2IMG:
|
||||
prompt = msg.text
|
||||
command = msg.command
|
||||
inputs = msg.inputs
|
||||
|
||||
case BaseCommands.REDO:
|
||||
prompt = await self.db.get_last_prompt_of(user.id)
|
||||
command = await self.db.get_last_method_of(user.id)
|
||||
inputs = await self.db.get_last_inputs_of(user.id)
|
||||
|
||||
if not prompt:
|
||||
await self.reply_to(msg, 'no last prompt found, try doing a non-redo request first')
|
||||
return
|
||||
|
||||
case _:
|
||||
await self.reply_to(msg, f'unknown request of type {msg.command}')
|
||||
return
|
||||
|
||||
if (
|
||||
msg.command == BaseCommands.IMG2IMG
|
||||
and
|
||||
len(inputs) == 0
|
||||
):
|
||||
await self.edit_msg(status_msg, 'seems you tried to do an img2img command without sending image')
|
||||
return
|
||||
|
||||
# maybe apply recomended settings to this request
|
||||
del user_row['id']
|
||||
if user_row['autoconf']:
|
||||
user_row = perform_auto_conf(user_row)
|
||||
|
||||
user_row = sanitize_params(user_row)
|
||||
|
||||
body = BodyV0(
|
||||
method=command,
|
||||
params=BodyV0Params(
|
||||
prompt=prompt,
|
||||
**user_row
|
||||
)
|
||||
)
|
||||
|
||||
# publish inputs to ipfs
|
||||
input_cids = []
|
||||
for i in inputs:
|
||||
await i.publish(self.ipfs, user_row)
|
||||
input_cids.append(i.cid)
|
||||
|
||||
inputs_str = ','.join((i for i in input_cids))
|
||||
|
||||
# unless its a redo request, update db user data
|
||||
if command != BaseCommands.REDO:
|
||||
await self.db.update_user_stats(
|
||||
user.id,
|
||||
command,
|
||||
last_prompt=prompt,
|
||||
last_inputs=inputs
|
||||
)
|
||||
|
||||
await self.update_request_status_step_0(status_msg, msg)
|
||||
|
||||
# prepare and send enqueue request
|
||||
request_time = datetime.now().isoformat()
|
||||
str_body = msgspec.json.encode(body).decode('utf-8')
|
||||
|
||||
enqueue_receipt = await self.cleos.a_push_action(
|
||||
self.config.receiver,
|
||||
'enqueue',
|
||||
[
|
||||
self.config.account,
|
||||
str_body,
|
||||
inputs_str,
|
||||
self.config.reward,
|
||||
1
|
||||
],
|
||||
self.config.account,
|
||||
key=self.cleos.private_keys[self.config.account],
|
||||
permission=self.config.permission
|
||||
)
|
||||
|
||||
await self.update_request_status_step_1(status_msg, enqueue_receipt)
|
||||
|
||||
# wait and search submit request using hyperion endpoint
|
||||
console = enqueue_receipt['processed']['action_traces'][0]['console']
|
||||
console_lines = console.split('\n')
|
||||
|
||||
request_id = None
|
||||
request_hash = None
|
||||
if self.config.proto_version == 0:
|
||||
'''
|
||||
v0 has req_id:nonce printed in enqueue console output
|
||||
to search for a result request_hash arg on submit has
|
||||
to match the sha256 of nonce + body + input_str
|
||||
'''
|
||||
request_id, nonce = console_lines[-1].rstrip().split(':')
|
||||
request_hash = sha256(
|
||||
(nonce + str_body + inputs_str).encode('utf-8')).hexdigest().upper()
|
||||
|
||||
request_id = int(request_id)
|
||||
|
||||
elif self.config.proto_version == 1:
|
||||
'''
|
||||
v1 uses a global unique nonce and prints it on enqueue
|
||||
console output to search for a result request_id arg
|
||||
on submit has to match the printed req_id
|
||||
'''
|
||||
request_id = int(console_lines[-1].rstrip())
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
worker = None
|
||||
submit_tx_hash = None
|
||||
result_cid = None
|
||||
for i in range(1, self.config.request_timeout + 1):
|
||||
try:
|
||||
submits = await self.hyperion.aget_actions(
|
||||
account=self.config.account,
|
||||
filter=f'{self.config.receiver}:submit',
|
||||
sort='desc',
|
||||
after=request_time
|
||||
)
|
||||
if self.config.proto_version == 0:
|
||||
actions = [
|
||||
action
|
||||
for action in submits['actions']
|
||||
if action['act']['data']['request_hash'] == request_hash
|
||||
]
|
||||
elif self.config.proto_version == 1:
|
||||
actions = [
|
||||
action
|
||||
for action in submits['actions']
|
||||
if action['act']['data']['request_id'] == request_id
|
||||
]
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if len(actions) > 0:
|
||||
action = actions[0]
|
||||
submit_tx_hash = action['trx_id']
|
||||
data = action['act']['data']
|
||||
result_cid = data['ipfs_hash']
|
||||
worker = data['worker']
|
||||
logging.info(f'found matching submit! tx: {submit_tx_hash} cid: {result_cid}')
|
||||
break
|
||||
|
||||
except json.JSONDecodeError:
|
||||
if i < self.config.request_timeout:
|
||||
logging.error('network error while searching for submit, retry...')
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# if we found matching submit submit_tx_hash, worker, and result_cid will not be None
|
||||
if not result_cid:
|
||||
await self.update_request_status_timeout(status_msg)
|
||||
raise RequestTimeoutError
|
||||
|
||||
await self.update_request_status_step_2(status_msg, submit_tx_hash)
|
||||
|
||||
# attempt to get the image and send it
|
||||
result_link = f'https://{self.config.ipfs_domain}/ipfs/{result_cid}'
|
||||
get_img_response = await get_ipfs_file(result_link)
|
||||
|
||||
result_img = None
|
||||
if get_img_response and get_img_response.status_code == 200:
|
||||
try:
|
||||
with Image.open(io.BytesIO(get_img_response.read())) as img:
|
||||
w, h = img.size
|
||||
|
||||
if (
|
||||
w > self.config.result_max_width
|
||||
or
|
||||
h > self.config.result_max_height
|
||||
):
|
||||
max_size = (self.config.result_max_width, self.config.result_max_height)
|
||||
logging.warning(
|
||||
f'raw result is of size {img.size}, resizing to {max_size}')
|
||||
img.thumbnail(max_size)
|
||||
|
||||
tmp_buf = io.BytesIO()
|
||||
img.save(tmp_buf, format='PNG')
|
||||
result_img = tmp_buf.getvalue()
|
||||
|
||||
except UnidentifiedImageError:
|
||||
logging.warning(f'couldn\'t get ipfs result at {result_link}!')
|
||||
|
||||
await self.update_request_status_final(
|
||||
msg, status_msg, user, body.params, inputs, submit_tx_hash, worker, result_link, result_img)
|
||||
|
||||
await self.db.increment_generated(user.id)
|
||||
|
||||
async def send_help(self, msg: BaseMessage):
|
||||
if len(msg.text) == 0:
|
||||
await self.reply_to(msg, HELP_TEXT)
|
||||
|
||||
else:
|
||||
if msg.text in HELP_TOPICS:
|
||||
await self.reply_to(msg, HELP_TOPICS[msg.text])
|
||||
|
||||
else:
|
||||
await self.reply_to(msg, HELP_UNKWNOWN_PARAM)
|
||||
|
||||
async def send_cool_words(self, msg: BaseMessage):
|
||||
await self.reply_to(msg, '\n'.join(COOL_WORDS))
|
||||
|
||||
async def get_queue(self, msg: BaseMessage):
|
||||
an_hour_ago = datetime.now() - timedelta(hours=1)
|
||||
queue = await self.cleos.aget_table(
|
||||
self.config.receiver, self.config.receiver, 'queue',
|
||||
index_position=2,
|
||||
key_type='i64',
|
||||
sort='desc',
|
||||
lower_bound=int(an_hour_ago.timestamp())
|
||||
)
|
||||
await self.reply_to(
|
||||
msg, f'Requests on skynet queue: {len(queue)}')
|
||||
|
||||
async def set_config(self, msg: BaseMessage):
|
||||
try:
|
||||
attr, val, reply_txt = validate_user_config_request(msg.text)
|
||||
|
||||
await self.db.update_user_config(msg.author.id, attr, val)
|
||||
|
||||
except BaseException as e:
|
||||
reply_txt = str(e)
|
||||
|
||||
finally:
|
||||
await self.reply_to(msg, reply_txt)
|
||||
|
||||
async def user_stats(self, msg: BaseMessage):
|
||||
await self.db.get_or_create_user(msg.author.id)
|
||||
generated, joined, role = await self.db.get_user_stats(msg.author.id)
|
||||
|
||||
stats_str = f'generated: {generated}\n'
|
||||
stats_str += f'joined: {joined}\n'
|
||||
stats_str += f'role: {role}\n'
|
||||
|
||||
await self.reply_to(msg, stats_str)
|
||||
|
||||
async def donation_info(self, msg: BaseMessage):
|
||||
await self.reply_to(msg, DONATION_INFO)
|
||||
|
||||
async def say(self, msg: BaseMessage):
|
||||
if (
|
||||
msg.chat.is_private
|
||||
or
|
||||
not msg.author.is_admin
|
||||
):
|
||||
return
|
||||
|
||||
await self.new_msg(self.main_group, msg.text)
|
||||
|
||||
async def echo_unknown(self, msg: BaseMessage):
|
||||
await self.reply_to(msg, UNKNOWN_CMD_TEXT)
|
|
@ -0,0 +1,424 @@
|
|||
import logging
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import docker
|
||||
import psycopg2
|
||||
import asyncpg
|
||||
|
||||
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
|
||||
from contextlib import contextmanager as cm
|
||||
|
||||
from skynet.constants import (
|
||||
DEFAULT_ROLE, DEFAULT_MODEL, DEFAULT_STEP,
|
||||
DEFAULT_WIDTH, DEFAULT_HEIGHT, DEFAULT_GUIDANCE,
|
||||
DEFAULT_STRENGTH, DEFAULT_UPSCALER
|
||||
)
|
||||
from skynet.frontend.chatbot.types import BaseFileInput
|
||||
|
||||
DB_INIT_SQL = """
|
||||
CREATE SCHEMA IF NOT EXISTS skynet;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS skynet.user(
|
||||
id BIGSERIAL PRIMARY KEY NOT NULL,
|
||||
generated INT NOT NULL,
|
||||
joined TIMESTAMP NOT NULL,
|
||||
last_method TEXT,
|
||||
last_prompt TEXT,
|
||||
last_inputs TEXT,
|
||||
role VARCHAR(128) NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS skynet.user_config(
|
||||
id BIGSERIAL NOT NULL,
|
||||
model VARCHAR(512) NOT NULL,
|
||||
step INT NOT NULL,
|
||||
width INT NOT NULL,
|
||||
height INT NOT NULL,
|
||||
seed NUMERIC,
|
||||
guidance DECIMAL NOT NULL,
|
||||
strength DECIMAL NOT NULL,
|
||||
upscaler VARCHAR(128),
|
||||
autoconf BOOLEAN DEFAULT TRUE,
|
||||
CONSTRAINT fk_config
|
||||
FOREIGN KEY(id)
|
||||
REFERENCES skynet.user(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS skynet.user_requests(
|
||||
id BIGSERIAL NOT NULL,
|
||||
user_id BIGSERIAL NOT NULL,
|
||||
sent TIMESTAMP NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
status_msg BIGSERIAL PRIMARY KEY NOT NULL,
|
||||
CONSTRAINT fk_user_req
|
||||
FOREIGN KEY(user_id)
|
||||
REFERENCES skynet.user(id)
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
@cm
|
||||
def open_new_database(cleanup: bool = True):
|
||||
"""
|
||||
Context manager that spins up a temporary Postgres Docker container,
|
||||
creates a 'skynet' user and database, and yields (container, password, host).
|
||||
Stops the container on exit if 'cleanup' is True.
|
||||
"""
|
||||
root_password = "".join(random.choice(string.ascii_lowercase) for _ in range(12))
|
||||
skynet_password = "".join(random.choice(string.ascii_lowercase) for _ in range(12))
|
||||
|
||||
dclient = docker.from_env()
|
||||
container = dclient.containers.run(
|
||||
"postgres",
|
||||
name="skynet-test-postgres",
|
||||
ports={"5432/tcp": None},
|
||||
environment={"POSTGRES_PASSWORD": root_password},
|
||||
detach=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# Wait for Postgres to be ready
|
||||
for log_line in container.logs(stream=True):
|
||||
line = log_line.decode().rstrip()
|
||||
logging.info(line)
|
||||
if (
|
||||
"database system is ready to accept connections" in line
|
||||
or "database system is shut down" in line
|
||||
):
|
||||
break
|
||||
|
||||
container.reload()
|
||||
port_info = container.ports["5432/tcp"][0]
|
||||
port = port_info["HostPort"]
|
||||
db_host = f"localhost:{port}"
|
||||
|
||||
# Let PostgreSQL settle
|
||||
time.sleep(1)
|
||||
logging.info("Creating 'skynet' database...")
|
||||
|
||||
conn = psycopg2.connect(
|
||||
user="postgres", password=root_password, host="localhost", port=port
|
||||
)
|
||||
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
||||
conn.autocommit = True
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"CREATE USER skynet WITH PASSWORD '{skynet_password}'")
|
||||
cursor.execute("CREATE DATABASE skynet")
|
||||
cursor.execute("GRANT ALL PRIVILEGES ON DATABASE skynet TO skynet")
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
logging.info("Database setup complete.")
|
||||
yield container, skynet_password, db_host
|
||||
|
||||
finally:
|
||||
if container and cleanup:
|
||||
container.stop()
|
||||
|
||||
|
||||
class FrontendUserDB:
|
||||
"""
|
||||
A class that manages the connection pool for the 'skynet' database,
|
||||
initializes the schema if needed, and provides high-level methods
|
||||
for interacting with the 'skynet' tables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_user: str,
|
||||
db_pass: str,
|
||||
db_host: str,
|
||||
db_name: str
|
||||
):
|
||||
self.db_user = db_user
|
||||
self.db_pass = db_pass
|
||||
self.db_host = db_host
|
||||
self.db_name = db_name
|
||||
self.pool: asyncpg.Pool | None = None
|
||||
|
||||
async def __aenter__(self) -> "FrontendUserDB":
|
||||
dsn = f"postgres://{self.db_user}:{self.db_pass}@{self.db_host}/{self.db_name}"
|
||||
self.pool = await asyncpg.create_pool(dsn=dsn)
|
||||
await self._init_db()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.pool:
|
||||
await self.pool.close()
|
||||
|
||||
async def _init_db(self):
|
||||
"""
|
||||
Ensures the 'skynet' schema and tables exist. Also checks for
|
||||
missing columns and adds them if necessary.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
# Check if schema is already initialized
|
||||
result = await conn.fetch("""
|
||||
SELECT DISTINCT table_schema
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'skynet'
|
||||
""")
|
||||
if not result:
|
||||
await conn.execute(DB_INIT_SQL)
|
||||
|
||||
# Check if 'autoconf' column exists in user_config
|
||||
col_check = await conn.fetch("""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'user_config' AND column_name = 'autoconf'
|
||||
""")
|
||||
if not col_check:
|
||||
await conn.execute(
|
||||
"ALTER TABLE skynet.user_config ADD COLUMN autoconf BOOLEAN DEFAULT TRUE;"
|
||||
)
|
||||
|
||||
# -------------
|
||||
# USER METHODS
|
||||
# -------------
|
||||
|
||||
async def get_user_config(self, user_id: int):
|
||||
"""
|
||||
Fetches the user_config for the given user ID.
|
||||
Returns the record if found, otherwise None.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
records = await conn.fetch(
|
||||
"SELECT * FROM skynet.user_config WHERE id = $1", user_id
|
||||
)
|
||||
return dict(records[0]) if len(records) == 1 else None
|
||||
|
||||
async def get_user(self, user_id: int):
|
||||
"""Alias for get_user_config (same data returned)."""
|
||||
return await self.get_user_config(user_id)
|
||||
|
||||
async def new_user(self, user_id: int):
|
||||
"""
|
||||
Inserts a new user in skynet.user and its corresponding user_config record.
|
||||
Raises ValueError if the user already exists.
|
||||
"""
|
||||
existing = await self.get_user(user_id)
|
||||
if existing:
|
||||
raise ValueError("User already present in DB")
|
||||
|
||||
logging.info(f"New user! {user_id}")
|
||||
now = datetime.utcnow()
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO skynet.user(
|
||||
id, generated, joined,
|
||||
last_method, last_prompt, last_inputs, role
|
||||
)
|
||||
VALUES($1, 0, $2, 'txt2img', NULL, NULL, $3)
|
||||
""",
|
||||
user_id,
|
||||
now,
|
||||
DEFAULT_ROLE,
|
||||
)
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO skynet.user_config(
|
||||
id, model, step, width,
|
||||
height, guidance, strength, upscaler
|
||||
)
|
||||
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
""",
|
||||
user_id,
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_STEP,
|
||||
DEFAULT_WIDTH,
|
||||
DEFAULT_HEIGHT,
|
||||
DEFAULT_GUIDANCE,
|
||||
DEFAULT_STRENGTH,
|
||||
DEFAULT_UPSCALER,
|
||||
)
|
||||
|
||||
async def get_or_create_user(self, user_id: int):
|
||||
"""
|
||||
Retrieves a user_config record for the given user_id.
|
||||
If none exists, creates the user and returns the new record.
|
||||
"""
|
||||
user_cfg = await self.get_user(user_id)
|
||||
if not user_cfg:
|
||||
await self.new_user(user_id)
|
||||
user_cfg = await self.get_user(user_id)
|
||||
return user_cfg
|
||||
|
||||
async def update_user(self, user_id: int, attr: str, val):
|
||||
"""
|
||||
Generic function to update a single field in skynet.user for a given user_id.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
f"UPDATE skynet.user SET {attr} = $2 WHERE id = $1", user_id, val
|
||||
)
|
||||
|
||||
async def update_user_config(self, user_id: int, attr: str, val):
|
||||
"""
|
||||
Generic function to update a single field in skynet.user_config for a given user_id.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
f"UPDATE skynet.user_config SET {attr} = $2 WHERE id = $1", user_id, val
|
||||
)
|
||||
|
||||
async def get_user_stats(self, user_id: int):
|
||||
"""
|
||||
Returns (generated, joined, role) for the given user_id.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
records = await conn.fetch(
|
||||
"""
|
||||
SELECT generated, joined, role
|
||||
FROM skynet.user
|
||||
WHERE id = $1
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
return records[0] if records else None
|
||||
|
||||
async def increment_generated(self, user_id: int):
|
||||
"""
|
||||
Increments the 'generated' count for a given user by 1.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE skynet.user
|
||||
SET generated = generated + 1
|
||||
WHERE id = $1
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
async def update_user_stats(
|
||||
self,
|
||||
user_id: int,
|
||||
method: str,
|
||||
last_prompt: str | None = None,
|
||||
last_inputs: list | None = None
|
||||
):
|
||||
"""
|
||||
Updates various 'last_*' fields in skynet.user.
|
||||
"""
|
||||
await self.update_user(user_id, "last_method", method)
|
||||
if last_prompt is not None:
|
||||
await self.update_user(user_id, "last_prompt", last_prompt)
|
||||
|
||||
last_inputs_str = None
|
||||
if isinstance(last_inputs, list):
|
||||
last_inputs_str = ','.join((f'{f.id}:{f.cid}' for f in last_inputs))
|
||||
await self.update_user(user_id, "last_inputs", last_inputs_str)
|
||||
|
||||
logging.info("Updated user stats: %s", (method, last_prompt, last_inputs_str))
|
||||
|
||||
# ----------------------
|
||||
# USER REQUESTS METHODS
|
||||
# ----------------------
|
||||
|
||||
async def get_user_request(self, request_id: int):
|
||||
"""
|
||||
Fetches all matching rows for a given request_id.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
return await conn.fetch(
|
||||
"SELECT * FROM skynet.user_requests WHERE id = $1", request_id
|
||||
)
|
||||
|
||||
async def get_user_request_by_sid(self, status_msg_id: int):
|
||||
"""
|
||||
Fetches exactly one row (first row) by status_msg primary key.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
records = await conn.fetch(
|
||||
"SELECT * FROM skynet.user_requests WHERE status_msg = $1", status_msg_id
|
||||
)
|
||||
return records[0] if records else None
|
||||
|
||||
async def new_user_request(
|
||||
self,
|
||||
user_id: int,
|
||||
request_id: int,
|
||||
status_msg_id: int,
|
||||
status: str = "started processing request..."
|
||||
):
|
||||
"""
|
||||
Inserts a new row in skynet.user_requests.
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO skynet.user_requests(
|
||||
id, user_id, sent, status, status_msg
|
||||
)
|
||||
VALUES($1, $2, $3, $4, $5)
|
||||
""",
|
||||
request_id, user_id, now, status, status_msg_id
|
||||
)
|
||||
|
||||
async def update_user_request(self, request_id: int, status: str):
|
||||
"""
|
||||
Updates the 'status' for a user request identified by 'request_id'.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE skynet.user_requests
|
||||
SET status = $2
|
||||
WHERE id = $1
|
||||
""",
|
||||
request_id, status
|
||||
)
|
||||
|
||||
async def update_user_request_by_sid(self, sid: int, status: str):
|
||||
"""
|
||||
Updates the 'status' for a user request identified by 'status_msg'.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE skynet.user_requests
|
||||
SET status = $2
|
||||
WHERE status_msg = $1
|
||||
""",
|
||||
sid, status
|
||||
)
|
||||
|
||||
# ----------------------------
|
||||
# Convenience "Get Last" Helpers
|
||||
# ----------------------------
|
||||
|
||||
async def get_last_method_of(self, user_id: int) -> str | None:
|
||||
async with self.pool.acquire() as conn:
|
||||
return await conn.fetchval(
|
||||
"SELECT last_method FROM skynet.user WHERE id = $1", user_id
|
||||
)
|
||||
|
||||
async def get_last_prompt_of(self, user_id: int) -> str | None:
|
||||
async with self.pool.acquire() as conn:
|
||||
return await conn.fetchval(
|
||||
"SELECT last_prompt FROM skynet.user WHERE id = $1", user_id
|
||||
)
|
||||
|
||||
async def get_last_inputs_of(self, user_id: int) -> list[BaseFileInput] | None:
|
||||
async with self.pool.acquire() as conn:
|
||||
last_inputs_str = await conn.fetchval(
|
||||
"SELECT last_inputs FROM skynet.user WHERE id = $1", user_id
|
||||
)
|
||||
|
||||
if not last_inputs_str:
|
||||
return []
|
||||
|
||||
last_inputs = []
|
||||
for i in last_inputs_str.split(','):
|
||||
id, cid = i.split(':')
|
||||
last_inputs.from_values(id, cid)
|
||||
|
||||
return last_inputs
|
|
@ -0,0 +1,406 @@
|
|||
import logging
|
||||
from typing import Self, Awaitable
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import discord
|
||||
from discord import (
|
||||
User as DCUser,
|
||||
Member,
|
||||
Message as DCMessage,
|
||||
Attachment,
|
||||
DMChannel
|
||||
)
|
||||
from discord.abc import Messageable
|
||||
from discord.ext import commands
|
||||
|
||||
from skynet.config import FrontendConfig
|
||||
from skynet.types import BodyV0Params
|
||||
from skynet.constants import VERSION
|
||||
from skynet.frontend.chatbot import BaseChatbot
|
||||
from skynet.frontend.chatbot.db import FrontendUserDB
|
||||
from skynet.frontend.chatbot.types import (
|
||||
BaseUser,
|
||||
BaseChatRoom,
|
||||
BaseFileInput,
|
||||
BaseCommands,
|
||||
BaseMessage
|
||||
)
|
||||
|
||||
GROUP_ID = -1
|
||||
ADMIN_USER_ID = -1
|
||||
|
||||
|
||||
def timestamp_pretty():
|
||||
return datetime.now(timezone.utc).strftime('%H:%M:%S')
|
||||
|
||||
|
||||
class DiscordUser(BaseUser):
|
||||
|
||||
def __init__(self, user: DCUser | Member):
|
||||
self._user = user
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._user.id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._user.name
|
||||
|
||||
@property
|
||||
def is_admin(self) -> bool:
|
||||
return self.id == ADMIN_USER_ID
|
||||
|
||||
|
||||
class DiscordChatRoom(BaseChatRoom):
|
||||
|
||||
def __init__(self, channel: Messageable):
|
||||
self._channel = channel
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._channel.id
|
||||
|
||||
@property
|
||||
def is_private(self) -> bool:
|
||||
return isinstance(self._channel, DMChannel)
|
||||
|
||||
class DiscordFileInput(BaseFileInput):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attachment: Attachment | None = None,
|
||||
id: int | None = None,
|
||||
cid: int | None = None
|
||||
):
|
||||
self._attachment = attachment
|
||||
self._id = id
|
||||
self._cid = cid
|
||||
|
||||
self._raw = None
|
||||
|
||||
def from_values(id: int, cid: str) -> Self:
|
||||
return DiscordFileInput(id=id, cid=cid)
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
if self._id:
|
||||
return self._id
|
||||
|
||||
return self._attachment.id
|
||||
|
||||
@property
|
||||
def cid(self) -> str:
|
||||
if self._cid:
|
||||
return self._cid
|
||||
|
||||
raise ValueError
|
||||
|
||||
def set_cid(self, cid: str):
|
||||
self._cid = cid
|
||||
|
||||
async def download(self) -> bytes:
|
||||
self._raw = await self._attachment.read()
|
||||
return self._raw
|
||||
|
||||
|
||||
class DiscordMessage(BaseMessage):
|
||||
|
||||
def __init__(self, cmd: BaseCommands | None, msg: DCMessage):
|
||||
self._msg = msg
|
||||
self._cmd = cmd
|
||||
self._chat = DiscordChatRoom(msg.channel)
|
||||
self._inputs: list[DiscordFileInput] | None = None
|
||||
self._author = None
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._msg.id
|
||||
|
||||
@property
|
||||
def chat(self) -> DiscordChatRoom:
|
||||
return self._chat
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
# remove command name, slash and first space
|
||||
return self._msg.contents[len(self._cmd) + 2:]
|
||||
|
||||
@property
|
||||
def author(self) -> DiscordUser:
|
||||
if self._author:
|
||||
return self._author
|
||||
|
||||
return DiscordUser(self._msg.author)
|
||||
|
||||
@property
|
||||
def command(self) -> str | None:
|
||||
return self._cmd
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[DiscordFileInput]:
|
||||
if self._inputs is None:
|
||||
self._inputs = []
|
||||
if self._msg.attachments:
|
||||
self._inputs = [
|
||||
DiscordFileInput(attachment=a)
|
||||
for a in self._msg.attachments
|
||||
]
|
||||
|
||||
return self._inputs
|
||||
|
||||
|
||||
def generate_reply_embed(
|
||||
config: FrontendConfig,
|
||||
user: DiscordUser,
|
||||
params: BodyV0Params,
|
||||
tx_hash: str,
|
||||
worker: str,
|
||||
) -> discord.Embed:
|
||||
embed = discord.Embed(
|
||||
title='[SKYNET Transaction Explorer]',
|
||||
url=f'https://{config.explorer_domain}/v2/explore/transaction/{tx_hash}',
|
||||
color=discord.Color.blue())
|
||||
|
||||
prompt = params.prompt
|
||||
if len(prompt) > 256:
|
||||
prompt = prompt[:256]
|
||||
|
||||
gen_str = f'generated by {user.name}\n'
|
||||
gen_str += f'performed by {worker}\n'
|
||||
gen_str += f'reward: {config.reward}\n'
|
||||
|
||||
embed.add_field(
|
||||
name='General Info', value=f'```{gen_str}```', inline=False)
|
||||
# meta_str = f'__by {user.name}__\n'
|
||||
# meta_str += f'*performed by {worker}*\n'
|
||||
# meta_str += f'__**reward: {reward}**__\n'
|
||||
embed.add_field(name='Prompt', value=f'```{prompt}\n```', inline=False)
|
||||
|
||||
# meta_str = f'`prompt:` {prompt}\n'
|
||||
|
||||
meta_str = f'seed: {params.seed}\n'
|
||||
meta_str += f'step: {params.step}\n'
|
||||
if params.guidance:
|
||||
meta_str += f'guidance: {params.guidance}\n'
|
||||
|
||||
if params.strength:
|
||||
meta_str += f'strength: {params.strength}\n'
|
||||
|
||||
meta_str += f'algo: {params.model}\n'
|
||||
if params.upscaler:
|
||||
meta_str += f'upscaler: {params.upscaler}\n'
|
||||
|
||||
embed.add_field(name='Parameters', value=f'```{meta_str}```', inline=False)
|
||||
|
||||
foot_str = f'Made with Skynet v{VERSION}\n'
|
||||
foot_str += 'JOIN THE SWARM: https://discord.gg/PAabjJtZAF'
|
||||
|
||||
embed.set_footer(text=foot_str)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def append_command_handler(client: discord.Client, command: str, help_txt: str, fn: Awaitable):
|
||||
@client.command(name=command, help=help_txt)
|
||||
async def wrap_msg_and_handle(ctx: commands.Context):
|
||||
msg = DiscordMessage(cmd=command, msg=ctx.message)
|
||||
for file in msg.inputs:
|
||||
await file.download()
|
||||
await fn(msg)
|
||||
|
||||
|
||||
class DiscordChatbot(BaseChatbot):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FrontendConfig,
|
||||
db: FrontendUserDB,
|
||||
):
|
||||
super().__init__(config, db)
|
||||
intents = discord.Intents(
|
||||
messages=True,
|
||||
guilds=True,
|
||||
typing=True,
|
||||
members=True,
|
||||
presences=True,
|
||||
reactions=True,
|
||||
message_content=True,
|
||||
voice_states=True
|
||||
)
|
||||
client = discord.Client(
|
||||
command_prefix='/',
|
||||
intents=intents
|
||||
)
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f'{client.user.name} has connected to Discord!')
|
||||
for guild in client.guilds:
|
||||
for channel in guild.channels:
|
||||
if channel.name == "skynet":
|
||||
await channel.send('Skynet bot online') # view=SkynetView(self.bot))
|
||||
await channel.send(
|
||||
'Welcome to Skynet\'s Discord Bot,\n\n'
|
||||
'Skynet operates as a decentralized compute layer, offering a wide array of '
|
||||
'support for diverse AI paradigms through the use of blockchain technology. '
|
||||
'Our present focus is image generation, powered by 11 distinct models.\n\n'
|
||||
'To begin exploring, use the \'/help\' command or directly interact with the'
|
||||
'provided buttons. Here is an example command to generate an image:\n\n'
|
||||
'\'/txt2img a big red tractor in a giant field of corn\''
|
||||
)
|
||||
|
||||
print("\n==============")
|
||||
print("Logged in as")
|
||||
print(client.user.name)
|
||||
print(client.user.id)
|
||||
print("==============")
|
||||
|
||||
@client.event
|
||||
async def on_message(message: DCMessage):
|
||||
if message.author == client.user:
|
||||
return
|
||||
|
||||
await self.process_commands(message)
|
||||
|
||||
@client.event
|
||||
async def on_command_error(ctx, error):
|
||||
if isinstance(error, commands.MissingRequiredArgument):
|
||||
await ctx.send('You missed a required argument, please try again.')
|
||||
|
||||
append_command_handler(client, BaseCommands.HELP, 'Responds with help text', self.send_help)
|
||||
append_command_handler(client, BaseCommands.COOL, 'Display a list of cool prompt words', self.send_cool_words)
|
||||
append_command_handler(client, BaseCommands.QUEUE, 'Get information on current skynet queue', self.get_queue)
|
||||
append_command_handler(client, BaseCommands.CONFIG, 'Allows user to configure inference params', self.set_config)
|
||||
append_command_handler(client, BaseCommands.STATS, 'See user statistics', self.user_stats)
|
||||
append_command_handler(client, BaseCommands.DONATE, 'See donation information', self.donation_info)
|
||||
append_command_handler(client, BaseCommands.SAY, 'Admin command to make bot speak', self.say)
|
||||
|
||||
append_command_handler(client, BaseCommands.TXT2IMG, 'Generate an image from a prompt', self.handle_request)
|
||||
append_command_handler(client, BaseCommands.REDO, 'Re-generate image using last prompt', self.handle_request)
|
||||
|
||||
self.client = client
|
||||
self._main_room: DiscordChatRoom | None = None
|
||||
|
||||
async def init(self):
|
||||
dc_channel = await self.client.get_channel(GROUP_ID)
|
||||
self._main_room = DiscordChatRoom(channel=dc_channel)
|
||||
logging.info('initialized')
|
||||
|
||||
async def run(self):
|
||||
await self.init()
|
||||
await self.client.run(self.config.token)
|
||||
|
||||
@property
|
||||
def main_group(self) -> DiscordChatRoom:
|
||||
return self._main_room
|
||||
|
||||
async def new_msg(self, chat: DiscordChatRoom, text: str, **kwargs) -> DiscordMessage:
|
||||
dc_msg = await chat._channel.send(text, **kwargs)
|
||||
return DiscordMessage(cmd=None, msg=dc_msg)
|
||||
|
||||
async def reply_to(self, msg: DiscordMessage, text: str, **kwargs) -> DiscordMessage:
|
||||
dc_msg = await msg._msg.reply(content=text, **kwargs)
|
||||
return DiscordMessage(cmd=None, msg=dc_msg)
|
||||
|
||||
async def edit_msg(self, msg: DiscordMessage, text: str, **kwargs):
|
||||
await msg._msg.edit(content=text, **kwargs)
|
||||
|
||||
async def create_status_msg(self, msg: DiscordMessage, init_text: str, force_user: DiscordUser | None = None) -> tuple[BaseUser, BaseMessage, dict]:
|
||||
# maybe init user
|
||||
user = msg.author
|
||||
if force_user:
|
||||
user = force_user
|
||||
|
||||
user_row = await self.db.get_or_create_user(user.id)
|
||||
|
||||
# create status msg
|
||||
embed = discord.Embed(
|
||||
title='live updates',
|
||||
description=init_text,
|
||||
color=discord.Color.blue()
|
||||
)
|
||||
status_msg = await self.new_msg(msg.chat, None, embed=embed)
|
||||
|
||||
# start tracking of request in db
|
||||
await self.db.new_user_request(user.id, msg.id, status_msg.id, status=init_text)
|
||||
return [user, status_msg, user_row]
|
||||
|
||||
async def update_status_msg(self, msg: DiscordMessage, text: str):
|
||||
await self.db.update_user_request_by_sid(msg.id, text)
|
||||
embed = discord.Embed(
|
||||
title='live updates',
|
||||
description=text,
|
||||
color=discord.Color.blue()
|
||||
)
|
||||
await self.edit_msg(msg, None, embed=embed)
|
||||
|
||||
async def append_status_msg(self, msg: DiscordMessage, text: str):
|
||||
request = await self.db.get_user_request_by_sid(msg.id)
|
||||
await self.update_status_msg(msg, request['status'] + text)
|
||||
|
||||
async def update_request_status_timeout(self, status_msg: DiscordMessage):
|
||||
await self.append_status_msg(
|
||||
status_msg,
|
||||
f'\n[{timestamp_pretty()}] **timeout processing request**',
|
||||
)
|
||||
|
||||
async def update_request_status_step_0(self, status_msg: DiscordMessage, user_msg: DiscordMessage):
|
||||
await self.update_status_msg(
|
||||
status_msg,
|
||||
f'processing a \'{status_msg.cmd}\' request by {status_msg.author.name}\n'
|
||||
f'[{timestamp_pretty()}] *broadcasting transaction to chain...* '
|
||||
)
|
||||
|
||||
async def update_request_status_step_1(self, status_msg: DiscordMessage, tx_result: dict):
|
||||
enqueue_tx_id = tx_result['transaction_id']
|
||||
enqueue_tx_link = f'[**Your request on Skynet Explorer**](https://{self.config.explorer_domain}/v2/explore/transaction/{enqueue_tx_id})'
|
||||
await self.append_status_msg(
|
||||
status_msg,
|
||||
'**broadcasted!** \n'
|
||||
f'{enqueue_tx_link}\n'
|
||||
f'[{timestamp_pretty()}] *workers are processing request...* '
|
||||
)
|
||||
|
||||
async def update_request_status_step_2(self, status_msg: DiscordMessage, submit_tx_hash: str):
|
||||
tx_link = f'[**Your result on Skynet Explorer**](https://{self.config.explorer_domain}/v2/explore/transaction/{submit_tx_hash})'
|
||||
await self.append_status_msg(
|
||||
status_msg,
|
||||
'**request processed!**\n'
|
||||
f'{tx_link}\n'
|
||||
f'[{timestamp_pretty()}] *trying to download image...*\n '
|
||||
)
|
||||
|
||||
async def update_request_status_final(
|
||||
self,
|
||||
og_msg: DiscordMessage,
|
||||
status_msg: DiscordMessage,
|
||||
user: DiscordUser,
|
||||
params: BodyV0Params,
|
||||
inputs: list[DiscordFileInput],
|
||||
submit_tx_hash: str,
|
||||
worker: str,
|
||||
result_url: str,
|
||||
result_img: bytes | None
|
||||
):
|
||||
embed = generate_reply_embed(
|
||||
self.config, user, params, submit_tx_hash, worker)
|
||||
|
||||
if not result_img:
|
||||
# result found on chain but failed to fetch img from ipfs
|
||||
await self.append_status_msg(status_msg, f'[{timestamp_pretty()}] *Couldn\'t get IPFS hosted img [**here**]({result_url})!*')
|
||||
return
|
||||
|
||||
await status_msg._msg.delete()
|
||||
|
||||
embed.set_image(url=result_url)
|
||||
|
||||
match len(inputs):
|
||||
case 0:
|
||||
await self.new_msg(og_msg.chat, None, embed=embed)
|
||||
|
||||
case _:
|
||||
_input = inputs[-1]
|
||||
dc_file = discord.File(_input._raw, filename=f'image-{og_msg.id}.png')
|
||||
embed.set_thumbnail(url=f'attachment://image-{og_msg.id}.png')
|
||||
await self.new_msg(og_msg.chat, None, embed=embed, file=dc_file)
|
|
@ -0,0 +1,425 @@
|
|||
import json
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from typing import Self, Awaitable
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from telebot.types import (
|
||||
User as TGUser,
|
||||
Chat as TGChat,
|
||||
PhotoSize as TGPhotoSize,
|
||||
Message as TGMessage,
|
||||
CallbackQuery,
|
||||
InputMediaPhoto,
|
||||
InlineKeyboardButton,
|
||||
InlineKeyboardMarkup
|
||||
)
|
||||
from telebot.async_telebot import AsyncTeleBot, ExceptionHandler
|
||||
from telebot.formatting import hlink
|
||||
|
||||
from skynet.types import BodyV0Params
|
||||
from skynet.config import FrontendConfig
|
||||
from skynet.constants import VERSION
|
||||
from skynet.frontend.chatbot import BaseChatbot
|
||||
from skynet.frontend.chatbot.db import FrontendUserDB
|
||||
from skynet.frontend.chatbot.types import (
|
||||
BaseUser,
|
||||
BaseChatRoom,
|
||||
BaseFileInput,
|
||||
BaseCommands,
|
||||
BaseMessage
|
||||
)
|
||||
|
||||
GROUP_ID = -1001541979235
|
||||
TEST_GROUP_ID = -4099622703
|
||||
ADMIN_USER_ID = 383385940
|
||||
|
||||
|
||||
# Chatbot types impls
|
||||
|
||||
class TelegramUser(BaseUser):
|
||||
def __init__(self, user: TGUser):
|
||||
self._user = user
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._user.id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
if self._user.username:
|
||||
return f'@{self._user.username}'
|
||||
|
||||
return f'{self._user.first_name} id: {self.id}'
|
||||
|
||||
@property
|
||||
def is_admin(self) -> bool:
|
||||
return self.id == ADMIN_USER_ID
|
||||
|
||||
|
||||
class TelegramChatRoom(BaseChatRoom):
|
||||
|
||||
def __init__(self, chat: TGChat):
|
||||
self._chat = chat
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._chat.id
|
||||
|
||||
@property
|
||||
def is_private(self) -> bool:
|
||||
return self._chat.type == 'private'
|
||||
|
||||
|
||||
class TelegramFileInput(BaseFileInput):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
photo: TGPhotoSize | None = None,
|
||||
id: int | None = None,
|
||||
cid: str | None = None
|
||||
):
|
||||
self._photo = photo
|
||||
self._id = id
|
||||
self._cid = cid
|
||||
|
||||
self._raw = None
|
||||
|
||||
def from_values(id: int, cid: str) -> Self:
|
||||
return TelegramFileInput(id=id, cid=cid)
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
if self._id:
|
||||
return self._id
|
||||
|
||||
return self._photo.file_id
|
||||
|
||||
@property
|
||||
def cid(self) -> str:
|
||||
if self._cid:
|
||||
return self._cid
|
||||
|
||||
raise ValueError
|
||||
|
||||
def set_cid(self, cid: str):
|
||||
self._cid = cid
|
||||
|
||||
async def download(self, bot: AsyncTeleBot) -> bytes:
|
||||
file_path = (await bot.get_file(self.id)).file_path
|
||||
self._raw = await bot.download_file(file_path)
|
||||
return self._raw
|
||||
|
||||
|
||||
class TelegramMessage(BaseMessage):
|
||||
|
||||
def __init__(self, cmd: BaseCommands | None, msg: TGMessage):
|
||||
self._msg = msg
|
||||
self._cmd = cmd
|
||||
self._chat = TelegramChatRoom(msg.chat)
|
||||
self._inputs: list[TelegramFileInput] | None = None
|
||||
|
||||
self._author = None
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._msg.message_id
|
||||
|
||||
@property
|
||||
def chat(self) -> TelegramChatRoom:
|
||||
return self._chat
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
# remove command name, slash and first space
|
||||
if self._msg.text:
|
||||
return self._msg.text[len(self._cmd) + 2:]
|
||||
|
||||
return self._msg.caption[len(self._cmd) + 2:]
|
||||
|
||||
@property
|
||||
def author(self) -> TelegramUser:
|
||||
if self._author:
|
||||
return self._author
|
||||
|
||||
return TelegramUser(self._msg.from_user)
|
||||
|
||||
@property
|
||||
def command(self) -> str | None:
|
||||
return self._cmd
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[TelegramFileInput]:
|
||||
if self._inputs is None:
|
||||
self._inputs = []
|
||||
if self._msg.photo:
|
||||
self._inputs = [
|
||||
TelegramFileInput(photo=p)
|
||||
for p in self._msg.photo
|
||||
]
|
||||
|
||||
return self._inputs
|
||||
|
||||
|
||||
# generic tg utils
|
||||
|
||||
def timestamp_pretty():
|
||||
return datetime.now(timezone.utc).strftime('%H:%M:%S')
|
||||
|
||||
|
||||
class TGExceptionHandler(ExceptionHandler):
|
||||
|
||||
def handle(exception):
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def build_redo_menu():
|
||||
btn_redo = InlineKeyboardButton("Redo", callback_data='{\"method\": \"redo\"}')
|
||||
inline_keyboard = InlineKeyboardMarkup()
|
||||
inline_keyboard.add(btn_redo)
|
||||
return inline_keyboard
|
||||
|
||||
|
||||
def prepare_metainfo_caption(user: TelegramUser, worker: str, reward: str, params: BodyV0Params) -> str:
|
||||
prompt = params.prompt
|
||||
if len(prompt) > 256:
|
||||
prompt = prompt[:256]
|
||||
|
||||
meta_str = f'<u>by {user.name}</u>\n'
|
||||
meta_str += f'<i>performed by {worker}</i>\n'
|
||||
meta_str += f'<b><u>reward: {reward}</u></b>\n'
|
||||
|
||||
meta_str += f'<code>prompt:</code> {prompt}\n'
|
||||
meta_str += f'<code>seed: {params.seed}</code>\n'
|
||||
meta_str += f'<code>step: {params.step}</code>\n'
|
||||
if params.guidance:
|
||||
meta_str += f'<code>guidance: {params.guidance}</code>\n'
|
||||
|
||||
if params.strength:
|
||||
meta_str += f'<code>strength: {params.strength}</code>\n'
|
||||
|
||||
meta_str += f'<code>algo: {params.model}</code>\n'
|
||||
|
||||
meta_str += f'<b><u>Made with Skynet v{VERSION}</u></b>\n'
|
||||
meta_str += '<b>JOIN THE SWARM: @skynetgpu</b>'
|
||||
return meta_str
|
||||
|
||||
|
||||
def generate_reply_caption(
|
||||
config: FrontendConfig,
|
||||
user: TelegramUser,
|
||||
params: BodyV0Params,
|
||||
tx_hash: str,
|
||||
worker: str,
|
||||
):
|
||||
explorer_link = hlink(
|
||||
'SKYNET Transaction Explorer',
|
||||
f'https://{config.explorer_domain}/v2/explore/transaction/{tx_hash}'
|
||||
)
|
||||
|
||||
meta_info = prepare_metainfo_caption(user, worker, config.reward, params)
|
||||
|
||||
final_msg = '\n'.join([
|
||||
'Worker finished your task!',
|
||||
explorer_link,
|
||||
f'PARAMETER INFO:\n{meta_info}'
|
||||
])
|
||||
|
||||
final_msg = '\n'.join([
|
||||
f'<b><i>{explorer_link}</i></b>',
|
||||
f'{meta_info}'
|
||||
])
|
||||
|
||||
return final_msg
|
||||
|
||||
|
||||
def append_handler(bot: AsyncTeleBot, command: str, fn: Awaitable):
|
||||
@bot.message_handler(commands=[command])
|
||||
async def wrap_msg_and_handle(tg_msg: TGMessage):
|
||||
await fn(TelegramMessage(cmd=command, msg=tg_msg))
|
||||
|
||||
|
||||
class TelegramChatbot(BaseChatbot):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FrontendConfig,
|
||||
db: FrontendUserDB,
|
||||
):
|
||||
super().__init__(config, db)
|
||||
bot = AsyncTeleBot(config.token, exception_handler=TGExceptionHandler)
|
||||
|
||||
append_handler(bot, BaseCommands.HELP, self.send_help)
|
||||
append_handler(bot, BaseCommands.COOL, self.send_cool_words)
|
||||
append_handler(bot, BaseCommands.QUEUE, self.get_queue)
|
||||
append_handler(bot, BaseCommands.CONFIG, self.set_config)
|
||||
append_handler(bot, BaseCommands.STATS, self.user_stats)
|
||||
append_handler(bot, BaseCommands.DONATE, self.donation_info)
|
||||
append_handler(bot, BaseCommands.SAY, self.say)
|
||||
|
||||
append_handler(bot, BaseCommands.TXT2IMG, self.handle_request)
|
||||
|
||||
append_handler(bot, BaseCommands.IMG2IMG, self.handle_request)
|
||||
|
||||
@bot.message_handler(func=lambda _: True, content_types=['photo', 'document'])
|
||||
async def handle_img2img(tg_msg: TGMessage):
|
||||
msg = TelegramMessage(cmd='img2img', msg=tg_msg)
|
||||
for file in msg.inputs:
|
||||
await file.download(bot)
|
||||
await self.handle_request(msg)
|
||||
|
||||
append_handler(bot, BaseCommands.REDO, self.handle_request)
|
||||
|
||||
@bot.message_handler(func=lambda _: True)
|
||||
async def unknown_cmd(tg_msg: TGMessage):
|
||||
if tg_msg.text[0] == '/':
|
||||
msg = TelegramMessage(cmd='unknown', msg=tg_msg)
|
||||
await self.echo_unknown(msg)
|
||||
|
||||
@bot.callback_query_handler(func=lambda _: True)
|
||||
async def callback_query(call: CallbackQuery):
|
||||
call_json = json.loads(call.data)
|
||||
method = call_json.get('method')
|
||||
match method:
|
||||
case 'redo':
|
||||
msg = await self.new_msg(self.main_group, 'processing a redo request...')
|
||||
msg._cmd = 'redo'
|
||||
await self.handle_request(msg, force_user=TelegramUser(user=call.from_user))
|
||||
await bot.delete_message(chat_id=self.main_group.id, message_id=msg.id)
|
||||
|
||||
self.bot = bot
|
||||
|
||||
self._main_room: TelegramChatRoom | None = None
|
||||
|
||||
async def init(self):
|
||||
tg_group = await self.bot.get_chat(TEST_GROUP_ID)
|
||||
self._main_room = TelegramChatRoom(chat=tg_group)
|
||||
logging.info('initialized')
|
||||
|
||||
async def run(self):
|
||||
await self.init()
|
||||
await self.bot.infinity_polling()
|
||||
|
||||
@property
|
||||
def main_group(self) -> TelegramChatRoom:
|
||||
return self._main_room
|
||||
|
||||
async def new_msg(self, chat: TelegramChatRoom, text: str) -> TelegramMessage:
|
||||
msg = await self.bot.send_message(chat.id, text, parse_mode='HTML')
|
||||
return TelegramMessage(cmd=None, msg=msg)
|
||||
|
||||
async def reply_to(self, msg: TelegramMessage, text: str) -> TelegramMessage:
|
||||
msg = await self.bot.reply_to(msg._msg, text, parse_mode='HTML')
|
||||
return TelegramMessage(cmd=None, msg=msg)
|
||||
|
||||
async def edit_msg(self, msg: TelegramMessage, text: str):
|
||||
await self.bot.edit_message_text(
|
||||
text,
|
||||
chat_id=msg.chat.id,
|
||||
message_id=msg.id,
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
async def update_request_status_timeout(self, status_msg: TelegramMessage):
|
||||
'''
|
||||
Notify users when we timedout trying to find a matching submit
|
||||
'''
|
||||
await self.append_status_msg(
|
||||
status_msg,
|
||||
f'\n[{timestamp_pretty()}] <b>timeout processing request</b>',
|
||||
)
|
||||
|
||||
async def update_request_status_step_0(self, status_msg: TelegramMessage, user_msg: TelegramMessage):
|
||||
'''
|
||||
First step in request status message lifecycle, should notify which user sent the request
|
||||
and that we are about to broadcast the request to chain
|
||||
'''
|
||||
await self.update_status_msg(
|
||||
status_msg,
|
||||
f'processing a \'{user_msg.command}\' request by {user_msg.author.name}\n'
|
||||
f'[{timestamp_pretty()}] <i>broadcasting transaction to chain...</i>'
|
||||
)
|
||||
|
||||
async def update_request_status_step_1(self, status_msg: TelegramMessage, tx_result: dict):
|
||||
'''
|
||||
Second step in request status message lifecycle, should notify enqueue transaction
|
||||
was processed by chain, and provide a link to the tx in the chain explorer
|
||||
'''
|
||||
enqueue_tx_id = tx_result['transaction_id']
|
||||
enqueue_tx_link = hlink(
|
||||
'Your request on Skynet Explorer',
|
||||
f'https://{self.config.explorer_domain}/v2/explore/transaction/{enqueue_tx_id}'
|
||||
)
|
||||
await self.append_status_msg(
|
||||
status_msg,
|
||||
f' <b>broadcasted!</b>\n'
|
||||
f'<b>{enqueue_tx_link}</b>\n'
|
||||
f'[{timestamp_pretty()}] <i>workers are processing request...</i>',
|
||||
)
|
||||
|
||||
async def update_request_status_step_2(self, status_msg: TelegramMessage, submit_tx_hash: str):
|
||||
'''
|
||||
Third step in request status message lifecycle, should notify matching submit transaction
|
||||
was found, and provide a link to the tx in the chain explorer
|
||||
'''
|
||||
tx_link = hlink(
|
||||
'Your result on Skynet Explorer',
|
||||
f'https://{self.config.explorer_domain}/v2/explore/transaction/{submit_tx_hash}'
|
||||
)
|
||||
await self.append_status_msg(
|
||||
status_msg,
|
||||
f' <b>request processed!</b>\n'
|
||||
f'<b>{tx_link}</b>\n'
|
||||
f'[{timestamp_pretty()}] <i>trying to download image...</i>\n',
|
||||
)
|
||||
|
||||
async def update_request_status_final(
|
||||
self,
|
||||
og_msg: TelegramMessage,
|
||||
status_msg: TelegramMessage,
|
||||
user: TelegramUser,
|
||||
params: BodyV0Params,
|
||||
inputs: list[TelegramFileInput],
|
||||
submit_tx_hash: str,
|
||||
worker: str,
|
||||
result_url: str,
|
||||
result_img: bytes | None
|
||||
):
|
||||
'''
|
||||
Last step in request status message lifecycle, should delete status message and send a
|
||||
new message replying to the original user's message, generate the appropiate
|
||||
reply caption and if provided also sent the found result img
|
||||
'''
|
||||
caption = generate_reply_caption(
|
||||
self.config, user, params, submit_tx_hash, worker)
|
||||
|
||||
await self.bot.delete_message(
|
||||
chat_id=status_msg.chat.id,
|
||||
message_id=status_msg.id
|
||||
)
|
||||
|
||||
if not result_img:
|
||||
# result found on chain but failed to fetch img from ipfs
|
||||
await self.reply_to(og_msg, caption, reply_markup=build_redo_menu())
|
||||
return
|
||||
|
||||
match len(inputs):
|
||||
case 0:
|
||||
await self.bot.send_photo(
|
||||
status_msg.chat.id,
|
||||
caption=caption,
|
||||
photo=result_img,
|
||||
reply_markup=build_redo_menu(),
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
case _:
|
||||
_input = inputs[-1]
|
||||
await self.bot.send_media_group(
|
||||
status_msg.chat.id,
|
||||
media=[
|
||||
InputMediaPhoto(_input.id),
|
||||
InputMediaPhoto(result_img, caption=caption, parse_mode='HTML')
|
||||
]
|
||||
)
|
|
@ -0,0 +1,116 @@
|
|||
import io
|
||||
|
||||
from abc import ABC, abstractproperty, abstractmethod
|
||||
from enum import StrEnum
|
||||
from typing import Self
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
from skynet.ipfs import AsyncIPFSHTTP
|
||||
|
||||
|
||||
class BaseUser(ABC):
|
||||
|
||||
@abstractproperty
|
||||
def id(self) -> int:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def name(self) -> str:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def is_admin(self) -> bool:
|
||||
...
|
||||
|
||||
|
||||
class BaseChatRoom(ABC):
|
||||
@abstractproperty
|
||||
def id(self) -> int:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def is_private(self) -> bool:
|
||||
...
|
||||
|
||||
|
||||
class BaseFileInput(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def from_values(id: int, cid: str) -> Self:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def id(self) -> int:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def cid(self) -> str:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def download(self, *args) -> bytes:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_cid(self, cid: str):
|
||||
...
|
||||
|
||||
async def publish(self, ipfs_api: AsyncIPFSHTTP, user_row: dict):
|
||||
with Image.open(io.BytesIO(self._raw)) as img:
|
||||
w, h = img.size
|
||||
|
||||
if (
|
||||
w > user_row['width']
|
||||
or
|
||||
h > user_row['height']
|
||||
):
|
||||
img.thumbnail((user_row['width'], user_row['height']))
|
||||
|
||||
img_path = Path('/tmp/ipfs-staging/img.png')
|
||||
img.save(img_path, format='PNG')
|
||||
|
||||
ipfs_info = await ipfs_api.add(img_path)
|
||||
ipfs_hash = ipfs_info['Hash']
|
||||
self.set_cid(ipfs_hash)
|
||||
await ipfs_api.pin(ipfs_hash)
|
||||
|
||||
|
||||
class BaseCommands(StrEnum):
|
||||
TXT2IMG = 'txt2img'
|
||||
IMG2IMG = 'img2img'
|
||||
REDO = 'redo'
|
||||
HELP = 'help'
|
||||
COOL = 'cool'
|
||||
QUEUE = 'queue'
|
||||
CONFIG = 'config'
|
||||
STATS = 'stats'
|
||||
DONATE = 'donate'
|
||||
SAY = 'say'
|
||||
|
||||
|
||||
class BaseMessage(ABC):
|
||||
@abstractproperty
|
||||
def id(self) -> int:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def chat(self) -> BaseChatRoom:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def text(self) -> str:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def author(self) -> BaseUser:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def command(self) -> str | None:
|
||||
...
|
||||
|
||||
@abstractproperty
|
||||
def inputs(self) -> list[BaseFileInput]:
|
||||
...
|
|
@ -1,322 +0,0 @@
|
|||
from json import JSONDecodeError
|
||||
import random
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from decimal import Decimal
|
||||
from hashlib import sha256
|
||||
from datetime import datetime
|
||||
from contextlib import (
|
||||
ExitStack,
|
||||
AsyncExitStack,
|
||||
)
|
||||
from contextlib import asynccontextmanager as acm
|
||||
|
||||
from leap.cleos import CLEOS
|
||||
from leap.sugar import (
|
||||
Name,
|
||||
asset_from_str,
|
||||
collect_stdout,
|
||||
)
|
||||
from leap.hyperion import HyperionAPI
|
||||
# from telebot.types import InputMediaPhoto
|
||||
|
||||
import discord
|
||||
import requests
|
||||
import io
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
|
||||
from skynet.db import open_database_connection
|
||||
from skynet.ipfs import get_ipfs_file, AsyncIPFSHTTP
|
||||
from skynet.constants import *
|
||||
|
||||
from . import *
|
||||
from .bot import DiscordBot
|
||||
|
||||
from .utils import *
|
||||
from .handlers import create_handler_context
|
||||
from .ui import SkynetView
|
||||
|
||||
|
||||
class SkynetDiscordFrontend:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# token: str,
|
||||
account: str,
|
||||
permission: str,
|
||||
node_url: str,
|
||||
hyperion_url: str,
|
||||
db_host: str,
|
||||
db_user: str,
|
||||
db_pass: str,
|
||||
ipfs_url: str,
|
||||
remote_ipfs_node: str,
|
||||
key: str,
|
||||
explorer_domain: str,
|
||||
ipfs_domain: str
|
||||
):
|
||||
# self.token = token
|
||||
self.account = account
|
||||
self.permission = permission
|
||||
self.node_url = node_url
|
||||
self.hyperion_url = hyperion_url
|
||||
self.db_host = db_host
|
||||
self.db_user = db_user
|
||||
self.db_pass = db_pass
|
||||
self.ipfs_url = ipfs_url
|
||||
self.remote_ipfs_node = remote_ipfs_node
|
||||
self.key = key
|
||||
self.explorer_domain = explorer_domain
|
||||
self.ipfs_domain = ipfs_domain
|
||||
|
||||
self.bot = DiscordBot(self)
|
||||
self.cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
self.hyperion = HyperionAPI(hyperion_url)
|
||||
self.ipfs_node = AsyncIPFSHTTP(ipfs_url)
|
||||
|
||||
self._exit_stack = ExitStack()
|
||||
self._async_exit_stack = AsyncExitStack()
|
||||
|
||||
async def start(self):
|
||||
if self.remote_ipfs_node:
|
||||
await self.ipfs_node.connect(self.remote_ipfs_node)
|
||||
|
||||
self.db_call = await self._async_exit_stack.enter_async_context(
|
||||
open_database_connection(
|
||||
self.db_user, self.db_pass, self.db_host))
|
||||
|
||||
create_handler_context(self)
|
||||
|
||||
async def stop(self):
|
||||
await self._async_exit_stack.aclose()
|
||||
self._exit_stack.close()
|
||||
|
||||
@acm
|
||||
async def open(self):
|
||||
await self.start()
|
||||
yield self
|
||||
await self.stop()
|
||||
|
||||
# maybe do this?
|
||||
# async def update_status_message(
|
||||
# self, status_msg, new_text: str, **kwargs
|
||||
# ):
|
||||
# await self.db_call(
|
||||
# 'update_user_request_by_sid', status_msg.id, new_text)
|
||||
# return await self.bot.edit_message_text(
|
||||
# new_text,
|
||||
# chat_id=status_msg.chat.id,
|
||||
# message_id=status_msg.id,
|
||||
# **kwargs
|
||||
# )
|
||||
|
||||
# async def append_status_message(
|
||||
# self, status_msg, add_text: str, **kwargs
|
||||
# ):
|
||||
# request = await self.db_call('get_user_request_by_sid', status_msg.id)
|
||||
# await self.update_status_message(
|
||||
# status_msg,
|
||||
# request['status'] + add_text,
|
||||
# **kwargs
|
||||
# )
|
||||
|
||||
async def work_request(
|
||||
self,
|
||||
user,
|
||||
status_msg,
|
||||
method: str,
|
||||
params: dict,
|
||||
ctx: discord.ext.commands.context.Context | discord.Message,
|
||||
file_id: str | None = None,
|
||||
binary_data: str = ''
|
||||
) -> bool:
|
||||
send = ctx.channel.send
|
||||
|
||||
if params['seed'] == None:
|
||||
params['seed'] = random.randint(0, 0xFFFFFFFF)
|
||||
|
||||
sanitized_params = {}
|
||||
for key, val in params.items():
|
||||
if isinstance(val, Decimal):
|
||||
val = str(val)
|
||||
|
||||
sanitized_params[key] = val
|
||||
|
||||
body = json.dumps({
|
||||
'method': 'diffuse',
|
||||
'params': sanitized_params
|
||||
})
|
||||
request_time = datetime.now().isoformat()
|
||||
|
||||
await status_msg.delete()
|
||||
msg_text = f'processing a \'{method}\' request by {user.name}\n[{timestamp_pretty()}] *broadcasting transaction to chain...* '
|
||||
embed = discord.Embed(
|
||||
title='live updates',
|
||||
description=msg_text,
|
||||
color=discord.Color.blue())
|
||||
|
||||
message = await send(embed=embed)
|
||||
|
||||
reward = '20.0000 GPU'
|
||||
res = await self.cleos.a_push_action(
|
||||
'gpu.scd',
|
||||
'enqueue',
|
||||
{
|
||||
'user': Name(self.account),
|
||||
'request_body': body,
|
||||
'binary_data': binary_data,
|
||||
'reward': asset_from_str(reward),
|
||||
'min_verification': 1
|
||||
},
|
||||
self.account, self.key, permission=self.permission
|
||||
)
|
||||
|
||||
if 'code' in res or 'statusCode' in res:
|
||||
logging.error(json.dumps(res, indent=4))
|
||||
await self.bot.channel.send(
|
||||
status_msg,
|
||||
'skynet has suffered an internal error trying to fill this request')
|
||||
return False
|
||||
|
||||
enqueue_tx_id = res['transaction_id']
|
||||
enqueue_tx_link = f'[**Your request on Skynet Explorer**](https://{self.explorer_domain}/v2/explore/transaction/{enqueue_tx_id})'
|
||||
|
||||
msg_text += f'**broadcasted!** \n{enqueue_tx_link}\n[{timestamp_pretty()}] *workers are processing request...* '
|
||||
embed = discord.Embed(
|
||||
title='live updates',
|
||||
description=msg_text,
|
||||
color=discord.Color.blue())
|
||||
|
||||
await message.edit(embed=embed)
|
||||
|
||||
out = collect_stdout(res)
|
||||
|
||||
request_id, nonce = out.split(':')
|
||||
|
||||
request_hash = sha256(
|
||||
(nonce + body + binary_data).encode('utf-8')).hexdigest().upper()
|
||||
|
||||
request_id = int(request_id)
|
||||
|
||||
logging.info(f'{request_id} enqueued.')
|
||||
|
||||
tx_hash = None
|
||||
ipfs_hash = None
|
||||
for i in range(60):
|
||||
try:
|
||||
submits = await self.hyperion.aget_actions(
|
||||
account=self.account,
|
||||
filter='gpu.scd:submit',
|
||||
sort='desc',
|
||||
after=request_time
|
||||
)
|
||||
actions = [
|
||||
action
|
||||
for action in submits['actions']
|
||||
if action[
|
||||
'act']['data']['request_hash'] == request_hash
|
||||
]
|
||||
if len(actions) > 0:
|
||||
tx_hash = actions[0]['trx_id']
|
||||
data = actions[0]['act']['data']
|
||||
ipfs_hash = data['ipfs_hash']
|
||||
worker = data['worker']
|
||||
logging.info('Found matching submit!')
|
||||
break
|
||||
|
||||
except JSONDecodeError:
|
||||
logging.error(f'network error while getting actions, retry..')
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if not ipfs_hash:
|
||||
|
||||
timeout_text = f'\n[{timestamp_pretty()}] **timeout processing request**'
|
||||
embed = discord.Embed(
|
||||
title='live updates',
|
||||
description=timeout_text,
|
||||
color=discord.Color.blue())
|
||||
|
||||
await message.edit(embed=embed)
|
||||
return False
|
||||
|
||||
tx_link = f'[**Your result on Skynet Explorer**](https://{self.explorer_domain}/v2/explore/transaction/{tx_hash})'
|
||||
|
||||
msg_text += f'**request processed!**\n{tx_link}\n[{timestamp_pretty()}] *trying to download image...*\n '
|
||||
embed = discord.Embed(
|
||||
title='live updates',
|
||||
description=msg_text,
|
||||
color=discord.Color.blue())
|
||||
|
||||
await message.edit(embed=embed)
|
||||
|
||||
# attempt to get the image and send it
|
||||
results = {}
|
||||
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
|
||||
ipfs_link_legacy = ipfs_link + '/image.png'
|
||||
|
||||
async def get_and_set_results(link: str):
|
||||
res = await get_ipfs_file(link)
|
||||
logging.info(f'got response from {link}')
|
||||
if not res or res.status_code != 200:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
||||
|
||||
else:
|
||||
try:
|
||||
with Image.open(io.BytesIO(res.raw)) as image:
|
||||
tmp_buf = io.BytesIO()
|
||||
image.save(tmp_buf, format='PNG')
|
||||
png_img = tmp_buf.getvalue()
|
||||
results[link] = png_img
|
||||
|
||||
except UnidentifiedImageError:
|
||||
logging.warning(
|
||||
f'couldn\'t get ipfs binary data at {link}!')
|
||||
|
||||
tasks = [
|
||||
get_and_set_results(ipfs_link),
|
||||
get_and_set_results(ipfs_link_legacy)
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
png_img = None
|
||||
if ipfs_link_legacy in results:
|
||||
png_img = results[ipfs_link_legacy]
|
||||
|
||||
if ipfs_link in results:
|
||||
png_img = results[ipfs_link]
|
||||
|
||||
if not png_img:
|
||||
logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!')
|
||||
embed.add_field(
|
||||
name='Error', value=f'couldn\'t get ipfs hosted image [**here**]({ipfs_link})!')
|
||||
await message.edit(embed=embed, view=SkynetView(self))
|
||||
return True
|
||||
|
||||
# reword this function, may not need caption
|
||||
caption, embed = generate_reply_caption(
|
||||
user, params, tx_hash, worker, reward, self.explorer_domain)
|
||||
|
||||
logging.info(f'success! sending generated image')
|
||||
await message.delete()
|
||||
if file_id: # img2img
|
||||
embed.set_image(url=ipfs_link)
|
||||
orig_url = f'https://{self.ipfs_domain}/ipfs/' + binary_data
|
||||
res = requests.get(orig_url, stream=True)
|
||||
if res.status_code == 200:
|
||||
with io.BytesIO(res.content) as img:
|
||||
file = discord.File(img, filename='image.png')
|
||||
embed.set_thumbnail(url='attachment://image.png')
|
||||
await send(embed=embed, view=SkynetView(self), file=file)
|
||||
# orig_url = f'https://{self.ipfs_domain}/ipfs/' \
|
||||
# + binary_data + '/image.png'
|
||||
# embed.set_thumbnail(
|
||||
# url=orig_url)
|
||||
else:
|
||||
await send(embed=embed, view=SkynetView(self))
|
||||
else: # txt2img
|
||||
embed.set_image(url=ipfs_link)
|
||||
await send(embed=embed, view=SkynetView(self))
|
||||
|
||||
return True
|
|
@ -1,89 +0,0 @@
|
|||
# import os
|
||||
import discord
|
||||
import asyncio
|
||||
# from dotenv import load_dotenv
|
||||
# from pathlib import Path
|
||||
from discord.ext import commands
|
||||
from .ui import SkynetView
|
||||
|
||||
|
||||
# # Auth
|
||||
# current_dir = Path(__file__).resolve().parent
|
||||
# # parent_dir = current_dir.parent
|
||||
# env_file_path = current_dir / ".env"
|
||||
# load_dotenv(dotenv_path=env_file_path)
|
||||
#
|
||||
# discordToken = os.getenv("DISCORD_TOKEN")
|
||||
|
||||
|
||||
# Actual Discord bot.
|
||||
class DiscordBot(commands.Bot):
|
||||
|
||||
def __init__(self, bot, *args, **kwargs):
|
||||
self.bot = bot
|
||||
intents = discord.Intents(
|
||||
messages=True,
|
||||
guilds=True,
|
||||
typing=True,
|
||||
members=True,
|
||||
presences=True,
|
||||
reactions=True,
|
||||
message_content=True,
|
||||
voice_states=True
|
||||
)
|
||||
super().__init__(command_prefix='/', intents=intents, *args, **kwargs)
|
||||
|
||||
# async def setup_hook(self):
|
||||
# db.poll_db.start()
|
||||
|
||||
async def on_ready(self):
|
||||
print(f'{self.user.name} has connected to Discord!')
|
||||
for guild in self.guilds:
|
||||
for channel in guild.channels:
|
||||
if channel.name == "skynet":
|
||||
await channel.send('Skynet bot online', view=SkynetView(self.bot))
|
||||
# intro_msg = await channel.send('Welcome to the Skynet discord bot.\nSkynet is a decentralized compute layer, focused on supporting AI paradigms. Skynet leverages blockchain technology to manage work requests and fills. We are currently featuring image generation and support 11 different models. Get started with the /help command, or just click on some buttons. Here is an example command to generate an image:\n/txt2img a big red tractor in a giant field of corn')
|
||||
intro_msg = await channel.send("Welcome to Skynet's Discord Bot,\n\nSkynet operates as a decentralized compute layer, offering a wide array of support for diverse AI paradigms through the use of blockchain technology. Our present focus is image generation, powered by 11 distinct models.\n\nTo begin exploring, use the '/help' command or directly interact with the provided buttons. Here is an example command to generate an image:\n\n'/txt2img a big red tractor in a giant field of corn'")
|
||||
# await intro_msg.pin()
|
||||
|
||||
print("\n==============")
|
||||
print("Logged in as")
|
||||
print(self.user.name)
|
||||
print(self.user.id)
|
||||
print("==============")
|
||||
|
||||
async def on_message(self, message):
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
return
|
||||
elif message.channel.name != 'skynet':
|
||||
return
|
||||
elif message.author == self.user:
|
||||
return
|
||||
await self.process_commands(message)
|
||||
# await asyncio.sleep(3)
|
||||
# await message.channel.send('', view=SkynetView(self.bot))
|
||||
|
||||
async def on_command_error(self, ctx, error):
|
||||
if isinstance(error, commands.MissingRequiredArgument):
|
||||
await ctx.send('You missed a required argument, please try again.')
|
||||
|
||||
# async def on_message(self, message):
|
||||
# print(f"message from {message.author} what he said {message.content}")
|
||||
# await message.channel.send(message.content)
|
||||
|
||||
# bot=DiscordBot()
|
||||
# @bot.command(name='config', help='Responds with the configuration')
|
||||
# async def config(ctx):
|
||||
# response = "This is the bot configuration" # Put your bot configuration here
|
||||
# await ctx.send(response)
|
||||
#
|
||||
# @bot.command(name='helper', help='Responds with a help')
|
||||
# async def helper(ctx):
|
||||
# response = "This is help information" # Put your help response here
|
||||
# await ctx.send(response)
|
||||
#
|
||||
# @bot.command(name='txt2img', help='Responds with an image')
|
||||
# async def txt2img(ctx, *, arg):
|
||||
# response = f"This is your prompt: {arg}"
|
||||
# await ctx.send(response)
|
||||
# bot.run(discordToken)
|
|
@ -1,601 +0,0 @@
|
|||
import io
|
||||
import json
|
||||
import logging
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from PIL import Image
|
||||
# from telebot.types import CallbackQuery, Message
|
||||
|
||||
from skynet.frontend import validate_user_config_request
|
||||
from skynet.constants import *
|
||||
from .ui import SkynetView
|
||||
|
||||
|
||||
def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
||||
|
||||
bot = frontend.bot
|
||||
cleos = frontend.cleos
|
||||
db_call = frontend.db_call
|
||||
work_request = frontend.work_request
|
||||
|
||||
ipfs_node = frontend.ipfs_node
|
||||
|
||||
@bot.command(name='config', help='Responds with the configuration')
|
||||
async def set_config(ctx):
|
||||
|
||||
user = ctx.author
|
||||
try:
|
||||
attr, val, reply_txt = validate_user_config_request(
|
||||
ctx.message.content)
|
||||
|
||||
logging.info(f'user config update: {attr} to {val}')
|
||||
await db_call('update_user_config', user.id, attr, val)
|
||||
logging.info('done')
|
||||
|
||||
except BaseException as e:
|
||||
reply_txt = str(e)
|
||||
|
||||
finally:
|
||||
await ctx.reply(content=reply_txt, view=SkynetView(frontend))
|
||||
|
||||
bot.remove_command('help')
|
||||
|
||||
@bot.command(name='help', help='Responds with a help')
|
||||
async def help(ctx):
|
||||
splt_msg = ctx.message.content.split(' ')
|
||||
|
||||
if len(splt_msg) == 1:
|
||||
await ctx.send(content=f'```{HELP_TEXT}```', view=SkynetView(frontend))
|
||||
|
||||
else:
|
||||
param = splt_msg[1]
|
||||
if param in HELP_TOPICS:
|
||||
await ctx.send(content=f'```{HELP_TOPICS[param]}```', view=SkynetView(frontend))
|
||||
|
||||
else:
|
||||
await ctx.send(content=f'```{HELP_UNKWNOWN_PARAM}```', view=SkynetView(frontend))
|
||||
|
||||
@bot.command(name='cool', help='Display a list of cool prompt words')
|
||||
async def send_cool_words(ctx):
|
||||
clean_cool_word = '\n'.join(CLEAN_COOL_WORDS)
|
||||
await ctx.send(content=f'```{clean_cool_word}```', view=SkynetView(frontend))
|
||||
|
||||
@bot.command(name='stats', help='See user statistics')
|
||||
async def user_stats(ctx):
|
||||
user = ctx.author
|
||||
|
||||
await db_call('get_or_create_user', user.id)
|
||||
generated, joined, role = await db_call('get_user_stats', user.id)
|
||||
|
||||
stats_str = f'```generated: {generated}\n'
|
||||
stats_str += f'joined: {joined}\n'
|
||||
stats_str += f'role: {role}\n```'
|
||||
|
||||
await ctx.reply(stats_str, view=SkynetView(frontend))
|
||||
|
||||
@bot.command(name='donate', help='See donate info')
|
||||
async def donation_info(ctx):
|
||||
await ctx.reply(
|
||||
f'```\n{DONATION_INFO}```', view=SkynetView(frontend))
|
||||
|
||||
@bot.command(name='txt2img', help='Responds with an image')
|
||||
async def send_txt2img(ctx):
|
||||
|
||||
# grab user from ctx
|
||||
user = ctx.author
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
|
||||
# init new msg
|
||||
init_msg = 'started processing txt2img request...'
|
||||
status_msg = await ctx.send(init_msg)
|
||||
await db_call(
|
||||
'new_user_request', user.id, ctx.message.id, status_msg.id, status=init_msg)
|
||||
|
||||
prompt = ' '.join(ctx.message.content.split(' ')[1:])
|
||||
|
||||
if len(prompt) == 0:
|
||||
await status_msg.edit(content='Empty text prompt ignored.'
|
||||
)
|
||||
await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
logging.info(f'mid: {ctx.message.id}')
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
|
||||
await db_call(
|
||||
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
|
||||
|
||||
success = await work_request(user, status_msg, 'txt2img', params, ctx)
|
||||
|
||||
if success:
|
||||
await db_call('increment_generated', user.id)
|
||||
|
||||
@bot.command(name='redo', help='Redo last request')
|
||||
async def redo(ctx):
|
||||
init_msg = 'started processing redo request...'
|
||||
status_msg = await ctx.send(init_msg)
|
||||
user = ctx.author
|
||||
|
||||
method = await db_call('get_last_method_of', user.id)
|
||||
prompt = await db_call('get_last_prompt_of', user.id)
|
||||
|
||||
file_id = None
|
||||
binary = ''
|
||||
if method == 'img2img':
|
||||
file_id = await db_call('get_last_file_of', user.id)
|
||||
binary = await db_call('get_last_binary_of', user.id)
|
||||
|
||||
if not prompt:
|
||||
await status_msg.edit(
|
||||
content='no last prompt found, do a txt2img cmd first!',
|
||||
view=SkynetView(frontend)
|
||||
)
|
||||
return
|
||||
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
await db_call(
|
||||
'new_user_request', user.id, ctx.message.id, status_msg.id, status=init_msg)
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
|
||||
success = await work_request(
|
||||
user, status_msg, 'redo', params, ctx,
|
||||
file_id=file_id,
|
||||
binary_data=binary
|
||||
)
|
||||
|
||||
if success:
|
||||
await db_call('increment_generated', user.id)
|
||||
|
||||
@bot.command(name='img2img', help='Responds with an image')
|
||||
async def send_img2img(ctx):
|
||||
# if isinstance(message_or_query, CallbackQuery):
|
||||
# query = message_or_query
|
||||
# message = query.message
|
||||
# user = query.from_user
|
||||
# chat = query.message.chat
|
||||
#
|
||||
# else:
|
||||
# message = message_or_query
|
||||
# user = message.from_user
|
||||
# chat = message.chat
|
||||
|
||||
# reply_id = None
|
||||
# if chat.type == 'group' and chat.id == GROUP_ID:
|
||||
# reply_id = message.message_id
|
||||
#
|
||||
user = ctx.author
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
|
||||
# init new msg
|
||||
init_msg = 'started processing img2img request...'
|
||||
status_msg = await ctx.send(init_msg)
|
||||
await db_call(
|
||||
'new_user_request', user.id, ctx.message.id, status_msg.id, status=init_msg)
|
||||
|
||||
if not ctx.message.content.startswith('/img2img'):
|
||||
await ctx.reply(
|
||||
'For image to image you need to add /img2img to the beggining of your caption'
|
||||
)
|
||||
return
|
||||
|
||||
prompt = ' '.join(ctx.message.content.split(' ')[1:])
|
||||
|
||||
if len(prompt) == 0:
|
||||
await ctx.reply('Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
# file_id = message.photo[-1].file_id
|
||||
# file_path = (await bot.get_file(file_id)).file_path
|
||||
# image_raw = await bot.download_file(file_path)
|
||||
#
|
||||
|
||||
file = ctx.message.attachments[-1]
|
||||
file_id = str(file.id)
|
||||
# file bytes
|
||||
image_raw = await file.read()
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
with Image.open(io.BytesIO(image_raw)) as image:
|
||||
w, h = image.size
|
||||
|
||||
if w > user_config['width'] or h > user_config['height']:
|
||||
logging.warning(f'user sent img of size {image.size}')
|
||||
image.thumbnail(
|
||||
(user_config['width'], user_config['height']))
|
||||
logging.warning(f'resized it to {image.size}')
|
||||
|
||||
# if w > 512 or h > 512:
|
||||
# logging.warning(f'user sent img of size {image.size}')
|
||||
# image.thumbnail((512, 512))
|
||||
# logging.warning(f'resized it to {image.size}')
|
||||
# image.save(f'ipfs-docker-staging/image.png', format='PNG')
|
||||
image_loc = 'ipfs-staging/image.png'
|
||||
image.save(image_loc, format='PNG')
|
||||
|
||||
ipfs_info = await ipfs_node.add(image_loc)
|
||||
ipfs_hash = ipfs_info['Hash']
|
||||
await ipfs_node.pin(ipfs_hash)
|
||||
|
||||
logging.info(f'published input image {ipfs_hash} on ipfs')
|
||||
|
||||
logging.info(f'mid: {ctx.message.id}')
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
|
||||
await db_call(
|
||||
'update_user_stats',
|
||||
user.id,
|
||||
'img2img',
|
||||
last_prompt=prompt,
|
||||
last_file=file_id,
|
||||
last_binary=ipfs_hash
|
||||
)
|
||||
|
||||
success = await work_request(
|
||||
user, status_msg, 'img2img', params, ctx,
|
||||
file_id=file_id,
|
||||
binary_data=ipfs_hash
|
||||
)
|
||||
|
||||
if success:
|
||||
await db_call('increment_generated', user.id)
|
||||
|
||||
# TODO: DELETE BELOW
|
||||
# user = 'testworker3'
|
||||
# status_msg = 'status'
|
||||
# params = {
|
||||
# 'prompt': arg,
|
||||
# 'seed': None,
|
||||
# 'step': 35,
|
||||
# 'guidance': 7.5,
|
||||
# 'strength': 0.5,
|
||||
# 'width': 512,
|
||||
# 'height': 512,
|
||||
# 'upscaler': None,
|
||||
# 'model': 'prompthero/openjourney',
|
||||
# }
|
||||
#
|
||||
# ec = await work_request(user, status_msg, 'txt2img', params, ctx)
|
||||
# print(ec)
|
||||
|
||||
# if ec == 0:
|
||||
# await db_call('increment_generated', user.id)
|
||||
|
||||
# response = f"This is your prompt: {arg}"
|
||||
# await ctx.send(response)
|
||||
|
||||
# generic / simple handlers
|
||||
|
||||
# @bot.message_handler(commands=['help'])
|
||||
# async def send_help(message):
|
||||
# splt_msg = message.text.split(' ')
|
||||
#
|
||||
# if len(splt_msg) == 1:
|
||||
# await bot.reply_to(message, HELP_TEXT)
|
||||
#
|
||||
# else:
|
||||
# param = splt_msg[1]
|
||||
# if param in HELP_TOPICS:
|
||||
# await bot.reply_to(message, HELP_TOPICS[param])
|
||||
#
|
||||
# else:
|
||||
# await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
|
||||
#
|
||||
# @bot.message_handler(commands=['cool'])
|
||||
# async def send_cool_words(message):
|
||||
# await bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||
#
|
||||
# @bot.message_handler(commands=['queue'])
|
||||
# async def queue(message):
|
||||
# an_hour_ago = datetime.now() - timedelta(hours=1)
|
||||
# queue = await cleos.aget_table(
|
||||
# 'gpu.scd', 'gpu.scd', 'queue',
|
||||
# index_position=2,
|
||||
# key_type='i64',
|
||||
# sort='desc',
|
||||
# lower_bound=int(an_hour_ago.timestamp())
|
||||
# )
|
||||
# await bot.reply_to(
|
||||
# message, f'Total requests on skynet queue: {len(queue)}')
|
||||
|
||||
# @bot.message_handler(commands=['config'])
|
||||
# async def set_config(message):
|
||||
# user = message.from_user.id
|
||||
# try:
|
||||
# attr, val, reply_txt = validate_user_config_request(
|
||||
# message.text)
|
||||
#
|
||||
# logging.info(f'user config update: {attr} to {val}')
|
||||
# await db_call('update_user_config', user, attr, val)
|
||||
# logging.info('done')
|
||||
#
|
||||
# except BaseException as e:
|
||||
# reply_txt = str(e)
|
||||
#
|
||||
# finally:
|
||||
# await bot.reply_to(message, reply_txt)
|
||||
#
|
||||
# @bot.message_handler(commands=['stats'])
|
||||
# async def user_stats(message):
|
||||
# user = message.from_user.id
|
||||
#
|
||||
# await db_call('get_or_create_user', user)
|
||||
# generated, joined, role = await db_call('get_user_stats', user)
|
||||
#
|
||||
# stats_str = f'generated: {generated}\n'
|
||||
# stats_str += f'joined: {joined}\n'
|
||||
# stats_str += f'role: {role}\n'
|
||||
#
|
||||
# await bot.reply_to(
|
||||
# message, stats_str)
|
||||
#
|
||||
# @bot.message_handler(commands=['donate'])
|
||||
# async def donation_info(message):
|
||||
# await bot.reply_to(
|
||||
# message, DONATION_INFO)
|
||||
#
|
||||
# @bot.message_handler(commands=['say'])
|
||||
# async def say(message):
|
||||
# chat = message.chat
|
||||
# user = message.from_user
|
||||
#
|
||||
# if (chat.type == 'group') or (user.id != 383385940):
|
||||
# return
|
||||
#
|
||||
# await bot.send_message(GROUP_ID, message.text[4:])
|
||||
|
||||
# generic txt2img handler
|
||||
|
||||
# async def _generic_txt2img(message_or_query):
|
||||
# if isinstance(message_or_query, CallbackQuery):
|
||||
# query = message_or_query
|
||||
# message = query.message
|
||||
# user = query.from_user
|
||||
# chat = query.message.chat
|
||||
#
|
||||
# else:
|
||||
# message = message_or_query
|
||||
# user = message.from_user
|
||||
# chat = message.chat
|
||||
#
|
||||
# reply_id = None
|
||||
# if chat.type == 'group' and chat.id == GROUP_ID:
|
||||
# reply_id = message.message_id
|
||||
#
|
||||
# user_row = await db_call('get_or_create_user', user.id)
|
||||
#
|
||||
# # init new msg
|
||||
# init_msg = 'started processing txt2img request...'
|
||||
# status_msg = await bot.reply_to(message, init_msg)
|
||||
# await db_call(
|
||||
# 'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
||||
#
|
||||
# prompt = ' '.join(message.text.split(' ')[1:])
|
||||
#
|
||||
# if len(prompt) == 0:
|
||||
# await bot.edit_message_text(
|
||||
# 'Empty text prompt ignored.',
|
||||
# chat_id=status_msg.chat.id,
|
||||
# message_id=status_msg.id
|
||||
# )
|
||||
# await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
|
||||
# return
|
||||
#
|
||||
# logging.info(f'mid: {message.id}')
|
||||
#
|
||||
# user_config = {**user_row}
|
||||
# del user_config['id']
|
||||
#
|
||||
# params = {
|
||||
# 'prompt': prompt,
|
||||
# **user_config
|
||||
# }
|
||||
#
|
||||
# await db_call(
|
||||
# 'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
|
||||
#
|
||||
# ec = await work_request(user, status_msg, 'txt2img', params)
|
||||
|
||||
# if ec == 0:
|
||||
# await db_call('increment_generated', user.id)
|
||||
#
|
||||
#
|
||||
# # generic img2img handler
|
||||
#
|
||||
# async def _generic_img2img(message_or_query):
|
||||
# if isinstance(message_or_query, CallbackQuery):
|
||||
# query = message_or_query
|
||||
# message = query.message
|
||||
# user = query.from_user
|
||||
# chat = query.message.chat
|
||||
#
|
||||
# else:
|
||||
# message = message_or_query
|
||||
# user = message.from_user
|
||||
# chat = message.chat
|
||||
#
|
||||
# reply_id = None
|
||||
# if chat.type == 'group' and chat.id == GROUP_ID:
|
||||
# reply_id = message.message_id
|
||||
#
|
||||
# user_row = await db_call('get_or_create_user', user.id)
|
||||
#
|
||||
# # init new msg
|
||||
# init_msg = 'started processing txt2img request...'
|
||||
# status_msg = await bot.reply_to(message, init_msg)
|
||||
# await db_call(
|
||||
# 'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
||||
#
|
||||
# if not message.caption.startswith('/img2img'):
|
||||
# await bot.reply_to(
|
||||
# message,
|
||||
# 'For image to image you need to add /img2img to the beggining of your caption'
|
||||
# )
|
||||
# return
|
||||
#
|
||||
# prompt = ' '.join(message.caption.split(' ')[1:])
|
||||
#
|
||||
# if len(prompt) == 0:
|
||||
# await bot.reply_to(message, 'Empty text prompt ignored.')
|
||||
# return
|
||||
#
|
||||
# file_id = message.photo[-1].file_id
|
||||
# file_path = (await bot.get_file(file_id)).file_path
|
||||
# image_raw = await bot.download_file(file_path)
|
||||
|
||||
# with Image.open(io.BytesIO(image_raw)) as image:
|
||||
# w, h = image.size
|
||||
#
|
||||
# if w > 512 or h > 512:
|
||||
# logging.warning(f'user sent img of size {image.size}')
|
||||
# image.thumbnail((512, 512))
|
||||
# logging.warning(f'resized it to {image.size}')
|
||||
#
|
||||
# image.save(f'ipfs-docker-staging/image.png', format='PNG')
|
||||
#
|
||||
# ipfs_hash = ipfs_node.add('image.png')
|
||||
# ipfs_node.pin(ipfs_hash)
|
||||
#
|
||||
# logging.info(f'published input image {ipfs_hash} on ipfs')
|
||||
#
|
||||
# logging.info(f'mid: {message.id}')
|
||||
#
|
||||
# user_config = {**user_row}
|
||||
# del user_config['id']
|
||||
#
|
||||
# params = {
|
||||
# 'prompt': prompt,
|
||||
# **user_config
|
||||
# }
|
||||
#
|
||||
# await db_call(
|
||||
# 'update_user_stats',
|
||||
# user.id,
|
||||
# 'img2img',
|
||||
# last_file=file_id,
|
||||
# last_prompt=prompt,
|
||||
# last_binary=ipfs_hash
|
||||
# )
|
||||
#
|
||||
# ec = await work_request(
|
||||
# user, status_msg, 'img2img', params,
|
||||
# file_id=file_id,
|
||||
# binary_data=ipfs_hash
|
||||
# )
|
||||
#
|
||||
# if ec == 0:
|
||||
# await db_call('increment_generated', user.id)
|
||||
#
|
||||
|
||||
# generic redo handler
|
||||
|
||||
# async def _redo(message_or_query):
|
||||
# is_query = False
|
||||
# if isinstance(message_or_query, CallbackQuery):
|
||||
# is_query = True
|
||||
# query = message_or_query
|
||||
# message = query.message
|
||||
# user = query.from_user
|
||||
# chat = query.message.chat
|
||||
#
|
||||
# elif isinstance(message_or_query, Message):
|
||||
# message = message_or_query
|
||||
# user = message.from_user
|
||||
# chat = message.chat
|
||||
#
|
||||
# init_msg = 'started processing redo request...'
|
||||
# if is_query:
|
||||
# status_msg = await bot.send_message(chat.id, init_msg)
|
||||
#
|
||||
# else:
|
||||
# status_msg = await bot.reply_to(message, init_msg)
|
||||
#
|
||||
# method = await db_call('get_last_method_of', user.id)
|
||||
# prompt = await db_call('get_last_prompt_of', user.id)
|
||||
#
|
||||
# file_id = None
|
||||
# binary = ''
|
||||
# if method == 'img2img':
|
||||
# file_id = await db_call('get_last_file_of', user.id)
|
||||
# binary = await db_call('get_last_binary_of', user.id)
|
||||
#
|
||||
# if not prompt:
|
||||
# await bot.reply_to(
|
||||
# message,
|
||||
# 'no last prompt found, do a txt2img cmd first!'
|
||||
# )
|
||||
# return
|
||||
#
|
||||
#
|
||||
# user_row = await db_call('get_or_create_user', user.id)
|
||||
# await db_call(
|
||||
# 'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
||||
# user_config = {**user_row}
|
||||
# del user_config['id']
|
||||
#
|
||||
# params = {
|
||||
# 'prompt': prompt,
|
||||
# **user_config
|
||||
# }
|
||||
#
|
||||
# await work_request(
|
||||
# user, status_msg, 'redo', params,
|
||||
# file_id=file_id,
|
||||
# binary_data=binary
|
||||
# )
|
||||
|
||||
# "proxy" handlers just request routers
|
||||
|
||||
# @bot.message_handler(commands=['txt2img'])
|
||||
# async def send_txt2img(message):
|
||||
# await _generic_txt2img(message)
|
||||
#
|
||||
# @bot.message_handler(func=lambda message: True, content_types=[
|
||||
# 'photo', 'document'])
|
||||
# async def send_img2img(message):
|
||||
# await _generic_img2img(message)
|
||||
#
|
||||
# @bot.message_handler(commands=['img2img'])
|
||||
# async def img2img_missing_image(message):
|
||||
# await bot.reply_to(
|
||||
# message,
|
||||
# 'seems you tried to do an img2img command without sending image'
|
||||
# )
|
||||
#
|
||||
# @bot.message_handler(commands=['redo'])
|
||||
# async def redo(message):
|
||||
# await _redo(message)
|
||||
#
|
||||
# @bot.callback_query_handler(func=lambda call: True)
|
||||
# async def callback_query(call):
|
||||
# msg = json.loads(call.data)
|
||||
# logging.info(call.data)
|
||||
# method = msg.get('method')
|
||||
# match method:
|
||||
# case 'redo':
|
||||
# await _redo(call)
|
||||
|
||||
# catch all handler for things we dont support
|
||||
|
||||
# @bot.message_handler(func=lambda message: True)
|
||||
# async def echo_message(message):
|
||||
# if message.text[0] == '/':
|
||||
# await bot.reply_to(message, UNKNOWN_CMD_TEXT)
|
|
@ -1,325 +0,0 @@
|
|||
import io
|
||||
import discord
|
||||
from PIL import Image
|
||||
import logging
|
||||
from skynet.constants import *
|
||||
from skynet.frontend import validate_user_config_request
|
||||
|
||||
|
||||
class SkynetView(discord.ui.View):
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
super().__init__(timeout=None)
|
||||
self.add_item(RedoButton(
|
||||
'redo', discord.ButtonStyle.primary, self.bot))
|
||||
self.add_item(Txt2ImgButton(
|
||||
'txt2img', discord.ButtonStyle.primary, self.bot))
|
||||
self.add_item(Img2ImgButton(
|
||||
'img2img', discord.ButtonStyle.primary, self.bot))
|
||||
self.add_item(StatsButton(
|
||||
'stats', discord.ButtonStyle.secondary, self.bot))
|
||||
self.add_item(DonateButton(
|
||||
'donate', discord.ButtonStyle.secondary, self.bot))
|
||||
self.add_item(ConfigButton(
|
||||
'config', discord.ButtonStyle.secondary, self.bot))
|
||||
self.add_item(HelpButton(
|
||||
'help', discord.ButtonStyle.secondary, self.bot))
|
||||
self.add_item(CoolButton(
|
||||
'cool', discord.ButtonStyle.secondary, self.bot))
|
||||
|
||||
|
||||
class Txt2ImgButton(discord.ui.Button):
|
||||
|
||||
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||
self.bot = bot
|
||||
super().__init__(label=label, style=style)
|
||||
|
||||
async def callback(self, interaction):
|
||||
db_call = self.bot.db_call
|
||||
work_request = self.bot.work_request
|
||||
msg = await grab('Enter your prompt:', interaction)
|
||||
# grab user from msg
|
||||
user = msg.author
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
|
||||
# init new msg
|
||||
init_msg = 'started processing txt2img request...'
|
||||
status_msg = await msg.channel.send(init_msg)
|
||||
await db_call(
|
||||
'new_user_request', user.id, msg.id, status_msg.id, status=init_msg)
|
||||
|
||||
prompt = msg.content
|
||||
|
||||
if len(prompt) == 0:
|
||||
await status_msg.edit(content='Empty text prompt ignored.'
|
||||
)
|
||||
await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
logging.info(f'mid: {msg.id}')
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
|
||||
await db_call(
|
||||
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
|
||||
|
||||
success = await work_request(user, status_msg, 'txt2img', params, msg)
|
||||
|
||||
if success:
|
||||
await db_call('increment_generated', user.id)
|
||||
|
||||
|
||||
class Img2ImgButton(discord.ui.Button):
|
||||
|
||||
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||
self.bot = bot
|
||||
super().__init__(label=label, style=style)
|
||||
|
||||
async def callback(self, interaction):
|
||||
db_call = self.bot.db_call
|
||||
work_request = self.bot.work_request
|
||||
ipfs_node = self.bot.ipfs_node
|
||||
msg = await grab('Attach an Image. Enter your prompt:', interaction)
|
||||
|
||||
user = msg.author
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
|
||||
# init new msg
|
||||
init_msg = 'started processing img2img request...'
|
||||
status_msg = await msg.channel.send(init_msg)
|
||||
await db_call(
|
||||
'new_user_request', user.id, msg.id, status_msg.id, status=init_msg)
|
||||
|
||||
# if not msg.content.startswith('/img2img'):
|
||||
# await msg.reply(
|
||||
# 'For image to image you need to add /img2img to the beggining of your caption'
|
||||
# )
|
||||
# return
|
||||
|
||||
prompt = msg.content
|
||||
|
||||
if len(prompt) == 0:
|
||||
await msg.reply('Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
# file_id = message.photo[-1].file_id
|
||||
# file_path = (await bot.get_file(file_id)).file_path
|
||||
# image_raw = await bot.download_file(file_path)
|
||||
#
|
||||
|
||||
file = msg.attachments[-1]
|
||||
file_id = str(file.id)
|
||||
# file bytes
|
||||
image_raw = await file.read()
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
with Image.open(io.BytesIO(image_raw)) as image:
|
||||
w, h = image.size
|
||||
|
||||
if w > user_config['width'] or h > user_config['height']:
|
||||
logging.warning(f'user sent img of size {image.size}')
|
||||
image.thumbnail(
|
||||
(user_config['width'], user_config['height']))
|
||||
logging.warning(f'resized it to {image.size}')
|
||||
|
||||
# if w > 512 or h > 512:
|
||||
# logging.warning(f'user sent img of size {image.size}')
|
||||
# image.thumbnail((512, 512))
|
||||
# logging.warning(f'resized it to {image.size}')
|
||||
# image.save(f'ipfs-docker-staging/image.png', format='PNG')
|
||||
image_loc = 'ipfs-staging/image.png'
|
||||
image.save(image_loc, format='PNG')
|
||||
|
||||
ipfs_info = await ipfs_node.add(image_loc)
|
||||
ipfs_hash = ipfs_info['Hash']
|
||||
await ipfs_node.pin(ipfs_hash)
|
||||
|
||||
logging.info(f'published input image {ipfs_hash} on ipfs')
|
||||
|
||||
logging.info(f'mid: {msg.id}')
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
|
||||
await db_call(
|
||||
'update_user_stats',
|
||||
user.id,
|
||||
'img2img',
|
||||
last_prompt=prompt,
|
||||
last_file=file_id,
|
||||
last_binary=ipfs_hash
|
||||
)
|
||||
|
||||
success = await work_request(
|
||||
user, status_msg, 'img2img', params, msg,
|
||||
file_id=file_id,
|
||||
binary_data=ipfs_hash
|
||||
)
|
||||
|
||||
if success:
|
||||
await db_call('increment_generated', user.id)
|
||||
|
||||
|
||||
class RedoButton(discord.ui.Button):
|
||||
|
||||
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||
self.bot = bot
|
||||
super().__init__(label=label, style=style)
|
||||
|
||||
async def callback(self, interaction):
|
||||
db_call = self.bot.db_call
|
||||
work_request = self.bot.work_request
|
||||
init_msg = 'started processing redo request...'
|
||||
await interaction.response.send_message(init_msg)
|
||||
status_msg = await interaction.original_response()
|
||||
user = interaction.user
|
||||
|
||||
method = await db_call('get_last_method_of', user.id)
|
||||
prompt = await db_call('get_last_prompt_of', user.id)
|
||||
|
||||
file_id = None
|
||||
binary = ''
|
||||
if method == 'img2img':
|
||||
file_id = await db_call('get_last_file_of', user.id)
|
||||
binary = await db_call('get_last_binary_of', user.id)
|
||||
|
||||
if not prompt:
|
||||
await status_msg.edit(
|
||||
content='no last prompt found, do a txt2img cmd first!',
|
||||
view=SkynetView(self.bot)
|
||||
)
|
||||
return
|
||||
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
await db_call(
|
||||
'new_user_request', user.id, interaction.id, status_msg.id, status=init_msg)
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
success = await work_request(
|
||||
user, status_msg, 'redo', params, interaction,
|
||||
file_id=file_id,
|
||||
binary_data=binary
|
||||
)
|
||||
|
||||
if success:
|
||||
await db_call('increment_generated', user.id)
|
||||
|
||||
|
||||
class ConfigButton(discord.ui.Button):
|
||||
|
||||
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||
self.bot = bot
|
||||
super().__init__(label=label, style=style)
|
||||
|
||||
async def callback(self, interaction):
|
||||
db_call = self.bot.db_call
|
||||
msg = await grab('What params do you want to change? (format: <param> <value>)', interaction)
|
||||
|
||||
user = interaction.user
|
||||
try:
|
||||
attr, val, reply_txt = validate_user_config_request(
|
||||
'/config ' + msg.content)
|
||||
|
||||
logging.info(f'user config update: {attr} to {val}')
|
||||
await db_call('update_user_config', user.id, attr, val)
|
||||
logging.info('done')
|
||||
|
||||
except BaseException as e:
|
||||
reply_txt = str(e)
|
||||
|
||||
finally:
|
||||
await msg.reply(content=reply_txt, view=SkynetView(self.bot))
|
||||
|
||||
|
||||
class StatsButton(discord.ui.Button):
|
||||
|
||||
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||
self.bot = bot
|
||||
super().__init__(label=label, style=style)
|
||||
|
||||
async def callback(self, interaction):
|
||||
db_call = self.bot.db_call
|
||||
|
||||
user = interaction.user
|
||||
|
||||
await db_call('get_or_create_user', user.id)
|
||||
generated, joined, role = await db_call('get_user_stats', user.id)
|
||||
|
||||
stats_str = f'```generated: {generated}\n'
|
||||
stats_str += f'joined: {joined}\n'
|
||||
stats_str += f'role: {role}\n```'
|
||||
|
||||
await interaction.response.send_message(
|
||||
content=stats_str, view=SkynetView(self.bot))
|
||||
|
||||
|
||||
class DonateButton(discord.ui.Button):
|
||||
|
||||
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||
self.bot = bot
|
||||
super().__init__(label=label, style=style)
|
||||
|
||||
async def callback(self, interaction):
|
||||
await interaction.response.send_message(
|
||||
content=f'```\n{DONATION_INFO}```',
|
||||
view=SkynetView(self.bot))
|
||||
|
||||
|
||||
class CoolButton(discord.ui.Button):
|
||||
|
||||
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||
self.bot = bot
|
||||
super().__init__(label=label, style=style)
|
||||
|
||||
async def callback(self, interaction):
|
||||
clean_cool_word = '\n'.join(CLEAN_COOL_WORDS)
|
||||
await interaction.response.send_message(
|
||||
content=f'```{clean_cool_word}```',
|
||||
view=SkynetView(self.bot))
|
||||
|
||||
|
||||
class HelpButton(discord.ui.Button):
|
||||
|
||||
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||
self.bot = bot
|
||||
super().__init__(label=label, style=style)
|
||||
|
||||
async def callback(self, interaction):
|
||||
msg = await grab('What would you like help with? (a for all)', interaction)
|
||||
|
||||
param = msg.content
|
||||
|
||||
if param == 'a':
|
||||
await msg.reply(content=f'```{HELP_TEXT}```', view=SkynetView(self.bot))
|
||||
|
||||
else:
|
||||
if param in HELP_TOPICS:
|
||||
await msg.reply(content=f'```{HELP_TOPICS[param]}```', view=SkynetView(self.bot))
|
||||
|
||||
else:
|
||||
await msg.reply(content=f'```{HELP_UNKWNOWN_PARAM}```', view=SkynetView(self.bot))
|
||||
|
||||
|
||||
async def grab(prompt, interaction):
|
||||
def vet(m):
|
||||
return m.author == interaction.user and m.channel == interaction.channel
|
||||
|
||||
await interaction.response.send_message(prompt, ephemeral=True)
|
||||
message = await interaction.client.wait_for('message', check=vet)
|
||||
return message
|
|
@ -1,123 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from telebot.types import InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telebot.async_telebot import ExceptionHandler
|
||||
from telebot.formatting import hlink
|
||||
import discord
|
||||
|
||||
from skynet.constants import *
|
||||
|
||||
|
||||
def timestamp_pretty():
|
||||
return datetime.now(timezone.utc).strftime('%H:%M:%S')
|
||||
|
||||
|
||||
def tg_user_pretty(tguser):
|
||||
if tguser.username:
|
||||
return f'@{tguser.username}'
|
||||
else:
|
||||
return f'{tguser.first_name} id: {tguser.id}'
|
||||
|
||||
|
||||
class SKYExceptionHandler(ExceptionHandler):
|
||||
|
||||
def handle(exception):
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def build_redo_menu():
|
||||
btn_redo = InlineKeyboardButton(
|
||||
"Redo", callback_data=json.dumps({'method': 'redo'}))
|
||||
inline_keyboard = InlineKeyboardMarkup()
|
||||
inline_keyboard.add(btn_redo)
|
||||
return inline_keyboard
|
||||
|
||||
|
||||
def prepare_metainfo_caption(user, worker: str, reward: str, meta: dict, embed) -> str:
|
||||
prompt = meta["prompt"]
|
||||
if len(prompt) > 256:
|
||||
prompt = prompt[:256]
|
||||
|
||||
gen_str = f'generated by {user.name}\n'
|
||||
gen_str += f'performed by {worker}\n'
|
||||
gen_str += f'reward: {reward}\n'
|
||||
|
||||
embed.add_field(
|
||||
name='General Info', value=f'```{gen_str}```', inline=False)
|
||||
# meta_str = f'__by {user.name}__\n'
|
||||
# meta_str += f'*performed by {worker}*\n'
|
||||
# meta_str += f'__**reward: {reward}**__\n'
|
||||
embed.add_field(name='Prompt', value=f'```{prompt}\n```', inline=False)
|
||||
|
||||
# meta_str = f'`prompt:` {prompt}\n'
|
||||
|
||||
meta_str = f'seed: {meta["seed"]}\n'
|
||||
meta_str += f'step: {meta["step"]}\n'
|
||||
meta_str += f'guidance: {meta["guidance"]}\n'
|
||||
|
||||
if meta['strength']:
|
||||
meta_str += f'strength: {meta["strength"]}\n'
|
||||
meta_str += f'algo: {meta["model"]}\n'
|
||||
if meta['upscaler']:
|
||||
meta_str += f'upscaler: {meta["upscaler"]}\n'
|
||||
|
||||
embed.add_field(name='Parameters', value=f'```{meta_str}```', inline=False)
|
||||
|
||||
foot_str = f'Made with Skynet v{VERSION}\n'
|
||||
foot_str += f'JOIN THE SWARM: https://discord.gg/PAabjJtZAF'
|
||||
|
||||
embed.set_footer(text=foot_str)
|
||||
|
||||
return meta_str
|
||||
|
||||
|
||||
def generate_reply_caption(
|
||||
user, # discord user
|
||||
params: dict,
|
||||
tx_hash: str,
|
||||
worker: str,
|
||||
reward: str,
|
||||
explorer_domain: str
|
||||
):
|
||||
explorer_link = discord.Embed(
|
||||
title='[SKYNET Transaction Explorer]',
|
||||
url=f'https://{explorer_domain}/v2/explore/transaction/{tx_hash}',
|
||||
color=discord.Color.blue())
|
||||
|
||||
meta_info = prepare_metainfo_caption(
|
||||
user, worker, reward, params, explorer_link)
|
||||
|
||||
# why do we have this?
|
||||
final_msg = '\n'.join([
|
||||
'Worker finished your task!',
|
||||
# explorer_link,
|
||||
f'PARAMETER INFO:\n{meta_info}'
|
||||
])
|
||||
|
||||
# final_msg += '\n'.join([
|
||||
# # f'***{explorer_link}***',
|
||||
# f'{meta_info}'
|
||||
# ])
|
||||
|
||||
logging.info(final_msg)
|
||||
|
||||
return final_msg, explorer_link
|
||||
|
||||
|
||||
async def get_global_config(cleos):
|
||||
return (await cleos.aget_table(
|
||||
'gpu.scd', 'gpu.scd', 'config'))[0]
|
||||
|
||||
|
||||
async def get_user_nonce(cleos, user: str):
|
||||
return (await cleos.aget_table(
|
||||
'gpu.scd', 'gpu.scd', 'users',
|
||||
index_position=1,
|
||||
key_type='name',
|
||||
lower_bound=user,
|
||||
upper_bound=user
|
||||
))[0]['nonce']
|
|
@ -1,295 +0,0 @@
|
|||
import io
|
||||
import random
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from json import JSONDecodeError
|
||||
from decimal import Decimal
|
||||
from hashlib import sha256
|
||||
from datetime import datetime
|
||||
from contextlib import AsyncExitStack
|
||||
from contextlib import asynccontextmanager as acm
|
||||
|
||||
from leap.cleos import CLEOS
|
||||
from leap.protocol import Name, Asset
|
||||
from leap.hyperion import HyperionAPI
|
||||
|
||||
from telebot.types import InputMediaPhoto
|
||||
from telebot.async_telebot import AsyncTeleBot
|
||||
|
||||
from skynet.db import open_database_connection
|
||||
from skynet.ipfs import get_ipfs_file, AsyncIPFSHTTP
|
||||
from skynet.constants import *
|
||||
|
||||
from . import *
|
||||
|
||||
from .utils import *
|
||||
from .handlers import create_handler_context
|
||||
|
||||
|
||||
class SkynetTelegramFrontend:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token: str,
|
||||
account: str,
|
||||
permission: str,
|
||||
node_url: str,
|
||||
hyperion_url: str,
|
||||
db_host: str,
|
||||
db_user: str,
|
||||
db_pass: str,
|
||||
ipfs_node: str,
|
||||
key: str,
|
||||
explorer_domain: str,
|
||||
ipfs_domain: str
|
||||
):
|
||||
self.token = token
|
||||
self.account = account
|
||||
self.permission = permission
|
||||
self.node_url = node_url
|
||||
self.hyperion_url = hyperion_url
|
||||
self.db_host = db_host
|
||||
self.db_user = db_user
|
||||
self.db_pass = db_pass
|
||||
self.key = key
|
||||
self.explorer_domain = explorer_domain
|
||||
self.ipfs_domain = ipfs_domain
|
||||
|
||||
self.bot = AsyncTeleBot(token, exception_handler=SKYExceptionHandler)
|
||||
self.cleos = CLEOS(endpoint=node_url)
|
||||
self.cleos.load_abi('gpu.scd', GPU_CONTRACT_ABI)
|
||||
self.hyperion = HyperionAPI(hyperion_url)
|
||||
self.ipfs_node = AsyncIPFSHTTP(ipfs_node)
|
||||
|
||||
self._async_exit_stack = AsyncExitStack()
|
||||
|
||||
async def start(self):
|
||||
self.db_call = await self._async_exit_stack.enter_async_context(
|
||||
open_database_connection(
|
||||
self.db_user, self.db_pass, self.db_host))
|
||||
|
||||
create_handler_context(self)
|
||||
|
||||
async def stop(self):
|
||||
await self._async_exit_stack.aclose()
|
||||
|
||||
@acm
|
||||
async def open(self):
|
||||
await self.start()
|
||||
yield self
|
||||
await self.stop()
|
||||
|
||||
async def update_status_message(
|
||||
self, status_msg, new_text: str, **kwargs
|
||||
):
|
||||
await self.db_call(
|
||||
'update_user_request_by_sid', status_msg.id, new_text)
|
||||
return await self.bot.edit_message_text(
|
||||
new_text,
|
||||
chat_id=status_msg.chat.id,
|
||||
message_id=status_msg.id,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def append_status_message(
|
||||
self, status_msg, add_text: str, **kwargs
|
||||
):
|
||||
request = await self.db_call('get_user_request_by_sid', status_msg.id)
|
||||
await self.update_status_message(
|
||||
status_msg,
|
||||
request['status'] + add_text,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def work_request(
|
||||
self,
|
||||
user,
|
||||
status_msg,
|
||||
method: str,
|
||||
params: dict,
|
||||
file_id: str | None = None,
|
||||
inputs: list[str] = []
|
||||
) -> bool:
|
||||
if params['seed'] == None:
|
||||
params['seed'] = random.randint(0, 0xFFFFFFFF)
|
||||
|
||||
sanitized_params = {}
|
||||
for key, val in params.items():
|
||||
if isinstance(val, Decimal):
|
||||
val = str(val)
|
||||
|
||||
sanitized_params[key] = val
|
||||
|
||||
body = json.dumps({
|
||||
'method': 'diffuse',
|
||||
'params': sanitized_params
|
||||
})
|
||||
request_time = datetime.now().isoformat()
|
||||
|
||||
await self.update_status_message(
|
||||
status_msg,
|
||||
f'processing a \'{method}\' request by {tg_user_pretty(user)}\n'
|
||||
f'[{timestamp_pretty()}] <i>broadcasting transaction to chain...</i>',
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
reward = '20.0000 GPU'
|
||||
res = await self.cleos.a_push_action(
|
||||
'gpu.scd',
|
||||
'enqueue',
|
||||
list({
|
||||
'user': Name(self.account),
|
||||
'request_body': body,
|
||||
'binary_data': ','.join(inputs),
|
||||
'reward': Asset.from_str(reward),
|
||||
'min_verification': 1
|
||||
}.values()),
|
||||
self.account, self.key, permission=self.permission
|
||||
)
|
||||
|
||||
if 'code' in res or 'statusCode' in res:
|
||||
logging.error(json.dumps(res, indent=4))
|
||||
await self.update_status_message(
|
||||
status_msg,
|
||||
'skynet has suffered an internal error trying to fill this request')
|
||||
return False
|
||||
|
||||
enqueue_tx_id = res['transaction_id']
|
||||
enqueue_tx_link = hlink(
|
||||
'Your request on Skynet Explorer',
|
||||
f'https://{self.explorer_domain}/v2/explore/transaction/{enqueue_tx_id}'
|
||||
)
|
||||
|
||||
await self.append_status_message(
|
||||
status_msg,
|
||||
f' <b>broadcasted!</b>\n'
|
||||
f'<b>{enqueue_tx_link}</b>\n'
|
||||
f'[{timestamp_pretty()}] <i>workers are processing request...</i>',
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
out = res['processed']['action_traces'][0]['console']
|
||||
|
||||
request_id, nonce = out.split(':')
|
||||
|
||||
request_hash = sha256(
|
||||
(nonce + body + ','.join(inputs)).encode('utf-8')).hexdigest().upper()
|
||||
|
||||
request_id = int(request_id)
|
||||
|
||||
logging.info(f'{request_id} enqueued.')
|
||||
|
||||
tx_hash = None
|
||||
ipfs_hash = None
|
||||
for i in range(60 * 3):
|
||||
try:
|
||||
submits = await self.hyperion.aget_actions(
|
||||
account=self.account,
|
||||
filter='gpu.scd:submit',
|
||||
sort='desc',
|
||||
after=request_time
|
||||
)
|
||||
actions = [
|
||||
action
|
||||
for action in submits['actions']
|
||||
if action[
|
||||
'act']['data']['request_hash'] == request_hash
|
||||
]
|
||||
if len(actions) > 0:
|
||||
tx_hash = actions[0]['trx_id']
|
||||
data = actions[0]['act']['data']
|
||||
ipfs_hash = data['ipfs_hash']
|
||||
worker = data['worker']
|
||||
logging.info('Found matching submit!')
|
||||
break
|
||||
|
||||
except JSONDecodeError:
|
||||
logging.error(f'network error while getting actions, retry..')
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if not ipfs_hash:
|
||||
await self.update_status_message(
|
||||
status_msg,
|
||||
f'\n[{timestamp_pretty()}] <b>timeout processing request</b>',
|
||||
parse_mode='HTML'
|
||||
)
|
||||
return False
|
||||
|
||||
tx_link = hlink(
|
||||
'Your result on Skynet Explorer',
|
||||
f'https://{self.explorer_domain}/v2/explore/transaction/{tx_hash}'
|
||||
)
|
||||
|
||||
await self.append_status_message(
|
||||
status_msg,
|
||||
f' <b>request processed!</b>\n'
|
||||
f'<b>{tx_link}</b>\n'
|
||||
f'[{timestamp_pretty()}] <i>trying to download image...</i>\n',
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
caption = generate_reply_caption(
|
||||
user, params, tx_hash, worker, reward, self.explorer_domain)
|
||||
|
||||
# attempt to get the image and send it
|
||||
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
|
||||
|
||||
res = await get_ipfs_file(ipfs_link)
|
||||
logging.info(f'got response from {ipfs_link}')
|
||||
if not res or res.status_code != 200:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {ipfs_link}!')
|
||||
|
||||
else:
|
||||
try:
|
||||
with Image.open(io.BytesIO(res.raw)) as image:
|
||||
w, h = image.size
|
||||
|
||||
if w > TG_MAX_WIDTH or h > TG_MAX_HEIGHT:
|
||||
logging.warning(f'result is of size {image.size}')
|
||||
image.thumbnail((TG_MAX_WIDTH, TG_MAX_HEIGHT))
|
||||
|
||||
tmp_buf = io.BytesIO()
|
||||
image.save(tmp_buf, format='PNG')
|
||||
png_img = tmp_buf.getvalue()
|
||||
|
||||
except UnidentifiedImageError:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {ipfs_link}!')
|
||||
|
||||
if not png_img:
|
||||
await self.update_status_message(
|
||||
status_msg,
|
||||
caption,
|
||||
reply_markup=build_redo_menu(),
|
||||
parse_mode='HTML'
|
||||
)
|
||||
return True
|
||||
|
||||
logging.info(f'success! sending generated image')
|
||||
await self.bot.delete_message(
|
||||
chat_id=status_msg.chat.id, message_id=status_msg.id)
|
||||
if file_id: # img2img
|
||||
await self.bot.send_media_group(
|
||||
status_msg.chat.id,
|
||||
media=[
|
||||
InputMediaPhoto(file_id),
|
||||
InputMediaPhoto(
|
||||
png_img,
|
||||
caption=caption,
|
||||
parse_mode='HTML'
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
else: # txt2img
|
||||
await self.bot.send_photo(
|
||||
status_msg.chat.id,
|
||||
caption=caption,
|
||||
photo=png_img,
|
||||
reply_markup=build_redo_menu(),
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
return True
|
|
@ -1,365 +0,0 @@
|
|||
import io
|
||||
import json
|
||||
import logging
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from PIL import Image
|
||||
from telebot.types import CallbackQuery, Message
|
||||
|
||||
from skynet.frontend import validate_user_config_request, perform_auto_conf
|
||||
from skynet.constants import *
|
||||
|
||||
|
||||
def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
||||
|
||||
bot = frontend.bot
|
||||
cleos = frontend.cleos
|
||||
db_call = frontend.db_call
|
||||
work_request = frontend.work_request
|
||||
|
||||
ipfs_node = frontend.ipfs_node
|
||||
|
||||
# generic / simple handlers
|
||||
|
||||
@bot.message_handler(commands=['help'])
|
||||
async def send_help(message):
|
||||
splt_msg = message.text.split(' ')
|
||||
|
||||
if len(splt_msg) == 1:
|
||||
await bot.reply_to(message, HELP_TEXT)
|
||||
|
||||
else:
|
||||
param = splt_msg[1]
|
||||
if param in HELP_TOPICS:
|
||||
await bot.reply_to(message, HELP_TOPICS[param])
|
||||
|
||||
else:
|
||||
await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
|
||||
|
||||
@bot.message_handler(commands=['cool'])
|
||||
async def send_cool_words(message):
|
||||
await bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||
|
||||
@bot.message_handler(commands=['queue'])
|
||||
async def queue(message):
|
||||
an_hour_ago = datetime.now() - timedelta(hours=1)
|
||||
queue = await cleos.aget_table(
|
||||
'gpu.scd', 'gpu.scd', 'queue',
|
||||
index_position=2,
|
||||
key_type='i64',
|
||||
sort='desc',
|
||||
lower_bound=int(an_hour_ago.timestamp())
|
||||
)
|
||||
await bot.reply_to(
|
||||
message, f'Total requests on skynet queue: {len(queue)}')
|
||||
|
||||
|
||||
@bot.message_handler(commands=['config'])
|
||||
async def set_config(message):
|
||||
user = message.from_user.id
|
||||
try:
|
||||
attr, val, reply_txt = validate_user_config_request(
|
||||
message.text)
|
||||
|
||||
logging.info(f'user config update: {attr} to {val}')
|
||||
await db_call('update_user_config', user, attr, val)
|
||||
logging.info('done')
|
||||
|
||||
except BaseException as e:
|
||||
reply_txt = str(e)
|
||||
|
||||
finally:
|
||||
await bot.reply_to(message, reply_txt)
|
||||
|
||||
@bot.message_handler(commands=['stats'])
|
||||
async def user_stats(message):
|
||||
user = message.from_user.id
|
||||
|
||||
await db_call('get_or_create_user', user)
|
||||
generated, joined, role = await db_call('get_user_stats', user)
|
||||
|
||||
stats_str = f'generated: {generated}\n'
|
||||
stats_str += f'joined: {joined}\n'
|
||||
stats_str += f'role: {role}\n'
|
||||
|
||||
await bot.reply_to(
|
||||
message, stats_str)
|
||||
|
||||
@bot.message_handler(commands=['donate'])
|
||||
async def donation_info(message):
|
||||
await bot.reply_to(
|
||||
message, DONATION_INFO)
|
||||
|
||||
@bot.message_handler(commands=['say'])
|
||||
async def say(message):
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
|
||||
if (chat.type == 'group') or (user.id != 383385940):
|
||||
return
|
||||
|
||||
await bot.send_message(GROUP_ID, message.text[4:])
|
||||
|
||||
|
||||
# generic txt2img handler
|
||||
|
||||
async def _generic_txt2img(message_or_query):
|
||||
if isinstance(message_or_query, CallbackQuery):
|
||||
query = message_or_query
|
||||
message = query.message
|
||||
user = query.from_user
|
||||
chat = query.message.chat
|
||||
|
||||
else:
|
||||
message = message_or_query
|
||||
user = message.from_user
|
||||
chat = message.chat
|
||||
|
||||
if chat.type == 'private':
|
||||
return
|
||||
|
||||
reply_id = None
|
||||
if chat.type == 'group' and chat.id == GROUP_ID:
|
||||
reply_id = message.message_id
|
||||
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
|
||||
# init new msg
|
||||
init_msg = 'started processing txt2img request...'
|
||||
status_msg = await bot.reply_to(message, init_msg)
|
||||
await db_call(
|
||||
'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
||||
|
||||
prompt = ' '.join(message.text.split(' ')[1:])
|
||||
|
||||
if len(prompt) == 0:
|
||||
await bot.edit_message_text(
|
||||
'Empty text prompt ignored.',
|
||||
chat_id=status_msg.chat.id,
|
||||
message_id=status_msg.id
|
||||
)
|
||||
await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
logging.info(f'mid: {message.id}')
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
if user_config['autoconf']:
|
||||
user_config = perform_auto_conf(user_config)
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
|
||||
await db_call(
|
||||
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
|
||||
|
||||
success = await work_request(user, status_msg, 'txt2img', params)
|
||||
|
||||
if success:
|
||||
await db_call('increment_generated', user.id)
|
||||
|
||||
|
||||
# generic img2img handler
|
||||
|
||||
async def _generic_img2img(message_or_query):
|
||||
if isinstance(message_or_query, CallbackQuery):
|
||||
query = message_or_query
|
||||
message = query.message
|
||||
user = query.from_user
|
||||
chat = query.message.chat
|
||||
|
||||
else:
|
||||
message = message_or_query
|
||||
user = message.from_user
|
||||
chat = message.chat
|
||||
|
||||
if chat.type == 'private':
|
||||
return
|
||||
|
||||
reply_id = None
|
||||
if chat.type == 'group' and chat.id == GROUP_ID:
|
||||
reply_id = message.message_id
|
||||
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
|
||||
# init new msg
|
||||
init_msg = 'started processing txt2img request...'
|
||||
status_msg = await bot.reply_to(message, init_msg)
|
||||
await db_call(
|
||||
'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
||||
|
||||
if not message.caption.startswith('/img2img'):
|
||||
await bot.reply_to(
|
||||
message,
|
||||
'For image to image you need to add /img2img to the beggining of your caption'
|
||||
)
|
||||
return
|
||||
|
||||
prompt = ' '.join(message.caption.split(' ')[1:])
|
||||
|
||||
if len(prompt) == 0:
|
||||
await bot.reply_to(message, 'Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
file_id = message.photo[-1].file_id
|
||||
file_path = (await bot.get_file(file_id)).file_path
|
||||
image_raw = await bot.download_file(file_path)
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
if user_config['autoconf']:
|
||||
user_config = perform_auto_conf(user_config)
|
||||
|
||||
with Image.open(io.BytesIO(image_raw)) as image:
|
||||
w, h = image.size
|
||||
|
||||
if w > user_config['width'] or h > user_config['height']:
|
||||
logging.warning(f'user sent img of size {image.size}')
|
||||
image.thumbnail(
|
||||
(user_config['width'], user_config['height']))
|
||||
logging.warning(f'resized it to {image.size}')
|
||||
|
||||
image_loc = 'ipfs-staging/image.png'
|
||||
image.save(image_loc, format='PNG')
|
||||
|
||||
ipfs_info = await ipfs_node.add(image_loc)
|
||||
ipfs_hash = ipfs_info['Hash']
|
||||
await ipfs_node.pin(ipfs_hash)
|
||||
|
||||
logging.info(f'published input image {ipfs_hash} on ipfs')
|
||||
|
||||
logging.info(f'mid: {message.id}')
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
|
||||
await db_call(
|
||||
'update_user_stats',
|
||||
user.id,
|
||||
'img2img',
|
||||
last_file=file_id,
|
||||
last_prompt=prompt,
|
||||
last_binary=ipfs_hash
|
||||
)
|
||||
|
||||
success = await work_request(
|
||||
user, status_msg, 'img2img', params,
|
||||
file_id=file_id,
|
||||
inputs=ipfs_hash
|
||||
)
|
||||
|
||||
if success:
|
||||
await db_call('increment_generated', user.id)
|
||||
|
||||
|
||||
# generic redo handler
|
||||
|
||||
async def _redo(message_or_query):
|
||||
is_query = False
|
||||
if isinstance(message_or_query, CallbackQuery):
|
||||
is_query = True
|
||||
query = message_or_query
|
||||
message = query.message
|
||||
user = query.from_user
|
||||
chat = query.message.chat
|
||||
|
||||
elif isinstance(message_or_query, Message):
|
||||
message = message_or_query
|
||||
user = message.from_user
|
||||
chat = message.chat
|
||||
|
||||
if chat.type == 'private':
|
||||
return
|
||||
|
||||
init_msg = 'started processing redo request...'
|
||||
if is_query:
|
||||
status_msg = await bot.send_message(chat.id, init_msg)
|
||||
|
||||
else:
|
||||
status_msg = await bot.reply_to(message, init_msg)
|
||||
|
||||
method = await db_call('get_last_method_of', user.id)
|
||||
prompt = await db_call('get_last_prompt_of', user.id)
|
||||
|
||||
file_id = None
|
||||
binary = ''
|
||||
if method == 'img2img':
|
||||
file_id = await db_call('get_last_file_of', user.id)
|
||||
binary = await db_call('get_last_binary_of', user.id)
|
||||
|
||||
if not prompt:
|
||||
await bot.reply_to(
|
||||
message,
|
||||
'no last prompt found, do a txt2img cmd first!'
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
await db_call(
|
||||
'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
if user_config['autoconf']:
|
||||
user_config = perform_auto_conf(user_config)
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
|
||||
success = await work_request(
|
||||
user, status_msg, 'redo', params,
|
||||
file_id=file_id,
|
||||
inputs=binary
|
||||
)
|
||||
|
||||
if success:
|
||||
await db_call('increment_generated', user.id)
|
||||
|
||||
|
||||
# "proxy" handlers just request routers
|
||||
|
||||
@bot.message_handler(commands=['txt2img'])
|
||||
async def send_txt2img(message):
|
||||
await _generic_txt2img(message)
|
||||
|
||||
@bot.message_handler(func=lambda message: True, content_types=[
|
||||
'photo', 'document'])
|
||||
async def send_img2img(message):
|
||||
await _generic_img2img(message)
|
||||
|
||||
@bot.message_handler(commands=['img2img'])
|
||||
async def img2img_missing_image(message):
|
||||
await bot.reply_to(
|
||||
message,
|
||||
'seems you tried to do an img2img command without sending image'
|
||||
)
|
||||
|
||||
@bot.message_handler(commands=['redo'])
|
||||
async def redo(message):
|
||||
await _redo(message)
|
||||
|
||||
@bot.callback_query_handler(func=lambda call: True)
|
||||
async def callback_query(call):
|
||||
msg = json.loads(call.data)
|
||||
logging.info(call.data)
|
||||
method = msg.get('method')
|
||||
match method:
|
||||
case 'redo':
|
||||
await _redo(call)
|
||||
|
||||
|
||||
# catch all handler for things we dont support
|
||||
|
||||
@bot.message_handler(func=lambda message: True)
|
||||
async def echo_message(message):
|
||||
if message.text[0] == '/':
|
||||
await bot.reply_to(message, UNKNOWN_CMD_TEXT)
|
|
@ -1,105 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from telebot.types import InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telebot.async_telebot import ExceptionHandler
|
||||
from telebot.formatting import hlink
|
||||
|
||||
from skynet.constants import *
|
||||
|
||||
|
||||
def timestamp_pretty():
|
||||
return datetime.now(timezone.utc).strftime('%H:%M:%S')
|
||||
|
||||
|
||||
def tg_user_pretty(tguser):
|
||||
if tguser.username:
|
||||
return f'@{tguser.username}'
|
||||
else:
|
||||
return f'{tguser.first_name} id: {tguser.id}'
|
||||
|
||||
|
||||
class SKYExceptionHandler(ExceptionHandler):
|
||||
|
||||
def handle(exception):
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def build_redo_menu():
|
||||
btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'}))
|
||||
inline_keyboard = InlineKeyboardMarkup()
|
||||
inline_keyboard.add(btn_redo)
|
||||
return inline_keyboard
|
||||
|
||||
|
||||
def prepare_metainfo_caption(tguser, worker: str, reward: str, meta: dict) -> str:
|
||||
prompt = meta["prompt"]
|
||||
if len(prompt) > 256:
|
||||
prompt = prompt[:256]
|
||||
|
||||
|
||||
meta_str = f'<u>by {tg_user_pretty(tguser)}</u>\n'
|
||||
meta_str += f'<i>performed by {worker}</i>\n'
|
||||
meta_str += f'<b><u>reward: {reward}</u></b>\n'
|
||||
|
||||
meta_str += f'<code>prompt:</code> {prompt}\n'
|
||||
meta_str += f'<code>seed: {meta["seed"]}</code>\n'
|
||||
meta_str += f'<code>step: {meta["step"]}</code>\n'
|
||||
meta_str += f'<code>guidance: {meta["guidance"]}</code>\n'
|
||||
if meta['strength']:
|
||||
meta_str += f'<code>strength: {meta["strength"]}</code>\n'
|
||||
meta_str += f'<code>algo: {meta["model"]}</code>\n'
|
||||
if meta['upscaler']:
|
||||
meta_str += f'<code>upscaler: {meta["upscaler"]}</code>\n'
|
||||
|
||||
meta_str += f'<b><u>Made with Skynet v{VERSION}</u></b>\n'
|
||||
meta_str += f'<b>JOIN THE SWARM: @skynetgpu</b>'
|
||||
return meta_str
|
||||
|
||||
|
||||
def generate_reply_caption(
|
||||
tguser, # telegram user
|
||||
params: dict,
|
||||
tx_hash: str,
|
||||
worker: str,
|
||||
reward: str,
|
||||
explorer_domain: str
|
||||
):
|
||||
explorer_link = hlink(
|
||||
'SKYNET Transaction Explorer',
|
||||
f'https://{explorer_domain}/v2/explore/transaction/{tx_hash}'
|
||||
)
|
||||
|
||||
meta_info = prepare_metainfo_caption(tguser, worker, reward, params)
|
||||
|
||||
final_msg = '\n'.join([
|
||||
'Worker finished your task!',
|
||||
explorer_link,
|
||||
f'PARAMETER INFO:\n{meta_info}'
|
||||
])
|
||||
|
||||
final_msg = '\n'.join([
|
||||
f'<b><i>{explorer_link}</i></b>',
|
||||
f'{meta_info}'
|
||||
])
|
||||
|
||||
logging.info(final_msg)
|
||||
|
||||
return final_msg
|
||||
|
||||
|
||||
async def get_global_config(cleos):
|
||||
return (await cleos.aget_table(
|
||||
'gpu.scd', 'gpu.scd', 'config'))[0]
|
||||
|
||||
async def get_user_nonce(cleos, user: str):
|
||||
return (await cleos.aget_table(
|
||||
'gpu.scd', 'gpu.scd', 'users',
|
||||
index_position=1,
|
||||
key_type='name',
|
||||
lower_bound=user,
|
||||
upper_bound=user
|
||||
))[0]['nonce']
|
|
@ -85,6 +85,7 @@ class BodyV0Params(Struct):
|
|||
strength: str | float | None = None
|
||||
output_type: str | None = 'png'
|
||||
upscaler: str | None = None
|
||||
autoconf: bool | None = None
|
||||
|
||||
|
||||
class BodyV0(Struct):
|
||||
|
|
48
uv.lock
48
uv.lock
|
@ -1,4 +1,5 @@
|
|||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.10, <3.13"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'darwin'",
|
||||
|
@ -271,7 +272,7 @@ name = "cffi"
|
|||
version = "1.17.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pycparser" },
|
||||
{ name = "pycparser", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fc/97/c783634659c2920c3fc70419e3af40972dbaf758daa229a7d6ea6135c90d/cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", size = 516621 }
|
||||
wheels = [
|
||||
|
@ -1271,7 +1272,7 @@ name = "nvidia-cudnn-cu12"
|
|||
version = "9.1.0.70"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
|
||||
|
@ -1298,9 +1299,9 @@ name = "nvidia-cusolver-cu12"
|
|||
version = "11.4.5.107"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12" },
|
||||
{ name = "nvidia-cusparse-cu12" },
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 },
|
||||
|
@ -1311,7 +1312,7 @@ name = "nvidia-cusparse-cu12"
|
|||
version = "12.1.0.106"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 },
|
||||
|
@ -1598,7 +1599,7 @@ wheels = [
|
|||
[[package]]
|
||||
name = "py-leap"
|
||||
version = "0.1a35"
|
||||
source = { git = "https://github.com/guilledk/py-leap.git?branch=struct_unwrap#18b3c73e724922a060db5f8ea2b9d9727b6152cc" }
|
||||
source = { editable = "../py-leap" }
|
||||
dependencies = [
|
||||
{ name = "base58" },
|
||||
{ name = "cryptos" },
|
||||
|
@ -1608,6 +1609,33 @@ dependencies = [
|
|||
{ name = "ripemd-hash" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "base58", specifier = ">=2.1.1,<3" },
|
||||
{ name = "cryptos", specifier = ">=2.0.9,<3" },
|
||||
{ name = "httpx", specifier = ">=0.28.1,<0.29" },
|
||||
{ name = "msgspec", specifier = ">=0.19.0" },
|
||||
{ name = "requests", specifier = "<2.32.0" },
|
||||
{ name = "ripemd-hash", specifier = ">=1.0.1,<2" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "docker", specifier = ">=6.1.3,<7" },
|
||||
{ name = "pdbpp", specifier = ">=0.10.3,<0.11" },
|
||||
{ name = "pytest", specifier = ">=8.3.4,<9" },
|
||||
{ name = "pytest-trio", specifier = ">=0.8.0,<0.9" },
|
||||
]
|
||||
docs = [
|
||||
{ name = "sphinx", specifier = "==7.1.2" },
|
||||
{ name = "sphinx-rtd-theme", specifier = "==1.3.0" },
|
||||
]
|
||||
snaps = [
|
||||
{ name = "bs4", specifier = ">=0.0.2,<0.0.3" },
|
||||
{ name = "tdqm", specifier = ">=0.0.1,<0.0.2" },
|
||||
{ name = "zstandard", specifier = ">=0.21.0,<0.22" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
version = "2.22"
|
||||
|
@ -1633,8 +1661,6 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/61/74/49f5d20c514ccc631b940cc9dfec45dcce418dc84a98463a2e2ebec33904/pycryptodomex-3.21.0-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:52e23a0a6e61691134aa8c8beba89de420602541afaae70f66e16060fdcd677e", size = 2257982 },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/4b/d33ef74e2cc0025a259936661bb53432c5bbbadc561c5f2e023bcd73ce4c/pycryptodomex-3.21.0-cp36-abi3-win32.whl", hash = "sha256:a3d77919e6ff56d89aada1bd009b727b874d464cb0e2e3f00a49f7d2e709d76e", size = 1779052 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/be/7c991840af1184009fc86267160948350d1bf875f153c97bb471ad944e40/pycryptodomex-3.21.0-cp36-abi3-win_amd64.whl", hash = "sha256:b0e9765f93fe4890f39875e6c90c96cb341767833cfa767f41b490b506fa9ec0", size = 1816307 },
|
||||
{ url = "https://files.pythonhosted.org/packages/af/ac/24125ad36778914a36f08d61ba5338cb9159382c638d9761ee19c8de822c/pycryptodomex-3.21.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:feaecdce4e5c0045e7a287de0c4351284391fe170729aa9182f6bd967631b3a8", size = 1694999 },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/73/be7a54a5903508070e5508925ba94493a1f326cfeecfff750e3eb250ea28/pycryptodomex-3.21.0-pp27-pypy_73-win32.whl", hash = "sha256:365aa5a66d52fd1f9e0530ea97f392c48c409c2f01ff8b9a39c73ed6f527d36c", size = 1769437 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/9f/39a6187f3986841fa6a9f35c6fdca5030ef73ff708b45a993813a51d7d10/pycryptodomex-3.21.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:3efddfc50ac0ca143364042324046800c126a1d63816d532f2e19e6f2d8c0c31", size = 1619607 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/70/60bb08e9e9841b18d4669fb69d84b64ce900aacd7eb0ebebd4c7b9bdecd3/pycryptodomex-3.21.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0df2608682db8279a9ebbaf05a72f62a321433522ed0e499bc486a6889b96bf3", size = 1653571 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/6f/191b73509291c5ff0dddec9cc54797b1d73303c12b2e4017b24678e57099/pycryptodomex-3.21.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5823d03e904ea3e53aebd6799d6b8ec63b7675b5d2f4a4bd5e3adcb512d03b37", size = 1691548 },
|
||||
|
@ -2126,7 +2152,7 @@ requires-dist = [
|
|||
{ name = "outcome", specifier = ">=1.3.0.post0" },
|
||||
{ name = "pillow", specifier = ">=10.0.1,<11" },
|
||||
{ name = "protobuf", specifier = ">=5.29.3,<6" },
|
||||
{ name = "py-leap", git = "https://github.com/guilledk/py-leap.git?branch=struct_unwrap" },
|
||||
{ name = "py-leap", editable = "../py-leap" },
|
||||
{ name = "pytz", specifier = "~=2023.3.post1" },
|
||||
{ name = "toml", specifier = ">=0.10.2,<0.11" },
|
||||
{ name = "trio", specifier = ">=0.22.2,<0.23" },
|
||||
|
@ -2436,7 +2462,7 @@ name = "triton"
|
|||
version = "3.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock" },
|
||||
{ name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b0dd10a925263abbe9fa37dcde67a5e9b2383fc269fdf59f5657cac38c5d1d8", size = 209460013 },
|
||||
|
|
Loading…
Reference in New Issue