Compare commits

..

92 Commits

Author SHA1 Message Date
RockChinQ
6919bece77 chore: release v3.1.0.2 2024-03-31 14:41:32 +08:00
RockChinQ
8b003739f1 feat: message.content 支持 mirai.MessageChain 对象 (#741) 2024-03-31 14:38:15 +08:00
RockChinQ
2e9229a6ad fix: 工作目录必须在 main.py 目录 2024-03-30 21:34:22 +08:00
RockChinQ
5a3e7fe8ee perf: 禁止双击运行 2024-03-30 21:28:42 +08:00
RockChinQ
7b3d7e7bd6 fix: json配置文件错误的加载流程 2024-03-30 19:01:59 +08:00
Junyan Qin
fdd7c1864d feat(chatcmpl): 对函数调用进行异常捕获 (#749) 2024-03-30 09:45:30 +00:00
Junyan Qin
cac5a5adff fix(qq-botpy): 群内单query多回复时msg_seq重复问题 2024-03-30 02:58:37 +00:00
RockChinQ
63307633c2 feat: chatcmpl请求时也忽略空的 system prompt message (#745) 2024-03-29 17:34:09 +08:00
RockChinQ
387dfa39ff fix: 内容过滤无效 (#743) 2024-03-29 17:24:42 +08:00
Junyan Qin
1f797f899c doc(README.md): 添加使用量计数徽标 2024-03-26 15:25:08 +08:00
RockChinQ
092bb0a1e2 chore: release v3.1.0.1 2024-03-23 22:50:54 +08:00
RockChinQ
2c3399e237 perf: 敏感词迁移的双条件检查 2024-03-23 22:41:21 +08:00
RockChinQ
835275b47f fix: 多处对 launcher_type 枚举的不当比较 (#736) 2024-03-23 22:39:42 +08:00
Junyan Qin
7b060ce3f9 doc(README.md): 更新wakapi路径 2024-03-23 19:14:43 +08:00
RockChinQ
1fb69311b0 chore: release v3.1.0 2024-03-22 17:17:16 +08:00
Junyan Qin
995d1f61d2 Merge pull request #735 from RockChinQ/feat/plugin-api
Feat: 插件异步 API
2024-03-22 17:10:06 +08:00
RockChinQ
80258e9182 perf: 修改platform_mgr名称 2024-03-22 17:09:43 +08:00
RockChinQ
bd6a32e08e doc: 为可扩展组件添加注释 2024-03-22 16:41:46 +08:00
RockChinQ
5f138de75b doc: 完善query对象的注释 2024-03-22 11:05:58 +08:00
RockChinQ
d0b0f2209a fix: chat处理过程的插件返回值目标错误 2024-03-20 23:32:28 +08:00
RockChinQ
0752698c1d chore: 完善plugin对外对象的注释 2024-03-20 18:43:52 +08:00
RockChinQ
9855c6b8f5 feat: 新的引入路径 2024-03-20 15:48:11 +08:00
RockChinQ
52a7c25540 feat: 异步风格插件方法注册器 2024-03-20 15:09:47 +08:00
RockChinQ
fa823de6b0 perf: 初始化config对象时支持传递dict作为模板 2024-03-20 14:20:56 +08:00
RockChinQ
f53070d8b6 feat: 插件加载阶段前置 (#681) 2024-03-19 22:48:02 +08:00
Junyan Qin
7677672691 Merge pull request #734 from RockChinQ/feat/moonshot
Feat: 添加对 moonshot 模型的支持
2024-03-19 22:41:40 +08:00
RockChinQ
dead8fa168 feat: 添加对 moonshot 模型的支持 2024-03-19 22:39:45 +08:00
RockChinQ
c6347bea45 fix: full-scenario 命名和目录名错误问题 (#731) 2024-03-18 21:05:54 +08:00
RockChinQ
32bd194bfc chore: anthropic 的配置补全迁移 2024-03-18 21:04:09 +08:00
Junyan Qin
cca48a394d Merge pull request #732 from RockChinQ/feat/claude-3
Feat: 接入 claude 3 系列模型
2024-03-18 11:27:22 +08:00
RockChinQ
a723c8ce37 perf: claude 的接口异常处理 2024-03-17 23:22:26 -04:00
RockChinQ
327b2509f6 perf: 忽略用户空消息 2024-03-17 23:06:40 -04:00
RockChinQ
1dae7bd655 feat: 对 claude api 的基本支持 2024-03-17 12:44:45 -04:00
RockChinQ
550a131685 deps: 添加 anthropic 依赖库 2024-03-17 12:03:25 -04:00
RockChinQ
0cfb8bb29f fix: 获取模型列表时未传递version参数 2024-03-16 22:23:02 +08:00
Junyan Qin
9c32420a95 Merge pull request #730 from RockChinQ/feat/customized-model
Feat: 允许自定义模型信息
2024-03-16 22:19:27 +08:00
RockChinQ
867093cc88 chore: 更改 provider.json 格式 2024-03-16 22:12:13 +08:00
RockChinQ
82763f8ec5 chore: 删除默认prompt 2024-03-16 21:43:45 +08:00
RockChinQ
97449065df feat: 通过元数据生成模型列表 2024-03-16 21:43:09 +08:00
Junyan Qin
9489783846 Merge pull request #729 from RockChinQ/feat/migration-stage
Feat: 配置文件迁移功能
2024-03-16 20:34:29 +08:00
RockChinQ
f91c9015bc feat: 添加配置文件迁移阶段 2024-03-16 20:27:17 +08:00
RockChinQ
302d86056d refactor: 所有的 json 加载统一到启动阶段中 2024-03-16 15:41:59 +08:00
Junyan Qin
98bebfddaa Merge pull request #728 from RockChinQ/feat/active-message
Feat: aiocqhttp 和 qq-botpy 适配器的主动消息发送接口
2024-03-16 15:18:27 +08:00
RockChinQ
dab20e3187 feat: aiocqhttp和qq-botpy的主动消息发送接口 2024-03-16 15:16:46 +08:00
RockChinQ
09e72f7c5f chore: 删除注释的代码 2024-03-14 17:24:36 +08:00
Junyan Qin
2028d85f84 Merge pull request #726 from RockChinQ/feat/qq-botpy-cache
Feat: qq-botpy 适配器对 member 和 group 的 openid 进行静态缓存
2024-03-14 16:05:14 +08:00
RockChinQ
ed3c0d9014 feat: qq-botpy 适配器对 member 和 group 的 openid 进行静态缓存 2024-03-14 16:00:22 +08:00
RockChinQ
be06150990 chore: aiocqhttp添加默认access-token参数 2024-03-13 16:53:30 +08:00
Junyan Qin
afb3fb4a31 Merge pull request #725 from RockChinQ/feat/aiocqhttp-access-token
Feat: aiocqhttp支持access-token
2024-03-13 16:49:56 +08:00
RockChinQ
d66577e6c3 feat: aiocqhttp支持access-token 2024-03-13 16:49:11 +08:00
Junyan Qin
6a4ea5446a Merge pull request #724 from RockChinQ/fix/at-resp
Fix: 回复并at机器人时会多一个at组件
2024-03-13 16:31:54 +08:00
RockChinQ
74e84c744a fix: 回复并at机器人时会多一个at组件 2024-03-13 16:31:06 +08:00
Junyan Qin
5ad2446cf3 Update bug-report.yml 2024-03-13 16:13:14 +08:00
Junyan Qin
63303bb5c0 Merge pull request #712 from RockChinQ/feat/component-extensibility
Feat: 更多组件的可扩展性
2024-03-13 00:32:26 +08:00
Junyan Qin
13393b6624 feat: 限速算法的扩展性 2024-03-12 16:31:54 +00:00
Junyan Qin
b9fa11c0c3 feat: prompt 加载器的扩展性 2024-03-12 16:22:07 +00:00
RockChinQ
8c6ce1f030 feat: 群响应规则的扩展性 2024-03-12 23:34:13 +08:00
RockChinQ
1d963d0f0c feat: 不再预先计算前文token数而是在报错时提醒用户重置 2024-03-12 16:04:11 +08:00
Junyan Qin
0ee383be27 Update announcement.json 2024-03-08 22:35:17 +08:00
RockChinQ
53d09129b4 fix: 命令事件的command参数处理错误 (#713) 2024-03-08 21:10:43 +08:00
RockChinQ
a398c6f311 feat: 消息平台适配器可扩展性 2024-03-08 20:40:54 +08:00
RockChinQ
4347ddd42a feat: 长消息处理策略可扩展性 2024-03-08 20:31:22 +08:00
RockChinQ
22cb8a6a06 feat: 内容过滤器的可扩展性 2024-03-08 20:22:06 +08:00
RockChinQ
7f554fd862 feat: command支持扩展命令类 2024-03-08 19:56:57 +08:00
Junyan Qin
a82bfa8a56 perf: 为命令装饰器添加断言 2024-03-08 11:38:26 +00:00
RockChinQ
95784debbf perf: 支持识别docker环境 2024-03-07 15:55:02 +08:00
Junyan Qin
2471c5bf0f Merge pull request #709 from RockChinQ/doc/comments
Doc: 补全部分注释
2024-03-03 16:35:31 +08:00
RockChinQ
2fe6d731b8 doc: 补全部分注释 2024-03-03 16:34:59 +08:00
RockChinQ
ce881372ee chore: release v3.0.2 2024-03-02 21:03:04 +08:00
Junyan Qin
171ea7c375 Merge pull request #708 from RockChinQ/fix/llonebot-not-supported
Fix: 修复使用llonebot时的协议问题
2024-03-02 20:59:41 +08:00
RockChinQ
1e9a6f813f fix: 修复使用llonebot时的协议问题 2024-03-02 20:58:58 +08:00
Junyan Qin
39a7f3b2b9 Merge pull request #707 from RockChinQ/feat/booting-stages
Feat: 分阶段启动
2024-03-02 20:27:51 +08:00
RockChinQ
8d375a02db fix: 未导入问题 2024-03-02 20:05:23 +08:00
RockChinQ
cac8a0a414 perf: 优化导入 2024-03-02 16:39:29 +08:00
RockChinQ
c89623967e refactor: 应用初始化流程初步分阶段 2024-03-02 16:37:30 +08:00
RockChinQ
92aa9c1711 perf: 配置文件生成步骤移动到main.py 2024-03-02 14:57:55 +08:00
Junyan Qin
71f2a58acb feat: 依赖检查移动到main.py 2024-02-29 11:10:30 +00:00
RockChinQ
1f07a8a9e3 refactor: 移动pool到pipeline包 2024-02-29 03:38:38 +00:00
RockChinQ
cacd21bde7 refactor: 移动控制器到pipeline包 2024-02-29 03:38:38 +00:00
RockChinQ
a060ec66c3 deps: 整理依赖 2024-02-29 11:03:11 +08:00
Junyan Qin
fd10db3c75 ci: fix 2024-02-21 13:56:38 +00:00
Junyan Qin
db4c658980 chore: test 2024-02-21 13:52:54 +00:00
Junyan Qin
0ee88674f8 ci: update 2024-02-21 13:52:33 +00:00
Junyan Qin
3540759682 chore: release v3.0.1.1 2024-02-21 13:46:38 +00:00
Junyan Qin
44cc8f15b4 Merge pull request #695 from RockChinQ/ci/arm-image
CI: 构建arm64镜像
2024-02-21 21:45:40 +08:00
Junyan Qin
59f821bf0a ci: 构建arm64镜像 2024-02-21 13:44:07 +00:00
RockChinQ
80858672b0 perf: 控制台输出请求响应过程 2024-02-20 22:56:42 +08:00
RockChinQ
3258d5b255 chore: aiocqhttp默认监听地址改为0.0.0.0 2024-02-20 20:13:46 +08:00
RockChinQ
e8c8cc0a9c chore: release v3.0.1 2024-02-20 11:48:26 +08:00
Junyan Qin
570c19f29f Merge pull request #693 from RockChinQ/fix/3.9-compability
Fix: 针对python3.9的兼容性
2024-02-20 11:47:49 +08:00
RockChinQ
ee93fd8636 hotfix: 针对python3.9的兼容性 2024-02-20 11:47:04 +08:00
RockChinQ
1e6c32ffc7 fix: 'VersionManager' object has no attribute 'get_release_list' 2024-02-20 09:54:02 +08:00
113 changed files with 2153 additions and 975 deletions

View File

@@ -16,11 +16,13 @@ body:
required: true
- type: dropdown
attributes:
label: 登录框架
label: 消息平台适配器
description: "连接QQ使用的框架"
options:
- Mirai
- go-cqhttp
- yiri-miraiMirai
- Nakurugo-cqhttp
- aiocqhttp使用 OneBot 协议接入的)
- qq-botpyQQ官方API
validations:
required: false
- type: input

View File

@@ -17,22 +17,32 @@ jobs:
run: |
if [ -z "$GITHUB_REF" ]; then
export GITHUB_REF=${{ github.ref }}
echo $GITHUB_REF
fi
# - name: Check GITHUB_REF env
# run: echo $GITHUB_REF
# - name: Get version # 在 GitHub Actions 运行环境
# id: get_version
# if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
# run: export GITHUB_REF=${GITHUB_REF/refs\/tags\//}
- name: Check version
id: check_version
run: |
echo $GITHUB_REF
# 如果是tag则去掉refs/tags/前缀
if [[ $GITHUB_REF == refs/tags/* ]]; then
echo "It's a tag"
echo $GITHUB_REF
echo $GITHUB_REF | awk -F '/' '{print $3}'
echo ::set-output name=version::$(echo $GITHUB_REF | awk -F '/' '{print $3}')
else
echo "It's not a tag"
echo $GITHUB_REF
echo ::set-output name=version::${GITHUB_REF}
fi
- name: Check GITHUB_REF env
run: echo $GITHUB_REF
- name: Get version
id: get_version
if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//}
- name: Build # image name: rockchin/qchatgpt:<VERSION>
run: docker build --network=host -t rockchin/qchatgpt:${{ steps.get_version.outputs.VERSION }} -t rockchin/qchatgpt:latest .
- name: Login to Registry
run: docker login --username=${{ secrets.DOCKER_USERNAME }} --password ${{ secrets.DOCKER_PASSWORD }}
- name: Push image
if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
run: docker push rockchin/qchatgpt:${{ steps.get_version.outputs.VERSION }}
- name: Push latest image
if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
run: docker push rockchin/qchatgpt:latest
- name: Create Buildx
run: docker buildx create --name mybuilder --use
- name: Build # image name: rockchin/qchatgpt:<VERSION>
run: docker buildx build --platform linux/arm64,linux/amd64 -t rockchin/qchatgpt:${{ steps.check_version.outputs.version }} -t rockchin/qchatgpt:latest . --push

2
.gitignore vendored
View File

@@ -34,4 +34,4 @@ bard.json
res/instance_id.json
.DS_Store
/data
botpy.log
botpy.log*

View File

@@ -3,6 +3,9 @@ WORKDIR /app
COPY . .
RUN python -m pip install -r requirements.txt
RUN apt update \
&& apt install gcc -y \
&& python -m pip install -r requirements.txt \
&& touch /.dockerenv
CMD [ "python", "main.py" ]

View File

@@ -11,10 +11,8 @@
<a href="https://hub.docker.com/repository/docker/rockchin/qchatgpt">
<img src="https://img.shields.io/docker/pulls/rockchin/qchatgpt?color=blue" alt="docker pull">
</a>
![Wakapi Count](https://wakapi.dev/api/badge/RockChinQ/interval:any/project:QChatGPT)
<a href="https://codecov.io/gh/RockChinQ/QChatGPT" >
<img src="https://codecov.io/gh/RockChinQ/QChatGPT/graph/badge.svg?token=pjxYIL2kbC"/>
</a>
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.qchatgpt.rockchin.top%2Fapi%2Fv2%2Fview%2Frealtime%2Fcount_query%3Fminute%3D10080&query=%24.data.count&label=%E4%BD%BF%E7%94%A8%E9%87%8F%EF%BC%887%E6%97%A5%EF%BC%89)
![Wakapi Count](https://wakapi.rockchin.top/api/badge/RockChinQ/interval:any/project:QChatGPT)
<br/>
<img src="https://img.shields.io/badge/python-3.9 | 3.10 | 3.11-blue.svg" alt="python">
<a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=66-aWvn8cbP4c1ut_1YYkvvGVeEtyTH8&authKey=pTaKBK5C%2B8dFzQ4XlENf6MHTCLaHnlKcCRx7c14EeVVlpX2nRSaS8lJm8YeM4mCU&noverify=0&group_code=195992197">
@@ -23,11 +21,8 @@
<a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=nC80H57wmKPwRDLFeQrDDjVl81XuC21P&authKey=2wTUTfoQ5v%2BD4C5zfpuR%2BSPMDqdXgDXA%2FS2wHI1NxTfWIG%2B%2FqK08dgyjMMOzhXa9&noverify=0&group_code=248432104">
<img alt="Static Badge" src="https://img.shields.io/badge/%E7%A4%BE%E5%8C%BA%E7%BE%A4-248432104-purple">
</a>
<a href="https://www.bilibili.com/video/BV14h4y1w7TC">
<img alt="Static Badge" src="https://img.shields.io/badge/%E8%A7%86%E9%A2%91%E6%95%99%E7%A8%8B-208647">
</a>
<a href="https://www.bilibili.com/video/BV11h4y1y74H">
<img alt="Static Badge" src="https://img.shields.io/badge/Linux%E9%83%A8%E7%BD%B2%E8%A7%86%E9%A2%91-208647">
<a href="https://codecov.io/gh/RockChinQ/QChatGPT" >
<img src="https://codecov.io/gh/RockChinQ/QChatGPT/graph/badge.svg?token=pjxYIL2kbC"/>
</a>
## 使用文档

73
main.py
View File

@@ -1,5 +1,6 @@
import asyncio
# QChatGPT 终端启动入口
# 在此层级解决依赖项检查。
# QChatGPT/main.py
asciiart = r"""
___ ___ _ _ ___ ___ _____
@@ -11,8 +12,72 @@ asciiart = r"""
📖文档地址: https://q.rkcn.top
"""
if __name__ == '__main__':
async def main_entry():
print(asciiart)
import sys
# 检查依赖
from pkg.core.bootutils import deps
missing_deps = await deps.check_deps()
if missing_deps:
print("以下依赖包未安装,将自动安装,请完成后重启程序:")
for dep in missing_deps:
print("-", dep)
await deps.install_deps(missing_deps)
print("已自动安装缺失的依赖包,请重启程序。")
sys.exit(0)
# 检查配置文件
from pkg.core.bootutils import files
generated_files = await files.generate_files()
if generated_files:
print("以下文件不存在,已自动生成,请按需修改配置文件后重启:")
for file in generated_files:
print("-", file)
sys.exit(0)
from pkg.core import boot
asyncio.run(boot.main())
await boot.main()
if __name__ == '__main__':
import os
import psutil
if os.name == 'nt':
allowed_parent_process = ['cmd.exe', 'powershell.exe', 'wsl.exe']
parent_process = psutil.Process(os.getppid()).name()
if parent_process not in allowed_parent_process:
print("请在命令行中运行此程序。")
input("按任意键退出...")
exit(0)
# 检查本目录是否有main.py且包含QChatGPT字符串
invalid_pwd = False
if not os.path.exists('main.py'):
invalid_pwd = True
else:
with open('main.py', 'r', encoding='utf-8') as f:
content = f.read()
if "QChatGPT/main.py" not in content:
invalid_pwd = True
if invalid_pwd:
print("请在QChatGPT项目根目录下运行此程序。")
input("按任意键退出...")
exit(0)
import asyncio
asyncio.run(main_entry())

View File

@@ -34,6 +34,9 @@ class APIGroup(metaclass=abc.ABCMeta):
headers: dict = {},
**kwargs
):
"""
执行请求
"""
self._runtime_info['account_id'] = "-1"
url = self.prefix + path

View File

@@ -1,3 +1,5 @@
# 实例 识别码 控制
import os
import uuid
import json

View File

@@ -7,6 +7,7 @@ from ..provider import entities as llm_entities
from . import entities, operator, errors
from ..config import manager as cfg_mgr
# 引入所有算子以便注册
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update
@@ -17,6 +18,9 @@ class CommandManager:
ap: app.Application
cmd_list: list[operator.CommandOperator]
"""
运行时命令列表,扁平存储,各个对象包含对应的子节点引用
"""
def __init__(self, ap: app.Application):
self.ap = ap
@@ -60,7 +64,7 @@ class CommandManager:
"""
found = False
if len(context.crt_params) > 0:
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
for oper in operator_list:
if (context.crt_params[0] == oper.name \
or context.crt_params[0] in oper.alias) \
@@ -78,7 +82,7 @@ class CommandManager:
yield ret
break
if not found:
if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
if operator is None:
yield entities.CommandReturn(
error=errors.CommandNotFoundError(context.crt_params[0])

View File

@@ -10,6 +10,8 @@ from . import errors, operator
class CommandReturn(pydantic.BaseModel):
"""命令返回值
"""
text: typing.Optional[str]
"""文本
@@ -18,25 +20,52 @@ class CommandReturn(pydantic.BaseModel):
image: typing.Optional[mirai.Image]
error: typing.Optional[errors.CommandError]= None
"""错误
"""
class Config:
arbitrary_types_allowed = True
class ExecuteContext(pydantic.BaseModel):
"""单次命令执行上下文
"""
query: core_entities.Query
"""本次消息的请求对象"""
session: core_entities.Session
"""本次消息所属的会话对象"""
command_text: str
"""命令完整文本"""
command: str
"""命令名称"""
crt_command: str
"""当前命令
多级命令中crt_command为当前命令command为根命令。
例如:!plugin on Webwlkr
处理到plugin时command为plugincrt_command为plugin
处理到on时command为plugincrt_command为on
"""
params: list[str]
"""命令参数
整个命令以空格分割后的参数列表
"""
crt_params: list[str]
"""当前命令参数
多级命令中crt_params为当前命令参数params为根命令参数。
例如:!plugin on Webwlkr
处理到plugin时params为['on', 'Webwlkr']crt_params为['on', 'Webwlkr']
处理到on时params为['on', 'Webwlkr']crt_params为['Webwlkr']
"""
privilege: int
"""发起人权限"""

View File

@@ -8,17 +8,34 @@ from . import entities
preregistered_operators: list[typing.Type[CommandOperator]] = []
"""预注册命令算子列表。在初始化时,所有算子类会被注册到此列表中。"""
def operator_class(
name: str,
help: str,
help: str = "",
usage: str = None,
alias: list[str] = [],
privilege: int=1, # 1为普通用户2为管理员
parent_class: typing.Type[CommandOperator] = None
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
"""命令类装饰器
Args:
name (str): 名称
help (str, optional): 帮助信息. Defaults to "".
usage (str, optional): 使用说明. Defaults to None.
alias (list[str], optional): 别名. Defaults to [].
privilege (int, optional): 权限1为普通用户可用2为仅管理员可用. Defaults to 1.
parent_class (typing.Type[CommandOperator], optional): 父节点若为None则为顶级命令. Defaults to None.
Returns:
typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 装饰器
"""
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
assert issubclass(cls, CommandOperator)
cls.name = name
cls.alias = alias
cls.help = help
@@ -34,7 +51,12 @@ def operator_class(
class CommandOperator(metaclass=abc.ABCMeta):
"""命令算子
"""命令算子抽象类
以下的参数均不需要在子类中设置,只需要在使用装饰器注册类时作为参数传递即可。
命令支持级联,即一个命令可以有多个子命令,子命令可以有子命令,以此类推。
处理命令时,若有子命令,会以当前参数列表的第一个参数去匹配子命令,若匹配成功,则转移到子命令中执行。
若没有匹配成功或没有子命令,则执行当前命令。
"""
ap: app.Application
@@ -43,7 +65,8 @@ class CommandOperator(metaclass=abc.ABCMeta):
"""名称,搜索到时若符合则使用"""
path: str
"""路径所有父节点的name的连接用于定义命令权限"""
"""路径所有父节点的name的连接用于定义命令权限,由管理器在初始化时自动设置。
"""
alias: list[str]
"""同name"""
@@ -52,8 +75,9 @@ class CommandOperator(metaclass=abc.ABCMeta):
"""此节点的帮助信息"""
usage: str = None
"""用法"""
parent_class: typing.Type[CommandOperator] | None = None
parent_class: typing.Union[typing.Type[CommandOperator], None] = None
"""父节点类。标记以供管理器在初始化时编织父子关系。"""
lowest_privilege: int = 0
@@ -75,4 +99,15 @@ class CommandOperator(metaclass=abc.ABCMeta):
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""实现此方法以执行命令
支持多次yield以返回多个结果。
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
Args:
context (entities.ExecuteContext): 命令执行上下文
Yields:
entities.CommandReturn: 命令返回封装
"""
pass

View File

@@ -8,40 +8,47 @@ from .. import model as file_model
class JSONConfigFile(file_model.ConfigFile):
"""JSON配置文件"""
config_file_name: str = None
"""配置文件名"""
template_file_name: str = None
"""模板文件名"""
def __init__(self, config_file_name: str, template_file_name: str) -> None:
def __init__(
self, config_file_name: str, template_file_name: str = None, template_data: dict = None
) -> None:
self.config_file_name = config_file_name
self.template_file_name = template_file_name
self.template_data = template_data
def exists(self) -> bool:
return os.path.exists(self.config_file_name)
async def create(self):
if self.template_file_name is not None:
shutil.copyfile(self.template_file_name, self.config_file_name)
elif self.template_data is not None:
with open(self.config_file_name, "w", encoding="utf-8") as f:
json.dump(self.template_data, f, indent=4, ensure_ascii=False)
else:
raise ValueError("template_file_name or template_data must be provided")
async def load(self) -> dict:
if not self.exists():
await self.create()
with open(self.config_file_name, 'r', encoding='utf-8') as f:
if self.template_file_name is not None:
with open(self.template_file_name, "r", encoding="utf-8") as f:
self.template_data = json.load(f)
with open(self.config_file_name, "r", encoding="utf-8") as f:
cfg = json.load(f)
# 从模板文件中进行补全
with open(self.template_file_name, 'r', encoding='utf-8') as f:
template_cfg = json.load(f)
for key in template_cfg:
for key in self.template_data:
if key not in cfg:
cfg[key] = template_cfg[key]
cfg[key] = self.template_data[key]
return cfg
async def save(self, cfg: dict):
with open(self.config_file_name, 'w', encoding='utf-8') as f:
with open(self.config_file_name, "w", encoding="utf-8") as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)
def save_sync(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)

View File

@@ -60,3 +60,6 @@ class PythonModuleConfigFile(file_model.ConfigFile):
async def save(self, data: dict):
logging.warning('Python模块配置文件不支持保存')
def save_sync(self, data: dict):
logging.warning('Python模块配置文件不支持保存')

View File

@@ -26,6 +26,9 @@ class ConfigManager:
async def dump_config(self):
await self.file.save(self.data)
def dump_config_sync(self):
self.file.save_sync(self.data)
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager:
"""加载Python模块配置文件"""
@@ -40,11 +43,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con
return cfg_mgr
async def load_json_config(config_name: str, template_name: str) -> ConfigManager:
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager:
"""加载JSON配置文件"""
cfg_inst = json_file.JSONConfigFile(
config_name,
template_name
template_name,
template_data
)
cfg_mgr = ConfigManager(cfg_inst)

47
pkg/config/migration.py Normal file
View File

@@ -0,0 +1,47 @@
from __future__ import annotations
import abc
import typing
from ..core import app
preregistered_migrations: list[typing.Type[Migration]] = []
"""当前阶段暂不支持扩展"""
def migration_class(name: str, number: int):
"""注册一个迁移
"""
def decorator(cls: typing.Type[Migration]) -> typing.Type[Migration]:
cls.name = name
cls.number = number
preregistered_migrations.append(cls)
return cls
return decorator
class Migration(abc.ABC):
"""一个版本的迁移
"""
name: str
number: int
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
@abc.abstractmethod
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
pass
@abc.abstractmethod
async def run(self):
"""执行迁移
"""
pass

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
import os
import sys
from .. import migration
@migration.migration_class("sensitive-word-migration", 1)
class SensitiveWordMigration(migration.Migration):
"""敏感词迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return os.path.exists("data/config/sensitive-words.json") and not os.path.exists("data/metadata/sensitive-words.json")
async def run(self):
"""执行迁移
"""
# 移动文件
os.rename("data/config/sensitive-words.json", "data/metadata/sensitive-words.json")
# 重新加载配置
await self.ap.sensitive_meta.load_config()

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("openai-config-migration", 2)
class OpenAIConfigMigration(migration.Migration):
"""OpenAI配置迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'openai-config' in self.ap.provider_cfg.data
async def run(self):
"""执行迁移
"""
old_openai_config = self.ap.provider_cfg.data['openai-config'].copy()
if 'keys' not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data['keys'] = {}
if 'openai' not in self.ap.provider_cfg.data['keys']:
self.ap.provider_cfg.data['keys']['openai'] = []
self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys']
self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model']
del old_openai_config['chat-completions-params']['model']
if 'requester' not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data['requester'] = {}
if 'openai-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {}
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {
'base-url': old_openai_config['base_url'],
'args': old_openai_config['chat-completions-params'],
'timeout': old_openai_config['request-timeout'],
}
del self.ap.provider_cfg.data['openai-config']
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("anthropic-requester-config-completion", 3)
class AnthropicRequesterConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'anthropic-messages' not in self.ap.provider_cfg.data['requester'] \
or 'anthropic' not in self.ap.provider_cfg.data['keys']
async def run(self):
"""执行迁移
"""
if 'anthropic-messages' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['anthropic-messages'] = {
'base-url': 'https://api.anthropic.com',
'args': {
'max_tokens': 1024
},
'timeout': 120,
}
if 'anthropic' not in self.ap.provider_cfg.data['keys']:
self.ap.provider_cfg.data['keys']['anthropic'] = []
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("moonshot-config-completion", 4)
class MoonshotConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester'] \
or 'moonshot' not in self.ap.provider_cfg.data['keys']
async def run(self):
"""执行迁移
"""
if 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['moonshot-chat-completions'] = {
'base-url': 'https://api.moonshot.cn/v1',
'args': {},
'timeout': 120,
}
if 'moonshot' not in self.ap.provider_cfg.data['keys']:
self.ap.provider_cfg.data['keys']['moonshot'] = []
await self.ap.provider_cfg.dump_config()

View File

@@ -10,6 +10,9 @@ class ConfigFile(metaclass=abc.ABCMeta):
template_file_name: str = None
"""模板文件名"""
template_data: dict = None
"""模板数据"""
@abc.abstractmethod
def exists(self) -> bool:
pass
@@ -25,3 +28,7 @@ class ConfigFile(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def save(self, data: dict):
pass
@abc.abstractmethod
def save_sync(self, data: dict):
pass

View File

@@ -4,24 +4,24 @@ import logging
import asyncio
import traceback
import aioconsole
from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.requester import modelmgr as llm_model_mgr
from ..provider.modelmgr import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr
from ..config import manager as config_mgr
from ..audit.center import v2 as center_mgr
from ..command import cmdmgr
from ..plugin import manager as plugin_mgr
from . import pool, controller
from ..pipeline import stagemgr
from ..pipeline import pool
from ..pipeline import controller, stagemgr
from ..utils import version as version_mgr, proxy as proxy_mgr
class Application:
im_mgr: im_mgr.PlatformManager = None
"""运行时应用对象和上下文"""
platform_mgr: im_mgr.PlatformManager = None
cmd_mgr: cmdmgr.CommandManager = None
@@ -33,6 +33,8 @@ class Application:
tool_mgr: llm_tool_mgr.ToolManager = None
# ======= 配置管理器 =======
command_cfg: config_mgr.ConfigManager = None
pipeline_cfg: config_mgr.ConfigManager = None
@@ -43,6 +45,18 @@ class Application:
system_cfg: config_mgr.ConfigManager = None
# ======= 元数据配置管理器 =======
sensitive_meta: config_mgr.ConfigManager = None
adapter_qq_botpy_meta: config_mgr.ConfigManager = None
plugin_setting_meta: config_mgr.ConfigManager = None
llm_models_meta: config_mgr.ConfigManager = None
# =========================
ctr_mgr: center_mgr.V2CenterAPI = None
plugin_mgr: plugin_mgr.PluginManager = None
@@ -66,27 +80,18 @@ class Application:
pass
async def run(self):
await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins()
tasks = []
try:
tasks = [
asyncio.create_task(self.im_mgr.run()),
asyncio.create_task(self.platform_mgr.run()),
asyncio.create_task(self.ctrl.run())
]
# async def interrupt(tasks):
# await asyncio.sleep(1.5)
# while await aioconsole.ainput("使用 ctrl+c 或 'exit' 退出程序 > ") != 'exit':
# pass
# for task in tasks:
# task.cancel()
# await interrupt(tasks)
# 挂信号处理
import signal

View File

@@ -1,143 +1,33 @@
from __future__ import print_function
import os
import sys
from .bootutils import files
from .bootutils import deps
from .bootutils import log
from .bootutils import config
from . import app
from . import pool
from . import controller
from ..pipeline import stagemgr
from ..audit import identifier
from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.requester import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr
from ..platform import manager as im_mgr
from ..command import cmdmgr
from ..plugin import manager as plugin_mgr
from ..audit.center import v2 as center_v2
from ..utils import version, proxy, announce
from . import stage
use_override = False
# 引入启动阶段实现以便注册
from .stages import load_config, setup_logger, build_app, migrate
stage_order = [
"LoadConfigStage",
"MigrationStage",
"SetupLoggerStage",
"BuildAppStage"
]
async def make_app() -> app.Application:
global use_override
generated_files = await files.generate_files()
if generated_files:
print("以下文件不存在,已自动生成,请按需修改配置文件后重启:")
for file in generated_files:
print("-", file)
sys.exit(0)
missing_deps = await deps.check_deps()
if missing_deps:
print("以下依赖包未安装,将自动安装,请完成后重启程序:")
for dep in missing_deps:
print("-", dep)
await deps.install_deps(missing_deps)
sys.exit(0)
qcg_logger = await log.init_logging()
# 生成标识符
identifier.init()
# ========== 加载配置文件 ==========
command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json")
pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json")
platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json")
provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json")
system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json")
# ========== 构建应用实例 ==========
ap = app.Application()
ap.logger = qcg_logger
ap.command_cfg = command_cfg
ap.pipeline_cfg = pipeline_cfg
ap.platform_cfg = platform_cfg
ap.provider_cfg = provider_cfg
ap.system_cfg = system_cfg
proxy_mgr = proxy.ProxyManager(ap)
await proxy_mgr.initialize()
ap.proxy_mgr = proxy_mgr
ver_mgr = version.VersionManager(ap)
await ver_mgr.initialize()
ap.ver_mgr = ver_mgr
center_v2_api = center_v2.V2CenterAPI(
ap,
basic_info={
"host_id": identifier.identifier["host_id"],
"instance_id": identifier.identifier["instance_id"],
"semantic_version": ver_mgr.get_current_version(),
"platform": sys.platform,
},
runtime_info={
"admin_id": "{}".format(system_cfg.data["admin-sessions"]),
"msg_source": str([
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
for adapter_cfg in platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
]),
},
)
ap.ctr_mgr = center_v2_api
# 发送公告
ann_mgr = announce.AnnouncementManager(ap)
await ann_mgr.show_announcements()
ap.query_pool = pool.QueryPool()
await ap.ver_mgr.show_version_update()
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst
cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize()
ap.cmd_mgr = cmd_mgr_inst
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
await llm_model_mgr_inst.initialize()
ap.model_mgr = llm_model_mgr_inst
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
await llm_session_mgr_inst.initialize()
ap.sess_mgr = llm_session_mgr_inst
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
await llm_prompt_mgr_inst.initialize()
ap.prompt_mgr = llm_prompt_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize()
ap.im_mgr = im_mgr_inst
stage_mgr = stagemgr.StageManager(ap)
await stage_mgr.initialize()
ap.stage_mgr = stage_mgr
ctrl = controller.Controller(ap)
ap.ctrl = ctrl
# 执行启动阶段
for stage_name in stage_order:
stage_cls = stage.preregistered_stages[stage_name]
stage_inst = stage_cls()
await stage_inst.run(ap)
await ap.initialize()

View File

@@ -3,11 +3,13 @@ import pip
required_deps = {
"requests": "requests",
"openai": "openai",
"anthropic": "anthropic",
"colorlog": "colorlog",
"mirai": "yiri-mirai-rc",
"aiocqhttp": "aiocqhttp",
"botpy": "qq-botpy",
"PIL": "pillow",
"nakuru": "nakuru-project-idk",
"CallingGPT": "CallingGPT",
"tiktoken": "tiktoken",
"yaml": "pyyaml",
"aiohttp": "aiohttp",

View File

@@ -13,13 +13,13 @@ required_files = {
"data/config/platform.json": "templates/platform.json",
"data/config/provider.json": "templates/provider.json",
"data/config/system.json": "templates/system.json",
"data/config/sensitive-words.json": "templates/sensitive-words.json",
"data/scenario/default.json": "templates/scenario-template.json",
}
required_paths = [
"temp",
"data",
"data/metadata",
"data/prompts",
"data/scenario",
"data/logs",

View File

@@ -16,6 +16,10 @@ log_colors_config = {
async def init_logging() -> logging.Logger:
# 删除所有现有的logger
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
level = logging.INFO
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
@@ -46,7 +50,7 @@ async def init_logging() -> logging.Logger:
qcg_logger.debug("日志初始化完成,日志级别:%s" % level)
logging.basicConfig(
level=logging.INFO, # 设置日志输出格式
level=logging.CRITICAL, # 设置日志输出格式
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
# 日志输出的格式
# -8表示占位符让输出左对齐输出长度都为8位

View File

@@ -9,13 +9,14 @@ import pydantic
import mirai
from ..provider import entities as llm_entities
from ..provider.requester import entities
from ..provider.modelmgr import entities
from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
class LauncherTypes(enum.Enum):
"""一个请求的发起者类型"""
PERSON = 'person'
"""私聊"""
@@ -31,43 +32,43 @@ class Query(pydantic.BaseModel):
"""请求ID添加进请求池时生成"""
launcher_type: LauncherTypes
"""会话类型platform设置"""
"""会话类型platform处理阶段设置"""
launcher_id: int
"""会话IDplatform设置"""
"""会话IDplatform处理阶段设置"""
sender_id: int
"""发送者IDplatform设置"""
"""发送者IDplatform处理阶段设置"""
message_event: mirai.MessageEvent
"""事件platform收到的事件"""
"""事件platform收到的原始事件"""
message_chain: mirai.MessageChain
"""消息链platform收到的消息链"""
"""消息链platform收到的原始消息链"""
adapter: msadapter.MessageSourceAdapter
"""适配器对象"""
"""消息平台适配器对象单个app中可能启用了多个消息平台适配器此对象表明发起此query的适配器"""
session: typing.Optional[Session] = None
"""会话对象,由前置处理器设置"""
"""会话对象,由前置处理器阶段设置"""
messages: typing.Optional[list[llm_entities.Message]] = []
"""历史消息列表,由前置处理器设置"""
"""历史消息列表,由前置处理器阶段设置"""
prompt: typing.Optional[sysprompt_entities.Prompt] = None
"""情景预设内容,由前置处理器设置"""
"""情景预设内容,由前置处理器阶段设置"""
user_message: typing.Optional[llm_entities.Message] = None
"""此次请求的用户消息对象,由前置处理器设置"""
"""此次请求的用户消息对象,由前置处理器阶段设置"""
use_model: typing.Optional[entities.LLMModelInfo] = None
"""使用的模型,由前置处理器设置"""
"""使用的模型,由前置处理器阶段设置"""
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器设置"""
"""使用的函数,由前置处理器阶段设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] = []
"""provider生成的回复消息对象列表"""
"""Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None
"""回复消息链从resp_messages包装而得"""
@@ -77,7 +78,7 @@ class Query(pydantic.BaseModel):
class Conversation(pydantic.BaseModel):
"""对话"""
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation但只有一个当前使用的 Conversation"""
prompt: sysprompt_entities.Prompt
@@ -93,7 +94,7 @@ class Conversation(pydantic.BaseModel):
class Session(pydantic.BaseModel):
"""会话"""
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
launcher_type: LauncherTypes
launcher_id: int
@@ -111,6 +112,7 @@ class Session(pydantic.BaseModel):
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None
"""当前会话的信号量,用于限制并发"""
class Config:
arbitrary_types_allowed = True

34
pkg/core/stage.py Normal file
View File

@@ -0,0 +1,34 @@
from __future__ import annotations
import abc
import typing
from . import app
preregistered_stages: dict[str, typing.Type[BootingStage]] = {}
"""预注册的请求处理阶段。在初始化时,所有请求处理阶段类会被注册到此字典中。
当前阶段暂不支持扩展
"""
def stage_class(
name: str
):
def decorator(cls: typing.Type[BootingStage]) -> typing.Type[BootingStage]:
preregistered_stages[name] = cls
return cls
return decorator
class BootingStage(abc.ABC):
"""启动阶段
"""
name: str = None
@abc.abstractmethod
async def run(self, ap: app.Application):
"""启动
"""
pass

View File

@@ -0,0 +1,96 @@
from __future__ import annotations
import sys
from .. import stage, app
from ...utils import version, proxy, announce, platform
from ...audit.center import v2 as center_v2
from ...audit import identifier
from ...pipeline import pool, controller, stagemgr
from ...plugin import manager as plugin_mgr
from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr
from ...platform import manager as im_mgr
@stage.stage_class("BuildAppStage")
class BuildAppStage(stage.BootingStage):
"""构建应用阶段
"""
async def run(self, ap: app.Application):
"""构建app对象的各个组件对象并初始化
"""
proxy_mgr = proxy.ProxyManager(ap)
await proxy_mgr.initialize()
ap.proxy_mgr = proxy_mgr
ver_mgr = version.VersionManager(ap)
await ver_mgr.initialize()
ap.ver_mgr = ver_mgr
center_v2_api = center_v2.V2CenterAPI(
ap,
basic_info={
"host_id": identifier.identifier["host_id"],
"instance_id": identifier.identifier["instance_id"],
"semantic_version": ver_mgr.get_current_version(),
"platform": platform.get_platform(),
},
runtime_info={
"admin_id": "{}".format(ap.system_cfg.data["admin-sessions"]),
"msg_source": str([
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
for adapter_cfg in ap.platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
]),
},
)
ap.ctr_mgr = center_v2_api
# 发送公告
ann_mgr = announce.AnnouncementManager(ap)
await ann_mgr.show_announcements()
ap.query_pool = pool.QueryPool()
await ap.ver_mgr.show_version_update()
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst
await plugin_mgr_inst.load_plugins()
cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize()
ap.cmd_mgr = cmd_mgr_inst
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
await llm_model_mgr_inst.initialize()
ap.model_mgr = llm_model_mgr_inst
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
await llm_session_mgr_inst.initialize()
ap.sess_mgr = llm_session_mgr_inst
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
await llm_prompt_mgr_inst.initialize()
ap.prompt_mgr = llm_prompt_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize()
ap.platform_mgr = im_mgr_inst
stage_mgr = stagemgr.StageManager(ap)
await stage_mgr.initialize()
ap.stage_mgr = stage_mgr
ctrl = controller.Controller(ap)
ap.ctrl = ctrl

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from .. import stage, app
from ..bootutils import config
@stage.stage_class("LoadConfigStage")
class LoadConfigStage(stage.BootingStage):
"""加载配置文件阶段
"""
async def run(self, ap: app.Application):
"""启动
"""
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json")
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json")
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json")
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json")
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json")
ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json")
await ap.plugin_setting_meta.dump_config()
ap.sensitive_meta = await config.load_json_config("data/metadata/sensitive-words.json", "templates/metadata/sensitive-words.json")
await ap.sensitive_meta.dump_config()
ap.adapter_qq_botpy_meta = await config.load_json_config("data/metadata/adapter-qq-botpy.json", "templates/metadata/adapter-qq-botpy.json")
await ap.adapter_qq_botpy_meta.dump_config()
ap.llm_models_meta = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json")
await ap.llm_models_meta.dump_config()

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
import importlib
from .. import stage, app
from ...config import migration
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
@stage.stage_class("MigrationStage")
class MigrationStage(stage.BootingStage):
"""迁移阶段
"""
async def run(self, ap: app.Application):
"""启动
"""
migrations = migration.preregistered_migrations
# 按照迁移号排序
migrations.sort(key=lambda x: x.number)
for migration_cls in migrations:
migration_instance = migration_cls(ap)
if await migration_instance.need_migrate():
await migration_instance.run()

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from .. import stage, app
from ..bootutils import log
@stage.stage_class("SetupLoggerStage")
class SetupLoggerStage(stage.BootingStage):
"""设置日志器阶段
"""
async def run(self, ap: app.Application):
"""启动
"""
ap.logger = await log.init_logging()

View File

@@ -8,6 +8,7 @@ from ...config import manager as cfg_mgr
@stage.stage_class('BanSessionCheckStage')
class BanSessionCheckStage(stage.PipelineStage):
"""访问控制处理阶段"""
async def initialize(self):
pass
@@ -24,22 +25,24 @@ class BanSessionCheckStage(stage.PipelineStage):
sess_list = self.ap.pipeline_cfg.data['access-control'][mode]
if (query.launcher_type == 'group' and 'group_*' in sess_list) \
or (query.launcher_type == 'person' and 'person_*' in sess_list):
if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \
or (query.launcher_type.value == 'person' and 'person_*' in sess_list):
found = True
else:
for sess in sess_list:
if sess == f"{query.launcher_type}_{query.launcher_id}":
if sess == f"{query.launcher_type.value}_{query.launcher_id}":
found = True
break
result = False
ctn = False
if mode == 'blacklist':
result = found
if mode == 'whitelist':
ctn = found
else:
ctn = not found
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT,
result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT,
new_query=query,
debug_notice=f'根据访问控制忽略消息: {query.launcher_type}_{query.launcher_id}' if result else ''
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else ''
)

View File

@@ -7,28 +7,38 @@ from ...core import app
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter, entities as filter_entities
from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine
@stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage):
"""内容过滤阶段"""
filter_chain: list[filter.ContentFilter]
filter_chain: list[filter_model.ContentFilter]
def __init__(self, ap: app.Application):
self.filter_chain = []
super().__init__(ap)
async def initialize(self):
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
filters_required = [
"content-ignore",
]
if self.ap.pipeline_cfg.data['check-sensitive-words']:
self.filter_chain.append(banwords.BanWordFilter(self.ap))
filters_required.append("ban-word-filter")
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
filters_required.append("baidu-cloud-examine")
for filter in filter_model.preregistered_filters:
if filter.name in filters_required:
self.filter_chain.append(
filter(self.ap)
)
for filter in self.filter_chain:
await filter.initialize()
@@ -125,9 +135,17 @@ class ContentFilterStage(stage.PipelineStage):
query
)
elif stage_inst_name == 'PostContentFilterStage':
# 仅处理 query.resp_messages[-1].content 是 str 的情况
if isinstance(query.resp_messages[-1].content, str):
return await self._post_process(
query.resp_messages[-1].content,
query
)
else:
self.ap.logger.debug(f"resp_messages[-1] 不是 str 类型,跳过内容过滤器检查。")
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')

View File

@@ -31,15 +31,24 @@ class EnableStage(enum.Enum):
class FilterResult(pydantic.BaseModel):
level: ResultLevel
"""结果等级
对于前置处理阶段,只要有任意一个返回 非PASS 的内容过滤器结果,就会中断处理。
对于后置处理阶段,当且内容过滤器返回 BLOCK 时,会中断处理。
"""
replacement: str
"""替换后的消息"""
"""替换后的消息
内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。
若没有修改内容,也需要返回原消息。
"""
user_notice: str
"""不通过时,用户提示消息"""
"""不通过时,若此值不为空,将对用户提示消息"""
console_notice: str
"""不通过时,控制台提示消息"""
"""不通过时,若此值不为空,将在控制台提示消息"""
class ManagerResultLevel(enum.Enum):

View File

@@ -1,12 +1,42 @@
# 内容过滤器的抽象类
from __future__ import annotations
import abc
import typing
from ...core import app
from . import entities
preregistered_filters: list[typing.Type[ContentFilter]] = []
def filter_class(
name: str
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
"""内容过滤器类装饰器
Args:
name (str): 过滤器名称
Returns:
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器
"""
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
assert issubclass(cls, ContentFilter)
cls.name = name
preregistered_filters.append(cls)
return cls
return decorator
class ContentFilter(metaclass=abc.ABCMeta):
"""内容过滤器抽象类"""
name: str
ap: app.Application
@@ -16,6 +46,11 @@ class ContentFilter(metaclass=abc.ABCMeta):
@property
def enable_stages(self):
"""启用的阶段
默认为消息请求AI前后的两个阶段。
entity.EnableStage.PRE: 消息请求AI前此时需要检查的内容是用户的输入消息。
entity.EnableStage.POST: 消息请求AI后此时需要检查的内容是AI的回复消息。
"""
return [
entities.EnableStage.PRE,
@@ -30,5 +65,14 @@ class ContentFilter(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def process(self, message: str) -> entities.FilterResult:
"""处理消息
分为前后阶段,具体取决于 enable_stages 的值。
对于内容过滤器来说,不需要考虑消息所处的阶段,只需要检查消息内容即可。
Args:
message (str): 需要检查的内容
Returns:
entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档
"""
raise NotImplementedError

View File

@@ -10,6 +10,7 @@ BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
@filter_model.filter_class("baidu-cloud-examine")
class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核"""

View File

@@ -6,34 +6,30 @@ from .. import entities
from ....config import manager as cfg_mgr
@filter_model.filter_class("ban-word-filter")
class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言"""
sensitive: cfg_mgr.ConfigManager
async def initialize(self):
self.sensitive = await cfg_mgr.load_json_config(
"data/config/sensitive-words.json",
"templates/sensitive-words.json"
)
pass
async def process(self, message: str) -> entities.FilterResult:
found = False
for word in self.sensitive.data['words']:
for word in self.ap.sensitive_meta.data['words']:
match = re.findall(word, message)
if len(match) > 0:
found = True
for i in range(len(match)):
if self.sensitive.data['mask_word'] == "":
if self.ap.sensitive_meta.data['mask_word'] == "":
message = message.replace(
match[i], self.sensitive.data['mask'] * len(match[i])
match[i], self.ap.sensitive_meta.data['mask'] * len(match[i])
)
else:
message = message.replace(
match[i], self.sensitive.data['mask_word']
match[i], self.ap.sensitive_meta.data['mask_word']
)
return entities.FilterResult(

View File

@@ -5,6 +5,7 @@ from .. import entities
from .. import filter as filter_model
@filter_model.filter_class("content-ignore")
class ContentIgnore(filter_model.ContentFilter):
"""根据内容忽略消息"""

View File

@@ -4,8 +4,8 @@ import asyncio
import typing
import traceback
from . import app, entities
from ..pipeline import entities as pipeline_entities
from ..core import app, entities
from . import entities as pipeline_entities
from ..plugin import events
@@ -68,7 +68,7 @@ class Controller:
"""检查输出
"""
if result.user_notice:
await self.ap.im_mgr.send(
await self.ap.platform_mgr.send(
query.message_event,
result.user_notice,
query.adapter
@@ -85,7 +85,7 @@ class Controller:
stage_index: int,
query: entities.Query,
):
"""从指定阶段开始执行
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
如何看懂这里为什么这么写
去问 GPT-4:

View File

@@ -15,6 +15,8 @@ from ...config import manager as cfg_mgr
@stage.stage_class("LongTextProcessStage")
class LongTextProcessStage(stage.PipelineStage):
"""长消息处理阶段
"""
strategy_impl: strategy.LongTextStrategy
@@ -44,15 +46,29 @@ class LongTextProcessStage(stage.PipelineStage):
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
if config['strategy'] == 'image':
self.strategy_impl = image.Text2ImageStrategy(self.ap)
elif config['strategy'] == 'forward':
self.strategy_impl = forward.ForwardComponentStrategy(self.ap)
for strategy_cls in strategy.preregistered_strategies:
if strategy_cls.name == config['strategy']:
self.strategy_impl = strategy_cls(self.ap)
break
else:
raise ValueError(f"未找到名为 {config['strategy']} 的长消息处理策略")
await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
# 检查是否包含非 Plain 组件
contains_non_plain = False
for msg in query.resp_message_chain:
if not isinstance(msg, Plain):
contains_non_plain = True
break
if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query

View File

@@ -36,6 +36,7 @@ class Forward(MessageComponent):
return '[聊天记录]'
@strategy_model.strategy_class("forward")
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:

View File

@@ -15,6 +15,7 @@ from .. import strategy as strategy_model
from ....core import entities as core_entities
@strategy_model.strategy_class("image")
class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_render_font: ImageFont.FreeTypeFont

View File

@@ -9,7 +9,39 @@ from ...core import app
from ...core import entities as core_entities
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
def strategy_class(
name: str
) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]:
"""长文本处理策略类装饰器
Args:
name (str): 策略名称
Returns:
typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: 装饰器
"""
def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]:
assert issubclass(cls, LongTextStrategy)
cls.name = name
preregistered_strategies.append(cls)
return cls
return decorator
class LongTextStrategy(metaclass=abc.ABCMeta):
"""长文本处理策略抽象类
"""
name: str
ap: app.Application
def __init__(self, ap: app.Application):
@@ -20,4 +52,15 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
"""处理长文本
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
Args:
message (str): 消息
query (core_entities.Query): 此次请求的上下文对象
Returns:
list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表
"""
return []

View File

@@ -4,11 +4,12 @@ import asyncio
import mirai
from . import entities
from ..core import entities
from ..platform import adapter as msadapter
class QueryPool:
"""请求池请求获得调度进入pipeline之前保存在这里"""
query_id_counter: int = 0

View File

@@ -8,7 +8,7 @@ from ...plugin import events
@stage.stage_class("PreProcessor")
class PreProcessor(stage.PipelineStage):
"""预处理
"""请求预处理阶段
"""
async def process(
@@ -51,28 +51,6 @@ class PreProcessor(stage.PipelineStage):
query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt
# 根据模型max_tokens剪裁
max_tokens = min(query.use_model.max_tokens, self.ap.pipeline_cfg.data['submit-messages-tokens'])
test_messages = query.prompt.messages + query.messages + [query.user_message]
while await query.use_model.tokenizer.count_token(test_messages, query.use_model) > max_tokens:
# 前文都pop完了还是大于max_tokens由于prompt和user_messages不能删减报错
if len(query.prompt.messages) == 0:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='输入内容过长,请减少情景预设或者输入内容长度',
console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 submit-messages-tokens 项但不能超过所用模型最大tokens数'
)
query.messages.pop(0) # pop第一个肯定是role=user的
# 继续pop到第二个role=user前一个
while len(query.messages) > 0 and query.messages[0].role != 'user':
query.messages.pop(0)
test_messages = query.prompt.messages + query.messages + [query.user_message]
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query

View File

@@ -23,3 +23,12 @@ class MessageHandler(metaclass=abc.ABCMeta):
query: core_entities.Query,
) -> entities.StageProcessResult:
raise NotImplementedError
def cut_str(self, s: str) -> str:
"""
取字符串第一行最多20个字符若有多行或超过20个字符则加省略号
"""
s0 = s.split('\n')[0]
if len(s0) > 20 or '\n' in s:
s0 = s0[:20] + '...'
return s0

View File

@@ -21,8 +21,6 @@ class ChatMessageHandler(handler.MessageHandler):
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
# 取session
# 取conversation
# 调API
# 生成器
@@ -41,7 +39,14 @@ class ChatMessageHandler(handler.MessageHandler):
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append(
llm_entities.Message(
role='plugin',
content=mc,
)
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
@@ -78,6 +83,8 @@ class ChatMessageHandler(handler.MessageHandler):
async for result in query.use_model.requester.request(query):
query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
if result.content is not None:
text_length += len(result.content)
@@ -86,6 +93,9 @@ class ChatMessageHandler(handler.MessageHandler):
new_query=query
)
except Exception as e:
self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}')
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,

View File

@@ -19,15 +19,16 @@ class CommandHandler(handler.MessageHandler):
"""处理
"""
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent
command_text = str(query.message_chain).strip()[1:]
privilege = 1
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']:
privilege = 2
spt = str(query.message_chain).strip().split(' ')
spt = command_text.split(' ')
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_class(
@@ -73,8 +74,6 @@ class CommandHandler(handler.MessageHandler):
session = await self.ap.sess_mgr.get_session(query)
command_text = str(query.message_chain).strip()[1:]
async for ret in self.ap.cmd_mgr.execute(
command_text=command_text,
query=query,
@@ -91,6 +90,8 @@ class CommandHandler(handler.MessageHandler):
)
)
self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}')
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
@@ -106,6 +107,8 @@ class CommandHandler(handler.MessageHandler):
)
)
self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}')
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query

View File

@@ -11,6 +11,7 @@ from ...config import manager as cfg_mgr
@stage.stage_class("MessageProcessor")
class Processor(stage.PipelineStage):
"""请求实际处理阶段"""
cmd_handler: handler.MessageHandler
@@ -34,7 +35,12 @@ class Processor(stage.PipelineStage):
self.ap.logger.info(f"处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}")
async def generator():
if message_text.startswith('!') or message_text.startswith(''):
return self.cmd_handler.handle(query)
async for result in self.cmd_handler.handle(query):
yield result
else:
return self.chat_handler.handle(query)
async for result in self.chat_handler.handle(query):
yield result
return generator()

View File

@@ -1,10 +1,26 @@
from __future__ import annotations
import abc
import typing
from ...core import app
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
def algo_class(name: str):
def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]:
cls.name = name
preregistered_algos.append(cls)
return cls
return decorator
class ReteLimitAlgo(metaclass=abc.ABCMeta):
"""限流算法抽象类"""
name: str = None
ap: app.Application
@@ -16,9 +32,27 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
"""进入处理流程
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
Args:
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
launcher_id (int): 请求者ID
Returns:
bool: 是否允许进入处理流程若返回false则直接丢弃该请求
"""
raise NotImplementedError
@abc.abstractmethod
async def release_access(self, launcher_type: str, launcher_id: int):
"""退出处理流程
Args:
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
launcher_id (int): 请求者ID
"""
raise NotImplementedError

View File

@@ -19,6 +19,7 @@ class SessionContainer:
self.records = {}
@algo.algo_class("fixwin")
class FixedWindowAlgo(algo.ReteLimitAlgo):
containers_lock: asyncio.Lock

View File

@@ -11,11 +11,24 @@ from ...core import entities as core_entities
@stage.stage_class("RequireRateLimitOccupancy")
@stage.stage_class("ReleaseRateLimitOccupancy")
class RateLimit(stage.PipelineStage):
"""限速器控制阶段"""
algo: algo.ReteLimitAlgo
async def initialize(self):
self.algo = fixedwin.FixedWindowAlgo(self.ap)
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
algo_class = None
for algo_cls in algo.preregistered_algos:
if algo_cls.name == algo_name:
algo_class = algo_cls
break
else:
raise ValueError(f'未知的限速算法: {algo_name}')
self.algo = algo_class(self.ap)
await self.algo.initialize()
async def process(
@@ -46,7 +59,7 @@ class RateLimit(stage.PipelineStage):
)
elif stage_inst_name == "ReleaseRateLimitOccupancy":
await self.algo.release_access(
query.launcher_type,
query.launcher_type.value,
query.launcher_id,
)
return entities.StageProcessResult(

View File

@@ -29,7 +29,7 @@ class SendResponseBackStage(stage.PipelineStage):
await asyncio.sleep(random_delay)
await self.ap.im_mgr.send(
await self.ap.platform_mgr.send(
query.message_event,
query.resp_message_chain,
adapter=query.adapter

View File

@@ -21,15 +21,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
async def initialize(self):
"""初始化检查器
"""
self.rule_matchers = [
atbot.AtBotRule(self.ap),
prefix.PrefixRule(self.ap),
regexp.RegExpRule(self.ap),
random.RandomRespRule(self.ap),
]
for rule_matcher in self.rule_matchers:
await rule_matcher.initialize()
self.rule_matchers = []
for rule_matcher in rule.preregisetered_rules:
rule_inst = rule_matcher(self.ap)
await rule_inst.initialize()
self.rule_matchers.append(rule_inst)
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import abc
import typing
import mirai
@@ -7,9 +8,20 @@ from ...core import app, entities as core_entities
from . import entities
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
def rule_class(name: str):
def decorator(cls: typing.Type[GroupRespondRule]) -> typing.Type[GroupRespondRule]:
cls.name = name
preregisetered_rules.append(cls)
return cls
return decorator
class GroupRespondRule(metaclass=abc.ABCMeta):
"""群组响应规则的抽象类
"""
name: str
ap: app.Application

View File

@@ -7,6 +7,7 @@ from .. import entities
from ....core import entities as core_entities
@rule_model.rule_class("at-bot")
class AtBotRule(rule_model.GroupRespondRule):
async def match(
@@ -19,6 +20,10 @@ class AtBotRule(rule_model.GroupRespondRule):
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(query.adapter.bot_account_id))
if message_chain.has(mirai.At(query.adapter.bot_account_id)): # 回复消息时会at两次检查并删除重复的
message_chain.remove(mirai.At(query.adapter.bot_account_id))
return entities.RuleJudgeResult(
matching=True,
replacement=message_chain,

View File

@@ -5,6 +5,7 @@ from .. import entities
from ....core import entities as core_entities
@rule_model.rule_class("prefix")
class PrefixRule(rule_model.GroupRespondRule):
async def match(

View File

@@ -7,6 +7,7 @@ from .. import entities
from ....core import entities as core_entities
@rule_model.rule_class("random")
class RandomRespRule(rule_model.GroupRespondRule):
async def match(

View File

@@ -7,6 +7,7 @@ from .. import entities
from ....core import entities as core_entities
@rule_model.rule_class("regexp")
class RegExpRule(rule_model.GroupRespondRule):
async def match(

View File

@@ -15,6 +15,7 @@ from .preproc import preproc
from .ratelimit import ratelimit
# 请求处理阶段顺序
stage_order = [
"GroupRespondRuleCheckStage",
"BanSessionCheckStage",

View File

@@ -29,6 +29,16 @@ class ResponseWrapper(stage.PipelineStage):
if query.resp_messages[-1].role == 'command':
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif query.resp_messages[-1].role == 'plugin':
if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
else:
query.resp_message_chain = query.resp_messages[-1].content
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query

View File

@@ -14,6 +14,14 @@ preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = []
def adapter_class(
name: str
):
"""消息平台适配器类装饰器
Args:
name (str): 适配器名称
Returns:
typing.Callable[[typing.Type[MessageSourceAdapter]], typing.Type[MessageSourceAdapter]]: 装饰器
"""
def decorator(cls: typing.Type[MessageSourceAdapter]) -> typing.Type[MessageSourceAdapter]:
cls.name = name
preregistered_adapters.append(cls)
@@ -22,15 +30,24 @@ def adapter_class(
class MessageSourceAdapter(metaclass=abc.ABCMeta):
"""消息平台适配器基类"""
name: str
bot_account_id: int
"""机器人账号ID需要在初始化时设置"""
config: dict
ap: app.Application
def __init__(self, config: dict, ap: app.Application):
"""初始化适配器
Args:
config (dict): 对应的配置
ap (app.Application): 应用上下文
"""
self.config = config
self.ap = ap
@@ -40,7 +57,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
target_id: str,
message: mirai.MessageChain
):
"""发送消息
"""主动发送消息
Args:
target_type (str): 目标类型,`person`或`group`

View File

@@ -146,7 +146,7 @@ class PlatformManager:
if len(self.adapters) == 0:
self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。')
async def send(self, event, msg, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage):
@@ -163,25 +163,6 @@ class PlatformManager:
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False
)
# 通知系统管理员
# TODO delete
# async def notify_admin(self, message: str):
# await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
# async def notify_admin_message_chain(self, message: mirai.MessageChain):
# if self.ap.system_cfg.data['admin-sessions'] != []:
# admin_list = []
# for admin in self.ap.system_cfg.data['admin-sessions']:
# admin_list.append(admin)
# for adm in admin_list:
# self.adapter.send_message(
# adm.split("_")[0],
# adm.split("_")[1],
# message
# )
async def run(self):
try:
tasks = []

View File

@@ -40,7 +40,6 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif type(msg) is mirai.Voice:
msg_list.append(aiocqhttp.MessageSegment.record(msg.path))
elif type(msg) is forward.Forward:
# print("aiocqhttp 暂不支持转发消息组件的转换,使用普通消息链发送")
for node in msg.node_list:
msg_list.extend(AiocqhttpMessageConverter.yiri2target(node.message_chain)[0])
@@ -170,7 +169,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
name=event.sender["nickname"],
permission=mirai.models.entities.Permission.Member,
),
special_title=event.sender["title"],
special_title=event.sender["title"] if "title" in event.sender else "",
join_timestamp=0,
last_speak_timestamp=0,
mute_time_remaining=0,
@@ -216,13 +215,21 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
self.ap = ap
if "access-token" in config:
self.bot = aiocqhttp.CQHttp(access_token=config["access-token"])
del self.config["access-token"]
else:
self.bot = aiocqhttp.CQHttp()
async def send_message(
self, target_type: str, target_id: str, message: mirai.MessageChain
):
# TODO 实现发送消息
return super().send_message(target_type, target_id, message)
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
if target_type == "group":
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
elif target_type == "person":
await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg)
async def reply_message(
self,
@@ -230,7 +237,6 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
message: mirai.MessageChain,
quote_origin: bool = False,
):
aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
if quote_origin:

View File

@@ -24,6 +24,8 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
msg_list = message_chain.__root__
elif type(message_chain) is list:
msg_list = message_chain
elif type(message_chain) is str:
msg_list = [mirai.Plain(message_chain)]
else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))

View File

@@ -17,6 +17,7 @@ import botpy.types.message as botpy_message_type
from .. import adapter as adapter_model
from ...pipeline.longtext.strategies import forward
from ...core import app
from ...config import manager as cfg_mgr
class OfficialGroupMessage(mirai.GroupMessage):
@@ -34,6 +35,7 @@ cached_message_ids = {}
id_index = 0
def save_msg_id(message_id: str) -> int:
"""保存消息id"""
global id_index, cached_message_ids
@@ -43,43 +45,82 @@ def save_msg_id(message_id: str) -> int:
cached_message_ids[str(crt_index)] = message_id
return crt_index
cached_member_openids = {}
"""QQ官方 用户的id是字符串而YiriMirai的用户id是整数所以需要一个索引来进行转换"""
member_openid_index = 100
def char_to_value(char):
"""将单个字符转换为相应的数值。"""
if '0' <= char <= '9':
return ord(char) - ord('0')
elif 'A' <= char <= 'Z':
return ord(char) - ord('A') + 10
def save_member_openid(member_openid: str) -> int:
"""保存用户id"""
global member_openid_index, cached_member_openids
return ord(char) - ord('a') + 36
if member_openid in cached_member_openids.values():
return list(cached_member_openids.keys())[list(cached_member_openids.values()).index(member_openid)]
def digest(s: str) -> int:
"""计算字符串的hash值。"""
# 取末尾的8位
sub_s = s[-10:]
crt_index = member_openid_index
member_openid_index += 1
cached_member_openids[str(crt_index)] = member_openid
return crt_index
number = 0
base = 36
cached_group_openids = {}
"""QQ官方 群组的id是字符串而YiriMirai的群组id是整数所以需要一个索引来进行转换"""
for i in range(len(sub_s)):
number = number * base + char_to_value(sub_s[i])
group_openid_index = 1000
return number
def save_group_openid(group_openid: str) -> int:
"""保存群组id"""
global group_openid_index, cached_group_openids
K = typing.TypeVar("K")
V = typing.TypeVar("V")
if group_openid in cached_group_openids.values():
return list(cached_group_openids.keys())[list(cached_group_openids.values()).index(group_openid)]
crt_index = group_openid_index
group_openid_index += 1
cached_group_openids[str(crt_index)] = group_openid
return crt_index
class OpenIDMapping(typing.Generic[K, V]):
map: dict[K, V]
dump_func: typing.Callable
digest_func: typing.Callable[[K], V]
def __init__(self, map: dict[K, V], dump_func: typing.Callable, digest_func: typing.Callable[[K], V] = digest):
self.map = map
self.dump_func = dump_func
self.digest_func = digest_func
def __getitem__(self, key: K) -> V:
return self.map[key]
def __setitem__(self, key: K, value: V):
self.map[key] = value
self.dump_func()
def __contains__(self, key: K) -> bool:
return key in self.map
def __delitem__(self, key: K):
del self.map[key]
self.dump_func()
def getkey(self, value: V) -> K:
return list(self.map.keys())[list(self.map.values()).index(value)]
def save_openid(self, key: K) -> V:
if key in self.map:
return self.map[key]
value = self.digest_func(key)
self.map[key] = value
self.dump_func()
return value
class OfficialMessageConverter(adapter_model.MessageConverter):
"""QQ 官方消息转换器"""
@staticmethod
def yiri2target(message_chain: mirai.MessageChain):
"""将 YiriMirai 的消息链转换为 QQ 官方消息"""
@@ -89,8 +130,12 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
msg_list = message_chain.__root__
elif type(message_chain) is list:
msg_list = message_chain
elif type(message_chain) is str:
msg_list = [mirai.Plain(text=message_chain)]
else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
raise Exception(
"Unknown message type: " + str(message_chain) + str(type(message_chain))
)
offcial_messages: list[dict] = []
"""
@@ -108,36 +153,24 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
# 遍历并转换
for component in msg_list:
if type(component) is mirai.Plain:
offcial_messages.append({
"type": "text",
"content": component.text
})
offcial_messages.append({"type": "text", "content": component.text})
elif type(component) is mirai.Image:
if component.url is not None:
offcial_messages.append(
{
"type": "image",
"content": component.url
}
)
offcial_messages.append({"type": "image", "content": component.url})
elif component.path is not None:
offcial_messages.append(
{
"type": "file_image",
"content": component.path
}
{"type": "file_image", "content": component.path}
)
elif type(component) is mirai.At:
offcial_messages.append(
{
"type": "at",
"content": ""
}
)
offcial_messages.append({"type": "at", "content": ""})
elif type(component) is mirai.AtAll:
print("上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。")
print(
"上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
)
elif type(component) is mirai.Voice:
print("上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。")
print(
"上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
)
elif type(component) is forward.Forward:
# 转发消息
yiri_forward_node_list = component.node_list
@@ -148,20 +181,31 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
message_chain = yiri_forward_node.message_chain
# 平铺
offcial_messages.extend(OfficialMessageConverter.yiri2target(message_chain))
offcial_messages.extend(
OfficialMessageConverter.yiri2target(message_chain)
)
except Exception as e:
import traceback
traceback.print_exc()
return offcial_messages
@staticmethod
def extract_message_chain_from_obj(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage], message_id: str = None, bot_account_id: int = 0) -> mirai.MessageChain:
def extract_message_chain_from_obj(
message: typing.Union[botpy_message.Message, botpy_message.DirectMessage],
message_id: str = None,
bot_account_id: int = 0,
) -> mirai.MessageChain:
yiri_msg_list = []
# 存id
yiri_msg_list.append(mirai.models.message.Source(id=save_msg_id(message_id), time=datetime.datetime.now()))
yiri_msg_list.append(
mirai.models.message.Source(
id=save_msg_id(message_id), time=datetime.datetime.now()
)
)
if type(message) is not botpy_message.DirectMessage:
yiri_msg_list.append(mirai.At(target=bot_account_id))
@@ -177,7 +221,9 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
if attachment.content_type == "image":
yiri_msg_list.append(mirai.Image(url=attachment.url))
else:
logging.warning("不支持的附件类型:" + attachment.content_type + ",忽略此附件。")
logging.warning(
"不支持的附件类型:" + attachment.content_type + ",忽略此附件。"
)
content = re.sub(r"<@!\d+>", "", str(message.content))
if content.strip() != "":
@@ -190,25 +236,36 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
class OfficialEventConverter(adapter_model.EventConverter):
"""事件转换器"""
@staticmethod
def yiri2target(event: typing.Type[mirai.Event]):
member_openid_mapping: OpenIDMapping[str, int]
group_openid_mapping: OpenIDMapping[str, int]
def __init__(self, member_openid_mapping: OpenIDMapping[str, int], group_openid_mapping: OpenIDMapping[str, int]):
self.member_openid_mapping = member_openid_mapping
self.group_openid_mapping = group_openid_mapping
def yiri2target(self, event: typing.Type[mirai.Event]):
if event == mirai.GroupMessage:
return botpy_message.Message
elif event == mirai.FriendMessage:
return botpy_message.DirectMessage
else:
raise Exception("未支持转换的事件类型(YiriMirai -> Official): " + str(event))
raise Exception(
"未支持转换的事件类型(YiriMirai -> Official): " + str(event)
)
@staticmethod
def target2yiri(event: typing.Union[botpy_message.Message, botpy_message.DirectMessage]) -> mirai.Event:
def target2yiri(
self,
event: typing.Union[botpy_message.Message, botpy_message.DirectMessage]
) -> mirai.Event:
import mirai.models.entities as mirai_entities
if type(event) == botpy_message.Message: # 频道内,转群聊事件
permission = "MEMBER"
if '2' in event.member.roles:
if "2" in event.member.roles:
permission = "ADMINISTRATOR"
elif '4' in event.member.roles:
elif "4" in event.member.roles:
permission = "OWNER"
return mirai.GroupMessage(
@@ -219,15 +276,25 @@ class OfficialEventConverter(adapter_model.EventConverter):
group=mirai_entities.Group(
id=event.channel_id,
name=event.author.username,
permission=mirai_entities.Permission.Member
permission=mirai_entities.Permission.Member,
),
special_title="",
join_timestamp=int(
datetime.datetime.strptime(
event.member.joined_at, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
),
special_title='',
join_timestamp=int(datetime.datetime.strptime(event.member.joined_at, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
last_speak_timestamp=datetime.datetime.now().timestamp(),
mute_time_remaining=0,
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
),
)
elif type(event) == botpy_message.DirectMessage: # 私聊,转私聊事件
return mirai.FriendMessage(
@@ -236,12 +303,18 @@ class OfficialEventConverter(adapter_model.EventConverter):
nickname=event.author.username,
remark=event.author.username,
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
),
)
elif type(event) == botpy_message.GroupMessage:
replacing_member_id = save_member_openid(event.author.member_openid)
replacing_member_id = self.member_openid_mapping.save_openid(event.author.member_openid)
return OfficialGroupMessage(
sender=mirai_entities.GroupMember(
@@ -249,29 +322,36 @@ class OfficialEventConverter(adapter_model.EventConverter):
member_name=replacing_member_id,
permission="MEMBER",
group=mirai_entities.Group(
id=save_group_openid(event.group_openid),
id=self.group_openid_mapping.save_openid(event.group_openid),
name=replacing_member_id,
permission=mirai_entities.Permission.Member
permission=mirai_entities.Permission.Member,
),
special_title='',
special_title="",
join_timestamp=int(0),
last_speak_timestamp=datetime.datetime.now().timestamp(),
mute_time_remaining=0,
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
),
)
@adapter_model.adapter_class("qq-botpy")
class OfficialAdapter(adapter_model.MessageSourceAdapter):
"""QQ 官方消息适配器"""
bot: botpy.Client = None
bot_account_id: int = 0
message_converter: OfficialMessageConverter = OfficialMessageConverter()
# event_handler: adapter_model.EventHandler = adapter_model.EventHandler()
message_converter: OfficialMessageConverter
event_converter: OfficialEventConverter
cfg: dict = None
@@ -283,78 +363,110 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
ap: app.Application
metadata: cfg_mgr.ConfigManager = None
member_openid_mapping: OpenIDMapping[str, int] = None
group_openid_mapping: OpenIDMapping[str, int] = None
group_msg_seq = None
def __init__(self, cfg: dict, ap: app.Application):
"""初始化适配器"""
self.cfg = cfg
self.ap = ap
self.group_msg_seq = 1
switchs = {}
for intent in cfg['intents']:
for intent in cfg["intents"]:
switchs[intent] = True
del cfg['intents']
del cfg["intents"]
intents = botpy.Intents(**switchs)
self.bot = botpy.Client(intents=intents)
async def send_message(
self,
target_type: str,
target_id: str,
message: mirai.MessageChain
self, target_type: str, target_id: str, message: mirai.MessageChain
):
pass
message_list = self.message_converter.yiri2target(message)
for msg in message_list:
args = {}
if msg["type"] == "text":
args["content"] = msg["content"]
elif msg["type"] == "image":
args["image"] = msg["content"]
elif msg["type"] == "file_image":
args["file_image"] = msg["content"]
else:
continue
if target_type == "group":
args["channel_id"] = str(target_id)
await self.bot.api.post_message(**args)
elif target_type == "person":
args["guild_id"] = str(target_id)
await self.bot.api.post_dms(**args)
async def reply_message(
self,
message_source: mirai.MessageEvent,
message: mirai.MessageChain,
quote_origin: bool = False
quote_origin: bool = False,
):
message_list = self.message_converter.yiri2target(message)
tasks = []
msg_seq = 1
message_list = self.message_converter.yiri2target(message)
for msg in message_list:
args = {}
if msg['type'] == 'text':
args['content'] = msg['content']
elif msg['type'] == 'image':
args['image'] = msg['content']
elif msg['type'] == 'file_image':
args['file_image'] = msg["content"]
if msg["type"] == "text":
args["content"] = msg["content"]
elif msg["type"] == "image":
args["image"] = msg["content"]
elif msg["type"] == "file_image":
args["file_image"] = msg["content"]
else:
continue
if quote_origin:
args['message_reference'] = botpy_message_type.Reference(message_id=cached_message_ids[str(message_source.message_chain.message_id)])
if type(message_source) == mirai.GroupMessage:
args['channel_id'] = str(message_source.sender.group.id)
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
await self.bot.api.post_message(**args)
elif type(message_source) == mirai.FriendMessage:
args['guild_id'] = str(message_source.sender.id)
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
await self.bot.api.post_dms(**args)
elif type(message_source) == OfficialGroupMessage:
# args['guild_id'] = str(message_source.sender.group.id)
# args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
# await self.bot.api.post_message(**args)
if 'image' in args or 'file_image' in args:
continue
args['group_openid'] = cached_group_openids[str(message_source.sender.group.id)]
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
args['msg_seq'] = msg_seq
msg_seq += 1
await self.bot.api.post_group_message(
**args
args["message_reference"] = botpy_message_type.Reference(
message_id=cached_message_ids[
str(message_source.message_chain.message_id)
]
)
if type(message_source) == mirai.GroupMessage:
args["channel_id"] = str(message_source.sender.group.id)
args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id)
]
await self.bot.api.post_message(**args)
elif type(message_source) == mirai.FriendMessage:
args["guild_id"] = str(message_source.sender.id)
args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id)
]
await self.bot.api.post_dms(**args)
elif type(message_source) == OfficialGroupMessage:
if "image" in args or "file_image" in args:
continue
args["group_openid"] = self.group_openid_mapping.getkey(
message_source.sender.group.id
)
args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id)
]
args["msg_seq"] = self.group_msg_seq
self.group_msg_seq += 1
await self.bot.api.post_group_message(**args)
async def is_muted(self, group_id: int) -> bool:
return False
@@ -362,14 +474,22 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
callback: typing.Callable[
[mirai.Event, adapter_model.MessageSourceAdapter], None
],
):
try:
async def wrapper(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage]):
async def wrapper(
message: typing.Union[
botpy_message.Message,
botpy_message.DirectMessage,
botpy_message.GroupMessage,
]
):
self.cached_official_messages[str(message.id)] = message
await callback(OfficialEventConverter.target2yiri(message), self)
await callback(self.event_converter.target2yiri(message), self)
for event_handler in event_handler_mapping[event_type]:
setattr(self.bot, event_handler, wrapper)
@@ -380,15 +500,33 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
callback: typing.Callable[
[mirai.Event, adapter_model.MessageSourceAdapter], None
],
):
delattr(self.bot, event_handler_mapping[event_type])
async def run_async(self):
self.ap.logger.info("运行 QQ 官方适配器")
await self.bot.start(
**self.cfg
self.metadata = self.ap.adapter_qq_botpy_meta
self.member_openid_mapping = OpenIDMapping(
map=self.metadata.data["mapping"]["members"],
dump_func=self.metadata.dump_config_sync,
)
self.group_openid_mapping = OpenIDMapping(
map=self.metadata.data["mapping"]["groups"],
dump_func=self.metadata.dump_config_sync,
)
self.message_converter = OfficialMessageConverter()
self.event_converter = OfficialEventConverter(
self.member_openid_mapping, self.group_openid_mapping
)
self.ap.logger.info("运行 QQ 官方适配器")
await self.bot.start(**self.cfg)
def kill(self) -> bool:
return False

View File

@@ -9,10 +9,86 @@ from ..provider.tools import entities as tools_entities
from ..core import app
def register(
name: str,
description: str,
version: str,
author: str
) -> typing.Callable[[typing.Type[BasePlugin]], typing.Type[BasePlugin]]:
"""注册插件类
使用示例:
@register(
name="插件名称",
description="插件描述",
version="插件版本",
author="插件作者"
)
class MyPlugin(BasePlugin):
pass
"""
pass
def handler(
event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册事件监听器
使用示例:
class MyPlugin(BasePlugin):
@handler(NormalMessageResponded)
async def on_normal_message_responded(self, ctx: EventContext):
pass
"""
pass
def llm_func(
name: str=None,
) -> typing.Callable:
"""注册内容函数
使用示例:
class MyPlugin(BasePlugin):
@llm_func("access_the_web_page")
async def _(self, query, url: str, brief_len: int):
\"""Call this function to search about the question before you answer any questions.
- Do not search through google.com at any time.
- If you need to search somthing, visit https://www.sogou.com/web?query=<something>.
- If user ask you to open a url (start with http:// or https://), visit it directly.
- Summary the plain content result by yourself, DO NOT directly output anything in the result you got.
Args:
url(str): url to visit
brief_len(int): max length of the plain text content, recommend 1024-4096, prefer 4096
Returns:
str: plain text content of the web page or error message(starts with 'error:')
\"""
"""
pass
class BasePlugin(metaclass=abc.ABCMeta):
"""插件基类"""
host: APIHost
"""API宿主"""
ap: app.Application
"""应用程序对象"""
def __init__(self, host: APIHost):
self.host = host
async def initialize(self):
"""初始化插件"""
pass
class APIHost:
@@ -61,8 +137,10 @@ class EventContext:
"""事件编号"""
host: APIHost = None
"""API宿主"""
event: events.BaseEventModel = None
"""此次事件的对象具体类型为handler注册时指定监听的类型可查看events.py中的定义"""
__prevent_default__ = False
"""是否阻止默认行为"""

View File

@@ -10,8 +10,10 @@ from ..provider import entities as llm_entities
class BaseEventModel(pydantic.BaseModel):
"""事件模型基类"""
query: core_entities.Query | None
query: typing.Union[core_entities.Query, None]
"""此次请求的query对象非请求过程的事件时为None"""
class Config:
arbitrary_types_allowed = True

View File

@@ -1,3 +1,7 @@
# 此模块已过时
# 请从 pkg.plugin.context 引入 BasePlugin, EventContext 和 APIHost
# 最早将于 v3.4 移除此模块
from . events import *
from . context import EventContext, APIHost as PluginHost

View File

@@ -7,6 +7,7 @@ from ..core import app
class PluginInstaller(metaclass=abc.ABCMeta):
"""插件安装器抽象类"""
ap: app.Application

View File

@@ -12,6 +12,8 @@ from ...utils import pkgmgr
class GitHubRepoInstaller(installer.PluginInstaller):
"""GitHub仓库插件安装器
"""
def get_github_plugin_repo_label(self, repo_url: str) -> list[str]:
"""获取username, repo"""

View File

@@ -9,7 +9,7 @@ from . import context, events
class PluginLoader(metaclass=abc.ABCMeta):
"""插件加载器"""
"""插件加载器抽象类"""
ap: app.Application

View File

@@ -5,11 +5,10 @@ import pkgutil
import importlib
import traceback
from CallingGPT.entities.namespace import get_func_schema
from .. import loader, events, context, models, host
from ...core import entities as core_entities
from ...provider.tools import entities as tools_entities
from ...utils import funcschema
class PluginLoader(loader.PluginLoader):
@@ -29,6 +28,10 @@ class PluginLoader(loader.PluginLoader):
setattr(models, 'on', self.on)
setattr(models, 'func', self.func)
setattr(context, 'register', self.register)
setattr(context, 'handler', self.handler)
setattr(context, 'llm_func', self.llm_func)
def register(
self,
name: str,
@@ -57,6 +60,8 @@ class PluginLoader(loader.PluginLoader):
return wrapper
# 过时
# 最早将于 v3.4 版本移除
def on(
self,
event: typing.Type[events.BaseEventModel]
@@ -83,6 +88,8 @@ class PluginLoader(loader.PluginLoader):
return wrapper
# 过时
# 最早将于 v3.4 版本移除
def func(
self,
name: str=None,
@@ -91,10 +98,11 @@ class PluginLoader(loader.PluginLoader):
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = get_func_schema(func)
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
async def handler(
plugin: context.BasePlugin,
query: core_entities.Query,
*args,
**kwargs
@@ -116,6 +124,46 @@ class PluginLoader(loader.PluginLoader):
return wrapper
def handler(
self,
event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册事件处理器"""
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
def wrapper(func: typing.Callable) -> typing.Callable:
self._current_container.event_handlers[event] = func
return func
return wrapper
def llm_func(
self,
name: str=None,
) -> typing.Callable:
"""注册内容函数"""
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
llm_function = tools_entities.LLMFunction(
name=function_name,
human_desc='',
description=function_schema['description'],
enable=True,
parameters=function_schema['parameters'],
func=func,
)
self._current_container.content_functions.append(llm_function)
return func
return wrapper
async def _walk_plugin_path(
self,
module,

View File

@@ -5,11 +5,12 @@ import traceback
from ..core import app
from . import context, loader, events, installer, setting, models
from .loaders import legacy
from .loaders import classic
from .installers import github
class PluginManager:
"""插件管理器"""
ap: app.Application
@@ -25,7 +26,7 @@ class PluginManager:
def __init__(self, ap: app.Application):
self.ap = ap
self.loader = legacy.PluginLoader(ap)
self.loader = classic.PluginLoader(ap)
self.installer = github.GitHubRepoInstaller(ap)
self.setting = setting.SettingManager(ap)
self.api_host = context.APIHost(ap)
@@ -51,6 +52,9 @@ class PluginManager:
for plugin in self.plugins:
try:
plugin.plugin_inst = plugin.plugin_class(self.api_host)
plugin.plugin_inst.ap = self.ap
plugin.plugin_inst.host = self.api_host
await plugin.plugin_inst.initialize()
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
self.ap.logger.exception(e)
@@ -136,8 +140,7 @@ class PluginManager:
for plugin in self.plugins:
if plugin.enabled:
if event.__class__ in plugin.event_handlers:
emitted_plugins.append(plugin)
self.ap.logger.debug(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__}')
is_prevented_default_before_call = ctx.is_prevented_default()
@@ -150,6 +153,8 @@ class PluginManager:
self.ap.logger.error(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__} 时发生错误: {e}')
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
emitted_plugins.append(plugin)
if not is_prevented_default_before_call and ctx.is_prevented_default():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')

View File

@@ -1,3 +1,7 @@
# 此模块已过时,请引入 pkg.plugin.context 中的 register, handler 和 llm_func 来注册插件、事件处理函数和内容函数
# 各个事件模型请从 pkg.plugin.events 引入
# 最早将于 v3.4 移除此模块
from __future__ import annotations
import typing

View File

@@ -6,6 +6,7 @@ from . import context
class SettingManager:
"""插件设置管理器"""
ap: app.Application
@@ -15,10 +16,7 @@ class SettingManager:
self.ap = ap
async def initialize(self):
self.settings = await cfg_mgr.load_json_config(
'plugins/plugins.json',
'templates/plugin-settings.json'
)
self.settings = self.ap.plugin_setting_meta
async def sync_setting(
self,

View File

@@ -4,6 +4,8 @@ import typing
import enum
import pydantic
import mirai
class FunctionCall(pydantic.BaseModel):
name: str
@@ -20,14 +22,31 @@ class ToolCall(pydantic.BaseModel):
class Message(pydantic.BaseModel):
role: str # user, system, assistant, tool, command
"""消息"""
role: str # user, system, assistant, tool, command, plugin
"""消息的角色"""
name: typing.Optional[str] = None
"""名称,仅函数调用返回时设置"""
content: typing.Optional[str] = None
content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None
"""内容"""
function_call: typing.Optional[FunctionCall] = None
"""函数调用不再受支持请使用tool_calls"""
tool_calls: typing.Optional[list[ToolCall]] = None
"""工具调用"""
tool_call_id: typing.Optional[str] = None
def readable_str(self) -> str:
if self.content is not None:
return str(self.content)
elif self.function_call is not None:
return f'{self.function_call.name}({self.function_call.arguments})'
elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}'
else:
return '未知消息'

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
preregistered_requesters: list[typing.Type[LLMAPIRequester]] = []
def requester_class(name: str):
def decorator(cls: typing.Type[LLMAPIRequester]) -> typing.Type[LLMAPIRequester]:
cls.name = name
preregistered_requesters.append(cls)
return cls
return decorator
class LLMAPIRequester(metaclass=abc.ABCMeta):
"""LLM API请求器
"""
name: str = None
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def request(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求API
对话前文可以从 query 对象中获取。
可以多次yield消息对象。
Args:
query (core_entities.Query): 本次请求的上下文对象
Yields:
pkg.provider.entities.Message: 返回消息对象
"""
raise NotImplementedError

View File

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
import typing
import traceback
import anthropic
from .. import api, entities, errors
from .. import api, entities, errors
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("anthropic-messages")
class AnthropicMessages(api.LLMAPIRequester):
"""Anthropic Messages API 请求器"""
client: anthropic.AsyncAnthropic
async def initialize(self):
self.client = anthropic.AsyncAnthropic(
api_key="",
base_url=self.ap.provider_cfg.data['requester']['anthropic-messages']['base-url'],
timeout=self.ap.provider_cfg.data['requester']['anthropic-messages']['timeout'],
proxies=self.ap.proxy_mgr.get_forward_proxies()
)
async def request(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
self.client.api_key = query.use_model.token_mgr.get_token()
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != ""
] + [m.dict(exclude_none=True) for m in query.messages]
# 删除所有 role=system & content='' 的消息
req_messages = [
m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "")
]
# 检查是否有 role=system 的消息,若有,改为 role=user并在后面加一个 role=assistant 的消息
system_role_index = []
for i, m in enumerate(req_messages):
if m["role"] == "system":
system_role_index.append(i)
m["role"] = "user"
if system_role_index:
for i in system_role_index[::-1]:
req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."})
# 忽略掉空消息,用户可能发送空消息,而上层未过滤
req_messages = [
m for m in req_messages if m["content"].strip() != ""
]
args["messages"] = req_messages
try:
resp = await self.client.messages.create(**args)
yield llm_entities.Message(
content=resp.content[0].text,
role=resp.role
)
except anthropic.AuthenticationError as e:
raise errors.RequesterError(f'api-key 无效: {e.message}')
except anthropic.BadRequestError as e:
raise errors.RequesterError(str(e.message))
except anthropic.NotFoundError as e:
if 'model: ' in str(e):
raise errors.RequesterError(f'模型无效: {e.message}')
else:
raise errors.RequesterError(f'请求地址无效: {e.message}')

View File

@@ -9,22 +9,31 @@ import openai
import openai.types.chat.chat_completion as chat_completion
import httpx
from pkg.provider.entities import Message
from .. import api, entities, errors
from ....core import entities as core_entities
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
class OpenAIChatCompletion(api.LLMAPIRequester):
@api.requester_class("openai-chat-completions")
class OpenAIChatCompletions(api.LLMAPIRequester):
"""OpenAI ChatCompletion API 请求器"""
client: openai.AsyncClient
requester_cfg: dict
def __init__(self, ap: app.Application):
self.ap = ap
self.requester_cfg = self.ap.provider_cfg.data['requester']['openai-chat-completions']
async def initialize(self):
self.client = openai.AsyncClient(
api_key="",
base_url=self.ap.provider_cfg.data['openai-config']['base_url'],
timeout=self.ap.provider_cfg.data['openai-config']['request-timeout'],
base_url=self.requester_cfg['base-url'],
timeout=self.requester_cfg['timeout'],
http_client=httpx.AsyncClient(
proxies=self.ap.proxy_mgr.get_forward_proxies()
)
@@ -55,7 +64,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.ap.provider_cfg.data['openai-config']['chat-completions-params'].copy()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_model.tool_call_supported:
@@ -84,11 +93,12 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
pending_tool_calls = []
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
m.dict(exclude_none=True) for m in query.prompt.messages
m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != ""
] + [m.dict(exclude_none=True) for m in query.messages]
# req_messages.append({"role": "user", "content": str(query.message_chain)})
# 首次请求
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
yield msg
@@ -97,8 +107,10 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
req_messages.append(msg.dict(exclude_none=True))
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
@@ -114,6 +126,17 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
yield msg
req_messages.append(msg.dict(exclude_none=True))
except Exception as e:
# 出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
yield err_msg
req_messages.append(
err_msg.dict(exclude_none=True)
)
# 处理完所有调用,继续请求
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
@@ -124,14 +147,17 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
req_messages.append(msg.dict(exclude_none=True))
async def request(self, query: core_entities.Query) -> AsyncGenerator[Message, None]:
async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]:
try:
async for msg in self._request(query):
yield msg
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
except openai.BadRequestError as e:
raise errors.RequesterError(f'请求错误: {e.message}')
if 'context_length_exceeded' in e.message:
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
else:
raise errors.RequesterError(f'请求参数错误: {e.message}')
except openai.AuthenticationError as e:
raise errors.RequesterError(f'无效的 api-key: {e.message}')
except openai.NotFoundError as e:

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from ....core import app
from . import chatcmpl
from .. import api
@api.requester_class("moonshot-chat-completions")
class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Moonshot ChatCompletion API 请求器"""
def __init__(self, ap: app.Application):
self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions']
self.ap = ap

View File

@@ -5,7 +5,7 @@ import typing
import pydantic
from . import api
from . import token, tokenizer
from . import token
class LLMModelInfo(pydantic.BaseModel):
@@ -19,11 +19,7 @@ class LLMModelInfo(pydantic.BaseModel):
requester: api.LLMAPIRequester
tokenizer: 'tokenizer.LLMTokenizer'
tool_call_supported: typing.Optional[bool] = False
max_tokens: typing.Optional[int] = 2048
class Config:
arbitrary_types_allowed = True

View File

@@ -0,0 +1,109 @@
from __future__ import annotations
import aiohttp
from . import entities
from ...core import app
from . import token, api
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
class ModelManager:
"""模型管理器"""
ap: app.Application
model_list: list[entities.LLMModelInfo]
requesters: dict[str, api.LLMAPIRequester]
token_mgrs: dict[str, token.TokenManager]
def __init__(self, ap: app.Application):
self.ap = ap
self.model_list = []
self.requesters = {}
self.token_mgrs = {}
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
"""通过名称获取模型
"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
async def initialize(self):
# 初始化token_mgr, requester
for k, v in self.ap.provider_cfg.data['keys'].items():
self.token_mgrs[k] = token.TokenManager(k, v)
for api_cls in api.preregistered_requesters:
api_inst = api_cls(self.ap)
await api_inst.initialize()
self.requesters[api_inst.name] = api_inst
# 尝试从api获取最新的模型信息
try:
async with aiohttp.ClientSession() as session:
async with session.request(
method="GET",
url=FETCH_MODEL_LIST_URL,
# 参数
params={
"version": self.ap.ver_mgr.get_current_version()
},
) as resp:
model_list = (await resp.json())['data']['list']
for model in model_list:
for index, local_model in enumerate(self.ap.llm_models_meta.data['list']):
if model['name'] == local_model['name']:
self.ap.llm_models_meta.data['list'][index] = model
break
else:
self.ap.llm_models_meta.data['list'].append(model)
await self.ap.llm_models_meta.dump_config()
except Exception as e:
self.ap.logger.debug(f'获取最新模型列表失败: {e}')
default_model_info: entities.LLMModelInfo = None
for model in self.ap.llm_models_meta.data['list']:
if model['name'] == 'default':
default_model_info = entities.LLMModelInfo(
name=model['name'],
model_name=None,
token_mgr=self.token_mgrs[model['token_mgr']],
requester=self.requesters[model['requester']],
tool_call_supported=model['tool_call_supported']
)
break
for model in self.ap.llm_models_meta.data['list']:
try:
model_name = model.get('model_name', default_model_info.model_name)
token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr
requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported)
model_info = entities.LLMModelInfo(
name=model['name'],
model_name=model_name,
token_mgr=token_mgr,
requester=requester,
tool_call_supported=tool_call_supported
)
self.model_list.append(model_info)
except Exception as e:
self.ap.logger.error(f"初始化模型 {model['name']} 失败: {e} ,请检查配置文件")

View File

@@ -6,6 +6,8 @@ import pydantic
class TokenManager():
"""鉴权 Token 管理器
"""
provider: str

View File

@@ -1,29 +0,0 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
class LLMAPIRequester(metaclass=abc.ABCMeta):
"""LLM API请求器
"""
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def request(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求
"""
raise NotImplementedError

View File

@@ -1,242 +0,0 @@
from __future__ import annotations
from . import entities
from ...core import app
from .apis import chatcmpl
from . import token
from .tokenizers import tiktoken
class ModelManager:
ap: app.Application
model_list: list[entities.LLMModelInfo]
def __init__(self, ap: app.Application):
self.ap = ap
self.model_list = []
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
"""通过名称获取模型
"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f"不支持模型: {name} , 请检查配置文件")
async def initialize(self):
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
await openai_chat_completion.initialize()
openai_token_mgr = token.TokenManager(self.ap, list(self.ap.provider_cfg.data['openai-config']['api-keys']))
tiktoken_tokenizer = tiktoken.Tiktoken(self.ap)
model_list = [
entities.LLMModelInfo(
name="gpt-3.5-turbo",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=4096
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-1106",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=16385
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-16k",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=16385
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-0613",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=4096
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-16k-0613",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=16385
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-0301",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=4096
)
]
self.model_list.extend(model_list)
gpt4_model_list = [
entities.LLMModelInfo(
name="gpt-4-0125-preview",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="gpt-4-turbo-preview",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="gpt-4-1106-preview",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="gpt-4-vision-preview",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="gpt-4",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=8192
),
entities.LLMModelInfo(
name="gpt-4-0613",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=8192
),
entities.LLMModelInfo(
name="gpt-4-32k",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=32768
),
entities.LLMModelInfo(
name="gpt-4-32k-0613",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=32768
)
]
self.model_list.extend(gpt4_model_list)
one_api_model_list = [
entities.LLMModelInfo(
name="OneAPI/SparkDesk",
model_name='SparkDesk',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=8192
),
entities.LLMModelInfo(
name="OneAPI/chatglm_pro",
model_name='chatglm_pro',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="OneAPI/chatglm_std",
model_name='chatglm_std',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="OneAPI/chatglm_lite",
model_name='chatglm_lite',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="OneAPI/qwen-v1",
model_name='qwen-v1',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=6000
),
entities.LLMModelInfo(
name="OneAPI/qwen-plus-v1",
model_name='qwen-plus-v1',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=30000
),
entities.LLMModelInfo(
name="OneAPI/ERNIE-Bot",
model_name='ERNIE-Bot',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=2000
),
entities.LLMModelInfo(
name="OneAPI/ERNIE-Bot-turbo",
model_name='ERNIE-Bot-turbo',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=7000
),
entities.LLMModelInfo(
name="OneAPI/gemini-pro",
model_name='gemini-pro',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=30720
),
]
self.model_list.extend(one_api_model_list)

View File

@@ -1,29 +0,0 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from .. import entities as llm_entities
from . import entities
class LLMTokenizer(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
"""初始化分词器
"""
pass
@abc.abstractmethod
async def count_token(
self,
messages: list[llm_entities.Message],
model: entities.LLMModelInfo
) -> int:
pass

View File

@@ -1,28 +0,0 @@
from __future__ import annotations
import tiktoken
from .. import tokenizer
from ... import entities as llm_entities
from .. import entities
class Tiktoken(tokenizer.LLMTokenizer):
async def count_token(
self,
messages: list[llm_entities.Message],
model: entities.LLMModelInfo
) -> int:
try:
encoding = tiktoken.encoding_for_model(model.name)
except KeyError:
# print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for message in messages:
num_tokens += len(encoding.encode(message.role))
num_tokens += len(encoding.encode(message.content if message.content is not None else ''))
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens

View File

@@ -6,6 +6,8 @@ from ...core import app, entities as core_entities
class SessionManager:
"""会话管理器
"""
ap: app.Application
@@ -39,6 +41,8 @@ class SessionManager:
return session
async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation:
"""获取对话或创建对话"""
if not session.conversations:
session.conversations = []
@@ -46,7 +50,7 @@ class SessionManager:
conversation = core_entities.Conversation(
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
messages=[],
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['openai-config']['chat-completions-params']['model']),
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']),
use_funcs=await self.ap.tool_mgr.get_all_functions(),
)
session.conversations.append(conversation)

View File

@@ -10,5 +10,7 @@ class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""
name: str
"""名称"""
messages: list[entities.Message]
"""消息列表"""

View File

@@ -1,13 +1,27 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from . import entities
preregistered_loaders: list[typing.Type[PromptLoader]] = []
def loader_class(name: str):
def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]:
cls.name = name
preregistered_loaders.append(cls)
return cls
return decorator
class PromptLoader(metaclass=abc.ABCMeta):
"""Prompt加载器抽象类
"""
name: str
ap: app.Application
@@ -22,7 +36,7 @@ class PromptLoader(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def load(self):
"""加载Prompt
"""加载Prompt存放到prompts列表中
"""
raise NotImplementedError

View File

@@ -8,14 +8,15 @@ from .. import entities
from ....provider import entities as llm_entities
@loader.loader_class("full-scenario")
class ScenarioPromptLoader(loader.PromptLoader):
"""加载scenario目录下的json"""
async def load(self):
"""加载Prompt
"""
for file in os.listdir("data/scenarios"):
with open("data/scenarios/{}".format(file), "r", encoding="utf-8") as f:
for file in os.listdir("data/scenario"):
with open("data/scenario/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
file_json = json.loads(file_str)

View File

@@ -6,6 +6,7 @@ from .. import entities
from ....provider import entities as llm_entities
@loader.loader_class("normal")
class SingleSystemPromptLoader(loader.PromptLoader):
"""配置文件中的单条system prompt的prompt加载器
"""

View File

@@ -6,6 +6,8 @@ from .loaders import single, scenario
class PromptManager:
"""Prompt管理器
"""
ap: app.Application
@@ -18,14 +20,18 @@ class PromptManager:
async def initialize(self):
loader_map = {
"normal": single.SingleSystemPromptLoader,
"full_scenario": scenario.ScenarioPromptLoader
}
mode_name = self.ap.provider_cfg.data['prompt-mode']
loader_cls = loader_map[self.ap.provider_cfg.data['prompt-mode']]
loader_class = None
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
for loader_cls in loader.preregistered_loaders:
if loader_cls.name == mode_name:
loader_class = loader_cls
break
else:
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
await self.loader_inst.initialize()
await self.loader_inst.load()

View File

@@ -5,6 +5,7 @@ import traceback
from ...core import app, entities as core_entities
from . import entities
from ...plugin import context as plugin_context
class ToolManager:
@@ -28,6 +29,15 @@ class ToolManager:
return function
return None
async def get_function_and_plugin(self, name: str) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]:
"""获取函数和插件
"""
for plugin in self.ap.plugin_mgr.plugins:
for function in plugin.content_functions:
if function.name == name:
return function, plugin
return None, None
async def get_all_functions(self) -> list[entities.LLMFunction]:
"""获取所有函数
"""
@@ -68,7 +78,7 @@ class ToolManager:
try:
function = await self.get_function(name)
function, plugin = await self.get_function_and_plugin(name)
if function is None:
return None
@@ -79,7 +89,7 @@ class ToolManager:
**parameters
}
return await function.func(**parameters)
return await function.func(plugin, **parameters)
except Exception as e:
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
traceback.print_exc()

Some files were not shown because too many files have changed in this diff Show More