diff --git a/base-config.yaml b/base-config.yaml index 45b39df..8680723 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -3,3 +3,5 @@ name: "@dalle:boehm.sh" nickname: "DALL-E" model: "dalle" reply_in_thread: True +redis_host: "192.168.1.242" +redis_port: "6379" diff --git a/gpt/__init__.py b/gpt/__init__.py new file mode 100644 index 0000000..22acd5e --- /dev/null +++ b/gpt/__init__.py @@ -0,0 +1 @@ +from .gpt import Gpt diff --git a/gpt/gpt.py b/gpt/gpt.py new file mode 100644 index 0000000..800f8ab --- /dev/null +++ b/gpt/gpt.py @@ -0,0 +1,140 @@ +import json +import urllib +import requests +import io +import base64 +import asyncio +import requests +import re + +from PIL import Image +from typing import Type, Deque, Dict, Generator +from mautrix.types import ImageInfo, EventType, MessageType, RelationType +from mautrix.types.event.message import BaseFileInfo, Format, TextMessageEventContent +from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper +from mautrix.util import markdown +from maubot import Plugin, MessageEvent +from maubot.handlers import event +from mautrix.client import Client +from .history import History + +class Config(BaseProxyConfig): + def do_update(self, helper: ConfigUpdateHelper) -> None: + helper.copy("respond_to_notice") + helper.copy("reply_in_thread") + helper.copy("model") + helper.copy("name") + helper.copy("nickname") + helper.copy("redis_host") + helper.copy("redis_port") + +class Gpt(Plugin): + + name: str + nickname: str + history: History + + async def start(self) -> None: + self.config.load_and_update() + self.name = self.config["name"] + self.nickname = self.config["nickname"] + self.history = History(self.config["redis_host"], self.config["redis_port"]) + + @classmethod + def get_config_class(cls) -> Type[BaseProxyConfig]: + return Config + + async def should_respond(self, event: MessageEvent) -> bool: + if (event.sender == self.client.mxid or + event.content.body.startswith("!") or + event.content["msgtype"] != MessageType.TEXT or + event.content.relates_to["rel_type"] == RelationType.REPLACE): + return False + + # Check if user is using element + # Check if bot is mentioned + if "m.mentions" in event.content: + # case if element x and desktop element + if re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content["m.mentions"]["user_ids"][0], re.IGNORECASE): + return True + # most other clients + elif re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content.body, re.IGNORECASE): + return True + # element + elif re.search("(^|\s)(@)?" + self.nickname + "([ :,.!?]|$)", event.content.body, re.IGNORECASE): + return True + + # Reply to all DMs + if len(await self.client.get_joined_members(event.room_id)) == 2: + return True + + return False + + @event.on(EventType.ROOM_MESSAGE) + async def on_message(self, event: MessageEvent) -> None: + if not await self.should_respond(event): + return + + if "m.mentions" in event.content: + # case if element x and desktop element + if re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content["m.mentions"]["user_ids"][0], re.IGNORECASE): + event.content.body = event.content.body.replace(self.nickname, "") + # most other clients + elif re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content.body, re.IGNORECASE): + event.content.body = event.content.body.replace(self.name, "") + # element + elif re.search(self.nickname + "([ :,.!?]|$)", event.content.body, re.IGNORECASE): + event.content.body = event.content.body.replace(self.nickname, "") + + + try: + await event.mark_read() + + # Call the GPT API to get picture + await self.client.set_typing(event.room_id, timeout=99999) + messages = await self.history.get(event) + print(messages) + response = await self._call_gpt(event.content["body"], messages) + + + # Send the repond back to the chat room + content = TextMessageEventContent(msgtype=MessageType.NOTICE, body=response, format=Format.HTML, formatted_body=markdown.render(response)) + await event.respond(content, in_thread=self.config['reply_in_thread']) + + # Reset our typing status + await self.client.set_typing(event.room_id, timeout=0) + + # Send this to our cache + await self.history.add(event, event.content["body"], response) + except Exception as e: + await self.client.set_typing(event.room_id, timeout=0) + self.log.exception(f"We have failed somewhere? {e}") + pass + + async def _call_gpt(self, prompt, messages): + headers = { "Content-Type": "application/json" } + data = { + "messages": messages, + "model": self.config["model"], + "prompt": prompt + } + + response = requests.post("https://nexra.aryahcr.cc/api/chat/gpt", headers=headers, data=json.dumps(data)) + + if response.status_code != 200: + self.log.warning(f"Unexpected status sending request to nexra.aryahcr.cc: {response.status_code}") + return + + count = -1 + for i in range(len(response.text)): + if count <= -1: + if response.text[i] == "{": + count = i + else: + break + response_json = json.loads(response.text[count:]) + + content = response_json['gpt'] + + return content + diff --git a/gpt/history.py b/gpt/history.py new file mode 100644 index 0000000..0faaea7 --- /dev/null +++ b/gpt/history.py @@ -0,0 +1,41 @@ +import redis.asyncio as redis +import json + +class History(): + r: redis.Redis() + + def __init__(self, redis_host, redis_port): + self.r = redis.Redis(host=redis_host, port=redis_port, db=0) + + def createCacheKey(self, id): + return f"gpt-history-user-{id}" + + async def reset(self, userData): + key = self.createCacheKey(userData["sender"]) + await self.r.mset(key, null) + + async def get(self, userData): + key = self.createCacheKey(userData["sender"]) + history = await self.r.get(key) + + if history is not None: + history = json.loads(history) + return history + else: + return [] + + + async def add(self, userData, userMessage, assistantMessage): + history = await self.get(userData) + history.append({ "role": "user", "content": userMessage }) + history.append({ "role": "assistant", "content": assistantMessage }) + + key = self.createCacheKey(userData["sender"]) + history = json.dumps(history) + await self.r.psetex(key, 300_000, history) # 5 mins + + async def dump(self, userData): + history = await self.get(userData) + if (len(history) == 0): + return "You have no ChatGPT history at the moment." + # Going to add this at a later point via haste.snrd.eu diff --git a/maubot.yaml b/maubot.yaml index 68793cc..8695415 100644 --- a/maubot.yaml +++ b/maubot.yaml @@ -1,10 +1,10 @@ maubot: 0.1.0 id: sh.boehm.gpt -version: 0.0.004 +version: 0.0.008 license: MIT modules: - gpt -main_class: gpt/Gpt +main_class: Gpt database: false config: true extra_files: