diff --git a/skynet/cli.py b/skynet/cli.py
index 58297ca..80799ae 100755
--- a/skynet/cli.py
+++ b/skynet/cli.py
@@ -210,50 +210,22 @@ def telegram(
db_pass: str
):
import asyncio
- from .frontend.telegram import SkynetTelegramFrontend
+ from skynet.frontend.chatbot.telegram import TelegramChatbot
+ from skynet.frontend.chatbot.db import FrontendUserDB
logging.basicConfig(level=loglevel)
- config = load_skynet_toml()
- tg_token = config.telegram.tg_token
-
- key = config.telegram.key
- account = config.telegram.account
- permission = config.telegram.permission
- node_url = config.telegram.node_url
- hyperion_url = config.telegram.hyperion_url
-
- ipfs_url = config.telegram.ipfs_url
-
- try:
- explorer_domain = config.telegram.explorer_domain
-
- except ConfigParsingError:
- explorer_domain = DEFAULT_EXPLORER_DOMAIN
-
- try:
- ipfs_domain = config.telegram.ipfs_domain
-
- except ConfigParsingError:
- ipfs_domain = DEFAULT_IPFS_DOMAIN
+ config = load_skynet_toml().telegram
async def _async_main():
- frontend = SkynetTelegramFrontend(
- tg_token,
- account,
- permission,
- node_url,
- hyperion_url,
- db_host, db_user, db_pass,
- ipfs_url,
- key=key,
- explorer_domain=explorer_domain,
- ipfs_domain=ipfs_domain
- )
-
- async with frontend.open():
- await frontend.bot.infinity_polling()
-
+ async with FrontendUserDB(
+ config.db_user,
+ config.db_pass,
+ config.db_host,
+ config.db_name
+ ) as db:
+ bot = TelegramChatbot(config, db)
+ await bot.run()
asyncio.run(_async_main())
diff --git a/skynet/config.py b/skynet/config.py
index fed610d..3310c4f 100755
--- a/skynet/config.py
+++ b/skynet/config.py
@@ -36,6 +36,13 @@ class FrontendConfig(msgspec.Struct):
hyperion_url: str
ipfs_url: str
token: str
+ ipfs_domain: str = 'ipfs.skygpu.net'
+ explorer_domain: str = 'explorer.skygpu.net'
+ proto_version: int = 0
+ reward: str = '20.0000 GPU'
+ receiver: str = 'gpu.scd'
+ result_max_width: int = 1280
+ result_max_height: int = 1280
class PinnerConfig(msgspec.Struct):
diff --git a/skynet/constants.py b/skynet/constants.py
index 4b09f6e..4de9b23 100755
--- a/skynet/constants.py
+++ b/skynet/constants.py
@@ -14,91 +14,91 @@ MODELS: dict[str, ModelDesc] = {
'runwayml/stable-diffusion-v1-5': ModelDesc(
short='stable',
mem=6,
- attrs={'size': {'w': 512, 'h': 512}},
+ attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
tags=['txt2img']
),
'stabilityai/stable-diffusion-2-1-base': ModelDesc(
short='stable2',
mem=6,
- attrs={'size': {'w': 512, 'h': 512}},
+ attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
tags=['txt2img']
),
'snowkidy/stable-diffusion-xl-base-0.9': ModelDesc(
short='stablexl0.9',
mem=8.3,
- attrs={'size': {'w': 1024, 'h': 1024}},
+ attrs={'size': {'w': 1024, 'h': 1024}, 'step': 28},
tags=['txt2img']
),
'Linaqruf/anything-v3.0': ModelDesc(
short='hdanime',
mem=6,
- attrs={'size': {'w': 512, 'h': 512}},
+ attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
tags=['txt2img']
),
'hakurei/waifu-diffusion': ModelDesc(
short='waifu',
mem=6,
- attrs={'size': {'w': 512, 'h': 512}},
+ attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
tags=['txt2img']
),
'nitrosocke/Ghibli-Diffusion': ModelDesc(
short='ghibli',
mem=6,
- attrs={'size': {'w': 512, 'h': 512}},
+ attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
tags=['txt2img']
),
'dallinmackay/Van-Gogh-diffusion': ModelDesc(
short='van-gogh',
mem=6,
- attrs={'size': {'w': 512, 'h': 512}},
+ attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
tags=['txt2img']
),
'lambdalabs/sd-pokemon-diffusers': ModelDesc(
short='pokemon',
mem=6,
- attrs={'size': {'w': 512, 'h': 512}},
+ attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
tags=['txt2img']
),
'Envvi/Inkpunk-Diffusion': ModelDesc(
short='ink',
mem=6,
- attrs={'size': {'w': 512, 'h': 512}},
+ attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
tags=['txt2img']
),
'nousr/robo-diffusion': ModelDesc(
short='robot',
mem=6,
- attrs={'size': {'w': 512, 'h': 512}},
+ attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
tags=['txt2img']
),
'black-forest-labs/FLUX.1-schnell': ModelDesc(
short='flux',
mem=24,
- attrs={'size': {'w': 1024, 'h': 1024}},
+ attrs={'size': {'w': 1024, 'h': 1024}, 'step': 4},
tags=['txt2img']
),
'black-forest-labs/FLUX.1-Fill-dev': ModelDesc(
short='flux-inpaint',
mem=24,
- attrs={'size': {'w': 1024, 'h': 1024}},
+ attrs={'size': {'w': 1024, 'h': 1024}, 'step': 28},
tags=['inpaint']
),
'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': ModelDesc(
short='stablexl-inpaint',
mem=8.3,
- attrs={'size': {'w': 1024, 'h': 1024}},
+ attrs={'size': {'w': 1024, 'h': 1024}, 'step': 28},
tags=['inpaint']
),
'prompthero/openjourney': ModelDesc(
short='midj',
mem=6,
- attrs={'size': {'w': 512, 'h': 512}},
+ attrs={'size': {'w': 512, 'h': 512}, 'step': 28},
tags=['txt2img', 'img2img']
),
'stabilityai/stable-diffusion-xl-base-1.0': ModelDesc(
short='stablexl',
mem=8.3,
- attrs={'size': {'w': 1024, 'h': 1024}},
+ attrs={'size': {'w': 1024, 'h': 1024}, 'step': 28},
tags=['txt2img']
),
}
@@ -225,8 +225,6 @@ Noise is added to the image you use as an init image for img2img, and then the\
HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.'
-GROUP_ID = -1001541979235
-
MP_ENABLED_ROLES = ['god']
MIN_STEP = 1
diff --git a/skynet/frontend/chatbot/__init__.py b/skynet/frontend/chatbot/__init__.py
new file mode 100644
index 0000000..26dc7ff
--- /dev/null
+++ b/skynet/frontend/chatbot/__init__.py
@@ -0,0 +1,448 @@
+import io
+import json
+import asyncio
+import logging
+from abc import ABC, abstractmethod, abstractproperty
+from PIL import Image, UnidentifiedImageError
+from random import randint
+from decimal import Decimal
+from hashlib import sha256
+from datetime import datetime, timedelta
+
+import msgspec
+from leap import CLEOS
+from leap.hyperion import HyperionAPI
+
+from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
+from skynet.types import BodyV0, BodyV0Params
+from skynet.config import FrontendConfig
+from skynet.constants import (
+ MODELS, GPU_CONTRACT_ABI,
+ HELP_TEXT,
+ HELP_TOPICS,
+ HELP_UNKWNOWN_PARAM,
+ COOL_WORDS,
+ DONATION_INFO
+)
+from skynet.frontend import validate_user_config_request
+from skynet.frontend.chatbot.db import FrontendUserDB
+from skynet.frontend.chatbot.types import (
+ BaseUser,
+ BaseChatRoom,
+ BaseCommands,
+ BaseFileInput,
+ BaseMessage
+)
+
+
+def perform_auto_conf(config: dict) -> dict:
+ model = MODELS[config['model']]
+
+ maybe_step = model.attrs.get('step', None)
+ if maybe_step:
+ config['step'] = maybe_step
+
+ maybe_width = model.attrs.get('width', None)
+ if maybe_width:
+ config['width'] = maybe_step
+
+ maybe_height = model.attrs.get('height', None)
+ if maybe_height:
+ config['height'] = maybe_step
+
+ return config
+
+
+def sanitize_params(params: dict) -> dict:
+ if 'seed' not in params:
+ params['seed'] = randint(0, 0xffffffff)
+
+ s_params = {}
+ for key, val in params.items():
+ if isinstance(val, Decimal):
+ val = str(val)
+
+ s_params[key] = val
+
+ return s_params
+
+
+class RequestTimeoutError(BaseException):
+ ...
+
+
+class BaseChatbot(ABC):
+
+ def __init__(
+ self,
+ config: FrontendConfig,
+ db: FrontendUserDB
+ ):
+ self.db = db
+ self.config = config
+ self.ipfs = AsyncIPFSHTTP(config.ipfs_url)
+ self.cleos = CLEOS(endpoint=config.node_url)
+ self.cleos.load_abi('gpu.scd', GPU_CONTRACT_ABI)
+ self.hyperion = HyperionAPI(config.hyperion_url)
+
+ async def init(self):
+ ...
+
+ @abstractmethod
+ async def run(self):
+ ...
+
+ @abstractproperty
+ def main_group(self) -> BaseChatRoom:
+ ...
+
+ @abstractmethod
+ async def new_msg(self, chat: BaseChatRoom, text: str, **kwargs) -> BaseMessage:
+ '''
+ Send text to a chat/channel.
+ '''
+ ...
+
+ @abstractmethod
+ async def reply_to(self, msg: BaseMessage, text: str, **kwargs) -> BaseMessage:
+ '''
+ Reply to existing message by sending new message.
+ '''
+ ...
+
+ @abstractmethod
+ async def edit_msg(self, msg: BaseMessage, text: str, **kwargs):
+ '''
+ Edit an existing message.
+ '''
+ ...
+
+ async def create_status_msg(self, msg: BaseMessage, init_text: str) -> tuple[BaseUser, BaseMessage, dict]:
+ # maybe init user
+ user = msg.author
+ user_row = await self.db.get_or_create_user(user.id)
+
+ # create status msg
+ status_msg = await self.reply_to(msg, init_text)
+
+ # start tracking of request in db
+ await self.db.new_user_request(user.id, msg.id, status_msg.id, status=init_text)
+ return [user, status_msg, user_row]
+
+ async def update_status_msg(self, msg: BaseMessage, text: str):
+ '''
+ Update an existing status message, also mirrors changes on db
+ '''
+ await self.db.update_user_request_by_sid(msg.id, text)
+ await self.edit_msg(msg, text)
+
+ async def append_status_msg(self, msg: BaseMessage, text: str):
+ '''
+ Append text to an existing status message
+ '''
+ request = await self.db.get_user_request_by_sid(msg.id)
+ await self.update_status_msg(msg, request['status'] + text)
+
+ @abstractmethod
+ async def update_request_status_timeout(self, status_msg: BaseMessage):
+ '''
+ Notify users when we timedout trying to find a matching submit
+ '''
+ ...
+
+ @abstractmethod
+ async def update_request_status_step_0(self, status_msg: BaseMessage, user: BaseUser):
+ '''
+ First step in request status message lifecycle, should notify which user sent the request
+ and that we are about to broadcast the request to chain
+ '''
+ ...
+
+ @abstractmethod
+ async def update_request_status_step_1(self, status_msg: BaseMessage, tx_result: dict):
+ '''
+ Second step in request status message lifecycle, should notify enqueue transaction
+ was processed by chain, and provide a link to the tx in the chain explorer
+ '''
+ ...
+
+ @abstractmethod
+ async def update_request_status_step_2(self, status_msg: BaseMessage, submit_tx_hash: str):
+ '''
+ Third step in request status message lifecycle, should notify matching submit transaction
+ was found, and provide a link to the tx in the chain explorer
+ '''
+ ...
+
+ @abstractmethod
+ async def update_request_status_final(
+ self,
+ og_msg: BaseMessage,
+ status_msg: BaseMessage,
+ user: BaseUser,
+ params: BodyV0Params,
+ inputs: list[BaseFileInput],
+ submit_tx_hash: str,
+ worker: str,
+ result_img: bytes | None
+ ):
+ '''
+ Last step in request status message lifecycle, should delete status message and send a
+ new message replying to the original user's message, generate the appropiate
+ reply caption and if provided also sent the found result img
+ '''
+ ...
+
+
+ async def handle_request(
+ self,
+ msg: BaseMessage
+ ):
+ if msg.chat.is_private:
+ return
+
+ if len(msg.text) == 0:
+ await self.reply_to(msg.id, 'empty prompt ignored.')
+ return
+
+ # maybe initialize user db row and send a new msg thats gonna
+ # be updated throughout the request lifecycle
+ user, status_msg, user_row = await self.create_status_msg(
+ msg, 'started processing a {msg.command} request...')
+
+ # if this is a redo msg, we attempt to get the input params from db
+ # else use msg properties
+ match msg.command:
+ case BaseCommands.TXT2IMG | BaseCommands.IMG2IMG:
+ prompt = msg.text
+ command = msg.command
+ inputs = msg.inputs
+
+ case BaseCommands.REDO:
+ prompt = await self.db.get_last_prompt_of(user.id)
+ command = await self.db.get_last_method_of(user.id)
+ inputs = await self.db.get_last_inputs_of(user.id)
+
+ if not prompt:
+ await self.reply_to(msg.id, 'no last prompt found, try doing a non-redo request first')
+ return
+
+ case _:
+ await self.reply_to(msg.id, f'unknown request of type {msg.command}')
+ return
+
+ # maybe apply recomended settings to this request
+ del user_row['id']
+ if user_row['autoconf']:
+ user_row = perform_auto_conf(user_row)
+
+ body = BodyV0(
+ method=command,
+ params=BodyV0Params(
+ prompt=prompt,
+ **user_row
+ )
+ )
+
+ # publish inputs to ipfs
+ input_cids = []
+ for i in inputs:
+ i.publish()
+ input_cids.append(i.cid)
+
+ inputs_str = ','.join((i for i in input_cids))
+
+ # unless its a redo request, update db user data
+ if command != BaseCommands.REDO:
+ await self.db.update_user_stats(
+ user.id,
+ command,
+ last_prompt=prompt,
+ last_inputs=inputs
+ )
+
+ await self.update_request_status_step_0(status_msg)
+
+ # prepare and send enqueue request
+ request_time = datetime.now().isoformat()
+ str_body = msgspec.json.encode(body).decode('utf-8')
+
+ enqueue_receipt = await self.cleos.a_push_action(
+ self.receiver,
+ 'enqueue',
+ (
+ self.config.account,
+ str_body,
+ inputs_str,
+ self.config.reward,
+ 1
+ )
+ )
+
+ await self.update_request_status_step_1(status_msg, enqueue_receipt)
+
+ # wait and search submit request using hyperion endpoint
+ console = enqueue_receipt['processed']['action_traces'][0]['console']
+ console_lines = console.split('\n')
+
+ request_id = None
+ request_hash = None
+ if self.config.proto_version == 0:
+ '''
+ v0 has req_id:nonce printed in enqueue console output
+ to search for a result request_hash arg on submit has
+ to match the sha256 of nonce + body + input_str
+ '''
+ request_id, nonce = console_lines[-1].rstrip().split(':')
+ request_hash = sha256(
+ (nonce + str_body + inputs_str).encode('utf-8')).hexdigest.upper()
+
+ request_id = int(request_id)
+
+ elif self.config.proto_version == 1:
+ '''
+ v1 uses a global unique nonce and prints it on enqueue
+ console output to search for a result request_id arg
+ on submit has to match the printed req_id
+ '''
+ request_id = int(console_lines[-1].rstrip())
+
+ else:
+ raise NotImplementedError
+
+ worker = None
+ submit_tx_hash = None
+ result_cid = None
+ for i in range(1, self.config.request_timeout + 1):
+ try:
+ submits = await self.hyperion.aget_actions(
+ account=self.config.account,
+ filter=f'{self.config.receiver}:submit',
+ sort='desc',
+ after=request_time
+ )
+ if self.config.proto_version == 0:
+ actions = [
+ action
+ for action in submits['actions']
+ if action['act']['data']['request_hash'] == request_hash
+ ]
+ elif self.config.proto_version == 1:
+ actions = [
+ action
+ for action in submits['actions']
+ if action['act']['data']['request_id'] == request_id
+ ]
+
+ else:
+ raise NotImplementedError
+
+ if len(actions) > 0:
+ action = actions[0]
+ submit_tx_hash = action['trx_id']
+ data = action['act']['data']
+ result_cid = data['ipfs_hash']
+ worker = data['worker']
+ logging.info('found matching submit! tx: {submit_tx_hash} cid: {ipfs_hash}')
+ break
+
+ except json.JSONDecodeError:
+ if i < self.config.request_timeout:
+ logging.error('network error while searching for submit, retry...')
+ await asyncio.sleep(1)
+
+ # if we found matching submit submit_tx_hash, worker, and result_cid will not be None
+ if not result_cid:
+ await self.update_request_status_timeout(status_msg)
+ raise RequestTimeoutError
+
+ await self.update_request_status_step_2(status_msg, submit_tx_hash)
+
+ # attempt to get the image and send it
+ result_link = f'https://{self.config.ipfs_domain}/ipfs/{result_cid}'
+ get_img_response = await get_ipfs_file(result_link)
+
+ result_img = None
+ if get_img_response and get_img_response.status_code == 200:
+ try:
+ with Image.open(io.BytesIO(get_img_response.raw)) as img:
+ w, h = img.size
+
+ if (
+ w > self.config.result_max_width
+ or
+ h > self.config.result_max_height
+ ):
+ max_size = (self.config.result_max_width, self.config.result_max_height)
+ logging.warning(
+ f'raw result is of size {img.size}, resizing to {max_size}')
+ img.thumbnail(max_size)
+
+ tmp_buf = io.BytesIO()
+ img.save(tmp_buf, format='PNG')
+ result_img = tmp_buf.getvalue()
+
+ except UnidentifiedImageError:
+ logging.warning(f'couldn\'t get ipfs result at {result_link}!')
+
+ await self.update_request_status_final(
+ msg, status_msg, user, body, inputs, submit_tx_hash, worker, result_img)
+
+ await self.db.increment_generated(user.id)
+
+ async def send_help(self, msg: BaseMessage):
+ if len(msg.text) == 0:
+ await self.reply_to(msg, HELP_TEXT)
+
+ else:
+ if msg.text in HELP_TOPICS:
+ await self.reply_to(msg, HELP_TOPICS[msg.text])
+
+ else:
+ await self.reply_to(msg, HELP_UNKWNOWN_PARAM)
+
+ async def send_cool_words(self, msg: BaseMessage):
+ await self.reply_to(msg, '\n'.join(COOL_WORDS))
+
+ async def get_queue(self, msg: BaseMessage):
+ an_hour_ago = datetime.now() - timedelta(hours=1)
+ queue = await self.cleos.aget_table(
+ self.config.receiver, self.config.receiver, 'queue',
+ index_position=2,
+ key_type='i64',
+ sort='desc',
+ lower_bound=int(an_hour_ago.timestamp())
+ )
+ await self.reply_to(
+ msg, f'Requests on skynet queue: {len(queue)}')
+
+ async def set_config(self, msg: BaseMessage):
+ try:
+ attr, val, reply_txt = validate_user_config_request(msg.text)
+
+ await self.db.update_user_config(msg.author.id, attr, val)
+
+ except BaseException as e:
+ reply_txt = str(e)
+
+ finally:
+ await self.reply_to(msg, reply_txt)
+
+ async def user_stats(self, msg: BaseMessage):
+ await self.db.get_or_create_user(msg.author.id)
+ generated, joined, role = await self.db.get_user_stats(msg.author.id)
+
+ stats_str = f'generated: {generated}\n'
+ stats_str += f'joined: {joined}\n'
+ stats_str += f'role: {role}\n'
+
+ await self.reply_to(msg, stats_str)
+
+ async def donation_info(self, msg: BaseMessage):
+ await self.reply_to(msg, DONATION_INFO)
+
+ async def say(self, msg: BaseMessage):
+ if not msg.chat.is_private or not msg.author.is_admin:
+ return
+
+ await self.new_msg(self.main_group, msg.text)
diff --git a/skynet/frontend/chatbot/db.py b/skynet/frontend/chatbot/db.py
new file mode 100644
index 0000000..3e36bff
--- /dev/null
+++ b/skynet/frontend/chatbot/db.py
@@ -0,0 +1,440 @@
+import logging
+import random
+import string
+import time
+from datetime import datetime
+
+import docker
+import psycopg2
+import asyncpg
+
+from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
+from contextlib import contextmanager as cm
+
+from skynet.constants import (
+ DEFAULT_ROLE, DEFAULT_MODEL, DEFAULT_STEP,
+ DEFAULT_WIDTH, DEFAULT_HEIGHT, DEFAULT_GUIDANCE,
+ DEFAULT_STRENGTH, DEFAULT_UPSCALER
+)
+from skynet.frontend.chatbot.types import BaseFileInput
+
+DB_INIT_SQL = """
+CREATE SCHEMA IF NOT EXISTS skynet;
+
+CREATE TABLE IF NOT EXISTS skynet.user(
+ id BIGSERIAL PRIMARY KEY NOT NULL,
+ generated INT NOT NULL,
+ joined TIMESTAMP NOT NULL,
+ last_method TEXT,
+ last_prompt TEXT,
+ last_inputs TEXT,
+ role VARCHAR(128) NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS skynet.user_config(
+ id BIGSERIAL NOT NULL,
+ model VARCHAR(512) NOT NULL,
+ step INT NOT NULL,
+ width INT NOT NULL,
+ height INT NOT NULL,
+ seed NUMERIC,
+ guidance DECIMAL NOT NULL,
+ strength DECIMAL NOT NULL,
+ upscaler VARCHAR(128),
+ autoconf BOOLEAN DEFAULT TRUE,
+ CONSTRAINT fk_config
+ FOREIGN KEY(id)
+ REFERENCES skynet.user(id)
+);
+
+CREATE TABLE IF NOT EXISTS skynet.user_requests(
+ id BIGSERIAL NOT NULL,
+ user_id BIGSERIAL NOT NULL,
+ sent TIMESTAMP NOT NULL,
+ status TEXT NOT NULL,
+ status_msg BIGSERIAL PRIMARY KEY NOT NULL,
+ CONSTRAINT fk_user_req
+ FOREIGN KEY(user_id)
+ REFERENCES skynet.user(id)
+);
+"""
+
+def try_decode_uid(uid: str) -> tuple[str | None, int | None]:
+ """
+ Attempts to decode the user ID. The user ID can be just an integer
+ or of the format 'proto+uid'. Returns (None, int) if it's just an
+ integer or (proto, int) if it's 'proto+uid'. Returns (None, None)
+ if neither format is valid.
+ """
+ try:
+ return None, int(uid)
+ except ValueError:
+ pass
+
+ try:
+ proto, uid_str = uid.split("+", 1)
+ return proto, int(uid_str)
+ except ValueError:
+ logging.warning(f"Got non-chat-proto UID?: {uid}")
+ return None, None
+
+
+@cm
+def open_new_database(cleanup: bool = True):
+ """
+ Context manager that spins up a temporary Postgres Docker container,
+ creates a 'skynet' user and database, and yields (container, password, host).
+ Stops the container on exit if 'cleanup' is True.
+ """
+ root_password = "".join(random.choice(string.ascii_lowercase) for _ in range(12))
+ skynet_password = "".join(random.choice(string.ascii_lowercase) for _ in range(12))
+
+ dclient = docker.from_env()
+ container = dclient.containers.run(
+ "postgres",
+ name="skynet-test-postgres",
+ ports={"5432/tcp": None},
+ environment={"POSTGRES_PASSWORD": root_password},
+ detach=True,
+ )
+
+ try:
+ # Wait for Postgres to be ready
+ for log_line in container.logs(stream=True):
+ line = log_line.decode().rstrip()
+ logging.info(line)
+ if (
+ "database system is ready to accept connections" in line
+ or "database system is shut down" in line
+ ):
+ break
+
+ container.reload()
+ port_info = container.ports["5432/tcp"][0]
+ port = port_info["HostPort"]
+ db_host = f"localhost:{port}"
+
+ # Let PostgreSQL settle
+ time.sleep(1)
+ logging.info("Creating 'skynet' database...")
+
+ with psycopg2.connect(
+ user="postgres", password=root_password, host="localhost", port=port
+ ) as conn:
+ conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
+ with conn.cursor() as cursor:
+ cursor.execute(f"CREATE USER skynet WITH PASSWORD '{skynet_password}'")
+ cursor.execute("CREATE DATABASE skynet")
+ cursor.execute("GRANT ALL PRIVILEGES ON DATABASE skynet TO skynet")
+
+ logging.info("Database setup complete.")
+ yield container, skynet_password, db_host
+
+ finally:
+ if container and cleanup:
+ container.stop()
+
+
+class FrontendUserDB:
+ """
+ A class that manages the connection pool for the 'skynet' database,
+ initializes the schema if needed, and provides high-level methods
+ for interacting with the 'skynet' tables.
+ """
+
+ def __init__(
+ self,
+ db_user: str,
+ db_pass: str,
+ db_host: str,
+ db_name: str
+ ):
+ self.db_user = db_user
+ self.db_pass = db_pass
+ self.db_host = db_host
+ self.db_name = db_name
+ self.pool: asyncpg.Pool | None = None
+
+ async def __aenter__(self) -> "FrontendUserDB":
+ dsn = f"postgres://{self.db_user}:{self.db_pass}@{self.db_host}/{self.db_name}"
+ self.pool = await asyncpg.create_pool(dsn=dsn)
+ await self._init_db()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ if self.pool:
+ await self.pool.close()
+
+ async def _init_db(self):
+ """
+ Ensures the 'skynet' schema and tables exist. Also checks for
+ missing columns and adds them if necessary.
+ """
+ async with self.pool.acquire() as conn:
+ # Check if schema is already initialized
+ result = await conn.fetch("""
+ SELECT DISTINCT table_schema
+ FROM information_schema.tables
+ WHERE table_schema = 'skynet'
+ """)
+ if not result:
+ await conn.execute(DB_INIT_SQL)
+
+ # Check if 'autoconf' column exists in user_config
+ col_check = await conn.fetch("""
+ SELECT column_name
+ FROM information_schema.columns
+ WHERE table_name = 'user_config' AND column_name = 'autoconf'
+ """)
+ if not col_check:
+ await conn.execute(
+ "ALTER TABLE skynet.user_config ADD COLUMN autoconf BOOLEAN DEFAULT TRUE;"
+ )
+
+ # -------------
+ # USER METHODS
+ # -------------
+
+ async def get_user_config(self, user_id: int):
+ """
+ Fetches the user_config for the given user ID.
+ Returns the record if found, otherwise None.
+ """
+ async with self.pool.acquire() as conn:
+ records = await conn.fetch(
+ "SELECT * FROM skynet.user_config WHERE id = $1", user_id
+ )
+ return records[0] if len(records) == 1 else None
+
+ async def get_user(self, user_id: int):
+ """Alias for get_user_config (same data returned)."""
+ return await self.get_user_config(user_id)
+
+ async def new_user(self, user_id: int):
+ """
+ Inserts a new user in skynet.user and its corresponding user_config record.
+ Raises ValueError if the user already exists.
+ """
+ existing = await self.get_user(user_id)
+ if existing:
+ raise ValueError("User already present in DB")
+
+ logging.info(f"New user! {user_id}")
+ now = datetime.utcnow()
+
+ async with self.pool.acquire() as conn:
+ async with conn.transaction():
+ await conn.execute(
+ """
+ INSERT INTO skynet.user(
+ id, generated, joined,
+ last_method, last_prompt, last_inputs, role
+ )
+ VALUES($1, 0, $2, 'txt2img', NULL, NULL, $3)
+ """,
+ user_id,
+ now,
+ DEFAULT_ROLE,
+ )
+ await conn.execute(
+ """
+ INSERT INTO skynet.user_config(
+ id, model, step, width,
+ height, guidance, strength, upscaler
+ )
+ VALUES($1, $2, $3, $4, $5, $6, $7, $8)
+ """,
+ user_id,
+ DEFAULT_MODEL,
+ DEFAULT_STEP,
+ DEFAULT_WIDTH,
+ DEFAULT_HEIGHT,
+ DEFAULT_GUIDANCE,
+ DEFAULT_STRENGTH,
+ DEFAULT_UPSCALER,
+ )
+
+ async def get_or_create_user(self, user_id: int):
+ """
+ Retrieves a user_config record for the given user_id.
+ If none exists, creates the user and returns the new record.
+ """
+ user_cfg = await self.get_user(user_id)
+ if not user_cfg:
+ await self.new_user(user_id)
+ user_cfg = await self.get_user(user_id)
+ return user_cfg
+
+ async def update_user(self, user_id: int, attr: str, val):
+ """
+ Generic function to update a single field in skynet.user for a given user_id.
+ """
+ async with self.pool.acquire() as conn:
+ await conn.execute(
+ f"UPDATE skynet.user SET {attr} = $2 WHERE id = $1", user_id, val
+ )
+
+ async def update_user_config(self, user_id: int, attr: str, val):
+ """
+ Generic function to update a single field in skynet.user_config for a given user_id.
+ """
+ async with self.pool.acquire() as conn:
+ await conn.execute(
+ f"UPDATE skynet.user_config SET {attr} = $2 WHERE id = $1", user_id, val
+ )
+
+ async def get_user_stats(self, user_id: int):
+ """
+ Returns (generated, joined, role) for the given user_id.
+ """
+ async with self.pool.acquire() as conn:
+ records = await conn.fetch(
+ """
+ SELECT generated, joined, role
+ FROM skynet.user
+ WHERE id = $1
+ """,
+ user_id,
+ )
+ return records[0] if records else None
+
+ async def increment_generated(self, user_id: int):
+ """
+ Increments the 'generated' count for a given user by 1.
+ """
+ async with self.pool.acquire() as conn:
+ await conn.execute(
+ """
+ UPDATE skynet.user
+ SET generated = generated + 1
+ WHERE id = $1
+ """,
+ user_id,
+ )
+
+ async def update_user_stats(
+ self,
+ user_id: int,
+ method: str,
+ last_prompt: str | None = None,
+ last_inputs: list | None = None
+ ):
+ """
+ Updates various 'last_*' fields in skynet.user.
+ """
+ await self.update_user(user_id, "last_method", method)
+ if last_prompt is not None:
+ await self.update_user(user_id, "last_prompt", last_prompt)
+
+ last_inputs_str = None
+ if isinstance(last_inputs, list):
+ last_inputs_str = ','.join((f'{f.id}:{f.cid}' for f in last_inputs))
+ await self.update_user(user_id, "last_inputs", last_inputs_str)
+
+ logging.info("Updated user stats: %s", (method, last_prompt, last_inputs_str))
+
+ # ----------------------
+ # USER REQUESTS METHODS
+ # ----------------------
+
+ async def get_user_request(self, request_id: int):
+ """
+ Fetches all matching rows for a given request_id.
+ """
+ async with self.pool.acquire() as conn:
+ return await conn.fetch(
+ "SELECT * FROM skynet.user_requests WHERE id = $1", request_id
+ )
+
+ async def get_user_request_by_sid(self, status_msg_id: int):
+ """
+ Fetches exactly one row (first row) by status_msg primary key.
+ """
+ async with self.pool.acquire() as conn:
+ records = await conn.fetch(
+ "SELECT * FROM skynet.user_requests WHERE status_msg = $1", status_msg_id
+ )
+ return records[0] if records else None
+
+ async def new_user_request(
+ self,
+ user_id: int,
+ request_id: int,
+ status_msg_id: int,
+ status: str = "started processing request..."
+ ):
+ """
+ Inserts a new row in skynet.user_requests.
+ """
+ now = datetime.utcnow()
+ async with self.pool.acquire() as conn:
+ async with conn.transaction():
+ await conn.execute(
+ """
+ INSERT INTO skynet.user_requests(
+ id, user_id, sent, status, status_msg
+ )
+ VALUES($1, $2, $3, $4, $5)
+ """,
+ request_id, user_id, now, status, status_msg_id
+ )
+
+ async def update_user_request(self, request_id: int, status: str):
+ """
+ Updates the 'status' for a user request identified by 'request_id'.
+ """
+ async with self.pool.acquire() as conn:
+ await conn.execute(
+ """
+ UPDATE skynet.user_requests
+ SET status = $2
+ WHERE id = $1
+ """,
+ request_id, status
+ )
+
+ async def update_user_request_by_sid(self, sid: int, status: str):
+ """
+ Updates the 'status' for a user request identified by 'status_msg'.
+ """
+ async with self.pool.acquire() as conn:
+ await conn.execute(
+ """
+ UPDATE skynet.user_requests
+ SET status = $2
+ WHERE status_msg = $1
+ """,
+ sid, status
+ )
+
+ # ----------------------------
+ # Convenience "Get Last" Helpers
+ # ----------------------------
+
+ async def get_last_method_of(self, user_id: int) -> str | None:
+ async with self.pool.acquire() as conn:
+ return await conn.fetchval(
+ "SELECT last_method FROM skynet.user WHERE id = $1", user_id
+ )
+
+ async def get_last_prompt_of(self, user_id: int) -> str | None:
+ async with self.pool.acquire() as conn:
+ return await conn.fetchval(
+ "SELECT last_prompt FROM skynet.user WHERE id = $1", user_id
+ )
+
+ async def get_last_inputs_of(self, user_id: int) -> list[BaseFileInput] | None:
+ async with self.pool.acquire() as conn:
+ last_inputs_str = await conn.fetchval(
+ "SELECT last_inputs FROM skynet.user WHERE id = $1", user_id
+ )
+
+ if not last_inputs_str:
+ return None
+
+ last_inputs = []
+ for i in last_inputs_str.split(','):
+ id, cid = i.split(':')
+ last_inputs.from_values(id, cid)
+
+ return last_inputs
diff --git a/skynet/frontend/chatbot/telegram.py b/skynet/frontend/chatbot/telegram.py
new file mode 100644
index 0000000..ba18de9
--- /dev/null
+++ b/skynet/frontend/chatbot/telegram.py
@@ -0,0 +1,378 @@
+import traceback
+
+from typing import Self, Awaitable
+from datetime import datetime, timezone
+
+from telebot.types import (
+ AsyncTeleBot,
+ User as TGUser,
+ Chat as TGChat,
+ PhotoSize as TGPhotoSize,
+ Message as TGMessage,
+ InputMediaPhoto,
+ InlineKeyboardButton,
+ InlineKeyboardMarkup
+)
+from telebot.async_telebot import ExceptionHandler
+from telebot.formatting import hlink
+
+from skynet.types import BodyV0Params
+from skynet.config import FrontendConfig
+from skynet.constants import VERSION
+from skynet.frontend.chatbot import BaseChatbot
+from skynet.frontend.chatbot.db import FrontendUserDB
+from skynet.frontend.types import (
+ BaseUser,
+ BaseChatRoom,
+ BaseFileInput,
+ BaseCommands,
+ BaseMessage
+)
+
+GROUP_ID = -1001541979235
+ADMIN_USER_ID = 383385940
+
+
+# Chatbot types impls
+
+class TelegramUser(BaseUser):
+ def __init__(self, user: TGUser):
+ self._user = user
+
+ @property
+ def id(self) -> int:
+ return self._user.id
+
+ @property
+ def name(self) -> str:
+ if self._user.username:
+ return f'@{self._user.username}'
+
+ return f'{self._user.first_name} id: {self.id}'
+
+ @property
+ def is_admin(self) -> bool:
+ return self.id == ADMIN_USER_ID
+
+
+class TelegramChatRoom(BaseChatRoom):
+
+ def __init__(self, chat: TGChat):
+ self._chat = chat
+
+ @property
+ def id(self) -> int:
+ return self._chat.id
+
+ @property
+ def is_private(self) -> bool:
+ return self._chat.type == 'private'
+
+
+class TelegramFileInput(BaseFileInput):
+
+ def __init__(
+ self,
+ photo: TGPhotoSize | None = None,
+ id: int | None = None,
+ cid: str | None = None
+ ):
+ self._photo = photo
+ self._id = id
+ self._cid = cid
+
+ self._raw = None
+
+ def from_values(id: int, cid: str) -> Self:
+ return TelegramFileInput(id=id, cid=cid)
+
+ @property
+ def id(self) -> int:
+ if self._id:
+ return self._id
+
+ return self._photo.file_id
+
+ @property
+ def cid(self) -> str:
+ if self._cid:
+ return self._cid
+
+ raise ValueError
+
+ async def download(self, bot: AsyncTeleBot) -> bytes:
+ file_path = (await bot.get_file(self.id)).file_path
+ self._raw = await bot.download_file(file_path)
+ return self._raw
+
+
+class TelegramMessage(BaseMessage):
+
+ def __init__(self, cmd: BaseCommands | None, msg: TGMessage):
+ self._msg = msg
+ self._cmd = cmd
+ self._chat = TelegramChatRoom(msg.chat)
+
+ @property
+ def id(self) -> int:
+ return self._msg.message_id
+
+ @property
+ def chat(self) -> TelegramChatRoom:
+ return self._chat
+
+ @property
+ def text(self) -> str:
+ return self._msg.text[len(self._cmd) + 1:]
+
+ @property
+ def author(self) -> TelegramUser:
+ return TelegramUser(self._msg.from_user)
+
+ @property
+ def command(self) -> str | None:
+ return self._cmd
+
+ @property
+ def inputs(self) -> list[TelegramFileInput]:
+ return [
+ TelegramFileInput(photo=p)
+ for p in self._msg.photo
+ ]
+
+
+# generic tg utils
+
+def timestamp_pretty():
+ return datetime.now(timezone.utc).strftime('%H:%M:%S')
+
+
+class TGExceptionHandler(ExceptionHandler):
+
+ def handle(exception):
+ traceback.print_exc()
+
+
+def build_redo_menu():
+ btn_redo = InlineKeyboardButton("Redo", callback_data='{\"method\": \"redo\"}')
+ inline_keyboard = InlineKeyboardMarkup()
+ inline_keyboard.add(btn_redo)
+ return inline_keyboard
+
+
+def prepare_metainfo_caption(user: TelegramUser, worker: str, reward: str, params: BodyV0Params) -> str:
+ prompt = params.prompt
+ if len(prompt) > 256:
+ prompt = prompt[:256]
+
+ meta_str = f'by {user.name}\n'
+ meta_str += f'performed by {worker}\n'
+ meta_str += f'reward: {reward}\n'
+
+ meta_str += f'prompt:
{prompt}\n'
+ meta_str += f'seed: {params.seed}
\n'
+ meta_str += f'step: {params.step}
\n'
+ if params.guidance:
+ meta_str += f'guidance: {params.guidance}
\n'
+
+ if params.strength:
+ meta_str += f'strength: {params.strength}
\n'
+
+ meta_str += f'algo: {params.model}
\n'
+
+ meta_str += f'Made with Skynet v{VERSION}\n'
+ meta_str += 'JOIN THE SWARM: @skynetgpu'
+ return meta_str
+
+
+def generate_reply_caption(
+ user: TelegramUser,
+ params: BodyV0Params,
+ tx_hash: str,
+ worker: str,
+ reward: str,
+ explorer_domain: str
+):
+ explorer_link = hlink(
+ 'SKYNET Transaction Explorer',
+ f'https://{explorer_domain}/v2/explore/transaction/{tx_hash}'
+ )
+
+ meta_info = prepare_metainfo_caption(user, worker, reward, params)
+
+ final_msg = '\n'.join([
+ 'Worker finished your task!',
+ explorer_link,
+ f'PARAMETER INFO:\n{meta_info}'
+ ])
+
+ final_msg = '\n'.join([
+ f'{explorer_link}',
+ f'{meta_info}'
+ ])
+
+ return final_msg
+
+
+def append_handler(bot: AsyncTeleBot, command: str, fn: Awaitable):
+ @bot.message_handler(commands=[command])
+ async def wrap_msg_and_handle(tg_msg: TGMessage):
+ await fn(TelegramMessage(cmd=command, msg=tg_msg))
+
+
+class TelegramChatbot(BaseChatbot):
+
+ def __init__(
+ self,
+ config: FrontendConfig,
+ db: FrontendUserDB,
+ ):
+ super().__init__(config, db)
+ bot = AsyncTeleBot(config.token, exception_handler=TGExceptionHandler)
+
+ append_handler(bot, BaseCommands.HELP, self.send_help)
+ append_handler(bot, BaseCommands.COOL, self.send_cool_words)
+ append_handler(bot, BaseCommands.QUEUE, self.get_queue)
+ append_handler(bot, BaseCommands.CONFIG, self.set_config)
+ append_handler(bot, BaseCommands.STATS, self.user_stats)
+ append_handler(bot, BaseCommands.DONATE, self.donation_info)
+ append_handler(bot, BaseCommands.SAY, self.say)
+
+ append_handler(bot, BaseCommands.TXT2IMG, self.handle_request)
+ append_handler(bot, BaseCommands.IMG2IMG, self.handle_request)
+ append_handler(bot, BaseCommands.REDO, self.handle_request)
+
+ self._main_room: TelegramChatRoom | None = None
+
+ async def init(self):
+ tg_group = await self.bot.get_chat(GROUP_ID)
+ self._main_room = TelegramChatRoom(chat=tg_group)
+
+ async def run(self):
+ await self.init()
+ await self.bot.infinity_polling()
+
+ @property
+ def main_group(self) -> TelegramChatRoom:
+ return self._main_room
+
+ async def new_msg(self, chat: TelegramChatRoom, text: str) -> TelegramMessage:
+ msg = await self.bot.send_message(chat.id, text, parse_mode='HTML')
+ return TelegramMessage(cmd=None, msg=msg)
+
+ async def reply_to(self, msg: TelegramMessage, text: str) -> TelegramMessage:
+ msg = await self.bot.reply_to(msg._msg, text, parse_mode='HTML')
+ return TelegramMessage(cmd=None, msg=msg)
+
+ async def edit_msg(self, msg: TelegramMessage, text: str):
+ await self.bot.edit_message_text(
+ text,
+ chat_id=msg.chat.id,
+ message_id=msg.id,
+ parse_mode='HTML'
+ )
+
+ async def update_request_status_timeout(self, status_msg: TelegramMessage):
+ '''
+ Notify users when we timedout trying to find a matching submit
+ '''
+ await self.append_status_msg(
+ status_msg,
+ f'\n[{timestamp_pretty()}] timeout processing request',
+ )
+
+ async def update_request_status_step_0(self, status_msg: TelegramMessage):
+ '''
+ First step in request status message lifecycle, should notify which user sent the request
+ and that we are about to broadcast the request to chain
+ '''
+ await self.update_status_msg(
+ status_msg,
+ f'processing a \'{status_msg.command}\' request by {status_msg.author.name}\n'
+ f'[{timestamp_pretty()}] broadcasting transaction to chain...'
+ )
+
+ async def update_request_status_step_1(self, status_msg: TelegramMessage, tx_result: dict):
+ '''
+ Second step in request status message lifecycle, should notify enqueue transaction
+ was processed by chain, and provide a link to the tx in the chain explorer
+ '''
+ enqueue_tx_id = tx_result['transaction_id']
+ enqueue_tx_link = hlink(
+ 'Your request on Skynet Explorer',
+ f'https://{self.explorer_domain}/v2/explore/transaction/{enqueue_tx_id}'
+ )
+ await self.append_status_msg(
+ status_msg,
+ f' broadcasted!\n'
+ f'{enqueue_tx_link}\n'
+ f'[{timestamp_pretty()}] workers are processing request...',
+ )
+
+ async def update_request_status_step_2(self, status_msg: TelegramMessage, submit_tx_hash: str):
+ '''
+ Third step in request status message lifecycle, should notify matching submit transaction
+ was found, and provide a link to the tx in the chain explorer
+ '''
+ tx_link = hlink(
+ 'Your result on Skynet Explorer',
+ f'https://{self.explorer_domain}/v2/explore/transaction/{submit_tx_hash}'
+ )
+ await self.append_status_msg(
+ status_msg,
+ f' request processed!\n'
+ f'{tx_link}\n'
+ f'[{timestamp_pretty()}] trying to download image...\n',
+ )
+
+ async def update_request_status_final(
+ self,
+ og_msg: TelegramMessage,
+ status_msg: TelegramMessage,
+ user: TelegramUser,
+ params: BodyV0Params,
+ inputs: list[TelegramFileInput],
+ submit_tx_hash: str,
+ worker: str,
+ result_img: bytes | None
+ ):
+ '''
+ Last step in request status message lifecycle, should delete status message and send a
+ new message replying to the original user's message, generate the appropiate
+ reply caption and if provided also sent the found result img
+ '''
+ caption = generate_reply_caption(
+ user, worker, self.config.reward, params)
+
+ await self.bot.delete_message(
+ chat_id=status_msg.chat.id,
+ message_id=status_msg.id
+ )
+
+ if not result_img:
+ # result found on chain but failed to fetch img from ipfs
+ await self.reply_to(og_msg, caption, reply_markup=build_redo_menu())
+ return
+
+ match len(inputs):
+ case 0:
+ await self.bot.send_photo(
+ status_msg.chat.id,
+ caption=caption,
+ photo=result_img,
+ reply_markup=build_redo_menu(),
+ parse_mode='HTML'
+ )
+
+ case 1:
+ _input = inputs.pop()
+ await self.bot.send_media_group(
+ status_msg.chat.id,
+ media=[
+ InputMediaPhoto(_input.id),
+ InputMediaPhoto(result_img, caption=caption, parse_mode='HTML')
+ ]
+ )
+
+ case _:
+ raise NotImplementedError
diff --git a/skynet/frontend/chatbot/types.py b/skynet/frontend/chatbot/types.py
new file mode 100644
index 0000000..3ae4e1a
--- /dev/null
+++ b/skynet/frontend/chatbot/types.py
@@ -0,0 +1,110 @@
+import io
+
+from ABC import ABC, abstractproperty, abstractmethod
+from enum import StrEnum
+from typing import Self
+from PIL import Image
+
+from skynet.ipfs import AsyncIPFSHTTP
+
+
+class BaseUser(ABC):
+
+ @abstractproperty
+ def id(self) -> int:
+ ...
+
+ @abstractproperty
+ def name(self) -> str:
+ ...
+
+ @abstractproperty
+ def is_admin(self) -> bool:
+ ...
+
+
+class BaseChatRoom(ABC):
+ @abstractproperty
+ def id(self) -> int:
+ ...
+
+ @abstractproperty
+ def is_private(self) -> bool:
+ ...
+
+
+class BaseFileInput(ABC):
+
+ @staticmethod
+ @abstractmethod
+ def from_values(id: int, cid: str) -> Self:
+ ...
+
+ @abstractproperty
+ def id(self) -> int:
+ ...
+
+ @abstractproperty
+ def cid(self) -> str:
+ ...
+
+ @abstractmethod
+ async def download(self, *args) -> bytes:
+ ...
+
+ async def publish(self, ipfs_api: AsyncIPFSHTTP, user_row: dict):
+ with Image.open(io.BytesIO(self._raw)) as img:
+ w, h = img.size
+
+ if (
+ w > user_row['width']
+ or
+ h > user_row['height']
+ ):
+ img.thumbnail((user_row['width'], user_row['height']))
+
+ img_path = '/tmp/ipfs-staging/img.png'
+ img.save(img_path, format='PNG')
+
+ ipfs_info = await ipfs_api.add(img_path)
+ ipfs_hash = ipfs_info['Hash']
+ await ipfs_api.pin(ipfs_hash)
+
+
+class BaseCommands(StrEnum):
+ TXT2IMG = 'txt2img'
+ IMG2IMG = 'img2img'
+ REDO = 'redo'
+ HELP = 'help'
+ COOL = 'cool'
+ QUEUE = 'queue'
+ CONFIG = 'config'
+ STATS = 'stats'
+ DONATE = 'donate'
+ SAY = 'say'
+
+
+class BaseMessage(ABC):
+ @abstractproperty
+ def id(self) -> int:
+ ...
+
+ @abstractproperty
+ def chat(self) -> BaseChatRoom:
+ ...
+
+ @abstractproperty
+ def text(self) -> str:
+ ...
+
+ @abstractproperty
+ def author(self) -> BaseUser:
+ ...
+
+ @abstractproperty
+ def command(self) -> str | None:
+ ...
+
+ @abstractproperty
+ def inputs(self) -> list[BaseFileInput]:
+ ...
diff --git a/skynet/frontend/telegram/__init__.py b/skynet/frontend/telegram/__init__.py
deleted file mode 100644
index 540a240..0000000
--- a/skynet/frontend/telegram/__init__.py
+++ /dev/null
@@ -1,295 +0,0 @@
-import io
-import random
-import logging
-import asyncio
-
-from PIL import Image, UnidentifiedImageError
-from json import JSONDecodeError
-from decimal import Decimal
-from hashlib import sha256
-from datetime import datetime
-from contextlib import AsyncExitStack
-from contextlib import asynccontextmanager as acm
-
-from leap.cleos import CLEOS
-from leap.protocol import Name, Asset
-from leap.hyperion import HyperionAPI
-
-from telebot.types import InputMediaPhoto
-from telebot.async_telebot import AsyncTeleBot
-
-from skynet.db import open_database_connection
-from skynet.ipfs import get_ipfs_file, AsyncIPFSHTTP
-from skynet.constants import *
-
-from . import *
-
-from .utils import *
-from .handlers import create_handler_context
-
-
-class SkynetTelegramFrontend:
-
- def __init__(
- self,
- token: str,
- account: str,
- permission: str,
- node_url: str,
- hyperion_url: str,
- db_host: str,
- db_user: str,
- db_pass: str,
- ipfs_node: str,
- key: str,
- explorer_domain: str,
- ipfs_domain: str
- ):
- self.token = token
- self.account = account
- self.permission = permission
- self.node_url = node_url
- self.hyperion_url = hyperion_url
- self.db_host = db_host
- self.db_user = db_user
- self.db_pass = db_pass
- self.key = key
- self.explorer_domain = explorer_domain
- self.ipfs_domain = ipfs_domain
-
- self.bot = AsyncTeleBot(token, exception_handler=SKYExceptionHandler)
- self.cleos = CLEOS(endpoint=node_url)
- self.cleos.load_abi('gpu.scd', GPU_CONTRACT_ABI)
- self.hyperion = HyperionAPI(hyperion_url)
- self.ipfs_node = AsyncIPFSHTTP(ipfs_node)
-
- self._async_exit_stack = AsyncExitStack()
-
- async def start(self):
- self.db_call = await self._async_exit_stack.enter_async_context(
- open_database_connection(
- self.db_user, self.db_pass, self.db_host))
-
- create_handler_context(self)
-
- async def stop(self):
- await self._async_exit_stack.aclose()
-
- @acm
- async def open(self):
- await self.start()
- yield self
- await self.stop()
-
- async def update_status_message(
- self, status_msg, new_text: str, **kwargs
- ):
- await self.db_call(
- 'update_user_request_by_sid', status_msg.id, new_text)
- return await self.bot.edit_message_text(
- new_text,
- chat_id=status_msg.chat.id,
- message_id=status_msg.id,
- **kwargs
- )
-
- async def append_status_message(
- self, status_msg, add_text: str, **kwargs
- ):
- request = await self.db_call('get_user_request_by_sid', status_msg.id)
- await self.update_status_message(
- status_msg,
- request['status'] + add_text,
- **kwargs
- )
-
- async def work_request(
- self,
- user,
- status_msg,
- method: str,
- params: dict,
- file_id: str | None = None,
- inputs: list[str] = []
- ) -> bool:
- if params['seed'] == None:
- params['seed'] = random.randint(0, 0xFFFFFFFF)
-
- sanitized_params = {}
- for key, val in params.items():
- if isinstance(val, Decimal):
- val = str(val)
-
- sanitized_params[key] = val
-
- body = json.dumps({
- 'method': 'diffuse',
- 'params': sanitized_params
- })
- request_time = datetime.now().isoformat()
-
- await self.update_status_message(
- status_msg,
- f'processing a \'{method}\' request by {tg_user_pretty(user)}\n'
- f'[{timestamp_pretty()}] broadcasting transaction to chain...',
- parse_mode='HTML'
- )
-
- reward = '20.0000 GPU'
- res = await self.cleos.a_push_action(
- 'gpu.scd',
- 'enqueue',
- list({
- 'user': Name(self.account),
- 'request_body': body,
- 'binary_data': ','.join(inputs),
- 'reward': Asset.from_str(reward),
- 'min_verification': 1
- }.values()),
- self.account, self.key, permission=self.permission
- )
-
- if 'code' in res or 'statusCode' in res:
- logging.error(json.dumps(res, indent=4))
- await self.update_status_message(
- status_msg,
- 'skynet has suffered an internal error trying to fill this request')
- return False
-
- enqueue_tx_id = res['transaction_id']
- enqueue_tx_link = hlink(
- 'Your request on Skynet Explorer',
- f'https://{self.explorer_domain}/v2/explore/transaction/{enqueue_tx_id}'
- )
-
- await self.append_status_message(
- status_msg,
- f' broadcasted!\n'
- f'{enqueue_tx_link}\n'
- f'[{timestamp_pretty()}] workers are processing request...',
- parse_mode='HTML'
- )
-
- out = res['processed']['action_traces'][0]['console']
-
- request_id, nonce = out.split(':')
-
- request_hash = sha256(
- (nonce + body + ','.join(inputs)).encode('utf-8')).hexdigest().upper()
-
- request_id = int(request_id)
-
- logging.info(f'{request_id} enqueued.')
-
- tx_hash = None
- ipfs_hash = None
- for i in range(60 * 3):
- try:
- submits = await self.hyperion.aget_actions(
- account=self.account,
- filter='gpu.scd:submit',
- sort='desc',
- after=request_time
- )
- actions = [
- action
- for action in submits['actions']
- if action[
- 'act']['data']['request_hash'] == request_hash
- ]
- if len(actions) > 0:
- tx_hash = actions[0]['trx_id']
- data = actions[0]['act']['data']
- ipfs_hash = data['ipfs_hash']
- worker = data['worker']
- logging.info('Found matching submit!')
- break
-
- except JSONDecodeError:
- logging.error(f'network error while getting actions, retry..')
-
- await asyncio.sleep(1)
-
- if not ipfs_hash:
- await self.update_status_message(
- status_msg,
- f'\n[{timestamp_pretty()}] timeout processing request',
- parse_mode='HTML'
- )
- return False
-
- tx_link = hlink(
- 'Your result on Skynet Explorer',
- f'https://{self.explorer_domain}/v2/explore/transaction/{tx_hash}'
- )
-
- await self.append_status_message(
- status_msg,
- f' request processed!\n'
- f'{tx_link}\n'
- f'[{timestamp_pretty()}] trying to download image...\n',
- parse_mode='HTML'
- )
-
- caption = generate_reply_caption(
- user, params, tx_hash, worker, reward, self.explorer_domain)
-
- # attempt to get the image and send it
- ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
-
- res = await get_ipfs_file(ipfs_link)
- logging.info(f'got response from {ipfs_link}')
- if not res or res.status_code != 200:
- logging.warning(f'couldn\'t get ipfs binary data at {ipfs_link}!')
-
- else:
- try:
- with Image.open(io.BytesIO(res.raw)) as image:
- w, h = image.size
-
- if w > TG_MAX_WIDTH or h > TG_MAX_HEIGHT:
- logging.warning(f'result is of size {image.size}')
- image.thumbnail((TG_MAX_WIDTH, TG_MAX_HEIGHT))
-
- tmp_buf = io.BytesIO()
- image.save(tmp_buf, format='PNG')
- png_img = tmp_buf.getvalue()
-
- except UnidentifiedImageError:
- logging.warning(f'couldn\'t get ipfs binary data at {ipfs_link}!')
-
- if not png_img:
- await self.update_status_message(
- status_msg,
- caption,
- reply_markup=build_redo_menu(),
- parse_mode='HTML'
- )
- return True
-
- logging.info(f'success! sending generated image')
- await self.bot.delete_message(
- chat_id=status_msg.chat.id, message_id=status_msg.id)
- if file_id: # img2img
- await self.bot.send_media_group(
- status_msg.chat.id,
- media=[
- InputMediaPhoto(file_id),
- InputMediaPhoto(
- png_img,
- caption=caption,
- parse_mode='HTML'
- )
- ],
- )
-
- else: # txt2img
- await self.bot.send_photo(
- status_msg.chat.id,
- caption=caption,
- photo=png_img,
- reply_markup=build_redo_menu(),
- parse_mode='HTML'
- )
-
- return True
diff --git a/skynet/frontend/telegram/handlers.py b/skynet/frontend/telegram/handlers.py
deleted file mode 100644
index e9eaebb..0000000
--- a/skynet/frontend/telegram/handlers.py
+++ /dev/null
@@ -1,365 +0,0 @@
-import io
-import json
-import logging
-
-from datetime import datetime, timedelta
-
-from PIL import Image
-from telebot.types import CallbackQuery, Message
-
-from skynet.frontend import validate_user_config_request, perform_auto_conf
-from skynet.constants import *
-
-
-def create_handler_context(frontend: 'SkynetTelegramFrontend'):
-
- bot = frontend.bot
- cleos = frontend.cleos
- db_call = frontend.db_call
- work_request = frontend.work_request
-
- ipfs_node = frontend.ipfs_node
-
- # generic / simple handlers
-
- @bot.message_handler(commands=['help'])
- async def send_help(message):
- splt_msg = message.text.split(' ')
-
- if len(splt_msg) == 1:
- await bot.reply_to(message, HELP_TEXT)
-
- else:
- param = splt_msg[1]
- if param in HELP_TOPICS:
- await bot.reply_to(message, HELP_TOPICS[param])
-
- else:
- await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
-
- @bot.message_handler(commands=['cool'])
- async def send_cool_words(message):
- await bot.reply_to(message, '\n'.join(COOL_WORDS))
-
- @bot.message_handler(commands=['queue'])
- async def queue(message):
- an_hour_ago = datetime.now() - timedelta(hours=1)
- queue = await cleos.aget_table(
- 'gpu.scd', 'gpu.scd', 'queue',
- index_position=2,
- key_type='i64',
- sort='desc',
- lower_bound=int(an_hour_ago.timestamp())
- )
- await bot.reply_to(
- message, f'Total requests on skynet queue: {len(queue)}')
-
-
- @bot.message_handler(commands=['config'])
- async def set_config(message):
- user = message.from_user.id
- try:
- attr, val, reply_txt = validate_user_config_request(
- message.text)
-
- logging.info(f'user config update: {attr} to {val}')
- await db_call('update_user_config', user, attr, val)
- logging.info('done')
-
- except BaseException as e:
- reply_txt = str(e)
-
- finally:
- await bot.reply_to(message, reply_txt)
-
- @bot.message_handler(commands=['stats'])
- async def user_stats(message):
- user = message.from_user.id
-
- await db_call('get_or_create_user', user)
- generated, joined, role = await db_call('get_user_stats', user)
-
- stats_str = f'generated: {generated}\n'
- stats_str += f'joined: {joined}\n'
- stats_str += f'role: {role}\n'
-
- await bot.reply_to(
- message, stats_str)
-
- @bot.message_handler(commands=['donate'])
- async def donation_info(message):
- await bot.reply_to(
- message, DONATION_INFO)
-
- @bot.message_handler(commands=['say'])
- async def say(message):
- chat = message.chat
- user = message.from_user
-
- if (chat.type == 'group') or (user.id != 383385940):
- return
-
- await bot.send_message(GROUP_ID, message.text[4:])
-
-
- # generic txt2img handler
-
- async def _generic_txt2img(message_or_query):
- if isinstance(message_or_query, CallbackQuery):
- query = message_or_query
- message = query.message
- user = query.from_user
- chat = query.message.chat
-
- else:
- message = message_or_query
- user = message.from_user
- chat = message.chat
-
- if chat.type == 'private':
- return
-
- reply_id = None
- if chat.type == 'group' and chat.id == GROUP_ID:
- reply_id = message.message_id
-
- user_row = await db_call('get_or_create_user', user.id)
-
- # init new msg
- init_msg = 'started processing txt2img request...'
- status_msg = await bot.reply_to(message, init_msg)
- await db_call(
- 'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
-
- prompt = ' '.join(message.text.split(' ')[1:])
-
- if len(prompt) == 0:
- await bot.edit_message_text(
- 'Empty text prompt ignored.',
- chat_id=status_msg.chat.id,
- message_id=status_msg.id
- )
- await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
- return
-
- logging.info(f'mid: {message.id}')
-
- user_config = {**user_row}
- del user_config['id']
-
- if user_config['autoconf']:
- user_config = perform_auto_conf(user_config)
-
- params = {
- 'prompt': prompt,
- **user_config
- }
-
- await db_call(
- 'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
-
- success = await work_request(user, status_msg, 'txt2img', params)
-
- if success:
- await db_call('increment_generated', user.id)
-
-
- # generic img2img handler
-
- async def _generic_img2img(message_or_query):
- if isinstance(message_or_query, CallbackQuery):
- query = message_or_query
- message = query.message
- user = query.from_user
- chat = query.message.chat
-
- else:
- message = message_or_query
- user = message.from_user
- chat = message.chat
-
- if chat.type == 'private':
- return
-
- reply_id = None
- if chat.type == 'group' and chat.id == GROUP_ID:
- reply_id = message.message_id
-
- user_row = await db_call('get_or_create_user', user.id)
-
- # init new msg
- init_msg = 'started processing txt2img request...'
- status_msg = await bot.reply_to(message, init_msg)
- await db_call(
- 'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
-
- if not message.caption.startswith('/img2img'):
- await bot.reply_to(
- message,
- 'For image to image you need to add /img2img to the beggining of your caption'
- )
- return
-
- prompt = ' '.join(message.caption.split(' ')[1:])
-
- if len(prompt) == 0:
- await bot.reply_to(message, 'Empty text prompt ignored.')
- return
-
- file_id = message.photo[-1].file_id
- file_path = (await bot.get_file(file_id)).file_path
- image_raw = await bot.download_file(file_path)
-
- user_config = {**user_row}
- del user_config['id']
- if user_config['autoconf']:
- user_config = perform_auto_conf(user_config)
-
- with Image.open(io.BytesIO(image_raw)) as image:
- w, h = image.size
-
- if w > user_config['width'] or h > user_config['height']:
- logging.warning(f'user sent img of size {image.size}')
- image.thumbnail(
- (user_config['width'], user_config['height']))
- logging.warning(f'resized it to {image.size}')
-
- image_loc = 'ipfs-staging/image.png'
- image.save(image_loc, format='PNG')
-
- ipfs_info = await ipfs_node.add(image_loc)
- ipfs_hash = ipfs_info['Hash']
- await ipfs_node.pin(ipfs_hash)
-
- logging.info(f'published input image {ipfs_hash} on ipfs')
-
- logging.info(f'mid: {message.id}')
-
- params = {
- 'prompt': prompt,
- **user_config
- }
-
- await db_call(
- 'update_user_stats',
- user.id,
- 'img2img',
- last_file=file_id,
- last_prompt=prompt,
- last_binary=ipfs_hash
- )
-
- success = await work_request(
- user, status_msg, 'img2img', params,
- file_id=file_id,
- inputs=ipfs_hash
- )
-
- if success:
- await db_call('increment_generated', user.id)
-
-
- # generic redo handler
-
- async def _redo(message_or_query):
- is_query = False
- if isinstance(message_or_query, CallbackQuery):
- is_query = True
- query = message_or_query
- message = query.message
- user = query.from_user
- chat = query.message.chat
-
- elif isinstance(message_or_query, Message):
- message = message_or_query
- user = message.from_user
- chat = message.chat
-
- if chat.type == 'private':
- return
-
- init_msg = 'started processing redo request...'
- if is_query:
- status_msg = await bot.send_message(chat.id, init_msg)
-
- else:
- status_msg = await bot.reply_to(message, init_msg)
-
- method = await db_call('get_last_method_of', user.id)
- prompt = await db_call('get_last_prompt_of', user.id)
-
- file_id = None
- binary = ''
- if method == 'img2img':
- file_id = await db_call('get_last_file_of', user.id)
- binary = await db_call('get_last_binary_of', user.id)
-
- if not prompt:
- await bot.reply_to(
- message,
- 'no last prompt found, do a txt2img cmd first!'
- )
- return
-
-
- user_row = await db_call('get_or_create_user', user.id)
- await db_call(
- 'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
- user_config = {**user_row}
- del user_config['id']
- if user_config['autoconf']:
- user_config = perform_auto_conf(user_config)
-
- params = {
- 'prompt': prompt,
- **user_config
- }
-
- success = await work_request(
- user, status_msg, 'redo', params,
- file_id=file_id,
- inputs=binary
- )
-
- if success:
- await db_call('increment_generated', user.id)
-
-
- # "proxy" handlers just request routers
-
- @bot.message_handler(commands=['txt2img'])
- async def send_txt2img(message):
- await _generic_txt2img(message)
-
- @bot.message_handler(func=lambda message: True, content_types=[
- 'photo', 'document'])
- async def send_img2img(message):
- await _generic_img2img(message)
-
- @bot.message_handler(commands=['img2img'])
- async def img2img_missing_image(message):
- await bot.reply_to(
- message,
- 'seems you tried to do an img2img command without sending image'
- )
-
- @bot.message_handler(commands=['redo'])
- async def redo(message):
- await _redo(message)
-
- @bot.callback_query_handler(func=lambda call: True)
- async def callback_query(call):
- msg = json.loads(call.data)
- logging.info(call.data)
- method = msg.get('method')
- match method:
- case 'redo':
- await _redo(call)
-
-
- # catch all handler for things we dont support
-
- @bot.message_handler(func=lambda message: True)
- async def echo_message(message):
- if message.text[0] == '/':
- await bot.reply_to(message, UNKNOWN_CMD_TEXT)
diff --git a/skynet/frontend/telegram/utils.py b/skynet/frontend/telegram/utils.py
deleted file mode 100644
index 13271fb..0000000
--- a/skynet/frontend/telegram/utils.py
+++ /dev/null
@@ -1,105 +0,0 @@
-import json
-import logging
-import traceback
-
-from datetime import datetime, timezone
-
-from telebot.types import InlineKeyboardButton, InlineKeyboardMarkup
-from telebot.async_telebot import ExceptionHandler
-from telebot.formatting import hlink
-
-from skynet.constants import *
-
-
-def timestamp_pretty():
- return datetime.now(timezone.utc).strftime('%H:%M:%S')
-
-
-def tg_user_pretty(tguser):
- if tguser.username:
- return f'@{tguser.username}'
- else:
- return f'{tguser.first_name} id: {tguser.id}'
-
-
-class SKYExceptionHandler(ExceptionHandler):
-
- def handle(exception):
- traceback.print_exc()
-
-
-def build_redo_menu():
- btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'}))
- inline_keyboard = InlineKeyboardMarkup()
- inline_keyboard.add(btn_redo)
- return inline_keyboard
-
-
-def prepare_metainfo_caption(tguser, worker: str, reward: str, meta: dict) -> str:
- prompt = meta["prompt"]
- if len(prompt) > 256:
- prompt = prompt[:256]
-
-
- meta_str = f'by {tg_user_pretty(tguser)}\n'
- meta_str += f'performed by {worker}\n'
- meta_str += f'reward: {reward}\n'
-
- meta_str += f'prompt:
{prompt}\n'
- meta_str += f'seed: {meta["seed"]}
\n'
- meta_str += f'step: {meta["step"]}
\n'
- meta_str += f'guidance: {meta["guidance"]}
\n'
- if meta['strength']:
- meta_str += f'strength: {meta["strength"]}
\n'
- meta_str += f'algo: {meta["model"]}
\n'
- if meta['upscaler']:
- meta_str += f'upscaler: {meta["upscaler"]}
\n'
-
- meta_str += f'Made with Skynet v{VERSION}\n'
- meta_str += f'JOIN THE SWARM: @skynetgpu'
- return meta_str
-
-
-def generate_reply_caption(
- tguser, # telegram user
- params: dict,
- tx_hash: str,
- worker: str,
- reward: str,
- explorer_domain: str
-):
- explorer_link = hlink(
- 'SKYNET Transaction Explorer',
- f'https://{explorer_domain}/v2/explore/transaction/{tx_hash}'
- )
-
- meta_info = prepare_metainfo_caption(tguser, worker, reward, params)
-
- final_msg = '\n'.join([
- 'Worker finished your task!',
- explorer_link,
- f'PARAMETER INFO:\n{meta_info}'
- ])
-
- final_msg = '\n'.join([
- f'{explorer_link}',
- f'{meta_info}'
- ])
-
- logging.info(final_msg)
-
- return final_msg
-
-
-async def get_global_config(cleos):
- return (await cleos.aget_table(
- 'gpu.scd', 'gpu.scd', 'config'))[0]
-
-async def get_user_nonce(cleos, user: str):
- return (await cleos.aget_table(
- 'gpu.scd', 'gpu.scd', 'users',
- index_position=1,
- key_type='name',
- lower_bound=user,
- upper_bound=user
- ))[0]['nonce']