From 3c0b9664ffbc4d7d9a567c016c68f3fd9b83547b Mon Sep 17 00:00:00 2001 From: Philipp Date: Sat, 6 Apr 2024 18:23:45 +0200 Subject: [PATCH] first commit --- base-config.yaml | 4 ++ dalle.py | 135 +++++++++++++++++++++++++++++++++++++++++++++++ maubot.yaml | 11 ++++ 3 files changed, 150 insertions(+) create mode 100644 base-config.yaml create mode 100644 dalle.py create mode 100644 maubot.yaml diff --git a/base-config.yaml b/base-config.yaml new file mode 100644 index 0000000..d2aa33b --- /dev/null +++ b/base-config.yaml @@ -0,0 +1,4 @@ +respond_to_notice: False +name: "@dalle:boehm.sh" +nickname: "DALL-E" +model: "dalle" diff --git a/dalle.py b/dalle.py new file mode 100644 index 0000000..8eb2546 --- /dev/null +++ b/dalle.py @@ -0,0 +1,135 @@ +import json +import urllib +import requests +import io +import base64 +import asyncio +import requests +import re + +from PIL import Image +from typing import Type, Deque, Dict, Generator +from mautrix.types import ImageInfo, EventType, MessageType, RelationType +from mautrix.types.event.message import BaseFileInfo, Format, TextMessageEventContent +from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper +from maubot import Plugin, MessageEvent +from maubot.handlers import event +from mautrix.client import Client + +class Config(BaseProxyConfig): + def do_update(self, helper: ConfigUpdateHelper) -> None: + helper.copy("respond_to_notice") + helper.copy("model") + helper.copy("name") + helper.copy("nickname") + +class Dalle(Plugin): + + name: str + nickname: str + + async def start(self) -> None: + self.config.load_and_update() + self.name = self.config["name"] + self.nickname = self.config["nickname"] + + @classmethod + def get_config_class(cls) -> Type[BaseProxyConfig]: + return Config + + async def should_respond(self, event: MessageEvent) -> bool: + if (event.sender == self.client.mxid or + event.content.body.startswith("!") or + event.content["msgtype"] != MessageType.TEXT or + event.content.relates_to["rel_type"] == RelationType.REPLACE): + return False + + # Check if user is using element + # Check if bot is mentioned + if "m.mentions" in event.content: + # case if element x and desktop element + if re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content["m.mentions"]["user_ids"][0], re.IGNORECASE): + return True + # most other clients + elif re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content.body, re.IGNORECASE): + return True + # element + elif re.search("(^|\s)(@)?" + self.nickname + "([ :,.!?]|$)", event.content.body, re.IGNORECASE): + return True + + # Reply to all DMs + if len(await self.client.get_joined_members(event.room_id)) == 2: + return True + + return False + + @event.on(EventType.ROOM_MESSAGE) + async def on_message(self, event: MessageEvent) -> None: + if not await self.should_respond(event): + return + + if "m.mentions" in event.content: + # case if element x and desktop element + if re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content["m.mentions"]["user_ids"][0], re.IGNORECASE): + event.content.body = event.content.body.replace(self.nickname, "") + # most other clients + elif re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content.body, re.IGNORECASE): + event.content.body = event.content.body.replace(self.name, "") + # element + elif re.search(self.nickname + "([ :,.!?]|$)", event.content.body, re.IGNORECASE): + event.content.body = event.content.body.replace(self.nickname, "") + + + try: + await event.mark_read() + + # Call the Dalle API to get picture + await self.client.set_typing(event.room_id, timeout=99999) + media = await self._call_dalle(event.content["body"]) + + file_name = event.content["body"] + + # Send the repond back to the chat room + uri = await self.client.upload_media(media, mime_type="image/jpeg", filename=file_name) + await self.client.send_file(event.room_id, url=uri, info=BaseFileInfo(mimetype="image/jpeg"), file_name=file_name, file_type=MessageType.IMAGE) + + # Reset out typing status + await self.client.set_typing(event.room_id, timeout=0) + except Exception as e: + await self.client.set_typing(event.room_id, timeout=0) + self.log.exception(f"We have failed somewhere? {e}") + pass + + async def _call_dalle(self, prompt): + headers = { "Content-Type": "application/json" } + data = { + "model": self.config["model"], + "prompt": prompt + } + + response = requests.post("https://nexra.aryahcr.cc/api/image/complements", headers=headers, data=json.dumps(data)) + + if response.status_code != 200: + self.log.warning(f"Unexpected status sending request to nexra.aryahcr.cc: {response.status_code}") + return + + count = -1 + for i in range(len(response.text)): + if count <= -1: + if response.text[i] == "{": + count = i + else: + break + response_json = json.loads(response.text[count:]) + + content = response_json['images'][0] + + base64_str = content.replace("data:image/jpeg;base64,", "") + mem_file = io.BytesIO(base64.decodebytes(bytes(base64_str, "utf-8"))) + + pil_image = Image.open(mem_file) + pil_image.save(mem_file, format=pil_image.format) + mem_file.seek(0) + + return mem_file + diff --git a/maubot.yaml b/maubot.yaml new file mode 100644 index 0000000..88a283a --- /dev/null +++ b/maubot.yaml @@ -0,0 +1,11 @@ +maubot: 0.1.0 +id: sh.boehm.dalle +version: 0.0.029 +license: MIT +modules: + - dalle +main_class: dalle/Dalle +database: false +config: true +extra_files: + - base-config.yaml