first commit
This commit is contained in:
commit
3c0b9664ff
3 changed files with 150 additions and 0 deletions
4
base-config.yaml
Normal file
4
base-config.yaml
Normal file
|
@ -0,0 +1,4 @@
|
|||
respond_to_notice: False
|
||||
name: "@dalle:boehm.sh"
|
||||
nickname: "DALL-E"
|
||||
model: "dalle"
|
135
dalle.py
Normal file
135
dalle.py
Normal file
|
@ -0,0 +1,135 @@
|
|||
import json
|
||||
import urllib
|
||||
import requests
|
||||
import io
|
||||
import base64
|
||||
import asyncio
|
||||
import requests
|
||||
import re
|
||||
|
||||
from PIL import Image
|
||||
from typing import Type, Deque, Dict, Generator
|
||||
from mautrix.types import ImageInfo, EventType, MessageType, RelationType
|
||||
from mautrix.types.event.message import BaseFileInfo, Format, TextMessageEventContent
|
||||
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
||||
from maubot import Plugin, MessageEvent
|
||||
from maubot.handlers import event
|
||||
from mautrix.client import Client
|
||||
|
||||
class Config(BaseProxyConfig):
|
||||
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
||||
helper.copy("respond_to_notice")
|
||||
helper.copy("model")
|
||||
helper.copy("name")
|
||||
helper.copy("nickname")
|
||||
|
||||
class Dalle(Plugin):
|
||||
|
||||
name: str
|
||||
nickname: str
|
||||
|
||||
async def start(self) -> None:
|
||||
self.config.load_and_update()
|
||||
self.name = self.config["name"]
|
||||
self.nickname = self.config["nickname"]
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[BaseProxyConfig]:
|
||||
return Config
|
||||
|
||||
async def should_respond(self, event: MessageEvent) -> bool:
|
||||
if (event.sender == self.client.mxid or
|
||||
event.content.body.startswith("!") or
|
||||
event.content["msgtype"] != MessageType.TEXT or
|
||||
event.content.relates_to["rel_type"] == RelationType.REPLACE):
|
||||
return False
|
||||
|
||||
# Check if user is using element
|
||||
# Check if bot is mentioned
|
||||
if "m.mentions" in event.content:
|
||||
# case if element x and desktop element
|
||||
if re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content["m.mentions"]["user_ids"][0], re.IGNORECASE):
|
||||
return True
|
||||
# most other clients
|
||||
elif re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
|
||||
return True
|
||||
# element
|
||||
elif re.search("(^|\s)(@)?" + self.nickname + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
|
||||
return True
|
||||
|
||||
# Reply to all DMs
|
||||
if len(await self.client.get_joined_members(event.room_id)) == 2:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@event.on(EventType.ROOM_MESSAGE)
|
||||
async def on_message(self, event: MessageEvent) -> None:
|
||||
if not await self.should_respond(event):
|
||||
return
|
||||
|
||||
if "m.mentions" in event.content:
|
||||
# case if element x and desktop element
|
||||
if re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content["m.mentions"]["user_ids"][0], re.IGNORECASE):
|
||||
event.content.body = event.content.body.replace(self.nickname, "")
|
||||
# most other clients
|
||||
elif re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
|
||||
event.content.body = event.content.body.replace(self.name, "")
|
||||
# element
|
||||
elif re.search(self.nickname + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
|
||||
event.content.body = event.content.body.replace(self.nickname, "")
|
||||
|
||||
|
||||
try:
|
||||
await event.mark_read()
|
||||
|
||||
# Call the Dalle API to get picture
|
||||
await self.client.set_typing(event.room_id, timeout=99999)
|
||||
media = await self._call_dalle(event.content["body"])
|
||||
|
||||
file_name = event.content["body"]
|
||||
|
||||
# Send the repond back to the chat room
|
||||
uri = await self.client.upload_media(media, mime_type="image/jpeg", filename=file_name)
|
||||
await self.client.send_file(event.room_id, url=uri, info=BaseFileInfo(mimetype="image/jpeg"), file_name=file_name, file_type=MessageType.IMAGE)
|
||||
|
||||
# Reset out typing status
|
||||
await self.client.set_typing(event.room_id, timeout=0)
|
||||
except Exception as e:
|
||||
await self.client.set_typing(event.room_id, timeout=0)
|
||||
self.log.exception(f"We have failed somewhere? {e}")
|
||||
pass
|
||||
|
||||
async def _call_dalle(self, prompt):
|
||||
headers = { "Content-Type": "application/json" }
|
||||
data = {
|
||||
"model": self.config["model"],
|
||||
"prompt": prompt
|
||||
}
|
||||
|
||||
response = requests.post("https://nexra.aryahcr.cc/api/image/complements", headers=headers, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
self.log.warning(f"Unexpected status sending request to nexra.aryahcr.cc: {response.status_code}")
|
||||
return
|
||||
|
||||
count = -1
|
||||
for i in range(len(response.text)):
|
||||
if count <= -1:
|
||||
if response.text[i] == "{":
|
||||
count = i
|
||||
else:
|
||||
break
|
||||
response_json = json.loads(response.text[count:])
|
||||
|
||||
content = response_json['images'][0]
|
||||
|
||||
base64_str = content.replace("data:image/jpeg;base64,", "")
|
||||
mem_file = io.BytesIO(base64.decodebytes(bytes(base64_str, "utf-8")))
|
||||
|
||||
pil_image = Image.open(mem_file)
|
||||
pil_image.save(mem_file, format=pil_image.format)
|
||||
mem_file.seek(0)
|
||||
|
||||
return mem_file
|
||||
|
11
maubot.yaml
Normal file
11
maubot.yaml
Normal file
|
@ -0,0 +1,11 @@
|
|||
maubot: 0.1.0
|
||||
id: sh.boehm.dalle
|
||||
version: 0.0.029
|
||||
license: MIT
|
||||
modules:
|
||||
- dalle
|
||||
main_class: dalle/Dalle
|
||||
database: false
|
||||
config: true
|
||||
extra_files:
|
||||
- base-config.yaml
|
Loading…
Reference in a new issue