feat: 完成基本功能

This commit is contained in:
Rock Chin
2022-12-08 12:06:04 +08:00
parent 9ca641a2e5
commit 77ec1c7ff0
5 changed files with 53 additions and 16 deletions

View File

@@ -1,5 +1,6 @@
mirai_http_api_config = { mirai_http_api_config = {
"host": "", "host": "",
"port": 8080,
"verifyKey": "", "verifyKey": "",
"qq": 0 "qq": 0
} }

17
main.py
View File

@@ -1,10 +1,13 @@
import os import os
import shutil import shutil
import sys import sys
import threading
import time
import pkg.openai.manager import pkg.openai.manager
import pkg.database.manager import pkg.database.manager
import pkg.openai.session import pkg.openai.session
import pkg.qqbot.manager
def init_db(): def init_db():
@@ -32,9 +35,23 @@ def main():
# 加载所有未超时的session # 加载所有未超时的session
pkg.openai.session.load_sessions() pkg.openai.session.load_sessions()
# 初始化qq机器人
qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config,
timeout=config.process_message_timeout, retry=config.retry_times)
qq_bot_thread = threading.Thread(target=qqbot.bot.run, args=(), daemon=True)
qq_bot_thread.start()
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) > 1 and sys.argv[1] == 'init_db': if len(sys.argv) > 1 and sys.argv[1] == 'init_db':
init_db() init_db()
sys.exit(0) sys.exit(0)
main() main()
while True:
try:
time.sleep(86400)
except KeyboardInterrupt:
print("程序退出")
break

View File

@@ -1,6 +1,7 @@
import time import time
import pymysql import pymysql
from pymysql.converters import escape_string
import config import config
@@ -61,12 +62,13 @@ class DatabaseManager:
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`) insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`)
values ('{}', '{}', {}, {}, {}, '{}') values ('{}', '{}', {}, {}, {}, '{}')
""".format("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, """.format("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, prompt)) last_interact_timestamp, escape_string(prompt)))
else: else:
self.cursor.execute(""" self.cursor.execute("""
update `sessions` set `last_interact_timestamp` = {}, `prompt` = '{}' update `sessions` set `last_interact_timestamp` = {}, `prompt` = '{}'
where `type` = '{}' and `number` = {} and `create_timestamp` = {} where `type` = '{}' and `number` = {} and `create_timestamp` = {}
""".format(last_interact_timestamp, prompt, subject_type, subject_number, create_timestamp)) """.format(last_interact_timestamp, escape_string(prompt), subject_type,
subject_number, create_timestamp))
# 记载还没过期的session数据 # 记载还没过期的session数据
def load_valid_sessions(self) -> dict: def load_valid_sessions(self) -> dict:

View File

@@ -88,3 +88,16 @@ class Session:
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
self.prompt) self.prompt)
def reset(self):
if self.prompt != '':
self.persistence()
self.prompt = ''
self.create_timestamp = int(time.time())
self.last_interact_timestamp = 0
def last_session(self):
pass
def next_session(self):
pass

View File

@@ -2,7 +2,7 @@ from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage,
import pkg.openai.session import pkg.openai.session
from func_timeout import func_set_timeout, FunctionTimedOut from func_timeout import func_set_timeout, FunctionTimedOut
help_text = """ help_text = """帮助信息:
!help - 显示帮助 !help - 显示帮助
!reset - 重置会话 !reset - 重置会话
!last - 切换到上一次的对话 !last - 切换到上一次的对话
@@ -13,6 +13,7 @@ inst = None
processing = [] processing = []
class QQBotManager: class QQBotManager:
timeout = 60 timeout = 60
retry = 3 retry = 3
@@ -29,7 +30,7 @@ class QQBotManager:
adapter=WebSocketAdapter( adapter=WebSocketAdapter(
verify_key=mirai_http_api_config['verifyKey'], verify_key=mirai_http_api_config['verifyKey'],
host=mirai_http_api_config['host'], host=mirai_http_api_config['host'],
port=8080 port=mirai_http_api_config['port']
) )
) )
@@ -56,26 +57,27 @@ class QQBotManager:
reply = '' reply = ''
session_name = "{}_{}".format(launcher_type, launcher_id) session_name = "{}_{}".format(launcher_type, launcher_id)
if text_message.startswith('!'): # 指令 if text_message.startswith('!') or text_message.startswith(""): # 指令
cmd = text_message cmd = text_message[1:].strip()
if cmd == '!help': if cmd == 'help':
reply = help_text reply = "[bot]" + help_text
elif cmd == '!reset': elif cmd == 'reset':
pkg.openai.session.get_session(session_name).reset()
reply = "[bot]会话已重置"
elif cmd == 'last':
pass pass
elif cmd == '!last': elif cmd == 'next':
pass
elif cmd == '!next':
pass pass
else: # 消息 else: # 消息
session = pkg.openai.session.get_session(session_name) session = pkg.openai.session.get_session(session_name)
reply = session.append(text_message) reply = "[GPT]" + session.append(text_message)
return reply return reply
async def on_person_message(self, event: MessageEvent): async def on_person_message(self, event: MessageEvent):
if "person_{}".format(event.sender.id) in processing: if "person_{}".format(event.sender.id) in processing:
return return await self.bot.send(event, "err:正在处理中,请稍后再试")
reply = '' reply = ''
@@ -107,7 +109,7 @@ class QQBotManager:
async def on_group_message(self, event: GroupMessage): async def on_group_message(self, event: GroupMessage):
if "group_{}".format(event.group.id) in processing: if "group_{}".format(event.group.id) in processing:
return return await self.bot.send(event, "err:正在处理中,请稍后再试")
reply = '' reply = ''
@@ -116,13 +118,15 @@ class QQBotManager:
elif At(self.bot.qq) not in event.message_chain: elif At(self.bot.qq) not in event.message_chain:
pass pass
else: else:
event.message_chain.remove(At(self.bot.qq))
processing.append("group_{}".format(event.sender.id)) processing.append("group_{}".format(event.sender.id))
# 超时则重试,重试超过次数则放弃 # 超时则重试,重试超过次数则放弃
failed = 0 failed = 0
for i in range(self.retry): for i in range(self.retry):
try: try:
reply = self.process_message('group', event.group.id, str(event.message_chain)) reply = self.process_message('group', event.group.id, str(event.message_chain).strip())
break break
except FunctionTimedOut: except FunctionTimedOut:
failed += 1 failed += 1