skynet/skynet/frontend/chatbot/discord.py

407 lines
13 KiB
Python

import logging
from typing import Self, Awaitable
from datetime import datetime, timezone
import discord
from discord import (
User as DCUser,
Member,
Message as DCMessage,
Attachment,
DMChannel
)
from discord.abc import Messageable
from discord.ext import commands
from skynet.config import FrontendConfig
from skynet.types import BodyV0Params
from skynet.constants import VERSION
from skynet.frontend.chatbot import BaseChatbot
from skynet.frontend.chatbot.db import FrontendUserDB
from skynet.frontend.chatbot.types import (
BaseUser,
BaseChatRoom,
BaseFileInput,
BaseCommands,
BaseMessage
)
GROUP_ID = -1
ADMIN_USER_ID = -1
def timestamp_pretty():
return datetime.now(timezone.utc).strftime('%H:%M:%S')
class DiscordUser(BaseUser):
def __init__(self, user: DCUser | Member):
self._user = user
@property
def id(self) -> int:
return self._user.id
@property
def name(self) -> str:
return self._user.name
@property
def is_admin(self) -> bool:
return self.id == ADMIN_USER_ID
class DiscordChatRoom(BaseChatRoom):
def __init__(self, channel: Messageable):
self._channel = channel
@property
def id(self) -> int:
return self._channel.id
@property
def is_private(self) -> bool:
return isinstance(self._channel, DMChannel)
class DiscordFileInput(BaseFileInput):
def __init__(
self,
attachment: Attachment | None = None,
id: int | None = None,
cid: int | None = None
):
self._attachment = attachment
self._id = id
self._cid = cid
self._raw = None
def from_values(id: int, cid: str) -> Self:
return DiscordFileInput(id=id, cid=cid)
@property
def id(self) -> int:
if self._id:
return self._id
return self._attachment.id
@property
def cid(self) -> str:
if self._cid:
return self._cid
raise ValueError
def set_cid(self, cid: str):
self._cid = cid
async def download(self) -> bytes:
self._raw = await self._attachment.read()
return self._raw
class DiscordMessage(BaseMessage):
def __init__(self, cmd: BaseCommands | None, msg: DCMessage):
self._msg = msg
self._cmd = cmd
self._chat = DiscordChatRoom(msg.channel)
self._inputs: list[DiscordFileInput] | None = None
self._author = None
@property
def id(self) -> int:
return self._msg.id
@property
def chat(self) -> DiscordChatRoom:
return self._chat
@property
def text(self) -> str:
# remove command name, slash and first space
return self._msg.contents[len(self._cmd) + 2:]
@property
def author(self) -> DiscordUser:
if self._author:
return self._author
return DiscordUser(self._msg.author)
@property
def command(self) -> str | None:
return self._cmd
@property
def inputs(self) -> list[DiscordFileInput]:
if self._inputs is None:
self._inputs = []
if self._msg.attachments:
self._inputs = [
DiscordFileInput(attachment=a)
for a in self._msg.attachments
]
return self._inputs
def generate_reply_embed(
config: FrontendConfig,
user: DiscordUser,
params: BodyV0Params,
tx_hash: str,
worker: str,
) -> discord.Embed:
embed = discord.Embed(
title='[SKYNET Transaction Explorer]',
url=f'https://{config.explorer_domain}/v2/explore/transaction/{tx_hash}',
color=discord.Color.blue())
prompt = params.prompt
if len(prompt) > 256:
prompt = prompt[:256]
gen_str = f'generated by {user.name}\n'
gen_str += f'performed by {worker}\n'
gen_str += f'reward: {config.reward}\n'
embed.add_field(
name='General Info', value=f'```{gen_str}```', inline=False)
# meta_str = f'__by {user.name}__\n'
# meta_str += f'*performed by {worker}*\n'
# meta_str += f'__**reward: {reward}**__\n'
embed.add_field(name='Prompt', value=f'```{prompt}\n```', inline=False)
# meta_str = f'`prompt:` {prompt}\n'
meta_str = f'seed: {params.seed}\n'
meta_str += f'step: {params.step}\n'
if params.guidance:
meta_str += f'guidance: {params.guidance}\n'
if params.strength:
meta_str += f'strength: {params.strength}\n'
meta_str += f'algo: {params.model}\n'
if params.upscaler:
meta_str += f'upscaler: {params.upscaler}\n'
embed.add_field(name='Parameters', value=f'```{meta_str}```', inline=False)
foot_str = f'Made with Skynet v{VERSION}\n'
foot_str += 'JOIN THE SWARM: https://discord.gg/PAabjJtZAF'
embed.set_footer(text=foot_str)
return embed
def append_command_handler(client: discord.Client, command: str, help_txt: str, fn: Awaitable):
@client.command(name=command, help=help_txt)
async def wrap_msg_and_handle(ctx: commands.Context):
msg = DiscordMessage(cmd=command, msg=ctx.message)
for file in msg.inputs:
await file.download()
await fn(msg)
class DiscordChatbot(BaseChatbot):
def __init__(
self,
config: FrontendConfig,
db: FrontendUserDB,
):
super().__init__(config, db)
intents = discord.Intents(
messages=True,
guilds=True,
typing=True,
members=True,
presences=True,
reactions=True,
message_content=True,
voice_states=True
)
client = discord.Client(
command_prefix='/',
intents=intents
)
@client.event
async def on_ready():
print(f'{client.user.name} has connected to Discord!')
for guild in client.guilds:
for channel in guild.channels:
if channel.name == "skynet":
await channel.send('Skynet bot online') # view=SkynetView(self.bot))
await channel.send(
'Welcome to Skynet\'s Discord Bot,\n\n'
'Skynet operates as a decentralized compute layer, offering a wide array of '
'support for diverse AI paradigms through the use of blockchain technology. '
'Our present focus is image generation, powered by 11 distinct models.\n\n'
'To begin exploring, use the \'/help\' command or directly interact with the'
'provided buttons. Here is an example command to generate an image:\n\n'
'\'/txt2img a big red tractor in a giant field of corn\''
)
print("\n==============")
print("Logged in as")
print(client.user.name)
print(client.user.id)
print("==============")
@client.event
async def on_message(message: DCMessage):
if message.author == client.user:
return
await self.process_commands(message)
@client.event
async def on_command_error(ctx, error):
if isinstance(error, commands.MissingRequiredArgument):
await ctx.send('You missed a required argument, please try again.')
append_command_handler(client, BaseCommands.HELP, 'Responds with help text', self.send_help)
append_command_handler(client, BaseCommands.COOL, 'Display a list of cool prompt words', self.send_cool_words)
append_command_handler(client, BaseCommands.QUEUE, 'Get information on current skynet queue', self.get_queue)
append_command_handler(client, BaseCommands.CONFIG, 'Allows user to configure inference params', self.set_config)
append_command_handler(client, BaseCommands.STATS, 'See user statistics', self.user_stats)
append_command_handler(client, BaseCommands.DONATE, 'See donation information', self.donation_info)
append_command_handler(client, BaseCommands.SAY, 'Admin command to make bot speak', self.say)
append_command_handler(client, BaseCommands.TXT2IMG, 'Generate an image from a prompt', self.handle_request)
append_command_handler(client, BaseCommands.REDO, 'Re-generate image using last prompt', self.handle_request)
self.client = client
self._main_room: DiscordChatRoom | None = None
async def init(self):
dc_channel = await self.client.get_channel(GROUP_ID)
self._main_room = DiscordChatRoom(channel=dc_channel)
logging.info('initialized')
async def run(self):
await self.init()
await self.client.run(self.config.token)
@property
def main_group(self) -> DiscordChatRoom:
return self._main_room
async def new_msg(self, chat: DiscordChatRoom, text: str, **kwargs) -> DiscordMessage:
dc_msg = await chat._channel.send(text, **kwargs)
return DiscordMessage(cmd=None, msg=dc_msg)
async def reply_to(self, msg: DiscordMessage, text: str, **kwargs) -> DiscordMessage:
dc_msg = await msg._msg.reply(content=text, **kwargs)
return DiscordMessage(cmd=None, msg=dc_msg)
async def edit_msg(self, msg: DiscordMessage, text: str, **kwargs):
await msg._msg.edit(content=text, **kwargs)
async def create_status_msg(self, msg: DiscordMessage, init_text: str, force_user: DiscordUser | None = None) -> tuple[BaseUser, BaseMessage, dict]:
# maybe init user
user = msg.author
if force_user:
user = force_user
user_row = await self.db.get_or_create_user(user.id)
# create status msg
embed = discord.Embed(
title='live updates',
description=init_text,
color=discord.Color.blue()
)
status_msg = await self.new_msg(msg.chat, None, embed=embed)
# start tracking of request in db
await self.db.new_user_request(user.id, msg.id, status_msg.id, status=init_text)
return [user, status_msg, user_row]
async def update_status_msg(self, msg: DiscordMessage, text: str):
await self.db.update_user_request_by_sid(msg.id, text)
embed = discord.Embed(
title='live updates',
description=text,
color=discord.Color.blue()
)
await self.edit_msg(msg, None, embed=embed)
async def append_status_msg(self, msg: DiscordMessage, text: str):
request = await self.db.get_user_request_by_sid(msg.id)
await self.update_status_msg(msg, request['status'] + text)
async def update_request_status_timeout(self, status_msg: DiscordMessage):
await self.append_status_msg(
status_msg,
f'\n[{timestamp_pretty()}] **timeout processing request**',
)
async def update_request_status_step_0(self, status_msg: DiscordMessage, user_msg: DiscordMessage):
await self.update_status_msg(
status_msg,
f'processing a \'{status_msg.cmd}\' request by {status_msg.author.name}\n'
f'[{timestamp_pretty()}] *broadcasting transaction to chain...* '
)
async def update_request_status_step_1(self, status_msg: DiscordMessage, tx_result: dict):
enqueue_tx_id = tx_result['transaction_id']
enqueue_tx_link = f'[**Your request on Skynet Explorer**](https://{self.config.explorer_domain}/v2/explore/transaction/{enqueue_tx_id})'
await self.append_status_msg(
status_msg,
'**broadcasted!** \n'
f'{enqueue_tx_link}\n'
f'[{timestamp_pretty()}] *workers are processing request...* '
)
async def update_request_status_step_2(self, status_msg: DiscordMessage, submit_tx_hash: str):
tx_link = f'[**Your result on Skynet Explorer**](https://{self.config.explorer_domain}/v2/explore/transaction/{submit_tx_hash})'
await self.append_status_msg(
status_msg,
'**request processed!**\n'
f'{tx_link}\n'
f'[{timestamp_pretty()}] *trying to download image...*\n '
)
async def update_request_status_final(
self,
og_msg: DiscordMessage,
status_msg: DiscordMessage,
user: DiscordUser,
params: BodyV0Params,
inputs: list[DiscordFileInput],
submit_tx_hash: str,
worker: str,
result_url: str,
result_img: bytes | None
):
embed = generate_reply_embed(
self.config, user, params, submit_tx_hash, worker)
if not result_img:
# result found on chain but failed to fetch img from ipfs
await self.append_status_msg(status_msg, f'[{timestamp_pretty()}] *Couldn\'t get IPFS hosted img [**here**]({result_url})!*')
return
await status_msg._msg.delete()
embed.set_image(url=result_url)
match len(inputs):
case 0:
await self.new_msg(og_msg.chat, None, embed=embed)
case _:
_input = inputs[-1]
dc_file = discord.File(_input._raw, filename=f'image-{og_msg.id}.png')
embed.set_thumbnail(url=f'attachment://image-{og_msg.id}.png')
await self.new_msg(og_msg.chat, None, embed=embed, file=dc_file)