import json import urllib import requests import io import base64 import asyncio import requests import re from PIL import Image from typing import Type, Deque, Dict, Generator from mautrix.types import ImageInfo, EventType, MessageType, RelationType from mautrix.types.event.message import BaseFileInfo, Format, TextMessageEventContent from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper from maubot import Plugin, MessageEvent from maubot.handlers import event from mautrix.client import Client class Config(BaseProxyConfig): def do_update(self, helper: ConfigUpdateHelper) -> None: helper.copy("respond_to_notice") helper.copy("model") helper.copy("name") helper.copy("nickname") class Dalle(Plugin): name: str nickname: str async def start(self) -> None: self.config.load_and_update() self.name = self.config["name"] self.nickname = self.config["nickname"] @classmethod def get_config_class(cls) -> Type[BaseProxyConfig]: return Config async def should_respond(self, event: MessageEvent) -> bool: if (event.sender == self.client.mxid or event.content.body.startswith("!") or event.content["msgtype"] != MessageType.TEXT or event.content.relates_to["rel_type"] == RelationType.REPLACE): return False # Check if user is using element # Check if bot is mentioned if "m.mentions" in event.content: # case if element x and desktop element if re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content["m.mentions"]["user_ids"][0], re.IGNORECASE): return True # most other clients elif re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content.body, re.IGNORECASE): return True # element elif re.search("(^|\s)(@)?" + self.nickname + "([ :,.!?]|$)", event.content.body, re.IGNORECASE): return True # Reply to all DMs if len(await self.client.get_joined_members(event.room_id)) == 2: return True return False @event.on(EventType.ROOM_MESSAGE) async def on_message(self, event: MessageEvent) -> None: if not await self.should_respond(event): return if "m.mentions" in event.content: # case if element x and desktop element if re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content["m.mentions"]["user_ids"][0], re.IGNORECASE): event.content.body = event.content.body.replace(self.nickname, "") # most other clients elif re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content.body, re.IGNORECASE): event.content.body = event.content.body.replace(self.name, "") # element elif re.search(self.nickname + "([ :,.!?]|$)", event.content.body, re.IGNORECASE): event.content.body = event.content.body.replace(self.nickname, "") try: await event.mark_read() # Call the Dalle API to get picture await self.client.set_typing(event.room_id, timeout=99999) media = await self._call_dalle(event.content["body"]) file_name = event.content["body"] # Send the repond back to the chat room uri = await self.client.upload_media(media, mime_type="image/jpeg", filename=file_name) await self.client.send_file(event.room_id, url=uri, info=BaseFileInfo(mimetype="image/jpeg"), file_name=file_name, file_type=MessageType.IMAGE) # Reset out typing status await self.client.set_typing(event.room_id, timeout=0) except Exception as e: await self.client.set_typing(event.room_id, timeout=0) self.log.exception(f"We have failed somewhere? {e}") pass async def _call_dalle(self, prompt): headers = { "Content-Type": "application/json" } data = { "model": self.config["model"], "prompt": prompt } response = requests.post("https://nexra.aryahcr.cc/api/image/complements", headers=headers, data=json.dumps(data)) if response.status_code != 200: self.log.warning(f"Unexpected status sending request to nexra.aryahcr.cc: {response.status_code}") return count = -1 for i in range(len(response.text)): if count <= -1: if response.text[i] == "{": count = i else: break response_json = json.loads(response.text[count:]) content = response_json['images'][0] base64_str = content.replace("data:image/jpeg;base64,", "") mem_file = io.BytesIO(base64.decodebytes(bytes(base64_str, "utf-8"))) pil_image = Image.open(mem_file) pil_image.save(mem_file, format=pil_image.format) mem_file.seek(0) return mem_file