mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 03:15:06 +08:00
Compare commits
78 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
092bb0a1e2 | ||
|
|
2c3399e237 | ||
|
|
835275b47f | ||
|
|
7b060ce3f9 | ||
|
|
1fb69311b0 | ||
|
|
995d1f61d2 | ||
|
|
80258e9182 | ||
|
|
bd6a32e08e | ||
|
|
5f138de75b | ||
|
|
d0b0f2209a | ||
|
|
0752698c1d | ||
|
|
9855c6b8f5 | ||
|
|
52a7c25540 | ||
|
|
fa823de6b0 | ||
|
|
f53070d8b6 | ||
|
|
7677672691 | ||
|
|
dead8fa168 | ||
|
|
c6347bea45 | ||
|
|
32bd194bfc | ||
|
|
cca48a394d | ||
|
|
a723c8ce37 | ||
|
|
327b2509f6 | ||
|
|
1dae7bd655 | ||
|
|
550a131685 | ||
|
|
0cfb8bb29f | ||
|
|
9c32420a95 | ||
|
|
867093cc88 | ||
|
|
82763f8ec5 | ||
|
|
97449065df | ||
|
|
9489783846 | ||
|
|
f91c9015bc | ||
|
|
302d86056d | ||
|
|
98bebfddaa | ||
|
|
dab20e3187 | ||
|
|
09e72f7c5f | ||
|
|
2028d85f84 | ||
|
|
ed3c0d9014 | ||
|
|
be06150990 | ||
|
|
afb3fb4a31 | ||
|
|
d66577e6c3 | ||
|
|
6a4ea5446a | ||
|
|
74e84c744a | ||
|
|
5ad2446cf3 | ||
|
|
63303bb5c0 | ||
|
|
13393b6624 | ||
|
|
b9fa11c0c3 | ||
|
|
8c6ce1f030 | ||
|
|
1d963d0f0c | ||
|
|
0ee383be27 | ||
|
|
53d09129b4 | ||
|
|
a398c6f311 | ||
|
|
4347ddd42a | ||
|
|
22cb8a6a06 | ||
|
|
7f554fd862 | ||
|
|
a82bfa8a56 | ||
|
|
95784debbf | ||
|
|
2471c5bf0f | ||
|
|
2fe6d731b8 | ||
|
|
ce881372ee | ||
|
|
171ea7c375 | ||
|
|
1e9a6f813f | ||
|
|
39a7f3b2b9 | ||
|
|
8d375a02db | ||
|
|
cac8a0a414 | ||
|
|
c89623967e | ||
|
|
92aa9c1711 | ||
|
|
71f2a58acb | ||
|
|
1f07a8a9e3 | ||
|
|
cacd21bde7 | ||
|
|
a060ec66c3 | ||
|
|
fd10db3c75 | ||
|
|
db4c658980 | ||
|
|
0ee88674f8 | ||
|
|
3540759682 | ||
|
|
44cc8f15b4 | ||
|
|
59f821bf0a | ||
|
|
80858672b0 | ||
|
|
3258d5b255 |
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -16,11 +16,13 @@ body:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 登录框架
|
||||
label: 消息平台适配器
|
||||
description: "连接QQ使用的框架"
|
||||
options:
|
||||
- Mirai
|
||||
- go-cqhttp
|
||||
- yiri-mirai(Mirai)
|
||||
- Nakuru(go-cqhttp)
|
||||
- aiocqhttp(使用 OneBot 协议接入的)
|
||||
- qq-botpy(QQ官方API)
|
||||
validations:
|
||||
required: false
|
||||
- type: input
|
||||
|
||||
42
.github/workflows/build-docker-image.yml
vendored
42
.github/workflows/build-docker-image.yml
vendored
@@ -17,22 +17,32 @@ jobs:
|
||||
run: |
|
||||
if [ -z "$GITHUB_REF" ]; then
|
||||
export GITHUB_REF=${{ github.ref }}
|
||||
echo $GITHUB_REF
|
||||
fi
|
||||
# - name: Check GITHUB_REF env
|
||||
# run: echo $GITHUB_REF
|
||||
# - name: Get version # 在 GitHub Actions 运行环境
|
||||
# id: get_version
|
||||
# if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
|
||||
# run: export GITHUB_REF=${GITHUB_REF/refs\/tags\//}
|
||||
- name: Check version
|
||||
id: check_version
|
||||
run: |
|
||||
echo $GITHUB_REF
|
||||
# 如果是tag,则去掉refs/tags/前缀
|
||||
if [[ $GITHUB_REF == refs/tags/* ]]; then
|
||||
echo "It's a tag"
|
||||
echo $GITHUB_REF
|
||||
echo $GITHUB_REF | awk -F '/' '{print $3}'
|
||||
echo ::set-output name=version::$(echo $GITHUB_REF | awk -F '/' '{print $3}')
|
||||
else
|
||||
echo "It's not a tag"
|
||||
echo $GITHUB_REF
|
||||
echo ::set-output name=version::${GITHUB_REF}
|
||||
fi
|
||||
- name: Check GITHUB_REF env
|
||||
run: echo $GITHUB_REF
|
||||
- name: Get version
|
||||
id: get_version
|
||||
if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
|
||||
run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//}
|
||||
- name: Build # image name: rockchin/qchatgpt:<VERSION>
|
||||
run: docker build --network=host -t rockchin/qchatgpt:${{ steps.get_version.outputs.VERSION }} -t rockchin/qchatgpt:latest .
|
||||
- name: Login to Registry
|
||||
run: docker login --username=${{ secrets.DOCKER_USERNAME }} --password ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Push image
|
||||
if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
|
||||
run: docker push rockchin/qchatgpt:${{ steps.get_version.outputs.VERSION }}
|
||||
|
||||
- name: Push latest image
|
||||
if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
|
||||
run: docker push rockchin/qchatgpt:latest
|
||||
- name: Create Buildx
|
||||
run: docker buildx create --name mybuilder --use
|
||||
- name: Build # image name: rockchin/qchatgpt:<VERSION>
|
||||
run: docker buildx build --platform linux/arm64,linux/amd64 -t rockchin/qchatgpt:${{ steps.check_version.outputs.version }} -t rockchin/qchatgpt:latest . --push
|
||||
|
||||
@@ -3,6 +3,9 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN python -m pip install -r requirements.txt
|
||||
RUN apt update \
|
||||
&& apt install gcc -y \
|
||||
&& python -m pip install -r requirements.txt \
|
||||
&& touch /.dockerenv
|
||||
|
||||
CMD [ "python", "main.py" ]
|
||||
@@ -11,7 +11,7 @@
|
||||
<a href="https://hub.docker.com/repository/docker/rockchin/qchatgpt">
|
||||
<img src="https://img.shields.io/docker/pulls/rockchin/qchatgpt?color=blue" alt="docker pull">
|
||||
</a>
|
||||

|
||||

|
||||
<a href="https://codecov.io/gh/RockChinQ/QChatGPT" >
|
||||
<img src="https://codecov.io/gh/RockChinQ/QChatGPT/graph/badge.svg?token=pjxYIL2kbC"/>
|
||||
</a>
|
||||
|
||||
44
main.py
44
main.py
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
# QChatGPT 终端启动入口
|
||||
# 在此层级解决依赖项检查。
|
||||
|
||||
asciiart = r"""
|
||||
___ ___ _ _ ___ ___ _____
|
||||
@@ -11,8 +11,44 @@ asciiart = r"""
|
||||
📖文档地址: https://q.rkcn.top
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
async def main_entry():
|
||||
print(asciiart)
|
||||
|
||||
import sys
|
||||
|
||||
# 检查依赖
|
||||
|
||||
from pkg.core.bootutils import deps
|
||||
|
||||
missing_deps = await deps.check_deps()
|
||||
|
||||
if missing_deps:
|
||||
print("以下依赖包未安装,将自动安装,请完成后重启程序:")
|
||||
for dep in missing_deps:
|
||||
print("-", dep)
|
||||
await deps.install_deps(missing_deps)
|
||||
print("已自动安装缺失的依赖包,请重启程序。")
|
||||
sys.exit(0)
|
||||
|
||||
# 检查配置文件
|
||||
|
||||
from pkg.core.bootutils import files
|
||||
|
||||
generated_files = await files.generate_files()
|
||||
|
||||
if generated_files:
|
||||
print("以下文件不存在,已自动生成,请按需修改配置文件后重启:")
|
||||
for file in generated_files:
|
||||
print("-", file)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
from pkg.core import boot
|
||||
asyncio.run(boot.main())
|
||||
await boot.main()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main_entry())
|
||||
|
||||
@@ -34,6 +34,9 @@ class APIGroup(metaclass=abc.ABCMeta):
|
||||
headers: dict = {},
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
执行请求
|
||||
"""
|
||||
self._runtime_info['account_id'] = "-1"
|
||||
|
||||
url = self.prefix + path
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# 实例 识别码 控制
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..provider import entities as llm_entities
|
||||
from . import entities, operator, errors
|
||||
from ..config import manager as cfg_mgr
|
||||
|
||||
# 引入所有算子以便注册
|
||||
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update
|
||||
|
||||
|
||||
@@ -17,6 +18,9 @@ class CommandManager:
|
||||
ap: app.Application
|
||||
|
||||
cmd_list: list[operator.CommandOperator]
|
||||
"""
|
||||
运行时命令列表,扁平存储,各个对象包含对应的子节点引用
|
||||
"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
@@ -60,7 +64,7 @@ class CommandManager:
|
||||
"""
|
||||
|
||||
found = False
|
||||
if len(context.crt_params) > 0:
|
||||
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
|
||||
for oper in operator_list:
|
||||
if (context.crt_params[0] == oper.name \
|
||||
or context.crt_params[0] in oper.alias) \
|
||||
@@ -78,7 +82,7 @@ class CommandManager:
|
||||
yield ret
|
||||
break
|
||||
|
||||
if not found:
|
||||
if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
|
||||
if operator is None:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandNotFoundError(context.crt_params[0])
|
||||
|
||||
@@ -10,6 +10,8 @@ from . import errors, operator
|
||||
|
||||
|
||||
class CommandReturn(pydantic.BaseModel):
|
||||
"""命令返回值
|
||||
"""
|
||||
|
||||
text: typing.Optional[str]
|
||||
"""文本
|
||||
@@ -18,25 +20,52 @@ class CommandReturn(pydantic.BaseModel):
|
||||
image: typing.Optional[mirai.Image]
|
||||
|
||||
error: typing.Optional[errors.CommandError]= None
|
||||
"""错误
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ExecuteContext(pydantic.BaseModel):
|
||||
"""单次命令执行上下文
|
||||
"""
|
||||
|
||||
query: core_entities.Query
|
||||
"""本次消息的请求对象"""
|
||||
|
||||
session: core_entities.Session
|
||||
"""本次消息所属的会话对象"""
|
||||
|
||||
command_text: str
|
||||
"""命令完整文本"""
|
||||
|
||||
command: str
|
||||
"""命令名称"""
|
||||
|
||||
crt_command: str
|
||||
"""当前命令
|
||||
|
||||
多级命令中crt_command为当前命令,command为根命令。
|
||||
例如:!plugin on Webwlkr
|
||||
处理到plugin时,command为plugin,crt_command为plugin
|
||||
处理到on时,command为plugin,crt_command为on
|
||||
"""
|
||||
|
||||
params: list[str]
|
||||
"""命令参数
|
||||
|
||||
整个命令以空格分割后的参数列表
|
||||
"""
|
||||
|
||||
crt_params: list[str]
|
||||
"""当前命令参数
|
||||
|
||||
多级命令中crt_params为当前命令参数,params为根命令参数。
|
||||
例如:!plugin on Webwlkr
|
||||
处理到plugin时,params为['on', 'Webwlkr'],crt_params为['on', 'Webwlkr']
|
||||
处理到on时,params为['on', 'Webwlkr'],crt_params为['Webwlkr']
|
||||
"""
|
||||
|
||||
privilege: int
|
||||
"""发起人权限"""
|
||||
|
||||
@@ -8,17 +8,34 @@ from . import entities
|
||||
|
||||
|
||||
preregistered_operators: list[typing.Type[CommandOperator]] = []
|
||||
"""预注册命令算子列表。在初始化时,所有算子类会被注册到此列表中。"""
|
||||
|
||||
|
||||
def operator_class(
|
||||
name: str,
|
||||
help: str,
|
||||
help: str = "",
|
||||
usage: str = None,
|
||||
alias: list[str] = [],
|
||||
privilege: int=1, # 1为普通用户,2为管理员
|
||||
parent_class: typing.Type[CommandOperator] = None
|
||||
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
|
||||
"""命令类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 名称
|
||||
help (str, optional): 帮助信息. Defaults to "".
|
||||
usage (str, optional): 使用说明. Defaults to None.
|
||||
alias (list[str], optional): 别名. Defaults to [].
|
||||
privilege (int, optional): 权限,1为普通用户可用,2为仅管理员可用. Defaults to 1.
|
||||
parent_class (typing.Type[CommandOperator], optional): 父节点,若为None则为顶级命令. Defaults to None.
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 装饰器
|
||||
"""
|
||||
|
||||
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
|
||||
assert issubclass(cls, CommandOperator)
|
||||
|
||||
cls.name = name
|
||||
cls.alias = alias
|
||||
cls.help = help
|
||||
@@ -34,7 +51,12 @@ def operator_class(
|
||||
|
||||
|
||||
class CommandOperator(metaclass=abc.ABCMeta):
|
||||
"""命令算子
|
||||
"""命令算子抽象类
|
||||
|
||||
以下的参数均不需要在子类中设置,只需要在使用装饰器注册类时作为参数传递即可。
|
||||
命令支持级联,即一个命令可以有多个子命令,子命令可以有子命令,以此类推。
|
||||
处理命令时,若有子命令,会以当前参数列表的第一个参数去匹配子命令,若匹配成功,则转移到子命令中执行。
|
||||
若没有匹配成功或没有子命令,则执行当前命令。
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
@@ -43,7 +65,8 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
||||
"""名称,搜索到时若符合则使用"""
|
||||
|
||||
path: str
|
||||
"""路径,所有父节点的name的连接,用于定义命令权限"""
|
||||
"""路径,所有父节点的name的连接,用于定义命令权限,由管理器在初始化时自动设置。
|
||||
"""
|
||||
|
||||
alias: list[str]
|
||||
"""同name"""
|
||||
@@ -52,6 +75,7 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
||||
"""此节点的帮助信息"""
|
||||
|
||||
usage: str = None
|
||||
"""用法"""
|
||||
|
||||
parent_class: typing.Union[typing.Type[CommandOperator], None] = None
|
||||
"""父节点类。标记以供管理器在初始化时编织父子关系。"""
|
||||
@@ -75,4 +99,15 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""实现此方法以执行命令
|
||||
|
||||
支持多次yield以返回多个结果。
|
||||
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
|
||||
|
||||
Args:
|
||||
context (entities.ExecuteContext): 命令执行上下文
|
||||
|
||||
Yields:
|
||||
entities.CommandReturn: 命令返回封装
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -8,15 +8,12 @@ from .. import model as file_model
|
||||
class JSONConfigFile(file_model.ConfigFile):
|
||||
"""JSON配置文件"""
|
||||
|
||||
config_file_name: str = None
|
||||
"""配置文件名"""
|
||||
|
||||
template_file_name: str = None
|
||||
"""模板文件名"""
|
||||
|
||||
def __init__(self, config_file_name: str, template_file_name: str) -> None:
|
||||
def __init__(
|
||||
self, config_file_name: str, template_file_name: str = None, template_data: dict = None
|
||||
) -> None:
|
||||
self.config_file_name = config_file_name
|
||||
self.template_file_name = template_file_name
|
||||
self.template_data = template_data
|
||||
|
||||
def exists(self) -> bool:
|
||||
return os.path.exists(self.config_file_name)
|
||||
@@ -29,19 +26,24 @@ class JSONConfigFile(file_model.ConfigFile):
|
||||
if not self.exists():
|
||||
await self.create()
|
||||
|
||||
with open(self.config_file_name, 'r', encoding='utf-8') as f:
|
||||
cfg = json.load(f)
|
||||
if self.template_file_name is not None:
|
||||
with open(self.config_file_name, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
# 从模板文件中进行补全
|
||||
with open(self.template_file_name, 'r', encoding='utf-8') as f:
|
||||
template_cfg = json.load(f)
|
||||
with open(self.template_file_name, "r", encoding="utf-8") as f:
|
||||
self.template_data = json.load(f)
|
||||
|
||||
for key in template_cfg:
|
||||
for key in self.template_data:
|
||||
if key not in cfg:
|
||||
cfg[key] = template_cfg[key]
|
||||
cfg[key] = self.template_data[key]
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
async def save(self, cfg: dict):
|
||||
with open(self.config_file_name, 'w', encoding='utf-8') as f:
|
||||
json.dump(cfg, f, indent=4, ensure_ascii=False)
|
||||
with open(self.config_file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(cfg, f, indent=4, ensure_ascii=False)
|
||||
|
||||
def save_sync(self, cfg: dict):
|
||||
with open(self.config_file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(cfg, f, indent=4, ensure_ascii=False)
|
||||
|
||||
@@ -60,3 +60,6 @@ class PythonModuleConfigFile(file_model.ConfigFile):
|
||||
|
||||
async def save(self, data: dict):
|
||||
logging.warning('Python模块配置文件不支持保存')
|
||||
|
||||
def save_sync(self, data: dict):
|
||||
logging.warning('Python模块配置文件不支持保存')
|
||||
@@ -26,6 +26,9 @@ class ConfigManager:
|
||||
async def dump_config(self):
|
||||
await self.file.save(self.data)
|
||||
|
||||
def dump_config_sync(self):
|
||||
self.file.save_sync(self.data)
|
||||
|
||||
|
||||
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager:
|
||||
"""加载Python模块配置文件"""
|
||||
@@ -40,11 +43,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con
|
||||
return cfg_mgr
|
||||
|
||||
|
||||
async def load_json_config(config_name: str, template_name: str) -> ConfigManager:
|
||||
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager:
|
||||
"""加载JSON配置文件"""
|
||||
cfg_inst = json_file.JSONConfigFile(
|
||||
config_name,
|
||||
template_name
|
||||
template_name,
|
||||
template_data
|
||||
)
|
||||
|
||||
cfg_mgr = ConfigManager(cfg_inst)
|
||||
|
||||
47
pkg/config/migration.py
Normal file
47
pkg/config/migration.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ..core import app
|
||||
|
||||
|
||||
preregistered_migrations: list[typing.Type[Migration]] = []
|
||||
"""当前阶段暂不支持扩展"""
|
||||
|
||||
def migration_class(name: str, number: int):
|
||||
"""注册一个迁移
|
||||
"""
|
||||
def decorator(cls: typing.Type[Migration]) -> typing.Type[Migration]:
|
||||
cls.name = name
|
||||
cls.number = number
|
||||
preregistered_migrations.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Migration(abc.ABC):
|
||||
"""一个版本的迁移
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
number: int
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
@abc.abstractmethod
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def run(self):
|
||||
"""执行迁移
|
||||
"""
|
||||
pass
|
||||
26
pkg/config/migrations/m001_sensitive_word_migration.py
Normal file
26
pkg/config/migrations/m001_sensitive_word_migration.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("sensitive-word-migration", 1)
|
||||
class SensitiveWordMigration(migration.Migration):
|
||||
"""敏感词迁移
|
||||
"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移
|
||||
"""
|
||||
return os.path.exists("data/config/sensitive-words.json") and not os.path.exists("data/metadata/sensitive-words.json")
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移
|
||||
"""
|
||||
# 移动文件
|
||||
os.rename("data/config/sensitive-words.json", "data/metadata/sensitive-words.json")
|
||||
|
||||
# 重新加载配置
|
||||
await self.ap.sensitive_meta.load_config()
|
||||
47
pkg/config/migrations/m002_openai_config_migration.py
Normal file
47
pkg/config/migrations/m002_openai_config_migration.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("openai-config-migration", 2)
|
||||
class OpenAIConfigMigration(migration.Migration):
|
||||
"""OpenAI配置迁移
|
||||
"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移
|
||||
"""
|
||||
return 'openai-config' in self.ap.provider_cfg.data
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移
|
||||
"""
|
||||
old_openai_config = self.ap.provider_cfg.data['openai-config'].copy()
|
||||
|
||||
if 'keys' not in self.ap.provider_cfg.data:
|
||||
self.ap.provider_cfg.data['keys'] = {}
|
||||
|
||||
if 'openai' not in self.ap.provider_cfg.data['keys']:
|
||||
self.ap.provider_cfg.data['keys']['openai'] = []
|
||||
|
||||
self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys']
|
||||
|
||||
self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model']
|
||||
|
||||
del old_openai_config['chat-completions-params']['model']
|
||||
|
||||
if 'requester' not in self.ap.provider_cfg.data:
|
||||
self.ap.provider_cfg.data['requester'] = {}
|
||||
|
||||
if 'openai-chat-completions' not in self.ap.provider_cfg.data['requester']:
|
||||
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {}
|
||||
|
||||
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {
|
||||
'base-url': old_openai_config['base_url'],
|
||||
'args': old_openai_config['chat-completions-params'],
|
||||
'timeout': old_openai_config['request-timeout'],
|
||||
}
|
||||
|
||||
del self.ap.provider_cfg.data['openai-config']
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("anthropic-requester-config-completion", 3)
|
||||
class AnthropicRequesterConfigCompletionMigration(migration.Migration):
|
||||
"""OpenAI配置迁移
|
||||
"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移
|
||||
"""
|
||||
return 'anthropic-messages' not in self.ap.provider_cfg.data['requester'] \
|
||||
or 'anthropic' not in self.ap.provider_cfg.data['keys']
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移
|
||||
"""
|
||||
if 'anthropic-messages' not in self.ap.provider_cfg.data['requester']:
|
||||
self.ap.provider_cfg.data['requester']['anthropic-messages'] = {
|
||||
'base-url': 'https://api.anthropic.com',
|
||||
'args': {
|
||||
'max_tokens': 1024
|
||||
},
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
if 'anthropic' not in self.ap.provider_cfg.data['keys']:
|
||||
self.ap.provider_cfg.data['keys']['anthropic'] = []
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
30
pkg/config/migrations/m004_moonshot_cfg_completion.py
Normal file
30
pkg/config/migrations/m004_moonshot_cfg_completion.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("moonshot-config-completion", 4)
|
||||
class MoonshotConfigCompletionMigration(migration.Migration):
|
||||
"""OpenAI配置迁移
|
||||
"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移
|
||||
"""
|
||||
return 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester'] \
|
||||
or 'moonshot' not in self.ap.provider_cfg.data['keys']
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移
|
||||
"""
|
||||
if 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']:
|
||||
self.ap.provider_cfg.data['requester']['moonshot-chat-completions'] = {
|
||||
'base-url': 'https://api.moonshot.cn/v1',
|
||||
'args': {},
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
if 'moonshot' not in self.ap.provider_cfg.data['keys']:
|
||||
self.ap.provider_cfg.data['keys']['moonshot'] = []
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
@@ -10,6 +10,9 @@ class ConfigFile(metaclass=abc.ABCMeta):
|
||||
template_file_name: str = None
|
||||
"""模板文件名"""
|
||||
|
||||
template_data: dict = None
|
||||
"""模板数据"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def exists(self) -> bool:
|
||||
pass
|
||||
@@ -25,3 +28,7 @@ class ConfigFile(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
async def save(self, data: dict):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def save_sync(self, data: dict):
|
||||
pass
|
||||
|
||||
@@ -4,24 +4,24 @@ import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
import aioconsole
|
||||
|
||||
from ..platform import manager as im_mgr
|
||||
from ..provider.session import sessionmgr as llm_session_mgr
|
||||
from ..provider.requester import modelmgr as llm_model_mgr
|
||||
from ..provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ..provider.tools import toolmgr as llm_tool_mgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..audit.center import v2 as center_mgr
|
||||
from ..command import cmdmgr
|
||||
from ..plugin import manager as plugin_mgr
|
||||
from . import pool, controller
|
||||
from ..pipeline import stagemgr
|
||||
from ..pipeline import pool
|
||||
from ..pipeline import controller, stagemgr
|
||||
from ..utils import version as version_mgr, proxy as proxy_mgr
|
||||
|
||||
|
||||
class Application:
|
||||
im_mgr: im_mgr.PlatformManager = None
|
||||
"""运行时应用对象和上下文"""
|
||||
|
||||
platform_mgr: im_mgr.PlatformManager = None
|
||||
|
||||
cmd_mgr: cmdmgr.CommandManager = None
|
||||
|
||||
@@ -33,6 +33,8 @@ class Application:
|
||||
|
||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||
|
||||
# ======= 配置管理器 =======
|
||||
|
||||
command_cfg: config_mgr.ConfigManager = None
|
||||
|
||||
pipeline_cfg: config_mgr.ConfigManager = None
|
||||
@@ -43,6 +45,18 @@ class Application:
|
||||
|
||||
system_cfg: config_mgr.ConfigManager = None
|
||||
|
||||
# ======= 元数据配置管理器 =======
|
||||
|
||||
sensitive_meta: config_mgr.ConfigManager = None
|
||||
|
||||
adapter_qq_botpy_meta: config_mgr.ConfigManager = None
|
||||
|
||||
plugin_setting_meta: config_mgr.ConfigManager = None
|
||||
|
||||
llm_models_meta: config_mgr.ConfigManager = None
|
||||
|
||||
# =========================
|
||||
|
||||
ctr_mgr: center_mgr.V2CenterAPI = None
|
||||
|
||||
plugin_mgr: plugin_mgr.PluginManager = None
|
||||
@@ -66,27 +80,18 @@ class Application:
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
await self.plugin_mgr.load_plugins()
|
||||
await self.plugin_mgr.initialize_plugins()
|
||||
|
||||
tasks = []
|
||||
|
||||
try:
|
||||
|
||||
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(self.im_mgr.run()),
|
||||
asyncio.create_task(self.platform_mgr.run()),
|
||||
asyncio.create_task(self.ctrl.run())
|
||||
]
|
||||
|
||||
# async def interrupt(tasks):
|
||||
# await asyncio.sleep(1.5)
|
||||
# while await aioconsole.ainput("使用 ctrl+c 或 'exit' 退出程序 > ") != 'exit':
|
||||
# pass
|
||||
# for task in tasks:
|
||||
# task.cancel()
|
||||
|
||||
# await interrupt(tasks)
|
||||
# 挂信号处理
|
||||
|
||||
import signal
|
||||
|
||||
|
||||
142
pkg/core/boot.py
142
pkg/core/boot.py
@@ -1,143 +1,33 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from .bootutils import files
|
||||
from .bootutils import deps
|
||||
from .bootutils import log
|
||||
from .bootutils import config
|
||||
|
||||
from . import app
|
||||
from . import pool
|
||||
from . import controller
|
||||
from ..pipeline import stagemgr
|
||||
from ..audit import identifier
|
||||
from ..provider.session import sessionmgr as llm_session_mgr
|
||||
from ..provider.requester import modelmgr as llm_model_mgr
|
||||
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ..provider.tools import toolmgr as llm_tool_mgr
|
||||
from ..platform import manager as im_mgr
|
||||
from ..command import cmdmgr
|
||||
from ..plugin import manager as plugin_mgr
|
||||
from ..audit.center import v2 as center_v2
|
||||
from ..utils import version, proxy, announce
|
||||
from . import stage
|
||||
|
||||
use_override = False
|
||||
# 引入启动阶段实现以便注册
|
||||
from .stages import load_config, setup_logger, build_app, migrate
|
||||
|
||||
|
||||
stage_order = [
|
||||
"LoadConfigStage",
|
||||
"MigrationStage",
|
||||
"SetupLoggerStage",
|
||||
"BuildAppStage"
|
||||
]
|
||||
|
||||
|
||||
async def make_app() -> app.Application:
|
||||
global use_override
|
||||
|
||||
generated_files = await files.generate_files()
|
||||
|
||||
if generated_files:
|
||||
print("以下文件不存在,已自动生成,请按需修改配置文件后重启:")
|
||||
for file in generated_files:
|
||||
print("-", file)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
missing_deps = await deps.check_deps()
|
||||
|
||||
if missing_deps:
|
||||
print("以下依赖包未安装,将自动安装,请完成后重启程序:")
|
||||
for dep in missing_deps:
|
||||
print("-", dep)
|
||||
await deps.install_deps(missing_deps)
|
||||
sys.exit(0)
|
||||
|
||||
qcg_logger = await log.init_logging()
|
||||
|
||||
# 生成标识符
|
||||
identifier.init()
|
||||
|
||||
# ========== 加载配置文件 ==========
|
||||
|
||||
command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json")
|
||||
pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json")
|
||||
platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json")
|
||||
provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json")
|
||||
system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json")
|
||||
|
||||
# ========== 构建应用实例 ==========
|
||||
ap = app.Application()
|
||||
ap.logger = qcg_logger
|
||||
|
||||
ap.command_cfg = command_cfg
|
||||
ap.pipeline_cfg = pipeline_cfg
|
||||
ap.platform_cfg = platform_cfg
|
||||
ap.provider_cfg = provider_cfg
|
||||
ap.system_cfg = system_cfg
|
||||
|
||||
proxy_mgr = proxy.ProxyManager(ap)
|
||||
await proxy_mgr.initialize()
|
||||
ap.proxy_mgr = proxy_mgr
|
||||
|
||||
ver_mgr = version.VersionManager(ap)
|
||||
await ver_mgr.initialize()
|
||||
ap.ver_mgr = ver_mgr
|
||||
|
||||
center_v2_api = center_v2.V2CenterAPI(
|
||||
ap,
|
||||
basic_info={
|
||||
"host_id": identifier.identifier["host_id"],
|
||||
"instance_id": identifier.identifier["instance_id"],
|
||||
"semantic_version": ver_mgr.get_current_version(),
|
||||
"platform": sys.platform,
|
||||
},
|
||||
runtime_info={
|
||||
"admin_id": "{}".format(system_cfg.data["admin-sessions"]),
|
||||
"msg_source": str([
|
||||
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
|
||||
for adapter_cfg in platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
|
||||
]),
|
||||
},
|
||||
)
|
||||
ap.ctr_mgr = center_v2_api
|
||||
|
||||
# 发送公告
|
||||
ann_mgr = announce.AnnouncementManager(ap)
|
||||
await ann_mgr.show_announcements()
|
||||
|
||||
ap.query_pool = pool.QueryPool()
|
||||
|
||||
await ap.ver_mgr.show_version_update()
|
||||
|
||||
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
|
||||
await plugin_mgr_inst.initialize()
|
||||
ap.plugin_mgr = plugin_mgr_inst
|
||||
|
||||
cmd_mgr_inst = cmdmgr.CommandManager(ap)
|
||||
await cmd_mgr_inst.initialize()
|
||||
ap.cmd_mgr = cmd_mgr_inst
|
||||
|
||||
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
|
||||
await llm_model_mgr_inst.initialize()
|
||||
ap.model_mgr = llm_model_mgr_inst
|
||||
|
||||
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
|
||||
await llm_session_mgr_inst.initialize()
|
||||
ap.sess_mgr = llm_session_mgr_inst
|
||||
|
||||
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
|
||||
await llm_prompt_mgr_inst.initialize()
|
||||
ap.prompt_mgr = llm_prompt_mgr_inst
|
||||
|
||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.im_mgr = im_mgr_inst
|
||||
|
||||
stage_mgr = stagemgr.StageManager(ap)
|
||||
await stage_mgr.initialize()
|
||||
ap.stage_mgr = stage_mgr
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
# 执行启动阶段
|
||||
for stage_name in stage_order:
|
||||
stage_cls = stage.preregistered_stages[stage_name]
|
||||
stage_inst = stage_cls()
|
||||
await stage_inst.run(ap)
|
||||
|
||||
await ap.initialize()
|
||||
|
||||
|
||||
@@ -3,11 +3,13 @@ import pip
|
||||
required_deps = {
|
||||
"requests": "requests",
|
||||
"openai": "openai",
|
||||
"anthropic": "anthropic",
|
||||
"colorlog": "colorlog",
|
||||
"mirai": "yiri-mirai-rc",
|
||||
"aiocqhttp": "aiocqhttp",
|
||||
"botpy": "qq-botpy",
|
||||
"PIL": "pillow",
|
||||
"nakuru": "nakuru-project-idk",
|
||||
"CallingGPT": "CallingGPT",
|
||||
"tiktoken": "tiktoken",
|
||||
"yaml": "pyyaml",
|
||||
"aiohttp": "aiohttp",
|
||||
|
||||
@@ -13,13 +13,13 @@ required_files = {
|
||||
"data/config/platform.json": "templates/platform.json",
|
||||
"data/config/provider.json": "templates/provider.json",
|
||||
"data/config/system.json": "templates/system.json",
|
||||
"data/config/sensitive-words.json": "templates/sensitive-words.json",
|
||||
"data/scenario/default.json": "templates/scenario-template.json",
|
||||
}
|
||||
|
||||
required_paths = [
|
||||
"temp",
|
||||
"data",
|
||||
"data/metadata",
|
||||
"data/prompts",
|
||||
"data/scenario",
|
||||
"data/logs",
|
||||
|
||||
@@ -16,6 +16,10 @@ log_colors_config = {
|
||||
|
||||
|
||||
async def init_logging() -> logging.Logger:
|
||||
# 删除所有现有的logger
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
level = logging.INFO
|
||||
|
||||
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
|
||||
@@ -46,7 +50,7 @@ async def init_logging() -> logging.Logger:
|
||||
|
||||
qcg_logger.debug("日志初始化完成,日志级别:%s" % level)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, # 设置日志输出格式
|
||||
level=logging.CRITICAL, # 设置日志输出格式
|
||||
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
|
||||
# 日志输出的格式
|
||||
# -8表示占位符,让输出左对齐,输出长度都为8位
|
||||
|
||||
@@ -9,13 +9,14 @@ import pydantic
|
||||
import mirai
|
||||
|
||||
from ..provider import entities as llm_entities
|
||||
from ..provider.requester import entities
|
||||
from ..provider.modelmgr import entities
|
||||
from ..provider.sysprompt import entities as sysprompt_entities
|
||||
from ..provider.tools import entities as tools_entities
|
||||
from ..platform import adapter as msadapter
|
||||
|
||||
|
||||
class LauncherTypes(enum.Enum):
|
||||
"""一个请求的发起者类型"""
|
||||
|
||||
PERSON = 'person'
|
||||
"""私聊"""
|
||||
@@ -31,43 +32,43 @@ class Query(pydantic.BaseModel):
|
||||
"""请求ID,添加进请求池时生成"""
|
||||
|
||||
launcher_type: LauncherTypes
|
||||
"""会话类型,platform设置"""
|
||||
"""会话类型,platform处理阶段设置"""
|
||||
|
||||
launcher_id: int
|
||||
"""会话ID,platform设置"""
|
||||
"""会话ID,platform处理阶段设置"""
|
||||
|
||||
sender_id: int
|
||||
"""发送者ID,platform设置"""
|
||||
"""发送者ID,platform处理阶段设置"""
|
||||
|
||||
message_event: mirai.MessageEvent
|
||||
"""事件,platform收到的事件"""
|
||||
"""事件,platform收到的原始事件"""
|
||||
|
||||
message_chain: mirai.MessageChain
|
||||
"""消息链,platform收到的消息链"""
|
||||
"""消息链,platform收到的原始消息链"""
|
||||
|
||||
adapter: msadapter.MessageSourceAdapter
|
||||
"""适配器对象"""
|
||||
"""消息平台适配器对象,单个app中可能启用了多个消息平台适配器,此对象表明发起此query的适配器"""
|
||||
|
||||
session: typing.Optional[Session] = None
|
||||
"""会话对象,由前置处理器设置"""
|
||||
"""会话对象,由前置处理器阶段设置"""
|
||||
|
||||
messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""历史消息列表,由前置处理器设置"""
|
||||
"""历史消息列表,由前置处理器阶段设置"""
|
||||
|
||||
prompt: typing.Optional[sysprompt_entities.Prompt] = None
|
||||
"""情景预设内容,由前置处理器设置"""
|
||||
"""情景预设内容,由前置处理器阶段设置"""
|
||||
|
||||
user_message: typing.Optional[llm_entities.Message] = None
|
||||
"""此次请求的用户消息对象,由前置处理器设置"""
|
||||
"""此次请求的用户消息对象,由前置处理器阶段设置"""
|
||||
|
||||
use_model: typing.Optional[entities.LLMModelInfo] = None
|
||||
"""使用的模型,由前置处理器设置"""
|
||||
"""使用的模型,由前置处理器阶段设置"""
|
||||
|
||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
|
||||
"""使用的函数,由前置处理器设置"""
|
||||
"""使用的函数,由前置处理器阶段设置"""
|
||||
|
||||
resp_messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""由provider生成的回复消息对象列表"""
|
||||
"""由Process阶段生成的回复消息对象列表"""
|
||||
|
||||
resp_message_chain: typing.Optional[mirai.MessageChain] = None
|
||||
"""回复消息链,从resp_messages包装而得"""
|
||||
@@ -77,7 +78,7 @@ class Query(pydantic.BaseModel):
|
||||
|
||||
|
||||
class Conversation(pydantic.BaseModel):
|
||||
"""对话"""
|
||||
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation,但只有一个当前使用的 Conversation"""
|
||||
|
||||
prompt: sysprompt_entities.Prompt
|
||||
|
||||
@@ -93,7 +94,7 @@ class Conversation(pydantic.BaseModel):
|
||||
|
||||
|
||||
class Session(pydantic.BaseModel):
|
||||
"""会话"""
|
||||
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
|
||||
launcher_type: LauncherTypes
|
||||
|
||||
launcher_id: int
|
||||
@@ -111,6 +112,7 @@ class Session(pydantic.BaseModel):
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
semaphore: typing.Optional[asyncio.Semaphore] = None
|
||||
"""当前会话的信号量,用于限制并发"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
34
pkg/core/stage.py
Normal file
34
pkg/core/stage.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from . import app
|
||||
|
||||
|
||||
preregistered_stages: dict[str, typing.Type[BootingStage]] = {}
|
||||
"""预注册的请求处理阶段。在初始化时,所有请求处理阶段类会被注册到此字典中。
|
||||
|
||||
当前阶段暂不支持扩展
|
||||
"""
|
||||
|
||||
def stage_class(
|
||||
name: str
|
||||
):
|
||||
def decorator(cls: typing.Type[BootingStage]) -> typing.Type[BootingStage]:
|
||||
preregistered_stages[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class BootingStage(abc.ABC):
|
||||
"""启动阶段
|
||||
"""
|
||||
name: str = None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def run(self, ap: app.Application):
|
||||
"""启动
|
||||
"""
|
||||
pass
|
||||
96
pkg/core/stages/build_app.py
Normal file
96
pkg/core/stages/build_app.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from .. import stage, app
|
||||
from ...utils import version, proxy, announce, platform
|
||||
from ...audit.center import v2 as center_v2
|
||||
from ...audit import identifier
|
||||
from ...pipeline import pool, controller, stagemgr
|
||||
from ...plugin import manager as plugin_mgr
|
||||
from ...command import cmdmgr
|
||||
from ...provider.session import sessionmgr as llm_session_mgr
|
||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||
from ...platform import manager as im_mgr
|
||||
|
||||
|
||||
@stage.stage_class("BuildAppStage")
|
||||
class BuildAppStage(stage.BootingStage):
|
||||
"""构建应用阶段
|
||||
"""
|
||||
|
||||
async def run(self, ap: app.Application):
|
||||
"""构建app对象的各个组件对象并初始化
|
||||
"""
|
||||
|
||||
proxy_mgr = proxy.ProxyManager(ap)
|
||||
await proxy_mgr.initialize()
|
||||
ap.proxy_mgr = proxy_mgr
|
||||
|
||||
ver_mgr = version.VersionManager(ap)
|
||||
await ver_mgr.initialize()
|
||||
ap.ver_mgr = ver_mgr
|
||||
|
||||
center_v2_api = center_v2.V2CenterAPI(
|
||||
ap,
|
||||
basic_info={
|
||||
"host_id": identifier.identifier["host_id"],
|
||||
"instance_id": identifier.identifier["instance_id"],
|
||||
"semantic_version": ver_mgr.get_current_version(),
|
||||
"platform": platform.get_platform(),
|
||||
},
|
||||
runtime_info={
|
||||
"admin_id": "{}".format(ap.system_cfg.data["admin-sessions"]),
|
||||
"msg_source": str([
|
||||
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
|
||||
for adapter_cfg in ap.platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
|
||||
]),
|
||||
},
|
||||
)
|
||||
ap.ctr_mgr = center_v2_api
|
||||
|
||||
# 发送公告
|
||||
ann_mgr = announce.AnnouncementManager(ap)
|
||||
await ann_mgr.show_announcements()
|
||||
|
||||
ap.query_pool = pool.QueryPool()
|
||||
|
||||
await ap.ver_mgr.show_version_update()
|
||||
|
||||
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
|
||||
await plugin_mgr_inst.initialize()
|
||||
ap.plugin_mgr = plugin_mgr_inst
|
||||
await plugin_mgr_inst.load_plugins()
|
||||
|
||||
cmd_mgr_inst = cmdmgr.CommandManager(ap)
|
||||
await cmd_mgr_inst.initialize()
|
||||
ap.cmd_mgr = cmd_mgr_inst
|
||||
|
||||
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
|
||||
await llm_model_mgr_inst.initialize()
|
||||
ap.model_mgr = llm_model_mgr_inst
|
||||
|
||||
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
|
||||
await llm_session_mgr_inst.initialize()
|
||||
ap.sess_mgr = llm_session_mgr_inst
|
||||
|
||||
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
|
||||
await llm_prompt_mgr_inst.initialize()
|
||||
ap.prompt_mgr = llm_prompt_mgr_inst
|
||||
|
||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.platform_mgr = im_mgr_inst
|
||||
|
||||
stage_mgr = stagemgr.StageManager(ap)
|
||||
await stage_mgr.initialize()
|
||||
ap.stage_mgr = stage_mgr
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
31
pkg/core/stages/load_config.py
Normal file
31
pkg/core/stages/load_config.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import stage, app
|
||||
from ..bootutils import config
|
||||
|
||||
|
||||
@stage.stage_class("LoadConfigStage")
|
||||
class LoadConfigStage(stage.BootingStage):
|
||||
"""加载配置文件阶段
|
||||
"""
|
||||
|
||||
async def run(self, ap: app.Application):
|
||||
"""启动
|
||||
"""
|
||||
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json")
|
||||
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json")
|
||||
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json")
|
||||
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json")
|
||||
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json")
|
||||
|
||||
ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json")
|
||||
await ap.plugin_setting_meta.dump_config()
|
||||
|
||||
ap.sensitive_meta = await config.load_json_config("data/metadata/sensitive-words.json", "templates/metadata/sensitive-words.json")
|
||||
await ap.sensitive_meta.dump_config()
|
||||
|
||||
ap.adapter_qq_botpy_meta = await config.load_json_config("data/metadata/adapter-qq-botpy.json", "templates/metadata/adapter-qq-botpy.json")
|
||||
await ap.adapter_qq_botpy_meta.dump_config()
|
||||
|
||||
ap.llm_models_meta = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json")
|
||||
await ap.llm_models_meta.dump_config()
|
||||
28
pkg/core/stages/migrate.py
Normal file
28
pkg/core/stages/migrate.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
|
||||
from .. import stage, app
|
||||
from ...config import migration
|
||||
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
|
||||
|
||||
|
||||
@stage.stage_class("MigrationStage")
|
||||
class MigrationStage(stage.BootingStage):
|
||||
"""迁移阶段
|
||||
"""
|
||||
|
||||
async def run(self, ap: app.Application):
|
||||
"""启动
|
||||
"""
|
||||
|
||||
migrations = migration.preregistered_migrations
|
||||
|
||||
# 按照迁移号排序
|
||||
migrations.sort(key=lambda x: x.number)
|
||||
|
||||
for migration_cls in migrations:
|
||||
migration_instance = migration_cls(ap)
|
||||
|
||||
if await migration_instance.need_migrate():
|
||||
await migration_instance.run()
|
||||
15
pkg/core/stages/setup_logger.py
Normal file
15
pkg/core/stages/setup_logger.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import stage, app
|
||||
from ..bootutils import log
|
||||
|
||||
|
||||
@stage.stage_class("SetupLoggerStage")
|
||||
class SetupLoggerStage(stage.BootingStage):
|
||||
"""设置日志器阶段
|
||||
"""
|
||||
|
||||
async def run(self, ap: app.Application):
|
||||
"""启动
|
||||
"""
|
||||
ap.logger = await log.init_logging()
|
||||
@@ -8,6 +8,7 @@ from ...config import manager as cfg_mgr
|
||||
|
||||
@stage.stage_class('BanSessionCheckStage')
|
||||
class BanSessionCheckStage(stage.PipelineStage):
|
||||
"""访问控制处理阶段"""
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
@@ -24,22 +25,24 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
|
||||
sess_list = self.ap.pipeline_cfg.data['access-control'][mode]
|
||||
|
||||
if (query.launcher_type == 'group' and 'group_*' in sess_list) \
|
||||
or (query.launcher_type == 'person' and 'person_*' in sess_list):
|
||||
if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \
|
||||
or (query.launcher_type.value == 'person' and 'person_*' in sess_list):
|
||||
found = True
|
||||
else:
|
||||
for sess in sess_list:
|
||||
if sess == f"{query.launcher_type}_{query.launcher_id}":
|
||||
if sess == f"{query.launcher_type.value}_{query.launcher_id}":
|
||||
found = True
|
||||
break
|
||||
|
||||
ctn = False
|
||||
|
||||
result = False
|
||||
|
||||
if mode == 'blacklist':
|
||||
result = found
|
||||
if mode == 'whitelist':
|
||||
ctn = found
|
||||
else:
|
||||
ctn = not found
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT,
|
||||
result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
debug_notice=f'根据访问控制忽略消息: {query.launcher_type}_{query.launcher_id}' if result else ''
|
||||
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else ''
|
||||
)
|
||||
|
||||
@@ -7,28 +7,38 @@ from ...core import app
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from . import filter, entities as filter_entities
|
||||
from . import filter as filter_model, entities as filter_entities
|
||||
from .filters import cntignore, banwords, baiduexamine
|
||||
|
||||
|
||||
@stage.stage_class('PostContentFilterStage')
|
||||
@stage.stage_class('PreContentFilterStage')
|
||||
class ContentFilterStage(stage.PipelineStage):
|
||||
"""内容过滤阶段"""
|
||||
|
||||
filter_chain: list[filter.ContentFilter]
|
||||
filter_chain: list[filter_model.ContentFilter]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.filter_chain = []
|
||||
super().__init__(ap)
|
||||
|
||||
async def initialize(self):
|
||||
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
|
||||
|
||||
filters_required = [
|
||||
"content-filter"
|
||||
]
|
||||
|
||||
if self.ap.pipeline_cfg.data['check-sensitive-words']:
|
||||
self.filter_chain.append(banwords.BanWordFilter(self.ap))
|
||||
|
||||
filters_required.append("ban-word-filter")
|
||||
|
||||
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
|
||||
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
|
||||
filters_required.append("baidu-cloud-examine")
|
||||
|
||||
for filter in filter_model.preregistered_filters:
|
||||
if filter.name in filters_required:
|
||||
self.filter_chain.append(
|
||||
filter(self.ap)
|
||||
)
|
||||
|
||||
for filter in self.filter_chain:
|
||||
await filter.initialize()
|
||||
|
||||
@@ -31,15 +31,24 @@ class EnableStage(enum.Enum):
|
||||
|
||||
class FilterResult(pydantic.BaseModel):
|
||||
level: ResultLevel
|
||||
"""结果等级
|
||||
|
||||
对于前置处理阶段,只要有任意一个返回 非PASS 的内容过滤器结果,就会中断处理。
|
||||
对于后置处理阶段,当且内容过滤器返回 BLOCK 时,会中断处理。
|
||||
"""
|
||||
|
||||
replacement: str
|
||||
"""替换后的消息"""
|
||||
"""替换后的消息
|
||||
|
||||
内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。
|
||||
若没有修改内容,也需要返回原消息。
|
||||
"""
|
||||
|
||||
user_notice: str
|
||||
"""不通过时,用户提示消息"""
|
||||
"""不通过时,若此值不为空,将对用户提示消息"""
|
||||
|
||||
console_notice: str
|
||||
"""不通过时,控制台提示消息"""
|
||||
"""不通过时,若此值不为空,将在控制台提示消息"""
|
||||
|
||||
|
||||
class ManagerResultLevel(enum.Enum):
|
||||
|
||||
@@ -1,12 +1,42 @@
|
||||
# 内容过滤器的抽象类
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
preregistered_filters: list[typing.Type[ContentFilter]] = []
|
||||
|
||||
|
||||
def filter_class(
|
||||
name: str
|
||||
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
|
||||
"""内容过滤器类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 过滤器名称
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器
|
||||
"""
|
||||
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
|
||||
assert issubclass(cls, ContentFilter)
|
||||
|
||||
cls.name = name
|
||||
|
||||
preregistered_filters.append(cls)
|
||||
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ContentFilter(metaclass=abc.ABCMeta):
|
||||
"""内容过滤器抽象类"""
|
||||
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -16,6 +46,11 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
@property
|
||||
def enable_stages(self):
|
||||
"""启用的阶段
|
||||
|
||||
默认为消息请求AI前后的两个阶段。
|
||||
|
||||
entity.EnableStage.PRE: 消息请求AI前,此时需要检查的内容是用户的输入消息。
|
||||
entity.EnableStage.POST: 消息请求AI后,此时需要检查的内容是AI的回复消息。
|
||||
"""
|
||||
return [
|
||||
entities.EnableStage.PRE,
|
||||
@@ -30,5 +65,14 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
|
||||
分为前后阶段,具体取决于 enable_stages 的值。
|
||||
对于内容过滤器来说,不需要考虑消息所处的阶段,只需要检查消息内容即可。
|
||||
|
||||
Args:
|
||||
message (str): 需要检查的内容
|
||||
|
||||
Returns:
|
||||
entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -10,6 +10,7 @@ BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v
|
||||
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
|
||||
|
||||
@filter_model.filter_class("baidu-cloud-examine")
|
||||
class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
"""百度云内容审核"""
|
||||
|
||||
|
||||
@@ -6,34 +6,30 @@ from .. import entities
|
||||
from ....config import manager as cfg_mgr
|
||||
|
||||
|
||||
@filter_model.filter_class("ban-word-filter")
|
||||
class BanWordFilter(filter_model.ContentFilter):
|
||||
"""根据内容禁言"""
|
||||
|
||||
sensitive: cfg_mgr.ConfigManager
|
||||
|
||||
async def initialize(self):
|
||||
self.sensitive = await cfg_mgr.load_json_config(
|
||||
"data/config/sensitive-words.json",
|
||||
"templates/sensitive-words.json"
|
||||
)
|
||||
pass
|
||||
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
found = False
|
||||
|
||||
for word in self.sensitive.data['words']:
|
||||
for word in self.ap.sensitive_meta.data['words']:
|
||||
match = re.findall(word, message)
|
||||
|
||||
if len(match) > 0:
|
||||
found = True
|
||||
|
||||
for i in range(len(match)):
|
||||
if self.sensitive.data['mask_word'] == "":
|
||||
if self.ap.sensitive_meta.data['mask_word'] == "":
|
||||
message = message.replace(
|
||||
match[i], self.sensitive.data['mask'] * len(match[i])
|
||||
match[i], self.ap.sensitive_meta.data['mask'] * len(match[i])
|
||||
)
|
||||
else:
|
||||
message = message.replace(
|
||||
match[i], self.sensitive.data['mask_word']
|
||||
match[i], self.ap.sensitive_meta.data['mask_word']
|
||||
)
|
||||
|
||||
return entities.FilterResult(
|
||||
|
||||
@@ -5,6 +5,7 @@ from .. import entities
|
||||
from .. import filter as filter_model
|
||||
|
||||
|
||||
@filter_model.filter_class("content-ignore")
|
||||
class ContentIgnore(filter_model.ContentFilter):
|
||||
"""根据内容忽略消息"""
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import asyncio
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from . import app, entities
|
||||
from ..pipeline import entities as pipeline_entities
|
||||
from ..core import app, entities
|
||||
from . import entities as pipeline_entities
|
||||
from ..plugin import events
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class Controller:
|
||||
"""检查输出
|
||||
"""
|
||||
if result.user_notice:
|
||||
await self.ap.im_mgr.send(
|
||||
await self.ap.platform_mgr.send(
|
||||
query.message_event,
|
||||
result.user_notice,
|
||||
query.adapter
|
||||
@@ -85,7 +85,7 @@ class Controller:
|
||||
stage_index: int,
|
||||
query: entities.Query,
|
||||
):
|
||||
"""从指定阶段开始执行
|
||||
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
|
||||
|
||||
如何看懂这里为什么这么写?
|
||||
去问 GPT-4:
|
||||
@@ -15,6 +15,8 @@ from ...config import manager as cfg_mgr
|
||||
|
||||
@stage.stage_class("LongTextProcessStage")
|
||||
class LongTextProcessStage(stage.PipelineStage):
|
||||
"""长消息处理阶段
|
||||
"""
|
||||
|
||||
strategy_impl: strategy.LongTextStrategy
|
||||
|
||||
@@ -43,11 +45,14 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font))
|
||||
|
||||
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
|
||||
|
||||
if config['strategy'] == 'image':
|
||||
self.strategy_impl = image.Text2ImageStrategy(self.ap)
|
||||
elif config['strategy'] == 'forward':
|
||||
self.strategy_impl = forward.ForwardComponentStrategy(self.ap)
|
||||
|
||||
for strategy_cls in strategy.preregistered_strategies:
|
||||
if strategy_cls.name == config['strategy']:
|
||||
self.strategy_impl = strategy_cls(self.ap)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"未找到名为 {config['strategy']} 的长消息处理策略")
|
||||
|
||||
await self.strategy_impl.initialize()
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
|
||||
@@ -36,6 +36,7 @@ class Forward(MessageComponent):
|
||||
return '[聊天记录]'
|
||||
|
||||
|
||||
@strategy_model.strategy_class("forward")
|
||||
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
|
||||
|
||||
@@ -15,6 +15,7 @@ from .. import strategy as strategy_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@strategy_model.strategy_class("image")
|
||||
class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
text_render_font: ImageFont.FreeTypeFont
|
||||
|
||||
@@ -9,7 +9,39 @@ from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
|
||||
|
||||
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
|
||||
|
||||
|
||||
def strategy_class(
|
||||
name: str
|
||||
) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]:
|
||||
"""长文本处理策略类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 策略名称
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: 装饰器
|
||||
"""
|
||||
|
||||
def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]:
|
||||
assert issubclass(cls, LongTextStrategy)
|
||||
|
||||
cls.name = name
|
||||
|
||||
preregistered_strategies.append(cls)
|
||||
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
"""长文本处理策略抽象类
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
@@ -20,4 +52,15 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
|
||||
"""处理长文本
|
||||
|
||||
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
|
||||
|
||||
Args:
|
||||
message (str): 消息
|
||||
query (core_entities.Query): 此次请求的上下文对象
|
||||
|
||||
Returns:
|
||||
list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表
|
||||
"""
|
||||
return []
|
||||
|
||||
@@ -4,11 +4,12 @@ import asyncio
|
||||
|
||||
import mirai
|
||||
|
||||
from . import entities
|
||||
from ..core import entities
|
||||
from ..platform import adapter as msadapter
|
||||
|
||||
|
||||
class QueryPool:
|
||||
"""请求池,请求获得调度进入pipeline之前,保存在这里"""
|
||||
|
||||
query_id_counter: int = 0
|
||||
|
||||
@@ -8,7 +8,7 @@ from ...plugin import events
|
||||
|
||||
@stage.stage_class("PreProcessor")
|
||||
class PreProcessor(stage.PipelineStage):
|
||||
"""预处理器
|
||||
"""请求预处理阶段
|
||||
"""
|
||||
|
||||
async def process(
|
||||
@@ -51,28 +51,6 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.prompt.messages = event_ctx.event.default_prompt
|
||||
query.messages = event_ctx.event.prompt
|
||||
|
||||
# 根据模型max_tokens剪裁
|
||||
max_tokens = min(query.use_model.max_tokens, self.ap.pipeline_cfg.data['submit-messages-tokens'])
|
||||
|
||||
test_messages = query.prompt.messages + query.messages + [query.user_message]
|
||||
|
||||
while await query.use_model.tokenizer.count_token(test_messages, query.use_model) > max_tokens:
|
||||
# 前文都pop完了,还是大于max_tokens,由于prompt和user_messages不能删减,报错
|
||||
if len(query.prompt.messages) == 0:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice='输入内容过长,请减少情景预设或者输入内容长度',
|
||||
console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 submit-messages-tokens 项(但不能超过所用模型最大tokens数)'
|
||||
)
|
||||
|
||||
query.messages.pop(0) # pop第一个肯定是role=user的
|
||||
# 继续pop到第二个role=user前一个
|
||||
while len(query.messages) > 0 and query.messages[0].role != 'user':
|
||||
query.messages.pop(0)
|
||||
|
||||
test_messages = query.prompt.messages + query.messages + [query.user_message]
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
|
||||
@@ -23,3 +23,12 @@ class MessageHandler(metaclass=abc.ABCMeta):
|
||||
query: core_entities.Query,
|
||||
) -> entities.StageProcessResult:
|
||||
raise NotImplementedError
|
||||
|
||||
def cut_str(self, s: str) -> str:
|
||||
"""
|
||||
取字符串第一行,最多20个字符,若有多行,或超过20个字符,则加省略号
|
||||
"""
|
||||
s0 = s.split('\n')[0]
|
||||
if len(s0) > 20 or '\n' in s:
|
||||
s0 = s0[:20] + '...'
|
||||
return s0
|
||||
|
||||
@@ -21,8 +21,6 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理
|
||||
"""
|
||||
# 取session
|
||||
# 取conversation
|
||||
# 调API
|
||||
# 生成器
|
||||
|
||||
@@ -41,7 +39,14 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
mc = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
role='plugin',
|
||||
content=str(mc),
|
||||
)
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
@@ -78,6 +83,8 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
async for result in query.use_model.requester.request(query):
|
||||
query.resp_messages.append(result)
|
||||
|
||||
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
|
||||
|
||||
if result.content is not None:
|
||||
text_length += len(result.content)
|
||||
|
||||
@@ -86,6 +93,9 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
new_query=query
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}')
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
|
||||
@@ -19,15 +19,16 @@ class CommandHandler(handler.MessageHandler):
|
||||
"""处理
|
||||
"""
|
||||
|
||||
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent
|
||||
|
||||
command_text = str(query.message_chain).strip()[1:]
|
||||
|
||||
privilege = 1
|
||||
|
||||
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']:
|
||||
privilege = 2
|
||||
|
||||
spt = str(query.message_chain).strip().split(' ')
|
||||
spt = command_text.split(' ')
|
||||
|
||||
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_class(
|
||||
@@ -73,8 +74,6 @@ class CommandHandler(handler.MessageHandler):
|
||||
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
command_text = str(query.message_chain).strip()[1:]
|
||||
|
||||
async for ret in self.ap.cmd_mgr.execute(
|
||||
command_text=command_text,
|
||||
query=query,
|
||||
@@ -91,6 +90,8 @@ class CommandHandler(handler.MessageHandler):
|
||||
)
|
||||
)
|
||||
|
||||
self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}')
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
@@ -106,6 +107,8 @@ class CommandHandler(handler.MessageHandler):
|
||||
)
|
||||
)
|
||||
|
||||
self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}')
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
|
||||
@@ -11,6 +11,7 @@ from ...config import manager as cfg_mgr
|
||||
|
||||
@stage.stage_class("MessageProcessor")
|
||||
class Processor(stage.PipelineStage):
|
||||
"""请求实际处理阶段"""
|
||||
|
||||
cmd_handler: handler.MessageHandler
|
||||
|
||||
@@ -34,7 +35,12 @@ class Processor(stage.PipelineStage):
|
||||
|
||||
self.ap.logger.info(f"处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}")
|
||||
|
||||
if message_text.startswith('!') or message_text.startswith('!'):
|
||||
return self.cmd_handler.handle(query)
|
||||
else:
|
||||
return self.chat_handler.handle(query)
|
||||
async def generator():
|
||||
if message_text.startswith('!') or message_text.startswith('!'):
|
||||
async for result in self.cmd_handler.handle(query):
|
||||
yield result
|
||||
else:
|
||||
async for result in self.chat_handler.handle(query):
|
||||
yield result
|
||||
|
||||
return generator()
|
||||
|
||||
@@ -1,11 +1,27 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
|
||||
|
||||
class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
|
||||
|
||||
def algo_class(name: str):
|
||||
|
||||
def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]:
|
||||
cls.name = name
|
||||
preregistered_algos.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
"""限流算法抽象类"""
|
||||
|
||||
name: str = None
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
@@ -16,9 +32,27 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
|
||||
"""进入处理流程
|
||||
|
||||
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
|
||||
|
||||
Args:
|
||||
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
|
||||
launcher_id (int): 请求者ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许进入处理流程,若返回false,则直接丢弃该请求
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def release_access(self, launcher_type: str, launcher_id: int):
|
||||
"""退出处理流程
|
||||
|
||||
Args:
|
||||
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
|
||||
launcher_id (int): 请求者ID
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -19,6 +19,7 @@ class SessionContainer:
|
||||
self.records = {}
|
||||
|
||||
|
||||
@algo.algo_class("fixwin")
|
||||
class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
|
||||
containers_lock: asyncio.Lock
|
||||
|
||||
@@ -11,11 +11,24 @@ from ...core import entities as core_entities
|
||||
@stage.stage_class("RequireRateLimitOccupancy")
|
||||
@stage.stage_class("ReleaseRateLimitOccupancy")
|
||||
class RateLimit(stage.PipelineStage):
|
||||
"""限速器控制阶段"""
|
||||
|
||||
algo: algo.ReteLimitAlgo
|
||||
|
||||
async def initialize(self):
|
||||
self.algo = fixedwin.FixedWindowAlgo(self.ap)
|
||||
|
||||
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
|
||||
|
||||
algo_class = None
|
||||
|
||||
for algo_cls in algo.preregistered_algos:
|
||||
if algo_cls.name == algo_name:
|
||||
algo_class = algo_cls
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'未知的限速算法: {algo_name}')
|
||||
|
||||
self.algo = algo_class(self.ap)
|
||||
await self.algo.initialize()
|
||||
|
||||
async def process(
|
||||
@@ -46,7 +59,7 @@ class RateLimit(stage.PipelineStage):
|
||||
)
|
||||
elif stage_inst_name == "ReleaseRateLimitOccupancy":
|
||||
await self.algo.release_access(
|
||||
query.launcher_type,
|
||||
query.launcher_type.value,
|
||||
query.launcher_id,
|
||||
)
|
||||
return entities.StageProcessResult(
|
||||
|
||||
@@ -29,7 +29,7 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
|
||||
await asyncio.sleep(random_delay)
|
||||
|
||||
await self.ap.im_mgr.send(
|
||||
await self.ap.platform_mgr.send(
|
||||
query.message_event,
|
||||
query.resp_message_chain,
|
||||
adapter=query.adapter
|
||||
|
||||
@@ -21,15 +21,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
async def initialize(self):
|
||||
"""初始化检查器
|
||||
"""
|
||||
self.rule_matchers = [
|
||||
atbot.AtBotRule(self.ap),
|
||||
prefix.PrefixRule(self.ap),
|
||||
regexp.RegExpRule(self.ap),
|
||||
random.RandomRespRule(self.ap),
|
||||
]
|
||||
|
||||
for rule_matcher in self.rule_matchers:
|
||||
await rule_matcher.initialize()
|
||||
self.rule_matchers = []
|
||||
|
||||
for rule_matcher in rule.preregisetered_rules:
|
||||
rule_inst = rule_matcher(self.ap)
|
||||
await rule_inst.initialize()
|
||||
self.rule_matchers.append(rule_inst)
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
|
||||
@@ -7,9 +8,20 @@ from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
|
||||
|
||||
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
|
||||
|
||||
def rule_class(name: str):
|
||||
def decorator(cls: typing.Type[GroupRespondRule]) -> typing.Type[GroupRespondRule]:
|
||||
cls.name = name
|
||||
preregisetered_rules.append(cls)
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
|
||||
class GroupRespondRule(metaclass=abc.ABCMeta):
|
||||
"""群组响应规则的抽象类
|
||||
"""
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@rule_model.rule_class("at-bot")
|
||||
class AtBotRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
@@ -19,6 +20,10 @@ class AtBotRule(rule_model.GroupRespondRule):
|
||||
|
||||
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||
message_chain.remove(mirai.At(query.adapter.bot_account_id))
|
||||
|
||||
if message_chain.has(mirai.At(query.adapter.bot_account_id)): # 回复消息时会at两次,检查并删除重复的
|
||||
message_chain.remove(mirai.At(query.adapter.bot_account_id))
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=True,
|
||||
replacement=message_chain,
|
||||
|
||||
@@ -5,6 +5,7 @@ from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@rule_model.rule_class("prefix")
|
||||
class PrefixRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
|
||||
@@ -7,6 +7,7 @@ from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@rule_model.rule_class("random")
|
||||
class RandomRespRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
|
||||
@@ -7,6 +7,7 @@ from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@rule_model.rule_class("regexp")
|
||||
class RegExpRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
|
||||
@@ -15,6 +15,7 @@ from .preproc import preproc
|
||||
from .ratelimit import ratelimit
|
||||
|
||||
|
||||
# 请求处理阶段顺序
|
||||
stage_order = [
|
||||
"GroupRespondRuleCheckStage",
|
||||
"BanSessionCheckStage",
|
||||
|
||||
@@ -29,6 +29,13 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
if query.resp_messages[-1].role == 'command':
|
||||
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif query.resp_messages[-1].role == 'plugin':
|
||||
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
|
||||
@@ -14,6 +14,14 @@ preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = []
|
||||
def adapter_class(
|
||||
name: str
|
||||
):
|
||||
"""消息平台适配器类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 适配器名称
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[MessageSourceAdapter]], typing.Type[MessageSourceAdapter]]: 装饰器
|
||||
"""
|
||||
def decorator(cls: typing.Type[MessageSourceAdapter]) -> typing.Type[MessageSourceAdapter]:
|
||||
cls.name = name
|
||||
preregistered_adapters.append(cls)
|
||||
@@ -22,15 +30,24 @@ def adapter_class(
|
||||
|
||||
|
||||
class MessageSourceAdapter(metaclass=abc.ABCMeta):
|
||||
"""消息平台适配器基类"""
|
||||
|
||||
name: str
|
||||
|
||||
bot_account_id: int
|
||||
"""机器人账号ID,需要在初始化时设置"""
|
||||
|
||||
config: dict
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application):
|
||||
"""初始化适配器
|
||||
|
||||
Args:
|
||||
config (dict): 对应的配置
|
||||
ap (app.Application): 应用上下文
|
||||
"""
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
|
||||
@@ -40,7 +57,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
|
||||
target_id: str,
|
||||
message: mirai.MessageChain
|
||||
):
|
||||
"""发送消息
|
||||
"""主动发送消息
|
||||
|
||||
Args:
|
||||
target_type (str): 目标类型,`person`或`group`
|
||||
|
||||
@@ -146,7 +146,7 @@ class PlatformManager:
|
||||
if len(self.adapters) == 0:
|
||||
self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。')
|
||||
|
||||
async def send(self, event, msg, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
|
||||
async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
|
||||
|
||||
if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage):
|
||||
|
||||
@@ -163,25 +163,6 @@ class PlatformManager:
|
||||
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False
|
||||
)
|
||||
|
||||
# 通知系统管理员
|
||||
# TODO delete
|
||||
# async def notify_admin(self, message: str):
|
||||
# await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
|
||||
|
||||
# async def notify_admin_message_chain(self, message: mirai.MessageChain):
|
||||
# if self.ap.system_cfg.data['admin-sessions'] != []:
|
||||
|
||||
# admin_list = []
|
||||
# for admin in self.ap.system_cfg.data['admin-sessions']:
|
||||
# admin_list.append(admin)
|
||||
|
||||
# for adm in admin_list:
|
||||
# self.adapter.send_message(
|
||||
# adm.split("_")[0],
|
||||
# adm.split("_")[1],
|
||||
# message
|
||||
# )
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
tasks = []
|
||||
|
||||
@@ -40,7 +40,6 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
elif type(msg) is mirai.Voice:
|
||||
msg_list.append(aiocqhttp.MessageSegment.record(msg.path))
|
||||
elif type(msg) is forward.Forward:
|
||||
# print("aiocqhttp 暂不支持转发消息组件的转换,使用普通消息链发送")
|
||||
|
||||
for node in msg.node_list:
|
||||
msg_list.extend(AiocqhttpMessageConverter.yiri2target(node.message_chain)[0])
|
||||
@@ -170,7 +169,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
|
||||
name=event.sender["nickname"],
|
||||
permission=mirai.models.entities.Permission.Member,
|
||||
),
|
||||
special_title=event.sender["title"],
|
||||
special_title=event.sender["title"] if "title" in event.sender else "",
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
@@ -216,21 +215,28 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
|
||||
|
||||
self.ap = ap
|
||||
|
||||
self.bot = aiocqhttp.CQHttp()
|
||||
if "access-token" in config:
|
||||
self.bot = aiocqhttp.CQHttp(access_token=config["access-token"])
|
||||
del self.config["access-token"]
|
||||
else:
|
||||
self.bot = aiocqhttp.CQHttp()
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: mirai.MessageChain
|
||||
):
|
||||
# TODO 实现发送消息
|
||||
return super().send_message(target_type, target_id, message)
|
||||
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
|
||||
|
||||
if target_type == "group":
|
||||
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
|
||||
elif target_type == "person":
|
||||
await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg)
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: mirai.MessageEvent,
|
||||
message: mirai.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
|
||||
):
|
||||
aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
|
||||
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
|
||||
if quote_origin:
|
||||
|
||||
@@ -24,6 +24,8 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
msg_list = message_chain.__root__
|
||||
elif type(message_chain) is list:
|
||||
msg_list = message_chain
|
||||
elif type(message_chain) is str:
|
||||
msg_list = [mirai.Plain(message_chain)]
|
||||
else:
|
||||
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import botpy.types.message as botpy_message_type
|
||||
from .. import adapter as adapter_model
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
class OfficialGroupMessage(mirai.GroupMessage):
|
||||
@@ -34,52 +35,92 @@ cached_message_ids = {}
|
||||
|
||||
id_index = 0
|
||||
|
||||
|
||||
def save_msg_id(message_id: str) -> int:
|
||||
"""保存消息id"""
|
||||
global id_index, cached_message_ids
|
||||
|
||||
|
||||
crt_index = id_index
|
||||
id_index += 1
|
||||
cached_message_ids[str(crt_index)] = message_id
|
||||
return crt_index
|
||||
|
||||
cached_member_openids = {}
|
||||
"""QQ官方 用户的id是字符串,而YiriMirai的用户id是整数,所以需要一个索引来进行转换"""
|
||||
|
||||
member_openid_index = 100
|
||||
|
||||
def save_member_openid(member_openid: str) -> int:
|
||||
"""保存用户id"""
|
||||
global member_openid_index, cached_member_openids
|
||||
def char_to_value(char):
|
||||
"""将单个字符转换为相应的数值。"""
|
||||
if '0' <= char <= '9':
|
||||
return ord(char) - ord('0')
|
||||
elif 'A' <= char <= 'Z':
|
||||
return ord(char) - ord('A') + 10
|
||||
|
||||
if member_openid in cached_member_openids.values():
|
||||
return list(cached_member_openids.keys())[list(cached_member_openids.values()).index(member_openid)]
|
||||
|
||||
crt_index = member_openid_index
|
||||
member_openid_index += 1
|
||||
cached_member_openids[str(crt_index)] = member_openid
|
||||
return crt_index
|
||||
return ord(char) - ord('a') + 36
|
||||
|
||||
cached_group_openids = {}
|
||||
"""QQ官方 群组的id是字符串,而YiriMirai的群组id是整数,所以需要一个索引来进行转换"""
|
||||
def digest(s: str) -> int:
|
||||
"""计算字符串的hash值。"""
|
||||
# 取末尾的8位
|
||||
sub_s = s[-10:]
|
||||
|
||||
group_openid_index = 1000
|
||||
number = 0
|
||||
base = 36
|
||||
|
||||
def save_group_openid(group_openid: str) -> int:
|
||||
"""保存群组id"""
|
||||
global group_openid_index, cached_group_openids
|
||||
for i in range(len(sub_s)):
|
||||
number = number * base + char_to_value(sub_s[i])
|
||||
|
||||
return number
|
||||
|
||||
K = typing.TypeVar("K")
|
||||
V = typing.TypeVar("V")
|
||||
|
||||
|
||||
class OpenIDMapping(typing.Generic[K, V]):
|
||||
|
||||
map: dict[K, V]
|
||||
|
||||
dump_func: typing.Callable
|
||||
|
||||
digest_func: typing.Callable[[K], V]
|
||||
|
||||
def __init__(self, map: dict[K, V], dump_func: typing.Callable, digest_func: typing.Callable[[K], V] = digest):
|
||||
self.map = map
|
||||
|
||||
self.dump_func = dump_func
|
||||
|
||||
self.digest_func = digest_func
|
||||
|
||||
def __getitem__(self, key: K) -> V:
|
||||
return self.map[key]
|
||||
|
||||
def __setitem__(self, key: K, value: V):
|
||||
self.map[key] = value
|
||||
self.dump_func()
|
||||
|
||||
def __contains__(self, key: K) -> bool:
|
||||
return key in self.map
|
||||
|
||||
def __delitem__(self, key: K):
|
||||
del self.map[key]
|
||||
self.dump_func()
|
||||
|
||||
def getkey(self, value: V) -> K:
|
||||
return list(self.map.keys())[list(self.map.values()).index(value)]
|
||||
|
||||
if group_openid in cached_group_openids.values():
|
||||
return list(cached_group_openids.keys())[list(cached_group_openids.values()).index(group_openid)]
|
||||
|
||||
crt_index = group_openid_index
|
||||
group_openid_index += 1
|
||||
cached_group_openids[str(crt_index)] = group_openid
|
||||
return crt_index
|
||||
def save_openid(self, key: K) -> V:
|
||||
|
||||
if key in self.map:
|
||||
return self.map[key]
|
||||
|
||||
value = self.digest_func(key)
|
||||
|
||||
self.map[key] = value
|
||||
|
||||
self.dump_func()
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
"""QQ 官方消息转换器"""
|
||||
|
||||
@staticmethod
|
||||
def yiri2target(message_chain: mirai.MessageChain):
|
||||
"""将 YiriMirai 的消息链转换为 QQ 官方消息"""
|
||||
@@ -89,9 +130,13 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
msg_list = message_chain.__root__
|
||||
elif type(message_chain) is list:
|
||||
msg_list = message_chain
|
||||
elif type(message_chain) is str:
|
||||
msg_list = [mirai.Plain(text=message_chain)]
|
||||
else:
|
||||
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
|
||||
|
||||
raise Exception(
|
||||
"Unknown message type: " + str(message_chain) + str(type(message_chain))
|
||||
)
|
||||
|
||||
offcial_messages: list[dict] = []
|
||||
"""
|
||||
{
|
||||
@@ -108,36 +153,24 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
# 遍历并转换
|
||||
for component in msg_list:
|
||||
if type(component) is mirai.Plain:
|
||||
offcial_messages.append({
|
||||
"type": "text",
|
||||
"content": component.text
|
||||
})
|
||||
offcial_messages.append({"type": "text", "content": component.text})
|
||||
elif type(component) is mirai.Image:
|
||||
if component.url is not None:
|
||||
offcial_messages.append(
|
||||
{
|
||||
"type": "image",
|
||||
"content": component.url
|
||||
}
|
||||
)
|
||||
offcial_messages.append({"type": "image", "content": component.url})
|
||||
elif component.path is not None:
|
||||
offcial_messages.append(
|
||||
{
|
||||
"type": "file_image",
|
||||
"content": component.path
|
||||
}
|
||||
{"type": "file_image", "content": component.path}
|
||||
)
|
||||
elif type(component) is mirai.At:
|
||||
offcial_messages.append(
|
||||
{
|
||||
"type": "at",
|
||||
"content": ""
|
||||
}
|
||||
)
|
||||
offcial_messages.append({"type": "at", "content": ""})
|
||||
elif type(component) is mirai.AtAll:
|
||||
print("上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。")
|
||||
print(
|
||||
"上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
|
||||
)
|
||||
elif type(component) is mirai.Voice:
|
||||
print("上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。")
|
||||
print(
|
||||
"上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
|
||||
)
|
||||
elif type(component) is forward.Forward:
|
||||
# 转发消息
|
||||
yiri_forward_node_list = component.node_list
|
||||
@@ -146,22 +179,33 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
for yiri_forward_node in yiri_forward_node_list:
|
||||
try:
|
||||
message_chain = yiri_forward_node.message_chain
|
||||
|
||||
|
||||
# 平铺
|
||||
offcial_messages.extend(OfficialMessageConverter.yiri2target(message_chain))
|
||||
offcial_messages.extend(
|
||||
OfficialMessageConverter.yiri2target(message_chain)
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
return offcial_messages
|
||||
|
||||
|
||||
@staticmethod
|
||||
def extract_message_chain_from_obj(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage], message_id: str = None, bot_account_id: int = 0) -> mirai.MessageChain:
|
||||
def extract_message_chain_from_obj(
|
||||
message: typing.Union[botpy_message.Message, botpy_message.DirectMessage],
|
||||
message_id: str = None,
|
||||
bot_account_id: int = 0,
|
||||
) -> mirai.MessageChain:
|
||||
yiri_msg_list = []
|
||||
|
||||
# 存id
|
||||
|
||||
yiri_msg_list.append(mirai.models.message.Source(id=save_msg_id(message_id), time=datetime.datetime.now()))
|
||||
yiri_msg_list.append(
|
||||
mirai.models.message.Source(
|
||||
id=save_msg_id(message_id), time=datetime.datetime.now()
|
||||
)
|
||||
)
|
||||
|
||||
if type(message) is not botpy_message.DirectMessage:
|
||||
yiri_msg_list.append(mirai.At(target=bot_account_id))
|
||||
@@ -177,7 +221,9 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
if attachment.content_type == "image":
|
||||
yiri_msg_list.append(mirai.Image(url=attachment.url))
|
||||
else:
|
||||
logging.warning("不支持的附件类型:" + attachment.content_type + ",忽略此附件。")
|
||||
logging.warning(
|
||||
"不支持的附件类型:" + attachment.content_type + ",忽略此附件。"
|
||||
)
|
||||
|
||||
content = re.sub(r"<@!\d+>", "", str(message.content))
|
||||
if content.strip() != "":
|
||||
@@ -186,29 +232,40 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
chain = mirai.MessageChain(yiri_msg_list)
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
|
||||
class OfficialEventConverter(adapter_model.EventConverter):
|
||||
"""事件转换器"""
|
||||
@staticmethod
|
||||
def yiri2target(event: typing.Type[mirai.Event]):
|
||||
|
||||
member_openid_mapping: OpenIDMapping[str, int]
|
||||
group_openid_mapping: OpenIDMapping[str, int]
|
||||
|
||||
def __init__(self, member_openid_mapping: OpenIDMapping[str, int], group_openid_mapping: OpenIDMapping[str, int]):
|
||||
self.member_openid_mapping = member_openid_mapping
|
||||
self.group_openid_mapping = group_openid_mapping
|
||||
|
||||
def yiri2target(self, event: typing.Type[mirai.Event]):
|
||||
if event == mirai.GroupMessage:
|
||||
return botpy_message.Message
|
||||
elif event == mirai.FriendMessage:
|
||||
return botpy_message.DirectMessage
|
||||
else:
|
||||
raise Exception("未支持转换的事件类型(YiriMirai -> Official): " + str(event))
|
||||
raise Exception(
|
||||
"未支持转换的事件类型(YiriMirai -> Official): " + str(event)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(event: typing.Union[botpy_message.Message, botpy_message.DirectMessage]) -> mirai.Event:
|
||||
def target2yiri(
|
||||
self,
|
||||
event: typing.Union[botpy_message.Message, botpy_message.DirectMessage]
|
||||
) -> mirai.Event:
|
||||
import mirai.models.entities as mirai_entities
|
||||
|
||||
if type(event) == botpy_message.Message: # 频道内,转群聊事件
|
||||
permission = "MEMBER"
|
||||
|
||||
if '2' in event.member.roles:
|
||||
if "2" in event.member.roles:
|
||||
permission = "ADMINISTRATOR"
|
||||
elif '4' in event.member.roles:
|
||||
elif "4" in event.member.roles:
|
||||
permission = "OWNER"
|
||||
|
||||
return mirai.GroupMessage(
|
||||
@@ -219,15 +276,25 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
group=mirai_entities.Group(
|
||||
id=event.channel_id,
|
||||
name=event.author.username,
|
||||
permission=mirai_entities.Permission.Member
|
||||
permission=mirai_entities.Permission.Member,
|
||||
),
|
||||
special_title="",
|
||||
join_timestamp=int(
|
||||
datetime.datetime.strptime(
|
||||
event.member.joined_at, "%Y-%m-%dT%H:%M:%S%z"
|
||||
).timestamp()
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=int(datetime.datetime.strptime(event.member.joined_at, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
|
||||
last_speak_timestamp=datetime.datetime.now().timestamp(),
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
|
||||
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
|
||||
event, event.id
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
|
||||
).timestamp()
|
||||
),
|
||||
)
|
||||
elif type(event) == botpy_message.DirectMessage: # 私聊,转私聊事件
|
||||
return mirai.FriendMessage(
|
||||
@@ -236,12 +303,18 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
nickname=event.author.username,
|
||||
remark=event.author.username,
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
|
||||
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
|
||||
event, event.id
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
|
||||
).timestamp()
|
||||
),
|
||||
)
|
||||
elif type(event) == botpy_message.GroupMessage:
|
||||
|
||||
replacing_member_id = save_member_openid(event.author.member_openid)
|
||||
replacing_member_id = self.member_openid_mapping.save_openid(event.author.member_openid)
|
||||
|
||||
return OfficialGroupMessage(
|
||||
sender=mirai_entities.GroupMember(
|
||||
@@ -249,29 +322,36 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
member_name=replacing_member_id,
|
||||
permission="MEMBER",
|
||||
group=mirai_entities.Group(
|
||||
id=save_group_openid(event.group_openid),
|
||||
id=self.group_openid_mapping.save_openid(event.group_openid),
|
||||
name=replacing_member_id,
|
||||
permission=mirai_entities.Permission.Member
|
||||
permission=mirai_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
special_title="",
|
||||
join_timestamp=int(0),
|
||||
last_speak_timestamp=datetime.datetime.now().timestamp(),
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
|
||||
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
|
||||
event, event.id
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
|
||||
).timestamp()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@adapter_model.adapter_class("qq-botpy")
|
||||
class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
||||
"""QQ 官方消息适配器"""
|
||||
|
||||
bot: botpy.Client = None
|
||||
|
||||
bot_account_id: int = 0
|
||||
|
||||
message_converter: OfficialMessageConverter = OfficialMessageConverter()
|
||||
# event_handler: adapter_model.EventHandler = adapter_model.EventHandler()
|
||||
message_converter: OfficialMessageConverter
|
||||
event_converter: OfficialEventConverter
|
||||
|
||||
cfg: dict = None
|
||||
|
||||
@@ -283,6 +363,11 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
metadata: cfg_mgr.ConfigManager = None
|
||||
|
||||
member_openid_mapping: OpenIDMapping[str, int] = None
|
||||
group_openid_mapping: OpenIDMapping[str, int] = None
|
||||
|
||||
def __init__(self, cfg: dict, ap: app.Application):
|
||||
"""初始化适配器"""
|
||||
self.cfg = cfg
|
||||
@@ -290,86 +375,119 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
||||
|
||||
switchs = {}
|
||||
|
||||
for intent in cfg['intents']:
|
||||
for intent in cfg["intents"]:
|
||||
switchs[intent] = True
|
||||
|
||||
del cfg['intents']
|
||||
del cfg["intents"]
|
||||
|
||||
intents = botpy.Intents(**switchs)
|
||||
|
||||
self.bot = botpy.Client(intents=intents)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: mirai.MessageChain
|
||||
self, target_type: str, target_id: str, message: mirai.MessageChain
|
||||
):
|
||||
pass
|
||||
message_list = self.message_converter.yiri2target(message)
|
||||
|
||||
for msg in message_list:
|
||||
args = {}
|
||||
|
||||
if msg["type"] == "text":
|
||||
args["content"] = msg["content"]
|
||||
elif msg["type"] == "image":
|
||||
args["image"] = msg["content"]
|
||||
elif msg["type"] == "file_image":
|
||||
args["file_image"] = msg["content"]
|
||||
else:
|
||||
continue
|
||||
|
||||
if target_type == "group":
|
||||
args["channel_id"] = str(target_id)
|
||||
|
||||
await self.bot.api.post_message(**args)
|
||||
elif target_type == "person":
|
||||
args["guild_id"] = str(target_id)
|
||||
|
||||
await self.bot.api.post_dms(**args)
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: mirai.MessageEvent,
|
||||
message: mirai.MessageChain,
|
||||
quote_origin: bool = False
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
|
||||
message_list = self.message_converter.yiri2target(message)
|
||||
tasks = []
|
||||
|
||||
msg_seq = 1
|
||||
|
||||
for msg in message_list:
|
||||
args = {}
|
||||
|
||||
if msg['type'] == 'text':
|
||||
args['content'] = msg['content']
|
||||
elif msg['type'] == 'image':
|
||||
args['image'] = msg['content']
|
||||
elif msg['type'] == 'file_image':
|
||||
args['file_image'] = msg["content"]
|
||||
if msg["type"] == "text":
|
||||
args["content"] = msg["content"]
|
||||
elif msg["type"] == "image":
|
||||
args["image"] = msg["content"]
|
||||
elif msg["type"] == "file_image":
|
||||
args["file_image"] = msg["content"]
|
||||
else:
|
||||
continue
|
||||
|
||||
if quote_origin:
|
||||
args['message_reference'] = botpy_message_type.Reference(message_id=cached_message_ids[str(message_source.message_chain.message_id)])
|
||||
|
||||
if type(message_source) == mirai.GroupMessage:
|
||||
args['channel_id'] = str(message_source.sender.group.id)
|
||||
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
|
||||
await self.bot.api.post_message(**args)
|
||||
elif type(message_source) == mirai.FriendMessage:
|
||||
args['guild_id'] = str(message_source.sender.id)
|
||||
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
|
||||
await self.bot.api.post_dms(**args)
|
||||
elif type(message_source) == OfficialGroupMessage:
|
||||
# args['guild_id'] = str(message_source.sender.group.id)
|
||||
# args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
|
||||
# await self.bot.api.post_message(**args)
|
||||
if 'image' in args or 'file_image' in args:
|
||||
continue
|
||||
args['group_openid'] = cached_group_openids[str(message_source.sender.group.id)]
|
||||
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
|
||||
args['msg_seq'] = msg_seq
|
||||
msg_seq += 1
|
||||
await self.bot.api.post_group_message(
|
||||
**args
|
||||
args["message_reference"] = botpy_message_type.Reference(
|
||||
message_id=cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
)
|
||||
|
||||
if type(message_source) == mirai.GroupMessage:
|
||||
args["channel_id"] = str(message_source.sender.group.id)
|
||||
args["msg_id"] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
await self.bot.api.post_message(**args)
|
||||
elif type(message_source) == mirai.FriendMessage:
|
||||
args["guild_id"] = str(message_source.sender.id)
|
||||
args["msg_id"] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
await self.bot.api.post_dms(**args)
|
||||
elif type(message_source) == OfficialGroupMessage:
|
||||
if "image" in args or "file_image" in args:
|
||||
continue
|
||||
args["group_openid"] = self.group_openid_mapping.getkey(
|
||||
message_source.sender.group.id
|
||||
)
|
||||
|
||||
args["msg_id"] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
args["msg_seq"] = msg_seq
|
||||
msg_seq += 1
|
||||
await self.bot.api.post_group_message(**args)
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
callback: typing.Callable[
|
||||
[mirai.Event, adapter_model.MessageSourceAdapter], None
|
||||
],
|
||||
):
|
||||
|
||||
|
||||
try:
|
||||
|
||||
async def wrapper(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage]):
|
||||
async def wrapper(
|
||||
message: typing.Union[
|
||||
botpy_message.Message,
|
||||
botpy_message.DirectMessage,
|
||||
botpy_message.GroupMessage,
|
||||
]
|
||||
):
|
||||
self.cached_official_messages[str(message.id)] = message
|
||||
await callback(OfficialEventConverter.target2yiri(message), self)
|
||||
await callback(self.event_converter.target2yiri(message), self)
|
||||
|
||||
for event_handler in event_handler_mapping[event_type]:
|
||||
setattr(self.bot, event_handler, wrapper)
|
||||
@@ -380,15 +498,33 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
callback: typing.Callable[
|
||||
[mirai.Event, adapter_model.MessageSourceAdapter], None
|
||||
],
|
||||
):
|
||||
delattr(self.bot, event_handler_mapping[event_type])
|
||||
|
||||
async def run_async(self):
|
||||
self.ap.logger.info("运行 QQ 官方适配器")
|
||||
await self.bot.start(
|
||||
**self.cfg
|
||||
|
||||
self.metadata = self.ap.adapter_qq_botpy_meta
|
||||
|
||||
self.member_openid_mapping = OpenIDMapping(
|
||||
map=self.metadata.data["mapping"]["members"],
|
||||
dump_func=self.metadata.dump_config_sync,
|
||||
)
|
||||
|
||||
self.group_openid_mapping = OpenIDMapping(
|
||||
map=self.metadata.data["mapping"]["groups"],
|
||||
dump_func=self.metadata.dump_config_sync,
|
||||
)
|
||||
|
||||
self.message_converter = OfficialMessageConverter()
|
||||
self.event_converter = OfficialEventConverter(
|
||||
self.member_openid_mapping, self.group_openid_mapping
|
||||
)
|
||||
|
||||
self.ap.logger.info("运行 QQ 官方适配器")
|
||||
await self.bot.start(**self.cfg)
|
||||
|
||||
def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -9,10 +9,86 @@ from ..provider.tools import entities as tools_entities
|
||||
from ..core import app
|
||||
|
||||
|
||||
def register(
|
||||
name: str,
|
||||
description: str,
|
||||
version: str,
|
||||
author: str
|
||||
) -> typing.Callable[[typing.Type[BasePlugin]], typing.Type[BasePlugin]]:
|
||||
"""注册插件类
|
||||
|
||||
使用示例:
|
||||
|
||||
@register(
|
||||
name="插件名称",
|
||||
description="插件描述",
|
||||
version="插件版本",
|
||||
author="插件作者"
|
||||
)
|
||||
class MyPlugin(BasePlugin):
|
||||
pass
|
||||
"""
|
||||
pass
|
||||
|
||||
def handler(
|
||||
event: typing.Type[events.BaseEventModel]
|
||||
) -> typing.Callable[[typing.Callable], typing.Callable]:
|
||||
"""注册事件监听器
|
||||
|
||||
使用示例:
|
||||
|
||||
class MyPlugin(BasePlugin):
|
||||
|
||||
@handler(NormalMessageResponded)
|
||||
async def on_normal_message_responded(self, ctx: EventContext):
|
||||
pass
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def llm_func(
|
||||
name: str=None,
|
||||
) -> typing.Callable:
|
||||
"""注册内容函数
|
||||
|
||||
使用示例:
|
||||
|
||||
class MyPlugin(BasePlugin):
|
||||
|
||||
@llm_func("access_the_web_page")
|
||||
async def _(self, query, url: str, brief_len: int):
|
||||
\"""Call this function to search about the question before you answer any questions.
|
||||
- Do not search through google.com at any time.
|
||||
- If you need to search somthing, visit https://www.sogou.com/web?query=<something>.
|
||||
- If user ask you to open a url (start with http:// or https://), visit it directly.
|
||||
- Summary the plain content result by yourself, DO NOT directly output anything in the result you got.
|
||||
|
||||
Args:
|
||||
url(str): url to visit
|
||||
brief_len(int): max length of the plain text content, recommend 1024-4096, prefer 4096
|
||||
|
||||
Returns:
|
||||
str: plain text content of the web page or error message(starts with 'error:')
|
||||
\"""
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BasePlugin(metaclass=abc.ABCMeta):
|
||||
"""插件基类"""
|
||||
|
||||
host: APIHost
|
||||
"""API宿主"""
|
||||
|
||||
ap: app.Application
|
||||
"""应用程序对象"""
|
||||
|
||||
def __init__(self, host: APIHost):
|
||||
self.host = host
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化插件"""
|
||||
pass
|
||||
|
||||
|
||||
class APIHost:
|
||||
@@ -61,8 +137,10 @@ class EventContext:
|
||||
"""事件编号"""
|
||||
|
||||
host: APIHost = None
|
||||
"""API宿主"""
|
||||
|
||||
event: events.BaseEventModel = None
|
||||
"""此次事件的对象,具体类型为handler注册时指定监听的类型,可查看events.py中的定义"""
|
||||
|
||||
__prevent_default__ = False
|
||||
"""是否阻止默认行为"""
|
||||
|
||||
@@ -10,8 +10,10 @@ from ..provider import entities as llm_entities
|
||||
|
||||
|
||||
class BaseEventModel(pydantic.BaseModel):
|
||||
"""事件模型基类"""
|
||||
|
||||
query: typing.Union[core_entities.Query, None]
|
||||
"""此次请求的query对象,非请求过程的事件时为None"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# 此模块已过时
|
||||
# 请从 pkg.plugin.context 引入 BasePlugin, EventContext 和 APIHost
|
||||
# 最早将于 v3.4 移除此模块
|
||||
|
||||
from . events import *
|
||||
from . context import EventContext, APIHost as PluginHost
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..core import app
|
||||
|
||||
|
||||
class PluginInstaller(metaclass=abc.ABCMeta):
|
||||
"""插件安装器抽象类"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from ...utils import pkgmgr
|
||||
|
||||
|
||||
class GitHubRepoInstaller(installer.PluginInstaller):
|
||||
"""GitHub仓库插件安装器
|
||||
"""
|
||||
|
||||
def get_github_plugin_repo_label(self, repo_url: str) -> list[str]:
|
||||
"""获取username, repo"""
|
||||
|
||||
@@ -9,7 +9,7 @@ from . import context, events
|
||||
|
||||
|
||||
class PluginLoader(metaclass=abc.ABCMeta):
|
||||
"""插件加载器"""
|
||||
"""插件加载器抽象类"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
|
||||
@@ -5,11 +5,10 @@ import pkgutil
|
||||
import importlib
|
||||
import traceback
|
||||
|
||||
from CallingGPT.entities.namespace import get_func_schema
|
||||
|
||||
from .. import loader, events, context, models, host
|
||||
from ...core import entities as core_entities
|
||||
from ...provider.tools import entities as tools_entities
|
||||
from ...utils import funcschema
|
||||
|
||||
|
||||
class PluginLoader(loader.PluginLoader):
|
||||
@@ -29,6 +28,10 @@ class PluginLoader(loader.PluginLoader):
|
||||
setattr(models, 'on', self.on)
|
||||
setattr(models, 'func', self.func)
|
||||
|
||||
setattr(context, 'register', self.register)
|
||||
setattr(context, 'handler', self.handler)
|
||||
setattr(context, 'llm_func', self.llm_func)
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
@@ -57,6 +60,8 @@ class PluginLoader(loader.PluginLoader):
|
||||
|
||||
return wrapper
|
||||
|
||||
# 过时
|
||||
# 最早将于 v3.4 版本移除
|
||||
def on(
|
||||
self,
|
||||
event: typing.Type[events.BaseEventModel]
|
||||
@@ -83,6 +88,8 @@ class PluginLoader(loader.PluginLoader):
|
||||
|
||||
return wrapper
|
||||
|
||||
# 过时
|
||||
# 最早将于 v3.4 版本移除
|
||||
def func(
|
||||
self,
|
||||
name: str=None,
|
||||
@@ -91,10 +98,11 @@ class PluginLoader(loader.PluginLoader):
|
||||
self.ap.logger.debug(f'注册内容函数 {name}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
function_schema = get_func_schema(func)
|
||||
function_schema = funcschema.get_func_schema(func)
|
||||
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
||||
|
||||
async def handler(
|
||||
plugin: context.BasePlugin,
|
||||
query: core_entities.Query,
|
||||
*args,
|
||||
**kwargs
|
||||
@@ -116,6 +124,46 @@ class PluginLoader(loader.PluginLoader):
|
||||
|
||||
return wrapper
|
||||
|
||||
def handler(
|
||||
self,
|
||||
event: typing.Type[events.BaseEventModel]
|
||||
) -> typing.Callable[[typing.Callable], typing.Callable]:
|
||||
"""注册事件处理器"""
|
||||
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
self._current_container.event_handlers[event] = func
|
||||
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
def llm_func(
|
||||
self,
|
||||
name: str=None,
|
||||
) -> typing.Callable:
|
||||
"""注册内容函数"""
|
||||
self.ap.logger.debug(f'注册内容函数 {name}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
function_schema = funcschema.get_func_schema(func)
|
||||
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
||||
|
||||
llm_function = tools_entities.LLMFunction(
|
||||
name=function_name,
|
||||
human_desc='',
|
||||
description=function_schema['description'],
|
||||
enable=True,
|
||||
parameters=function_schema['parameters'],
|
||||
func=func,
|
||||
)
|
||||
|
||||
self._current_container.content_functions.append(llm_function)
|
||||
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
async def _walk_plugin_path(
|
||||
self,
|
||||
module,
|
||||
@@ -5,11 +5,12 @@ import traceback
|
||||
|
||||
from ..core import app
|
||||
from . import context, loader, events, installer, setting, models
|
||||
from .loaders import legacy
|
||||
from .loaders import classic
|
||||
from .installers import github
|
||||
|
||||
|
||||
class PluginManager:
|
||||
"""插件管理器"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -25,7 +26,7 @@ class PluginManager:
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.loader = legacy.PluginLoader(ap)
|
||||
self.loader = classic.PluginLoader(ap)
|
||||
self.installer = github.GitHubRepoInstaller(ap)
|
||||
self.setting = setting.SettingManager(ap)
|
||||
self.api_host = context.APIHost(ap)
|
||||
@@ -51,6 +52,9 @@ class PluginManager:
|
||||
for plugin in self.plugins:
|
||||
try:
|
||||
plugin.plugin_inst = plugin.plugin_class(self.api_host)
|
||||
plugin.plugin_inst.ap = self.ap
|
||||
plugin.plugin_inst.host = self.api_host
|
||||
await plugin.plugin_inst.initialize()
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
|
||||
self.ap.logger.exception(e)
|
||||
@@ -136,9 +140,8 @@ class PluginManager:
|
||||
for plugin in self.plugins:
|
||||
if plugin.enabled:
|
||||
if event.__class__ in plugin.event_handlers:
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__}')
|
||||
|
||||
emitted_plugins.append(plugin)
|
||||
|
||||
is_prevented_default_before_call = ctx.is_prevented_default()
|
||||
|
||||
try:
|
||||
@@ -150,6 +153,8 @@ class PluginManager:
|
||||
self.ap.logger.error(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__} 时发生错误: {e}')
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
emitted_plugins.append(plugin)
|
||||
|
||||
if not is_prevented_default_before_call and ctx.is_prevented_default():
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')
|
||||
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# 此模块已过时,请引入 pkg.plugin.context 中的 register, handler 和 llm_func 来注册插件、事件处理函数和内容函数
|
||||
# 各个事件模型请从 pkg.plugin.events 引入
|
||||
# 最早将于 v3.4 移除此模块
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
@@ -6,6 +6,7 @@ from . import context
|
||||
|
||||
|
||||
class SettingManager:
|
||||
"""插件设置管理器"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -15,10 +16,7 @@ class SettingManager:
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
self.settings = await cfg_mgr.load_json_config(
|
||||
'plugins/plugins.json',
|
||||
'templates/plugin-settings.json'
|
||||
)
|
||||
self.settings = self.ap.plugin_setting_meta
|
||||
|
||||
async def sync_setting(
|
||||
self,
|
||||
|
||||
@@ -20,14 +20,31 @@ class ToolCall(pydantic.BaseModel):
|
||||
|
||||
|
||||
class Message(pydantic.BaseModel):
|
||||
role: str # user, system, assistant, tool, command
|
||||
"""消息"""
|
||||
|
||||
role: str # user, system, assistant, tool, command, plugin
|
||||
"""消息的角色"""
|
||||
|
||||
name: typing.Optional[str] = None
|
||||
"""名称,仅函数调用返回时设置"""
|
||||
|
||||
content: typing.Optional[str] = None
|
||||
"""内容"""
|
||||
|
||||
function_call: typing.Optional[FunctionCall] = None
|
||||
"""函数调用,不再受支持,请使用tool_calls"""
|
||||
|
||||
tool_calls: typing.Optional[list[ToolCall]] = None
|
||||
"""工具调用"""
|
||||
|
||||
tool_call_id: typing.Optional[str] = None
|
||||
|
||||
def readable_str(self) -> str:
|
||||
if self.content is not None:
|
||||
return self.content
|
||||
elif self.function_call is not None:
|
||||
return f'{self.function_call.name}({self.function_call.arguments})'
|
||||
elif self.tool_calls is not None:
|
||||
return f'调用工具: {self.tool_calls[0].id}'
|
||||
else:
|
||||
return '未知消息'
|
||||
|
||||
52
pkg/provider/modelmgr/api.py
Normal file
52
pkg/provider/modelmgr/api.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
|
||||
|
||||
preregistered_requesters: list[typing.Type[LLMAPIRequester]] = []
|
||||
|
||||
def requester_class(name: str):
|
||||
|
||||
def decorator(cls: typing.Type[LLMAPIRequester]) -> typing.Type[LLMAPIRequester]:
|
||||
cls.name = name
|
||||
preregistered_requesters.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""LLM API请求器
|
||||
"""
|
||||
name: str = None
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def request(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求API
|
||||
|
||||
对话前文可以从 query 对象中获取。
|
||||
可以多次yield消息对象。
|
||||
|
||||
Args:
|
||||
query (core_entities.Query): 本次请求的上下文对象
|
||||
|
||||
Yields:
|
||||
pkg.provider.entities.Message: 返回消息对象
|
||||
"""
|
||||
raise NotImplementedError
|
||||
0
pkg/provider/modelmgr/apis/__init__.py
Normal file
0
pkg/provider/modelmgr/apis/__init__.py
Normal file
82
pkg/provider/modelmgr/apis/anthropicmsgs.py
Normal file
82
pkg/provider/modelmgr/apis/anthropicmsgs.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
import anthropic
|
||||
|
||||
from .. import api, entities, errors
|
||||
|
||||
from .. import api, entities, errors
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
|
||||
|
||||
@api.requester_class("anthropic-messages")
|
||||
class AnthropicMessages(api.LLMAPIRequester):
|
||||
"""Anthropic Messages API 请求器"""
|
||||
|
||||
client: anthropic.AsyncAnthropic
|
||||
|
||||
async def initialize(self):
|
||||
self.client = anthropic.AsyncAnthropic(
|
||||
api_key="",
|
||||
base_url=self.ap.provider_cfg.data['requester']['anthropic-messages']['base-url'],
|
||||
timeout=self.ap.provider_cfg.data['requester']['anthropic-messages']['timeout'],
|
||||
proxies=self.ap.proxy_mgr.get_forward_proxies()
|
||||
)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
self.client.api_key = query.use_model.token_mgr.get_token()
|
||||
|
||||
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
|
||||
args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name
|
||||
|
||||
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
|
||||
m.dict(exclude_none=True) for m in query.prompt.messages
|
||||
] + [m.dict(exclude_none=True) for m in query.messages]
|
||||
|
||||
# 删除所有 role=system & content='' 的消息
|
||||
req_messages = [
|
||||
m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "")
|
||||
]
|
||||
|
||||
# 检查是否有 role=system 的消息,若有,改为 role=user,并在后面加一个 role=assistant 的消息
|
||||
system_role_index = []
|
||||
for i, m in enumerate(req_messages):
|
||||
if m["role"] == "system":
|
||||
system_role_index.append(i)
|
||||
m["role"] = "user"
|
||||
|
||||
if system_role_index:
|
||||
for i in system_role_index[::-1]:
|
||||
req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."})
|
||||
|
||||
# 忽略掉空消息,用户可能发送空消息,而上层未过滤
|
||||
req_messages = [
|
||||
m for m in req_messages if m["content"].strip() != ""
|
||||
]
|
||||
|
||||
args["messages"] = req_messages
|
||||
|
||||
try:
|
||||
|
||||
resp = await self.client.messages.create(**args)
|
||||
|
||||
yield llm_entities.Message(
|
||||
content=resp.content[0].text,
|
||||
role=resp.role
|
||||
)
|
||||
except anthropic.AuthenticationError as e:
|
||||
raise errors.RequesterError(f'api-key 无效: {e.message}')
|
||||
except anthropic.BadRequestError as e:
|
||||
raise errors.RequesterError(str(e.message))
|
||||
except anthropic.NotFoundError as e:
|
||||
if 'model: ' in str(e):
|
||||
raise errors.RequesterError(f'模型无效: {e.message}')
|
||||
else:
|
||||
raise errors.RequesterError(f'请求地址无效: {e.message}')
|
||||
@@ -9,22 +9,31 @@ import openai
|
||||
import openai.types.chat.chat_completion as chat_completion
|
||||
import httpx
|
||||
|
||||
from pkg.provider.entities import Message
|
||||
|
||||
from .. import api, entities, errors
|
||||
from ....core import entities as core_entities
|
||||
from ....core import entities as core_entities, app
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
|
||||
|
||||
class OpenAIChatCompletion(api.LLMAPIRequester):
|
||||
@api.requester_class("openai-chat-completions")
|
||||
class OpenAIChatCompletions(api.LLMAPIRequester):
|
||||
"""OpenAI ChatCompletion API 请求器"""
|
||||
|
||||
client: openai.AsyncClient
|
||||
|
||||
requester_cfg: dict
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
self.requester_cfg = self.ap.provider_cfg.data['requester']['openai-chat-completions']
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
self.client = openai.AsyncClient(
|
||||
api_key="",
|
||||
base_url=self.ap.provider_cfg.data['openai-config']['base_url'],
|
||||
timeout=self.ap.provider_cfg.data['openai-config']['request-timeout'],
|
||||
base_url=self.requester_cfg['base-url'],
|
||||
timeout=self.requester_cfg['timeout'],
|
||||
http_client=httpx.AsyncClient(
|
||||
proxies=self.ap.proxy_mgr.get_forward_proxies()
|
||||
)
|
||||
@@ -55,7 +64,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.ap.provider_cfg.data['openai-config']['chat-completions-params'].copy()
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
|
||||
if use_model.tool_call_supported:
|
||||
@@ -124,14 +133,17 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
|
||||
|
||||
req_messages.append(msg.dict(exclude_none=True))
|
||||
|
||||
async def request(self, query: core_entities.Query) -> AsyncGenerator[Message, None]:
|
||||
async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]:
|
||||
try:
|
||||
async for msg in self._request(query):
|
||||
yield msg
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
except openai.BadRequestError as e:
|
||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||
if 'context_length_exceeded' in e.message:
|
||||
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
|
||||
else:
|
||||
raise errors.RequesterError(f'请求参数错误: {e.message}')
|
||||
except openai.AuthenticationError as e:
|
||||
raise errors.RequesterError(f'无效的 api-key: {e.message}')
|
||||
except openai.NotFoundError as e:
|
||||
15
pkg/provider/modelmgr/apis/moonshotchatcmpl.py
Normal file
15
pkg/provider/modelmgr/apis/moonshotchatcmpl.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ....core import app
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import api
|
||||
|
||||
|
||||
@api.requester_class("moonshot-chat-completions")
|
||||
class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""Moonshot ChatCompletion API 请求器"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions']
|
||||
self.ap = ap
|
||||
@@ -5,7 +5,7 @@ import typing
|
||||
import pydantic
|
||||
|
||||
from . import api
|
||||
from . import token, tokenizer
|
||||
from . import token
|
||||
|
||||
|
||||
class LLMModelInfo(pydantic.BaseModel):
|
||||
@@ -19,11 +19,7 @@ class LLMModelInfo(pydantic.BaseModel):
|
||||
|
||||
requester: api.LLMAPIRequester
|
||||
|
||||
tokenizer: 'tokenizer.LLMTokenizer'
|
||||
|
||||
tool_call_supported: typing.Optional[bool] = False
|
||||
|
||||
max_tokens: typing.Optional[int] = 2048
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -2,4 +2,4 @@ class RequesterError(Exception):
|
||||
"""Base class for all Requester errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__("模型请求失败: "+message)
|
||||
super().__init__("模型请求失败: "+message)
|
||||
109
pkg/provider/modelmgr/modelmgr.py
Normal file
109
pkg/provider/modelmgr/modelmgr.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import aiohttp
|
||||
|
||||
from . import entities
|
||||
from ...core import app
|
||||
|
||||
from . import token, api
|
||||
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl
|
||||
|
||||
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""模型管理器"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
model_list: list[entities.LLMModelInfo]
|
||||
|
||||
requesters: dict[str, api.LLMAPIRequester]
|
||||
|
||||
token_mgrs: dict[str, token.TokenManager]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.model_list = []
|
||||
self.requesters = {}
|
||||
self.token_mgrs = {}
|
||||
|
||||
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
|
||||
"""通过名称获取模型
|
||||
"""
|
||||
for model in self.model_list:
|
||||
if model.name == name:
|
||||
return model
|
||||
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
# 初始化token_mgr, requester
|
||||
for k, v in self.ap.provider_cfg.data['keys'].items():
|
||||
self.token_mgrs[k] = token.TokenManager(k, v)
|
||||
|
||||
for api_cls in api.preregistered_requesters:
|
||||
api_inst = api_cls(self.ap)
|
||||
await api_inst.initialize()
|
||||
self.requesters[api_inst.name] = api_inst
|
||||
|
||||
# 尝试从api获取最新的模型信息
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
method="GET",
|
||||
url=FETCH_MODEL_LIST_URL,
|
||||
# 参数
|
||||
params={
|
||||
"version": self.ap.ver_mgr.get_current_version()
|
||||
},
|
||||
) as resp:
|
||||
model_list = (await resp.json())['data']['list']
|
||||
|
||||
for model in model_list:
|
||||
|
||||
for index, local_model in enumerate(self.ap.llm_models_meta.data['list']):
|
||||
if model['name'] == local_model['name']:
|
||||
self.ap.llm_models_meta.data['list'][index] = model
|
||||
break
|
||||
else:
|
||||
self.ap.llm_models_meta.data['list'].append(model)
|
||||
|
||||
await self.ap.llm_models_meta.dump_config()
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.debug(f'获取最新模型列表失败: {e}')
|
||||
|
||||
default_model_info: entities.LLMModelInfo = None
|
||||
|
||||
for model in self.ap.llm_models_meta.data['list']:
|
||||
if model['name'] == 'default':
|
||||
default_model_info = entities.LLMModelInfo(
|
||||
name=model['name'],
|
||||
model_name=None,
|
||||
token_mgr=self.token_mgrs[model['token_mgr']],
|
||||
requester=self.requesters[model['requester']],
|
||||
tool_call_supported=model['tool_call_supported']
|
||||
)
|
||||
break
|
||||
|
||||
for model in self.ap.llm_models_meta.data['list']:
|
||||
|
||||
try:
|
||||
|
||||
model_name = model.get('model_name', default_model_info.model_name)
|
||||
token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr
|
||||
requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
|
||||
tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported)
|
||||
|
||||
model_info = entities.LLMModelInfo(
|
||||
name=model['name'],
|
||||
model_name=model_name,
|
||||
token_mgr=token_mgr,
|
||||
requester=requester,
|
||||
tool_call_supported=tool_call_supported
|
||||
)
|
||||
self.model_list.append(model_info)
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"初始化模型 {model['name']} 失败: {e} ,请检查配置文件")
|
||||
@@ -6,6 +6,8 @@ import pydantic
|
||||
|
||||
|
||||
class TokenManager():
|
||||
"""鉴权 Token 管理器
|
||||
"""
|
||||
|
||||
provider: str
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
|
||||
class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""LLM API请求器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def request(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,242 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import entities
|
||||
from ...core import app
|
||||
|
||||
from .apis import chatcmpl
|
||||
from . import token
|
||||
from .tokenizers import tiktoken
|
||||
|
||||
|
||||
class ModelManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
model_list: list[entities.LLMModelInfo]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.model_list = []
|
||||
|
||||
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
|
||||
"""通过名称获取模型
|
||||
"""
|
||||
for model in self.model_list:
|
||||
if model.name == name:
|
||||
return model
|
||||
raise ValueError(f"不支持模型: {name} , 请检查配置文件")
|
||||
|
||||
async def initialize(self):
|
||||
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
|
||||
await openai_chat_completion.initialize()
|
||||
openai_token_mgr = token.TokenManager(self.ap, list(self.ap.provider_cfg.data['openai-config']['api-keys']))
|
||||
|
||||
tiktoken_tokenizer = tiktoken.Tiktoken(self.ap)
|
||||
|
||||
model_list = [
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-3.5-turbo",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=4096
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-3.5-turbo-1106",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=16385
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-3.5-turbo-16k",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=16385
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-3.5-turbo-0613",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=4096
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-3.5-turbo-16k-0613",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=16385
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-3.5-turbo-0301",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=4096
|
||||
)
|
||||
]
|
||||
|
||||
self.model_list.extend(model_list)
|
||||
|
||||
gpt4_model_list = [
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-4-0125-preview",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=128000
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-4-turbo-preview",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=128000
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-4-1106-preview",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=128000
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-4-vision-preview",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=128000
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-4",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=8192
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-4-0613",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=8192
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-4-32k",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=32768
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-4-32k-0613",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=True,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=32768
|
||||
)
|
||||
]
|
||||
|
||||
self.model_list.extend(gpt4_model_list)
|
||||
|
||||
one_api_model_list = [
|
||||
entities.LLMModelInfo(
|
||||
name="OneAPI/SparkDesk",
|
||||
model_name='SparkDesk',
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=False,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=8192
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="OneAPI/chatglm_pro",
|
||||
model_name='chatglm_pro',
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=False,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=128000
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="OneAPI/chatglm_std",
|
||||
model_name='chatglm_std',
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=False,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=128000
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="OneAPI/chatglm_lite",
|
||||
model_name='chatglm_lite',
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=False,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=128000
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="OneAPI/qwen-v1",
|
||||
model_name='qwen-v1',
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=False,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=6000
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="OneAPI/qwen-plus-v1",
|
||||
model_name='qwen-plus-v1',
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=False,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=30000
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="OneAPI/ERNIE-Bot",
|
||||
model_name='ERNIE-Bot',
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=False,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=2000
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="OneAPI/ERNIE-Bot-turbo",
|
||||
model_name='ERNIE-Bot-turbo',
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=False,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=7000
|
||||
),
|
||||
entities.LLMModelInfo(
|
||||
name="OneAPI/gemini-pro",
|
||||
model_name='gemini-pro',
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
tool_call_supported=False,
|
||||
tokenizer=tiktoken_tokenizer,
|
||||
max_tokens=30720
|
||||
),
|
||||
]
|
||||
|
||||
self.model_list.extend(one_api_model_list)
|
||||
@@ -1,29 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from .. import entities as llm_entities
|
||||
from . import entities
|
||||
|
||||
|
||||
class LLMTokenizer(metaclass=abc.ABCMeta):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化分词器
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def count_token(
|
||||
self,
|
||||
messages: list[llm_entities.Message],
|
||||
model: entities.LLMModelInfo
|
||||
) -> int:
|
||||
pass
|
||||
@@ -1,28 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import tiktoken
|
||||
|
||||
from .. import tokenizer
|
||||
from ... import entities as llm_entities
|
||||
from .. import entities
|
||||
|
||||
|
||||
class Tiktoken(tokenizer.LLMTokenizer):
|
||||
|
||||
async def count_token(
|
||||
self,
|
||||
messages: list[llm_entities.Message],
|
||||
model: entities.LLMModelInfo
|
||||
) -> int:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model.name)
|
||||
except KeyError:
|
||||
# print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += len(encoding.encode(message.role))
|
||||
num_tokens += len(encoding.encode(message.content if message.content is not None else ''))
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
@@ -6,6 +6,8 @@ from ...core import app, entities as core_entities
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""会话管理器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -39,6 +41,8 @@ class SessionManager:
|
||||
return session
|
||||
|
||||
async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation:
|
||||
"""获取对话或创建对话"""
|
||||
|
||||
if not session.conversations:
|
||||
session.conversations = []
|
||||
|
||||
@@ -46,7 +50,7 @@ class SessionManager:
|
||||
conversation = core_entities.Conversation(
|
||||
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
|
||||
messages=[],
|
||||
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['openai-config']['chat-completions-params']['model']),
|
||||
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']),
|
||||
use_funcs=await self.ap.tool_mgr.get_all_functions(),
|
||||
)
|
||||
session.conversations.append(conversation)
|
||||
|
||||
@@ -10,5 +10,7 @@ class Prompt(pydantic.BaseModel):
|
||||
"""供AI使用的Prompt"""
|
||||
|
||||
name: str
|
||||
"""名称"""
|
||||
|
||||
messages: list[entities.Message]
|
||||
"""消息列表"""
|
||||
|
||||
@@ -1,13 +1,27 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
preregistered_loaders: list[typing.Type[PromptLoader]] = []
|
||||
|
||||
def loader_class(name: str):
|
||||
|
||||
def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]:
|
||||
cls.name = name
|
||||
preregistered_loaders.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PromptLoader(metaclass=abc.ABCMeta):
|
||||
"""Prompt加载器抽象类
|
||||
"""
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -22,7 +36,7 @@ class PromptLoader(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""加载Prompt,存放到prompts列表中
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -8,14 +8,15 @@ from .. import entities
|
||||
from ....provider import entities as llm_entities
|
||||
|
||||
|
||||
@loader.loader_class("full-scenario")
|
||||
class ScenarioPromptLoader(loader.PromptLoader):
|
||||
"""加载scenario目录下的json"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
for file in os.listdir("data/scenarios"):
|
||||
with open("data/scenarios/{}".format(file), "r", encoding="utf-8") as f:
|
||||
for file in os.listdir("data/scenario"):
|
||||
with open("data/scenario/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
file_json = json.loads(file_str)
|
||||
|
||||
@@ -6,6 +6,7 @@ from .. import entities
|
||||
from ....provider import entities as llm_entities
|
||||
|
||||
|
||||
@loader.loader_class("normal")
|
||||
class SingleSystemPromptLoader(loader.PromptLoader):
|
||||
"""配置文件中的单条system prompt的prompt加载器
|
||||
"""
|
||||
|
||||
@@ -6,6 +6,8 @@ from .loaders import single, scenario
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""Prompt管理器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -18,14 +20,18 @@ class PromptManager:
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
loader_map = {
|
||||
"normal": single.SingleSystemPromptLoader,
|
||||
"full_scenario": scenario.ScenarioPromptLoader
|
||||
}
|
||||
mode_name = self.ap.provider_cfg.data['prompt-mode']
|
||||
|
||||
loader_cls = loader_map[self.ap.provider_cfg.data['prompt-mode']]
|
||||
loader_class = None
|
||||
|
||||
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
|
||||
for loader_cls in loader.preregistered_loaders:
|
||||
if loader_cls.name == mode_name:
|
||||
loader_class = loader_cls
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
|
||||
|
||||
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
|
||||
|
||||
await self.loader_inst.initialize()
|
||||
await self.loader_inst.load()
|
||||
|
||||
@@ -5,6 +5,7 @@ import traceback
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
from ...plugin import context as plugin_context
|
||||
|
||||
|
||||
class ToolManager:
|
||||
@@ -28,6 +29,15 @@ class ToolManager:
|
||||
return function
|
||||
return None
|
||||
|
||||
async def get_function_and_plugin(self, name: str) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]:
|
||||
"""获取函数和插件
|
||||
"""
|
||||
for plugin in self.ap.plugin_mgr.plugins:
|
||||
for function in plugin.content_functions:
|
||||
if function.name == name:
|
||||
return function, plugin
|
||||
return None, None
|
||||
|
||||
async def get_all_functions(self) -> list[entities.LLMFunction]:
|
||||
"""获取所有函数
|
||||
"""
|
||||
@@ -68,7 +78,7 @@ class ToolManager:
|
||||
|
||||
try:
|
||||
|
||||
function = await self.get_function(name)
|
||||
function, plugin = await self.get_function_and_plugin(name)
|
||||
if function is None:
|
||||
return None
|
||||
|
||||
@@ -79,7 +89,7 @@ class ToolManager:
|
||||
**parameters
|
||||
}
|
||||
|
||||
return await function.func(**parameters)
|
||||
return await function.func(plugin, **parameters)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
|
||||
traceback.print_exc()
|
||||
|
||||
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user