Compare commits

...

81 Commits

Author SHA1 Message Date
RockChinQ
45a10b4ac7 chore: release v3.2.4 2024-07-05 18:19:10 +08:00
RockChinQ
b5d33ef629 perf: 优化 pipeline 处理时的报错 2024-07-04 13:03:58 +08:00
RockChinQ
d3629916bf fix: user_notice 处理时为对齐为 MessageChain (#809) 2024-07-04 12:47:55 +08:00
RockChinQ
c5cb26d295 fix: GroupNormalMessageReceived事件设置 alter 无效 (#803) 2024-07-03 23:16:16 +08:00
RockChinQ
4b2785c5eb fix: QQ 官方 API 图片识别功能不正常 (#825) 2024-07-03 22:36:35 +08:00
RockChinQ
7ed190e6d2 doc: 删除广告 2024-07-03 17:50:58 +08:00
Junyan Qin
eac041cdd2 Merge pull request #834 from RockChinQ/feat/env-reminder
Feat: 添加启动信息阶段
2024-07-03 17:45:56 +08:00
RockChinQ
05527cfc01 feat: 添加 windows 下针对选择模式的提示 2024-07-03 17:44:10 +08:00
RockChinQ
61e2af4a14 feat: 添加启动信息阶段 2024-07-03 17:34:23 +08:00
RockChinQ
79804b6ecd chore: release 3.2.3 2024-06-26 10:55:21 +08:00
Junyan Qin
76434b2f4e Merge pull request #829 from RockChinQ/version/3.2.3
Release 3.2.3
2024-06-26 10:54:42 +08:00
RockChinQ
ec8bd4922e fix: 错误地resprule选择逻辑 (#810) 2024-06-26 10:37:08 +08:00
RockChinQ
4ffa773fac fix: 前缀响应时图片被错误地转换为文字 (#820) 2024-06-26 10:15:21 +08:00
Junyan Qin
ea8b7bc8aa Merge pull request #818 from Huoyuuu/master
fix: ensure content is string in chatcmpl call method
2024-06-24 17:12:46 +08:00
RockChinQ
39ce5646f6 perf: content元素拼接时使用换行符间隔 2024-06-24 17:04:50 +08:00
Huoyuuu
5092a82739 Update chatcmpl.py 2024-06-19 19:13:00 +08:00
Huoyuuu
3bba0b6d9a Merge pull request #1 from Huoyuuu/fix/issue-817-ensure-content-string
fix: ensure content is string in chatcmpl call method
2024-06-19 17:32:30 +08:00
Huoyuuu
7a19dd503d fix: ensure content is string in chatcmpl call method
fix: ensure content is string in chatcmpl call method

- Ensure user message content is a string instead of an array
- Updated `call` method in `chatcmpl.py` to guarantee content is a string
- Resolves compatibility issue with the yi-large model
2024-06-19 17:26:06 +08:00
RockChinQ
9e6a01fefd chore: release v3.2.2 2024-05-31 19:20:34 +08:00
RockChinQ
933471b4d9 perf: 启动失败时输出完整traceback (#799) 2024-05-31 15:37:56 +08:00
RockChinQ
f81808d239 perf: 添加JSON配置文件语法检查 (#796) 2024-05-29 21:11:21 +08:00
RockChinQ
96832b6f7d perf: 忽略空的 assistant content 消息 (#795) 2024-05-29 21:00:48 +08:00
Junyan Qin
e2eb0a84b0 Merge pull request #797 from RockChinQ/feat/context-truncater
Feat: 消息截断器
2024-05-29 20:38:14 +08:00
RockChinQ
c8eb2e3376 feat: 消息截断器 2024-05-29 20:34:49 +08:00
Junyan Qin
21fe5822f9 Merge pull request #794 from RockChinQ/perf/advanced-fixwin
Feat: fixwin限速支持设置窗口大小
2024-05-26 10:33:49 +08:00
RockChinQ
d49cc9a7a3 feat: fixwin限速支持设置窗口大小 (#791) 2024-05-26 10:29:10 +08:00
Junyan Qin
910d0bfae1 Update README.md 2024-05-25 12:27:27 +08:00
RockChinQ
d6761949ca chore: release v3.2.1 2024-05-23 16:29:26 +08:00
RockChinQ
6afac1f593 feat: 允许指定遥测服务器url 2024-05-23 16:25:51 +08:00
RockChinQ
4d1a270d22 doc: 添加qcg-center源码链接 2024-05-23 16:16:13 +08:00
Junyan Qin
a7888f5536 Merge pull request #787 from RockChinQ/perf/claude-ability
Perf: Claude 的能力完善支持
2024-05-22 20:33:39 +08:00
RockChinQ
b9049e91cf chore: 同步 llm-models.json 2024-05-22 20:31:46 +08:00
RockChinQ
7db56c8e77 feat: claude 支持视觉 2024-05-22 20:09:29 +08:00
Junyan Qin
50563cb957 Merge pull request #785 from RockChinQ/fix/msg-chain-compability
Fix: 修复 query.resp_messages 对插件reply的兼容性
2024-05-18 20:13:50 +08:00
RockChinQ
18ae2299a7 fix: 修复 query.resp_messages 对插件reply的兼容性 2024-05-18 20:08:48 +08:00
RockChinQ
7463e0aab9 perf: 删除多个地方残留的 config.py 字段 (#781) 2024-05-18 18:52:45 +08:00
Junyan Qin
c92d47bb95 Merge pull request #779 from jerryliang122/master
修复aiocqhttp的图片错误
2024-05-17 17:05:58 +08:00
RockChinQ
0b1af7df91 perf: 统一判断方式 2024-05-17 17:05:20 +08:00
jerryliang122
a9104eb2da 通过base64编码发送,修复cqhttp无法发送图片 2024-05-17 08:20:06 +00:00
RockChinQ
abbd15d5cc chore: release v3.2.0.1 2024-05-17 09:48:20 +08:00
RockChinQ
aadfa14d59 fix: claude 请求失败 2024-05-17 09:46:06 +08:00
Junyan Qin
4cd10bbe25 Update README.md 2024-05-16 22:17:46 +08:00
RockChinQ
1d4a6b71ab chore: release v3.2.0 2024-05-16 21:22:40 +08:00
Junyan Qin
a7f830dd73 Merge pull request #773 from RockChinQ/feat/multi-modal
Feat: 多模态
2024-05-16 21:13:15 +08:00
RockChinQ
bae86ac05c chore: 恢复版本号 2024-05-16 21:03:56 +08:00
RockChinQ
a3706bfe21 perf: 细节优化 2024-05-16 21:02:59 +08:00
RockChinQ
91e23b8c11 perf: 为图片base64函数添加lru 2024-05-16 20:52:17 +08:00
RockChinQ
37ef1c9fab feat: 删除oss相关代码 2024-05-16 20:32:30 +08:00
RockChinQ
6bc6f77af1 feat: 通过 base64 传输图片 2024-05-16 20:25:51 +08:00
RockChinQ
2c478ccc25 feat: 模型vision支持性参数 2024-05-16 20:11:54 +08:00
RockChinQ
404e5492a3 chore: 同步现有模型信息 2024-05-16 18:29:23 +08:00
RockChinQ
d5b5d667a5 feat: 模型视觉多模态支持 2024-05-15 21:40:18 +08:00
RockChinQ
8807f02f36 perf: resp_message_chain 改为 list 类型 (#770) 2024-05-14 23:08:49 +08:00
RockChinQ
269e561497 perf: messages 存回 conversation 应该仅在成功执行本次请求时执行 (#769) 2024-05-14 22:41:39 +08:00
RockChinQ
527ad81d38 feat: 解藕chat的处理器和请求器 (#772) 2024-05-14 22:20:31 +08:00
Junyan Qin
972d3c18af Update README.md 2024-05-08 21:49:45 +08:00
Junyan Qin
3cbfc078fc doc(README.md): 更新 社区四群群号 2024-05-08 21:46:19 +08:00
RockChinQ
fde6822b5c chore: release v3.1.1 2024-05-08 02:28:40 +00:00
Junyan Qin
930321bcf1 Merge pull request #762 from RockChinQ/feat/deepseek
Feat: 支持 deepseek 模型
2024-05-07 22:48:37 +08:00
RockChinQ
c45931363a feat: deepseek配置迁移 2024-05-07 14:45:59 +00:00
RockChinQ
9c6491e5ee feat: 支持 deepseek 的模型 2024-05-07 14:28:52 +00:00
RockChinQ
9bc248f5bc feat: 删除submit-messages-tokens配置项 2024-05-07 12:32:54 +00:00
Junyan Qin
becac2fde5 doc(README.md): 添加 GitHub Trending 徽标 2024-04-29 21:00:22 +08:00
RockChinQ
1e1a103882 feat: aiocqhttp允许使用图片链接作为参数 2024-04-11 03:26:12 +00:00
RockChinQ
e5cffb7c9b chore: release v3.1.0.4 2024-04-06 16:51:15 +08:00
RockChinQ
e2becf7777 feat: 删除父进程判断 (#750) 2024-04-06 16:50:35 +08:00
RockChinQ
a6b875a242 fix: GroupMessageReceived 事件参数错误 2024-04-04 16:50:45 +08:00
RockChinQ
b5e67f3df8 fix: 内容函数调用时错误地传递了RuntimeContainer 2024-04-04 15:08:40 +08:00
RockChinQ
2093fb16a7 chore: release v3.1.0.3 2024-04-02 22:33:36 +08:00
RockChinQ
fc9a9d2386 fix: 缺失的 psutil 依赖 2024-04-02 22:33:06 +08:00
RockChinQ
5e69f78f7e chore: 不再支持python 3.9 2024-04-01 18:16:49 +08:00
RockChinQ
6919bece77 chore: release v3.1.0.2 2024-03-31 14:41:32 +08:00
RockChinQ
8b003739f1 feat: message.content 支持 mirai.MessageChain 对象 (#741) 2024-03-31 14:38:15 +08:00
RockChinQ
2e9229a6ad fix: 工作目录必须在 main.py 目录 2024-03-30 21:34:22 +08:00
RockChinQ
5a3e7fe8ee perf: 禁止双击运行 2024-03-30 21:28:42 +08:00
RockChinQ
7b3d7e7bd6 fix: json配置文件错误的加载流程 2024-03-30 19:01:59 +08:00
Junyan Qin
fdd7c1864d feat(chatcmpl): 对函数调用进行异常捕获 (#749) 2024-03-30 09:45:30 +00:00
Junyan Qin
cac5a5adff fix(qq-botpy): 群内单query多回复时msg_seq重复问题 2024-03-30 02:58:37 +00:00
RockChinQ
63307633c2 feat: chatcmpl请求时也忽略空的 system prompt message (#745) 2024-03-29 17:34:09 +08:00
RockChinQ
387dfa39ff fix: 内容过滤无效 (#743) 2024-03-29 17:24:42 +08:00
Junyan Qin
1f797f899c doc(README.md): 添加使用量计数徽标 2024-03-26 15:25:08 +08:00
76 changed files with 1380 additions and 428 deletions

2
.gitignore vendored
View File

@@ -34,4 +34,4 @@ bard.json
res/instance_id.json res/instance_id.json
.DS_Store .DS_Store
/data /data
botpy.log botpy.log*

View File

@@ -2,34 +2,30 @@
<p align="center"> <p align="center">
<img src="https://qchatgpt.rockchin.top/logo.png" alt="QChatGPT" width="180" /> <img src="https://qchatgpt.rockchin.top/logo.png" alt="QChatGPT" width="180" />
</p> </p>
<div align="center"> <div align="center">
# QChatGPT # QChatGPT
<a href="https://trendshift.io/repositories/6187" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6187" alt="RockChinQ%2FQChatGPT | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/RockChinQ/QChatGPT)](https://github.com/RockChinQ/QChatGPT/releases/latest) [![GitHub release (latest by date)](https://img.shields.io/github/v/release/RockChinQ/QChatGPT)](https://github.com/RockChinQ/QChatGPT/releases/latest)
<a href="https://hub.docker.com/repository/docker/rockchin/qchatgpt"> <a href="https://hub.docker.com/repository/docker/rockchin/qchatgpt">
<img src="https://img.shields.io/docker/pulls/rockchin/qchatgpt?color=blue" alt="docker pull"> <img src="https://img.shields.io/docker/pulls/rockchin/qchatgpt?color=blue" alt="docker pull">
</a> </a>
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.qchatgpt.rockchin.top%2Fapi%2Fv2%2Fview%2Frealtime%2Fcount_query%3Fminute%3D10080&query=%24.data.count&label=%E4%BD%BF%E7%94%A8%E9%87%8F%EF%BC%887%E6%97%A5%EF%BC%89)
![Wakapi Count](https://wakapi.rockchin.top/api/badge/RockChinQ/interval:any/project:QChatGPT) ![Wakapi Count](https://wakapi.rockchin.top/api/badge/RockChinQ/interval:any/project:QChatGPT)
<a href="https://codecov.io/gh/RockChinQ/QChatGPT" >
<img src="https://codecov.io/gh/RockChinQ/QChatGPT/graph/badge.svg?token=pjxYIL2kbC"/>
</a>
<br/> <br/>
<img src="https://img.shields.io/badge/python-3.9 | 3.10 | 3.11-blue.svg" alt="python"> <img src="https://img.shields.io/badge/python-3.10 | 3.11 | 3.12-blue.svg" alt="python">
<a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=66-aWvn8cbP4c1ut_1YYkvvGVeEtyTH8&authKey=pTaKBK5C%2B8dFzQ4XlENf6MHTCLaHnlKcCRx7c14EeVVlpX2nRSaS8lJm8YeM4mCU&noverify=0&group_code=195992197"> <a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=66-aWvn8cbP4c1ut_1YYkvvGVeEtyTH8&authKey=pTaKBK5C%2B8dFzQ4XlENf6MHTCLaHnlKcCRx7c14EeVVlpX2nRSaS8lJm8YeM4mCU&noverify=0&group_code=195992197">
<img alt="Static Badge" src="https://img.shields.io/badge/%E5%AE%98%E6%96%B9%E7%BE%A4-195992197-purple"> <img alt="Static Badge" src="https://img.shields.io/badge/%E5%AE%98%E6%96%B9%E7%BE%A4-195992197-purple">
</a> </a>
<a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=nC80H57wmKPwRDLFeQrDDjVl81XuC21P&authKey=2wTUTfoQ5v%2BD4C5zfpuR%2BSPMDqdXgDXA%2FS2wHI1NxTfWIG%2B%2FqK08dgyjMMOzhXa9&noverify=0&group_code=248432104"> <a href="https://qm.qq.com/q/1yxEaIgXMA">
<img alt="Static Badge" src="https://img.shields.io/badge/%E7%A4%BE%E5%8C%BA%E7%BE%A4-248432104-purple"> <img alt="Static Badge" src="https://img.shields.io/badge/%E7%A4%BE%E5%8C%BA%E7%BE%A4-619154800-purple">
</a> </a>
<a href="https://www.bilibili.com/video/BV14h4y1w7TC"> <a href="https://codecov.io/gh/RockChinQ/QChatGPT" >
<img alt="Static Badge" src="https://img.shields.io/badge/%E8%A7%86%E9%A2%91%E6%95%99%E7%A8%8B-208647"> <img src="https://codecov.io/gh/RockChinQ/QChatGPT/graph/badge.svg?token=pjxYIL2kbC"/>
</a> </a>
<a href="https://www.bilibili.com/video/BV11h4y1y74H">
<img alt="Static Badge" src="https://img.shields.io/badge/Linux%E9%83%A8%E7%BD%B2%E8%A7%86%E9%A2%91-208647">
</a>
## 使用文档 ## 使用文档
<a href="https://qchatgpt.rockchin.top">项目主页</a> <a href="https://qchatgpt.rockchin.top">项目主页</a>
@@ -43,7 +39,8 @@
<a href="https://github.com/RockChinQ/qcg-installer">安装器源码</a> <a href="https://github.com/RockChinQ/qcg-installer">安装器源码</a>
<a href="https://github.com/RockChinQ/qcg-tester">测试工程源码</a> <a href="https://github.com/RockChinQ/qcg-tester">测试工程源码</a>
<a href="https://github.com/RockChinQ/qcg-center">遥测服务端源码</a>
<a href="https://github.com/the-lazy-me/QChatGPT-Wiki">官方文档储存库</a> <a href="https://github.com/the-lazy-me/QChatGPT-Wiki">官方文档储存库</a>
<img alt="回复效果(带有联网插件)" src="https://qchatgpt.rockchin.top/assets/image/QChatGPT-1211.png" width="500px"/> <img alt="回复效果(带有联网插件)" src="https://qchatgpt.rockchin.top/assets/image/QChatGPT-0516.png" width="500px"/>
</div> </div>

18
main.py
View File

@@ -1,5 +1,6 @@
# QChatGPT 终端启动入口 # QChatGPT 终端启动入口
# 在此层级解决依赖项检查。 # 在此层级解决依赖项检查。
# QChatGPT/main.py
asciiart = r""" asciiart = r"""
___ ___ _ _ ___ ___ _____ ___ ___ _ _ ___ ___ _____
@@ -49,6 +50,23 @@ async def main_entry():
if __name__ == '__main__': if __name__ == '__main__':
import os
# 检查本目录是否有main.py且包含QChatGPT字符串
invalid_pwd = False
if not os.path.exists('main.py'):
invalid_pwd = True
else:
with open('main.py', 'r', encoding='utf-8') as f:
content = f.read()
if "QChatGPT/main.py" not in content:
invalid_pwd = True
if invalid_pwd:
print("请在QChatGPT项目根目录下以命令形式运行此程序。")
input("按任意键退出...")
exit(0)
import asyncio import asyncio
asyncio.run(main_entry()) asyncio.run(main_entry())

View File

@@ -9,8 +9,6 @@ from .groups import plugin
from ...core import app from ...core import app
BACKEND_URL = "https://api.qchatgpt.rockchin.top/api/v2"
class V2CenterAPI: class V2CenterAPI:
"""中央服务器 v2 API 交互类""" """中央服务器 v2 API 交互类"""
@@ -23,7 +21,7 @@ class V2CenterAPI:
plugin: plugin.V2PluginDataAPI = None plugin: plugin.V2PluginDataAPI = None
"""插件 API 组""" """插件 API 组"""
def __init__(self, ap: app.Application, basic_info: dict = None, runtime_info: dict = None): def __init__(self, ap: app.Application, backend_url: str, basic_info: dict = None, runtime_info: dict = None):
"""初始化""" """初始化"""
logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info) logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info)
@@ -31,7 +29,7 @@ class V2CenterAPI:
apigroup.APIGroup._basic_info = basic_info apigroup.APIGroup._basic_info = basic_info
apigroup.APIGroup._runtime_info = runtime_info apigroup.APIGroup._runtime_info = runtime_info
self.main = main.V2MainDataAPI(BACKEND_URL, ap) self.main = main.V2MainDataAPI(backend_url, ap)
self.usage = usage.V2UsageDataAPI(BACKEND_URL, ap) self.usage = usage.V2UsageDataAPI(backend_url, ap)
self.plugin = plugin.V2PluginDataAPI(BACKEND_URL, ap) self.plugin = plugin.V2PluginDataAPI(backend_url, ap)

View File

@@ -13,11 +13,16 @@ class CommandReturn(pydantic.BaseModel):
"""命令返回值 """命令返回值
""" """
text: typing.Optional[str] text: typing.Optional[str] = None
"""文本 """文本
""" """
image: typing.Optional[mirai.Image] image: typing.Optional[mirai.Image] = None
"""弃用"""
image_url: typing.Optional[str] = None
"""图片链接
"""
error: typing.Optional[errors.CommandError]= None error: typing.Optional[errors.CommandError]= None
"""错误 """错误

View File

@@ -24,7 +24,7 @@ class DefaultOperator(operator.CommandOperator):
content = "" content = ""
for msg in prompt.messages: for msg in prompt.messages:
content += f" {msg.role}: {msg.content}" content += f" {msg.readable_str()}\n"
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n" reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
@@ -45,18 +45,18 @@ class DefaultSetOperator(operator.CommandOperator):
context: entities.ExecuteContext context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
else: else:
prompt_name = context.crt_params[0] prompt_name = context.crt_params[0]
try: try:
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
if prompt is None: if prompt is None:
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
else: else:
context.session.use_prompt_name = prompt.name context.session.use_prompt_name = prompt.name
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))

View File

@@ -30,7 +30,7 @@ class LastOperator(operator.CommandOperator):
context.session.using_conversation = context.session.conversations[index-1] context.session.using_conversation = context.session.conversations[index-1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}")
return return
else: else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -42,7 +42,7 @@ class ListOperator(operator.CommandOperator):
using_conv_index = index using_conv_index = index
if index >= page * record_per_page and index < (page + 1) * record_per_page: if index >= page * record_per_page and index < (page + 1) * record_per_page:
content += f"{index} {time_str}: {conv.messages[0].content if len(conv.messages) > 0 else '无内容'}\n" content += f"{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else '无内容'}\n"
index += 1 index += 1
if content == '': if content == '':
@@ -51,6 +51,6 @@ class ListOperator(operator.CommandOperator):
if context.session.using_conversation is None: if context.session.using_conversation is None:
content += "\n当前处于新会话" content += "\n当前处于新会话"
else: else:
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content if len(context.session.using_conversation.messages) > 0 else '无内容'}" content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else '无内容'}"
yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}") yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}")

View File

@@ -19,24 +19,34 @@ class JSONConfigFile(file_model.ConfigFile):
return os.path.exists(self.config_file_name) return os.path.exists(self.config_file_name)
async def create(self): async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name) if self.template_file_name is not None:
shutil.copyfile(self.template_file_name, self.config_file_name)
elif self.template_data is not None:
with open(self.config_file_name, "w", encoding="utf-8") as f:
json.dump(self.template_data, f, indent=4, ensure_ascii=False)
else:
raise ValueError("template_file_name or template_data must be provided")
async def load(self) -> dict: async def load(self, completion: bool=True) -> dict:
if not self.exists(): if not self.exists():
await self.create() await self.create()
if self.template_file_name is not None: if self.template_file_name is not None:
with open(self.config_file_name, "r", encoding="utf-8") as f: with open(self.template_file_name, "r", encoding="utf-8") as f:
self.template_data = json.load(f)
with open(self.config_file_name, "r", encoding="utf-8") as f:
try:
cfg = json.load(f) cfg = json.load(f)
except json.JSONDecodeError as e:
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}")
# 从模板文件中进行补全 if completion:
with open(self.template_file_name, "r", encoding="utf-8") as f:
self.template_data = json.load(f)
for key in self.template_data: for key in self.template_data:
if key not in cfg: if key not in cfg:
cfg[key] = self.template_data[key] cfg[key] = self.template_data[key]
return cfg return cfg

View File

@@ -25,7 +25,7 @@ class PythonModuleConfigFile(file_model.ConfigFile):
async def create(self): async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name) shutil.copyfile(self.template_file_name, self.config_file_name)
async def load(self) -> dict: async def load(self, completion: bool=True) -> dict:
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0] module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
@@ -43,18 +43,19 @@ class PythonModuleConfigFile(file_model.ConfigFile):
cfg[key] = getattr(module, key) cfg[key] = getattr(module, key)
# 从模板模块文件中进行补全 # 从模板模块文件中进行补全
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0] if completion:
module = importlib.import_module(module_name) module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
module = importlib.import_module(module_name)
for key in dir(module): for key in dir(module):
if key.startswith('__'): if key.startswith('__'):
continue continue
if not isinstance(getattr(module, key), allowed_types): if not isinstance(getattr(module, key), allowed_types):
continue continue
if key not in cfg: if key not in cfg:
cfg[key] = getattr(module, key) cfg[key] = getattr(module, key)
return cfg return cfg

View File

@@ -20,8 +20,8 @@ class ConfigManager:
self.file = cfg_file self.file = cfg_file
self.data = {} self.data = {}
async def load_config(self): async def load_config(self, completion: bool=True):
self.data = await self.file.load() self.data = await self.file.load(completion=completion)
async def dump_config(self): async def dump_config(self):
await self.file.save(self.data) await self.file.save(self.data)
@@ -30,7 +30,7 @@ class ConfigManager:
self.file.save_sync(self.data) self.file.save_sync(self.data)
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager: async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager:
"""加载Python模块配置文件""" """加载Python模块配置文件"""
cfg_inst = pymodule.PythonModuleConfigFile( cfg_inst = pymodule.PythonModuleConfigFile(
config_name, config_name,
@@ -38,12 +38,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con
) )
cfg_mgr = ConfigManager(cfg_inst) cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config() await cfg_mgr.load_config(completion=completion)
return cfg_mgr return cfg_mgr
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager: async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
"""加载JSON配置文件""" """加载JSON配置文件"""
cfg_inst = json_file.JSONConfigFile( cfg_inst = json_file.JSONConfigFile(
config_name, config_name,
@@ -52,6 +52,6 @@ async def load_json_config(config_name: str, template_name: str=None, template_d
) )
cfg_mgr = ConfigManager(cfg_inst) cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config() await cfg_mgr.load_config(completion=completion)
return cfg_mgr return cfg_mgr

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("deepseek-config-completion", 5)
class DeepseekConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester'] \
or 'deepseek' not in self.ap.provider_cfg.data['keys']
async def run(self):
"""执行迁移
"""
if 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['deepseek-chat-completions'] = {
'base-url': 'https://api.deepseek.com',
'args': {},
'timeout': 120,
}
if 'deepseek' not in self.ap.provider_cfg.data['keys']:
self.ap.provider_cfg.data['keys']['deepseek'] = []
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("vision-config", 6)
class VisionConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return "enable-vision" not in self.ap.provider_cfg.data
async def run(self):
"""执行迁移"""
if "enable-vision" not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data["enable-vision"] = False
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("qcg-center-url-config", 7)
class QCGCenterURLConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return "qcg-center-url" not in self.ap.system_cfg.data
async def run(self):
"""执行迁移"""
if "qcg-center-url" not in self.ap.system_cfg.data:
self.ap.system_cfg.data["qcg-center-url"] = "https://api.qchatgpt.rockchin.top/api/v2"
await self.ap.system_cfg.dump_config()

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("ad-fixwin-cfg-migration", 8)
class AdFixwinConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return isinstance(
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]["default"],
int
)
async def run(self):
"""执行迁移"""
for session_name in self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]:
temp_dict = {
"window-size": 60,
"limit": self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name]
}
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name] = temp_dict
await self.ap.pipeline_cfg.dump_config()

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("msg-truncator-cfg-migration", 9)
class MsgTruncatorConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'msg-truncate' not in self.ap.pipeline_cfg.data
async def run(self):
"""执行迁移"""
self.ap.pipeline_cfg.data['msg-truncate'] = {
'method': 'round',
'round': {
'max-round': 10
}
}
await self.ap.pipeline_cfg.dump_config()

View File

@@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def load(self) -> dict: async def load(self, completion: bool=True) -> dict:
pass pass
@abc.abstractmethod @abc.abstractmethod

View File

@@ -15,7 +15,7 @@ from ..command import cmdmgr
from ..plugin import manager as plugin_mgr from ..plugin import manager as plugin_mgr
from ..pipeline import pool from ..pipeline import pool
from ..pipeline import controller, stagemgr from ..pipeline import controller, stagemgr
from ..utils import version as version_mgr, proxy as proxy_mgr from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
class Application: class Application:
@@ -69,6 +69,8 @@ class Application:
ver_mgr: version_mgr.VersionManager = None ver_mgr: version_mgr.VersionManager = None
ann_mgr: announce_mgr.AnnouncementManager = None
proxy_mgr: proxy_mgr.ProxyManager = None proxy_mgr: proxy_mgr.ProxyManager = None
logger: logging.Logger = None logger: logging.Logger = None

View File

@@ -1,18 +1,21 @@
from __future__ import print_function from __future__ import print_function
import traceback
from . import app from . import app
from ..audit import identifier from ..audit import identifier
from . import stage from . import stage
# 引入启动阶段实现以便注册 # 引入启动阶段实现以便注册
from .stages import load_config, setup_logger, build_app, migrate from .stages import load_config, setup_logger, build_app, migrate, show_notes
stage_order = [ stage_order = [
"LoadConfigStage", "LoadConfigStage",
"MigrationStage", "MigrationStage",
"SetupLoggerStage", "SetupLoggerStage",
"BuildAppStage" "BuildAppStage",
"ShowNotesStage"
] ]
@@ -27,6 +30,7 @@ async def make_app() -> app.Application:
for stage_name in stage_order: for stage_name in stage_order:
stage_cls = stage.preregistered_stages[stage_name] stage_cls = stage.preregistered_stages[stage_name]
stage_inst = stage_cls() stage_inst = stage_cls()
await stage_inst.run(ap) await stage_inst.run(ap)
await ap.initialize() await ap.initialize()
@@ -35,5 +39,8 @@ async def make_app() -> app.Application:
async def main(): async def main():
app_inst = await make_app() try:
await app_inst.run() app_inst = await make_app()
await app_inst.run()
except Exception as e:
traceback.print_exc()

View File

@@ -8,16 +8,3 @@ from ...config.impls import pymodule
load_python_module_config = config_mgr.load_python_module_config load_python_module_config = config_mgr.load_python_module_config
load_json_config = config_mgr.load_json_config load_json_config = config_mgr.load_json_config
async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str]:
override_json = json.load(open("override.json", "r", encoding="utf-8"))
overrided = []
config = cfg_mgr.data
for key in override_json:
if key in config:
config[key] = override_json[key]
overrided.append(key)
return overrided

View File

@@ -13,6 +13,8 @@ required_deps = {
"tiktoken": "tiktoken", "tiktoken": "tiktoken",
"yaml": "pyyaml", "yaml": "pyyaml",
"aiohttp": "aiohttp", "aiohttp": "aiohttp",
"psutil": "psutil",
"async_lru": "async-lru",
} }

View File

@@ -67,12 +67,15 @@ class Query(pydantic.BaseModel):
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置""" """使用的函数,由前置处理器阶段设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] = [] resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[mirai.MessageChain]] = []
"""由Process阶段生成的回复消息对象列表""" """由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None
"""回复消息链从resp_messages包装而得""" """回复消息链从resp_messages包装而得"""
# ======= 内部保留 =======
current_stage: "pkg.pipeline.stagemgr.StageInstContainer" = None
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True

44
pkg/core/note.py Normal file
View File

@@ -0,0 +1,44 @@
from __future__ import annotations
import abc
import typing
from . import app
preregistered_notes: list[typing.Type[LaunchNote]] = []
def note_class(name: str, number: int):
"""注册一个启动信息
"""
def decorator(cls: typing.Type[LaunchNote]) -> typing.Type[LaunchNote]:
cls.name = name
cls.number = number
preregistered_notes.append(cls)
return cls
return decorator
class LaunchNote(abc.ABC):
"""启动信息
"""
name: str
number: int
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
@abc.abstractmethod
async def need_show(self) -> bool:
"""判断当前环境是否需要显示此启动信息
"""
pass
@abc.abstractmethod
async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]:
"""生成启动信息
"""
pass

View File

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
import typing
from .. import note, app
@note.note_class("ClassicNotes", 1)
class ClassicNotes(note.LaunchNote):
"""经典启动信息
"""
async def need_show(self) -> bool:
return True
async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]:
yield await self.ap.ann_mgr.show_announcements()
yield await self.ap.ver_mgr.show_version_update()

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
import typing
import os
import sys
import logging
from .. import note, app
@note.note_class("SelectionModeOnWindows", 2)
class SelectionModeOnWindows(note.LaunchNote):
"""Windows 上的选择模式提示信息
"""
async def need_show(self) -> bool:
return os.name == 'nt'
async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]:
yield """您正在使用 Windows 系统,若窗口左上角显示处于”选择“模式,程序将被暂停运行,此时请右键窗口中空白区域退出选择模式。""", logging.INFO

View File

@@ -15,7 +15,6 @@ from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr from ...provider.tools import toolmgr as llm_tool_mgr
from ...platform import manager as im_mgr from ...platform import manager as im_mgr
@stage.stage_class("BuildAppStage") @stage.stage_class("BuildAppStage")
class BuildAppStage(stage.BootingStage): class BuildAppStage(stage.BootingStage):
"""构建应用阶段 """构建应用阶段
@@ -35,6 +34,7 @@ class BuildAppStage(stage.BootingStage):
center_v2_api = center_v2.V2CenterAPI( center_v2_api = center_v2.V2CenterAPI(
ap, ap,
backend_url=ap.system_cfg.data["qcg-center-url"],
basic_info={ basic_info={
"host_id": identifier.identifier["host_id"], "host_id": identifier.identifier["host_id"],
"instance_id": identifier.identifier["instance_id"], "instance_id": identifier.identifier["instance_id"],
@@ -53,12 +53,10 @@ class BuildAppStage(stage.BootingStage):
# 发送公告 # 发送公告
ann_mgr = announce.AnnouncementManager(ap) ann_mgr = announce.AnnouncementManager(ap)
await ann_mgr.show_announcements() ap.ann_mgr = ann_mgr
ap.query_pool = pool.QueryPool() ap.query_pool = pool.QueryPool()
await ap.ver_mgr.show_version_update()
plugin_mgr_inst = plugin_mgr.PluginManager(ap) plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize() await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst ap.plugin_mgr = plugin_mgr_inst
@@ -83,7 +81,6 @@ class BuildAppStage(stage.BootingStage):
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap) llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize() await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst ap.tool_mgr = llm_tool_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap) im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize() await im_mgr_inst.initialize()
ap.platform_mgr = im_mgr_inst ap.platform_mgr = im_mgr_inst
@@ -92,5 +89,6 @@ class BuildAppStage(stage.BootingStage):
await stage_mgr.initialize() await stage_mgr.initialize()
ap.stage_mgr = stage_mgr ap.stage_mgr = stage_mgr
ctrl = controller.Controller(ap) ctrl = controller.Controller(ap)
ap.ctrl = ctrl ap.ctrl = ctrl

View File

@@ -12,11 +12,11 @@ class LoadConfigStage(stage.BootingStage):
async def run(self, ap: app.Application): async def run(self, ap: app.Application):
"""启动 """启动
""" """
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json") ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json", completion=False)
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json") ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json", completion=False)
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json") ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json", completion=False)
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json") ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json", completion=False)
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json") ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json", completion=False)
ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json") ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json")
await ap.plugin_setting_meta.dump_config() await ap.plugin_setting_meta.dump_config()

View File

@@ -5,6 +5,7 @@ import importlib
from .. import stage, app from .. import stage, app
from ...config import migration from ...config import migration
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
@stage.stage_class("MigrationStage") @stage.stage_class("MigrationStage")

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
from .. import stage, app, note
from ..notes import n001_classic_msgs, n002_selection_mode_on_windows
@stage.stage_class("ShowNotesStage")
class ShowNotesStage(stage.BootingStage):
"""显示启动信息阶段
"""
async def run(self, ap: app.Application):
# 排序
note.preregistered_notes.sort(key=lambda x: x.number)
for note_cls in note.preregistered_notes:
try:
note_inst = note_cls(ap)
if await note_inst.need_show():
async for ret in note_inst.yield_note():
if not ret:
continue
msg, level = ret
if msg:
ap.logger.log(level, msg)
except Exception as e:
continue

View File

@@ -8,7 +8,10 @@ from ...config import manager as cfg_mgr
@stage.stage_class('BanSessionCheckStage') @stage.stage_class('BanSessionCheckStage')
class BanSessionCheckStage(stage.PipelineStage): class BanSessionCheckStage(stage.PipelineStage):
"""访问控制处理阶段""" """访问控制处理阶段
仅检查query中群号或个人号是否在访问控制列表中。
"""
async def initialize(self): async def initialize(self):
pass pass

View File

@@ -9,12 +9,24 @@ from ...core import entities as core_entities
from ...config import manager as cfg_mgr from ...config import manager as cfg_mgr
from . import filter as filter_model, entities as filter_entities from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine from .filters import cntignore, banwords, baiduexamine
from ...provider import entities as llm_entities
@stage.stage_class('PostContentFilterStage') @stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage') @stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage): class ContentFilterStage(stage.PipelineStage):
"""内容过滤阶段""" """内容过滤阶段
前置:
检查消息是否符合规则,不符合则拦截。
改写:
message_chain
后置:
检查AI回复消息是否符合规则可能进行改写不符合则拦截。
改写:
query.resp_messages
"""
filter_chain: list[filter_model.ContentFilter] filter_chain: list[filter_model.ContentFilter]
@@ -25,7 +37,7 @@ class ContentFilterStage(stage.PipelineStage):
async def initialize(self): async def initialize(self):
filters_required = [ filters_required = [
"content-filter" "content-ignore",
] ]
if self.ap.pipeline_cfg.data['check-sensitive-words']: if self.ap.pipeline_cfg.data['check-sensitive-words']:
@@ -130,14 +142,37 @@ class ContentFilterStage(stage.PipelineStage):
"""处理 """处理
""" """
if stage_inst_name == 'PreContentFilterStage': if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False
for me in query.message_chain:
if not isinstance(me, mirai.Plain):
contain_non_text = True
break
if contain_non_text:
self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。")
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
return await self._pre_process( return await self._pre_process(
str(query.message_chain).strip(), str(query.message_chain).strip(),
query query
) )
elif stage_inst_name == 'PostContentFilterStage': elif stage_inst_name == 'PostContentFilterStage':
return await self._post_process( # 仅处理 query.resp_messages[-1].content 是 str 的情况
query.resp_messages[-1].content, if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(query.resp_messages[-1].content, str):
query return await self._post_process(
) query.resp_messages[-1].content,
query
)
else:
self.ap.logger.debug(f"resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。")
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else: else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}') raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')

View File

@@ -4,6 +4,8 @@ import enum
import pydantic import pydantic
from ...provider import entities as llm_entities
class ResultLevel(enum.Enum): class ResultLevel(enum.Enum):
"""结果等级""" """结果等级"""
@@ -38,7 +40,7 @@ class FilterResult(pydantic.BaseModel):
""" """
replacement: str replacement: str
"""替换后的消息 """替换后的文本消息
内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。 内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。
若没有修改内容,也需要返回原消息。 若没有修改内容,也需要返回原消息。

View File

@@ -5,6 +5,7 @@ import typing
from ...core import app from ...core import app
from . import entities from . import entities
from ...provider import entities as llm_entities
preregistered_filters: list[typing.Type[ContentFilter]] = [] preregistered_filters: list[typing.Type[ContentFilter]] = []
@@ -63,7 +64,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def process(self, message: str) -> entities.FilterResult: async def process(self, message: str=None, image_url=None) -> entities.FilterResult:
"""处理消息 """处理消息
分为前后阶段,具体取决于 enable_stages 的值。 分为前后阶段,具体取决于 enable_stages 的值。
@@ -71,6 +72,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
Args: Args:
message (str): 需要检查的内容 message (str): 需要检查的内容
image_url (str): 要检查的图片的 URL
Returns: Returns:
entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档 entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档

View File

@@ -8,7 +8,7 @@ from ....config import manager as cfg_mgr
@filter_model.filter_class("ban-word-filter") @filter_model.filter_class("ban-word-filter")
class BanWordFilter(filter_model.ContentFilter): class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言""" """根据内容过滤"""
async def initialize(self): async def initialize(self):
pass pass

View File

@@ -4,6 +4,8 @@ import asyncio
import typing import typing
import traceback import traceback
import mirai
from ..core import app, entities from ..core import app, entities
from . import entities as pipeline_entities from . import entities as pipeline_entities
from ..plugin import events from ..plugin import events
@@ -68,6 +70,17 @@ class Controller:
"""检查输出 """检查输出
""" """
if result.user_notice: if result.user_notice:
# 处理str类型
if isinstance(result.user_notice, str):
result.user_notice = mirai.MessageChain(
mirai.Plain(result.user_notice)
)
elif isinstance(result.user_notice, list):
result.user_notice = mirai.MessageChain(
*result.user_notice
)
await self.ap.platform_mgr.send( await self.ap.platform_mgr.send(
query.message_event, query.message_event,
result.user_notice, result.user_notice,
@@ -109,6 +122,8 @@ class Controller:
while i < len(self.ap.stage_mgr.stage_containers): while i < len(self.ap.stage_mgr.stage_containers):
stage_container = self.ap.stage_mgr.stage_containers[i] stage_container = self.ap.stage_mgr.stage_containers[i]
query.current_stage = stage_container # 标记到 Query 对象里
result = stage_container.inst.process(query, stage_container.inst_name) result = stage_container.inst.process(query, stage_container.inst_name)
@@ -149,7 +164,7 @@ class Controller:
try: try:
await self._execute_from_stage(0, query) await self._execute_from_stage(0, query)
except Exception as e: except Exception as e:
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id}: {e}") self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={query.current_stage.inst_name} : {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
# traceback.print_exc() # traceback.print_exc()
finally: finally:

View File

@@ -16,6 +16,9 @@ from ...config import manager as cfg_mgr
@stage.stage_class("LongTextProcessStage") @stage.stage_class("LongTextProcessStage")
class LongTextProcessStage(stage.PipelineStage): class LongTextProcessStage(stage.PipelineStage):
"""长消息处理阶段 """长消息处理阶段
改写:
- resp_message_chain
""" """
strategy_impl: strategy.LongTextStrategy strategy_impl: strategy.LongTextStrategy
@@ -31,18 +34,18 @@ class LongTextProcessStage(stage.PipelineStage):
if os.name == "nt": if os.name == "nt":
use_font = "C:/Windows/Fonts/msyh.ttc" use_font = "C:/Windows/Fonts/msyh.ttc"
if not os.path.exists(use_font): if not os.path.exists(use_font):
self.ap.logger.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。") self.ap.logger.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在配置文件中调整相关设置。")
config['blob_message_strategy'] = "forward" config['blob_message_strategy'] = "forward"
else: else:
self.ap.logger.info("使用Windows自带字体" + use_font) self.ap.logger.info("使用Windows自带字体" + use_font)
config['font-path'] = use_font config['font-path'] = use_font
else: else:
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。")
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
except: except:
traceback.print_exc() traceback.print_exc()
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font))
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
@@ -56,8 +59,19 @@ class LongTextProcessStage(stage.PipelineStage):
await self.strategy_impl.initialize() await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']: # 检查是否包含非 Plain 组件
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query)) contains_non_plain = False
for msg in query.resp_message_chain[-1]:
if not isinstance(msg, Plain):
contains_non_plain = True
break
if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query

View File

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from . import truncator
from .truncators import round
@stage.stage_class("ConversationMessageTruncator")
class ConversationMessageTruncator(stage.PipelineStage):
"""会话消息截断器
用于截断会话消息链,以适应平台消息长度限制。
"""
trun: truncator.Truncator
async def initialize(self):
use_method = self.ap.pipeline_cfg.data['msg-truncate']['method']
for trun in truncator.preregistered_truncators:
if trun.name == use_method:
self.trun = trun(self.ap)
break
else:
raise ValueError(f"未知的截断器: {use_method}")
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理
"""
query = await self.trun.truncate(query)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
import typing
import abc
from ...core import entities as core_entities, app
preregistered_truncators: list[typing.Type[Truncator]] = []
def truncator_class(
name: str
) -> typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]:
"""截断器类装饰器
Args:
name (str): 截断器名称
Returns:
typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: 装饰器
"""
def decorator(cls: typing.Type[Truncator]) -> typing.Type[Truncator]:
assert issubclass(cls, Truncator)
cls.name = name
preregistered_truncators.append(cls)
return cls
return decorator
class Truncator(abc.ABC):
"""消息截断器基类
"""
name: str
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
"""截断
一般只需要操作query.messages也可以扩展操作query.prompt, query.user_message。
请勿操作其他字段。
"""
pass

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from .. import truncator
from ....core import entities as core_entities
@truncator.truncator_class("round")
class RoundTruncator(truncator.Truncator):
"""前文回合数阶段器
"""
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
"""截断
"""
max_round = self.ap.pipeline_cfg.data['msg-truncate']['round']['max-round']
temp_messages = []
current_round = 0
# 从后往前遍历
for msg in query.messages[::-1]:
if current_round < max_round:
temp_messages.append(msg)
if msg.role == 'user':
current_round += 1
else:
break
query.messages = temp_messages[::-1]
return query

View File

@@ -43,7 +43,7 @@ class QueryPool:
message_event=message_event, message_event=message_event,
message_chain=message_chain, message_chain=message_chain,
resp_messages=[], resp_messages=[],
resp_message_chain=None, resp_message_chain=[],
adapter=adapter adapter=adapter
) )
self.queries.append(query) self.queries.append(query)

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import mirai
from .. import stage, entities, stagemgr from .. import stage, entities, stagemgr
from ...core import entities as core_entities from ...core import entities as core_entities
from ...provider import entities as llm_entities from ...provider import entities as llm_entities
@@ -9,6 +11,16 @@ from ...plugin import events
@stage.stage_class("PreProcessor") @stage.stage_class("PreProcessor")
class PreProcessor(stage.PipelineStage): class PreProcessor(stage.PipelineStage):
"""请求预处理阶段 """请求预处理阶段
签出会话、prompt、上文、模型、内容函数。
改写:
- session
- prompt
- messages
- user_message
- use_model
- use_funcs
""" """
async def process( async def process(
@@ -27,21 +39,42 @@ class PreProcessor(stage.PipelineStage):
query.prompt = conversation.prompt.copy() query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy() query.messages = conversation.messages.copy()
query.user_message = llm_entities.Message(
role='user',
content=str(query.message_chain).strip()
)
query.use_model = conversation.use_model query.use_model = conversation.use_model
query.use_funcs = conversation.use_funcs query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None
# 检查vision是否启用没启用就删除所有图片
if not self.ap.provider_cfg.data['enable-vision'] or not query.use_model.vision_supported:
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
if me.type == 'image_url':
msg.content.remove(me)
content_list = []
for me in query.message_chain:
if isinstance(me, mirai.Plain):
content_list.append(
llm_entities.ContentElement.from_text(me.text)
)
elif isinstance(me, mirai.Image):
if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported:
if me.url is not None:
content_list.append(
llm_entities.ContentElement.from_image_url(str(me.url))
)
query.user_message = llm_entities.Message( # TODO 适配多模态输入
role='user',
content=content_list
)
# =========== 触发事件 PromptPreProcessing # =========== 触发事件 PromptPreProcessing
session = query.session
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PromptPreProcessing( event=events.PromptPreProcessing(
session_name=f'{session.launcher_type.value}_{session.launcher_id}', session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
default_prompt=query.prompt.messages, default_prompt=query.prompt.messages,
prompt=query.messages, prompt=query.messages,
query=query query=query

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import typing import typing
import time import time
import traceback import traceback
import json
import mirai import mirai
@@ -41,12 +42,7 @@ class ChatMessageHandler(handler.MessageHandler):
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply) mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append( query.resp_messages.append(mc)
llm_entities.Message(
role='plugin',
content=str(mc),
)
)
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@@ -66,13 +62,8 @@ class ChatMessageHandler(handler.MessageHandler):
) )
if event_ctx.event.alter is not None: if event_ctx.event.alter is not None:
query.message_chain = mirai.MessageChain([ # if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
mirai.Plain(event_ctx.event.alter) query.user_message.content = event_ctx.event.alter
])
query.messages.append(
query.user_message
)
text_length = 0 text_length = 0
@@ -80,7 +71,7 @@ class ChatMessageHandler(handler.MessageHandler):
try: try:
async for result in query.use_model.requester.request(query): async for result in self.runner(query):
query.resp_messages.append(result) query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
@@ -92,6 +83,9 @@ class ChatMessageHandler(handler.MessageHandler):
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
except Exception as e: except Exception as e:
self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}') self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}')
@@ -104,8 +98,6 @@ class ChatMessageHandler(handler.MessageHandler):
debug_notice=traceback.format_exc() debug_notice=traceback.format_exc()
) )
finally: finally:
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
await self.ap.ctr_mgr.usage.post_query_record( await self.ap.ctr_mgr.usage.post_query_record(
session_type=query.session.launcher_type.value, session_type=query.session.launcher_type.value,
@@ -115,4 +107,65 @@ class ChatMessageHandler(handler.MessageHandler):
model_name=query.use_model.name, model_name=query.use_model.name,
response_seconds=int(time.time() - start_time), response_seconds=int(time.time() - start_time),
retry_times=-1, retry_times=-1,
) )
async def runner(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""执行一个请求处理过程中的LLM接口请求、函数调用的循环
这是临时处理方案后续可能改为使用LangChain或者自研的工作流处理器
"""
await query.use_model.requester.preprocess(query)
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg)
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
yield err_msg
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)

View File

@@ -48,12 +48,7 @@ class CommandHandler(handler.MessageHandler):
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply) mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append( query.resp_messages.append(mc)
llm_entities.Message(
role='command',
content=str(mc),
)
)
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@@ -80,9 +75,6 @@ class CommandHandler(handler.MessageHandler):
session=session session=session
): ):
if ret.error is not None: if ret.error is not None:
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(str(ret.error))
# ])
query.resp_messages.append( query.resp_messages.append(
llm_entities.Message( llm_entities.Message(
role='command', role='command',
@@ -96,18 +88,28 @@ class CommandHandler(handler.MessageHandler):
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
elif ret.text is not None: elif ret.text is not None or ret.image_url is not None:
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(ret.text) content: list[llm_entities.ContentElement]= []
# ])
if ret.text is not None:
content.append(
llm_entities.ContentElement.from_text(ret.text)
)
if ret.image_url is not None:
content.append(
llm_entities.ContentElement.from_image_url(ret.image_url)
)
query.resp_messages.append( query.resp_messages.append(
llm_entities.Message( llm_entities.Message(
role='command', role='command',
content=ret.text, content=content,
) )
) )
self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}') self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,

View File

@@ -11,7 +11,13 @@ from ...config import manager as cfg_mgr
@stage.stage_class("MessageProcessor") @stage.stage_class("MessageProcessor")
class Processor(stage.PipelineStage): class Processor(stage.PipelineStage):
"""请求实际处理阶段""" """请求实际处理阶段
通过命令处理器和聊天处理器处理消息。
改写:
- resp_messages
"""
cmd_handler: handler.MessageHandler cmd_handler: handler.MessageHandler

View File

@@ -1,18 +1,15 @@
# 固定窗口算法
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import time import time
from .. import algo from .. import algo
# 固定窗口算法
class SessionContainer: class SessionContainer:
wait_lock: asyncio.Lock wait_lock: asyncio.Lock
records: dict[int, int] records: dict[int, int]
"""访问记录key为每分钟的起始时间戳value为访问次数""" """访问记录key为每窗口长度的起始时间戳value为访问次数"""
def __init__(self): def __init__(self):
self.wait_lock = asyncio.Lock() self.wait_lock = asyncio.Lock()
@@ -47,30 +44,34 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
# 等待锁 # 等待锁
async with container.wait_lock: async with container.wait_lock:
# 获取窗口大小和限制
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['window-size']
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['limit']
if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
# 获取当前时间戳 # 获取当前时间戳
now = int(time.time()) now = int(time.time())
# 获取当前分钟的起始时间戳 # 获取当前窗口的起始时间戳
now = now - now % 60 now = now - now % window_size
# 获取当前分钟的访问次数 # 获取当前窗口的访问次数
count = container.records.get(now, 0) count = container.records.get(now, 0)
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']
if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]
# 如果访问次数超过了限制 # 如果访问次数超过了限制
if count >= limitation: if count >= limitation:
if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop': if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop':
return False return False
elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait': elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait':
# 等待下一分钟 # 等待下一窗口
await asyncio.sleep(60 - time.time() % 60) await asyncio.sleep(window_size - time.time() % window_size)
now = int(time.time()) now = int(time.time())
now = now - now % 60 now = now - now % window_size
if now not in container.records: if now not in container.records:
container.records = {} container.records = {}

View File

@@ -11,7 +11,10 @@ from ...core import entities as core_entities
@stage.stage_class("RequireRateLimitOccupancy") @stage.stage_class("RequireRateLimitOccupancy")
@stage.stage_class("ReleaseRateLimitOccupancy") @stage.stage_class("ReleaseRateLimitOccupancy")
class RateLimit(stage.PipelineStage): class RateLimit(stage.PipelineStage):
"""限速器控制阶段""" """限速器控制阶段
不改写query只检查是否需要限速。
"""
algo: algo.ReteLimitAlgo algo: algo.ReteLimitAlgo

View File

@@ -31,7 +31,7 @@ class SendResponseBackStage(stage.PipelineStage):
await self.ap.platform_mgr.send( await self.ap.platform_mgr.send(
query.message_event, query.message_event,
query.resp_message_chain, query.resp_message_chain[-1],
adapter=query.adapter adapter=query.adapter
) )

View File

@@ -14,9 +14,12 @@ from ...config import manager as cfg_mgr
@stage.stage_class("GroupRespondRuleCheckStage") @stage.stage_class("GroupRespondRuleCheckStage")
class GroupRespondRuleCheckStage(stage.PipelineStage): class GroupRespondRuleCheckStage(stage.PipelineStage):
"""群组响应规则检查器 """群组响应规则检查器
仅检查群消息是否符合规则。
""" """
rule_matchers: list[rule.GroupRespondRule] rule_matchers: list[rule.GroupRespondRule]
"""检查器实例"""
async def initialize(self): async def initialize(self):
"""初始化检查器 """初始化检查器
@@ -31,7 +34,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type.value != 'group': if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
@@ -41,8 +44,8 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
use_rule = rules['default'] use_rule = rules['default']
if str(query.launcher_id) in use_rule: if str(query.launcher_id) in rules:
use_rule = use_rule[str(query.launcher_id)] use_rule = rules[str(query.launcher_id)]
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行 for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query) res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)

View File

@@ -20,11 +20,14 @@ class PrefixRule(rule_model.GroupRespondRule):
for prefix in prefixes: for prefix in prefixes:
if message_text.startswith(prefix): if message_text.startswith(prefix):
# 查找第一个plain元素
for me in message_chain:
if isinstance(me, mirai.Plain):
me.text = me.text[len(prefix):]
return entities.RuleJudgeResult( return entities.RuleJudgeResult(
matching=True, matching=True,
replacement=mirai.MessageChain([ replacement=message_chain,
mirai.Plain(message_text[len(prefix):])
]),
) )
return entities.RuleJudgeResult( return entities.RuleJudgeResult(

View File

@@ -13,21 +13,23 @@ from .respback import respback
from .wrapper import wrapper from .wrapper import wrapper
from .preproc import preproc from .preproc import preproc
from .ratelimit import ratelimit from .ratelimit import ratelimit
from .msgtrun import msgtrun
# 请求处理阶段顺序 # 请求处理阶段顺序
stage_order = [ stage_order = [
"GroupRespondRuleCheckStage", "GroupRespondRuleCheckStage", # 群响应规则检查
"BanSessionCheckStage", "BanSessionCheckStage", # 封禁会话检查
"PreContentFilterStage", "PreContentFilterStage", # 内容过滤前置阶段
"PreProcessor", "PreProcessor", # 预处理器
"RequireRateLimitOccupancy", "ConversationMessageTruncator", # 会话消息截断器
"MessageProcessor", "RequireRateLimitOccupancy", # 请求速率限制占用
"ReleaseRateLimitOccupancy", "MessageProcessor", # 处理器
"PostContentFilterStage", "ReleaseRateLimitOccupancy", # 释放速率限制占用
"ResponseWrapper", "PostContentFilterStage", # 内容过滤后置阶段
"LongTextProcessStage", "ResponseWrapper", # 响应包装器
"SendResponseBackStage", "LongTextProcessStage", # 长文本处理
"SendResponseBackStage", # 发送响应
] ]

View File

@@ -14,6 +14,13 @@ from ...plugin import events
@stage.stage_class("ResponseWrapper") @stage.stage_class("ResponseWrapper")
class ResponseWrapper(stage.PipelineStage): class ResponseWrapper(stage.PipelineStage):
"""回复包装阶段
把回复的 message 包装成人类识读的形式。
改写:
- resp_message_chain
"""
async def initialize(self): async def initialize(self):
pass pass
@@ -25,75 +32,49 @@ class ResponseWrapper(stage.PipelineStage):
) -> typing.AsyncGenerator[entities.StageProcessResult, None]: ) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理 """处理
""" """
if query.resp_messages[-1].role == 'command': # 如果 resp_messages[-1] 已经是 MessageChain 了
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content) if isinstance(query.resp_messages[-1], mirai.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
elif query.resp_messages[-1].role == 'plugin':
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else: else:
if query.resp_messages[-1].role == 'command':
# query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))
if query.resp_messages[-1].role == 'assistant': yield entities.StageProcessResult(
result = query.resp_messages[-1] result_type=entities.ResultType.CONTINUE,
session = await self.ap.sess_mgr.get_session(query) new_query=query
)
elif query.resp_messages[-1].role == 'plugin':
# if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
# query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
# else:
# query.resp_message_chain.append(query.resp_messages[-1].content)
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain())
reply_text = '' yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
if result.content is not None: # 有内容 if query.resp_messages[-1].role == 'assistant':
reply_text = result.content result = query.resp_messages[-1]
session = await self.ap.sess_mgr.get_session(query)
# ============= 触发插件事件 =============== reply_text = ''
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
else: if result.content: # 有内容
reply_text = str(result.get_content_mirai_message_chain())
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) # ============= 触发插件事件 ===============
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
if self.ap.platform_cfg.data['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded( event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value, launcher_type=query.launcher_type.value,
@@ -107,7 +88,6 @@ class ResponseWrapper(stage.PipelineStage):
query=query query=query
) )
) )
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, result_type=entities.ResultType.INTERRUPT,
@@ -116,13 +96,56 @@ class ResponseWrapper(stage.PipelineStage):
else: else:
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
else: else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) query.resp_message_chain.append(result.get_content_mirai_message_chain())
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
else:
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -82,8 +82,8 @@ class PlatformManager:
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.GroupMessageReceived( event=events.GroupMessageReceived(
launcher_type='person', launcher_type='group',
launcher_id=event.sender.id, launcher_id=event.group.id,
sender_id=event.sender.id, sender_id=event.sender.id,
message_chain=event.message_chain, message_chain=event.message_chain,
query=None query=None

View File

@@ -30,7 +30,16 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
msg_id = msg.id msg_id = msg.id
msg_time = msg.time msg_time = msg.time
elif type(msg) is mirai.Image: elif type(msg) is mirai.Image:
msg_list.append(aiocqhttp.MessageSegment.image(msg.path)) arg = ''
if msg.base64:
arg = msg.base64
msg_list.append(aiocqhttp.MessageSegment.image(f"base64://{arg}"))
elif msg.url:
arg = msg.url
msg_list.append(aiocqhttp.MessageSegment.image(arg))
elif msg.path:
arg = msg.path
msg_list.append(aiocqhttp.MessageSegment.image(arg))
elif type(msg) is mirai.At: elif type(msg) is mirai.At:
msg_list.append(aiocqhttp.MessageSegment.at(msg.target)) msg_list.append(aiocqhttp.MessageSegment.at(msg.target))
elif type(msg) is mirai.AtAll: elif type(msg) is mirai.AtAll:

View File

@@ -322,7 +322,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
proxies=None proxies=None
) )
if resp.status_code == 403: if resp.status_code == 403:
raise Exception("go-cqhttp拒绝访问请检查config.py中nakuru_config的token是否与go-cqhttp设置的access-token匹配") raise Exception("go-cqhttp拒绝访问请检查配置文件中nakuru适配器的配置")
self.bot_account_id = int(resp.json()['data']['user_id']) self.bot_account_id = int(resp.json()['data']['user_id'])
except Exception as e: except Exception as e:
raise Exception("获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确") raise Exception("获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确")

View File

@@ -198,7 +198,6 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
bot_account_id: int = 0, bot_account_id: int = 0,
) -> mirai.MessageChain: ) -> mirai.MessageChain:
yiri_msg_list = [] yiri_msg_list = []
# 存id # 存id
yiri_msg_list.append( yiri_msg_list.append(
@@ -218,7 +217,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
yiri_msg_list.append(mirai.At(target=mention.id)) yiri_msg_list.append(mirai.At(target=mention.id))
for attachment in message.attachments: for attachment in message.attachments:
if attachment.content_type == "image": if attachment.content_type.startswith("image"):
yiri_msg_list.append(mirai.Image(url=attachment.url)) yiri_msg_list.append(mirai.Image(url=attachment.url))
else: else:
logging.warning( logging.warning(
@@ -368,11 +367,15 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
member_openid_mapping: OpenIDMapping[str, int] = None member_openid_mapping: OpenIDMapping[str, int] = None
group_openid_mapping: OpenIDMapping[str, int] = None group_openid_mapping: OpenIDMapping[str, int] = None
group_msg_seq = None
def __init__(self, cfg: dict, ap: app.Application): def __init__(self, cfg: dict, ap: app.Application):
"""初始化适配器""" """初始化适配器"""
self.cfg = cfg self.cfg = cfg
self.ap = ap self.ap = ap
self.group_msg_seq = 1
switchs = {} switchs = {}
for intent in cfg["intents"]: for intent in cfg["intents"]:
@@ -419,8 +422,6 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
message_list = self.message_converter.yiri2target(message) message_list = self.message_converter.yiri2target(message)
msg_seq = 1
for msg in message_list: for msg in message_list:
args = {} args = {}
@@ -462,8 +463,8 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
args["msg_id"] = cached_message_ids[ args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id) str(message_source.message_chain.message_id)
] ]
args["msg_seq"] = msg_seq args["msg_seq"] = self.group_msg_seq
msg_seq += 1 self.group_msg_seq += 1
await self.bot.api.post_group_message(**args) await self.bot.api.post_group_message(**args)
async def is_muted(self, group_id: int) -> bool: async def is_muted(self, group_id: int) -> bool:

View File

@@ -4,6 +4,8 @@ import typing
import enum import enum
import pydantic import pydantic
import mirai
class FunctionCall(pydantic.BaseModel): class FunctionCall(pydantic.BaseModel):
name: str name: str
@@ -19,6 +21,39 @@ class ToolCall(pydantic.BaseModel):
function: FunctionCall function: FunctionCall
class ImageURLContentObject(pydantic.BaseModel):
url: str
def __str__(self):
return self.url[:128] + ('...' if len(self.url) > 128 else '')
class ContentElement(pydantic.BaseModel):
type: str
"""内容类型"""
text: typing.Optional[str] = None
image_url: typing.Optional[ImageURLContentObject] = None
def __str__(self):
if self.type == 'text':
return self.text
elif self.type == 'image_url':
return f'[图片]({self.image_url})'
else:
return '未知内容'
@classmethod
def from_text(cls, text: str):
return cls(type='text', text=text)
@classmethod
def from_image_url(cls, image_url: str):
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))
class Message(pydantic.BaseModel): class Message(pydantic.BaseModel):
"""消息""" """消息"""
@@ -28,12 +63,9 @@ class Message(pydantic.BaseModel):
name: typing.Optional[str] = None name: typing.Optional[str] = None
"""名称,仅函数调用返回时设置""" """名称,仅函数调用返回时设置"""
content: typing.Optional[str] = None content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None
"""内容""" """内容"""
function_call: typing.Optional[FunctionCall] = None
"""函数调用不再受支持请使用tool_calls"""
tool_calls: typing.Optional[list[ToolCall]] = None tool_calls: typing.Optional[list[ToolCall]] = None
"""工具调用""" """工具调用"""
@@ -41,10 +73,47 @@ class Message(pydantic.BaseModel):
def readable_str(self) -> str: def readable_str(self) -> str:
if self.content is not None: if self.content is not None:
return self.content return str(self.role) + ": " + str(self.get_content_mirai_message_chain())
elif self.function_call is not None:
return f'{self.function_call.name}({self.function_call.arguments})'
elif self.tool_calls is not None: elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}' return f'调用工具: {self.tool_calls[0].id}'
else: else:
return '未知消息' return '未知消息'
def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None:
"""将内容转换为 Mirai MessageChain 对象
Args:
prefix_text (str): 首个文字组件的前缀文本
"""
if self.content is None:
return None
elif isinstance(self.content, str):
return mirai.MessageChain([mirai.Plain(prefix_text+self.content)])
elif isinstance(self.content, list):
mc = []
for ce in self.content:
if ce.type == 'text':
mc.append(mirai.Plain(ce.text))
elif ce.type == 'image':
if ce.image_url.url.startswith("http"):
mc.append(mirai.Image(url=ce.image_url.url))
else: # base64
b64_str = ce.image_url.url
if b64_str.startswith("data:"):
b64_str = b64_str.split(",")[1]
mc.append(mirai.Image(base64=b64_str))
# 找第一个文字组件
if prefix_text:
for i, c in enumerate(mc):
if isinstance(c, mirai.Plain):
mc[i] = mirai.Plain(prefix_text+c.text)
break
else:
mc.insert(0, mirai.Plain(prefix_text))
return mirai.MessageChain(mc)

View File

@@ -6,6 +6,8 @@ import typing
from ...core import app from ...core import app
from ...core import entities as core_entities from ...core import entities as core_entities
from .. import entities as llm_entities from .. import entities as llm_entities
from . import entities as modelmgr_entities
from ..tools import entities as tools_entities
preregistered_requesters: list[typing.Type[LLMAPIRequester]] = [] preregistered_requesters: list[typing.Type[LLMAPIRequester]] = []
@@ -33,20 +35,31 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
async def initialize(self): async def initialize(self):
pass pass
@abc.abstractmethod async def preprocess(
async def request(
self, self,
query: core_entities.Query, query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]: ):
"""请求API """预处理
在这里处理特定API对Query对象的兼容性问题。
"""
pass
对话前文可以从 query 对象中获取。 @abc.abstractmethod
可以多次yield消息对象。 async def call(
self,
model: modelmgr_entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
"""调用API
Args: Args:
query (core_entities.Query): 本次请求的上下文对象 model (modelmgr_entities.LLMModelInfo): 使用的模型信息
messages (typing.List[llm_entities.Message]): 消息对象列表
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
Yields: Returns:
pkg.provider.entities.Message: 返回消息对象 llm_entities.Message: 返回消息对象
""" """
raise NotImplementedError pass

View File

@@ -11,6 +11,7 @@ from .. import api, entities, errors
from ....core import entities as core_entities from ....core import entities as core_entities
from ... import entities as llm_entities from ... import entities as llm_entities
from ...tools import entities as tools_entities from ...tools import entities as tools_entities
from ....utils import image
@api.requester_class("anthropic-messages") @api.requester_class("anthropic-messages")
@@ -27,47 +28,76 @@ class AnthropicMessages(api.LLMAPIRequester):
proxies=self.ap.proxy_mgr.get_forward_proxies() proxies=self.ap.proxy_mgr.get_forward_proxies()
) )
async def request( async def call(
self, self,
query: core_entities.Query, model: entities.LLMModelInfo,
) -> typing.AsyncGenerator[llm_entities.Message, None]: messages: typing.List[llm_entities.Message],
self.client.api_key = query.use_model.token_mgr.get_token() funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = model.token_mgr.get_token()
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy() args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name args["model"] = model.name if model.model_name is None else model.model_name
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行 # 处理消息
m.dict(exclude_none=True) for m in query.prompt.messages
] + [m.dict(exclude_none=True) for m in query.messages]
# 删除所有 role=system & content='' 的消息 # system
req_messages = [ system_role_message = None
m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "")
]
# 检查是否有 role=system 的消息,若有,改为 role=user并在后面加一个 role=assistant 的消息 for i, m in enumerate(messages):
system_role_index = [] if m.role == "system":
for i, m in enumerate(req_messages): system_role_message = m
if m["role"] == "system":
system_role_index.append(i)
m["role"] = "user"
if system_role_index: messages.pop(i)
for i in system_role_index[::-1]: break
req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."})
# 忽略掉空消息,用户可能发送空消息,而上层未过滤 if isinstance(system_role_message, llm_entities.Message) \
req_messages = [ and isinstance(system_role_message.content, str):
m for m in req_messages if m["content"].strip() != "" args['system'] = system_role_message.content
]
req_messages = []
for m in messages:
if isinstance(m.content, str) and m.content.strip() != "":
req_messages.append(m.dict(exclude_none=True))
elif isinstance(m.content, list):
# m.content = [
# c for c in m.content if c.type == "text"
# ]
# if len(m.content) > 0:
# req_messages.append(m.dict(exclude_none=True))
msg_dict = m.dict(exclude_none=True)
for i, ce in enumerate(m.content):
if ce.type == "image_url":
alter_image_ele = {
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": await image.qq_image_url_to_base64(ce.image_url.url)
}
}
msg_dict["content"][i] = alter_image_ele
req_messages.append(msg_dict)
args["messages"] = req_messages args["messages"] = req_messages
try: # anthropic的tools处在beta阶段sdk不稳定故暂时不支持
#
# if funcs:
# tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
# if tools:
# args["tools"] = tools
try:
resp = await self.client.messages.create(**args) resp = await self.client.messages.create(**args)
yield llm_entities.Message( return llm_entities.Message(
content=resp.content[0].text, content=resp.content[0].text,
role=resp.role role=resp.role
) )
@@ -79,4 +109,4 @@ class AnthropicMessages(api.LLMAPIRequester):
if 'model: ' in str(e): if 'model: ' in str(e):
raise errors.RequesterError(f'模型无效: {e.message}') raise errors.RequesterError(f'模型无效: {e.message}')
else: else:
raise errors.RequesterError(f'请求地址无效: {e.message}') raise errors.RequesterError(f'请求地址无效: {e.message}')

View File

@@ -3,16 +3,20 @@ from __future__ import annotations
import asyncio import asyncio
import typing import typing
import json import json
import base64
from typing import AsyncGenerator from typing import AsyncGenerator
import openai import openai
import openai.types.chat.chat_completion as chat_completion import openai.types.chat.chat_completion as chat_completion
import httpx import httpx
import aiohttp
import async_lru
from .. import api, entities, errors from .. import api, entities, errors
from ....core import entities as core_entities, app from ....core import entities as core_entities, app
from ... import entities as llm_entities from ... import entities as llm_entities
from ...tools import entities as tools_entities from ...tools import entities as tools_entities
from ....utils import image
@api.requester_class("openai-chat-completions") @api.requester_class("openai-chat-completions")
@@ -43,7 +47,6 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
self, self,
args: dict, args: dict,
) -> chat_completion.ChatCompletion: ) -> chat_completion.ChatCompletion:
self.ap.logger.debug(f"req chat_completion with args {args}")
return await self.client.chat.completions.create(**args) return await self.client.chat.completions.create(**args)
async def _make_msg( async def _make_msg(
@@ -67,14 +70,22 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
args = self.requester_cfg['args'].copy() args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_model.tool_call_supported: if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools: if tools:
args["tools"] = tools args["tools"] = tools
# 设置此次请求中的messages # 设置此次请求中的messages
messages = req_messages messages = req_messages.copy()
# 检查vision
for msg in messages:
if 'content' in msg and isinstance(msg["content"], list):
for me in msg["content"]:
if me["type"] == "image_url":
me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url'])
args["messages"] = messages args["messages"] = messages
# 发送请求 # 发送请求
@@ -84,59 +95,26 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
message = await self._make_msg(resp) message = await self._make_msg(resp)
return message return message
async def call(
self,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
for m in messages:
msg_dict = m.dict(exclude_none=True)
content = msg_dict.get("content")
if isinstance(content, list):
# 检查 content 列表中是否每个部分都是文本
if all(isinstance(part, dict) and part.get("type") == "text" for part in content):
# 将所有文本部分合并为一个字符串
msg_dict["content"] = "\n".join(part["text"] for part in content)
req_messages.append(msg_dict)
async def _request(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求"""
pending_tool_calls = []
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
m.dict(exclude_none=True) for m in query.prompt.messages
] + [m.dict(exclude_none=True) for m in query.messages]
# req_messages.append({"role": "user", "content": str(query.message_chain)})
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg.dict(exclude_none=True))
while pending_tool_calls:
for tool_call in pending_tool_calls:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg.dict(exclude_none=True))
# 处理完所有调用,继续请求
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg.dict(exclude_none=True))
async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]:
try: try:
async for msg in self._request(query): return await self._closure(req_messages, model, funcs)
yield msg
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise errors.RequesterError('请求超时') raise errors.RequesterError('请求超时')
except openai.BadRequestError as e: except openai.BadRequestError as e:
@@ -149,6 +127,16 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
except openai.NotFoundError as e: except openai.NotFoundError as e:
raise errors.RequesterError(f'请求路径错误: {e.message}') raise errors.RequesterError(f'请求路径错误: {e.message}')
except openai.RateLimitError as e: except openai.RateLimitError as e:
raise errors.RequesterError(f'请求过于频繁: {e.message}') raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
except openai.APIError as e: except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}') raise errors.RequesterError(f'请求错误: {e.message}')
@async_lru.alru_cache(maxsize=128)
async def get_base64_str(
self,
original_url: str,
) -> str:
base64_image = await image.qq_image_url_to_base64(original_url)
return f"data:image/jpeg;base64,{base64_image}"

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from ....core import app
from . import chatcmpl
from .. import api, entities, errors
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("deepseek-chat-completions")
class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Deepseek ChatCompletion API 请求器"""
def __init__(self, ap: app.Application):
self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions']
self.ap = ap
async def _closure(
self,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
# 设置此次请求中的messages
messages = req_messages
# deepseek 不支持多模态把content都转换成纯文字
for m in messages:
if 'content' in m and isinstance(m["content"], list):
m["content"] = " ".join([c["text"] for c in m["content"]])
args["messages"] = messages
# 发送请求
resp = await self._req(args)
# 处理请求结果
message = await self._make_msg(resp)
return message

View File

@@ -3,7 +3,10 @@ from __future__ import annotations
from ....core import app from ....core import app
from . import chatcmpl from . import chatcmpl
from .. import api from .. import api, entities, errors
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("moonshot-chat-completions") @api.requester_class("moonshot-chat-completions")
@@ -13,3 +16,41 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions'] self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions']
self.ap = ap self.ap = ap
async def _closure(
self,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
# 设置此次请求中的messages
messages = req_messages
# deepseek 不支持多模态把content都转换成纯文字
for m in messages:
if 'content' in m and isinstance(m["content"], list):
m["content"] = " ".join([c["text"] for c in m["content"]])
# 删除空的
messages = [m for m in messages if m["content"].strip() != ""]
args["messages"] = messages
# 发送请求
resp = await self._req(args)
# 处理请求结果
message = await self._make_msg(resp)
return message

View File

@@ -21,5 +21,7 @@ class LLMModelInfo(pydantic.BaseModel):
tool_call_supported: typing.Optional[bool] = False tool_call_supported: typing.Optional[bool] = False
vision_supported: typing.Optional[bool] = False
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True

View File

@@ -6,7 +6,7 @@ from . import entities
from ...core import app from ...core import app
from . import token, api from . import token, api
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list" FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
@@ -37,7 +37,7 @@ class ModelManager:
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置") raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
async def initialize(self): async def initialize(self):
# 初始化token_mgr, requester # 初始化token_mgr, requester
for k, v in self.ap.provider_cfg.data['keys'].items(): for k, v in self.ap.provider_cfg.data['keys'].items():
self.token_mgrs[k] = token.TokenManager(k, v) self.token_mgrs[k] = token.TokenManager(k, v)
@@ -83,7 +83,8 @@ class ModelManager:
model_name=None, model_name=None,
token_mgr=self.token_mgrs[model['token_mgr']], token_mgr=self.token_mgrs[model['token_mgr']],
requester=self.requesters[model['requester']], requester=self.requesters[model['requester']],
tool_call_supported=model['tool_call_supported'] tool_call_supported=model['tool_call_supported'],
vision_supported=model['vision_supported']
) )
break break
@@ -95,13 +96,15 @@ class ModelManager:
token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr
requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported) tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported)
vision_supported = model.get('vision_supported', default_model_info.vision_supported)
model_info = entities.LLMModelInfo( model_info = entities.LLMModelInfo(
name=model['name'], name=model['name'],
model_name=model_name, model_name=model_name,
token_mgr=token_mgr, token_mgr=token_mgr,
requester=requester, requester=requester,
tool_call_supported=tool_call_supported tool_call_supported=tool_call_supported,
vision_supported=vision_supported
) )
self.model_list.append(model_info) self.model_list.append(model_info)

View File

@@ -9,11 +9,10 @@ from ...plugin import context as plugin_context
class ToolManager: class ToolManager:
"""LLM工具管理器 """LLM工具管理器"""
"""
ap: app.Application ap: app.Application
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.ap = ap self.ap = ap
self.all_functions = [] self.all_functions = []
@@ -22,35 +21,33 @@ class ToolManager:
pass pass
async def get_function(self, name: str) -> entities.LLMFunction: async def get_function(self, name: str) -> entities.LLMFunction:
"""获取函数 """获取函数"""
"""
for function in await self.get_all_functions(): for function in await self.get_all_functions():
if function.name == name: if function.name == name:
return function return function
return None return None
async def get_function_and_plugin(self, name: str) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]: async def get_function_and_plugin(
"""获取函数和插件 self, name: str
""" ) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]:
"""获取函数和插件"""
for plugin in self.ap.plugin_mgr.plugins: for plugin in self.ap.plugin_mgr.plugins:
for function in plugin.content_functions: for function in plugin.content_functions:
if function.name == name: if function.name == name:
return function, plugin return function, plugin.plugin_inst
return None, None return None, None
async def get_all_functions(self) -> list[entities.LLMFunction]: async def get_all_functions(self) -> list[entities.LLMFunction]:
"""获取所有函数 """获取所有函数"""
"""
all_functions: list[entities.LLMFunction] = [] all_functions: list[entities.LLMFunction] = []
for plugin in self.ap.plugin_mgr.plugins: for plugin in self.ap.plugin_mgr.plugins:
all_functions.extend(plugin.content_functions) all_functions.extend(plugin.content_functions)
return all_functions return all_functions
async def generate_tools_for_openai(self, use_funcs: entities.LLMFunction) -> str: async def generate_tools_for_openai(self, use_funcs: list[entities.LLMFunction]) -> list:
"""生成函数列表 """生成函数列表"""
"""
tools = [] tools = []
for function in use_funcs: for function in use_funcs:
@@ -60,40 +57,71 @@ class ToolManager:
"function": { "function": {
"name": function.name, "name": function.name,
"description": function.description, "description": function.description,
"parameters": function.parameters "parameters": function.parameters,
} },
}
tools.append(function_schema)
return tools
async def generate_tools_for_anthropic(
self, use_funcs: list[entities.LLMFunction]
) -> list:
"""为anthropic生成函数列表
e.g.
[
{
"name": "get_stock_price",
"description": "Get the current stock price for a given ticker symbol.",
"input_schema": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "The stock ticker symbol, e.g. AAPL for Apple Inc."
}
},
"required": ["ticker"]
}
}
]
"""
tools = []
for function in use_funcs:
if function.enable:
function_schema = {
"name": function.name,
"description": function.description,
"input_schema": function.parameters,
} }
tools.append(function_schema) tools.append(function_schema)
return tools return tools
async def execute_func_call( async def execute_func_call(
self, self, query: core_entities.Query, name: str, parameters: dict
query: core_entities.Query,
name: str,
parameters: dict
) -> typing.Any: ) -> typing.Any:
"""执行函数调用 """执行函数调用"""
"""
try: try:
function, plugin = await self.get_function_and_plugin(name) function, plugin = await self.get_function_and_plugin(name)
if function is None: if function is None:
return None return None
parameters = parameters.copy() parameters = parameters.copy()
parameters = { parameters = {"query": query, **parameters}
"query": query,
**parameters
}
return await function.func(plugin, **parameters) return await function.func(plugin, **parameters)
except Exception as e: except Exception as e:
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}') self.ap.logger.error(f"执行函数 {name} 时发生错误: {e}")
traceback.print_exc() traceback.print_exc()
return f'error occurred when executing function {name}: {e}' return f"error occurred when executing function {name}: {e}"
finally: finally:
plugin = None plugin = None
@@ -107,11 +135,11 @@ class ToolManager:
await self.ap.ctr_mgr.usage.post_function_record( await self.ap.ctr_mgr.usage.post_function_record(
plugin={ plugin={
'name': plugin.plugin_name, "name": plugin.plugin_name,
'remote': plugin.plugin_source, "remote": plugin.plugin_source,
'version': plugin.plugin_version, "version": plugin.plugin_version,
'author': plugin.plugin_author "author": plugin.plugin_author,
}, },
function_name=function.name, function_name=function.name,
function_description=function.description, function_description=function.description,
) )

View File

@@ -4,6 +4,7 @@ import json
import typing import typing
import os import os
import base64 import base64
import logging
import pydantic import pydantic
import requests import requests
@@ -107,17 +108,20 @@ class AnnouncementManager:
async def show_announcements( async def show_announcements(
self self
): ) -> typing.Tuple[str, int]:
"""显示公告""" """显示公告"""
try: try:
announcements = await self.fetch_new() announcements = await self.fetch_new()
ann_text = ""
for ann in announcements: for ann in announcements:
self.ap.logger.info(f'[公告] {ann.time}: {ann.content}') ann_text += f"[公告] {ann.time}: {ann.content}\n"
if announcements: if announcements:
await self.ap.ctr_mgr.main.post_announcement_showed( await self.ap.ctr_mgr.main.post_announcement_showed(
ids=[item.id for item in announcements] ids=[item.id for item in announcements]
) )
return ann_text, logging.INFO
except Exception as e: except Exception as e:
self.ap.logger.warning(f'获取公告时出错: {e}') return f'获取公告时出错: {e}', logging.WARNING

View File

@@ -1 +1 @@
semantic_version = "v3.1.0.1" semantic_version = "v3.2.4"

41
pkg/utils/image.py Normal file
View File

@@ -0,0 +1,41 @@
import base64
import typing
from urllib.parse import urlparse, parse_qs
import ssl
import aiohttp
async def qq_image_url_to_base64(
image_url: str
) -> str:
"""将QQ图片URL转为base64
Args:
image_url (str): QQ图片URL
Returns:
str: base64编码
"""
parsed = urlparse(image_url)
query = parse_qs(parsed.query)
# Flatten the query dictionary
query = {k: v[0] for k, v in query.items()}
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession(trust_env=False) as session:
async with session.get(
f"http://{parsed.netloc}{parsed.path}",
params=query,
ssl=ssl_context
) as resp:
resp.raise_for_status() # 检查HTTP错误
file_bytes = await resp.read()
base64_str = base64.b64encode(file_bytes).decode()
return base64_str

View File

@@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import os import os
import typing
import logging
import time import time
import requests import requests
@@ -213,11 +215,11 @@ class VersionManager:
async def show_version_update( async def show_version_update(
self self
): ) -> typing.Tuple[str, int]:
try: try:
if await self.ap.ver_mgr.is_new_version_available(): if await self.ap.ver_mgr.is_new_version_available():
self.ap.logger.info("有新版本可用,请使用 !update 命令更新") return "有新版本可用,请使用管理员账号发送 !update 命令更新", logging.INFO
except Exception as e: except Exception as e:
self.ap.logger.warning(f"检查版本更新时出错: {e}") return f"检查版本更新时出错: {e}", logging.WARNING

View File

@@ -12,4 +12,6 @@ PyYaml
aiohttp aiohttp
pydantic pydantic
websockets websockets
urllib3 urllib3
psutil
async-lru

View File

@@ -4,23 +4,73 @@
"name": "default", "name": "default",
"requester": "openai-chat-completions", "requester": "openai-chat-completions",
"token_mgr": "openai", "token_mgr": "openai",
"tool_call_supported": false "tool_call_supported": false,
"vision_supported": false
},
{
"name": "gpt-3.5-turbo-0125",
"tool_call_supported": true,
"vision_supported": false
}, },
{ {
"name": "gpt-3.5-turbo", "name": "gpt-3.5-turbo",
"tool_call_supported": true "tool_call_supported": true,
"vision_supported": false
}, },
{ {
"name": "gpt-4", "name": "gpt-3.5-turbo-1106",
"tool_call_supported": true "tool_call_supported": true,
"vision_supported": false
},
{
"name": "gpt-4-turbo",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-turbo-2024-04-09",
"tool_call_supported": true,
"vision_supported": true
}, },
{ {
"name": "gpt-4-turbo-preview", "name": "gpt-4-turbo-preview",
"tool_call_supported": true "tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-0125-preview",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-1106-preview",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4o",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-0613",
"tool_call_supported": true,
"vision_supported": true
}, },
{ {
"name": "gpt-4-32k", "name": "gpt-4-32k",
"tool_call_supported": true "tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-32k-0613",
"tool_call_supported": true,
"vision_supported": true
}, },
{ {
"model_name": "SparkDesk", "model_name": "SparkDesk",
@@ -33,32 +83,43 @@
{ {
"name": "claude-3-opus-20240229", "name": "claude-3-opus-20240229",
"requester": "anthropic-messages", "requester": "anthropic-messages",
"token_mgr": "anthropic" "token_mgr": "anthropic",
"vision_supported": true
}, },
{ {
"name": "claude-3-sonnet-20240229", "name": "claude-3-sonnet-20240229",
"requester": "anthropic-messages", "requester": "anthropic-messages",
"token_mgr": "anthropic" "token_mgr": "anthropic",
"vision_supported": true
}, },
{ {
"name": "claude-3-haiku-20240307", "name": "claude-3-haiku-20240307",
"requester": "anthropic-messages", "requester": "anthropic-messages",
"token_mgr": "anthropic" "token_mgr": "anthropic",
"vision_supported": true
}, },
{ {
"name": "moonshot-v1-8k", "name": "moonshot-v1-8k",
"requester": "moonshot-chat-completions", "requester": "moonshot-chat-completions",
"token_mgr": "moonshot" "token_mgr": "moonshot",
"tool_call_supported": true
}, },
{ {
"name": "moonshot-v1-32k", "name": "moonshot-v1-32k",
"requester": "moonshot-chat-completions", "requester": "moonshot-chat-completions",
"token_mgr": "moonshot" "token_mgr": "moonshot",
"tool_call_supported": true
}, },
{ {
"name": "moonshot-v1-128k", "name": "moonshot-v1-128k",
"requester": "moonshot-chat-completions", "requester": "moonshot-chat-completions",
"token_mgr": "moonshot" "token_mgr": "moonshot",
"tool_call_supported": true
},
{
"name": "deepseek-chat",
"requester": "deepseek-chat-completions",
"token_mgr": "deepseek"
} }
] ]
} }

View File

@@ -25,12 +25,20 @@
"api-key": "", "api-key": "",
"api-secret": "" "api-secret": ""
}, },
"submit-messages-tokens": 3072,
"rate-limit": { "rate-limit": {
"strategy": "drop", "strategy": "drop",
"algo": "fixwin", "algo": "fixwin",
"fixwin": { "fixwin": {
"default": 60 "default": {
"window-size": 60,
"limit": 60
}
}
},
"msg-truncate": {
"method": "round",
"round": {
"max-round": 10
} }
} }
} }

View File

@@ -1,5 +1,6 @@
{ {
"enable-chat": true, "enable-chat": true,
"enable-vision": true,
"keys": { "keys": {
"openai": [ "openai": [
"sk-1234567890" "sk-1234567890"
@@ -9,6 +10,9 @@
], ],
"moonshot": [ "moonshot": [
"sk-1234567890" "sk-1234567890"
],
"deepseek": [
"sk-1234567890"
] ]
}, },
"requester": { "requester": {
@@ -28,6 +32,11 @@
"base-url": "https://api.moonshot.cn/v1", "base-url": "https://api.moonshot.cn/v1",
"args": {}, "args": {},
"timeout": 120 "timeout": 120
},
"deepseek-chat-completions": {
"base-url": "https://api.deepseek.com",
"args": {},
"timeout": 120
} }
}, },
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",

View File

@@ -10,5 +10,6 @@
"default": 1 "default": 1
}, },
"pipeline-concurrency": 20, "pipeline-concurrency": 20,
"qcg-center-url": "https://api.qchatgpt.rockchin.top/api/v2",
"help-message": "QChatGPT - 😎高稳定性、🧩支持插件、🌏实时联网的 ChatGPT QQ 机器人🤖\n链接https://q.rkcn.top" "help-message": "QChatGPT - 😎高稳定性、🧩支持插件、🌏实时联网的 ChatGPT QQ 机器人🤖\n链接https://q.rkcn.top"
} }