maubot-dalle-bot/dalle.py

134 lines
5 KiB
Python

import json
import urllib
import io
import base64
import asyncio
import aiohttp
import re
from PIL import Image
from typing import Type, Deque, Dict, Generator
from mautrix.types import ImageInfo, EventType, MessageType, RelationType
from mautrix.types.event.message import BaseFileInfo, Format, TextMessageEventContent
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from maubot import Plugin, MessageEvent
from maubot.handlers import event
from mautrix.client import Client
class Config(BaseProxyConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("respond_to_notice")
helper.copy("model")
helper.copy("name")
helper.copy("nickname")
class Dalle(Plugin):
name: str
nickname: str
async def start(self) -> None:
self.config.load_and_update()
self.name = self.config["name"]
self.nickname = self.config["nickname"]
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:
return Config
async def should_respond(self, event: MessageEvent) -> bool:
if (event.sender == self.client.mxid or
event.content.body.startswith("!") or
event.content["msgtype"] != MessageType.TEXT or
event.content.relates_to["rel_type"] == RelationType.REPLACE):
return False
# Check if user is using element
# Check if bot is mentioned
if "m.mentions" in event.content:
# case if element x and desktop element
if re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content["m.mentions"]["user_ids"][0], re.IGNORECASE):
return True
# most other clients
elif re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
return True
# element
elif re.search("(^|\s)(@)?" + self.nickname + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
return True
# Reply to all DMs
if len(await self.client.get_joined_members(event.room_id)) == 2:
return True
return False
@event.on(EventType.ROOM_MESSAGE)
async def on_message(self, event: MessageEvent) -> None:
if not await self.should_respond(event):
return
if "m.mentions" in event.content:
# case if element x and desktop element
if re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content["m.mentions"]["user_ids"][0], re.IGNORECASE):
event.content.body = event.content.body.replace(self.nickname, "")
# most other clients
elif re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
event.content.body = event.content.body.replace(self.name, "")
# element
elif re.search(self.nickname + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
event.content.body = event.content.body.replace(self.nickname, "")
try:
await event.mark_read()
# Call the Dalle API to get picture
await self.client.set_typing(event.room_id, timeout=99999)
media = await self._call_dalle(event.content["body"])
file_name = event.content["body"]
# Send the repond back to the chat room
uri = await self.client.upload_media(media, mime_type="image/jpeg", filename=file_name)
await self.client.send_file(event.room_id, url=uri, info=BaseFileInfo(mimetype="image/jpeg"), file_name=file_name, file_type=MessageType.IMAGE)
# Reset out typing status
await self.client.set_typing(event.room_id, timeout=0)
except Exception as e:
await self.client.set_typing(event.room_id, timeout=0)
self.log.exception(f"We have failed somewhere? {e}")
pass
async def _call_dalle(self, prompt):
headers = { "Content-Type": "application/json" }
data = {
"model": self.config["model"],
"prompt": prompt
}
async with aiohttp.request("POST", "https://nexra.aryahcr.cc/api/image/complements", headers=headers, data=json.dumps(data)) as response:
if response.status != 200:
self.log.warning(f"Unexpected status sending request to nexra.aryahcr.cc: {response.status_code}")
return
count = -1
response_text = await response.text()
for i in range(len(response_text)):
if count <= -1:
if response_text[i] == "{":
count = i
else:
break
response_json = json.loads(response_text[count:])
content = response_json['images'][0]
base64_str = content.replace("data:image/jpeg;base64,", "")
mem_file = io.BytesIO(base64.decodebytes(bytes(base64_str, "utf-8")))
pil_image = Image.open(mem_file)
pil_image.save(mem_file, format=pil_image.format)
mem_file.seek(0)
return mem_file