适配线程版本

This commit is contained in:
chordfish
2023-03-09 17:56:57 +08:00
parent 357d6aaf75
commit d7d9d88e16
9 changed files with 132 additions and 83 deletions

View File

@@ -93,12 +93,15 @@ filter_ai_warning = False
# 群内响应规则
# 符合此消息的群内消息即使不包含at机器人也会响应
# 支持消息前缀匹配及正则表达式匹配
# 支持设置是否响应at消息、随机响应概率
# 注意:由消息前缀(prefix)匹配的消息中将会删除此前缀,正则表达式(regexp)匹配的消息不会删除匹配的部分
# 前缀匹配优先级高于正则表达式匹配
# 正则表达式简明教程https://www.runoob.com/regexp/regexp-tutorial.html
response_rules = {
"at": True, # 是否响应at机器人的消息
"prefix": ["/ai", "!ai", "ai", "ai"],
"regexp": [] # "为什么.*", "怎么?样.*", "怎么.*", "如何.*", "[Hh]ow to.*", "[Ww]hy not.*", "[Ww]hat is.*", ".*怎么办", ".*咋办"
"regexp": [], # "为什么.*", "怎么?样.*", "怎么.*", "如何.*", "[Hh]ow to.*", "[Ww]hy not.*", "[Ww]hat is.*", ".*怎么办", ".*咋办"
"random_rate": 0.0, # 随机响应概率0.0-1.00.0为不随机响应1.0为响应所有消息, 仅在前几项判断不通过时生效
}
# 消息忽略规则
@@ -213,6 +216,11 @@ hide_exce_info_to_user = False
# 设置为空字符串时,不发送提示信息
alter_tip_message = '出错了,请稍后再试'
# 机器人线程池大小
# 该参数决定机器人可以同时处理几个人的消息,超出线程池数量的请求会被阻塞,不会被丢弃
# 如果你不清楚该参数的意义,请不要更改
pool_num = 10
# 每个会话的过期时间,单位为秒
# 默认值20分钟
session_expire_time = 60 * 20

69
main.py
View File

@@ -45,7 +45,9 @@ def init_db():
def ensure_dependencies():
import pkg.utils.pkgmgr as pkgmgr
pkgmgr.run_pip(["install", "openai", "Pillow", "--upgrade"])
pkgmgr.run_pip(["install", "openai", "Pillow", "--upgrade",
"-i", "https://pypi.douban.com/simple/",
"--trusted-host", "pypi.douban.com"])
known_exception_caught = False
@@ -105,6 +107,8 @@ def reset_logging():
def main(first_time_init=False):
"""启动流程reload之后会被执行"""
global known_exception_caught
import config
@@ -127,13 +131,26 @@ def main(first_time_init=False):
config = importlib.import_module('config')
import pkg.utils.context
pkg.utils.context.set_config(config)
init_runtime_log_file()
sh = reset_logging()
# 配置完整性校验
is_integrity = True
config_template = importlib.import_module('config-template')
for key in dir(config_template):
if not key.startswith("__") and not hasattr(config, key):
setattr(config, key, getattr(config_template, key))
logging.warning("[{}]不存在".format(key))
is_integrity = False
if not is_integrity:
logging.warning("配置文件不完整请依据config-template.py检查config.py")
logging.warning("以上配置已被设为默认值将在5秒后继续启动... ")
time.sleep(5)
import pkg.utils.context
pkg.utils.context.set_config(config)
# 检查是否设置了管理员
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
# logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
@@ -180,7 +197,7 @@ def main(first_time_init=False):
# 初始化qq机器人
qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config,
timeout=config.process_message_timeout, retry=config.retry_times,
first_time_init=first_time_init)
first_time_init=first_time_init, pool_num=config.pool_num)
# 加载插件
import pkg.plugin.host
@@ -188,7 +205,7 @@ def main(first_time_init=False):
pkg.plugin.host.initialize_plugins()
if first_time_init: # 不是热重载之后的启动,则启动新的bot线程
if first_time_init: # 不是热重载之后的启动,则启动新的bot线程
import mirai.exceptions
@@ -277,17 +294,7 @@ def main(first_time_init=False):
except Exception as e:
logging.warning("检查更新失败:{}".format(e))
while True:
try:
time.sleep(10)
if qqbot != pkg.utils.context.get_qqbot_manager(): # 已经reload了
logging.info("以前的main流程由于reload退出")
break
except KeyboardInterrupt:
stop()
print("程序退出")
sys.exit(0)
return qqbot
def stop():
@@ -340,19 +347,9 @@ if __name__ == '__main__':
sys.exit(0)
elif len(sys.argv) > 1 and sys.argv[1] == 'update':
try:
try:
import pkg.utils.pkgmgr
pkg.utils.pkgmgr.ensure_dulwich()
except:
pass
from dulwich import porcelain
repo = porcelain.open_repo('.')
porcelain.pull(repo)
except ModuleNotFoundError:
print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
print("正在进行程序更新...")
import pkg.utils.updater as updater
updater.update_all(cli=True)
sys.exit(0)
# import pkg.utils.configmgr
@@ -360,4 +357,14 @@ if __name__ == '__main__':
# pkg.utils.configmgr.set_config_and_reload("quote_origin", False)
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
main(True)
qqbot = main(True)
import pkg.utils.context
while True:
try:
time.sleep(10)
except KeyboardInterrupt:
stop()
print("程序退出")
sys.exit(0)

Binary file not shown.

View File

@@ -4,7 +4,6 @@
"""
import logging
import os
import threading
import time
import json
@@ -21,8 +20,6 @@ import pkg.plugin.models as plugin_models
sessions = {}
class SessionOfflineStatus:
ON_GOING = 'on_going'
EXPLICITLY_CLOSED = 'explicitly_closed'
@@ -189,8 +186,6 @@ class Session:
self.schedule()
self.response_lock = threading.Lock()
self.bot_name = 'ai'
self.bot_filter = None
self.prompt = self.get_default_prompt()
# 设定检查session最后一次对话是否超过过期时间的计时器
@@ -235,7 +230,7 @@ class Session:
self.last_interact_timestamp = int(time.time())
# 触发插件事件
if self.prompt == self.get_default_prompt(get_only=True):
if self.prompt == self.get_default_prompt(get_only = True):
args = {
'session_name': self.name,
'session': self,
@@ -264,23 +259,10 @@ class Session:
del (res_ans_spt[0])
res_ans = '\n\n'.join(res_ans_spt)
#检测是否包含ai人格否定
logging.debug('bot_filter: {}'.format(self.bot_filter))
if config.filter_ai_warning and self.bot_filter:
import re
match = re.search(self.bot_filter['reg'], res_ans)
logging.debug(self.bot_filter)
logging.debug(res_ans)
if match:
logging.debug('回复:{} 检测到人格否定,替换中。。'.format(res_ans))
res_ans = self.bot_filter['replace']
logging.debug('替换为: {}'.format(res_ans))
# 将此次对话的双方内容加入到prompt中
self.prompt.append({'role': 'user', 'content': text})
self.prompt.append({'role': 'assistant', 'content': res_ans})
if self.just_switched_to_exist_session:
self.just_switched_to_exist_session = False
self.set_ongoing()
@@ -329,7 +311,7 @@ class Session:
# 持久化session
def persistence(self):
if self.prompt == self.get_default_prompt(get_only=True):
if self.prompt == self.get_default_prompt(get_only = True):
return
db_inst = pkg.utils.context.get_database_manager()

View File

@@ -7,6 +7,8 @@ import logging
class ReplyFilter:
sensitive_words = []
mask = "*"
mask_word = ""
# 默认值( 兼容性考虑 )
baidu_check = False
@@ -14,8 +16,10 @@ class ReplyFilter:
baidu_secret_key = ""
inappropriate_message_tips = "[百度云]请珍惜机器人,当前返回内容不合规"
def __init__(self, sensitive_words: list):
def __init__(self, sensitive_words: list, mask: str = "*", mask_word: str = ""):
self.sensitive_words = sensitive_words
self.mask = mask
self.mask_word = mask_word
import config
if hasattr(config, 'baidu_check') and hasattr(config, 'baidu_api_key') and hasattr(config, 'baidu_secret_key'):
self.baidu_check = config.baidu_check
@@ -36,7 +40,10 @@ class ReplyFilter:
match = re.findall(word, message)
if len(match) > 0:
for i in range(len(match)):
message = message.replace(match[i], "*" * len(match[i]))
if self.mask_word == "":
message = message.replace(match[i], self.mask * len(match[i]))
else:
message = message.replace(match[i], self.mask_word)
# 百度云审核
if self.baidu_check:

View File

@@ -2,6 +2,7 @@ import asyncio
import json
import os
import threading
from concurrent.futures import ThreadPoolExecutor
import mirai.models.bus
from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
@@ -21,12 +22,6 @@ import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
# 并行运行
def go(func, args=()):
thread = threading.Thread(target=func, args=args, daemon=True)
thread.start()
# 检查消息是否符合泛响应匹配机制
def check_response_rule(text: str, event):
config = pkg.utils.context.get_config()
@@ -41,7 +36,6 @@ def check_response_rule(text: str, event):
import re
if re.search(bot_name, text):
return True, text
rules = config.response_rules
# 检查前缀匹配
if 'prefix' in rules:
@@ -60,10 +54,29 @@ def check_response_rule(text: str, event):
return False, ""
def response_at():
config = pkg.utils.context.get_config()
if 'at' not in config.response_rules:
return True
return config.response_rules['at']
def random_responding():
config = pkg.utils.context.get_config()
if 'random_rate' in config.response_rules:
import random
return random.random() < config.response_rules['random_rate']
return False
# 控制QQ消息输入输出的类
class QQBotManager:
retry = 3
#线程池控制
pool = None
bot: Mirai = None
reply_filter = None
@@ -73,11 +86,14 @@ class QQBotManager:
ban_person = []
ban_group = []
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True):
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, pool_num: int = 10, first_time_init=True):
self.timeout = timeout
self.retry = retry
self.pool_num = pool_num
self.pool = ThreadPoolExecutor(max_workers=self.pool_num)
logging.debug("Registered thread pool Size:{}".format(pool_num))
# 加载禁用列表
if os.path.exists("banlist.py"):
import banlist
@@ -91,7 +107,12 @@ class QQBotManager:
and config.sensitive_word_filter is not None \
and config.sensitive_word_filter:
with open("sensitive.json", "r", encoding="utf-8") as f:
self.reply_filter = pkg.qqbot.filter.ReplyFilter(json.load(f)['words'])
sensitive_json = json.load(f)
self.reply_filter = pkg.qqbot.filter.ReplyFilter(
sensitive_words=sensitive_json['words'],
mask=sensitive_json['mask'] if 'mask' in sensitive_json else '*',
mask_word=sensitive_json['mask_word'] if 'mask_word' in sensitive_json else ''
)
else:
self.reply_filter = pkg.qqbot.filter.ReplyFilter([])
@@ -125,7 +146,7 @@ class QQBotManager:
self.on_person_message(event)
go(friend_message_handler, (event,))
self.go(friend_message_handler, event)
@self.bot.on(StrangerMessage)
async def on_stranger_message(event: StrangerMessage):
@@ -145,7 +166,7 @@ class QQBotManager:
self.on_person_message(event)
go(stranger_message_handler, (event,))
self.go(stranger_message_handler, event)
@self.bot.on(GroupMessage)
async def on_group_message(event: GroupMessage):
@@ -165,7 +186,7 @@ class QQBotManager:
self.on_group_message(event)
go(group_message_handler, (event,))
self.go(group_message_handler, event)
def unsubscribe_all():
"""取消所有订阅
@@ -182,6 +203,9 @@ class QQBotManager:
self.unsubscribe_all = unsubscribe_all
def go(self, func, *args, **kwargs):
self.pool.submit(func, *args, **kwargs)
def first_time_init(self, mirai_http_api_config: dict):
"""热重载后不再运行此函数"""
@@ -297,13 +321,18 @@ class QQBotManager:
if Image in event.message_chain:
pass
elif At(self.bot.qq) not in event.message_chain:
else:
if At(self.bot.qq) in event.message_chain and response_at():
# 直接调用
reply = process()
else:
check, result = check_response_rule(str(event.message_chain).strip(), event)
if check:
reply = process(result.strip())
else:
# 直接调用
# 检查是否随机响应
elif random_responding():
logging.info("随机响应group_{}消息".format(event.group.id))
reply = process()
if reply:

File diff suppressed because one or more lines are too long

View File

@@ -54,7 +54,7 @@ def get_current_tag() -> str:
return current_tag
def update_all() -> bool:
def update_all(cli: bool = False) -> bool:
"""检查更新并下载源码"""
current_tag = get_current_tag()
@@ -69,12 +69,19 @@ def update_all() -> bool:
if latest_rls == {}:
latest_rls = rls
if not cli:
logging.info("更新日志: {}".format(rls_notes))
else:
print("更新日志: {}".format(rls_notes))
if latest_rls == {}: # 没有新版本
return False
# 下载最新版本的zip到temp目录
if not cli:
logging.info("开始下载最新版本: {}".format(latest_rls['zipball_url']))
else:
print("开始下载最新版本: {}".format(latest_rls['zipball_url']))
zip_url = latest_rls['zipball_url']
zip_resp = requests.get(url=zip_url)
zip_data = zip_resp.content
@@ -87,7 +94,10 @@ def update_all() -> bool:
with open("temp/updater/{}.zip".format(latest_rls['tag_name']), "wb") as f:
f.write(zip_data)
if not cli:
logging.info("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name'])))
else:
print("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name'])))
# 解压zip到temp/updater/<tag_name>/
import zipfile
@@ -124,8 +134,11 @@ def update_all() -> bool:
f.write(current_tag)
# 通知管理员
if not cli:
import pkg.utils.context
pkg.utils.context.get_qqbot_manager().notify_admin("已更新到最新版本: {}\n更新日志:\n{}\n新功能通常可以在config-template.py中看到完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看".format(current_tag, "\n".join(rls_notes)))
else:
print("已更新到最新版本: {}\n更新日志:\n{}\n新功能通常可以在config-template.py中看到完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看".format(current_tag, "\n".join(rls_notes)))
return True

View File

@@ -1,4 +1,7 @@
{
"说明": "mask将替换敏感词中的每一个字若mask_word值不为空则将敏感词整个替换为mask_word的值",
"mask": "*",
"mask_word": "",
"words": [
"习近平",
"胡锦涛",