mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 11:29:39 +08:00
Compare commits
292 Commits
noReleaseP
...
v2.2.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e89035e11c | ||
|
|
2ea711e629 | ||
|
|
a716f071be | ||
|
|
3450a91824 | ||
|
|
d2c2b457e5 | ||
|
|
9cd7e49804 | ||
|
|
e9155e836f | ||
|
|
ed248539c7 | ||
|
|
54cc75506f | ||
|
|
4269c7927e | ||
|
|
064ac7f603 | ||
|
|
48ccf15273 | ||
|
|
b920ced6d4 | ||
|
|
69610a674c | ||
|
|
1828e34190 | ||
|
|
d53f4e3917 | ||
|
|
01706d5b4e | ||
|
|
8916b8a450 | ||
|
|
ed33af5638 | ||
|
|
c94a9e1ae6 | ||
|
|
e2e93afd06 | ||
|
|
a810158d5b | ||
|
|
5a5ebb95fc | ||
|
|
61dd9e29c0 | ||
|
|
ac65d81ba1 | ||
|
|
7288d3cb15 | ||
|
|
7477c7c67f | ||
|
|
453952859e | ||
|
|
85d46089e3 | ||
|
|
3b55f706de | ||
|
|
f448276423 | ||
|
|
830ee704da | ||
|
|
393369e446 | ||
|
|
2cc6a09905 | ||
|
|
d7d9d88e16 | ||
|
|
357d6aaf75 | ||
|
|
8059c422e3 | ||
|
|
b336e1334d | ||
|
|
12a0942ddb | ||
|
|
7e5a77f77e | ||
|
|
2933d4843f | ||
|
|
c5de978098 | ||
|
|
8b9cfab072 | ||
|
|
ea5f3c222f | ||
|
|
36bcbca15b | ||
|
|
2b2060e71b | ||
|
|
451688f2df | ||
|
|
d993852de7 | ||
|
|
9d73770a4e | ||
|
|
2541acf9d2 | ||
|
|
a1bfbad24e | ||
|
|
8af4918048 | ||
|
|
49f4ab0ec8 | ||
|
|
85c623fb0f | ||
|
|
9e28298250 | ||
|
|
7a04ef0985 | ||
|
|
83005e9ba9 | ||
|
|
f0c78f0529 | ||
|
|
3f638adcf9 | ||
|
|
d9405d8d5d | ||
|
|
606713a418 | ||
|
|
52102f0d0a | ||
|
|
61c29829ed | ||
|
|
df30931aad | ||
|
|
5afcc03e8b | ||
|
|
fbeb4673f4 | ||
|
|
4aba319560 | ||
|
|
74f79e002c | ||
|
|
2668ef2b3f | ||
|
|
74c018e271 | ||
|
|
64776fd601 | ||
|
|
59877bf71d | ||
|
|
d2800ac58b | ||
|
|
ffef944119 | ||
|
|
651b291ef6 | ||
|
|
e4b581f197 | ||
|
|
4f3939e2d9 | ||
|
|
1048ca612d | ||
|
|
b1a2d21ee9 | ||
|
|
dd4e8bdc8b | ||
|
|
e28c9bae0c | ||
|
|
5c10f520fb | ||
|
|
f8abe90674 | ||
|
|
964ad42cb4 | ||
|
|
424b970469 | ||
|
|
792366e221 | ||
|
|
79e970c4c3 | ||
|
|
d12acd5f31 | ||
|
|
13e55e05a4 | ||
|
|
9a7490bc2f | ||
|
|
a610a9d3d3 | ||
|
|
56e906c83f | ||
|
|
101f26e5a3 | ||
|
|
0bba205cf2 | ||
|
|
cc3beb191f | ||
|
|
42f5092bb9 | ||
|
|
bc6728d123 | ||
|
|
754278f80f | ||
|
|
c9c980b6fe | ||
|
|
a457d13d2c | ||
|
|
7440e9e5d2 | ||
|
|
39d901a5cb | ||
|
|
2e1ebff985 | ||
|
|
b8ed9ba321 | ||
|
|
c89a8e1cd1 | ||
|
|
480d201c55 | ||
|
|
a4b7d4a012 | ||
|
|
7fe676712b | ||
|
|
552733129c | ||
|
|
a4d73090f8 | ||
|
|
7d39b72800 | ||
|
|
f1e12563e9 | ||
|
|
0ac5e5b35e | ||
|
|
6b3f74a39a | ||
|
|
3c3e2e86c3 | ||
|
|
204a778db2 | ||
|
|
3594e64bfc | ||
|
|
c23d114094 | ||
|
|
6cb3fdc7c9 | ||
|
|
c57642bd4e | ||
|
|
891ee0fac8 | ||
|
|
1b69f0b668 | ||
|
|
46b310ceb9 | ||
|
|
85fe44ec92 | ||
|
|
fdcec0fbf7 | ||
|
|
2664ea8622 | ||
|
|
862724da74 | ||
|
|
a1c167fb7f | ||
|
|
adc2290fc1 | ||
|
|
8713fd8130 | ||
|
|
77df3d1ae5 | ||
|
|
2234e9db0e | ||
|
|
dd3d403de8 | ||
|
|
5364c36a79 | ||
|
|
118fbe3f7d | ||
|
|
61ec8e96f2 | ||
|
|
19289527ae | ||
|
|
77fdd6ddb8 | ||
|
|
f7830b5e9d | ||
|
|
13e5d76a44 | ||
|
|
7b8ad2e315 | ||
|
|
623f094e5b | ||
|
|
fd25d61b56 | ||
|
|
6f5802551f | ||
|
|
cbab824fd0 | ||
|
|
0c3d911e74 | ||
|
|
e161343d72 | ||
|
|
4984896c95 | ||
|
|
28d1f5ead9 | ||
|
|
5044f757fb | ||
|
|
aa28b5aead | ||
|
|
5ada507c2b | ||
|
|
48be080fe0 | ||
|
|
cc2442e761 | ||
|
|
a5560823d9 | ||
|
|
9cd313c4df | ||
|
|
173f05a8ae | ||
|
|
29819668e3 | ||
|
|
9cbb9734f2 | ||
|
|
cf897410ee | ||
|
|
bf1896f959 | ||
|
|
e7c79a5156 | ||
|
|
eb60c1b0a0 | ||
|
|
6b1b69c741 | ||
|
|
352930694a | ||
|
|
215ed7ab0e | ||
|
|
2da5883d7a | ||
|
|
b6d731cf87 | ||
|
|
bef918749d | ||
|
|
96b7674644 | ||
|
|
d0bcf6940a | ||
|
|
c35fd9c4b7 | ||
|
|
0bb5923257 | ||
|
|
41752aff60 | ||
|
|
e873d81b63 | ||
|
|
7489a11ab3 | ||
|
|
86b7a8482c | ||
|
|
731dedf155 | ||
|
|
131297d859 | ||
|
|
7ce9687702 | ||
|
|
d56163b19b | ||
|
|
e0f8a04f8e | ||
|
|
3c08741cb6 | ||
|
|
c902822723 | ||
|
|
13f31d3fae | ||
|
|
70268c0cbb | ||
|
|
04eaf9f3e9 | ||
|
|
fd57b7df18 | ||
|
|
d72c364962 | ||
|
|
618487947b | ||
|
|
e75140d732 | ||
|
|
e09f6105a1 | ||
|
|
d3a6928e3a | ||
|
|
8b2128b4dc | ||
|
|
0773490c77 | ||
|
|
c2610a32e4 | ||
|
|
aaf72de552 | ||
|
|
65664ae178 | ||
|
|
742600fc4f | ||
|
|
6531aae617 | ||
|
|
842748947f | ||
|
|
b733f8f55b | ||
|
|
b7ae1fa516 | ||
|
|
64c587c17d | ||
|
|
133d8bbeef | ||
|
|
7d3bc4203e | ||
|
|
24a10265f3 | ||
|
|
351039dc3b | ||
|
|
3ed9f1a532 | ||
|
|
99282100a0 | ||
|
|
057d8a05d7 | ||
|
|
71b69a3226 | ||
|
|
d1a5c9a090 | ||
|
|
2b20d946e6 | ||
|
|
3c96e1298c | ||
|
|
4e54c24bf0 | ||
|
|
2894309fa6 | ||
|
|
fbd53dae7c | ||
|
|
ba2c362082 | ||
|
|
680085d16f | ||
|
|
2319c7eae2 | ||
|
|
645099ecf2 | ||
|
|
d51a0a644a | ||
|
|
37abc79551 | ||
|
|
20bdc7de58 | ||
|
|
690e542f37 | ||
|
|
33f80c8d16 | ||
|
|
e01cc09a28 | ||
|
|
120ec98ba7 | ||
|
|
b4938ba1fb | ||
|
|
41d0082cee | ||
|
|
0e786660b4 | ||
|
|
6af55d8a1d | ||
|
|
2c9e7f70f2 | ||
|
|
42819daf0f | ||
|
|
08d86dbd30 | ||
|
|
32e8f08398 | ||
|
|
78c73def8a | ||
|
|
82d845b5c8 | ||
|
|
a6bda0dec7 | ||
|
|
40fd9b0579 | ||
|
|
3eda4382b2 | ||
|
|
fd2812a30b | ||
|
|
fd27a7c999 | ||
|
|
870aba0560 | ||
|
|
37153e7360 | ||
|
|
f06b16437c | ||
|
|
e582780195 | ||
|
|
df9e89deb7 | ||
|
|
b4033b2902 | ||
|
|
023ed21363 | ||
|
|
52d6721ae2 | ||
|
|
fa967c3c89 | ||
|
|
6d81821557 | ||
|
|
56664f9fbc | ||
|
|
eb1564a3dd | ||
|
|
d5c6d43ddf | ||
|
|
da5b1cf3fa | ||
|
|
4232ab6f47 | ||
|
|
78c1ad16ce | ||
|
|
9962a6ebcc | ||
|
|
36def20a07 | ||
|
|
c7689d3c89 | ||
|
|
dfa8621a1a | ||
|
|
f884313d72 | ||
|
|
7afe5f39bf | ||
|
|
01bc529b93 | ||
|
|
77bf1c7d8e | ||
|
|
9d31d8b071 | ||
|
|
5256d3c718 | ||
|
|
c662e2c4e3 | ||
|
|
06264354cf | ||
|
|
2dcbe87986 | ||
|
|
8506cdae8f | ||
|
|
bf7487fafe | ||
|
|
0afc2d5903 | ||
|
|
2d62b5937e | ||
|
|
97ddb10ff5 | ||
|
|
2a74c8e053 | ||
|
|
bd920cedf5 | ||
|
|
ee76929fee | ||
|
|
8bb8a72060 | ||
|
|
dcc5d40a04 | ||
|
|
3bebeb4d99 | ||
|
|
d2922afce2 | ||
|
|
5dfd0a9b50 | ||
|
|
67cfe654b8 | ||
|
|
7ca8dcfb6a | ||
|
|
948b0f4df9 | ||
|
|
584cacba6c | ||
|
|
45ed06be64 | ||
|
|
3dec627d40 |
2
.github/ISSUE_TEMPLATE/漏洞反馈.md
vendored
2
.github/ISSUE_TEMPLATE/漏洞反馈.md
vendored
@@ -7,6 +7,8 @@ assignees: ''
|
||||
|
||||
---
|
||||
|
||||
请认真按照实际情况填写以下信息!!!!
|
||||
|
||||
**运行环境**
|
||||
- 部署方式:
|
||||
手动部署/自动部署/Docker部署
|
||||
|
||||
15
.github/dependabot.yml
vendored
Normal file
15
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
# To get started with Dependabot version updates, you'll need to specify which
|
||||
# package ecosystems to update and where the package manifests are located.
|
||||
# Please see the documentation for all configuration options:
|
||||
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
|
||||
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "pip" # See documentation for possible values
|
||||
directory: "/" # Location of package manifests
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
allow:
|
||||
- dependency-name: "yiri-mirai"
|
||||
- dependency-name: "dulwich"
|
||||
- dependency-name: "openai"
|
||||
12
.gitignore
vendored
12
.gitignore
vendored
@@ -3,4 +3,14 @@ config.py
|
||||
__pycache__/
|
||||
database.db
|
||||
qchatgpt.log
|
||||
config.py
|
||||
/banlist.py
|
||||
plugins/
|
||||
!plugins/__init__.py
|
||||
/revcfg.py
|
||||
prompts/
|
||||
logs/
|
||||
sensitive.json
|
||||
temp/
|
||||
current_tag
|
||||
scenario/
|
||||
!scenario/default-template.json
|
||||
19
CONTRIBUTING.md
Normal file
19
CONTRIBUTING.md
Normal file
@@ -0,0 +1,19 @@
|
||||
## 参与项目
|
||||
|
||||
欢迎为此项目贡献代码或其他支持,以使您的点子或众人期待的功能成为现实,助力社区成长。
|
||||
|
||||
### 贡献形式
|
||||
|
||||
- 提交PR,解决issues中提到的bug或期待的功能
|
||||
- 提交PR,实现您设想的功能(请先提出issue与作者沟通)
|
||||
- 优化代码架构,使各个模块的组织更加整洁优雅
|
||||
- 在issues中提出发现的bug或者期待的功能
|
||||
- 为本项目在其他社交平台撰写文章、制作视频等
|
||||
- 为本项目的衍生项目作出贡献,或开发插件增加功能
|
||||
|
||||
### 如何开始
|
||||
|
||||
- 加入本项目交流群,一同探讨项目相关事务
|
||||
- 解决本项目或衍生项目的issues中亟待解决的问题
|
||||
- 阅读并完善本项目文档
|
||||
- 在各个社交媒体撰写本项目教程等
|
||||
155
README.md
155
README.md
@@ -1,19 +1,113 @@
|
||||
# QChatGPT🤖
|
||||
> 2023/3/3 官方接口疑似被墙,可考虑使用网络代理 [#198](https://github.com/RockChinQ/QChatGPT/issues/198)
|
||||
> 2023/3/3 现已在主线支持官方ChatGPT接口,使用方法查看[#195](https://github.com/RockChinQ/QChatGPT/issues/195)
|
||||
> 2023/3/2 OpenAI已发布ChatGPT官方接口,我们正在全力接入,预计明日前完成,请查看[此PR](https://github.com/RockChinQ/QChatGPT/pull/194)
|
||||
> 2023/2/16 现已支持接入ChatGPT网页版,详情请完成部署并查看底部**插件**小节或[此仓库](https://github.com/RockChinQ/revLibs)
|
||||
|
||||
- 到[项目Wiki](https://github.com/RockChinQ/QChatGPT/wiki)可了解项目详细信息
|
||||
- 由bilibili TheLazy制作的[视频教程](https://www.bilibili.com/video/BV15v4y1X7aP)
|
||||
- 测试号: 2196084348
|
||||
- 交流、答疑群: 204785790
|
||||
- **进群提问前请您`确保`已经找遍文档和issue均无法解决**
|
||||
- **进群提问前请您`确保`已经找遍文档和issue均无法解决**
|
||||
- 交流、答疑群: ~~204785790~~(已满)、~~691226829~~(已满)、656285629
|
||||
- **进群提问前请您`确保`已经找遍文档和issue均无法解决**
|
||||
- QQ频道机器人见[QQChannelChatGPT](https://github.com/Soulter/QQChannelChatGPT)
|
||||
|
||||
通过调用OpenAI GPT-3模型提供的Completion API来实现一个更加智能的QQ机器人
|
||||
通过调用OpenAI的ChatGPT等语言模型来实现一个更加智能的QQ机器人
|
||||
|
||||
## 🍺模型适配一览
|
||||
|
||||
### 文字对话
|
||||
|
||||
- OpenAI GPT-3.5模型(ChatGPT API), 本项目原生支持, 默认使用
|
||||
- OpenAI GPT-3模型, 本项目原生支持, 部署完成后前往config.py切换
|
||||
- ChatGPT网页版逆向API, 由[插件](https://github.com/RockChinQ/revLibs)接入
|
||||
|
||||
### 故事续写
|
||||
|
||||
- NovelAI API, 由[插件](https://github.com/dominoar/QCPNovelAi)接入
|
||||
|
||||
### 图片绘制
|
||||
|
||||
- OpenAI DALL·E模型, 本项目原生支持, 使用方法查看[Wiki功能使用页](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E5%8A%9F%E8%83%BD%E7%82%B9%E5%88%97%E4%B8%BE)
|
||||
- NovelAI API, 由[插件](https://github.com/dominoar/QCPNovelAi)接入
|
||||
|
||||
### 语音生成
|
||||
|
||||
- TTS+VITS, 由[插件](https://github.com/dominoar/QChatPlugins)接入
|
||||
|
||||
## ✅功能
|
||||
|
||||
查看[Wiki功能使用页](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E5%8A%9F%E8%83%BD%E7%82%B9%E5%88%97%E4%B8%BE)
|
||||
<details>
|
||||
<summary>✅支持敏感词过滤,避免账号风险</summary>
|
||||
|
||||
- 难以监测机器人与用户对话时的内容,故引入此功能以减少机器人风险
|
||||
- 加入了百度云内容审核,在`config.py`中修改`baidu_check`的值,并填写`baidu_api_key`和`baidu_secret_key`以开启此功能
|
||||
- 编辑`sensitive.json`,并在`config.py`中修改`sensitive_word_filter`的值以开启此功能
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>✅群内多种响应规则,不必at</summary>
|
||||
|
||||
- 默认回复`ai`作为前缀或`@`机器人的消息
|
||||
- 详细见`config.py`中的`response_rules`字段
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>✅完善的多api-key管理,超额自动切换</summary>
|
||||
|
||||
- 支持配置多个`api-key`,内部统计使用量并在超额时自动切换
|
||||
- 请在`config.py`中修改`openai_config`的值以设置`api-key`
|
||||
- 可以在`config.py`中修改`api_key_fee_threshold`来自定义切换阈值
|
||||
- 运行期间向机器人说`!usage`以查看当前使用情况
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>✅支持预设指令文字</summary>
|
||||
|
||||
- 支持以自然语言预设文字,自定义机器人人格等信息
|
||||
- 详见`config.py`中的`default_prompt`部分
|
||||
- 支持设置多个预设情景,并通过!reset、!default等指令控制,详细请查看[wiki指令](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E6%9C%BA%E5%99%A8%E4%BA%BA%E6%8C%87%E4%BB%A4)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>✅支持对话、绘图等模型,可玩性更高</summary>
|
||||
|
||||
- 现已支持OpenAI的对话`Completion API`和绘图`Image API`
|
||||
- 向机器人发送指令`!draw <prompt>`即可使用绘图模型
|
||||
</details>
|
||||
<details>
|
||||
<summary>✅支持指令控制热重载、热更新</summary>
|
||||
|
||||
- 允许在运行期间修改`config.py`或其他代码后,以管理员账号向机器人发送指令`!reload`进行热重载,无需重启
|
||||
- 运行期间允许以管理员账号向机器人发送指令`!update`进行热更新,拉取远程最新代码并执行热重载
|
||||
</details>
|
||||
<details>
|
||||
<summary>✅支持插件加载🧩</summary>
|
||||
|
||||
- 自行实现插件加载器及相关支持
|
||||
- 详细查看[插件使用页](https://github.com/RockChinQ/QChatGPT/wiki/%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8)
|
||||
</details>
|
||||
<details>
|
||||
<summary>✅私聊、群聊黑名单机制</summary>
|
||||
|
||||
- 支持将人或群聊加入黑名单以忽略其消息
|
||||
- 详见Wiki`加入黑名单`节
|
||||
</details>
|
||||
<details>
|
||||
<summary>✅长消息处理策略</summary>
|
||||
|
||||
- 支持将长消息转换成图片或消息记录组件,避免消息刷屏
|
||||
- 请查看`config.py`中`blob_message_strategy`等字段
|
||||
</details>
|
||||
<details>
|
||||
<summary>✅回复速度限制</summary>
|
||||
|
||||
- 支持限制单会话内每分钟可进行的对话次数
|
||||
- 具有“等待”和“丢弃”两种策略
|
||||
- “等待”策略:在获取到回复后,等待直到此次响应时间达到对话响应时间均值
|
||||
- “丢弃”策略:此分钟内对话次数达到限制时,丢弃之后的对话
|
||||
- 详细请查看config.py中的相关配置
|
||||
</details>
|
||||
|
||||
详情请查看[Wiki功能使用页](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E5%8A%9F%E8%83%BD%E7%82%B9%E5%88%97%E4%B8%BE)
|
||||
|
||||
## 🔩部署
|
||||
|
||||
@@ -21,9 +115,9 @@
|
||||
|
||||
### - 注册OpenAI账号
|
||||
|
||||
参考以下文章
|
||||
参考以下文章自行注册
|
||||
|
||||
> [只需 1 元搞定 ChatGPT 注册](https://zhuanlan.zhihu.com/p/589470082)
|
||||
> [国内注册ChatGPT的方法(100%可用)](https://www.pythonthree.com/register-openai-chatgpt/)
|
||||
> [手把手教你如何注册ChatGPT,超级详细](https://guxiaobei.com/51461)
|
||||
|
||||
注册成功后请前往[个人中心查看](https://beta.openai.com/account/api-keys)api_key
|
||||
@@ -49,10 +143,7 @@
|
||||
<details>
|
||||
<summary>手动部署适用于所有平台</summary>
|
||||
|
||||
- 请使用Python 3.9.x以上版本
|
||||
- 请注意OpenAI账号额度消耗
|
||||
- 每个账户仅有18美元免费额度,如未绑定银行卡,则会在超出时报错
|
||||
- OpenAI收费标准:默认使用的`text-davinci-003`模型 0.02美元/千字
|
||||
- 请使用Python 3.9.x以上版本
|
||||
|
||||
#### 配置Mirai
|
||||
|
||||
@@ -95,17 +186,51 @@ python3 main.py
|
||||
|
||||
**常见问题**
|
||||
|
||||
- mirai登录提示`QQ版本过低`,见[此issue](https://github.com/RockChinQ/QChatGPT/issues/38)
|
||||
- mirai登录提示`QQ版本过低`,见[此issue](https://github.com/RockChinQ/QChatGPT/issues/137)
|
||||
- 如提示安装`uvicorn`或`hypercorn`请*不要*安装,这两个不是必需的,目前存在未知原因bug
|
||||
- 如报错`TypeError: As of 3.10, the *loop* parameter was removed from Lock() since it is no longer necessary`, 请参考 [此处](https://github.com/RockChinQ/QChatGPT/issues/5)
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
## 🚀使用
|
||||
|
||||
查看[Wiki功能使用页](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E4%BD%BF%E7%94%A8%E6%96%B9%E5%BC%8F)
|
||||
|
||||
## 🧩插件生态
|
||||
|
||||
现已支持自行开发插件对功能进行扩展或自定义程序行为
|
||||
详见[Wiki插件使用页](https://github.com/RockChinQ/QChatGPT/wiki/%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8)
|
||||
开发教程见[Wiki插件开发页](https://github.com/RockChinQ/QChatGPT/wiki/%E6%8F%92%E4%BB%B6%E5%BC%80%E5%8F%91)
|
||||
|
||||
### 示例插件
|
||||
|
||||
在`tests/plugin_examples`目录下,将其整个目录复制到`plugins`目录下即可使用
|
||||
|
||||
- `cmdcn` - 主程序指令中文形式
|
||||
- `hello_plugin` - 在收到消息`hello`时回复相应消息
|
||||
- `urlikethisijustsix` - 收到冒犯性消息时回复相应消息
|
||||
|
||||
### 更多
|
||||
|
||||
欢迎提交新的插件
|
||||
|
||||
- [revLibs](https://github.com/RockChinQ/revLibs) - 将ChatGPT网页版接入此项目,关于[官方接口和网页版有什么区别](https://github.com/RockChinQ/QChatGPT/wiki/%E5%AE%98%E6%96%B9%E6%8E%A5%E5%8F%A3%E4%B8%8EChatGPT%E7%BD%91%E9%A1%B5%E7%89%88)
|
||||
- [hello_plugin](https://github.com/RockChinQ/hello_plugin) - `hello_plugin` 的储存库形式,插件开发模板
|
||||
- [dominoar/QChatPlugins](https://github.com/dominoar/QchatPlugins) - dominoar编写的诸多新功能插件(语言输出、Ranimg、屏蔽词规则等)
|
||||
- [dominoar/QCP-NovelAi](https://github.com/dominoar/QCP-NovelAi) - NovelAI 故事叙述与绘画
|
||||
|
||||
## 😘致谢
|
||||
|
||||
- [@the-lazy-me](https://github.com/the-lazy-me) 为本项目制作[视频教程](https://www.bilibili.com/video/BV15v4y1X7aP)
|
||||
- [@mikumifa](https://github.com/mikumifa) 本项目Docker部署仓库开发者
|
||||
- [@dominoar](https://github.com/dominoar) 为本项目开发多种插件
|
||||
- [@hissincn](https://github.com/hissincn) 本项目贡献者
|
||||
- [@LINSTCL](https://github.com/LINSTCL) GPT-3.5官方模型适配贡献者
|
||||
- [@Haibersut](https://github.com/Haibersut) 本项目贡献者
|
||||
- [@万神的星空](https://github.com/qq255204159) 整合包发行
|
||||
|
||||
以及其他所有为本项目提供支持的朋友们。
|
||||
|
||||
## 👍赞赏
|
||||
|
||||
<img alt="赞赏码" src="res/mm_reward_qrcode_1672840549070.png" width="400" height="400"/>
|
||||
<img alt="赞赏码" src="res/mm_reward_qrcode_1672840549070.png" width="400" height="400"/>
|
||||
|
||||
20
banlist-template.py
Normal file
20
banlist-template.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# 是否启用禁用列表
|
||||
enable = True
|
||||
|
||||
# 禁用规则(黑名单)
|
||||
# person为个人,其中的QQ号会被禁止与机器人进行私聊或群聊交互
|
||||
# 示例: person = [2854196310, 1234567890, 9876543210]
|
||||
# group为群组,其中的群号会被禁止与机器人进行交互
|
||||
# 示例: group = [123456789, 987654321, 1234567890]
|
||||
#
|
||||
# 支持正则表达式,字符串都将被识别为正则表达式,例如:
|
||||
# person = [12345678, 87654321, "2854.*"]
|
||||
# group = [123456789, 987654321, "1234.*"]
|
||||
# 若要排除某个QQ号或群号(即允许使用),可以在前面加上"!",例如:
|
||||
# person = ["!1234567890"]
|
||||
# group = ["!987654321"]
|
||||
# 排除规则优先级高于包含规则,即如果同时存在包含规则和排除规则,排除规则将生效,例如:
|
||||
# person = ["1234.*", "!1234567890"]
|
||||
# 那么1234567890将不会被禁用,而其他以1234开头的QQ号都会被禁用
|
||||
person = [2854196310] # 2854196310是Q群管家机器人的QQ号,默认屏蔽以免出现循环
|
||||
group = [204785790, 691226829] # 本项目交流群的群号,默认屏蔽,避免在交流群测试机器人
|
||||
@@ -20,65 +20,145 @@ mirai_http_api_config = {
|
||||
|
||||
# [必需] OpenAI的配置
|
||||
# api_key: OpenAI的API Key
|
||||
# http_proxy: 请求OpenAI时使用的代理,None为不使用,https和socks5暂不能使用
|
||||
# 若只有一个api-key,请直接修改以下内容中的"openai_api_key"为你的api-key
|
||||
#
|
||||
# 如准备了多个api-key,可以以字典的形式填写,程序会自动选择可用的api-key
|
||||
# 例如{
|
||||
# "api0": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
# "api1": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
# }
|
||||
# 例如
|
||||
# openai_config = {
|
||||
# "api_key": {
|
||||
# "default": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
# "key1": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
# "key2": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
# },
|
||||
# "http_proxy": "http://127.0.0.1:12345"
|
||||
# }
|
||||
openai_config = {
|
||||
"api_key": {
|
||||
"default": "openai_api_key"
|
||||
},
|
||||
"http_proxy": None
|
||||
}
|
||||
|
||||
# 管理员QQ号,用于接收报错等通知及执行管理员级别指令,为0时关闭此功能
|
||||
# [必需] 管理员QQ号,用于接收报错等通知及执行管理员级别指令
|
||||
# 支持多个管理员,可以使用list形式设置,例如:
|
||||
# admin_qq = [12345678, 87654321]
|
||||
admin_qq = 0
|
||||
|
||||
# 情景预设(机器人人格)
|
||||
# 每个会话的预设信息,影响所有会话,无视指令重置
|
||||
# 可以通过这个字段指定某些情况的回复,可直接用自然语言描述指令
|
||||
# 例如: 如果我之后想获取帮助,请你说“输入!help获取帮助”,
|
||||
# 例如:
|
||||
# default_prompt = "如果我之后想获取帮助,请你说“输入!help获取帮助”"
|
||||
# 这样用户在不知所措的时候机器人就会提示其输入!help获取帮助
|
||||
# 可参考 https://github.com/PlexPt/awesome-chatgpt-prompts-zh
|
||||
default_prompt = "如果我之后想获取帮助,请你说“输入!help获取帮助”"
|
||||
#
|
||||
# 如果需要多个情景预设,并在运行期间方便切换,请使用字典的形式填写,例如
|
||||
# default_prompt = {
|
||||
# "default": "如果我之后想获取帮助,请你说“输入!help获取帮助”",
|
||||
# "linux-terminal": "我想让你充当 Linux 终端。我将输入命令,您将回复终端应显示的内容。",
|
||||
# "en-dict": "我想让你充当英英词典,对于给出的英文单词,你要给出其中文意思以及英文解释,并且给出一个例句,此外不要有其他反馈。",
|
||||
# }
|
||||
#
|
||||
# 在使用期间即可通过指令:
|
||||
# !reset [名称]
|
||||
# 来使用指定的情景预设重置会话
|
||||
# 例如:
|
||||
# !reset linux-terminal
|
||||
# 若不指定名称,则使用默认情景预设
|
||||
#
|
||||
# 也可以使用指令:
|
||||
# !default <名称>
|
||||
# 将指定的情景预设设置为默认情景预设
|
||||
# 例如:
|
||||
# !default linux-terminal
|
||||
# 之后的会话重置时若不指定名称,则使用linux-terminal情景预设
|
||||
#
|
||||
# 还可以加载文件中的预设文字,使用方法请查看:https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E9%A2%84%E8%AE%BE%E6%96%87%E5%AD%97
|
||||
default_prompt = {
|
||||
"default": "如果我之后想获取帮助,请你说“输入!help获取帮助”",
|
||||
}
|
||||
|
||||
# 实验性设置项: JSON完整情景导入
|
||||
# 预设prompt模式
|
||||
# 参考值:旧版本方式:default | 完整情景:full_scenario
|
||||
preset_mode = "default"
|
||||
|
||||
# 群内响应规则
|
||||
# 符合此消息的群内消息即使不包含at机器人也会响应
|
||||
# 支持消息前缀匹配及正则表达式匹配
|
||||
# 注意:由消息前缀(prefix)匹配的消息中将会删除此前缀,正则表达式匹配的消息不会删除匹配的部分
|
||||
# 支持设置是否响应at消息、随机响应概率
|
||||
# 注意:由消息前缀(prefix)匹配的消息中将会删除此前缀,正则表达式(regexp)匹配的消息不会删除匹配的部分
|
||||
# 前缀匹配优先级高于正则表达式匹配
|
||||
# 正则表达式简明教程:https://www.runoob.com/regexp/regexp-tutorial.html
|
||||
response_rules = {
|
||||
"at": True, # 是否响应at机器人的消息
|
||||
"prefix": ["/ai", "!ai", "!ai", "ai"],
|
||||
"regexp": [] # "为什么.*", "怎么?样.*", "怎么.*", "如何.*", "[Hh]ow to.*", "[Ww]hy not.*", "[Ww]hat is.*", ".*怎么办", ".*咋办"
|
||||
"regexp": [], # "为什么.*", "怎么?样.*", "怎么.*", "如何.*", "[Hh]ow to.*", "[Ww]hy not.*", "[Ww]hat is.*", ".*怎么办", ".*咋办"
|
||||
"random_rate": 0.0, # 随机响应概率,0.0-1.0,0.0为不随机响应,1.0为响应所有消息, 仅在前几项判断不通过时生效
|
||||
}
|
||||
|
||||
# 单个api-key的费用警告阈值
|
||||
# 当使用此api-key进行请求所消耗的费用估算达到此阈值时,会在控制台输出警告并通知管理员
|
||||
# 若之后还有未使用超过此值的api-key,则会切换到新的api-key进行请求
|
||||
# 单位:美元
|
||||
api_key_fee_threshold = 18.0
|
||||
# 消息忽略规则
|
||||
# 适用于私聊及群聊
|
||||
# 符合此规则的消息将不会被响应
|
||||
# 支持消息前缀匹配及正则表达式匹配
|
||||
# 此设置优先级高于response_rules
|
||||
# 用以过滤mirai等其他层级的指令
|
||||
# @see https://github.com/RockChinQ/QChatGPT/issues/165
|
||||
ignore_rules = {
|
||||
"prefix": ["/"],
|
||||
"regexp": []
|
||||
}
|
||||
|
||||
# 是否根据估算的使用费用切换api-key
|
||||
# 设置为False将只在接口报错超额时自动切换
|
||||
auto_switch_api_key = False
|
||||
# 是否检查收到的消息中是否包含敏感词
|
||||
# 若收到的消息无法通过下方指定的敏感词检查策略,则发送提示信息
|
||||
income_msg_check = False
|
||||
|
||||
# 敏感词过滤开关,以同样数量的*代替敏感词回复
|
||||
# 请在sensitive.json中添加敏感词
|
||||
sensitive_word_filter = True
|
||||
|
||||
# 是否启用百度云内容安全审核
|
||||
# 注册方式查看 https://cloud.baidu.com/doc/ANTIPORN/s/Wkhu9d5iy
|
||||
baidu_check = False
|
||||
|
||||
# 百度云API_KEY 24位英文数字字符串
|
||||
baidu_api_key = ""
|
||||
|
||||
# 百度云SECRET_KEY 32位的英文数字字符串
|
||||
baidu_secret_key = ""
|
||||
|
||||
# 不合规消息自定义返回
|
||||
inappropriate_message_tips = "[百度云]请珍惜机器人,当前返回内容不合规"
|
||||
|
||||
# 启动时是否发送赞赏码
|
||||
# 仅当使用量已经超过2048字时发送
|
||||
encourage_sponsor_at_start = True
|
||||
|
||||
# 每次向OpenAI接口发送对话记录上下文的字符数
|
||||
# 最大不超过(4096 - max_tokens)个字符,max_tokens为上述completion_api_params中的max_tokens
|
||||
# 最大不超过(4096 - max_tokens)个字符,max_tokens为下方completion_api_params中的max_tokens
|
||||
# 注意:较大的prompt_submit_length会导致OpenAI账户额度消耗更快
|
||||
prompt_submit_length = 1024
|
||||
|
||||
# OpenAI的completion API的参数
|
||||
# OpenAI补全API的参数
|
||||
# 请在下方填写模型,程序自动选择接口
|
||||
# 现已支持的模型有:
|
||||
#
|
||||
# 'gpt-3.5-turbo'
|
||||
# 'gpt-3.5-turbo-0301'
|
||||
# 'text-davinci-003'
|
||||
# 'text-davinci-002'
|
||||
# 'code-davinci-002'
|
||||
# 'code-cushman-001'
|
||||
# 'text-curie-001'
|
||||
# 'text-babbage-001'
|
||||
# 'text-ada-001'
|
||||
#
|
||||
# 具体请查看OpenAI的文档: https://beta.openai.com/docs/api-reference/completions/create
|
||||
completion_api_params = {
|
||||
"model": "text-davinci-003",
|
||||
"temperature": 0.6, # 数值越低得到的回答越理性,取值范围[0, 1]
|
||||
"max_tokens": 512, # 每次向OpenAI请求的最大字符数, 不高于4096
|
||||
"model": "gpt-3.5-turbo",
|
||||
"temperature": 0.9, # 数值越低得到的回答越理性,取值范围[0, 1]
|
||||
"max_tokens": 1024, # 每次获取OpenAI接口响应的文字量上限, 不高于4096
|
||||
"top_p": 1, # 生成的文本的文本与要求的符合度, 取值范围[0, 1]
|
||||
"frequency_penalty": 0.2,
|
||||
"presence_penalty": 1.0,
|
||||
@@ -90,31 +170,74 @@ image_api_params = {
|
||||
"size": "256x256", # 图片尺寸,支持256x256, 512x512, 1024x1024
|
||||
}
|
||||
|
||||
# 回复消息时是否引用原消息
|
||||
# 群内回复消息时是否引用原消息
|
||||
quote_origin = True
|
||||
|
||||
# 回复绘图时是否包含图片描述
|
||||
include_image_description = True
|
||||
|
||||
# 消息处理的超时时间,单位为秒
|
||||
process_message_timeout = 15
|
||||
|
||||
# 机器人的配置
|
||||
# user_name: 管理员(主人)的名字
|
||||
# bot_name: 机器人的名字
|
||||
user_name = 'You'
|
||||
bot_name = 'Bot'
|
||||
process_message_timeout = 30
|
||||
|
||||
# 回复消息时是否显示[GPT]前缀
|
||||
show_prefix = False
|
||||
|
||||
# 应用长消息处理策略的阈值
|
||||
# 当回复消息长度超过此值时,将使用长消息处理策略
|
||||
blob_message_threshold = 256
|
||||
|
||||
# 长消息处理策略
|
||||
# - "image": 将长消息转换为图片发送
|
||||
# - "forward": 将长消息转换为转发消息组件发送
|
||||
blob_message_strategy = "forward"
|
||||
|
||||
# 文字转图片时使用的字体文件路径
|
||||
# 当策略为"image"时生效
|
||||
# 若在Windows系统下,程序会自动使用Windows自带的微软雅黑字体
|
||||
# 若未填写或不存在且不是Windows,将禁用文字转图片功能,改为使用转发消息组件
|
||||
font_path = ""
|
||||
|
||||
# 消息处理超时重试次数
|
||||
retry_times = 3
|
||||
|
||||
# 消息处理出错时是否向用户隐藏错误详细信息
|
||||
# 设置为True时,仅向管理员发送错误详细信息
|
||||
# 设置为False时,向用户及管理员发送错误详细信息
|
||||
hide_exce_info_to_user = False
|
||||
|
||||
# 消息处理出错时向用户发送的提示信息
|
||||
# 仅当hide_exce_info_to_user为True时生效
|
||||
# 设置为空字符串时,不发送提示信息
|
||||
alter_tip_message = '出错了,请稍后再试'
|
||||
|
||||
# 机器人线程池大小
|
||||
# 该参数决定机器人可以同时处理几个人的消息,超出线程池数量的请求会被阻塞,不会被丢弃
|
||||
# 如果你不清楚该参数的意义,请不要更改
|
||||
pool_num = 10
|
||||
|
||||
# 每个会话的过期时间,单位为秒
|
||||
# 默认值20分钟
|
||||
session_expire_time = 60 * 20
|
||||
|
||||
# 会话限速
|
||||
# 单会话内每分钟可进行的对话次数
|
||||
# 若不需要限速,可以设置为一个很大的值
|
||||
# 默认值60次,基本上不会触发限速
|
||||
rate_limitation = 60
|
||||
|
||||
# 会话限速策略
|
||||
# - "wait": 每次对话获取到回复时,等待一定时间再发送回复,保证其不会超过限速均值
|
||||
# - "drop": 此分钟内,若对话次数超过限速次数,则丢弃之后的对话,每自然分钟重置
|
||||
rate_limit_strategy = "wait"
|
||||
|
||||
# drop策略时,超过限速均值时,丢弃的对话的提示信息
|
||||
# 仅当rate_limitation_strategy为"drop"时生效
|
||||
# 若设置为空字符串,则不发送提示信息
|
||||
rate_limit_drop_tip = "本分钟对话次数超过限速次数,此对话被丢弃"
|
||||
|
||||
# 是否在启动时进行依赖库更新
|
||||
upgrade_dependencies = True
|
||||
|
||||
# 是否上报统计信息
|
||||
# 用于统计机器人的使用情况,不会收集任何用户信息
|
||||
# 仅上报时间、字数使用量、绘图使用量,其他信息不会上报
|
||||
@@ -131,11 +254,4 @@ help_message = """此机器人通过调用OpenAI的GPT-3大型语言模型生成
|
||||
每次会话最后一次交互后{}分钟后会自动结束,结束后将开启新会话,如需继续前一次会话请发送 !last 重新开启
|
||||
欢迎到github.com/RockChinQ/QChatGPT 给个star
|
||||
|
||||
帮助信息:
|
||||
!help - 显示帮助
|
||||
!reset - 重置会话
|
||||
!last - 切换到前一次的对话
|
||||
!next - 切换到后一次的对话
|
||||
!prompt - 显示当前对话所有内容
|
||||
!list - 列出所有历史会话
|
||||
!usage - 列出各个api-key的使用量""".format(session_expire_time // 60)
|
||||
指令帮助信息请查看: https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E6%9C%BA%E5%99%A8%E4%BA%BA%E6%8C%87%E4%BB%A4""".format(session_expire_time // 60)
|
||||
|
||||
283
main.py
283
main.py
@@ -7,16 +7,24 @@ import time
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import mirai.exceptions
|
||||
import websockets.exceptions
|
||||
|
||||
try:
|
||||
import colorlog
|
||||
except ImportError:
|
||||
print("未安装colorlog,请查看 https://github.com/RockChinQ/qcg-installer/issues/15")
|
||||
sys.exit(1)
|
||||
# 尝试安装
|
||||
import pkg.utils.pkgmgr as pkgmgr
|
||||
pkgmgr.install_requirements("requirements.txt")
|
||||
try:
|
||||
import colorlog
|
||||
except ImportError:
|
||||
print("依赖不满足,请查看 https://github.com/RockChinQ/qcg-installer/issues/15")
|
||||
sys.exit(1)
|
||||
import colorlog
|
||||
|
||||
import requests
|
||||
import websockets.exceptions
|
||||
from urllib3.exceptions import InsecureRequestWarning
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
log_colors_config = {
|
||||
@@ -35,12 +43,87 @@ def init_db():
|
||||
database.initialize_database()
|
||||
|
||||
|
||||
def ensure_dependencies():
|
||||
import pkg.utils.pkgmgr as pkgmgr
|
||||
pkgmgr.run_pip(["install", "openai", "Pillow", "--upgrade",
|
||||
"-i", "https://pypi.douban.com/simple/",
|
||||
"--trusted-host", "pypi.douban.com"])
|
||||
|
||||
|
||||
known_exception_caught = False
|
||||
|
||||
log_file_name = "qchatgpt.log"
|
||||
|
||||
|
||||
def init_runtime_log_file():
|
||||
"""为此次运行生成日志文件
|
||||
格式: qchatgpt-yyyy-MM-dd-HH-mm-ss.log
|
||||
"""
|
||||
global log_file_name
|
||||
|
||||
# 检查logs目录是否存在
|
||||
if not os.path.exists("logs"):
|
||||
os.mkdir("logs")
|
||||
|
||||
# 检查本目录是否有qchatgpt.log,若有,移动到logs目录
|
||||
if os.path.exists("qchatgpt.log"):
|
||||
shutil.move("qchatgpt.log", "logs/qchatgpt.legacy.log")
|
||||
|
||||
log_file_name = "logs/qchatgpt-%s.log" % time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
||||
|
||||
|
||||
def reset_logging():
|
||||
global log_file_name
|
||||
assert os.path.exists('config.py')
|
||||
|
||||
config = importlib.import_module('config')
|
||||
|
||||
import pkg.utils.context
|
||||
|
||||
if pkg.utils.context.context['logger_handler'] is not None:
|
||||
logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler'])
|
||||
|
||||
for handler in logging.getLogger().handlers:
|
||||
logging.getLogger().removeHandler(handler)
|
||||
|
||||
logging.basicConfig(level=config.logging_level, # 设置日志输出格式
|
||||
filename=log_file_name, # log日志输出的文件位置和文件名
|
||||
format="[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s",
|
||||
# 日志输出的格式
|
||||
# -8表示占位符,让输出左对齐,输出长度都为8位
|
||||
datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式
|
||||
)
|
||||
sh = logging.StreamHandler()
|
||||
sh.setLevel(config.logging_level)
|
||||
sh.setFormatter(colorlog.ColoredFormatter(
|
||||
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : "
|
||||
"%(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
log_colors=log_colors_config
|
||||
))
|
||||
logging.getLogger().addHandler(sh)
|
||||
pkg.utils.context.context['logger_handler'] = sh
|
||||
return sh
|
||||
|
||||
|
||||
def main(first_time_init=False):
|
||||
"""启动流程,reload之后会被执行"""
|
||||
|
||||
global known_exception_caught
|
||||
|
||||
import config
|
||||
# 更新openai库到最新版本
|
||||
if not hasattr(config, 'upgrade_dependencies') or config.upgrade_dependencies:
|
||||
print("正在更新依赖库,请等待...")
|
||||
if not hasattr(config, 'upgrade_dependencies'):
|
||||
print("这个操作不是必须的,如果不想更新,请在config.py中添加upgrade_dependencies=False")
|
||||
else:
|
||||
print("这个操作不是必须的,如果不想更新,请在config.py中将upgrade_dependencies设置为False")
|
||||
try:
|
||||
ensure_dependencies()
|
||||
except Exception as e:
|
||||
print("更新openai库失败:{}, 请忽略或自行更新".format(e))
|
||||
|
||||
known_exception_caught = False
|
||||
try:
|
||||
# 导入config.py
|
||||
@@ -48,37 +131,58 @@ def main(first_time_init=False):
|
||||
|
||||
config = importlib.import_module('config')
|
||||
|
||||
init_runtime_log_file()
|
||||
|
||||
sh = reset_logging()
|
||||
|
||||
# 配置完整性校验
|
||||
is_integrity = True
|
||||
config_template = importlib.import_module('config-template')
|
||||
for key in dir(config_template):
|
||||
if not key.startswith("__") and not hasattr(config, key):
|
||||
setattr(config, key, getattr(config_template, key))
|
||||
logging.warning("[{}]不存在".format(key))
|
||||
is_integrity = False
|
||||
if not is_integrity:
|
||||
logging.warning("配置文件不完整,请依据config-template.py检查config.py")
|
||||
logging.warning("以上配置已被设为默认值,将在5秒后继续启动... ")
|
||||
time.sleep(5)
|
||||
|
||||
import pkg.utils.context
|
||||
pkg.utils.context.set_config(config)
|
||||
|
||||
if pkg.utils.context.context['logger_handler'] is not None:
|
||||
logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler'])
|
||||
|
||||
logging.basicConfig(level=config.logging_level, # 设置日志输出格式
|
||||
filename='qchatgpt.log', # log日志输出的文件位置和文件名
|
||||
format="[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s",
|
||||
# 日志输出的格式
|
||||
# -8表示占位符,让输出左对齐,输出长度都为8位
|
||||
datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式
|
||||
)
|
||||
sh = logging.StreamHandler()
|
||||
sh.setLevel(config.logging_level)
|
||||
sh.setFormatter(colorlog.ColoredFormatter(
|
||||
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : "
|
||||
"%(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
log_colors=log_colors_config
|
||||
))
|
||||
logging.getLogger().addHandler(sh)
|
||||
|
||||
# 检查是否设置了管理员
|
||||
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
|
||||
logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
|
||||
# logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
|
||||
while True:
|
||||
try:
|
||||
config.admin_qq = int(input("未设置管理员QQ,管理员权限指令及运行告警将无法使用,请输入管理员QQ号: "))
|
||||
# 写入到文件
|
||||
|
||||
# 读取文件
|
||||
config_file_str = ""
|
||||
with open("config.py", "r", encoding="utf-8") as f:
|
||||
config_file_str = f.read()
|
||||
# 替换
|
||||
config_file_str = config_file_str.replace("admin_qq = 0", "admin_qq = " + str(config.admin_qq))
|
||||
# 写入
|
||||
with open("config.py", "w", encoding="utf-8") as f:
|
||||
f.write(config_file_str)
|
||||
|
||||
print("管理员QQ已设置,如需修改请修改config.py中的admin_qq字段")
|
||||
time.sleep(4)
|
||||
break
|
||||
except ValueError:
|
||||
print("请输入数字")
|
||||
|
||||
import pkg.openai.manager
|
||||
import pkg.database.manager
|
||||
import pkg.openai.session
|
||||
import pkg.qqbot.manager
|
||||
import pkg.openai.dprompt
|
||||
|
||||
pkg.openai.dprompt.read_prompt_from_file()
|
||||
pkg.openai.dprompt.read_scenario_from_file()
|
||||
|
||||
pkg.utils.context.context['logger_handler'] = sh
|
||||
# 主启动流程
|
||||
@@ -94,9 +198,17 @@ def main(first_time_init=False):
|
||||
# 初始化qq机器人
|
||||
qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config,
|
||||
timeout=config.process_message_timeout, retry=config.retry_times,
|
||||
first_time_init=first_time_init)
|
||||
first_time_init=first_time_init, pool_num=config.pool_num)
|
||||
|
||||
if first_time_init: # 不是热重载之后的启动,则不启动新的bot线程
|
||||
# 加载插件
|
||||
import pkg.plugin.host
|
||||
pkg.plugin.host.load_plugins()
|
||||
|
||||
pkg.plugin.host.initialize_plugins()
|
||||
|
||||
if first_time_init: # 不是热重载之后的启动,则启动新的bot线程
|
||||
|
||||
import mirai.exceptions
|
||||
|
||||
def run_bot_wrapper():
|
||||
global known_exception_caught
|
||||
@@ -105,16 +217,16 @@ def main(first_time_init=False):
|
||||
except TypeError as e:
|
||||
if str(e).__contains__("argument 'debug'"):
|
||||
logging.error(
|
||||
"连接bot失败:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/82".format(e))
|
||||
"连接bot失败:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/82".format(e))
|
||||
known_exception_caught = True
|
||||
elif str(e).__contains__("As of 3.10, the *loop*"):
|
||||
logging.error(
|
||||
"Websockets版本过低:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/5".format(e))
|
||||
"Websockets版本过低:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/5".format(e))
|
||||
known_exception_caught = True
|
||||
|
||||
except websockets.exceptions.InvalidStatus as e:
|
||||
logging.error(
|
||||
"mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(
|
||||
"mirai-api-http端口无法使用:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/22".format(
|
||||
e))
|
||||
known_exception_caught = True
|
||||
except mirai.exceptions.NetworkError as e:
|
||||
@@ -123,18 +235,32 @@ def main(first_time_init=False):
|
||||
except Exception as e:
|
||||
if str(e).__contains__("404"):
|
||||
logging.error(
|
||||
"mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(
|
||||
"mirai-api-http端口无法使用:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/22".format(
|
||||
e))
|
||||
known_exception_caught = True
|
||||
elif str(e).__contains__("signal only works in main thread"):
|
||||
logging.error(
|
||||
"hypercorn异常:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/86".format(
|
||||
e))
|
||||
known_exception_caught = True
|
||||
elif str(e).__contains__("did not receive a valid HTTP"):
|
||||
logging.error(
|
||||
"mirai-api-http端口无法使用:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/22".format(
|
||||
e))
|
||||
else:
|
||||
logging.error(
|
||||
"捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/issues 查找或提issue".format(e))
|
||||
"捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/QChatGPT/issues 查找或提issue".format(e))
|
||||
known_exception_caught = True
|
||||
raise e
|
||||
|
||||
qq_bot_thread = threading.Thread(target=run_bot_wrapper, args=(), daemon=True)
|
||||
qq_bot_thread.start()
|
||||
finally:
|
||||
# 判断若是Windows,输出选择模式可能会暂停程序的警告
|
||||
if os.name == 'nt':
|
||||
time.sleep(2)
|
||||
logging.info("您正在使用Windows系统,若命令行窗口处于“选择”模式,程序可能会被暂停,此时请右键点击窗口空白区域使其取消选择模式。")
|
||||
|
||||
time.sleep(12)
|
||||
if first_time_init:
|
||||
if not known_exception_caught:
|
||||
@@ -145,17 +271,36 @@ def main(first_time_init=False):
|
||||
else:
|
||||
logging.info('热重载完成')
|
||||
|
||||
while True:
|
||||
try:
|
||||
time.sleep(10)
|
||||
if qqbot != pkg.utils.context.get_qqbot_manager(): # 已经reload了
|
||||
logging.info("以前的main流程由于reload退出")
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
stop()
|
||||
# 发送赞赏码
|
||||
if hasattr(config, 'encourage_sponsor_at_start') \
|
||||
and config.encourage_sponsor_at_start \
|
||||
and pkg.utils.context.get_openai_manager().audit_mgr.get_total_text_length() >= 2048:
|
||||
|
||||
print("程序退出")
|
||||
sys.exit(0)
|
||||
logging.info("发送赞赏码")
|
||||
from mirai import MessageChain, Plain, Image
|
||||
import pkg.utils.constants
|
||||
message_chain = MessageChain([
|
||||
Plain("自2022年12月初以来,开发者已经花费了大量时间和精力来维护本项目,如果您觉得本项目对您有帮助,欢迎赞赏开发者,"
|
||||
"以支持项目稳定运行😘"),
|
||||
Image(base64=pkg.utils.constants.alipay_qr_b64),
|
||||
Image(base64=pkg.utils.constants.wechat_qr_b64),
|
||||
Plain("BTC: 3N4Azee63vbBB9boGv9Rjf4N5SocMe5eCq\nXMR: 89LS21EKQuDGkyQoe2nDupiuWXk4TVD6FALvSKv5owfmeJEPFpHeMsZLYtLiJ6GxLrhsRe5gMs6MyMSDn4GNQAse2Mae4KE\n\n"),
|
||||
Plain("(本消息仅在启动时发送至管理员,如果您不想再看到此消息,请在config.py中将encourage_sponsor_at_start设置为False)")
|
||||
])
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin_message_chain(message_chain)
|
||||
|
||||
time.sleep(5)
|
||||
import pkg.utils.updater
|
||||
try:
|
||||
if pkg.utils.updater.is_new_version_available():
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("新版本可用,请发送 !update 进行自动更新\n更新日志:\n{}".format("\n".join(pkg.utils.updater.get_rls_notes())))
|
||||
else:
|
||||
logging.info("当前已是最新版本")
|
||||
|
||||
except Exception as e:
|
||||
logging.warning("检查更新失败:{}".format(e))
|
||||
|
||||
return qqbot
|
||||
|
||||
|
||||
def stop():
|
||||
@@ -163,10 +308,12 @@ def stop():
|
||||
import pkg.qqbot.manager
|
||||
import pkg.openai.session
|
||||
try:
|
||||
import pkg.plugin.host
|
||||
pkg.plugin.host.unload_plugins()
|
||||
|
||||
qqbot_inst = pkg.utils.context.get_qqbot_manager()
|
||||
assert isinstance(qqbot_inst, pkg.qqbot.manager.QQBotManager)
|
||||
|
||||
pkg.utils.context.get_openai_manager().key_mgr.dump_fee()
|
||||
for session in pkg.openai.session.sessions:
|
||||
logging.info('持久化session: %s', session)
|
||||
pkg.openai.session.sessions[session].persistence()
|
||||
@@ -183,22 +330,48 @@ if __name__ == '__main__':
|
||||
print('请先在config.py中填写配置')
|
||||
sys.exit(0)
|
||||
|
||||
# 检查是否有banlist.py,如果没有就把banlist-template.py复制一份
|
||||
if not os.path.exists('banlist.py'):
|
||||
shutil.copy('banlist-template.py', 'banlist.py')
|
||||
|
||||
# 检查是否有sensitive.json
|
||||
if not os.path.exists("sensitive.json"):
|
||||
shutil.copy("sensitive-template.json", "sensitive.json")
|
||||
|
||||
# 检查是否有scenario/default.json
|
||||
if not os.path.exists("scenario/default.json"):
|
||||
shutil.copy("scenario/default-template.json", "scenario/default.json")
|
||||
|
||||
# 检查temp目录
|
||||
if not os.path.exists("temp/"):
|
||||
os.mkdir("temp/")
|
||||
|
||||
# 检查并创建plugins、prompts目录
|
||||
check_path = ["plugins", "prompts"]
|
||||
for path in check_path:
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == 'init_db':
|
||||
init_db()
|
||||
sys.exit(0)
|
||||
|
||||
elif len(sys.argv) > 1 and sys.argv[1] == 'update':
|
||||
try:
|
||||
from dulwich import porcelain
|
||||
|
||||
repo = porcelain.open_repo('.')
|
||||
porcelain.pull(repo)
|
||||
except ModuleNotFoundError:
|
||||
print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
|
||||
print("正在进行程序更新...")
|
||||
import pkg.utils.updater as updater
|
||||
updater.update_all(cli=True)
|
||||
sys.exit(0)
|
||||
|
||||
# import pkg.utils.configmgr
|
||||
#
|
||||
# pkg.utils.configmgr.set_config_and_reload("quote_origin", False)
|
||||
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
|
||||
|
||||
main(True)
|
||||
qqbot = main(True)
|
||||
|
||||
import pkg.utils.context
|
||||
while True:
|
||||
try:
|
||||
time.sleep(10)
|
||||
except KeyboardInterrupt:
|
||||
stop()
|
||||
|
||||
print("程序退出")
|
||||
sys.exit(0)
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
审计相关操作
|
||||
"""
|
||||
@@ -1,3 +1,7 @@
|
||||
"""
|
||||
使用量统计以及数据上报功能实现
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
@@ -5,14 +9,16 @@ import logging
|
||||
import requests
|
||||
|
||||
import pkg.utils.context
|
||||
|
||||
version = "0.1.0"
|
||||
import pkg.utils.updater
|
||||
|
||||
|
||||
class DataGatherer:
|
||||
"""数据收集器"""
|
||||
|
||||
usage = {}
|
||||
"""以key值md5为key,{
|
||||
"""各api-key的使用量
|
||||
|
||||
以key值md5为key,{
|
||||
"text": {
|
||||
"text-davinci-003": 文字量:int,
|
||||
},
|
||||
@@ -21,22 +27,38 @@ class DataGatherer:
|
||||
}
|
||||
}为值的字典"""
|
||||
|
||||
version_str = "undetermined"
|
||||
|
||||
def __init__(self):
|
||||
self.load_from_db()
|
||||
try:
|
||||
self.version_str = pkg.utils.updater.get_current_tag() # 从updater模块获取版本号
|
||||
except:
|
||||
pass
|
||||
|
||||
def report_to_server(self, subservice_name: str, count: int):
|
||||
"""向中央服务器报告使用量
|
||||
|
||||
只会报告此次请求的使用量,不会报告总量。
|
||||
不包含除版本号、使用类型、使用量以外的任何信息,仅供开发者分析使用情况。
|
||||
"""
|
||||
try:
|
||||
config = pkg.utils.context.get_config()
|
||||
if hasattr(config, "report_usage") and not config.report_usage:
|
||||
return
|
||||
res = requests.get("http://rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}".format(subservice_name, version, count))
|
||||
res = requests.get("http://rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}".format(subservice_name, self.version_str, count))
|
||||
if res.status_code != 200 or res.text != "ok":
|
||||
logging.warning("report to server failed, status_code: {}, text: {}".format(res.status_code, res.text))
|
||||
except:
|
||||
return
|
||||
|
||||
def report_text_model_usage(self, model, text):
|
||||
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()
|
||||
def get_usage(self, key_md5):
|
||||
return self.usage[key_md5] if key_md5 in self.usage else {}
|
||||
|
||||
def report_text_model_usage(self, model, total_tokens):
|
||||
"""调用方报告文字模型请求文字使用量"""
|
||||
|
||||
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存
|
||||
|
||||
if key_md5 not in self.usage:
|
||||
self.usage[key_md5] = {}
|
||||
@@ -47,13 +69,15 @@ class DataGatherer:
|
||||
if model not in self.usage[key_md5]["text"]:
|
||||
self.usage[key_md5]["text"][model] = 0
|
||||
|
||||
length = int((len(text.encode('utf-8')) - len(text)) / 2 + len(text))
|
||||
length = total_tokens
|
||||
self.usage[key_md5]["text"][model] += length
|
||||
self.dump_to_db()
|
||||
|
||||
self.report_to_server("text", length)
|
||||
|
||||
def report_image_model_usage(self, size):
|
||||
"""调用方报告图片模型请求图片使用量"""
|
||||
|
||||
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()
|
||||
|
||||
if key_md5 not in self.usage:
|
||||
@@ -71,6 +95,7 @@ class DataGatherer:
|
||||
self.report_to_server("image", 1)
|
||||
|
||||
def get_text_length_of_key(self, key):
|
||||
"""获取指定api-key (明文) 的文字总使用量(本地记录)"""
|
||||
key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
|
||||
if key_md5 not in self.usage:
|
||||
return 0
|
||||
@@ -80,6 +105,8 @@ class DataGatherer:
|
||||
return sum(self.usage[key_md5]["text"].values())
|
||||
|
||||
def get_image_count_of_key(self, key):
|
||||
"""获取指定api-key (明文) 的图片总使用量(本地记录)"""
|
||||
|
||||
key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
|
||||
if key_md5 not in self.usage:
|
||||
return 0
|
||||
@@ -88,6 +115,15 @@ class DataGatherer:
|
||||
# 遍历其中所有模型,求和
|
||||
return sum(self.usage[key_md5]["image"].values())
|
||||
|
||||
def get_total_text_length(self):
|
||||
"""获取所有api-key的文字总使用量(本地记录)"""
|
||||
total = 0
|
||||
for key in self.usage:
|
||||
if "text" not in self.usage[key]:
|
||||
continue
|
||||
total += sum(self.usage[key]["text"].values())
|
||||
return total
|
||||
|
||||
def dump_to_db(self):
|
||||
pkg.utils.context.get_database_manager().dump_usage_json(self.usage)
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
数据库操作封装
|
||||
"""
|
||||
@@ -1,3 +1,6 @@
|
||||
"""
|
||||
数据库管理模块
|
||||
"""
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
@@ -9,9 +12,9 @@ import sqlite3
|
||||
import pkg.utils.context
|
||||
|
||||
|
||||
# 数据库管理
|
||||
# 为其他模块提供数据库操作接口
|
||||
class DatabaseManager:
|
||||
"""封装数据库底层操作,并提供方法给上层使用"""
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
|
||||
@@ -23,21 +26,25 @@ class DatabaseManager:
|
||||
|
||||
# 连接到数据库文件
|
||||
def reconnect(self):
|
||||
"""连接到数据库"""
|
||||
self.conn = sqlite3.connect('database.db', check_same_thread=False)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def execute(self, *args, **kwargs) -> Cursor:
|
||||
def __execute__(self, *args, **kwargs) -> Cursor:
|
||||
# logging.debug('SQL: {}'.format(sql))
|
||||
logging.debug('SQL: {}'.format(args))
|
||||
c = self.cursor.execute(*args, **kwargs)
|
||||
self.conn.commit()
|
||||
return c
|
||||
|
||||
# 初始化数据库的函数
|
||||
def initialize_database(self):
|
||||
self.execute("""
|
||||
"""创建数据表"""
|
||||
|
||||
self.__execute__("""
|
||||
create table if not exists `sessions` (
|
||||
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
`name` varchar(255) not null,
|
||||
@@ -46,20 +53,24 @@ class DatabaseManager:
|
||||
`create_timestamp` bigint not null,
|
||||
`last_interact_timestamp` bigint not null,
|
||||
`status` varchar(255) not null default 'on_going',
|
||||
`default_prompt` text not null default '',
|
||||
`prompt` text not null
|
||||
)
|
||||
""")
|
||||
|
||||
# self.execute("""
|
||||
# create table if not exists `api_key_usage`(
|
||||
# `id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
# `key_md5` varchar(255) not null,
|
||||
# `timestamp` bigint not null,
|
||||
# `usage` bigint not null
|
||||
# )
|
||||
# """)
|
||||
# 检查sessions表是否存在`default_prompt`字段
|
||||
self.__execute__("PRAGMA table_info('sessions')")
|
||||
columns = self.cursor.fetchall()
|
||||
has_default_prompt = False
|
||||
for field in columns:
|
||||
if field[1] == 'default_prompt':
|
||||
has_default_prompt = True
|
||||
break
|
||||
if not has_default_prompt:
|
||||
self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''")
|
||||
|
||||
self.execute("""
|
||||
|
||||
self.__execute__("""
|
||||
create table if not exists `account_fee`(
|
||||
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
`key_md5` varchar(255) not null,
|
||||
@@ -68,7 +79,7 @@ class DatabaseManager:
|
||||
)
|
||||
""")
|
||||
|
||||
self.execute("""
|
||||
self.__execute__("""
|
||||
create table if not exists `account_usage`(
|
||||
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
`json` text not null
|
||||
@@ -78,47 +89,49 @@ class DatabaseManager:
|
||||
|
||||
# session持久化
|
||||
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
|
||||
last_interact_timestamp: int, prompt: str):
|
||||
last_interact_timestamp: int, prompt: str, default_prompt: str = ''):
|
||||
"""持久化指定session"""
|
||||
|
||||
# 检查是否已经有了此name和create_timestamp的session
|
||||
# 如果有,就更新prompt和last_interact_timestamp
|
||||
# 如果没有,就插入一条新的记录
|
||||
self.execute("""
|
||||
self.__execute__("""
|
||||
select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {}
|
||||
""".format(subject_type, subject_number, create_timestamp))
|
||||
count = self.cursor.fetchone()[0]
|
||||
if count == 0:
|
||||
|
||||
sql = """
|
||||
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`)
|
||||
values (?, ?, ?, ?, ?, ?)
|
||||
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`)
|
||||
values (?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
self.execute(sql,
|
||||
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
|
||||
last_interact_timestamp, prompt))
|
||||
self.__execute__(sql,
|
||||
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
|
||||
last_interact_timestamp, prompt, default_prompt))
|
||||
else:
|
||||
sql = """
|
||||
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?
|
||||
where `type` = ? and `number` = ? and `create_timestamp` = ?
|
||||
"""
|
||||
|
||||
self.execute(sql, (last_interact_timestamp, prompt, subject_type,
|
||||
subject_number, create_timestamp))
|
||||
self.__execute__(sql, (last_interact_timestamp, prompt, subject_type,
|
||||
subject_number, create_timestamp))
|
||||
|
||||
# 显式关闭一个session
|
||||
def explicit_close_session(self, session_name: str, create_timestamp: int):
|
||||
self.execute("""
|
||||
self.__execute__("""
|
||||
update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {}
|
||||
""".format(session_name, create_timestamp))
|
||||
|
||||
def set_session_ongoing(self, session_name: str, create_timestamp: int):
|
||||
self.execute("""
|
||||
self.__execute__("""
|
||||
update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {}
|
||||
""".format(session_name, create_timestamp))
|
||||
|
||||
# 设置session为过期
|
||||
def set_session_expired(self, session_name: str, create_timestamp: int):
|
||||
self.execute("""
|
||||
self.__execute__("""
|
||||
update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {}
|
||||
""".format(session_name, create_timestamp))
|
||||
|
||||
@@ -126,8 +139,8 @@ class DatabaseManager:
|
||||
def load_valid_sessions(self) -> dict:
|
||||
# 从数据库中加载所有还没过期的session
|
||||
config = pkg.utils.context.get_config()
|
||||
self.execute("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
||||
self.__execute__("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
|
||||
from `sessions` where `last_interact_timestamp` > {}
|
||||
""".format(int(time.time()) - config.session_expire_time))
|
||||
results = self.cursor.fetchall()
|
||||
@@ -140,6 +153,7 @@ class DatabaseManager:
|
||||
last_interact_timestamp = result[4]
|
||||
prompt = result[5]
|
||||
status = result[6]
|
||||
default_prompt = result[7]
|
||||
|
||||
# 当且仅当最后一个该对象的会话是on_going状态时,才会被加载
|
||||
if status == 'on_going':
|
||||
@@ -148,7 +162,8 @@ class DatabaseManager:
|
||||
'subject_number': subject_number,
|
||||
'create_timestamp': create_timestamp,
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt
|
||||
'prompt': prompt,
|
||||
'default_prompt': default_prompt
|
||||
}
|
||||
else:
|
||||
if session_name in sessions:
|
||||
@@ -159,8 +174,8 @@ class DatabaseManager:
|
||||
# 获取此session_name前一个session的数据
|
||||
def last_session(self, session_name: str, cursor_timestamp: int):
|
||||
|
||||
self.execute("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
||||
self.__execute__("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
|
||||
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
|
||||
limit 1
|
||||
""".format(session_name, cursor_timestamp))
|
||||
@@ -176,20 +191,22 @@ class DatabaseManager:
|
||||
last_interact_timestamp = result[4]
|
||||
prompt = result[5]
|
||||
status = result[6]
|
||||
default_prompt = result[7]
|
||||
|
||||
return {
|
||||
'subject_type': subject_type,
|
||||
'subject_number': subject_number,
|
||||
'create_timestamp': create_timestamp,
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt
|
||||
'prompt': prompt,
|
||||
'default_prompt': default_prompt
|
||||
}
|
||||
|
||||
# 获取此session_name后一个session的数据
|
||||
def next_session(self, session_name: str, cursor_timestamp: int):
|
||||
|
||||
self.execute("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
||||
self.__execute__("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
|
||||
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
|
||||
limit 1
|
||||
""".format(session_name, cursor_timestamp))
|
||||
@@ -205,19 +222,21 @@ class DatabaseManager:
|
||||
last_interact_timestamp = result[4]
|
||||
prompt = result[5]
|
||||
status = result[6]
|
||||
default_prompt = result[7]
|
||||
|
||||
return {
|
||||
'subject_type': subject_type,
|
||||
'subject_number': subject_number,
|
||||
'create_timestamp': create_timestamp,
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt
|
||||
'prompt': prompt,
|
||||
'default_prompt': default_prompt
|
||||
}
|
||||
|
||||
# 列出与某个对象的所有对话session
|
||||
def list_history(self, session_name: str, capacity: int, page: int, replace: str = ""):
|
||||
self.execute("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
||||
def list_history(self, session_name: str, capacity: int, page: int):
|
||||
self.__execute__("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
|
||||
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
|
||||
""".format(session_name, capacity, capacity * page))
|
||||
results = self.cursor.fetchall()
|
||||
@@ -230,17 +249,40 @@ class DatabaseManager:
|
||||
last_interact_timestamp = result[4]
|
||||
prompt = result[5]
|
||||
status = result[6]
|
||||
default_prompt = result[7]
|
||||
|
||||
sessions.append({
|
||||
'subject_type': subject_type,
|
||||
'subject_number': subject_number,
|
||||
'create_timestamp': create_timestamp,
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt if replace == "" else prompt.replace(replace, "")
|
||||
'prompt': prompt,
|
||||
'default_prompt': default_prompt
|
||||
})
|
||||
|
||||
return sessions
|
||||
|
||||
def delete_history(self, session_name: str, index: int) -> bool:
|
||||
# 删除倒序第index个session
|
||||
# 查找其id再删除
|
||||
self.__execute__("""
|
||||
delete from `sessions` where `id` in (select `id` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit 1 offset {})
|
||||
""".format(session_name, index))
|
||||
|
||||
return self.cursor.rowcount == 1
|
||||
|
||||
def delete_all_history(self, session_name: str) -> bool:
|
||||
self.__execute__("""
|
||||
delete from `sessions` where `name` = '{}'
|
||||
""".format(session_name))
|
||||
return self.cursor.rowcount > 0
|
||||
|
||||
def delete_all_session_history(self) -> bool:
|
||||
self.__execute__("""
|
||||
delete from `sessions`
|
||||
""")
|
||||
return self.cursor.rowcount > 0
|
||||
|
||||
# 将apikey的使用量存进数据库
|
||||
def dump_api_key_usage(self, api_keys: dict, usage: dict):
|
||||
logging.debug('dumping api key usage...')
|
||||
@@ -255,22 +297,22 @@ class DatabaseManager:
|
||||
usage_count = usage[key_md5]
|
||||
# 将使用量存进数据库
|
||||
# 先检查是否已存在
|
||||
self.execute("""
|
||||
self.__execute__("""
|
||||
select count(*) from `api_key_usage` where `key_md5` = '{}'""".format(key_md5))
|
||||
result = self.cursor.fetchone()
|
||||
if result[0] == 0:
|
||||
# 不存在则插入
|
||||
self.execute("""
|
||||
self.__execute__("""
|
||||
insert into `api_key_usage` (`key_md5`, `usage`,`timestamp`) values ('{}', {}, {})
|
||||
""".format(key_md5, usage_count, int(time.time())))
|
||||
else:
|
||||
# 存在则更新,timestamp设置为当前
|
||||
self.execute("""
|
||||
self.__execute__("""
|
||||
update `api_key_usage` set `usage` = {}, `timestamp` = {} where `key_md5` = '{}'
|
||||
""".format(usage_count, int(time.time()), key_md5))
|
||||
|
||||
def load_api_key_usage(self):
|
||||
self.execute("""
|
||||
self.__execute__("""
|
||||
select `key_md5`, `usage` from `api_key_usage`
|
||||
""")
|
||||
results = self.cursor.fetchall()
|
||||
@@ -281,63 +323,25 @@ class DatabaseManager:
|
||||
usage[key_md5] = usage_count
|
||||
return usage
|
||||
|
||||
def dump_api_key_fee(self, api_keys: dict, fee: dict):
|
||||
logging.debug("dumping api key fee...")
|
||||
logging.debug(api_keys)
|
||||
logging.debug(fee)
|
||||
for api_key in api_keys:
|
||||
# 计算key的md5值
|
||||
key_md5 = hashlib.md5(api_keys[api_key].encode('utf-8')).hexdigest()
|
||||
# 获取使用量
|
||||
fee_count = 0
|
||||
if key_md5 in fee:
|
||||
fee_count = fee[key_md5]
|
||||
# 将使用量存进数据库
|
||||
# 先检查是否已存在
|
||||
self.execute("""
|
||||
select count(*) from `account_fee` where `key_md5` = '{}'""".format(key_md5))
|
||||
result = self.cursor.fetchone()
|
||||
if result[0] == 0:
|
||||
# 不存在则插入
|
||||
self.execute("""
|
||||
insert into `account_fee` (`key_md5`, `fee`,`timestamp`) values ('{}', {}, {})
|
||||
""".format(key_md5, fee_count, int(time.time())))
|
||||
else:
|
||||
# 存在则更新,timestamp设置为当前
|
||||
self.execute("""
|
||||
update `account_fee` set `fee` = {}, `timestamp` = {} where `key_md5` = '{}'
|
||||
""".format(fee_count, int(time.time()), key_md5))
|
||||
|
||||
def load_api_key_fee(self):
|
||||
self.execute("""
|
||||
select `key_md5`, `fee` from `account_fee`
|
||||
""")
|
||||
results = self.cursor.fetchall()
|
||||
fee = {}
|
||||
for result in results:
|
||||
key_md5 = result[0]
|
||||
fee_count = result[1]
|
||||
fee[key_md5] = fee_count
|
||||
return fee
|
||||
|
||||
def dump_usage_json(self, usage: dict):
|
||||
|
||||
json_str = json.dumps(usage)
|
||||
self.execute("""
|
||||
self.__execute__("""
|
||||
select count(*) from `account_usage`""")
|
||||
result = self.cursor.fetchone()
|
||||
if result[0] == 0:
|
||||
# 不存在则插入
|
||||
self.execute("""
|
||||
self.__execute__("""
|
||||
insert into `account_usage` (`json`) values ('{}')
|
||||
""".format(json_str))
|
||||
else:
|
||||
# 存在则更新
|
||||
self.execute("""
|
||||
self.__execute__("""
|
||||
update `account_usage` set `json` = '{}' where `id` = 1
|
||||
""".format(json_str))
|
||||
|
||||
def load_usage_json(self):
|
||||
self.execute("""
|
||||
self.__execute__("""
|
||||
select `json` from `account_usage` order by id desc limit 1
|
||||
""")
|
||||
result = self.cursor.fetchone()
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
"""OpenAI 接口处理及会话管理相关
|
||||
"""
|
||||
|
||||
121
pkg/openai/dprompt.py
Normal file
121
pkg/openai/dprompt.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# 多情景预设值管理
|
||||
import json
|
||||
import logging
|
||||
|
||||
__current__ = "default"
|
||||
"""当前默认使用的情景预设的名称
|
||||
|
||||
由管理员使用`!default <名称>`指令切换
|
||||
"""
|
||||
|
||||
__prompts_from_files__ = {}
|
||||
"""从文件中读取的情景预设值"""
|
||||
|
||||
__scenario_from_files__ = {}
|
||||
|
||||
|
||||
def read_prompt_from_file():
|
||||
"""从文件读取预设值"""
|
||||
# 读取prompts/目录下的所有文件,以文件名为键,文件内容为值
|
||||
# 保存在__prompts_from_files__中
|
||||
global __prompts_from_files__
|
||||
import os
|
||||
|
||||
__prompts_from_files__ = {}
|
||||
for file in os.listdir("prompts"):
|
||||
with open(os.path.join("prompts", file), encoding="utf-8") as f:
|
||||
__prompts_from_files__[file] = f.read()
|
||||
|
||||
|
||||
def read_scenario_from_file():
|
||||
"""从JSON文件读取情景预设"""
|
||||
global __scenario_from_files__
|
||||
import os
|
||||
|
||||
__scenario_from_files__ = {}
|
||||
for file in os.listdir("scenario"):
|
||||
if file == "default-template.json":
|
||||
continue
|
||||
with open(os.path.join("scenario", file), encoding="utf-8") as f:
|
||||
__scenario_from_files__[file] = json.load(f)
|
||||
|
||||
|
||||
def get_prompt_dict() -> dict:
|
||||
"""获取预设值字典"""
|
||||
import config
|
||||
default_prompt = config.default_prompt
|
||||
if type(default_prompt) == str:
|
||||
default_prompt = {"default": default_prompt}
|
||||
elif type(default_prompt) == dict:
|
||||
pass
|
||||
else:
|
||||
raise TypeError("default_prompt must be str or dict")
|
||||
|
||||
# 将文件中的预设值合并到default_prompt中
|
||||
for key in __prompts_from_files__:
|
||||
default_prompt[key] = __prompts_from_files__[key]
|
||||
|
||||
return default_prompt
|
||||
|
||||
|
||||
def set_current(name):
|
||||
global __current__
|
||||
for key in get_prompt_dict():
|
||||
if key.lower().startswith(name.lower()):
|
||||
__current__ = key
|
||||
return
|
||||
raise KeyError("未找到情景预设: " + name)
|
||||
|
||||
|
||||
def get_current():
|
||||
global __current__
|
||||
return __current__
|
||||
|
||||
|
||||
def set_to_default():
|
||||
global __current__
|
||||
default_dict = get_prompt_dict()
|
||||
|
||||
if "default" in default_dict:
|
||||
__current__ = "default"
|
||||
else:
|
||||
__current__ = list(default_dict.keys())[0]
|
||||
|
||||
|
||||
def get_prompt(name: str = None) -> list:
|
||||
global __scenario_from_files__
|
||||
import config
|
||||
preset_mode = config.preset_mode
|
||||
|
||||
"""获取预设值"""
|
||||
if name is None:
|
||||
name = get_current()
|
||||
|
||||
# JSON预设方式
|
||||
if preset_mode == 'full_scenario':
|
||||
import os
|
||||
|
||||
for key in __scenario_from_files__:
|
||||
if key.lower().startswith(name.lower()):
|
||||
logging.debug('成功加载情景预设从JSON文件: {}'.format(key))
|
||||
return __scenario_from_files__[key]['prompt']
|
||||
|
||||
# 默认预设方式
|
||||
elif preset_mode == 'default':
|
||||
|
||||
default_dict = get_prompt_dict()
|
||||
|
||||
for key in default_dict:
|
||||
if key.lower().startswith(name.lower()):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": default_dict[key]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "好的。"
|
||||
}
|
||||
]
|
||||
|
||||
raise KeyError("未找到默认情景预设: " + name)
|
||||
@@ -2,27 +2,29 @@
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
import pkg.database.manager
|
||||
import pkg.qqbot.manager
|
||||
import pkg.utils.context
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
|
||||
|
||||
class KeysManager:
|
||||
api_key = {}
|
||||
|
||||
# api-key的使用量
|
||||
# 其中键为api-key的md5值,值为使用量
|
||||
fee = {}
|
||||
|
||||
api_key_fee_threshold = 18.0
|
||||
"""所有api-key"""
|
||||
|
||||
using_key = ""
|
||||
"""当前使用的api-key
|
||||
"""
|
||||
|
||||
alerted = []
|
||||
"""已提示过超额的key
|
||||
|
||||
记录在此以避免重复提示
|
||||
"""
|
||||
|
||||
# 在此list中的都是经超额报错标记过的api-key
|
||||
# 记录的是key值,仅在运行时有效
|
||||
exceeded = []
|
||||
"""已超额的key
|
||||
|
||||
供自动切换功能识别
|
||||
"""
|
||||
|
||||
def get_using_key(self):
|
||||
return self.using_key
|
||||
@@ -31,13 +33,6 @@ class KeysManager:
|
||||
return hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
|
||||
|
||||
def __init__(self, api_key):
|
||||
# if hasattr(config, 'api_key_usage_threshold'):
|
||||
# self.api_key_usage_threshold = config.api_key_usage_threshold
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
if hasattr(config, 'api_key_fee_threshold'):
|
||||
self.api_key_fee_threshold = config.api_key_fee_threshold
|
||||
self.load_fee()
|
||||
|
||||
if type(api_key) is dict:
|
||||
self.api_key = api_key
|
||||
@@ -48,27 +43,32 @@ class KeysManager:
|
||||
elif type(api_key) is list:
|
||||
for i in range(len(api_key)):
|
||||
self.api_key[str(i)] = api_key[i]
|
||||
|
||||
self.auto_switch()
|
||||
# 从usage中删除未加载的api-key的记录
|
||||
# 不删了,也许会运行时添加曾经有记录的api-key
|
||||
|
||||
if 'exceeded_keys' in pkg.utils.context.context and pkg.utils.context.context['exceeded_keys'] is not None:
|
||||
self.exceeded = pkg.utils.context.context['exceeded_keys']
|
||||
self.auto_switch()
|
||||
|
||||
# 根据tested自动切换到可用的api-key
|
||||
# 返回是否切换成功, 切换后的api-key的别名
|
||||
def auto_switch(self) -> (bool, str):
|
||||
self.dump_fee()
|
||||
"""尝试切换api-key
|
||||
|
||||
Returns:
|
||||
是否切换成功, 切换后的api-key的别名
|
||||
"""
|
||||
|
||||
for key_name in self.api_key:
|
||||
if self.api_key[key_name] not in self.exceeded:
|
||||
self.using_key = self.api_key[key_name]
|
||||
|
||||
logging.info("使用api-key:" + key_name)
|
||||
|
||||
# 触发插件事件
|
||||
args = {
|
||||
"key_name": key_name,
|
||||
"key_list": self.api_key.keys()
|
||||
}
|
||||
_ = plugin_host.emit(plugin_models.KeySwitched, **args)
|
||||
|
||||
return True, key_name
|
||||
# if self.get_fee(self.api_key[key_name]) < self.api_key_fee_threshold:
|
||||
# self.using_key = self.api_key[key_name]
|
||||
# logging.info("使用api-key:" + key_name)
|
||||
# return True, key_name
|
||||
|
||||
self.using_key = list(self.api_key.values())[0]
|
||||
logging.info("使用api-key:" + list(self.api_key.keys())[0])
|
||||
@@ -78,14 +78,10 @@ class KeysManager:
|
||||
def add(self, key_name, key):
|
||||
self.api_key[key_name] = key
|
||||
|
||||
# 设置当前使用的api-key使用量超限
|
||||
# 这是在尝试调用api时发生超限异常时调用的
|
||||
def set_current_exceeded(self):
|
||||
# md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
|
||||
# self.usage[md5] = self.api_key_usage_threshold
|
||||
# self.fee[md5] = self.api_key_fee_threshold
|
||||
"""设置当前使用的api-key使用量超限
|
||||
"""
|
||||
self.exceeded.append(self.using_key)
|
||||
self.dump_fee()
|
||||
|
||||
def get_key_name(self, api_key):
|
||||
"""根据api-key获取其别名"""
|
||||
@@ -93,45 +89,3 @@ class KeysManager:
|
||||
if self.api_key[key_name] == api_key:
|
||||
return key_name
|
||||
return ""
|
||||
|
||||
def get_fee(self, api_key):
|
||||
md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest()
|
||||
if md5 not in self.fee:
|
||||
self.fee[md5] = 0
|
||||
return self.fee[md5]
|
||||
|
||||
def report_fee(self, fee: float) -> bool:
|
||||
logging.debug("report fee:" + str(fee))
|
||||
|
||||
md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
|
||||
if md5 not in self.fee:
|
||||
self.fee[md5] = 0
|
||||
|
||||
self.fee[md5] += fee
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
if self.fee[md5] >= self.api_key_fee_threshold and \
|
||||
hasattr(config, 'auto_switch_api_key') and config.auto_switch_api_key:
|
||||
switch_result, key_name = self.auto_switch()
|
||||
|
||||
# 检查是否切换到新的
|
||||
if switch_result:
|
||||
if key_name not in self.alerted:
|
||||
# 通知管理员
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("api-key已切换到:" + key_name)
|
||||
self.alerted.append(key_name)
|
||||
return True
|
||||
else:
|
||||
if key_name not in self.alerted:
|
||||
# 通知管理员
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("api-key已用完,无未使用的api-key可供切换")
|
||||
self.alerted.append(key_name)
|
||||
return False
|
||||
return False
|
||||
|
||||
def dump_fee(self):
|
||||
pkg.utils.context.get_database_manager().dump_api_key_fee(api_keys=self.api_key, fee=self.fee)
|
||||
|
||||
def load_fee(self):
|
||||
self.fee = pkg.utils.context.get_database_manager().load_api_key_fee()
|
||||
logging.info("load fee:" + str(self.fee))
|
||||
|
||||
@@ -3,14 +3,16 @@ import logging
|
||||
import openai
|
||||
|
||||
import pkg.openai.keymgr
|
||||
import pkg.openai.pricing as pricing
|
||||
import pkg.utils.context
|
||||
import pkg.audit.gatherer
|
||||
from pkg.openai.modelmgr import ModelRequest, create_openai_model_request
|
||||
|
||||
|
||||
# 为其他模块提供与OpenAI交互的接口
|
||||
class OpenAIInteract:
|
||||
api_params = {}
|
||||
"""OpenAI 接口封装
|
||||
|
||||
将文字接口和图片接口封装供调用方使用
|
||||
"""
|
||||
|
||||
key_mgr: pkg.openai.keymgr.KeysManager = None
|
||||
|
||||
@@ -21,37 +23,61 @@ class OpenAIInteract:
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
# self.api_key = api_key
|
||||
|
||||
self.key_mgr = pkg.openai.keymgr.KeysManager(api_key)
|
||||
self.audit_mgr = pkg.audit.gatherer.DataGatherer()
|
||||
|
||||
logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length())
|
||||
|
||||
openai.api_key = self.key_mgr.get_using_key()
|
||||
|
||||
pkg.utils.context.set_openai_manager(self)
|
||||
|
||||
# 请求OpenAI Completion
|
||||
def request_completion(self, prompt, stop):
|
||||
def request_completion(self, prompts) -> str:
|
||||
"""请求补全接口回复
|
||||
|
||||
Parameters:
|
||||
prompts (str): 提示语
|
||||
|
||||
Returns:
|
||||
str: 回复
|
||||
"""
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
timeout=config.process_message_timeout,
|
||||
|
||||
# 根据模型选择使用的接口
|
||||
ai: ModelRequest = create_openai_model_request(
|
||||
config.completion_api_params['model'],
|
||||
'user',
|
||||
config.openai_config["http_proxy"] if "http_proxy" in config.openai_config else None
|
||||
)
|
||||
ai.request(
|
||||
prompts,
|
||||
**config.completion_api_params
|
||||
)
|
||||
response = ai.get_response()
|
||||
|
||||
self.audit_mgr.report_text_model_usage(config.completion_api_params['model'],
|
||||
prompt + response['choices'][0]['text'])
|
||||
logging.debug("OpenAI response: %s", response)
|
||||
|
||||
switched = self.key_mgr.report_fee(pricing.language_base_price(config.completion_api_params['model'],
|
||||
prompt + response['choices'][0]['text']))
|
||||
if switched:
|
||||
openai.api_key = self.key_mgr.get_using_key()
|
||||
if 'model' in config.completion_api_params:
|
||||
self.audit_mgr.report_text_model_usage(config.completion_api_params['model'],
|
||||
ai.get_total_tokens())
|
||||
elif 'engine' in config.completion_api_params:
|
||||
self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'],
|
||||
response['usage']['total_tokens'])
|
||||
|
||||
return response
|
||||
return ai.get_message()
|
||||
|
||||
def request_image(self, prompt):
|
||||
def request_image(self, prompt) -> dict:
|
||||
"""请求图片接口回复
|
||||
|
||||
Parameters:
|
||||
prompt (str): 提示语
|
||||
|
||||
Returns:
|
||||
dict: 响应
|
||||
"""
|
||||
config = pkg.utils.context.get_config()
|
||||
params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params
|
||||
|
||||
@@ -63,10 +89,5 @@ class OpenAIInteract:
|
||||
|
||||
self.audit_mgr.report_image_model_usage(params['size'])
|
||||
|
||||
switched = self.key_mgr.report_fee(pricing.image_price(params['size']))
|
||||
|
||||
if switched:
|
||||
openai.api_key = self.key_mgr.get_using_key()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -1,7 +1,26 @@
|
||||
# 提供与模型交互的抽象接口
|
||||
"""OpenAI 接口底层封装
|
||||
|
||||
目前使用的对话接口有:
|
||||
ChatCompletion - gpt-3.5-turbo 等模型
|
||||
Completion - text-davinci-003 等模型
|
||||
此模块封装此两个接口的请求实现,为上层提供统一的调用方式
|
||||
"""
|
||||
import openai, logging, threading, asyncio
|
||||
import openai.error as aiE
|
||||
|
||||
COMPLETION_MODELS = {
|
||||
'text-davinci-003'
|
||||
'text-davinci-003',
|
||||
'text-davinci-002',
|
||||
'code-davinci-002',
|
||||
'code-cushman-001',
|
||||
'text-curie-001',
|
||||
'text-babbage-001',
|
||||
'text-ada-001',
|
||||
}
|
||||
|
||||
CHAT_COMPLETION_MODELS = {
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-0301',
|
||||
}
|
||||
|
||||
EDIT_MODELS = {
|
||||
@@ -13,22 +32,153 @@ IMAGE_MODELS = {
|
||||
}
|
||||
|
||||
|
||||
# ModelManager
|
||||
# 由session包含
|
||||
class ModelMgr(object):
|
||||
class ModelRequest:
|
||||
"""模型接口请求父类"""
|
||||
|
||||
using_completion_model = ""
|
||||
using_edit_model = ""
|
||||
using_image_model = ""
|
||||
can_chat = False
|
||||
runtime: threading.Thread = None
|
||||
ret = {}
|
||||
proxy: str = None
|
||||
request_ready = True
|
||||
error_info: str = "若在没有任何错误的情况下看到这句话,请带着配置文件上报Issues"
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, model_name, user_name, request_fun, http_proxy:str = None, time_out = None):
|
||||
self.model_name = model_name
|
||||
self.user_name = user_name
|
||||
self.request_fun = request_fun
|
||||
self.time_out = time_out
|
||||
if http_proxy != None:
|
||||
self.proxy = http_proxy
|
||||
openai.proxy = self.proxy
|
||||
self.request_ready = False
|
||||
|
||||
def get_using_completion_model(self):
|
||||
return self.using_completion_model
|
||||
async def __a_request__(self, **kwargs):
|
||||
"""异步请求"""
|
||||
|
||||
def get_using_edit_model(self):
|
||||
return self.using_edit_model
|
||||
try:
|
||||
self.ret:dict = await self.request_fun(**kwargs)
|
||||
self.request_ready = True
|
||||
except aiE.APIConnectionError as e:
|
||||
self.error_info = "{}\n请检查网络连接或代理是否正常".format(e)
|
||||
raise ConnectionError(self.error_info)
|
||||
except ValueError as e:
|
||||
self.error_info = "{}\n该错误可能是由于http_proxy格式设置错误引起的"
|
||||
except Exception as e:
|
||||
self.error_info = "{}\n由于请求异常产生的未知错误,请查看日志".format(e)
|
||||
raise Exception(self.error_info)
|
||||
|
||||
def get_using_image_model(self):
|
||||
return self.using_image_model
|
||||
def request(self, **kwargs):
|
||||
"""向接口发起请求"""
|
||||
|
||||
if self.proxy != None: #异步请求
|
||||
self.request_ready = False
|
||||
loop = asyncio.new_event_loop()
|
||||
self.runtime = threading.Thread(
|
||||
target=loop.run_until_complete,
|
||||
args=(self.__a_request__(**kwargs),)
|
||||
)
|
||||
self.runtime.start()
|
||||
else: #同步请求
|
||||
self.ret = self.request_fun(**kwargs)
|
||||
|
||||
def __msg_handle__(self, msg):
|
||||
"""将prompt dict转换成接口需要的格式"""
|
||||
return msg
|
||||
|
||||
def ret_handle(self):
|
||||
'''
|
||||
API消息返回处理函数
|
||||
若重写该方法,应检查异步线程状态,或在需要检查处super该方法
|
||||
'''
|
||||
if self.runtime != None and isinstance(self.runtime, threading.Thread):
|
||||
self.runtime.join(self.time_out)
|
||||
if self.request_ready:
|
||||
return
|
||||
raise Exception(self.error_info)
|
||||
|
||||
def get_total_tokens(self):
|
||||
try:
|
||||
return self.ret['usage']['total_tokens']
|
||||
except:
|
||||
return 0
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def get_response(self):
|
||||
return self.ret
|
||||
|
||||
|
||||
class ChatCompletionModel(ModelRequest):
|
||||
"""ChatCompletion接口的请求实现"""
|
||||
|
||||
Chat_role = ['system', 'user', 'assistant']
|
||||
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
|
||||
if http_proxy == None:
|
||||
request_fun = openai.ChatCompletion.create
|
||||
else:
|
||||
request_fun = openai.ChatCompletion.acreate
|
||||
self.can_chat = True
|
||||
super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs)
|
||||
|
||||
def request(self, prompts, **kwargs):
|
||||
prompts = self.__msg_handle__(prompts)
|
||||
kwargs['messages'] = prompts
|
||||
super().request(**kwargs)
|
||||
self.ret_handle()
|
||||
|
||||
def __msg_handle__(self, msgs):
|
||||
temp_msgs = []
|
||||
# 把msgs拷贝进temp_msgs
|
||||
for msg in msgs:
|
||||
temp_msgs.append(msg.copy())
|
||||
return temp_msgs
|
||||
|
||||
def get_message(self):
|
||||
return self.ret["choices"][0]["message"]['content'] #需要时直接加载加快请求速度,降低内存消耗
|
||||
|
||||
|
||||
class CompletionModel(ModelRequest):
|
||||
"""Completion接口的请求实现"""
|
||||
|
||||
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
|
||||
if http_proxy == None:
|
||||
request_fun = openai.Completion.create
|
||||
else:
|
||||
request_fun = openai.Completion.acreate
|
||||
super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs)
|
||||
|
||||
def request(self, prompts, **kwargs):
|
||||
prompts = self.__msg_handle__(prompts)
|
||||
kwargs['prompt'] = prompts
|
||||
super().request(**kwargs)
|
||||
self.ret_handle()
|
||||
|
||||
def __msg_handle__(self, msgs):
|
||||
prompt = ''
|
||||
for msg in msgs:
|
||||
prompt = prompt + "{}: {}\n".format(msg['role'], msg['content'])
|
||||
# for msg in msgs:
|
||||
# if msg['role'] == 'assistant':
|
||||
# prompt = prompt + "{}\n".format(msg['content'])
|
||||
# else:
|
||||
# prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content'])
|
||||
prompt = prompt + "assistant: "
|
||||
return prompt
|
||||
|
||||
def get_message(self):
|
||||
return self.ret["choices"][0]["text"]
|
||||
|
||||
|
||||
def create_openai_model_request(model_name: str, user_name: str = 'user', http_proxy:str = None) -> ModelRequest:
|
||||
"""使用给定的模型名称创建模型请求对象"""
|
||||
if model_name in CHAT_COMPLETION_MODELS:
|
||||
model = ChatCompletionModel(model_name, user_name, http_proxy)
|
||||
elif model_name in COMPLETION_MODELS:
|
||||
model = CompletionModel(model_name, user_name, http_proxy)
|
||||
else :
|
||||
log = "找不到模型[{}],请检查配置文件".format(model_name)
|
||||
logging.error(log)
|
||||
raise IndexError(log)
|
||||
logging.debug("使用接口[{}]创建模型请求[{}]".format(model.__class__.__name__, model_name))
|
||||
return model
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# 计费模块
|
||||
# 已弃用 https://github.com/RockChinQ/QChatGPT/issues/81
|
||||
|
||||
import logging
|
||||
|
||||
pricing = {
|
||||
@@ -1,11 +1,21 @@
|
||||
"""主线使用的会话管理模块
|
||||
|
||||
每个人、每个群单独一个session,session内部保留了对话的上下文,
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import json
|
||||
|
||||
import pkg.openai.manager
|
||||
import pkg.openai.modelmgr
|
||||
import pkg.database.manager
|
||||
import pkg.utils.context
|
||||
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
|
||||
# 运行时保存的所有session
|
||||
sessions = {}
|
||||
|
||||
@@ -15,8 +25,38 @@ class SessionOfflineStatus:
|
||||
EXPLICITLY_CLOSED = 'explicitly_closed'
|
||||
|
||||
|
||||
# 重置session.prompt
|
||||
def reset_session_prompt(session_name, prompt):
|
||||
# 备份原始数据
|
||||
bak_path = 'logs/{}-{}.bak'.format(
|
||||
session_name,
|
||||
time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
||||
)
|
||||
f = open(bak_path, 'w+')
|
||||
f.write(prompt)
|
||||
f.close()
|
||||
# 生成新数据
|
||||
config = pkg.utils.context.get_config()
|
||||
prompt = [
|
||||
{
|
||||
'role': 'system',
|
||||
'content': config.default_prompt['default'] if type(config.default_prompt) == dict else config.default_prompt
|
||||
}
|
||||
]
|
||||
# 警告
|
||||
logging.warning(
|
||||
"""
|
||||
用户[{}]的数据已被重置,有可能是因为数据版本过旧或存储错误
|
||||
原始数据将备份在:
|
||||
{}""".format(session_name, bak_path)
|
||||
) # 为保证多行文本格式正确故无缩进
|
||||
return prompt
|
||||
|
||||
|
||||
# 从数据加载session
|
||||
def load_sessions():
|
||||
"""从数据库加载sessions"""
|
||||
|
||||
global sessions
|
||||
|
||||
db_inst = pkg.utils.context.get_database_manager()
|
||||
@@ -30,7 +70,13 @@ def load_sessions():
|
||||
temp_session.name = session_name
|
||||
temp_session.create_timestamp = session_data[session_name]['create_timestamp']
|
||||
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
|
||||
temp_session.prompt = session_data[session_name]['prompt']
|
||||
try:
|
||||
temp_session.prompt = json.loads(session_data[session_name]['prompt'])
|
||||
except Exception:
|
||||
temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt'])
|
||||
temp_session.persistence()
|
||||
temp_session.default_prompt = json.loads(session_data[session_name]['default_prompt']) if \
|
||||
session_data[session_name]['default_prompt'] else []
|
||||
|
||||
sessions[session_name] = temp_session
|
||||
|
||||
@@ -51,38 +97,23 @@ def dump_session(session_name: str):
|
||||
del sessions[session_name]
|
||||
|
||||
|
||||
# def blocked_func(lock: threading.Lock):
|
||||
#
|
||||
# def decorator(func):
|
||||
# def wrapper(*args, **kwargs):
|
||||
# print('lock acquire,{}'.format(lock))
|
||||
# lock.acquire()
|
||||
# try:
|
||||
# return func(*args, **kwargs)
|
||||
# finally:
|
||||
# lock.release()
|
||||
#
|
||||
# return wrapper
|
||||
#
|
||||
# return decorator
|
||||
|
||||
|
||||
# 通用的OpenAI API交互session
|
||||
# session内部保留了对话的上下文,
|
||||
# 收到用户消息后,将上下文提交给OpenAI API生成回复
|
||||
class Session:
|
||||
name = ''
|
||||
|
||||
prompt = ""
|
||||
prompt = []
|
||||
"""使用list来保存会话中的回合"""
|
||||
|
||||
import config
|
||||
|
||||
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
|
||||
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
|
||||
default_prompt = []
|
||||
"""本session的默认prompt"""
|
||||
|
||||
create_timestamp = 0
|
||||
"""会话创建时间"""
|
||||
|
||||
last_interact_timestamp = 0
|
||||
"""上次交互(产生回复)时间"""
|
||||
|
||||
just_switched_to_exist_session = False
|
||||
|
||||
@@ -102,13 +133,14 @@ class Session:
|
||||
logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock))
|
||||
|
||||
# 从配置文件获取会话预设信息
|
||||
def get_default_prompt(self):
|
||||
config = pkg.utils.context.get_config()
|
||||
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
|
||||
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
|
||||
return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \
|
||||
and config.default_prompt != "" else '') + \
|
||||
bot_name + ":好的\n"
|
||||
def get_default_prompt(self, use_default: str = None):
|
||||
import pkg.openai.dprompt as dprompt
|
||||
|
||||
if use_default is None:
|
||||
use_default = dprompt.get_current()
|
||||
|
||||
current_default_prompt = dprompt.get_prompt(use_default)
|
||||
return current_default_prompt
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
@@ -117,7 +149,9 @@ class Session:
|
||||
self.schedule()
|
||||
|
||||
self.response_lock = threading.Lock()
|
||||
self.prompt = self.get_default_prompt()
|
||||
|
||||
self.default_prompt = self.get_default_prompt()
|
||||
logging.debug("prompt is: {}".format(self.default_prompt))
|
||||
|
||||
# 设定检查session最后一次对话是否超过过期时间的计时器
|
||||
def schedule(self):
|
||||
@@ -136,6 +170,17 @@ class Session:
|
||||
config = pkg.utils.context.get_config()
|
||||
if int(time.time()) - self.last_interact_timestamp > config.session_expire_time:
|
||||
logging.info('session {} 已过期'.format(self.name))
|
||||
|
||||
# 触发插件事件
|
||||
args = {
|
||||
'session_name': self.name,
|
||||
'session': self,
|
||||
'session_expire_time': config.session_expire_time
|
||||
}
|
||||
event = pkg.plugin.host.emit(plugin_models.SessionExpired, **args)
|
||||
if event.is_prevented_default():
|
||||
return
|
||||
|
||||
self.reset(expired=True, schedule_new=False)
|
||||
|
||||
# 删除此session
|
||||
@@ -145,24 +190,32 @@ class Session:
|
||||
# 请求回复
|
||||
# 这个函数是阻塞的
|
||||
def append(self, text: str) -> str:
|
||||
"""向session中添加一条消息,返回接口回复"""
|
||||
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
|
||||
# max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7
|
||||
# 触发插件事件
|
||||
if not self.prompt:
|
||||
args = {
|
||||
'session_name': self.name,
|
||||
'session': self,
|
||||
'default_prompt': self.default_prompt,
|
||||
}
|
||||
|
||||
event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args)
|
||||
if event.is_prevented_default():
|
||||
return None
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
max_rounds = 1000 # 不再限制回合数
|
||||
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
|
||||
|
||||
# 向API请求补全
|
||||
response = pkg.utils.context.get_openai_manager().request_completion(
|
||||
self.cut_out(self.prompt + self.user_name + ':' +
|
||||
text + '\n' + self.bot_name + ':',
|
||||
max_rounds, max_length),
|
||||
self.user_name + ':')
|
||||
message = pkg.utils.context.get_openai_manager().request_completion(
|
||||
self.cut_out(text, max_length),
|
||||
)
|
||||
|
||||
self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':'
|
||||
# print(response)
|
||||
# 处理回复
|
||||
res_test = response["choices"][0]["text"]
|
||||
# 成功获取,处理回复
|
||||
res_test = message
|
||||
res_ans = res_test
|
||||
|
||||
# 去除开头可能的提示
|
||||
@@ -171,38 +224,59 @@ class Session:
|
||||
del (res_ans_spt[0])
|
||||
res_ans = '\n\n'.join(res_ans_spt)
|
||||
|
||||
self.prompt += "{}".format(res_ans) + '\n'
|
||||
# 将此次对话的双方内容加入到prompt中
|
||||
self.prompt.append({'role': 'user', 'content': text})
|
||||
self.prompt.append({'role': 'assistant', 'content': res_ans})
|
||||
|
||||
if self.just_switched_to_exist_session:
|
||||
self.just_switched_to_exist_session = False
|
||||
self.set_ongoing()
|
||||
|
||||
return res_ans
|
||||
return res_ans if res_ans[0] != '\n' else res_ans[1:]
|
||||
|
||||
# 从尾部截取prompt里不多于max_rounds个回合,长度不大于max_tokens的字符串
|
||||
# 保证都是完整的对话
|
||||
def cut_out(self, prompt: str, max_rounds: int, max_tokens: int) -> str:
|
||||
# 分隔出每个回合
|
||||
rounds_spt_by_user_name = prompt.split(self.user_name + ':')
|
||||
# 删除上一回合并返回上一回合的问题
|
||||
def undo(self) -> str:
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
|
||||
result = ''
|
||||
# 删除最后两个消息
|
||||
if len(self.prompt) < 2:
|
||||
raise Exception('之前无对话,无法撤销')
|
||||
|
||||
checked_rounds = 0
|
||||
# 从后往前遍历,加到result前面,检查result是否符合要求
|
||||
for i in range(len(rounds_spt_by_user_name) - 1, 0, -1):
|
||||
result_temp = self.user_name + ':' + rounds_spt_by_user_name[i] + result
|
||||
checked_rounds += 1
|
||||
question = self.prompt[-2]['content']
|
||||
self.prompt = self.prompt[:-2]
|
||||
|
||||
if checked_rounds > max_rounds:
|
||||
# 返回上一回合的问题
|
||||
return question
|
||||
|
||||
# 构建对话体
|
||||
def cut_out(self, msg: str, max_tokens: int) -> list:
|
||||
"""将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens"""
|
||||
# 如果用户消息长度超过max_tokens,直接返回
|
||||
temp_prompt: list = []
|
||||
temp_prompt += self.default_prompt
|
||||
temp_prompt.append(
|
||||
{
|
||||
'role': 'user',
|
||||
'content': msg
|
||||
}
|
||||
)
|
||||
|
||||
token_count = 0
|
||||
for item in temp_prompt:
|
||||
token_count += len(item['content'])
|
||||
|
||||
# 倒序遍历prompt
|
||||
for i in range(len(self.prompt) - 1, -1, -1):
|
||||
if token_count >= max_tokens:
|
||||
break
|
||||
|
||||
if int((len(result_temp.encode('utf-8')) - len(result_temp)) / 2 + len(result_temp)) > max_tokens:
|
||||
break
|
||||
# 将prompt加到temp_prompt倒数第二个位置
|
||||
temp_prompt.insert(len(self.default_prompt), self.prompt[i])
|
||||
token_count += len(self.prompt[i]['content'])
|
||||
|
||||
result = result_temp
|
||||
logging.debug('cut_out: {}'.format(json.dumps(temp_prompt, ensure_ascii=False, indent=4)))
|
||||
|
||||
logging.debug('cut_out: {}'.format(result))
|
||||
return result
|
||||
return temp_prompt
|
||||
|
||||
# 持久化session
|
||||
def persistence(self):
|
||||
@@ -217,18 +291,29 @@ class Session:
|
||||
subject_number = int(name_spt[1])
|
||||
|
||||
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
|
||||
self.prompt)
|
||||
json.dumps(self.prompt), json.dumps(self.default_prompt))
|
||||
|
||||
# 重置session
|
||||
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True):
|
||||
if self.prompt != self.get_default_prompt():
|
||||
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None):
|
||||
if self.prompt:
|
||||
self.persistence()
|
||||
if explicit:
|
||||
# 触发插件事件
|
||||
args = {
|
||||
'session_name': self.name,
|
||||
'session': self
|
||||
}
|
||||
|
||||
# 此事件不支持阻止默认行为
|
||||
_ = pkg.plugin.host.emit(plugin_models.SessionExplicitReset, **args)
|
||||
|
||||
pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
|
||||
|
||||
if expired:
|
||||
pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
|
||||
self.prompt = self.get_default_prompt()
|
||||
|
||||
self.default_prompt = self.get_default_prompt(use_prompt)
|
||||
self.prompt = []
|
||||
self.create_timestamp = int(time.time())
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
self.just_switched_to_exist_session = False
|
||||
@@ -252,7 +337,12 @@ class Session:
|
||||
|
||||
self.create_timestamp = last_one['create_timestamp']
|
||||
self.last_interact_timestamp = last_one['last_interact_timestamp']
|
||||
self.prompt = last_one['prompt']
|
||||
try:
|
||||
self.prompt = json.loads(last_one['prompt'])
|
||||
except json.decoder.JSONDecodeError:
|
||||
self.prompt = reset_session_prompt(self.name, last_one['prompt'])
|
||||
self.persistence()
|
||||
self.default_prompt = json.loads(last_one['default_prompt']) if last_one['default_prompt'] else []
|
||||
|
||||
self.just_switched_to_exist_session = True
|
||||
return self
|
||||
@@ -267,14 +357,24 @@ class Session:
|
||||
|
||||
self.create_timestamp = next_one['create_timestamp']
|
||||
self.last_interact_timestamp = next_one['last_interact_timestamp']
|
||||
self.prompt = next_one['prompt']
|
||||
try:
|
||||
self.prompt = json.loads(next_one['prompt'])
|
||||
except json.decoder.JSONDecodeError:
|
||||
self.prompt = reset_session_prompt(self.name, next_one['prompt'])
|
||||
self.persistence()
|
||||
self.default_prompt = json.loads(next_one['default_prompt']) if next_one['default_prompt'] else []
|
||||
|
||||
self.just_switched_to_exist_session = True
|
||||
return self
|
||||
|
||||
def list_history(self, capacity: int = 10, page: int = 0):
|
||||
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page,
|
||||
self.get_default_prompt())
|
||||
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page)
|
||||
|
||||
def delete_history(self, index: int) -> bool:
|
||||
return pkg.utils.context.get_database_manager().delete_history(self.name, index)
|
||||
|
||||
def delete_all_history(self) -> bool:
|
||||
return pkg.utils.context.get_database_manager().delete_all_history(self.name)
|
||||
|
||||
def draw_image(self, prompt: str):
|
||||
return pkg.utils.context.get_openai_manager().request_image(prompt)
|
||||
|
||||
4
pkg/plugin/__init__.py
Normal file
4
pkg/plugin/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""插件支持包
|
||||
|
||||
包含插件基类、插件宿主以及部分API接口
|
||||
"""
|
||||
325
pkg/plugin/host.py
Normal file
325
pkg/plugin/host.py
Normal file
@@ -0,0 +1,325 @@
|
||||
# 插件管理模块
|
||||
import asyncio
|
||||
import logging
|
||||
import importlib
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import pkg.utils.context as context
|
||||
import pkg.plugin.switch as switch
|
||||
import pkg.plugin.settings as settings
|
||||
|
||||
from mirai import Mirai
|
||||
|
||||
__plugins__ = {}
|
||||
"""
|
||||
插件列表
|
||||
|
||||
示例:
|
||||
{
|
||||
"example": {
|
||||
"path": "plugins/example/main.py",
|
||||
"enabled: True,
|
||||
"name": "example",
|
||||
"description": "example",
|
||||
"version": "0.0.1",
|
||||
"author": "RockChinQ",
|
||||
"class": <class 'plugins.example.ExamplePlugin'>,
|
||||
"hooks": {
|
||||
"person_message": [
|
||||
<function ExamplePlugin.person_message at 0x0000020E1D1B8D38>
|
||||
]
|
||||
},
|
||||
"instance": None
|
||||
}
|
||||
}"""
|
||||
|
||||
__plugins_order__ = []
|
||||
"""插件顺序"""
|
||||
|
||||
|
||||
def generate_plugin_order():
|
||||
""" 根据__plugin__生成插件初始顺序,无视是否启用 """
|
||||
global __plugins_order__
|
||||
__plugins_order__ = []
|
||||
for plugin_name in __plugins__:
|
||||
__plugins_order__.append(plugin_name)
|
||||
|
||||
|
||||
def iter_plugins():
|
||||
""" 按照顺序迭代插件 """
|
||||
for plugin_name in __plugins_order__:
|
||||
yield __plugins__[plugin_name]
|
||||
|
||||
|
||||
def iter_plugins_name():
|
||||
""" 迭代插件名 """
|
||||
for plugin_name in __plugins_order__:
|
||||
yield plugin_name
|
||||
|
||||
|
||||
__current_module_path__ = ""
|
||||
|
||||
|
||||
def walk_plugin_path(module, prefix='', path_prefix=''):
|
||||
global __current_module_path__
|
||||
"""遍历插件路径"""
|
||||
for item in pkgutil.iter_modules(module.__path__):
|
||||
if item.ispkg:
|
||||
logging.debug("扫描插件包: plugins/{}".format(path_prefix + item.name))
|
||||
walk_plugin_path(__import__(module.__name__ + '.' + item.name, fromlist=['']),
|
||||
prefix + item.name + '.', path_prefix + item.name + '/')
|
||||
else:
|
||||
try:
|
||||
logging.debug("扫描插件模块: plugins/{}".format(path_prefix + item.name + '.py'))
|
||||
__current_module_path__ = "plugins/"+path_prefix + item.name + '.py'
|
||||
|
||||
importlib.import_module(module.__name__ + '.' + item.name)
|
||||
logging.info('加载模块: plugins/{} 成功'.format(path_prefix + item.name + '.py'))
|
||||
except:
|
||||
logging.error('加载模块: plugins/{} 失败: {}'.format(path_prefix + item.name + '.py', sys.exc_info()))
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def load_plugins():
|
||||
""" 加载插件 """
|
||||
logging.info("加载插件")
|
||||
PluginHost()
|
||||
walk_plugin_path(__import__('plugins'))
|
||||
|
||||
logging.debug(__plugins__)
|
||||
|
||||
# 加载开关数据
|
||||
switch.load_switch()
|
||||
|
||||
# 生成初始顺序
|
||||
generate_plugin_order()
|
||||
# 加载插件顺序
|
||||
settings.load_settings()
|
||||
|
||||
|
||||
def initialize_plugins():
|
||||
""" 初始化插件 """
|
||||
logging.info("初始化插件")
|
||||
import pkg.plugin.models as models
|
||||
for plugin in iter_plugins():
|
||||
if not plugin['enabled']:
|
||||
continue
|
||||
try:
|
||||
models.__current_registering_plugin__ = plugin['name']
|
||||
plugin['instance'] = plugin["class"](plugin_host=context.get_plugin_host())
|
||||
logging.info("插件 {} 已初始化".format(plugin['name']))
|
||||
except:
|
||||
logging.error("插件{}初始化时发生错误: {}".format(plugin['name'], sys.exc_info()))
|
||||
|
||||
|
||||
def unload_plugins():
|
||||
""" 卸载插件
|
||||
"""
|
||||
# 不再显式卸载插件,因为当程序结束时,插件的析构函数会被系统执行
|
||||
# for plugin in __plugins__.values():
|
||||
# if plugin['enabled'] and plugin['instance'] is not None:
|
||||
# if not hasattr(plugin['instance'], '__del__'):
|
||||
# logging.warning("插件{}没有定义析构函数".format(plugin['name']))
|
||||
# else:
|
||||
# try:
|
||||
# plugin['instance'].__del__()
|
||||
# logging.info("卸载插件: {}".format(plugin['name']))
|
||||
# plugin['instance'] = None
|
||||
# except:
|
||||
# logging.error("插件{}卸载时发生错误: {}".format(plugin['name'], sys.exc_info()))
|
||||
|
||||
|
||||
def install_plugin(repo_url: str):
|
||||
""" 安装插件,从git储存库获取并解决依赖 """
|
||||
try:
|
||||
import pkg.utils.pkgmgr
|
||||
pkg.utils.pkgmgr.ensure_dulwich()
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
import dulwich
|
||||
except ModuleNotFoundError:
|
||||
raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
|
||||
|
||||
from dulwich import porcelain
|
||||
|
||||
logging.info("克隆插件储存库: {}".format(repo_url))
|
||||
repo = porcelain.clone(repo_url, "plugins/"+repo_url.split(".git")[0].split("/")[-1]+"/", checkout=True)
|
||||
|
||||
# 检查此目录是否包含requirements.txt
|
||||
if os.path.exists("plugins/"+repo_url.split(".git")[0].split("/")[-1]+"/requirements.txt"):
|
||||
logging.info("检测到requirements.txt,正在安装依赖")
|
||||
import pkg.utils.pkgmgr
|
||||
pkg.utils.pkgmgr.install_requirements("plugins/"+repo_url.split(".git")[0].split("/")[-1]+"/requirements.txt")
|
||||
|
||||
import main
|
||||
main.reset_logging()
|
||||
|
||||
|
||||
class EventContext:
|
||||
""" 事件上下文 """
|
||||
eid = 0
|
||||
"""事件编号"""
|
||||
|
||||
name = ""
|
||||
|
||||
__prevent_default__ = False
|
||||
""" 是否阻止默认行为 """
|
||||
|
||||
__prevent_postorder__ = False
|
||||
""" 是否阻止后续插件的执行 """
|
||||
|
||||
__return_value__ = {}
|
||||
""" 返回值
|
||||
示例:
|
||||
{
|
||||
"example": [
|
||||
'value1',
|
||||
'value2',
|
||||
3,
|
||||
4,
|
||||
{
|
||||
'key1': 'value1',
|
||||
},
|
||||
['value1', 'value2']
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
def add_return(self, key: str, ret):
|
||||
"""添加返回值"""
|
||||
if key not in self.__return_value__:
|
||||
self.__return_value__[key] = []
|
||||
self.__return_value__[key].append(ret)
|
||||
|
||||
def get_return(self, key: str):
|
||||
"""获取key的所有返回值"""
|
||||
if key in self.__return_value__:
|
||||
return self.__return_value__[key]
|
||||
return None
|
||||
|
||||
def get_return_value(self, key: str):
|
||||
"""获取key的首个返回值"""
|
||||
if key in self.__return_value__:
|
||||
return self.__return_value__[key][0]
|
||||
return None
|
||||
|
||||
def prevent_default(self):
|
||||
"""阻止默认行为"""
|
||||
self.__prevent_default__ = True
|
||||
|
||||
def prevent_postorder(self):
|
||||
"""阻止后续插件执行"""
|
||||
self.__prevent_postorder__ = True
|
||||
|
||||
def is_prevented_default(self):
|
||||
"""是否阻止默认行为"""
|
||||
return self.__prevent_default__
|
||||
|
||||
def is_prevented_postorder(self):
|
||||
"""是否阻止后序插件执行"""
|
||||
return self.__prevent_postorder__
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.eid = EventContext.eid
|
||||
self.__prevent_default__ = False
|
||||
self.__prevent_postorder__ = False
|
||||
self.__return_value__ = {}
|
||||
EventContext.eid += 1
|
||||
|
||||
|
||||
def emit(event_name: str, **kwargs) -> EventContext:
|
||||
""" 触发事件 """
|
||||
import pkg.utils.context as context
|
||||
if context.get_plugin_host() is None:
|
||||
return None
|
||||
return context.get_plugin_host().emit(event_name, **kwargs)
|
||||
|
||||
|
||||
class PluginHost:
|
||||
"""插件宿主"""
|
||||
|
||||
def __init__(self):
|
||||
context.set_plugin_host(self)
|
||||
|
||||
def get_runtime_context(self) -> context:
|
||||
"""获取运行时上下文(pkg.utils.context模块的对象)
|
||||
|
||||
此上下文用于和主程序其他模块交互(数据库、QQ机器人、OpenAI接口等)
|
||||
详见pkg.utils.context模块
|
||||
其中的context变量保存了其他重要模块的类对象,可以使用这些对象进行交互
|
||||
"""
|
||||
return context
|
||||
|
||||
def get_bot(self) -> Mirai:
|
||||
"""获取机器人对象"""
|
||||
return context.get_qqbot_manager().bot
|
||||
|
||||
def send_person_message(self, person, message):
|
||||
"""发送私聊消息"""
|
||||
asyncio.run(self.get_bot().send_friend_message(person, message))
|
||||
|
||||
def send_group_message(self, group, message):
|
||||
"""发送群消息"""
|
||||
asyncio.run(self.get_bot().send_group_message(group, message))
|
||||
|
||||
def notify_admin(self, message):
|
||||
"""通知管理员"""
|
||||
context.get_qqbot_manager().notify_admin(message)
|
||||
|
||||
def emit(self, event_name: str, **kwargs) -> EventContext:
|
||||
""" 触发事件 """
|
||||
import json
|
||||
|
||||
event_context = EventContext(event_name)
|
||||
logging.debug("触发事件: {} ({})".format(event_name, event_context.eid))
|
||||
for plugin in iter_plugins():
|
||||
|
||||
if not plugin['enabled']:
|
||||
continue
|
||||
|
||||
# if plugin['instance'] is None:
|
||||
# # 从关闭状态切到开启状态之后,重新加载插件
|
||||
# try:
|
||||
# plugin['instance'] = plugin["class"](plugin_host=self)
|
||||
# logging.info("插件 {} 已初始化".format(plugin['name']))
|
||||
# except:
|
||||
# logging.error("插件 {} 初始化时发生错误: {}".format(plugin['name'], sys.exc_info()))
|
||||
# continue
|
||||
|
||||
if 'hooks' not in plugin or event_name not in plugin['hooks']:
|
||||
continue
|
||||
|
||||
hooks = []
|
||||
if event_name in plugin["hooks"]:
|
||||
hooks = plugin["hooks"][event_name]
|
||||
for hook in hooks:
|
||||
try:
|
||||
already_prevented_default = event_context.is_prevented_default()
|
||||
|
||||
kwargs['host'] = context.get_plugin_host()
|
||||
kwargs['event'] = event_context
|
||||
|
||||
hook(plugin['instance'], **kwargs)
|
||||
|
||||
if event_context.is_prevented_default() and not already_prevented_default:
|
||||
logging.debug("插件 {} 已要求阻止事件 {} 的默认行为".format(plugin['name'], event_name))
|
||||
|
||||
except Exception as e:
|
||||
logging.error("插件{}触发事件{}时发生错误".format(plugin['name'], event_name))
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
# print("done:{}".format(plugin['name']))
|
||||
if event_context.is_prevented_postorder():
|
||||
logging.debug("插件 {} 阻止了后序插件的执行".format(plugin['name']))
|
||||
break
|
||||
|
||||
logging.debug("事件 {} ({}) 处理完毕,返回值: {}".format(event_name, event_context.eid,
|
||||
event_context.__return_value__))
|
||||
|
||||
return event_context
|
||||
223
pkg/plugin/models.py
Normal file
223
pkg/plugin/models.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import logging
|
||||
|
||||
import pkg.plugin.host as host
|
||||
import pkg.utils.context
|
||||
|
||||
PersonMessageReceived = "person_message_received"
|
||||
"""收到私聊消息时,在判断是否应该响应前触发
|
||||
kwargs:
|
||||
launcher_type: str 发起对象类型(group/person)
|
||||
launcher_id: int 发起对象ID(群号/QQ号)
|
||||
sender_id: int 发送者ID(QQ号)
|
||||
message_chain: mirai.models.message.MessageChain 消息链
|
||||
"""
|
||||
|
||||
GroupMessageReceived = "group_message_received"
|
||||
"""收到群聊消息时,在判断是否应该响应前触发(所有群消息)
|
||||
kwargs:
|
||||
launcher_type: str 发起对象类型(group/person)
|
||||
launcher_id: int 发起对象ID(群号/QQ号)
|
||||
sender_id: int 发送者ID(QQ号)
|
||||
message_chain: mirai.models.message.MessageChain 消息链
|
||||
"""
|
||||
|
||||
PersonNormalMessageReceived = "person_normal_message_received"
|
||||
"""判断为应该处理的私聊普通消息时触发
|
||||
kwargs:
|
||||
launcher_type: str 发起对象类型(group/person)
|
||||
launcher_id: int 发起对象ID(群号/QQ号)
|
||||
sender_id: int 发送者ID(QQ号)
|
||||
text_message: str 消息文本
|
||||
|
||||
returns (optional):
|
||||
alter: str 修改后的消息文本
|
||||
reply: list 回复消息组件列表
|
||||
"""
|
||||
|
||||
PersonCommandSent = "person_command_sent"
|
||||
"""判断为应该处理的私聊指令时触发
|
||||
kwargs:
|
||||
launcher_type: str 发起对象类型(group/person)
|
||||
launcher_id: int 发起对象ID(群号/QQ号)
|
||||
sender_id: int 发送者ID(QQ号)
|
||||
command: str 指令
|
||||
params: list[str] 参数列表
|
||||
text_message: str 完整指令文本
|
||||
is_admin: bool 是否为管理员
|
||||
|
||||
returns (optional):
|
||||
alter: str 修改后的完整指令文本
|
||||
reply: list 回复消息组件列表
|
||||
"""
|
||||
|
||||
GroupNormalMessageReceived = "group_normal_message_received"
|
||||
"""判断为应该处理的群聊普通消息时触发
|
||||
kwargs:
|
||||
launcher_type: str 发起对象类型(group/person)
|
||||
launcher_id: int 发起对象ID(群号/QQ号)
|
||||
sender_id: int 发送者ID(QQ号)
|
||||
text_message: str 消息文本
|
||||
|
||||
returns (optional):
|
||||
alter: str 修改后的消息文本
|
||||
reply: list 回复消息组件列表
|
||||
"""
|
||||
|
||||
GroupCommandSent = "group_command_sent"
|
||||
"""判断为应该处理的群聊指令时触发
|
||||
kwargs:
|
||||
launcher_type: str 发起对象类型(group/person)
|
||||
launcher_id: int 发起对象ID(群号/QQ号)
|
||||
sender_id: int 发送者ID(QQ号)
|
||||
command: str 指令
|
||||
params: list[str] 参数列表
|
||||
text_message: str 完整指令文本
|
||||
is_admin: bool 是否为管理员
|
||||
|
||||
returns (optional):
|
||||
alter: str 修改后的完整指令文本
|
||||
reply: list 回复消息组件列表
|
||||
"""
|
||||
|
||||
NormalMessageResponded = "normal_message_responded"
|
||||
"""获取到对普通消息的文字响应时触发
|
||||
kwargs:
|
||||
launcher_type: str 发起对象类型(group/person)
|
||||
launcher_id: int 发起对象ID(群号/QQ号)
|
||||
sender_id: int 发送者ID(QQ号)
|
||||
session: pkg.openai.session.Session 会话对象
|
||||
prefix: str 回复文字消息的前缀
|
||||
response_text: str 响应文本
|
||||
|
||||
returns (optional):
|
||||
prefix: str 修改后的回复文字消息的前缀
|
||||
reply: list 替换回复消息组件列表
|
||||
"""
|
||||
|
||||
SessionFirstMessageReceived = "session_first_message_received"
|
||||
"""会话被第一次交互时触发
|
||||
kwargs:
|
||||
session_name: str 会话名称(<launcher_type>_<launcher_id>)
|
||||
session: pkg.openai.session.Session 会话对象
|
||||
default_prompt: str 预设值
|
||||
"""
|
||||
|
||||
SessionExplicitReset = "session_reset"
|
||||
"""会话被用户手动重置时触发,此事件不支持阻止默认行为
|
||||
kwargs:
|
||||
session_name: str 会话名称(<launcher_type>_<launcher_id>)
|
||||
session: pkg.openai.session.Session 会话对象
|
||||
"""
|
||||
|
||||
SessionExpired = "session_expired"
|
||||
"""会话过期时触发
|
||||
kwargs:
|
||||
session_name: str 会话名称(<launcher_type>_<launcher_id>)
|
||||
session: pkg.openai.session.Session 会话对象
|
||||
session_expire_time: int 已设置的会话过期时间(秒)
|
||||
"""
|
||||
|
||||
KeyExceeded = "key_exceeded"
|
||||
"""api-key超额时触发
|
||||
kwargs:
|
||||
key_name: str 超额的api-key名称
|
||||
usage: dict 超额的api-key使用情况
|
||||
exceeded_keys: list[str] 超额的api-key列表
|
||||
"""
|
||||
|
||||
KeySwitched = "key_switched"
|
||||
"""api-key超额切换成功时触发,此事件不支持阻止默认行为
|
||||
kwargs:
|
||||
key_name: str 切换成功的api-key名称
|
||||
key_list: list[str] api-key列表
|
||||
"""
|
||||
|
||||
|
||||
def on(event: str):
|
||||
"""注册事件监听器
|
||||
:param
|
||||
event: str 事件名称
|
||||
"""
|
||||
return Plugin.on(event)
|
||||
|
||||
|
||||
__current_registering_plugin__ = ""
|
||||
|
||||
|
||||
class Plugin:
|
||||
"""插件基类"""
|
||||
|
||||
host: host.PluginHost
|
||||
"""插件宿主,提供插件的一些基础功能"""
|
||||
|
||||
@classmethod
|
||||
def on(cls, event):
|
||||
"""事件处理器装饰器
|
||||
|
||||
:param
|
||||
event: 事件类型
|
||||
:return:
|
||||
None
|
||||
"""
|
||||
global __current_registering_plugin__
|
||||
|
||||
def wrapper(func):
|
||||
plugin_hooks = host.__plugins__[__current_registering_plugin__]["hooks"]
|
||||
|
||||
if event not in plugin_hooks:
|
||||
plugin_hooks[event] = []
|
||||
plugin_hooks[event].append(func)
|
||||
|
||||
# print("registering hook: p='{}', e='{}', f={}".format(__current_registering_plugin__, event, func))
|
||||
|
||||
host.__plugins__[__current_registering_plugin__]["hooks"] = plugin_hooks
|
||||
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def register(name: str, description: str, version: str, author: str):
|
||||
"""注册插件, 此函数作为装饰器使用
|
||||
|
||||
Args:
|
||||
name (str): 插件名称
|
||||
description (str): 插件描述
|
||||
version (str): 插件版本
|
||||
author (str): 插件作者
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
global __current_registering_plugin__
|
||||
|
||||
__current_registering_plugin__ = name
|
||||
# print("registering plugin: n='{}', d='{}', v={}, a='{}'".format(name, description, version, author))
|
||||
host.__plugins__[name] = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"version": version,
|
||||
"author": author,
|
||||
"hooks": {},
|
||||
"path": host.__current_module_path__,
|
||||
"enabled": True,
|
||||
"instance": None,
|
||||
}
|
||||
|
||||
def wrapper(cls: Plugin):
|
||||
cls.name = name
|
||||
cls.description = description
|
||||
cls.version = version
|
||||
cls.author = author
|
||||
cls.host = pkg.utils.context.get_plugin_host()
|
||||
cls.enabled = True
|
||||
cls.path = host.__current_module_path__
|
||||
|
||||
# 存到插件列表
|
||||
host.__plugins__[name]["class"] = cls
|
||||
|
||||
logging.info("插件注册完成: n='{}', d='{}', v={}, a='{}' ({})".format(name, description, version, author, cls))
|
||||
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
84
pkg/plugin/settings.py
Normal file
84
pkg/plugin/settings.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pkg.plugin.host as host
|
||||
import logging
|
||||
|
||||
|
||||
def wrapper_dict_from_runtime_context() -> dict:
|
||||
"""从变量中包装settings.json的数据字典"""
|
||||
settings = {
|
||||
"order": []
|
||||
}
|
||||
|
||||
for plugin_name in host.__plugins_order__:
|
||||
settings["order"].append(plugin_name)
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
def apply_settings(settings: dict):
|
||||
"""将settings.json数据应用到变量中"""
|
||||
if "order" in settings:
|
||||
host.__plugins_order__ = settings["order"]
|
||||
|
||||
|
||||
def dump_settings():
|
||||
"""保存settings.json数据"""
|
||||
logging.debug("保存plugins/settings.json数据")
|
||||
|
||||
settings = wrapper_dict_from_runtime_context()
|
||||
|
||||
with open("plugins/settings.json", "w", encoding="utf-8") as f:
|
||||
json.dump(settings, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def load_settings():
|
||||
"""加载settings.json数据"""
|
||||
logging.debug("加载plugins/settings.json数据")
|
||||
|
||||
# 读取plugins/settings.json
|
||||
settings = {
|
||||
}
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists("plugins/settings.json"):
|
||||
# 不存在则创建
|
||||
with open("plugins/settings.json", "w", encoding="utf-8") as f:
|
||||
json.dump(wrapper_dict_from_runtime_context(), f, indent=4, ensure_ascii=False)
|
||||
|
||||
with open("plugins/settings.json", "r", encoding="utf-8") as f:
|
||||
settings = json.load(f)
|
||||
|
||||
if settings is None:
|
||||
settings = {
|
||||
}
|
||||
|
||||
# 检查每个设置项
|
||||
if "order" not in settings:
|
||||
settings["order"] = []
|
||||
|
||||
settings_modified = False
|
||||
|
||||
settings_copy = settings.copy()
|
||||
|
||||
# 检查settings中多余的插件项
|
||||
|
||||
# order
|
||||
for plugin_name in settings_copy["order"]:
|
||||
if plugin_name not in host.__plugins_order__:
|
||||
settings["order"].remove(plugin_name)
|
||||
settings_modified = True
|
||||
|
||||
# 检查settings中缺少的插件项
|
||||
|
||||
# order
|
||||
for plugin_name in host.__plugins_order__:
|
||||
if plugin_name not in settings_copy["order"]:
|
||||
settings["order"].append(plugin_name)
|
||||
settings_modified = True
|
||||
|
||||
apply_settings(settings)
|
||||
|
||||
if settings_modified:
|
||||
dump_settings()
|
||||
89
pkg/plugin/switch.py
Normal file
89
pkg/plugin/switch.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# 控制插件的开关
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pkg.plugin.host as host
|
||||
|
||||
|
||||
def wrapper_dict_from_plugin_list() -> dict:
|
||||
""" 将插件列表转换为开关json """
|
||||
switch = {}
|
||||
|
||||
for plugin_name in host.__plugins__:
|
||||
plugin = host.__plugins__[plugin_name]
|
||||
|
||||
switch[plugin_name] = {
|
||||
"path": plugin["path"],
|
||||
"enabled": plugin["enabled"],
|
||||
}
|
||||
|
||||
return switch
|
||||
|
||||
|
||||
def apply_switch(switch: dict):
|
||||
"""将开关数据应用到插件列表中"""
|
||||
# print("将开关数据应用到插件列表中")
|
||||
# print(switch)
|
||||
for plugin_name in switch:
|
||||
host.__plugins__[plugin_name]["enabled"] = switch[plugin_name]["enabled"]
|
||||
|
||||
|
||||
def dump_switch():
|
||||
""" 保存开关数据 """
|
||||
logging.debug("保存开关数据")
|
||||
# 将开关数据写入plugins/switch.json
|
||||
|
||||
switch = wrapper_dict_from_plugin_list()
|
||||
|
||||
with open("plugins/switch.json", "w", encoding="utf-8") as f:
|
||||
json.dump(switch, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def load_switch():
|
||||
""" 加载开关数据 """
|
||||
logging.debug("加载开关数据")
|
||||
# 读取plugins/switch.json
|
||||
|
||||
switch = {}
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists("plugins/switch.json"):
|
||||
# 不存在则创建
|
||||
with open("plugins/switch.json", "w", encoding="utf-8") as f:
|
||||
json.dump(switch, f, indent=4, ensure_ascii=False)
|
||||
|
||||
with open("plugins/switch.json", "r", encoding="utf-8") as f:
|
||||
switch = json.load(f)
|
||||
|
||||
if switch is None:
|
||||
switch = {}
|
||||
|
||||
switch_modified = False
|
||||
|
||||
switch_copy = switch.copy()
|
||||
# 检查switch中多余的和path不相符的
|
||||
for plugin_name in switch_copy:
|
||||
if plugin_name not in host.__plugins__:
|
||||
del switch[plugin_name]
|
||||
switch_modified = True
|
||||
elif switch[plugin_name]["path"] != host.__plugins__[plugin_name]["path"]:
|
||||
# 删除此不相符的
|
||||
del switch[plugin_name]
|
||||
switch_modified = True
|
||||
|
||||
# 检查plugin中多余的
|
||||
for plugin_name in host.__plugins__:
|
||||
if plugin_name not in switch:
|
||||
switch[plugin_name] = {
|
||||
"path": host.__plugins__[plugin_name]["path"],
|
||||
"enabled": host.__plugins__[plugin_name]["enabled"],
|
||||
}
|
||||
switch_modified = True
|
||||
|
||||
# 应用开关数据
|
||||
apply_switch(switch)
|
||||
|
||||
# 如果switch有修改,保存
|
||||
if switch_modified:
|
||||
dump_switch()
|
||||
50
pkg/qqbot/banlist.py
Normal file
50
pkg/qqbot/banlist.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import pkg.utils.context
|
||||
|
||||
|
||||
def is_banned(launcher_type: str, launcher_id: int, sender_id: int) -> bool:
|
||||
if not pkg.utils.context.get_qqbot_manager().enable_banlist:
|
||||
return False
|
||||
|
||||
result = False
|
||||
|
||||
if launcher_type == 'group':
|
||||
# 检查是否显式声明发起人QQ要被person忽略
|
||||
if sender_id in pkg.utils.context.get_qqbot_manager().ban_person:
|
||||
result = True
|
||||
else:
|
||||
for group_rule in pkg.utils.context.get_qqbot_manager().ban_group:
|
||||
if type(group_rule) == int:
|
||||
if group_rule == launcher_id: # 此群群号被禁用
|
||||
result = True
|
||||
elif type(group_rule) == str:
|
||||
if group_rule.startswith('!'):
|
||||
# 截取!后面的字符串作为表达式,判断是否匹配
|
||||
reg_str = group_rule[1:]
|
||||
import re
|
||||
if re.match(reg_str, str(launcher_id)): # 被豁免,最高级别
|
||||
result = False
|
||||
break
|
||||
else:
|
||||
# 判断是否匹配regexp
|
||||
import re
|
||||
if re.match(group_rule, str(launcher_id)): # 此群群号被禁用
|
||||
result = True
|
||||
|
||||
else:
|
||||
# ban_person, 与群规则相同
|
||||
for person_rule in pkg.utils.context.get_qqbot_manager().ban_person:
|
||||
if type(person_rule) == int:
|
||||
if person_rule == launcher_id:
|
||||
result = True
|
||||
elif type(person_rule) == str:
|
||||
if person_rule.startswith('!'):
|
||||
reg_str = person_rule[1:]
|
||||
import re
|
||||
if re.match(reg_str, str(launcher_id)):
|
||||
result = False
|
||||
break
|
||||
else:
|
||||
import re
|
||||
if re.match(person_rule, str(launcher_id)):
|
||||
result = True
|
||||
return result
|
||||
105
pkg/qqbot/blob.py
Normal file
105
pkg/qqbot/blob.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# 长消息处理相关
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import base64
|
||||
|
||||
import config
|
||||
from mirai.models.message import MessageComponent, MessageChain, Image
|
||||
from mirai.models.message import ForwardMessageNode
|
||||
from mirai.models.base import MiraiBaseModel
|
||||
from typing import List
|
||||
import pkg.utils.context as context
|
||||
import pkg.utils.text2img as text2img
|
||||
|
||||
|
||||
class ForwardMessageDiaplay(MiraiBaseModel):
|
||||
title: str = "群聊的聊天记录"
|
||||
brief: str = "[聊天记录]"
|
||||
source: str = "聊天记录"
|
||||
preview: List[str] = []
|
||||
summary: str = "查看x条转发消息"
|
||||
|
||||
|
||||
class Forward(MessageComponent):
|
||||
"""合并转发。"""
|
||||
type: str = "Forward"
|
||||
"""消息组件类型。"""
|
||||
display: ForwardMessageDiaplay
|
||||
"""显示信息"""
|
||||
node_list: List[ForwardMessageNode]
|
||||
"""转发消息节点列表。"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
if len(args) == 1:
|
||||
self.node_list = args[0]
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return '[聊天记录]'
|
||||
|
||||
|
||||
def text_to_image(text: str) -> MessageComponent:
|
||||
"""将文本转换成图片"""
|
||||
# 检查temp文件夹是否存在
|
||||
if not os.path.exists('temp'):
|
||||
os.mkdir('temp')
|
||||
img_path = text2img.text_to_image(text_str=text, save_as='temp/{}.png'.format(int(time.time())))
|
||||
|
||||
compressed_path, size = text2img.compress_image(img_path, outfile="temp/{}_compressed.png".format(int(time.time())))
|
||||
# 读取图片,转换成base64
|
||||
with open(compressed_path, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
b64 = base64.b64encode(img)
|
||||
|
||||
# 删除图片
|
||||
os.remove(img_path)
|
||||
|
||||
# 判断compressed_path是否存在
|
||||
if os.path.exists(compressed_path):
|
||||
os.remove(compressed_path)
|
||||
# 返回图片
|
||||
return Image(base64=b64.decode('utf-8'))
|
||||
|
||||
|
||||
def check_text(text: str) -> list:
|
||||
"""检查文本是否为长消息,并转换成该使用的消息链组件"""
|
||||
if not hasattr(config, 'blob_message_threshold'):
|
||||
return [text]
|
||||
|
||||
if len(text) > config.blob_message_threshold:
|
||||
if not hasattr(config, 'blob_message_strategy'):
|
||||
raise AttributeError('未定义长消息处理策略')
|
||||
|
||||
# logging.info("长消息: {}".format(text))
|
||||
if config.blob_message_strategy == 'image':
|
||||
# 转换成图片
|
||||
return [text_to_image(text)]
|
||||
elif config.blob_message_strategy == 'forward':
|
||||
# 敏感词屏蔽
|
||||
text = context.get_qqbot_manager().reply_filter.process(text)
|
||||
|
||||
# 包装转发消息
|
||||
display = ForwardMessageDiaplay(
|
||||
title='群聊的聊天记录',
|
||||
brief='[聊天记录]',
|
||||
source='聊天记录',
|
||||
preview=["bot: "+text],
|
||||
summary="查看1条转发消息"
|
||||
)
|
||||
|
||||
node = ForwardMessageNode(
|
||||
sender_id=config.mirai_http_api_config['qq'],
|
||||
sender_name='bot',
|
||||
message_chain=MessageChain([text])
|
||||
)
|
||||
|
||||
forward = Forward(
|
||||
display=display,
|
||||
node_list=[node]
|
||||
)
|
||||
|
||||
return [forward]
|
||||
else:
|
||||
return [text]
|
||||
387
pkg/qqbot/command.py
Normal file
387
pkg/qqbot/command.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# 指令处理模块
|
||||
import logging
|
||||
import json
|
||||
import datetime
|
||||
import os
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import pkg.openai.session
|
||||
import pkg.openai.manager
|
||||
import pkg.utils.reloader
|
||||
import pkg.utils.updater
|
||||
import pkg.utils.context
|
||||
import pkg.qqbot.message
|
||||
import pkg.utils.credit as credit
|
||||
|
||||
from mirai import Image
|
||||
|
||||
|
||||
def config_operation(cmd, params):
|
||||
reply = []
|
||||
config = pkg.utils.context.get_config()
|
||||
reply_str = ""
|
||||
if len(params) == 0:
|
||||
reply = ["[bot]err:请输入配置项"]
|
||||
else:
|
||||
cfg_name = params[0]
|
||||
if cfg_name == 'all':
|
||||
reply_str = "[bot]所有配置项:\n\n"
|
||||
for cfg in dir(config):
|
||||
if not cfg.startswith('__') and not cfg == 'logging':
|
||||
# 根据配置项类型进行格式化,如果是字典则转换为json并格式化
|
||||
if isinstance(getattr(config, cfg), str):
|
||||
reply_str += "{}: \"{}\"\n".format(cfg, getattr(config, cfg))
|
||||
elif isinstance(getattr(config, cfg), dict):
|
||||
# 不进行unicode转义,并格式化
|
||||
reply_str += "{}: {}\n".format(cfg,
|
||||
json.dumps(getattr(config, cfg),
|
||||
ensure_ascii=False, indent=4))
|
||||
else:
|
||||
reply_str += "{}: {}\n".format(cfg, getattr(config, cfg))
|
||||
reply = [reply_str]
|
||||
elif cfg_name in dir(config):
|
||||
if len(params) == 1:
|
||||
# 按照配置项类型进行格式化
|
||||
if isinstance(getattr(config, cfg_name), str):
|
||||
reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, getattr(config, cfg_name))
|
||||
elif isinstance(getattr(config, cfg_name), dict):
|
||||
reply_str = "[bot]配置项{}: {}\n".format(cfg_name,
|
||||
json.dumps(getattr(config, cfg_name),
|
||||
ensure_ascii=False, indent=4))
|
||||
else:
|
||||
reply_str = "[bot]配置项{}: {}\n".format(cfg_name, getattr(config, cfg_name))
|
||||
reply = [reply_str]
|
||||
else:
|
||||
cfg_value = " ".join(params[1:])
|
||||
# 类型转换,如果是json则转换为字典
|
||||
if cfg_value == 'true':
|
||||
cfg_value = True
|
||||
elif cfg_value == 'false':
|
||||
cfg_value = False
|
||||
elif cfg_value.isdigit():
|
||||
cfg_value = int(cfg_value)
|
||||
elif cfg_value.startswith('{') and cfg_value.endswith('}'):
|
||||
cfg_value = json.loads(cfg_value)
|
||||
else:
|
||||
try:
|
||||
cfg_value = float(cfg_value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 检查类型是否匹配
|
||||
if isinstance(getattr(config, cfg_name), type(cfg_value)):
|
||||
setattr(config, cfg_name, cfg_value)
|
||||
pkg.utils.context.set_config(config)
|
||||
reply = ["[bot]配置项{}修改成功".format(cfg_name)]
|
||||
else:
|
||||
reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)]
|
||||
|
||||
else:
|
||||
reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
|
||||
|
||||
return reply
|
||||
|
||||
|
||||
def plugin_operation(cmd, params, is_admin):
|
||||
reply = []
|
||||
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.utils.updater as updater
|
||||
|
||||
plugin_list = plugin_host.__plugins__
|
||||
|
||||
if len(params) == 0:
|
||||
reply_str = "[bot]所有插件({}):\n".format(len(plugin_host.__plugins__))
|
||||
idx = 0
|
||||
for key in plugin_host.iter_plugins_name():
|
||||
plugin = plugin_list[key]
|
||||
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
|
||||
.format((idx+1), plugin['name'],
|
||||
"[已禁用]" if not plugin['enabled'] else "",
|
||||
plugin['description'],
|
||||
plugin['version'], plugin['author'])
|
||||
|
||||
if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
|
||||
remote_url = updater.get_remote_url("/".join(plugin['path'].split('/')[:-1]))
|
||||
if remote_url != "https://github.com/RockChinQ/QChatGPT" and remote_url != "https://gitee.com/RockChin/QChatGPT":
|
||||
reply_str += "源码: "+remote_url+"\n"
|
||||
|
||||
idx += 1
|
||||
|
||||
reply = [reply_str]
|
||||
elif params[0] == 'update':
|
||||
# 更新所有插件
|
||||
if is_admin:
|
||||
def closure():
|
||||
import pkg.utils.context
|
||||
updated = []
|
||||
for key in plugin_list:
|
||||
plugin = plugin_list[key]
|
||||
if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
|
||||
success = updater.pull_latest("/".join(plugin['path'].split('/')[:-1]))
|
||||
if success:
|
||||
updated.append(plugin['name'])
|
||||
|
||||
# 检查是否有requirements.txt
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("正在安装依赖...")
|
||||
for key in plugin_list:
|
||||
plugin = plugin_list[key]
|
||||
if os.path.exists("/".join(plugin['path'].split('/')[:-1])+"/requirements.txt"):
|
||||
logging.info("{}检测到requirements.txt,安装依赖".format(plugin['name']))
|
||||
import pkg.utils.pkgmgr
|
||||
pkg.utils.pkgmgr.install_requirements("/".join(plugin['path'].split('/')[:-1])+"/requirements.txt")
|
||||
|
||||
import main
|
||||
main.reset_logging()
|
||||
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("已更新插件: {}".format(", ".join(updated)))
|
||||
|
||||
threading.Thread(target=closure).start()
|
||||
reply = ["[bot]正在更新所有插件,请勿重复发起..."]
|
||||
else:
|
||||
reply = ["[bot]err:权限不足"]
|
||||
elif params[0].startswith("http"):
|
||||
if is_admin:
|
||||
|
||||
def closure():
|
||||
try:
|
||||
plugin_host.install_plugin(params[0])
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("插件安装成功,请发送 !reload 指令重载插件")
|
||||
except Exception as e:
|
||||
logging.error("插件安装失败:{}".format(e))
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("插件安装失败:{}".format(e))
|
||||
|
||||
threading.Thread(target=closure, args=()).start()
|
||||
reply = ["[bot]正在安装插件..."]
|
||||
else:
|
||||
reply = ["[bot]err:权限不足,请使用管理员账号私聊发起"]
|
||||
return reply
|
||||
|
||||
|
||||
def process_command(session_name: str, text_message: str, mgr, config,
|
||||
launcher_type: str, launcher_id: int, sender_id: int, is_admin: bool) -> list:
|
||||
reply = []
|
||||
try:
|
||||
logging.info(
|
||||
"[{}]发起指令:{}".format(session_name, text_message[:min(20, len(text_message))] + (
|
||||
"..." if len(text_message) > 20 else "")))
|
||||
|
||||
cmd = text_message[1:].strip().split(' ')[0]
|
||||
|
||||
params = text_message[1:].strip().split(' ')[1:]
|
||||
if cmd == 'help':
|
||||
reply = ["[bot]" + config.help_message]
|
||||
elif cmd == 'reset':
|
||||
if len(params) == 0:
|
||||
pkg.openai.session.get_session(session_name).reset(explicit=True)
|
||||
reply = ["[bot]会话已重置"]
|
||||
else:
|
||||
pkg.openai.session.get_session(session_name).reset(explicit=True, use_prompt=params[0])
|
||||
reply = ["[bot]会话已重置,使用场景预设:{}".format(params[0])]
|
||||
elif cmd == 'last':
|
||||
result = pkg.openai.session.get_session(session_name).last_session()
|
||||
if result is None:
|
||||
reply = ["[bot]没有前一次的对话"]
|
||||
else:
|
||||
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
|
||||
'%Y-%m-%d %H:%M:%S')
|
||||
reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format(datetime_str)]
|
||||
elif cmd == 'next':
|
||||
result = pkg.openai.session.get_session(session_name).next_session()
|
||||
if result is None:
|
||||
reply = ["[bot]没有后一次的对话"]
|
||||
else:
|
||||
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
|
||||
'%Y-%m-%d %H:%M:%S')
|
||||
reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format(datetime_str)]
|
||||
elif cmd == 'prompt':
|
||||
msgs = ""
|
||||
session:list = pkg.openai.session.get_session(session_name).prompt
|
||||
for msg in session:
|
||||
if len(params) != 0 and params[0] in ['-all', '-a']:
|
||||
msgs = msgs + "{}: {}\n\n".format(msg['role'], msg['content'])
|
||||
elif len(msg['content']) > 30:
|
||||
msgs = msgs + "[{}]: {}...\n\n".format(msg['role'], msg['content'][:30])
|
||||
else:
|
||||
msgs = msgs + "[{}]: {}\n\n".format(msg['role'], msg['content'])
|
||||
reply = ["[bot]当前对话所有内容:\n{}".format(msgs)]
|
||||
elif cmd == 'list':
|
||||
pkg.openai.session.get_session(session_name).persistence()
|
||||
page = 0
|
||||
|
||||
if len(params) > 0:
|
||||
try:
|
||||
page = int(params[0])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
results = pkg.openai.session.get_session(session_name).list_history(page=page)
|
||||
if len(results) == 0:
|
||||
reply = ["[bot]第{}页没有历史会话".format(page)]
|
||||
else:
|
||||
reply_str = "[bot]历史会话 第{}页:\n".format(page)
|
||||
current = -1
|
||||
for i in range(len(results)):
|
||||
# 时间(使用create_timestamp转换) 序号 部分内容
|
||||
datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp'])
|
||||
msg = ""
|
||||
try:
|
||||
msg = json.loads(results[i]['prompt'])
|
||||
except json.decoder.JSONDecodeError:
|
||||
msg = pkg.openai.session.reset_session_prompt(session_name, results[i]['prompt'])
|
||||
# 持久化
|
||||
pkg.openai.session.get_session(session_name).persistence()
|
||||
if len(msg) >= 2:
|
||||
reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
|
||||
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
msg[0]['content'])
|
||||
else:
|
||||
reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
|
||||
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"无内容")
|
||||
if results[i]['create_timestamp'] == pkg.openai.session.get_session(
|
||||
session_name).create_timestamp:
|
||||
current = i + page * 10
|
||||
|
||||
reply_str += "\n以上信息倒序排列"
|
||||
if current != -1:
|
||||
reply_str += ",当前会话是 #{}\n".format(current)
|
||||
else:
|
||||
reply_str += ",当前处于全新会话或不在此页"
|
||||
|
||||
reply = [reply_str]
|
||||
elif cmd == 'resend':
|
||||
session = pkg.openai.session.get_session(session_name)
|
||||
to_send = session.undo()
|
||||
|
||||
reply = pkg.qqbot.message.process_normal_message(to_send, mgr, config,
|
||||
launcher_type, launcher_id, sender_id)
|
||||
elif cmd == 'del': # 删除指定会话历史记录
|
||||
if len(params) == 0:
|
||||
reply = ["[bot]参数不足, 格式: !del <序号>\n可以通过!list查看序号"]
|
||||
else:
|
||||
if params[0] == 'all':
|
||||
pkg.openai.session.get_session(session_name).delete_all_history()
|
||||
reply = ["[bot]已删除所有历史会话"]
|
||||
elif params[0].isdigit():
|
||||
if pkg.openai.session.get_session(session_name).delete_history(int(params[0])):
|
||||
reply = ["[bot]已删除历史会话 #{}".format(params[0])]
|
||||
else:
|
||||
reply = ["[bot]没有历史会话 #{}".format(params[0])]
|
||||
else:
|
||||
reply = ["[bot]参数错误, 格式: !del <序号>\n可以通过!list查看序号"]
|
||||
elif cmd == 'usage':
|
||||
reply_str = "[bot]各api-key使用情况:\n\n"
|
||||
|
||||
api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key
|
||||
for key_name in api_keys:
|
||||
text_length = pkg.utils.context.get_openai_manager().audit_mgr \
|
||||
.get_text_length_of_key(api_keys[key_name])
|
||||
image_count = pkg.utils.context.get_openai_manager().audit_mgr \
|
||||
.get_image_count_of_key(api_keys[key_name])
|
||||
reply_str += "{}:\n - 文本长度:{}\n - 图片数量:{}\n".format(key_name, int(text_length),
|
||||
int(image_count))
|
||||
# 获取此key的额度
|
||||
try:
|
||||
credit_data = credit.fetch_credit_data(api_keys[key_name])
|
||||
reply_str += " - 使用额度:{:.2f}/{:.2f}\n".format(credit_data['total_used'],credit_data['total_granted'])
|
||||
except Exception as e:
|
||||
logging.warning("获取额度失败:{}".format(e))
|
||||
|
||||
reply = [reply_str]
|
||||
elif cmd == 'draw':
|
||||
if len(params) == 0:
|
||||
reply = ["[bot]err:请输入图片描述文字"]
|
||||
else:
|
||||
session = pkg.openai.session.get_session(session_name)
|
||||
|
||||
res = session.draw_image(" ".join(params))
|
||||
|
||||
logging.debug("draw_image result:{}".format(res))
|
||||
reply = [Image(url=res['data'][0]['url'])]
|
||||
if not (hasattr(config, 'include_image_description')
|
||||
and not config.include_image_description):
|
||||
reply.append(" ".join(params))
|
||||
elif cmd == 'version':
|
||||
reply_str = "[bot]当前版本:\n{}\n".format(pkg.utils.updater.get_current_version_info())
|
||||
try:
|
||||
if pkg.utils.updater.is_new_version_available():
|
||||
reply_str += "\n有新版本可用,请使用命令 !update 进行更新"
|
||||
except:
|
||||
pass
|
||||
|
||||
reply = [reply_str]
|
||||
|
||||
elif cmd == 'plugin':
|
||||
reply = plugin_operation(cmd, params, is_admin)
|
||||
|
||||
elif cmd == 'default':
|
||||
if len(params) == 0:
|
||||
# 输出目前所有情景预设
|
||||
import pkg.openai.dprompt as dprompt
|
||||
reply_str = "[bot]当前所有情景预设:\n\n"
|
||||
for key,value in dprompt.get_prompt_dict().items():
|
||||
reply_str += " - {}: {}\n".format(key,value)
|
||||
|
||||
reply_str += "\n当前默认情景预设:{}\n".format(dprompt.get_current())
|
||||
reply_str += "请使用!default <情景预设>来设置默认情景预设"
|
||||
reply = [reply_str]
|
||||
elif len(params) >0 and is_admin:
|
||||
# 设置默认情景
|
||||
import pkg.openai.dprompt as dprompt
|
||||
try:
|
||||
dprompt.set_current(params[0])
|
||||
reply = ["[bot]已设置默认情景预设为:{}".format(dprompt.get_current())]
|
||||
except KeyError:
|
||||
reply = ["[bot]err: 未找到情景预设:{}".format(params[0])]
|
||||
else:
|
||||
reply = ["[bot]err: 仅管理员可设置默认情景预设"]
|
||||
elif cmd == "delhst" and is_admin:
|
||||
if len(params) == 0:
|
||||
reply = ["[bot]err:请输入要删除的会话名: group_<群号> 或者 person_<QQ号>, 或使用 !delhst all 删除所有会话的历史记录"]
|
||||
else:
|
||||
if params[0] == "all":
|
||||
pkg.utils.context.get_database_manager().delete_all_session_history()
|
||||
reply = ["[bot]已删除所有会话的历史记录"]
|
||||
else:
|
||||
if pkg.utils.context.get_database_manager().delete_all_history(params[0]):
|
||||
reply = ["[bot]已删除会话 {} 的所有历史记录".format(params[0])]
|
||||
else:
|
||||
reply = ["[bot]未找到会话 {} 的历史记录".format(params[0])]
|
||||
elif cmd == 'reload' and is_admin:
|
||||
def reload_task():
|
||||
pkg.utils.reloader.reload_all()
|
||||
|
||||
threading.Thread(target=reload_task, daemon=True).start()
|
||||
elif cmd == 'update' and is_admin:
|
||||
def update_task():
|
||||
try:
|
||||
if pkg.utils.updater.update_all():
|
||||
pkg.utils.reloader.reload_all(notify=False)
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("更新完成")
|
||||
else:
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("无新版本")
|
||||
except Exception as e0:
|
||||
traceback.print_exc()
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("更新失败:{}".format(e0))
|
||||
return
|
||||
|
||||
threading.Thread(target=update_task, daemon=True).start()
|
||||
|
||||
reply = ["[bot]正在更新,请耐心等待,请勿重复发起更新..."]
|
||||
elif cmd == 'cfg' and is_admin:
|
||||
reply = config_operation(cmd, params)
|
||||
else:
|
||||
if cmd.startswith("~") and is_admin:
|
||||
config_item = cmd[1:]
|
||||
params = [config_item] + params
|
||||
reply = config_operation("cfg", params)
|
||||
else:
|
||||
reply = ["[bot]err:未知的指令或权限不足: " + cmd]
|
||||
except Exception as e:
|
||||
mgr.notify_admin("{}指令执行失败:{}".format(session_name, e))
|
||||
logging.exception(e)
|
||||
reply = ["[bot]err:{}".format(e)]
|
||||
|
||||
return reply
|
||||
@@ -1,18 +1,84 @@
|
||||
# 敏感词过滤模块
|
||||
import re
|
||||
import requests
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
||||
class ReplyFilter:
|
||||
|
||||
sensitive_words = []
|
||||
mask = "*"
|
||||
mask_word = ""
|
||||
|
||||
def __init__(self, sensitive_words: list):
|
||||
# 默认值( 兼容性考虑 )
|
||||
baidu_check = False
|
||||
baidu_api_key = ""
|
||||
baidu_secret_key = ""
|
||||
inappropriate_message_tips = "[百度云]请珍惜机器人,当前返回内容不合规"
|
||||
|
||||
def __init__(self, sensitive_words: list, mask: str = "*", mask_word: str = ""):
|
||||
self.sensitive_words = sensitive_words
|
||||
self.mask = mask
|
||||
self.mask_word = mask_word
|
||||
import config
|
||||
if hasattr(config, 'baidu_check') and hasattr(config, 'baidu_api_key') and hasattr(config, 'baidu_secret_key'):
|
||||
self.baidu_check = config.baidu_check
|
||||
self.baidu_api_key = config.baidu_api_key
|
||||
self.baidu_secret_key = config.baidu_secret_key
|
||||
self.inappropriate_message_tips = config.inappropriate_message_tips
|
||||
|
||||
def is_illegal(self, message: str) -> bool:
|
||||
processed = self.process(message)
|
||||
if processed != message:
|
||||
return True
|
||||
return False
|
||||
|
||||
def process(self, message: str) -> str:
|
||||
|
||||
# 本地关键词屏蔽
|
||||
for word in self.sensitive_words:
|
||||
match = re.findall(word, message)
|
||||
if len(match) > 0:
|
||||
for i in range(len(match)):
|
||||
message = message.replace(match[i], "*" * len(match[i]))
|
||||
if self.mask_word == "":
|
||||
message = message.replace(match[i], self.mask * len(match[i]))
|
||||
else:
|
||||
message = message.replace(match[i], self.mask_word)
|
||||
|
||||
# 百度云审核
|
||||
if self.baidu_check:
|
||||
|
||||
# 百度云审核URL
|
||||
baidu_url = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token=" + \
|
||||
str(requests.post("https://aip.baidubce.com/oauth/2.0/token",
|
||||
params={"grant_type": "client_credentials",
|
||||
"client_id": self.baidu_api_key,
|
||||
"client_secret": self.baidu_secret_key}).json().get("access_token"))
|
||||
|
||||
# 百度云审核
|
||||
payload = "text=" + message
|
||||
logging.info("向百度云发送:" + payload)
|
||||
headers = {'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'}
|
||||
|
||||
if isinstance(payload, str):
|
||||
payload = payload.encode('utf-8')
|
||||
|
||||
response = requests.request("POST", baidu_url, headers=headers, data=payload)
|
||||
response_dict = json.loads(response.text)
|
||||
|
||||
if "error_code" in response_dict:
|
||||
error_msg = response_dict.get("error_msg")
|
||||
logging.warning(f"百度云判定出错,错误信息:{error_msg}")
|
||||
conclusion = f"百度云判定出错,错误信息:{error_msg}\n以下是原消息:{message}"
|
||||
else:
|
||||
conclusion = response_dict["conclusion"]
|
||||
if conclusion in ("合规"):
|
||||
logging.info(f"百度云判定结果:{conclusion}")
|
||||
return message
|
||||
else:
|
||||
logging.warning(f"百度云判定结果:{conclusion}")
|
||||
conclusion = self.inappropriate_message_tips
|
||||
# 返回百度云审核结果
|
||||
return conclusion
|
||||
|
||||
return message
|
||||
|
||||
19
pkg/qqbot/ignore.py
Normal file
19
pkg/qqbot/ignore.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import re
|
||||
|
||||
|
||||
def ignore(msg: str) -> bool:
|
||||
"""检查消息是否应该被忽略"""
|
||||
import config
|
||||
|
||||
if not hasattr(config, 'ignore_rules'):
|
||||
return False
|
||||
|
||||
if 'prefix' in config.ignore_rules:
|
||||
for rule in config.ignore_rules['prefix']:
|
||||
if msg.startswith(rule):
|
||||
return True
|
||||
|
||||
if 'regexp' in config.ignore_rules:
|
||||
for rule in config.ignore_rules['regexp']:
|
||||
if re.search(rule, msg):
|
||||
return True
|
||||
@@ -2,10 +2,12 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import mirai.models.bus
|
||||
from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
|
||||
from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
|
||||
FriendMessage, Image
|
||||
from func_timeout import func_set_timeout
|
||||
|
||||
import pkg.openai.session
|
||||
import pkg.openai.manager
|
||||
@@ -16,15 +18,12 @@ import pkg.qqbot.filter
|
||||
import pkg.qqbot.process as processor
|
||||
import pkg.utils.context
|
||||
|
||||
|
||||
# 并行运行
|
||||
def go(func, args=()):
|
||||
thread = threading.Thread(target=func, args=args, daemon=True)
|
||||
thread.start()
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
|
||||
|
||||
# 检查消息是否符合泛响应匹配机制
|
||||
def check_response_rule(text: str) -> (bool, str):
|
||||
def check_response_rule(text: str):
|
||||
config = pkg.utils.context.get_config()
|
||||
if not hasattr(config, 'response_rules'):
|
||||
return False, ''
|
||||
@@ -47,25 +46,65 @@ def check_response_rule(text: str) -> (bool, str):
|
||||
return False, ""
|
||||
|
||||
|
||||
def response_at():
|
||||
config = pkg.utils.context.get_config()
|
||||
if 'at' not in config.response_rules:
|
||||
return True
|
||||
|
||||
return config.response_rules['at']
|
||||
|
||||
|
||||
def random_responding():
|
||||
config = pkg.utils.context.get_config()
|
||||
if 'random_rate' in config.response_rules:
|
||||
import random
|
||||
return random.random() < config.response_rules['random_rate']
|
||||
return False
|
||||
|
||||
|
||||
# 控制QQ消息输入输出的类
|
||||
class QQBotManager:
|
||||
retry = 3
|
||||
|
||||
bot = None
|
||||
#线程池控制
|
||||
pool = None
|
||||
|
||||
bot: Mirai = None
|
||||
|
||||
reply_filter = None
|
||||
|
||||
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True):
|
||||
enable_banlist = False
|
||||
|
||||
ban_person = []
|
||||
ban_group = []
|
||||
|
||||
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, pool_num: int = 10, first_time_init=True):
|
||||
self.timeout = timeout
|
||||
self.retry = retry
|
||||
|
||||
self.pool_num = pool_num
|
||||
self.pool = ThreadPoolExecutor(max_workers=self.pool_num)
|
||||
logging.debug("Registered thread pool Size:{}".format(pool_num))
|
||||
|
||||
# 加载禁用列表
|
||||
if os.path.exists("banlist.py"):
|
||||
import banlist
|
||||
self.enable_banlist = banlist.enable
|
||||
self.ban_person = banlist.person
|
||||
self.ban_group = banlist.group
|
||||
logging.info("加载禁用列表: person: {}, group: {}".format(self.ban_person, self.ban_group))
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
if os.path.exists("sensitive.json") \
|
||||
and config.sensitive_word_filter is not None \
|
||||
and config.sensitive_word_filter:
|
||||
with open("sensitive.json", "r", encoding="utf-8") as f:
|
||||
self.reply_filter = pkg.qqbot.filter.ReplyFilter(json.load(f)['words'])
|
||||
sensitive_json = json.load(f)
|
||||
self.reply_filter = pkg.qqbot.filter.ReplyFilter(
|
||||
sensitive_words=sensitive_json['words'],
|
||||
mask=sensitive_json['mask'] if 'mask' in sensitive_json else '*',
|
||||
mask_word=sensitive_json['mask_word'] if 'mask_word' in sensitive_json else ''
|
||||
)
|
||||
else:
|
||||
self.reply_filter = pkg.qqbot.filter.ReplyFilter([])
|
||||
|
||||
@@ -82,15 +121,64 @@ class QQBotManager:
|
||||
# Caution: 注册新的事件处理器之后,请务必在unsubscribe_all中编写相应的取消订阅代码
|
||||
@self.bot.on(FriendMessage)
|
||||
async def on_friend_message(event: FriendMessage):
|
||||
go(self.on_person_message, (event,))
|
||||
|
||||
def friend_message_handler(event: FriendMessage):
|
||||
|
||||
# 触发事件
|
||||
args = {
|
||||
"launcher_type": "person",
|
||||
"launcher_id": event.sender.id,
|
||||
"sender_id": event.sender.id,
|
||||
"message_chain": event.message_chain,
|
||||
}
|
||||
plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args)
|
||||
|
||||
if plugin_event.is_prevented_default():
|
||||
return
|
||||
|
||||
self.on_person_message(event)
|
||||
|
||||
self.go(friend_message_handler, event)
|
||||
|
||||
@self.bot.on(StrangerMessage)
|
||||
async def on_stranger_message(event: StrangerMessage):
|
||||
go(self.on_person_message, (event,))
|
||||
|
||||
def stranger_message_handler(event: StrangerMessage):
|
||||
# 触发事件
|
||||
args = {
|
||||
"launcher_type": "person",
|
||||
"launcher_id": event.sender.id,
|
||||
"sender_id": event.sender.id,
|
||||
"message_chain": event.message_chain,
|
||||
}
|
||||
plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args)
|
||||
|
||||
if plugin_event.is_prevented_default():
|
||||
return
|
||||
|
||||
self.on_person_message(event)
|
||||
|
||||
self.go(stranger_message_handler, event)
|
||||
|
||||
@self.bot.on(GroupMessage)
|
||||
async def on_group_message(event: GroupMessage):
|
||||
go(self.on_group_message, (event,))
|
||||
|
||||
def group_message_handler(event: GroupMessage):
|
||||
# 触发事件
|
||||
args = {
|
||||
"launcher_type": "group",
|
||||
"launcher_id": event.group.id,
|
||||
"sender_id": event.sender.id,
|
||||
"message_chain": event.message_chain,
|
||||
}
|
||||
plugin_event = plugin_host.emit(plugin_models.GroupMessageReceived, **args)
|
||||
|
||||
if plugin_event.is_prevented_default():
|
||||
return
|
||||
|
||||
self.on_group_message(event)
|
||||
|
||||
self.go(group_message_handler, event)
|
||||
|
||||
def unsubscribe_all():
|
||||
"""取消所有订阅
|
||||
@@ -107,6 +195,9 @@ class QQBotManager:
|
||||
|
||||
self.unsubscribe_all = unsubscribe_all
|
||||
|
||||
def go(self, func, *args, **kwargs):
|
||||
self.pool.submit(func, *args, **kwargs)
|
||||
|
||||
def first_time_init(self, mirai_http_api_config: dict):
|
||||
"""热重载后不再运行此函数"""
|
||||
|
||||
@@ -142,6 +233,7 @@ class QQBotManager:
|
||||
|
||||
# 私聊消息处理
|
||||
def on_person_message(self, event: MessageEvent):
|
||||
import config
|
||||
reply = ''
|
||||
|
||||
if event.sender.id == self.bot.qq:
|
||||
@@ -154,12 +246,21 @@ class QQBotManager:
|
||||
failed = 0
|
||||
for i in range(self.retry):
|
||||
try:
|
||||
reply = processor.process_message('person', event.sender.id, str(event.message_chain),
|
||||
event.message_chain,
|
||||
event.sender.id)
|
||||
|
||||
@func_set_timeout(config.process_message_timeout)
|
||||
def time_ctrl_wrapper():
|
||||
reply = processor.process_message('person', event.sender.id, str(event.message_chain),
|
||||
event.message_chain,
|
||||
event.sender.id)
|
||||
return reply
|
||||
|
||||
reply = time_ctrl_wrapper()
|
||||
break
|
||||
except FunctionTimedOut:
|
||||
logging.warning("person_{}: 超时,重试中({})".format(event.sender.id, i))
|
||||
pkg.openai.session.get_session('person_{}'.format(event.sender.id)).release_response_lock()
|
||||
if "person_{}".format(event.sender.id) in pkg.qqbot.process.processing:
|
||||
pkg.qqbot.process.processing.remove('person_{}'.format(event.sender.id))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
@@ -173,7 +274,7 @@ class QQBotManager:
|
||||
|
||||
# 群消息处理
|
||||
def on_group_message(self, event: GroupMessage):
|
||||
|
||||
import config
|
||||
reply = ''
|
||||
|
||||
def process(text=None) -> str:
|
||||
@@ -185,13 +286,21 @@ class QQBotManager:
|
||||
failed = 0
|
||||
for i in range(self.retry):
|
||||
try:
|
||||
replys = processor.process_message('group', event.group.id,
|
||||
str(event.message_chain).strip() if text is None else text,
|
||||
event.message_chain,
|
||||
event.sender.id)
|
||||
@func_set_timeout(config.process_message_timeout)
|
||||
def time_ctrl_wrapper():
|
||||
replys = processor.process_message('group', event.group.id,
|
||||
str(event.message_chain).strip() if text is None else text,
|
||||
event.message_chain,
|
||||
event.sender.id)
|
||||
return replys
|
||||
|
||||
replys = time_ctrl_wrapper()
|
||||
break
|
||||
except FunctionTimedOut:
|
||||
logging.warning("group_{}: 超时,重试中({})".format(event.group.id, i))
|
||||
pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock()
|
||||
if "group_{}".format(event.group.id) in pkg.qqbot.process.processing:
|
||||
pkg.qqbot.process.processing.remove('group_{}'.format(event.group.id))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
@@ -204,14 +313,19 @@ class QQBotManager:
|
||||
|
||||
if Image in event.message_chain:
|
||||
pass
|
||||
elif At(self.bot.qq) not in event.message_chain:
|
||||
check, result = check_response_rule(str(event.message_chain).strip())
|
||||
|
||||
if check:
|
||||
reply = process(result.strip())
|
||||
else:
|
||||
# 直接调用
|
||||
reply = process()
|
||||
if At(self.bot.qq) in event.message_chain and response_at():
|
||||
# 直接调用
|
||||
reply = process()
|
||||
else:
|
||||
check, result = check_response_rule(str(event.message_chain).strip())
|
||||
|
||||
if check:
|
||||
reply = process(result.strip())
|
||||
# 检查是否随机响应
|
||||
elif random_responding():
|
||||
logging.info("随机响应group_{}消息".format(event.group.id))
|
||||
reply = process()
|
||||
|
||||
if reply:
|
||||
return self.send(event, reply)
|
||||
@@ -219,7 +333,25 @@ class QQBotManager:
|
||||
# 通知系统管理员
|
||||
def notify_admin(self, message: str):
|
||||
config = pkg.utils.context.get_config()
|
||||
if hasattr(config, "admin_qq") and config.admin_qq != 0:
|
||||
if hasattr(config, "admin_qq") and config.admin_qq != 0 and config.admin_qq != []:
|
||||
logging.info("通知管理员:{}".format(message))
|
||||
send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message))
|
||||
threading.Thread(target=asyncio.run, args=(send_task,)).start()
|
||||
if type(config.admin_qq) == int:
|
||||
send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message))
|
||||
threading.Thread(target=asyncio.run, args=(send_task,)).start()
|
||||
else:
|
||||
for adm in config.admin_qq:
|
||||
send_task = self.bot.send_friend_message(adm, "[bot]{}".format(message))
|
||||
threading.Thread(target=asyncio.run, args=(send_task,)).start()
|
||||
|
||||
|
||||
def notify_admin_message_chain(self, message):
|
||||
config = pkg.utils.context.get_config()
|
||||
if hasattr(config, "admin_qq") and config.admin_qq != 0 and config.admin_qq != []:
|
||||
logging.info("通知管理员:{}".format(message))
|
||||
if type(config.admin_qq) == int:
|
||||
send_task = self.bot.send_friend_message(config.admin_qq, message)
|
||||
threading.Thread(target=asyncio.run, args=(send_task,)).start()
|
||||
else:
|
||||
for adm in config.admin_qq:
|
||||
send_task = self.bot.send_friend_message(adm, message)
|
||||
threading.Thread(target=asyncio.run, args=(send_task,)).start()
|
||||
|
||||
130
pkg/qqbot/message.py
Normal file
130
pkg/qqbot/message.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# 普通消息处理模块
|
||||
import logging
|
||||
import time
|
||||
import openai
|
||||
import pkg.utils.context
|
||||
import pkg.openai.session
|
||||
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
import pkg.qqbot.blob as blob
|
||||
|
||||
|
||||
def handle_exception(notify_admin: str = "", set_reply: str = "") -> list:
|
||||
"""处理异常,当notify_admin不为空时,会通知管理员,返回通知用户的消息"""
|
||||
import config
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin(notify_admin)
|
||||
if hasattr(config, 'hide_exce_info_to_user') and config.hide_exce_info_to_user:
|
||||
if hasattr(config, 'alter_tip_message'):
|
||||
return [config.alter_tip_message] if config.alter_tip_message else []
|
||||
else:
|
||||
return ["[bot]出错了,请重试或联系管理员"]
|
||||
else:
|
||||
return [set_reply]
|
||||
|
||||
|
||||
def process_normal_message(text_message: str, mgr, config, launcher_type: str,
|
||||
launcher_id: int, sender_id: int) -> list:
|
||||
session_name = f"{launcher_type}_{launcher_id}"
|
||||
logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + (
|
||||
"..." if len(text_message) > 20 else "")))
|
||||
|
||||
session = pkg.openai.session.get_session(session_name)
|
||||
|
||||
unexpected_exception_times = 0
|
||||
|
||||
max_unexpected_exception_times = 3
|
||||
|
||||
reply = []
|
||||
while True:
|
||||
if unexpected_exception_times >= max_unexpected_exception_times:
|
||||
reply = handle_exception(notify_admin=f"{session_name},多次尝试失败。", set_reply=f"[bot]多次尝试失败,请重试或联系管理员")
|
||||
break
|
||||
try:
|
||||
prefix = "[GPT]" if hasattr(config, "show_prefix") and config.show_prefix else ""
|
||||
|
||||
text = session.append(text_message)
|
||||
|
||||
# 触发插件事件
|
||||
args = {
|
||||
"launcher_type": launcher_type,
|
||||
"launcher_id": launcher_id,
|
||||
"sender_id": sender_id,
|
||||
"session": session,
|
||||
"prefix": prefix,
|
||||
"response_text": text
|
||||
}
|
||||
|
||||
event = pkg.plugin.host.emit(plugin_models.NormalMessageResponded, **args)
|
||||
|
||||
if event.get_return_value("prefix") is not None:
|
||||
prefix = event.get_return_value("prefix")
|
||||
|
||||
if event.get_return_value("reply") is not None:
|
||||
reply = event.get_return_value("reply")
|
||||
|
||||
if not event.is_prevented_default():
|
||||
reply = blob.check_text(prefix + text)
|
||||
except openai.error.APIConnectionError as e:
|
||||
err_msg = str(e)
|
||||
if err_msg.__contains__('Error communicating with OpenAI'):
|
||||
reply = handle_exception("{}会话调用API失败:{}\n请尝试关闭网络代理来解决此问题。".format(session_name, e),
|
||||
"[bot]err:调用API失败,请重试或联系管理员,或等待修复")
|
||||
else:
|
||||
reply = handle_exception("{}会话调用API失败:{}".format(session_name, e), "[bot]err:调用API失败,请重试或联系管理员,或等待修复")
|
||||
except openai.error.RateLimitError as e:
|
||||
logging.debug(type(e))
|
||||
logging.debug(e.error['message'])
|
||||
|
||||
if 'message' in e.error and e.error['message'].__contains__('You exceeded your current quota'):
|
||||
# 尝试切换api-key
|
||||
current_key_name = pkg.utils.context.get_openai_manager().key_mgr.get_key_name(
|
||||
pkg.utils.context.get_openai_manager().key_mgr.using_key
|
||||
)
|
||||
pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded()
|
||||
|
||||
# 触发插件事件
|
||||
args = {
|
||||
'key_name': current_key_name,
|
||||
'usage': pkg.utils.context.get_openai_manager().audit_mgr
|
||||
.get_usage(pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()),
|
||||
'exceeded_keys': pkg.utils.context.get_openai_manager().key_mgr.exceeded,
|
||||
}
|
||||
event = plugin_host.emit(plugin_models.KeyExceeded, **args)
|
||||
|
||||
if not event.is_prevented_default():
|
||||
switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch()
|
||||
|
||||
if not switched:
|
||||
reply = handle_exception(
|
||||
"api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key;如果你认为这是误判,请尝试重启程序。".format(
|
||||
current_key_name), "[bot]err:API调用额度超额,请联系管理员,或等待修复")
|
||||
else:
|
||||
openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key()
|
||||
mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name))
|
||||
reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"]
|
||||
continue
|
||||
elif 'message' in e.error and e.error['message'].__contains__('You can retry your request'):
|
||||
# 重试
|
||||
unexpected_exception_times += 1
|
||||
continue
|
||||
elif 'message' in e.error and e.error['message']\
|
||||
.__contains__('The server had an error while processing your request'):
|
||||
# 重试
|
||||
unexpected_exception_times += 1
|
||||
continue
|
||||
else:
|
||||
reply = handle_exception("{}会话调用API失败:{}".format(session_name, e),
|
||||
"[bot]err:RateLimitError,请重试或联系作者,或等待修复")
|
||||
except openai.error.InvalidRequestError as e:
|
||||
reply = handle_exception("{}API调用参数错误:{}\n\n这可能是由于config.py中的prompt_submit_length参数或"
|
||||
"completion_api_params中的max_tokens参数数值过大导致的,请尝试将其降低".format(
|
||||
session_name, e), "[bot]err:API调用参数错误,请联系管理员,或等待修复")
|
||||
except openai.error.ServiceUnavailableError as e:
|
||||
reply = handle_exception("{}API调用服务不可用:{}".format(session_name, e), "[bot]err:API调用服务不可用,请重试或联系管理员,或等待修复")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
reply = handle_exception("{}会话处理异常:{}".format(session_name, e), "[bot]err:{}".format(e))
|
||||
break
|
||||
|
||||
return reply
|
||||
@@ -1,96 +1,44 @@
|
||||
# 此模块提供了消息处理的具体逻辑的接口
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
|
||||
from func_timeout import func_set_timeout
|
||||
import mirai
|
||||
import logging
|
||||
import openai
|
||||
|
||||
from mirai import Image, MessageChain
|
||||
from mirai import MessageChain, Plain
|
||||
|
||||
# 这里不使用动态引入config
|
||||
# 因为在这里动态引入会卡死程序
|
||||
# 而此模块静态引用config与动态引入的表现一致
|
||||
import config as config_init_import
|
||||
# 已弃用,由于超时时间现已动态使用
|
||||
# import config as config_init_import
|
||||
|
||||
import pkg.openai.session
|
||||
import pkg.openai.manager
|
||||
import pkg.utils.reloader
|
||||
import pkg.utils.updater
|
||||
import pkg.utils.context
|
||||
import pkg.qqbot.message
|
||||
import pkg.qqbot.command
|
||||
import pkg.qqbot.ratelimit as ratelimit
|
||||
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
import pkg.qqbot.ignore as ignore
|
||||
import pkg.qqbot.banlist as banlist
|
||||
|
||||
processing = []
|
||||
|
||||
|
||||
def config_operation(cmd, params):
|
||||
reply = []
|
||||
config = pkg.utils.context.get_config()
|
||||
reply_str = ""
|
||||
if len(params) == 0:
|
||||
reply = ["[bot]err:请输入配置项"]
|
||||
def is_admin(qq: int) -> bool:
|
||||
"""兼容list和int类型的管理员判断"""
|
||||
import config
|
||||
if type(config.admin_qq) == list:
|
||||
return qq in config.admin_qq
|
||||
else:
|
||||
cfg_name = params[0]
|
||||
if cfg_name == 'all':
|
||||
reply_str = "[bot]所有配置项:\n\n"
|
||||
for cfg in dir(config):
|
||||
if not cfg.startswith('__') and not cfg == 'logging':
|
||||
# 根据配置项类型进行格式化,如果是字典则转换为json并格式化
|
||||
if isinstance(getattr(config, cfg), str):
|
||||
reply_str += "{}: \"{}\"\n".format(cfg, getattr(config, cfg))
|
||||
elif isinstance(getattr(config, cfg), dict):
|
||||
# 不进行unicode转义,并格式化
|
||||
reply_str += "{}: {}\n".format(cfg,
|
||||
json.dumps(getattr(config, cfg),
|
||||
ensure_ascii=False, indent=4))
|
||||
else:
|
||||
reply_str += "{}: {}\n".format(cfg, getattr(config, cfg))
|
||||
reply = [reply_str]
|
||||
elif cfg_name in dir(config):
|
||||
if len(params) == 1:
|
||||
# 按照配置项类型进行格式化
|
||||
if isinstance(getattr(config, cfg_name), str):
|
||||
reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, getattr(config, cfg_name))
|
||||
elif isinstance(getattr(config, cfg_name), dict):
|
||||
reply_str = "[bot]配置项{}: {}\n".format(cfg_name,
|
||||
json.dumps(getattr(config, cfg_name),
|
||||
ensure_ascii=False, indent=4))
|
||||
else:
|
||||
reply_str = "[bot]配置项{}: {}\n".format(cfg_name, getattr(config, cfg_name))
|
||||
reply = [reply_str]
|
||||
else:
|
||||
cfg_value = " ".join(params[1:])
|
||||
# 类型转换,如果是json则转换为字典
|
||||
if cfg_value == 'true':
|
||||
cfg_value = True
|
||||
elif cfg_value == 'false':
|
||||
cfg_value = False
|
||||
elif cfg_value.isdigit():
|
||||
cfg_value = int(cfg_value)
|
||||
elif cfg_value.startswith('{') and cfg_value.endswith('}'):
|
||||
cfg_value = json.loads(cfg_value)
|
||||
else:
|
||||
try:
|
||||
cfg_value = float(cfg_value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 检查类型是否匹配
|
||||
if isinstance(getattr(config, cfg_name), type(cfg_value)):
|
||||
setattr(config, cfg_name, cfg_value)
|
||||
pkg.utils.context.set_config(config)
|
||||
reply = ["[bot]配置项{}修改成功".format(cfg_name)]
|
||||
else:
|
||||
reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)]
|
||||
|
||||
else:
|
||||
reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
|
||||
|
||||
return reply
|
||||
return qq == config.admin_qq
|
||||
|
||||
|
||||
@func_set_timeout(config_init_import.process_message_timeout)
|
||||
def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: MessageChain,
|
||||
sender_id: int) -> MessageChain:
|
||||
global processing
|
||||
@@ -100,231 +48,121 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
||||
reply = []
|
||||
session_name = "{}_{}".format(launcher_type, launcher_id)
|
||||
|
||||
# 检查发送方是否被禁用
|
||||
if banlist.is_banned(launcher_type, launcher_id, sender_id):
|
||||
logging.info("根据禁用列表忽略{}_{}的消息".format(launcher_type, launcher_id))
|
||||
return []
|
||||
|
||||
if ignore.ignore(text_message):
|
||||
logging.info("根据忽略规则忽略消息: {}".format(text_message))
|
||||
return []
|
||||
|
||||
# 检查是否被禁言
|
||||
if launcher_type == 'group':
|
||||
result = mgr.bot.member_info(target=launcher_id, member_id=mgr.bot.qq).get()
|
||||
result = asyncio.run(result)
|
||||
if result.mute_time_remaining > 0:
|
||||
logging.info("机器人被禁言,跳过消息处理(group_{},剩余{}s)".format(launcher_id,
|
||||
result.mute_time_remaining))
|
||||
result.mute_time_remaining))
|
||||
return reply
|
||||
|
||||
import config
|
||||
if hasattr(config, 'income_msg_check') and config.income_msg_check:
|
||||
if mgr.reply_filter.is_illegal(text_message):
|
||||
return MessageChain(Plain("[bot] 你的提问中有不合适的内容, 请更换措辞~"))
|
||||
|
||||
pkg.openai.session.get_session(session_name).acquire_response_lock()
|
||||
|
||||
text_message = text_message.strip()
|
||||
|
||||
# 处理消息
|
||||
try:
|
||||
if session_name in processing:
|
||||
pkg.openai.session.get_session(session_name).release_response_lock()
|
||||
return ["[bot]err:正在处理中,请稍后再试"]
|
||||
|
||||
processing.append(session_name)
|
||||
return MessageChain([Plain("[bot]err:正在处理中,请稍后再试")])
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
|
||||
processing.append(session_name)
|
||||
try:
|
||||
|
||||
if text_message.startswith('!') or text_message.startswith("!"): # 指令
|
||||
try:
|
||||
logging.info(
|
||||
"[{}]发起指令:{}".format(session_name, text_message[:min(20, len(text_message))] + (
|
||||
"..." if len(text_message) > 20 else "")))
|
||||
# 触发插件事件
|
||||
args = {
|
||||
'launcher_type': launcher_type,
|
||||
'launcher_id': launcher_id,
|
||||
'sender_id': sender_id,
|
||||
'command': text_message[1:].strip().split(' ')[0],
|
||||
'params': text_message[1:].strip().split(' ')[1:],
|
||||
'text_message': text_message,
|
||||
'is_admin': is_admin(sender_id),
|
||||
}
|
||||
event = plugin_host.emit(plugin_models.PersonCommandSent
|
||||
if launcher_type == 'person'
|
||||
else plugin_models.GroupCommandSent, **args)
|
||||
|
||||
cmd = text_message[1:].strip().split(' ')[0]
|
||||
if event.get_return_value("alter") is not None:
|
||||
text_message = event.get_return_value("alter")
|
||||
|
||||
params = text_message[1:].strip().split(' ')[1:]
|
||||
if cmd == 'help':
|
||||
reply = ["[bot]" + config.help_message]
|
||||
elif cmd == 'reset':
|
||||
pkg.openai.session.get_session(session_name).reset(explicit=True)
|
||||
reply = ["[bot]会话已重置"]
|
||||
elif cmd == 'last':
|
||||
result = pkg.openai.session.get_session(session_name).last_session()
|
||||
if result is None:
|
||||
reply = ["[bot]没有前一次的对话"]
|
||||
else:
|
||||
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
|
||||
'%Y-%m-%d %H:%M:%S')
|
||||
reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format(
|
||||
datetime_str) + result.prompt[
|
||||
:min(100,
|
||||
len(result.prompt))] + \
|
||||
("..." if len(result.prompt) > 100 else "#END#")]
|
||||
elif cmd == 'next':
|
||||
result = pkg.openai.session.get_session(session_name).next_session()
|
||||
if result is None:
|
||||
reply = ["[bot]没有后一次的对话"]
|
||||
else:
|
||||
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
|
||||
'%Y-%m-%d %H:%M:%S')
|
||||
reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format(
|
||||
datetime_str) + result.prompt[
|
||||
:min(100,
|
||||
len(result.prompt))] + \
|
||||
("..." if len(result.prompt) > 100 else "#END#")]
|
||||
elif cmd == 'prompt':
|
||||
reply = ["[bot]当前对话所有内容:\n" + pkg.openai.session.get_session(session_name).prompt]
|
||||
elif cmd == 'list':
|
||||
pkg.openai.session.get_session(session_name).persistence()
|
||||
page = 0
|
||||
# 取出插件提交的返回值赋值给reply
|
||||
if event.get_return_value("reply") is not None:
|
||||
reply = event.get_return_value("reply")
|
||||
|
||||
if len(params) > 0:
|
||||
try:
|
||||
page = int(params[0])
|
||||
except ValueError:
|
||||
pass
|
||||
if not event.is_prevented_default():
|
||||
reply = pkg.qqbot.command.process_command(session_name, text_message,
|
||||
mgr, config, launcher_type, launcher_id, sender_id, is_admin(sender_id))
|
||||
|
||||
results = pkg.openai.session.get_session(session_name).list_history(page=page)
|
||||
if len(results) == 0:
|
||||
reply = ["[bot]第{}页没有历史会话".format(page)]
|
||||
else:
|
||||
reply_str = "[bot]历史会话 第{}页:\n".format(page)
|
||||
current = -1
|
||||
for i in range(len(results)):
|
||||
# 时间(使用create_timestamp转换) 序号 部分内容
|
||||
datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp'])
|
||||
reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
|
||||
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
results[i]['prompt'][
|
||||
:min(20, len(results[i]['prompt']))])
|
||||
if results[i]['create_timestamp'] == pkg.openai.session.get_session(
|
||||
session_name).create_timestamp:
|
||||
current = i + page * 10
|
||||
|
||||
reply_str += "\n以上信息倒序排列"
|
||||
if current != -1:
|
||||
reply_str += ",当前会话是 #{}\n".format(current)
|
||||
else:
|
||||
reply_str += ",当前处于全新会话或不在此页"
|
||||
|
||||
reply = [reply_str]
|
||||
elif cmd == 'fee':
|
||||
api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key
|
||||
reply_str = "[bot]api-key费用情况(估算):(阈值:{})\n\n".format(
|
||||
pkg.utils.context.get_openai_manager().key_mgr.api_key_fee_threshold)
|
||||
|
||||
using_key_name = ""
|
||||
for api_key in api_keys:
|
||||
reply_str += "{}:\n - {}美元 {}%\n".format(api_key,
|
||||
round(
|
||||
pkg.utils.context.get_openai_manager().key_mgr.get_fee(
|
||||
api_keys[api_key]), 6),
|
||||
round(
|
||||
pkg.utils.context.get_openai_manager().key_mgr.get_fee(
|
||||
api_keys[
|
||||
api_key]) / pkg.utils.context.get_openai_manager().key_mgr.api_key_fee_threshold * 100,
|
||||
3))
|
||||
if api_keys[api_key] == pkg.utils.context.get_openai_manager().key_mgr.using_key:
|
||||
using_key_name = api_key
|
||||
reply_str += "\n当前使用:{}".format(using_key_name)
|
||||
|
||||
reply = [reply_str]
|
||||
elif cmd == 'usage':
|
||||
reply_str = "[bot]各api-key使用情况:\n\n"
|
||||
|
||||
api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key
|
||||
for key_name in api_keys:
|
||||
text_length = pkg.utils.context.get_openai_manager().audit_mgr\
|
||||
.get_text_length_of_key(api_keys[key_name])
|
||||
image_count = pkg.utils.context.get_openai_manager().audit_mgr\
|
||||
.get_image_count_of_key(api_keys[key_name])
|
||||
reply_str += "{}:\n - 文本长度:{}\n - 图片数量:{}\n".format(key_name, int(text_length), int(image_count))
|
||||
|
||||
reply = [reply_str]
|
||||
elif cmd == 'draw':
|
||||
if len(params) == 0:
|
||||
reply = ["[bot]err:请输入图片描述文字"]
|
||||
else:
|
||||
session = pkg.openai.session.get_session(session_name)
|
||||
|
||||
res = session.draw_image(" ".join(params))
|
||||
|
||||
logging.debug("draw_image result:{}".format(res))
|
||||
reply = [Image(url=res['data'][0]['url'])]
|
||||
if not (hasattr(config, 'include_image_description')
|
||||
and not config.include_image_description):
|
||||
reply.append(" ".join(params))
|
||||
elif cmd == 'reload' and launcher_type == 'person' and launcher_id == config.admin_qq:
|
||||
def reload_task():
|
||||
pkg.utils.reloader.reload_all()
|
||||
|
||||
threading.Thread(target=reload_task, daemon=True).start()
|
||||
elif cmd == 'update' and launcher_type == 'person' and launcher_id == config.admin_qq:
|
||||
def update_task():
|
||||
try:
|
||||
pkg.utils.updater.update_all()
|
||||
except Exception as e0:
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("更新失败:{}".format(e0))
|
||||
return
|
||||
pkg.utils.reloader.reload_all(notify=False)
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("更新完成")
|
||||
|
||||
threading.Thread(target=update_task, daemon=True).start()
|
||||
elif cmd == 'cfg' and launcher_type == 'person' and launcher_id == config.admin_qq:
|
||||
reply = config_operation(cmd, params)
|
||||
else:
|
||||
if cmd.startswith("~") and launcher_type == 'person' and launcher_id == config.admin_qq:
|
||||
config_item = cmd[1:]
|
||||
params = [config_item] + params
|
||||
reply = config_operation("cfg", params)
|
||||
else:
|
||||
reply = ["[bot]err:未知的指令或权限不足: "+cmd]
|
||||
except Exception as e:
|
||||
mgr.notify_admin("{}指令执行失败:{}".format(session_name, e))
|
||||
logging.exception(e)
|
||||
reply = ["[bot]err:{}".format(e)]
|
||||
else: # 消息
|
||||
logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + (
|
||||
"..." if len(text_message) > 20 else "")))
|
||||
# 限速丢弃检查
|
||||
# print(ratelimit.__crt_minute_usage__[session_name])
|
||||
if hasattr(config, "rate_limitation") and config.rate_limit_strategy == "drop":
|
||||
if ratelimit.is_reach_limit(session_name):
|
||||
logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message))
|
||||
return MessageChain(["[bot]"+config.rate_limit_drop_tip]) if hasattr(config, "rate_limit_drop_tip") and config.rate_limit_drop_tip != "" else []
|
||||
|
||||
session = pkg.openai.session.get_session(session_name)
|
||||
before = time.time()
|
||||
# 触发插件事件
|
||||
args = {
|
||||
"launcher_type": launcher_type,
|
||||
"launcher_id": launcher_id,
|
||||
"sender_id": sender_id,
|
||||
"text_message": text_message,
|
||||
}
|
||||
event = plugin_host.emit(plugin_models.PersonNormalMessageReceived
|
||||
if launcher_type == 'person'
|
||||
else plugin_models.GroupNormalMessageReceived, **args)
|
||||
|
||||
while True:
|
||||
try:
|
||||
prefix = "[GPT]" if hasattr(config, "show_prefix") and config.show_prefix else ""
|
||||
reply = [prefix + session.append(text_message)]
|
||||
except openai.error.APIConnectionError as e:
|
||||
mgr.notify_admin("{}会话调用API失败:{}".format(session_name, e))
|
||||
reply = ["[bot]err:调用API失败,请重试或联系作者,或等待修复"]
|
||||
except openai.error.RateLimitError as e:
|
||||
# 尝试切换api-key
|
||||
current_tokens_amt = pkg.utils.context.get_openai_manager().key_mgr.get_fee(
|
||||
pkg.utils.context.get_openai_manager().key_mgr.get_using_key())
|
||||
pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded()
|
||||
switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch()
|
||||
if event.get_return_value("alter") is not None:
|
||||
text_message = event.get_return_value("alter")
|
||||
|
||||
if not switched:
|
||||
mgr.notify_admin("API调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key".format(
|
||||
current_tokens_amt))
|
||||
reply = ["[bot]err:API调用额度超额,请联系作者,或等待修复"]
|
||||
else:
|
||||
openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key()
|
||||
mgr.notify_admin("API调用额度超限({}),接口报错,已切换到{}".format(current_tokens_amt, name))
|
||||
reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"]
|
||||
continue
|
||||
except openai.error.InvalidRequestError as e:
|
||||
mgr.notify_admin("{}API调用参数错误:{}\n\n这可能是由于config.py中的prompt_submit_length参数或"
|
||||
"completion_api_params中的max_tokens参数数值过大导致的,请尝试将其降低".format(
|
||||
session_name, e))
|
||||
reply = ["[bot]err:API调用参数错误,请联系作者,或等待修复"]
|
||||
except openai.error.ServiceUnavailableError as e:
|
||||
# mgr.notify_admin("{}API调用服务不可用:{}".format(session_name, e))
|
||||
reply = ["[bot]err:API调用服务暂不可用,请尝试重试"]
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
reply = ["[bot]err:{}".format(e)]
|
||||
break
|
||||
# 取出插件提交的返回值赋值给reply
|
||||
if event.get_return_value("reply") is not None:
|
||||
reply = event.get_return_value("reply")
|
||||
|
||||
if reply is not None and type(reply[0]) == str:
|
||||
if not event.is_prevented_default():
|
||||
reply = pkg.qqbot.message.process_normal_message(text_message,
|
||||
mgr, config, launcher_type, launcher_id, sender_id)
|
||||
|
||||
# 限速等待时间
|
||||
if hasattr(config, "rate_limitation") and config.rate_limit_strategy == "wait":
|
||||
time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before))
|
||||
|
||||
if hasattr(config, "rate_limitation"):
|
||||
ratelimit.add_usage(session_name)
|
||||
|
||||
if reply is not None and len(reply) > 0 and (type(reply[0]) == str or type(reply[0]) == mirai.Plain):
|
||||
if type(reply[0]) == mirai.Plain:
|
||||
reply[0] = reply[0].text
|
||||
logging.info(
|
||||
"回复[{}]文字消息:{}".format(session_name,
|
||||
reply[0][:min(100, len(reply[0]))] + (
|
||||
"..." if len(reply[0]) > 100 else "")))
|
||||
reply = [mgr.reply_filter.process(reply[0])]
|
||||
else:
|
||||
logging.info("回复[{}]图片消息:{}".format(session_name, reply))
|
||||
logging.info("回复[{}]消息".format(session_name))
|
||||
|
||||
finally:
|
||||
processing.remove(session_name)
|
||||
finally:
|
||||
pkg.openai.session.get_session(session_name).release_response_lock()
|
||||
|
||||
return MessageChain(reply)
|
||||
return MessageChain(reply)
|
||||
|
||||
86
pkg/qqbot/ratelimit.py
Normal file
86
pkg/qqbot/ratelimit.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# 限速相关模块
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
|
||||
__crt_minute_usage__ = {}
|
||||
"""当前分钟每个会话的对话次数"""
|
||||
|
||||
|
||||
__timer_thr__: threading.Thread = None
|
||||
|
||||
|
||||
def add_usage(session_name: str):
|
||||
"""增加会话的对话次数"""
|
||||
global __crt_minute_usage__
|
||||
if session_name in __crt_minute_usage__:
|
||||
__crt_minute_usage__[session_name] += 1
|
||||
else:
|
||||
__crt_minute_usage__[session_name] = 1
|
||||
|
||||
|
||||
def start_timer():
|
||||
"""启动定时器"""
|
||||
global __timer_thr__
|
||||
__timer_thr__ = threading.Thread(target=run_timer, daemon=True)
|
||||
__timer_thr__.start()
|
||||
|
||||
|
||||
def run_timer():
|
||||
"""启动定时器,每分钟清空一次对话次数"""
|
||||
global __crt_minute_usage__
|
||||
global __timer_thr__
|
||||
|
||||
# 等待直到整分钟
|
||||
time.sleep(60 - time.time() % 60)
|
||||
|
||||
while True:
|
||||
if __timer_thr__ != threading.current_thread():
|
||||
break
|
||||
|
||||
logging.debug("清空当前分钟的对话次数")
|
||||
__crt_minute_usage__ = {}
|
||||
time.sleep(60)
|
||||
|
||||
|
||||
def get_usage(session_name: str) -> int:
|
||||
"""获取会话的对话次数"""
|
||||
global __crt_minute_usage__
|
||||
if session_name in __crt_minute_usage__:
|
||||
return __crt_minute_usage__[session_name]
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_rest_wait_time(session_name: str, spent: float) -> float:
|
||||
"""获取会话此回合的剩余等待时间"""
|
||||
global __crt_minute_usage__
|
||||
|
||||
import config
|
||||
|
||||
if not hasattr(config, 'rate_limitation'):
|
||||
return 0
|
||||
|
||||
min_seconds_per_round = 60.0 / config.rate_limitation
|
||||
|
||||
if session_name in __crt_minute_usage__:
|
||||
return max(0, min_seconds_per_round - spent)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def is_reach_limit(session_name: str) -> bool:
|
||||
"""判断会话是否超过限制"""
|
||||
global __crt_minute_usage__
|
||||
|
||||
import config
|
||||
|
||||
if not hasattr(config, 'rate_limitation'):
|
||||
return False
|
||||
|
||||
if session_name in __crt_minute_usage__:
|
||||
return __crt_minute_usage__[session_name] >= config.rate_limitation
|
||||
else:
|
||||
return False
|
||||
|
||||
start_timer()
|
||||
5
pkg/utils/constants.py
Normal file
5
pkg/utils/constants.py
Normal file
File diff suppressed because one or more lines are too long
@@ -1,7 +1,3 @@
|
||||
import pkg.database.manager
|
||||
import pkg.openai.manager
|
||||
import pkg.qqbot.manager
|
||||
|
||||
context = {
|
||||
'inst': {
|
||||
'database.manager.DatabaseManager': None,
|
||||
@@ -10,6 +6,7 @@ context = {
|
||||
},
|
||||
'logger_handler': None,
|
||||
'config': None,
|
||||
'plugin_host': None,
|
||||
}
|
||||
|
||||
|
||||
@@ -42,4 +39,12 @@ def set_qqbot_manager(inst):
|
||||
|
||||
|
||||
def get_qqbot_manager():
|
||||
return context['inst']['qqbot.manager.QQBotManager']
|
||||
return context['inst']['qqbot.manager.QQBotManager']
|
||||
|
||||
|
||||
def set_plugin_host(inst):
|
||||
context['plugin_host'] = inst
|
||||
|
||||
|
||||
def get_plugin_host():
|
||||
return context['plugin_host']
|
||||
|
||||
13
pkg/utils/credit.py
Normal file
13
pkg/utils/credit.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# OpenAI账号免费额度剩余查询
|
||||
import requests
|
||||
|
||||
|
||||
def fetch_credit_data(api_key: str) -> dict:
|
||||
"""OpenAI账号免费额度剩余查询"""
|
||||
resp = requests.get(
|
||||
url="https://api.openai.com/dashboard/billing/credit_grants",
|
||||
headers={
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
)
|
||||
return resp.json()
|
||||
40
pkg/utils/pkgmgr.py
Normal file
40
pkg/utils/pkgmgr.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from pip._internal import main as pipmain
|
||||
|
||||
import main
|
||||
|
||||
|
||||
def install(package):
|
||||
pipmain(['install', package])
|
||||
main.reset_logging()
|
||||
|
||||
|
||||
def run_pip(params: list):
|
||||
pipmain(params)
|
||||
main.reset_logging()
|
||||
|
||||
|
||||
def install_requirements(file):
|
||||
pipmain(['install', '-r', file, "--upgrade"])
|
||||
main.reset_logging()
|
||||
|
||||
|
||||
def ensure_dulwich():
|
||||
# 尝试三次
|
||||
for i in range(3):
|
||||
try:
|
||||
import dulwich
|
||||
return
|
||||
except ImportError:
|
||||
install('dulwich')
|
||||
|
||||
raise ImportError("无法自动安装dulwich库")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
install("openai11")
|
||||
except Exception as e:
|
||||
print(111)
|
||||
print(e)
|
||||
|
||||
print(222)
|
||||
@@ -4,15 +4,18 @@ import threading
|
||||
import importlib
|
||||
import pkgutil
|
||||
import pkg.utils.context
|
||||
import pkg.plugin.host
|
||||
|
||||
|
||||
def walk(module, prefix=''):
|
||||
def walk(module, prefix='', path_prefix=''):
|
||||
"""遍历并重载所有模块"""
|
||||
for item in pkgutil.iter_modules(module.__path__):
|
||||
if item.ispkg:
|
||||
walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.')
|
||||
|
||||
walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.', path_prefix + item.name + '/')
|
||||
else:
|
||||
logging.info('reload module: {}'.format(prefix + item.name))
|
||||
logging.info('reload module: {}, path: {}'.format(prefix + item.name, path_prefix + item.name + '.py'))
|
||||
pkg.plugin.host.__current_module_path__ = "plugins/" + path_prefix + item.name + '.py'
|
||||
importlib.reload(__import__(module.__name__ + '.' + item.name, fromlist=['']))
|
||||
|
||||
|
||||
@@ -31,8 +34,13 @@ def reload_all(notify=True):
|
||||
walk(pkg)
|
||||
importlib.reload(__import__('config'))
|
||||
importlib.reload(__import__('main'))
|
||||
importlib.reload(__import__('banlist'))
|
||||
pkg.utils.context.context = context
|
||||
|
||||
# 重载插件
|
||||
import plugins
|
||||
walk(plugins)
|
||||
|
||||
# 执行启动流程
|
||||
logging.info("执行程序启动流程")
|
||||
threading.Thread(target=main.main, args=(False,), daemon=False).start()
|
||||
|
||||
193
pkg/utils/text2img.py
Normal file
193
pkg/utils/text2img.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import logging
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import re
|
||||
import os
|
||||
import config
|
||||
import traceback
|
||||
|
||||
text_render_font: ImageFont = None
|
||||
|
||||
if hasattr(config, "blob_message_strategy") and config.blob_message_strategy == "image": # 仅在启用了image时才加载字体
|
||||
use_font = config.font_path if hasattr(config, "font_path") else ""
|
||||
try:
|
||||
|
||||
# 检查是否存在
|
||||
if not os.path.exists(use_font):
|
||||
# 若是windows系统,使用微软雅黑
|
||||
if os.name == "nt":
|
||||
use_font = "C:/Windows/Fonts/msyh.ttc"
|
||||
if not os.path.exists(use_font):
|
||||
logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。")
|
||||
config.blob_message_strategy = "forward"
|
||||
else:
|
||||
logging.info("使用Windows自带字体:" + use_font)
|
||||
text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8")
|
||||
else:
|
||||
logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。")
|
||||
config.blob_message_strategy = "forward"
|
||||
else:
|
||||
text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
logging.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font))
|
||||
config.blob_message_strategy = "forward"
|
||||
|
||||
|
||||
def indexNumber(path=''):
|
||||
"""
|
||||
查找字符串中数字所在串中的位置
|
||||
:param path:目标字符串
|
||||
:return:<class 'list'>: <class 'list'>: [['1', 16], ['2', 35], ['1', 51]]
|
||||
"""
|
||||
kv = []
|
||||
nums = []
|
||||
beforeDatas = re.findall('[\d]+', path)
|
||||
for num in beforeDatas:
|
||||
indexV = []
|
||||
times = path.count(num)
|
||||
if times > 1:
|
||||
if num not in nums:
|
||||
indexs = re.finditer(num, path)
|
||||
for index in indexs:
|
||||
iV = []
|
||||
i = index.span()[0]
|
||||
iV.append(num)
|
||||
iV.append(i)
|
||||
kv.append(iV)
|
||||
nums.append(num)
|
||||
else:
|
||||
index = path.find(num)
|
||||
indexV.append(num)
|
||||
indexV.append(index)
|
||||
kv.append(indexV)
|
||||
# 根据数字位置排序
|
||||
indexSort = []
|
||||
resultIndex = []
|
||||
for vi in kv:
|
||||
indexSort.append(vi[1])
|
||||
indexSort.sort()
|
||||
for i in indexSort:
|
||||
for v in kv:
|
||||
if i == v[1]:
|
||||
resultIndex.append(v)
|
||||
return resultIndex
|
||||
|
||||
|
||||
def get_size(file):
|
||||
# 获取文件大小:KB
|
||||
size = os.path.getsize(file)
|
||||
return size / 1024
|
||||
|
||||
|
||||
def get_outfile(infile, outfile):
|
||||
if outfile:
|
||||
return outfile
|
||||
dir, suffix = os.path.splitext(infile)
|
||||
outfile = '{}-out{}'.format(dir, suffix)
|
||||
return outfile
|
||||
|
||||
|
||||
def compress_image(infile, outfile='', kb=100, step=20, quality=90):
|
||||
"""不改变图片尺寸压缩到指定大小
|
||||
:param infile: 压缩源文件
|
||||
:param outfile: 压缩文件保存地址
|
||||
:param mb: 压缩目标,KB
|
||||
:param step: 每次调整的压缩比率
|
||||
:param quality: 初始压缩比率
|
||||
:return: 压缩文件地址,压缩文件大小
|
||||
"""
|
||||
o_size = get_size(infile)
|
||||
if o_size <= kb:
|
||||
return infile, o_size
|
||||
outfile = get_outfile(infile, outfile)
|
||||
while o_size > kb:
|
||||
im = Image.open(infile)
|
||||
im.save(outfile, quality=quality)
|
||||
if quality - step < 0:
|
||||
break
|
||||
quality -= step
|
||||
o_size = get_size(outfile)
|
||||
return outfile, get_size(outfile)
|
||||
|
||||
|
||||
def text_to_image(text_str: str, save_as="temp.png", width=800):
|
||||
global text_render_font
|
||||
|
||||
text_str = text_str.replace("\t", " ")
|
||||
|
||||
# 分行
|
||||
lines = text_str.split('\n')
|
||||
|
||||
# 计算并分割
|
||||
final_lines = []
|
||||
|
||||
text_width = width-80
|
||||
for line in lines:
|
||||
# 如果长了就分割
|
||||
line_width = text_render_font.getlength(line)
|
||||
if line_width < text_width:
|
||||
final_lines.append(line)
|
||||
continue
|
||||
else:
|
||||
rest_text = line
|
||||
while True:
|
||||
# 分割最前面的一行
|
||||
point = int(len(rest_text) * (text_width / line_width))
|
||||
|
||||
# 检查断点是否在数字中间
|
||||
numbers = indexNumber(rest_text)
|
||||
|
||||
for number in numbers:
|
||||
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
|
||||
point = number[1]
|
||||
break
|
||||
|
||||
final_lines.append(rest_text[:point])
|
||||
rest_text = rest_text[point:]
|
||||
line_width = text_render_font.getlength(rest_text)
|
||||
if line_width < text_width:
|
||||
final_lines.append(rest_text)
|
||||
break
|
||||
else:
|
||||
continue
|
||||
# 准备画布
|
||||
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
|
||||
draw = ImageDraw.Draw(img, mode='RGBA')
|
||||
|
||||
|
||||
# 绘制正文
|
||||
line_number = 0
|
||||
offset_x = 20
|
||||
offset_y = 30
|
||||
for final_line in final_lines:
|
||||
draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=text_render_font)
|
||||
# 遍历此行,检查是否有emoji
|
||||
idx_in_line = 0
|
||||
for ch in final_line:
|
||||
# if self.is_emoji(ch):
|
||||
# emoji_img_valid = ensure_emoji(hex(ord(ch))[2:])
|
||||
# if emoji_img_valid: # emoji图像可用,绘制到指定位置
|
||||
# emoji_image = Image.open("emojis/{}.png".format(hex(ord(ch))[2:]), mode='r').convert('RGBA')
|
||||
# emoji_image = emoji_image.resize((32, 32))
|
||||
|
||||
# x, y = emoji_image.size
|
||||
|
||||
# final_emoji_img = Image.new('RGBA', emoji_image.size, (255, 255, 255))
|
||||
# final_emoji_img.paste(emoji_image, (0, 0, x, y), emoji_image)
|
||||
|
||||
# img.paste(final_emoji_img, box=(int(offset_x + idx_in_line * 32), offset_y + 35 * line_number))
|
||||
|
||||
# 检查字符占位宽
|
||||
char_code = ord(ch)
|
||||
if char_code >= 127:
|
||||
idx_in_line += 1
|
||||
else:
|
||||
idx_in_line += 0.5
|
||||
|
||||
line_number += 1
|
||||
|
||||
|
||||
img.save(save_as)
|
||||
|
||||
return save_as
|
||||
@@ -1,15 +1,254 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os.path
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
import pkg.utils.constants
|
||||
|
||||
|
||||
def check_dulwich_closure():
|
||||
try:
|
||||
import pkg.utils.pkgmgr
|
||||
pkg.utils.pkgmgr.ensure_dulwich()
|
||||
except:
|
||||
pass
|
||||
|
||||
def update_all():
|
||||
"""使用dulwich更新源码"""
|
||||
try:
|
||||
import dulwich
|
||||
except ModuleNotFoundError:
|
||||
raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
|
||||
|
||||
|
||||
def pull_latest(repo_path: str) -> bool:
|
||||
"""拉取最新代码"""
|
||||
check_dulwich_closure()
|
||||
|
||||
from dulwich import porcelain
|
||||
|
||||
repo = porcelain.open_repo(repo_path)
|
||||
porcelain.pull(repo)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_release_list() -> list:
|
||||
"""获取发行列表"""
|
||||
rls_list_resp = requests.get(
|
||||
url="https://api.github.com/repos/RockChinQ/QChatGPT/releases"
|
||||
)
|
||||
|
||||
rls_list = rls_list_resp.json()
|
||||
|
||||
return rls_list
|
||||
|
||||
|
||||
def get_current_tag() -> str:
|
||||
"""获取当前tag"""
|
||||
current_tag = pkg.utils.constants.semantic_version
|
||||
if os.path.exists("current_tag"):
|
||||
with open("current_tag", "r") as f:
|
||||
current_tag = f.read()
|
||||
|
||||
return current_tag
|
||||
|
||||
|
||||
def update_all(cli: bool = False) -> bool:
|
||||
"""检查更新并下载源码"""
|
||||
current_tag = get_current_tag()
|
||||
|
||||
rls_list = get_release_list()
|
||||
|
||||
latest_rls = {}
|
||||
rls_notes = []
|
||||
for rls in rls_list:
|
||||
rls_notes.append(rls['name']) # 使用发行名称作为note
|
||||
if rls['tag_name'] == current_tag:
|
||||
break
|
||||
|
||||
if latest_rls == {}:
|
||||
latest_rls = rls
|
||||
if not cli:
|
||||
logging.info("更新日志: {}".format(rls_notes))
|
||||
else:
|
||||
print("更新日志: {}".format(rls_notes))
|
||||
|
||||
if latest_rls == {}: # 没有新版本
|
||||
return False
|
||||
|
||||
# 下载最新版本的zip到temp目录
|
||||
if not cli:
|
||||
logging.info("开始下载最新版本: {}".format(latest_rls['zipball_url']))
|
||||
else:
|
||||
print("开始下载最新版本: {}".format(latest_rls['zipball_url']))
|
||||
zip_url = latest_rls['zipball_url']
|
||||
zip_resp = requests.get(url=zip_url)
|
||||
zip_data = zip_resp.content
|
||||
|
||||
# 检查temp/updater目录
|
||||
if not os.path.exists("temp"):
|
||||
os.mkdir("temp")
|
||||
if not os.path.exists("temp/updater"):
|
||||
os.mkdir("temp/updater")
|
||||
with open("temp/updater/{}.zip".format(latest_rls['tag_name']), "wb") as f:
|
||||
f.write(zip_data)
|
||||
|
||||
if not cli:
|
||||
logging.info("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name'])))
|
||||
else:
|
||||
print("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name'])))
|
||||
|
||||
# 解压zip到temp/updater/<tag_name>/
|
||||
import zipfile
|
||||
# 检查目标文件夹
|
||||
if os.path.exists("temp/updater/{}".format(latest_rls['tag_name'])):
|
||||
import shutil
|
||||
shutil.rmtree("temp/updater/{}".format(latest_rls['tag_name']))
|
||||
os.mkdir("temp/updater/{}".format(latest_rls['tag_name']))
|
||||
with zipfile.ZipFile("temp/updater/{}.zip".format(latest_rls['tag_name']), 'r') as zip_ref:
|
||||
zip_ref.extractall("temp/updater/{}".format(latest_rls['tag_name']))
|
||||
|
||||
# 覆盖源码
|
||||
source_root = ""
|
||||
# 找到temp/updater/<tag_name>/中的第一个子目录路径
|
||||
for root, dirs, files in os.walk("temp/updater/{}".format(latest_rls['tag_name'])):
|
||||
if root != "temp/updater/{}".format(latest_rls['tag_name']):
|
||||
source_root = root
|
||||
break
|
||||
|
||||
# 覆盖源码
|
||||
import shutil
|
||||
for root, dirs, files in os.walk(source_root):
|
||||
# 覆盖所有子文件子目录
|
||||
for file in files:
|
||||
src = os.path.join(root, file)
|
||||
dst = src.replace(source_root, ".")
|
||||
if os.path.exists(dst):
|
||||
os.remove(dst)
|
||||
|
||||
# 检查目标文件夹是否存在
|
||||
if not os.path.exists(os.path.dirname(dst)):
|
||||
os.makedirs(os.path.dirname(dst))
|
||||
# 检查目标文件是否存在
|
||||
if not os.path.exists(dst):
|
||||
# 创建目标文件
|
||||
open(dst, "w").close()
|
||||
|
||||
shutil.copy(src, dst)
|
||||
|
||||
# 把current_tag写入文件
|
||||
current_tag = latest_rls['tag_name']
|
||||
with open("current_tag", "w") as f:
|
||||
f.write(current_tag)
|
||||
|
||||
# 通知管理员
|
||||
if not cli:
|
||||
import pkg.utils.context
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("已更新到最新版本: {}\n更新日志:\n{}\n新功能通常可以在config-template.py中看到,完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看".format(current_tag, "\n".join(rls_notes)))
|
||||
else:
|
||||
print("已更新到最新版本: {}\n更新日志:\n{}\n新功能通常可以在config-template.py中看到,完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看".format(current_tag, "\n".join(rls_notes)))
|
||||
return True
|
||||
|
||||
|
||||
def is_repo(path: str) -> bool:
|
||||
"""检查是否是git仓库"""
|
||||
check_dulwich_closure()
|
||||
|
||||
from dulwich import porcelain
|
||||
try:
|
||||
from dulwich import porcelain
|
||||
repo = porcelain.open_repo('.')
|
||||
porcelain.pull(repo)
|
||||
except ModuleNotFoundError:
|
||||
raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
|
||||
except dulwich.porcelain.DivergedBranches:
|
||||
raise Exception("分支不一致,自动更新仅支持master分支,请手动更新(https://github.com/RockChinQ/QChatGPT/issues/76)")
|
||||
porcelain.open_repo(path)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def get_remote_url(repo_path: str) -> str:
|
||||
"""获取远程仓库地址"""
|
||||
check_dulwich_closure()
|
||||
|
||||
from dulwich import porcelain
|
||||
repo = porcelain.open_repo(repo_path)
|
||||
return str(porcelain.get_remote_repo(repo, "origin")[1])
|
||||
|
||||
|
||||
def get_current_version_info() -> str:
|
||||
"""获取当前版本信息"""
|
||||
rls_list = get_release_list()
|
||||
current_tag = get_current_tag()
|
||||
for rls in rls_list:
|
||||
if rls['tag_name'] == current_tag:
|
||||
return rls['name'] + "\n" + rls['body']
|
||||
return "未知版本"
|
||||
|
||||
|
||||
def get_commit_id_and_time_and_msg() -> str:
|
||||
"""获取当前提交id和时间和提交信息"""
|
||||
check_dulwich_closure()
|
||||
|
||||
from dulwich import porcelain
|
||||
|
||||
repo = porcelain.open_repo('.')
|
||||
|
||||
for entry in repo.get_walker():
|
||||
tz = datetime.timezone(datetime.timedelta(hours=entry.commit.commit_timezone // 3600))
|
||||
dt = datetime.datetime.fromtimestamp(entry.commit.commit_time, tz)
|
||||
return str(entry.commit.id)[2:9] + " " + dt.strftime('%Y-%m-%d %H:%M:%S') + " [" + str(entry.commit.message, encoding="utf-8").strip()+"]"
|
||||
|
||||
|
||||
def get_current_commit_id() -> str:
|
||||
"""检查是否有新版本"""
|
||||
check_dulwich_closure()
|
||||
|
||||
from dulwich import porcelain
|
||||
|
||||
repo = porcelain.open_repo('.')
|
||||
current_commit_id = ""
|
||||
for entry in repo.get_walker():
|
||||
current_commit_id = str(entry.commit.id)[2:-1]
|
||||
break
|
||||
|
||||
return current_commit_id
|
||||
|
||||
|
||||
def is_new_version_available() -> bool:
|
||||
"""检查是否有新版本"""
|
||||
# 从github获取release列表
|
||||
rls_list = get_release_list()
|
||||
if rls_list is None:
|
||||
return False
|
||||
|
||||
# 获取当前版本
|
||||
current_tag = get_current_tag()
|
||||
|
||||
# 检查是否有新版本
|
||||
for rls in rls_list:
|
||||
if rls['tag_name'] == current_tag:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def get_rls_notes() -> list:
|
||||
"""获取更新日志"""
|
||||
# 从github获取release列表
|
||||
rls_list = get_release_list()
|
||||
if rls_list is None:
|
||||
return None
|
||||
|
||||
# 获取当前版本
|
||||
current_tag = get_current_tag()
|
||||
|
||||
# 检查是否有新版本
|
||||
rls_notes = []
|
||||
for rls in rls_list:
|
||||
if rls['tag_name'] == current_tag:
|
||||
break
|
||||
|
||||
rls_notes.append(rls['name'])
|
||||
|
||||
return rls_notes
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
update_all()
|
||||
|
||||
0
plugins/__init__.py
Normal file
0
plugins/__init__.py
Normal file
9
requirements.txt
Normal file
9
requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
requests~=2.28.1
|
||||
openai~=0.27.0
|
||||
dulwich~=0.21.3
|
||||
colorlog~=6.6.0
|
||||
yiri-mirai~=0.2.6.1
|
||||
websockets~=10.4
|
||||
urllib3~=1.26.10
|
||||
func_timeout~=4.3.5
|
||||
Pillow
|
||||
BIN
res/alipay.jpg
Normal file
BIN
res/alipay.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 26 KiB |
BIN
res/plugin_hello_group.jpg
Normal file
BIN
res/plugin_hello_group.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 54 KiB |
BIN
res/plugin_hello_person.png
Normal file
BIN
res/plugin_hello_person.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
12
scenario/default-template.json
Normal file
12
scenario/default-template.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"prompt": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant. 如果我需要帮助,你要说“输入!help获得帮助”"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "好的,我是一个能干的AI助手。 如果你需要帮助,我会说“输入!help获得帮助”"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,4 +1,7 @@
|
||||
{
|
||||
"说明": "mask将替换敏感词中的每一个字,若mask_word值不为空,则将敏感词整个替换为mask_word的值",
|
||||
"mask": "*",
|
||||
"mask_word": "",
|
||||
"words": [
|
||||
"习近平",
|
||||
"胡锦涛",
|
||||
@@ -9,6 +12,7 @@
|
||||
"毛泽东",
|
||||
"邓小平",
|
||||
"周恩来",
|
||||
"马克思",
|
||||
"社会主义",
|
||||
"共产党",
|
||||
"共产主义",
|
||||
@@ -21,6 +25,8 @@
|
||||
"天安门",
|
||||
"六四",
|
||||
"政治局常委",
|
||||
"两会",
|
||||
"共青团",
|
||||
"学潮",
|
||||
"八九",
|
||||
"二十大",
|
||||
@@ -48,6 +54,7 @@
|
||||
"作爱",
|
||||
"做爱",
|
||||
"性交",
|
||||
"性爱",
|
||||
"自慰",
|
||||
"阴茎",
|
||||
"淫妇",
|
||||
46
tests/compatibility_tests/models_and_interfaces.py
Normal file
46
tests/compatibility_tests/models_and_interfaces.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import openai
|
||||
import time
|
||||
|
||||
# 测试completion api
|
||||
models = [
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-0301',
|
||||
'text-davinci-003',
|
||||
'text-davinci-002',
|
||||
'code-davinci-002',
|
||||
'code-cushman-001',
|
||||
'text-curie-001',
|
||||
'text-babbage-001',
|
||||
'text-ada-001',
|
||||
]
|
||||
|
||||
openai.api_key = "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"
|
||||
|
||||
for model in models:
|
||||
print('Testing model: ', model)
|
||||
|
||||
# completion api
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
model=model,
|
||||
prompt="Say this is a test",
|
||||
max_tokens=7,
|
||||
temperature=0
|
||||
)
|
||||
print(' completion api: ', response['choices'][0]['text'].strip())
|
||||
except Exception as e:
|
||||
print(' completion api err: ', e)
|
||||
|
||||
# chat completion api
|
||||
try:
|
||||
completion = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
)
|
||||
print(" chat api: ",completion.choices[0].message['content'].strip())
|
||||
except Exception as e:
|
||||
print(' chat api err: ', e)
|
||||
|
||||
time.sleep(60)
|
||||
3
tests/plugin_examples/__init__.py
Normal file
3
tests/plugin_examples/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# 插件示例
|
||||
# 将此目录下的目录放入plugins目录即可使用
|
||||
# 每个示例插件的功能请查看其包内的__init__.py或README.md
|
||||
0
tests/plugin_examples/auto_approval/__init__.py
Normal file
0
tests/plugin_examples/auto_approval/__init__.py
Normal file
44
tests/plugin_examples/auto_approval/main.py
Normal file
44
tests/plugin_examples/auto_approval/main.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from mirai import Mirai
|
||||
|
||||
import pkg.qqbot.manager
|
||||
from pkg.plugin.models import *
|
||||
from pkg.plugin.host import PluginHost
|
||||
|
||||
from mirai.models import MemberJoinRequestEvent
|
||||
|
||||
"""
|
||||
加群自动审批
|
||||
"""
|
||||
|
||||
__group_id__ = 1025599757
|
||||
__application_contains__ = ['github', 'gitee', 'Github', 'Gitee', 'GitHub']
|
||||
|
||||
|
||||
# 注册插件
|
||||
@register(name="加群审批", description="自动审批加群申请", version="0.1", author="RockChinQ")
|
||||
class AutoApproval(Plugin):
|
||||
|
||||
bot: Mirai = None
|
||||
|
||||
# 插件加载时触发
|
||||
def __init__(self, plugin_host: PluginHost):
|
||||
qqmgr = plugin_host.get_runtime_context().get_qqbot_manager()
|
||||
assert isinstance(qqmgr, pkg.qqbot.manager.QQBotManager)
|
||||
self.bot = qqmgr.bot
|
||||
|
||||
# 向YiriMirai注册 加群申请 事件处理函数
|
||||
@qqmgr.bot.on(MemberJoinRequestEvent)
|
||||
async def process(event: MemberJoinRequestEvent):
|
||||
assert isinstance(qqmgr, pkg.qqbot.manager.QQBotManager)
|
||||
if event.group_id == __group_id__:
|
||||
if any([x in event.message for x in __application_contains__]):
|
||||
logging.info("自动同意加群申请")
|
||||
await qqmgr.bot.allow(event)
|
||||
|
||||
self.process = process
|
||||
|
||||
# 插件卸载时触发
|
||||
def __del__(self):
|
||||
# 关闭时向YiriMirai注销 加群申请 事件处理函数
|
||||
if self.bot is not None:
|
||||
self.bot.bus.unsubscribe(MemberJoinRequestEvent, self.process)
|
||||
0
tests/plugin_examples/cmdcn/__init__.py
Normal file
0
tests/plugin_examples/cmdcn/__init__.py
Normal file
51
tests/plugin_examples/cmdcn/cmdcn.py
Normal file
51
tests/plugin_examples/cmdcn/cmdcn.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from pkg.plugin.models import *
|
||||
from pkg.plugin.host import EventContext, PluginHost
|
||||
|
||||
"""
|
||||
基本命令的中文形式支持
|
||||
"""
|
||||
|
||||
|
||||
__mapping__ = {
|
||||
"帮助": "help",
|
||||
"重置": "reset",
|
||||
"前一次": "last",
|
||||
"后一次": "next",
|
||||
"会话内容": "prompt",
|
||||
"列出会话": "list",
|
||||
"重新回答": "resend",
|
||||
"使用量": "usage",
|
||||
"绘画": "draw",
|
||||
"版本": "version",
|
||||
"热重载": "reload",
|
||||
"热更新": "update",
|
||||
"配置": "cfg",
|
||||
}
|
||||
|
||||
|
||||
@register(name="CmdCN", description="命令中文支持", version="0.1", author="RockChinQ")
|
||||
class CmdCnPlugin(Plugin):
|
||||
|
||||
def __init__(self, plugin_host: PluginHost):
|
||||
pass
|
||||
|
||||
# 私聊发送指令
|
||||
@on(PersonCommandSent)
|
||||
def person_command_sent(self, event: EventContext, **kwargs):
|
||||
cmd = kwargs['command']
|
||||
if cmd in __mapping__:
|
||||
|
||||
# 返回替换后的指令
|
||||
event.add_return("alter", "!"+__mapping__[cmd]+" "+" ".join(kwargs['params']))
|
||||
|
||||
# 群聊发送指令
|
||||
@on(GroupCommandSent)
|
||||
def group_command_sent(self, event: EventContext, **kwargs):
|
||||
cmd = kwargs['command']
|
||||
if cmd in __mapping__:
|
||||
|
||||
# 返回替换后的指令
|
||||
event.add_return("alter", "!"+__mapping__[cmd]+" "+" ".join(kwargs['params']))
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
0
tests/plugin_examples/hello_plugin/__init__.py
Normal file
0
tests/plugin_examples/hello_plugin/__init__.py
Normal file
50
tests/plugin_examples/hello_plugin/main.py
Normal file
50
tests/plugin_examples/hello_plugin/main.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from pkg.plugin.models import *
|
||||
from pkg.plugin.host import EventContext, PluginHost
|
||||
|
||||
"""
|
||||
在收到私聊或群聊消息"hello"时,回复"hello, <发送者id>!"或"hello, everyone!"
|
||||
"""
|
||||
|
||||
|
||||
# 注册插件
|
||||
@register(name="Hello", description="hello world", version="0.1", author="RockChinQ")
|
||||
class HelloPlugin(Plugin):
|
||||
|
||||
# 插件加载时触发
|
||||
# plugin_host (pkg.plugin.host.PluginHost) 提供了与主程序交互的一些方法,详细请查看其源码
|
||||
def __init__(self, plugin_host: PluginHost):
|
||||
pass
|
||||
|
||||
# 当收到个人消息时触发
|
||||
@on(PersonNormalMessageReceived)
|
||||
def person_normal_message_received(self, event: EventContext, **kwargs):
|
||||
msg = kwargs['text_message']
|
||||
if msg == "hello": # 如果消息为hello
|
||||
|
||||
# 输出调试信息
|
||||
logging.debug("hello, {}".format(kwargs['sender_id']))
|
||||
|
||||
# 回复消息 "hello, <发送者id>!"
|
||||
event.add_return("reply", ["hello, {}!".format(kwargs['sender_id'])])
|
||||
|
||||
# 阻止该事件默认行为(向接口获取回复)
|
||||
event.prevent_default()
|
||||
|
||||
# 当收到群消息时触发
|
||||
@on(GroupNormalMessageReceived)
|
||||
def group_normal_message_received(self, event: EventContext, **kwargs):
|
||||
msg = kwargs['text_message']
|
||||
if msg == "hello": # 如果消息为hello
|
||||
|
||||
# 输出调试信息
|
||||
logging.debug("hello, {}".format(kwargs['sender_id']))
|
||||
|
||||
# 回复消息 "hello, everyone!"
|
||||
event.add_return("reply", ["hello, everyone!"])
|
||||
|
||||
# 阻止该事件默认行为(向接口获取回复)
|
||||
event.prevent_default()
|
||||
|
||||
# 插件卸载时触发
|
||||
def __del__(self):
|
||||
pass
|
||||
44
tests/plugin_examples/urlikethisijustsix/urlt.py
Normal file
44
tests/plugin_examples/urlikethisijustsix/urlt.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import random
|
||||
|
||||
from mirai import Plain
|
||||
|
||||
from pkg.plugin.models import *
|
||||
from pkg.plugin.host import EventContext, PluginHost
|
||||
|
||||
"""
|
||||
私聊或群聊消息为以下列出的一些冒犯性词语时,自动回复__random_reply__中的一句话
|
||||
"""
|
||||
|
||||
|
||||
__words__ = ['sb', "傻逼", "dinner", "操你妈", "cnm", "fuck you", "fuckyou",
|
||||
"f*ck you", "弱智", "若智", "答辩", "依托答辩", "低能儿", "nt", "脑瘫", "闹谈", "老坛"]
|
||||
|
||||
__random_reply__ = ['好好好', "啊对对对", "好好好好", "你说得对", "谢谢夸奖"]
|
||||
|
||||
|
||||
@register(name="啊对对对", description="你都这样了,我就顺从你吧", version="0.1", author="RockChinQ")
|
||||
class AdddPlugin(Plugin):
|
||||
|
||||
def __init__(self, plugin_host: PluginHost):
|
||||
pass
|
||||
|
||||
# 绑定私聊消息事件和群消息事件
|
||||
@on(PersonNormalMessageReceived)
|
||||
@on(GroupNormalMessageReceived)
|
||||
def normal_message_received(self, event: EventContext, **kwargs):
|
||||
msg = kwargs['text_message']
|
||||
|
||||
# 如果消息中包含关键词
|
||||
if msg in __words__:
|
||||
# 随机一个回复
|
||||
idx = random.randint(0, len(__random_reply__)-1)
|
||||
|
||||
# 返回回复的消息
|
||||
event.add_return("reply", [Plain(__random_reply__[idx])])
|
||||
|
||||
# 阻止向接口获取回复
|
||||
event.prevent_default()
|
||||
event.prevent_postorder()
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user