mirror of https://github.com/skygpu/skynet.git
407 lines
13 KiB
Python
407 lines
13 KiB
Python
import logging
|
|
from typing import Self, Awaitable
|
|
from datetime import datetime, timezone
|
|
|
|
import discord
|
|
from discord import (
|
|
User as DCUser,
|
|
Member,
|
|
Message as DCMessage,
|
|
Attachment,
|
|
DMChannel
|
|
)
|
|
from discord.abc import Messageable
|
|
from discord.ext import commands
|
|
|
|
from skynet.config import FrontendConfig
|
|
from skynet.types import BodyV0Params
|
|
from skynet.constants import VERSION
|
|
from skynet.frontend.chatbot import BaseChatbot
|
|
from skynet.frontend.chatbot.db import FrontendUserDB
|
|
from skynet.frontend.chatbot.types import (
|
|
BaseUser,
|
|
BaseChatRoom,
|
|
BaseFileInput,
|
|
BaseCommands,
|
|
BaseMessage
|
|
)
|
|
|
|
GROUP_ID = -1
|
|
ADMIN_USER_ID = -1
|
|
|
|
|
|
def timestamp_pretty():
|
|
return datetime.now(timezone.utc).strftime('%H:%M:%S')
|
|
|
|
|
|
class DiscordUser(BaseUser):
|
|
|
|
def __init__(self, user: DCUser | Member):
|
|
self._user = user
|
|
|
|
@property
|
|
def id(self) -> int:
|
|
return self._user.id
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._user.name
|
|
|
|
@property
|
|
def is_admin(self) -> bool:
|
|
return self.id == ADMIN_USER_ID
|
|
|
|
|
|
class DiscordChatRoom(BaseChatRoom):
|
|
|
|
def __init__(self, channel: Messageable):
|
|
self._channel = channel
|
|
|
|
@property
|
|
def id(self) -> int:
|
|
return self._channel.id
|
|
|
|
@property
|
|
def is_private(self) -> bool:
|
|
return isinstance(self._channel, DMChannel)
|
|
|
|
class DiscordFileInput(BaseFileInput):
|
|
|
|
def __init__(
|
|
self,
|
|
attachment: Attachment | None = None,
|
|
id: int | None = None,
|
|
cid: int | None = None
|
|
):
|
|
self._attachment = attachment
|
|
self._id = id
|
|
self._cid = cid
|
|
|
|
self._raw = None
|
|
|
|
def from_values(id: int, cid: str) -> Self:
|
|
return DiscordFileInput(id=id, cid=cid)
|
|
|
|
@property
|
|
def id(self) -> int:
|
|
if self._id:
|
|
return self._id
|
|
|
|
return self._attachment.id
|
|
|
|
@property
|
|
def cid(self) -> str:
|
|
if self._cid:
|
|
return self._cid
|
|
|
|
raise ValueError
|
|
|
|
def set_cid(self, cid: str):
|
|
self._cid = cid
|
|
|
|
async def download(self) -> bytes:
|
|
self._raw = await self._attachment.read()
|
|
return self._raw
|
|
|
|
|
|
class DiscordMessage(BaseMessage):
|
|
|
|
def __init__(self, cmd: BaseCommands | None, msg: DCMessage):
|
|
self._msg = msg
|
|
self._cmd = cmd
|
|
self._chat = DiscordChatRoom(msg.channel)
|
|
self._inputs: list[DiscordFileInput] | None = None
|
|
self._author = None
|
|
|
|
@property
|
|
def id(self) -> int:
|
|
return self._msg.id
|
|
|
|
@property
|
|
def chat(self) -> DiscordChatRoom:
|
|
return self._chat
|
|
|
|
@property
|
|
def text(self) -> str:
|
|
# remove command name, slash and first space
|
|
return self._msg.contents[len(self._cmd) + 2:]
|
|
|
|
@property
|
|
def author(self) -> DiscordUser:
|
|
if self._author:
|
|
return self._author
|
|
|
|
return DiscordUser(self._msg.author)
|
|
|
|
@property
|
|
def command(self) -> str | None:
|
|
return self._cmd
|
|
|
|
@property
|
|
def inputs(self) -> list[DiscordFileInput]:
|
|
if self._inputs is None:
|
|
self._inputs = []
|
|
if self._msg.attachments:
|
|
self._inputs = [
|
|
DiscordFileInput(attachment=a)
|
|
for a in self._msg.attachments
|
|
]
|
|
|
|
return self._inputs
|
|
|
|
|
|
def generate_reply_embed(
|
|
config: FrontendConfig,
|
|
user: DiscordUser,
|
|
params: BodyV0Params,
|
|
tx_hash: str,
|
|
worker: str,
|
|
) -> discord.Embed:
|
|
embed = discord.Embed(
|
|
title='[SKYNET Transaction Explorer]',
|
|
url=f'https://{config.explorer_domain}/v2/explore/transaction/{tx_hash}',
|
|
color=discord.Color.blue())
|
|
|
|
prompt = params.prompt
|
|
if len(prompt) > 256:
|
|
prompt = prompt[:256]
|
|
|
|
gen_str = f'generated by {user.name}\n'
|
|
gen_str += f'performed by {worker}\n'
|
|
gen_str += f'reward: {config.reward}\n'
|
|
|
|
embed.add_field(
|
|
name='General Info', value=f'```{gen_str}```', inline=False)
|
|
# meta_str = f'__by {user.name}__\n'
|
|
# meta_str += f'*performed by {worker}*\n'
|
|
# meta_str += f'__**reward: {reward}**__\n'
|
|
embed.add_field(name='Prompt', value=f'```{prompt}\n```', inline=False)
|
|
|
|
# meta_str = f'`prompt:` {prompt}\n'
|
|
|
|
meta_str = f'seed: {params.seed}\n'
|
|
meta_str += f'step: {params.step}\n'
|
|
if params.guidance:
|
|
meta_str += f'guidance: {params.guidance}\n'
|
|
|
|
if params.strength:
|
|
meta_str += f'strength: {params.strength}\n'
|
|
|
|
meta_str += f'algo: {params.model}\n'
|
|
if params.upscaler:
|
|
meta_str += f'upscaler: {params.upscaler}\n'
|
|
|
|
embed.add_field(name='Parameters', value=f'```{meta_str}```', inline=False)
|
|
|
|
foot_str = f'Made with Skynet v{VERSION}\n'
|
|
foot_str += 'JOIN THE SWARM: https://discord.gg/PAabjJtZAF'
|
|
|
|
embed.set_footer(text=foot_str)
|
|
|
|
return embed
|
|
|
|
|
|
def append_command_handler(client: discord.Client, command: str, help_txt: str, fn: Awaitable):
|
|
@client.command(name=command, help=help_txt)
|
|
async def wrap_msg_and_handle(ctx: commands.Context):
|
|
msg = DiscordMessage(cmd=command, msg=ctx.message)
|
|
for file in msg.inputs:
|
|
await file.download()
|
|
await fn(msg)
|
|
|
|
|
|
class DiscordChatbot(BaseChatbot):
|
|
|
|
def __init__(
|
|
self,
|
|
config: FrontendConfig,
|
|
db: FrontendUserDB,
|
|
):
|
|
super().__init__(config, db)
|
|
intents = discord.Intents(
|
|
messages=True,
|
|
guilds=True,
|
|
typing=True,
|
|
members=True,
|
|
presences=True,
|
|
reactions=True,
|
|
message_content=True,
|
|
voice_states=True
|
|
)
|
|
client = discord.Client(
|
|
command_prefix='/',
|
|
intents=intents
|
|
)
|
|
|
|
@client.event
|
|
async def on_ready():
|
|
print(f'{client.user.name} has connected to Discord!')
|
|
for guild in client.guilds:
|
|
for channel in guild.channels:
|
|
if channel.name == "skynet":
|
|
await channel.send('Skynet bot online') # view=SkynetView(self.bot))
|
|
await channel.send(
|
|
'Welcome to Skynet\'s Discord Bot,\n\n'
|
|
'Skynet operates as a decentralized compute layer, offering a wide array of '
|
|
'support for diverse AI paradigms through the use of blockchain technology. '
|
|
'Our present focus is image generation, powered by 11 distinct models.\n\n'
|
|
'To begin exploring, use the \'/help\' command or directly interact with the'
|
|
'provided buttons. Here is an example command to generate an image:\n\n'
|
|
'\'/txt2img a big red tractor in a giant field of corn\''
|
|
)
|
|
|
|
print("\n==============")
|
|
print("Logged in as")
|
|
print(client.user.name)
|
|
print(client.user.id)
|
|
print("==============")
|
|
|
|
@client.event
|
|
async def on_message(message: DCMessage):
|
|
if message.author == client.user:
|
|
return
|
|
|
|
await self.process_commands(message)
|
|
|
|
@client.event
|
|
async def on_command_error(ctx, error):
|
|
if isinstance(error, commands.MissingRequiredArgument):
|
|
await ctx.send('You missed a required argument, please try again.')
|
|
|
|
append_command_handler(client, BaseCommands.HELP, 'Responds with help text', self.send_help)
|
|
append_command_handler(client, BaseCommands.COOL, 'Display a list of cool prompt words', self.send_cool_words)
|
|
append_command_handler(client, BaseCommands.QUEUE, 'Get information on current skynet queue', self.get_queue)
|
|
append_command_handler(client, BaseCommands.CONFIG, 'Allows user to configure inference params', self.set_config)
|
|
append_command_handler(client, BaseCommands.STATS, 'See user statistics', self.user_stats)
|
|
append_command_handler(client, BaseCommands.DONATE, 'See donation information', self.donation_info)
|
|
append_command_handler(client, BaseCommands.SAY, 'Admin command to make bot speak', self.say)
|
|
|
|
append_command_handler(client, BaseCommands.TXT2IMG, 'Generate an image from a prompt', self.handle_request)
|
|
append_command_handler(client, BaseCommands.REDO, 'Re-generate image using last prompt', self.handle_request)
|
|
|
|
self.client = client
|
|
self._main_room: DiscordChatRoom | None = None
|
|
|
|
async def init(self):
|
|
dc_channel = await self.client.get_channel(GROUP_ID)
|
|
self._main_room = DiscordChatRoom(channel=dc_channel)
|
|
logging.info('initialized')
|
|
|
|
async def run(self):
|
|
await self.init()
|
|
await self.client.run(self.config.token)
|
|
|
|
@property
|
|
def main_group(self) -> DiscordChatRoom:
|
|
return self._main_room
|
|
|
|
async def new_msg(self, chat: DiscordChatRoom, text: str, **kwargs) -> DiscordMessage:
|
|
dc_msg = await chat._channel.send(text, **kwargs)
|
|
return DiscordMessage(cmd=None, msg=dc_msg)
|
|
|
|
async def reply_to(self, msg: DiscordMessage, text: str, **kwargs) -> DiscordMessage:
|
|
dc_msg = await msg._msg.reply(content=text, **kwargs)
|
|
return DiscordMessage(cmd=None, msg=dc_msg)
|
|
|
|
async def edit_msg(self, msg: DiscordMessage, text: str, **kwargs):
|
|
await msg._msg.edit(content=text, **kwargs)
|
|
|
|
async def create_status_msg(self, msg: DiscordMessage, init_text: str, force_user: DiscordUser | None = None) -> tuple[BaseUser, BaseMessage, dict]:
|
|
# maybe init user
|
|
user = msg.author
|
|
if force_user:
|
|
user = force_user
|
|
|
|
user_row = await self.db.get_or_create_user(user.id)
|
|
|
|
# create status msg
|
|
embed = discord.Embed(
|
|
title='live updates',
|
|
description=init_text,
|
|
color=discord.Color.blue()
|
|
)
|
|
status_msg = await self.new_msg(msg.chat, None, embed=embed)
|
|
|
|
# start tracking of request in db
|
|
await self.db.new_user_request(user.id, msg.id, status_msg.id, status=init_text)
|
|
return [user, status_msg, user_row]
|
|
|
|
async def update_status_msg(self, msg: DiscordMessage, text: str):
|
|
await self.db.update_user_request_by_sid(msg.id, text)
|
|
embed = discord.Embed(
|
|
title='live updates',
|
|
description=text,
|
|
color=discord.Color.blue()
|
|
)
|
|
await self.edit_msg(msg, None, embed=embed)
|
|
|
|
async def append_status_msg(self, msg: DiscordMessage, text: str):
|
|
request = await self.db.get_user_request_by_sid(msg.id)
|
|
await self.update_status_msg(msg, request['status'] + text)
|
|
|
|
async def update_request_status_timeout(self, status_msg: DiscordMessage):
|
|
await self.append_status_msg(
|
|
status_msg,
|
|
f'\n[{timestamp_pretty()}] **timeout processing request**',
|
|
)
|
|
|
|
async def update_request_status_step_0(self, status_msg: DiscordMessage, user_msg: DiscordMessage):
|
|
await self.update_status_msg(
|
|
status_msg,
|
|
f'processing a \'{status_msg.cmd}\' request by {status_msg.author.name}\n'
|
|
f'[{timestamp_pretty()}] *broadcasting transaction to chain...* '
|
|
)
|
|
|
|
async def update_request_status_step_1(self, status_msg: DiscordMessage, tx_result: dict):
|
|
enqueue_tx_id = tx_result['transaction_id']
|
|
enqueue_tx_link = f'[**Your request on Skynet Explorer**](https://{self.config.explorer_domain}/v2/explore/transaction/{enqueue_tx_id})'
|
|
await self.append_status_msg(
|
|
status_msg,
|
|
'**broadcasted!** \n'
|
|
f'{enqueue_tx_link}\n'
|
|
f'[{timestamp_pretty()}] *workers are processing request...* '
|
|
)
|
|
|
|
async def update_request_status_step_2(self, status_msg: DiscordMessage, submit_tx_hash: str):
|
|
tx_link = f'[**Your result on Skynet Explorer**](https://{self.config.explorer_domain}/v2/explore/transaction/{submit_tx_hash})'
|
|
await self.append_status_msg(
|
|
status_msg,
|
|
'**request processed!**\n'
|
|
f'{tx_link}\n'
|
|
f'[{timestamp_pretty()}] *trying to download image...*\n '
|
|
)
|
|
|
|
async def update_request_status_final(
|
|
self,
|
|
og_msg: DiscordMessage,
|
|
status_msg: DiscordMessage,
|
|
user: DiscordUser,
|
|
params: BodyV0Params,
|
|
inputs: list[DiscordFileInput],
|
|
submit_tx_hash: str,
|
|
worker: str,
|
|
result_url: str,
|
|
result_img: bytes | None
|
|
):
|
|
embed = generate_reply_embed(
|
|
self.config, user, params, submit_tx_hash, worker)
|
|
|
|
if not result_img:
|
|
# result found on chain but failed to fetch img from ipfs
|
|
await self.append_status_msg(status_msg, f'[{timestamp_pretty()}] *Couldn\'t get IPFS hosted img [**here**]({result_url})!*')
|
|
return
|
|
|
|
await status_msg._msg.delete()
|
|
|
|
embed.set_image(url=result_url)
|
|
|
|
match len(inputs):
|
|
case 0:
|
|
await self.new_msg(og_msg.chat, None, embed=embed)
|
|
|
|
case _:
|
|
_input = inputs[-1]
|
|
dc_file = discord.File(_input._raw, filename=f'image-{og_msg.id}.png')
|
|
embed.set_thumbnail(url=f'attachment://image-{og_msg.id}.png')
|
|
await self.new_msg(og_msg.chat, None, embed=embed, file=dc_file)
|