Compare commits

...

204 Commits

Author SHA1 Message Date
RockChinQ
ee0d6dcdae chore: release v3.3.1.0 2024-09-08 15:14:24 +08:00
Junyan Qin
bcf1d92f73 Merge pull request #881 from RockChinQ/version/3.3.1.0
Version/3.3.1.0
2024-09-08 15:13:39 +08:00
RockChinQ
ffdec16ce6 docs: wiki 所有页面加上已弃用说明 2024-09-08 14:52:35 +08:00
RockChinQ
b2f6e84adc typo: 优化插件执行日志信息 2024-09-08 14:51:39 +08:00
Junyan Qin
f76c457e1f Update README.md 2024-09-03 20:07:41 +08:00
RockChinQ
80bd0a20df doc: 修复 README 中的logo图片 2024-08-30 14:48:23 +08:00
RockChinQ
efeaf73339 doc: 修改README图片链接 2024-08-30 11:13:04 +08:00
Junyan Qin
91b5100a24 Merge pull request #872 from RockChinQ/feat/config-file-api
Feat: 添加yaml配置文件的支持
2024-08-24 20:55:19 +08:00
RockChinQ
d1a06f4730 feat: 添加yaml配置文件的支持 2024-08-24 20:54:36 +08:00
Junyan Qin
b0b186e951 Merge pull request #871 from RockChinQ/feat/qq-c2c
Feat: 添加对 QQ 官方 API 私聊场景的支持
2024-08-24 17:04:41 +08:00
RockChinQ
4c8fedef6e feat: QQ官方api群聊和私聊支持图片 2024-08-24 17:01:35 +08:00
RockChinQ
718c221d01 feat: 支持官方机器人私信接口 2024-08-24 16:26:47 +08:00
Junyan Qin
077e77eee5 Merge pull request #869 from ligen131/lg/fix_image_format
fix: 发送正确的图片格式而不是默认的 `image/jpeg`
2024-08-24 15:47:55 +08:00
ligen131
b51ca06c7c fix: 发送正确的图片格式而不是默认的 image/jpeg 2024-08-19 00:00:29 +08:00
RockChinQ
2f092f4a87 chore: release v3.3.0.2 2024-08-01 23:14:07 +08:00
Junyan Qin
f1ff9c05c4 Merge pull request #864 from RockChinQ/version/3.3.0.2
fix: 消息忽略规则失效 (#854)
2024-08-01 23:12:33 +08:00
RockChinQ
c9c8603ccc fix: 消息忽略规则失效 (#854) 2024-08-01 23:01:28 +08:00
RockChinQ
47e281fb61 chore: release v3.3.0.1 2024-07-28 22:47:49 +08:00
RockChinQ
dc625647eb fix: ollama 依赖检查 2024-07-28 22:47:19 +08:00
RockChinQ
66cf1b05be chore: 优化issue和pr模板 2024-07-28 21:32:22 +08:00
RockChinQ
622cc89414 chore: release v3.3.0 2024-07-28 20:58:29 +08:00
Junyan Qin
78d98c40b1 Merge pull request #847 from RockChinQ/version/3.3
Release: 3.3
2024-07-28 20:57:26 +08:00
RockChinQ
1c5f06d9a9 feat: 添加 reply 和 send_message 两个插件api方法 2024-07-28 20:23:52 +08:00
Junyan Qin
998fe5a980 Merge pull request #857 from RockChinQ/feat/runner-abstraction
Feat: Runner 组件抽象
2024-07-28 18:47:38 +08:00
RockChinQ
8cad4089a7 feat: runner 层抽象 (#839) 2024-07-28 18:45:27 +08:00
RockChinQ
48cc3656bd feat: 允许自定义命令前缀 2024-07-28 16:01:58 +08:00
RockChinQ
68ddb3a6e1 feat: 添加 model 命令 2024-07-28 15:46:09 +08:00
ElvisChenML
70583f5ba0 Fixed aiocqhttp mirai.Voice类型无法正确传递url及base64的异常 2024-07-28 15:08:33 +08:00
Junyan Qin
5bebe01dd0 Update README.md 2024-07-28 15:08:33 +08:00
Junyan Qin
4dd976c9c5 Merge pull request #856 from ElvisChenML/pr
Fixed aiocqhttp mirai.Voice类型无法正确传递url及base64的异常
2024-07-28 13:05:06 +08:00
ElvisChenML
221b310485 Fixed aiocqhttp mirai.Voice类型无法正确传递url及base64的异常 2024-07-25 16:14:24 +08:00
Junyan Qin
dd1cec70c0 Update README.md 2024-07-13 09:15:18 +08:00
Junyan Qin
7656443b28 Merge pull request #845 from ElvisChenML/pr
fixed pkg\provider\entities.py\get_content_mirai_message_chain中ce.type图片类型不正确的异常
2024-07-10 00:13:48 +08:00
Junyan Qin
9d91c13b12 Merge pull request #844 from canyuan0801/pr
Feat: Ollama平台集成
2024-07-10 00:09:48 +08:00
RockChinQ
7c06141ce2 perf(ollama): 优化命令显示细节 2024-07-10 00:07:32 +08:00
RockChinQ
3dc413638b feat(ollama): 配置文件迁移 2024-07-09 23:37:34 +08:00
RockChinQ
bdb8baeddd perf(ollama): 修改请求器名称以适配请求路径 2024-07-09 23:37:19 +08:00
ElvisChenML
21966bfb69 fixed pkg\provider\entities.py\get_content_mirai_message_chain中ce.type图片类型不正确的异常 2024-07-09 17:04:11 +08:00
canyuan
e78c82e999 mod: merge ollama cmd 2024-07-09 16:19:09 +08:00
canyuan
2bdc3468d1 add ollama cmd 2024-07-09 14:57:39 +08:00
canyuan
987b3dc4ef add ollama chat 2024-07-09 14:57:28 +08:00
RockChinQ
45a10b4ac7 chore: release v3.2.4 2024-07-05 18:19:10 +08:00
RockChinQ
b5d33ef629 perf: 优化 pipeline 处理时的报错 2024-07-04 13:03:58 +08:00
RockChinQ
d3629916bf fix: user_notice 处理时为对齐为 MessageChain (#809) 2024-07-04 12:47:55 +08:00
RockChinQ
c5cb26d295 fix: GroupNormalMessageReceived事件设置 alter 无效 (#803) 2024-07-03 23:16:16 +08:00
RockChinQ
4b2785c5eb fix: QQ 官方 API 图片识别功能不正常 (#825) 2024-07-03 22:36:35 +08:00
RockChinQ
7ed190e6d2 doc: 删除广告 2024-07-03 17:50:58 +08:00
Junyan Qin
eac041cdd2 Merge pull request #834 from RockChinQ/feat/env-reminder
Feat: 添加启动信息阶段
2024-07-03 17:45:56 +08:00
RockChinQ
05527cfc01 feat: 添加 windows 下针对选择模式的提示 2024-07-03 17:44:10 +08:00
RockChinQ
61e2af4a14 feat: 添加启动信息阶段 2024-07-03 17:34:23 +08:00
RockChinQ
79804b6ecd chore: release 3.2.3 2024-06-26 10:55:21 +08:00
Junyan Qin
76434b2f4e Merge pull request #829 from RockChinQ/version/3.2.3
Release 3.2.3
2024-06-26 10:54:42 +08:00
RockChinQ
ec8bd4922e fix: 错误地resprule选择逻辑 (#810) 2024-06-26 10:37:08 +08:00
RockChinQ
4ffa773fac fix: 前缀响应时图片被错误地转换为文字 (#820) 2024-06-26 10:15:21 +08:00
Junyan Qin
ea8b7bc8aa Merge pull request #818 from Huoyuuu/master
fix: ensure content is string in chatcmpl call method
2024-06-24 17:12:46 +08:00
RockChinQ
39ce5646f6 perf: content元素拼接时使用换行符间隔 2024-06-24 17:04:50 +08:00
Huoyuuu
5092a82739 Update chatcmpl.py 2024-06-19 19:13:00 +08:00
Huoyuuu
3bba0b6d9a Merge pull request #1 from Huoyuuu/fix/issue-817-ensure-content-string
fix: ensure content is string in chatcmpl call method
2024-06-19 17:32:30 +08:00
Huoyuuu
7a19dd503d fix: ensure content is string in chatcmpl call method
fix: ensure content is string in chatcmpl call method

- Ensure user message content is a string instead of an array
- Updated `call` method in `chatcmpl.py` to guarantee content is a string
- Resolves compatibility issue with the yi-large model
2024-06-19 17:26:06 +08:00
RockChinQ
9e6a01fefd chore: release v3.2.2 2024-05-31 19:20:34 +08:00
RockChinQ
933471b4d9 perf: 启动失败时输出完整traceback (#799) 2024-05-31 15:37:56 +08:00
RockChinQ
f81808d239 perf: 添加JSON配置文件语法检查 (#796) 2024-05-29 21:11:21 +08:00
RockChinQ
96832b6f7d perf: 忽略空的 assistant content 消息 (#795) 2024-05-29 21:00:48 +08:00
Junyan Qin
e2eb0a84b0 Merge pull request #797 from RockChinQ/feat/context-truncater
Feat: 消息截断器
2024-05-29 20:38:14 +08:00
RockChinQ
c8eb2e3376 feat: 消息截断器 2024-05-29 20:34:49 +08:00
Junyan Qin
21fe5822f9 Merge pull request #794 from RockChinQ/perf/advanced-fixwin
Feat: fixwin限速支持设置窗口大小
2024-05-26 10:33:49 +08:00
RockChinQ
d49cc9a7a3 feat: fixwin限速支持设置窗口大小 (#791) 2024-05-26 10:29:10 +08:00
Junyan Qin
910d0bfae1 Update README.md 2024-05-25 12:27:27 +08:00
RockChinQ
d6761949ca chore: release v3.2.1 2024-05-23 16:29:26 +08:00
RockChinQ
6afac1f593 feat: 允许指定遥测服务器url 2024-05-23 16:25:51 +08:00
RockChinQ
4d1a270d22 doc: 添加qcg-center源码链接 2024-05-23 16:16:13 +08:00
Junyan Qin
a7888f5536 Merge pull request #787 from RockChinQ/perf/claude-ability
Perf: Claude 的能力完善支持
2024-05-22 20:33:39 +08:00
RockChinQ
b9049e91cf chore: 同步 llm-models.json 2024-05-22 20:31:46 +08:00
RockChinQ
7db56c8e77 feat: claude 支持视觉 2024-05-22 20:09:29 +08:00
Junyan Qin
50563cb957 Merge pull request #785 from RockChinQ/fix/msg-chain-compability
Fix: 修复 query.resp_messages 对插件reply的兼容性
2024-05-18 20:13:50 +08:00
RockChinQ
18ae2299a7 fix: 修复 query.resp_messages 对插件reply的兼容性 2024-05-18 20:08:48 +08:00
RockChinQ
7463e0aab9 perf: 删除多个地方残留的 config.py 字段 (#781) 2024-05-18 18:52:45 +08:00
Junyan Qin
c92d47bb95 Merge pull request #779 from jerryliang122/master
修复aiocqhttp的图片错误
2024-05-17 17:05:58 +08:00
RockChinQ
0b1af7df91 perf: 统一判断方式 2024-05-17 17:05:20 +08:00
jerryliang122
a9104eb2da 通过base64编码发送,修复cqhttp无法发送图片 2024-05-17 08:20:06 +00:00
RockChinQ
abbd15d5cc chore: release v3.2.0.1 2024-05-17 09:48:20 +08:00
RockChinQ
aadfa14d59 fix: claude 请求失败 2024-05-17 09:46:06 +08:00
Junyan Qin
4cd10bbe25 Update README.md 2024-05-16 22:17:46 +08:00
RockChinQ
1d4a6b71ab chore: release v3.2.0 2024-05-16 21:22:40 +08:00
Junyan Qin
a7f830dd73 Merge pull request #773 from RockChinQ/feat/multi-modal
Feat: 多模态
2024-05-16 21:13:15 +08:00
RockChinQ
bae86ac05c chore: 恢复版本号 2024-05-16 21:03:56 +08:00
RockChinQ
a3706bfe21 perf: 细节优化 2024-05-16 21:02:59 +08:00
RockChinQ
91e23b8c11 perf: 为图片base64函数添加lru 2024-05-16 20:52:17 +08:00
RockChinQ
37ef1c9fab feat: 删除oss相关代码 2024-05-16 20:32:30 +08:00
RockChinQ
6bc6f77af1 feat: 通过 base64 传输图片 2024-05-16 20:25:51 +08:00
RockChinQ
2c478ccc25 feat: 模型vision支持性参数 2024-05-16 20:11:54 +08:00
RockChinQ
404e5492a3 chore: 同步现有模型信息 2024-05-16 18:29:23 +08:00
RockChinQ
d5b5d667a5 feat: 模型视觉多模态支持 2024-05-15 21:40:18 +08:00
RockChinQ
8807f02f36 perf: resp_message_chain 改为 list 类型 (#770) 2024-05-14 23:08:49 +08:00
RockChinQ
269e561497 perf: messages 存回 conversation 应该仅在成功执行本次请求时执行 (#769) 2024-05-14 22:41:39 +08:00
RockChinQ
527ad81d38 feat: 解藕chat的处理器和请求器 (#772) 2024-05-14 22:20:31 +08:00
Junyan Qin
972d3c18af Update README.md 2024-05-08 21:49:45 +08:00
Junyan Qin
3cbfc078fc doc(README.md): 更新 社区四群群号 2024-05-08 21:46:19 +08:00
RockChinQ
fde6822b5c chore: release v3.1.1 2024-05-08 02:28:40 +00:00
Junyan Qin
930321bcf1 Merge pull request #762 from RockChinQ/feat/deepseek
Feat: 支持 deepseek 模型
2024-05-07 22:48:37 +08:00
RockChinQ
c45931363a feat: deepseek配置迁移 2024-05-07 14:45:59 +00:00
RockChinQ
9c6491e5ee feat: 支持 deepseek 的模型 2024-05-07 14:28:52 +00:00
RockChinQ
9bc248f5bc feat: 删除submit-messages-tokens配置项 2024-05-07 12:32:54 +00:00
Junyan Qin
becac2fde5 doc(README.md): 添加 GitHub Trending 徽标 2024-04-29 21:00:22 +08:00
RockChinQ
1e1a103882 feat: aiocqhttp允许使用图片链接作为参数 2024-04-11 03:26:12 +00:00
RockChinQ
e5cffb7c9b chore: release v3.1.0.4 2024-04-06 16:51:15 +08:00
RockChinQ
e2becf7777 feat: 删除父进程判断 (#750) 2024-04-06 16:50:35 +08:00
RockChinQ
a6b875a242 fix: GroupMessageReceived 事件参数错误 2024-04-04 16:50:45 +08:00
RockChinQ
b5e67f3df8 fix: 内容函数调用时错误地传递了RuntimeContainer 2024-04-04 15:08:40 +08:00
RockChinQ
2093fb16a7 chore: release v3.1.0.3 2024-04-02 22:33:36 +08:00
RockChinQ
fc9a9d2386 fix: 缺失的 psutil 依赖 2024-04-02 22:33:06 +08:00
RockChinQ
5e69f78f7e chore: 不再支持python 3.9 2024-04-01 18:16:49 +08:00
RockChinQ
6919bece77 chore: release v3.1.0.2 2024-03-31 14:41:32 +08:00
RockChinQ
8b003739f1 feat: message.content 支持 mirai.MessageChain 对象 (#741) 2024-03-31 14:38:15 +08:00
RockChinQ
2e9229a6ad fix: 工作目录必须在 main.py 目录 2024-03-30 21:34:22 +08:00
RockChinQ
5a3e7fe8ee perf: 禁止双击运行 2024-03-30 21:28:42 +08:00
RockChinQ
7b3d7e7bd6 fix: json配置文件错误的加载流程 2024-03-30 19:01:59 +08:00
Junyan Qin
fdd7c1864d feat(chatcmpl): 对函数调用进行异常捕获 (#749) 2024-03-30 09:45:30 +00:00
Junyan Qin
cac5a5adff fix(qq-botpy): 群内单query多回复时msg_seq重复问题 2024-03-30 02:58:37 +00:00
RockChinQ
63307633c2 feat: chatcmpl请求时也忽略空的 system prompt message (#745) 2024-03-29 17:34:09 +08:00
RockChinQ
387dfa39ff fix: 内容过滤无效 (#743) 2024-03-29 17:24:42 +08:00
Junyan Qin
1f797f899c doc(README.md): 添加使用量计数徽标 2024-03-26 15:25:08 +08:00
RockChinQ
092bb0a1e2 chore: release v3.1.0.1 2024-03-23 22:50:54 +08:00
RockChinQ
2c3399e237 perf: 敏感词迁移的双条件检查 2024-03-23 22:41:21 +08:00
RockChinQ
835275b47f fix: 多处对 launcher_type 枚举的不当比较 (#736) 2024-03-23 22:39:42 +08:00
Junyan Qin
7b060ce3f9 doc(README.md): 更新wakapi路径 2024-03-23 19:14:43 +08:00
RockChinQ
1fb69311b0 chore: release v3.1.0 2024-03-22 17:17:16 +08:00
Junyan Qin
995d1f61d2 Merge pull request #735 from RockChinQ/feat/plugin-api
Feat: 插件异步 API
2024-03-22 17:10:06 +08:00
RockChinQ
80258e9182 perf: 修改platform_mgr名称 2024-03-22 17:09:43 +08:00
RockChinQ
bd6a32e08e doc: 为可扩展组件添加注释 2024-03-22 16:41:46 +08:00
RockChinQ
5f138de75b doc: 完善query对象的注释 2024-03-22 11:05:58 +08:00
RockChinQ
d0b0f2209a fix: chat处理过程的插件返回值目标错误 2024-03-20 23:32:28 +08:00
RockChinQ
0752698c1d chore: 完善plugin对外对象的注释 2024-03-20 18:43:52 +08:00
RockChinQ
9855c6b8f5 feat: 新的引入路径 2024-03-20 15:48:11 +08:00
RockChinQ
52a7c25540 feat: 异步风格插件方法注册器 2024-03-20 15:09:47 +08:00
RockChinQ
fa823de6b0 perf: 初始化config对象时支持传递dict作为模板 2024-03-20 14:20:56 +08:00
RockChinQ
f53070d8b6 feat: 插件加载阶段前置 (#681) 2024-03-19 22:48:02 +08:00
Junyan Qin
7677672691 Merge pull request #734 from RockChinQ/feat/moonshot
Feat: 添加对 moonshot 模型的支持
2024-03-19 22:41:40 +08:00
RockChinQ
dead8fa168 feat: 添加对 moonshot 模型的支持 2024-03-19 22:39:45 +08:00
RockChinQ
c6347bea45 fix: full-scenario 命名和目录名错误问题 (#731) 2024-03-18 21:05:54 +08:00
RockChinQ
32bd194bfc chore: anthropic 的配置补全迁移 2024-03-18 21:04:09 +08:00
Junyan Qin
cca48a394d Merge pull request #732 from RockChinQ/feat/claude-3
Feat: 接入 claude 3 系列模型
2024-03-18 11:27:22 +08:00
RockChinQ
a723c8ce37 perf: claude 的接口异常处理 2024-03-17 23:22:26 -04:00
RockChinQ
327b2509f6 perf: 忽略用户空消息 2024-03-17 23:06:40 -04:00
RockChinQ
1dae7bd655 feat: 对 claude api 的基本支持 2024-03-17 12:44:45 -04:00
RockChinQ
550a131685 deps: 添加 anthropic 依赖库 2024-03-17 12:03:25 -04:00
RockChinQ
0cfb8bb29f fix: 获取模型列表时未传递version参数 2024-03-16 22:23:02 +08:00
Junyan Qin
9c32420a95 Merge pull request #730 from RockChinQ/feat/customized-model
Feat: 允许自定义模型信息
2024-03-16 22:19:27 +08:00
RockChinQ
867093cc88 chore: 更改 provider.json 格式 2024-03-16 22:12:13 +08:00
RockChinQ
82763f8ec5 chore: 删除默认prompt 2024-03-16 21:43:45 +08:00
RockChinQ
97449065df feat: 通过元数据生成模型列表 2024-03-16 21:43:09 +08:00
Junyan Qin
9489783846 Merge pull request #729 from RockChinQ/feat/migration-stage
Feat: 配置文件迁移功能
2024-03-16 20:34:29 +08:00
RockChinQ
f91c9015bc feat: 添加配置文件迁移阶段 2024-03-16 20:27:17 +08:00
RockChinQ
302d86056d refactor: 所有的 json 加载统一到启动阶段中 2024-03-16 15:41:59 +08:00
Junyan Qin
98bebfddaa Merge pull request #728 from RockChinQ/feat/active-message
Feat: aiocqhttp 和 qq-botpy 适配器的主动消息发送接口
2024-03-16 15:18:27 +08:00
RockChinQ
dab20e3187 feat: aiocqhttp和qq-botpy的主动消息发送接口 2024-03-16 15:16:46 +08:00
RockChinQ
09e72f7c5f chore: 删除注释的代码 2024-03-14 17:24:36 +08:00
Junyan Qin
2028d85f84 Merge pull request #726 from RockChinQ/feat/qq-botpy-cache
Feat: qq-botpy 适配器对 member 和 group 的 openid 进行静态缓存
2024-03-14 16:05:14 +08:00
RockChinQ
ed3c0d9014 feat: qq-botpy 适配器对 member 和 group 的 openid 进行静态缓存 2024-03-14 16:00:22 +08:00
RockChinQ
be06150990 chore: aiocqhttp添加默认access-token参数 2024-03-13 16:53:30 +08:00
Junyan Qin
afb3fb4a31 Merge pull request #725 from RockChinQ/feat/aiocqhttp-access-token
Feat: aiocqhttp支持access-token
2024-03-13 16:49:56 +08:00
RockChinQ
d66577e6c3 feat: aiocqhttp支持access-token 2024-03-13 16:49:11 +08:00
Junyan Qin
6a4ea5446a Merge pull request #724 from RockChinQ/fix/at-resp
Fix: 回复并at机器人时会多一个at组件
2024-03-13 16:31:54 +08:00
RockChinQ
74e84c744a fix: 回复并at机器人时会多一个at组件 2024-03-13 16:31:06 +08:00
Junyan Qin
5ad2446cf3 Update bug-report.yml 2024-03-13 16:13:14 +08:00
Junyan Qin
63303bb5c0 Merge pull request #712 from RockChinQ/feat/component-extensibility
Feat: 更多组件的可扩展性
2024-03-13 00:32:26 +08:00
Junyan Qin
13393b6624 feat: 限速算法的扩展性 2024-03-12 16:31:54 +00:00
Junyan Qin
b9fa11c0c3 feat: prompt 加载器的扩展性 2024-03-12 16:22:07 +00:00
RockChinQ
8c6ce1f030 feat: 群响应规则的扩展性 2024-03-12 23:34:13 +08:00
RockChinQ
1d963d0f0c feat: 不再预先计算前文token数而是在报错时提醒用户重置 2024-03-12 16:04:11 +08:00
Junyan Qin
0ee383be27 Update announcement.json 2024-03-08 22:35:17 +08:00
RockChinQ
53d09129b4 fix: 命令事件的command参数处理错误 (#713) 2024-03-08 21:10:43 +08:00
RockChinQ
a398c6f311 feat: 消息平台适配器可扩展性 2024-03-08 20:40:54 +08:00
RockChinQ
4347ddd42a feat: 长消息处理策略可扩展性 2024-03-08 20:31:22 +08:00
RockChinQ
22cb8a6a06 feat: 内容过滤器的可扩展性 2024-03-08 20:22:06 +08:00
RockChinQ
7f554fd862 feat: command支持扩展命令类 2024-03-08 19:56:57 +08:00
Junyan Qin
a82bfa8a56 perf: 为命令装饰器添加断言 2024-03-08 11:38:26 +00:00
RockChinQ
95784debbf perf: 支持识别docker环境 2024-03-07 15:55:02 +08:00
Junyan Qin
2471c5bf0f Merge pull request #709 from RockChinQ/doc/comments
Doc: 补全部分注释
2024-03-03 16:35:31 +08:00
RockChinQ
2fe6d731b8 doc: 补全部分注释 2024-03-03 16:34:59 +08:00
RockChinQ
ce881372ee chore: release v3.0.2 2024-03-02 21:03:04 +08:00
Junyan Qin
171ea7c375 Merge pull request #708 from RockChinQ/fix/llonebot-not-supported
Fix: 修复使用llonebot时的协议问题
2024-03-02 20:59:41 +08:00
RockChinQ
1e9a6f813f fix: 修复使用llonebot时的协议问题 2024-03-02 20:58:58 +08:00
Junyan Qin
39a7f3b2b9 Merge pull request #707 from RockChinQ/feat/booting-stages
Feat: 分阶段启动
2024-03-02 20:27:51 +08:00
RockChinQ
8d375a02db fix: 未导入问题 2024-03-02 20:05:23 +08:00
RockChinQ
cac8a0a414 perf: 优化导入 2024-03-02 16:39:29 +08:00
RockChinQ
c89623967e refactor: 应用初始化流程初步分阶段 2024-03-02 16:37:30 +08:00
RockChinQ
92aa9c1711 perf: 配置文件生成步骤移动到main.py 2024-03-02 14:57:55 +08:00
Junyan Qin
71f2a58acb feat: 依赖检查移动到main.py 2024-02-29 11:10:30 +00:00
RockChinQ
1f07a8a9e3 refactor: 移动pool到pipeline包 2024-02-29 03:38:38 +00:00
RockChinQ
cacd21bde7 refactor: 移动控制器到pipeline包 2024-02-29 03:38:38 +00:00
RockChinQ
a060ec66c3 deps: 整理依赖 2024-02-29 11:03:11 +08:00
Junyan Qin
fd10db3c75 ci: fix 2024-02-21 13:56:38 +00:00
Junyan Qin
db4c658980 chore: test 2024-02-21 13:52:54 +00:00
Junyan Qin
0ee88674f8 ci: update 2024-02-21 13:52:33 +00:00
Junyan Qin
3540759682 chore: release v3.0.1.1 2024-02-21 13:46:38 +00:00
Junyan Qin
44cc8f15b4 Merge pull request #695 from RockChinQ/ci/arm-image
CI: 构建arm64镜像
2024-02-21 21:45:40 +08:00
Junyan Qin
59f821bf0a ci: 构建arm64镜像 2024-02-21 13:44:07 +00:00
RockChinQ
80858672b0 perf: 控制台输出请求响应过程 2024-02-20 22:56:42 +08:00
RockChinQ
3258d5b255 chore: aiocqhttp默认监听地址改为0.0.0.0 2024-02-20 20:13:46 +08:00
RockChinQ
e8c8cc0a9c chore: release v3.0.1 2024-02-20 11:48:26 +08:00
Junyan Qin
570c19f29f Merge pull request #693 from RockChinQ/fix/3.9-compability
Fix: 针对python3.9的兼容性
2024-02-20 11:47:49 +08:00
RockChinQ
ee93fd8636 hotfix: 针对python3.9的兼容性 2024-02-20 11:47:04 +08:00
RockChinQ
1e6c32ffc7 fix: 'VersionManager' object has no attribute 'get_release_list' 2024-02-20 09:54:02 +08:00
162 changed files with 4175 additions and 1417 deletions

View File

@@ -5,57 +5,35 @@ labels: ["bug?"]
body: body:
- type: dropdown - type: dropdown
attributes: attributes:
label: 部署方式 label: 消息平台适配器
description: "主程序使用的部署方式"
options:
- 手动部署
- 安装器部署
- 一键安装包部署
- Docker部署
validations:
required: true
- type: dropdown
attributes:
label: 登录框架
description: "连接QQ使用的框架" description: "连接QQ使用的框架"
options: options:
- Mirai - yiri-miraiMirai
- go-cqhttp - Nakurugo-cqhttp
- aiocqhttp使用 OneBot 协议接入的)
- qq-botpyQQ官方API
validations: validations:
required: false required: false
- type: input - type: input
attributes: attributes:
label: 系统环境 label: 运行环境
description: 操作系统、系统架构、**主机地理位置**,地理位置最好写清楚,涉及网络问题排查。 description: 操作系统、系统架构、**Python版本**、**主机地理位置**
placeholder: 例如: CentOS x64 中国大陆、Windows11 美国 placeholder: 例如: CentOS x64 Python 3.10.3、Docker 的直接写 Docker 就行
validations:
required: true
- type: input
attributes:
label: Python环境
description: 运行程序的Python版本
placeholder: 例如: Python 3.10
validations: validations:
required: true required: true
- type: input - type: input
attributes: attributes:
label: QChatGPT版本 label: QChatGPT版本
description: QChatGPT版本号 description: QChatGPT版本号
placeholder: 例如: v2.6.0,可以使用`!version`命令查看 placeholder: 例如:v3.3.0,可以使用`!version`命令查看,或者到 pkg/utils/constants.py 查看
validations: validations:
required: true required: true
- type: textarea - type: textarea
attributes: attributes:
label: 异常情况 label: 异常情况
description: 完整描述异常情况,什么时候发生的、发生了什么,尽可能详细 description: 完整描述异常情况,什么时候发生的、发生了什么。**请附带日志信息。**
validations: validations:
required: true required: true
- type: textarea
attributes:
label: 日志信息
description: 请提供完整的 **登录框架 和 QChatGPT控制台**的相关日志信息(若有),不提供日志信息**无法**为您排查问题,请尽可能详细
validations:
required: false
- type: textarea - type: textarea
attributes: attributes:
label: 启用的插件 label: 启用的插件

View File

@@ -2,24 +2,16 @@
实现/解决/优化的内容: 实现/解决/优化的内容:
### 事务 ## 检查清单
- [ ] 已阅读仓库[贡献指引](https://github.com/RockChinQ/QChatGPT/blob/master/CONTRIBUTING.md) ### PR 作者完成
- [ ] 已与维护者在issues或其他平台沟通此PR大致内容
## 以下内容可在起草PR后、合并PR前逐步完成 - [ ] 阅读仓库[贡献指引](https://github.com/RockChinQ/QChatGPT/blob/master/CONTRIBUTING.md)了吗?
- [ ] 与项目所有者沟通过了吗?
### 功能 ### 项目所有者完成
- [ ] 已编写完善的配置文件字段说明(若有新增) - [ ] 相关 issues 链接了吗?
- [ ] 已编写面向用户的新功能说明(若有必要) - [ ] 配置项写好了吗?迁移写好了吗?生效了吗?
- [ ] 已测试新功能或更改 - [ ] 依赖写到 requirements.txt 和 core/bootutils/deps.py 了吗
- [ ] 文档编写了吗?
### 兼容性
- [ ] 已处理版本兼容性
- [ ] 已处理插件兼容问题
### 风险
可能导致或已知的问题:

View File

@@ -17,22 +17,32 @@ jobs:
run: | run: |
if [ -z "$GITHUB_REF" ]; then if [ -z "$GITHUB_REF" ]; then
export GITHUB_REF=${{ github.ref }} export GITHUB_REF=${{ github.ref }}
echo $GITHUB_REF
fi
# - name: Check GITHUB_REF env
# run: echo $GITHUB_REF
# - name: Get version # 在 GitHub Actions 运行环境
# id: get_version
# if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
# run: export GITHUB_REF=${GITHUB_REF/refs\/tags\//}
- name: Check version
id: check_version
run: |
echo $GITHUB_REF
# 如果是tag则去掉refs/tags/前缀
if [[ $GITHUB_REF == refs/tags/* ]]; then
echo "It's a tag"
echo $GITHUB_REF
echo $GITHUB_REF | awk -F '/' '{print $3}'
echo ::set-output name=version::$(echo $GITHUB_REF | awk -F '/' '{print $3}')
else
echo "It's not a tag"
echo $GITHUB_REF
echo ::set-output name=version::${GITHUB_REF}
fi fi
- name: Check GITHUB_REF env
run: echo $GITHUB_REF
- name: Get version
id: get_version
if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//}
- name: Build # image name: rockchin/qchatgpt:<VERSION>
run: docker build --network=host -t rockchin/qchatgpt:${{ steps.get_version.outputs.VERSION }} -t rockchin/qchatgpt:latest .
- name: Login to Registry - name: Login to Registry
run: docker login --username=${{ secrets.DOCKER_USERNAME }} --password ${{ secrets.DOCKER_PASSWORD }} run: docker login --username=${{ secrets.DOCKER_USERNAME }} --password ${{ secrets.DOCKER_PASSWORD }}
- name: Create Buildx
- name: Push image run: docker buildx create --name mybuilder --use
if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT') - name: Build # image name: rockchin/qchatgpt:<VERSION>
run: docker push rockchin/qchatgpt:${{ steps.get_version.outputs.VERSION }} run: docker buildx build --platform linux/arm64,linux/amd64 -t rockchin/qchatgpt:${{ steps.check_version.outputs.version }} -t rockchin/qchatgpt:latest . --push
- name: Push latest image
if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
run: docker push rockchin/qchatgpt:latest

2
.gitignore vendored
View File

@@ -34,4 +34,4 @@ bard.json
res/instance_id.json res/instance_id.json
.DS_Store .DS_Store
/data /data
botpy.log botpy.log*

View File

@@ -3,6 +3,9 @@ WORKDIR /app
COPY . . COPY . .
RUN python -m pip install -r requirements.txt RUN apt update \
&& apt install gcc -y \
&& python -m pip install -r requirements.txt \
&& touch /.dockerenv
CMD [ "python", "main.py" ] CMD [ "python", "main.py" ]

View File

@@ -1,34 +1,30 @@
<p align="center"> <p align="center">
<img src="https://qchatgpt.rockchin.top/logo.png" alt="QChatGPT" width="180" /> <img src="https://qchatgpt.rockchin.top/chrome-512.png" alt="QChatGPT" width="180" />
</p> </p>
<div align="center"> <div align="center">
# QChatGPT # QChatGPT
<a href="https://trendshift.io/repositories/6187" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6187" alt="RockChinQ%2FQChatGPT | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/RockChinQ/QChatGPT)](https://github.com/RockChinQ/QChatGPT/releases/latest) [![GitHub release (latest by date)](https://img.shields.io/github/v/release/RockChinQ/QChatGPT)](https://github.com/RockChinQ/QChatGPT/releases/latest)
<a href="https://hub.docker.com/repository/docker/rockchin/qchatgpt"> <a href="https://hub.docker.com/repository/docker/rockchin/qchatgpt">
<img src="https://img.shields.io/docker/pulls/rockchin/qchatgpt?color=blue" alt="docker pull"> <img src="https://img.shields.io/docker/pulls/rockchin/qchatgpt?color=blue" alt="docker pull">
</a> </a>
![Wakapi Count](https://wakapi.dev/api/badge/RockChinQ/interval:any/project:QChatGPT) ![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.qchatgpt.rockchin.top%2Fapi%2Fv2%2Fview%2Frealtime%2Fcount_query%3Fminute%3D10080&query=%24.data.count&label=%E4%BD%BF%E7%94%A8%E9%87%8F%EF%BC%887%E6%97%A5%EF%BC%89)
<a href="https://codecov.io/gh/RockChinQ/QChatGPT" > ![Wakapi Count](https://wakapi.rockchin.top/api/badge/RockChinQ/interval:any/project:QChatGPT)
<img src="https://codecov.io/gh/RockChinQ/QChatGPT/graph/badge.svg?token=pjxYIL2kbC"/>
</a>
<br/> <br/>
<img src="https://img.shields.io/badge/python-3.9 | 3.10 | 3.11-blue.svg" alt="python"> <img src="https://img.shields.io/badge/python-3.10 | 3.11 | 3.12-blue.svg" alt="python">
<a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=66-aWvn8cbP4c1ut_1YYkvvGVeEtyTH8&authKey=pTaKBK5C%2B8dFzQ4XlENf6MHTCLaHnlKcCRx7c14EeVVlpX2nRSaS8lJm8YeM4mCU&noverify=0&group_code=195992197"> <a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=66-aWvn8cbP4c1ut_1YYkvvGVeEtyTH8&authKey=pTaKBK5C%2B8dFzQ4XlENf6MHTCLaHnlKcCRx7c14EeVVlpX2nRSaS8lJm8YeM4mCU&noverify=0&group_code=195992197">
<img alt="Static Badge" src="https://img.shields.io/badge/%E5%AE%98%E6%96%B9%E7%BE%A4-195992197-purple"> <img alt="Static Badge" src="https://img.shields.io/badge/%E5%AE%98%E6%96%B9%E7%BE%A4-195992197-purple">
</a> </a>
<a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=nC80H57wmKPwRDLFeQrDDjVl81XuC21P&authKey=2wTUTfoQ5v%2BD4C5zfpuR%2BSPMDqdXgDXA%2FS2wHI1NxTfWIG%2B%2FqK08dgyjMMOzhXa9&noverify=0&group_code=248432104"> <a href="https://qm.qq.com/q/PClALFK242">
<img alt="Static Badge" src="https://img.shields.io/badge/%E7%A4%BE%E5%8C%BA%E7%BE%A4-248432104-purple"> <img alt="Static Badge" src="https://img.shields.io/badge/%E7%A4%BE%E5%8C%BA%E7%BE%A4-619154800-purple">
</a>
<a href="https://www.bilibili.com/video/BV14h4y1w7TC">
<img alt="Static Badge" src="https://img.shields.io/badge/%E8%A7%86%E9%A2%91%E6%95%99%E7%A8%8B-208647">
</a>
<a href="https://www.bilibili.com/video/BV11h4y1y74H">
<img alt="Static Badge" src="https://img.shields.io/badge/Linux%E9%83%A8%E7%BD%B2%E8%A7%86%E9%A2%91-208647">
</a> </a>
<a href="https://codecov.io/gh/RockChinQ/QChatGPT" >
<img src="https://codecov.io/gh/RockChinQ/QChatGPT/graph/badge.svg?token=pjxYIL2kbC"/>
</a>
## 使用文档 ## 使用文档
@@ -43,7 +39,17 @@
<a href="https://github.com/RockChinQ/qcg-installer">安装器源码</a> <a href="https://github.com/RockChinQ/qcg-installer">安装器源码</a>
<a href="https://github.com/RockChinQ/qcg-tester">测试工程源码</a> <a href="https://github.com/RockChinQ/qcg-tester">测试工程源码</a>
<a href="https://github.com/RockChinQ/qcg-center">遥测服务端源码</a>
<a href="https://github.com/the-lazy-me/QChatGPT-Wiki">官方文档储存库</a> <a href="https://github.com/the-lazy-me/QChatGPT-Wiki">官方文档储存库</a>
<img alt="回复效果(带有联网插件)" src="https://qchatgpt.rockchin.top/assets/image/QChatGPT-1211.png" width="500px"/> <hr/>
<div align="center">
京东云 4090 单卡 15C90G 实例 <br/>
仅需1.89/小时包月1225元起 <br/>
可选预装Stable Diffusion等应用随用随停计费透明欢迎首选支持 <br/>
https://3.cn/24A-2NXd
</div>
<img alt="回复效果(带有联网插件)" src="https://qchatgpt.rockchin.top/QChatGPT-0516.png" width="500px"/>
</div> </div>

62
main.py
View File

@@ -1,5 +1,6 @@
import asyncio # QChatGPT 终端启动入口
# 在此层级解决依赖项检查。
# QChatGPT/main.py
asciiart = r""" asciiart = r"""
___ ___ _ _ ___ ___ _____ ___ ___ _ _ ___ ___ _____
@@ -11,8 +12,61 @@ asciiart = r"""
📖文档地址: https://q.rkcn.top 📖文档地址: https://q.rkcn.top
""" """
if __name__ == '__main__':
async def main_entry():
print(asciiart) print(asciiart)
import sys
# 检查依赖
from pkg.core.bootutils import deps
missing_deps = await deps.check_deps()
if missing_deps:
print("以下依赖包未安装,将自动安装,请完成后重启程序:")
for dep in missing_deps:
print("-", dep)
await deps.install_deps(missing_deps)
print("已自动安装缺失的依赖包,请重启程序。")
sys.exit(0)
# 检查配置文件
from pkg.core.bootutils import files
generated_files = await files.generate_files()
if generated_files:
print("以下文件不存在,已自动生成,请按需修改配置文件后重启:")
for file in generated_files:
print("-", file)
sys.exit(0)
from pkg.core import boot from pkg.core import boot
asyncio.run(boot.main()) await boot.main()
if __name__ == '__main__':
import os
# 检查本目录是否有main.py且包含QChatGPT字符串
invalid_pwd = False
if not os.path.exists('main.py'):
invalid_pwd = True
else:
with open('main.py', 'r', encoding='utf-8') as f:
content = f.read()
if "QChatGPT/main.py" not in content:
invalid_pwd = True
if invalid_pwd:
print("请在QChatGPT项目根目录下以命令形式运行此程序。")
input("按任意键退出...")
exit(0)
import asyncio
asyncio.run(main_entry())

View File

@@ -34,6 +34,9 @@ class APIGroup(metaclass=abc.ABCMeta):
headers: dict = {}, headers: dict = {},
**kwargs **kwargs
): ):
"""
执行请求
"""
self._runtime_info['account_id'] = "-1" self._runtime_info['account_id'] = "-1"
url = self.prefix + path url = self.prefix + path

View File

@@ -9,8 +9,6 @@ from .groups import plugin
from ...core import app from ...core import app
BACKEND_URL = "https://api.qchatgpt.rockchin.top/api/v2"
class V2CenterAPI: class V2CenterAPI:
"""中央服务器 v2 API 交互类""" """中央服务器 v2 API 交互类"""
@@ -23,7 +21,7 @@ class V2CenterAPI:
plugin: plugin.V2PluginDataAPI = None plugin: plugin.V2PluginDataAPI = None
"""插件 API 组""" """插件 API 组"""
def __init__(self, ap: app.Application, basic_info: dict = None, runtime_info: dict = None): def __init__(self, ap: app.Application, backend_url: str, basic_info: dict = None, runtime_info: dict = None):
"""初始化""" """初始化"""
logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info) logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info)
@@ -31,7 +29,7 @@ class V2CenterAPI:
apigroup.APIGroup._basic_info = basic_info apigroup.APIGroup._basic_info = basic_info
apigroup.APIGroup._runtime_info = runtime_info apigroup.APIGroup._runtime_info = runtime_info
self.main = main.V2MainDataAPI(BACKEND_URL, ap) self.main = main.V2MainDataAPI(backend_url, ap)
self.usage = usage.V2UsageDataAPI(BACKEND_URL, ap) self.usage = usage.V2UsageDataAPI(backend_url, ap)
self.plugin = plugin.V2PluginDataAPI(BACKEND_URL, ap) self.plugin = plugin.V2PluginDataAPI(backend_url, ap)

View File

@@ -1,3 +1,5 @@
# 实例 识别码 控制
import os import os
import uuid import uuid
import json import json

View File

@@ -7,7 +7,8 @@ from ..provider import entities as llm_entities
from . import entities, operator, errors from . import entities, operator, errors
from ..config import manager as cfg_mgr from ..config import manager as cfg_mgr
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update # 引入所有算子以便注册
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
class CommandManager: class CommandManager:
@@ -17,6 +18,9 @@ class CommandManager:
ap: app.Application ap: app.Application
cmd_list: list[operator.CommandOperator] cmd_list: list[operator.CommandOperator]
"""
运行时命令列表,扁平存储,各个对象包含对应的子节点引用
"""
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.ap = ap self.ap = ap
@@ -60,7 +64,7 @@ class CommandManager:
""" """
found = False found = False
if len(context.crt_params) > 0: if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
for oper in operator_list: for oper in operator_list:
if (context.crt_params[0] == oper.name \ if (context.crt_params[0] == oper.name \
or context.crt_params[0] in oper.alias) \ or context.crt_params[0] in oper.alias) \
@@ -78,7 +82,7 @@ class CommandManager:
yield ret yield ret
break break
if not found: if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
if operator is None: if operator is None:
yield entities.CommandReturn( yield entities.CommandReturn(
error=errors.CommandNotFoundError(context.crt_params[0]) error=errors.CommandNotFoundError(context.crt_params[0])

View File

@@ -10,33 +10,67 @@ from . import errors, operator
class CommandReturn(pydantic.BaseModel): class CommandReturn(pydantic.BaseModel):
"""命令返回值
"""
text: typing.Optional[str] text: typing.Optional[str] = None
"""文本 """文本
""" """
image: typing.Optional[mirai.Image] image: typing.Optional[mirai.Image] = None
"""弃用"""
image_url: typing.Optional[str] = None
"""图片链接
"""
error: typing.Optional[errors.CommandError]= None error: typing.Optional[errors.CommandError]= None
"""错误
"""
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
class ExecuteContext(pydantic.BaseModel): class ExecuteContext(pydantic.BaseModel):
"""单次命令执行上下文
"""
query: core_entities.Query query: core_entities.Query
"""本次消息的请求对象"""
session: core_entities.Session session: core_entities.Session
"""本次消息所属的会话对象"""
command_text: str command_text: str
"""命令完整文本"""
command: str command: str
"""命令名称"""
crt_command: str crt_command: str
"""当前命令
多级命令中crt_command为当前命令command为根命令。
例如:!plugin on Webwlkr
处理到plugin时command为plugincrt_command为plugin
处理到on时command为plugincrt_command为on
"""
params: list[str] params: list[str]
"""命令参数
整个命令以空格分割后的参数列表
"""
crt_params: list[str] crt_params: list[str]
"""当前命令参数
多级命令中crt_params为当前命令参数params为根命令参数。
例如:!plugin on Webwlkr
处理到plugin时params为['on', 'Webwlkr']crt_params为['on', 'Webwlkr']
处理到on时params为['on', 'Webwlkr']crt_params为['Webwlkr']
"""
privilege: int privilege: int
"""发起人权限"""

View File

@@ -8,17 +8,34 @@ from . import entities
preregistered_operators: list[typing.Type[CommandOperator]] = [] preregistered_operators: list[typing.Type[CommandOperator]] = []
"""预注册命令算子列表。在初始化时,所有算子类会被注册到此列表中。"""
def operator_class( def operator_class(
name: str, name: str,
help: str, help: str = "",
usage: str = None, usage: str = None,
alias: list[str] = [], alias: list[str] = [],
privilege: int=1, # 1为普通用户2为管理员 privilege: int=1, # 1为普通用户2为管理员
parent_class: typing.Type[CommandOperator] = None parent_class: typing.Type[CommandOperator] = None
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: ) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
"""命令类装饰器
Args:
name (str): 名称
help (str, optional): 帮助信息. Defaults to "".
usage (str, optional): 使用说明. Defaults to None.
alias (list[str], optional): 别名. Defaults to [].
privilege (int, optional): 权限1为普通用户可用2为仅管理员可用. Defaults to 1.
parent_class (typing.Type[CommandOperator], optional): 父节点若为None则为顶级命令. Defaults to None.
Returns:
typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 装饰器
"""
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]: def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
assert issubclass(cls, CommandOperator)
cls.name = name cls.name = name
cls.alias = alias cls.alias = alias
cls.help = help cls.help = help
@@ -34,7 +51,12 @@ def operator_class(
class CommandOperator(metaclass=abc.ABCMeta): class CommandOperator(metaclass=abc.ABCMeta):
"""命令算子 """命令算子抽象类
以下的参数均不需要在子类中设置,只需要在使用装饰器注册类时作为参数传递即可。
命令支持级联,即一个命令可以有多个子命令,子命令可以有子命令,以此类推。
处理命令时,若有子命令,会以当前参数列表的第一个参数去匹配子命令,若匹配成功,则转移到子命令中执行。
若没有匹配成功或没有子命令,则执行当前命令。
""" """
ap: app.Application ap: app.Application
@@ -43,7 +65,8 @@ class CommandOperator(metaclass=abc.ABCMeta):
"""名称,搜索到时若符合则使用""" """名称,搜索到时若符合则使用"""
path: str path: str
"""路径所有父节点的name的连接用于定义命令权限""" """路径所有父节点的name的连接用于定义命令权限,由管理器在初始化时自动设置。
"""
alias: list[str] alias: list[str]
"""同name""" """同name"""
@@ -52,8 +75,9 @@ class CommandOperator(metaclass=abc.ABCMeta):
"""此节点的帮助信息""" """此节点的帮助信息"""
usage: str = None usage: str = None
"""用法"""
parent_class: typing.Type[CommandOperator] | None = None parent_class: typing.Union[typing.Type[CommandOperator], None] = None
"""父节点类。标记以供管理器在初始化时编织父子关系。""" """父节点类。标记以供管理器在初始化时编织父子关系。"""
lowest_privilege: int = 0 lowest_privilege: int = 0
@@ -75,4 +99,15 @@ class CommandOperator(metaclass=abc.ABCMeta):
self, self,
context: entities.ExecuteContext context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""实现此方法以执行命令
支持多次yield以返回多个结果。
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
Args:
context (entities.ExecuteContext): 命令执行上下文
Yields:
entities.CommandReturn: 命令返回封装
"""
pass pass

View File

@@ -24,7 +24,7 @@ class DefaultOperator(operator.CommandOperator):
content = "" content = ""
for msg in prompt.messages: for msg in prompt.messages:
content += f" {msg.role}: {msg.content}" content += f" {msg.readable_str()}\n"
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n" reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
@@ -45,18 +45,18 @@ class DefaultSetOperator(operator.CommandOperator):
context: entities.ExecuteContext context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
else: else:
prompt_name = context.crt_params[0] prompt_name = context.crt_params[0]
try: try:
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
if prompt is None: if prompt is None:
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
else: else:
context.session.use_prompt_name = prompt.name context.session.use_prompt_name = prompt.name
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))

View File

@@ -30,7 +30,7 @@ class LastOperator(operator.CommandOperator):
context.session.using_conversation = context.session.conversations[index-1] context.session.using_conversation = context.session.conversations[index-1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}")
return return
else: else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -42,7 +42,7 @@ class ListOperator(operator.CommandOperator):
using_conv_index = index using_conv_index = index
if index >= page * record_per_page and index < (page + 1) * record_per_page: if index >= page * record_per_page and index < (page + 1) * record_per_page:
content += f"{index} {time_str}: {conv.messages[0].content if len(conv.messages) > 0 else '无内容'}\n" content += f"{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else '无内容'}\n"
index += 1 index += 1
if content == '': if content == '':
@@ -51,6 +51,6 @@ class ListOperator(operator.CommandOperator):
if context.session.using_conversation is None: if context.session.using_conversation is None:
content += "\n当前处于新会话" content += "\n当前处于新会话"
else: else:
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content if len(context.session.using_conversation.messages) > 0 else '无内容'}" content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else '无内容'}"
yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}") yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}")

View File

@@ -0,0 +1,86 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="model",
help='显示和切换模型列表',
usage='!model\n!model show <模型名>\n!model set <模型名>',
privilege=2
)
class ModelOperator(operator.CommandOperator):
"""Model命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content = '模型列表:\n'
model_list = self.ap.model_mgr.model_list
for model in model_list:
content += f"\n名称: {model.name}\n"
content += f"请求器: {model.requester.name}\n"
content += f"\n当前对话使用模型: {context.query.use_model.name}\n"
content += f"新对话默认使用模型: {self.ap.provider_cfg.data.get('model')}\n"
yield entities.CommandReturn(text=content.strip())
@operator.operator_class(
name="show",
help='显示模型详情',
privilege=2,
parent_class=ModelOperator
)
class ModelShowOperator(operator.CommandOperator):
"""Model Show命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0]
model = None
for _model in self.ap.model_mgr.model_list:
if model_name == _model.name:
model = _model
break
if model is None:
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}"))
else:
content = f"模型详情\n"
content += f"名称: {model.name}\n"
if model.model_name is not None:
content += f"请求模型名称: {model.model_name}\n"
content += f"请求器: {model.requester.name}\n"
content += f"密钥组: {model.token_mgr.provider}\n"
content += f"支持视觉: {model.vision_supported}\n"
content += f"支持工具: {model.tool_call_supported}\n"
yield entities.CommandReturn(text=content.strip())
@operator.operator_class(
name="set",
help='设置默认使用模型',
privilege=2,
parent_class=ModelOperator
)
class ModelSetOperator(operator.CommandOperator):
"""Model Set命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0]
model = None
for _model in self.ap.model_mgr.model_list:
if model_name == _model.name:
model = _model
break
if model is None:
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}"))
else:
self.ap.provider_cfg.data['model'] = model_name
await self.ap.provider_cfg.dump_config()
yield entities.CommandReturn(text=f"已设置当前使用模型为 {model_name},重置会话以生效")

View File

@@ -0,0 +1,121 @@
from __future__ import annotations
import json
import typing
import traceback
import ollama
from .. import operator, entities, errors
@operator.operator_class(
name="ollama",
help="ollama平台操作",
usage="!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>"
)
class OllamaOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
content: str = '模型列表:\n'
model_list: list = ollama.list().get('models', [])
for model in model_list:
content += f"名称: {model['name']}\n"
content += f"修改时间: {model['modified_at']}\n"
content += f"大小: {bytes_to_mb(model['size'])}MB\n\n"
yield entities.CommandReturn(text=f"{content.strip()}")
except ollama.ResponseError as e:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常"))
def bytes_to_mb(num_bytes):
mb: float = num_bytes / 1024 / 1024
return format(mb, '.2f')
@operator.operator_class(
name="show",
help="ollama模型详情",
privilege=2,
parent_class=OllamaOperator
)
class OllamaShowOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content: str = '模型详情:\n'
try:
show: dict = ollama.show(model=context.crt_params[0])
model_info: dict = show.get('model_info', {})
ignore_show: str = 'too long to show...'
for key in ['license', 'modelfile']:
show[key] = ignore_show
for key in ['tokenizer.chat_template.rag', 'tokenizer.chat_template.tool_use']:
model_info[key] = ignore_show
content += json.dumps(show, indent=4)
yield entities.CommandReturn(text=content.strip())
except ollama.ResponseError as e:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型详情,请确认 Ollama 服务正常"))
@operator.operator_class(
name="pull",
help="ollama模型拉取",
privilege=2,
parent_class=OllamaOperator
)
class OllamaPullOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
model_list: list = ollama.list().get('models', [])
if context.crt_params[0] in [model['name'] for model in model_list]:
yield entities.CommandReturn(text="模型已存在")
return
except ollama.ResponseError as e:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常"))
return
on_progress: bool = False
progress_count: int = 0
try:
for resp in ollama.pull(model=context.crt_params[0], stream=True):
total: typing.Any = resp.get('total')
if not on_progress:
if total is not None:
on_progress = True
yield entities.CommandReturn(text=resp.get('status'))
else:
if total is None:
on_progress = False
completed: typing.Any = resp.get('completed')
if isinstance(completed, int) and isinstance(total, int):
percentage_completed = (completed / total) * 100
if percentage_completed > progress_count:
progress_count += 10
yield entities.CommandReturn(
text=f"下载进度: {completed}/{total} ({percentage_completed:.2f}%)")
except ollama.ResponseError as e:
yield entities.CommandReturn(text=f"拉取失败: {e.error}")
@operator.operator_class(
name="del",
help="ollama模型删除",
privilege=2,
parent_class=OllamaOperator
)
class OllamaDelOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
ret: str = ollama.delete(model=context.crt_params[0])['status']
except ollama.ResponseError as e:
ret = f"{e.error}"
yield entities.CommandReturn(text=ret)

View File

@@ -20,7 +20,7 @@ class VersionCommand(operator.CommandOperator):
try: try:
if await self.ap.ver_mgr.is_new_version_available(): if await self.ap.ver_mgr.is_new_version_available():
reply_str += "\n\n有新版本可用, 使用 !update 更新" reply_str += "\n\n有新版本可用"
except: except:
pass pass

View File

@@ -8,40 +8,52 @@ from .. import model as file_model
class JSONConfigFile(file_model.ConfigFile): class JSONConfigFile(file_model.ConfigFile):
"""JSON配置文件""" """JSON配置文件"""
config_file_name: str = None def __init__(
"""配置文件名""" self, config_file_name: str, template_file_name: str = None, template_data: dict = None
) -> None:
template_file_name: str = None
"""模板文件名"""
def __init__(self, config_file_name: str, template_file_name: str) -> None:
self.config_file_name = config_file_name self.config_file_name = config_file_name
self.template_file_name = template_file_name self.template_file_name = template_file_name
self.template_data = template_data
def exists(self) -> bool: def exists(self) -> bool:
return os.path.exists(self.config_file_name) return os.path.exists(self.config_file_name)
async def create(self): async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name) if self.template_file_name is not None:
shutil.copyfile(self.template_file_name, self.config_file_name)
elif self.template_data is not None:
with open(self.config_file_name, "w", encoding="utf-8") as f:
json.dump(self.template_data, f, indent=4, ensure_ascii=False)
else:
raise ValueError("template_file_name or template_data must be provided")
async def load(self) -> dict: async def load(self, completion: bool=True) -> dict:
if not self.exists(): if not self.exists():
await self.create() await self.create()
with open(self.config_file_name, 'r', encoding='utf-8') as f: if self.template_file_name is not None:
cfg = json.load(f) with open(self.template_file_name, "r", encoding="utf-8") as f:
self.template_data = json.load(f)
# 从模板文件中进行补全 with open(self.config_file_name, "r", encoding="utf-8") as f:
with open(self.template_file_name, 'r', encoding='utf-8') as f: try:
template_cfg = json.load(f) cfg = json.load(f)
except json.JSONDecodeError as e:
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}")
for key in template_cfg: if completion:
if key not in cfg:
cfg[key] = template_cfg[key] for key in self.template_data:
if key not in cfg:
cfg[key] = self.template_data[key]
return cfg return cfg
async def save(self, cfg: dict): async def save(self, cfg: dict):
with open(self.config_file_name, 'w', encoding='utf-8') as f: with open(self.config_file_name, "w", encoding="utf-8") as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)
def save_sync(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
json.dump(cfg, f, indent=4, ensure_ascii=False) json.dump(cfg, f, indent=4, ensure_ascii=False)

View File

@@ -25,7 +25,7 @@ class PythonModuleConfigFile(file_model.ConfigFile):
async def create(self): async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name) shutil.copyfile(self.template_file_name, self.config_file_name)
async def load(self) -> dict: async def load(self, completion: bool=True) -> dict:
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0] module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
@@ -43,20 +43,24 @@ class PythonModuleConfigFile(file_model.ConfigFile):
cfg[key] = getattr(module, key) cfg[key] = getattr(module, key)
# 从模板模块文件中进行补全 # 从模板模块文件中进行补全
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0] if completion:
module = importlib.import_module(module_name) module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
module = importlib.import_module(module_name)
for key in dir(module): for key in dir(module):
if key.startswith('__'): if key.startswith('__'):
continue continue
if not isinstance(getattr(module, key), allowed_types): if not isinstance(getattr(module, key), allowed_types):
continue continue
if key not in cfg: if key not in cfg:
cfg[key] = getattr(module, key) cfg[key] = getattr(module, key)
return cfg return cfg
async def save(self, data: dict): async def save(self, data: dict):
logging.warning('Python模块配置文件不支持保存') logging.warning('Python模块配置文件不支持保存')
def save_sync(self, data: dict):
logging.warning('Python模块配置文件不支持保存')

59
pkg/config/impls/yaml.py Normal file
View File

@@ -0,0 +1,59 @@
import os
import shutil
import yaml
from .. import model as file_model
class YAMLConfigFile(file_model.ConfigFile):
"""YAML配置文件"""
def __init__(
self, config_file_name: str, template_file_name: str = None, template_data: dict = None
) -> None:
self.config_file_name = config_file_name
self.template_file_name = template_file_name
self.template_data = template_data
def exists(self) -> bool:
return os.path.exists(self.config_file_name)
async def create(self):
if self.template_file_name is not None:
shutil.copyfile(self.template_file_name, self.config_file_name)
elif self.template_data is not None:
with open(self.config_file_name, "w", encoding="utf-8") as f:
yaml.dump(self.template_data, f, indent=4, allow_unicode=True)
else:
raise ValueError("template_file_name or template_data must be provided")
async def load(self, completion: bool=True) -> dict:
if not self.exists():
await self.create()
if self.template_file_name is not None:
with open(self.template_file_name, "r", encoding="utf-8") as f:
self.template_data = yaml.load(f, Loader=yaml.FullLoader)
with open(self.config_file_name, "r", encoding="utf-8") as f:
try:
cfg = yaml.load(f, Loader=yaml.FullLoader)
except yaml.YAMLError as e:
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}")
if completion:
for key in self.template_data:
if key not in cfg:
cfg[key] = self.template_data[key]
return cfg
async def save(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
yaml.dump(cfg, f, indent=4, allow_unicode=True)
def save_sync(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
yaml.dump(cfg, f, indent=4, allow_unicode=True)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from . import model as file_model from . import model as file_model
from .impls import pymodule, json as json_file from .impls import pymodule, json as json_file, yaml as yaml_file
managers: ConfigManager = [] managers: ConfigManager = []
@@ -20,34 +20,78 @@ class ConfigManager:
self.file = cfg_file self.file = cfg_file
self.data = {} self.data = {}
async def load_config(self): async def load_config(self, completion: bool=True):
self.data = await self.file.load() self.data = await self.file.load(completion=completion)
async def dump_config(self): async def dump_config(self):
await self.file.save(self.data) await self.file.save(self.data)
def dump_config_sync(self):
self.file.save_sync(self.data)
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager:
"""加载Python模块配置文件""" async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager:
"""加载Python模块配置文件
Args:
config_name (str): 配置文件名
template_name (str): 模板文件名
completion (bool): 是否自动补全内存中的配置文件
Returns:
ConfigManager: 配置文件管理器
"""
cfg_inst = pymodule.PythonModuleConfigFile( cfg_inst = pymodule.PythonModuleConfigFile(
config_name, config_name,
template_name template_name
) )
cfg_mgr = ConfigManager(cfg_inst) cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config() await cfg_mgr.load_config(completion=completion)
return cfg_mgr return cfg_mgr
async def load_json_config(config_name: str, template_name: str) -> ConfigManager: async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
"""加载JSON配置文件""" """加载JSON配置文件
Args:
config_name (str): 配置文件名
template_name (str): 模板文件名
template_data (dict): 模板数据
completion (bool): 是否自动补全内存中的配置文件
"""
cfg_inst = json_file.JSONConfigFile( cfg_inst = json_file.JSONConfigFile(
config_name, config_name,
template_name template_name,
template_data
) )
cfg_mgr = ConfigManager(cfg_inst) cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config() await cfg_mgr.load_config(completion=completion)
return cfg_mgr
async def load_yaml_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
"""加载YAML配置文件
Args:
config_name (str): 配置文件名
template_name (str): 模板文件名
template_data (dict): 模板数据
completion (bool): 是否自动补全内存中的配置文件
Returns:
ConfigManager: 配置文件管理器
"""
cfg_inst = yaml_file.YAMLConfigFile(
config_name,
template_name,
template_data
)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config(completion=completion)
return cfg_mgr return cfg_mgr

View File

@@ -10,6 +10,9 @@ class ConfigFile(metaclass=abc.ABCMeta):
template_file_name: str = None template_file_name: str = None
"""模板文件名""" """模板文件名"""
template_data: dict = None
"""模板数据"""
@abc.abstractmethod @abc.abstractmethod
def exists(self) -> bool: def exists(self) -> bool:
pass pass
@@ -19,9 +22,13 @@ class ConfigFile(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def load(self) -> dict: async def load(self, completion: bool=True) -> dict:
pass pass
@abc.abstractmethod @abc.abstractmethod
async def save(self, data: dict): async def save(self, data: dict):
pass pass
@abc.abstractmethod
def save_sync(self, data: dict):
pass

View File

@@ -4,24 +4,25 @@ import logging
import asyncio import asyncio
import traceback import traceback
import aioconsole
from ..platform import manager as im_mgr from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.requester import modelmgr as llm_model_mgr from ..provider.modelmgr import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr from ..provider.tools import toolmgr as llm_tool_mgr
from ..provider import runnermgr
from ..config import manager as config_mgr from ..config import manager as config_mgr
from ..audit.center import v2 as center_mgr from ..audit.center import v2 as center_mgr
from ..command import cmdmgr from ..command import cmdmgr
from ..plugin import manager as plugin_mgr from ..plugin import manager as plugin_mgr
from . import pool, controller from ..pipeline import pool
from ..pipeline import stagemgr from ..pipeline import controller, stagemgr
from ..utils import version as version_mgr, proxy as proxy_mgr from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
class Application: class Application:
im_mgr: im_mgr.PlatformManager = None """运行时应用对象和上下文"""
platform_mgr: im_mgr.PlatformManager = None
cmd_mgr: cmdmgr.CommandManager = None cmd_mgr: cmdmgr.CommandManager = None
@@ -33,6 +34,10 @@ class Application:
tool_mgr: llm_tool_mgr.ToolManager = None tool_mgr: llm_tool_mgr.ToolManager = None
runner_mgr: runnermgr.RunnerManager = None
# ======= 配置管理器 =======
command_cfg: config_mgr.ConfigManager = None command_cfg: config_mgr.ConfigManager = None
pipeline_cfg: config_mgr.ConfigManager = None pipeline_cfg: config_mgr.ConfigManager = None
@@ -43,6 +48,18 @@ class Application:
system_cfg: config_mgr.ConfigManager = None system_cfg: config_mgr.ConfigManager = None
# ======= 元数据配置管理器 =======
sensitive_meta: config_mgr.ConfigManager = None
adapter_qq_botpy_meta: config_mgr.ConfigManager = None
plugin_setting_meta: config_mgr.ConfigManager = None
llm_models_meta: config_mgr.ConfigManager = None
# =========================
ctr_mgr: center_mgr.V2CenterAPI = None ctr_mgr: center_mgr.V2CenterAPI = None
plugin_mgr: plugin_mgr.PluginManager = None plugin_mgr: plugin_mgr.PluginManager = None
@@ -55,6 +72,8 @@ class Application:
ver_mgr: version_mgr.VersionManager = None ver_mgr: version_mgr.VersionManager = None
ann_mgr: announce_mgr.AnnouncementManager = None
proxy_mgr: proxy_mgr.ProxyManager = None proxy_mgr: proxy_mgr.ProxyManager = None
logger: logging.Logger = None logger: logging.Logger = None
@@ -66,27 +85,18 @@ class Application:
pass pass
async def run(self): async def run(self):
await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins() await self.plugin_mgr.initialize_plugins()
tasks = [] tasks = []
try: try:
tasks = [ tasks = [
asyncio.create_task(self.im_mgr.run()), asyncio.create_task(self.platform_mgr.run()),
asyncio.create_task(self.ctrl.run()) asyncio.create_task(self.ctrl.run())
] ]
# async def interrupt(tasks): # 挂信号处理
# await asyncio.sleep(1.5)
# while await aioconsole.ainput("使用 ctrl+c 或 'exit' 退出程序 > ") != 'exit':
# pass
# for task in tasks:
# task.cancel()
# await interrupt(tasks)
import signal import signal

View File

@@ -1,143 +1,37 @@
from __future__ import print_function from __future__ import print_function
import os import traceback
import sys
from .bootutils import files
from .bootutils import deps
from .bootutils import log
from .bootutils import config
from . import app from . import app
from . import pool
from . import controller
from ..pipeline import stagemgr
from ..audit import identifier from ..audit import identifier
from ..provider.session import sessionmgr as llm_session_mgr from . import stage
from ..provider.requester import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr
from ..platform import manager as im_mgr
from ..command import cmdmgr
from ..plugin import manager as plugin_mgr
from ..audit.center import v2 as center_v2
from ..utils import version, proxy, announce
use_override = False # 引入启动阶段实现以便注册
from .stages import load_config, setup_logger, build_app, migrate, show_notes
stage_order = [
"LoadConfigStage",
"MigrationStage",
"SetupLoggerStage",
"BuildAppStage",
"ShowNotesStage"
]
async def make_app() -> app.Application: async def make_app() -> app.Application:
global use_override
generated_files = await files.generate_files()
if generated_files:
print("以下文件不存在,已自动生成,请按需修改配置文件后重启:")
for file in generated_files:
print("-", file)
sys.exit(0)
missing_deps = await deps.check_deps()
if missing_deps:
print("以下依赖包未安装,将自动安装,请完成后重启程序:")
for dep in missing_deps:
print("-", dep)
await deps.install_deps(missing_deps)
sys.exit(0)
qcg_logger = await log.init_logging()
# 生成标识符 # 生成标识符
identifier.init() identifier.init()
# ========== 加载配置文件 ==========
command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json")
pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json")
platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json")
provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json")
system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json")
# ========== 构建应用实例 ==========
ap = app.Application() ap = app.Application()
ap.logger = qcg_logger
ap.command_cfg = command_cfg # 执行启动阶段
ap.pipeline_cfg = pipeline_cfg for stage_name in stage_order:
ap.platform_cfg = platform_cfg stage_cls = stage.preregistered_stages[stage_name]
ap.provider_cfg = provider_cfg stage_inst = stage_cls()
ap.system_cfg = system_cfg
proxy_mgr = proxy.ProxyManager(ap) await stage_inst.run(ap)
await proxy_mgr.initialize()
ap.proxy_mgr = proxy_mgr
ver_mgr = version.VersionManager(ap)
await ver_mgr.initialize()
ap.ver_mgr = ver_mgr
center_v2_api = center_v2.V2CenterAPI(
ap,
basic_info={
"host_id": identifier.identifier["host_id"],
"instance_id": identifier.identifier["instance_id"],
"semantic_version": ver_mgr.get_current_version(),
"platform": sys.platform,
},
runtime_info={
"admin_id": "{}".format(system_cfg.data["admin-sessions"]),
"msg_source": str([
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
for adapter_cfg in platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
]),
},
)
ap.ctr_mgr = center_v2_api
# 发送公告
ann_mgr = announce.AnnouncementManager(ap)
await ann_mgr.show_announcements()
ap.query_pool = pool.QueryPool()
await ap.ver_mgr.show_version_update()
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst
cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize()
ap.cmd_mgr = cmd_mgr_inst
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
await llm_model_mgr_inst.initialize()
ap.model_mgr = llm_model_mgr_inst
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
await llm_session_mgr_inst.initialize()
ap.sess_mgr = llm_session_mgr_inst
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
await llm_prompt_mgr_inst.initialize()
ap.prompt_mgr = llm_prompt_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize()
ap.im_mgr = im_mgr_inst
stage_mgr = stagemgr.StageManager(ap)
await stage_mgr.initialize()
ap.stage_mgr = stage_mgr
ctrl = controller.Controller(ap)
ap.ctrl = ctrl
await ap.initialize() await ap.initialize()
@@ -145,5 +39,8 @@ async def make_app() -> app.Application:
async def main(): async def main():
app_inst = await make_app() try:
await app_inst.run() app_inst = await make_app()
await app_inst.run()
except Exception as e:
traceback.print_exc()

View File

@@ -8,16 +8,3 @@ from ...config.impls import pymodule
load_python_module_config = config_mgr.load_python_module_config load_python_module_config = config_mgr.load_python_module_config
load_json_config = config_mgr.load_json_config load_json_config = config_mgr.load_json_config
async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str]:
override_json = json.load(open("override.json", "r", encoding="utf-8"))
overrided = []
config = cfg_mgr.data
for key in override_json:
if key in config:
config[key] = override_json[key]
overrided.append(key)
return overrided

View File

@@ -3,14 +3,19 @@ import pip
required_deps = { required_deps = {
"requests": "requests", "requests": "requests",
"openai": "openai", "openai": "openai",
"anthropic": "anthropic",
"colorlog": "colorlog", "colorlog": "colorlog",
"mirai": "yiri-mirai-rc", "mirai": "yiri-mirai-rc",
"aiocqhttp": "aiocqhttp",
"botpy": "qq-botpy",
"PIL": "pillow", "PIL": "pillow",
"nakuru": "nakuru-project-idk", "nakuru": "nakuru-project-idk",
"CallingGPT": "CallingGPT",
"tiktoken": "tiktoken", "tiktoken": "tiktoken",
"yaml": "pyyaml", "yaml": "pyyaml",
"aiohttp": "aiohttp", "aiohttp": "aiohttp",
"psutil": "psutil",
"async_lru": "async-lru",
"ollama": "ollama",
} }

View File

@@ -13,13 +13,13 @@ required_files = {
"data/config/platform.json": "templates/platform.json", "data/config/platform.json": "templates/platform.json",
"data/config/provider.json": "templates/provider.json", "data/config/provider.json": "templates/provider.json",
"data/config/system.json": "templates/system.json", "data/config/system.json": "templates/system.json",
"data/config/sensitive-words.json": "templates/sensitive-words.json",
"data/scenario/default.json": "templates/scenario-template.json", "data/scenario/default.json": "templates/scenario-template.json",
} }
required_paths = [ required_paths = [
"temp", "temp",
"data", "data",
"data/metadata",
"data/prompts", "data/prompts",
"data/scenario", "data/scenario",
"data/logs", "data/logs",

View File

@@ -16,6 +16,10 @@ log_colors_config = {
async def init_logging() -> logging.Logger: async def init_logging() -> logging.Logger:
# 删除所有现有的logger
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
level = logging.INFO level = logging.INFO
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]: if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
@@ -46,7 +50,7 @@ async def init_logging() -> logging.Logger:
qcg_logger.debug("日志初始化完成,日志级别:%s" % level) qcg_logger.debug("日志初始化完成,日志级别:%s" % level)
logging.basicConfig( logging.basicConfig(
level=logging.INFO, # 设置日志输出格式 level=logging.CRITICAL, # 设置日志输出格式
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
# 日志输出的格式 # 日志输出的格式
# -8表示占位符让输出左对齐输出长度都为8位 # -8表示占位符让输出左对齐输出长度都为8位

View File

@@ -9,13 +9,14 @@ import pydantic
import mirai import mirai
from ..provider import entities as llm_entities from ..provider import entities as llm_entities
from ..provider.requester import entities from ..provider.modelmgr import entities
from ..provider.sysprompt import entities as sysprompt_entities from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter from ..platform import adapter as msadapter
class LauncherTypes(enum.Enum): class LauncherTypes(enum.Enum):
"""一个请求的发起者类型"""
PERSON = 'person' PERSON = 'person'
"""私聊""" """私聊"""
@@ -31,53 +32,56 @@ class Query(pydantic.BaseModel):
"""请求ID添加进请求池时生成""" """请求ID添加进请求池时生成"""
launcher_type: LauncherTypes launcher_type: LauncherTypes
"""会话类型platform设置""" """会话类型platform处理阶段设置"""
launcher_id: int launcher_id: int
"""会话IDplatform设置""" """会话IDplatform处理阶段设置"""
sender_id: int sender_id: int
"""发送者IDplatform设置""" """发送者IDplatform处理阶段设置"""
message_event: mirai.MessageEvent message_event: mirai.MessageEvent
"""事件platform收到的事件""" """事件platform收到的原始事件"""
message_chain: mirai.MessageChain message_chain: mirai.MessageChain
"""消息链platform收到的消息链""" """消息链platform收到的原始消息链"""
adapter: msadapter.MessageSourceAdapter adapter: msadapter.MessageSourceAdapter
"""适配器对象""" """消息平台适配器对象单个app中可能启用了多个消息平台适配器此对象表明发起此query的适配器"""
session: typing.Optional[Session] = None session: typing.Optional[Session] = None
"""会话对象,由前置处理器设置""" """会话对象,由前置处理器阶段设置"""
messages: typing.Optional[list[llm_entities.Message]] = [] messages: typing.Optional[list[llm_entities.Message]] = []
"""历史消息列表,由前置处理器设置""" """历史消息列表,由前置处理器阶段设置"""
prompt: typing.Optional[sysprompt_entities.Prompt] = None prompt: typing.Optional[sysprompt_entities.Prompt] = None
"""情景预设内容,由前置处理器设置""" """情景预设内容,由前置处理器阶段设置"""
user_message: typing.Optional[llm_entities.Message] = None user_message: typing.Optional[llm_entities.Message] = None
"""此次请求的用户消息对象,由前置处理器设置""" """此次请求的用户消息对象,由前置处理器阶段设置"""
use_model: typing.Optional[entities.LLMModelInfo] = None use_model: typing.Optional[entities.LLMModelInfo] = None
"""使用的模型,由前置处理器设置""" """使用的模型,由前置处理器阶段设置"""
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器设置""" """使用的函数,由前置处理器阶段设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] = [] resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[mirai.MessageChain]] = []
"""provider生成的回复消息对象列表""" """Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None
"""回复消息链从resp_messages包装而得""" """回复消息链从resp_messages包装而得"""
# ======= 内部保留 =======
current_stage: "pkg.pipeline.stagemgr.StageInstContainer" = None
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
class Conversation(pydantic.BaseModel): class Conversation(pydantic.BaseModel):
"""对话""" """对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation但只有一个当前使用的 Conversation"""
prompt: sysprompt_entities.Prompt prompt: sysprompt_entities.Prompt
@@ -93,7 +97,7 @@ class Conversation(pydantic.BaseModel):
class Session(pydantic.BaseModel): class Session(pydantic.BaseModel):
"""会话""" """会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
launcher_type: LauncherTypes launcher_type: LauncherTypes
launcher_id: int launcher_id: int
@@ -111,6 +115,7 @@ class Session(pydantic.BaseModel):
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None semaphore: typing.Optional[asyncio.Semaphore] = None
"""当前会话的信号量,用于限制并发"""
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True

47
pkg/core/migration.py Normal file
View File

@@ -0,0 +1,47 @@
from __future__ import annotations
import abc
import typing
from . import app
preregistered_migrations: list[typing.Type[Migration]] = []
"""当前阶段暂不支持扩展"""
def migration_class(name: str, number: int):
"""注册一个迁移
"""
def decorator(cls: typing.Type[Migration]) -> typing.Type[Migration]:
cls.name = name
cls.number = number
preregistered_migrations.append(cls)
return cls
return decorator
class Migration(abc.ABC):
"""一个版本的迁移
"""
name: str
number: int
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
@abc.abstractmethod
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
pass
@abc.abstractmethod
async def run(self):
"""执行迁移
"""
pass

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
import os
import sys
from .. import migration
@migration.migration_class("sensitive-word-migration", 1)
class SensitiveWordMigration(migration.Migration):
"""敏感词迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return os.path.exists("data/config/sensitive-words.json") and not os.path.exists("data/metadata/sensitive-words.json")
async def run(self):
"""执行迁移
"""
# 移动文件
os.rename("data/config/sensitive-words.json", "data/metadata/sensitive-words.json")
# 重新加载配置
await self.ap.sensitive_meta.load_config()

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("openai-config-migration", 2)
class OpenAIConfigMigration(migration.Migration):
"""OpenAI配置迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'openai-config' in self.ap.provider_cfg.data
async def run(self):
"""执行迁移
"""
old_openai_config = self.ap.provider_cfg.data['openai-config'].copy()
if 'keys' not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data['keys'] = {}
if 'openai' not in self.ap.provider_cfg.data['keys']:
self.ap.provider_cfg.data['keys']['openai'] = []
self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys']
self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model']
del old_openai_config['chat-completions-params']['model']
if 'requester' not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data['requester'] = {}
if 'openai-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {}
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {
'base-url': old_openai_config['base_url'],
'args': old_openai_config['chat-completions-params'],
'timeout': old_openai_config['request-timeout'],
}
del self.ap.provider_cfg.data['openai-config']
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("anthropic-requester-config-completion", 3)
class AnthropicRequesterConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'anthropic-messages' not in self.ap.provider_cfg.data['requester'] \
or 'anthropic' not in self.ap.provider_cfg.data['keys']
async def run(self):
"""执行迁移
"""
if 'anthropic-messages' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['anthropic-messages'] = {
'base-url': 'https://api.anthropic.com',
'args': {
'max_tokens': 1024
},
'timeout': 120,
}
if 'anthropic' not in self.ap.provider_cfg.data['keys']:
self.ap.provider_cfg.data['keys']['anthropic'] = []
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("moonshot-config-completion", 4)
class MoonshotConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester'] \
or 'moonshot' not in self.ap.provider_cfg.data['keys']
async def run(self):
"""执行迁移
"""
if 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['moonshot-chat-completions'] = {
'base-url': 'https://api.moonshot.cn/v1',
'args': {},
'timeout': 120,
}
if 'moonshot' not in self.ap.provider_cfg.data['keys']:
self.ap.provider_cfg.data['keys']['moonshot'] = []
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("deepseek-config-completion", 5)
class DeepseekConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester'] \
or 'deepseek' not in self.ap.provider_cfg.data['keys']
async def run(self):
"""执行迁移
"""
if 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['deepseek-chat-completions'] = {
'base-url': 'https://api.deepseek.com',
'args': {},
'timeout': 120,
}
if 'deepseek' not in self.ap.provider_cfg.data['keys']:
self.ap.provider_cfg.data['keys']['deepseek'] = []
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("vision-config", 6)
class VisionConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return "enable-vision" not in self.ap.provider_cfg.data
async def run(self):
"""执行迁移"""
if "enable-vision" not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data["enable-vision"] = False
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("qcg-center-url-config", 7)
class QCGCenterURLConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return "qcg-center-url" not in self.ap.system_cfg.data
async def run(self):
"""执行迁移"""
if "qcg-center-url" not in self.ap.system_cfg.data:
self.ap.system_cfg.data["qcg-center-url"] = "https://api.qchatgpt.rockchin.top/api/v2"
await self.ap.system_cfg.dump_config()

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("ad-fixwin-cfg-migration", 8)
class AdFixwinConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return isinstance(
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]["default"],
int
)
async def run(self):
"""执行迁移"""
for session_name in self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]:
temp_dict = {
"window-size": 60,
"limit": self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name]
}
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name] = temp_dict
await self.ap.pipeline_cfg.dump_config()

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("msg-truncator-cfg-migration", 9)
class MsgTruncatorConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'msg-truncate' not in self.ap.pipeline_cfg.data
async def run(self):
"""执行迁移"""
self.ap.pipeline_cfg.data['msg-truncate'] = {
'method': 'round',
'round': {
'max-round': 10
}
}
await self.ap.pipeline_cfg.dump_config()

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("ollama-requester-config", 10)
class MsgTruncatorConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'ollama-chat' not in self.ap.provider_cfg.data['requester']
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['ollama-chat'] = {
"base-url": "http://127.0.0.1:11434",
"args": {},
"timeout": 600
}
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("command-prefix-config", 11)
class CommandPrefixConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'command-prefix' not in self.ap.command_cfg.data
async def run(self):
"""执行迁移"""
self.ap.command_cfg.data['command-prefix'] = [
"!", ""
]
await self.ap.command_cfg.dump_config()

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("runner-config", 12)
class RunnerConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'runner' not in self.ap.provider_cfg.data
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['runner'] = 'local-agent'
await self.ap.provider_cfg.dump_config()

44
pkg/core/note.py Normal file
View File

@@ -0,0 +1,44 @@
from __future__ import annotations
import abc
import typing
from . import app
preregistered_notes: list[typing.Type[LaunchNote]] = []
def note_class(name: str, number: int):
"""注册一个启动信息
"""
def decorator(cls: typing.Type[LaunchNote]) -> typing.Type[LaunchNote]:
cls.name = name
cls.number = number
preregistered_notes.append(cls)
return cls
return decorator
class LaunchNote(abc.ABC):
"""启动信息
"""
name: str
number: int
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
@abc.abstractmethod
async def need_show(self) -> bool:
"""判断当前环境是否需要显示此启动信息
"""
pass
@abc.abstractmethod
async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]:
"""生成启动信息
"""
pass

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
import typing
from .. import note, app
@note.note_class("ClassicNotes", 1)
class ClassicNotes(note.LaunchNote):
"""经典启动信息
"""
async def need_show(self) -> bool:
return True
async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]:
yield await self.ap.ann_mgr.show_announcements()
yield await self.ap.ver_mgr.show_version_update()

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
import typing
import os
import sys
import logging
from .. import note, app
@note.note_class("SelectionModeOnWindows", 2)
class SelectionModeOnWindows(note.LaunchNote):
"""Windows 上的选择模式提示信息
"""
async def need_show(self) -> bool:
return os.name == 'nt'
async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]:
yield """您正在使用 Windows 系统,若窗口左上角显示处于”选择“模式,程序将被暂停运行,此时请右键窗口中空白区域退出选择模式。""", logging.INFO

34
pkg/core/stage.py Normal file
View File

@@ -0,0 +1,34 @@
from __future__ import annotations
import abc
import typing
from . import app
preregistered_stages: dict[str, typing.Type[BootingStage]] = {}
"""预注册的请求处理阶段。在初始化时,所有请求处理阶段类会被注册到此字典中。
当前阶段暂不支持扩展
"""
def stage_class(
name: str
):
def decorator(cls: typing.Type[BootingStage]) -> typing.Type[BootingStage]:
preregistered_stages[name] = cls
return cls
return decorator
class BootingStage(abc.ABC):
"""启动阶段
"""
name: str = None
@abc.abstractmethod
async def run(self, ap: app.Application):
"""启动
"""
pass

View File

@@ -0,0 +1,100 @@
from __future__ import annotations
import sys
from .. import stage, app
from ...utils import version, proxy, announce, platform
from ...audit.center import v2 as center_v2
from ...audit import identifier
from ...pipeline import pool, controller, stagemgr
from ...plugin import manager as plugin_mgr
from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr
from ...provider import runnermgr
from ...platform import manager as im_mgr
@stage.stage_class("BuildAppStage")
class BuildAppStage(stage.BootingStage):
"""构建应用阶段
"""
async def run(self, ap: app.Application):
"""构建app对象的各个组件对象并初始化
"""
proxy_mgr = proxy.ProxyManager(ap)
await proxy_mgr.initialize()
ap.proxy_mgr = proxy_mgr
ver_mgr = version.VersionManager(ap)
await ver_mgr.initialize()
ap.ver_mgr = ver_mgr
center_v2_api = center_v2.V2CenterAPI(
ap,
backend_url=ap.system_cfg.data["qcg-center-url"],
basic_info={
"host_id": identifier.identifier["host_id"],
"instance_id": identifier.identifier["instance_id"],
"semantic_version": ver_mgr.get_current_version(),
"platform": platform.get_platform(),
},
runtime_info={
"admin_id": "{}".format(ap.system_cfg.data["admin-sessions"]),
"msg_source": str([
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
for adapter_cfg in ap.platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
]),
},
)
ap.ctr_mgr = center_v2_api
# 发送公告
ann_mgr = announce.AnnouncementManager(ap)
ap.ann_mgr = ann_mgr
ap.query_pool = pool.QueryPool()
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst
await plugin_mgr_inst.load_plugins()
cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize()
ap.cmd_mgr = cmd_mgr_inst
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
await llm_model_mgr_inst.initialize()
ap.model_mgr = llm_model_mgr_inst
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
await llm_session_mgr_inst.initialize()
ap.sess_mgr = llm_session_mgr_inst
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
await llm_prompt_mgr_inst.initialize()
ap.prompt_mgr = llm_prompt_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst
runner_mgr_inst = runnermgr.RunnerManager(ap)
await runner_mgr_inst.initialize()
ap.runner_mgr = runner_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize()
ap.platform_mgr = im_mgr_inst
stage_mgr = stagemgr.StageManager(ap)
await stage_mgr.initialize()
ap.stage_mgr = stage_mgr
ctrl = controller.Controller(ap)
ap.ctrl = ctrl

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from .. import stage, app
from ..bootutils import config
@stage.stage_class("LoadConfigStage")
class LoadConfigStage(stage.BootingStage):
"""加载配置文件阶段
"""
async def run(self, ap: app.Application):
"""启动
"""
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json", completion=False)
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json", completion=False)
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json", completion=False)
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json", completion=False)
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json", completion=False)
ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json")
await ap.plugin_setting_meta.dump_config()
ap.sensitive_meta = await config.load_json_config("data/metadata/sensitive-words.json", "templates/metadata/sensitive-words.json")
await ap.sensitive_meta.dump_config()
ap.adapter_qq_botpy_meta = await config.load_json_config("data/metadata/adapter-qq-botpy.json", "templates/metadata/adapter-qq-botpy.json")
await ap.adapter_qq_botpy_meta.dump_config()
ap.llm_models_meta = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json")
await ap.llm_models_meta.dump_config()

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
import importlib
from .. import stage, app
from .. import migration
from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config
@stage.stage_class("MigrationStage")
class MigrationStage(stage.BootingStage):
"""迁移阶段
"""
async def run(self, ap: app.Application):
"""启动
"""
migrations = migration.preregistered_migrations
# 按照迁移号排序
migrations.sort(key=lambda x: x.number)
for migration_cls in migrations:
migration_instance = migration_cls(ap)
if await migration_instance.need_migrate():
await migration_instance.run()

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from .. import stage, app
from ..bootutils import log
@stage.stage_class("SetupLoggerStage")
class SetupLoggerStage(stage.BootingStage):
"""设置日志器阶段
"""
async def run(self, ap: app.Application):
"""启动
"""
ap.logger = await log.init_logging()

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
from .. import stage, app, note
from ..notes import n001_classic_msgs, n002_selection_mode_on_windows
@stage.stage_class("ShowNotesStage")
class ShowNotesStage(stage.BootingStage):
"""显示启动信息阶段
"""
async def run(self, ap: app.Application):
# 排序
note.preregistered_notes.sort(key=lambda x: x.number)
for note_cls in note.preregistered_notes:
try:
note_inst = note_cls(ap)
if await note_inst.need_show():
async for ret in note_inst.yield_note():
if not ret:
continue
msg, level = ret
if msg:
ap.logger.log(level, msg)
except Exception as e:
continue

View File

@@ -8,6 +8,10 @@ from ...config import manager as cfg_mgr
@stage.stage_class('BanSessionCheckStage') @stage.stage_class('BanSessionCheckStage')
class BanSessionCheckStage(stage.PipelineStage): class BanSessionCheckStage(stage.PipelineStage):
"""访问控制处理阶段
仅检查query中群号或个人号是否在访问控制列表中。
"""
async def initialize(self): async def initialize(self):
pass pass
@@ -24,22 +28,24 @@ class BanSessionCheckStage(stage.PipelineStage):
sess_list = self.ap.pipeline_cfg.data['access-control'][mode] sess_list = self.ap.pipeline_cfg.data['access-control'][mode]
if (query.launcher_type == 'group' and 'group_*' in sess_list) \ if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \
or (query.launcher_type == 'person' and 'person_*' in sess_list): or (query.launcher_type.value == 'person' and 'person_*' in sess_list):
found = True found = True
else: else:
for sess in sess_list: for sess in sess_list:
if sess == f"{query.launcher_type}_{query.launcher_id}": if sess == f"{query.launcher_type.value}_{query.launcher_id}":
found = True found = True
break break
result = False ctn = False
if mode == 'blacklist': if mode == 'whitelist':
result = found ctn = found
else:
ctn = not found
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT, result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT,
new_query=query, new_query=query,
debug_notice=f'根据访问控制忽略消息: {query.launcher_type}_{query.launcher_id}' if result else '' console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else ''
) )

View File

@@ -1,34 +1,58 @@
from __future__ import annotations from __future__ import annotations
import mirai import mirai
import mirai.models
import mirai.models.message
from ...core import app from ...core import app
from .. import stage, entities, stagemgr from .. import stage, entities, stagemgr
from ...core import entities as core_entities from ...core import entities as core_entities
from ...config import manager as cfg_mgr from ...config import manager as cfg_mgr
from . import filter, entities as filter_entities from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine from .filters import cntignore, banwords, baiduexamine
from ...provider import entities as llm_entities
@stage.stage_class('PostContentFilterStage') @stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage') @stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage): class ContentFilterStage(stage.PipelineStage):
"""内容过滤阶段
filter_chain: list[filter.ContentFilter] 前置:
检查消息是否符合规则,不符合则拦截。
改写:
message_chain
后置:
检查AI回复消息是否符合规则可能进行改写不符合则拦截。
改写:
query.resp_messages
"""
filter_chain: list[filter_model.ContentFilter]
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.filter_chain = [] self.filter_chain = []
super().__init__(ap) super().__init__(ap)
async def initialize(self): async def initialize(self):
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
filters_required = [
"content-ignore",
]
if self.ap.pipeline_cfg.data['check-sensitive-words']: if self.ap.pipeline_cfg.data['check-sensitive-words']:
self.filter_chain.append(banwords.BanWordFilter(self.ap)) filters_required.append("ban-word-filter")
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']: if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap)) filters_required.append("baidu-cloud-examine")
for filter in filter_model.preregistered_filters:
if filter.name in filters_required:
self.filter_chain.append(
filter(self.ap)
)
for filter in self.filter_chain: for filter in self.filter_chain:
await filter.initialize() await filter.initialize()
@@ -41,6 +65,7 @@ class ContentFilterStage(stage.PipelineStage):
"""请求llm前处理消息 """请求llm前处理消息
只要有一个不通过就不放行,只放行 PASS 的消息 只要有一个不通过就不放行,只放行 PASS 的消息
""" """
if not self.ap.pipeline_cfg.data['income-msg-check']: if not self.ap.pipeline_cfg.data['income-msg-check']:
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@@ -120,14 +145,39 @@ class ContentFilterStage(stage.PipelineStage):
"""处理 """处理
""" """
if stage_inst_name == 'PreContentFilterStage': if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False
text_components = [mirai.Plain, mirai.models.message.Source]
for me in query.message_chain:
if type(me) not in text_components:
contain_non_text = True
break
if contain_non_text:
self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。")
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
return await self._pre_process( return await self._pre_process(
str(query.message_chain).strip(), str(query.message_chain).strip(),
query query
) )
elif stage_inst_name == 'PostContentFilterStage': elif stage_inst_name == 'PostContentFilterStage':
return await self._post_process( # 仅处理 query.resp_messages[-1].content 是 str 的情况
query.resp_messages[-1].content, if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(query.resp_messages[-1].content, str):
query return await self._post_process(
) query.resp_messages[-1].content,
query
)
else:
self.ap.logger.debug(f"resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。")
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else: else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}') raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')

View File

@@ -4,6 +4,8 @@ import enum
import pydantic import pydantic
from ...provider import entities as llm_entities
class ResultLevel(enum.Enum): class ResultLevel(enum.Enum):
"""结果等级""" """结果等级"""
@@ -31,15 +33,24 @@ class EnableStage(enum.Enum):
class FilterResult(pydantic.BaseModel): class FilterResult(pydantic.BaseModel):
level: ResultLevel level: ResultLevel
"""结果等级
对于前置处理阶段,只要有任意一个返回 非PASS 的内容过滤器结果,就会中断处理。
对于后置处理阶段,当且内容过滤器返回 BLOCK 时,会中断处理。
"""
replacement: str replacement: str
"""替换后的消息""" """替换后的文本消息
内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。
若没有修改内容,也需要返回原消息。
"""
user_notice: str user_notice: str
"""不通过时,用户提示消息""" """不通过时,若此值不为空,将对用户提示消息"""
console_notice: str console_notice: str
"""不通过时,控制台提示消息""" """不通过时,若此值不为空,将在控制台提示消息"""
class ManagerResultLevel(enum.Enum): class ManagerResultLevel(enum.Enum):

View File

@@ -1,12 +1,43 @@
# 内容过滤器的抽象类 # 内容过滤器的抽象类
from __future__ import annotations from __future__ import annotations
import abc import abc
import typing
from ...core import app from ...core import app
from . import entities from . import entities
from ...provider import entities as llm_entities
preregistered_filters: list[typing.Type[ContentFilter]] = []
def filter_class(
name: str
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
"""内容过滤器类装饰器
Args:
name (str): 过滤器名称
Returns:
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器
"""
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
assert issubclass(cls, ContentFilter)
cls.name = name
preregistered_filters.append(cls)
return cls
return decorator
class ContentFilter(metaclass=abc.ABCMeta): class ContentFilter(metaclass=abc.ABCMeta):
"""内容过滤器抽象类"""
name: str
ap: app.Application ap: app.Application
@@ -16,6 +47,11 @@ class ContentFilter(metaclass=abc.ABCMeta):
@property @property
def enable_stages(self): def enable_stages(self):
"""启用的阶段 """启用的阶段
默认为消息请求AI前后的两个阶段。
entity.EnableStage.PRE: 消息请求AI前此时需要检查的内容是用户的输入消息。
entity.EnableStage.POST: 消息请求AI后此时需要检查的内容是AI的回复消息。
""" """
return [ return [
entities.EnableStage.PRE, entities.EnableStage.PRE,
@@ -28,7 +64,17 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def process(self, message: str) -> entities.FilterResult: async def process(self, message: str=None, image_url=None) -> entities.FilterResult:
"""处理消息 """处理消息
分为前后阶段,具体取决于 enable_stages 的值。
对于内容过滤器来说,不需要考虑消息所处的阶段,只需要检查消息内容即可。
Args:
message (str): 需要检查的内容
image_url (str): 要检查的图片的 URL
Returns:
entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档
""" """
raise NotImplementedError raise NotImplementedError

View File

@@ -10,6 +10,7 @@ BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
@filter_model.filter_class("baidu-cloud-examine")
class BaiduCloudExamine(filter_model.ContentFilter): class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核""" """百度云内容审核"""

View File

@@ -6,34 +6,30 @@ from .. import entities
from ....config import manager as cfg_mgr from ....config import manager as cfg_mgr
@filter_model.filter_class("ban-word-filter")
class BanWordFilter(filter_model.ContentFilter): class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言""" """根据内容过滤"""
sensitive: cfg_mgr.ConfigManager
async def initialize(self): async def initialize(self):
self.sensitive = await cfg_mgr.load_json_config( pass
"data/config/sensitive-words.json",
"templates/sensitive-words.json"
)
async def process(self, message: str) -> entities.FilterResult: async def process(self, message: str) -> entities.FilterResult:
found = False found = False
for word in self.sensitive.data['words']: for word in self.ap.sensitive_meta.data['words']:
match = re.findall(word, message) match = re.findall(word, message)
if len(match) > 0: if len(match) > 0:
found = True found = True
for i in range(len(match)): for i in range(len(match)):
if self.sensitive.data['mask_word'] == "": if self.ap.sensitive_meta.data['mask_word'] == "":
message = message.replace( message = message.replace(
match[i], self.sensitive.data['mask'] * len(match[i]) match[i], self.ap.sensitive_meta.data['mask'] * len(match[i])
) )
else: else:
message = message.replace( message = message.replace(
match[i], self.sensitive.data['mask_word'] match[i], self.ap.sensitive_meta.data['mask_word']
) )
return entities.FilterResult( return entities.FilterResult(

View File

@@ -5,6 +5,7 @@ from .. import entities
from .. import filter as filter_model from .. import filter as filter_model
@filter_model.filter_class("content-ignore")
class ContentIgnore(filter_model.ContentFilter): class ContentIgnore(filter_model.ContentFilter):
"""根据内容忽略消息""" """根据内容忽略消息"""

View File

@@ -4,8 +4,10 @@ import asyncio
import typing import typing
import traceback import traceback
from . import app, entities import mirai
from ..pipeline import entities as pipeline_entities
from ..core import app, entities
from . import entities as pipeline_entities
from ..plugin import events from ..plugin import events
@@ -68,7 +70,18 @@ class Controller:
"""检查输出 """检查输出
""" """
if result.user_notice: if result.user_notice:
await self.ap.im_mgr.send( # 处理str类型
if isinstance(result.user_notice, str):
result.user_notice = mirai.MessageChain(
mirai.Plain(result.user_notice)
)
elif isinstance(result.user_notice, list):
result.user_notice = mirai.MessageChain(
*result.user_notice
)
await self.ap.platform_mgr.send(
query.message_event, query.message_event,
result.user_notice, result.user_notice,
query.adapter query.adapter
@@ -85,7 +98,7 @@ class Controller:
stage_index: int, stage_index: int,
query: entities.Query, query: entities.Query,
): ):
"""从指定阶段开始执行 """从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
如何看懂这里为什么这么写 如何看懂这里为什么这么写
去问 GPT-4: 去问 GPT-4:
@@ -110,6 +123,8 @@ class Controller:
while i < len(self.ap.stage_mgr.stage_containers): while i < len(self.ap.stage_mgr.stage_containers):
stage_container = self.ap.stage_mgr.stage_containers[i] stage_container = self.ap.stage_mgr.stage_containers[i]
query.current_stage = stage_container # 标记到 Query 对象里
result = stage_container.inst.process(query, stage_container.inst_name) result = stage_container.inst.process(query, stage_container.inst_name)
if isinstance(result, typing.Coroutine): if isinstance(result, typing.Coroutine):
@@ -149,7 +164,7 @@ class Controller:
try: try:
await self._execute_from_stage(0, query) await self._execute_from_stage(0, query)
except Exception as e: except Exception as e:
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id}: {e}") self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={query.current_stage.inst_name} : {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
# traceback.print_exc() # traceback.print_exc()
finally: finally:

View File

@@ -15,6 +15,11 @@ from ...config import manager as cfg_mgr
@stage.stage_class("LongTextProcessStage") @stage.stage_class("LongTextProcessStage")
class LongTextProcessStage(stage.PipelineStage): class LongTextProcessStage(stage.PipelineStage):
"""长消息处理阶段
改写:
- resp_message_chain
"""
strategy_impl: strategy.LongTextStrategy strategy_impl: strategy.LongTextStrategy
@@ -29,30 +34,44 @@ class LongTextProcessStage(stage.PipelineStage):
if os.name == "nt": if os.name == "nt":
use_font = "C:/Windows/Fonts/msyh.ttc" use_font = "C:/Windows/Fonts/msyh.ttc"
if not os.path.exists(use_font): if not os.path.exists(use_font):
self.ap.logger.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。") self.ap.logger.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在配置文件中调整相关设置。")
config['blob_message_strategy'] = "forward" config['blob_message_strategy'] = "forward"
else: else:
self.ap.logger.info("使用Windows自带字体" + use_font) self.ap.logger.info("使用Windows自带字体" + use_font)
config['font-path'] = use_font config['font-path'] = use_font
else: else:
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。")
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
except: except:
traceback.print_exc() traceback.print_exc()
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font))
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
if config['strategy'] == 'image': for strategy_cls in strategy.preregistered_strategies:
self.strategy_impl = image.Text2ImageStrategy(self.ap) if strategy_cls.name == config['strategy']:
elif config['strategy'] == 'forward': self.strategy_impl = strategy_cls(self.ap)
self.strategy_impl = forward.ForwardComponentStrategy(self.ap) break
else:
raise ValueError(f"未找到名为 {config['strategy']} 的长消息处理策略")
await self.strategy_impl.initialize() await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']: # 检查是否包含非 Plain 组件
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query)) contains_non_plain = False
for msg in query.resp_message_chain[-1]:
if not isinstance(msg, Plain):
contains_non_plain = True
break
if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query

View File

@@ -36,6 +36,7 @@ class Forward(MessageComponent):
return '[聊天记录]' return '[聊天记录]'
@strategy_model.strategy_class("forward")
class ForwardComponentStrategy(strategy_model.LongTextStrategy): class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:

View File

@@ -15,6 +15,7 @@ from .. import strategy as strategy_model
from ....core import entities as core_entities from ....core import entities as core_entities
@strategy_model.strategy_class("image")
class Text2ImageStrategy(strategy_model.LongTextStrategy): class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_render_font: ImageFont.FreeTypeFont text_render_font: ImageFont.FreeTypeFont

View File

@@ -9,7 +9,39 @@ from ...core import app
from ...core import entities as core_entities from ...core import entities as core_entities
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
def strategy_class(
name: str
) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]:
"""长文本处理策略类装饰器
Args:
name (str): 策略名称
Returns:
typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: 装饰器
"""
def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]:
assert issubclass(cls, LongTextStrategy)
cls.name = name
preregistered_strategies.append(cls)
return cls
return decorator
class LongTextStrategy(metaclass=abc.ABCMeta): class LongTextStrategy(metaclass=abc.ABCMeta):
"""长文本处理策略抽象类
"""
name: str
ap: app.Application ap: app.Application
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
@@ -20,4 +52,15 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
"""处理长文本
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
Args:
message (str): 消息
query (core_entities.Query): 此次请求的上下文对象
Returns:
list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表
"""
return [] return []

View File

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from . import truncator
from .truncators import round
@stage.stage_class("ConversationMessageTruncator")
class ConversationMessageTruncator(stage.PipelineStage):
"""会话消息截断器
用于截断会话消息链,以适应平台消息长度限制。
"""
trun: truncator.Truncator
async def initialize(self):
use_method = self.ap.pipeline_cfg.data['msg-truncate']['method']
for trun in truncator.preregistered_truncators:
if trun.name == use_method:
self.trun = trun(self.ap)
break
else:
raise ValueError(f"未知的截断器: {use_method}")
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理
"""
query = await self.trun.truncate(query)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
import typing
import abc
from ...core import entities as core_entities, app
preregistered_truncators: list[typing.Type[Truncator]] = []
def truncator_class(
name: str
) -> typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]:
"""截断器类装饰器
Args:
name (str): 截断器名称
Returns:
typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: 装饰器
"""
def decorator(cls: typing.Type[Truncator]) -> typing.Type[Truncator]:
assert issubclass(cls, Truncator)
cls.name = name
preregistered_truncators.append(cls)
return cls
return decorator
class Truncator(abc.ABC):
"""消息截断器基类
"""
name: str
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
"""截断
一般只需要操作query.messages也可以扩展操作query.prompt, query.user_message。
请勿操作其他字段。
"""
pass

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from .. import truncator
from ....core import entities as core_entities
@truncator.truncator_class("round")
class RoundTruncator(truncator.Truncator):
"""前文回合数阶段器
"""
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
"""截断
"""
max_round = self.ap.pipeline_cfg.data['msg-truncate']['round']['max-round']
temp_messages = []
current_round = 0
# 从后往前遍历
for msg in query.messages[::-1]:
if current_round < max_round:
temp_messages.append(msg)
if msg.role == 'user':
current_round += 1
else:
break
query.messages = temp_messages[::-1]
return query

View File

@@ -4,11 +4,12 @@ import asyncio
import mirai import mirai
from . import entities from ..core import entities
from ..platform import adapter as msadapter from ..platform import adapter as msadapter
class QueryPool: class QueryPool:
"""请求池请求获得调度进入pipeline之前保存在这里"""
query_id_counter: int = 0 query_id_counter: int = 0
@@ -42,7 +43,7 @@ class QueryPool:
message_event=message_event, message_event=message_event,
message_chain=message_chain, message_chain=message_chain,
resp_messages=[], resp_messages=[],
resp_message_chain=None, resp_message_chain=[],
adapter=adapter adapter=adapter
) )
self.queries.append(query) self.queries.append(query)

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import mirai
from .. import stage, entities, stagemgr from .. import stage, entities, stagemgr
from ...core import entities as core_entities from ...core import entities as core_entities
from ...provider import entities as llm_entities from ...provider import entities as llm_entities
@@ -8,7 +10,17 @@ from ...plugin import events
@stage.stage_class("PreProcessor") @stage.stage_class("PreProcessor")
class PreProcessor(stage.PipelineStage): class PreProcessor(stage.PipelineStage):
"""预处理 """请求预处理阶段
签出会话、prompt、上文、模型、内容函数。
改写:
- session
- prompt
- messages
- user_message
- use_model
- use_funcs
""" """
async def process( async def process(
@@ -27,21 +39,42 @@ class PreProcessor(stage.PipelineStage):
query.prompt = conversation.prompt.copy() query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy() query.messages = conversation.messages.copy()
query.user_message = llm_entities.Message(
role='user',
content=str(query.message_chain).strip()
)
query.use_model = conversation.use_model query.use_model = conversation.use_model
query.use_funcs = conversation.use_funcs query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None
# 检查vision是否启用没启用就删除所有图片
if not self.ap.provider_cfg.data['enable-vision'] or not query.use_model.vision_supported:
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
if me.type == 'image_url':
msg.content.remove(me)
content_list = []
for me in query.message_chain:
if isinstance(me, mirai.Plain):
content_list.append(
llm_entities.ContentElement.from_text(me.text)
)
elif isinstance(me, mirai.Image):
if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported:
if me.url is not None:
content_list.append(
llm_entities.ContentElement.from_image_url(str(me.url))
)
query.user_message = llm_entities.Message( # TODO 适配多模态输入
role='user',
content=content_list
)
# =========== 触发事件 PromptPreProcessing # =========== 触发事件 PromptPreProcessing
session = query.session
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PromptPreProcessing( event=events.PromptPreProcessing(
session_name=f'{session.launcher_type.value}_{session.launcher_id}', session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
default_prompt=query.prompt.messages, default_prompt=query.prompt.messages,
prompt=query.messages, prompt=query.messages,
query=query query=query
@@ -51,28 +84,6 @@ class PreProcessor(stage.PipelineStage):
query.prompt.messages = event_ctx.event.default_prompt query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt query.messages = event_ctx.event.prompt
# 根据模型max_tokens剪裁
max_tokens = min(query.use_model.max_tokens, self.ap.pipeline_cfg.data['submit-messages-tokens'])
test_messages = query.prompt.messages + query.messages + [query.user_message]
while await query.use_model.tokenizer.count_token(test_messages, query.use_model) > max_tokens:
# 前文都pop完了还是大于max_tokens由于prompt和user_messages不能删减报错
if len(query.prompt.messages) == 0:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='输入内容过长,请减少情景预设或者输入内容长度',
console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 submit-messages-tokens 项但不能超过所用模型最大tokens数'
)
query.messages.pop(0) # pop第一个肯定是role=user的
# 继续pop到第二个role=user前一个
while len(query.messages) > 0 and query.messages[0].role != 'user':
query.messages.pop(0)
test_messages = query.prompt.messages + query.messages + [query.user_message]
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query

View File

@@ -23,3 +23,12 @@ class MessageHandler(metaclass=abc.ABCMeta):
query: core_entities.Query, query: core_entities.Query,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
raise NotImplementedError raise NotImplementedError
def cut_str(self, s: str) -> str:
"""
取字符串第一行最多20个字符若有多行或超过20个字符则加省略号
"""
s0 = s.split('\n')[0]
if len(s0) > 20 or '\n' in s:
s0 = s0[:20] + '...'
return s0

View File

@@ -3,13 +3,14 @@ from __future__ import annotations
import typing import typing
import time import time
import traceback import traceback
import json
import mirai import mirai
from .. import handler from .. import handler
from ... import entities from ... import entities
from ....core import entities as core_entities from ....core import entities as core_entities
from ....provider import entities as llm_entities from ....provider import entities as llm_entities, runnermgr
from ....plugin import events from ....plugin import events
@@ -21,8 +22,6 @@ class ChatMessageHandler(handler.MessageHandler):
) -> typing.AsyncGenerator[entities.StageProcessResult, None]: ) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理 """处理
""" """
# 取session
# 取conversation
# 调API # 调API
# 生成器 # 生成器
@@ -41,7 +40,9 @@ class ChatMessageHandler(handler.MessageHandler):
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append(mc)
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@@ -61,13 +62,8 @@ class ChatMessageHandler(handler.MessageHandler):
) )
if event_ctx.event.alter is not None: if event_ctx.event.alter is not None:
query.message_chain = mirai.MessageChain([ # if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
mirai.Plain(event_ctx.event.alter) query.user_message.content = event_ctx.event.alter
])
query.messages.append(
query.user_message
)
text_length = 0 text_length = 0
@@ -75,9 +71,13 @@ class ChatMessageHandler(handler.MessageHandler):
try: try:
async for result in query.use_model.requester.request(query): runner = self.ap.runner_mgr.get_runner()
async for result in runner.run(query):
query.resp_messages.append(result) query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
if result.content is not None: if result.content is not None:
text_length += len(result.content) text_length += len(result.content)
@@ -85,7 +85,13 @@ class ChatMessageHandler(handler.MessageHandler):
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
except Exception as e: except Exception as e:
self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}')
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, result_type=entities.ResultType.INTERRUPT,
new_query=query, new_query=query,
@@ -94,8 +100,6 @@ class ChatMessageHandler(handler.MessageHandler):
debug_notice=traceback.format_exc() debug_notice=traceback.format_exc()
) )
finally: finally:
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
await self.ap.ctr_mgr.usage.post_query_record( await self.ap.ctr_mgr.usage.post_query_record(
session_type=query.session.launcher_type.value, session_type=query.session.launcher_type.value,

View File

@@ -19,15 +19,16 @@ class CommandHandler(handler.MessageHandler):
"""处理 """处理
""" """
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent command_text = str(query.message_chain).strip()[1:]
privilege = 1 privilege = 1
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']: if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']:
privilege = 2 privilege = 2
spt = str(query.message_chain).strip().split(' ') spt = command_text.split(' ')
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_class( event=event_class(
@@ -47,12 +48,7 @@ class CommandHandler(handler.MessageHandler):
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply) mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append( query.resp_messages.append(mc)
llm_entities.Message(
role='command',
content=str(mc),
)
)
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@@ -73,17 +69,12 @@ class CommandHandler(handler.MessageHandler):
session = await self.ap.sess_mgr.get_session(query) session = await self.ap.sess_mgr.get_session(query)
command_text = str(query.message_chain).strip()[1:]
async for ret in self.ap.cmd_mgr.execute( async for ret in self.ap.cmd_mgr.execute(
command_text=command_text, command_text=command_text,
query=query, query=query,
session=session session=session
): ):
if ret.error is not None: if ret.error is not None:
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(str(ret.error))
# ])
query.resp_messages.append( query.resp_messages.append(
llm_entities.Message( llm_entities.Message(
role='command', role='command',
@@ -91,21 +82,35 @@ class CommandHandler(handler.MessageHandler):
) )
) )
self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}')
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
elif ret.text is not None: elif ret.text is not None or ret.image_url is not None:
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(ret.text) content: list[llm_entities.ContentElement]= []
# ])
if ret.text is not None:
content.append(
llm_entities.ContentElement.from_text(ret.text)
)
if ret.image_url is not None:
content.append(
llm_entities.ContentElement.from_image_url(ret.image_url)
)
query.resp_messages.append( query.resp_messages.append(
llm_entities.Message( llm_entities.Message(
role='command', role='command',
content=ret.text, content=content,
) )
) )
self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query

View File

@@ -11,6 +11,13 @@ from ...config import manager as cfg_mgr
@stage.stage_class("MessageProcessor") @stage.stage_class("MessageProcessor")
class Processor(stage.PipelineStage): class Processor(stage.PipelineStage):
"""请求实际处理阶段
通过命令处理器和聊天处理器处理消息。
改写:
- resp_messages
"""
cmd_handler: handler.MessageHandler cmd_handler: handler.MessageHandler
@@ -34,7 +41,14 @@ class Processor(stage.PipelineStage):
self.ap.logger.info(f"处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}") self.ap.logger.info(f"处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}")
if message_text.startswith('!') or message_text.startswith(''): async def generator():
return self.cmd_handler.handle(query) cmd_prefix = self.ap.command_cfg.data['command-prefix']
else:
return self.chat_handler.handle(query) if any(message_text.startswith(prefix) for prefix in cmd_prefix):
async for result in self.cmd_handler.handle(query):
yield result
else:
async for result in self.chat_handler.handle(query):
yield result
return generator()

View File

@@ -1,10 +1,26 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import typing
from ...core import app from ...core import app
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
def algo_class(name: str):
def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]:
cls.name = name
preregistered_algos.append(cls)
return cls
return decorator
class ReteLimitAlgo(metaclass=abc.ABCMeta): class ReteLimitAlgo(metaclass=abc.ABCMeta):
"""限流算法抽象类"""
name: str = None
ap: app.Application ap: app.Application
@@ -16,9 +32,27 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def require_access(self, launcher_type: str, launcher_id: int) -> bool: async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
"""进入处理流程
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
Args:
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
launcher_id (int): 请求者ID
Returns:
bool: 是否允许进入处理流程若返回false则直接丢弃该请求
"""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def release_access(self, launcher_type: str, launcher_id: int): async def release_access(self, launcher_type: str, launcher_id: int):
"""退出处理流程
Args:
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
launcher_id (int): 请求者ID
"""
raise NotImplementedError raise NotImplementedError

View File

@@ -1,24 +1,22 @@
# 固定窗口算法
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import time import time
from .. import algo from .. import algo
# 固定窗口算法
class SessionContainer: class SessionContainer:
wait_lock: asyncio.Lock wait_lock: asyncio.Lock
records: dict[int, int] records: dict[int, int]
"""访问记录key为每分钟的起始时间戳value为访问次数""" """访问记录key为每窗口长度的起始时间戳value为访问次数"""
def __init__(self): def __init__(self):
self.wait_lock = asyncio.Lock() self.wait_lock = asyncio.Lock()
self.records = {} self.records = {}
@algo.algo_class("fixwin")
class FixedWindowAlgo(algo.ReteLimitAlgo): class FixedWindowAlgo(algo.ReteLimitAlgo):
containers_lock: asyncio.Lock containers_lock: asyncio.Lock
@@ -46,30 +44,34 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
# 等待锁 # 等待锁
async with container.wait_lock: async with container.wait_lock:
# 获取窗口大小和限制
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['window-size']
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['limit']
if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
# 获取当前时间戳 # 获取当前时间戳
now = int(time.time()) now = int(time.time())
# 获取当前分钟的起始时间戳 # 获取当前窗口的起始时间戳
now = now - now % 60 now = now - now % window_size
# 获取当前分钟的访问次数 # 获取当前窗口的访问次数
count = container.records.get(now, 0) count = container.records.get(now, 0)
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']
if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]
# 如果访问次数超过了限制 # 如果访问次数超过了限制
if count >= limitation: if count >= limitation:
if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop': if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop':
return False return False
elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait': elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait':
# 等待下一分钟 # 等待下一窗口
await asyncio.sleep(60 - time.time() % 60) await asyncio.sleep(window_size - time.time() % window_size)
now = int(time.time()) now = int(time.time())
now = now - now % 60 now = now - now % window_size
if now not in container.records: if now not in container.records:
container.records = {} container.records = {}

View File

@@ -11,11 +11,27 @@ from ...core import entities as core_entities
@stage.stage_class("RequireRateLimitOccupancy") @stage.stage_class("RequireRateLimitOccupancy")
@stage.stage_class("ReleaseRateLimitOccupancy") @stage.stage_class("ReleaseRateLimitOccupancy")
class RateLimit(stage.PipelineStage): class RateLimit(stage.PipelineStage):
"""限速器控制阶段
不改写query只检查是否需要限速。
"""
algo: algo.ReteLimitAlgo algo: algo.ReteLimitAlgo
async def initialize(self): async def initialize(self):
self.algo = fixedwin.FixedWindowAlgo(self.ap)
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
algo_class = None
for algo_cls in algo.preregistered_algos:
if algo_cls.name == algo_name:
algo_class = algo_cls
break
else:
raise ValueError(f'未知的限速算法: {algo_name}')
self.algo = algo_class(self.ap)
await self.algo.initialize() await self.algo.initialize()
async def process( async def process(
@@ -46,7 +62,7 @@ class RateLimit(stage.PipelineStage):
) )
elif stage_inst_name == "ReleaseRateLimitOccupancy": elif stage_inst_name == "ReleaseRateLimitOccupancy":
await self.algo.release_access( await self.algo.release_access(
query.launcher_type, query.launcher_type.value,
query.launcher_id, query.launcher_id,
) )
return entities.StageProcessResult( return entities.StageProcessResult(

View File

@@ -29,9 +29,9 @@ class SendResponseBackStage(stage.PipelineStage):
await asyncio.sleep(random_delay) await asyncio.sleep(random_delay)
await self.ap.im_mgr.send( await self.ap.platform_mgr.send(
query.message_event, query.message_event,
query.resp_message_chain, query.resp_message_chain[-1],
adapter=query.adapter adapter=query.adapter
) )

View File

@@ -14,26 +14,27 @@ from ...config import manager as cfg_mgr
@stage.stage_class("GroupRespondRuleCheckStage") @stage.stage_class("GroupRespondRuleCheckStage")
class GroupRespondRuleCheckStage(stage.PipelineStage): class GroupRespondRuleCheckStage(stage.PipelineStage):
"""群组响应规则检查器 """群组响应规则检查器
仅检查群消息是否符合规则。
""" """
rule_matchers: list[rule.GroupRespondRule] rule_matchers: list[rule.GroupRespondRule]
"""检查器实例"""
async def initialize(self): async def initialize(self):
"""初始化检查器 """初始化检查器
""" """
self.rule_matchers = [
atbot.AtBotRule(self.ap),
prefix.PrefixRule(self.ap),
regexp.RegExpRule(self.ap),
random.RandomRespRule(self.ap),
]
for rule_matcher in self.rule_matchers: self.rule_matchers = []
await rule_matcher.initialize()
for rule_matcher in rule.preregisetered_rules:
rule_inst = rule_matcher(self.ap)
await rule_inst.initialize()
self.rule_matchers.append(rule_inst)
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type.value != 'group': if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
@@ -43,8 +44,8 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
use_rule = rules['default'] use_rule = rules['default']
if str(query.launcher_id) in use_rule: if str(query.launcher_id) in rules:
use_rule = use_rule[str(query.launcher_id)] use_rule = rules[str(query.launcher_id)]
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行 for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query) res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import typing
import mirai import mirai
@@ -7,9 +8,20 @@ from ...core import app, entities as core_entities
from . import entities from . import entities
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
def rule_class(name: str):
def decorator(cls: typing.Type[GroupRespondRule]) -> typing.Type[GroupRespondRule]:
cls.name = name
preregisetered_rules.append(cls)
return cls
return decorator
class GroupRespondRule(metaclass=abc.ABCMeta): class GroupRespondRule(metaclass=abc.ABCMeta):
"""群组响应规则的抽象类 """群组响应规则的抽象类
""" """
name: str
ap: app.Application ap: app.Application

View File

@@ -7,6 +7,7 @@ from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
@rule_model.rule_class("at-bot")
class AtBotRule(rule_model.GroupRespondRule): class AtBotRule(rule_model.GroupRespondRule):
async def match( async def match(
@@ -19,6 +20,10 @@ class AtBotRule(rule_model.GroupRespondRule):
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']: if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(query.adapter.bot_account_id)) message_chain.remove(mirai.At(query.adapter.bot_account_id))
if message_chain.has(mirai.At(query.adapter.bot_account_id)): # 回复消息时会at两次检查并删除重复的
message_chain.remove(mirai.At(query.adapter.bot_account_id))
return entities.RuleJudgeResult( return entities.RuleJudgeResult(
matching=True, matching=True,
replacement=message_chain, replacement=message_chain,

View File

@@ -5,6 +5,7 @@ from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
@rule_model.rule_class("prefix")
class PrefixRule(rule_model.GroupRespondRule): class PrefixRule(rule_model.GroupRespondRule):
async def match( async def match(
@@ -19,11 +20,14 @@ class PrefixRule(rule_model.GroupRespondRule):
for prefix in prefixes: for prefix in prefixes:
if message_text.startswith(prefix): if message_text.startswith(prefix):
# 查找第一个plain元素
for me in message_chain:
if isinstance(me, mirai.Plain):
me.text = me.text[len(prefix):]
return entities.RuleJudgeResult( return entities.RuleJudgeResult(
matching=True, matching=True,
replacement=mirai.MessageChain([ replacement=message_chain,
mirai.Plain(message_text[len(prefix):])
]),
) )
return entities.RuleJudgeResult( return entities.RuleJudgeResult(

View File

@@ -7,6 +7,7 @@ from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
@rule_model.rule_class("random")
class RandomRespRule(rule_model.GroupRespondRule): class RandomRespRule(rule_model.GroupRespondRule):
async def match( async def match(

View File

@@ -7,6 +7,7 @@ from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
@rule_model.rule_class("regexp")
class RegExpRule(rule_model.GroupRespondRule): class RegExpRule(rule_model.GroupRespondRule):
async def match( async def match(

View File

@@ -13,20 +13,23 @@ from .respback import respback
from .wrapper import wrapper from .wrapper import wrapper
from .preproc import preproc from .preproc import preproc
from .ratelimit import ratelimit from .ratelimit import ratelimit
from .msgtrun import msgtrun
# 请求处理阶段顺序
stage_order = [ stage_order = [
"GroupRespondRuleCheckStage", "GroupRespondRuleCheckStage", # 群响应规则检查
"BanSessionCheckStage", "BanSessionCheckStage", # 封禁会话检查
"PreContentFilterStage", "PreContentFilterStage", # 内容过滤前置阶段
"PreProcessor", "PreProcessor", # 预处理器
"RequireRateLimitOccupancy", "ConversationMessageTruncator", # 会话消息截断器
"MessageProcessor", "RequireRateLimitOccupancy", # 请求速率限制占用
"ReleaseRateLimitOccupancy", "MessageProcessor", # 处理器
"PostContentFilterStage", "ReleaseRateLimitOccupancy", # 释放速率限制占用
"ResponseWrapper", "PostContentFilterStage", # 内容过滤后置阶段
"LongTextProcessStage", "ResponseWrapper", # 响应包装器
"SendResponseBackStage", "LongTextProcessStage", # 长文本处理
"SendResponseBackStage", # 发送响应
] ]

View File

@@ -14,6 +14,13 @@ from ...plugin import events
@stage.stage_class("ResponseWrapper") @stage.stage_class("ResponseWrapper")
class ResponseWrapper(stage.PipelineStage): class ResponseWrapper(stage.PipelineStage):
"""回复包装阶段
把回复的 message 包装成人类识读的形式。
改写:
- resp_message_chain
"""
async def initialize(self): async def initialize(self):
pass pass
@@ -26,67 +33,48 @@ class ResponseWrapper(stage.PipelineStage):
"""处理 """处理
""" """
if query.resp_messages[-1].role == 'command': # 如果 resp_messages[-1] 已经是 MessageChain 了
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content) if isinstance(query.resp_messages[-1], mirai.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
else: else:
if query.resp_messages[-1].role == 'assistant': if query.resp_messages[-1].role == 'command':
result = query.resp_messages[-1] # query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
session = await self.ap.sess_mgr.get_session(query) query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))
reply_text = '' yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif query.resp_messages[-1].role == 'plugin':
# if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
# query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
# else:
# query.resp_message_chain.append(query.resp_messages[-1].content)
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain())
if result.content is not None: # 有内容 yield entities.StageProcessResult(
reply_text = result.content result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
# ============= 触发插件事件 =============== if query.resp_messages[-1].role == 'assistant':
event_ctx = await self.ap.plugin_mgr.emit_event( result = query.resp_messages[-1]
event=events.NormalMessageResponded( session = await self.ap.sess_mgr.get_session(query)
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) reply_text = ''
else: if result.content: # 有内容
reply_text = str(result.get_content_mirai_message_chain())
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
if self.ap.platform_cfg.data['track-function-calls']:
# ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded( event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value, launcher_type=query.launcher_type.value,
@@ -100,7 +88,6 @@ class ResponseWrapper(stage.PipelineStage):
query=query query=query
) )
) )
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, result_type=entities.ResultType.INTERRUPT,
@@ -109,13 +96,56 @@ class ResponseWrapper(stage.PipelineStage):
else: else:
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
else: else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) query.resp_message_chain.append(result.get_content_mirai_message_chain())
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
else:
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -14,6 +14,14 @@ preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = []
def adapter_class( def adapter_class(
name: str name: str
): ):
"""消息平台适配器类装饰器
Args:
name (str): 适配器名称
Returns:
typing.Callable[[typing.Type[MessageSourceAdapter]], typing.Type[MessageSourceAdapter]]: 装饰器
"""
def decorator(cls: typing.Type[MessageSourceAdapter]) -> typing.Type[MessageSourceAdapter]: def decorator(cls: typing.Type[MessageSourceAdapter]) -> typing.Type[MessageSourceAdapter]:
cls.name = name cls.name = name
preregistered_adapters.append(cls) preregistered_adapters.append(cls)
@@ -22,15 +30,24 @@ def adapter_class(
class MessageSourceAdapter(metaclass=abc.ABCMeta): class MessageSourceAdapter(metaclass=abc.ABCMeta):
"""消息平台适配器基类"""
name: str name: str
bot_account_id: int bot_account_id: int
"""机器人账号ID需要在初始化时设置"""
config: dict config: dict
ap: app.Application ap: app.Application
def __init__(self, config: dict, ap: app.Application): def __init__(self, config: dict, ap: app.Application):
"""初始化适配器
Args:
config (dict): 对应的配置
ap (app.Application): 应用上下文
"""
self.config = config self.config = config
self.ap = ap self.ap = ap
@@ -40,7 +57,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
target_id: str, target_id: str,
message: mirai.MessageChain message: mirai.MessageChain
): ):
"""发送消息 """主动发送消息
Args: Args:
target_type (str): 目标类型,`person`或`group` target_type (str): 目标类型,`person`或`group`

View File

@@ -82,8 +82,8 @@ class PlatformManager:
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.GroupMessageReceived( event=events.GroupMessageReceived(
launcher_type='person', launcher_type='group',
launcher_id=event.sender.id, launcher_id=event.group.id,
sender_id=event.sender.id, sender_id=event.sender.id,
message_chain=event.message_chain, message_chain=event.message_chain,
query=None query=None
@@ -146,9 +146,9 @@ class PlatformManager:
if len(self.adapters) == 0: if len(self.adapters) == 0:
self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。') self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。')
async def send(self, event, msg, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True): async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter):
if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage): if self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage):
msg.insert( msg.insert(
0, 0,
@@ -160,28 +160,9 @@ class PlatformManager:
await adapter.reply_message( await adapter.reply_message(
event, event,
msg, msg,
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False quote_origin=True if self.ap.platform_cfg.data['quote-origin'] else False
) )
# 通知系统管理员
# TODO delete
# async def notify_admin(self, message: str):
# await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
# async def notify_admin_message_chain(self, message: mirai.MessageChain):
# if self.ap.system_cfg.data['admin-sessions'] != []:
# admin_list = []
# for admin in self.ap.system_cfg.data['admin-sessions']:
# admin_list.append(admin)
# for adm in admin_list:
# self.adapter.send_message(
# adm.split("_")[0],
# adm.split("_")[1],
# message
# )
async def run(self): async def run(self):
try: try:
tasks = [] tasks = []

View File

@@ -30,7 +30,16 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
msg_id = msg.id msg_id = msg.id
msg_time = msg.time msg_time = msg.time
elif type(msg) is mirai.Image: elif type(msg) is mirai.Image:
msg_list.append(aiocqhttp.MessageSegment.image(msg.path)) arg = ''
if msg.base64:
arg = msg.base64
msg_list.append(aiocqhttp.MessageSegment.image(f"base64://{arg}"))
elif msg.url:
arg = msg.url
msg_list.append(aiocqhttp.MessageSegment.image(arg))
elif msg.path:
arg = msg.path
msg_list.append(aiocqhttp.MessageSegment.image(arg))
elif type(msg) is mirai.At: elif type(msg) is mirai.At:
msg_list.append(aiocqhttp.MessageSegment.at(msg.target)) msg_list.append(aiocqhttp.MessageSegment.at(msg.target))
elif type(msg) is mirai.AtAll: elif type(msg) is mirai.AtAll:
@@ -38,9 +47,17 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif type(msg) is mirai.Face: elif type(msg) is mirai.Face:
msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id)) msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id))
elif type(msg) is mirai.Voice: elif type(msg) is mirai.Voice:
msg_list.append(aiocqhttp.MessageSegment.record(msg.path)) arg = ''
if msg.base64:
arg = msg.base64
msg_list.append(aiocqhttp.MessageSegment.record(f"base64://{arg}"))
elif msg.url:
arg = msg.url
msg_list.append(aiocqhttp.MessageSegment.record(arg))
elif msg.path:
arg = msg.path
msg_list.append(aiocqhttp.MessageSegment.record(msg.path))
elif type(msg) is forward.Forward: elif type(msg) is forward.Forward:
# print("aiocqhttp 暂不支持转发消息组件的转换,使用普通消息链发送")
for node in msg.node_list: for node in msg.node_list:
msg_list.extend(AiocqhttpMessageConverter.yiri2target(node.message_chain)[0]) msg_list.extend(AiocqhttpMessageConverter.yiri2target(node.message_chain)[0])
@@ -170,7 +187,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
name=event.sender["nickname"], name=event.sender["nickname"],
permission=mirai.models.entities.Permission.Member, permission=mirai.models.entities.Permission.Member,
), ),
special_title=event.sender["title"], special_title=event.sender["title"] if "title" in event.sender else "",
join_timestamp=0, join_timestamp=0,
last_speak_timestamp=0, last_speak_timestamp=0,
mute_time_remaining=0, mute_time_remaining=0,
@@ -216,13 +233,21 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
self.ap = ap self.ap = ap
self.bot = aiocqhttp.CQHttp() if "access-token" in config:
self.bot = aiocqhttp.CQHttp(access_token=config["access-token"])
del self.config["access-token"]
else:
self.bot = aiocqhttp.CQHttp()
async def send_message( async def send_message(
self, target_type: str, target_id: str, message: mirai.MessageChain self, target_type: str, target_id: str, message: mirai.MessageChain
): ):
# TODO 实现发送消息 aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
return super().send_message(target_type, target_id, message)
if target_type == "group":
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
elif target_type == "person":
await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg)
async def reply_message( async def reply_message(
self, self,
@@ -230,7 +255,6 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
message: mirai.MessageChain, message: mirai.MessageChain,
quote_origin: bool = False, quote_origin: bool = False,
): ):
aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id) aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0] aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
if quote_origin: if quote_origin:

View File

@@ -24,6 +24,8 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
msg_list = message_chain.__root__ msg_list = message_chain.__root__
elif type(message_chain) is list: elif type(message_chain) is list:
msg_list = message_chain msg_list = message_chain
elif type(message_chain) is str:
msg_list = [mirai.Plain(message_chain)]
else: else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
@@ -320,7 +322,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
proxies=None proxies=None
) )
if resp.status_code == 403: if resp.status_code == 403:
raise Exception("go-cqhttp拒绝访问请检查config.py中nakuru_config的token是否与go-cqhttp设置的access-token匹配") raise Exception("go-cqhttp拒绝访问请检查配置文件中nakuru适配器的配置")
self.bot_account_id = int(resp.json()['data']['user_id']) self.bot_account_id = int(resp.json()['data']['user_id'])
except Exception as e: except Exception as e:
raise Exception("获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确") raise Exception("获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确")

View File

@@ -3,29 +3,31 @@ from __future__ import annotations
import logging import logging
import typing import typing
import datetime import datetime
import asyncio
import re import re
import traceback import traceback
import json
import threading
import mirai import mirai
import botpy import botpy
import botpy.message as botpy_message import botpy.message as botpy_message
import botpy.types.message as botpy_message_type import botpy.types.message as botpy_message_type
import pydantic
import pydantic.networks
from .. import adapter as adapter_model from .. import adapter as adapter_model
from ...pipeline.longtext.strategies import forward from ...pipeline.longtext.strategies import forward
from ...core import app from ...core import app
from ...config import manager as cfg_mgr
class OfficialGroupMessage(mirai.GroupMessage): class OfficialGroupMessage(mirai.GroupMessage):
pass pass
class OfficialFriendMessage(mirai.FriendMessage):
pass
event_handler_mapping = { event_handler_mapping = {
mirai.GroupMessage: ["on_at_message_create", "on_group_at_message_create"], mirai.GroupMessage: ["on_at_message_create", "on_group_at_message_create"],
mirai.FriendMessage: ["on_direct_message_create"], mirai.FriendMessage: ["on_direct_message_create", "on_c2c_message_create"],
} }
@@ -34,6 +36,7 @@ cached_message_ids = {}
id_index = 0 id_index = 0
def save_msg_id(message_id: str) -> int: def save_msg_id(message_id: str) -> int:
"""保存消息id""" """保存消息id"""
global id_index, cached_message_ids global id_index, cached_message_ids
@@ -43,43 +46,82 @@ def save_msg_id(message_id: str) -> int:
cached_message_ids[str(crt_index)] = message_id cached_message_ids[str(crt_index)] = message_id
return crt_index return crt_index
cached_member_openids = {}
"""QQ官方 用户的id是字符串而YiriMirai的用户id是整数所以需要一个索引来进行转换"""
member_openid_index = 100 def char_to_value(char):
"""将单个字符转换为相应的数值。"""
if '0' <= char <= '9':
return ord(char) - ord('0')
elif 'A' <= char <= 'Z':
return ord(char) - ord('A') + 10
def save_member_openid(member_openid: str) -> int: return ord(char) - ord('a') + 36
"""保存用户id"""
global member_openid_index, cached_member_openids
if member_openid in cached_member_openids.values(): def digest(s: str) -> int:
return list(cached_member_openids.keys())[list(cached_member_openids.values()).index(member_openid)] """计算字符串的hash值。"""
# 取末尾的8位
sub_s = s[-10:]
crt_index = member_openid_index number = 0
member_openid_index += 1 base = 36
cached_member_openids[str(crt_index)] = member_openid
return crt_index
cached_group_openids = {} for i in range(len(sub_s)):
"""QQ官方 群组的id是字符串而YiriMirai的群组id是整数所以需要一个索引来进行转换""" number = number * base + char_to_value(sub_s[i])
group_openid_index = 1000 return number
def save_group_openid(group_openid: str) -> int: K = typing.TypeVar("K")
"""保存群组id""" V = typing.TypeVar("V")
global group_openid_index, cached_group_openids
if group_openid in cached_group_openids.values():
return list(cached_group_openids.keys())[list(cached_group_openids.values()).index(group_openid)]
crt_index = group_openid_index class OpenIDMapping(typing.Generic[K, V]):
group_openid_index += 1
cached_group_openids[str(crt_index)] = group_openid map: dict[K, V]
return crt_index
dump_func: typing.Callable
digest_func: typing.Callable[[K], V]
def __init__(self, map: dict[K, V], dump_func: typing.Callable, digest_func: typing.Callable[[K], V] = digest):
self.map = map
self.dump_func = dump_func
self.digest_func = digest_func
def __getitem__(self, key: K) -> V:
return self.map[key]
def __setitem__(self, key: K, value: V):
self.map[key] = value
self.dump_func()
def __contains__(self, key: K) -> bool:
return key in self.map
def __delitem__(self, key: K):
del self.map[key]
self.dump_func()
def getkey(self, value: V) -> K:
return list(self.map.keys())[list(self.map.values()).index(value)]
def save_openid(self, key: K) -> V:
if key in self.map:
return self.map[key]
value = self.digest_func(key)
self.map[key] = value
self.dump_func()
return value
class OfficialMessageConverter(adapter_model.MessageConverter): class OfficialMessageConverter(adapter_model.MessageConverter):
"""QQ 官方消息转换器""" """QQ 官方消息转换器"""
@staticmethod @staticmethod
def yiri2target(message_chain: mirai.MessageChain): def yiri2target(message_chain: mirai.MessageChain):
"""将 YiriMirai 的消息链转换为 QQ 官方消息""" """将 YiriMirai 的消息链转换为 QQ 官方消息"""
@@ -89,8 +131,12 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
msg_list = message_chain.__root__ msg_list = message_chain.__root__
elif type(message_chain) is list: elif type(message_chain) is list:
msg_list = message_chain msg_list = message_chain
elif type(message_chain) is str:
msg_list = [mirai.Plain(text=message_chain)]
else: else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) raise Exception(
"Unknown message type: " + str(message_chain) + str(type(message_chain))
)
offcial_messages: list[dict] = [] offcial_messages: list[dict] = []
""" """
@@ -108,36 +154,24 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
# 遍历并转换 # 遍历并转换
for component in msg_list: for component in msg_list:
if type(component) is mirai.Plain: if type(component) is mirai.Plain:
offcial_messages.append({ offcial_messages.append({"type": "text", "content": component.text})
"type": "text",
"content": component.text
})
elif type(component) is mirai.Image: elif type(component) is mirai.Image:
if component.url is not None: if component.url is not None:
offcial_messages.append( offcial_messages.append({"type": "image", "content": component.url})
{
"type": "image",
"content": component.url
}
)
elif component.path is not None: elif component.path is not None:
offcial_messages.append( offcial_messages.append(
{ {"type": "file_image", "content": component.path}
"type": "file_image",
"content": component.path
}
) )
elif type(component) is mirai.At: elif type(component) is mirai.At:
offcial_messages.append( offcial_messages.append({"type": "at", "content": ""})
{
"type": "at",
"content": ""
}
)
elif type(component) is mirai.AtAll: elif type(component) is mirai.AtAll:
print("上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。") print(
"上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
)
elif type(component) is mirai.Voice: elif type(component) is mirai.Voice:
print("上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。") print(
"上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
)
elif type(component) is forward.Forward: elif type(component) is forward.Forward:
# 转发消息 # 转发消息
yiri_forward_node_list = component.node_list yiri_forward_node_list = component.node_list
@@ -148,22 +182,32 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
message_chain = yiri_forward_node.message_chain message_chain = yiri_forward_node.message_chain
# 平铺 # 平铺
offcial_messages.extend(OfficialMessageConverter.yiri2target(message_chain)) offcial_messages.extend(
OfficialMessageConverter.yiri2target(message_chain)
)
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return offcial_messages return offcial_messages
@staticmethod @staticmethod
def extract_message_chain_from_obj(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage], message_id: str = None, bot_account_id: int = 0) -> mirai.MessageChain: def extract_message_chain_from_obj(
message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage],
message_id: str = None,
bot_account_id: int = 0,
) -> mirai.MessageChain:
yiri_msg_list = [] yiri_msg_list = []
# 存id # 存id
yiri_msg_list.append(mirai.models.message.Source(id=save_msg_id(message_id), time=datetime.datetime.now())) yiri_msg_list.append(
mirai.models.message.Source(
id=save_msg_id(message_id), time=datetime.datetime.now()
)
)
if type(message) is not botpy_message.DirectMessage: if type(message) not in [botpy_message.DirectMessage, botpy_message.C2CMessage]:
yiri_msg_list.append(mirai.At(target=bot_account_id)) yiri_msg_list.append(mirai.At(target=bot_account_id))
if hasattr(message, "mentions"): if hasattr(message, "mentions"):
@@ -174,10 +218,12 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
yiri_msg_list.append(mirai.At(target=mention.id)) yiri_msg_list.append(mirai.At(target=mention.id))
for attachment in message.attachments: for attachment in message.attachments:
if attachment.content_type == "image": if attachment.content_type.startswith("image"):
yiri_msg_list.append(mirai.Image(url=attachment.url)) yiri_msg_list.append(mirai.Image(url=attachment.url))
else: else:
logging.warning("不支持的附件类型:" + attachment.content_type + ",忽略此附件。") logging.warning(
"不支持的附件类型:" + attachment.content_type + ",忽略此附件。"
)
content = re.sub(r"<@!\d+>", "", str(message.content)) content = re.sub(r"<@!\d+>", "", str(message.content))
if content.strip() != "": if content.strip() != "":
@@ -190,25 +236,36 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
class OfficialEventConverter(adapter_model.EventConverter): class OfficialEventConverter(adapter_model.EventConverter):
"""事件转换器""" """事件转换器"""
@staticmethod
def yiri2target(event: typing.Type[mirai.Event]): member_openid_mapping: OpenIDMapping[str, int]
group_openid_mapping: OpenIDMapping[str, int]
def __init__(self, member_openid_mapping: OpenIDMapping[str, int], group_openid_mapping: OpenIDMapping[str, int]):
self.member_openid_mapping = member_openid_mapping
self.group_openid_mapping = group_openid_mapping
def yiri2target(self, event: typing.Type[mirai.Event]):
if event == mirai.GroupMessage: if event == mirai.GroupMessage:
return botpy_message.Message return botpy_message.Message
elif event == mirai.FriendMessage: elif event == mirai.FriendMessage:
return botpy_message.DirectMessage return botpy_message.DirectMessage
else: else:
raise Exception("未支持转换的事件类型(YiriMirai -> Official): " + str(event)) raise Exception(
"未支持转换的事件类型(YiriMirai -> Official): " + str(event)
)
@staticmethod def target2yiri(
def target2yiri(event: typing.Union[botpy_message.Message, botpy_message.DirectMessage]) -> mirai.Event: self,
event: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage],
) -> mirai.Event:
import mirai.models.entities as mirai_entities import mirai.models.entities as mirai_entities
if type(event) == botpy_message.Message: # 频道内,转群聊事件 if type(event) == botpy_message.Message: # 频道内,转群聊事件
permission = "MEMBER" permission = "MEMBER"
if '2' in event.member.roles: if "2" in event.member.roles:
permission = "ADMINISTRATOR" permission = "ADMINISTRATOR"
elif '4' in event.member.roles: elif "4" in event.member.roles:
permission = "OWNER" permission = "OWNER"
return mirai.GroupMessage( return mirai.GroupMessage(
@@ -219,29 +276,45 @@ class OfficialEventConverter(adapter_model.EventConverter):
group=mirai_entities.Group( group=mirai_entities.Group(
id=event.channel_id, id=event.channel_id,
name=event.author.username, name=event.author.username,
permission=mirai_entities.Permission.Member permission=mirai_entities.Permission.Member,
),
special_title="",
join_timestamp=int(
datetime.datetime.strptime(
event.member.joined_at, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
), ),
special_title='',
join_timestamp=int(datetime.datetime.strptime(event.member.joined_at, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
last_speak_timestamp=datetime.datetime.now().timestamp(), last_speak_timestamp=datetime.datetime.now().timestamp(),
mute_time_remaining=0, mute_time_remaining=0,
), ),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id), message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()), event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
),
) )
elif type(event) == botpy_message.DirectMessage: # 私聊,转私聊事件 elif type(event) == botpy_message.DirectMessage: # 频道私聊,转私聊事件
return mirai.FriendMessage( return mirai.FriendMessage(
sender=mirai_entities.Friend( sender=mirai_entities.Friend(
id=event.guild_id, id=event.guild_id,
nickname=event.author.username, nickname=event.author.username,
remark=event.author.username, remark=event.author.username,
), ),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id), message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()), event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
),
) )
elif type(event) == botpy_message.GroupMessage: elif type(event) == botpy_message.GroupMessage: # 群聊,转群聊事件
replacing_member_id = save_member_openid(event.author.member_openid) replacing_member_id = self.member_openid_mapping.save_openid(event.author.member_openid)
return OfficialGroupMessage( return OfficialGroupMessage(
sender=mirai_entities.GroupMember( sender=mirai_entities.GroupMember(
@@ -249,29 +322,55 @@ class OfficialEventConverter(adapter_model.EventConverter):
member_name=replacing_member_id, member_name=replacing_member_id,
permission="MEMBER", permission="MEMBER",
group=mirai_entities.Group( group=mirai_entities.Group(
id=save_group_openid(event.group_openid), id=self.group_openid_mapping.save_openid(event.group_openid),
name=replacing_member_id, name=replacing_member_id,
permission=mirai_entities.Permission.Member permission=mirai_entities.Permission.Member,
), ),
special_title='', special_title="",
join_timestamp=int(0), join_timestamp=int(0),
last_speak_timestamp=datetime.datetime.now().timestamp(), last_speak_timestamp=datetime.datetime.now().timestamp(),
mute_time_remaining=0, mute_time_remaining=0,
), ),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id), message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()), event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
),
)
elif type(event) == botpy_message.C2CMessage: # 私聊,转私聊事件
user_id_alter = self.member_openid_mapping.save_openid(event.author.user_openid) # 实测这里的user_openid与group的member_openid是一样的
return OfficialFriendMessage(
sender=mirai_entities.Friend(
id=user_id_alter,
nickname=user_id_alter,
remark=user_id_alter,
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
),
) )
@adapter_model.adapter_class("qq-botpy") @adapter_model.adapter_class("qq-botpy")
class OfficialAdapter(adapter_model.MessageSourceAdapter): class OfficialAdapter(adapter_model.MessageSourceAdapter):
"""QQ 官方消息适配器""" """QQ 官方消息适配器"""
bot: botpy.Client = None bot: botpy.Client = None
bot_account_id: int = 0 bot_account_id: int = 0
message_converter: OfficialMessageConverter = OfficialMessageConverter() message_converter: OfficialMessageConverter
# event_handler: adapter_model.EventHandler = adapter_model.EventHandler() event_converter: OfficialEventConverter
cfg: dict = None cfg: dict = None
@@ -283,78 +382,152 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
ap: app.Application ap: app.Application
metadata: cfg_mgr.ConfigManager = None
member_openid_mapping: OpenIDMapping[str, int] = None
group_openid_mapping: OpenIDMapping[str, int] = None
group_msg_seq = None
c2c_msg_seq = None
def __init__(self, cfg: dict, ap: app.Application): def __init__(self, cfg: dict, ap: app.Application):
"""初始化适配器""" """初始化适配器"""
self.cfg = cfg self.cfg = cfg
self.ap = ap self.ap = ap
self.group_msg_seq = 1
self.c2c_msg_seq = 1
switchs = {} switchs = {}
for intent in cfg['intents']: for intent in cfg["intents"]:
switchs[intent] = True switchs[intent] = True
del cfg['intents'] del cfg["intents"]
intents = botpy.Intents(**switchs) intents = botpy.Intents(**switchs)
self.bot = botpy.Client(intents=intents) self.bot = botpy.Client(intents=intents)
async def send_message( async def send_message(
self, self, target_type: str, target_id: str, message: mirai.MessageChain
target_type: str,
target_id: str,
message: mirai.MessageChain
): ):
pass message_list = self.message_converter.yiri2target(message)
for msg in message_list:
args = {}
if msg["type"] == "text":
args["content"] = msg["content"]
elif msg["type"] == "image":
args["image"] = msg["content"]
elif msg["type"] == "file_image":
args["file_image"] = msg["content"]
else:
continue
if target_type == "group":
args["channel_id"] = str(target_id)
await self.bot.api.post_message(**args)
elif target_type == "person":
args["guild_id"] = str(target_id)
await self.bot.api.post_dms(**args)
async def reply_message( async def reply_message(
self, self,
message_source: mirai.MessageEvent, message_source: mirai.MessageEvent,
message: mirai.MessageChain, message: mirai.MessageChain,
quote_origin: bool = False quote_origin: bool = False,
): ):
message_list = self.message_converter.yiri2target(message)
tasks = []
msg_seq = 1 message_list = self.message_converter.yiri2target(message)
for msg in message_list: for msg in message_list:
args = {} args = {}
if msg['type'] == 'text': if msg["type"] == "text":
args['content'] = msg['content'] args["content"] = msg["content"]
elif msg['type'] == 'image': elif msg["type"] == "image":
args['image'] = msg['content'] args["image"] = msg["content"]
elif msg['type'] == 'file_image': elif msg["type"] == "file_image":
args['file_image'] = msg["content"] args["file_image"] = msg["content"]
else: else:
continue continue
if quote_origin: if quote_origin:
args['message_reference'] = botpy_message_type.Reference(message_id=cached_message_ids[str(message_source.message_chain.message_id)]) args["message_reference"] = botpy_message_type.Reference(
message_id=cached_message_ids[
if type(message_source) == mirai.GroupMessage: str(message_source.message_chain.message_id)
args['channel_id'] = str(message_source.sender.group.id) ]
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
await self.bot.api.post_message(**args)
elif type(message_source) == mirai.FriendMessage:
args['guild_id'] = str(message_source.sender.id)
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
await self.bot.api.post_dms(**args)
elif type(message_source) == OfficialGroupMessage:
# args['guild_id'] = str(message_source.sender.group.id)
# args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
# await self.bot.api.post_message(**args)
if 'image' in args or 'file_image' in args:
continue
args['group_openid'] = cached_group_openids[str(message_source.sender.group.id)]
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
args['msg_seq'] = msg_seq
msg_seq += 1
await self.bot.api.post_group_message(
**args
) )
if type(message_source) == mirai.GroupMessage:
args["channel_id"] = str(message_source.sender.group.id)
args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id)
]
await self.bot.api.post_message(**args)
elif type(message_source) == mirai.FriendMessage:
args["guild_id"] = str(message_source.sender.id)
args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id)
]
await self.bot.api.post_dms(**args)
elif type(message_source) == OfficialGroupMessage:
if "file_image" in args: # 暂不支持发送文件图片
continue
args["group_openid"] = self.group_openid_mapping.getkey(
message_source.sender.group.id
)
if "image" in args:
uploadMedia = await self.bot.api.post_group_file(
group_openid=args["group_openid"],
file_type=1,
url=str(args['image'])
)
del args['image']
args['media'] = uploadMedia
args['msg_type'] = 7
args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id)
]
args["msg_seq"] = self.group_msg_seq
self.group_msg_seq += 1
await self.bot.api.post_group_message(**args)
elif type(message_source) == OfficialFriendMessage:
if "file_image" in args:
continue
args["openid"] = self.member_openid_mapping.getkey(
message_source.sender.id
)
if "image" in args:
uploadMedia = await self.bot.api.post_c2c_file(
openid=args["openid"],
file_type=1,
url=str(args['image'])
)
del args['image']
args['media'] = uploadMedia
args['msg_type'] = 7
args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id)
]
args["msg_seq"] = self.c2c_msg_seq
self.c2c_msg_seq += 1
await self.bot.api.post_c2c_message(**args)
async def is_muted(self, group_id: int) -> bool: async def is_muted(self, group_id: int) -> bool:
return False return False
@@ -362,14 +535,22 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] callback: typing.Callable[
[mirai.Event, adapter_model.MessageSourceAdapter], None
],
): ):
try: try:
async def wrapper(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage]): async def wrapper(
message: typing.Union[
botpy_message.Message,
botpy_message.DirectMessage,
botpy_message.GroupMessage,
]
):
self.cached_official_messages[str(message.id)] = message self.cached_official_messages[str(message.id)] = message
await callback(OfficialEventConverter.target2yiri(message), self) await callback(self.event_converter.target2yiri(message), self)
for event_handler in event_handler_mapping[event_type]: for event_handler in event_handler_mapping[event_type]:
setattr(self.bot, event_handler, wrapper) setattr(self.bot, event_handler, wrapper)
@@ -380,15 +561,33 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] callback: typing.Callable[
[mirai.Event, adapter_model.MessageSourceAdapter], None
],
): ):
delattr(self.bot, event_handler_mapping[event_type]) delattr(self.bot, event_handler_mapping[event_type])
async def run_async(self): async def run_async(self):
self.ap.logger.info("运行 QQ 官方适配器")
await self.bot.start( self.metadata = self.ap.adapter_qq_botpy_meta
**self.cfg
self.member_openid_mapping = OpenIDMapping(
map=self.metadata.data["mapping"]["members"],
dump_func=self.metadata.dump_config_sync,
) )
self.group_openid_mapping = OpenIDMapping(
map=self.metadata.data["mapping"]["groups"],
dump_func=self.metadata.dump_config_sync,
)
self.message_converter = OfficialMessageConverter()
self.event_converter = OfficialEventConverter(
self.member_openid_mapping, self.group_openid_mapping
)
self.ap.logger.info("运行 QQ 官方适配器")
await self.bot.start(**self.cfg)
def kill(self) -> bool: def kill(self) -> bool:
return False return False

View File

@@ -3,16 +3,93 @@ from __future__ import annotations
import typing import typing
import abc import abc
import pydantic import pydantic
import mirai
from . import events from . import events
from ..provider.tools import entities as tools_entities from ..provider.tools import entities as tools_entities
from ..core import app from ..core import app
def register(
name: str,
description: str,
version: str,
author: str
) -> typing.Callable[[typing.Type[BasePlugin]], typing.Type[BasePlugin]]:
"""注册插件类
使用示例:
@register(
name="插件名称",
description="插件描述",
version="插件版本",
author="插件作者"
)
class MyPlugin(BasePlugin):
pass
"""
pass
def handler(
event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册事件监听器
使用示例:
class MyPlugin(BasePlugin):
@handler(NormalMessageResponded)
async def on_normal_message_responded(self, ctx: EventContext):
pass
"""
pass
def llm_func(
name: str=None,
) -> typing.Callable:
"""注册内容函数
使用示例:
class MyPlugin(BasePlugin):
@llm_func("access_the_web_page")
async def _(self, query, url: str, brief_len: int):
\"""Call this function to search about the question before you answer any questions.
- Do not search through google.com at any time.
- If you need to search somthing, visit https://www.sogou.com/web?query=<something>.
- If user ask you to open a url (start with http:// or https://), visit it directly.
- Summary the plain content result by yourself, DO NOT directly output anything in the result you got.
Args:
url(str): url to visit
brief_len(int): max length of the plain text content, recommend 1024-4096, prefer 4096
Returns:
str: plain text content of the web page or error message(starts with 'error:')
\"""
"""
pass
class BasePlugin(metaclass=abc.ABCMeta): class BasePlugin(metaclass=abc.ABCMeta):
"""插件基类""" """插件基类"""
host: APIHost host: APIHost
"""API宿主"""
ap: app.Application
"""应用程序对象"""
def __init__(self, host: APIHost):
self.host = host
async def initialize(self):
"""初始化插件"""
pass
class APIHost: class APIHost:
@@ -61,8 +138,10 @@ class EventContext:
"""事件编号""" """事件编号"""
host: APIHost = None host: APIHost = None
"""API宿主"""
event: events.BaseEventModel = None event: events.BaseEventModel = None
"""此次事件的对象具体类型为handler注册时指定监听的类型可查看events.py中的定义"""
__prevent_default__ = False __prevent_default__ = False
"""是否阻止默认行为""" """是否阻止默认行为"""
@@ -87,12 +166,55 @@ class EventContext:
} }
""" """
# ========== 插件可调用的 API ==========
def add_return(self, key: str, ret): def add_return(self, key: str, ret):
"""添加返回值""" """添加返回值"""
if key not in self.__return_value__: if key not in self.__return_value__:
self.__return_value__[key] = [] self.__return_value__[key] = []
self.__return_value__[key].append(ret) self.__return_value__[key].append(ret)
async def reply(self, message_chain: mirai.MessageChain):
"""回复此次消息请求
Args:
message_chain (mirai.MessageChain): YiriMirai库的消息链若用户使用的不是 YiriMirai 适配器,程序也能自动转换为目标消息链
"""
await self.host.ap.platform_mgr.send(
event=self.event.query.message_event,
msg=message_chain,
adapter=self.event.query.adapter,
)
async def send_message(
self,
target_type: str,
target_id: str,
message: mirai.MessageChain
):
"""主动发送消息
Args:
target_type (str): 目标类型,`person`或`group`
target_id (str): 目标ID
message (mirai.MessageChain): YiriMirai库的消息链若用户使用的不是 YiriMirai 适配器,程序也能自动转换为目标消息链
"""
await self.event.query.adapter.send_message(
target_type=target_type,
target_id=target_id,
message=message
)
def prevent_postorder(self):
"""阻止后续插件执行"""
self.__prevent_postorder__ = True
def prevent_default(self):
"""阻止默认行为"""
self.__prevent_default__ = True
# ========== 以下是内部保留方法,插件不应调用 ==========
def get_return(self, key: str) -> list: def get_return(self, key: str) -> list:
"""获取key的所有返回值""" """获取key的所有返回值"""
if key in self.__return_value__: if key in self.__return_value__:
@@ -105,14 +227,6 @@ class EventContext:
return self.__return_value__[key][0] return self.__return_value__[key][0]
return None return None
def prevent_default(self):
"""阻止默认行为"""
self.__prevent_default__ = True
def prevent_postorder(self):
"""阻止后续插件执行"""
self.__prevent_postorder__ = True
def is_prevented_default(self): def is_prevented_default(self):
"""是否阻止默认行为""" """是否阻止默认行为"""
return self.__prevent_default__ return self.__prevent_default__
@@ -121,6 +235,7 @@ class EventContext:
"""是否阻止后序插件执行""" """是否阻止后序插件执行"""
return self.__prevent_postorder__ return self.__prevent_postorder__
def __init__(self, host: APIHost, event: events.BaseEventModel): def __init__(self, host: APIHost, event: events.BaseEventModel):
self.eid = EventContext.eid self.eid = EventContext.eid

View File

@@ -10,8 +10,10 @@ from ..provider import entities as llm_entities
class BaseEventModel(pydantic.BaseModel): class BaseEventModel(pydantic.BaseModel):
"""事件模型基类"""
query: core_entities.Query | None query: typing.Union[core_entities.Query, None]
"""此次请求的query对象非请求过程的事件时为None"""
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True

View File

@@ -1,3 +1,7 @@
# 此模块已过时
# 请从 pkg.plugin.context 引入 BasePlugin, EventContext 和 APIHost
# 最早将于 v3.4 移除此模块
from . events import * from . events import *
from . context import EventContext, APIHost as PluginHost from . context import EventContext, APIHost as PluginHost

View File

@@ -7,6 +7,7 @@ from ..core import app
class PluginInstaller(metaclass=abc.ABCMeta): class PluginInstaller(metaclass=abc.ABCMeta):
"""插件安装器抽象类"""
ap: app.Application ap: app.Application

Some files were not shown because too many files have changed in this diff Show More