Compare commits

...

118 Commits
v2 ... v3.0.2

Author SHA1 Message Date
RockChinQ
ce881372ee chore: release v3.0.2 2024-03-02 21:03:04 +08:00
Junyan Qin
171ea7c375 Merge pull request #708 from RockChinQ/fix/llonebot-not-supported
Fix: 修复使用llonebot时的协议问题
2024-03-02 20:59:41 +08:00
RockChinQ
1e9a6f813f fix: 修复使用llonebot时的协议问题 2024-03-02 20:58:58 +08:00
Junyan Qin
39a7f3b2b9 Merge pull request #707 from RockChinQ/feat/booting-stages
Feat: 分阶段启动
2024-03-02 20:27:51 +08:00
RockChinQ
8d375a02db fix: 未导入问题 2024-03-02 20:05:23 +08:00
RockChinQ
cac8a0a414 perf: 优化导入 2024-03-02 16:39:29 +08:00
RockChinQ
c89623967e refactor: 应用初始化流程初步分阶段 2024-03-02 16:37:30 +08:00
RockChinQ
92aa9c1711 perf: 配置文件生成步骤移动到main.py 2024-03-02 14:57:55 +08:00
Junyan Qin
71f2a58acb feat: 依赖检查移动到main.py 2024-02-29 11:10:30 +00:00
RockChinQ
1f07a8a9e3 refactor: 移动pool到pipeline包 2024-02-29 03:38:38 +00:00
RockChinQ
cacd21bde7 refactor: 移动控制器到pipeline包 2024-02-29 03:38:38 +00:00
RockChinQ
a060ec66c3 deps: 整理依赖 2024-02-29 11:03:11 +08:00
Junyan Qin
fd10db3c75 ci: fix 2024-02-21 13:56:38 +00:00
Junyan Qin
db4c658980 chore: test 2024-02-21 13:52:54 +00:00
Junyan Qin
0ee88674f8 ci: update 2024-02-21 13:52:33 +00:00
Junyan Qin
3540759682 chore: release v3.0.1.1 2024-02-21 13:46:38 +00:00
Junyan Qin
44cc8f15b4 Merge pull request #695 from RockChinQ/ci/arm-image
CI: 构建arm64镜像
2024-02-21 21:45:40 +08:00
Junyan Qin
59f821bf0a ci: 构建arm64镜像 2024-02-21 13:44:07 +00:00
RockChinQ
80858672b0 perf: 控制台输出请求响应过程 2024-02-20 22:56:42 +08:00
RockChinQ
3258d5b255 chore: aiocqhttp默认监听地址改为0.0.0.0 2024-02-20 20:13:46 +08:00
RockChinQ
e8c8cc0a9c chore: release v3.0.1 2024-02-20 11:48:26 +08:00
Junyan Qin
570c19f29f Merge pull request #693 from RockChinQ/fix/3.9-compability
Fix: 针对python3.9的兼容性
2024-02-20 11:47:49 +08:00
RockChinQ
ee93fd8636 hotfix: 针对python3.9的兼容性 2024-02-20 11:47:04 +08:00
RockChinQ
1e6c32ffc7 fix: 'VersionManager' object has no attribute 'get_release_list' 2024-02-20 09:54:02 +08:00
RockChinQ
3ef2fb958c chore: release v3.0.0 2024-02-19 22:04:41 +08:00
RockChinQ
97edfe7cd7 doc: 整理README 2024-02-19 22:03:27 +08:00
Junyan Qin
1bdc96f8b2 Merge pull request #669 from RockChinQ/feat/asyncio
Refactor: 异步架构
2024-02-19 21:59:41 +08:00
RockChinQ
4ef285aee9 chore: 删除无用文件 2024-02-19 21:51:56 +08:00
RockChinQ
6ccee3b7cf chore: 删除 README_en.md 2024-02-19 21:48:52 +08:00
RockChinQ
082731ba32 fix: !version 命令获取最新版本失败时导致命令失败 2024-02-19 21:47:51 +08:00
RockChinQ
0bf85fb644 fix: msg_source无法通过审计接口发给center 2024-02-19 21:41:53 +08:00
RockChinQ
5ce1759dd9 fix: 启动后未进行对话时,!list会 越界异常 2024-02-19 21:40:34 +08:00
RockChinQ
1e016dfa24 ci: 修改工作流文件名 2024-02-19 20:37:40 +08:00
RockChinQ
7b3bb53f06 ci: 更换基础镜像 2024-02-19 20:36:26 +08:00
RockChinQ
53d0059848 perf: 不再需要exit来退出程序 2024-02-19 19:27:42 +08:00
RockChinQ
9a85178a29 deps: 重新添加nakuru 2024-02-19 19:17:18 +08:00
RockChinQ
d74681a128 deps: 删除无用依赖 2024-02-19 18:59:52 +08:00
RockChinQ
06c8773975 perf: 优化控制台输出 2024-02-16 14:11:22 +08:00
RockChinQ
ae358dd6d0 fix: 昨天错误的shutdown_trigger逻辑 2024-02-16 13:08:26 +08:00
RockChinQ
7174cbf41f feat: 支持 ctrl+c 退出 2024-02-15 22:21:56 +08:00
RockChinQ
f73d69e814 perf: 添加未启用适配器时的警告 2024-02-15 16:12:42 +08:00
Junyan Qin
8af174127d Merge pull request #685 from RockChinQ/feat/run-multi-adapter
Feat: 支持同时运行多个适配器
2024-02-12 13:38:56 +08:00
RockChinQ
991a0aa5f6 fix: 修复nakuru无法运行的问题 2024-02-12 13:37:41 +08:00
RockChinQ
abc19e78b8 feat: 命令行退出方式 2024-02-11 23:35:05 +08:00
RockChinQ
836df87e18 feat: 删除过时配置 2024-02-11 23:11:13 +08:00
RockChinQ
9cad94e961 feat: 支持同时运行多个平台适配器 2024-02-11 23:07:38 +08:00
Junyan Qin
b9568eb558 doc(README.md): 更新社区群群号 2024-02-11 09:47:21 +08:00
RockChinQ
f951625025 chore: 修改推荐的docker-compose.yaml配置 2024-02-08 13:45:26 +08:00
RockChinQ
c2b3b53c12 chore: 修改启动相关 2024-02-08 13:40:25 +08:00
RockChinQ
d95e18c202 chore: 整理代码 2024-02-08 13:37:27 +08:00
Junyan Qin
e705e707e5 Merge pull request #680 from RockChinQ/feat/nakuru
Feat: 恢复nakuru使用
2024-02-08 13:14:53 +08:00
RockChinQ
2fa5d7608f chore: 删除无效代码 2024-02-08 13:13:35 +08:00
RockChinQ
f9a3e99795 feat: 恢复nakuru使用 2024-02-08 13:12:33 +08:00
RockChinQ
d86ad25f86 feat: 正向代理支持 2024-02-07 23:58:22 +08:00
Junyan Qin
cf583486e3 Merge pull request #679 from RockChinQ/feat/botpy-qq
Feat: 接入 QQ 官方 API
2024-02-07 23:29:56 +08:00
RockChinQ
7366ca59c7 chore: 忽略botpy.log 2024-02-07 23:27:10 +08:00
RockChinQ
12820e6c64 feat: 支持qq-botpy 2024-02-07 23:21:32 +08:00
Junyan Qin
71b54fd684 Merge pull request #678 from RockChinQ/feat/aiocqhttp
Feat: 适配aiocqhttp
2024-02-07 20:23:43 +08:00
RockChinQ
aeb1912db6 feat: 适配aiocqhttp 2024-02-07 20:03:46 +08:00
Junyan Qin
84b2867148 Merge pull request #677 from RockChinQ/refactor/asyncio/config
Refactor: 配置文件重构
2024-02-07 00:09:23 +08:00
RockChinQ
5880dacad8 ci: 修改dockerfile 2024-02-07 00:07:55 +08:00
RockChinQ
b5b67ad958 refactor: 恢复命令权限设置 2024-02-06 23:57:21 +08:00
RockChinQ
2a913ed24c chore: 删除过时文件 2024-02-06 21:29:31 +08:00
RockChinQ
aab56294ba chore: 删除字体文件 2024-02-06 21:28:24 +08:00
RockChinQ
26912ef976 chore: 删除多余文件 2024-02-06 21:28:01 +08:00
RockChinQ
c1fed3410b chore: 删除过时的配置文件 2024-02-06 21:27:14 +08:00
RockChinQ
c853bba4ba refactor: 配置文件均改为json 2024-02-06 21:26:03 +08:00
RockChinQ
f340a44abf feat: 恢复ratelimit 2024-02-01 18:38:20 +08:00
RockChinQ
0dec10ddf2 chore: 删除tests目录 2024-02-01 18:38:04 +08:00
RockChinQ
7026abe56a perf: 完善openai异常处理 2024-02-01 18:11:47 +08:00
RockChinQ
a9d92115f8 feat: chat前的前文剪裁逻辑 2024-02-01 17:42:51 +08:00
RockChinQ
6f2d7d96d0 perf: 完善历史消息处理逻辑 2024-02-01 16:43:44 +08:00
RockChinQ
532a713355 refactor: 独立出预处理阶段 2024-02-01 16:35:00 +08:00
RockChinQ
976a9de39c refactor: 分隔LLM请求过程和消息封装过程 2024-02-01 15:48:26 +08:00
RockChinQ
32162afa65 refactor: 恢复所有审计API调用 2024-01-31 00:02:19 +08:00
RockChinQ
c1c751a9ab feat: 更新操作 2024-01-30 22:50:52 +08:00
RockChinQ
b749ba587d feat: 恢复强制消息延迟 2024-01-30 21:56:25 +08:00
GitHub Actions
b2741686fd Update override-all.json 2024-01-30 13:45:50 +00:00
RockChinQ
94bf7739a0 chore: 默认回复函数响应 2024-01-30 21:45:31 +08:00
RockChinQ
33d600fb6b refactor: 恢复插件事件调用 2024-01-30 21:45:17 +08:00
RockChinQ
e2de3d0102 feat: 删除部分插件事件 2024-01-30 17:47:03 +08:00
RockChinQ
6b76adc00e feat: 添加事件对象 2024-01-30 17:24:22 +08:00
RockChinQ
61f4cb2f65 perf: 完善模型信息 2024-01-30 16:58:11 +08:00
RockChinQ
28bd232dda feat: 添加更多LLM模型 2024-01-30 16:29:54 +08:00
RockChinQ
e9e458c877 feat: 公告和更新检查 2024-01-30 16:13:33 +08:00
RockChinQ
437971ded8 feat: 应用层异常处理 2024-01-30 14:58:34 +08:00
RockChinQ
3945ac95d1 refactor: 审计api改为异步 2024-01-29 21:58:47 +08:00
RockChinQ
13ab647dc0 perf: 完善插件加载流程 2024-01-29 21:41:20 +08:00
RockChinQ
c75b0ce8fb perf: 优化代码声明 2024-01-29 21:31:11 +08:00
RockChinQ
6cc4688660 refactor: 重构插件系统 2024-01-29 21:22:27 +08:00
RockChinQ
b730f17eb6 chore: 修改包名 2024-01-28 19:20:10 +08:00
RockChinQ
698782c537 chore: 整理文件 2024-01-28 18:45:18 +08:00
Junyan Qin
2b0faea8ec Merge pull request #673 from RockChinQ/refactor/asyncio/control-flow
Refactor: 请求处理控制流
2024-01-28 18:41:59 +08:00
RockChinQ
d130c376f4 chore: 删除命令权限同步脚本 2024-01-28 18:40:10 +08:00
RockChinQ
238c55a40e chore: 删除已弃用的文件 2024-01-28 18:38:47 +08:00
RockChinQ
b5924bb34f refactor: 添加更新命令 2024-01-28 18:27:48 +08:00
RockChinQ
1368ee22b2 refactor: 命令基本完成 2024-01-28 18:21:43 +08:00
RockChinQ
2a0cf57303 refactor: 命令处理基础 2024-01-28 00:16:42 +08:00
RockChinQ
f10af09bd2 refactor: AI对话基本完成 2024-01-27 21:50:40 +08:00
RockChinQ
850a4eeb7c refactor: 重构openai包基础组件架构 2024-01-27 00:06:38 +08:00
RockChinQ
411034902a feat: 启动时展示asciiart 2024-01-27 00:05:55 +08:00
RockChinQ
1900ddacbb chore: 删除 qqbot 包中的流程代码 2024-01-26 15:54:24 +08:00
RockChinQ
8d084427d2 refactor: 请求处理控制流基础架构 2024-01-26 15:51:49 +08:00
Junyan Qin
a064c24f60 Merge pull request #670 from RockChinQ/refactor/asyncio/simplify-qqbot-mgr
Refactor: 简化和调整qqbot包架构
2024-01-25 22:39:25 +08:00
RockChinQ
b43882aad0 refactor: 独立ratelimiter包 2024-01-25 22:35:15 +08:00
RockChinQ
f4ead5ec5c refactor: 独立resprule为单独的包 2024-01-25 18:07:28 +08:00
RockChinQ
ea9ae85428 refactor: 独立长消息处理为longtext包 2024-01-25 17:05:09 +08:00
RockChinQ
a9a798b19d refactor: filter和ignore独立成新的cntfilter包 2024-01-25 15:28:23 +08:00
RockChinQ
f4ae9df3bf refactor: 重构会话封禁功能处理逻辑 2024-01-24 23:38:13 +08:00
RockChinQ
f3bcff1261 chore: banlist模版移至根目录 2024-01-24 23:33:48 +08:00
RockChinQ
b4bd86549e chore: banlist模版移至根目录 2024-01-24 23:33:19 +08:00
RockChinQ
a975718a64 refactor: 暂时删除对热重载的支持 2024-01-24 22:29:19 +08:00
RockChinQ
3d06a18bcb refactor: 简化私聊群聊共同处理代码 2024-01-24 17:00:56 +08:00
RockChinQ
a236089785 refactor: 独立resprule模块 2024-01-24 16:11:56 +08:00
RockChinQ
2f877965cf chore: 删除部分注释代码 2024-01-23 23:27:55 +08:00
RockChinQ
ad5ef95e65 refactor: yirimirai 适配器实现异步 2024-01-23 22:28:30 +08:00
RockChinQ
8d35ecd711 refactor: 基本启动流程 2024-01-23 20:55:20 +08:00
RockChinQ
e63c6ac723 feat: 删除main.py中init_db函数 2024-01-23 15:42:23 +08:00
254 changed files with 7476 additions and 8799 deletions

View File

@@ -11,5 +11,4 @@ updates:
interval: "weekly"
allow:
- dependency-name: "yiri-mirai-rc"
- dependency-name: "dulwich"
- dependency-name: "openai"

View File

@@ -0,0 +1,48 @@
name: Build Docker Image
on:
#防止fork乱用action设置只能手动触发构建
workflow_dispatch:
## 发布release的时候会自动构建
release:
types: [published]
jobs:
publish-docker-image:
runs-on: ubuntu-latest
name: Build image
steps:
- name: Checkout
uses: actions/checkout@v2
- name: judge has env GITHUB_REF # 如果没有GITHUB_REF环境变量则把github.ref变量赋值给GITHUB_REF
run: |
if [ -z "$GITHUB_REF" ]; then
export GITHUB_REF=${{ github.ref }}
echo $GITHUB_REF
fi
# - name: Check GITHUB_REF env
# run: echo $GITHUB_REF
# - name: Get version # 在 GitHub Actions 运行环境
# id: get_version
# if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
# run: export GITHUB_REF=${GITHUB_REF/refs\/tags\//}
- name: Check version
id: check_version
run: |
echo $GITHUB_REF
# 如果是tag则去掉refs/tags/前缀
if [[ $GITHUB_REF == refs/tags/* ]]; then
echo "It's a tag"
echo $GITHUB_REF
echo $GITHUB_REF | awk -F '/' '{print $3}'
echo ::set-output name=version::$(echo $GITHUB_REF | awk -F '/' '{print $3}')
else
echo "It's not a tag"
echo $GITHUB_REF
echo ::set-output name=version::${GITHUB_REF}
fi
- name: Login to Registry
run: docker login --username=${{ secrets.DOCKER_USERNAME }} --password ${{ secrets.DOCKER_PASSWORD }}
- name: Create Buildx
run: docker buildx create --name mybuilder --use
- name: Build # image name: rockchin/qchatgpt:<VERSION>
run: docker buildx build --platform linux/arm64,linux/amd64 -t rockchin/qchatgpt:${{ steps.check_version.outputs.version }} -t rockchin/qchatgpt:latest . --push

View File

@@ -1,38 +0,0 @@
name: Build Docker Image
on:
#防止fork乱用action设置只能手动触发构建
workflow_dispatch:
## 发布release的时候会自动构建
release:
types: [published]
jobs:
publish-docker-image:
runs-on: ubuntu-latest
name: Build image
steps:
- name: Checkout
uses: actions/checkout@v2
- name: judge has env GITHUB_REF # 如果没有GITHUB_REF环境变量则把github.ref变量赋值给GITHUB_REF
run: |
if [ -z "$GITHUB_REF" ]; then
export GITHUB_REF=${{ github.ref }}
fi
- name: Check GITHUB_REF env
run: echo $GITHUB_REF
- name: Get version
id: get_version
if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//}
- name: Build # image name: rockchin/qchatgpt:<VERSION>
run: docker build --network=host -t rockchin/qchatgpt:${{ steps.get_version.outputs.VERSION }} -t rockchin/qchatgpt:latest .
- name: Login to Registry
run: docker login --username=${{ secrets.DOCKER_USERNAME }} --password ${{ secrets.DOCKER_PASSWORD }}
- name: Push image
if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
run: docker push rockchin/qchatgpt:${{ steps.get_version.outputs.VERSION }}
- name: Push latest image
if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
run: docker push rockchin/qchatgpt:latest

View File

@@ -1,58 +0,0 @@
name: Update cmdpriv-template
on:
push:
paths:
- 'pkg/qqbot/cmds/**'
pull_request:
types: [closed]
paths:
- 'pkg/qqbot/cmds/**'
jobs:
update-cmdpriv-template:
if: github.event.pull_request.merged == true || github.event_name == 'push'
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.10.13
- name: Install dependencies
run: |
python -m pip install --upgrade yiri-mirai-rc openai>=1.0.0 colorlog func_timeout dulwich Pillow CallingGPT tiktoken
python -m pip install -U openai>=1.0.0
- name: Copy Scripts
run: |
cp res/scripts/generate_cmdpriv_template.py .
- name: Generate Files
run: |
python main.py
- name: Run generate_cmdpriv_template.py
run: python3 generate_cmdpriv_template.py
- name: Check for changes in cmdpriv-template.json
id: check_changes
run: |
if git diff --name-only | grep -q "res/templates/cmdpriv-template.json"; then
echo "::set-output name=changes_detected::true"
else
echo "::set-output name=changes_detected::false"
fi
- name: Commit changes to cmdpriv-template.json
if: steps.check_changes.outputs.changes_detected == 'true'
run: |
git config --global user.name "GitHub Actions Bot"
git config --global user.email "<github-actions@github.com>"
git add res/templates/cmdpriv-template.json
git commit -m "Update cmdpriv-template.json"
git push

View File

@@ -1,52 +0,0 @@
name: Check and Update override_all
on:
push:
paths:
- 'config-template.py'
pull_request:
types:
- closed
branches:
- master
paths:
- 'config-template.py'
jobs:
update-override-all:
name: check and update
if: github.event.pull_request.merged == true || github.event_name == 'push'
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.x
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- name: Copy Scripts
run: |
cp res/scripts/generate_override_all.py .
- name: Run generate_override_all.py
run: python3 generate_override_all.py
- name: Check for changes in override-all.json
id: check_changes
run: |
git diff --exit-code override-all.json || echo "::set-output name=changes_detected::true"
- name: Commit and push changes
if: steps.check_changes.outputs.changes_detected == 'true'
run: |
git config --global user.email "github-actions[bot]@users.noreply.github.com"
git config --global user.name "GitHub Actions"
git add override-all.json
git commit -m "Update override-all.json"
git push

2
.gitignore vendored
View File

@@ -33,3 +33,5 @@ bard.json
!/docker-compose.yaml
res/instance_id.json
.DS_Store
/data
botpy.log

View File

@@ -1,15 +1,10 @@
FROM python:3.10.13-bullseye
WORKDIR /QChatGPT
FROM python:3.10.13-slim
WORKDIR /app
COPY . /QChatGPT/
RUN ls
RUN python -m pip install -r requirements.txt && \
python -m pip install -U websockets==10.0 && \
python -m pip install -U httpcore httpx openai
# 生成配置文件
RUN python main.py
COPY . .
RUN apt update \
&& apt install gcc -y \
&& python -m pip install -r requirements.txt
CMD [ "python", "main.py" ]

View File

@@ -7,8 +7,6 @@
# QChatGPT
<blockquote> 🥳 QChatGPT 一周年啦,感谢大家的支持!欢迎前往<a href="https://github.com/RockChinQ/QChatGPT/discussions/627">讨论</a>。</blockquote>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/RockChinQ/QChatGPT)](https://github.com/RockChinQ/QChatGPT/releases/latest)
<a href="https://hub.docker.com/repository/docker/rockchin/qchatgpt">
<img src="https://img.shields.io/docker/pulls/rockchin/qchatgpt?color=blue" alt="docker pull">
@@ -22,8 +20,8 @@
<a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=66-aWvn8cbP4c1ut_1YYkvvGVeEtyTH8&authKey=pTaKBK5C%2B8dFzQ4XlENf6MHTCLaHnlKcCRx7c14EeVVlpX2nRSaS8lJm8YeM4mCU&noverify=0&group_code=195992197">
<img alt="Static Badge" src="https://img.shields.io/badge/%E5%AE%98%E6%96%B9%E7%BE%A4-195992197-purple">
</a>
<a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=nC80H57wmKPwRDLFeQrDDjVl81XuC21P&authKey=2wTUTfoQ5v%2BD4C5zfpuR%2BSPMDqdXgDXA%2FS2wHI1NxTfWIG%2B%2FqK08dgyjMMOzhXa9&noverify=0&group_code=738382634">
<img alt="Static Badge" src="https://img.shields.io/badge/%E7%A4%BE%E5%8C%BA%E7%BE%A4-738382634-purple">
<a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=nC80H57wmKPwRDLFeQrDDjVl81XuC21P&authKey=2wTUTfoQ5v%2BD4C5zfpuR%2BSPMDqdXgDXA%2FS2wHI1NxTfWIG%2B%2FqK08dgyjMMOzhXa9&noverify=0&group_code=248432104">
<img alt="Static Badge" src="https://img.shields.io/badge/%E7%A4%BE%E5%8C%BA%E7%BE%A4-248432104-purple">
</a>
<a href="https://www.bilibili.com/video/BV14h4y1w7TC">
<img alt="Static Badge" src="https://img.shields.io/badge/%E8%A7%86%E9%A2%91%E6%95%99%E7%A8%8B-208647">

View File

@@ -1,215 +0,0 @@
# QChatGPT🤖
<p align="center">
<img src="res/social.png" alt="QChatGPT" width="640" />
</p>
English | [简体中文](README.md)
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/RockChinQ/QChatGPT?style=flat-square)](https://github.com/RockChinQ/QChatGPT/releases/latest)
![Wakapi Count](https://wakapi.dev/api/badge/RockChinQ/interval:any/project:QChatGPT)
- Refer to [Wiki](https://github.com/RockChinQ/QChatGPT/wiki) to get further information.
- Official QQ group: 656285629
- Community QQ group: 362515018
- QQ channel robot: [QQChannelChatGPT](https://github.com/Soulter/QQChannelChatGPT)
- Any contribution is welcome, please refer to [CONTRIBUTING.md](CONTRIBUTING.md)
## 🍺List of supported models
<details>
<summary>Details</summary>
### Chat
- OpenAI GPT-3.5 (ChatGPT API), default model
- OpenAI GPT-3, supported natively, switch to it in `config.py`
- OpenAI GPT-4, supported natively, qualification for internal testing required, switch to it in `config.py`
- ChatGPT website edition (GPT-3.5), see [revLibs plugin](https://github.com/RockChinQ/revLibs)
- ChatGPT website edition (GPT-4), ChatGPT plus subscription required, see [revLibs plugin](https://github.com/RockChinQ/revLibs)
- New Bing, see [revLibs plugin](https://github.com/RockChinQ/revLibs)
- HuggingChat, see [revLibs plugin](https://github.com/RockChinQ/revLibs), English only
### Story
- NovelAI API, see [QCPNovelAi plugin](https://github.com/dominoar/QCPNovelAi)
### Image
- OpenAI DALL·E, supported natively, see [Wiki(cn)](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E5%8A%9F%E8%83%BD%E7%82%B9%E5%88%97%E4%B8%BE)
- NovelAI API, see [QCPNovelAi plugin](https://github.com/dominoar/QCPNovelAi)
### Voice
- TTS+VITS, see [QChatPlugins](https://github.com/dominoar/QChatPlugins)
- Plachta/VITS-Umamusume-voice-synthesizer, see [chat_voice plugin](https://github.com/oliverkirk-sudo/chat_voice)
</details>
Install this [plugin](https://github.com/RockChinQ/Switcher) to switch between different models.
## ✅Features
<details>
<summary>Details</summary>
- ✅Sensitive word filtering, avoid being banned
- ✅Multiple responding rules, including regular expression matching
- ✅Multiple api-key management, automatic switching when exceeding
- ✅Support for customizing the preset prompt text
- ✅Chat, story, image, voice, etc. models are supported
- ✅Support for hot reloading and hot updating
- ✅Support for plugin loading
- ✅Blacklist mechanism for private chat and group chat
- ✅Excellent long message processing strategy
- ✅Reply rate limitation
- ✅Support for network proxy
- ✅Support for customizing the output format
</details>
More details, see [Wiki(cn)](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E5%8A%9F%E8%83%BD%E7%82%B9%E5%88%97%E4%B8%BE)
## 🔩Deployment
**If you encounter any problems during deployment, please search in the issue of [QChatGPT](https://github.com/RockChinQ/QChatGPT/issues) or [qcg-installer](https://github.com/RockChinQ/qcg-installer/issues) first.**
### - Register OpenAI account
> If you want to use a model other than OpenAI (such as New Bing), you can skip this step and directly refer to following steps, and then configure it according to the relevant plugin documentation.
To register OpenAI account, please refer to the following articles(in Chinese):
> [国内注册ChatGPT的方法(100%可用)](https://www.pythonthree.com/register-openai-chatgpt/)
> [手把手教你如何注册ChatGPT超级详细](https://guxiaobei.com/51461)
Check your api-key in [personal center](https://beta.openai.com/account/api-keys) after registration, and then follow the following steps to deploy.
### - Deploy Automatically
<details>
<summary>Details</summary>
#### Docker
See [this document(cn)](res/docs/docker_deploy.md)
Contributed by [@mikumifa](https://github.com/mikumifa)
#### Installer
Use [this installer](https://github.com/RockChinQ/qcg-installer) to deploy.
- The installer currently only supports some platforms, please refer to the repository document for details, and manually deploy for other platforms
</details>
### - Deploy Manually
<details>
<summary>Manually deployment supports any platforms</summary>
- Python 3.9.x or higher
#### 配置QQ登录框架
Currently supports mirai and go-cqhttp, configure either one
<details>
<summary>mirai</summary>
Follow [this tutorial(cn)](https://yiri-mirai.wybxc.cc/tutorials/01/configuration) to configure Mirai and YiriMirai.
After starting mirai-console, use the `login` command to log in to the QQ account, and keep the mirai-console running.
</details>
<details>
<summary>go-cqhttp</summary>
1. Follow [this tutorial(cn)](https://github.com/RockChinQ/QChatGPT/wiki/go-cqhttp%E9%85%8D%E7%BD%AE) to configure go-cqhttp.
2. Start go-cqhttp, make sure it is logged in and running.
</details>
#### Configure QChatGPT
1. Clone the repository
```bash
git clone https://github.com/RockChinQ/QChatGPT
cd QChatGPT
```
2. Install dependencies
```bash
pip3 install requests yiri-mirai-rc openai colorlog func_timeout dulwich Pillow nakuru-project-idk
```
3. Generate `config.py`
```bash
python3 main.py
```
4. Edit `config.py`
5. Run
```bash
python3 main.py
```
Any problems, please refer to the issues page.
</details>
## 🚀Usage
**After deployment, please read: [Commands(cn)](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E6%9C%BA%E5%99%A8%E4%BA%BA%E6%8C%87%E4%BB%A4)**
**For more details, please refer to the [Wiki(cn)](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E4%BD%BF%E7%94%A8%E6%96%B9%E5%BC%8F)**
## 🧩Plugin Ecosystem
Plugin [usage](https://github.com/RockChinQ/QChatGPT/wiki/%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8) and [development](https://github.com/RockChinQ/QChatGPT/wiki/%E6%8F%92%E4%BB%B6%E5%BC%80%E5%8F%91) are supported.
<details>
<summary>List of plugins (cn)</summary>
### Examples
`tests/plugin_examples`目录下,将其整个目录复制到`plugins`目录下即可使用
- `cmdcn` - 主程序命令中文形式
- `hello_plugin` - 在收到消息`hello`时回复相应消息
- `urlikethisijustsix` - 收到冒犯性消息时回复相应消息
### More Plugins
欢迎提交新的插件
- [revLibs](https://github.com/RockChinQ/revLibs) - 将ChatGPT网页版接入此项目关于[官方接口和网页版有什么区别](https://github.com/RockChinQ/QChatGPT/wiki/%E5%AE%98%E6%96%B9%E6%8E%A5%E5%8F%A3%E4%B8%8EChatGPT%E7%BD%91%E9%A1%B5%E7%89%88)
- [Switcher](https://github.com/RockChinQ/Switcher) - 支持通过命令切换使用的模型
- [hello_plugin](https://github.com/RockChinQ/hello_plugin) - `hello_plugin` 的储存库形式,插件开发模板
- [dominoar/QChatPlugins](https://github.com/dominoar/QchatPlugins) - dominoar编写的诸多新功能插件语音输出、Ranimg、屏蔽词规则等
- [dominoar/QCP-NovelAi](https://github.com/dominoar/QCP-NovelAi) - NovelAI 故事叙述与绘画
- [oliverkirk-sudo/chat_voice](https://github.com/oliverkirk-sudo/chat_voice) - 文字转语音输出使用HuggingFace上的[VITS-Umamusume-voice-synthesizer模型](https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer)
- [RockChinQ/WaitYiYan](https://github.com/RockChinQ/WaitYiYan) - 实时获取百度`文心一言`等待列表人数
- [chordfish-k/QChartGPT_Emoticon_Plugin](https://github.com/chordfish-k/QChartGPT_Emoticon_Plugin) - 使机器人根据回复内容发送表情包
- [oliverkirk-sudo/ChatPoeBot](https://github.com/oliverkirk-sudo/ChatPoeBot) - 接入[Poe](https://poe.com/)上的机器人
- [lieyanqzu/WeatherPlugin](https://github.com/lieyanqzu/WeatherPlugin) - 天气查询插件
</details>
## 😘Thanks
- [@the-lazy-me](https://github.com/the-lazy-me) video tutorial creator
- [@mikumifa](https://github.com/mikumifa) Docker deployment
- [@dominoar](https://github.com/dominoar) Plugin development
- [@万神的星空](https://github.com/qq255204159) Packages publisher
- [@ljcduo](https://github.com/ljcduo) GPT-4 API internal test account
And all [contributors](https://github.com/RockChinQ/QChatGPT/graphs/contributors) and other friends who support this project.
<!-- ## 👍赞赏
<img alt="赞赏码" src="res/mm_reward_qrcode_1672840549070.png" width="400" height="400"/> -->

View File

@@ -1,370 +0,0 @@
# 配置文件: 注释里标[必需]的参数必须修改, 其他参数根据需要修改, 但请勿删除
import logging
# 消息处理协议适配器
# 目前支持以下适配器:
# - "yirimirai": mirai的通信框架YiriMirai框架适配器, 请同时填写下方mirai_http_api_config
# - "nakuru": go-cqhttp通信框架请同时填写下方nakuru_config
msg_source_adapter = "yirimirai"
# [必需(与nakuru二选一取决于msg_source_adapter)] Mirai的配置
# 请到配置mirai的步骤中的教程查看每个字段的信息
# adapter: 选择适配器目前支持HTTPAdapter和WebSocketAdapter
# host: 运行mirai的主机地址
# port: 运行mirai的主机端口
# verifyKey: mirai-api-http的verifyKey
# qq: 机器人的QQ号
#
# 注意: QQ机器人配置不支持热重载及热更新
mirai_http_api_config = {
"adapter": "WebSocketAdapter",
"host": "localhost",
"port": 8080,
"verifyKey": "yirimirai",
"qq": 1234567890
}
# [必需(与mirai二选一取决于msg_source_adapter)]
# 使用nakuru-project框架连接go-cqhttp的配置
nakuru_config = {
"host": "localhost", # go-cqhttp的地址
"port": 6700, # go-cqhttp的正向websocket端口
"http_port": 5700, # go-cqhttp的正向http端口
"token": "" # 若在go-cqhttp的config.yml设置了access_token, 则填写此处
}
# [必需] OpenAI的配置
# api_key: OpenAI的API Key
# http_proxy: 请求OpenAI时使用的代理None为不使用https和socks5暂不能使用
# 若只有一个api-key请直接修改以下内容中的"openai_api_key"为你的api-key
#
# 如准备了多个api-key可以以字典的形式填写程序会自动选择可用的api-key
# 例如
# openai_config = {
# "api_key": {
# "default": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
# "key1": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
# "key2": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
# },
# "http_proxy": "http://127.0.0.1:12345"
# }
#
# 现已支持反向代理可以添加reverse_proxy字段以使用反向代理
# 使用反向代理可以在国内使用OpenAI的API反向代理的配置请参考
# https://github.com/Ice-Hazymoon/openai-scf-proxy
#
# 反向代理填写示例:
# openai_config = {
# "api_key": {
# "default": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
# "key1": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
# "key2": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
# },
# "reverse_proxy": "http://example.com:12345/v1"
# }
#
# 作者开设公用反向代理地址: https://api.openai.rockchin.top/v1
# 随时可能关闭,仅供测试使用,有条件建议使用正向代理或者自建反向代理
openai_config = {
"api_key": {
"default": "openai_api_key"
},
"http_proxy": None,
"reverse_proxy": None
}
# api-key切换策略
# active每次请求时都会切换api-key
# passive仅当api-key超额时才会切换api-key
switch_strategy = "active"
# [必需] 管理员QQ号用于接收报错等通知及执行管理员级别命令
# 支持多个管理员可以使用list形式设置例如
# admin_qq = [12345678, 87654321]
admin_qq = 0
# 情景预设(机器人人格)
# 每个会话的预设信息,影响所有会话,无视命令重置
# 可以通过这个字段指定某些情况的回复,可直接用自然语言描述指令
# 例如:
# default_prompt = "如果我之后想获取帮助,请你说“输入!help获取帮助”"
# 这样用户在不知所措的时候机器人就会提示其输入!help获取帮助
# 可参考 https://github.com/PlexPt/awesome-chatgpt-prompts-zh
#
# 如果需要多个情景预设,并在运行期间方便切换,请使用字典的形式填写,例如
# default_prompt = {
# "default": "如果我之后想获取帮助,请你说“输入!help获取帮助”",
# "linux-terminal": "我想让你充当 Linux 终端。我将输入命令,您将回复终端应显示的内容。",
# "en-dict": "我想让你充当英英词典,对于给出的英文单词,你要给出其中文意思以及英文解释,并且给出一个例句,此外不要有其他反馈。",
# }
#
# 在使用期间即可通过命令:
# !reset [名称]
# 来使用指定的情景预设重置会话
# 例如:
# !reset linux-terminal
# 若不指定名称,则使用默认情景预设
#
# 也可以使用命令:
# !default <名称>
# 将指定的情景预设设置为默认情景预设
# 例如:
# !default linux-terminal
# 之后的会话重置时若不指定名称则使用linux-terminal情景预设
#
# 还可以加载文件中的预设文字使用方法请查看https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E9%A2%84%E8%AE%BE%E6%96%87%E5%AD%97
default_prompt = {
"default": "如果用户之后想获取帮助,请你说“输入!help获取帮助”。",
}
# 情景预设格式
# 参考值默认方式normal | 完整情景full_scenario
# 默认方式 的格式为上述default_prompt中的内容或prompts目录下的文件名
# 完整情景方式 的格式为JSON在scenario目录下的JSON文件中列出对话的每个回合编写方法见scenario/default-template.json
# 编写方法请查看https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E9%A2%84%E8%AE%BE%E6%96%87%E5%AD%97full_scenario%E6%A8%A1%E5%BC%8F
preset_mode = "normal"
# 群内响应规则
# 符合此消息的群内消息即使不包含at机器人也会响应
# 支持消息前缀匹配及正则表达式匹配
# 支持设置是否响应at消息、随机响应概率
# 注意:由消息前缀(prefix)匹配的消息中将会删除此前缀,正则表达式(regexp)匹配的消息不会删除匹配的部分
# 前缀匹配优先级高于正则表达式匹配
# 正则表达式简明教程https://www.runoob.com/regexp/regexp-tutorial.html
#
# 支持针对不同群设置不同的响应规则,例如:
# response_rules = {
# "default": {
# "at": True,
# "prefix": ["/ai", "!ai", "ai", "ai"],
# "regexp": [],
# "random_rate": 0.0,
# },
# "12345678": {
# "at": False,
# "prefix": ["/ai", "!ai", "ai", "ai"],
# "regexp": [],
# "random_rate": 0.0,
# },
# }
#
# 以上设置将会在群号为12345678的群中关闭at响应
# 未单独设置的群将使用default规则
response_rules = {
"default": {
"at": True, # 是否响应at机器人的消息
"prefix": ["/ai", "!ai", "ai", "ai"],
"regexp": [], # "为什么.*", "怎么?样.*", "怎么.*", "如何.*", "[Hh]ow to.*", "[Ww]hy not.*", "[Ww]hat is.*", ".*怎么办", ".*咋办"
"random_rate": 0.0, # 随机响应概率0.0-1.00.0为不随机响应1.0为响应所有消息, 仅在前几项判断不通过时生效
},
}
# 消息忽略规则
# 适用于私聊及群聊
# 符合此规则的消息将不会被响应
# 支持消息前缀匹配及正则表达式匹配
# 此设置优先级高于response_rules
# 用以过滤mirai等其他层级的命令
# @see https://github.com/RockChinQ/QChatGPT/issues/165
ignore_rules = {
"prefix": ["/"],
"regexp": []
}
# 是否检查收到的消息中是否包含敏感词
# 若收到的消息无法通过下方指定的敏感词检查策略,则发送提示信息
income_msg_check = False
# 敏感词过滤开关,以同样数量的*代替敏感词回复
# 请在sensitive.json中添加敏感词
sensitive_word_filter = True
# 是否启用百度云内容安全审核
# 注册方式查看 https://cloud.baidu.com/doc/ANTIPORN/s/Wkhu9d5iy
baidu_check = False
# 百度云API_KEY 24位英文数字字符串
baidu_api_key = ""
# 百度云SECRET_KEY 32位的英文数字字符串
baidu_secret_key = ""
# 不合规消息自定义返回
inappropriate_message_tips = "[百度云]请珍惜机器人,当前返回内容不合规"
# 启动时是否发送赞赏码
# 仅当使用量已经超过2048字时发送
encourage_sponsor_at_start = True
# 每次向OpenAI接口发送对话记录上下文的字符数
# 最大不超过(4096 - max_tokens)个字符max_tokens为下方completion_api_params中的max_tokens
# 注意较大的prompt_submit_length会导致OpenAI账户额度消耗更快
prompt_submit_length = 3072
# 是否在token超限报错时自动重置会话
# 可在tips.py中编辑提示语
auto_reset = True
# OpenAI补全API的参数
# 请在下方填写模型,程序自动选择接口
# 模型文档https://platform.openai.com/docs/models
# 现已支持的模型有:
#
# ChatCompletions 接口:
# # GPT 4 系列
# "gpt-4-1106-preview",
# "gpt-4-vision-preview",
# "gpt-4",
# "gpt-4-32k",
# "gpt-4-0613",
# "gpt-4-32k-0613",
# "gpt-4-0314", # legacy
# "gpt-4-32k-0314", # legacy
# # GPT 3.5 系列
# "gpt-3.5-turbo-1106",
# "gpt-3.5-turbo",
# "gpt-3.5-turbo-16k",
# "gpt-3.5-turbo-0613", # legacy
# "gpt-3.5-turbo-16k-0613", # legacy
# "gpt-3.5-turbo-0301", # legacy
#
# Completions接口
# "gpt-3.5-turbo-instruct",
#
# 具体请查看OpenAI的文档: https://beta.openai.com/docs/api-reference/completions/create
# 请将内容修改到config.py中请勿修改config-template.py
#
# 支持通过 One API 接入多种模型请在上方的openai_config中设置One API的代理地址
# 并在此填写您要使用的模型名称详细请参考https://github.com/songquanpeng/one-api
#
# 支持的 One API 模型:
# "SparkDesk",
# "chatglm_pro",
# "chatglm_std",
# "chatglm_lite",
# "qwen-v1",
# "qwen-plus-v1",
# "ERNIE-Bot",
# "ERNIE-Bot-turbo",
# "gemini-pro",
completion_api_params = {
"model": "gpt-3.5-turbo",
"temperature": 0.9, # 数值越低得到的回答越理性,取值范围[0, 1]
}
# OpenAI的Image API的参数
# 具体请查看OpenAI的文档: https://platform.openai.com/docs/api-reference/images/create
image_api_params = {
"model": "dall-e-2", # 默认使用 dall-e-2 模型,也可以改为 dall-e-3
# 图片尺寸
# dall-e-2 模型支持 256x256, 512x512, 1024x1024
# dall-e-3 模型支持 1024x1024, 1792x1024, 1024x1792
"size": "256x256",
}
# 跟踪函数调用
# 为True时在每次GPT进行Function Calling时都会输出发送一条回复给用户
# 同时一次提问内所有的Function Calling和普通回复消息都会单独发送给用户
trace_function_calls = False
# 群内回复消息时是否引用原消息
quote_origin = False
# 群内回复消息时是否at发送者
at_sender = False
# 回复绘图时是否包含图片描述
include_image_description = True
# 消息处理的超时时间,单位为秒
process_message_timeout = 120
# 回复消息时是否显示[GPT]前缀
show_prefix = False
# 回复前的强制延迟时间,降低机器人被腾讯风控概率
# *此机制对命令和消息、私聊及群聊均生效
# 每次处理时从以下的范围取一个随机秒数,
# 当此次消息处理时间低于此秒数时,将会强制延迟至此秒数
# 例如:[1.5, 3]则每次处理时会随机取一个1.5-3秒的随机数若处理时间低于此随机数则强制延迟至此随机秒数
# 若您不需要此功能请将force_delay_range设置为[0, 0]
force_delay_range = [0, 0]
# 应用长消息处理策略的阈值
# 当回复消息长度超过此值时,将使用长消息处理策略
blob_message_threshold = 256
# 长消息处理策略
# - "image": 将长消息转换为图片发送
# - "forward": 将长消息转换为转发消息组件发送
blob_message_strategy = "forward"
# 允许等待
# 同一会话内,是否等待上一条消息处理完成后再处理下一条消息
# 若设置为False若上一条未处理完时收到了新消息将会丢弃新消息
# 丢弃消息时的提示信息可以在tips.py中修改
wait_last_done = True
# 文字转图片时使用的字体文件路径
# 当策略为"image"时生效
# 若在Windows系统下程序会自动使用Windows自带的微软雅黑字体
# 若未填写或不存在且不是Windows将禁用文字转图片功能改为使用转发消息组件
font_path = ""
# 消息处理超时重试次数
retry_times = 3
# 消息处理出错时是否向用户隐藏错误详细信息
# 设置为True时仅向管理员发送错误详细信息
# 设置为False时向用户及管理员发送错误详细信息
hide_exce_info_to_user = False
# 每个会话的过期时间,单位为秒
# 默认值20分钟
session_expire_time = 1200
# 会话限速
# 单会话内每分钟可进行的对话次数
# 若不需要限速,可以设置为一个很大的值
# 默认值60次基本上不会触发限速
#
# 若要设置针对某特定群的限速,请使用如下格式:
# {
# "group_<群号>": 60,
# "default": 60,
# }
# 若要设置针对某特定用户私聊的限速,请使用如下格式:
# {
# "person_<用户QQ>": 60,
# "default": 60,
# }
# 同时设置多个群和私聊的限速,示例:
# {
# "group_12345678": 60,
# "group_87654321": 60,
# "person_234567890": 60,
# "person_345678901": 60,
# "default": 60,
# }
#
# 注意: 未指定的都使用default的限速值default不可删除
rate_limitation = {
"default": 60,
}
# 会话限速策略
# - "wait": 每次对话获取到回复时,等待一定时间再发送回复,保证其不会超过限速均值
# - "drop": 此分钟内,若对话次数超过限速次数,则丢弃之后的对话,每自然分钟重置
rate_limit_strategy = "drop"
# 是否在启动时进行依赖库更新
upgrade_dependencies = False
# 是否上报统计信息
# 用于统计机器人的使用情况,数据不公开,不会收集任何敏感信息。
# 仅实例识别UUID、上报时间、字数使用量、绘图使用量、插件使用情况、用户信息其他信息不会上报
report_usage = True
# 日志级别
logging_level = logging.INFO

View File

@@ -4,15 +4,7 @@ services:
qchatgpt:
image: rockchin/qchatgpt:latest
volumes:
- ./config.py:/QChatGPT/config.py
- ./banlist.py:/QChatGPT/banlist.py
- ./cmdpriv.json:/QChatGPT/cmdpriv.json
- ./sensitive.json:/QChatGPT/sensitive.json
- ./tips.py:/QChatGPT/tips.py
# 目录映射
- ./plugins:/QChatGPT/plugins
- ./scenario:/QChatGPT/scenario
- ./temp:/QChatGPT/temp
- ./logs:/QChatGPT/logs
restart: always
- ./data:/app/data
- ./plugins:/app/plugins
restart: on-failure
# 根据具体环境配置网络

508
main.py
View File

@@ -1,496 +1,54 @@
import importlib
import json
import os
import shutil
import threading
import time
# QChatGPT 终端启动入口
# 在此层级解决依赖项检查。
import logging
import sys
import traceback
import asyncio
asciiart = r"""
___ ___ _ _ ___ ___ _____
/ _ \ / __| |_ __ _| |_ / __| _ \_ _|
| (_) | (__| ' \/ _` | _| (_ | _/ | |
\__\_\\___|_||_\__,_|\__|\___|_| |_|
sys.path.append(".")
⭐️开源地址: https://github.com/RockChinQ/QChatGPT
📖文档地址: https://q.rkcn.top
"""
def check_file():
# 检查是否有banlist.py,如果没有就把banlist-template.py复制一份
if not os.path.exists('banlist.py'):
shutil.copy('res/templates/banlist-template.py', 'banlist.py')
async def main_entry():
print(asciiart)
# 检查是否有sensitive.json
if not os.path.exists("sensitive.json"):
shutil.copy("res/templates/sensitive-template.json", "sensitive.json")
import sys
# 检查是否有scenario/default.json
if not os.path.exists("scenario/default.json"):
shutil.copy("scenario/default-template.json", "scenario/default.json")
# 检查依赖
# 检查cmdpriv.json
if not os.path.exists("cmdpriv.json"):
shutil.copy("res/templates/cmdpriv-template.json", "cmdpriv.json")
from pkg.core.bootutils import deps
# 检查tips_custom
if not os.path.exists("tips.py"):
shutil.copy("tips-custom-template.py", "tips.py")
missing_deps = await deps.check_deps()
# 检查temp目录
if not os.path.exists("temp/"):
os.mkdir("temp/")
# 检查并创建plugins、prompts目录
check_path = ["plugins", "prompts"]
for path in check_path:
if not os.path.exists(path):
os.mkdir(path)
# 配置文件存在性校验
if not os.path.exists('config.py'):
shutil.copy('config-template.py', 'config.py')
print('请先在config.py中填写配置')
if missing_deps:
print("以下依赖包未安装,将自动安装,请完成后重启程序:")
for dep in missing_deps:
print("-", dep)
await deps.install_deps(missing_deps)
print("已自动安装缺失的依赖包,请重启程序。")
sys.exit(0)
# 初始化相关文件
check_file()
from pkg.utils.log import init_runtime_log_file, reset_logging
from pkg.config import manager as config_mgr
from pkg.config.impls import pymodule as pymodule_cfg
try:
import colorlog
except ImportError:
# 尝试安装
import pkg.utils.pkgmgr as pkgmgr
try:
pkgmgr.install_requirements("requirements.txt")
import colorlog
except ImportError:
print("依赖不满足,请查看 https://github.com/RockChinQ/qcg-installer/issues/15")
sys.exit(1)
import colorlog
import requests
import websockets.exceptions
from urllib3.exceptions import InsecureRequestWarning
import pkg.utils.context
# 是否使用override.json覆盖配置
# 仅在启动时提供 --override 或 -r 参数时生效
use_override = False
def init_db():
import pkg.database.manager
database = pkg.database.manager.DatabaseManager()
database.initialize_database()
def ensure_dependencies():
import pkg.utils.pkgmgr as pkgmgr
pkgmgr.run_pip(["install", "openai", "Pillow", "nakuru-project-idk", "CallingGPT", "tiktoken", "--upgrade",
"-i", "https://pypi.tuna.tsinghua.edu.cn/simple",
"--trusted-host", "pypi.tuna.tsinghua.edu.cn"])
known_exception_caught = False
def override_config_manager():
config = pkg.utils.context.get_config_manager().data
if os.path.exists("override.json") and use_override:
override_json = json.load(open("override.json", "r", encoding="utf-8"))
overrided = []
for key in override_json:
if key in config:
config[key] = override_json[key]
# logging.info("覆写配置[{}]为[{}]".format(key, override_json[key]))
overrided.append(key)
else:
logging.error("无法覆写配置[{}]为[{}]该配置不存在请检查override.json是否正确".format(key, override_json[key]))
if len(overrided) > 0:
logging.info("已根据override.json覆写配置项: {}".format(", ".join(overrided)))
def complete_tips():
"""根据tips-custom-template模块补全tips模块的属性"""
non_exist_keys = []
is_integrity = True
logging.debug("检查tips模块完整性.")
tips_template = importlib.import_module('tips-custom-template')
tips = importlib.import_module('tips')
for key in dir(tips_template):
if not key.startswith("__") and not hasattr(tips, key):
setattr(tips, key, getattr(tips_template, key))
# logging.warning("[{}]不存在".format(key))
non_exist_keys.append(key)
is_integrity = False
if not is_integrity:
logging.warning("以下提示语字段不存在: {}".format(", ".join(non_exist_keys)))
logging.warning("tips模块不完整您可以依据tips-custom-template.py检查tips.py")
logging.warning("以上配置已被设为默认值将在3秒后继续启动... ")
time.sleep(3)
async def start_process(first_time_init=False):
"""启动流程reload之后会被执行"""
global known_exception_caught
import pkg.utils.context
# 计算host和instance标识符
import pkg.audit.identifier
pkg.audit.identifier.init()
# 加载配置
cfg_inst: pymodule_cfg.PythonModuleConfigFile = pymodule_cfg.PythonModuleConfigFile(
'config.py',
'config-template.py'
)
await config_mgr.ConfigManager(cfg_inst).load_config()
override_config_manager()
# 检查tips模块
complete_tips()
cfg = pkg.utils.context.get_config_manager().data
# 更新openai库到最新版本
if 'upgrade_dependencies' not in cfg or cfg['upgrade_dependencies']:
print("正在更新依赖库,请等待...")
if 'upgrade_dependencies' not in cfg:
print("这个操作不是必须的,如果不想更新,请在config.py中添加upgrade_dependencies=False")
else:
print("这个操作不是必须的,如果不想更新,请在config.py中将upgrade_dependencies设置为False")
try:
ensure_dependencies()
except Exception as e:
print("更新openai库失败:{}, 请忽略或自行更新".format(e))
known_exception_caught = False
try:
try:
sh = reset_logging()
pkg.utils.context.context['logger_handler'] = sh
# 初始化文字转图片
from pkg.utils import text2img
text2img.initialize()
# 检查是否设置了管理员
if cfg['admin_qq'] == 0:
# logging.warning("未设置管理员QQ,管理员权限命令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
while True:
try:
cfg['admin_qq'] = int(input("未设置管理员QQ,管理员权限命令及运行告警将无法使用,请输入管理员QQ号: "))
# 写入到文件
# 读取文件
config_file_str = ""
with open("config.py", "r", encoding="utf-8") as f:
config_file_str = f.read()
# 替换
config_file_str = config_file_str.replace("admin_qq = 0", "admin_qq = " + str(cfg['admin_qq']))
# 写入
with open("config.py", "w", encoding="utf-8") as f:
f.write(config_file_str)
print("管理员QQ已设置如需修改请修改config.py中的admin_qq字段")
time.sleep(4)
break
except ValueError:
print("请输入数字")
# 初始化中央服务器 API 交互实例
from pkg.utils.center import apigroup
from pkg.utils.center import v2 as center_v2
center_v2_api = center_v2.V2CenterAPI(
basic_info={
"host_id": pkg.audit.identifier.identifier['host_id'],
"instance_id": pkg.audit.identifier.identifier['instance_id'],
"semantic_version": pkg.utils.updater.get_current_tag(),
"platform": sys.platform,
},
runtime_info={
"admin_id": "{}".format(cfg['admin_qq']),
"msg_source": cfg['msg_source_adapter'],
}
)
pkg.utils.context.set_center_v2_api(center_v2_api)
import pkg.openai.manager
import pkg.database.manager
import pkg.openai.session
import pkg.qqbot.manager
import pkg.openai.dprompt
import pkg.qqbot.cmds.aamgr
try:
pkg.openai.dprompt.register_all()
pkg.qqbot.cmds.aamgr.register_all()
pkg.qqbot.cmds.aamgr.apply_privileges()
except Exception as e:
logging.error(e)
traceback.print_exc()
# 配置OpenAI proxy
import openai
openai.proxies = None # 先重置因为重载后可能需要清除proxy
if "http_proxy" in cfg['openai_config'] and cfg['openai_config']["http_proxy"] is not None:
openai.proxies = {
"http": cfg['openai_config']["http_proxy"],
"https": cfg['openai_config']["http_proxy"]
}
# 配置openai api_base
if "reverse_proxy" in cfg['openai_config'] and cfg['openai_config']["reverse_proxy"] is not None:
logging.debug("设置反向代理: "+cfg['openai_config']['reverse_proxy'])
openai.base_url = cfg['openai_config']["reverse_proxy"]
# 主启动流程
database = pkg.database.manager.DatabaseManager()
database.initialize_database()
openai_interact = pkg.openai.manager.OpenAIInteract(cfg['openai_config']['api_key'])
# 加载所有未超时的session
pkg.openai.session.load_sessions()
# 初始化qq机器人
qqbot = pkg.qqbot.manager.QQBotManager(first_time_init=first_time_init)
# 加载插件
import pkg.plugin.host
pkg.plugin.host.load_plugins()
pkg.plugin.host.initialize_plugins()
if first_time_init: # 不是热重载之后的启动,则启动新的bot线程
import mirai.exceptions
def run_bot_wrapper():
global known_exception_caught
try:
logging.debug("使用账号: {}".format(qqbot.bot_account_id))
qqbot.adapter.run_sync()
except TypeError as e:
if str(e).__contains__("argument 'debug'"):
logging.error(
"连接bot失败:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/82".format(e))
known_exception_caught = True
elif str(e).__contains__("As of 3.10, the *loop*"):
logging.error(
"Websockets版本过低:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/5".format(e))
known_exception_caught = True
except websockets.exceptions.InvalidStatus as e:
logging.error(
"mirai-api-http端口无法使用:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/22".format(
e))
known_exception_caught = True
except mirai.exceptions.NetworkError as e:
logging.error("连接mirai-api-http失败:{}, 请检查是否已按照文档启动mirai".format(e))
known_exception_caught = True
except Exception as e:
if str(e).__contains__("404"):
logging.error(
"mirai-api-http端口无法使用:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/22".format(
e))
known_exception_caught = True
elif str(e).__contains__("signal only works in main thread"):
logging.error(
"hypercorn异常:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/86".format(
e))
known_exception_caught = True
elif str(e).__contains__("did not receive a valid HTTP"):
logging.error(
"mirai-api-http端口无法使用:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/22".format(
e))
else:
import traceback
traceback.print_exc()
logging.error(
"捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/QChatGPT/issues 查找或提issue".format(e))
known_exception_caught = True
raise e
finally:
time.sleep(12)
threading.Thread(
target=run_bot_wrapper
).start()
except Exception as e:
traceback.print_exc()
if isinstance(e, KeyboardInterrupt):
logging.info("程序被用户中止")
sys.exit(0)
elif isinstance(e, SyntaxError):
logging.error("配置文件存在语法错误,请检查配置文件:\n1. 是否存在中文符号\n2. 是否已按照文件中的说明填写正确")
sys.exit(1)
else:
logging.error("初始化失败:{}".format(e))
sys.exit(1)
finally:
# 判断若是Windows输出选择模式可能会暂停程序的警告
if os.name == 'nt':
time.sleep(2)
logging.info("您正在使用Windows系统若命令行窗口处于“选择”模式程序可能会被暂停此时请右键点击窗口空白区域使其取消选择模式。")
time.sleep(12)
# 检查配置文件
if first_time_init:
if not known_exception_caught:
if cfg['msg_source_adapter'] == "yirimirai":
logging.info("QQ: {}, MAH: {}".format(cfg['mirai_http_api_config']['qq'], cfg['mirai_http_api_config']['host']+":"+str(cfg['mirai_http_api_config']['port'])))
logging.critical('程序启动完成,如长时间未显示 "成功登录到账号xxxxx" ,并且不回复消息,解决办法(请勿到群里问): '
'https://github.com/RockChinQ/QChatGPT/issues/37')
elif cfg['msg_source_adapter'] == 'nakuru':
logging.info("host: {}, port: {}, http_port: {}".format(cfg['nakuru_config']['host'], cfg['nakuru_config']['port'], cfg['nakuru_config']['http_port']))
logging.critical('程序启动完成,如长时间未显示 "Protocol: connected" ,并且不回复消息,请检查config.py中的nakuru_config是否正确')
else:
sys.exit(1)
else:
logging.info('热重载完成')
from pkg.core.bootutils import files
# 发送赞赏码
if cfg['encourage_sponsor_at_start'] \
and pkg.utils.context.get_openai_manager().audit_mgr.get_total_text_length() >= 2048:
generated_files = await files.generate_files()
logging.info("发送赞赏码")
from mirai import MessageChain, Plain, Image
import pkg.utils.constants
message_chain = MessageChain([
Plain("自2022年12月初以来开发者已经花费了大量时间和精力来维护本项目如果您觉得本项目对您有帮助欢迎赞赏开发者"
"以支持项目稳定运行😘"),
Image(base64=pkg.utils.constants.alipay_qr_b64),
Image(base64=pkg.utils.constants.wechat_qr_b64),
Plain("BTC: 3N4Azee63vbBB9boGv9Rjf4N5SocMe5eCq\nXMR: 89LS21EKQuDGkyQoe2nDupiuWXk4TVD6FALvSKv5owfmeJEPFpHeMsZLYtLiJ6GxLrhsRe5gMs6MyMSDn4GNQAse2Mae4KE\n\n"),
Plain("(本消息仅在启动时发送至管理员如果您不想再看到此消息请在config.py中将encourage_sponsor_at_start设置为False)")
])
pkg.utils.context.get_qqbot_manager().notify_admin_message_chain(message_chain)
if generated_files:
print("以下文件不存在,已自动生成,请按需修改配置文件后重启:")
for file in generated_files:
print("-", file)
time.sleep(5)
import pkg.utils.updater
try:
if pkg.utils.updater.is_new_version_available():
logging.info("新版本可用,请发送 !update 进行自动更新\n更新日志:\n{}".format("\n".join(pkg.utils.updater.get_rls_notes())))
else:
# logging.info("当前已是最新版本")
pass
except Exception as e:
logging.warning("检查更新失败:{}".format(e))
try:
import pkg.utils.announcement as announcement
new_announcement = announcement.fetch_new()
if len(new_announcement) > 0:
for announcement in new_announcement:
logging.critical("[公告]<{}> {}".format(announcement['time'], announcement['content']))
# 发送统计数据
pkg.utils.context.get_center_v2_api().main.post_announcement_showed(
[announcement['id'] for announcement in new_announcement]
)
except Exception as e:
logging.warning("获取公告失败:{}".format(e))
return qqbot
def stop():
import pkg.qqbot.manager
import pkg.openai.session
try:
import pkg.plugin.host
pkg.plugin.host.unload_plugins()
qqbot_inst = pkg.utils.context.get_qqbot_manager()
assert isinstance(qqbot_inst, pkg.qqbot.manager.QQBotManager)
for session in pkg.openai.session.sessions:
logging.info('持久化session: %s', session)
pkg.openai.session.sessions[session].persistence()
pkg.utils.context.get_database_manager().close()
except Exception as e:
if not isinstance(e, KeyboardInterrupt):
raise e
def main():
global use_override
# 检查是否携带了 --override 或 -r 参数
if '--override' in sys.argv or '-r' in sys.argv:
use_override = True
# 初始化logging
init_runtime_log_file()
pkg.utils.context.context['logger_handler'] = reset_logging()
# 配置线程池
from pkg.utils import ThreadCtl
thread_ctl = ThreadCtl(
sys_pool_num=8,
admin_pool_num=4,
user_pool_num=8
)
# 存进上下文
pkg.utils.context.set_thread_ctl(thread_ctl)
# 启动指令处理
if len(sys.argv) > 1 and sys.argv[1] == 'init_db':
init_db()
sys.exit(0)
elif len(sys.argv) > 1 and sys.argv[1] == 'update':
print("正在进行程序更新...")
import pkg.utils.updater as updater
updater.update_all(cli=True)
sys.exit(0)
# 关闭urllib的http警告
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
def run_wrapper():
asyncio.run(start_process(True))
pkg.utils.context.get_thread_ctl().submit_sys_task(
run_wrapper
)
# 主线程循环
while True:
try:
time.sleep(0xFF)
except:
stop()
pkg.utils.context.get_thread_ctl().shutdown()
launch_args = sys.argv.copy()
if "--cov-report" not in launch_args:
import platform
if platform.system() == 'Windows':
cmd = "taskkill /F /PID {}".format(os.getpid())
elif platform.system() in ['Linux', 'Darwin']:
cmd = "kill -9 {}".format(os.getpid())
os.system(cmd)
else:
print("正常退出以生成覆盖率报告")
sys.exit(0)
from pkg.core import boot
await boot.main()
if __name__ == '__main__':
main()
import asyncio
asyncio.run(main_entry())

View File

@@ -1,90 +0,0 @@
{
"comment": "这是override.json支持的字段全集, 关于override.json机制, 请查看https://github.com/RockChinQ/QChatGPT/pull/271",
"msg_source_adapter": "yirimirai",
"mirai_http_api_config": {
"adapter": "WebSocketAdapter",
"host": "localhost",
"port": 8080,
"verifyKey": "yirimirai",
"qq": 1234567890
},
"nakuru_config": {
"host": "localhost",
"port": 6700,
"http_port": 5700,
"token": ""
},
"openai_config": {
"api_key": {
"default": "openai_api_key"
},
"http_proxy": null,
"reverse_proxy": null
},
"switch_strategy": "active",
"admin_qq": 0,
"default_prompt": {
"default": "如果用户之后想获取帮助,请你说“输入!help获取帮助”。"
},
"preset_mode": "normal",
"response_rules": {
"default": {
"at": true,
"prefix": [
"/ai",
"!ai",
"ai",
"ai"
],
"regexp": [],
"random_rate": 0.0
}
},
"ignore_rules": {
"prefix": [
"/"
],
"regexp": []
},
"income_msg_check": false,
"sensitive_word_filter": true,
"baidu_check": false,
"baidu_api_key": "",
"baidu_secret_key": "",
"inappropriate_message_tips": "[百度云]请珍惜机器人,当前返回内容不合规",
"encourage_sponsor_at_start": true,
"prompt_submit_length": 3072,
"auto_reset": true,
"completion_api_params": {
"model": "gpt-3.5-turbo",
"temperature": 0.9
},
"image_api_params": {
"model": "dall-e-2",
"size": "256x256"
},
"trace_function_calls": false,
"quote_origin": false,
"at_sender": false,
"include_image_description": true,
"process_message_timeout": 120,
"show_prefix": false,
"force_delay_range": [
0,
0
],
"blob_message_threshold": 256,
"blob_message_strategy": "forward",
"wait_last_done": true,
"font_path": "",
"retry_times": 3,
"hide_exce_info_to_user": false,
"session_expire_time": 1200,
"rate_limitation": {
"default": 60
},
"rate_limit_strategy": "drop",
"upgrade_dependencies": false,
"report_usage": true,
"logging_level": 20
}

View File

@@ -0,0 +1,89 @@
from __future__ import annotations
import abc
import uuid
import json
import logging
import asyncio
import aiohttp
import requests
from ...core import app
class APIGroup(metaclass=abc.ABCMeta):
"""API 组抽象类"""
_basic_info: dict = None
_runtime_info: dict = None
prefix = None
ap: app.Application
def __init__(self, prefix: str, ap: app.Application):
self.prefix = prefix
self.ap = ap
async def _do(
self,
method: str,
path: str,
data: dict = None,
params: dict = None,
headers: dict = {},
**kwargs
):
self._runtime_info['account_id'] = "-1"
url = self.prefix + path
data = json.dumps(data)
headers['Content-Type'] = 'application/json'
try:
async with aiohttp.ClientSession() as session:
async with session.request(
method,
url,
data=data,
params=params,
headers=headers,
**kwargs
) as resp:
self.ap.logger.debug("data: %s", data)
self.ap.logger.debug("ret: %s", await resp.text())
except Exception as e:
self.ap.logger.debug(f'上报失败: {e}')
async def do(
self,
method: str,
path: str,
data: dict = None,
params: dict = None,
headers: dict = {},
**kwargs
) -> asyncio.Task:
"""执行请求"""
asyncio.create_task(self._do(method, path, data, params, headers, **kwargs))
def gen_rid(
self
):
"""生成一个请求 ID"""
return str(uuid.uuid4())
def basic_info(
self
):
"""获取基本信息"""
basic_info = APIGroup._basic_info.copy()
basic_info['rid'] = self.gen_rid()
return basic_info
def runtime_info(
self
):
"""获取运行时信息"""
return APIGroup._runtime_info

View File

@@ -1,22 +1,22 @@
from __future__ import annotations
from .. import apigroup
from ... import context
from ....core import app
class V2MainDataAPI(apigroup.APIGroup):
"""主程序相关 数据API"""
def __init__(self, prefix: str):
super().__init__(prefix+"/main")
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/main", ap)
def do(self, *args, **kwargs):
config = context.get_config_manager().data
if not config['report_usage']:
async def do(self, *args, **kwargs):
if not self.ap.system_cfg.data['report-usage']:
return None
return super().do(*args, **kwargs)
return await super().do(*args, **kwargs)
def post_update_record(
async def post_update_record(
self,
spent_seconds: int,
infer_reason: str,
@@ -24,7 +24,7 @@ class V2MainDataAPI(apigroup.APIGroup):
new_version: str,
):
"""提交更新记录"""
return self.do(
return await self.do(
"POST",
"/update",
data={
@@ -38,12 +38,12 @@ class V2MainDataAPI(apigroup.APIGroup):
}
)
def post_announcement_showed(
async def post_announcement_showed(
self,
ids: list[int],
):
"""提交公告已阅"""
return self.do(
return await self.do(
"POST",
"/announcement",
data={

View File

@@ -1,27 +1,27 @@
from __future__ import annotations
from ....core import app
from .. import apigroup
from ... import context
class V2PluginDataAPI(apigroup.APIGroup):
"""插件数据相关 API"""
def __init__(self, prefix: str):
super().__init__(prefix+"/plugin")
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/plugin", ap)
def do(self, *args, **kwargs):
config = context.get_config_manager().data
if not config['report_usage']:
async def do(self, *args, **kwargs):
if not self.ap.system_cfg.data['report-usage']:
return None
return super().do(*args, **kwargs)
return await super().do(*args, **kwargs)
def post_install_record(
async def post_install_record(
self,
plugin: dict
):
"""提交插件安装记录"""
return self.do(
return await self.do(
"POST",
"/install",
data={
@@ -30,12 +30,12 @@ class V2PluginDataAPI(apigroup.APIGroup):
}
)
def post_remove_record(
async def post_remove_record(
self,
plugin: dict
):
"""提交插件卸载记录"""
return self.do(
return await self.do(
"POST",
"/remove",
data={
@@ -44,14 +44,14 @@ class V2PluginDataAPI(apigroup.APIGroup):
}
)
def post_update_record(
async def post_update_record(
self,
plugin: dict,
old_version: str,
new_version: str,
):
"""提交插件更新记录"""
return self.do(
return await self.do(
"POST",
"/update",
data={

View File

@@ -1,22 +1,22 @@
from __future__ import annotations
from .. import apigroup
from ... import context
from ....core import app
class V2UsageDataAPI(apigroup.APIGroup):
"""使用量数据相关 API"""
def __init__(self, prefix: str):
super().__init__(prefix+"/usage")
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/usage", ap)
def do(self, *args, **kwargs):
config = context.get_config_manager().data
if not config['report_usage']:
async def do(self, *args, **kwargs):
if not self.ap.system_cfg.data['report-usage']:
return None
return super().do(*args, **kwargs)
def post_query_record(
return await super().do(*args, **kwargs)
async def post_query_record(
self,
session_type: str,
session_id: str,
@@ -27,7 +27,7 @@ class V2UsageDataAPI(apigroup.APIGroup):
retry_times: int,
):
"""提交请求记录"""
return self.do(
return await self.do(
"POST",
"/query",
data={
@@ -47,13 +47,13 @@ class V2UsageDataAPI(apigroup.APIGroup):
}
)
def post_event_record(
async def post_event_record(
self,
plugins: list[dict],
event_name: str,
):
"""提交事件触发记录"""
return self.do(
return await self.do(
"POST",
"/event",
data={
@@ -66,14 +66,14 @@ class V2UsageDataAPI(apigroup.APIGroup):
}
)
def post_function_record(
async def post_function_record(
self,
plugin: dict,
function_name: str,
function_description: str,
):
"""提交内容函数使用记录"""
return self.do(
return await self.do(
"POST",
"/function",
data={

View File

@@ -6,6 +6,7 @@ from . import apigroup
from .groups import main
from .groups import usage
from .groups import plugin
from ...core import app
BACKEND_URL = "https://api.qchatgpt.rockchin.top/api/v2"
@@ -22,7 +23,7 @@ class V2CenterAPI:
plugin: plugin.V2PluginDataAPI = None
"""插件 API 组"""
def __init__(self, basic_info: dict = None, runtime_info: dict = None):
def __init__(self, ap: app.Application, basic_info: dict = None, runtime_info: dict = None):
"""初始化"""
logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info)
@@ -30,6 +31,7 @@ class V2CenterAPI:
apigroup.APIGroup._basic_info = basic_info
apigroup.APIGroup._runtime_info = runtime_info
self.main = main.V2MainDataAPI(BACKEND_URL)
self.usage = usage.V2UsageDataAPI(BACKEND_URL)
self.plugin = plugin.V2PluginDataAPI(BACKEND_URL)
self.main = main.V2MainDataAPI(BACKEND_URL, ap)
self.usage = usage.V2UsageDataAPI(BACKEND_URL, ap)
self.plugin = plugin.V2PluginDataAPI(BACKEND_URL, ap)

View File

@@ -1,114 +0,0 @@
"""
使用量统计以及数据上报功能实现
"""
import hashlib
import json
import logging
import threading
import requests
from ..utils import context
from ..utils import updater
class DataGatherer:
"""数据收集器"""
usage = {}
"""各api-key的使用量
以key值md5为key,{
"text": {
"gpt-3.5-turbo": 文字量:int,
},
"image": {
"256x256": 图片数量:int,
}
}为值的字典"""
version_str = "undetermined"
def __init__(self):
self.load_from_db()
try:
self.version_str = updater.get_current_tag() # 从updater模块获取版本号
except:
pass
def get_usage(self, key_md5):
return self.usage[key_md5] if key_md5 in self.usage else {}
def report_text_model_usage(self, model, total_tokens):
"""调用方报告文字模型请求文字使用量"""
key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存
if key_md5 not in self.usage:
self.usage[key_md5] = {}
if "text" not in self.usage[key_md5]:
self.usage[key_md5]["text"] = {}
if model not in self.usage[key_md5]["text"]:
self.usage[key_md5]["text"][model] = 0
length = total_tokens
self.usage[key_md5]["text"][model] += length
self.dump_to_db()
def report_image_model_usage(self, size):
"""调用方报告图片模型请求图片使用量"""
key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5()
if key_md5 not in self.usage:
self.usage[key_md5] = {}
if "image" not in self.usage[key_md5]:
self.usage[key_md5]["image"] = {}
if size not in self.usage[key_md5]["image"]:
self.usage[key_md5]["image"][size] = 0
self.usage[key_md5]["image"][size] += 1
self.dump_to_db()
def get_text_length_of_key(self, key):
"""获取指定api-key (明文) 的文字总使用量(本地记录)"""
key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
if key_md5 not in self.usage:
return 0
if "text" not in self.usage[key_md5]:
return 0
# 遍历其中所有模型,求和
return sum(self.usage[key_md5]["text"].values())
def get_image_count_of_key(self, key):
"""获取指定api-key (明文) 的图片总使用量(本地记录)"""
key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
if key_md5 not in self.usage:
return 0
if "image" not in self.usage[key_md5]:
return 0
# 遍历其中所有模型,求和
return sum(self.usage[key_md5]["image"].values())
def get_total_text_length(self):
"""获取所有api-key的文字总使用量(本地记录)"""
total = 0
for key in self.usage:
if "text" not in self.usage[key]:
continue
total += sum(self.usage[key]["text"].values())
return total
def dump_to_db(self):
context.get_database_manager().dump_usage_json(self.usage)
def load_from_db(self):
json_str = context.get_database_manager().load_usage_json()
if json_str is not None:
self.usage = json.loads(json_str)

125
pkg/command/cmdmgr.py Normal file
View File

@@ -0,0 +1,125 @@
from __future__ import annotations
import typing
from ..core import app, entities as core_entities
from ..provider import entities as llm_entities
from . import entities, operator, errors
from ..config import manager as cfg_mgr
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update
class CommandManager:
"""命令管理器
"""
ap: app.Application
cmd_list: list[operator.CommandOperator]
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
# 设置各个类的路径
def set_path(cls: operator.CommandOperator, ancestors: list[str]):
cls.path = '.'.join(ancestors + [cls.name])
for op in operator.preregistered_operators:
if op.parent_class == cls:
set_path(op, ancestors + [cls.name])
for cls in operator.preregistered_operators:
if cls.parent_class is None:
set_path(cls, [])
# 应用命令权限配置
for cls in operator.preregistered_operators:
if cls.path in self.ap.command_cfg.data['privilege']:
cls.lowest_privilege = self.ap.command_cfg.data['privilege'][cls.path]
# 实例化所有类
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
# 设置所有类的子节点
for cmd in self.cmd_list:
cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__]
# 初始化所有类
for cmd in self.cmd_list:
await cmd.initialize()
async def _execute(
self,
context: entities.ExecuteContext,
operator_list: list[operator.CommandOperator],
operator: operator.CommandOperator = None
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""
found = False
if len(context.crt_params) > 0:
for oper in operator_list:
if (context.crt_params[0] == oper.name \
or context.crt_params[0] in oper.alias) \
and (oper.parent_class is None or oper.parent_class == operator.__class__):
found = True
context.crt_command = context.crt_params[0]
context.crt_params = context.crt_params[1:]
async for ret in self._execute(
context,
oper.children,
oper
):
yield ret
break
if not found:
if operator is None:
yield entities.CommandReturn(
error=errors.CommandNotFoundError(context.crt_params[0])
)
else:
if operator.lowest_privilege > context.privilege:
yield entities.CommandReturn(
error=errors.CommandPrivilegeError(operator.name)
)
else:
async for ret in operator.execute(context):
yield ret
async def execute(
self,
command_text: str,
query: core_entities.Query,
session: core_entities.Session
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""
privilege = 1
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']:
privilege = 2
ctx = entities.ExecuteContext(
query=query,
session=session,
command_text=command_text,
command='',
crt_command='',
params=command_text.split(' '),
crt_params=command_text.split(' '),
privilege=privilege
)
async for ret in self._execute(
ctx,
self.cmd_list
):
yield ret

42
pkg/command/entities.py Normal file
View File

@@ -0,0 +1,42 @@
from __future__ import annotations
import typing
import pydantic
import mirai
from ..core import app, entities as core_entities
from . import errors, operator
class CommandReturn(pydantic.BaseModel):
text: typing.Optional[str]
"""文本
"""
image: typing.Optional[mirai.Image]
error: typing.Optional[errors.CommandError]= None
class Config:
arbitrary_types_allowed = True
class ExecuteContext(pydantic.BaseModel):
query: core_entities.Query
session: core_entities.Session
command_text: str
command: str
crt_command: str
params: list[str]
crt_params: list[str]
privilege: int

33
pkg/command/errors.py Normal file
View File

@@ -0,0 +1,33 @@
class CommandError(Exception):
def __init__(self, message: str = None):
self.message = message
def __str__(self):
return self.message
class CommandNotFoundError(CommandError):
def __init__(self, message: str = None):
super().__init__("未知命令: "+message)
class CommandPrivilegeError(CommandError):
def __init__(self, message: str = None):
super().__init__("权限不足: "+message)
class ParamNotEnoughError(CommandError):
def __init__(self, message: str = None):
super().__init__("参数不足: "+message)
class CommandOperationError(CommandError):
def __init__(self, message: str = None):
super().__init__("操作失败: "+message)

78
pkg/command/operator.py Normal file
View File

@@ -0,0 +1,78 @@
from __future__ import annotations
import typing
import abc
from ..core import app, entities as core_entities
from . import entities
preregistered_operators: list[typing.Type[CommandOperator]] = []
def operator_class(
name: str,
help: str,
usage: str = None,
alias: list[str] = [],
privilege: int=1, # 1为普通用户2为管理员
parent_class: typing.Type[CommandOperator] = None
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
cls.name = name
cls.alias = alias
cls.help = help
cls.usage = usage
cls.parent_class = parent_class
cls.lowest_privilege = privilege
preregistered_operators.append(cls)
return cls
return decorator
class CommandOperator(metaclass=abc.ABCMeta):
"""命令算子
"""
ap: app.Application
name: str
"""名称,搜索到时若符合则使用"""
path: str
"""路径所有父节点的name的连接用于定义命令权限"""
alias: list[str]
"""同name"""
help: str
"""此节点的帮助信息"""
usage: str = None
parent_class: typing.Union[typing.Type[CommandOperator], None] = None
"""父节点类。标记以供管理器在初始化时编织父子关系。"""
lowest_privilege: int = 0
"""最低权限。若权限低于此值,则不予执行。"""
children: list[CommandOperator]
"""子节点。解析命令时,若节点有子节点,则以下一个参数去匹配子节点,
若有匹配中的,转移到子节点中执行,若没有匹配中的或没有子节点,执行此节点。"""
def __init__(self, ap: app.Application):
self.ap = ap
self.children = []
async def initialize(self):
pass
@abc.abstractmethod
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
pass

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="cmd",
help='显示命令列表',
usage='!cmd\n!cmd <命令名称>'
)
class CmdOperator(operator.CommandOperator):
"""命令列表
"""
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
if len(context.crt_params) == 0:
reply_str = "当前所有命令: \n\n"
for cmd in self.ap.cmd_mgr.cmd_list:
if cmd.parent_class is None:
reply_str += f"{cmd.name}: {cmd.help}\n"
reply_str += "\n使用 !cmd <命令名称> 查看命令的详细帮助"
yield entities.CommandReturn(text=reply_str.strip())
else:
cmd_name = context.crt_params[0]
cmd = None
for _cmd in self.ap.cmd_mgr.cmd_list:
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None):
cmd = _cmd
break
if cmd is None:
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
else:
reply_str = f"{cmd.name}: {cmd.help}\n\n"
reply_str += f"使用方法: \n{cmd.usage}"
yield entities.CommandReturn(text=reply_str.strip())

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="default",
help="操作情景预设",
usage='!default\n!default set <指定情景预设为默认>'
)
class DefaultOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前所有情景预设: \n\n"
for prompt in self.ap.prompt_mgr.get_all_prompts():
content = ""
for msg in prompt.messages:
content += f" {msg.role}: {msg.content}"
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
reply_str += f"当前会话使用的是: {context.session.use_prompt_name}"
yield entities.CommandReturn(text=reply_str.strip())
@operator.operator_class(
name="set",
help="设置当前会话默认情景预设",
parent_class=DefaultOperator
)
class DefaultSetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
else:
prompt_name = context.crt_params[0]
try:
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
if prompt is None:
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
else:
context.session.use_prompt_name = prompt.name
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="del",
help="删除当前会话的历史记录",
usage='!del <序号>\n!del all'
)
class DelOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
delete_index = 0
if len(context.crt_params) > 0:
try:
delete_index = int(context.crt_params[0])
except:
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
return
if delete_index < 0 or delete_index >= len(context.session.conversations):
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
return
# 倒序
to_delete_index = len(context.session.conversations)-1-delete_index
if context.session.conversations[to_delete_index] == context.session.using_conversation:
context.session.using_conversation = None
del context.session.conversations[to_delete_index]
yield entities.CommandReturn(text=f"已删除对话: {delete_index}")
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
@operator.operator_class(
name="all",
help="删除此会话的所有历史记录",
parent_class=DelOperator
)
class DelAllOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
context.session.conversations = []
context.session.using_conversation = None
yield entities.CommandReturn(text="已删除所有对话")

View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from typing import AsyncGenerator
from .. import operator, entities, cmdmgr
@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func')
class FuncOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前已加载的内容函数: \n\n"
index = 1
all_functions = await self.ap.tool_mgr.get_all_functions()
for func in all_functions:
reply_str += "{}. {}{}:\n{}\n\n".format(
index,
("(已禁用) " if not func.enable else ""),
func.name,
func.description,
)
index += 1
yield entities.CommandReturn(text=reply_str)

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name='help',
help='显示帮助',
usage='!help\n!help <命令名称>'
)
class HelpOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
help = self.ap.system_cfg.data['help-message']
help += '\n发送命令 !cmd 可查看命令列表'
yield entities.CommandReturn(text=help)

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="last",
help="切换到前一个对话",
usage='!last'
)
class LastOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的上一个会话
for index in range(len(context.session.conversations)-1, -1, -1):
if context.session.conversations[index] == context.session.using_conversation:
if index == 0:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
return
else:
context.session.using_conversation = context.session.conversations[index-1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
return
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="list",
help="列出此会话中的所有历史对话",
usage='!list\n!list <页码>'
)
class ListOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
page = 0
if len(context.crt_params) > 0:
try:
page = int(context.crt_params[0]-1)
except:
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
return
record_per_page = 10
content = ''
index = 0
using_conv_index = 0
for conv in context.session.conversations[::-1]:
time_str = conv.create_time.strftime("%Y-%m-%d %H:%M:%S")
if conv == context.session.using_conversation:
using_conv_index = index
if index >= page * record_per_page and index < (page + 1) * record_per_page:
content += f"{index} {time_str}: {conv.messages[0].content if len(conv.messages) > 0 else '无内容'}\n"
index += 1
if content == '':
content = ''
else:
if context.session.using_conversation is None:
content += "\n当前处于新会话"
else:
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content if len(context.session.using_conversation.messages) > 0 else '无内容'}"
yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}")

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="next",
help="切换到后一个对话",
usage='!next'
)
class NextOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的下一个会话
for index in range(len(context.session.conversations)):
if context.session.conversations[index] == context.session.using_conversation:
if index == len(context.session.conversations)-1:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
return
else:
context.session.using_conversation = context.session.conversations[index+1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
yield entities.CommandReturn(text=f"已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
return
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -0,0 +1,237 @@
from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
from ...core import app
@operator.operator_class(
name="plugin",
help="插件操作",
usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>"
)
class PluginOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
plugin_list = self.ap.plugin_mgr.plugins
reply_str = "所有插件({}):\n".format(len(plugin_list))
idx = 0
for plugin in plugin_list:
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
.format((idx+1), plugin.plugin_name,
"[已禁用]" if not plugin.enabled else "",
plugin.plugin_description,
plugin.plugin_version, plugin.plugin_author)
# TODO 从元数据调远程地址
idx += 1
yield entities.CommandReturn(text=reply_str)
@operator.operator_class(
name="get",
help="安装插件",
privilege=2,
parent_class=PluginOperator
)
class PluginGetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
else:
repo = context.crt_params[0]
yield entities.CommandReturn(text="正在安装插件...")
try:
await self.ap.plugin_mgr.install_plugin(repo)
yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件安装失败: "+str(e)))
@operator.operator_class(
name="update",
help="更新插件",
privilege=2,
parent_class=PluginOperator
)
class PluginUpdateOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None:
yield entities.CommandReturn(text="正在更新插件...")
await self.ap.plugin_mgr.update_plugin(plugin_name)
yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件")
else:
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件"))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
@operator.operator_class(
name="all",
help="更新所有插件",
privilege=2,
parent_class=PluginUpdateOperator
)
class PluginUpdateAllOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
plugins = [
p.plugin_name
for p in self.ap.plugin_mgr.plugins
]
if plugins:
yield entities.CommandReturn(text="正在更新插件...")
updated = []
try:
for plugin_name in plugins:
await self.ap.plugin_mgr.update_plugin(plugin_name)
updated.append(plugin_name)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
yield entities.CommandReturn(text="已更新插件: {}".format(", ".join(updated)))
else:
yield entities.CommandReturn(text="没有可更新的插件")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
@operator.operator_class(
name="del",
help="删除插件",
privilege=2,
parent_class=PluginOperator
)
class PluginDelOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None:
yield entities.CommandReturn(text="正在删除插件...")
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件")
else:
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件"))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e)))
async def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application):
if ap.plugin_mgr.get_plugin_by_name(plugin_name) is not None:
for plugin in ap.plugin_mgr.plugins:
if plugin.plugin_name == plugin_name:
plugin.enabled = new_status
for func in plugin.content_functions:
func.enable = new_status
await ap.plugin_mgr.setting.dump_container_setting(ap.plugin_mgr.plugins)
break
return True
else:
return False
@operator.operator_class(
name="on",
help="启用插件",
privilege=2,
parent_class=PluginOperator
)
class PluginEnableOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
if await update_plugin_status(plugin_name, True, self.ap):
yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name))
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))
@operator.operator_class(
name="off",
help="禁用插件",
privilege=2,
parent_class=PluginOperator
)
class PluginDisableOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
if await update_plugin_status(plugin_name, False, self.ap):
yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name))
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="prompt",
help="查看当前对话的前文",
usage='!prompt'
)
class PromptOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
else:
reply_str = '当前对话所有内容:\n\n'
for msg in context.session.using_conversation.messages:
reply_str += f"{msg.role}: {msg.content}\n"
yield entities.CommandReturn(text=reply_str)

View File

@@ -0,0 +1,34 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="resend",
help="重发当前会话的最后一条消息",
usage='!resend'
)
class ResendOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
# 回滚到最后一条用户message前
if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandError("当前没有对话"))
else:
conv_msg = context.session.using_conversation.messages
# 倒序一直删到最后一条用户message
while len(conv_msg) > 0 and conv_msg[-1].role != 'user':
conv_msg.pop()
if len(conv_msg) > 0:
# 删除最后一条用户message
conv_msg.pop()
# 不重发了,提示用户已删除就行了
yield entities.CommandReturn(text="已删除最后一次请求记录")

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="reset",
help="重置当前会话",
usage='!reset'
)
class ResetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
context.session.using_conversation = None
yield entities.CommandReturn(text="已重置当前会话")

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="update",
help="更新程序",
usage='!update',
privilege=2
)
class UpdateCommand(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
yield entities.CommandReturn(text="正在进行更新...")
if await self.ap.ver_mgr.update_all():
yield entities.CommandReturn(text="更新完成,请重启程序以应用更新")
else:
yield entities.CommandReturn(text="当前已是最新版本")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("更新失败: "+str(e)))

View File

@@ -0,0 +1,27 @@
from __future__ import annotations
import typing
from .. import operator, cmdmgr, entities, errors
@operator.operator_class(
name="version",
help="显示版本信息",
usage='!version'
)
class VersionCommand(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = f"当前版本: \n{self.ap.ver_mgr.get_current_version()}"
try:
if await self.ap.ver_mgr.is_new_version_available():
reply_str += "\n\n有新版本可用, 使用 !update 更新"
except:
pass
yield entities.CommandReturn(text=reply_str.strip())

47
pkg/config/impls/json.py Normal file
View File

@@ -0,0 +1,47 @@
import os
import shutil
import json
from .. import model as file_model
class JSONConfigFile(file_model.ConfigFile):
"""JSON配置文件"""
config_file_name: str = None
"""配置文件名"""
template_file_name: str = None
"""模板文件名"""
def __init__(self, config_file_name: str, template_file_name: str) -> None:
self.config_file_name = config_file_name
self.template_file_name = template_file_name
def exists(self) -> bool:
return os.path.exists(self.config_file_name)
async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name)
async def load(self) -> dict:
if not self.exists():
await self.create()
with open(self.config_file_name, 'r', encoding='utf-8') as f:
cfg = json.load(f)
# 从模板文件中进行补全
with open(self.template_file_name, 'r', encoding='utf-8') as f:
template_cfg = json.load(f)
for key in template_cfg:
if key not in cfg:
cfg[key] = template_cfg[key]
return cfg
async def save(self, cfg: dict):
with open(self.config_file_name, 'w', encoding='utf-8') as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)

View File

@@ -1,5 +1,10 @@
from __future__ import annotations
from . import model as file_model
from ..utils import context
from .impls import pymodule, json as json_file
managers: ConfigManager = []
class ConfigManager:
@@ -14,10 +19,35 @@ class ConfigManager:
def __init__(self, cfg_file: file_model.ConfigFile) -> None:
self.file = cfg_file
self.data = {}
context.set_config_manager(self)
async def load_config(self):
self.data = await self.file.load()
async def dump_config(self):
await self.file.save(self.data)
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager:
"""加载Python模块配置文件"""
cfg_inst = pymodule.PythonModuleConfigFile(
config_name,
template_name
)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config()
return cfg_mgr
async def load_json_config(config_name: str, template_name: str) -> ConfigManager:
"""加载JSON配置文件"""
cfg_inst = json_file.JSONConfigFile(
config_name,
template_name
)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config()
return cfg_mgr

106
pkg/core/app.py Normal file
View File

@@ -0,0 +1,106 @@
from __future__ import annotations
import logging
import asyncio
import traceback
from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.requester import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr
from ..config import manager as config_mgr
from ..audit.center import v2 as center_mgr
from ..command import cmdmgr
from ..plugin import manager as plugin_mgr
from ..pipeline import pool
from ..pipeline import controller, stagemgr
from ..utils import version as version_mgr, proxy as proxy_mgr
class Application:
im_mgr: im_mgr.PlatformManager = None
cmd_mgr: cmdmgr.CommandManager = None
sess_mgr: llm_session_mgr.SessionManager = None
model_mgr: llm_model_mgr.ModelManager = None
prompt_mgr: llm_prompt_mgr.PromptManager = None
tool_mgr: llm_tool_mgr.ToolManager = None
command_cfg: config_mgr.ConfigManager = None
pipeline_cfg: config_mgr.ConfigManager = None
platform_cfg: config_mgr.ConfigManager = None
provider_cfg: config_mgr.ConfigManager = None
system_cfg: config_mgr.ConfigManager = None
ctr_mgr: center_mgr.V2CenterAPI = None
plugin_mgr: plugin_mgr.PluginManager = None
query_pool: pool.QueryPool = None
ctrl: controller.Controller = None
stage_mgr: stagemgr.StageManager = None
ver_mgr: version_mgr.VersionManager = None
proxy_mgr: proxy_mgr.ProxyManager = None
logger: logging.Logger = None
def __init__(self):
pass
async def initialize(self):
pass
async def run(self):
await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins()
tasks = []
try:
tasks = [
asyncio.create_task(self.im_mgr.run()),
asyncio.create_task(self.ctrl.run())
]
# async def interrupt(tasks):
# await asyncio.sleep(1.5)
# while await aioconsole.ainput("使用 ctrl+c 或 'exit' 退出程序 > ") != 'exit':
# pass
# for task in tasks:
# task.cancel()
# await interrupt(tasks)
import signal
def signal_handler(sig, frame):
for task in tasks:
task.cancel()
self.logger.info("程序退出.")
exit(0)
signal.signal(signal.SIGINT, signal_handler)
await asyncio.gather(*tasks, return_exceptions=True)
except asyncio.CancelledError:
pass
except Exception as e:
self.logger.error(f"应用运行致命异常: {e}")
self.logger.debug(f"Traceback: {traceback.format_exc()}")

35
pkg/core/boot.py Normal file
View File

@@ -0,0 +1,35 @@
from __future__ import print_function
from . import app
from ..audit import identifier
from . import stage
from .stages import load_config, setup_logger, build_app
stage_order = [
"LoadConfigStage",
"SetupLoggerStage",
"BuildAppStage"
]
async def make_app() -> app.Application:
# 生成标识符
identifier.init()
ap = app.Application()
for stage_name in stage_order:
stage_cls = stage.preregistered_stages[stage_name]
stage_inst = stage_cls()
await stage_inst.run(ap)
await ap.initialize()
return ap
async def main():
app_inst = await make_app()
await app_inst.run()

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
import json
from ...config import manager as config_mgr
from ...config.impls import pymodule
load_python_module_config = config_mgr.load_python_module_config
load_json_config = config_mgr.load_json_config
async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str]:
override_json = json.load(open("override.json", "r", encoding="utf-8"))
overrided = []
config = cfg_mgr.data
for key in override_json:
if key in config:
config[key] = override_json[key]
overrided.append(key)
return overrided

View File

@@ -0,0 +1,34 @@
import pip
required_deps = {
"requests": "requests",
"openai": "openai",
"colorlog": "colorlog",
"mirai": "yiri-mirai-rc",
"aiocqhttp": "aiocqhttp",
"botpy": "qq-botpy",
"PIL": "pillow",
"nakuru": "nakuru-project-idk",
"CallingGPT": "CallingGPT",
"tiktoken": "tiktoken",
"yaml": "pyyaml",
"aiohttp": "aiohttp",
}
async def check_deps() -> list[str]:
global required_deps
missing_deps = []
for dep in required_deps:
try:
__import__(dep)
except ImportError:
missing_deps.append(dep)
return missing_deps
async def install_deps(deps: list[str]):
global required_deps
for dep in deps:
pip.main(["install", required_deps[dep]])

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
import os
import shutil
import sys
required_files = {
"plugins/__init__.py": "templates/__init__.py",
"plugins/plugins.json": "templates/plugin-settings.json",
"data/config/command.json": "templates/command.json",
"data/config/pipeline.json": "templates/pipeline.json",
"data/config/platform.json": "templates/platform.json",
"data/config/provider.json": "templates/provider.json",
"data/config/system.json": "templates/system.json",
"data/config/sensitive-words.json": "templates/sensitive-words.json",
"data/scenario/default.json": "templates/scenario-template.json",
}
required_paths = [
"temp",
"data",
"data/prompts",
"data/scenario",
"data/logs",
"data/config",
"plugins"
]
async def generate_files() -> list[str]:
global required_files, required_paths
for required_paths in required_paths:
if not os.path.exists(required_paths):
os.mkdir(required_paths)
generated_files = []
for file in required_files:
if not os.path.exists(file):
shutil.copyfile(required_files[file], file)
generated_files.append(file)
return generated_files

61
pkg/core/bootutils/log.py Normal file
View File

@@ -0,0 +1,61 @@
import logging
import os
import sys
import time
import colorlog
log_colors_config = {
"DEBUG": "green", # cyan white
"INFO": "white",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "cyan",
}
async def init_logging() -> logging.Logger:
# 删除所有现有的logger
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
level = logging.INFO
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
level = logging.DEBUG
log_file_name = "data/logs/qcg-%s.log" % time.strftime(
"%Y-%m-%d-%H-%M-%S", time.localtime()
)
qcg_logger = logging.getLogger("qcg")
qcg_logger.setLevel(level)
color_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
log_colors=log_colors_config,
)
stream_handler = logging.StreamHandler(sys.stdout)
log_handlers: logging.Handler = [stream_handler, logging.FileHandler(log_file_name)]
for handler in log_handlers:
handler.setLevel(level)
handler.setFormatter(color_formatter)
qcg_logger.addHandler(handler)
qcg_logger.debug("日志初始化完成,日志级别:%s" % level)
logging.basicConfig(
level=logging.CRITICAL, # 设置日志输出格式
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
# 日志输出的格式
# -8表示占位符让输出左对齐输出长度都为8位
datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式
handlers=[logging.NullHandler()],
)
return qcg_logger

116
pkg/core/entities.py Normal file
View File

@@ -0,0 +1,116 @@
from __future__ import annotations
import enum
import typing
import datetime
import asyncio
import pydantic
import mirai
from ..provider import entities as llm_entities
from ..provider.requester import entities
from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
class LauncherTypes(enum.Enum):
PERSON = 'person'
"""私聊"""
GROUP = 'group'
"""群聊"""
class Query(pydantic.BaseModel):
"""一次请求的信息封装"""
query_id: int
"""请求ID添加进请求池时生成"""
launcher_type: LauncherTypes
"""会话类型platform设置"""
launcher_id: int
"""会话IDplatform设置"""
sender_id: int
"""发送者IDplatform设置"""
message_event: mirai.MessageEvent
"""事件platform收到的事件"""
message_chain: mirai.MessageChain
"""消息链platform收到的消息链"""
adapter: msadapter.MessageSourceAdapter
"""适配器对象"""
session: typing.Optional[Session] = None
"""会话对象,由前置处理器设置"""
messages: typing.Optional[list[llm_entities.Message]] = []
"""历史消息列表,由前置处理器设置"""
prompt: typing.Optional[sysprompt_entities.Prompt] = None
"""情景预设内容,由前置处理器设置"""
user_message: typing.Optional[llm_entities.Message] = None
"""此次请求的用户消息对象,由前置处理器设置"""
use_model: typing.Optional[entities.LLMModelInfo] = None
"""使用的模型,由前置处理器设置"""
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] = []
"""由provider生成的回复消息对象列表"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None
"""回复消息链从resp_messages包装而得"""
class Config:
arbitrary_types_allowed = True
class Conversation(pydantic.BaseModel):
"""对话"""
prompt: sysprompt_entities.Prompt
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_model: entities.LLMModelInfo
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
class Session(pydantic.BaseModel):
"""会话"""
launcher_type: LauncherTypes
launcher_id: int
sender_id: typing.Optional[int] = 0
use_prompt_name: typing.Optional[str] = 'default'
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = []
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None
class Config:
arbitrary_types_allowed = True

30
pkg/core/stage.py Normal file
View File

@@ -0,0 +1,30 @@
from __future__ import annotations
import abc
import typing
from . import app
preregistered_stages: dict[str, typing.Type[BootingStage]] = {}
def stage_class(
name: str
):
def decorator(cls: typing.Type[BootingStage]) -> typing.Type[BootingStage]:
preregistered_stages[name] = cls
return cls
return decorator
class BootingStage(abc.ABC):
"""启动阶段
"""
name: str = None
@abc.abstractmethod
async def run(self, ap: app.Application):
"""启动
"""
pass

View File

@@ -0,0 +1,95 @@
from __future__ import annotations
import sys
from .. import stage, app
from ...utils import version, proxy, announce
from ...audit.center import v2 as center_v2
from ...audit import identifier
from ...pipeline import pool, controller, stagemgr
from ...plugin import manager as plugin_mgr
from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.requester import modelmgr as llm_model_mgr
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr
from ...platform import manager as im_mgr
@stage.stage_class("BuildAppStage")
class BuildAppStage(stage.BootingStage):
"""构建应用阶段
"""
async def run(self, ap: app.Application):
"""启动
"""
proxy_mgr = proxy.ProxyManager(ap)
await proxy_mgr.initialize()
ap.proxy_mgr = proxy_mgr
ver_mgr = version.VersionManager(ap)
await ver_mgr.initialize()
ap.ver_mgr = ver_mgr
center_v2_api = center_v2.V2CenterAPI(
ap,
basic_info={
"host_id": identifier.identifier["host_id"],
"instance_id": identifier.identifier["instance_id"],
"semantic_version": ver_mgr.get_current_version(),
"platform": sys.platform,
},
runtime_info={
"admin_id": "{}".format(ap.system_cfg.data["admin-sessions"]),
"msg_source": str([
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
for adapter_cfg in ap.platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
]),
},
)
ap.ctr_mgr = center_v2_api
# 发送公告
ann_mgr = announce.AnnouncementManager(ap)
await ann_mgr.show_announcements()
ap.query_pool = pool.QueryPool()
await ap.ver_mgr.show_version_update()
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst
cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize()
ap.cmd_mgr = cmd_mgr_inst
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
await llm_model_mgr_inst.initialize()
ap.model_mgr = llm_model_mgr_inst
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
await llm_session_mgr_inst.initialize()
ap.sess_mgr = llm_session_mgr_inst
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
await llm_prompt_mgr_inst.initialize()
ap.prompt_mgr = llm_prompt_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize()
ap.im_mgr = im_mgr_inst
stage_mgr = stagemgr.StageManager(ap)
await stage_mgr.initialize()
ap.stage_mgr = stage_mgr
ctrl = controller.Controller(ap)
ap.ctrl = ctrl

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from .. import stage, app
from ..bootutils import config
@stage.stage_class("LoadConfigStage")
class LoadConfigStage(stage.BootingStage):
"""加载配置文件阶段
"""
async def run(self, ap: app.Application):
"""启动
"""
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json")
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json")
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json")
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json")
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json")

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from .. import stage, app
from ..bootutils import log
@stage.stage_class("SetupLoggerStage")
class SetupLoggerStage(stage.BootingStage):
"""设置日志器阶段
"""
async def run(self, ap: app.Application):
"""启动
"""
ap.logger = await log.init_logging()

View File

@@ -1,3 +0,0 @@
"""
数据库操作封装
"""

View File

@@ -1,365 +0,0 @@
"""
数据库管理模块
"""
import hashlib
import json
import logging
import time
import sqlite3
from ..utils import context
class DatabaseManager:
"""封装数据库底层操作,并提供方法给上层使用"""
conn = None
cursor = None
def __init__(self):
self.reconnect()
context.set_database_manager(self)
# 连接到数据库文件
def reconnect(self):
"""连接到数据库"""
self.conn = sqlite3.connect('database.db', check_same_thread=False)
self.cursor = self.conn.cursor()
def close(self):
self.conn.close()
def __execute__(self, *args, **kwargs) -> sqlite3.Cursor:
# logging.debug('SQL: {}'.format(sql))
logging.debug('SQL: {}'.format(args))
c = self.cursor.execute(*args, **kwargs)
self.conn.commit()
return c
# 初始化数据库的函数
def initialize_database(self):
"""创建数据表"""
self.__execute__("""
create table if not exists `sessions` (
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`name` varchar(255) not null,
`type` varchar(255) not null,
`number` bigint not null,
`create_timestamp` bigint not null,
`last_interact_timestamp` bigint not null,
`status` varchar(255) not null default 'on_going',
`default_prompt` text not null default '',
`prompt` text not null,
`token_counts` text not null default '[]'
)
""")
# 检查sessions表是否存在`default_prompt`字段, 检查是否存在`token_counts`字段
self.__execute__("PRAGMA table_info('sessions')")
columns = self.cursor.fetchall()
has_default_prompt = False
has_token_counts = False
for field in columns:
if field[1] == 'default_prompt':
has_default_prompt = True
if field[1] == 'token_counts':
has_token_counts = True
if has_default_prompt and has_token_counts:
break
if not has_default_prompt:
self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''")
if not has_token_counts:
self.__execute__("alter table `sessions` add column `token_counts` text not null default '[]'")
self.__execute__("""
create table if not exists `account_fee`(
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`key_md5` varchar(255) not null,
`timestamp` bigint not null,
`fee` DECIMAL(12,6) not null
)
""")
self.__execute__("""
create table if not exists `account_usage`(
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`json` text not null
)
""")
# print('Database initialized.')
# session持久化
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: str = ''):
"""持久化指定session"""
# 检查是否已经有了此name和create_timestamp的session
# 如果有就更新prompt和last_interact_timestamp
# 如果没有,就插入一条新的记录
self.__execute__("""
select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {}
""".format(subject_type, subject_number, create_timestamp))
count = self.cursor.fetchone()[0]
if count == 0:
sql = """
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`, `token_counts`)
values (?, ?, ?, ?, ?, ?, ?, ?)
"""
self.__execute__(sql,
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, prompt, default_prompt, token_counts))
else:
sql = """
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ?
where `type` = ? and `number` = ? and `create_timestamp` = ?
"""
self.__execute__(sql, (last_interact_timestamp, prompt, token_counts, subject_type,
subject_number, create_timestamp))
# 显式关闭一个session
def explicit_close_session(self, session_name: str, create_timestamp: int):
self.__execute__("""
update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp))
def set_session_ongoing(self, session_name: str, create_timestamp: int):
self.__execute__("""
update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp))
# 设置session为过期
def set_session_expired(self, session_name: str, create_timestamp: int):
self.__execute__("""
update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp))
# 从数据库加载还没过期的session数据
def load_valid_sessions(self) -> dict:
# 从数据库中加载所有还没过期的session
config = context.get_config_manager().data
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `last_interact_timestamp` > {}
""".format(int(time.time()) - config['session_expire_time']))
results = self.cursor.fetchall()
sessions = {}
for result in results:
session_name = result[0]
subject_type = result[1]
subject_number = result[2]
create_timestamp = result[3]
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
# 当且仅当最后一个该对象的会话是on_going状态时才会被加载
if status == 'on_going':
sessions[session_name] = {
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt,
'token_counts': token_counts
}
else:
if session_name in sessions:
del sessions[session_name]
return sessions
# 获取此session_name前一个session的数据
def last_session(self, session_name: str, cursor_timestamp: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
limit 1
""".format(session_name, cursor_timestamp))
results = self.cursor.fetchall()
if len(results) == 0:
return None
result = results[0]
session_name = result[0]
subject_type = result[1]
subject_number = result[2]
create_timestamp = result[3]
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
return {
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt,
'token_counts': token_counts
}
# 获取此session_name后一个session的数据
def next_session(self, session_name: str, cursor_timestamp: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
limit 1
""".format(session_name, cursor_timestamp))
results = self.cursor.fetchall()
if len(results) == 0:
return None
result = results[0]
session_name = result[0]
subject_type = result[1]
subject_number = result[2]
create_timestamp = result[3]
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
return {
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt,
'token_counts': token_counts
}
# 列出与某个对象的所有对话session
def list_history(self, session_name: str, capacity: int, page: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
""".format(session_name, capacity, capacity * page))
results = self.cursor.fetchall()
sessions = []
for result in results:
session_name = result[0]
subject_type = result[1]
subject_number = result[2]
create_timestamp = result[3]
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
sessions.append({
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt,
'token_counts': token_counts
})
return sessions
def delete_history(self, session_name: str, index: int) -> bool:
# 删除倒序第index个session
# 查找其id再删除
self.__execute__("""
delete from `sessions` where `id` in (select `id` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit 1 offset {})
""".format(session_name, index))
return self.cursor.rowcount == 1
def delete_all_history(self, session_name: str) -> bool:
self.__execute__("""
delete from `sessions` where `name` = '{}'
""".format(session_name))
return self.cursor.rowcount > 0
def delete_all_session_history(self) -> bool:
self.__execute__("""
delete from `sessions`
""")
return self.cursor.rowcount > 0
# 将apikey的使用量存进数据库
def dump_api_key_usage(self, api_keys: dict, usage: dict):
logging.debug('dumping api key usage...')
logging.debug(api_keys)
logging.debug(usage)
for api_key in api_keys:
# 计算key的md5值
key_md5 = hashlib.md5(api_keys[api_key].encode('utf-8')).hexdigest()
# 获取使用量
usage_count = 0
if key_md5 in usage:
usage_count = usage[key_md5]
# 将使用量存进数据库
# 先检查是否已存在
self.__execute__("""
select count(*) from `api_key_usage` where `key_md5` = '{}'""".format(key_md5))
result = self.cursor.fetchone()
if result[0] == 0:
# 不存在则插入
self.__execute__("""
insert into `api_key_usage` (`key_md5`, `usage`,`timestamp`) values ('{}', {}, {})
""".format(key_md5, usage_count, int(time.time())))
else:
# 存在则更新timestamp设置为当前
self.__execute__("""
update `api_key_usage` set `usage` = {}, `timestamp` = {} where `key_md5` = '{}'
""".format(usage_count, int(time.time()), key_md5))
def load_api_key_usage(self):
self.__execute__("""
select `key_md5`, `usage` from `api_key_usage`
""")
results = self.cursor.fetchall()
usage = {}
for result in results:
key_md5 = result[0]
usage_count = result[1]
usage[key_md5] = usage_count
return usage
def dump_usage_json(self, usage: dict):
json_str = json.dumps(usage)
self.__execute__("""
select count(*) from `account_usage`""")
result = self.cursor.fetchone()
if result[0] == 0:
# 不存在则插入
self.__execute__("""
insert into `account_usage` (`json`) values ('{}')
""".format(json_str))
else:
# 存在则更新
self.__execute__("""
update `account_usage` set `json` = '{}' where `id` = 1
""".format(json_str))
def load_usage_json(self):
self.__execute__("""
select `json` from `account_usage` order by id desc limit 1
""")
result = self.cursor.fetchone()
if result is None:
return None
else:
return result[0]

View File

@@ -1,232 +0,0 @@
import json
import logging
import openai
from openai.types.chat import chat_completion_message
from .model import RequestBase
from .. import funcmgr
from ...plugin import host
from ...utils import context
class ChatCompletionRequest(RequestBase):
"""调用ChatCompletion接口的请求类。
此类保证每一次返回的角色为assistant的信息的finish_reason一定为stop。
若有函数调用响应,本类的返回瀑布是:函数调用请求->函数调用结果->...->assistant的信息->stop。
"""
model: str
messages: list[dict[str, str]]
kwargs: dict
stopped: bool = False
pending_func_call: chat_completion_message.FunctionCall = None
pending_msg: str
def flush_pending_msg(self):
self.append_message(
role="assistant",
content=self.pending_msg
)
self.pending_msg = ""
def append_message(self, role: str, content: str, name: str=None, function_call: dict=None):
msg = {
"role": role,
"content": content
}
if name is not None:
msg['name'] = name
if function_call is not None:
msg['function_call'] = function_call
self.messages.append(msg)
def __init__(
self,
client: openai.Client,
model: str,
messages: list[dict[str, str]],
**kwargs
):
self.client = client
self.model = model
self.messages = messages.copy()
self.kwargs = kwargs
self.req_func = self.client.chat.completions.create
self.pending_func_call = None
self.stopped = False
self.pending_msg = ""
def __iter__(self):
return self
def __next__(self) -> dict:
if self.stopped:
raise StopIteration()
if self.pending_func_call is None: # 没有待处理的函数调用请求
args = {
"model": self.model,
"messages": self.messages,
}
funcs = funcmgr.get_func_schema_list()
if len(funcs) > 0:
args['functions'] = funcs
# 拼接kwargs
args = {**args, **self.kwargs}
from openai.types.chat import chat_completion
resp: chat_completion.ChatCompletion = self._req(**args)
choice0 = resp.choices[0]
# 如果不是函数调用且finish_reason为stop则停止迭代
if choice0.finish_reason == 'stop': # and choice0["finish_reason"] == "stop"
self.stopped = True
if hasattr(choice0.message, 'function_call') and choice0.message.function_call is not None:
self.pending_func_call = choice0.message.function_call
self.append_message(
role="assistant",
content=choice0.message.content,
function_call=choice0.message.function_call
)
return {
"id": resp.id,
"choices": [
{
"index": choice0.index,
"message": {
"role": "assistant",
"type": "function_call",
"content": choice0.message.content,
"function_call": {
"name": choice0.message.function_call.name,
"arguments": choice0.message.function_call.arguments
}
},
"finish_reason": "function_call"
}
],
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
}
else:
# self.pending_msg += choice0['message']['content']
# 普通回复一定处于最后方故不用再追加进内部messages
return {
"id": resp.id,
"choices": [
{
"index": choice0.index,
"message": {
"role": "assistant",
"type": "text",
"content": choice0.message.content
},
"finish_reason": choice0.finish_reason
}
],
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
}
else: # 处理函数调用请求
cp_pending_func_call = self.pending_func_call.copy()
self.pending_func_call = None
func_name = cp_pending_func_call.name
arguments = {}
try:
try:
arguments = json.loads(cp_pending_func_call.arguments)
# 若不是json格式的异常处理
except json.decoder.JSONDecodeError:
# 获取函数的参数列表
func_schema = funcmgr.get_func_schema(func_name)
arguments = {
func_schema['parameters']['required'][0]: cp_pending_func_call.arguments
}
logging.info("执行函数调用: name={}, arguments={}".format(func_name, arguments))
# 执行函数调用
ret = ""
try:
ret = funcmgr.execute_function(func_name, arguments)
logging.info("函数执行完成。")
except Exception as e:
ret = "error: execute function failed: {}".format(str(e))
logging.error("函数执行失败: {}".format(str(e)))
# 上报数据
plugin_info = host.get_plugin_info_for_audit(func_name.split('-')[0])
audit_func_name = func_name.split('-')[1]
audit_func_desc = funcmgr.get_func_schema(func_name)['description']
context.get_center_v2_api().usage.post_function_record(
plugin=plugin_info,
function_name=audit_func_name,
function_description=audit_func_desc,
)
self.append_message(
role="function",
content=json.dumps(ret, ensure_ascii=False),
name=func_name
)
return {
"id": -1,
"choices": [
{
"index": -1,
"message": {
"role": "function",
"type": "function_return",
"function_name": func_name,
"content": json.dumps(ret, ensure_ascii=False)
},
"finish_reason": "function_return"
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
except funcmgr.ContentFunctionNotFoundError:
raise Exception("没有找到函数: {}".format(func_name))

View File

@@ -1,100 +0,0 @@
import openai
from openai.types import completion, completion_choice
from . import model
class CompletionRequest(model.RequestBase):
"""调用Completion接口的请求类。
调用方可以一直next completion直到finish_reason为stop。
"""
model: str
prompt: str
kwargs: dict
stopped: bool = False
def __init__(
self,
client: openai.Client,
model: str,
messages: list[dict[str, str]],
**kwargs
):
self.client = client
self.model = model
self.prompt = ""
for message in messages:
self.prompt += message["role"] + ": " + message["content"] + "\n"
self.prompt += "assistant: "
self.kwargs = kwargs
self.req_func = self.client.completions.create
def __iter__(self):
return self
def __next__(self) -> dict:
"""调用Completion接口返回生成的文本
{
"id": "id",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"type": "text",
"content": "message"
},
"finish_reason": "reason"
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}
"""
if self.stopped:
raise StopIteration()
resp: completion.Completion = self._req(
model=self.model,
prompt=self.prompt,
**self.kwargs
)
if resp.choices[0].finish_reason == "stop":
self.stopped = True
choice0: completion_choice.CompletionChoice = resp.choices[0]
self.prompt += choice0.text
return {
"id": resp.id,
"choices": [
{
"index": choice0.index,
"message": {
"role": "assistant",
"type": "text",
"content": choice0.text
},
"finish_reason": choice0.finish_reason
}
],
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
}

View File

@@ -1,40 +0,0 @@
# 定义不同接口请求的模型
import logging
import openai
from ...utils import context
class RequestBase:
client: openai.Client
req_func: callable
def __init__(self, *args, **kwargs):
raise NotImplementedError
def _next_key(self):
switched, name = context.get_openai_manager().key_mgr.auto_switch()
logging.debug("切换api-key: switched={}, name={}".format(switched, name))
self.client.api_key = context.get_openai_manager().key_mgr.get_using_key()
def _req(self, **kwargs):
"""处理代理问题"""
logging.debug("请求接口参数: %s", str(kwargs))
config = context.get_config_manager().data
ret = self.req_func(**kwargs)
logging.debug("接口请求返回:%s", str(ret))
if config['switch_strategy'] == 'active':
self._next_key()
return ret
def __iter__(self):
raise self
def __next__(self):
raise NotImplementedError

View File

@@ -1,134 +0,0 @@
# 多情景预设值管理
import json
import logging
import os
from ..utils import context
# __current__ = "default"
# """当前默认使用的情景预设的名称
# 由管理员使用`!default <名称>`命令切换
# """
# __prompts_from_files__ = {}
# """从文件中读取的情景预设值"""
# __scenario_from_files__ = {}
class ScenarioMode:
"""情景预设模式抽象类"""
using_prompt_name = "default"
"""新session创建时使用的prompt名称"""
prompts: dict[str, list] = {}
def __init__(self):
logging.debug("prompts: {}".format(self.prompts))
def list(self) -> dict[str, list]:
"""获取所有情景预设的名称及内容"""
return self.prompts
def get_prompt(self, name: str) -> tuple[list, str]:
"""获取指定情景预设的名称及内容"""
for key in self.prompts:
if key.startswith(name):
return self.prompts[key], key
raise Exception("没有找到情景预设: {}".format(name))
def set_using_name(self, name: str) -> str:
"""设置默认情景预设"""
for key in self.prompts:
if key.startswith(name):
self.using_prompt_name = key
return key
raise Exception("没有找到情景预设: {}".format(name))
def get_full_name(self, name: str) -> str:
"""获取完整的情景预设名称"""
for key in self.prompts:
if key.startswith(name):
return key
raise Exception("没有找到情景预设: {}".format(name))
def get_using_name(self) -> str:
"""获取默认情景预设"""
return self.using_prompt_name
class NormalScenarioMode(ScenarioMode):
"""普通情景预设模式"""
def __init__(self):
config = context.get_config_manager().data
# 加载config中的default_prompt值
if type(config['default_prompt']) == str:
self.using_prompt_name = "default"
self.prompts = {"default": [
{
"role": "system",
"content": config['default_prompt']
}
]}
elif type(config['default_prompt']) == dict:
for key in config['default_prompt']:
self.prompts[key] = [
{
"role": "system",
"content": config['default_prompt'][key]
}
]
# 从prompts/目录下的文件中载入
# 遍历文件
for file in os.listdir("prompts"):
with open(os.path.join("prompts", file), encoding="utf-8") as f:
self.prompts[file] = [
{
"role": "system",
"content": f.read()
}
]
class FullScenarioMode(ScenarioMode):
"""完整情景预设模式"""
def __init__(self):
"""从json读取所有"""
# 遍历scenario/目录下的所有文件以文件名为键文件内容中的prompt为值
for file in os.listdir("scenario"):
if file == "default-template.json":
continue
with open(os.path.join("scenario", file), encoding="utf-8") as f:
self.prompts[file] = json.load(f)["prompt"]
super().__init__()
scenario_mode_mapping = {}
"""情景预设模式名称与对象的映射"""
def register_all():
"""注册所有情景预设模式,不使用装饰器,因为装饰器的方式不支持热重载"""
global scenario_mode_mapping
scenario_mode_mapping = {
"normal": NormalScenarioMode(),
"full_scenario": FullScenarioMode()
}
def mode_inst() -> ScenarioMode:
"""获取指定名称的情景预设模式对象"""
config = context.get_config_manager().data
if config['preset_mode'] == "default":
config['preset_mode'] = "normal"
return scenario_mode_mapping[config['preset_mode']]

View File

@@ -1,46 +0,0 @@
# 封装了function calling的一些支持函数
import logging
from ..plugin import host
class ContentFunctionNotFoundError(Exception):
pass
def get_func_schema_list() -> list:
"""从plugin包中的函数结构中获取并处理成受GPT支持的格式"""
if not host.__enable_content_functions__:
return []
schemas = []
for func in host.__callable_functions__:
if func['enabled']:
fun_cp = func.copy()
del fun_cp['enabled']
schemas.append(fun_cp)
return schemas
def get_func(name: str) -> callable:
if name not in host.__function_inst_map__:
raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name))
return host.__function_inst_map__[name]
def get_func_schema(name: str) -> dict:
for func in host.__callable_functions__:
if func['name'] == name:
return func
raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name))
def execute_function(name: str, kwargs: dict) -> any:
"""执行函数调用"""
logging.debug("executing function: name='{}', kwargs={}".format(name, kwargs))
func = get_func(name)
return func(**kwargs)

View File

@@ -1,103 +0,0 @@
# 此模块提供了维护api-key的各种功能
import hashlib
import logging
from ..plugin import host as plugin_host
from ..plugin import models as plugin_models
class KeysManager:
api_key = {}
"""所有api-key"""
using_key = ""
"""当前使用的api-key"""
alerted = []
"""已提示过超额的key
记录在此以避免重复提示
"""
exceeded = []
"""已超额的key
供自动切换功能识别
"""
def get_using_key(self):
return self.using_key
def get_using_key_md5(self):
return hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
def __init__(self, api_key):
assert type(api_key) == dict
self.api_key = api_key
# 从usage中删除未加载的api-key的记录
# 不删了也许会运行时添加曾经有记录的api-key
self.auto_switch()
def auto_switch(self) -> tuple[bool, str]:
"""尝试切换api-key
Returns:
是否切换成功, 切换后的api-key的别名
"""
index = 0
for key_name in self.api_key:
if self.api_key[key_name] == self.using_key:
break
index += 1
# 从当前key开始向后轮询
start_index = index
index += 1
if index >= len(self.api_key):
index = 0
while index != start_index:
key_name = list(self.api_key.keys())[index]
if self.api_key[key_name] not in self.exceeded:
self.using_key = self.api_key[key_name]
logging.debug("使用api-key:" + key_name)
# 触发插件事件
args = {
"key_name": key_name,
"key_list": self.api_key.keys()
}
_ = plugin_host.emit(plugin_models.KeySwitched, **args)
return True, key_name
index += 1
if index >= len(self.api_key):
index = 0
self.using_key = list(self.api_key.values())[start_index]
logging.debug("使用api-key:" + list(self.api_key.keys())[start_index])
return False, list(self.api_key.keys())[start_index]
def add(self, key_name, key):
self.api_key[key_name] = key
def set_current_exceeded(self):
"""设置当前使用的api-key使用量超限"""
self.exceeded.append(self.using_key)
def get_key_name(self, api_key):
"""根据api-key获取其别名"""
for key_name in self.api_key:
if self.api_key[key_name] == api_key:
return key_name
return ""

View File

@@ -1,90 +0,0 @@
import logging
import openai
from openai.types import images_response
from ..openai import keymgr
from ..utils import context
from ..audit import gatherer
from ..openai import modelmgr
from ..openai.api import model as api_model
class OpenAIInteract:
"""OpenAI 接口封装
将文字接口和图片接口封装供调用方使用
"""
key_mgr: keymgr.KeysManager = None
audit_mgr: gatherer.DataGatherer = None
default_image_api_params = {
"size": "256x256",
}
client: openai.Client = None
def __init__(self, api_key: str):
self.key_mgr = keymgr.KeysManager(api_key)
self.audit_mgr = gatherer.DataGatherer()
# logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length())
self.client = openai.Client(
api_key=self.key_mgr.get_using_key(),
base_url=openai.base_url
)
context.set_openai_manager(self)
def request_completion(self, messages: list):
"""请求补全接口回复=
"""
# 选择接口请求类
config = context.get_config_manager().data
request: api_model.RequestBase
model: str = config['completion_api_params']['model']
cp_parmas = config['completion_api_params'].copy()
del cp_parmas['model']
request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas)
# 请求接口
for resp in request:
if resp['usage']['total_tokens'] > 0:
self.audit_mgr.report_text_model_usage(
model,
resp['usage']['total_tokens']
)
yield resp
def request_image(self, prompt) -> images_response.ImagesResponse:
"""请求图片接口回复
Parameters:
prompt (str): 提示语
Returns:
dict: 响应
"""
config = context.get_config_manager().data
params = config['image_api_params']
response = self.client.images.generate(
prompt=prompt,
n=1,
**params
)
self.audit_mgr.report_image_model_usage(params['size'])
return response

View File

@@ -1,139 +0,0 @@
"""OpenAI 接口底层封装
目前使用的对话接口有:
ChatCompletion - gpt-3.5-turbo 等模型
Completion - text-davinci-003 等模型
此模块封装此两个接口的请求实现,为上层提供统一的调用方式
"""
import tiktoken
import openai
from ..openai.api import model as api_model
from ..openai.api import completion as api_completion
from ..openai.api import chat_completion as api_chat_completion
COMPLETION_MODELS = {
"gpt-3.5-turbo-instruct",
}
CHAT_COMPLETION_MODELS = {
# GPT 4 系列
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-0314", # legacy
"gpt-4-32k-0314", # legacy
# GPT 3.5 系列
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613", # legacy
"gpt-3.5-turbo-16k-0613", # legacy
"gpt-3.5-turbo-0301", # legacy
# One-API 接入
"SparkDesk",
"chatglm_pro",
"chatglm_std",
"chatglm_lite",
"qwen-v1",
"qwen-plus-v1",
"ERNIE-Bot",
"ERNIE-Bot-turbo",
"gemini-pro",
}
EDIT_MODELS = {
}
IMAGE_MODELS = {
}
def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> api_model.RequestBase:
if model_name in CHAT_COMPLETION_MODELS:
return api_chat_completion.ChatCompletionRequest(client, model_name, messages, **args)
elif model_name in COMPLETION_MODELS:
return api_completion.CompletionRequest(client, model_name, messages, **args)
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))
def count_chat_completion_tokens(messages: list, model: str) -> int:
"""Return the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
"SparkDesk",
"chatglm_pro",
"chatglm_std",
"chatglm_lite",
"qwen-v1",
"qwen-plus-v1",
"ERNIE-Bot",
"ERNIE-Bot-turbo",
"gemini-pro",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return count_chat_completion_tokens(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return count_chat_completion_tokens(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""count_chat_completion_tokens() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
def count_completion_tokens(messages: list, model: str) -> int:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
text = ""
for message in messages:
text += message['role'] + message['content'] + "\n"
text += "assistant: "
return len(encoding.encode(text))
def count_tokens(messages: list, model: str):
if model in CHAT_COMPLETION_MODELS:
return count_chat_completion_tokens(messages, model)
elif model in COMPLETION_MODELS:
return count_completion_tokens(messages, model)
raise ValueError("不支持模型[{}],请检查配置文件".format(model))

View File

@@ -1,504 +0,0 @@
"""主线使用的会话管理模块
每个人、每个群单独一个sessionsession内部保留了对话的上下文
"""
import logging
import threading
import time
import json
from ..openai import manager as openai_manager
from ..openai import modelmgr as openai_modelmgr
from ..database import manager as database_manager
from ..utils import context as context
from ..plugin import host as plugin_host
from ..plugin import models as plugin_models
# 运行时保存的所有session
sessions = {}
class SessionOfflineStatus:
ON_GOING = 'on_going'
EXPLICITLY_CLOSED = 'explicitly_closed'
# 从数据加载session
def load_sessions():
"""从数据库加载sessions"""
global sessions
db_inst = context.get_database_manager()
session_data = db_inst.load_valid_sessions()
for session_name in session_data:
logging.debug('加载session: {}'.format(session_name))
temp_session = Session(session_name)
temp_session.name = session_name
temp_session.create_timestamp = session_data[session_name]['create_timestamp']
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
temp_session.prompt = json.loads(session_data[session_name]['prompt'])
temp_session.token_counts = json.loads(session_data[session_name]['token_counts'])
temp_session.default_prompt = json.loads(session_data[session_name]['default_prompt']) if \
session_data[session_name]['default_prompt'] else []
sessions[session_name] = temp_session
# 获取指定名称的session如果不存在则创建一个新的
def get_session(session_name: str) -> 'Session':
global sessions
if session_name not in sessions:
sessions[session_name] = Session(session_name)
return sessions[session_name]
def dump_session(session_name: str):
global sessions
if session_name in sessions:
assert isinstance(sessions[session_name], Session)
sessions[session_name].persistence()
del sessions[session_name]
# 通用的OpenAI API交互session
# session内部保留了对话的上下文
# 收到用户消息后将上下文提交给OpenAI API生成回复
class Session:
name = ''
prompt = []
"""使用list来保存会话中的回合"""
default_prompt = []
"""本session的默认prompt"""
create_timestamp = 0
"""会话创建时间"""
last_interact_timestamp = 0
"""上次交互(产生回复)时间"""
just_switched_to_exist_session = False
response_lock = None
# 加锁
def acquire_response_lock(self):
logging.debug('{},lock acquire,{}'.format(self.name, self.response_lock))
self.response_lock.acquire()
logging.debug('{},lock acquire successfully,{}'.format(self.name, self.response_lock))
# 释放锁
def release_response_lock(self):
if self.response_lock.locked():
logging.debug('{},lock release,{}'.format(self.name, self.response_lock))
self.response_lock.release()
logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock))
# 从配置文件获取会话预设信息
def get_default_prompt(self, use_default: str = None):
import pkg.openai.dprompt as dprompt
if use_default is None:
use_default = dprompt.mode_inst().get_using_name()
current_default_prompt, _ = dprompt.mode_inst().get_prompt(use_default)
return current_default_prompt
def __init__(self, name: str):
self.name = name
self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time())
self.prompt = []
self.token_counts = []
self.schedule()
self.response_lock = threading.Lock()
self.default_prompt = self.get_default_prompt()
logging.debug("prompt is: {}".format(self.default_prompt))
# 设定检查session最后一次对话是否超过过期时间的计时器
def schedule(self):
threading.Thread(target=self.expire_check_timer_loop, args=(self.create_timestamp,)).start()
# 检查session是否已经过期
def expire_check_timer_loop(self, create_timestamp: int):
global sessions
while True:
time.sleep(60)
# 不是此session已更换退出
if self.create_timestamp != create_timestamp or self not in sessions.values():
return
config = context.get_config_manager().data
if int(time.time()) - self.last_interact_timestamp > config['session_expire_time']:
logging.info('session {} 已过期'.format(self.name))
# 触发插件事件
args = {
'session_name': self.name,
'session': self,
'session_expire_time': config['session_expire_time']
}
event = plugin_host.emit(plugin_models.SessionExpired, **args)
if event.is_prevented_default():
return
self.reset(expired=True, schedule_new=False)
# 删除此session
del sessions[self.name]
return
# 请求回复
# 这个函数是阻塞的
def query(self, text: str=None) -> tuple[str, str, list[str]]:
"""向session中添加一条消息返回接口回复
Args:
text (str): 用户消息
Returns:
tuple[str, str]: (接口回复, finish_reason, 已调用的函数列表)
"""
self.last_interact_timestamp = int(time.time())
# 触发插件事件
if not self.prompt:
args = {
'session_name': self.name,
'session': self,
'default_prompt': self.default_prompt,
}
event = plugin_host.emit(plugin_models.SessionFirstMessageReceived, **args)
if event.is_prevented_default():
return None, None, None
config = context.get_config_manager().data
max_length = config['prompt_submit_length']
local_default_prompt = self.default_prompt.copy()
local_prompt = self.prompt.copy()
# 触发PromptPreProcessing事件
args = {
'session_name': self.name,
'default_prompt': self.default_prompt,
'prompt': self.prompt,
'text_message': text,
}
event = plugin_host.emit(plugin_models.PromptPreProcessing, **args)
if event.get_return_value('default_prompt') is not None:
local_default_prompt = event.get_return_value('default_prompt')
if event.get_return_value('prompt') is not None:
local_prompt = event.get_return_value('prompt')
if event.get_return_value('text_message') is not None:
text = event.get_return_value('text_message')
# 裁剪messages到合适长度
prompts, _ = self.cut_out(text, max_length, local_default_prompt, local_prompt)
res_text = ""
pending_msgs = []
total_tokens = 0
finish_reason: str = ""
funcs = []
trace_func_calls = config['trace_function_calls']
botmgr = context.get_qqbot_manager()
session_name_spt: list[str] = self.name.split("_")
pending_res_text = ""
start_time = time.time()
# TODO 对不起,我知道这样非常非常屎山,但我之后会重构的
for resp in context.get_openai_manager().request_completion(prompts):
if pending_res_text != "":
botmgr.adapter.send_message(
session_name_spt[0],
session_name_spt[1],
pending_res_text
)
pending_res_text = ""
finish_reason = resp['choices'][0]['finish_reason']
if resp['choices'][0]['message']['role'] == "assistant" and resp['choices'][0]['message']['content'] != None: # 包含纯文本响应
if not trace_func_calls:
res_text += resp['choices'][0]['message']['content']
else:
res_text = resp['choices'][0]['message']['content']
pending_res_text = resp['choices'][0]['message']['content']
total_tokens += resp['usage']['total_tokens']
msg = {
"role": "assistant",
"content": resp['choices'][0]['message']['content']
}
if 'function_call' in resp['choices'][0]['message']:
msg['function_call'] = json.dumps(resp['choices'][0]['message']['function_call'])
pending_msgs.append(msg)
if resp['choices'][0]['message']['type'] == 'function_call':
# self.prompt.append(
# {
# "role": "assistant",
# "content": "function call: "+json.dumps(resp['choices'][0]['message']['function_call'])
# }
# )
if trace_func_calls:
botmgr.adapter.send_message(
session_name_spt[0],
session_name_spt[1],
"调用函数 "+resp['choices'][0]['message']['function_call']['name'] + "..."
)
total_tokens += resp['usage']['total_tokens']
elif resp['choices'][0]['message']['type'] == 'function_return':
# self.prompt.append(
# {
# "role": "function",
# "name": resp['choices'][0]['message']['function_name'],
# "content": json.dumps(resp['choices'][0]['message']['content'])
# }
# )
# total_tokens += resp['usage']['total_tokens']
funcs.append(
resp['choices'][0]['message']['function_name']
)
pass
# 向API请求补全
# message, total_token = pkg.utils.context.get_openai_manager().request_completion(
# prompts,
# )
# 成功获取,处理回复
# res_test = message
res_ans = res_text.strip()
# 将此次对话的双方内容加入到prompt中
# self.prompt.append({'role': 'user', 'content': text})
# self.prompt.append({'role': 'assistant', 'content': res_ans})
if text:
self.prompt.append({'role': 'user', 'content': text})
# 添加pending_msgs
self.prompt += pending_msgs
# 向token_counts中添加本回合的token数量
# self.token_counts.append(total_tokens-total_token_before_query)
# logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts))
if self.just_switched_to_exist_session:
self.just_switched_to_exist_session = False
self.set_ongoing()
# 上报使用量数据
session_type = session_name_spt[0]
session_id = session_name_spt[1]
ability_provider = "QChatGPT.Text"
usage = total_tokens
model_name = context.get_config_manager().data['completion_api_params']['model']
response_seconds = int(time.time() - start_time)
retry_times = -1 # 暂不记录
context.get_center_v2_api().usage.post_query_record(
session_type=session_type,
session_id=session_id,
query_ability_provider=ability_provider,
usage=usage,
model_name=model_name,
response_seconds=response_seconds,
retry_times=retry_times
)
return res_ans if res_ans[0] != '\n' else res_ans[1:], finish_reason, funcs
# 删除上一回合并返回上一回合的问题
def undo(self) -> str:
self.last_interact_timestamp = int(time.time())
# 删除最后两个消息
if len(self.prompt) < 2:
raise Exception('之前无对话,无法撤销')
question = self.prompt[-2]['content']
self.prompt = self.prompt[:-2]
self.token_counts = self.token_counts[:-1]
# 返回上一回合的问题
return question
# 构建对话体
def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) -> tuple[list, list]:
"""将现有prompt进行切割处理使得新的prompt长度不超过max_tokens
:return: (新的prompt, 新的token_counts)
"""
# 最终由三个部分组成
# - default_prompt 情景预设固定值
# - changable_prompts 可变部分, 此会话中的历史对话回合
# - current_question 当前问题
# 包装目前的对话回合内容
changable_prompts = []
use_model = context.get_config_manager().data['completion_api_params']['model']
ptr = len(prompt) - 1
# 直接从后向前扫描拼接,不管是否是整回合
while ptr >= 0:
if openai_modelmgr.count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
break
changable_prompts.insert(0, prompt[ptr])
ptr -= 1
# 将default_prompt和changable_prompts合并
result_prompt = default_prompt + changable_prompts
# 添加当前问题
if msg:
result_prompt.append(
{
'role': 'user',
'content': msg
}
)
logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4)))
return result_prompt, openai_modelmgr.count_tokens(changable_prompts, use_model)
# 持久化session
def persistence(self):
if self.prompt == self.get_default_prompt():
return
db_inst = context.get_database_manager()
name_spt = self.name.split('_')
subject_type = name_spt[0]
subject_number = int(name_spt[1])
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
json.dumps(self.prompt), json.dumps(self.default_prompt), json.dumps(self.token_counts))
# 重置session
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None, persist: bool = False):
if self.prompt:
self.persistence()
if explicit:
# 触发插件事件
args = {
'session_name': self.name,
'session': self
}
# 此事件不支持阻止默认行为
_ = plugin_host.emit(plugin_models.SessionExplicitReset, **args)
context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
if expired:
context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
if not persist: # 不要求保持default prompt
self.default_prompt = self.get_default_prompt(use_prompt)
self.prompt = []
self.token_counts = []
self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time())
self.just_switched_to_exist_session = False
# self.response_lock = threading.Lock()
if schedule_new:
self.schedule()
# 将本session的数据库状态设置为on_going
def set_ongoing(self):
context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
# 切换到上一个session
def last_session(self):
last_one = context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
if last_one is None:
return None
else:
self.persistence()
self.create_timestamp = last_one['create_timestamp']
self.last_interact_timestamp = last_one['last_interact_timestamp']
self.prompt = json.loads(last_one['prompt'])
self.token_counts = json.loads(last_one['token_counts'])
self.default_prompt = json.loads(last_one['default_prompt']) if last_one['default_prompt'] else []
self.just_switched_to_exist_session = True
return self
# 切换到下一个session
def next_session(self):
next_one = context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
if next_one is None:
return None
else:
self.persistence()
self.create_timestamp = next_one['create_timestamp']
self.last_interact_timestamp = next_one['last_interact_timestamp']
self.prompt = json.loads(next_one['prompt'])
self.token_counts = json.loads(next_one['token_counts'])
self.default_prompt = json.loads(next_one['default_prompt']) if next_one['default_prompt'] else []
self.just_switched_to_exist_session = True
return self
def list_history(self, capacity: int = 10, page: int = 0):
return context.get_database_manager().list_history(self.name, capacity, page)
def delete_history(self, index: int) -> bool:
return context.get_database_manager().delete_history(self.name, index)
def delete_all_history(self) -> bool:
return context.get_database_manager().delete_all_history(self.name)
def draw_image(self, prompt: str):
return context.get_openai_manager().request_image(prompt)

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
import re
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class('BanSessionCheckStage')
class BanSessionCheckStage(stage.PipelineStage):
async def initialize(self):
pass
async def process(
self,
query: core_entities.Query,
stage_inst_name: str
) -> entities.StageProcessResult:
found = False
mode = self.ap.pipeline_cfg.data['access-control']['mode']
sess_list = self.ap.pipeline_cfg.data['access-control'][mode]
if (query.launcher_type == 'group' and 'group_*' in sess_list) \
or (query.launcher_type == 'person' and 'person_*' in sess_list):
found = True
else:
for sess in sess_list:
if sess == f"{query.launcher_type}_{query.launcher_id}":
found = True
break
result = False
if mode == 'blacklist':
result = found
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT,
new_query=query,
debug_notice=f'根据访问控制忽略消息: {query.launcher_type}_{query.launcher_id}' if result else ''
)

View File

@@ -0,0 +1,133 @@
from __future__ import annotations
import mirai
from ...core import app
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine
@stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage):
filter_chain: list[filter.ContentFilter]
def __init__(self, ap: app.Application):
self.filter_chain = []
super().__init__(ap)
async def initialize(self):
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
if self.ap.pipeline_cfg.data['check-sensitive-words']:
self.filter_chain.append(banwords.BanWordFilter(self.ap))
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
for filter in self.filter_chain:
await filter.initialize()
async def _pre_process(
self,
message: str,
query: core_entities.Query,
) -> entities.StageProcessResult:
"""请求llm前处理消息
只要有一个不通过就不放行,只放行 PASS 的消息
"""
if not self.ap.pipeline_cfg.data['income-msg-check']:
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
for filter in self.filter_chain:
if filter_entities.EnableStage.PRE in filter.enable_stages:
result = await filter.process(message)
if result.level in [
filter_entities.ResultLevel.BLOCK,
filter_entities.ResultLevel.MASKED
]:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = mirai.MessageChain(
mirai.Plain(message)
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
async def _post_process(
self,
message: str,
query: core_entities.Query,
) -> entities.StageProcessResult:
"""请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
"""
if message is None:
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
message = message.strip()
for filter in self.filter_chain:
if filter_entities.EnableStage.POST in filter.enable_stages:
result = await filter.process(message)
if result.level == filter_entities.ResultLevel.BLOCK:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level in [
filter_entities.ResultLevel.PASS,
filter_entities.ResultLevel.MASKED
]:
message = result.replacement
query.resp_messages[-1].content = message
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
async def process(
self,
query: core_entities.Query,
stage_inst_name: str
) -> entities.StageProcessResult:
"""处理
"""
if stage_inst_name == 'PreContentFilterStage':
return await self._pre_process(
str(query.message_chain).strip(),
query
)
elif stage_inst_name == 'PostContentFilterStage':
return await self._post_process(
query.resp_messages[-1].content,
query
)
else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')

View File

@@ -0,0 +1,64 @@
import typing
import enum
import pydantic
class ResultLevel(enum.Enum):
"""结果等级"""
PASS = enum.auto()
"""通过"""
WARN = enum.auto()
"""警告"""
MASKED = enum.auto()
"""已掩去"""
BLOCK = enum.auto()
"""阻止"""
class EnableStage(enum.Enum):
"""启用阶段"""
PRE = enum.auto()
"""预处理"""
POST = enum.auto()
"""后处理"""
class FilterResult(pydantic.BaseModel):
level: ResultLevel
replacement: str
"""替换后的消息"""
user_notice: str
"""不通过时,用户提示消息"""
console_notice: str
"""不通过时,控制台提示消息"""
class ManagerResultLevel(enum.Enum):
"""处理器结果等级"""
CONTINUE = enum.auto()
"""继续"""
INTERRUPT = enum.auto()
"""中断"""
class FilterManagerResult(pydantic.BaseModel):
level: ManagerResultLevel
replacement: str
"""替换后的消息"""
user_notice: str
"""用户提示消息"""
console_notice: str
"""控制台提示消息"""

View File

@@ -0,0 +1,34 @@
# 内容过滤器的抽象类
from __future__ import annotations
import abc
from ...core import app
from . import entities
class ContentFilter(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
@property
def enable_stages(self):
"""启用的阶段
"""
return [
entities.EnableStage.PRE,
entities.EnableStage.POST
]
async def initialize(self):
"""初始化过滤器
"""
pass
@abc.abstractmethod
async def process(self, message: str) -> entities.FilterResult:
"""处理消息
"""
raise NotImplementedError

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
import aiohttp
from .. import entities
from .. import filter as filter_model
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核"""
async def _get_token(self) -> str:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_TOKEN_URL,
params={
"grant_type": "client_credentials",
"client_id": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
"client_secret": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret']
}
) as resp:
return (await resp.json())['access_token']
async def process(self, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()),
headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'},
data=f"text={message}".encode('utf-8')
) as resp:
result = await resp.json()
if "error_code" in result:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice='',
console_notice=f"百度云判定出错,错误信息:{result['error_msg']}"
)
else:
conclusion = result["conclusion"]
if conclusion in ("合规"):
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=f"百度云判定结果:{conclusion}"
)
else:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice="消息中存在不合适的内容, 请修改",
console_notice=f"百度云判定结果:{conclusion}"
)

View File

@@ -0,0 +1,44 @@
from __future__ import annotations
import re
from .. import filter as filter_model
from .. import entities
from ....config import manager as cfg_mgr
class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言"""
sensitive: cfg_mgr.ConfigManager
async def initialize(self):
self.sensitive = await cfg_mgr.load_json_config(
"data/config/sensitive-words.json",
"templates/sensitive-words.json"
)
async def process(self, message: str) -> entities.FilterResult:
found = False
for word in self.sensitive.data['words']:
match = re.findall(word, message)
if len(match) > 0:
found = True
for i in range(len(match)):
if self.sensitive.data['mask_word'] == "":
message = message.replace(
match[i], self.sensitive.data['mask'] * len(match[i])
)
else:
message = message.replace(
match[i], self.sensitive.data['mask_word']
)
return entities.FilterResult(
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
replacement=message,
user_notice='消息中存在不合适的内容, 请修改' if found else '',
console_notice=''
)

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
import re
from .. import entities
from .. import filter as filter_model
class ContentIgnore(filter_model.ContentFilter):
"""根据内容忽略消息"""
@property
def enable_stages(self):
return [
entities.EnableStage.PRE,
]
async def process(self, message: str) -> entities.FilterResult:
if 'prefix' in self.ap.pipeline_cfg.data['ignore-rules']:
for rule in self.ap.pipeline_cfg.data['ignore-rules']['prefix']:
if message.startswith(rule):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
)
if 'regexp' in self.ap.pipeline_cfg.data['ignore-rules']:
for rule in self.ap.pipeline_cfg.data['ignore-rules']['regexp']:
if re.search(rule, message):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息'
)
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=''
)

161
pkg/pipeline/controller.py Normal file
View File

@@ -0,0 +1,161 @@
from __future__ import annotations
import asyncio
import typing
import traceback
from ..core import app, entities
from . import entities as pipeline_entities
from ..plugin import events
class Controller:
"""总控制器
"""
ap: app.Application
semaphore: asyncio.Semaphore = None
"""请求并发控制信号量"""
def __init__(self, ap: app.Application):
self.ap = ap
self.semaphore = asyncio.Semaphore(self.ap.system_cfg.data['pipeline-concurrency'])
async def consumer(self):
"""事件处理循环
"""
try:
while True:
selected_query: entities.Query = None
# 取请求
async with self.ap.query_pool:
queries: list[entities.Query] = self.ap.query_pool.queries
for query in queries:
session = await self.ap.sess_mgr.get_session(query)
self.ap.logger.debug(f"Checking query {query} session {session}")
if not session.semaphore.locked():
selected_query = query
await session.semaphore.acquire()
break
if selected_query: # 找到了
queries.remove(selected_query)
else: # 没找到 说明:没有请求 或者 所有query对应的session都已达到并发上限
await self.ap.query_pool.condition.wait()
continue
if selected_query:
async def _process_query(selected_query):
async with self.semaphore: # 总并发上限
await self.process_query(selected_query)
async with self.ap.query_pool:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
# 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all()
asyncio.create_task(_process_query(selected_query))
except Exception as e:
# traceback.print_exc()
self.ap.logger.error(f"控制器循环出错: {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
"""检查输出
"""
if result.user_notice:
await self.ap.im_mgr.send(
query.message_event,
result.user_notice,
query.adapter
)
if result.debug_notice:
self.ap.logger.debug(result.debug_notice)
if result.console_notice:
self.ap.logger.info(result.console_notice)
if result.error_notice:
self.ap.logger.error(result.error_notice)
async def _execute_from_stage(
self,
stage_index: int,
query: entities.Query,
):
"""从指定阶段开始执行
如何看懂这里为什么这么写?
去问 GPT-4:
Q1: 现在有一个责任链其中有多个stagequery对象在其中传递stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None]
如果返回的是生成器需要挨个生成result检查是否result中是否要求继续如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了就继续生成下一个result
调用后续的stage直到该生成器全部生成完。责任链中可能有多个stage会返回生成器
Q2: 不是这样的你可能理解有误。如果我们责任链上有这些Stage
A B C D E F G
如果所有的stage都返回Result且所有Result都要求继续那么执行顺序是
A B C D E F G
现在假设C返回的是AsyncGenerator那么执行顺序是
A B C D E F G C D E F G C D E F G ...
Q3: 但是如果不止一个stage会返回生成器呢
"""
i = stage_index
while i < len(self.ap.stage_mgr.stage_containers):
stage_container = self.ap.stage_mgr.stage_containers[i]
result = stage_container.inst.process(query, stage_container.inst_name)
if isinstance(result, typing.Coroutine):
result = await result
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
await self._check_output(query, result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
query = result.new_query
elif isinstance(result, typing.AsyncGenerator): # 生成器
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
async for sub_result in result:
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
await self._check_output(query, sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
query = sub_result.new_query
await self._execute_from_stage(i + 1, query)
break
i += 1
async def process_query(self, query: entities.Query):
"""处理请求
"""
self.ap.logger.debug(f"Processing query {query}")
try:
await self._execute_from_stage(0, query)
except Exception as e:
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id}: {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
# traceback.print_exc()
finally:
self.ap.logger.debug(f"Query {query} processed")
async def run(self):
"""运行控制器
"""
await self.consumer()

40
pkg/pipeline/entities.py Normal file
View File

@@ -0,0 +1,40 @@
from __future__ import annotations
import enum
import typing
import pydantic
import mirai
import mirai.models.message as mirai_message
from ..core import entities
class ResultType(enum.Enum):
CONTINUE = enum.auto()
"""继续流水线"""
INTERRUPT = enum.auto()
"""中断流水线"""
class StageProcessResult(pydantic.BaseModel):
result_type: ResultType
new_query: entities.Query
user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
"""只要设置了就会发送给用户"""
# TODO delete
# admin_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
"""只要设置了就会发送给管理员"""
console_notice: typing.Optional[str] = ''
"""只要设置了就会输出到控制台"""
debug_notice: typing.Optional[str] = ''
error_notice: typing.Optional[str] = ''

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
import os
import traceback
from PIL import Image, ImageDraw, ImageFont
from mirai.models.message import MessageComponent, Plain, MessageChain
from ...core import app
from . import strategy
from .strategies import image, forward
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class("LongTextProcessStage")
class LongTextProcessStage(stage.PipelineStage):
strategy_impl: strategy.LongTextStrategy
async def initialize(self):
config = self.ap.platform_cfg.data['long-text-process']
if config['strategy'] == 'image':
use_font = config['font-path']
try:
# 检查是否存在
if not os.path.exists(use_font):
# 若是windows系统使用微软雅黑
if os.name == "nt":
use_font = "C:/Windows/Fonts/msyh.ttc"
if not os.path.exists(use_font):
self.ap.logger.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。")
config['blob_message_strategy'] = "forward"
else:
self.ap.logger.info("使用Windows自带字体" + use_font)
config['font-path'] = use_font
else:
self.ap.logger.warn("未找到字体文件且无法使用系统自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。")
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
except:
traceback.print_exc()
self.ap.logger.error("加载字体文件失败({})更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。".format(use_font))
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
if config['strategy'] == 'image':
self.strategy_impl = image.Text2ImageStrategy(self.ap)
elif config['strategy'] == 'forward':
self.strategy_impl = forward.ForwardComponentStrategy(self.ap)
await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -0,0 +1,63 @@
# 转发消息组件
from __future__ import annotations
import typing
from mirai.models import MessageChain
from mirai.models.message import MessageComponent, ForwardMessageNode
from mirai.models.base import MiraiBaseModel
from .. import strategy as strategy_model
from ....core import entities as core_entities
class ForwardMessageDiaplay(MiraiBaseModel):
title: str = "群聊的聊天记录"
brief: str = "[聊天记录]"
source: str = "聊天记录"
preview: typing.List[str] = []
summary: str = "查看x条转发消息"
class Forward(MessageComponent):
"""合并转发。"""
type: str = "Forward"
"""消息组件类型。"""
display: ForwardMessageDiaplay
"""显示信息"""
node_list: typing.List[ForwardMessageNode]
"""转发消息节点列表。"""
def __init__(self, *args, **kwargs):
if len(args) == 1:
self.node_list = args[0]
super().__init__(**kwargs)
super().__init__(*args, **kwargs)
def __str__(self):
return '[聊天记录]'
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
display = ForwardMessageDiaplay(
title="群聊的聊天记录",
brief="[聊天记录]",
source="聊天记录",
preview=["QQ用户: "+message],
summary="查看1条转发消息"
)
node_list = [
ForwardMessageNode(
sender_id=query.adapter.bot_account_id,
sender_name='QQ用户',
message_chain=MessageChain([message])
)
]
forward = Forward(
display=display,
node_list=node_list
)
return [forward]

View File

@@ -0,0 +1,198 @@
from __future__ import annotations
import typing
import os
import base64
import time
import re
from PIL import Image, ImageDraw, ImageFont
from mirai.models import MessageChain, Image as ImageComponent
from mirai.models.message import MessageComponent
from .. import strategy as strategy_model
from ....core import entities as core_entities
class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_render_font: ImageFont.FreeTypeFont
async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time()))
)
compressed_path, size = self.compress_image(
img_path,
outfile="temp/{}_compressed.png".format(int(time.time()))
)
with open(compressed_path, 'rb') as f:
img = f.read()
b64 = base64.b64encode(img)
# 删除图片
os.remove(img_path)
if os.path.exists(compressed_path):
os.remove(compressed_path)
return [
ImageComponent(
base64=b64.decode('utf-8'),
)
]
def indexNumber(self, path=''):
"""
查找字符串中数字所在串中的位置
:param path:目标字符串
:return:<class 'list'>: <class 'list'>: [['1', 16], ['2', 35], ['1', 51]]
"""
kv = []
nums = []
beforeDatas = re.findall('[\d]+', path)
for num in beforeDatas:
indexV = []
times = path.count(num)
if times > 1:
if num not in nums:
indexs = re.finditer(num, path)
for index in indexs:
iV = []
i = index.span()[0]
iV.append(num)
iV.append(i)
kv.append(iV)
nums.append(num)
else:
index = path.find(num)
indexV.append(num)
indexV.append(index)
kv.append(indexV)
# 根据数字位置排序
indexSort = []
resultIndex = []
for vi in kv:
indexSort.append(vi[1])
indexSort.sort()
for i in indexSort:
for v in kv:
if i == v[1]:
resultIndex.append(v)
return resultIndex
def get_size(self, file):
# 获取文件大小:KB
size = os.path.getsize(file)
return size / 1024
def get_outfile(self, infile, outfile):
if outfile:
return outfile
dir, suffix = os.path.splitext(infile)
outfile = '{}-out{}'.format(dir, suffix)
return outfile
def compress_image(self, infile, outfile='', kb=100, step=20, quality=90):
"""不改变图片尺寸压缩到指定大小
:param infile: 压缩源文件
:param outfile: 压缩文件保存地址
:param mb: 压缩目标,KB
:param step: 每次调整的压缩比率
:param quality: 初始压缩比率
:return: 压缩文件地址,压缩文件大小
"""
o_size = self.get_size(infile)
if o_size <= kb:
return infile, o_size
outfile = self.get_outfile(infile, outfile)
while o_size > kb:
im = Image.open(infile)
im.save(outfile, quality=quality)
if quality - step < 0:
break
quality -= step
o_size = self.get_size(outfile)
return outfile, self.get_size(outfile)
def text_to_image(self, text_str: str, save_as="temp.png", width=800):
text_str = text_str.replace("\t", " ")
# 分行
lines = text_str.split('\n')
# 计算并分割
final_lines = []
text_width = width-80
self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width))
for line in lines:
# 如果长了就分割
line_width = self.text_render_font.getlength(line)
self.ap.logger.debug("line_width: {}".format(line_width))
if line_width < text_width:
final_lines.append(line)
continue
else:
rest_text = line
while True:
# 分割最前面的一行
point = int(len(rest_text) * (text_width / line_width))
# 检查断点是否在数字中间
numbers = self.indexNumber(rest_text)
for number in numbers:
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
point = number[1]
break
final_lines.append(rest_text[:point])
rest_text = rest_text[point:]
line_width = self.text_render_font.getlength(rest_text)
if line_width < text_width:
final_lines.append(rest_text)
break
else:
continue
# 准备画布
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
draw = ImageDraw.Draw(img, mode='RGBA')
self.ap.logger.debug("正在绘制图片...")
# 绘制正文
line_number = 0
offset_x = 20
offset_y = 30
for final_line in final_lines:
draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=self.text_render_font)
# 遍历此行,检查是否有emoji
idx_in_line = 0
for ch in final_line:
# 检查字符占位宽
char_code = ord(ch)
if char_code >= 127:
idx_in_line += 1
else:
idx_in_line += 0.5
line_number += 1
self.ap.logger.debug("正在保存图片...")
img.save(save_as)
return save_as

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
import abc
import typing
import mirai
from mirai.models.message import MessageComponent
from ...core import app
from ...core import entities as core_entities
class LongTextStrategy(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
return []

57
pkg/pipeline/pool.py Normal file
View File

@@ -0,0 +1,57 @@
from __future__ import annotations
import asyncio
import mirai
from ..core import entities
from ..platform import adapter as msadapter
class QueryPool:
query_id_counter: int = 0
pool_lock: asyncio.Lock
queries: list[entities.Query]
condition: asyncio.Condition
def __init__(self):
self.query_id_counter = 0
self.pool_lock = asyncio.Lock()
self.queries = []
self.condition = asyncio.Condition(self.pool_lock)
async def add_query(
self,
launcher_type: entities.LauncherTypes,
launcher_id: int,
sender_id: int,
message_event: mirai.MessageEvent,
message_chain: mirai.MessageChain,
adapter: msadapter.MessageSourceAdapter
) -> entities.Query:
async with self.condition:
query = entities.Query(
query_id=self.query_id_counter,
launcher_type=launcher_type,
launcher_id=launcher_id,
sender_id=sender_id,
message_event=message_event,
message_chain=message_chain,
resp_messages=[],
resp_message_chain=None,
adapter=adapter
)
self.queries.append(query)
self.query_id_counter += 1
self.condition.notify_all()
async def __aenter__(self):
await self.pool_lock.acquire()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.pool_lock.release()

View File

@@ -0,0 +1,79 @@
from __future__ import annotations
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...provider import entities as llm_entities
from ...plugin import events
@stage.stage_class("PreProcessor")
class PreProcessor(stage.PipelineStage):
"""预处理器
"""
async def process(
self,
query: core_entities.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""处理
"""
session = await self.ap.sess_mgr.get_session(query)
conversation = await self.ap.sess_mgr.get_conversation(session)
# 从会话取出消息和情景预设到query
query.session = session
query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy()
query.user_message = llm_entities.Message(
role='user',
content=str(query.message_chain).strip()
)
query.use_model = conversation.use_model
query.use_funcs = conversation.use_funcs
# =========== 触发事件 PromptPreProcessing
session = query.session
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PromptPreProcessing(
session_name=f'{session.launcher_type.value}_{session.launcher_id}',
default_prompt=query.prompt.messages,
prompt=query.messages,
query=query
)
)
query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt
# 根据模型max_tokens剪裁
max_tokens = min(query.use_model.max_tokens, self.ap.pipeline_cfg.data['submit-messages-tokens'])
test_messages = query.prompt.messages + query.messages + [query.user_message]
while await query.use_model.tokenizer.count_token(test_messages, query.use_model) > max_tokens:
# 前文都pop完了还是大于max_tokens由于prompt和user_messages不能删减报错
if len(query.prompt.messages) == 0:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='输入内容过长,请减少情景预设或者输入内容长度',
console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 submit-messages-tokens 项但不能超过所用模型最大tokens数'
)
query.messages.pop(0) # pop第一个肯定是role=user的
# 继续pop到第二个role=user前一个
while len(query.messages) > 0 and query.messages[0].role != 'user':
query.messages.pop(0)
test_messages = query.prompt.messages + query.messages + [query.user_message]
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -0,0 +1,34 @@
from __future__ import annotations
import abc
from ...core import app
from ...core import entities as core_entities
from .. import entities
class MessageHandler(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def handle(
self,
query: core_entities.Query,
) -> entities.StageProcessResult:
raise NotImplementedError
def cut_str(self, s: str) -> str:
"""
取字符串第一行最多20个字符若有多行或超过20个字符则加省略号
"""
s0 = s.split('\n')[0]
if len(s0) > 20 or '\n' in s:
s0 = s0[:20] + '...'
return s0

View File

@@ -0,0 +1,113 @@
from __future__ import annotations
import typing
import time
import traceback
import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities
from ....plugin import events
class ChatMessageHandler(handler.MessageHandler):
async def handle(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
# 取session
# 取conversation
# 调API
# 生成器
# 触发插件事件
event_class = events.PersonNormalMessageReceived if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupNormalMessageReceived
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_class(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
text_message=str(query.message_chain),
query=query
)
)
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if not self.ap.provider_cfg.data['enable-chat']:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
)
if event_ctx.event.alter is not None:
query.message_chain = mirai.MessageChain([
mirai.Plain(event_ctx.event.alter)
])
query.messages.append(
query.user_message
)
text_length = 0
start_time = time.time()
try:
async for result in query.use_model.requester.request(query):
query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
if result.content is not None:
text_length += len(result.content)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
except Exception as e:
self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}')
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='请求失败' if self.ap.platform_cfg.data['hide-exception-info'] else f'{e}',
error_notice=f'{e}',
debug_notice=traceback.format_exc()
)
finally:
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
await self.ap.ctr_mgr.usage.post_query_record(
session_type=query.session.launcher_type.value,
session_id=str(query.session.launcher_id),
query_ability_provider="QChatGPT.Chat",
usage=text_length,
model_name=query.use_model.name,
response_seconds=int(time.time() - start_time),
retry_times=-1,
)

View File

@@ -0,0 +1,121 @@
from __future__ import annotations
import typing
import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities
from ....plugin import events
class CommandHandler(handler.MessageHandler):
async def handle(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent
privilege = 1
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']:
privilege = 2
spt = str(query.message_chain).strip().split(' ')
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_class(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
command=spt[0],
params=spt[1:] if len(spt) > 1 else [],
text_message=str(query.message_chain),
is_admin=(privilege==2),
query=query
)
)
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append(
llm_entities.Message(
role='command',
content=str(mc),
)
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.alter is not None:
query.message_chain = mirai.MessageChain([
mirai.Plain(event_ctx.event.alter)
])
session = await self.ap.sess_mgr.get_session(query)
command_text = str(query.message_chain).strip()[1:]
async for ret in self.ap.cmd_mgr.execute(
command_text=command_text,
query=query,
session=session
):
if ret.error is not None:
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(str(ret.error))
# ])
query.resp_messages.append(
llm_entities.Message(
role='command',
content=str(ret.error),
)
)
self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}')
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif ret.text is not None:
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(ret.text)
# ])
query.resp_messages.append(
llm_entities.Message(
role='command',
content=ret.text,
)
)
self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}')
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
from ...core import app, entities as core_entities
from . import handler
from .handlers import chat, command
from .. import entities
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class("MessageProcessor")
class Processor(stage.PipelineStage):
cmd_handler: handler.MessageHandler
chat_handler: handler.MessageHandler
async def initialize(self):
self.cmd_handler = command.CommandHandler(self.ap)
self.chat_handler = chat.ChatMessageHandler(self.ap)
await self.cmd_handler.initialize()
await self.chat_handler.initialize()
async def process(
self,
query: core_entities.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""处理
"""
message_text = str(query.message_chain).strip()
self.ap.logger.info(f"处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}")
async def generator():
if message_text.startswith('!') or message_text.startswith(''):
async for result in self.cmd_handler.handle(query):
yield result
else:
async for result in self.chat_handler.handle(query):
yield result
return generator()

View File

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
import abc
from ...core import app
class ReteLimitAlgo(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
raise NotImplementedError
@abc.abstractmethod
async def release_access(self, launcher_type: str, launcher_id: int):
raise NotImplementedError

Some files were not shown because too many files have changed in this diff Show More